diff --git a/.travis.yml b/.travis.yml index 3fa8aab..3179654 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,8 +1,7 @@ language: python python: - - "2.7" - - "3.4" - "3.5" + - "3.6" install: - "pip install ." diff --git a/athena_cli.py b/athena_cli.py index b9f1bca..13cb167 100755 --- a/athena_cli.py +++ b/athena_cli.py @@ -9,6 +9,7 @@ import sys import time import uuid +import itertools import boto3 import botocore @@ -16,18 +17,61 @@ from botocore.exceptions import ClientError, ParamValidationError from tabulate import tabulate -LESS = "less -FXRSn" +LESS = "less -FXRn" +LESS_TRUNC = "less -FXRSn" HISTORY_FILE_SIZE = 500 -__version__ = '0.1.8' +__version__ = '0.2.0' + +def output_results(athena, format, execution_id, output, is_shell): + results = athena.get_query_results(execution_id) + headers = results[0] + counter = itertools.count(start=1) + rows = zip(results[1], counter) + + try: + if format in ['CSV', 'CSV_HEADER', 'TSV', 'TSV_HEADER']: + if format in ['TSV', 'TSV_HEADER']: + delim = '\t' + quote = csv.QUOTE_NONE + esc = '\\' + else: + delim = ',' + quote = csv.QUOTE_ALL + esc = None + + csv_writer = csv.writer(output, delimiter=delim, quoting=quote, escapechar=esc) + + if format in ['CSV_HEADER', 'TSV_HEADER']: + csv_writer.writerow(headers) + csv_writer.writerows([row for row, count in rows]) + + elif format == 'VERTICAL': + for row, count in rows: + output.write('--[RECORD {}]--'.format(count)) + output.write('\n') + output.write(tabulate(zip(*[headers, row]), tablefmt='presto')) + output.write('\n') + + else: # ALIGNED + output.write(tabulate([row for row, count in rows], headers=headers, tablefmt='presto')) + output.write('\n') + + output.flush() + except IOError as x: + # quitting the less process in shell causes an IOError, so ignore + if not is_shell: + raise x + + return next(counter) - 1 class AthenaBatch(object): - def __init__(self, athena, db=None, format='CSV'): + def __init__(self, athena, db=None, format=None): self.athena = athena self.dbname = db - self.format = format + self.format = format or 'CSV' def execute(self, statement): execution_id = self.athena.start_query_execution(self.dbname, statement) @@ -42,58 +86,38 @@ def execute(self, statement): time.sleep(0.2) # 200ms if status == 'SUCCEEDED': - results = self.athena.get_query_results(execution_id) - headers = [h['Name'].encode("utf-8") for h in results['ResultSet']['ResultSetMetadata']['ColumnInfo']] - - if self.format in ['CSV', 'CSV_HEADER']: - csv_writer = csv.writer(sys.stdout, quoting=csv.QUOTE_ALL) - if self.format == 'CSV_HEADER': - csv_writer.writerow(headers) - csv_writer.writerows([[text.encode("utf-8") for text in row] for row in self.athena.yield_rows(results, headers)]) - elif self.format == 'TSV': - print(tabulate([row for row in self.athena.yield_rows(results, headers)], tablefmt='tsv')) - elif self.format == 'TSV_HEADER': - print(tabulate([row for row in self.athena.yield_rows(results, headers)], headers=headers, tablefmt='tsv')) - elif self.format == 'VERTICAL': - for num, row in enumerate(self.athena.yield_rows(results, headers)): - print('--[RECORD {}]--'.format(num+1)) - print(tabulate(zip(*[headers, row]), tablefmt='presto')) - else: # ALIGNED - print(tabulate([x for x in self.athena.yield_rows(results, headers)], headers=headers, tablefmt='presto')) + output_results(self.athena, self.format, execution_id, sys.stdout, is_shell=False) if status == 'FAILED': print(stats['QueryExecution']['Status']['StateChangeReason']) -try: - del cmd.Cmd.do_show # "show" is an Athena command -except AttributeError: - # "show" was removed from Cmd2 0.8.0 - pass - class AthenaShell(cmd.Cmd, object): - multilineCommands = ['WITH', 'SELECT', 'ALTER', 'CREATE', 'DESCRIBE', 'DROP', 'MSCK', 'SHOW', 'USE', 'VALUES'] + multiline_commands = ['WITH', 'SELECT', 'ALTER', 'CREATE', 'DESCRIBE', 'DROP', 'MSCK', 'SHOW', 'USE', 'VALUES'] + multiline_commands.extend([c.lower() for c in multiline_commands]) allow_cli_args = False + service_name = 'athena' - def __init__(self, athena, db=None): + def __init__(self, athena, db=None, format=None): cmd.Cmd.__init__(self) + # allow setting of the output format interactivately + self.settable['format'] = 'Output format'; + self.athena = athena self.dbname = db + self.format = format or 'ALIGNED' self.execution_id = None - self.row_count = 0 - self.set_prompt() - self.pager = os.environ.get('ATHENA_CLI_PAGER', LESS).split(' ') self.hist_file = os.path.join(os.path.expanduser("~"), ".athena_history") self.init_history() def set_prompt(self): - self.prompt = 'athena:%s> ' % self.dbname if self.dbname else 'athena> ' + self.prompt = '%s:%s> ' % (self.service_name, self.dbname) if self.dbname else '%s> ' % self.service_name def cmdloop_with_cancel(self, intro=None): try: @@ -134,6 +158,7 @@ def do_help(self, arg): help_output = """ Supported commands: QUIT +EXIT SELECT ALTER DATABASE ALTER TABLE @@ -157,9 +182,11 @@ def do_help(self, arg): print(help_output) def do_quit(self, arg): - print() return -1 + def do_exit(self, arg): + return self.do_quit(arg) + def do_EOF(self, arg): return self.do_quit(arg) @@ -169,17 +196,19 @@ def do_use(self, schema): def do_set(self, arg): try: - statement, param_name, val = arg.parsed.raw.split(None, 2) + statement, param_name, val = arg.raw.split(None, 2) val = val.strip() param_name = param_name.strip().lower() if param_name == 'debug': self.athena.debug = cmd.cast(True, val) + elif param_name == 'format': + arg = "format " + val.upper() except (ValueError, AttributeError): - self.do_show(arg) + pass super(AthenaShell, self).do_set(arg) def default(self, line): - self.execution_id = self.athena.start_query_execution(self.dbname, line.full_parsed_statement()) + self.execution_id = self.athena.start_query_execution(self.dbname, line.raw) if not self.execution_id: return @@ -197,15 +226,10 @@ def default(self, line): sys.stdout.flush() if status == 'SUCCEEDED': - results = self.athena.get_query_results(self.execution_id) - headers = [h['Name'] for h in results['ResultSet']['ResultSetMetadata']['ColumnInfo']] - row_count = len(results['ResultSet']['Rows']) - - if headers and len(results['ResultSet']['Rows']) and results['ResultSet']['Rows'][0]['Data'][0].get('VarCharValue', None) == headers[0]: - row_count -= 1 # don't count header - - process = subprocess.Popen(self.pager, stdin=subprocess.PIPE) - process.stdin.write(tabulate([x for x in self.athena.yield_rows(results, headers)], headers=headers, tablefmt='presto').encode('utf-8')) + less = LESS_TRUNC if self.format == 'TRUNCATE' else LESS + pager = os.environ.get('ATHENA_CLI_PAGER', less).split(' ') + process = subprocess.Popen(pager, stdin=subprocess.PIPE, universal_newlines=True) + row_count = output_results(self.athena, self.format, self.execution_id, process.stdin, is_shell=True) process.communicate() print('(%s rows)\n' % row_count) @@ -233,7 +257,8 @@ class Athena(object): def __init__(self, profile, region=None, bucket=None, debug=False, encryption=False): self.session = boto3.Session(profile_name=profile, region_name=region) - self.athena = self.session.client('athena') + session_config = botocore.config.Config(user_agent='athena-cli') + self.athena = self.session.client('athena', config=session_config) self.region = region or os.environ.get('AWS_DEFAULT_REGION', None) or self.session.region_name @@ -280,18 +305,27 @@ def get_query_execution(self, execution_id): print(e) def get_query_results(self, execution_id): + results = None try: - results = None paginator = self.athena.get_paginator('get_query_results') page_iterator = paginator.paginate( - QueryExecutionId=execution_id + QueryExecutionId=execution_id, + PaginationConfig={'PageSize':1000} ) - for page in page_iterator: - if results is None: - results = page - else: - results['ResultSet']['Rows'].extend(page['ResultSet']['Rows']) + pages = iter(page_iterator) + first_page = next(pages) # get first page so we can retrieve metadata for header + + headers = list(h['Name'] for h in first_page['ResultSet']['ResultSetMetadata']['ColumnInfo']) + first_row = None if len(first_page['ResultSet']['Rows']) == 0 else list(self.get_col_value(col) for col in first_page['ResultSet']['Rows'][0]['Data']) + rows = self.yield_rows(first_page, pages) + + # certain requests return the header as the first row, so skip it + if first_row == headers: + next(rows) + + results = (headers, rows) + except ClientError as e: sys.exit(e) @@ -300,6 +334,7 @@ def get_query_results(self, execution_id): return results + def stop_query_execution(self, execution_id): try: return self.athena.stop_query_execution( @@ -309,12 +344,16 @@ def stop_query_execution(self, execution_id): sys.exit(e) @staticmethod - def yield_rows(results, headers): - for row in results['ResultSet']['Rows']: - # https://forums.aws.amazon.com/thread.jspa?threadID=256505 - if headers and row['Data'][0].get('VarCharValue', None) == headers[0]: - continue # skip header - yield [d.get('VarCharValue', 'NULL') for d in row['Data']] + def get_col_value(col): + return col.get('VarCharValue', 'NULL') + + @staticmethod + def yield_rows(first_page, pages): + for row in first_page['ResultSet']['Rows']: + yield [Athena.get_col_value(col) for col in row['Data']] + for page in pages: + for row in page['ResultSet']['Rows']: + yield [Athena.get_col_value(col) for col in row['Data']] def console_link(self, execution_id): return 'https://{0}.console.aws.amazon.com/athena/home?force®ion={0}#query/history/{1}'.format(self.region, execution_id) @@ -404,7 +443,7 @@ def main(): batch = AthenaBatch(athena, db=args.schema, format=args.format) batch.execute(statement=args.execute) else: - shell = AthenaShell(athena, db=args.schema) + shell = AthenaShell(athena, db=args.schema, format=args.format) shell.cmdloop_with_cancel() if __name__ == '__main__': diff --git a/setup.py b/setup.py index 29334d9..bce3f12 100644 --- a/setup.py +++ b/setup.py @@ -1,7 +1,7 @@ from setuptools import setup, find_packages -version = '0.1.8' +version = '0.2.0' setup( name="athena-cli", @@ -17,8 +17,8 @@ ], install_requires=[ 'boto3', - 'cmd2', - 'tabulate>=0.8.1' + 'cmd2>=0.9.4', + 'tabulate>=0.8.2' ], include_package_data=True, zip_safe=True, @@ -30,5 +30,6 @@ keywords='aws athena presto cli', classifiers=[ 'Topic :: Utilities' - ] + ], + python_requires='>=3.5' )