Skip to content

Commit

Permalink
Fix issue tracking recursive references to custom annotations (#197)
Browse files Browse the repository at this point in the history
* Handle recursive custom annotations dependencies

* Fix tests and lint

Co-authored-by: Brad Girardeau <[email protected]>
  • Loading branch information
bgirardeau and bradg-dbx authored Oct 29, 2020
1 parent ff84331 commit fdf1761
Show file tree
Hide file tree
Showing 7 changed files with 158 additions and 91 deletions.
1 change: 0 additions & 1 deletion stone/backends/python_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,6 @@


class PythonClientBackend(CodeBackend):
# pylint: disable=attribute-defined-outside-init

cmdline_parser = _cmdline_parser
supported_auth_types = None
Expand Down
12 changes: 8 additions & 4 deletions stone/backends/python_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,9 @@

from stone.ir import AnnotationType, ApiNamespace
from stone.ir import (
get_custom_annotations_for_alias,
get_custom_annotations_recursive,
is_alias,
is_boolean_type,
is_composite_type,
is_bytes_type,
is_list_type,
is_map_type,
Expand Down Expand Up @@ -642,7 +641,7 @@ def _generate_custom_annotation_processors(self, ns, data_type, extra_annotation
dt, _, _ = unwrap(data_type)
if is_struct_type(dt) or is_union_type(dt):
annotation_types_seen = set()
for annotation in get_custom_annotations_recursive(dt):
for _, annotation in dt.recursive_custom_annotations:
if annotation.annotation_type not in annotation_types_seen:
yield (annotation.annotation_type,
generate_func_call(
Expand Down Expand Up @@ -672,7 +671,12 @@ def _generate_custom_annotation_processors(self, ns, data_type, extra_annotation

# annotations applied directly to this type (through aliases or
# passed in from the caller)
for annotation in itertools.chain(get_custom_annotations_for_alias(data_type),
indirect_annotations = dt.recursive_custom_annotations if is_composite_type(dt) else set()
all_annotations = (data_type.recursive_custom_annotations
if is_composite_type(data_type) else set())
remaining_annotations = [annotation for _, annotation in
all_annotations.difference(indirect_annotations)]
for annotation in itertools.chain(remaining_annotations,
extra_annotations):
yield (annotation.annotation_type,
generate_func_call(
Expand Down
71 changes: 71 additions & 0 deletions stone/frontend/ir_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
Int32,
Int64,
is_alias,
is_composite_type,
is_field_type,
is_list_type,
is_map_type,
Expand Down Expand Up @@ -297,6 +298,7 @@ def generate_IR(self):
self._populate_field_defaults()
self._populate_enumerated_subtypes()
self._populate_route_attributes()
self._populate_recursive_custom_annotations()
self._populate_examples()
self._validate_doc_refs()
self._validate_annotations()
Expand Down Expand Up @@ -802,6 +804,75 @@ def _populate_union_type_attributes(self, env, data_type):
data_type.set_attributes(
data_type._ast_node.doc, api_type_fields, parent_type, catch_all_field)

def _populate_recursive_custom_annotations(self):
"""
Populates custom annotations applied to fields recursively. This is done in
a separate pass because it requires all fields and routes to be defined so that
recursive chains can be followed accurately.
"""
data_types_seen = set()

def recurse(data_type):
# primitive types do not have annotations
if not is_composite_type(data_type):
return set()

# if we have already analyzed data type, just return result
if data_type.recursive_custom_annotations is not None:
return data_type.recursive_custom_annotations

# handle cycles safely (annotations will be found first time at top level)
if data_type in data_types_seen:
return set()
data_types_seen.add(data_type)

annotations = set()

# collect data types from subtypes recursively
if is_struct_type(data_type) or is_union_type(data_type):
for field in data_type.fields:
annotations.update(recurse(field.data_type))
# annotations can be defined directly on fields
annotations.update([(field, annotation)
for annotation in field.custom_annotations])
elif is_alias(data_type):
annotations.update(recurse(data_type.data_type))
# annotations can be defined directly on aliases
annotations.update([(data_type, annotation)
for annotation in data_type.custom_annotations])
elif is_list_type(data_type):
annotations.update(recurse(data_type.data_type))
elif is_map_type(data_type):
# only map values support annotations for now
annotations.update(recurse(data_type.value_data_type))
elif is_nullable_type(data_type):
annotations.update(recurse(data_type.data_type))

data_type.recursive_custom_annotations = annotations
return annotations

for namespace in self.api.namespaces.values():
namespace_annotations = set()
for data_type in namespace.data_types:
namespace_annotations.update(recurse(data_type))

for alias in namespace.aliases:
namespace_annotations.update(recurse(alias))

for route in namespace.routes:
namespace_annotations.update(recurse(route.arg_data_type))
namespace_annotations.update(recurse(route.result_data_type))
namespace_annotations.update(recurse(route.error_data_type))

# record annotation types as dependencies of the namespace. this allows for
# an optimization when processing custom annotations to ignore annotation
# types that are not applied to the data type, rather than recursing into it
for _, annotation in namespace_annotations:
if annotation.annotation_type.namespace.name != namespace.name:
namespace.add_imported_namespace(
annotation.annotation_type.namespace,
imported_annotation_type=True)

def _populate_field_defaults(self):
"""
Populate the defaults of each field. This is done in a separate pass
Expand Down
116 changes: 32 additions & 84 deletions stone/ir/data_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,30 @@ def generic_type_name(v):
return type(v).__name__


def record_custom_annotation_imports(annotation, namespace):
"""
Records imports for custom annotations in the given namespace.
"""
# first, check the annotation *type*
if annotation.annotation_type.namespace.name != namespace.name:
namespace.add_imported_namespace(
annotation.annotation_type.namespace,
imported_annotation_type=True)

# second, check if we need to import the annotation itself

# the annotation namespace is currently not actually used in the
# backends, which reconstruct the annotation from the annotation
# type directly. This could be changed in the future, and at
# the IR level it makes sense to include the dependency

if annotation.namespace.name != namespace.name:
namespace.add_imported_namespace(
annotation.namespace,
imported_annotation=True)


class DataType(object):
"""
Abstract class representing a data type.
Expand Down Expand Up @@ -118,6 +142,12 @@ class Composite(DataType): # pylint: disable=abstract-method
Composite types are any data type which can be constructed using primitive
data types and other composite types.
"""
def __init__(self):
super(Composite, self).__init__()
# contains custom annotations that apply to any containing data types (recursively)
# format is (location, CustomAnnotation) to indicate a custom annotation is applied
# to a location (Field or Alias)
self.recursive_custom_annotations = None


class Nullable(Composite):
Expand Down Expand Up @@ -781,22 +811,7 @@ def set_attributes(self, doc, fields, parent_type=None):
# they are treated as globals at the IR level
for field in self.fields:
for annotation in field.custom_annotations:
# first, check the annotation *type*
if annotation.annotation_type.namespace.name != self.namespace.name:
self.namespace.add_imported_namespace(
annotation.annotation_type.namespace,
imported_annotation_type=True)

# second, check if we need to import the annotation itself

# the annotation namespace is currently not actually used in the
# backends, which reconstruct the annotation from the annotation
# type directly. This could be changed in the future, and at
# the IR level it makes sense to include the dependency
if annotation.namespace.name != self.namespace.name:
self.namespace.add_imported_namespace(
annotation.namespace,
imported_annotation=True)
record_custom_annotation_imports(annotation, self.namespace)

# Indicate that the attributes of the type have been populated.
self._is_forward_ref = False
Expand Down Expand Up @@ -901,7 +916,6 @@ class Struct(UserDefined):
"""
Defines a product type: Composed of other primitive and/or struct types.
"""
# pylint: disable=attribute-defined-outside-init

composite_type = 'struct'

Expand Down Expand Up @@ -1359,7 +1373,6 @@ def __repr__(self):

class Union(UserDefined):
"""Defines a tagged union. Fields are variants."""
# pylint: disable=attribute-defined-outside-init

composite_type = 'union'

Expand Down Expand Up @@ -1830,25 +1843,7 @@ def set_annotations(self, annotations):
elif isinstance(annotation, CustomAnnotation):
# Note: we don't need to do this for builtin annotations because
# they are treated as globals at the IR level

# first, check the annotation *type*
if annotation.annotation_type.namespace.name != self.namespace.name:
self.namespace.add_imported_namespace(
annotation.annotation_type.namespace,
imported_annotation_type=True)

# second, check if we need to import the annotation itself

# the annotation namespace is currently not actually used in the
# backends, which reconstruct the annotation from the annotation
# type directly. This could be changed in the future, and at
# the IR level it makes sense to include the dependency

if annotation.namespace.name != self.namespace.name:
self.namespace.add_imported_namespace(
annotation.namespace,
imported_annotation=True)

record_custom_annotation_imports(annotation, self.namespace)
self.custom_annotations.append(annotation)
else:
raise InvalidSpec("Aliases only support 'Redacted' and custom annotations, not %r" %
Expand Down Expand Up @@ -2002,53 +1997,6 @@ def unwrap(data_type):
data_type = data_type.data_type
return data_type, unwrapped_nullable, unwrapped_alias

def get_custom_annotations_for_alias(data_type):
"""
Given a Stone data type, returns all custom annotations applied to it.
"""
# annotations can only be applied to Aliases, but they can be wrapped in
# Nullable. also, Aliases pointing to other Aliases don't automatically
# inherit their custom annotations, so we might have to traverse.
result = []
data_type, _ = unwrap_nullable(data_type)
while is_alias(data_type):
result.extend(data_type.custom_annotations)
data_type, _ = unwrap_nullable(data_type.data_type)
return result

def get_custom_annotations_recursive(data_type):
"""
Given a Stone data type, returns all custom annotations applied to any of
its memebers, as well as submembers, ..., to an arbitrary depth.
"""
# because Stone structs can contain references to themselves (or otherwise
# be cyclical), we need ot keep track of the data types we've already seen
data_types_seen = set()

def recurse(data_type):
if data_type in data_types_seen:
return
data_types_seen.add(data_type)

dt, _, _ = unwrap(data_type)
if is_struct_type(dt) or is_union_type(dt):
for field in dt.fields:
for annotation in recurse(field.data_type):
yield annotation
for annotation in field.custom_annotations:
yield annotation
elif is_list_type(dt):
for annotation in recurse(dt.data_type):
yield annotation
elif is_map_type(dt):
for annotation in recurse(dt.value_data_type):
yield annotation

for annotation in get_custom_annotations_for_alias(data_type):
yield annotation

return recurse(data_type)


def is_alias(data_type):
return isinstance(data_type, Alias)
Expand Down
2 changes: 0 additions & 2 deletions test/test_python_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,6 @@ def test_json_encoder(self):
self.assertEqual(json_encode(bv.Nullable(bv.String()), u'abc'), json.dumps('abc'))

def test_json_encoder_union(self):
# pylint: disable=attribute-defined-outside-init
class S(object):
_all_field_names_ = {'f'}
_all_fields_ = [('f', bv.String())]
Expand Down Expand Up @@ -331,7 +330,6 @@ def _get_val_data_type(cls, tag, cp):
self.assertEqual(json_encode(bv.Union(U), u, old_style=True), json.dumps({'g': m}))

def test_json_encoder_error_messages(self):
# pylint: disable=attribute-defined-outside-init
class S3(object):
_all_field_names_ = {'j'}
_all_fields_ = [('j', bv.UInt64(max_value=10))]
Expand Down
1 change: 1 addition & 0 deletions test/test_python_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ def test_struct_with_custom_annotations(self):
StructField('unannotated_field', Int32(), None, None),
])
struct.fields[0].set_annotations([annotation])
struct.recursive_custom_annotations = set([annotation])

result = self._evaluate_struct(ns, struct)

Expand Down
46 changes: 46 additions & 0 deletions test/test_stone.py
Original file line number Diff line number Diff line change
Expand Up @@ -4886,6 +4886,52 @@ def test_custom_annotations(self):

struct = api.namespaces['test'].data_type_by_name['TestStruct']
self.assertEqual(struct.fields[0].custom_annotations[0], annotation)
self.assertEqual(struct.recursive_custom_annotations, set([
(alias, api.namespaces['test'].annotation_by_name['VeryImportant']),
(struct.fields[0], api.namespaces['test'].annotation_by_name['SortaImportant']),
]))

# Test recursive references are captured
ns2 = textwrap.dedent("""\
namespace testchain
import test
alias TestAliasChain = String
@test.SortaImportant
struct TestStructChain
f test.TestStruct
g List(TestAliasChain)
""")
ns3 = textwrap.dedent("""\
namespace teststruct
import testchain
struct TestStructToStruct
f testchain.TestStructChain
""")
ns4 = textwrap.dedent("""\
namespace testalias
import testchain
struct TestStructToAlias
f testchain.TestAliasChain
""")

api = specs_to_ir([('test.stone', text), ('testchain.stone', ns2),
('teststruct.stone', ns3), ('testalias.stone', ns4)])

struct_namespaces = [ns.name for ns in
api.namespaces['teststruct'].get_imported_namespaces(
consider_annotation_types=True)]
self.assertTrue('test' in struct_namespaces)
alias_namespaces = [ns.name for ns in
api.namespaces['testalias'].get_imported_namespaces(
consider_annotation_types=True)]
self.assertTrue('test' in alias_namespaces)


if __name__ == '__main__':
Expand Down

0 comments on commit fdf1761

Please sign in to comment.