Skip to content
This repository was archived by the owner on Feb 10, 2023. It is now read-only.
Open
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
152 changes: 100 additions & 52 deletions athena_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,25 +9,74 @@
import sys
import time
import uuid
import itertools

import boto3
import botocore
import cmd2 as cmd
from botocore.exceptions import ClientError, ParamValidationError
from tabulate import tabulate

LESS = "less -FXRSn"
LESS = "less -FXRn"
LESS_TRUNC = "less -FXRSn"
HISTORY_FILE_SIZE = 500

__version__ = '0.1.8'
__version__ = '0.2.0'

def output_results(athena, format, execution_id, output, is_shell):
results = athena.get_query_results(execution_id)
headers = results[0]
counter = itertools.count(start=1)
rows = itertools.izip(results[1], counter)
count = 0

try:
if format in ['CSV', 'CSV_HEADER', 'TSV', 'TSV_HEADER']:
if format in ['TSV', 'TSV_HEADER']:
delim = '\t'
quote = csv.QUOTE_NONE
esc = '\\'
else:
delim = ','
quote = csv.QUOTE_ALL
esc = None

csv_writer = csv.writer(output, delimiter=delim, quoting=quote, escapechar=esc)

if format in ['CSV_HEADER', 'TSV_HEADER']:
csv_writer.writerow(encode(headers, 'utf-8'))
csv_writer.writerows([encode(row, 'utf-8') for row, count in rows])

elif format == 'VERTICAL':
for row, count in rows:
output.write('--[RECORD {}]--'.format(count))
output.write('\n')
output.write(tabulate(zip(*[headers, row]), tablefmt='presto').encode('utf-8'))
output.write('\n')

else: # ALIGNED
output.write(tabulate([row for row, count in rows], headers=headers, tablefmt='presto').encode('utf-8'))
output.write('\n')

output.flush()
except IOError as x:
# quitting the less process in shell causes an IOError, so ignore
if not is_shell:
raise x

return count


def encode(row, charset):
return [val.encode(charset) for val in row]


class AthenaBatch(object):

def __init__(self, athena, db=None, format='CSV'):
def __init__(self, athena, db=None, format=None):
self.athena = athena
self.dbname = db
self.format = format
self.format = 'CSV' if format is None else format
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A more idiomatic way of writing this is ...

self.format = format or 'CSV'

However, why did you remove the default format from the method signature?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The default wasn't being used anymore, it is always set, so it felt redundant having the default

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I updated the code as you suggested...


def execute(self, statement):
execution_id = self.athena.start_query_execution(self.dbname, statement)
Expand All @@ -42,24 +91,7 @@ def execute(self, statement):
time.sleep(0.2) # 200ms

if status == 'SUCCEEDED':
results = self.athena.get_query_results(execution_id)
headers = [h['Name'].encode("utf-8") for h in results['ResultSet']['ResultSetMetadata']['ColumnInfo']]

if self.format in ['CSV', 'CSV_HEADER']:
csv_writer = csv.writer(sys.stdout, quoting=csv.QUOTE_ALL)
if self.format == 'CSV_HEADER':
csv_writer.writerow(headers)
csv_writer.writerows([[text.encode("utf-8") for text in row] for row in self.athena.yield_rows(results, headers)])
elif self.format == 'TSV':
print(tabulate([row for row in self.athena.yield_rows(results, headers)], tablefmt='tsv'))
elif self.format == 'TSV_HEADER':
print(tabulate([row for row in self.athena.yield_rows(results, headers)], headers=headers, tablefmt='tsv'))
elif self.format == 'VERTICAL':
for num, row in enumerate(self.athena.yield_rows(results, headers)):
print('--[RECORD {}]--'.format(num+1))
print(tabulate(zip(*[headers, row]), tablefmt='presto'))
else: # ALIGNED
print(tabulate([x for x in self.athena.yield_rows(results, headers)], headers=headers, tablefmt='presto'))
output_results(self.athena, self.format, execution_id, sys.stdout, False)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For boolean parameters I would include the paramater name eg ..., is_shell=False).

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, agreed. Done.


if status == 'FAILED':
print(stats['QueryExecution']['Status']['StateChangeReason'])
Expand All @@ -73,27 +105,29 @@ def execute(self, statement):

class AthenaShell(cmd.Cmd, object):

multilineCommands = ['WITH', 'SELECT', 'ALTER', 'CREATE', 'DESCRIBE', 'DROP', 'MSCK', 'SHOW', 'USE', 'VALUES']
multilineCommands = ['WITH', 'SELECT', 'ALTER', 'CREATE', 'DESCRIBE', 'DROP', 'MSCK', 'SHOW', 'USE', 'VALUES', 'with', 'select', 'alter', 'create', 'describe', 'drop', 'msck', 'show', 'use', 'values']
allow_cli_args = False
service_name = 'athena'

def __init__(self, athena, db=None):
def __init__(self, athena, db=None, format=None):
cmd.Cmd.__init__(self)

# allow setting of the output format interactivately
self.settable['format'] = 'Output format';

self.athena = athena
self.dbname = db
self.format = 'ALIGNED' if format is None else format

self.execution_id = None

self.row_count = 0

self.set_prompt()
self.pager = os.environ.get('ATHENA_CLI_PAGER', LESS).split(' ')

self.hist_file = os.path.join(os.path.expanduser("~"), ".athena_history")
self.init_history()

def set_prompt(self):
self.prompt = 'athena:%s> ' % self.dbname if self.dbname else 'athena> '
self.prompt = '%s:%s> ' % (self.service_name, self.dbname) if self.dbname else '%s> ' % self.service_name

def cmdloop_with_cancel(self, intro=None):
try:
Expand Down Expand Up @@ -134,6 +168,7 @@ def do_help(self, arg):
help_output = """
Supported commands:
QUIT
EXIT
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do EXIT and QUIT do exactly the same thing?

Copy link
Copy Markdown
Author

@bryanck bryanck Sep 5, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, "exit" is supported by other CLIs (like the presto-cli) so some users liked to have it

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SELECT
ALTER DATABASE <schema>
ALTER TABLE <table>
Expand All @@ -157,9 +192,11 @@ def do_help(self, arg):
print(help_output)

def do_quit(self, arg):
print()
return -1

def do_exit(self, arg):
return self.do_quit(arg)

def do_EOF(self, arg):
return self.do_quit(arg)

Expand All @@ -174,6 +211,8 @@ def do_set(self, arg):
param_name = param_name.strip().lower()
if param_name == 'debug':
self.athena.debug = cmd.cast(True, val)
elif param_name == 'format':
arg = "format " + val.upper()
except (ValueError, AttributeError):
self.do_show(arg)
super(AthenaShell, self).do_set(arg)
Expand All @@ -197,15 +236,10 @@ def default(self, line):
sys.stdout.flush()

if status == 'SUCCEEDED':
results = self.athena.get_query_results(self.execution_id)
headers = [h['Name'] for h in results['ResultSet']['ResultSetMetadata']['ColumnInfo']]
row_count = len(results['ResultSet']['Rows'])

if headers and len(results['ResultSet']['Rows']) and results['ResultSet']['Rows'][0]['Data'][0].get('VarCharValue', None) == headers[0]:
row_count -= 1 # don't count header

process = subprocess.Popen(self.pager, stdin=subprocess.PIPE)
process.stdin.write(tabulate([x for x in self.athena.yield_rows(results, headers)], headers=headers, tablefmt='presto').encode('utf-8'))
less = LESS_TRUNC if self.format == 'TRUNCATE' else LESS
pager = os.environ.get('ATHENA_CLI_PAGER', less).split(' ')
process = subprocess.Popen(pager, stdin=subprocess.PIPE)
row_count = output_results(self.athena, self.format, self.execution_id, process.stdin, True)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ibid.

process.communicate()
print('(%s rows)\n' % row_count)

Expand Down Expand Up @@ -233,7 +267,8 @@ class Athena(object):
def __init__(self, profile, region=None, bucket=None, debug=False, encryption=False):

self.session = boto3.Session(profile_name=profile, region_name=region)
self.athena = self.session.client('athena')
session_config = botocore.config.Config(user_agent='athena-cli')
self.athena = self.session.client('athena', config=session_config)

self.region = region or os.environ.get('AWS_DEFAULT_REGION', None) or self.session.region_name

Expand Down Expand Up @@ -284,14 +319,23 @@ def get_query_results(self, execution_id):
results = None
paginator = self.athena.get_paginator('get_query_results')
page_iterator = paginator.paginate(
QueryExecutionId=execution_id
QueryExecutionId=execution_id,
PaginationConfig={'PageSize':1000}
)

for page in page_iterator:
if results is None:
results = page
else:
results['ResultSet']['Rows'].extend(page['ResultSet']['Rows'])
pages = iter(page_iterator)
first_page = pages.next() # get first page so we can retrieve metadata for header

headers = list(h['Name'] for h in first_page['ResultSet']['ResultSetMetadata']['ColumnInfo'])
first_row = None if len(first_page['ResultSet']['Rows']) == 0 else list(self.get_col_value(col) for col in first_page['ResultSet']['Rows'][0]['Data'])
rows = self.yield_rows(first_page, pages)

# certain requests return the header as the first row, so skip it
if first_row == headers:
rows.next()

return (headers, rows)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This return has made the code after this try/exception block unreachable.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Whoops. This should be fixed now


except ClientError as e:
sys.exit(e)

Expand All @@ -309,12 +353,16 @@ def stop_query_execution(self, execution_id):
sys.exit(e)

@staticmethod
def yield_rows(results, headers):
for row in results['ResultSet']['Rows']:
# https://forums.aws.amazon.com/thread.jspa?threadID=256505
if headers and row['Data'][0].get('VarCharValue', None) == headers[0]:
continue # skip header
yield [d.get('VarCharValue', 'NULL') for d in row['Data']]
def get_col_value(col):
return col.get('VarCharValue', 'NULL')

@staticmethod
def yield_rows(first_page, pages):
for row in first_page['ResultSet']['Rows']:
yield [Athena.get_col_value(col) for col in row['Data']]
for page in pages:
for row in page['ResultSet']['Rows']:
yield [Athena.get_col_value(col) for col in row['Data']]

def console_link(self, execution_id):
return 'https://{0}.console.aws.amazon.com/athena/home?force&region={0}#query/history/{1}'.format(self.region, execution_id)
Expand Down Expand Up @@ -404,7 +452,7 @@ def main():
batch = AthenaBatch(athena, db=args.schema, format=args.format)
batch.execute(statement=args.execute)
else:
shell = AthenaShell(athena, db=args.schema)
shell = AthenaShell(athena, db=args.schema, format=args.format)
shell.cmdloop_with_cancel()

if __name__ == '__main__':
Expand Down