1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
"""Integration between SQLAlchemy and Presto.
Some code based on
https://github.com/zzzeek/sqlalchemy/blob/rel_0_5/lib/sqlalchemy/databases/sqlite.py
which is released under the MIT license.
"""
from __future__ import absolute_import
from __future__ import unicode_literals
import re
from sqlalchemy import exc
from sqlalchemy import types
from sqlalchemy import util
# TODO shouldn't use mysql type
from sqlalchemy.databases import mysql
from sqlalchemy.engine import default
from sqlalchemy.sql import compiler
from sqlalchemy.sql.compiler import SQLCompiler
from pyhive import presto
from pyhive.common import UniversalSet
class PrestoIdentifierPreparer(compiler.IdentifierPreparer):
# Just quote everything to make things simpler / easier to upgrade
reserved_words = UniversalSet()
_type_map = {
'boolean': types.Boolean,
'tinyint': mysql.MSTinyInteger,
'smallint': types.SmallInteger,
'integer': types.Integer,
'bigint': types.BigInteger,
'real': types.Float,
'double': types.Float,
'varchar': types.String,
'timestamp': types.TIMESTAMP,
'date': types.DATE,
'varbinary': types.VARBINARY,
}
class PrestoCompiler(SQLCompiler):
def visit_char_length_func(self, fn, **kw):
return 'length{}'.format(self.function_argspec(fn, **kw))
class PrestoTypeCompiler(compiler.GenericTypeCompiler):
def visit_CLOB(self, type_, **kw):
raise ValueError("Presto does not support the CLOB column type.")
def visit_NCLOB(self, type_, **kw):
raise ValueError("Presto does not support the NCLOB column type.")
def visit_DATETIME(self, type_, **kw):
raise ValueError("Presto does not support the DATETIME column type.")
def visit_FLOAT(self, type_, **kw):
return 'DOUBLE'
def visit_TEXT(self, type_, **kw):
if type_.length:
return 'VARCHAR({:d})'.format(type_.length)
else:
return 'VARCHAR'
class PrestoDialect(default.DefaultDialect):
name = 'presto'
driver = 'rest'
paramstyle = 'pyformat'
preparer = PrestoIdentifierPreparer
statement_compiler = PrestoCompiler
supports_alter = False
supports_pk_autoincrement = False
supports_default_values = False
supports_empty_insert = False
supports_unicode_statements = True
supports_unicode_binds = True
returns_unicode_strings = True
description_encoding = None
supports_native_boolean = True
type_compiler = PrestoTypeCompiler
@classmethod
def dbapi(cls):
return presto
def create_connect_args(self, url):
db_parts = (url.database or 'hive').split('/')
kwargs = {
'host': url.host,
'port': url.port or 8080,
'username': url.username,
'password': url.password
}
kwargs.update(url.query)
if len(db_parts) == 1:
kwargs['catalog'] = db_parts[0]
elif len(db_parts) == 2:
kwargs['catalog'] = db_parts[0]
kwargs['schema'] = db_parts[1]
else:
raise ValueError("Unexpected database format {}".format(url.database))
return [], kwargs
def get_schema_names(self, connection, **kw):
return [row.Schema for row in connection.execute('SHOW SCHEMAS')]
def _get_table_columns(self, connection, table_name, schema):
full_table = self.identifier_preparer.quote_identifier(table_name)
if schema:
full_table = self.identifier_preparer.quote_identifier(schema) + '.' + full_table
try:
return connection.execute('SHOW COLUMNS FROM {}'.format(full_table))
except (presto.DatabaseError, exc.DatabaseError) as e:
# Normally SQLAlchemy should wrap this exception in sqlalchemy.exc.DatabaseError, which
# it successfully does in the Hive version. The difference with Presto is that this
# error is raised when fetching the cursor's description rather than the initial execute
# call. SQLAlchemy doesn't handle this. Thus, we catch the unwrapped
# presto.DatabaseError here.
# Does the table exist?
msg = (
e.args[0].get('message') if e.args and isinstance(e.args[0], dict)
else e.args[0] if e.args and isinstance(e.args[0], str)
else None
)
regex = r"Table\ \'.*{}\'\ does\ not\ exist".format(re.escape(table_name))
if msg and re.search(regex, msg):
raise exc.NoSuchTableError(table_name)
else:
raise
def has_table(self, connection, table_name, schema=None):
try:
self._get_table_columns(connection, table_name, schema)
return True
except exc.NoSuchTableError:
return False
def get_columns(self, connection, table_name, schema=None, **kw):
rows = self._get_table_columns(connection, table_name, schema)
result = []
for row in rows:
try:
coltype = _type_map[row.Type]
except KeyError:
util.warn("Did not recognize type '%s' of column '%s'" % (row.Type, row.Column))
coltype = types.NullType
result.append({
'name': row.Column,
'type': coltype,
# newer Presto no longer includes this column
'nullable': getattr(row, 'Null', True),
'default': None,
})
return result
def get_foreign_keys(self, connection, table_name, schema=None, **kw):
# Hive has no support for foreign keys.
return []
def get_pk_constraint(self, connection, table_name, schema=None, **kw):
# Hive has no support for primary keys.
return []
def get_indexes(self, connection, table_name, schema=None, **kw):
rows = self._get_table_columns(connection, table_name, schema)
col_names = []
for row in rows:
part_key = 'Partition Key'
# Presto puts this information in one of 3 places depending on version
# - a boolean column named "Partition Key"
# - a string in the "Comment" column
# - a string in the "Extra" column
is_partition_key = (
(part_key in row and row[part_key])
or row['Comment'].startswith(part_key)
or ('Extra' in row and 'partition key' in row['Extra'])
)
if is_partition_key:
col_names.append(row['Column'])
if col_names:
return [{'name': 'partition', 'column_names': col_names, 'unique': False}]
else:
return []
def get_table_names(self, connection, schema=None, **kw):
query = 'SHOW TABLES'
if schema:
query += ' FROM ' + self.identifier_preparer.quote_identifier(schema)
return [row.Table for row in connection.execute(query)]
def do_rollback(self, dbapi_connection):
# No transactions for Presto
pass
def _check_unicode_returns(self, connection, additional_tests=None):
# requests gives back Unicode strings
return True
def _check_unicode_description(self, connection):
# requests gives back Unicode strings
return True