# Copyright (c) 2016 Uber Technologies, Inc.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
from __future__ import absolute_import, print_function, unicode_literals
import sys
import types
from functools import partial
import thriftrw
from tornado import gen
from tornado.util import raise_exc_info
from tchannel.status import OK, FAILED
from tchannel.errors import OneWayNotSupportedError
from tchannel.errors import ValueExpectedError
from tchannel.response import Response, response_from_mixed
from tchannel.serializer.thrift import ThriftRWSerializer
from .module import ThriftRequest
[docs]def load(path, service=None, hostport=None, module_name=None):
"""Loads the Thrift file at the specified path.
The file is compiled in-memory and a Python module containing the result
is returned. It may be used with ``TChannel.thrift``. For example,
.. code-block:: python
from tchannel import TChannel, thrift
# Load our server's interface definition.
donuts = thrift.load(path='donuts.thrift')
# We need to specify a service name or hostport because this is a
# downstream service we'll be calling.
coffee = thrift.load(path='coffee.thrift', service='coffee')
tchannel = TChannel('donuts')
@tchannel.thrift.register(donuts.DonutsService)
@tornado.gen.coroutine
def submitOrder(request):
args = request.body
if args.coffee:
yield tchannel.thrift(
coffee.CoffeeService.order(args.coffee)
)
# ...
The returned module contains, one top-level type for each struct, enum,
union, exeption, and service defined in the Thrift file. For each service,
the corresponding class contains a classmethod for each function defined
in that service that accepts the arguments for that function and returns a
``ThriftRequest`` capable of being sent via ``TChannel.thrift``.
For more information on what gets generated by ``load``, see `thriftrw
<http://thriftrw.readthedocs.org/en/latest/>`_.
Note that the ``path`` accepted by ``load`` must be either an absolute
path or a path relative to the *the current directory*. If you need to
refer to Thrift files relative to the Python module in which ``load`` was
called, use the ``__file__`` magic variable.
.. code-block:: python
# Given,
#
# foo/
# myservice.thrift
# bar/
# x.py
#
# Inside foo/bar/x.py,
path = os.path.join(
os.path.dirname(__file__), '../myservice.thrift'
)
The returned value is a valid Python module. You can install the module by
adding it to the ``sys.modules`` dictionary. This will allow importing
items from this module directly. You can use the ``__name__`` magic
variable to make the generated module a submodule of the current module.
For example,
.. code-block:: python
# foo/bar.py
import sys
from tchannel import thrift
donuts = = thrift.load('donuts.thrift')
sys.modules[__name__ + '.donuts'] = donuts
This installs the module generated for ``donuts.thrift`` as the module
``foo.bar.donuts``. Callers can then import items from that module
directly. For example,
.. code-block:: python
# foo/baz.py
from foo.bar.donuts import DonutsService, Order
def baz(tchannel):
return tchannel.thrift(
DonutsService.submitOrder(Order(..))
)
:param str service:
Name of the service that the Thrift file represents. This name will be
used to route requests through Hyperbahn.
:param str path:
Path to the Thrift file. If this is a relative path, it must be
relative to the current directory.
:param str hostport:
Clients can use this to specify the hostport at which the service can
be found. If omitted, TChannel will route the requests through known
peers. This value is ignored by servers.
:param str module_name:
Name used for the generated Python module. Defaults to the name of the
Thrift file.
"""
# TODO replace with more specific exceptions
# assert service, 'service is required'
# assert path, 'path is required'
# Backwards compatibility for callers passing in service name as first arg.
if not path.endswith('.thrift'):
service, path = path, service
module = thriftrw.load(path=path, name=module_name)
return TChannelThriftModule(service, module, hostport)
class TChannelThriftModule(types.ModuleType):
"""Wraps the ``thriftrw``-generated module.
Wraps service classes with ``Service`` and exposes everything else from
the module as-is.
"""
def __init__(self, service, module, hostport=None):
"""Initialize a TChannelThriftModule.
:param str service:
Name of the service this module represents. This name will be used
for routing over Hyperbahn.
:param module:
Module generated by ``thriftrw`` for a Thrift file.
:param str hostport:
This may be specified if the caller is a client and wants all
requests sent to a specific address.
"""
self.service = service
self.hostport = hostport
self._module = module
services = getattr(self._module, '__services__', None)
if services is None:
# thriftrw <1.0
services = getattr(self._module, 'services')
for service_cls in services:
name = service_cls.service_spec.name
setattr(self, name, Service(service_cls, self))
def __getattr__(self, name):
return getattr(self._module, name)
def __str__(self):
return 'TChannelThriftModule(%s, %s)' % (self.service, self._module)
__repr__ = __str__
class Service(object):
"""Wraps service classes generated by thriftrw.
Exposes all functions of the service.
"""
def __init__(self, cls, module):
self._module = module
self._cls = cls
self._spec = cls.service_spec
self._setup_functions(self._spec)
def _setup_functions(self, spec):
if spec.parent:
# Set up inherited functions first.
self._setup_functions(spec.parent)
for func_spec in spec.functions:
setattr(self, func_spec.name, Function(func_spec, self))
@property
def name(self):
"""Name of the Thrift service this object represents."""
return self._spec.name
def __str__(self):
return 'Service(%s)' % self.name
__repr__ = __str__
class Function(object):
"""Wraps a ServiceFunction generated by thriftrw.
Acts as a callable that will construct ThriftRequests.
"""
__slots__ = (
'spec', 'service', '_func', '_request_cls', '_response_cls'
)
def __init__(self, func_spec, service):
self.spec = func_spec
self.service = service
self._func = func_spec.surface
self._request_cls = self._func.request
self._response_cls = self._func.response
@property
def endpoint(self):
"""Endpoint name for this function."""
return '%s::%s' % (self.service.name, self._func.name)
@property
def oneway(self):
"""Whether this function is oneway."""
return self.spec.oneway
def __call__(self, *args, **kwargs):
if self.oneway:
raise OneWayNotSupportedError(
'TChannel+Thrift does not currently support oneway '
'procedures.'
)
if not (
self.service._module.hostport or
self.service._module.service
):
raise ValueError(
"No 'service' or 'hostport' provided to " +
str(self)
)
module = self.service._module
call_args = self._request_cls(*args, **kwargs)
return ThriftRWRequest(
module=module,
service=module.service,
endpoint=self.endpoint,
result_type=self._response_cls,
call_args=call_args,
hostport=module.hostport,
)
def __str__(self):
return 'Function(%s)' % self.endpoint
__repr__ = __str__
def register(dispatcher, service, handler=None, method=None):
"""
:param dispatcher:
RequestDispatcher against which the new endpoint will be registered.
:param Service service:
Service object representing the service whose endpoint is being
registered.
:param handler:
A function implementing the given Thrift function.
:param method:
If specified, name of the method being registered. Defaults to the
name of the ``handler`` function.
"""
def decorator(method, handler):
if not method:
method = handler.__name__
function = getattr(service, method, None)
assert function, (
'Service "%s" does not define method "%s"' % (service.name, method)
)
assert not function.oneway
dispatcher.register(
function.endpoint,
build_handler(function, handler),
ThriftRWSerializer(service._module, function._request_cls),
ThriftRWSerializer(service._module, function._response_cls),
)
return handler
if handler is None:
return partial(decorator, method)
else:
return decorator(method, handler)
def build_handler(function, handler):
# response_cls is a class that represents the response union for this
# function. It accepts one parameter for each exception defined on the
# method and another parameter 'success' for the result of the call. The
# success kwarg is absent if the function doesn't return anything.
response_cls = function._response_cls
response_spec = response_cls.type_spec
@gen.coroutine
def handle(request):
# kwargs for this function's response_cls constructor
response_kwargs = {}
status = OK
try:
response = yield gen.maybe_future(handler(request))
except Exception as e:
response = Response()
for exc_spec in response_spec.exception_specs:
# Each exc_spec is a thriftrw.spec.FieldSpec. The spec
# attribute on that is the TypeSpec for the Exception class
# and the surface on the TypeSpec is the exception class.
exc_cls = exc_spec.spec.surface
if isinstance(e, exc_cls):
status = FAILED
response_kwargs[exc_spec.name] = e
break
else:
raise_exc_info(sys.exc_info())
else:
response = response_from_mixed(response)
if response_spec.return_spec is not None:
assert response.body is not None, (
'Expected a value to be returned for %s, '
'but recieved None - only void procedures can '
'return None.' % function.endpoint
)
response_kwargs['success'] = response.body
response.status = status
response.body = response_cls(**response_kwargs)
raise gen.Return(response)
handle.__name__ = function.spec.name
return handle
class ThriftRWRequest(ThriftRequest):
def __init__(self, module, **kwargs):
kwargs['serializer'] = ThriftRWSerializer(
module, kwargs['result_type']
)
super(ThriftRWRequest, self).__init__(**kwargs)
def read_body(self, body):
response_spec = self.result_type.type_spec
for exc_spec in response_spec.exception_specs:
exc = getattr(body, exc_spec.name)
if exc is not None:
raise exc
# success - non-void
if response_spec.return_spec is not None:
if body.success is None:
raise ValueExpectedError(
'Expected a value to be returned for %s, '
'but recieved None - only void procedures can '
'return None.' % self.endpoint
)
return body.success
# success - void
else:
return None