Commit 6c3b89ad authored by litaolemo's avatar litaolemo

update

parent 8405f1f9
# meta_base_code
metabase 数据开发相关代码
\ No newline at end of file
metabase 数据开发相关代码
服务器 airflow002
1. 切换权限 sudo su - gmuser
2. source /srv/envs/esmm/bin/activate
3. python crawler/crawler_sys/utils/get_query_result.py
/opt/spark/bin/spark-submit --master yarn --deploy-mode client --queue root.strategy --driver-memory 16g --executor-memory 1g --executor-cores 1 --num-executors 70 --conf spark.default.parallelism=100 --conf spark.storage.memoryFraction=0.5 --conf spark.shuffle.memoryFraction=0.3 --conf spark.executorEnv.LD_LIBRARY_PATH="/opt/java/jdk1.8.0_181/jre/lib/amd64/server:/opt/cloudera/parcels/CDH-5.16.1-1.cdh5.16.1.p0.3/lib64" --conf spark.locality.wait=0 --jars /srv/apps/tispark-core-2.1-SNAPSHOT-jar-with-dependencies.jar,/srv/apps/spark-connector_2.11-1.9.0-rc2.jar,/srv/apps/mysql-connector-java-5.1.38.jar /srv/apps/meta_base_code/task/conent_detail_page_grayscale_ctr.py
# -*- coding:UTF-8 -*-
# @Time : 2020/8/21 9:34
# @File : __init__.py.py
# @email : litao@igengmei.com
# @author : litao
\ No newline at end of file
http://www.opensource.org/licenses/mit-license.php
Copyright 2007-2011 David Alan Cridland
Copyright 2011 Lance Stout
Copyright 2012 Tyler L Hobbs
Permission is hereby granted, free of charge, to any person obtaining a copy of this
software and associated documentation files (the "Software"), to deal in the Software
without restriction, including without limitation the rights to use, copy, modify, merge,
publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons
to whom the Software is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all copies or
substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR
PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE
FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
Metadata-Version: 2.1
Name: pure-sasl
Version: 0.6.1
Summary: Pure Python client SASL implementation
Home-page: http://github.com/thobbs/pure-sasl
Author: Tyler Hobbs
Author-email: tylerlhobbs@gmail.com
Maintainer: Alex Shafer
Maintainer-email: ashafer01@gmail.com
License: MIT
Keywords: sasl
Platform: UNKNOWN
Classifier: Development Status :: 4 - Beta
Classifier: Intended Audience :: Developers
Classifier: License :: OSI Approved :: MIT License
Classifier: Natural Language :: English
Classifier: Operating System :: OS Independent
Classifier: Programming Language :: Python :: 2.6
Classifier: Programming Language :: Python :: 2.7
Classifier: Programming Language :: Python :: 3.3
Classifier: Programming Language :: Python :: 3.4
Classifier: Programming Language :: Python :: 3.5
Classifier: Programming Language :: Python :: 3.6
Classifier: Topic :: Software Development :: Libraries :: Python Modules
Provides-Extra: gssapi
Requires-Dist: kerberos (>=1.3.0) ; extra == 'gssapi'
This package provides a reasonably high-level SASL client written
in pure Python. New mechanisms may be integrated easily, but by default,
support for PLAIN, ANONYMOUS, EXTERNAL, CRAM-MD5, DIGEST-MD5, and GSSAPI are
provided.
pure_sasl-0.6.1.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4
pure_sasl-0.6.1.dist-info/LICENSE,sha256=NpiBGKBhU7n98ItgAMlDXwQw_NH-rDj9hPscwmf0pT0,1172
pure_sasl-0.6.1.dist-info/METADATA,sha256=BVuBACWERcyUyMiB1TOkLRmwd2U6J9E43aDUaevVl8M,1215
pure_sasl-0.6.1.dist-info/RECORD,,
pure_sasl-0.6.1.dist-info/WHEEL,sha256=wy6I0RkXf1SHQI9WgMGMZtkrPqc-c49UWrTH1_nKFnQ,93
pure_sasl-0.6.1.dist-info/top_level.txt,sha256=2acq8MK3X4dZrUf_tIPo23sT7L9lDb9BnXcpeiazxZk,9
puresasl/__init__.py,sha256=PLhXLobzrQmCWjZvQvbx4yZicVvRuUg8Pxt-QKgpIq8,1026
puresasl/__init__.pyc,,
puresasl/client.py,sha256=uVZDRtlvKYDFa5qNWz0hgHtrpRRXUWqKWNiu73KZMn4,9929
puresasl/client.pyc,,
puresasl/mechanisms.py,sha256=h5l7cMkFwyh9hhzhqzE09jWXPAsBmI6bXe_jAz122vE,19319
puresasl/mechanisms.pyc,,
Wheel-Version: 1.0
Generator: bdist_wheel (0.33.1)
Root-Is-Purelib: true
Tag: cp27-none-any
from __future__ import absolute_import
from __future__ import unicode_literals
__version__ = '0.6.1'
"""Package private common utilities. Do not use directly.
Many docstrings in this file are based on PEP-249, which is in the public domain.
"""
from __future__ import absolute_import
from __future__ import unicode_literals
from builtins import bytes
from builtins import int
from builtins import object
from builtins import str
from past.builtins import basestring
from pyhive import exc
import abc
import collections
import time
from future.utils import with_metaclass
from itertools import islice
class DBAPICursor(with_metaclass(abc.ABCMeta, object)):
"""Base class for some common DB-API logic"""
_STATE_NONE = 0
_STATE_RUNNING = 1
_STATE_FINISHED = 2
def __init__(self, poll_interval=1):
self._poll_interval = poll_interval
self._reset_state()
self.lastrowid = None
def _reset_state(self):
"""Reset state about the previous query in preparation for running another query"""
# State to return as part of DB-API
self._rownumber = 0
# Internal helper state
self._state = self._STATE_NONE
self._data = collections.deque()
self._columns = None
def _fetch_while(self, fn):
while fn():
self._fetch_more()
if fn():
time.sleep(self._poll_interval)
@abc.abstractproperty
def description(self):
raise NotImplementedError # pragma: no cover
def close(self):
"""By default, do nothing"""
pass
@abc.abstractmethod
def _fetch_more(self):
"""Get more results, append it to ``self._data``, and update ``self._state``."""
raise NotImplementedError # pragma: no cover
@property
def rowcount(self):
"""By default, return -1 to indicate that this is not supported."""
return -1
@abc.abstractmethod
def execute(self, operation, parameters=None):
"""Prepare and execute a database operation (query or command).
Parameters may be provided as sequence or mapping and will be bound to variables in the
operation. Variables are specified in a database-specific notation (see the module's
``paramstyle`` attribute for details).
Return values are not defined.
"""
raise NotImplementedError # pragma: no cover
def executemany(self, operation, seq_of_parameters):
"""Prepare a database operation (query or command) and then execute it against all parameter
sequences or mappings found in the sequence ``seq_of_parameters``.
Only the final result set is retained.
Return values are not defined.
"""
for parameters in seq_of_parameters[:-1]:
self.execute(operation, parameters)
while self._state != self._STATE_FINISHED:
self._fetch_more()
if seq_of_parameters:
self.execute(operation, seq_of_parameters[-1])
def fetchone(self):
"""Fetch the next row of a query result set, returning a single sequence, or ``None`` when
no more data is available.
An :py:class:`~pyhive.exc.Error` (or subclass) exception is raised if the previous call to
:py:meth:`execute` did not produce any result set or no call was issued yet.
"""
if self._state == self._STATE_NONE:
raise exc.ProgrammingError("No query yet")
# Sleep until we're done or we have some data to return
self._fetch_while(lambda: not self._data and self._state != self._STATE_FINISHED)
if not self._data:
return None
else:
self._rownumber += 1
return self._data.popleft()
def fetchmany(self, size=None):
"""Fetch the next set of rows of a query result, returning a sequence of sequences (e.g. a
list of tuples). An empty sequence is returned when no more rows are available.
The number of rows to fetch per call is specified by the parameter. If it is not given, the
cursor's arraysize determines the number of rows to be fetched. The method should try to
fetch as many rows as indicated by the size parameter. If this is not possible due to the
specified number of rows not being available, fewer rows may be returned.
An :py:class:`~pyhive.exc.Error` (or subclass) exception is raised if the previous call to
:py:meth:`execute` did not produce any result set or no call was issued yet.
"""
if size is None:
size = self.arraysize
return list(islice(iter(self.fetchone, None), size))
def fetchall(self):
"""Fetch all (remaining) rows of a query result, returning them as a sequence of sequences
(e.g. a list of tuples).
An :py:class:`~pyhive.exc.Error` (or subclass) exception is raised if the previous call to
:py:meth:`execute` did not produce any result set or no call was issued yet.
"""
return list(iter(self.fetchone, None))
@property
def arraysize(self):
"""This read/write attribute specifies the number of rows to fetch at a time with
:py:meth:`fetchmany`. It defaults to 1 meaning to fetch a single row at a time.
"""
return self._arraysize
@arraysize.setter
def arraysize(self, value):
self._arraysize = value
def setinputsizes(self, sizes):
"""Does nothing by default"""
pass
def setoutputsize(self, size, column=None):
"""Does nothing by default"""
pass
#
# Optional DB API Extensions
#
@property
def rownumber(self):
"""This read-only attribute should provide the current 0-based index of the cursor in the
result set.
The index can be seen as index of the cursor in a sequence (the result set). The next fetch
operation will fetch the row indexed by ``rownumber`` in that sequence.
"""
return self._rownumber
def __next__(self):
"""Return the next row from the currently executing SQL statement using the same semantics
as :py:meth:`fetchone`. A ``StopIteration`` exception is raised when the result set is
exhausted.
"""
one = self.fetchone()
if one is None:
raise StopIteration
else:
return one
next = __next__
def __iter__(self):
"""Return self to make cursors compatible to the iteration protocol."""
return self
class DBAPITypeObject(object):
# Taken from http://www.python.org/dev/peps/pep-0249/#implementation-hints
def __init__(self, *values):
self.values = values
def __cmp__(self, other):
if other in self.values:
return 0
if other < self.values:
return 1
else:
return -1
class ParamEscaper(object):
def escape_args(self, parameters):
if isinstance(parameters, dict):
return {k: self.escape_item(v) for k, v in parameters.items()}
elif isinstance(parameters, (list, tuple)):
return tuple(self.escape_item(x) for x in parameters)
else:
raise exc.ProgrammingError("Unsupported param format: {}".format(parameters))
def escape_number(self, item):
return item
def escape_string(self, item):
# Need to decode UTF-8 because of old sqlalchemy.
# Newer SQLAlchemy checks dialect.supports_unicode_binds before encoding Unicode strings
# as byte strings. The old version always encodes Unicode as byte strings, which breaks
# string formatting here.
if isinstance(item, bytes):
item = item.decode('utf-8')
# This is good enough when backslashes are literal, newlines are just followed, and the way
# to escape a single quote is to put two single quotes.
# (i.e. only special character is single quote)
return "'{}'".format(item.replace("'", "''"))
def escape_sequence(self, item):
l = map(str, map(self.escape_item, item))
return '(' + ','.join(l) + ')'
def escape_item(self, item):
if item is None:
return 'NULL'
elif isinstance(item, (int, float)):
return self.escape_number(item)
elif isinstance(item, basestring):
return self.escape_string(item)
elif isinstance(item, collections.Iterable):
return self.escape_sequence(item)
else:
raise exc.ProgrammingError("Unsupported object {}".format(item))
class UniversalSet(object):
"""set containing everything"""
def __contains__(self, item):
return True
"""
Package private common utilities. Do not use directly.
"""
from __future__ import absolute_import
from __future__ import unicode_literals
__all__ = [
'Error', 'Warning', 'InterfaceError', 'DatabaseError', 'InternalError', 'OperationalError',
'ProgrammingError', 'DataError', 'NotSupportedError',
]
class Error(Exception):
"""Exception that is the base class of all other error exceptions.
You can use this to catch all errors with one single except statement.
"""
pass
class Warning(Exception):
"""Exception raised for important warnings like data truncations while inserting, etc."""
pass
class InterfaceError(Error):
"""Exception raised for errors that are related to the database interface rather than the
database itself.
"""
pass
class DatabaseError(Error):
"""Exception raised for errors that are related to the database."""
pass
class InternalError(DatabaseError):
"""Exception raised when the database encounters an internal error, e.g. the cursor is not valid
anymore, the transaction is out of sync, etc."""
pass
class OperationalError(DatabaseError):
"""Exception raised for errors that are related to the database's operation and not necessarily
under the control of the programmer, e.g. an unexpected disconnect occurs, the data source name
is not found, a transaction could not be processed, a memory allocation error occurred during
processing, etc.
"""
pass
class ProgrammingError(DatabaseError):
"""Exception raised for programming errors, e.g. table not found or already exists, syntax error
in the SQL statement, wrong number of parameters specified, etc.
"""
pass
class DataError(DatabaseError):
"""Exception raised for errors that are due to problems with the processed data like division by
zero, numeric value out of range, etc.
"""
pass
class NotSupportedError(DatabaseError):
"""Exception raised in case a method or database API was used which is not supported by the
database, e.g. requesting a ``.rollback()`` on a connection that does not support transaction or
has transactions turned off.
"""
pass
"""DB-API implementation backed by HiveServer2 (Thrift API)
See http://www.python.org/dev/peps/pep-0249/
Many docstrings in this file are based on the PEP, which is in the public domain.
"""
from __future__ import absolute_import
from __future__ import unicode_literals
import datetime
import re
from decimal import Decimal
from TCLIService import TCLIService
from TCLIService import constants
from TCLIService import ttypes
from pyhive import common
from pyhive.common import DBAPITypeObject
# Make all exceptions visible in this module per DB-API
from pyhive.exc import * # noqa
from builtins import range
import contextlib
from future.utils import iteritems
import getpass
import logging
import sys
import thrift.protocol.TBinaryProtocol
import thrift.transport.TSocket
import thrift.transport.TTransport
# PEP 249 module globals
apilevel = '2.0'
threadsafety = 2 # Threads may share the module and connections.
paramstyle = 'pyformat' # Python extended format codes, e.g. ...WHERE name=%(name)s
_logger = logging.getLogger(__name__)
_TIMESTAMP_PATTERN = re.compile(r'(\d+-\d+-\d+ \d+:\d+:\d+(\.\d{,6})?)')
def _parse_timestamp(value):
if value:
match = _TIMESTAMP_PATTERN.match(value)
if match:
if match.group(2):
format = '%Y-%m-%d %H:%M:%S.%f'
# use the pattern to truncate the value
value = match.group()
else:
format = '%Y-%m-%d %H:%M:%S'
value = datetime.datetime.strptime(value, format)
else:
raise Exception(
'Cannot convert "{}" into a datetime'.format(value))
else:
value = None
return value
TYPES_CONVERTER = {"DECIMAL_TYPE": Decimal,
"TIMESTAMP_TYPE": _parse_timestamp}
class HiveParamEscaper(common.ParamEscaper):
def escape_string(self, item):
# backslashes and single quotes need to be escaped
# TODO verify against parser
# Need to decode UTF-8 because of old sqlalchemy.
# Newer SQLAlchemy checks dialect.supports_unicode_binds before encoding Unicode strings
# as byte strings. The old version always encodes Unicode as byte strings, which breaks
# string formatting here.
if isinstance(item, bytes):
item = item.decode('utf-8')
return "'{}'".format(
item
.replace('\\', '\\\\')
.replace("'", "\\'")
.replace('\r', '\\r')
.replace('\n', '\\n')
.replace('\t', '\\t')
)
_escaper = HiveParamEscaper()
def connect(*args, **kwargs):
"""Constructor for creating a connection to the database. See class :py:class:`Connection` for
arguments.
:returns: a :py:class:`Connection` object.
"""
return Connection(*args, **kwargs)
class Connection(object):
"""Wraps a Thrift session"""
def __init__(self, host=None, port=None, username=None, database='default', auth=None,
configuration=None, kerberos_service_name=None, password=None,
thrift_transport=None):
"""Connect to HiveServer2
:param host: What host HiveServer2 runs on
:param port: What port HiveServer2 runs on. Defaults to 10000.
:param auth: The value of hive.server2.authentication used by HiveServer2.
Defaults to ``NONE``.
:param configuration: A dictionary of Hive settings (functionally same as the `set` command)
:param kerberos_service_name: Use with auth='KERBEROS' only
:param password: Use with auth='LDAP' or auth='CUSTOM' only
:param thrift_transport: A ``TTransportBase`` for custom advanced usage.
Incompatible with host, port, auth, kerberos_service_name, and password.
The way to support LDAP and GSSAPI is originated from cloudera/Impyla:
https://github.com/cloudera/impyla/blob/255b07ed973d47a3395214ed92d35ec0615ebf62
/impala/_thrift_api.py#L152-L160
"""
username = username or getpass.getuser()
configuration = configuration or {}
if (password is not None) != (auth in ('LDAP', 'CUSTOM')):
raise ValueError("Password should be set if and only if in LDAP or CUSTOM mode; "
"Remove password or use one of those modes")
if (kerberos_service_name is not None) != (auth == 'KERBEROS'):
raise ValueError("kerberos_service_name should be set if and only if in KERBEROS mode")
if thrift_transport is not None:
has_incompatible_arg = (
host is not None
or port is not None
or auth is not None
or kerberos_service_name is not None
or password is not None
)
if has_incompatible_arg:
raise ValueError("thrift_transport cannot be used with "
"host/port/auth/kerberos_service_name/password")
if thrift_transport is not None:
self._transport = thrift_transport
else:
if port is None:
port = 10000
if auth is None:
auth = 'NONE'
socket = thrift.transport.TSocket.TSocket(host, port)
if auth == 'NOSASL':
# NOSASL corresponds to hive.server2.authentication=NOSASL in hive-site.xml
self._transport = thrift.transport.TTransport.TBufferedTransport(socket)
elif auth in ('LDAP', 'KERBEROS', 'NONE', 'CUSTOM'):
# Defer import so package dependency is optional
import sasl
import thrift_sasl
if auth == 'KERBEROS':
# KERBEROS mode in hive.server2.authentication is GSSAPI in sasl library
sasl_auth = 'GSSAPI'
else:
sasl_auth = 'PLAIN'
if password is None:
# Password doesn't matter in NONE mode, just needs to be nonempty.
password = 'x'
def sasl_factory():
sasl_client = sasl.Client()
sasl_client.setAttr('host', host)
if sasl_auth == 'GSSAPI':
sasl_client.setAttr('service', kerberos_service_name)
elif sasl_auth == 'PLAIN':
sasl_client.setAttr('username', username)
sasl_client.setAttr('password', password)
else:
raise AssertionError
sasl_client.init()
return sasl_client
self._transport = thrift_sasl.TSaslClientTransport(sasl_factory, sasl_auth, socket)
else:
# All HS2 config options:
# https://cwiki.apache.org/confluence/display/Hive/Setting+Up+HiveServer2#SettingUpHiveServer2-Configuration
# PAM currently left to end user via thrift_transport option.
raise NotImplementedError(
"Only NONE, NOSASL, LDAP, KERBEROS, CUSTOM "
"authentication are supported, got {}".format(auth))
protocol = thrift.protocol.TBinaryProtocol.TBinaryProtocol(self._transport)
self._client = TCLIService.Client(protocol)
# oldest version that still contains features we care about
# "V6 uses binary type for binary payload (was string) and uses columnar result set"
protocol_version = ttypes.TProtocolVersion.HIVE_CLI_SERVICE_PROTOCOL_V6
try:
self._transport.open()
open_session_req = ttypes.TOpenSessionReq(
client_protocol=protocol_version,
configuration=configuration,
username=username,
)
response = self._client.OpenSession(open_session_req)
_check_status(response)
assert response.sessionHandle is not None, "Expected a session from OpenSession"
self._sessionHandle = response.sessionHandle
assert response.serverProtocolVersion == protocol_version, \
"Unable to handle protocol version {}".format(response.serverProtocolVersion)
with contextlib.closing(self.cursor()) as cursor:
cursor.execute('USE `{}`'.format(database))
except:
self._transport.close()
raise
def __enter__(self):
"""Transport should already be opened by __init__"""
return self
def __exit__(self, exc_type, exc_val, exc_tb):
"""Call close"""
self.close()
def close(self):
"""Close the underlying session and Thrift transport"""
req = ttypes.TCloseSessionReq(sessionHandle=self._sessionHandle)
response = self._client.CloseSession(req)
self._transport.close()
_check_status(response)
def commit(self):
"""Hive does not support transactions, so this does nothing."""
pass
def cursor(self, *args, **kwargs):
"""Return a new :py:class:`Cursor` object using the connection."""
return Cursor(self, *args, **kwargs)
@property
def client(self):
return self._client
@property
def sessionHandle(self):
return self._sessionHandle
def rollback(self):
raise NotSupportedError("Hive does not have transactions") # pragma: no cover
class Cursor(common.DBAPICursor):
"""These objects represent a database cursor, which is used to manage the context of a fetch
operation.
Cursors are not isolated, i.e., any changes done to the database by a cursor are immediately
visible by other cursors or connections.
"""
def __init__(self, connection, arraysize=1000):
self._operationHandle = None
super(Cursor, self).__init__()
self._arraysize = arraysize
self._connection = connection
def _reset_state(self):
"""Reset state about the previous query in preparation for running another query"""
super(Cursor, self)._reset_state()
self._description = None
if self._operationHandle is not None:
request = ttypes.TCloseOperationReq(self._operationHandle)
try:
response = self._connection.client.CloseOperation(request)
_check_status(response)
finally:
self._operationHandle = None
@property
def arraysize(self):
return self._arraysize
@arraysize.setter
def arraysize(self, value):
"""Array size cannot be None, and should be an integer"""
default_arraysize = 1000
try:
self._arraysize = int(value) or default_arraysize
except TypeError:
self._arraysize = default_arraysize
@property
def description(self):
"""This read-only attribute is a sequence of 7-item sequences.
Each of these sequences contains information describing one result column:
- name
- type_code
- display_size (None in current implementation)
- internal_size (None in current implementation)
- precision (None in current implementation)
- scale (None in current implementation)
- null_ok (always True in current implementation)
This attribute will be ``None`` for operations that do not return rows or if the cursor has
not had an operation invoked via the :py:meth:`execute` method yet.
The ``type_code`` can be interpreted by comparing it to the Type Objects specified in the
section below.
"""
if self._operationHandle is None or not self._operationHandle.hasResultSet:
return None
if self._description is None:
req = ttypes.TGetResultSetMetadataReq(self._operationHandle)
response = self._connection.client.GetResultSetMetadata(req)
_check_status(response)
columns = response.schema.columns
self._description = []
for col in columns:
primary_type_entry = col.typeDesc.types[0]
if primary_type_entry.primitiveEntry is None:
# All fancy stuff maps to string
type_code = ttypes.TTypeId._VALUES_TO_NAMES[ttypes.TTypeId.STRING_TYPE]
else:
type_id = primary_type_entry.primitiveEntry.type
type_code = ttypes.TTypeId._VALUES_TO_NAMES[type_id]
self._description.append((
col.columnName.decode('utf-8') if sys.version_info[0] == 2 else col.columnName,
type_code.decode('utf-8') if sys.version_info[0] == 2 else type_code,
None, None, None, None, True
))
return self._description
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.close()
def close(self):
"""Close the operation handle"""
self._reset_state()
def execute(self, operation, parameters=None, **kwargs):
"""Prepare and execute a database operation (query or command).
Return values are not defined.
"""
# backward compatibility with Python < 3.7
for kw in ['async', 'async_']:
if kw in kwargs:
async_ = kwargs[kw]
break
else:
async_ = False
# Prepare statement
if parameters is None:
sql = operation
else:
sql = operation % _escaper.escape_args(parameters)
self._reset_state()
self._state = self._STATE_RUNNING
_logger.info('%s', sql)
req = ttypes.TExecuteStatementReq(self._connection.sessionHandle,
sql, runAsync=async_)
_logger.debug(req)
response = self._connection.client.ExecuteStatement(req)
_check_status(response)
self._operationHandle = response.operationHandle
def cancel(self):
req = ttypes.TCancelOperationReq(
operationHandle=self._operationHandle,
)
response = self._connection.client.CancelOperation(req)
_check_status(response)
def _fetch_more(self):
"""Send another TFetchResultsReq and update state"""
assert(self._state == self._STATE_RUNNING), "Should be running when in _fetch_more"
assert(self._operationHandle is not None), "Should have an op handle in _fetch_more"
if not self._operationHandle.hasResultSet:
raise ProgrammingError("No result set")
req = ttypes.TFetchResultsReq(
operationHandle=self._operationHandle,
orientation=ttypes.TFetchOrientation.FETCH_NEXT,
maxRows=self.arraysize,
)
response = self._connection.client.FetchResults(req)
_check_status(response)
schema = self.description
assert not response.results.rows, 'expected data in columnar format'
columns = [_unwrap_column(col, col_schema[1]) for col, col_schema in
zip(response.results.columns, schema)]
new_data = list(zip(*columns))
self._data += new_data
# response.hasMoreRows seems to always be False, so we instead check the number of rows
# https://github.com/apache/hive/blob/release-1.2.1/service/src/java/org/apache/hive/service/cli/thrift/ThriftCLIService.java#L678
# if not response.hasMoreRows:
if not new_data:
self._state = self._STATE_FINISHED
def poll(self, get_progress_update=True):
"""Poll for and return the raw status data provided by the Hive Thrift REST API.
:returns: ``ttypes.TGetOperationStatusResp``
:raises: ``ProgrammingError`` when no query has been started
.. note::
This is not a part of DB-API.
"""
if self._state == self._STATE_NONE:
raise ProgrammingError("No query yet")
req = ttypes.TGetOperationStatusReq(
operationHandle=self._operationHandle,
getProgressUpdate=get_progress_update,
)
response = self._connection.client.GetOperationStatus(req)
_check_status(response)
return response
def fetch_logs(self):
"""Retrieve the logs produced by the execution of the query.
Can be called multiple times to fetch the logs produced after the previous call.
:returns: list<str>
:raises: ``ProgrammingError`` when no query has been started
.. note::
This is not a part of DB-API.
"""
if self._state == self._STATE_NONE:
raise ProgrammingError("No query yet")
try: # Older Hive instances require logs to be retrieved using GetLog
req = ttypes.TGetLogReq(operationHandle=self._operationHandle)
logs = self._connection.client.GetLog(req).log.splitlines()
except ttypes.TApplicationException as e: # Otherwise, retrieve logs using newer method
if e.type != ttypes.TApplicationException.UNKNOWN_METHOD:
raise
logs = []
while True:
req = ttypes.TFetchResultsReq(
operationHandle=self._operationHandle,
orientation=ttypes.TFetchOrientation.FETCH_NEXT,
maxRows=self.arraysize,
fetchType=1 # 0: results, 1: logs
)
response = self._connection.client.FetchResults(req)
_check_status(response)
assert not response.results.rows, 'expected data in columnar format'
assert len(response.results.columns) == 1, response.results.columns
new_logs = _unwrap_column(response.results.columns[0])
logs += new_logs
if not new_logs:
break
return logs
#
# Type Objects and Constructors
#
for type_id in constants.PRIMITIVE_TYPES:
name = ttypes.TTypeId._VALUES_TO_NAMES[type_id]
setattr(sys.modules[__name__], name, DBAPITypeObject([name]))
#
# Private utilities
#
def _unwrap_column(col, type_=None):
"""Return a list of raw values from a TColumn instance."""
for attr, wrapper in iteritems(col.__dict__):
if wrapper is not None:
result = wrapper.values
nulls = wrapper.nulls # bit set describing what's null
assert isinstance(nulls, bytes)
for i, char in enumerate(nulls):
byte = ord(char) if sys.version_info[0] == 2 else char
for b in range(8):
if byte & (1 << b):
result[i * 8 + b] = None
converter = TYPES_CONVERTER.get(type_, None)
if converter and type_:
result = [converter(row) if row else row for row in result]
return result
raise DataError("Got empty column value {}".format(col)) # pragma: no cover
def _check_status(response):
"""Raise an OperationalError if the status is not success"""
_logger.debug(response)
if response.status.statusCode != ttypes.TStatusCode.SUCCESS_STATUS:
raise OperationalError(response)
"""DB-API implementation backed by Presto
See http://www.python.org/dev/peps/pep-0249/
Many docstrings in this file are based on the PEP, which is in the public domain.
"""
from __future__ import absolute_import
from __future__ import unicode_literals
from builtins import object
from pyhive import common
from pyhive.common import DBAPITypeObject
# Make all exceptions visible in this module per DB-API
from pyhive.exc import * # noqa
import base64
import getpass
import logging
import requests
from requests.auth import HTTPBasicAuth
try: # Python 3
import urllib.parse as urlparse
except ImportError: # Python 2
import urlparse
# PEP 249 module globals
apilevel = '2.0'
threadsafety = 2 # Threads may share the module and connections.
paramstyle = 'pyformat' # Python extended format codes, e.g. ...WHERE name=%(name)s
_logger = logging.getLogger(__name__)
_escaper = common.ParamEscaper()
def connect(*args, **kwargs):
"""Constructor for creating a connection to the database. See class :py:class:`Connection` for
arguments.
:returns: a :py:class:`Connection` object.
"""
return Connection(*args, **kwargs)
class Connection(object):
"""Presto does not have a notion of a persistent connection.
Thus, these objects are small stateless factories for cursors, which do all the real work.
"""
def __init__(self, *args, **kwargs):
self._args = args
self._kwargs = kwargs
def close(self):
"""Presto does not have anything to close"""
# TODO cancel outstanding queries?
pass
def commit(self):
"""Presto does not support transactions"""
pass
def cursor(self):
"""Return a new :py:class:`Cursor` object using the connection."""
return Cursor(*self._args, **self._kwargs)
def rollback(self):
raise NotSupportedError("Presto does not have transactions") # pragma: no cover
class Cursor(common.DBAPICursor):
"""These objects represent a database cursor, which is used to manage the context of a fetch
operation.
Cursors are not isolated, i.e., any changes done to the database by a cursor are immediately
visible by other cursors or connections.
"""
def __init__(self, host, port='8080', username=None, catalog='hive',
schema='default', poll_interval=1, source='pyhive', session_props=None,
protocol='http', password=None, requests_session=None, requests_kwargs=None):
"""
:param host: hostname to connect to, e.g. ``presto.example.com``
:param port: int -- port, defaults to 8080
:param username: string -- defaults to system user name
:param catalog: string -- defaults to ``hive``
:param schema: string -- defaults to ``default``
:param poll_interval: int -- how often to ask the Presto REST interface for a progress
update, defaults to a second
:param source: string -- arbitrary identifier (shows up in the Presto monitoring page)
:param protocol: string -- network protocol, valid options are ``http`` and ``https``.
defaults to ``http``
:param password: string -- Deprecated. Defaults to ``None``.
Using BasicAuth, requires ``https``.
Prefer ``requests_kwargs={'auth': HTTPBasicAuth(username, password)}``.
May not be specified with ``requests_kwargs['auth']``.
:param requests_session: a ``requests.Session`` object for advanced usage. If absent, this
class will use the default requests behavior of making a new session per HTTP request.
Caller is responsible for closing session.
:param requests_kwargs: Additional ``**kwargs`` to pass to requests
"""
super(Cursor, self).__init__(poll_interval)
# Config
self._host = host
self._port = port
self._username = username or getpass.getuser()
self._catalog = catalog
self._schema = schema
self._arraysize = 1
self._poll_interval = poll_interval
self._source = source
self._session_props = session_props if session_props is not None else {}
if protocol not in ('http', 'https'):
raise ValueError("Protocol must be http/https, was {!r}".format(protocol))
self._protocol = protocol
self._requests_session = requests_session or requests
requests_kwargs = dict(requests_kwargs) if requests_kwargs is not None else {}
if password is not None and 'auth' in requests_kwargs:
raise ValueError("Cannot use both password and requests_kwargs authentication")
for k in ('method', 'url', 'data', 'headers'):
if k in requests_kwargs:
raise ValueError("Cannot override requests argument {}".format(k))
if password is not None:
requests_kwargs['auth'] = HTTPBasicAuth(username, password)
if protocol != 'https':
raise ValueError("Protocol must be https when passing a password")
self._requests_kwargs = requests_kwargs
self._reset_state()
def _reset_state(self):
"""Reset state about the previous query in preparation for running another query"""
super(Cursor, self)._reset_state()
self._nextUri = None
self._columns = None
@property
def description(self):
"""This read-only attribute is a sequence of 7-item sequences.
Each of these sequences contains information describing one result column:
- name
- type_code
- display_size (None in current implementation)
- internal_size (None in current implementation)
- precision (None in current implementation)
- scale (None in current implementation)
- null_ok (always True in current implementation)
The ``type_code`` can be interpreted by comparing it to the Type Objects specified in the
section below.
"""
# Sleep until we're done or we got the columns
self._fetch_while(
lambda: self._columns is None and
self._state not in (self._STATE_NONE, self._STATE_FINISHED)
)
if self._columns is None:
return None
return [
# name, type_code, display_size, internal_size, precision, scale, null_ok
(col['name'], col['type'], None, None, None, None, True)
for col in self._columns
]
def execute(self, operation, parameters=None):
"""Prepare and execute a database operation (query or command).
Return values are not defined.
"""
headers = {
'X-Presto-Catalog': self._catalog,
'X-Presto-Schema': self._schema,
'X-Presto-Source': self._source,
'X-Presto-User': self._username,
}
if self._session_props:
headers['X-Presto-Session'] = ','.join(
'{}={}'.format(propname, propval)
for propname, propval in self._session_props.items()
)
# Prepare statement
if parameters is None:
sql = operation
else:
sql = operation % _escaper.escape_args(parameters)
self._reset_state()
self._state = self._STATE_RUNNING
url = urlparse.urlunparse((
self._protocol,
'{}:{}'.format(self._host, self._port), '/v1/statement', None, None, None))
_logger.info('%s', sql)
_logger.debug("Headers: %s", headers)
response = self._requests_session.post(
url, data=sql.encode('utf-8'), headers=headers, **self._requests_kwargs)
self._process_response(response)
def cancel(self):
if self._state == self._STATE_NONE:
raise ProgrammingError("No query yet")
if self._nextUri is None:
assert self._state == self._STATE_FINISHED, "Should be finished if nextUri is None"
return
response = self._requests_session.delete(self._nextUri, **self._requests_kwargs)
if response.status_code != requests.codes.no_content:
fmt = "Unexpected status code after cancel {}\n{}"
raise OperationalError(fmt.format(response.status_code, response.content))
self._state = self._STATE_FINISHED
self._nextUri = None
def poll(self):
"""Poll for and return the raw status data provided by the Presto REST API.
:returns: dict -- JSON status information or ``None`` if the query is done
:raises: ``ProgrammingError`` when no query has been started
.. note::
This is not a part of DB-API.
"""
if self._state == self._STATE_NONE:
raise ProgrammingError("No query yet")
if self._nextUri is None:
assert self._state == self._STATE_FINISHED, "Should be finished if nextUri is None"
return None
response = self._requests_session.get(self._nextUri, **self._requests_kwargs)
self._process_response(response)
return response.json()
def _fetch_more(self):
"""Fetch the next URI and update state"""
self._process_response(self._requests_session.get(self._nextUri, **self._requests_kwargs))
def _decode_binary(self, rows):
# As of Presto 0.69, binary data is returned as the varbinary type in base64 format
# This function decodes base64 data in place
for i, col in enumerate(self.description):
if col[1] == 'varbinary':
for row in rows:
if row[i] is not None:
row[i] = base64.b64decode(row[i])
def _process_response(self, response):
"""Given the JSON response from Presto's REST API, update the internal state with the next
URI and any data from the response
"""
# TODO handle HTTP 503
if response.status_code != requests.codes.ok:
fmt = "Unexpected status code {}\n{}"
raise OperationalError(fmt.format(response.status_code, response.content))
response_json = response.json()
_logger.debug("Got response %s", response_json)
assert self._state == self._STATE_RUNNING, "Should be running if processing response"
self._nextUri = response_json.get('nextUri')
self._columns = response_json.get('columns')
if 'X-Presto-Clear-Session' in response.headers:
propname = response.headers['X-Presto-Clear-Session']
self._session_props.pop(propname, None)
if 'X-Presto-Set-Session' in response.headers:
propname, propval = response.headers['X-Presto-Set-Session'].split('=', 1)
self._session_props[propname] = propval
if 'data' in response_json:
assert self._columns
new_data = response_json['data']
self._decode_binary(new_data)
self._data += map(tuple, new_data)
if 'nextUri' not in response_json:
self._state = self._STATE_FINISHED
if 'error' in response_json:
raise DatabaseError(response_json['error'])
#
# Type Objects and Constructors
#
# See types in presto-main/src/main/java/com/facebook/presto/tuple/TupleInfo.java
FIXED_INT_64 = DBAPITypeObject(['bigint'])
VARIABLE_BINARY = DBAPITypeObject(['varchar'])
DOUBLE = DBAPITypeObject(['double'])
BOOLEAN = DBAPITypeObject(['boolean'])
"""Integration between SQLAlchemy and Hive.
Some code based on
https://github.com/zzzeek/sqlalchemy/blob/rel_0_5/lib/sqlalchemy/databases/sqlite.py
which is released under the MIT license.
"""
from __future__ import absolute_import
from __future__ import unicode_literals
import datetime
import decimal
import re
from sqlalchemy import exc
from sqlalchemy import processors
from sqlalchemy import types
from sqlalchemy import util
# TODO shouldn't use mysql type
from sqlalchemy.databases import mysql
from sqlalchemy.engine import default
from sqlalchemy.sql import compiler
from sqlalchemy.sql.compiler import SQLCompiler
from pyhive import hive
from pyhive.common import UniversalSet
from dateutil.parser import parse
from decimal import Decimal
class HiveStringTypeBase(types.TypeDecorator):
"""Translates strings returned by Thrift into something else"""
impl = types.String
def process_bind_param(self, value, dialect):
raise NotImplementedError("Writing to Hive not supported")
class HiveDate(HiveStringTypeBase):
"""Translates date strings to date objects"""
impl = types.DATE
def process_result_value(self, value, dialect):
return processors.str_to_date(value)
def result_processor(self, dialect, coltype):
def process(value):
if isinstance(value, datetime.datetime):
return value.date()
elif isinstance(value, datetime.date):
return value
elif value is not None:
return parse(value).date()
else:
return None
return process
def adapt(self, impltype, **kwargs):
return self.impl
class HiveTimestamp(HiveStringTypeBase):
"""Translates timestamp strings to datetime objects"""
impl = types.TIMESTAMP
def process_result_value(self, value, dialect):
return processors.str_to_datetime(value)
def result_processor(self, dialect, coltype):
def process(value):
if isinstance(value, datetime.datetime):
return value
elif value is not None:
return parse(value)
else:
return None
return process
def adapt(self, impltype, **kwargs):
return self.impl
class HiveDecimal(HiveStringTypeBase):
"""Translates strings to decimals"""
impl = types.DECIMAL
def process_result_value(self, value, dialect):
if value is not None:
return decimal.Decimal(value)
else:
return None
def result_processor(self, dialect, coltype):
def process(value):
if isinstance(value, Decimal):
return value
elif value is not None:
return Decimal(value)
else:
return None
return process
def adapt(self, impltype, **kwargs):
return self.impl
class HiveIdentifierPreparer(compiler.IdentifierPreparer):
# Just quote everything to make things simpler / easier to upgrade
reserved_words = UniversalSet()
def __init__(self, dialect):
super(HiveIdentifierPreparer, self).__init__(
dialect,
initial_quote='`',
)
_type_map = {
'boolean': types.Boolean,
'tinyint': mysql.MSTinyInteger,
'smallint': types.SmallInteger,
'int': types.Integer,
'bigint': types.BigInteger,
'float': types.Float,
'double': types.Float,
'string': types.String,
'date': HiveDate,
'timestamp': HiveTimestamp,
'binary': types.String,
'array': types.String,
'map': types.String,
'struct': types.String,
'uniontype': types.String,
'decimal': HiveDecimal,
}
class HiveCompiler(SQLCompiler):
def visit_concat_op_binary(self, binary, operator, **kw):
return "concat(%s, %s)" % (self.process(binary.left), self.process(binary.right))
def visit_insert(self, *args, **kwargs):
result = super(HiveCompiler, self).visit_insert(*args, **kwargs)
# Massage the result into Hive's format
# INSERT INTO `pyhive_test_database`.`test_table` (`a`) SELECT ...
# =>
# INSERT INTO TABLE `pyhive_test_database`.`test_table` SELECT ...
regex = r'^(INSERT INTO) ([^\s]+) \([^\)]*\)'
assert re.search(regex, result), "Unexpected visit_insert result: {}".format(result)
return re.sub(regex, r'\1 TABLE \2', result)
def visit_column(self, *args, **kwargs):
result = super(HiveCompiler, self).visit_column(*args, **kwargs)
dot_count = result.count('.')
assert dot_count in (0, 1, 2), "Unexpected visit_column result {}".format(result)
if dot_count == 2:
# we have something of the form schema.table.column
# hive doesn't like the schema in front, so chop it out
result = result[result.index('.') + 1:]
return result
def visit_char_length_func(self, fn, **kw):
return 'length{}'.format(self.function_argspec(fn, **kw))
class HiveTypeCompiler(compiler.GenericTypeCompiler):
def visit_INTEGER(self, type_):
return 'INT'
def visit_NUMERIC(self, type_):
return 'DECIMAL'
def visit_CHAR(self, type_):
return 'STRING'
def visit_VARCHAR(self, type_):
return 'STRING'
def visit_NCHAR(self, type_):
return 'STRING'
def visit_TEXT(self, type_):
return 'STRING'
def visit_CLOB(self, type_):
return 'STRING'
def visit_BLOB(self, type_):
return 'BINARY'
def visit_TIME(self, type_):
return 'TIMESTAMP'
def visit_DATE(self, type_):
return 'TIMESTAMP'
def visit_DATETIME(self, type_):
return 'TIMESTAMP'
class HiveExecutionContext(default.DefaultExecutionContext):
"""This is pretty much the same as SQLiteExecutionContext to work around the same issue.
http://docs.sqlalchemy.org/en/latest/dialects/sqlite.html#dotted-column-names
engine = create_engine('hive://...', execution_options={'hive_raw_colnames': True})
"""
@util.memoized_property
def _preserve_raw_colnames(self):
# Ideally, this would also gate on hive.resultset.use.unique.column.names
return self.execution_options.get('hive_raw_colnames', False)
def _translate_colname(self, colname):
# Adjust for dotted column names.
# When hive.resultset.use.unique.column.names is true (the default), Hive returns column
# names as "tablename.colname" in cursor.description.
if not self._preserve_raw_colnames and '.' in colname:
return colname.split('.')[-1], colname
else:
return colname, None
class HiveDialect(default.DefaultDialect):
name = b'hive'
driver = b'thrift'
execution_ctx_cls = HiveExecutionContext
preparer = HiveIdentifierPreparer
statement_compiler = HiveCompiler
supports_views = True
supports_alter = True
supports_pk_autoincrement = False
supports_default_values = False
supports_empty_insert = False
supports_native_decimal = True
supports_native_boolean = True
supports_unicode_statements = True
supports_unicode_binds = True
returns_unicode_strings = True
description_encoding = None
supports_multivalues_insert = True
type_compiler = HiveTypeCompiler
@classmethod
def dbapi(cls):
return hive
def create_connect_args(self, url):
kwargs = {
'host': url.host,
'port': url.port or 10000,
'username': url.username,
'password': url.password,
'database': url.database or 'default',
}
kwargs.update(url.query)
return [], kwargs
def get_schema_names(self, connection, **kw):
# Equivalent to SHOW DATABASES
return [row[0] for row in connection.execute('SHOW SCHEMAS')]
def get_view_names(self, connection, schema=None, **kw):
# Hive does not provide functionality to query tableType
# This allows reflection to not crash at the cost of being inaccurate
return self.get_table_names(connection, schema, **kw)
def _get_table_columns(self, connection, table_name, schema):
full_table = table_name
if schema:
full_table = schema + '.' + table_name
# TODO using TGetColumnsReq hangs after sending TFetchResultsReq.
# Using DESCRIBE works but is uglier.
try:
# This needs the table name to be unescaped (no backticks).
rows = connection.execute('DESCRIBE {}'.format(full_table)).fetchall()
except exc.OperationalError as e:
# Does the table exist?
regex_fmt = r'TExecuteStatementResp.*SemanticException.*Table not found {}'
regex = regex_fmt.format(re.escape(full_table))
if re.search(regex, e.args[0]):
raise exc.NoSuchTableError(full_table)
else:
raise
else:
# Hive is stupid: this is what I get from DESCRIBE some_schema.does_not_exist
regex = r'Table .* does not exist'
if len(rows) == 1 and re.match(regex, rows[0].col_name):
raise exc.NoSuchTableError(full_table)
return rows
def has_table(self, connection, table_name, schema=None):
try:
self._get_table_columns(connection, table_name, schema)
return True
except exc.NoSuchTableError:
return False
def get_columns(self, connection, table_name, schema=None, **kw):
rows = self._get_table_columns(connection, table_name, schema)
# Strip whitespace
rows = [[col.strip() if col else None for col in row] for row in rows]
# Filter out empty rows and comment
rows = [row for row in rows if row[0] and row[0] != '# col_name']
result = []
for (col_name, col_type, _comment) in rows:
if col_name == '# Partition Information':
break
# Take out the more detailed type information
# e.g. 'map<int,int>' -> 'map'
# 'decimal(10,1)' -> decimal
col_type = re.search(r'^\w+', col_type).group(0)
try:
coltype = _type_map[col_type]
except KeyError:
util.warn("Did not recognize type '%s' of column '%s'" % (col_type, col_name))
coltype = types.NullType
result.append({
'name': col_name,
'type': coltype,
'nullable': True,
'default': None,
})
return result
def get_foreign_keys(self, connection, table_name, schema=None, **kw):
# Hive has no support for foreign keys.
return []
def get_pk_constraint(self, connection, table_name, schema=None, **kw):
# Hive has no support for primary keys.
return []
def get_indexes(self, connection, table_name, schema=None, **kw):
rows = self._get_table_columns(connection, table_name, schema)
# Strip whitespace
rows = [[col.strip() if col else None for col in row] for row in rows]
# Filter out empty rows and comment
rows = [row for row in rows if row[0] and row[0] != '# col_name']
for i, (col_name, _col_type, _comment) in enumerate(rows):
if col_name == '# Partition Information':
break
# Handle partition columns
col_names = []
for col_name, _col_type, _comment in rows[i + 1:]:
col_names.append(col_name)
if col_names:
return [{'name': 'partition', 'column_names': col_names, 'unique': False}]
else:
return []
def get_table_names(self, connection, schema=None, **kw):
query = 'SHOW TABLES'
if schema:
query += ' IN ' + self.identifier_preparer.quote_identifier(schema)
return [row[0] for row in connection.execute(query)]
def do_rollback(self, dbapi_connection):
# No transactions for Hive
pass
def _check_unicode_returns(self, connection, additional_tests=None):
# We decode everything as UTF-8
return True
def _check_unicode_description(self, connection):
# We decode everything as UTF-8
return True
"""Integration between SQLAlchemy and Presto.
Some code based on
https://github.com/zzzeek/sqlalchemy/blob/rel_0_5/lib/sqlalchemy/databases/sqlite.py
which is released under the MIT license.
"""
from __future__ import absolute_import
from __future__ import unicode_literals
import re
from sqlalchemy import exc
from sqlalchemy import types
from sqlalchemy import util
# TODO shouldn't use mysql type
from sqlalchemy.databases import mysql
from sqlalchemy.engine import default
from sqlalchemy.sql import compiler
from sqlalchemy.sql.compiler import SQLCompiler
from pyhive import presto
from pyhive.common import UniversalSet
class PrestoIdentifierPreparer(compiler.IdentifierPreparer):
# Just quote everything to make things simpler / easier to upgrade
reserved_words = UniversalSet()
_type_map = {
'boolean': types.Boolean,
'tinyint': mysql.MSTinyInteger,
'smallint': types.SmallInteger,
'integer': types.Integer,
'bigint': types.BigInteger,
'real': types.Float,
'double': types.Float,
'varchar': types.String,
'timestamp': types.TIMESTAMP,
'date': types.DATE,
'varbinary': types.VARBINARY,
}
class PrestoCompiler(SQLCompiler):
def visit_char_length_func(self, fn, **kw):
return 'length{}'.format(self.function_argspec(fn, **kw))
class PrestoTypeCompiler(compiler.GenericTypeCompiler):
def visit_CLOB(self, type_, **kw):
raise ValueError("Presto does not support the CLOB column type.")
def visit_NCLOB(self, type_, **kw):
raise ValueError("Presto does not support the NCLOB column type.")
def visit_DATETIME(self, type_, **kw):
raise ValueError("Presto does not support the DATETIME column type.")
def visit_FLOAT(self, type_, **kw):
return 'DOUBLE'
def visit_TEXT(self, type_, **kw):
if type_.length:
return 'VARCHAR({:d})'.format(type_.length)
else:
return 'VARCHAR'
class PrestoDialect(default.DefaultDialect):
name = 'presto'
driver = 'rest'
paramstyle = 'pyformat'
preparer = PrestoIdentifierPreparer
statement_compiler = PrestoCompiler
supports_alter = False
supports_pk_autoincrement = False
supports_default_values = False
supports_empty_insert = False
supports_unicode_statements = True
supports_unicode_binds = True
returns_unicode_strings = True
description_encoding = None
supports_native_boolean = True
type_compiler = PrestoTypeCompiler
@classmethod
def dbapi(cls):
return presto
def create_connect_args(self, url):
db_parts = (url.database or 'hive').split('/')
kwargs = {
'host': url.host,
'port': url.port or 8080,
'username': url.username,
'password': url.password
}
kwargs.update(url.query)
if len(db_parts) == 1:
kwargs['catalog'] = db_parts[0]
elif len(db_parts) == 2:
kwargs['catalog'] = db_parts[0]
kwargs['schema'] = db_parts[1]
else:
raise ValueError("Unexpected database format {}".format(url.database))
return [], kwargs
def get_schema_names(self, connection, **kw):
return [row.Schema for row in connection.execute('SHOW SCHEMAS')]
def _get_table_columns(self, connection, table_name, schema):
full_table = self.identifier_preparer.quote_identifier(table_name)
if schema:
full_table = self.identifier_preparer.quote_identifier(schema) + '.' + full_table
try:
return connection.execute('SHOW COLUMNS FROM {}'.format(full_table))
except (presto.DatabaseError, exc.DatabaseError) as e:
# Normally SQLAlchemy should wrap this exception in sqlalchemy.exc.DatabaseError, which
# it successfully does in the Hive version. The difference with Presto is that this
# error is raised when fetching the cursor's description rather than the initial execute
# call. SQLAlchemy doesn't handle this. Thus, we catch the unwrapped
# presto.DatabaseError here.
# Does the table exist?
msg = (
e.args[0].get('message') if e.args and isinstance(e.args[0], dict)
else e.args[0] if e.args and isinstance(e.args[0], str)
else None
)
regex = r"Table\ \'.*{}\'\ does\ not\ exist".format(re.escape(table_name))
if msg and re.search(regex, msg):
raise exc.NoSuchTableError(table_name)
else:
raise
def has_table(self, connection, table_name, schema=None):
try:
self._get_table_columns(connection, table_name, schema)
return True
except exc.NoSuchTableError:
return False
def get_columns(self, connection, table_name, schema=None, **kw):
rows = self._get_table_columns(connection, table_name, schema)
result = []
for row in rows:
try:
coltype = _type_map[row.Type]
except KeyError:
util.warn("Did not recognize type '%s' of column '%s'" % (row.Type, row.Column))
coltype = types.NullType
result.append({
'name': row.Column,
'type': coltype,
# newer Presto no longer includes this column
'nullable': getattr(row, 'Null', True),
'default': None,
})
return result
def get_foreign_keys(self, connection, table_name, schema=None, **kw):
# Hive has no support for foreign keys.
return []
def get_pk_constraint(self, connection, table_name, schema=None, **kw):
# Hive has no support for primary keys.
return []
def get_indexes(self, connection, table_name, schema=None, **kw):
rows = self._get_table_columns(connection, table_name, schema)
col_names = []
for row in rows:
part_key = 'Partition Key'
# Presto puts this information in one of 3 places depending on version
# - a boolean column named "Partition Key"
# - a string in the "Comment" column
# - a string in the "Extra" column
is_partition_key = (
(part_key in row and row[part_key])
or row['Comment'].startswith(part_key)
or ('Extra' in row and 'partition key' in row['Extra'])
)
if is_partition_key:
col_names.append(row['Column'])
if col_names:
return [{'name': 'partition', 'column_names': col_names, 'unique': False}]
else:
return []
def get_table_names(self, connection, schema=None, **kw):
query = 'SHOW TABLES'
if schema:
query += ' FROM ' + self.identifier_preparer.quote_identifier(schema)
return [row.Table for row in connection.execute(query)]
def do_rollback(self, dbapi_connection):
# No transactions for Presto
pass
def _check_unicode_returns(self, connection, additional_tests=None):
# requests gives back Unicode strings
return True
def _check_unicode_description(self, connection):
# requests gives back Unicode strings
return True
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
from thrift.Thrift import TProcessor, TMessageType
from thrift.protocol import TProtocolDecorator, TMultiplexedProtocol
from thrift.protocol.TProtocol import TProtocolException
class TMultiplexedProcessor(TProcessor):
def __init__(self):
self.defaultProcessor = None
self.services = {}
def registerDefault(self, processor):
"""
If a non-multiplexed processor connects to the server and wants to
communicate, use the given processor to handle it. This mechanism
allows servers to upgrade from non-multiplexed to multiplexed in a
backwards-compatible way and still handle old clients.
"""
self.defaultProcessor = processor
def registerProcessor(self, serviceName, processor):
self.services[serviceName] = processor
def on_message_begin(self, func):
for key in self.services.keys():
self.services[key].on_message_begin(func)
def process(self, iprot, oprot):
(name, type, seqid) = iprot.readMessageBegin()
if type != TMessageType.CALL and type != TMessageType.ONEWAY:
raise TProtocolException(
TProtocolException.NOT_IMPLEMENTED,
"TMultiplexedProtocol only supports CALL & ONEWAY")
index = name.find(TMultiplexedProtocol.SEPARATOR)
if index < 0:
if self.defaultProcessor:
return self.defaultProcessor.process(
StoredMessageProtocol(iprot, (name, type, seqid)), oprot)
else:
raise TProtocolException(
TProtocolException.NOT_IMPLEMENTED,
"Service name not found in message name: " + name + ". " +
"Did you forget to use TMultiplexedProtocol in your client?")
serviceName = name[0:index]
call = name[index + len(TMultiplexedProtocol.SEPARATOR):]
if serviceName not in self.services:
raise TProtocolException(
TProtocolException.NOT_IMPLEMENTED,
"Service name not found: " + serviceName + ". " +
"Did you forget to call registerProcessor()?")
standardMessage = (call, type, seqid)
return self.services[serviceName].process(
StoredMessageProtocol(iprot, standardMessage), oprot)
class StoredMessageProtocol(TProtocolDecorator.TProtocolDecorator):
def __init__(self, protocol, messageBegin):
self.messageBegin = messageBegin
def readMessageBegin(self):
return self.messageBegin
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
from thrift.Thrift import TType
TYPE_IDX = 1
SPEC_ARGS_IDX = 3
SPEC_ARGS_CLASS_REF_IDX = 0
SPEC_ARGS_THRIFT_SPEC_IDX = 1
def fix_spec(all_structs):
"""Wire up recursive references for all TStruct definitions inside of each thrift_spec."""
for struc in all_structs:
spec = struc.thrift_spec
for thrift_spec in spec:
if thrift_spec is None:
continue
elif thrift_spec[TYPE_IDX] == TType.STRUCT:
other = thrift_spec[SPEC_ARGS_IDX][SPEC_ARGS_CLASS_REF_IDX].thrift_spec
thrift_spec[SPEC_ARGS_IDX][SPEC_ARGS_THRIFT_SPEC_IDX] = other
elif thrift_spec[TYPE_IDX] in (TType.LIST, TType.SET):
_fix_list_or_set(thrift_spec[SPEC_ARGS_IDX])
elif thrift_spec[TYPE_IDX] == TType.MAP:
_fix_map(thrift_spec[SPEC_ARGS_IDX])
def _fix_list_or_set(element_type):
# For a list or set, the thrift_spec entry looks like,
# (1, TType.LIST, 'lister', (TType.STRUCT, [RecList, None], False), None, ), # 1
# so ``element_type`` will be,
# (TType.STRUCT, [RecList, None], False)
if element_type[0] == TType.STRUCT:
element_type[1][1] = element_type[1][0].thrift_spec
elif element_type[0] in (TType.LIST, TType.SET):
_fix_list_or_set(element_type[1])
elif element_type[0] == TType.MAP:
_fix_map(element_type[1])
def _fix_map(element_type):
# For a map of key -> value type, ``element_type`` will be,
# (TType.I16, None, TType.STRUCT, [RecMapBasic, None], False), None, )
# which is just a normal struct definition.
#
# For a map of key -> list / set, ``element_type`` will be,
# (TType.I16, None, TType.LIST, (TType.STRUCT, [RecMapList, None], False), False)
# and we need to process the 3rd element as a list.
#
# For a map of key -> map, ``element_type`` will be,
# (TType.I16, None, TType.MAP, (TType.I16, None, TType.STRUCT,
# [RecMapMap, None], False), False)
# and need to process 3rd element as a map.
# Is the map key a struct?
if element_type[0] == TType.STRUCT:
element_type[1][1] = element_type[1][0].thrift_spec
elif element_type[0] in (TType.LIST, TType.SET):
_fix_list_or_set(element_type[1])
elif element_type[0] == TType.MAP:
_fix_map(element_type[1])
# Is the map value a struct?
if element_type[2] == TType.STRUCT:
element_type[3][1] = element_type[3][0].thrift_spec
elif element_type[2] in (TType.LIST, TType.SET):
_fix_list_or_set(element_type[3])
elif element_type[2] == TType.MAP:
_fix_map(element_type[3])
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
from os import path
from SCons.Builder import Builder
from six.moves import map
def scons_env(env, add=''):
opath = path.dirname(path.abspath('$TARGET'))
lstr = 'thrift --gen cpp -o ' + opath + ' ' + add + ' $SOURCE'
cppbuild = Builder(action=lstr)
env.Append(BUILDERS={'ThriftCpp': cppbuild})
def gen_cpp(env, dir, file):
scons_env(env)
suffixes = ['_types.h', '_types.cpp']
targets = map(lambda s: 'gen-cpp/' + file + s, suffixes)
return env.ThriftCpp(targets, dir + file + '.thrift')
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
from .protocol import TBinaryProtocol
from .transport import TTransport
def serialize(thrift_object,
protocol_factory=TBinaryProtocol.TBinaryProtocolFactory()):
transport = TTransport.TMemoryBuffer()
protocol = protocol_factory.getProtocol(transport)
thrift_object.write(protocol)
return transport.getvalue()
def deserialize(base,
buf,
protocol_factory=TBinaryProtocol.TBinaryProtocolFactory()):
transport = TTransport.TMemoryBuffer(buf)
protocol = protocol_factory.getProtocol(transport)
base.read(protocol)
return base
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
from __future__ import absolute_import
import logging
import socket
import struct
from .transport.TTransport import TTransportException, TTransportBase, TMemoryBuffer
from io import BytesIO
from collections import deque
from contextlib import contextmanager
from tornado import gen, iostream, ioloop, tcpserver, concurrent
__all__ = ['TTornadoServer', 'TTornadoStreamTransport']
logger = logging.getLogger(__name__)
class _Lock(object):
def __init__(self):
self._waiters = deque()
def acquired(self):
return len(self._waiters) > 0
@gen.coroutine
def acquire(self):
blocker = self._waiters[-1] if self.acquired() else None
future = concurrent.Future()
self._waiters.append(future)
if blocker:
yield blocker
raise gen.Return(self._lock_context())
def release(self):
assert self.acquired(), 'Lock not aquired'
future = self._waiters.popleft()
future.set_result(None)
@contextmanager
def _lock_context(self):
try:
yield
finally:
self.release()
class TTornadoStreamTransport(TTransportBase):
"""a framed, buffered transport over a Tornado stream"""
def __init__(self, host, port, stream=None, io_loop=None):
self.host = host
self.port = port
self.io_loop = io_loop or ioloop.IOLoop.current()
self.__wbuf = BytesIO()
self._read_lock = _Lock()
# servers provide a ready-to-go stream
self.stream = stream
def with_timeout(self, timeout, future):
return gen.with_timeout(timeout, future, self.io_loop)
@gen.coroutine
def open(self, timeout=None):
logger.debug('socket connecting')
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
self.stream = iostream.IOStream(sock)
try:
connect = self.stream.connect((self.host, self.port))
if timeout is not None:
yield self.with_timeout(timeout, connect)
else:
yield connect
except (socket.error, IOError, ioloop.TimeoutError) as e:
message = 'could not connect to {}:{} ({})'.format(self.host, self.port, e)
raise TTransportException(
type=TTransportException.NOT_OPEN,
message=message)
raise gen.Return(self)
def set_close_callback(self, callback):
"""
Should be called only after open() returns
"""
self.stream.set_close_callback(callback)
def close(self):
# don't raise if we intend to close
self.stream.set_close_callback(None)
self.stream.close()
def read(self, _):
# The generated code for Tornado shouldn't do individual reads -- only
# frames at a time
assert False, "you're doing it wrong"
@contextmanager
def io_exception_context(self):
try:
yield
except (socket.error, IOError) as e:
raise TTransportException(
type=TTransportException.END_OF_FILE,
message=str(e))
except iostream.StreamBufferFullError as e:
raise TTransportException(
type=TTransportException.UNKNOWN,
message=str(e))
@gen.coroutine
def readFrame(self):
# IOStream processes reads one at a time
with (yield self._read_lock.acquire()):
with self.io_exception_context():
frame_header = yield self.stream.read_bytes(4)
if len(frame_header) == 0:
raise iostream.StreamClosedError('Read zero bytes from stream')
frame_length, = struct.unpack('!i', frame_header)
frame = yield self.stream.read_bytes(frame_length)
raise gen.Return(frame)
def write(self, buf):
self.__wbuf.write(buf)
def flush(self):
frame = self.__wbuf.getvalue()
# reset wbuf before write/flush to preserve state on underlying failure
frame_length = struct.pack('!i', len(frame))
self.__wbuf = BytesIO()
with self.io_exception_context():
return self.stream.write(frame_length + frame)
class TTornadoServer(tcpserver.TCPServer):
def __init__(self, processor, iprot_factory, oprot_factory=None,
*args, **kwargs):
super(TTornadoServer, self).__init__(*args, **kwargs)
self._processor = processor
self._iprot_factory = iprot_factory
self._oprot_factory = (oprot_factory if oprot_factory is not None
else iprot_factory)
@gen.coroutine
def handle_stream(self, stream, address):
host, port = address[:2]
trans = TTornadoStreamTransport(host=host, port=port, stream=stream,
io_loop=self.io_loop)
oprot = self._oprot_factory.getProtocol(trans)
try:
while not trans.stream.closed():
try:
frame = yield trans.readFrame()
except TTransportException as e:
if e.type == TTransportException.END_OF_FILE:
break
else:
raise
tr = TMemoryBuffer(frame)
iprot = self._iprot_factory.getProtocol(tr)
yield self._processor.process(iprot, oprot)
except Exception:
logger.exception('thrift exception in handle_stream')
trans.close()
logger.info('client disconnected %s:%d', host, port)
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
import sys
class TType(object):
STOP = 0
VOID = 1
BOOL = 2
BYTE = 3
I08 = 3
DOUBLE = 4
I16 = 6
I32 = 8
I64 = 10
STRING = 11
UTF7 = 11
STRUCT = 12
MAP = 13
SET = 14
LIST = 15
UTF8 = 16
UTF16 = 17
_VALUES_TO_NAMES = (
'STOP',
'VOID',
'BOOL',
'BYTE',
'DOUBLE',
None,
'I16',
None,
'I32',
None,
'I64',
'STRING',
'STRUCT',
'MAP',
'SET',
'LIST',
'UTF8',
'UTF16',
)
class TMessageType(object):
CALL = 1
REPLY = 2
EXCEPTION = 3
ONEWAY = 4
class TProcessor(object):
"""Base class for processor, which works on two streams."""
def process(self, iprot, oprot):
"""
Process a request. The normal behvaior is to have the
processor invoke the correct handler and then it is the
server's responsibility to write the response to oprot.
"""
pass
def on_message_begin(self, func):
"""
Install a callback that receives (name, type, seqid)
after the message header is read.
"""
pass
class TException(Exception):
"""Base class for all thrift exceptions."""
# BaseException.message is deprecated in Python v[2.6,3.0)
if (2, 6, 0) <= sys.version_info < (3, 0):
def _get_message(self):
return self._message
def _set_message(self, message):
self._message = message
message = property(_get_message, _set_message)
def __init__(self, message=None):
Exception.__init__(self, message)
self.message = message
class TApplicationException(TException):
"""Application level thrift exceptions."""
UNKNOWN = 0
UNKNOWN_METHOD = 1
INVALID_MESSAGE_TYPE = 2
WRONG_METHOD_NAME = 3
BAD_SEQUENCE_ID = 4
MISSING_RESULT = 5
INTERNAL_ERROR = 6
PROTOCOL_ERROR = 7
INVALID_TRANSFORM = 8
INVALID_PROTOCOL = 9
UNSUPPORTED_CLIENT_TYPE = 10
def __init__(self, type=UNKNOWN, message=None):
TException.__init__(self, message)
self.type = type
def __str__(self):
if self.message:
return self.message
elif self.type == self.UNKNOWN_METHOD:
return 'Unknown method'
elif self.type == self.INVALID_MESSAGE_TYPE:
return 'Invalid message type'
elif self.type == self.WRONG_METHOD_NAME:
return 'Wrong method name'
elif self.type == self.BAD_SEQUENCE_ID:
return 'Bad sequence ID'
elif self.type == self.MISSING_RESULT:
return 'Missing result'
elif self.type == self.INTERNAL_ERROR:
return 'Internal error'
elif self.type == self.PROTOCOL_ERROR:
return 'Protocol error'
elif self.type == self.INVALID_TRANSFORM:
return 'Invalid transform'
elif self.type == self.INVALID_PROTOCOL:
return 'Invalid protocol'
elif self.type == self.UNSUPPORTED_CLIENT_TYPE:
return 'Unsupported client type'
else:
return 'Default (unknown) TApplicationException'
def read(self, iprot):
iprot.readStructBegin()
while True:
(fname, ftype, fid) = iprot.readFieldBegin()
if ftype == TType.STOP:
break
if fid == 1:
if ftype == TType.STRING:
self.message = iprot.readString()
else:
iprot.skip(ftype)
elif fid == 2:
if ftype == TType.I32:
self.type = iprot.readI32()
else:
iprot.skip(ftype)
else:
iprot.skip(ftype)
iprot.readFieldEnd()
iprot.readStructEnd()
def write(self, oprot):
oprot.writeStructBegin('TApplicationException')
if self.message is not None:
oprot.writeFieldBegin('message', TType.STRING, 1)
oprot.writeString(self.message)
oprot.writeFieldEnd()
if self.type is not None:
oprot.writeFieldBegin('type', TType.I32, 2)
oprot.writeI32(self.type)
oprot.writeFieldEnd()
oprot.writeFieldStop()
oprot.writeStructEnd()
class TFrozenDict(dict):
"""A dictionary that is "frozen" like a frozenset"""
def __init__(self, *args, **kwargs):
super(TFrozenDict, self).__init__(*args, **kwargs)
# Sort the items so they will be in a consistent order.
# XOR in the hash of the class so we don't collide with
# the hash of a list of tuples.
self.__hashval = hash(TFrozenDict) ^ hash(tuple(sorted(self.items())))
def __setitem__(self, *args):
raise TypeError("Can't modify frozen TFreezableDict")
def __delitem__(self, *args):
raise TypeError("Can't modify frozen TFreezableDict")
def __hash__(self):
return self.__hashval
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
__all__ = ['Thrift', 'TSCons']
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
import sys
if sys.version_info[0] == 2:
from cStringIO import StringIO as BufferIO
def binary_to_str(bin_val):
return bin_val
def str_to_binary(str_val):
return str_val
def byte_index(bytes_val, i):
return ord(bytes_val[i])
else:
from io import BytesIO as BufferIO # noqa
def binary_to_str(bin_val):
return bin_val.decode('utf8')
def str_to_binary(str_val):
return bytes(str_val, 'utf8')
def byte_index(bytes_val, i):
return bytes_val[i]
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
from thrift.transport import TTransport
class TBase(object):
__slots__ = ()
def __repr__(self):
L = ['%s=%r' % (key, getattr(self, key)) for key in self.__slots__]
return '%s(%s)' % (self.__class__.__name__, ', '.join(L))
def __eq__(self, other):
if not isinstance(other, self.__class__):
return False
for attr in self.__slots__:
my_val = getattr(self, attr)
other_val = getattr(other, attr)
if my_val != other_val:
return False
return True
def __ne__(self, other):
return not (self == other)
def read(self, iprot):
if (iprot._fast_decode is not None and
isinstance(iprot.trans, TTransport.CReadableTransport) and
self.thrift_spec is not None):
iprot._fast_decode(self, iprot, [self.__class__, self.thrift_spec])
else:
iprot.readStruct(self, self.thrift_spec)
def write(self, oprot):
if (oprot._fast_encode is not None and self.thrift_spec is not None):
oprot.trans.write(
oprot._fast_encode(self, [self.__class__, self.thrift_spec]))
else:
oprot.writeStruct(self, self.thrift_spec)
class TExceptionBase(TBase, Exception):
pass
class TFrozenBase(TBase):
def __setitem__(self, *args):
raise TypeError("Can't modify frozen struct")
def __delitem__(self, *args):
raise TypeError("Can't modify frozen struct")
def __hash__(self, *args):
return hash(self.__class__) ^ hash(self.__slots__)
@classmethod
def read(cls, iprot):
if (iprot._fast_decode is not None and
isinstance(iprot.trans, TTransport.CReadableTransport) and
cls.thrift_spec is not None):
self = cls()
return iprot._fast_decode(None, iprot,
[self.__class__, self.thrift_spec])
else:
return iprot.readStruct(cls, cls.thrift_spec, True)
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
from .TProtocol import TType, TProtocolBase, TProtocolException, TProtocolFactory
from struct import pack, unpack
class TBinaryProtocol(TProtocolBase):
"""Binary implementation of the Thrift protocol driver."""
# NastyHaxx. Python 2.4+ on 32-bit machines forces hex constants to be
# positive, converting this into a long. If we hardcode the int value
# instead it'll stay in 32 bit-land.
# VERSION_MASK = 0xffff0000
VERSION_MASK = -65536
# VERSION_1 = 0x80010000
VERSION_1 = -2147418112
TYPE_MASK = 0x000000ff
def __init__(self, trans, strictRead=False, strictWrite=True, **kwargs):
TProtocolBase.__init__(self, trans)
self.strictRead = strictRead
self.strictWrite = strictWrite
self.string_length_limit = kwargs.get('string_length_limit', None)
self.container_length_limit = kwargs.get('container_length_limit', None)
def _check_string_length(self, length):
self._check_length(self.string_length_limit, length)
def _check_container_length(self, length):
self._check_length(self.container_length_limit, length)
def writeMessageBegin(self, name, type, seqid):
if self.strictWrite:
self.writeI32(TBinaryProtocol.VERSION_1 | type)
self.writeString(name)
self.writeI32(seqid)
else:
self.writeString(name)
self.writeByte(type)
self.writeI32(seqid)
def writeMessageEnd(self):
pass
def writeStructBegin(self, name):
pass
def writeStructEnd(self):
pass
def writeFieldBegin(self, name, type, id):
self.writeByte(type)
self.writeI16(id)
def writeFieldEnd(self):
pass
def writeFieldStop(self):
self.writeByte(TType.STOP)
def writeMapBegin(self, ktype, vtype, size):
self.writeByte(ktype)
self.writeByte(vtype)
self.writeI32(size)
def writeMapEnd(self):
pass
def writeListBegin(self, etype, size):
self.writeByte(etype)
self.writeI32(size)
def writeListEnd(self):
pass
def writeSetBegin(self, etype, size):
self.writeByte(etype)
self.writeI32(size)
def writeSetEnd(self):
pass
def writeBool(self, bool):
if bool:
self.writeByte(1)
else:
self.writeByte(0)
def writeByte(self, byte):
buff = pack("!b", byte)
self.trans.write(buff)
def writeI16(self, i16):
buff = pack("!h", i16)
self.trans.write(buff)
def writeI32(self, i32):
buff = pack("!i", i32)
self.trans.write(buff)
def writeI64(self, i64):
buff = pack("!q", i64)
self.trans.write(buff)
def writeDouble(self, dub):
buff = pack("!d", dub)
self.trans.write(buff)
def writeBinary(self, str):
self.writeI32(len(str))
self.trans.write(str)
def readMessageBegin(self):
sz = self.readI32()
if sz < 0:
version = sz & TBinaryProtocol.VERSION_MASK
if version != TBinaryProtocol.VERSION_1:
raise TProtocolException(
type=TProtocolException.BAD_VERSION,
message='Bad version in readMessageBegin: %d' % (sz))
type = sz & TBinaryProtocol.TYPE_MASK
name = self.readString()
seqid = self.readI32()
else:
if self.strictRead:
raise TProtocolException(type=TProtocolException.BAD_VERSION,
message='No protocol version header')
name = self.trans.readAll(sz)
type = self.readByte()
seqid = self.readI32()
return (name, type, seqid)
def readMessageEnd(self):
pass
def readStructBegin(self):
pass
def readStructEnd(self):
pass
def readFieldBegin(self):
type = self.readByte()
if type == TType.STOP:
return (None, type, 0)
id = self.readI16()
return (None, type, id)
def readFieldEnd(self):
pass
def readMapBegin(self):
ktype = self.readByte()
vtype = self.readByte()
size = self.readI32()
self._check_container_length(size)
return (ktype, vtype, size)
def readMapEnd(self):
pass
def readListBegin(self):
etype = self.readByte()
size = self.readI32()
self._check_container_length(size)
return (etype, size)
def readListEnd(self):
pass
def readSetBegin(self):
etype = self.readByte()
size = self.readI32()
self._check_container_length(size)
return (etype, size)
def readSetEnd(self):
pass
def readBool(self):
byte = self.readByte()
if byte == 0:
return False
return True
def readByte(self):
buff = self.trans.readAll(1)
val, = unpack('!b', buff)
return val
def readI16(self):
buff = self.trans.readAll(2)
val, = unpack('!h', buff)
return val
def readI32(self):
buff = self.trans.readAll(4)
val, = unpack('!i', buff)
return val
def readI64(self):
buff = self.trans.readAll(8)
val, = unpack('!q', buff)
return val
def readDouble(self):
buff = self.trans.readAll(8)
val, = unpack('!d', buff)
return val
def readBinary(self):
size = self.readI32()
self._check_string_length(size)
s = self.trans.readAll(size)
return s
class TBinaryProtocolFactory(TProtocolFactory):
def __init__(self, strictRead=False, strictWrite=True, **kwargs):
self.strictRead = strictRead
self.strictWrite = strictWrite
self.string_length_limit = kwargs.get('string_length_limit', None)
self.container_length_limit = kwargs.get('container_length_limit', None)
def getProtocol(self, trans):
prot = TBinaryProtocol(trans, self.strictRead, self.strictWrite,
string_length_limit=self.string_length_limit,
container_length_limit=self.container_length_limit)
return prot
class TBinaryProtocolAccelerated(TBinaryProtocol):
"""C-Accelerated version of TBinaryProtocol.
This class does not override any of TBinaryProtocol's methods,
but the generated code recognizes it directly and will call into
our C module to do the encoding, bypassing this object entirely.
We inherit from TBinaryProtocol so that the normal TBinaryProtocol
encoding can happen if the fastbinary module doesn't work for some
reason. (TODO(dreiss): Make this happen sanely in more cases.)
To disable this behavior, pass fallback=False constructor argument.
In order to take advantage of the C module, just use
TBinaryProtocolAccelerated instead of TBinaryProtocol.
NOTE: This code was contributed by an external developer.
The internal Thrift team has reviewed and tested it,
but we cannot guarantee that it is production-ready.
Please feel free to report bugs and/or success stories
to the public mailing list.
"""
pass
def __init__(self, *args, **kwargs):
fallback = kwargs.pop('fallback', True)
super(TBinaryProtocolAccelerated, self).__init__(*args, **kwargs)
try:
from thrift.protocol import fastbinary
except ImportError:
if not fallback:
raise
else:
self._fast_decode = fastbinary.decode_binary
self._fast_encode = fastbinary.encode_binary
class TBinaryProtocolAcceleratedFactory(TProtocolFactory):
def __init__(self,
string_length_limit=None,
container_length_limit=None,
fallback=True):
self.string_length_limit = string_length_limit
self.container_length_limit = container_length_limit
self._fallback = fallback
def getProtocol(self, trans):
return TBinaryProtocolAccelerated(
trans,
string_length_limit=self.string_length_limit,
container_length_limit=self.container_length_limit,
fallback=self._fallback)
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
from .TProtocol import TType, TProtocolBase, TProtocolException, TProtocolFactory, checkIntegerLimits
from struct import pack, unpack
from ..compat import binary_to_str, str_to_binary
__all__ = ['TCompactProtocol', 'TCompactProtocolFactory']
CLEAR = 0
FIELD_WRITE = 1
VALUE_WRITE = 2
CONTAINER_WRITE = 3
BOOL_WRITE = 4
FIELD_READ = 5
CONTAINER_READ = 6
VALUE_READ = 7
BOOL_READ = 8
def make_helper(v_from, container):
def helper(func):
def nested(self, *args, **kwargs):
assert self.state in (v_from, container), (self.state, v_from, container)
return func(self, *args, **kwargs)
return nested
return helper
writer = make_helper(VALUE_WRITE, CONTAINER_WRITE)
reader = make_helper(VALUE_READ, CONTAINER_READ)
def makeZigZag(n, bits):
checkIntegerLimits(n, bits)
return (n << 1) ^ (n >> (bits - 1))
def fromZigZag(n):
return (n >> 1) ^ -(n & 1)
def writeVarint(trans, n):
assert n >= 0, "Input to TCompactProtocol writeVarint cannot be negative!"
out = bytearray()
while True:
if n & ~0x7f == 0:
out.append(n)
break
else:
out.append((n & 0xff) | 0x80)
n = n >> 7
trans.write(bytes(out))
def readVarint(trans):
result = 0
shift = 0
while True:
x = trans.readAll(1)
byte = ord(x)
result |= (byte & 0x7f) << shift
if byte >> 7 == 0:
return result
shift += 7
class CompactType(object):
STOP = 0x00
TRUE = 0x01
FALSE = 0x02
BYTE = 0x03
I16 = 0x04
I32 = 0x05
I64 = 0x06
DOUBLE = 0x07
BINARY = 0x08
LIST = 0x09
SET = 0x0A
MAP = 0x0B
STRUCT = 0x0C
CTYPES = {
TType.STOP: CompactType.STOP,
TType.BOOL: CompactType.TRUE, # used for collection
TType.BYTE: CompactType.BYTE,
TType.I16: CompactType.I16,
TType.I32: CompactType.I32,
TType.I64: CompactType.I64,
TType.DOUBLE: CompactType.DOUBLE,
TType.STRING: CompactType.BINARY,
TType.STRUCT: CompactType.STRUCT,
TType.LIST: CompactType.LIST,
TType.SET: CompactType.SET,
TType.MAP: CompactType.MAP,
}
TTYPES = {}
for k, v in CTYPES.items():
TTYPES[v] = k
TTYPES[CompactType.FALSE] = TType.BOOL
del k
del v
class TCompactProtocol(TProtocolBase):
"""Compact implementation of the Thrift protocol driver."""
PROTOCOL_ID = 0x82
VERSION = 1
VERSION_MASK = 0x1f
TYPE_MASK = 0xe0
TYPE_BITS = 0x07
TYPE_SHIFT_AMOUNT = 5
def __init__(self, trans,
string_length_limit=None,
container_length_limit=None):
TProtocolBase.__init__(self, trans)
self.state = CLEAR
self.__last_fid = 0
self.__bool_fid = None
self.__bool_value = None
self.__structs = []
self.__containers = []
self.string_length_limit = string_length_limit
self.container_length_limit = container_length_limit
def _check_string_length(self, length):
self._check_length(self.string_length_limit, length)
def _check_container_length(self, length):
self._check_length(self.container_length_limit, length)
def __writeVarint(self, n):
writeVarint(self.trans, n)
def writeMessageBegin(self, name, type, seqid):
assert self.state == CLEAR
self.__writeUByte(self.PROTOCOL_ID)
self.__writeUByte(self.VERSION | (type << self.TYPE_SHIFT_AMOUNT))
# The sequence id is a signed 32-bit integer but the compact protocol
# writes this out as a "var int" which is always positive, and attempting
# to write a negative number results in an infinite loop, so we may
# need to do some conversion here...
tseqid = seqid
if tseqid < 0:
tseqid = 2147483648 + (2147483648 + tseqid)
self.__writeVarint(tseqid)
self.__writeBinary(str_to_binary(name))
self.state = VALUE_WRITE
def writeMessageEnd(self):
assert self.state == VALUE_WRITE
self.state = CLEAR
def writeStructBegin(self, name):
assert self.state in (CLEAR, CONTAINER_WRITE, VALUE_WRITE), self.state
self.__structs.append((self.state, self.__last_fid))
self.state = FIELD_WRITE
self.__last_fid = 0
def writeStructEnd(self):
assert self.state == FIELD_WRITE
self.state, self.__last_fid = self.__structs.pop()
def writeFieldStop(self):
self.__writeByte(0)
def __writeFieldHeader(self, type, fid):
delta = fid - self.__last_fid
if 0 < delta <= 15:
self.__writeUByte(delta << 4 | type)
else:
self.__writeByte(type)
self.__writeI16(fid)
self.__last_fid = fid
def writeFieldBegin(self, name, type, fid):
assert self.state == FIELD_WRITE, self.state
if type == TType.BOOL:
self.state = BOOL_WRITE
self.__bool_fid = fid
else:
self.state = VALUE_WRITE
self.__writeFieldHeader(CTYPES[type], fid)
def writeFieldEnd(self):
assert self.state in (VALUE_WRITE, BOOL_WRITE), self.state
self.state = FIELD_WRITE
def __writeUByte(self, byte):
self.trans.write(pack('!B', byte))
def __writeByte(self, byte):
self.trans.write(pack('!b', byte))
def __writeI16(self, i16):
self.__writeVarint(makeZigZag(i16, 16))
def __writeSize(self, i32):
self.__writeVarint(i32)
def writeCollectionBegin(self, etype, size):
assert self.state in (VALUE_WRITE, CONTAINER_WRITE), self.state
if size <= 14:
self.__writeUByte(size << 4 | CTYPES[etype])
else:
self.__writeUByte(0xf0 | CTYPES[etype])
self.__writeSize(size)
self.__containers.append(self.state)
self.state = CONTAINER_WRITE
writeSetBegin = writeCollectionBegin
writeListBegin = writeCollectionBegin
def writeMapBegin(self, ktype, vtype, size):
assert self.state in (VALUE_WRITE, CONTAINER_WRITE), self.state
if size == 0:
self.__writeByte(0)
else:
self.__writeSize(size)
self.__writeUByte(CTYPES[ktype] << 4 | CTYPES[vtype])
self.__containers.append(self.state)
self.state = CONTAINER_WRITE
def writeCollectionEnd(self):
assert self.state == CONTAINER_WRITE, self.state
self.state = self.__containers.pop()
writeMapEnd = writeCollectionEnd
writeSetEnd = writeCollectionEnd
writeListEnd = writeCollectionEnd
def writeBool(self, bool):
if self.state == BOOL_WRITE:
if bool:
ctype = CompactType.TRUE
else:
ctype = CompactType.FALSE
self.__writeFieldHeader(ctype, self.__bool_fid)
elif self.state == CONTAINER_WRITE:
if bool:
self.__writeByte(CompactType.TRUE)
else:
self.__writeByte(CompactType.FALSE)
else:
raise AssertionError("Invalid state in compact protocol")
writeByte = writer(__writeByte)
writeI16 = writer(__writeI16)
@writer
def writeI32(self, i32):
self.__writeVarint(makeZigZag(i32, 32))
@writer
def writeI64(self, i64):
self.__writeVarint(makeZigZag(i64, 64))
@writer
def writeDouble(self, dub):
self.trans.write(pack('<d', dub))
def __writeBinary(self, s):
self.__writeSize(len(s))
self.trans.write(s)
writeBinary = writer(__writeBinary)
def readFieldBegin(self):
assert self.state == FIELD_READ, self.state
type = self.__readUByte()
if type & 0x0f == TType.STOP:
return (None, 0, 0)
delta = type >> 4
if delta == 0:
fid = self.__readI16()
else:
fid = self.__last_fid + delta
self.__last_fid = fid
type = type & 0x0f
if type == CompactType.TRUE:
self.state = BOOL_READ
self.__bool_value = True
elif type == CompactType.FALSE:
self.state = BOOL_READ
self.__bool_value = False
else:
self.state = VALUE_READ
return (None, self.__getTType(type), fid)
def readFieldEnd(self):
assert self.state in (VALUE_READ, BOOL_READ), self.state
self.state = FIELD_READ
def __readUByte(self):
result, = unpack('!B', self.trans.readAll(1))
return result
def __readByte(self):
result, = unpack('!b', self.trans.readAll(1))
return result
def __readVarint(self):
return readVarint(self.trans)
def __readZigZag(self):
return fromZigZag(self.__readVarint())
def __readSize(self):
result = self.__readVarint()
if result < 0:
raise TProtocolException("Length < 0")
return result
def readMessageBegin(self):
assert self.state == CLEAR
proto_id = self.__readUByte()
if proto_id != self.PROTOCOL_ID:
raise TProtocolException(TProtocolException.BAD_VERSION,
'Bad protocol id in the message: %d' % proto_id)
ver_type = self.__readUByte()
type = (ver_type >> self.TYPE_SHIFT_AMOUNT) & self.TYPE_BITS
version = ver_type & self.VERSION_MASK
if version != self.VERSION:
raise TProtocolException(TProtocolException.BAD_VERSION,
'Bad version: %d (expect %d)' % (version, self.VERSION))
seqid = self.__readVarint()
# the sequence is a compact "var int" which is treaded as unsigned,
# however the sequence is actually signed...
if seqid > 2147483647:
seqid = -2147483648 - (2147483648 - seqid)
name = binary_to_str(self.__readBinary())
return (name, type, seqid)
def readMessageEnd(self):
assert self.state == CLEAR
assert len(self.__structs) == 0
def readStructBegin(self):
assert self.state in (CLEAR, CONTAINER_READ, VALUE_READ), self.state
self.__structs.append((self.state, self.__last_fid))
self.state = FIELD_READ
self.__last_fid = 0
def readStructEnd(self):
assert self.state == FIELD_READ
self.state, self.__last_fid = self.__structs.pop()
def readCollectionBegin(self):
assert self.state in (VALUE_READ, CONTAINER_READ), self.state
size_type = self.__readUByte()
size = size_type >> 4
type = self.__getTType(size_type)
if size == 15:
size = self.__readSize()
self._check_container_length(size)
self.__containers.append(self.state)
self.state = CONTAINER_READ
return type, size
readSetBegin = readCollectionBegin
readListBegin = readCollectionBegin
def readMapBegin(self):
assert self.state in (VALUE_READ, CONTAINER_READ), self.state
size = self.__readSize()
self._check_container_length(size)
types = 0
if size > 0:
types = self.__readUByte()
vtype = self.__getTType(types)
ktype = self.__getTType(types >> 4)
self.__containers.append(self.state)
self.state = CONTAINER_READ
return (ktype, vtype, size)
def readCollectionEnd(self):
assert self.state == CONTAINER_READ, self.state
self.state = self.__containers.pop()
readSetEnd = readCollectionEnd
readListEnd = readCollectionEnd
readMapEnd = readCollectionEnd
def readBool(self):
if self.state == BOOL_READ:
return self.__bool_value == CompactType.TRUE
elif self.state == CONTAINER_READ:
return self.__readByte() == CompactType.TRUE
else:
raise AssertionError("Invalid state in compact protocol: %d" %
self.state)
readByte = reader(__readByte)
__readI16 = __readZigZag
readI16 = reader(__readZigZag)
readI32 = reader(__readZigZag)
readI64 = reader(__readZigZag)
@reader
def readDouble(self):
buff = self.trans.readAll(8)
val, = unpack('<d', buff)
return val
def __readBinary(self):
size = self.__readSize()
self._check_string_length(size)
return self.trans.readAll(size)
readBinary = reader(__readBinary)
def __getTType(self, byte):
return TTYPES[byte & 0x0f]
class TCompactProtocolFactory(TProtocolFactory):
def __init__(self,
string_length_limit=None,
container_length_limit=None):
self.string_length_limit = string_length_limit
self.container_length_limit = container_length_limit
def getProtocol(self, trans):
return TCompactProtocol(trans,
self.string_length_limit,
self.container_length_limit)
class TCompactProtocolAccelerated(TCompactProtocol):
"""C-Accelerated version of TCompactProtocol.
This class does not override any of TCompactProtocol's methods,
but the generated code recognizes it directly and will call into
our C module to do the encoding, bypassing this object entirely.
We inherit from TCompactProtocol so that the normal TCompactProtocol
encoding can happen if the fastbinary module doesn't work for some
reason.
To disable this behavior, pass fallback=False constructor argument.
In order to take advantage of the C module, just use
TCompactProtocolAccelerated instead of TCompactProtocol.
"""
pass
def __init__(self, *args, **kwargs):
fallback = kwargs.pop('fallback', True)
super(TCompactProtocolAccelerated, self).__init__(*args, **kwargs)
try:
from thrift.protocol import fastbinary
except ImportError:
if not fallback:
raise
else:
self._fast_decode = fastbinary.decode_compact
self._fast_encode = fastbinary.encode_compact
class TCompactProtocolAcceleratedFactory(TProtocolFactory):
def __init__(self,
string_length_limit=None,
container_length_limit=None,
fallback=True):
self.string_length_limit = string_length_limit
self.container_length_limit = container_length_limit
self._fallback = fallback
def getProtocol(self, trans):
return TCompactProtocolAccelerated(
trans,
string_length_limit=self.string_length_limit,
container_length_limit=self.container_length_limit,
fallback=self._fallback)
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
from thrift.protocol.TBinaryProtocol import TBinaryProtocolAccelerated
from thrift.protocol.TCompactProtocol import TCompactProtocolAccelerated
from thrift.protocol.TProtocol import TProtocolBase, TProtocolException, TProtocolFactory
from thrift.Thrift import TApplicationException, TMessageType
from thrift.transport.THeaderTransport import THeaderTransport, THeaderSubprotocolID, THeaderClientType
PROTOCOLS_BY_ID = {
THeaderSubprotocolID.BINARY: TBinaryProtocolAccelerated,
THeaderSubprotocolID.COMPACT: TCompactProtocolAccelerated,
}
class THeaderProtocol(TProtocolBase):
"""A framed protocol with headers and payload transforms.
THeaderProtocol frames other Thrift protocols and adds support for optional
out-of-band headers. The currently supported subprotocols are
TBinaryProtocol and TCompactProtocol.
It's also possible to apply transforms to the encoded message payload. The
only transform currently supported is to gzip.
When used in a server, THeaderProtocol can accept messages from
non-THeaderProtocol clients if allowed (see `allowed_client_types`). This
includes framed and unframed transports and both TBinaryProtocol and
TCompactProtocol. The server will respond in the appropriate dialect for
the connected client. HTTP clients are not currently supported.
THeaderProtocol does not currently support THTTPServer, TNonblockingServer,
or TProcessPoolServer.
See doc/specs/HeaderFormat.md for details of the wire format.
"""
def __init__(self, transport, allowed_client_types):
# much of the actual work for THeaderProtocol happens down in
# THeaderTransport since we need to do low-level shenanigans to detect
# if the client is sending us headers or one of the headerless formats
# we support. this wraps the real transport with the one that does all
# the magic.
if not isinstance(transport, THeaderTransport):
transport = THeaderTransport(transport, allowed_client_types)
super(THeaderProtocol, self).__init__(transport)
self._set_protocol()
def get_headers(self):
return self.trans.get_headers()
def set_header(self, key, value):
self.trans.set_header(key, value)
def clear_headers(self):
self.trans.clear_headers()
def add_transform(self, transform_id):
self.trans.add_transform(transform_id)
def writeMessageBegin(self, name, ttype, seqid):
self.trans.sequence_id = seqid
return self._protocol.writeMessageBegin(name, ttype, seqid)
def writeMessageEnd(self):
return self._protocol.writeMessageEnd()
def writeStructBegin(self, name):
return self._protocol.writeStructBegin(name)
def writeStructEnd(self):
return self._protocol.writeStructEnd()
def writeFieldBegin(self, name, ttype, fid):
return self._protocol.writeFieldBegin(name, ttype, fid)
def writeFieldEnd(self):
return self._protocol.writeFieldEnd()
def writeFieldStop(self):
return self._protocol.writeFieldStop()
def writeMapBegin(self, ktype, vtype, size):
return self._protocol.writeMapBegin(ktype, vtype, size)
def writeMapEnd(self):
return self._protocol.writeMapEnd()
def writeListBegin(self, etype, size):
return self._protocol.writeListBegin(etype, size)
def writeListEnd(self):
return self._protocol.writeListEnd()
def writeSetBegin(self, etype, size):
return self._protocol.writeSetBegin(etype, size)
def writeSetEnd(self):
return self._protocol.writeSetEnd()
def writeBool(self, bool_val):
return self._protocol.writeBool(bool_val)
def writeByte(self, byte):
return self._protocol.writeByte(byte)
def writeI16(self, i16):
return self._protocol.writeI16(i16)
def writeI32(self, i32):
return self._protocol.writeI32(i32)
def writeI64(self, i64):
return self._protocol.writeI64(i64)
def writeDouble(self, dub):
return self._protocol.writeDouble(dub)
def writeBinary(self, str_val):
return self._protocol.writeBinary(str_val)
def _set_protocol(self):
try:
protocol_cls = PROTOCOLS_BY_ID[self.trans.protocol_id]
except KeyError:
raise TApplicationException(
TProtocolException.INVALID_PROTOCOL,
"Unknown protocol requested.",
)
self._protocol = protocol_cls(self.trans)
self._fast_encode = self._protocol._fast_encode
self._fast_decode = self._protocol._fast_decode
def readMessageBegin(self):
try:
self.trans.readFrame(0)
self._set_protocol()
except TApplicationException as exc:
self._protocol.writeMessageBegin(b"", TMessageType.EXCEPTION, 0)
exc.write(self._protocol)
self._protocol.writeMessageEnd()
self.trans.flush()
return self._protocol.readMessageBegin()
def readMessageEnd(self):
return self._protocol.readMessageEnd()
def readStructBegin(self):
return self._protocol.readStructBegin()
def readStructEnd(self):
return self._protocol.readStructEnd()
def readFieldBegin(self):
return self._protocol.readFieldBegin()
def readFieldEnd(self):
return self._protocol.readFieldEnd()
def readMapBegin(self):
return self._protocol.readMapBegin()
def readMapEnd(self):
return self._protocol.readMapEnd()
def readListBegin(self):
return self._protocol.readListBegin()
def readListEnd(self):
return self._protocol.readListEnd()
def readSetBegin(self):
return self._protocol.readSetBegin()
def readSetEnd(self):
return self._protocol.readSetEnd()
def readBool(self):
return self._protocol.readBool()
def readByte(self):
return self._protocol.readByte()
def readI16(self):
return self._protocol.readI16()
def readI32(self):
return self._protocol.readI32()
def readI64(self):
return self._protocol.readI64()
def readDouble(self):
return self._protocol.readDouble()
def readBinary(self):
return self._protocol.readBinary()
class THeaderProtocolFactory(TProtocolFactory):
def __init__(self, allowed_client_types=(THeaderClientType.HEADERS,)):
self.allowed_client_types = allowed_client_types
def getProtocol(self, trans):
return THeaderProtocol(trans, self.allowed_client_types)
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
from .TProtocol import (TType, TProtocolBase, TProtocolException,
TProtocolFactory, checkIntegerLimits)
import base64
import math
import sys
from ..compat import str_to_binary
__all__ = ['TJSONProtocol',
'TJSONProtocolFactory',
'TSimpleJSONProtocol',
'TSimpleJSONProtocolFactory']
VERSION = 1
COMMA = b','
COLON = b':'
LBRACE = b'{'
RBRACE = b'}'
LBRACKET = b'['
RBRACKET = b']'
QUOTE = b'"'
BACKSLASH = b'\\'
ZERO = b'0'
ESCSEQ0 = ord('\\')
ESCSEQ1 = ord('u')
ESCAPE_CHAR_VALS = {
'"': '\\"',
'\\': '\\\\',
'\b': '\\b',
'\f': '\\f',
'\n': '\\n',
'\r': '\\r',
'\t': '\\t',
# '/': '\\/',
}
ESCAPE_CHARS = {
b'"': '"',
b'\\': '\\',
b'b': '\b',
b'f': '\f',
b'n': '\n',
b'r': '\r',
b't': '\t',
b'/': '/',
}
NUMERIC_CHAR = b'+-.0123456789Ee'
CTYPES = {
TType.BOOL: 'tf',
TType.BYTE: 'i8',
TType.I16: 'i16',
TType.I32: 'i32',
TType.I64: 'i64',
TType.DOUBLE: 'dbl',
TType.STRING: 'str',
TType.STRUCT: 'rec',
TType.LIST: 'lst',
TType.SET: 'set',
TType.MAP: 'map',
}
JTYPES = {}
for key in CTYPES.keys():
JTYPES[CTYPES[key]] = key
class JSONBaseContext(object):
def __init__(self, protocol):
self.protocol = protocol
self.first = True
def doIO(self, function):
pass
def write(self):
pass
def read(self):
pass
def escapeNum(self):
return False
def __str__(self):
return self.__class__.__name__
class JSONListContext(JSONBaseContext):
def doIO(self, function):
if self.first is True:
self.first = False
else:
function(COMMA)
def write(self):
self.doIO(self.protocol.trans.write)
def read(self):
self.doIO(self.protocol.readJSONSyntaxChar)
class JSONPairContext(JSONBaseContext):
def __init__(self, protocol):
super(JSONPairContext, self).__init__(protocol)
self.colon = True
def doIO(self, function):
if self.first:
self.first = False
self.colon = True
else:
function(COLON if self.colon else COMMA)
self.colon = not self.colon
def write(self):
self.doIO(self.protocol.trans.write)
def read(self):
self.doIO(self.protocol.readJSONSyntaxChar)
def escapeNum(self):
return self.colon
def __str__(self):
return '%s, colon=%s' % (self.__class__.__name__, self.colon)
class LookaheadReader():
hasData = False
data = ''
def __init__(self, protocol):
self.protocol = protocol
def read(self):
if self.hasData is True:
self.hasData = False
else:
self.data = self.protocol.trans.read(1)
return self.data
def peek(self):
if self.hasData is False:
self.data = self.protocol.trans.read(1)
self.hasData = True
return self.data
class TJSONProtocolBase(TProtocolBase):
def __init__(self, trans):
TProtocolBase.__init__(self, trans)
self.resetWriteContext()
self.resetReadContext()
# We don't have length limit implementation for JSON protocols
@property
def string_length_limit(senf):
return None
@property
def container_length_limit(senf):
return None
def resetWriteContext(self):
self.context = JSONBaseContext(self)
self.contextStack = [self.context]
def resetReadContext(self):
self.resetWriteContext()
self.reader = LookaheadReader(self)
def pushContext(self, ctx):
self.contextStack.append(ctx)
self.context = ctx
def popContext(self):
self.contextStack.pop()
if self.contextStack:
self.context = self.contextStack[-1]
else:
self.context = JSONBaseContext(self)
def writeJSONString(self, string):
self.context.write()
json_str = ['"']
for s in string:
escaped = ESCAPE_CHAR_VALS.get(s, s)
json_str.append(escaped)
json_str.append('"')
self.trans.write(str_to_binary(''.join(json_str)))
def writeJSONNumber(self, number, formatter='{0}'):
self.context.write()
jsNumber = str(formatter.format(number)).encode('ascii')
if self.context.escapeNum():
self.trans.write(QUOTE)
self.trans.write(jsNumber)
self.trans.write(QUOTE)
else:
self.trans.write(jsNumber)
def writeJSONBase64(self, binary):
self.context.write()
self.trans.write(QUOTE)
self.trans.write(base64.b64encode(binary))
self.trans.write(QUOTE)
def writeJSONObjectStart(self):
self.context.write()
self.trans.write(LBRACE)
self.pushContext(JSONPairContext(self))
def writeJSONObjectEnd(self):
self.popContext()
self.trans.write(RBRACE)
def writeJSONArrayStart(self):
self.context.write()
self.trans.write(LBRACKET)
self.pushContext(JSONListContext(self))
def writeJSONArrayEnd(self):
self.popContext()
self.trans.write(RBRACKET)
def readJSONSyntaxChar(self, character):
current = self.reader.read()
if character != current:
raise TProtocolException(TProtocolException.INVALID_DATA,
"Unexpected character: %s" % current)
def _isHighSurrogate(self, codeunit):
return codeunit >= 0xd800 and codeunit <= 0xdbff
def _isLowSurrogate(self, codeunit):
return codeunit >= 0xdc00 and codeunit <= 0xdfff
def _toChar(self, high, low=None):
if not low:
if sys.version_info[0] == 2:
return ("\\u%04x" % high).decode('unicode-escape') \
.encode('utf-8')
else:
return chr(high)
else:
codepoint = (1 << 16) + ((high & 0x3ff) << 10)
codepoint += low & 0x3ff
if sys.version_info[0] == 2:
s = "\\U%08x" % codepoint
return s.decode('unicode-escape').encode('utf-8')
else:
return chr(codepoint)
def readJSONString(self, skipContext):
highSurrogate = None
string = []
if skipContext is False:
self.context.read()
self.readJSONSyntaxChar(QUOTE)
while True:
character = self.reader.read()
if character == QUOTE:
break
if ord(character) == ESCSEQ0:
character = self.reader.read()
if ord(character) == ESCSEQ1:
character = self.trans.read(4).decode('ascii')
codeunit = int(character, 16)
if self._isHighSurrogate(codeunit):
if highSurrogate:
raise TProtocolException(
TProtocolException.INVALID_DATA,
"Expected low surrogate char")
highSurrogate = codeunit
continue
elif self._isLowSurrogate(codeunit):
if not highSurrogate:
raise TProtocolException(
TProtocolException.INVALID_DATA,
"Expected high surrogate char")
character = self._toChar(highSurrogate, codeunit)
highSurrogate = None
else:
character = self._toChar(codeunit)
else:
if character not in ESCAPE_CHARS:
raise TProtocolException(
TProtocolException.INVALID_DATA,
"Expected control char")
character = ESCAPE_CHARS[character]
elif character in ESCAPE_CHAR_VALS:
raise TProtocolException(TProtocolException.INVALID_DATA,
"Unescaped control char")
elif sys.version_info[0] > 2:
utf8_bytes = bytearray([ord(character)])
while ord(self.reader.peek()) >= 0x80:
utf8_bytes.append(ord(self.reader.read()))
character = utf8_bytes.decode('utf8')
string.append(character)
if highSurrogate:
raise TProtocolException(TProtocolException.INVALID_DATA,
"Expected low surrogate char")
return ''.join(string)
def isJSONNumeric(self, character):
return (True if NUMERIC_CHAR.find(character) != - 1 else False)
def readJSONQuotes(self):
if (self.context.escapeNum()):
self.readJSONSyntaxChar(QUOTE)
def readJSONNumericChars(self):
numeric = []
while True:
character = self.reader.peek()
if self.isJSONNumeric(character) is False:
break
numeric.append(self.reader.read())
return b''.join(numeric).decode('ascii')
def readJSONInteger(self):
self.context.read()
self.readJSONQuotes()
numeric = self.readJSONNumericChars()
self.readJSONQuotes()
try:
return int(numeric)
except ValueError:
raise TProtocolException(TProtocolException.INVALID_DATA,
"Bad data encounted in numeric data")
def readJSONDouble(self):
self.context.read()
if self.reader.peek() == QUOTE:
string = self.readJSONString(True)
try:
double = float(string)
if (self.context.escapeNum is False and
not math.isinf(double) and
not math.isnan(double)):
raise TProtocolException(
TProtocolException.INVALID_DATA,
"Numeric data unexpectedly quoted")
return double
except ValueError:
raise TProtocolException(TProtocolException.INVALID_DATA,
"Bad data encounted in numeric data")
else:
if self.context.escapeNum() is True:
self.readJSONSyntaxChar(QUOTE)
try:
return float(self.readJSONNumericChars())
except ValueError:
raise TProtocolException(TProtocolException.INVALID_DATA,
"Bad data encounted in numeric data")
def readJSONBase64(self):
string = self.readJSONString(False)
size = len(string)
m = size % 4
# Force padding since b64encode method does not allow it
if m != 0:
for i in range(4 - m):
string += '='
return base64.b64decode(string)
def readJSONObjectStart(self):
self.context.read()
self.readJSONSyntaxChar(LBRACE)
self.pushContext(JSONPairContext(self))
def readJSONObjectEnd(self):
self.readJSONSyntaxChar(RBRACE)
self.popContext()
def readJSONArrayStart(self):
self.context.read()
self.readJSONSyntaxChar(LBRACKET)
self.pushContext(JSONListContext(self))
def readJSONArrayEnd(self):
self.readJSONSyntaxChar(RBRACKET)
self.popContext()
class TJSONProtocol(TJSONProtocolBase):
def readMessageBegin(self):
self.resetReadContext()
self.readJSONArrayStart()
if self.readJSONInteger() != VERSION:
raise TProtocolException(TProtocolException.BAD_VERSION,
"Message contained bad version.")
name = self.readJSONString(False)
typen = self.readJSONInteger()
seqid = self.readJSONInteger()
return (name, typen, seqid)
def readMessageEnd(self):
self.readJSONArrayEnd()
def readStructBegin(self):
self.readJSONObjectStart()
def readStructEnd(self):
self.readJSONObjectEnd()
def readFieldBegin(self):
character = self.reader.peek()
ttype = 0
id = 0
if character == RBRACE:
ttype = TType.STOP
else:
id = self.readJSONInteger()
self.readJSONObjectStart()
ttype = JTYPES[self.readJSONString(False)]
return (None, ttype, id)
def readFieldEnd(self):
self.readJSONObjectEnd()
def readMapBegin(self):
self.readJSONArrayStart()
keyType = JTYPES[self.readJSONString(False)]
valueType = JTYPES[self.readJSONString(False)]
size = self.readJSONInteger()
self.readJSONObjectStart()
return (keyType, valueType, size)
def readMapEnd(self):
self.readJSONObjectEnd()
self.readJSONArrayEnd()
def readCollectionBegin(self):
self.readJSONArrayStart()
elemType = JTYPES[self.readJSONString(False)]
size = self.readJSONInteger()
return (elemType, size)
readListBegin = readCollectionBegin
readSetBegin = readCollectionBegin
def readCollectionEnd(self):
self.readJSONArrayEnd()
readSetEnd = readCollectionEnd
readListEnd = readCollectionEnd
def readBool(self):
return (False if self.readJSONInteger() == 0 else True)
def readNumber(self):
return self.readJSONInteger()
readByte = readNumber
readI16 = readNumber
readI32 = readNumber
readI64 = readNumber
def readDouble(self):
return self.readJSONDouble()
def readString(self):
return self.readJSONString(False)
def readBinary(self):
return self.readJSONBase64()
def writeMessageBegin(self, name, request_type, seqid):
self.resetWriteContext()
self.writeJSONArrayStart()
self.writeJSONNumber(VERSION)
self.writeJSONString(name)
self.writeJSONNumber(request_type)
self.writeJSONNumber(seqid)
def writeMessageEnd(self):
self.writeJSONArrayEnd()
def writeStructBegin(self, name):
self.writeJSONObjectStart()
def writeStructEnd(self):
self.writeJSONObjectEnd()
def writeFieldBegin(self, name, ttype, id):
self.writeJSONNumber(id)
self.writeJSONObjectStart()
self.writeJSONString(CTYPES[ttype])
def writeFieldEnd(self):
self.writeJSONObjectEnd()
def writeFieldStop(self):
pass
def writeMapBegin(self, ktype, vtype, size):
self.writeJSONArrayStart()
self.writeJSONString(CTYPES[ktype])
self.writeJSONString(CTYPES[vtype])
self.writeJSONNumber(size)
self.writeJSONObjectStart()
def writeMapEnd(self):
self.writeJSONObjectEnd()
self.writeJSONArrayEnd()
def writeListBegin(self, etype, size):
self.writeJSONArrayStart()
self.writeJSONString(CTYPES[etype])
self.writeJSONNumber(size)
def writeListEnd(self):
self.writeJSONArrayEnd()
def writeSetBegin(self, etype, size):
self.writeJSONArrayStart()
self.writeJSONString(CTYPES[etype])
self.writeJSONNumber(size)
def writeSetEnd(self):
self.writeJSONArrayEnd()
def writeBool(self, boolean):
self.writeJSONNumber(1 if boolean is True else 0)
def writeByte(self, byte):
checkIntegerLimits(byte, 8)
self.writeJSONNumber(byte)
def writeI16(self, i16):
checkIntegerLimits(i16, 16)
self.writeJSONNumber(i16)
def writeI32(self, i32):
checkIntegerLimits(i32, 32)
self.writeJSONNumber(i32)
def writeI64(self, i64):
checkIntegerLimits(i64, 64)
self.writeJSONNumber(i64)
def writeDouble(self, dbl):
# 17 significant digits should be just enough for any double precision
# value.
self.writeJSONNumber(dbl, '{0:.17g}')
def writeString(self, string):
self.writeJSONString(string)
def writeBinary(self, binary):
self.writeJSONBase64(binary)
class TJSONProtocolFactory(TProtocolFactory):
def getProtocol(self, trans):
return TJSONProtocol(trans)
@property
def string_length_limit(senf):
return None
@property
def container_length_limit(senf):
return None
class TSimpleJSONProtocol(TJSONProtocolBase):
"""Simple, readable, write-only JSON protocol.
Useful for interacting with scripting languages.
"""
def readMessageBegin(self):
raise NotImplementedError()
def readMessageEnd(self):
raise NotImplementedError()
def readStructBegin(self):
raise NotImplementedError()
def readStructEnd(self):
raise NotImplementedError()
def writeMessageBegin(self, name, request_type, seqid):
self.resetWriteContext()
def writeMessageEnd(self):
pass
def writeStructBegin(self, name):
self.writeJSONObjectStart()
def writeStructEnd(self):
self.writeJSONObjectEnd()
def writeFieldBegin(self, name, ttype, fid):
self.writeJSONString(name)
def writeFieldEnd(self):
pass
def writeMapBegin(self, ktype, vtype, size):
self.writeJSONObjectStart()
def writeMapEnd(self):
self.writeJSONObjectEnd()
def _writeCollectionBegin(self, etype, size):
self.writeJSONArrayStart()
def _writeCollectionEnd(self):
self.writeJSONArrayEnd()
writeListBegin = _writeCollectionBegin
writeListEnd = _writeCollectionEnd
writeSetBegin = _writeCollectionBegin
writeSetEnd = _writeCollectionEnd
def writeByte(self, byte):
checkIntegerLimits(byte, 8)
self.writeJSONNumber(byte)
def writeI16(self, i16):
checkIntegerLimits(i16, 16)
self.writeJSONNumber(i16)
def writeI32(self, i32):
checkIntegerLimits(i32, 32)
self.writeJSONNumber(i32)
def writeI64(self, i64):
checkIntegerLimits(i64, 64)
self.writeJSONNumber(i64)
def writeBool(self, boolean):
self.writeJSONNumber(1 if boolean is True else 0)
def writeDouble(self, dbl):
self.writeJSONNumber(dbl)
def writeString(self, string):
self.writeJSONString(string)
def writeBinary(self, binary):
self.writeJSONBase64(binary)
class TSimpleJSONProtocolFactory(TProtocolFactory):
def getProtocol(self, trans):
return TSimpleJSONProtocol(trans)
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
from thrift.Thrift import TMessageType
from thrift.protocol import TProtocolDecorator
SEPARATOR = ":"
class TMultiplexedProtocol(TProtocolDecorator.TProtocolDecorator):
def __init__(self, protocol, serviceName):
self.serviceName = serviceName
def writeMessageBegin(self, name, type, seqid):
if (type == TMessageType.CALL or
type == TMessageType.ONEWAY):
super(TMultiplexedProtocol, self).writeMessageBegin(
self.serviceName + SEPARATOR + name,
type,
seqid
)
else:
super(TMultiplexedProtocol, self).writeMessageBegin(name, type, seqid)
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
from thrift.Thrift import TException, TType, TFrozenDict
from thrift.transport.TTransport import TTransportException
from ..compat import binary_to_str, str_to_binary
import six
import sys
from itertools import islice
from six.moves import zip
class TProtocolException(TException):
"""Custom Protocol Exception class"""
UNKNOWN = 0
INVALID_DATA = 1
NEGATIVE_SIZE = 2
SIZE_LIMIT = 3
BAD_VERSION = 4
NOT_IMPLEMENTED = 5
DEPTH_LIMIT = 6
INVALID_PROTOCOL = 7
def __init__(self, type=UNKNOWN, message=None):
TException.__init__(self, message)
self.type = type
class TProtocolBase(object):
"""Base class for Thrift protocol driver."""
def __init__(self, trans):
self.trans = trans
self._fast_decode = None
self._fast_encode = None
@staticmethod
def _check_length(limit, length):
if length < 0:
raise TTransportException(TTransportException.NEGATIVE_SIZE,
'Negative length: %d' % length)
if limit is not None and length > limit:
raise TTransportException(TTransportException.SIZE_LIMIT,
'Length exceeded max allowed: %d' % limit)
def writeMessageBegin(self, name, ttype, seqid):
pass
def writeMessageEnd(self):
pass
def writeStructBegin(self, name):
pass
def writeStructEnd(self):
pass
def writeFieldBegin(self, name, ttype, fid):
pass
def writeFieldEnd(self):
pass
def writeFieldStop(self):
pass
def writeMapBegin(self, ktype, vtype, size):
pass
def writeMapEnd(self):
pass
def writeListBegin(self, etype, size):
pass
def writeListEnd(self):
pass
def writeSetBegin(self, etype, size):
pass
def writeSetEnd(self):
pass
def writeBool(self, bool_val):
pass
def writeByte(self, byte):
pass
def writeI16(self, i16):
pass
def writeI32(self, i32):
pass
def writeI64(self, i64):
pass
def writeDouble(self, dub):
pass
def writeString(self, str_val):
self.writeBinary(str_to_binary(str_val))
def writeBinary(self, str_val):
pass
def writeUtf8(self, str_val):
self.writeString(str_val.encode('utf8'))
def readMessageBegin(self):
pass
def readMessageEnd(self):
pass
def readStructBegin(self):
pass
def readStructEnd(self):
pass
def readFieldBegin(self):
pass
def readFieldEnd(self):
pass
def readMapBegin(self):
pass
def readMapEnd(self):
pass
def readListBegin(self):
pass
def readListEnd(self):
pass
def readSetBegin(self):
pass
def readSetEnd(self):
pass
def readBool(self):
pass
def readByte(self):
pass
def readI16(self):
pass
def readI32(self):
pass
def readI64(self):
pass
def readDouble(self):
pass
def readString(self):
return binary_to_str(self.readBinary())
def readBinary(self):
pass
def readUtf8(self):
return self.readString().decode('utf8')
def skip(self, ttype):
if ttype == TType.BOOL:
self.readBool()
elif ttype == TType.BYTE:
self.readByte()
elif ttype == TType.I16:
self.readI16()
elif ttype == TType.I32:
self.readI32()
elif ttype == TType.I64:
self.readI64()
elif ttype == TType.DOUBLE:
self.readDouble()
elif ttype == TType.STRING:
self.readString()
elif ttype == TType.STRUCT:
name = self.readStructBegin()
while True:
(name, ttype, id) = self.readFieldBegin()
if ttype == TType.STOP:
break
self.skip(ttype)
self.readFieldEnd()
self.readStructEnd()
elif ttype == TType.MAP:
(ktype, vtype, size) = self.readMapBegin()
for i in range(size):
self.skip(ktype)
self.skip(vtype)
self.readMapEnd()
elif ttype == TType.SET:
(etype, size) = self.readSetBegin()
for i in range(size):
self.skip(etype)
self.readSetEnd()
elif ttype == TType.LIST:
(etype, size) = self.readListBegin()
for i in range(size):
self.skip(etype)
self.readListEnd()
else:
raise TProtocolException(
TProtocolException.INVALID_DATA,
"invalid TType")
# tuple of: ( 'reader method' name, is_container bool, 'writer_method' name )
_TTYPE_HANDLERS = (
(None, None, False), # 0 TType.STOP
(None, None, False), # 1 TType.VOID # TODO: handle void?
('readBool', 'writeBool', False), # 2 TType.BOOL
('readByte', 'writeByte', False), # 3 TType.BYTE and I08
('readDouble', 'writeDouble', False), # 4 TType.DOUBLE
(None, None, False), # 5 undefined
('readI16', 'writeI16', False), # 6 TType.I16
(None, None, False), # 7 undefined
('readI32', 'writeI32', False), # 8 TType.I32
(None, None, False), # 9 undefined
('readI64', 'writeI64', False), # 10 TType.I64
('readString', 'writeString', False), # 11 TType.STRING and UTF7
('readContainerStruct', 'writeContainerStruct', True), # 12 *.STRUCT
('readContainerMap', 'writeContainerMap', True), # 13 TType.MAP
('readContainerSet', 'writeContainerSet', True), # 14 TType.SET
('readContainerList', 'writeContainerList', True), # 15 TType.LIST
(None, None, False), # 16 TType.UTF8 # TODO: handle utf8 types?
(None, None, False) # 17 TType.UTF16 # TODO: handle utf16 types?
)
def _ttype_handlers(self, ttype, spec):
if spec == 'BINARY':
if ttype != TType.STRING:
raise TProtocolException(type=TProtocolException.INVALID_DATA,
message='Invalid binary field type %d' % ttype)
return ('readBinary', 'writeBinary', False)
if sys.version_info[0] == 2 and spec == 'UTF8':
if ttype != TType.STRING:
raise TProtocolException(type=TProtocolException.INVALID_DATA,
message='Invalid string field type %d' % ttype)
return ('readUtf8', 'writeUtf8', False)
return self._TTYPE_HANDLERS[ttype] if ttype < len(self._TTYPE_HANDLERS) else (None, None, False)
def _read_by_ttype(self, ttype, spec, espec):
reader_name, _, is_container = self._ttype_handlers(ttype, espec)
if reader_name is None:
raise TProtocolException(type=TProtocolException.INVALID_DATA,
message='Invalid type %d' % (ttype))
reader_func = getattr(self, reader_name)
read = (lambda: reader_func(espec)) if is_container else reader_func
while True:
yield read()
def readFieldByTType(self, ttype, spec):
return next(self._read_by_ttype(ttype, spec, spec))
def readContainerList(self, spec):
ttype, tspec, is_immutable = spec
(list_type, list_len) = self.readListBegin()
# TODO: compare types we just decoded with thrift_spec
elems = islice(self._read_by_ttype(ttype, spec, tspec), list_len)
results = (tuple if is_immutable else list)(elems)
self.readListEnd()
return results
def readContainerSet(self, spec):
ttype, tspec, is_immutable = spec
(set_type, set_len) = self.readSetBegin()
# TODO: compare types we just decoded with thrift_spec
elems = islice(self._read_by_ttype(ttype, spec, tspec), set_len)
results = (frozenset if is_immutable else set)(elems)
self.readSetEnd()
return results
def readContainerStruct(self, spec):
(obj_class, obj_spec) = spec
obj = obj_class()
obj.read(self)
return obj
def readContainerMap(self, spec):
ktype, kspec, vtype, vspec, is_immutable = spec
(map_ktype, map_vtype, map_len) = self.readMapBegin()
# TODO: compare types we just decoded with thrift_spec and
# abort/skip if types disagree
keys = self._read_by_ttype(ktype, spec, kspec)
vals = self._read_by_ttype(vtype, spec, vspec)
keyvals = islice(zip(keys, vals), map_len)
results = (TFrozenDict if is_immutable else dict)(keyvals)
self.readMapEnd()
return results
def readStruct(self, obj, thrift_spec, is_immutable=False):
if is_immutable:
fields = {}
self.readStructBegin()
while True:
(fname, ftype, fid) = self.readFieldBegin()
if ftype == TType.STOP:
break
try:
field = thrift_spec[fid]
except IndexError:
self.skip(ftype)
else:
if field is not None and ftype == field[1]:
fname = field[2]
fspec = field[3]
val = self.readFieldByTType(ftype, fspec)
if is_immutable:
fields[fname] = val
else:
setattr(obj, fname, val)
else:
self.skip(ftype)
self.readFieldEnd()
self.readStructEnd()
if is_immutable:
return obj(**fields)
def writeContainerStruct(self, val, spec):
val.write(self)
def writeContainerList(self, val, spec):
ttype, tspec, _ = spec
self.writeListBegin(ttype, len(val))
for _ in self._write_by_ttype(ttype, val, spec, tspec):
pass
self.writeListEnd()
def writeContainerSet(self, val, spec):
ttype, tspec, _ = spec
self.writeSetBegin(ttype, len(val))
for _ in self._write_by_ttype(ttype, val, spec, tspec):
pass
self.writeSetEnd()
def writeContainerMap(self, val, spec):
ktype, kspec, vtype, vspec, _ = spec
self.writeMapBegin(ktype, vtype, len(val))
for _ in zip(self._write_by_ttype(ktype, six.iterkeys(val), spec, kspec),
self._write_by_ttype(vtype, six.itervalues(val), spec, vspec)):
pass
self.writeMapEnd()
def writeStruct(self, obj, thrift_spec):
self.writeStructBegin(obj.__class__.__name__)
for field in thrift_spec:
if field is None:
continue
fname = field[2]
val = getattr(obj, fname)
if val is None:
# skip writing out unset fields
continue
fid = field[0]
ftype = field[1]
fspec = field[3]
self.writeFieldBegin(fname, ftype, fid)
self.writeFieldByTType(ftype, val, fspec)
self.writeFieldEnd()
self.writeFieldStop()
self.writeStructEnd()
def _write_by_ttype(self, ttype, vals, spec, espec):
_, writer_name, is_container = self._ttype_handlers(ttype, espec)
writer_func = getattr(self, writer_name)
write = (lambda v: writer_func(v, espec)) if is_container else writer_func
for v in vals:
yield write(v)
def writeFieldByTType(self, ttype, val, spec):
next(self._write_by_ttype(ttype, [val], spec, spec))
def checkIntegerLimits(i, bits):
if bits == 8 and (i < -128 or i > 127):
raise TProtocolException(TProtocolException.INVALID_DATA,
"i8 requires -128 <= number <= 127")
elif bits == 16 and (i < -32768 or i > 32767):
raise TProtocolException(TProtocolException.INVALID_DATA,
"i16 requires -32768 <= number <= 32767")
elif bits == 32 and (i < -2147483648 or i > 2147483647):
raise TProtocolException(TProtocolException.INVALID_DATA,
"i32 requires -2147483648 <= number <= 2147483647")
elif bits == 64 and (i < -9223372036854775808 or i > 9223372036854775807):
raise TProtocolException(TProtocolException.INVALID_DATA,
"i64 requires -9223372036854775808 <= number <= 9223372036854775807")
class TProtocolFactory(object):
def getProtocol(self, trans):
pass
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
class TProtocolDecorator(object):
def __new__(cls, protocol, *args, **kwargs):
decorated_cls = type(''.join(['Decorated', protocol.__class__.__name__]),
(cls, protocol.__class__),
protocol.__dict__)
return object.__new__(decorated_cls)
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
__all__ = ['fastbinary', 'TBase', 'TBinaryProtocol', 'TCompactProtocol',
'TJSONProtocol', 'TProtocol', 'TProtocolDecorator']
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
import ssl
from six.moves import BaseHTTPServer
from thrift.Thrift import TMessageType
from thrift.server import TServer
from thrift.transport import TTransport
class ResponseException(Exception):
"""Allows handlers to override the HTTP response
Normally, THttpServer always sends a 200 response. If a handler wants
to override this behavior (e.g., to simulate a misconfigured or
overloaded web server during testing), it can raise a ResponseException.
The function passed to the constructor will be called with the
RequestHandler as its only argument. Note that this is irrelevant
for ONEWAY requests, as the HTTP response must be sent before the
RPC is processed.
"""
def __init__(self, handler):
self.handler = handler
class THttpServer(TServer.TServer):
"""A simple HTTP-based Thrift server
This class is not very performant, but it is useful (for example) for
acting as a mock version of an Apache-based PHP Thrift endpoint.
Also important to note the HTTP implementation pretty much violates the
transport/protocol/processor/server layering, by performing the transport
functions here. This means things like oneway handling are oddly exposed.
"""
def __init__(self,
processor,
server_address,
inputProtocolFactory,
outputProtocolFactory=None,
server_class=BaseHTTPServer.HTTPServer,
**kwargs):
"""Set up protocol factories and HTTP (or HTTPS) server.
See BaseHTTPServer for server_address.
See TServer for protocol factories.
To make a secure server, provide the named arguments:
* cafile - to validate clients [optional]
* cert_file - the server cert
* key_file - the server's key
"""
if outputProtocolFactory is None:
outputProtocolFactory = inputProtocolFactory
TServer.TServer.__init__(self, processor, None, None, None,
inputProtocolFactory, outputProtocolFactory)
thttpserver = self
self._replied = None
class RequestHander(BaseHTTPServer.BaseHTTPRequestHandler):
def do_POST(self):
# Don't care about the request path.
thttpserver._replied = False
iftrans = TTransport.TFileObjectTransport(self.rfile)
itrans = TTransport.TBufferedTransport(
iftrans, int(self.headers['Content-Length']))
otrans = TTransport.TMemoryBuffer()
iprot = thttpserver.inputProtocolFactory.getProtocol(itrans)
oprot = thttpserver.outputProtocolFactory.getProtocol(otrans)
try:
thttpserver.processor.on_message_begin(self.on_begin)
thttpserver.processor.process(iprot, oprot)
except ResponseException as exn:
exn.handler(self)
else:
if not thttpserver._replied:
# If the request was ONEWAY we would have replied already
data = otrans.getvalue()
self.send_response(200)
self.send_header("Content-Length", len(data))
self.send_header("Content-Type", "application/x-thrift")
self.end_headers()
self.wfile.write(data)
def on_begin(self, name, type, seqid):
"""
Inspect the message header.
This allows us to post an immediate transport response
if the request is a ONEWAY message type.
"""
if type == TMessageType.ONEWAY:
self.send_response(200)
self.send_header("Content-Type", "application/x-thrift")
self.end_headers()
thttpserver._replied = True
self.httpd = server_class(server_address, RequestHander)
if (kwargs.get('cafile') or kwargs.get('cert_file') or kwargs.get('key_file')):
context = ssl.create_default_context(cafile=kwargs.get('cafile'))
context.check_hostname = False
context.load_cert_chain(kwargs.get('cert_file'), kwargs.get('key_file'))
context.verify_mode = ssl.CERT_REQUIRED if kwargs.get('cafile') else ssl.CERT_NONE
self.httpd.socket = context.wrap_socket(self.httpd.socket, server_side=True)
def serve(self):
self.httpd.serve_forever()
def shutdown(self):
self.httpd.socket.close()
# self.httpd.shutdown() # hangs forever, python doesn't handle POLLNVAL properly!
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
"""Implementation of non-blocking server.
The main idea of the server is to receive and send requests
only from the main thread.
The thread poool should be sized for concurrent tasks, not
maximum connections
"""
import logging
import select
import socket
import struct
import threading
from collections import deque
from six.moves import queue
from thrift.transport import TTransport
from thrift.protocol.TBinaryProtocol import TBinaryProtocolFactory
__all__ = ['TNonblockingServer']
logger = logging.getLogger(__name__)
class Worker(threading.Thread):
"""Worker is a small helper to process incoming connection."""
def __init__(self, queue):
threading.Thread.__init__(self)
self.queue = queue
def run(self):
"""Process queries from task queue, stop if processor is None."""
while True:
try:
processor, iprot, oprot, otrans, callback = self.queue.get()
if processor is None:
break
processor.process(iprot, oprot)
callback(True, otrans.getvalue())
except Exception:
logger.exception("Exception while processing request", exc_info=True)
callback(False, b'')
WAIT_LEN = 0
WAIT_MESSAGE = 1
WAIT_PROCESS = 2
SEND_ANSWER = 3
CLOSED = 4
def locked(func):
"""Decorator which locks self.lock."""
def nested(self, *args, **kwargs):
self.lock.acquire()
try:
return func(self, *args, **kwargs)
finally:
self.lock.release()
return nested
def socket_exception(func):
"""Decorator close object on socket.error."""
def read(self, *args, **kwargs):
try:
return func(self, *args, **kwargs)
except socket.error:
logger.debug('ignoring socket exception', exc_info=True)
self.close()
return read
class Message(object):
def __init__(self, offset, len_, header):
self.offset = offset
self.len = len_
self.buffer = None
self.is_header = header
@property
def end(self):
return self.offset + self.len
class Connection(object):
"""Basic class is represented connection.
It can be in state:
WAIT_LEN --- connection is reading request len.
WAIT_MESSAGE --- connection is reading request.
WAIT_PROCESS --- connection has just read whole request and
waits for call ready routine.
SEND_ANSWER --- connection is sending answer string (including length
of answer).
CLOSED --- socket was closed and connection should be deleted.
"""
def __init__(self, new_socket, wake_up):
self.socket = new_socket
self.socket.setblocking(False)
self.status = WAIT_LEN
self.len = 0
self.received = deque()
self._reading = Message(0, 4, True)
self._rbuf = b''
self._wbuf = b''
self.lock = threading.Lock()
self.wake_up = wake_up
self.remaining = False
@socket_exception
def read(self):
"""Reads data from stream and switch state."""
assert self.status in (WAIT_LEN, WAIT_MESSAGE)
assert not self.received
buf_size = 8192
first = True
done = False
while not done:
read = self.socket.recv(buf_size)
rlen = len(read)
done = rlen < buf_size
self._rbuf += read
if first and rlen == 0:
if self.status != WAIT_LEN or self._rbuf:
logger.error('could not read frame from socket')
else:
logger.debug('read zero length. client might have disconnected')
self.close()
while len(self._rbuf) >= self._reading.end:
if self._reading.is_header:
mlen, = struct.unpack('!i', self._rbuf[:4])
self._reading = Message(self._reading.end, mlen, False)
self.status = WAIT_MESSAGE
else:
self._reading.buffer = self._rbuf
self.received.append(self._reading)
self._rbuf = self._rbuf[self._reading.end:]
self._reading = Message(0, 4, True)
first = False
if self.received:
self.status = WAIT_PROCESS
break
self.remaining = not done
@socket_exception
def write(self):
"""Writes data from socket and switch state."""
assert self.status == SEND_ANSWER
sent = self.socket.send(self._wbuf)
if sent == len(self._wbuf):
self.status = WAIT_LEN
self._wbuf = b''
self.len = 0
else:
self._wbuf = self._wbuf[sent:]
@locked
def ready(self, all_ok, message):
"""Callback function for switching state and waking up main thread.
This function is the only function witch can be called asynchronous.
The ready can switch Connection to three states:
WAIT_LEN if request was oneway.
SEND_ANSWER if request was processed in normal way.
CLOSED if request throws unexpected exception.
The one wakes up main thread.
"""
assert self.status == WAIT_PROCESS
if not all_ok:
self.close()
self.wake_up()
return
self.len = 0
if len(message) == 0:
# it was a oneway request, do not write answer
self._wbuf = b''
self.status = WAIT_LEN
else:
self._wbuf = struct.pack('!i', len(message)) + message
self.status = SEND_ANSWER
self.wake_up()
@locked
def is_writeable(self):
"""Return True if connection should be added to write list of select"""
return self.status == SEND_ANSWER
# it's not necessary, but...
@locked
def is_readable(self):
"""Return True if connection should be added to read list of select"""
return self.status in (WAIT_LEN, WAIT_MESSAGE)
@locked
def is_closed(self):
"""Returns True if connection is closed."""
return self.status == CLOSED
def fileno(self):
"""Returns the file descriptor of the associated socket."""
return self.socket.fileno()
def close(self):
"""Closes connection"""
self.status = CLOSED
self.socket.close()
class TNonblockingServer(object):
"""Non-blocking server."""
def __init__(self,
processor,
lsocket,
inputProtocolFactory=None,
outputProtocolFactory=None,
threads=10):
self.processor = processor
self.socket = lsocket
self.in_protocol = inputProtocolFactory or TBinaryProtocolFactory()
self.out_protocol = outputProtocolFactory or self.in_protocol
self.threads = int(threads)
self.clients = {}
self.tasks = queue.Queue()
self._read, self._write = socket.socketpair()
self.prepared = False
self._stop = False
def setNumThreads(self, num):
"""Set the number of worker threads that should be created."""
# implement ThreadPool interface
assert not self.prepared, "Can't change number of threads after start"
self.threads = num
def prepare(self):
"""Prepares server for serve requests."""
if self.prepared:
return
self.socket.listen()
for _ in range(self.threads):
thread = Worker(self.tasks)
thread.setDaemon(True)
thread.start()
self.prepared = True
def wake_up(self):
"""Wake up main thread.
The server usually waits in select call in we should terminate one.
The simplest way is using socketpair.
Select always wait to read from the first socket of socketpair.
In this case, we can just write anything to the second socket from
socketpair.
"""
self._write.send(b'1')
def stop(self):
"""Stop the server.
This method causes the serve() method to return. stop() may be invoked
from within your handler, or from another thread.
After stop() is called, serve() will return but the server will still
be listening on the socket. serve() may then be called again to resume
processing requests. Alternatively, close() may be called after
serve() returns to close the server socket and shutdown all worker
threads.
"""
self._stop = True
self.wake_up()
def _select(self):
"""Does select on open connections."""
readable = [self.socket.handle.fileno(), self._read.fileno()]
writable = []
remaining = []
for i, connection in list(self.clients.items()):
if connection.is_readable():
readable.append(connection.fileno())
if connection.remaining or connection.received:
remaining.append(connection.fileno())
if connection.is_writeable():
writable.append(connection.fileno())
if connection.is_closed():
del self.clients[i]
if remaining:
return remaining, [], [], False
else:
return select.select(readable, writable, readable) + (True,)
def handle(self):
"""Handle requests.
WARNING! You must call prepare() BEFORE calling handle()
"""
assert self.prepared, "You have to call prepare before handle"
rset, wset, xset, selected = self._select()
for readable in rset:
if readable == self._read.fileno():
# don't care i just need to clean readable flag
self._read.recv(1024)
elif readable == self.socket.handle.fileno():
try:
client = self.socket.accept()
if client:
self.clients[client.handle.fileno()] = Connection(client.handle,
self.wake_up)
except socket.error:
logger.debug('error while accepting', exc_info=True)
else:
connection = self.clients[readable]
if selected:
connection.read()
if connection.received:
connection.status = WAIT_PROCESS
msg = connection.received.popleft()
itransport = TTransport.TMemoryBuffer(msg.buffer, msg.offset)
otransport = TTransport.TMemoryBuffer()
iprot = self.in_protocol.getProtocol(itransport)
oprot = self.out_protocol.getProtocol(otransport)
self.tasks.put([self.processor, iprot, oprot,
otransport, connection.ready])
for writeable in wset:
self.clients[writeable].write()
for oob in xset:
self.clients[oob].close()
del self.clients[oob]
def close(self):
"""Closes the server."""
for _ in range(self.threads):
self.tasks.put([None, None, None, None, None])
self.socket.close()
self.prepared = False
def serve(self):
"""Serve requests.
Serve requests forever, or until stop() is called.
"""
self._stop = False
self.prepare()
while not self._stop:
self.handle()
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
import logging
from multiprocessing import Process, Value, Condition
from .TServer import TServer
from thrift.transport.TTransport import TTransportException
logger = logging.getLogger(__name__)
class TProcessPoolServer(TServer):
"""Server with a fixed size pool of worker subprocesses to service requests
Note that if you need shared state between the handlers - it's up to you!
Written by Dvir Volk, doat.com
"""
def __init__(self, *args):
TServer.__init__(self, *args)
self.numWorkers = 10
self.workers = []
self.isRunning = Value('b', False)
self.stopCondition = Condition()
self.postForkCallback = None
def setPostForkCallback(self, callback):
if not callable(callback):
raise TypeError("This is not a callback!")
self.postForkCallback = callback
def setNumWorkers(self, num):
"""Set the number of worker threads that should be created"""
self.numWorkers = num
def workerProcess(self):
"""Loop getting clients from the shared queue and process them"""
if self.postForkCallback:
self.postForkCallback()
while self.isRunning.value:
try:
client = self.serverTransport.accept()
if not client:
continue
self.serveClient(client)
except (KeyboardInterrupt, SystemExit):
return 0
except Exception as x:
logger.exception(x)
def serveClient(self, client):
"""Process input/output from a client for as long as possible"""
itrans = self.inputTransportFactory.getTransport(client)
otrans = self.outputTransportFactory.getTransport(client)
iprot = self.inputProtocolFactory.getProtocol(itrans)
oprot = self.outputProtocolFactory.getProtocol(otrans)
try:
while True:
self.processor.process(iprot, oprot)
except TTransportException:
pass
except Exception as x:
logger.exception(x)
itrans.close()
otrans.close()
def serve(self):
"""Start workers and put into queue"""
# this is a shared state that can tell the workers to exit when False
self.isRunning.value = True
# first bind and listen to the port
self.serverTransport.listen()
# fork the children
for i in range(self.numWorkers):
try:
w = Process(target=self.workerProcess)
w.daemon = True
w.start()
self.workers.append(w)
except Exception as x:
logger.exception(x)
# wait until the condition is set by stop()
while True:
self.stopCondition.acquire()
try:
self.stopCondition.wait()
break
except (SystemExit, KeyboardInterrupt):
break
except Exception as x:
logger.exception(x)
self.isRunning.value = False
def stop(self):
self.isRunning.value = False
self.stopCondition.acquire()
self.stopCondition.notify()
self.stopCondition.release()
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
from six.moves import queue
import logging
import os
import threading
from thrift.protocol import TBinaryProtocol
from thrift.protocol.THeaderProtocol import THeaderProtocolFactory
from thrift.transport import TTransport
logger = logging.getLogger(__name__)
class TServer(object):
"""Base interface for a server, which must have a serve() method.
Three constructors for all servers:
1) (processor, serverTransport)
2) (processor, serverTransport, transportFactory, protocolFactory)
3) (processor, serverTransport,
inputTransportFactory, outputTransportFactory,
inputProtocolFactory, outputProtocolFactory)
"""
def __init__(self, *args):
if (len(args) == 2):
self.__initArgs__(args[0], args[1],
TTransport.TTransportFactoryBase(),
TTransport.TTransportFactoryBase(),
TBinaryProtocol.TBinaryProtocolFactory(),
TBinaryProtocol.TBinaryProtocolFactory())
elif (len(args) == 4):
self.__initArgs__(args[0], args[1], args[2], args[2], args[3], args[3])
elif (len(args) == 6):
self.__initArgs__(args[0], args[1], args[2], args[3], args[4], args[5])
def __initArgs__(self, processor, serverTransport,
inputTransportFactory, outputTransportFactory,
inputProtocolFactory, outputProtocolFactory):
self.processor = processor
self.serverTransport = serverTransport
self.inputTransportFactory = inputTransportFactory
self.outputTransportFactory = outputTransportFactory
self.inputProtocolFactory = inputProtocolFactory
self.outputProtocolFactory = outputProtocolFactory
input_is_header = isinstance(self.inputProtocolFactory, THeaderProtocolFactory)
output_is_header = isinstance(self.outputProtocolFactory, THeaderProtocolFactory)
if any((input_is_header, output_is_header)) and input_is_header != output_is_header:
raise ValueError("THeaderProtocol servers require that both the input and "
"output protocols are THeaderProtocol.")
def serve(self):
pass
class TSimpleServer(TServer):
"""Simple single-threaded server that just pumps around one transport."""
def __init__(self, *args):
TServer.__init__(self, *args)
def serve(self):
self.serverTransport.listen()
while True:
client = self.serverTransport.accept()
if not client:
continue
itrans = self.inputTransportFactory.getTransport(client)
iprot = self.inputProtocolFactory.getProtocol(itrans)
# for THeaderProtocol, we must use the same protocol instance for
# input and output so that the response is in the same dialect that
# the server detected the request was in.
if isinstance(self.inputProtocolFactory, THeaderProtocolFactory):
otrans = None
oprot = iprot
else:
otrans = self.outputTransportFactory.getTransport(client)
oprot = self.outputProtocolFactory.getProtocol(otrans)
try:
while True:
self.processor.process(iprot, oprot)
except TTransport.TTransportException:
pass
except Exception as x:
logger.exception(x)
itrans.close()
if otrans:
otrans.close()
class TThreadedServer(TServer):
"""Threaded server that spawns a new thread per each connection."""
def __init__(self, *args, **kwargs):
TServer.__init__(self, *args)
self.daemon = kwargs.get("daemon", False)
def serve(self):
self.serverTransport.listen()
while True:
try:
client = self.serverTransport.accept()
if not client:
continue
t = threading.Thread(target=self.handle, args=(client,))
t.setDaemon(self.daemon)
t.start()
except KeyboardInterrupt:
raise
except Exception as x:
logger.exception(x)
def handle(self, client):
itrans = self.inputTransportFactory.getTransport(client)
iprot = self.inputProtocolFactory.getProtocol(itrans)
# for THeaderProtocol, we must use the same protocol instance for input
# and output so that the response is in the same dialect that the
# server detected the request was in.
if isinstance(self.inputProtocolFactory, THeaderProtocolFactory):
otrans = None
oprot = iprot
else:
otrans = self.outputTransportFactory.getTransport(client)
oprot = self.outputProtocolFactory.getProtocol(otrans)
try:
while True:
self.processor.process(iprot, oprot)
except TTransport.TTransportException:
pass
except Exception as x:
logger.exception(x)
itrans.close()
if otrans:
otrans.close()
class TThreadPoolServer(TServer):
"""Server with a fixed size pool of threads which service requests."""
def __init__(self, *args, **kwargs):
TServer.__init__(self, *args)
self.clients = queue.Queue()
self.threads = 10
self.daemon = kwargs.get("daemon", False)
def setNumThreads(self, num):
"""Set the number of worker threads that should be created"""
self.threads = num
def serveThread(self):
"""Loop around getting clients from the shared queue and process them."""
while True:
try:
client = self.clients.get()
self.serveClient(client)
except Exception as x:
logger.exception(x)
def serveClient(self, client):
"""Process input/output from a client for as long as possible"""
itrans = self.inputTransportFactory.getTransport(client)
iprot = self.inputProtocolFactory.getProtocol(itrans)
# for THeaderProtocol, we must use the same protocol instance for input
# and output so that the response is in the same dialect that the
# server detected the request was in.
if isinstance(self.inputProtocolFactory, THeaderProtocolFactory):
otrans = None
oprot = iprot
else:
otrans = self.outputTransportFactory.getTransport(client)
oprot = self.outputProtocolFactory.getProtocol(otrans)
try:
while True:
self.processor.process(iprot, oprot)
except TTransport.TTransportException:
pass
except Exception as x:
logger.exception(x)
itrans.close()
if otrans:
otrans.close()
def serve(self):
"""Start a fixed number of worker threads and put client into a queue"""
for i in range(self.threads):
try:
t = threading.Thread(target=self.serveThread)
t.setDaemon(self.daemon)
t.start()
except Exception as x:
logger.exception(x)
# Pump the socket for clients
self.serverTransport.listen()
while True:
try:
client = self.serverTransport.accept()
if not client:
continue
self.clients.put(client)
except Exception as x:
logger.exception(x)
class TForkingServer(TServer):
"""A Thrift server that forks a new process for each request
This is more scalable than the threaded server as it does not cause
GIL contention.
Note that this has different semantics from the threading server.
Specifically, updates to shared variables will no longer be shared.
It will also not work on windows.
This code is heavily inspired by SocketServer.ForkingMixIn in the
Python stdlib.
"""
def __init__(self, *args):
TServer.__init__(self, *args)
self.children = []
def serve(self):
def try_close(file):
try:
file.close()
except IOError as e:
logger.warning(e, exc_info=True)
self.serverTransport.listen()
while True:
client = self.serverTransport.accept()
if not client:
continue
try:
pid = os.fork()
if pid: # parent
# add before collect, otherwise you race w/ waitpid
self.children.append(pid)
self.collect_children()
# Parent must close socket or the connection may not get
# closed promptly
itrans = self.inputTransportFactory.getTransport(client)
otrans = self.outputTransportFactory.getTransport(client)
try_close(itrans)
try_close(otrans)
else:
itrans = self.inputTransportFactory.getTransport(client)
iprot = self.inputProtocolFactory.getProtocol(itrans)
# for THeaderProtocol, we must use the same protocol
# instance for input and output so that the response is in
# the same dialect that the server detected the request was
# in.
if isinstance(self.inputProtocolFactory, THeaderProtocolFactory):
otrans = None
oprot = iprot
else:
otrans = self.outputTransportFactory.getTransport(client)
oprot = self.outputProtocolFactory.getProtocol(otrans)
ecode = 0
try:
try:
while True:
self.processor.process(iprot, oprot)
except TTransport.TTransportException:
pass
except Exception as e:
logger.exception(e)
ecode = 1
finally:
try_close(itrans)
if otrans:
try_close(otrans)
os._exit(ecode)
except TTransport.TTransportException:
pass
except Exception as x:
logger.exception(x)
def collect_children(self):
while self.children:
try:
pid, status = os.waitpid(0, os.WNOHANG)
except os.error:
pid = None
if pid:
self.children.remove(pid)
else:
break
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
__all__ = ['TServer', 'TNonblockingServer']
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
import struct
import zlib
from thrift.compat import BufferIO, byte_index
from thrift.protocol.TBinaryProtocol import TBinaryProtocol
from thrift.protocol.TCompactProtocol import TCompactProtocol, readVarint, writeVarint
from thrift.Thrift import TApplicationException
from thrift.transport.TTransport import (
CReadableTransport,
TMemoryBuffer,
TTransportBase,
TTransportException,
)
U16 = struct.Struct("!H")
I32 = struct.Struct("!i")
HEADER_MAGIC = 0x0FFF
HARD_MAX_FRAME_SIZE = 0x3FFFFFFF
class THeaderClientType(object):
HEADERS = 0x00
FRAMED_BINARY = 0x01
UNFRAMED_BINARY = 0x02
FRAMED_COMPACT = 0x03
UNFRAMED_COMPACT = 0x04
class THeaderSubprotocolID(object):
BINARY = 0x00
COMPACT = 0x02
class TInfoHeaderType(object):
KEY_VALUE = 0x01
class THeaderTransformID(object):
ZLIB = 0x01
READ_TRANSFORMS_BY_ID = {
THeaderTransformID.ZLIB: zlib.decompress,
}
WRITE_TRANSFORMS_BY_ID = {
THeaderTransformID.ZLIB: zlib.compress,
}
def _readString(trans):
size = readVarint(trans)
if size < 0:
raise TTransportException(
TTransportException.NEGATIVE_SIZE,
"Negative length"
)
return trans.read(size)
def _writeString(trans, value):
writeVarint(trans, len(value))
trans.write(value)
class THeaderTransport(TTransportBase, CReadableTransport):
def __init__(self, transport, allowed_client_types):
self._transport = transport
self._client_type = THeaderClientType.HEADERS
self._allowed_client_types = allowed_client_types
self._read_buffer = BufferIO(b"")
self._read_headers = {}
self._write_buffer = BufferIO()
self._write_headers = {}
self._write_transforms = []
self.flags = 0
self.sequence_id = 0
self._protocol_id = THeaderSubprotocolID.BINARY
self._max_frame_size = HARD_MAX_FRAME_SIZE
def isOpen(self):
return self._transport.isOpen()
def open(self):
return self._transport.open()
def close(self):
return self._transport.close()
def get_headers(self):
return self._read_headers
def set_header(self, key, value):
if not isinstance(key, bytes):
raise ValueError("header names must be bytes")
if not isinstance(value, bytes):
raise ValueError("header values must be bytes")
self._write_headers[key] = value
def clear_headers(self):
self._write_headers.clear()
def add_transform(self, transform_id):
if transform_id not in WRITE_TRANSFORMS_BY_ID:
raise ValueError("unknown transform")
self._write_transforms.append(transform_id)
def set_max_frame_size(self, size):
if not 0 < size < HARD_MAX_FRAME_SIZE:
raise ValueError("maximum frame size should be < %d and > 0" % HARD_MAX_FRAME_SIZE)
self._max_frame_size = size
@property
def protocol_id(self):
if self._client_type == THeaderClientType.HEADERS:
return self._protocol_id
elif self._client_type in (THeaderClientType.FRAMED_BINARY, THeaderClientType.UNFRAMED_BINARY):
return THeaderSubprotocolID.BINARY
elif self._client_type in (THeaderClientType.FRAMED_COMPACT, THeaderClientType.UNFRAMED_COMPACT):
return THeaderSubprotocolID.COMPACT
else:
raise TTransportException(
TTransportException.INVALID_CLIENT_TYPE,
"Protocol ID not know for client type %d" % self._client_type,
)
def read(self, sz):
# if there are bytes left in the buffer, produce those first.
bytes_read = self._read_buffer.read(sz)
bytes_left_to_read = sz - len(bytes_read)
if bytes_left_to_read == 0:
return bytes_read
# if we've determined this is an unframed client, just pass the read
# through to the underlying transport until we're reset again at the
# beginning of the next message.
if self._client_type in (THeaderClientType.UNFRAMED_BINARY, THeaderClientType.UNFRAMED_COMPACT):
return bytes_read + self._transport.read(bytes_left_to_read)
# we're empty and (maybe) framed. fill the buffers with the next frame.
self.readFrame(bytes_left_to_read)
return bytes_read + self._read_buffer.read(bytes_left_to_read)
def _set_client_type(self, client_type):
if client_type not in self._allowed_client_types:
raise TTransportException(
TTransportException.INVALID_CLIENT_TYPE,
"Client type %d not allowed by server." % client_type,
)
self._client_type = client_type
def readFrame(self, req_sz):
# the first word could either be the length field of a framed message
# or the first bytes of an unframed message.
first_word = self._transport.readAll(I32.size)
frame_size, = I32.unpack(first_word)
is_unframed = False
if frame_size & TBinaryProtocol.VERSION_MASK == TBinaryProtocol.VERSION_1:
self._set_client_type(THeaderClientType.UNFRAMED_BINARY)
is_unframed = True
elif (byte_index(first_word, 0) == TCompactProtocol.PROTOCOL_ID and
byte_index(first_word, 1) & TCompactProtocol.VERSION_MASK == TCompactProtocol.VERSION):
self._set_client_type(THeaderClientType.UNFRAMED_COMPACT)
is_unframed = True
if is_unframed:
bytes_left_to_read = req_sz - I32.size
if bytes_left_to_read > 0:
rest = self._transport.read(bytes_left_to_read)
else:
rest = b""
self._read_buffer = BufferIO(first_word + rest)
return
# ok, we're still here so we're framed.
if frame_size > self._max_frame_size:
raise TTransportException(
TTransportException.SIZE_LIMIT,
"Frame was too large.",
)
read_buffer = BufferIO(self._transport.readAll(frame_size))
# the next word is either going to be the version field of a
# binary/compact protocol message or the magic value + flags of a
# header protocol message.
second_word = read_buffer.read(I32.size)
version, = I32.unpack(second_word)
read_buffer.seek(0)
if version >> 16 == HEADER_MAGIC:
self._set_client_type(THeaderClientType.HEADERS)
self._read_buffer = self._parse_header_format(read_buffer)
elif version & TBinaryProtocol.VERSION_MASK == TBinaryProtocol.VERSION_1:
self._set_client_type(THeaderClientType.FRAMED_BINARY)
self._read_buffer = read_buffer
elif (byte_index(second_word, 0) == TCompactProtocol.PROTOCOL_ID and
byte_index(second_word, 1) & TCompactProtocol.VERSION_MASK == TCompactProtocol.VERSION):
self._set_client_type(THeaderClientType.FRAMED_COMPACT)
self._read_buffer = read_buffer
else:
raise TTransportException(
TTransportException.INVALID_CLIENT_TYPE,
"Could not detect client transport type.",
)
def _parse_header_format(self, buffer):
# make BufferIO look like TTransport for varint helpers
buffer_transport = TMemoryBuffer()
buffer_transport._buffer = buffer
buffer.read(2) # discard the magic bytes
self.flags, = U16.unpack(buffer.read(U16.size))
self.sequence_id, = I32.unpack(buffer.read(I32.size))
header_length = U16.unpack(buffer.read(U16.size))[0] * 4
end_of_headers = buffer.tell() + header_length
if end_of_headers > len(buffer.getvalue()):
raise TTransportException(
TTransportException.SIZE_LIMIT,
"Header size is larger than whole frame.",
)
self._protocol_id = readVarint(buffer_transport)
transforms = []
transform_count = readVarint(buffer_transport)
for _ in range(transform_count):
transform_id = readVarint(buffer_transport)
if transform_id not in READ_TRANSFORMS_BY_ID:
raise TApplicationException(
TApplicationException.INVALID_TRANSFORM,
"Unknown transform: %d" % transform_id,
)
transforms.append(transform_id)
transforms.reverse()
headers = {}
while buffer.tell() < end_of_headers:
header_type = readVarint(buffer_transport)
if header_type == TInfoHeaderType.KEY_VALUE:
count = readVarint(buffer_transport)
for _ in range(count):
key = _readString(buffer_transport)
value = _readString(buffer_transport)
headers[key] = value
else:
break # ignore unknown headers
self._read_headers = headers
# skip padding / anything we didn't understand
buffer.seek(end_of_headers)
payload = buffer.read()
for transform_id in transforms:
transform_fn = READ_TRANSFORMS_BY_ID[transform_id]
payload = transform_fn(payload)
return BufferIO(payload)
def write(self, buf):
self._write_buffer.write(buf)
def flush(self):
payload = self._write_buffer.getvalue()
self._write_buffer = BufferIO()
buffer = BufferIO()
if self._client_type == THeaderClientType.HEADERS:
for transform_id in self._write_transforms:
transform_fn = WRITE_TRANSFORMS_BY_ID[transform_id]
payload = transform_fn(payload)
headers = BufferIO()
writeVarint(headers, self._protocol_id)
writeVarint(headers, len(self._write_transforms))
for transform_id in self._write_transforms:
writeVarint(headers, transform_id)
if self._write_headers:
writeVarint(headers, TInfoHeaderType.KEY_VALUE)
writeVarint(headers, len(self._write_headers))
for key, value in self._write_headers.items():
_writeString(headers, key)
_writeString(headers, value)
self._write_headers = {}
padding_needed = (4 - (len(headers.getvalue()) % 4)) % 4
headers.write(b"\x00" * padding_needed)
header_bytes = headers.getvalue()
buffer.write(I32.pack(10 + len(header_bytes) + len(payload)))
buffer.write(U16.pack(HEADER_MAGIC))
buffer.write(U16.pack(self.flags))
buffer.write(I32.pack(self.sequence_id))
buffer.write(U16.pack(len(header_bytes) // 4))
buffer.write(header_bytes)
buffer.write(payload)
elif self._client_type in (THeaderClientType.FRAMED_BINARY, THeaderClientType.FRAMED_COMPACT):
buffer.write(I32.pack(len(payload)))
buffer.write(payload)
elif self._client_type in (THeaderClientType.UNFRAMED_BINARY, THeaderClientType.UNFRAMED_COMPACT):
buffer.write(payload)
else:
raise TTransportException(
TTransportException.INVALID_CLIENT_TYPE,
"Unknown client type.",
)
# the frame length field doesn't count towards the frame payload size
frame_bytes = buffer.getvalue()
frame_payload_size = len(frame_bytes) - 4
if frame_payload_size > self._max_frame_size:
raise TTransportException(
TTransportException.SIZE_LIMIT,
"Attempting to send frame that is too large.",
)
self._transport.write(frame_bytes)
self._transport.flush()
@property
def cstringio_buf(self):
return self._read_buffer
def cstringio_refill(self, partialread, reqlen):
result = bytearray(partialread)
while len(result) < reqlen:
result += self.read(reqlen - len(result))
self._read_buffer = BufferIO(result)
return self._read_buffer
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
from io import BytesIO
import os
import ssl
import sys
import warnings
import base64
from six.moves import urllib
from six.moves import http_client
from .TTransport import TTransportBase
import six
class THttpClient(TTransportBase):
"""Http implementation of TTransport base."""
def __init__(self, uri_or_host, port=None, path=None, cafile=None, cert_file=None, key_file=None, ssl_context=None):
"""THttpClient supports two different types of construction:
THttpClient(host, port, path) - deprecated
THttpClient(uri, [port=<n>, path=<s>, cafile=<filename>, cert_file=<filename>, key_file=<filename>, ssl_context=<context>])
Only the second supports https. To properly authenticate against the server,
provide the client's identity by specifying cert_file and key_file. To properly
authenticate the server, specify either cafile or ssl_context with a CA defined.
NOTE: if both cafile and ssl_context are defined, ssl_context will override cafile.
"""
if port is not None:
warnings.warn(
"Please use the THttpClient('http{s}://host:port/path') constructor",
DeprecationWarning,
stacklevel=2)
self.host = uri_or_host
self.port = port
assert path
self.path = path
self.scheme = 'http'
else:
parsed = urllib.parse.urlparse(uri_or_host)
self.scheme = parsed.scheme
assert self.scheme in ('http', 'https')
if self.scheme == 'http':
self.port = parsed.port or http_client.HTTP_PORT
elif self.scheme == 'https':
self.port = parsed.port or http_client.HTTPS_PORT
self.certfile = cert_file
self.keyfile = key_file
self.context = ssl.create_default_context(cafile=cafile) if (cafile and not ssl_context) else ssl_context
self.host = parsed.hostname
self.path = parsed.path
if parsed.query:
self.path += '?%s' % parsed.query
try:
proxy = urllib.request.getproxies()[self.scheme]
except KeyError:
proxy = None
else:
if urllib.request.proxy_bypass(self.host):
proxy = None
if proxy:
parsed = urllib.parse.urlparse(proxy)
self.realhost = self.host
self.realport = self.port
self.host = parsed.hostname
self.port = parsed.port
self.proxy_auth = self.basic_proxy_auth_header(parsed)
else:
self.realhost = self.realport = self.proxy_auth = None
self.__wbuf = BytesIO()
self.__http = None
self.__http_response = None
self.__timeout = None
self.__custom_headers = None
@staticmethod
def basic_proxy_auth_header(proxy):
if proxy is None or not proxy.username:
return None
ap = "%s:%s" % (urllib.parse.unquote(proxy.username),
urllib.parse.unquote(proxy.password))
cr = base64.b64encode(ap).strip()
return "Basic " + cr
def using_proxy(self):
return self.realhost is not None
def open(self):
if self.scheme == 'http':
self.__http = http_client.HTTPConnection(self.host, self.port,
timeout=self.__timeout)
elif self.scheme == 'https':
self.__http = http_client.HTTPSConnection(self.host, self.port,
key_file=self.keyfile,
cert_file=self.certfile,
timeout=self.__timeout,
context=self.context)
if self.using_proxy():
self.__http.set_tunnel(self.realhost, self.realport,
{"Proxy-Authorization": self.proxy_auth})
def close(self):
self.__http.close()
self.__http = None
self.__http_response = None
def isOpen(self):
return self.__http is not None
def setTimeout(self, ms):
if ms is None:
self.__timeout = None
else:
self.__timeout = ms / 1000.0
def setCustomHeaders(self, headers):
self.__custom_headers = headers
def read(self, sz):
return self.__http_response.read(sz)
def write(self, buf):
self.__wbuf.write(buf)
def flush(self):
if self.isOpen():
self.close()
self.open()
# Pull data out of buffer
data = self.__wbuf.getvalue()
self.__wbuf = BytesIO()
# HTTP request
if self.using_proxy() and self.scheme == "http":
# need full URL of real host for HTTP proxy here (HTTPS uses CONNECT tunnel)
self.__http.putrequest('POST', "http://%s:%s%s" %
(self.realhost, self.realport, self.path))
else:
self.__http.putrequest('POST', self.path)
# Write headers
self.__http.putheader('Content-Type', 'application/x-thrift')
self.__http.putheader('Content-Length', str(len(data)))
if self.using_proxy() and self.scheme == "http" and self.proxy_auth is not None:
self.__http.putheader("Proxy-Authorization", self.proxy_auth)
if not self.__custom_headers or 'User-Agent' not in self.__custom_headers:
user_agent = 'Python/THttpClient'
script = os.path.basename(sys.argv[0])
if script:
user_agent = '%s (%s)' % (user_agent, urllib.parse.quote(script))
self.__http.putheader('User-Agent', user_agent)
if self.__custom_headers:
for key, val in six.iteritems(self.__custom_headers):
self.__http.putheader(key, val)
self.__http.endheaders()
# Write payload
self.__http.send(data)
# Get reply to flush the request
self.__http_response = self.__http.getresponse()
self.code = self.__http_response.status
self.message = self.__http_response.reason
self.headers = self.__http_response.msg
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
import logging
import os
import socket
import ssl
import sys
import warnings
from .sslcompat import _match_hostname, _match_has_ipaddress
from thrift.transport import TSocket
from thrift.transport.TTransport import TTransportException
logger = logging.getLogger(__name__)
warnings.filterwarnings(
'default', category=DeprecationWarning, module=__name__)
class TSSLBase(object):
# SSLContext is not available for Python < 2.7.9
_has_ssl_context = sys.hexversion >= 0x020709F0
# ciphers argument is not available for Python < 2.7.0
_has_ciphers = sys.hexversion >= 0x020700F0
# For python >= 2.7.9, use latest TLS that both client and server
# supports.
# SSL 2.0 and 3.0 are disabled via ssl.OP_NO_SSLv2 and ssl.OP_NO_SSLv3.
# For python < 2.7.9, use TLS 1.0 since TLSv1_X nor OP_NO_SSLvX is
# unavailable.
_default_protocol = ssl.PROTOCOL_SSLv23 if _has_ssl_context else \
ssl.PROTOCOL_TLSv1
def _init_context(self, ssl_version):
if self._has_ssl_context:
self._context = ssl.SSLContext(ssl_version)
if self._context.protocol == ssl.PROTOCOL_SSLv23:
self._context.options |= ssl.OP_NO_SSLv2
self._context.options |= ssl.OP_NO_SSLv3
else:
self._context = None
self._ssl_version = ssl_version
@property
def _should_verify(self):
if self._has_ssl_context:
return self._context.verify_mode != ssl.CERT_NONE
else:
return self.cert_reqs != ssl.CERT_NONE
@property
def ssl_version(self):
if self._has_ssl_context:
return self.ssl_context.protocol
else:
return self._ssl_version
@property
def ssl_context(self):
return self._context
SSL_VERSION = _default_protocol
"""
Default SSL version.
For backwards compatibility, it can be modified.
Use __init__ keyword argument "ssl_version" instead.
"""
def _deprecated_arg(self, args, kwargs, pos, key):
if len(args) <= pos:
return
real_pos = pos + 3
warnings.warn(
'%dth positional argument is deprecated.'
'please use keyword argument instead.'
% real_pos, DeprecationWarning, stacklevel=3)
if key in kwargs:
raise TypeError(
'Duplicate argument: %dth argument and %s keyword argument.'
% (real_pos, key))
kwargs[key] = args[pos]
def _unix_socket_arg(self, host, port, args, kwargs):
key = 'unix_socket'
if host is None and port is None and len(args) == 1 and key not in kwargs:
kwargs[key] = args[0]
return True
return False
def __getattr__(self, key):
if key == 'SSL_VERSION':
warnings.warn(
'SSL_VERSION is deprecated.'
'please use ssl_version attribute instead.',
DeprecationWarning, stacklevel=2)
return self.ssl_version
def __init__(self, server_side, host, ssl_opts):
self._server_side = server_side
if TSSLBase.SSL_VERSION != self._default_protocol:
warnings.warn(
'SSL_VERSION is deprecated.'
'please use ssl_version keyword argument instead.',
DeprecationWarning, stacklevel=2)
self._context = ssl_opts.pop('ssl_context', None)
self._server_hostname = None
if not self._server_side:
self._server_hostname = ssl_opts.pop('server_hostname', host)
if self._context:
self._custom_context = True
if ssl_opts:
raise ValueError(
'Incompatible arguments: ssl_context and %s'
% ' '.join(ssl_opts.keys()))
if not self._has_ssl_context:
raise ValueError(
'ssl_context is not available for this version of Python')
else:
self._custom_context = False
ssl_version = ssl_opts.pop('ssl_version', TSSLBase.SSL_VERSION)
self._init_context(ssl_version)
self.cert_reqs = ssl_opts.pop('cert_reqs', ssl.CERT_REQUIRED)
self.ca_certs = ssl_opts.pop('ca_certs', None)
self.keyfile = ssl_opts.pop('keyfile', None)
self.certfile = ssl_opts.pop('certfile', None)
self.ciphers = ssl_opts.pop('ciphers', None)
if ssl_opts:
raise ValueError(
'Unknown keyword arguments: ', ' '.join(ssl_opts.keys()))
if self._should_verify:
if not self.ca_certs:
raise ValueError(
'ca_certs is needed when cert_reqs is not ssl.CERT_NONE')
if not os.access(self.ca_certs, os.R_OK):
raise IOError('Certificate Authority ca_certs file "%s" '
'is not readable, cannot validate SSL '
'certificates.' % (self.ca_certs))
@property
def certfile(self):
return self._certfile
@certfile.setter
def certfile(self, certfile):
if self._server_side and not certfile:
raise ValueError('certfile is needed for server-side')
if certfile and not os.access(certfile, os.R_OK):
raise IOError('No such certfile found: %s' % (certfile))
self._certfile = certfile
def _wrap_socket(self, sock):
if self._has_ssl_context:
if not self._custom_context:
self.ssl_context.verify_mode = self.cert_reqs
if self.certfile:
self.ssl_context.load_cert_chain(self.certfile,
self.keyfile)
if self.ciphers:
self.ssl_context.set_ciphers(self.ciphers)
if self.ca_certs:
self.ssl_context.load_verify_locations(self.ca_certs)
return self.ssl_context.wrap_socket(
sock, server_side=self._server_side,
server_hostname=self._server_hostname)
else:
ssl_opts = {
'ssl_version': self._ssl_version,
'server_side': self._server_side,
'ca_certs': self.ca_certs,
'keyfile': self.keyfile,
'certfile': self.certfile,
'cert_reqs': self.cert_reqs,
}
if self.ciphers:
if self._has_ciphers:
ssl_opts['ciphers'] = self.ciphers
else:
logger.warning(
'ciphers is specified but ignored due to old Python version')
return ssl.wrap_socket(sock, **ssl_opts)
class TSSLSocket(TSocket.TSocket, TSSLBase):
"""
SSL implementation of TSocket
This class creates outbound sockets wrapped using the
python standard ssl module for encrypted connections.
"""
# New signature
# def __init__(self, host='localhost', port=9090, unix_socket=None,
# **ssl_args):
# Deprecated signature
# def __init__(self, host='localhost', port=9090, validate=True,
# ca_certs=None, keyfile=None, certfile=None,
# unix_socket=None, ciphers=None):
def __init__(self, host='localhost', port=9090, *args, **kwargs):
"""Positional arguments: ``host``, ``port``, ``unix_socket``
Keyword arguments: ``keyfile``, ``certfile``, ``cert_reqs``,
``ssl_version``, ``ca_certs``,
``ciphers`` (Python 2.7.0 or later),
``server_hostname`` (Python 2.7.9 or later)
Passed to ssl.wrap_socket. See ssl.wrap_socket documentation.
Alternative keyword arguments: (Python 2.7.9 or later)
``ssl_context``: ssl.SSLContext to be used for SSLContext.wrap_socket
``server_hostname``: Passed to SSLContext.wrap_socket
Common keyword argument:
``validate_callback`` (cert, hostname) -> None:
Called after SSL handshake. Can raise when hostname does not
match the cert.
``socket_keepalive`` enable TCP keepalive, default off.
"""
self.is_valid = False
self.peercert = None
if args:
if len(args) > 6:
raise TypeError('Too many positional argument')
if not self._unix_socket_arg(host, port, args, kwargs):
self._deprecated_arg(args, kwargs, 0, 'validate')
self._deprecated_arg(args, kwargs, 1, 'ca_certs')
self._deprecated_arg(args, kwargs, 2, 'keyfile')
self._deprecated_arg(args, kwargs, 3, 'certfile')
self._deprecated_arg(args, kwargs, 4, 'unix_socket')
self._deprecated_arg(args, kwargs, 5, 'ciphers')
validate = kwargs.pop('validate', None)
if validate is not None:
cert_reqs_name = 'CERT_REQUIRED' if validate else 'CERT_NONE'
warnings.warn(
'validate is deprecated. please use cert_reqs=ssl.%s instead'
% cert_reqs_name,
DeprecationWarning, stacklevel=2)
if 'cert_reqs' in kwargs:
raise TypeError('Cannot specify both validate and cert_reqs')
kwargs['cert_reqs'] = ssl.CERT_REQUIRED if validate else ssl.CERT_NONE
unix_socket = kwargs.pop('unix_socket', None)
socket_keepalive = kwargs.pop('socket_keepalive', False)
self._validate_callback = kwargs.pop('validate_callback', _match_hostname)
TSSLBase.__init__(self, False, host, kwargs)
TSocket.TSocket.__init__(self, host, port, unix_socket,
socket_keepalive=socket_keepalive)
def close(self):
try:
self.handle.settimeout(0.001)
self.handle = self.handle.unwrap()
except (ssl.SSLError, socket.error, OSError):
# could not complete shutdown in a reasonable amount of time. bail.
pass
TSocket.TSocket.close(self)
@property
def validate(self):
warnings.warn('validate is deprecated. please use cert_reqs instead',
DeprecationWarning, stacklevel=2)
return self.cert_reqs != ssl.CERT_NONE
@validate.setter
def validate(self, value):
warnings.warn('validate is deprecated. please use cert_reqs instead',
DeprecationWarning, stacklevel=2)
self.cert_reqs = ssl.CERT_REQUIRED if value else ssl.CERT_NONE
def _do_open(self, family, socktype):
plain_sock = socket.socket(family, socktype)
try:
return self._wrap_socket(plain_sock)
except Exception as ex:
plain_sock.close()
msg = 'failed to initialize SSL'
logger.exception(msg)
raise TTransportException(type=TTransportException.NOT_OPEN, message=msg, inner=ex)
def open(self):
super(TSSLSocket, self).open()
if self._should_verify:
self.peercert = self.handle.getpeercert()
try:
self._validate_callback(self.peercert, self._server_hostname)
self.is_valid = True
except TTransportException:
raise
except Exception as ex:
raise TTransportException(message=str(ex), inner=ex)
class TSSLServerSocket(TSocket.TServerSocket, TSSLBase):
"""SSL implementation of TServerSocket
This uses the ssl module's wrap_socket() method to provide SSL
negotiated encryption.
"""
# New signature
# def __init__(self, host='localhost', port=9090, unix_socket=None, **ssl_args):
# Deprecated signature
# def __init__(self, host=None, port=9090, certfile='cert.pem', unix_socket=None, ciphers=None):
def __init__(self, host=None, port=9090, *args, **kwargs):
"""Positional arguments: ``host``, ``port``, ``unix_socket``
Keyword arguments: ``keyfile``, ``certfile``, ``cert_reqs``, ``ssl_version``,
``ca_certs``, ``ciphers`` (Python 2.7.0 or later)
See ssl.wrap_socket documentation.
Alternative keyword arguments: (Python 2.7.9 or later)
``ssl_context``: ssl.SSLContext to be used for SSLContext.wrap_socket
``server_hostname``: Passed to SSLContext.wrap_socket
Common keyword argument:
``validate_callback`` (cert, hostname) -> None:
Called after SSL handshake. Can raise when hostname does not
match the cert.
"""
if args:
if len(args) > 3:
raise TypeError('Too many positional argument')
if not self._unix_socket_arg(host, port, args, kwargs):
self._deprecated_arg(args, kwargs, 0, 'certfile')
self._deprecated_arg(args, kwargs, 1, 'unix_socket')
self._deprecated_arg(args, kwargs, 2, 'ciphers')
if 'ssl_context' not in kwargs:
# Preserve existing behaviors for default values
if 'cert_reqs' not in kwargs:
kwargs['cert_reqs'] = ssl.CERT_NONE
if'certfile' not in kwargs:
kwargs['certfile'] = 'cert.pem'
unix_socket = kwargs.pop('unix_socket', None)
self._validate_callback = \
kwargs.pop('validate_callback', _match_hostname)
TSSLBase.__init__(self, True, None, kwargs)
TSocket.TServerSocket.__init__(self, host, port, unix_socket)
if self._should_verify and not _match_has_ipaddress:
raise ValueError('Need ipaddress and backports.ssl_match_hostname '
'module to verify client certificate')
def setCertfile(self, certfile):
"""Set or change the server certificate file used to wrap new
connections.
@param certfile: The filename of the server certificate,
i.e. '/etc/certs/server.pem'
@type certfile: str
Raises an IOError exception if the certfile is not present or unreadable.
"""
warnings.warn(
'setCertfile is deprecated. please use certfile property instead.',
DeprecationWarning, stacklevel=2)
self.certfile = certfile
def accept(self):
plain_client, addr = self.handle.accept()
try:
client = self._wrap_socket(plain_client)
except (ssl.SSLError, socket.error, OSError):
logger.exception('Error while accepting from %s', addr)
# failed handshake/ssl wrap, close socket to client
plain_client.close()
# raise
# We can't raise the exception, because it kills most TServer derived
# serve() methods.
# Instead, return None, and let the TServer instance deal with it in
# other exception handling. (but TSimpleServer dies anyway)
return None
if self._should_verify:
client.peercert = client.getpeercert()
try:
self._validate_callback(client.peercert, addr[0])
client.is_valid = True
except Exception:
logger.warn('Failed to validate client certificate address: %s',
addr[0], exc_info=True)
client.close()
plain_client.close()
return None
result = TSocket.TSocket()
result.handle = client
return result
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
import errno
import logging
import os
import socket
import sys
from .TTransport import TTransportBase, TTransportException, TServerTransportBase
logger = logging.getLogger(__name__)
class TSocketBase(TTransportBase):
def _resolveAddr(self):
if self._unix_socket is not None:
return [(socket.AF_UNIX, socket.SOCK_STREAM, None, None,
self._unix_socket)]
else:
return socket.getaddrinfo(self.host,
self.port,
self._socket_family,
socket.SOCK_STREAM,
0,
socket.AI_PASSIVE | socket.AI_ADDRCONFIG)
def close(self):
if self.handle:
self.handle.close()
self.handle = None
class TSocket(TSocketBase):
"""Socket implementation of TTransport base."""
def __init__(self, host='localhost', port=9090, unix_socket=None,
socket_family=socket.AF_UNSPEC,
socket_keepalive=False):
"""Initialize a TSocket
@param host(str) The host to connect to.
@param port(int) The (TCP) port to connect to.
@param unix_socket(str) The filename of a unix socket to connect to.
(host and port will be ignored.)
@param socket_family(int) The socket family to use with this socket.
@param socket_keepalive(bool) enable TCP keepalive, default off.
"""
self.host = host
self.port = port
self.handle = None
self._unix_socket = unix_socket
self._timeout = None
self._socket_family = socket_family
self._socket_keepalive = socket_keepalive
def setHandle(self, h):
self.handle = h
def isOpen(self):
return self.handle is not None
def setTimeout(self, ms):
if ms is None:
self._timeout = None
else:
self._timeout = ms / 1000.0
if self.handle is not None:
self.handle.settimeout(self._timeout)
def _do_open(self, family, socktype):
return socket.socket(family, socktype)
@property
def _address(self):
return self._unix_socket if self._unix_socket else '%s:%d' % (self.host, self.port)
def open(self):
if self.handle:
raise TTransportException(type=TTransportException.ALREADY_OPEN, message="already open")
try:
addrs = self._resolveAddr()
except socket.gaierror as gai:
msg = 'failed to resolve sockaddr for ' + str(self._address)
logger.exception(msg)
raise TTransportException(type=TTransportException.NOT_OPEN, message=msg, inner=gai)
for family, socktype, _, _, sockaddr in addrs:
handle = self._do_open(family, socktype)
# TCP_KEEPALIVE
if self._socket_keepalive:
handle.setsockopt(socket.IPPROTO_TCP, socket.SO_KEEPALIVE, 1)
handle.settimeout(self._timeout)
try:
handle.connect(sockaddr)
self.handle = handle
return
except socket.error:
handle.close()
logger.info('Could not connect to %s', sockaddr, exc_info=True)
msg = 'Could not connect to any of %s' % list(map(lambda a: a[4],
addrs))
logger.error(msg)
raise TTransportException(type=TTransportException.NOT_OPEN, message=msg)
def read(self, sz):
try:
buff = self.handle.recv(sz)
except socket.error as e:
if (e.args[0] == errno.ECONNRESET and
(sys.platform == 'darwin' or sys.platform.startswith('freebsd'))):
# freebsd and Mach don't follow POSIX semantic of recv
# and fail with ECONNRESET if peer performed shutdown.
# See corresponding comment and code in TSocket::read()
# in lib/cpp/src/transport/TSocket.cpp.
self.close()
# Trigger the check to raise the END_OF_FILE exception below.
buff = ''
elif e.args[0] == errno.ETIMEDOUT:
raise TTransportException(type=TTransportException.TIMED_OUT, message="read timeout", inner=e)
else:
raise TTransportException(message="unexpected exception", inner=e)
if len(buff) == 0:
raise TTransportException(type=TTransportException.END_OF_FILE,
message='TSocket read 0 bytes')
return buff
def write(self, buff):
if not self.handle:
raise TTransportException(type=TTransportException.NOT_OPEN,
message='Transport not open')
sent = 0
have = len(buff)
while sent < have:
try:
plus = self.handle.send(buff)
if plus == 0:
raise TTransportException(type=TTransportException.END_OF_FILE,
message='TSocket sent 0 bytes')
sent += plus
buff = buff[plus:]
except socket.error as e:
raise TTransportException(message="unexpected exception", inner=e)
def flush(self):
pass
class TServerSocket(TSocketBase, TServerTransportBase):
"""Socket implementation of TServerTransport base."""
def __init__(self, host=None, port=9090, unix_socket=None, socket_family=socket.AF_UNSPEC):
self.host = host
self.port = port
self._unix_socket = unix_socket
self._socket_family = socket_family
self.handle = None
self._backlog = 128
def setBacklog(self, backlog=None):
if not self.handle:
self._backlog = backlog
else:
# We cann't update backlog when it is already listening, since the
# handle has been created.
logger.warn('You have to set backlog before listen.')
def listen(self):
res0 = self._resolveAddr()
socket_family = self._socket_family == socket.AF_UNSPEC and socket.AF_INET6 or self._socket_family
for res in res0:
if res[0] is socket_family or res is res0[-1]:
break
# We need remove the old unix socket if the file exists and
# nobody is listening on it.
if self._unix_socket:
tmp = socket.socket(res[0], res[1])
try:
tmp.connect(res[4])
except socket.error as err:
eno, message = err.args
if eno == errno.ECONNREFUSED:
os.unlink(res[4])
self.handle = socket.socket(res[0], res[1])
self.handle.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
if hasattr(self.handle, 'settimeout'):
self.handle.settimeout(None)
self.handle.bind(res[4])
self.handle.listen(self._backlog)
def accept(self):
client, addr = self.handle.accept()
result = TSocket()
result.setHandle(client)
return result
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
from struct import pack, unpack
from thrift.Thrift import TException
from ..compat import BufferIO
class TTransportException(TException):
"""Custom Transport Exception class"""
UNKNOWN = 0
NOT_OPEN = 1
ALREADY_OPEN = 2
TIMED_OUT = 3
END_OF_FILE = 4
NEGATIVE_SIZE = 5
SIZE_LIMIT = 6
INVALID_CLIENT_TYPE = 7
def __init__(self, type=UNKNOWN, message=None, inner=None):
TException.__init__(self, message)
self.type = type
self.inner = inner
class TTransportBase(object):
"""Base class for Thrift transport layer."""
def isOpen(self):
pass
def open(self):
pass
def close(self):
pass
def read(self, sz):
pass
def readAll(self, sz):
buff = b''
have = 0
while (have < sz):
chunk = self.read(sz - have)
chunkLen = len(chunk)
have += chunkLen
buff += chunk
if chunkLen == 0:
raise EOFError()
return buff
def write(self, buf):
pass
def flush(self):
pass
# This class should be thought of as an interface.
class CReadableTransport(object):
"""base class for transports that are readable from C"""
# TODO(dreiss): Think about changing this interface to allow us to use
# a (Python, not c) StringIO instead, because it allows
# you to write after reading.
# NOTE: This is a classic class, so properties will NOT work
# correctly for setting.
@property
def cstringio_buf(self):
"""A cStringIO buffer that contains the current chunk we are reading."""
pass
def cstringio_refill(self, partialread, reqlen):
"""Refills cstringio_buf.
Returns the currently used buffer (which can but need not be the same as
the old cstringio_buf). partialread is what the C code has read from the
buffer, and should be inserted into the buffer before any more reads. The
return value must be a new, not borrowed reference. Something along the
lines of self._buf should be fine.
If reqlen bytes can't be read, throw EOFError.
"""
pass
class TServerTransportBase(object):
"""Base class for Thrift server transports."""
def listen(self):
pass
def accept(self):
pass
def close(self):
pass
class TTransportFactoryBase(object):
"""Base class for a Transport Factory"""
def getTransport(self, trans):
return trans
class TBufferedTransportFactory(object):
"""Factory transport that builds buffered transports"""
def getTransport(self, trans):
buffered = TBufferedTransport(trans)
return buffered
class TBufferedTransport(TTransportBase, CReadableTransport):
"""Class that wraps another transport and buffers its I/O.
The implementation uses a (configurable) fixed-size read buffer
but buffers all writes until a flush is performed.
"""
DEFAULT_BUFFER = 4096
def __init__(self, trans, rbuf_size=DEFAULT_BUFFER):
self.__trans = trans
self.__wbuf = BufferIO()
# Pass string argument to initialize read buffer as cStringIO.InputType
self.__rbuf = BufferIO(b'')
self.__rbuf_size = rbuf_size
def isOpen(self):
return self.__trans.isOpen()
def open(self):
return self.__trans.open()
def close(self):
return self.__trans.close()
def read(self, sz):
ret = self.__rbuf.read(sz)
if len(ret) != 0:
return ret
self.__rbuf = BufferIO(self.__trans.read(max(sz, self.__rbuf_size)))
return self.__rbuf.read(sz)
def write(self, buf):
try:
self.__wbuf.write(buf)
except Exception as e:
# on exception reset wbuf so it doesn't contain a partial function call
self.__wbuf = BufferIO()
raise e
def flush(self):
out = self.__wbuf.getvalue()
# reset wbuf before write/flush to preserve state on underlying failure
self.__wbuf = BufferIO()
self.__trans.write(out)
self.__trans.flush()
# Implement the CReadableTransport interface.
@property
def cstringio_buf(self):
return self.__rbuf
def cstringio_refill(self, partialread, reqlen):
retstring = partialread
if reqlen < self.__rbuf_size:
# try to make a read of as much as we can.
retstring += self.__trans.read(self.__rbuf_size)
# but make sure we do read reqlen bytes.
if len(retstring) < reqlen:
retstring += self.__trans.readAll(reqlen - len(retstring))
self.__rbuf = BufferIO(retstring)
return self.__rbuf
class TMemoryBuffer(TTransportBase, CReadableTransport):
"""Wraps a cBytesIO object as a TTransport.
NOTE: Unlike the C++ version of this class, you cannot write to it
then immediately read from it. If you want to read from a
TMemoryBuffer, you must either pass a string to the constructor.
TODO(dreiss): Make this work like the C++ version.
"""
def __init__(self, value=None, offset=0):
"""value -- a value to read from for stringio
If value is set, this will be a transport for reading,
otherwise, it is for writing"""
if value is not None:
self._buffer = BufferIO(value)
else:
self._buffer = BufferIO()
if offset:
self._buffer.seek(offset)
def isOpen(self):
return not self._buffer.closed
def open(self):
pass
def close(self):
self._buffer.close()
def read(self, sz):
return self._buffer.read(sz)
def write(self, buf):
self._buffer.write(buf)
def flush(self):
pass
def getvalue(self):
return self._buffer.getvalue()
# Implement the CReadableTransport interface.
@property
def cstringio_buf(self):
return self._buffer
def cstringio_refill(self, partialread, reqlen):
# only one shot at reading...
raise EOFError()
class TFramedTransportFactory(object):
"""Factory transport that builds framed transports"""
def getTransport(self, trans):
framed = TFramedTransport(trans)
return framed
class TFramedTransport(TTransportBase, CReadableTransport):
"""Class that wraps another transport and frames its I/O when writing."""
def __init__(self, trans,):
self.__trans = trans
self.__rbuf = BufferIO(b'')
self.__wbuf = BufferIO()
def isOpen(self):
return self.__trans.isOpen()
def open(self):
return self.__trans.open()
def close(self):
return self.__trans.close()
def read(self, sz):
ret = self.__rbuf.read(sz)
if len(ret) != 0:
return ret
self.readFrame()
return self.__rbuf.read(sz)
def readFrame(self):
buff = self.__trans.readAll(4)
sz, = unpack('!i', buff)
self.__rbuf = BufferIO(self.__trans.readAll(sz))
def write(self, buf):
self.__wbuf.write(buf)
def flush(self):
wout = self.__wbuf.getvalue()
wsz = len(wout)
# reset wbuf before write/flush to preserve state on underlying failure
self.__wbuf = BufferIO()
# N.B.: Doing this string concatenation is WAY cheaper than making
# two separate calls to the underlying socket object. Socket writes in
# Python turn out to be REALLY expensive, but it seems to do a pretty
# good job of managing string buffer operations without excessive copies
buf = pack("!i", wsz) + wout
self.__trans.write(buf)
self.__trans.flush()
# Implement the CReadableTransport interface.
@property
def cstringio_buf(self):
return self.__rbuf
def cstringio_refill(self, prefix, reqlen):
# self.__rbuf will already be empty here because fastbinary doesn't
# ask for a refill until the previous buffer is empty. Therefore,
# we can start reading new frames immediately.
while len(prefix) < reqlen:
self.readFrame()
prefix += self.__rbuf.getvalue()
self.__rbuf = BufferIO(prefix)
return self.__rbuf
class TFileObjectTransport(TTransportBase):
"""Wraps a file-like object to make it work as a Thrift transport."""
def __init__(self, fileobj):
self.fileobj = fileobj
def isOpen(self):
return True
def close(self):
self.fileobj.close()
def read(self, sz):
return self.fileobj.read(sz)
def write(self, buf):
self.fileobj.write(buf)
def flush(self):
self.fileobj.flush()
class TSaslClientTransport(TTransportBase, CReadableTransport):
"""
SASL transport
"""
START = 1
OK = 2
BAD = 3
ERROR = 4
COMPLETE = 5
def __init__(self, transport, host, service, mechanism='GSSAPI',
**sasl_kwargs):
"""
transport: an underlying transport to use, typically just a TSocket
host: the name of the server, from a SASL perspective
service: the name of the server's service, from a SASL perspective
mechanism: the name of the preferred mechanism to use
All other kwargs will be passed to the puresasl.client.SASLClient
constructor.
"""
from puresasl.client import SASLClient
self.transport = transport
self.sasl = SASLClient(host, service, mechanism, **sasl_kwargs)
self.__wbuf = BufferIO()
self.__rbuf = BufferIO(b'')
def open(self):
if not self.transport.isOpen():
self.transport.open()
self.send_sasl_msg(self.START, bytes(self.sasl.mechanism, 'ascii'))
self.send_sasl_msg(self.OK, self.sasl.process())
while True:
status, challenge = self.recv_sasl_msg()
if status == self.OK:
self.send_sasl_msg(self.OK, self.sasl.process(challenge))
elif status == self.COMPLETE:
if not self.sasl.complete:
raise TTransportException(
TTransportException.NOT_OPEN,
"The server erroneously indicated "
"that SASL negotiation was complete")
else:
break
else:
raise TTransportException(
TTransportException.NOT_OPEN,
"Bad SASL negotiation status: %d (%s)"
% (status, challenge))
def send_sasl_msg(self, status, body):
header = pack(">BI", status, len(body))
self.transport.write(header + body)
self.transport.flush()
def recv_sasl_msg(self):
header = self.transport.readAll(5)
status, length = unpack(">BI", header)
if length > 0:
payload = self.transport.readAll(length)
else:
payload = ""
return status, payload
def write(self, data):
self.__wbuf.write(data)
def flush(self):
data = self.__wbuf.getvalue()
encoded = self.sasl.wrap(data)
self.transport.write(pack("!i", len(encoded)) + encoded)
self.transport.flush()
self.__wbuf = BufferIO()
def read(self, sz):
ret = self.__rbuf.read(sz)
if len(ret) != 0:
return ret
self._read_frame()
return self.__rbuf.read(sz)
def _read_frame(self):
header = self.transport.readAll(4)
length, = unpack('!i', header)
encoded = self.transport.readAll(length)
self.__rbuf = BufferIO(self.sasl.unwrap(encoded))
def close(self):
self.sasl.dispose()
self.transport.close()
# based on TFramedTransport
@property
def cstringio_buf(self):
return self.__rbuf
def cstringio_refill(self, prefix, reqlen):
# self.__rbuf will already be empty here because fastbinary doesn't
# ask for a refill until the previous buffer is empty. Therefore,
# we can start reading new frames immediately.
while len(prefix) < reqlen:
self._read_frame()
prefix += self.__rbuf.getvalue()
self.__rbuf = BufferIO(prefix)
return self.__rbuf
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
from io import BytesIO
import struct
from zope.interface import implementer, Interface, Attribute
from twisted.internet.protocol import ServerFactory, ClientFactory, \
connectionDone
from twisted.internet import defer
from twisted.internet.threads import deferToThread
from twisted.protocols import basic
from twisted.web import server, resource, http
from thrift.transport import TTransport
class TMessageSenderTransport(TTransport.TTransportBase):
def __init__(self):
self.__wbuf = BytesIO()
def write(self, buf):
self.__wbuf.write(buf)
def flush(self):
msg = self.__wbuf.getvalue()
self.__wbuf = BytesIO()
return self.sendMessage(msg)
def sendMessage(self, message):
raise NotImplementedError
class TCallbackTransport(TMessageSenderTransport):
def __init__(self, func):
TMessageSenderTransport.__init__(self)
self.func = func
def sendMessage(self, message):
return self.func(message)
class ThriftClientProtocol(basic.Int32StringReceiver):
MAX_LENGTH = 2 ** 31 - 1
def __init__(self, client_class, iprot_factory, oprot_factory=None):
self._client_class = client_class
self._iprot_factory = iprot_factory
if oprot_factory is None:
self._oprot_factory = iprot_factory
else:
self._oprot_factory = oprot_factory
self.recv_map = {}
self.started = defer.Deferred()
def dispatch(self, msg):
self.sendString(msg)
def connectionMade(self):
tmo = TCallbackTransport(self.dispatch)
self.client = self._client_class(tmo, self._oprot_factory)
self.started.callback(self.client)
def connectionLost(self, reason=connectionDone):
# the called errbacks can add items to our client's _reqs,
# so we need to use a tmp, and iterate until no more requests
# are added during errbacks
if self.client:
tex = TTransport.TTransportException(
type=TTransport.TTransportException.END_OF_FILE,
message='Connection closed (%s)' % reason)
while self.client._reqs:
_, v = self.client._reqs.popitem()
v.errback(tex)
del self.client._reqs
self.client = None
def stringReceived(self, frame):
tr = TTransport.TMemoryBuffer(frame)
iprot = self._iprot_factory.getProtocol(tr)
(fname, mtype, rseqid) = iprot.readMessageBegin()
try:
method = self.recv_map[fname]
except KeyError:
method = getattr(self.client, 'recv_' + fname)
self.recv_map[fname] = method
method(iprot, mtype, rseqid)
class ThriftSASLClientProtocol(ThriftClientProtocol):
START = 1
OK = 2
BAD = 3
ERROR = 4
COMPLETE = 5
MAX_LENGTH = 2 ** 31 - 1
def __init__(self, client_class, iprot_factory, oprot_factory=None,
host=None, service=None, mechanism='GSSAPI', **sasl_kwargs):
"""
host: the name of the server, from a SASL perspective
service: the name of the server's service, from a SASL perspective
mechanism: the name of the preferred mechanism to use
All other kwargs will be passed to the puresasl.client.SASLClient
constructor.
"""
from puresasl.client import SASLClient
self.SASLCLient = SASLClient
ThriftClientProtocol.__init__(self, client_class, iprot_factory, oprot_factory)
self._sasl_negotiation_deferred = None
self._sasl_negotiation_status = None
self.client = None
if host is not None:
self.createSASLClient(host, service, mechanism, **sasl_kwargs)
def createSASLClient(self, host, service, mechanism, **kwargs):
self.sasl = self.SASLClient(host, service, mechanism, **kwargs)
def dispatch(self, msg):
encoded = self.sasl.wrap(msg)
len_and_encoded = ''.join((struct.pack('!i', len(encoded)), encoded))
ThriftClientProtocol.dispatch(self, len_and_encoded)
@defer.inlineCallbacks
def connectionMade(self):
self._sendSASLMessage(self.START, self.sasl.mechanism)
initial_message = yield deferToThread(self.sasl.process)
self._sendSASLMessage(self.OK, initial_message)
while True:
status, challenge = yield self._receiveSASLMessage()
if status == self.OK:
response = yield deferToThread(self.sasl.process, challenge)
self._sendSASLMessage(self.OK, response)
elif status == self.COMPLETE:
if not self.sasl.complete:
msg = "The server erroneously indicated that SASL " \
"negotiation was complete"
raise TTransport.TTransportException(msg, message=msg)
else:
break
else:
msg = "Bad SASL negotiation status: %d (%s)" % (status, challenge)
raise TTransport.TTransportException(msg, message=msg)
self._sasl_negotiation_deferred = None
ThriftClientProtocol.connectionMade(self)
def _sendSASLMessage(self, status, body):
if body is None:
body = ""
header = struct.pack(">BI", status, len(body))
self.transport.write(header + body)
def _receiveSASLMessage(self):
self._sasl_negotiation_deferred = defer.Deferred()
self._sasl_negotiation_status = None
return self._sasl_negotiation_deferred
def connectionLost(self, reason=connectionDone):
if self.client:
ThriftClientProtocol.connectionLost(self, reason)
def dataReceived(self, data):
if self._sasl_negotiation_deferred:
# we got a sasl challenge in the format (status, length, challenge)
# save the status, let IntNStringReceiver piece the challenge data together
self._sasl_negotiation_status, = struct.unpack("B", data[0])
ThriftClientProtocol.dataReceived(self, data[1:])
else:
# normal frame, let IntNStringReceiver piece it together
ThriftClientProtocol.dataReceived(self, data)
def stringReceived(self, frame):
if self._sasl_negotiation_deferred:
# the frame is just a SASL challenge
response = (self._sasl_negotiation_status, frame)
self._sasl_negotiation_deferred.callback(response)
else:
# there's a second 4 byte length prefix inside the frame
decoded_frame = self.sasl.unwrap(frame[4:])
ThriftClientProtocol.stringReceived(self, decoded_frame)
class ThriftServerProtocol(basic.Int32StringReceiver):
MAX_LENGTH = 2 ** 31 - 1
def dispatch(self, msg):
self.sendString(msg)
def processError(self, error):
self.transport.loseConnection()
def processOk(self, _, tmo):
msg = tmo.getvalue()
if len(msg) > 0:
self.dispatch(msg)
def stringReceived(self, frame):
tmi = TTransport.TMemoryBuffer(frame)
tmo = TTransport.TMemoryBuffer()
iprot = self.factory.iprot_factory.getProtocol(tmi)
oprot = self.factory.oprot_factory.getProtocol(tmo)
d = self.factory.processor.process(iprot, oprot)
d.addCallbacks(self.processOk, self.processError,
callbackArgs=(tmo,))
class IThriftServerFactory(Interface):
processor = Attribute("Thrift processor")
iprot_factory = Attribute("Input protocol factory")
oprot_factory = Attribute("Output protocol factory")
class IThriftClientFactory(Interface):
client_class = Attribute("Thrift client class")
iprot_factory = Attribute("Input protocol factory")
oprot_factory = Attribute("Output protocol factory")
@implementer(IThriftServerFactory)
class ThriftServerFactory(ServerFactory):
protocol = ThriftServerProtocol
def __init__(self, processor, iprot_factory, oprot_factory=None):
self.processor = processor
self.iprot_factory = iprot_factory
if oprot_factory is None:
self.oprot_factory = iprot_factory
else:
self.oprot_factory = oprot_factory
@implementer(IThriftClientFactory)
class ThriftClientFactory(ClientFactory):
protocol = ThriftClientProtocol
def __init__(self, client_class, iprot_factory, oprot_factory=None):
self.client_class = client_class
self.iprot_factory = iprot_factory
if oprot_factory is None:
self.oprot_factory = iprot_factory
else:
self.oprot_factory = oprot_factory
def buildProtocol(self, addr):
p = self.protocol(self.client_class, self.iprot_factory,
self.oprot_factory)
p.factory = self
return p
class ThriftResource(resource.Resource):
allowedMethods = ('POST',)
def __init__(self, processor, inputProtocolFactory,
outputProtocolFactory=None):
resource.Resource.__init__(self)
self.inputProtocolFactory = inputProtocolFactory
if outputProtocolFactory is None:
self.outputProtocolFactory = inputProtocolFactory
else:
self.outputProtocolFactory = outputProtocolFactory
self.processor = processor
def getChild(self, path, request):
return self
def _cbProcess(self, _, request, tmo):
msg = tmo.getvalue()
request.setResponseCode(http.OK)
request.setHeader("content-type", "application/x-thrift")
request.write(msg)
request.finish()
def render_POST(self, request):
request.content.seek(0, 0)
data = request.content.read()
tmi = TTransport.TMemoryBuffer(data)
tmo = TTransport.TMemoryBuffer()
iprot = self.inputProtocolFactory.getProtocol(tmi)
oprot = self.outputProtocolFactory.getProtocol(tmo)
d = self.processor.process(iprot, oprot)
d.addCallback(self._cbProcess, request, tmo)
return server.NOT_DONE_YET
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
"""TZlibTransport provides a compressed transport and transport factory
class, using the python standard library zlib module to implement
data compression.
"""
from __future__ import division
import zlib
from .TTransport import TTransportBase, CReadableTransport
from ..compat import BufferIO
class TZlibTransportFactory(object):
"""Factory transport that builds zlib compressed transports.
This factory caches the last single client/transport that it was passed
and returns the same TZlibTransport object that was created.
This caching means the TServer class will get the _same_ transport
object for both input and output transports from this factory.
(For non-threaded scenarios only, since the cache only holds one object)
The purpose of this caching is to allocate only one TZlibTransport where
only one is really needed (since it must have separate read/write buffers),
and makes the statistics from getCompSavings() and getCompRatio()
easier to understand.
"""
# class scoped cache of last transport given and zlibtransport returned
_last_trans = None
_last_z = None
def getTransport(self, trans, compresslevel=9):
"""Wrap a transport, trans, with the TZlibTransport
compressed transport class, returning a new
transport to the caller.
@param compresslevel: The zlib compression level, ranging
from 0 (no compression) to 9 (best compression). Defaults to 9.
@type compresslevel: int
This method returns a TZlibTransport which wraps the
passed C{trans} TTransport derived instance.
"""
if trans == self._last_trans:
return self._last_z
ztrans = TZlibTransport(trans, compresslevel)
self._last_trans = trans
self._last_z = ztrans
return ztrans
class TZlibTransport(TTransportBase, CReadableTransport):
"""Class that wraps a transport with zlib, compressing writes
and decompresses reads, using the python standard
library zlib module.
"""
# Read buffer size for the python fastbinary C extension,
# the TBinaryProtocolAccelerated class.
DEFAULT_BUFFSIZE = 4096
def __init__(self, trans, compresslevel=9):
"""Create a new TZlibTransport, wrapping C{trans}, another
TTransport derived object.
@param trans: A thrift transport object, i.e. a TSocket() object.
@type trans: TTransport
@param compresslevel: The zlib compression level, ranging
from 0 (no compression) to 9 (best compression). Default is 9.
@type compresslevel: int
"""
self.__trans = trans
self.compresslevel = compresslevel
self.__rbuf = BufferIO()
self.__wbuf = BufferIO()
self._init_zlib()
self._init_stats()
def _reinit_buffers(self):
"""Internal method to initialize/reset the internal StringIO objects
for read and write buffers.
"""
self.__rbuf = BufferIO()
self.__wbuf = BufferIO()
def _init_stats(self):
"""Internal method to reset the internal statistics counters
for compression ratios and bandwidth savings.
"""
self.bytes_in = 0
self.bytes_out = 0
self.bytes_in_comp = 0
self.bytes_out_comp = 0
def _init_zlib(self):
"""Internal method for setting up the zlib compression and
decompression objects.
"""
self._zcomp_read = zlib.decompressobj()
self._zcomp_write = zlib.compressobj(self.compresslevel)
def getCompRatio(self):
"""Get the current measured compression ratios (in,out) from
this transport.
Returns a tuple of:
(inbound_compression_ratio, outbound_compression_ratio)
The compression ratios are computed as:
compressed / uncompressed
E.g., data that compresses by 10x will have a ratio of: 0.10
and data that compresses to half of ts original size will
have a ratio of 0.5
None is returned if no bytes have yet been processed in
a particular direction.
"""
r_percent, w_percent = (None, None)
if self.bytes_in > 0:
r_percent = self.bytes_in_comp / self.bytes_in
if self.bytes_out > 0:
w_percent = self.bytes_out_comp / self.bytes_out
return (r_percent, w_percent)
def getCompSavings(self):
"""Get the current count of saved bytes due to data
compression.
Returns a tuple of:
(inbound_saved_bytes, outbound_saved_bytes)
Note: if compression is actually expanding your
data (only likely with very tiny thrift objects), then
the values returned will be negative.
"""
r_saved = self.bytes_in - self.bytes_in_comp
w_saved = self.bytes_out - self.bytes_out_comp
return (r_saved, w_saved)
def isOpen(self):
"""Return the underlying transport's open status"""
return self.__trans.isOpen()
def open(self):
"""Open the underlying transport"""
self._init_stats()
return self.__trans.open()
def listen(self):
"""Invoke the underlying transport's listen() method"""
self.__trans.listen()
def accept(self):
"""Accept connections on the underlying transport"""
return self.__trans.accept()
def close(self):
"""Close the underlying transport,"""
self._reinit_buffers()
self._init_zlib()
return self.__trans.close()
def read(self, sz):
"""Read up to sz bytes from the decompressed bytes buffer, and
read from the underlying transport if the decompression
buffer is empty.
"""
ret = self.__rbuf.read(sz)
if len(ret) > 0:
return ret
# keep reading from transport until something comes back
while True:
if self.readComp(sz):
break
ret = self.__rbuf.read(sz)
return ret
def readComp(self, sz):
"""Read compressed data from the underlying transport, then
decompress it and append it to the internal StringIO read buffer
"""
zbuf = self.__trans.read(sz)
zbuf = self._zcomp_read.unconsumed_tail + zbuf
buf = self._zcomp_read.decompress(zbuf)
self.bytes_in += len(zbuf)
self.bytes_in_comp += len(buf)
old = self.__rbuf.read()
self.__rbuf = BufferIO(old + buf)
if len(old) + len(buf) == 0:
return False
return True
def write(self, buf):
"""Write some bytes, putting them into the internal write
buffer for eventual compression.
"""
self.__wbuf.write(buf)
def flush(self):
"""Flush any queued up data in the write buffer and ensure the
compression buffer is flushed out to the underlying transport
"""
wout = self.__wbuf.getvalue()
if len(wout) > 0:
zbuf = self._zcomp_write.compress(wout)
self.bytes_out += len(wout)
self.bytes_out_comp += len(zbuf)
else:
zbuf = ''
ztail = self._zcomp_write.flush(zlib.Z_SYNC_FLUSH)
self.bytes_out_comp += len(ztail)
if (len(zbuf) + len(ztail)) > 0:
self.__wbuf = BufferIO()
self.__trans.write(zbuf + ztail)
self.__trans.flush()
@property
def cstringio_buf(self):
"""Implement the CReadableTransport interface"""
return self.__rbuf
def cstringio_refill(self, partialread, reqlen):
"""Implement the CReadableTransport interface for refill"""
retstring = partialread
if reqlen < self.DEFAULT_BUFFSIZE:
retstring += self.read(self.DEFAULT_BUFFSIZE)
while len(retstring) < reqlen:
retstring += self.read(reqlen - len(retstring))
self.__rbuf = BufferIO(retstring)
return self.__rbuf
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
__all__ = ['TTransport', 'TSocket', 'THttpClient', 'TZlibTransport']
#
# licensed to the apache software foundation (asf) under one
# or more contributor license agreements. see the notice file
# distributed with this work for additional information
# regarding copyright ownership. the asf licenses this file
# to you under the apache license, version 2.0 (the
# "license"); you may not use this file except in compliance
# with the license. you may obtain a copy of the license at
#
# http://www.apache.org/licenses/license-2.0
#
# unless required by applicable law or agreed to in writing,
# software distributed under the license is distributed on an
# "as is" basis, without warranties or conditions of any
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
import logging
import sys
from thrift.transport.TTransport import TTransportException
logger = logging.getLogger(__name__)
def legacy_validate_callback(cert, hostname):
"""legacy method to validate the peer's SSL certificate, and to check
the commonName of the certificate to ensure it matches the hostname we
used to make this connection. Does not support subjectAltName records
in certificates.
raises TTransportException if the certificate fails validation.
"""
if 'subject' not in cert:
raise TTransportException(
TTransportException.NOT_OPEN,
'No SSL certificate found from %s' % hostname)
fields = cert['subject']
for field in fields:
# ensure structure we get back is what we expect
if not isinstance(field, tuple):
continue
cert_pair = field[0]
if len(cert_pair) < 2:
continue
cert_key, cert_value = cert_pair[0:2]
if cert_key != 'commonName':
continue
certhost = cert_value
# this check should be performed by some sort of Access Manager
if certhost == hostname:
# success, cert commonName matches desired hostname
return
else:
raise TTransportException(
TTransportException.UNKNOWN,
'Hostname we connected to "%s" doesn\'t match certificate '
'provided commonName "%s"' % (hostname, certhost))
raise TTransportException(
TTransportException.UNKNOWN,
'Could not validate SSL certificate from host "%s". Cert=%s'
% (hostname, cert))
def _optional_dependencies():
try:
import ipaddress # noqa
logger.debug('ipaddress module is available')
ipaddr = True
except ImportError:
logger.warn('ipaddress module is unavailable')
ipaddr = False
if sys.hexversion < 0x030500F0:
try:
from backports.ssl_match_hostname import match_hostname, __version__ as ver
ver = list(map(int, ver.split('.')))
logger.debug('backports.ssl_match_hostname module is available')
match = match_hostname
if ver[0] * 10 + ver[1] >= 35:
return ipaddr, match
else:
logger.warn('backports.ssl_match_hostname module is too old')
ipaddr = False
except ImportError:
logger.warn('backports.ssl_match_hostname is unavailable')
ipaddr = False
try:
from ssl import match_hostname
logger.debug('ssl.match_hostname is available')
match = match_hostname
except ImportError:
logger.warn('using legacy validation callback')
match = legacy_validate_callback
return ipaddr, match
_match_has_ipaddress, _match_hostname = _optional_dependencies()
Metadata-Version: 1.2
Name: thrift-sasl
Version: 0.3.0
Summary: Thrift SASL Python module that implements SASL transports for Thrift (`TSaslClientTransport`).
Home-page: https://github.com/cloudera/thrift_sasl
Author: Uri Laserson
Author-email: laserson@cloudera.com
Maintainer: Wes McKinney
Maintainer-email: wes@cloudera.com
License: Apache License, Version 2.0
Description: Thrift SASL Python module that implements SASL transports for Thrift (`TSaslClientTransport`).
Keywords: thrift sasl transport
Platform: UNKNOWN
Classifier: Programming Language :: Python :: 2
Classifier: Programming Language :: Python :: 2.6
Classifier: Programming Language :: Python :: 2.7
Classifier: Programming Language :: Python :: 3
Classifier: Programming Language :: Python :: 3.3
Classifier: Programming Language :: Python :: 3.4
Classifier: Programming Language :: Python :: 3.5
Classifier: Programming Language :: Python :: 3.6
setup.cfg
setup.py
thrift_sasl/__init__.py
thrift_sasl.egg-info/PKG-INFO
thrift_sasl.egg-info/SOURCES.txt
thrift_sasl.egg-info/dependency_links.txt
thrift_sasl.egg-info/requires.txt
thrift_sasl.egg-info/top_level.txt
\ No newline at end of file
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
""" SASL transports for Thrift. """
# Initially copied from the Impala repo
from __future__ import absolute_import
import sys
import struct
from thrift.transport.TTransport import (TTransportException, TTransportBase, CReadableTransport)
# TODO: Check whether the following distinction is necessary. Does not appear to
# break anything when `io.BytesIO` is used everywhere, but there may be some edge
# cases where things break down.
if sys.version_info[0] == 3:
from io import BytesIO as BufferIO
else:
from cStringIO import StringIO as BufferIO
class TSaslClientTransport(TTransportBase, CReadableTransport):
START = 1
OK = 2
BAD = 3
ERROR = 4
COMPLETE = 5
def __init__(self, sasl_client_factory, mechanism, trans):
"""
@param sasl_client_factory: a callable that returns a new sasl.Client object
@param mechanism: the SASL mechanism (e.g. "GSSAPI")
@param trans: the underlying transport over which to communicate.
"""
self._trans = trans
self.sasl_client_factory = sasl_client_factory
self.sasl = None
self.mechanism = mechanism
self.__wbuf = BufferIO()
self.__rbuf = BufferIO()
self.opened = False
self.encode = None
def isOpen(self):
return self._trans.isOpen()
def is_open(self):
return self.isOpen()
def open(self):
if not self._trans.isOpen():
self._trans.open()
if self.sasl is not None:
raise TTransportException(
type=TTransportException.NOT_OPEN,
message="Already open!")
self.sasl = self.sasl_client_factory()
ret, chosen_mech, initial_response = self.sasl.start(self.mechanism)
if not ret:
raise TTransportException(type=TTransportException.NOT_OPEN,
message=("Could not start SASL: %s" % self.sasl.getError()))
# Send initial response
self._send_message(self.START, chosen_mech)
self._send_message(self.OK, initial_response)
# SASL negotiation loop
while True:
status, payload = self._recv_sasl_message()
if status not in (self.OK, self.COMPLETE):
raise TTransportException(type=TTransportException.NOT_OPEN,
message=("Bad status: %d (%s)" % (status, payload)))
if status == self.COMPLETE:
break
ret, response = self.sasl.step(payload)
if not ret:
raise TTransportException(type=TTransportException.NOT_OPEN,
message=("Bad SASL result: %s" % (self.sasl.getError())))
self._send_message(self.OK, response)
def _send_message(self, status, body):
header = struct.pack(">BI", status, len(body))
self._trans.write(header + body)
self._trans.flush()
def _recv_sasl_message(self):
header = self._trans.readAll(5)
status, length = struct.unpack(">BI", header)
if length > 0:
payload = self._trans.readAll(length)
else:
payload = ""
return status, payload
def write(self, data):
self.__wbuf.write(data)
def flush(self):
buffer = self.__wbuf.getvalue()
# The first time we flush data, we send it to sasl.encode()
# If the length doesn't change, then we must be using a QOP
# of auth and we should no longer call sasl.encode(), otherwise
# we encode every time.
if self.encode == None:
success, encoded = self.sasl.encode(buffer)
if not success:
raise TTransportException(type=TTransportException.UNKNOWN,
message=self.sasl.getError())
if (len(encoded)==len(buffer)):
self.encode = False
self._flushPlain(buffer)
else:
self.encode = True
self._trans.write(encoded)
elif self.encode:
self._flushEncoded(buffer)
else:
self._flushPlain(buffer)
self._trans.flush()
self.__wbuf = BufferIO()
def _flushEncoded(self, buffer):
# sasl.ecnode() does the encoding and adds the length header, so nothing
# to do but call it and write the result.
success, encoded = self.sasl.encode(buffer)
if not success:
raise TTransportException(type=TTransportException.UNKNOWN,
message=self.sasl.getError())
self._trans.write(encoded)
def _flushPlain(self, buffer):
# When we have QOP of auth, sasl.encode() will pass the input to the output
# but won't put a length header, so we have to do that.
# Note stolen from TFramedTransport:
# N.B.: Doing this string concatenation is WAY cheaper than making
# two separate calls to the underlying socket object. Socket writes in
# Python turn out to be REALLY expensive, but it seems to do a pretty
# good job of managing string buffer operations without excessive copies
self._trans.write(struct.pack(">I", len(buffer)) + buffer)
def read(self, sz):
ret = self.__rbuf.read(sz)
if len(ret) == sz:
return ret
self._read_frame()
return ret + self.__rbuf.read(sz - len(ret))
def _read_frame(self):
header = self._trans.readAll(4)
(length,) = struct.unpack(">I", header)
if self.encode:
# If the frames are encoded (i.e. you're using a QOP of auth-int or
# auth-conf), then make sure to include the header in the bytes you send to
# sasl.decode()
encoded = header + self._trans.readAll(length)
success, decoded = self.sasl.decode(encoded)
if not success:
raise TTransportException(type=TTransportException.UNKNOWN,
message=self.sasl.getError())
else:
# If the frames are not encoded, just pass it through
decoded = self._trans.readAll(length)
self.__rbuf = BufferIO(decoded)
def close(self):
self._trans.close()
self.sasl = None
# Implement the CReadableTransport interface.
# Stolen shamelessly from TFramedTransport
@property
def cstringio_buf(self):
return self.__rbuf
def cstringio_refill(self, prefix, reqlen):
# self.__rbuf will already be empty here because fastbinary doesn't
# ask for a refill until the previous buffer is empty. Therefore,
# we can start reading new frames immediately.
while len(prefix) < reqlen:
self._read_frame()
prefix += self.__rbuf.getvalue()
self.__rbuf = BufferIO(prefix)
return self.__rbuf
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment