diff --git a/examples/smtp_envelope_example.py b/examples/smtp_envelope_example.py new file mode 100644 index 00000000..9c1c6b82 --- /dev/null +++ b/examples/smtp_envelope_example.py @@ -0,0 +1,30 @@ +import envelopes +import hug + + +@hug.directive() +class SMTP(object): + + def __init__(self, *args, **kwargs): + self.smtp = envelopes.SMTP(host='127.0.0.1') + self.envelopes_to_send = list() + + def send_envelope(self, envelope): + self.envelopes_to_send.append(envelope) + + def cleanup(self, exception=None): + if exception: + return + for envelope in self.envelopes_to_send: + self.smtp.send(envelope) + + +@hug.get('/hello') +def send_hello_email(smtp: SMTP): + envelope = envelopes.Envelope( + from_addr=(u'me@example.com', u'From me'), + to_addr=(u'world@example.com', u'To World'), + subject=u'Hello', + text_body=u"World!" + ) + smtp.send_envelope(envelope) diff --git a/examples/sqlalchemy_example.py b/examples/sqlalchemy_example.py new file mode 100644 index 00000000..08c62f34 --- /dev/null +++ b/examples/sqlalchemy_example.py @@ -0,0 +1,58 @@ +import hug + +from sqlalchemy import create_engine, Column, Integer, String +from sqlalchemy.ext.declarative.api import declarative_base +from sqlalchemy.orm.session import Session +from sqlalchemy.orm import scoped_session +from sqlalchemy.orm import sessionmaker + + +engine = create_engine("sqlite:///:memory:") + +session_factory = scoped_session(sessionmaker(bind=engine)) + + +Base = declarative_base() + + +class TestModel(Base): + __tablename__ = 'test_model' + id = Column(Integer, primary_key=True) + name = Column(String) + + +Base.metadata.create_all(bind=engine) + + +@hug.directive() +class Resource(object): + + def __init__(self, *args, **kwargs): + self._db = session_factory() + self.autocommit = True + + @property + def db(self) -> Session: + return self._db + + def cleanup(self, exception=None): + if exception: + self.db.rollback() + return + if self.autocommit: + self.db.commit() + + +@hug.directive() +def return_session() -> Session: + return session_factory() + + +@hug.get('/hello') +def make_simple_query(resource: Resource): + for word in ["hello", "world", ":)"]: + test_model = TestModel() + test_model.name = word + resource.db.add(test_model) + resource.db.flush() + return " ".join([obj.name for obj in resource.db.query(TestModel).all()]) diff --git a/hug/interface.py b/hug/interface.py index c85a9490..2ae3684d 100644 --- a/hug/interface.py +++ b/hug/interface.py @@ -270,6 +270,13 @@ def _rewrite_params(self, params): if interface_name in params: params[internal_name] = params.pop(interface_name) + @staticmethod + def cleanup_parameters(parameters, exception=None): + for parameter, directive in parameters.items(): + if hasattr(directive, 'cleanup'): + directive.cleanup(exception=exception) + + class Local(Interface): """Defines the Interface responsible for exposing functions locally""" __slots__ = ('skip_directives', 'skip_validation', 'version') @@ -325,7 +332,12 @@ def __call__(self, *args, **kwargs): if getattr(self, 'map_params', None): self._rewrite_params(kwargs) - result = self.interface(**kwargs) + try: + result = self.interface(**kwargs) + self.cleanup_parameters(kwargs) + except Exception as exception: + self.cleanup_parameters(kwargs, exception=exception) + raise exception if self.transform: result = self.transform(result) return self.outputs(result) if self.outputs else result @@ -480,11 +492,15 @@ def __call__(self): if getattr(self, 'map_params', None): self._rewrite_params(pass_to_function) - if args: - result = self.interface(*args, **pass_to_function) - else: - result = self.interface(**pass_to_function) - + try: + if args: + result = self.interface(*args, **pass_to_function) + else: + result = self.interface(**pass_to_function) + self.cleanup_parameters(pass_to_function) + except Exception as exception: + self.cleanup_parameters(pass_to_function, exception=exception) + raise exception return self.output(result) @@ -559,7 +575,6 @@ def gather_parameters(self, request, response, api_version=None, **input_paramet arguments = (self.defaults[parameter], ) if parameter in self.defaults else () input_parameters[parameter] = directive(*arguments, response=response, request=request, api=self.api, api_version=api_version, interface=self) - return input_parameters @property @@ -676,6 +691,7 @@ def __call__(self, request, response, api_version=None, **kwargs): else: exception_types = self.api.http.exception_handlers(api_version) exception_types = tuple(exception_types.keys()) if exception_types else () + input_parameters = {} try: self.set_response_defaults(response, request) @@ -691,9 +707,12 @@ def __call__(self, request, response, api_version=None, **kwargs): return self.render_errors(errors, request, response) self.render_content(self.call_function(input_parameters), request, response, **kwargs) - except falcon.HTTPNotFound: + self.cleanup_parameters(input_parameters) + except falcon.HTTPNotFound as exception: + self.cleanup_parameters(input_parameters, exception=exception) return self.api.http.not_found(request, response, **kwargs) except exception_types as exception: + self.cleanup_parameters(input_parameters, exception=exception) handler = None exception_type = type(exception) if exception_type in exception_types: @@ -710,6 +729,10 @@ def __call__(self, request, response, api_version=None, **kwargs): raise exception handler(request=request, response=response, exception=exception, **kwargs) + except Exception as exception: + self.cleanup_parameters(input_parameters, exception=exception) + raise exception + def documentation(self, add_to=None, version=None, prefix="", base_url="", url=""): """Returns the documentation specific to an HTTP interface""" diff --git a/tests/test_decorators.py b/tests/test_decorators.py index 58635ac5..a64053c7 100644 --- a/tests/test_decorators.py +++ b/tests/test_decorators.py @@ -934,6 +934,69 @@ def test(hug_timer): assert isinstance(hug.test.cli(test), float) +def test_cli_with_class_directives(): + + @hug.directive() + class ClassDirective(object): + + def __init__(self, *args, **kwargs): + self.test = 1 + + @hug.cli() + @hug.local(skip_directives=False) + def test(class_directive: ClassDirective): + return class_directive.test + + assert test() == 1 + assert hug.test.cli(test) == 1 + + class TestObject(object): + is_cleanup_launched = False + last_exception = None + + @hug.directive() + class ClassDirectiveWithCleanUp(object): + + def __init__(self, *args, **kwargs): + self.test_object = TestObject + + def cleanup(self, exception): + self.test_object.is_cleanup_launched = True + self.test_object.last_exception = exception + + @hug.cli() + @hug.local(skip_directives=False) + def test2(class_directive: ClassDirectiveWithCleanUp): + return class_directive.test_object.is_cleanup_launched + + assert not hug.test.cli(test2) # cleanup should be launched after running command + assert TestObject.is_cleanup_launched + assert TestObject.last_exception is None + TestObject.is_cleanup_launched = False + TestObject.last_exception = None + assert not test2() + assert TestObject.is_cleanup_launched + assert TestObject.last_exception is None + + @hug.cli() + @hug.local(skip_directives=False) + def test_with_attribute_error(class_directive: ClassDirectiveWithCleanUp): + raise class_directive.test_object2 + + hug.test.cli(test_with_attribute_error) + assert TestObject.is_cleanup_launched + assert isinstance(TestObject.last_exception, AttributeError) + TestObject.is_cleanup_launched = False + TestObject.last_exception = None + try: + test_with_attribute_error() + assert False + except AttributeError: + assert True + assert TestObject.is_cleanup_launched + assert isinstance(TestObject.last_exception, AttributeError) + + def test_cli_with_named_directives(): """Test to ensure you can pass named directives into the cli""" @hug.cli()