Skip to content
This repository was archived by the owner on Feb 10, 2023. It is now read-only.
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions .travis.yml
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
language: python
python:
- "2.7"
- "3.4"
- "3.5"
- "3.6"

install:
- "pip install ."
Expand Down
163 changes: 101 additions & 62 deletions athena_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,25 +9,69 @@
import sys
import time
import uuid
import itertools

import boto3
import botocore
import cmd2 as cmd
from botocore.exceptions import ClientError, ParamValidationError
from tabulate import tabulate

LESS = "less -FXRSn"
LESS = "less -FXRn"
LESS_TRUNC = "less -FXRSn"
HISTORY_FILE_SIZE = 500

__version__ = '0.1.8'
__version__ = '0.2.0'

def output_results(athena, format, execution_id, output, is_shell):
results = athena.get_query_results(execution_id)
headers = results[0]
counter = itertools.count(start=1)
rows = zip(results[1], counter)

try:
if format in ['CSV', 'CSV_HEADER', 'TSV', 'TSV_HEADER']:
if format in ['TSV', 'TSV_HEADER']:
delim = '\t'
quote = csv.QUOTE_NONE
esc = '\\'
else:
delim = ','
quote = csv.QUOTE_ALL
esc = None

csv_writer = csv.writer(output, delimiter=delim, quoting=quote, escapechar=esc)

if format in ['CSV_HEADER', 'TSV_HEADER']:
csv_writer.writerow(headers)
csv_writer.writerows([row for row, count in rows])

elif format == 'VERTICAL':
for row, count in rows:
output.write('--[RECORD {}]--'.format(count))
output.write('\n')
output.write(tabulate(zip(*[headers, row]), tablefmt='presto'))
output.write('\n')

else: # ALIGNED
output.write(tabulate([row for row, count in rows], headers=headers, tablefmt='presto'))
output.write('\n')

output.flush()
except IOError as x:
# quitting the less process in shell causes an IOError, so ignore
if not is_shell:
raise x

return next(counter) - 1


class AthenaBatch(object):

def __init__(self, athena, db=None, format='CSV'):
def __init__(self, athena, db=None, format=None):
self.athena = athena
self.dbname = db
self.format = format
self.format = format or 'CSV'

def execute(self, statement):
execution_id = self.athena.start_query_execution(self.dbname, statement)
Expand All @@ -42,58 +86,38 @@ def execute(self, statement):
time.sleep(0.2) # 200ms

if status == 'SUCCEEDED':
results = self.athena.get_query_results(execution_id)
headers = [h['Name'].encode("utf-8") for h in results['ResultSet']['ResultSetMetadata']['ColumnInfo']]

if self.format in ['CSV', 'CSV_HEADER']:
csv_writer = csv.writer(sys.stdout, quoting=csv.QUOTE_ALL)
if self.format == 'CSV_HEADER':
csv_writer.writerow(headers)
csv_writer.writerows([[text.encode("utf-8") for text in row] for row in self.athena.yield_rows(results, headers)])
elif self.format == 'TSV':
print(tabulate([row for row in self.athena.yield_rows(results, headers)], tablefmt='tsv'))
elif self.format == 'TSV_HEADER':
print(tabulate([row for row in self.athena.yield_rows(results, headers)], headers=headers, tablefmt='tsv'))
elif self.format == 'VERTICAL':
for num, row in enumerate(self.athena.yield_rows(results, headers)):
print('--[RECORD {}]--'.format(num+1))
print(tabulate(zip(*[headers, row]), tablefmt='presto'))
else: # ALIGNED
print(tabulate([x for x in self.athena.yield_rows(results, headers)], headers=headers, tablefmt='presto'))
output_results(self.athena, self.format, execution_id, sys.stdout, is_shell=False)

if status == 'FAILED':
print(stats['QueryExecution']['Status']['StateChangeReason'])

try:
del cmd.Cmd.do_show # "show" is an Athena command
except AttributeError:
# "show" was removed from Cmd2 0.8.0
pass


class AthenaShell(cmd.Cmd, object):

multilineCommands = ['WITH', 'SELECT', 'ALTER', 'CREATE', 'DESCRIBE', 'DROP', 'MSCK', 'SHOW', 'USE', 'VALUES']
multiline_commands = ['WITH', 'SELECT', 'ALTER', 'CREATE', 'DESCRIBE', 'DROP', 'MSCK', 'SHOW', 'USE', 'VALUES']
multiline_commands.extend([c.lower() for c in multiline_commands])
allow_cli_args = False
service_name = 'athena'

def __init__(self, athena, db=None):
def __init__(self, athena, db=None, format=None):
cmd.Cmd.__init__(self)

# allow setting of the output format interactivately
self.settable['format'] = 'Output format';

self.athena = athena
self.dbname = db
self.format = format or 'ALIGNED'

self.execution_id = None

self.row_count = 0

self.set_prompt()
self.pager = os.environ.get('ATHENA_CLI_PAGER', LESS).split(' ')

self.hist_file = os.path.join(os.path.expanduser("~"), ".athena_history")
self.init_history()

def set_prompt(self):
self.prompt = 'athena:%s> ' % self.dbname if self.dbname else 'athena> '
self.prompt = '%s:%s> ' % (self.service_name, self.dbname) if self.dbname else '%s> ' % self.service_name

def cmdloop_with_cancel(self, intro=None):
try:
Expand Down Expand Up @@ -134,6 +158,7 @@ def do_help(self, arg):
help_output = """
Supported commands:
QUIT
EXIT
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do EXIT and QUIT do exactly the same thing?

Copy link
Copy Markdown
Author

@bryanck bryanck Sep 5, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, "exit" is supported by other CLIs (like the presto-cli) so some users liked to have it

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SELECT
ALTER DATABASE <schema>
ALTER TABLE <table>
Expand All @@ -157,9 +182,11 @@ def do_help(self, arg):
print(help_output)

def do_quit(self, arg):
print()
return -1

def do_exit(self, arg):
return self.do_quit(arg)

def do_EOF(self, arg):
return self.do_quit(arg)

Expand All @@ -169,17 +196,19 @@ def do_use(self, schema):

def do_set(self, arg):
try:
statement, param_name, val = arg.parsed.raw.split(None, 2)
statement, param_name, val = arg.raw.split(None, 2)
val = val.strip()
param_name = param_name.strip().lower()
if param_name == 'debug':
self.athena.debug = cmd.cast(True, val)
elif param_name == 'format':
arg = "format " + val.upper()
except (ValueError, AttributeError):
self.do_show(arg)
pass
super(AthenaShell, self).do_set(arg)

def default(self, line):
self.execution_id = self.athena.start_query_execution(self.dbname, line.full_parsed_statement())
self.execution_id = self.athena.start_query_execution(self.dbname, line.raw)
if not self.execution_id:
return

Expand All @@ -197,15 +226,10 @@ def default(self, line):
sys.stdout.flush()

if status == 'SUCCEEDED':
results = self.athena.get_query_results(self.execution_id)
headers = [h['Name'] for h in results['ResultSet']['ResultSetMetadata']['ColumnInfo']]
row_count = len(results['ResultSet']['Rows'])

if headers and len(results['ResultSet']['Rows']) and results['ResultSet']['Rows'][0]['Data'][0].get('VarCharValue', None) == headers[0]:
row_count -= 1 # don't count header

process = subprocess.Popen(self.pager, stdin=subprocess.PIPE)
process.stdin.write(tabulate([x for x in self.athena.yield_rows(results, headers)], headers=headers, tablefmt='presto').encode('utf-8'))
less = LESS_TRUNC if self.format == 'TRUNCATE' else LESS
pager = os.environ.get('ATHENA_CLI_PAGER', less).split(' ')
process = subprocess.Popen(pager, stdin=subprocess.PIPE, universal_newlines=True)
row_count = output_results(self.athena, self.format, self.execution_id, process.stdin, is_shell=True)
process.communicate()
print('(%s rows)\n' % row_count)

Expand Down Expand Up @@ -233,7 +257,8 @@ class Athena(object):
def __init__(self, profile, region=None, bucket=None, debug=False, encryption=False):

self.session = boto3.Session(profile_name=profile, region_name=region)
self.athena = self.session.client('athena')
session_config = botocore.config.Config(user_agent='athena-cli')
self.athena = self.session.client('athena', config=session_config)

self.region = region or os.environ.get('AWS_DEFAULT_REGION', None) or self.session.region_name

Expand Down Expand Up @@ -280,18 +305,27 @@ def get_query_execution(self, execution_id):
print(e)

def get_query_results(self, execution_id):
results = None
try:
results = None
paginator = self.athena.get_paginator('get_query_results')
page_iterator = paginator.paginate(
QueryExecutionId=execution_id
QueryExecutionId=execution_id,
PaginationConfig={'PageSize':1000}
)

for page in page_iterator:
if results is None:
results = page
else:
results['ResultSet']['Rows'].extend(page['ResultSet']['Rows'])
pages = iter(page_iterator)
first_page = next(pages) # get first page so we can retrieve metadata for header

headers = list(h['Name'] for h in first_page['ResultSet']['ResultSetMetadata']['ColumnInfo'])
first_row = None if len(first_page['ResultSet']['Rows']) == 0 else list(self.get_col_value(col) for col in first_page['ResultSet']['Rows'][0]['Data'])
rows = self.yield_rows(first_page, pages)

# certain requests return the header as the first row, so skip it
if first_row == headers:
next(rows)

results = (headers, rows)

except ClientError as e:
sys.exit(e)

Expand All @@ -300,6 +334,7 @@ def get_query_results(self, execution_id):

return results


def stop_query_execution(self, execution_id):
try:
return self.athena.stop_query_execution(
Expand All @@ -309,12 +344,16 @@ def stop_query_execution(self, execution_id):
sys.exit(e)

@staticmethod
def yield_rows(results, headers):
for row in results['ResultSet']['Rows']:
# https://forums.aws.amazon.com/thread.jspa?threadID=256505
if headers and row['Data'][0].get('VarCharValue', None) == headers[0]:
continue # skip header
yield [d.get('VarCharValue', 'NULL') for d in row['Data']]
def get_col_value(col):
return col.get('VarCharValue', 'NULL')

@staticmethod
def yield_rows(first_page, pages):
for row in first_page['ResultSet']['Rows']:
yield [Athena.get_col_value(col) for col in row['Data']]
for page in pages:
for row in page['ResultSet']['Rows']:
yield [Athena.get_col_value(col) for col in row['Data']]

def console_link(self, execution_id):
return 'https://{0}.console.aws.amazon.com/athena/home?force&region={0}#query/history/{1}'.format(self.region, execution_id)
Expand Down Expand Up @@ -404,7 +443,7 @@ def main():
batch = AthenaBatch(athena, db=args.schema, format=args.format)
batch.execute(statement=args.execute)
else:
shell = AthenaShell(athena, db=args.schema)
shell = AthenaShell(athena, db=args.schema, format=args.format)
shell.cmdloop_with_cancel()

if __name__ == '__main__':
Expand Down
9 changes: 5 additions & 4 deletions setup.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@

from setuptools import setup, find_packages

version = '0.1.8'
version = '0.2.0'

setup(
name="athena-cli",
Expand All @@ -17,8 +17,8 @@
],
install_requires=[
'boto3',
'cmd2',
'tabulate>=0.8.1'
'cmd2>=0.9.4',
'tabulate>=0.8.2'
],
include_package_data=True,
zip_safe=True,
Expand All @@ -30,5 +30,6 @@
keywords='aws athena presto cli',
classifiers=[
'Topic :: Utilities'
]
],
python_requires='>=3.5'
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The tests are failing because they are being run against Python 2.7 and 3.4 but this python_requires restricts installation to Python 3.5+.

You need to update the .travis.yml file like so ...

python:
  - "3.5"
  - "3.6"

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I checked in a fix

)