Skip to content

Commit

Permalink
Merge pull request #603 from gtxm/class_directives
Browse files Browse the repository at this point in the history
Class-based directives
  • Loading branch information
timothycrosley authored Jan 4, 2018
2 parents 0272fda + 7ebcb60 commit 528a0de
Show file tree
Hide file tree
Showing 4 changed files with 182 additions and 8 deletions.
30 changes: 30 additions & 0 deletions examples/smtp_envelope_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import envelopes
import hug


@hug.directive()
class SMTP(object):

def __init__(self, *args, **kwargs):
self.smtp = envelopes.SMTP(host='127.0.0.1')
self.envelopes_to_send = list()

def send_envelope(self, envelope):
self.envelopes_to_send.append(envelope)

def cleanup(self, exception=None):
if exception:
return
for envelope in self.envelopes_to_send:
self.smtp.send(envelope)


@hug.get('/hello')
def send_hello_email(smtp: SMTP):
envelope = envelopes.Envelope(
from_addr=(u'[email protected]', u'From me'),
to_addr=(u'[email protected]', u'To World'),
subject=u'Hello',
text_body=u"World!"
)
smtp.send_envelope(envelope)
58 changes: 58 additions & 0 deletions examples/sqlalchemy_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import hug

from sqlalchemy import create_engine, Column, Integer, String
from sqlalchemy.ext.declarative.api import declarative_base
from sqlalchemy.orm.session import Session
from sqlalchemy.orm import scoped_session
from sqlalchemy.orm import sessionmaker


engine = create_engine("sqlite:///:memory:")

session_factory = scoped_session(sessionmaker(bind=engine))


Base = declarative_base()


class TestModel(Base):
__tablename__ = 'test_model'
id = Column(Integer, primary_key=True)
name = Column(String)


Base.metadata.create_all(bind=engine)


@hug.directive()
class Resource(object):

def __init__(self, *args, **kwargs):
self._db = session_factory()
self.autocommit = True

@property
def db(self) -> Session:
return self._db

def cleanup(self, exception=None):
if exception:
self.db.rollback()
return
if self.autocommit:
self.db.commit()


@hug.directive()
def return_session() -> Session:
return session_factory()


@hug.get('/hello')
def make_simple_query(resource: Resource):
for word in ["hello", "world", ":)"]:
test_model = TestModel()
test_model.name = word
resource.db.add(test_model)
resource.db.flush()
return " ".join([obj.name for obj in resource.db.query(TestModel).all()])
39 changes: 31 additions & 8 deletions hug/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,13 @@ def _rewrite_params(self, params):
if interface_name in params:
params[internal_name] = params.pop(interface_name)

@staticmethod
def cleanup_parameters(parameters, exception=None):
for parameter, directive in parameters.items():
if hasattr(directive, 'cleanup'):
directive.cleanup(exception=exception)


class Local(Interface):
"""Defines the Interface responsible for exposing functions locally"""
__slots__ = ('skip_directives', 'skip_validation', 'version')
Expand Down Expand Up @@ -325,7 +332,12 @@ def __call__(self, *args, **kwargs):

if getattr(self, 'map_params', None):
self._rewrite_params(kwargs)
result = self.interface(**kwargs)
try:
result = self.interface(**kwargs)
self.cleanup_parameters(kwargs)
except Exception as exception:
self.cleanup_parameters(kwargs, exception=exception)
raise exception
if self.transform:
result = self.transform(result)
return self.outputs(result) if self.outputs else result
Expand Down Expand Up @@ -480,11 +492,15 @@ def __call__(self):
if getattr(self, 'map_params', None):
self._rewrite_params(pass_to_function)

if args:
result = self.interface(*args, **pass_to_function)
else:
result = self.interface(**pass_to_function)

try:
if args:
result = self.interface(*args, **pass_to_function)
else:
result = self.interface(**pass_to_function)
self.cleanup_parameters(pass_to_function)
except Exception as exception:
self.cleanup_parameters(pass_to_function, exception=exception)
raise exception
return self.output(result)


Expand Down Expand Up @@ -559,7 +575,6 @@ def gather_parameters(self, request, response, api_version=None, **input_paramet
arguments = (self.defaults[parameter], ) if parameter in self.defaults else ()
input_parameters[parameter] = directive(*arguments, response=response, request=request,
api=self.api, api_version=api_version, interface=self)

return input_parameters

@property
Expand Down Expand Up @@ -676,6 +691,7 @@ def __call__(self, request, response, api_version=None, **kwargs):
else:
exception_types = self.api.http.exception_handlers(api_version)
exception_types = tuple(exception_types.keys()) if exception_types else ()
input_parameters = {}
try:
self.set_response_defaults(response, request)

Expand All @@ -691,9 +707,12 @@ def __call__(self, request, response, api_version=None, **kwargs):
return self.render_errors(errors, request, response)

self.render_content(self.call_function(input_parameters), request, response, **kwargs)
except falcon.HTTPNotFound:
self.cleanup_parameters(input_parameters)
except falcon.HTTPNotFound as exception:
self.cleanup_parameters(input_parameters, exception=exception)
return self.api.http.not_found(request, response, **kwargs)
except exception_types as exception:
self.cleanup_parameters(input_parameters, exception=exception)
handler = None
exception_type = type(exception)
if exception_type in exception_types:
Expand All @@ -710,6 +729,10 @@ def __call__(self, request, response, api_version=None, **kwargs):
raise exception

handler(request=request, response=response, exception=exception, **kwargs)
except Exception as exception:
self.cleanup_parameters(input_parameters, exception=exception)
raise exception


def documentation(self, add_to=None, version=None, prefix="", base_url="", url=""):
"""Returns the documentation specific to an HTTP interface"""
Expand Down
63 changes: 63 additions & 0 deletions tests/test_decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -934,6 +934,69 @@ def test(hug_timer):
assert isinstance(hug.test.cli(test), float)


def test_cli_with_class_directives():

@hug.directive()
class ClassDirective(object):

def __init__(self, *args, **kwargs):
self.test = 1

@hug.cli()
@hug.local(skip_directives=False)
def test(class_directive: ClassDirective):
return class_directive.test

assert test() == 1
assert hug.test.cli(test) == 1

class TestObject(object):
is_cleanup_launched = False
last_exception = None

@hug.directive()
class ClassDirectiveWithCleanUp(object):

def __init__(self, *args, **kwargs):
self.test_object = TestObject

def cleanup(self, exception):
self.test_object.is_cleanup_launched = True
self.test_object.last_exception = exception

@hug.cli()
@hug.local(skip_directives=False)
def test2(class_directive: ClassDirectiveWithCleanUp):
return class_directive.test_object.is_cleanup_launched

assert not hug.test.cli(test2) # cleanup should be launched after running command
assert TestObject.is_cleanup_launched
assert TestObject.last_exception is None
TestObject.is_cleanup_launched = False
TestObject.last_exception = None
assert not test2()
assert TestObject.is_cleanup_launched
assert TestObject.last_exception is None

@hug.cli()
@hug.local(skip_directives=False)
def test_with_attribute_error(class_directive: ClassDirectiveWithCleanUp):
raise class_directive.test_object2

hug.test.cli(test_with_attribute_error)
assert TestObject.is_cleanup_launched
assert isinstance(TestObject.last_exception, AttributeError)
TestObject.is_cleanup_launched = False
TestObject.last_exception = None
try:
test_with_attribute_error()
assert False
except AttributeError:
assert True
assert TestObject.is_cleanup_launched
assert isinstance(TestObject.last_exception, AttributeError)


def test_cli_with_named_directives():
"""Test to ensure you can pass named directives into the cli"""
@hug.cli()
Expand Down

0 comments on commit 528a0de

Please sign in to comment.