Skip to content

Commit 48b23a0

Browse files
Add special commands and tests (dbcli#155)
* Special commands create source files * Special commands add help special command * Add list databases command * Add list schemas special command * Add list tables * Add list views and list indexes * Add special command tests * Update tests to use dynamically generated resource names
1 parent 87685ea commit 48b23a0

File tree

6 files changed

+312
-14
lines changed

6 files changed

+312
-14
lines changed

build.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def unit_test():
9797
utility.exec_command(
9898
'pytest --cov mssqlcli tests/test_mssqlcliclient.py tests/test_main.py tests/test_fuzzy_completion.py '
9999
'tests/test_rowlimit.py tests/test_sqlcompletion.py tests/test_prioritization.py mssqlcli/jsonrpc/contracts/tests '
100-
'tests/test_telemetry.py',
100+
'tests/test_telemetry.py tests/test_special.py',
101101
utility.ROOT_DIR,
102102
continue_on_error=False)
103103

mssqlcli/mssqlcliclient.py

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from mssqlcli.jsonrpc.contracts import connectionservice, queryexecutestringservice as queryservice
1212
from mssqlcli.packages.parseutils.meta import ForeignKey
1313
from mssqlcli.sqltoolsclient import SqlToolsClient
14+
from packages import special
1415

1516
logger = logging.getLogger(u'mssqlcli.mssqlcliclient')
1617
time_wait_if_no_response = 0.05
@@ -121,19 +122,25 @@ def connect(self):
121122
return self.owner_uri
122123

123124
def execute_multi_statement_single_batch(self, query):
124-
# Remove spaces, EOL and semi-colons from end
125-
query = query.strip()
126-
if not query:
127-
yield None, None, None, query, False
128-
else:
129-
for sql in sqlparse.split(query):
130-
# Remove spaces, EOL and semi-colons from end
131-
sql = sql.strip().rstrip(';')
132-
if not sql:
133-
yield None, None, None, sql, False
134-
continue
135-
for rows, columns, status, statement, is_error in self.execute_single_batch_query(sql):
136-
yield rows, columns, status, statement, is_error
125+
# Try to run first as special command
126+
try:
127+
for rows, columns, status, statement, is_error in special.execute(self, query):
128+
yield rows, columns, status, statement, is_error
129+
except special.CommandNotFound:
130+
# Execute as normal sql
131+
# Remove spaces, EOL and semi-colons from end
132+
query = query.strip()
133+
if not query:
134+
yield None, None, None, query, False
135+
else:
136+
for sql in sqlparse.split(query):
137+
# Remove spaces, EOL and semi-colons from end
138+
sql = sql.strip().rstrip(';')
139+
if not sql:
140+
yield None, None, None, sql, False
141+
continue
142+
for rows, columns, status, statement, is_error in self.execute_single_batch_query(sql):
143+
yield rows, columns, status, statement, is_error
137144

138145
def execute_single_batch_query(self, query):
139146
if not self.is_connected:

mssqlcli/packages/special/__init__.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
__all__ = []
2+
3+
4+
def export(defn):
5+
"""Decorator to explicitly mark functions that are exposed in a lib."""
6+
globals()[defn.__name__] = defn
7+
__all__.append(defn.__name__)
8+
return defn
9+
10+
from . import main
11+
from . import commands

mssqlcli/packages/special/commands.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
import logging
2+
from .main import special_command, RAW_QUERY, PARSED_QUERY, NO_QUERY
3+
4+
logger = logging.getLogger('mssqlcli.commands')
5+
6+
7+
@special_command('\\l', '\\l[+] [pattern]', 'List databases.', aliases=('\\list',))
8+
def list_databases(mssqlcliclient, pattern, verbose):
9+
base_query = u'select {0} from sys.databases'
10+
if verbose:
11+
base_query = base_query.format('name, create_date, compatibility_level, collation_name')
12+
else:
13+
base_query = base_query.format('name')
14+
if pattern:
15+
base_query += " where name like '%{0}%'".format(pattern)
16+
17+
return mssqlcliclient.execute_multi_statement_single_batch(base_query)
18+
19+
20+
@special_command('\\dn', '\\dn[+] [pattern]', 'List schemas.')
21+
def list_schemas(mssqlcliclient, pattern, verbose):
22+
base_query = u'select {0} from sys.schemas'
23+
if verbose:
24+
base_query = base_query.format('name, schema_id, principal_id')
25+
else:
26+
base_query = base_query.format('name')
27+
if pattern:
28+
base_query += " where name like '%{0}%'".format(pattern)
29+
30+
return mssqlcliclient.execute_multi_statement_single_batch(base_query)
31+
32+
33+
@special_command('\\dt', '\\dt[+] [pattern]', 'List tables.')
34+
def list_tables(mssqlcliclient, pattern, verbose):
35+
base_query = u'select {0} from INFORMATION_SCHEMA.TABLES WHERE TABLE_TYPE=\'BASE TABLE\''
36+
if verbose:
37+
base_query = base_query.format('*')
38+
else:
39+
base_query = base_query.format('table_schema, table_name')
40+
if pattern:
41+
base_query += "and table_name like '%{0}%'".format(pattern)
42+
43+
return mssqlcliclient.execute_multi_statement_single_batch(base_query)
44+
45+
46+
@special_command('\\dv', '\\dv[+] [pattern]', 'List views.')
47+
def list_tables(mssqlcliclient, pattern, verbose):
48+
base_query = u'select {0} from INFORMATION_SCHEMA.TABLES WHERE TABLE_TYPE=\'VIEW\''
49+
if verbose:
50+
base_query = base_query.format('table_catalog as catalog, table_schema as schema_name, '
51+
'table_name as view_name')
52+
else:
53+
base_query = base_query.format('table_schema as schema_name, table_name as view_name')
54+
if pattern:
55+
base_query += "and table_name like '%{0}%'".format(pattern)
56+
57+
return mssqlcliclient.execute_multi_statement_single_batch(base_query)
58+
59+
60+
@special_command('\\di', '\\di[+] [pattern]', 'List indexes.')
61+
def list_indexes(mssqlcliclient, pattern, verbose):
62+
base_query = '''
63+
SELECT
64+
TableName = t.name,
65+
IndexName = ind.name,
66+
ColumnName = col.name {verbose}
67+
FROM
68+
sys.indexes ind
69+
INNER JOIN
70+
sys.index_columns ic ON ind.object_id = ic.object_id and ind.index_id = ic.index_id
71+
INNER JOIN
72+
sys.columns col ON ic.object_id = col.object_id and ic.column_id = col.column_id
73+
INNER JOIN
74+
sys.tables t ON ind.object_id = t.object_id
75+
WHERE
76+
ind.is_primary_key = 0
77+
AND ind.is_unique = 0
78+
AND ind.is_unique_constraint = 0
79+
AND t.is_ms_shipped = 0
80+
AND ind.name like '%{pattern}%'
81+
ORDER BY
82+
t.name, ind.name, ind.index_id, ic.index_column_id;
83+
'''
84+
85+
if verbose:
86+
base_query = base_query.format(verbose=',IndexId = ind.index_id, ColumnId = ic.index_column_id',
87+
pattern='{pattern}')
88+
else:
89+
base_query = base_query.format(verbose='', pattern='{pattern}')
90+
if pattern:
91+
base_query = base_query.format(pattern=pattern)
92+
else:
93+
base_query = base_query.format(pattern='')
94+
95+
return mssqlcliclient.execute_multi_statement_single_batch(base_query)

mssqlcli/packages/special/main.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
import logging
2+
from collections import namedtuple
3+
4+
from . import export
5+
6+
logger = logging.getLogger('mssqlcli.special')
7+
8+
NO_QUERY = 0
9+
PARSED_QUERY = 1
10+
RAW_QUERY = 2
11+
12+
SpecialCommand = namedtuple('SpecialCommand',
13+
['handler', 'command', 'shortcut', 'description', 'arg_type', 'hidden',
14+
'case_sensitive'])
15+
16+
COMMANDS = {}
17+
18+
19+
@export
20+
class CommandNotFound(Exception):
21+
pass
22+
23+
24+
@export
25+
def parse_special_command(sql):
26+
command, _, arg = sql.partition(' ')
27+
verbose = '+' in command
28+
command = command.strip().replace('+', '')
29+
return (command, verbose, arg.strip())
30+
31+
32+
@export
33+
def special_command(command, shortcut, description, arg_type=PARSED_QUERY,
34+
hidden=False, case_sensitive=False, aliases=()):
35+
def wrapper(wrapped):
36+
register_special_command(wrapped, command, shortcut, description,
37+
arg_type, hidden, case_sensitive, aliases)
38+
return wrapped
39+
return wrapper
40+
41+
42+
@export
43+
def register_special_command(handler, command, shortcut, description,
44+
arg_type=PARSED_QUERY, hidden=False, case_sensitive=False, aliases=()):
45+
cmd = command.lower() if not case_sensitive else command
46+
COMMANDS[cmd] = SpecialCommand(handler, command, shortcut, description,
47+
arg_type, hidden, case_sensitive)
48+
for alias in aliases:
49+
cmd = alias.lower() if not case_sensitive else alias
50+
COMMANDS[cmd] = SpecialCommand(handler, command, shortcut, description,
51+
arg_type, case_sensitive=case_sensitive,
52+
hidden=True)
53+
54+
55+
@export
56+
def execute(mssqlcliclient, sql):
57+
"""Execute a special command and return the results. If the special command
58+
is not supported a KeyError will be raised.
59+
"""
60+
command, verbose, pattern = parse_special_command(sql)
61+
62+
if (command not in COMMANDS) and (command.lower() not in COMMANDS):
63+
raise CommandNotFound
64+
65+
try:
66+
special_cmd = COMMANDS[command]
67+
except KeyError:
68+
special_cmd = COMMANDS[command.lower()]
69+
if special_cmd.case_sensitive:
70+
raise CommandNotFound('Command not found: %s' % command)
71+
72+
logger.debug(u'Executing special command {0} with argument {1}.'.format(command, pattern))
73+
74+
if special_cmd.arg_type == NO_QUERY:
75+
return special_cmd.handler()
76+
elif special_cmd.arg_type == PARSED_QUERY:
77+
return special_cmd.handler(mssqlcliclient, pattern=pattern, verbose=verbose)
78+
elif special_cmd.arg_type == RAW_QUERY:
79+
return special_cmd.handler(mssqlcliclient, query=sql)
80+
81+
82+
@special_command('help', '\\?', 'Show this help.', arg_type=NO_QUERY, aliases=('\\?', '?'))
83+
def show_help(): # All the parameters are ignored.
84+
headers = ['Command', 'Shortcut', 'Description']
85+
result = []
86+
87+
for _, value in sorted(COMMANDS.items()):
88+
if not value.hidden:
89+
result.append((value.command, value.shortcut, value.description))
90+
return [(result, headers, None, None, False)]

tests/test_special.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
import unittest
2+
import uuid
3+
from mssqlutils import create_mssql_cli_client, shutdown
4+
from mssqlcli.packages.special.main import special_command, execute, NO_QUERY
5+
6+
7+
# All tests require a live connection to a SQL Server database
8+
class SpecialCommandsTests(unittest.TestCase):
9+
session_guid = str(uuid.uuid4().hex)
10+
table1 = 'mssql_cli_table1_{0}'.format(session_guid)
11+
table2 = 'mssql_cli_table2_{0}'.format(session_guid)
12+
view = 'mssql_cli_view_{0}'.format(session_guid)
13+
database = 'mssql_cli_db_{0}'.format(session_guid)
14+
schema = 'mssql_cli_schema_{0}'.format(session_guid)
15+
index = 'mssql_cli_index_{0}'.format(session_guid)
16+
17+
@classmethod
18+
def setUpClass(cls):
19+
try:
20+
# create the database objects to test upon
21+
client = create_mssql_cli_client()
22+
list(client.execute_single_batch_query('CREATE DATABASE {0};'.format(cls.database)))
23+
list(client.execute_single_batch_query('CREATE TABLE {0} (a int, b varchar(25));'.format(cls.table1)))
24+
list(client.execute_single_batch_query('CREATE TABLE {0} (x int, y varchar(25), z bit);'.format(cls.table2)))
25+
list(client.execute_single_batch_query('CREATE VIEW {0} as SELECT a from {1};'.format(cls.view, cls.table1)))
26+
list(client.execute_single_batch_query('CREATE SCHEMA {0};'.format(cls.schema)))
27+
list(client.execute_single_batch_query('CREATE INDEX {0} ON {1} (x);'.format(cls.index, cls.table2)))
28+
finally:
29+
shutdown(client)
30+
31+
@classmethod
32+
def tearDownClass(cls):
33+
try:
34+
# delete the database objects created
35+
client = create_mssql_cli_client()
36+
list(client.execute_single_batch_query('DROP DATABASE {0};'.format(cls.database)))
37+
list(client.execute_single_batch_query('DROP INDEX {0} ON {1};'.format(cls.index, cls.table2)))
38+
list(client.execute_single_batch_query('DROP TABLE {0};'.format(cls.table1)))
39+
list(client.execute_single_batch_query('DROP TABLE {0};'.format(cls.table2)))
40+
list(client.execute_single_batch_query('DROP VIEW {0} IF EXISTS;'.format(cls.view)))
41+
list(client.execute_single_batch_query('DROP SCHEMA {0};'.format(cls.schema)))
42+
finally:
43+
shutdown(client)
44+
45+
def test_list_tables_command(self):
46+
self.command('\\dt', self.table1, min_rows_expected=2, rows_expected_pattern_query=1, cols_expected=2,
47+
cols_expected_verbose=4)
48+
49+
def test_list_views_command(self):
50+
self.command('\\dv', self.view, min_rows_expected=1, rows_expected_pattern_query=1, cols_expected=2,
51+
cols_expected_verbose=3)
52+
53+
def test_list_schemas_command(self):
54+
self.command('\\dn', self.schema, min_rows_expected=1, rows_expected_pattern_query=1, cols_expected=1,
55+
cols_expected_verbose=3)
56+
57+
def test_list_indices_command(self):
58+
self.command('\\di', self.index, min_rows_expected=1, rows_expected_pattern_query=1, cols_expected=3,
59+
cols_expected_verbose=5)
60+
61+
def test_list_databases_command(self):
62+
self.command('\\l', self.database, min_rows_expected=1, rows_expected_pattern_query=1, cols_expected=1,
63+
cols_expected_verbose=4)
64+
65+
def test_add_new_special_command(self):
66+
@special_command('\\empty', '\\empty[+]', 'returns an empty list', arg_type=NO_QUERY)
67+
def empty_list_special_command():
68+
return []
69+
70+
ret = execute(None, '\\empty')
71+
self.assertTrue(len(ret) == 0)
72+
73+
def command(self, command, pattern, min_rows_expected, rows_expected_pattern_query,
74+
cols_expected, cols_expected_verbose):
75+
try:
76+
client = create_mssql_cli_client()
77+
78+
for rows, col, message, query, is_error in \
79+
client.execute_multi_statement_single_batch(command):
80+
self.assertTrue(len(rows) >= min_rows_expected)
81+
self.assertTrue(len(col) == cols_expected)
82+
83+
# execute with pattern and verbose
84+
command = command + "+ " + pattern
85+
for rows, col, message, query, is_error in \
86+
client.execute_multi_statement_single_batch(command):
87+
self.assertTrue(len(rows) == rows_expected_pattern_query)
88+
self.assertTrue(len(col) == cols_expected_verbose)
89+
finally:
90+
shutdown(client)
91+
92+
93+
94+
if __name__ == u'__main__':
95+
unittest.main()

0 commit comments

Comments
 (0)