Skip to content
This repository was archived by the owner on Feb 10, 2023. It is now read-only.
Open
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
160 changes: 99 additions & 61 deletions athena_cli.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#!/usr/bin/env python3
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of adding a shebang here you should add "python_requires" option to setup.py. See https://packaging.python.org/guides/distributing-packages-using-setuptools/#python-requires

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed. I'll check that in in a few

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I set the python version requirement to 3.5 and also set a cmd2 version requirement

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line can be removed. It's ignored anyway as this is a python console script not a shell script.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK I removed it


import argparse
import atexit
Expand All @@ -9,25 +10,69 @@
import sys
import time
import uuid
import itertools

import boto3
import botocore
import cmd2 as cmd
from botocore.exceptions import ClientError, ParamValidationError
from tabulate import tabulate

LESS = "less -FXRSn"
LESS = "less -FXRn"
LESS_TRUNC = "less -FXRSn"
HISTORY_FILE_SIZE = 500

__version__ = '0.1.8'
__version__ = '0.2.0'

def output_results(athena, format, execution_id, output, is_shell):
results = athena.get_query_results(execution_id)
headers = results[0]
counter = itertools.count(start=1)
rows = zip(results[1], counter)

try:
if format in ['CSV', 'CSV_HEADER', 'TSV', 'TSV_HEADER']:
if format in ['TSV', 'TSV_HEADER']:
delim = '\t'
quote = csv.QUOTE_NONE
esc = '\\'
else:
delim = ','
quote = csv.QUOTE_ALL
esc = None

csv_writer = csv.writer(output, delimiter=delim, quoting=quote, escapechar=esc)

if format in ['CSV_HEADER', 'TSV_HEADER']:
csv_writer.writerow(headers)
csv_writer.writerows([row for row, count in rows])

elif format == 'VERTICAL':
for row, count in rows:
output.write('--[RECORD {}]--'.format(count))
output.write('\n')
output.write(tabulate(zip(*[headers, row]), tablefmt='presto'))
output.write('\n')

else: # ALIGNED
output.write(tabulate([row for row, count in rows], headers=headers, tablefmt='presto'))
output.write('\n')

output.flush()
except IOError as x:
# quitting the less process in shell causes an IOError, so ignore
if not is_shell:
raise x

return next(counter) - 1


class AthenaBatch(object):

def __init__(self, athena, db=None, format='CSV'):
def __init__(self, athena, db=None, format=None):
self.athena = athena
self.dbname = db
self.format = format
self.format = 'CSV' if format is None else format
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A more idiomatic way of writing this is ...

self.format = format or 'CSV'

However, why did you remove the default format from the method signature?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The default wasn't being used anymore, it is always set, so it felt redundant having the default

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I updated the code as you suggested...


def execute(self, statement):
execution_id = self.athena.start_query_execution(self.dbname, statement)
Expand All @@ -42,58 +87,37 @@ def execute(self, statement):
time.sleep(0.2) # 200ms

if status == 'SUCCEEDED':
results = self.athena.get_query_results(execution_id)
headers = [h['Name'].encode("utf-8") for h in results['ResultSet']['ResultSetMetadata']['ColumnInfo']]

if self.format in ['CSV', 'CSV_HEADER']:
csv_writer = csv.writer(sys.stdout, quoting=csv.QUOTE_ALL)
if self.format == 'CSV_HEADER':
csv_writer.writerow(headers)
csv_writer.writerows([[text.encode("utf-8") for text in row] for row in self.athena.yield_rows(results, headers)])
elif self.format == 'TSV':
print(tabulate([row for row in self.athena.yield_rows(results, headers)], tablefmt='tsv'))
elif self.format == 'TSV_HEADER':
print(tabulate([row for row in self.athena.yield_rows(results, headers)], headers=headers, tablefmt='tsv'))
elif self.format == 'VERTICAL':
for num, row in enumerate(self.athena.yield_rows(results, headers)):
print('--[RECORD {}]--'.format(num+1))
print(tabulate(zip(*[headers, row]), tablefmt='presto'))
else: # ALIGNED
print(tabulate([x for x in self.athena.yield_rows(results, headers)], headers=headers, tablefmt='presto'))
output_results(self.athena, self.format, execution_id, sys.stdout, is_shell=False)

if status == 'FAILED':
print(stats['QueryExecution']['Status']['StateChangeReason'])

try:
del cmd.Cmd.do_show # "show" is an Athena command
except AttributeError:
# "show" was removed from Cmd2 0.8.0
pass


class AthenaShell(cmd.Cmd, object):

multilineCommands = ['WITH', 'SELECT', 'ALTER', 'CREATE', 'DESCRIBE', 'DROP', 'MSCK', 'SHOW', 'USE', 'VALUES']
multiline_commands = ['WITH', 'SELECT', 'ALTER', 'CREATE', 'DESCRIBE', 'DROP', 'MSCK', 'SHOW', 'USE', 'VALUES', 'with', 'select', 'alter', 'create', 'describe', 'drop', 'msck', 'show', 'use', 'values']
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is it necessary to repeat the list of valid commands as both upper- and lowercase?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cmd2 seems to be case-sensitive with this feature, i.e. without the lower case commands in the list, then lower case statements aren't multiline

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This could be rewritten as ...

    multiline_commands = ['WITH', 'SELECT', 'ALTER', 'CREATE', 'DESCRIBE', 'DROP', 'MSCK', 'SHOW', 'USE', 'VALUES']
    multiline_commands.extend([c.lower() for c in multiline_commands])

... to avoid the possibility of missing the fact you need to add the command twice if new commands are added.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes that's better. Checked in.

allow_cli_args = False
service_name = 'athena'

def __init__(self, athena, db=None):
def __init__(self, athena, db=None, format=None):
cmd.Cmd.__init__(self)

# allow setting of the output format interactivately
self.settable['format'] = 'Output format';

self.athena = athena
self.dbname = db
self.format = 'ALIGNED' if format is None else format

self.execution_id = None

self.row_count = 0

self.set_prompt()
self.pager = os.environ.get('ATHENA_CLI_PAGER', LESS).split(' ')

self.hist_file = os.path.join(os.path.expanduser("~"), ".athena_history")
self.init_history()

def set_prompt(self):
self.prompt = 'athena:%s> ' % self.dbname if self.dbname else 'athena> '
self.prompt = '%s:%s> ' % (self.service_name, self.dbname) if self.dbname else '%s> ' % self.service_name

def cmdloop_with_cancel(self, intro=None):
try:
Expand Down Expand Up @@ -134,6 +158,7 @@ def do_help(self, arg):
help_output = """
Supported commands:
QUIT
EXIT
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do EXIT and QUIT do exactly the same thing?

Copy link
Copy Markdown
Author

@bryanck bryanck Sep 5, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, "exit" is supported by other CLIs (like the presto-cli) so some users liked to have it

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SELECT
ALTER DATABASE <schema>
ALTER TABLE <table>
Expand All @@ -157,9 +182,11 @@ def do_help(self, arg):
print(help_output)

def do_quit(self, arg):
print()
return -1

def do_exit(self, arg):
return self.do_quit(arg)

def do_EOF(self, arg):
return self.do_quit(arg)

Expand All @@ -169,17 +196,19 @@ def do_use(self, schema):

def do_set(self, arg):
try:
statement, param_name, val = arg.parsed.raw.split(None, 2)
statement, param_name, val = arg.raw.split(None, 2)
val = val.strip()
param_name = param_name.strip().lower()
if param_name == 'debug':
self.athena.debug = cmd.cast(True, val)
elif param_name == 'format':
arg = "format " + val.upper()
except (ValueError, AttributeError):
self.do_show(arg)
pass
super(AthenaShell, self).do_set(arg)

def default(self, line):
self.execution_id = self.athena.start_query_execution(self.dbname, line.full_parsed_statement())
self.execution_id = self.athena.start_query_execution(self.dbname, line.raw)
if not self.execution_id:
return

Expand All @@ -197,15 +226,10 @@ def default(self, line):
sys.stdout.flush()

if status == 'SUCCEEDED':
results = self.athena.get_query_results(self.execution_id)
headers = [h['Name'] for h in results['ResultSet']['ResultSetMetadata']['ColumnInfo']]
row_count = len(results['ResultSet']['Rows'])

if headers and len(results['ResultSet']['Rows']) and results['ResultSet']['Rows'][0]['Data'][0].get('VarCharValue', None) == headers[0]:
row_count -= 1 # don't count header

process = subprocess.Popen(self.pager, stdin=subprocess.PIPE)
process.stdin.write(tabulate([x for x in self.athena.yield_rows(results, headers)], headers=headers, tablefmt='presto').encode('utf-8'))
less = LESS_TRUNC if self.format == 'TRUNCATE' else LESS
pager = os.environ.get('ATHENA_CLI_PAGER', less).split(' ')
process = subprocess.Popen(pager, stdin=subprocess.PIPE, universal_newlines=True)
row_count = output_results(self.athena, self.format, self.execution_id, process.stdin, is_shell=True)
process.communicate()
print('(%s rows)\n' % row_count)

Expand Down Expand Up @@ -233,7 +257,8 @@ class Athena(object):
def __init__(self, profile, region=None, bucket=None, debug=False, encryption=False):

self.session = boto3.Session(profile_name=profile, region_name=region)
self.athena = self.session.client('athena')
session_config = botocore.config.Config(user_agent='athena-cli')
self.athena = self.session.client('athena', config=session_config)

self.region = region or os.environ.get('AWS_DEFAULT_REGION', None) or self.session.region_name

Expand Down Expand Up @@ -284,14 +309,23 @@ def get_query_results(self, execution_id):
results = None
paginator = self.athena.get_paginator('get_query_results')
page_iterator = paginator.paginate(
QueryExecutionId=execution_id
QueryExecutionId=execution_id,
PaginationConfig={'PageSize':1000}
)

for page in page_iterator:
if results is None:
results = page
else:
results['ResultSet']['Rows'].extend(page['ResultSet']['Rows'])
pages = iter(page_iterator)
first_page = next(pages) # get first page so we can retrieve metadata for header

headers = list(h['Name'] for h in first_page['ResultSet']['ResultSetMetadata']['ColumnInfo'])
first_row = None if len(first_page['ResultSet']['Rows']) == 0 else list(self.get_col_value(col) for col in first_page['ResultSet']['Rows'][0]['Data'])
rows = self.yield_rows(first_page, pages)

# certain requests return the header as the first row, so skip it
if first_row == headers:
next(rows)

return (headers, rows)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This return has made the code after this try/exception block unreachable.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Whoops. This should be fixed now


except ClientError as e:
sys.exit(e)

Expand All @@ -309,12 +343,16 @@ def stop_query_execution(self, execution_id):
sys.exit(e)

@staticmethod
def yield_rows(results, headers):
for row in results['ResultSet']['Rows']:
# https://forums.aws.amazon.com/thread.jspa?threadID=256505
if headers and row['Data'][0].get('VarCharValue', None) == headers[0]:
continue # skip header
yield [d.get('VarCharValue', 'NULL') for d in row['Data']]
def get_col_value(col):
return col.get('VarCharValue', 'NULL')

@staticmethod
def yield_rows(first_page, pages):
for row in first_page['ResultSet']['Rows']:
yield [Athena.get_col_value(col) for col in row['Data']]
for page in pages:
for row in page['ResultSet']['Rows']:
yield [Athena.get_col_value(col) for col in row['Data']]

def console_link(self, execution_id):
return 'https://{0}.console.aws.amazon.com/athena/home?force&region={0}#query/history/{1}'.format(self.region, execution_id)
Expand Down Expand Up @@ -404,7 +442,7 @@ def main():
batch = AthenaBatch(athena, db=args.schema, format=args.format)
batch.execute(statement=args.execute)
else:
shell = AthenaShell(athena, db=args.schema)
shell = AthenaShell(athena, db=args.schema, format=args.format)
shell.cmdloop_with_cancel()

if __name__ == '__main__':
Expand Down
9 changes: 5 additions & 4 deletions setup.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@

from setuptools import setup, find_packages

version = '0.1.8'
version = '0.2.0'

setup(
name="athena-cli",
Expand All @@ -17,8 +17,8 @@
],
install_requires=[
'boto3',
'cmd2',
'tabulate>=0.8.1'
'cmd2>=0.9.4',
'tabulate>=0.8.2'
],
include_package_data=True,
zip_safe=True,
Expand All @@ -30,5 +30,6 @@
keywords='aws athena presto cli',
classifiers=[
'Topic :: Utilities'
]
],
python_requires='>=3.5'
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The tests are failing because they are being run against Python 2.7 and 3.4 but this python_requires restricts installation to Python 3.5+.

You need to update the .travis.yml file like so ...

python:
  - "3.5"
  - "3.6"

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I checked in a fix

)