-
Notifications
You must be signed in to change notification settings - Fork 14
Enhancements to athena-cli #36
base: master
Are you sure you want to change the base?
Changes from 7 commits
8b2ff9f
d49639e
8272554
41f70bd
af26147
b5e774f
a808d4d
0e8dbc3
279b22d
d0fbaad
98fc251
7f90919
a01ffc8
72f119d
b69c280
4fd817c
7ec5534
33a2f63
81d8ce9
865fe2f
395fe97
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -9,25 +9,74 @@ | |
| import sys | ||
| import time | ||
| import uuid | ||
| import itertools | ||
|
|
||
| import boto3 | ||
| import botocore | ||
| import cmd2 as cmd | ||
| from botocore.exceptions import ClientError, ParamValidationError | ||
| from tabulate import tabulate | ||
|
|
||
| LESS = "less -FXRSn" | ||
| LESS = "less -FXRn" | ||
| LESS_TRUNC = "less -FXRSn" | ||
| HISTORY_FILE_SIZE = 500 | ||
|
|
||
| __version__ = '0.1.8' | ||
| __version__ = '0.2.0' | ||
|
|
||
| def output_results(athena, format, execution_id, output, is_shell): | ||
| results = athena.get_query_results(execution_id) | ||
| headers = results[0] | ||
| counter = itertools.count(start=1) | ||
| rows = itertools.izip(results[1], counter) | ||
| count = 0 | ||
|
|
||
| try: | ||
| if format in ['CSV', 'CSV_HEADER', 'TSV', 'TSV_HEADER']: | ||
| if format in ['TSV', 'TSV_HEADER']: | ||
| delim = '\t' | ||
| quote = csv.QUOTE_NONE | ||
| esc = '\\' | ||
| else: | ||
| delim = ',' | ||
| quote = csv.QUOTE_ALL | ||
| esc = None | ||
|
|
||
| csv_writer = csv.writer(output, delimiter=delim, quoting=quote, escapechar=esc) | ||
|
|
||
| if format in ['CSV_HEADER', 'TSV_HEADER']: | ||
| csv_writer.writerow(encode(headers, 'utf-8')) | ||
| csv_writer.writerows([encode(row, 'utf-8') for row, count in rows]) | ||
|
|
||
| elif format == 'VERTICAL': | ||
| for row, count in rows: | ||
| output.write('--[RECORD {}]--'.format(count)) | ||
| output.write('\n') | ||
| output.write(tabulate(zip(*[headers, row]), tablefmt='presto').encode('utf-8')) | ||
| output.write('\n') | ||
|
|
||
| else: # ALIGNED | ||
| output.write(tabulate([row for row, count in rows], headers=headers, tablefmt='presto').encode('utf-8')) | ||
| output.write('\n') | ||
|
|
||
| output.flush() | ||
| except IOError as x: | ||
| # quitting the less process in shell causes an IOError, so ignore | ||
| if not is_shell: | ||
| raise x | ||
|
|
||
| return count | ||
|
|
||
|
|
||
| def encode(row, charset): | ||
| return [val.encode(charset) for val in row] | ||
|
|
||
|
|
||
| class AthenaBatch(object): | ||
|
|
||
| def __init__(self, athena, db=None, format='CSV'): | ||
| def __init__(self, athena, db=None, format=None): | ||
| self.athena = athena | ||
| self.dbname = db | ||
| self.format = format | ||
| self.format = 'CSV' if format is None else format | ||
|
|
||
| def execute(self, statement): | ||
| execution_id = self.athena.start_query_execution(self.dbname, statement) | ||
|
|
@@ -42,24 +91,7 @@ def execute(self, statement): | |
| time.sleep(0.2) # 200ms | ||
|
|
||
| if status == 'SUCCEEDED': | ||
| results = self.athena.get_query_results(execution_id) | ||
| headers = [h['Name'].encode("utf-8") for h in results['ResultSet']['ResultSetMetadata']['ColumnInfo']] | ||
|
|
||
| if self.format in ['CSV', 'CSV_HEADER']: | ||
| csv_writer = csv.writer(sys.stdout, quoting=csv.QUOTE_ALL) | ||
| if self.format == 'CSV_HEADER': | ||
| csv_writer.writerow(headers) | ||
| csv_writer.writerows([[text.encode("utf-8") for text in row] for row in self.athena.yield_rows(results, headers)]) | ||
| elif self.format == 'TSV': | ||
| print(tabulate([row for row in self.athena.yield_rows(results, headers)], tablefmt='tsv')) | ||
| elif self.format == 'TSV_HEADER': | ||
| print(tabulate([row for row in self.athena.yield_rows(results, headers)], headers=headers, tablefmt='tsv')) | ||
| elif self.format == 'VERTICAL': | ||
| for num, row in enumerate(self.athena.yield_rows(results, headers)): | ||
| print('--[RECORD {}]--'.format(num+1)) | ||
| print(tabulate(zip(*[headers, row]), tablefmt='presto')) | ||
| else: # ALIGNED | ||
| print(tabulate([x for x in self.athena.yield_rows(results, headers)], headers=headers, tablefmt='presto')) | ||
| output_results(self.athena, self.format, execution_id, sys.stdout, False) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For boolean parameters I would include the paramater name eg
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, agreed. Done. |
||
|
|
||
| if status == 'FAILED': | ||
| print(stats['QueryExecution']['Status']['StateChangeReason']) | ||
|
|
@@ -73,27 +105,29 @@ def execute(self, statement): | |
|
|
||
| class AthenaShell(cmd.Cmd, object): | ||
|
|
||
| multilineCommands = ['WITH', 'SELECT', 'ALTER', 'CREATE', 'DESCRIBE', 'DROP', 'MSCK', 'SHOW', 'USE', 'VALUES'] | ||
| multilineCommands = ['WITH', 'SELECT', 'ALTER', 'CREATE', 'DESCRIBE', 'DROP', 'MSCK', 'SHOW', 'USE', 'VALUES', 'with', 'select', 'alter', 'create', 'describe', 'drop', 'msck', 'show', 'use', 'values'] | ||
| allow_cli_args = False | ||
| service_name = 'athena' | ||
|
|
||
| def __init__(self, athena, db=None): | ||
| def __init__(self, athena, db=None, format=None): | ||
| cmd.Cmd.__init__(self) | ||
|
|
||
| # allow setting of the output format interactivately | ||
| self.settable['format'] = 'Output format'; | ||
|
|
||
| self.athena = athena | ||
| self.dbname = db | ||
| self.format = 'ALIGNED' if format is None else format | ||
|
|
||
| self.execution_id = None | ||
|
|
||
| self.row_count = 0 | ||
|
|
||
| self.set_prompt() | ||
| self.pager = os.environ.get('ATHENA_CLI_PAGER', LESS).split(' ') | ||
|
|
||
| self.hist_file = os.path.join(os.path.expanduser("~"), ".athena_history") | ||
| self.init_history() | ||
|
|
||
| def set_prompt(self): | ||
| self.prompt = 'athena:%s> ' % self.dbname if self.dbname else 'athena> ' | ||
| self.prompt = '%s:%s> ' % (self.service_name, self.dbname) if self.dbname else '%s> ' % self.service_name | ||
|
|
||
| def cmdloop_with_cancel(self, intro=None): | ||
| try: | ||
|
|
@@ -134,6 +168,7 @@ def do_help(self, arg): | |
| help_output = """ | ||
| Supported commands: | ||
| QUIT | ||
| EXIT | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes, "exit" is supported by other CLIs (like the presto-cli) so some users liked to have it
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| SELECT | ||
| ALTER DATABASE <schema> | ||
| ALTER TABLE <table> | ||
|
|
@@ -157,9 +192,11 @@ def do_help(self, arg): | |
| print(help_output) | ||
|
|
||
| def do_quit(self, arg): | ||
| print() | ||
| return -1 | ||
|
|
||
| def do_exit(self, arg): | ||
| return self.do_quit(arg) | ||
|
|
||
| def do_EOF(self, arg): | ||
| return self.do_quit(arg) | ||
|
|
||
|
|
@@ -174,6 +211,8 @@ def do_set(self, arg): | |
| param_name = param_name.strip().lower() | ||
| if param_name == 'debug': | ||
| self.athena.debug = cmd.cast(True, val) | ||
| elif param_name == 'format': | ||
| arg = "format " + val.upper() | ||
| except (ValueError, AttributeError): | ||
| self.do_show(arg) | ||
| super(AthenaShell, self).do_set(arg) | ||
|
|
@@ -197,15 +236,10 @@ def default(self, line): | |
| sys.stdout.flush() | ||
|
|
||
| if status == 'SUCCEEDED': | ||
| results = self.athena.get_query_results(self.execution_id) | ||
| headers = [h['Name'] for h in results['ResultSet']['ResultSetMetadata']['ColumnInfo']] | ||
| row_count = len(results['ResultSet']['Rows']) | ||
|
|
||
| if headers and len(results['ResultSet']['Rows']) and results['ResultSet']['Rows'][0]['Data'][0].get('VarCharValue', None) == headers[0]: | ||
| row_count -= 1 # don't count header | ||
|
|
||
| process = subprocess.Popen(self.pager, stdin=subprocess.PIPE) | ||
| process.stdin.write(tabulate([x for x in self.athena.yield_rows(results, headers)], headers=headers, tablefmt='presto').encode('utf-8')) | ||
| less = LESS_TRUNC if self.format == 'TRUNCATE' else LESS | ||
| pager = os.environ.get('ATHENA_CLI_PAGER', less).split(' ') | ||
| process = subprocess.Popen(pager, stdin=subprocess.PIPE) | ||
| row_count = output_results(self.athena, self.format, self.execution_id, process.stdin, True) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ibid. |
||
| process.communicate() | ||
| print('(%s rows)\n' % row_count) | ||
|
|
||
|
|
@@ -233,7 +267,8 @@ class Athena(object): | |
| def __init__(self, profile, region=None, bucket=None, debug=False, encryption=False): | ||
|
|
||
| self.session = boto3.Session(profile_name=profile, region_name=region) | ||
| self.athena = self.session.client('athena') | ||
| session_config = botocore.config.Config(user_agent='athena-cli') | ||
| self.athena = self.session.client('athena', config=session_config) | ||
|
|
||
| self.region = region or os.environ.get('AWS_DEFAULT_REGION', None) or self.session.region_name | ||
|
|
||
|
|
@@ -284,14 +319,23 @@ def get_query_results(self, execution_id): | |
| results = None | ||
| paginator = self.athena.get_paginator('get_query_results') | ||
| page_iterator = paginator.paginate( | ||
| QueryExecutionId=execution_id | ||
| QueryExecutionId=execution_id, | ||
| PaginationConfig={'PageSize':1000} | ||
| ) | ||
|
|
||
| for page in page_iterator: | ||
| if results is None: | ||
| results = page | ||
| else: | ||
| results['ResultSet']['Rows'].extend(page['ResultSet']['Rows']) | ||
| pages = iter(page_iterator) | ||
| first_page = pages.next() # get first page so we can retrieve metadata for header | ||
|
|
||
| headers = list(h['Name'] for h in first_page['ResultSet']['ResultSetMetadata']['ColumnInfo']) | ||
| first_row = None if len(first_page['ResultSet']['Rows']) == 0 else list(self.get_col_value(col) for col in first_page['ResultSet']['Rows'][0]['Data']) | ||
| rows = self.yield_rows(first_page, pages) | ||
|
|
||
| # certain requests return the header as the first row, so skip it | ||
| if first_row == headers: | ||
| rows.next() | ||
|
|
||
| return (headers, rows) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Whoops. This should be fixed now |
||
|
|
||
| except ClientError as e: | ||
| sys.exit(e) | ||
|
|
||
|
|
@@ -309,12 +353,16 @@ def stop_query_execution(self, execution_id): | |
| sys.exit(e) | ||
|
|
||
| @staticmethod | ||
| def yield_rows(results, headers): | ||
| for row in results['ResultSet']['Rows']: | ||
| # https://forums.aws.amazon.com/thread.jspa?threadID=256505 | ||
| if headers and row['Data'][0].get('VarCharValue', None) == headers[0]: | ||
| continue # skip header | ||
| yield [d.get('VarCharValue', 'NULL') for d in row['Data']] | ||
| def get_col_value(col): | ||
| return col.get('VarCharValue', 'NULL') | ||
|
|
||
| @staticmethod | ||
| def yield_rows(first_page, pages): | ||
| for row in first_page['ResultSet']['Rows']: | ||
| yield [Athena.get_col_value(col) for col in row['Data']] | ||
| for page in pages: | ||
| for row in page['ResultSet']['Rows']: | ||
| yield [Athena.get_col_value(col) for col in row['Data']] | ||
|
|
||
| def console_link(self, execution_id): | ||
| return 'https://{0}.console.aws.amazon.com/athena/home?force®ion={0}#query/history/{1}'.format(self.region, execution_id) | ||
|
|
@@ -404,7 +452,7 @@ def main(): | |
| batch = AthenaBatch(athena, db=args.schema, format=args.format) | ||
| batch.execute(statement=args.execute) | ||
| else: | ||
| shell = AthenaShell(athena, db=args.schema) | ||
| shell = AthenaShell(athena, db=args.schema, format=args.format) | ||
| shell.cmdloop_with_cancel() | ||
|
|
||
| if __name__ == '__main__': | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A more idiomatic way of writing this is ...
However, why did you remove the default format from the method signature?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The default wasn't being used anymore, it is always set, so it felt redundant having the default
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I updated the code as you suggested...