Skip to content

Commit

Permalink
Improve argument name checks for build/compute_output_signature
Browse files Browse the repository at this point in the history
We want to support the case where a layer looks like this:

```python
class SomeLayer(keras_core.layers.Layer):
    def build(self, shape):
        ...

    def call(
        self,
        inputs,
        optional_tensor_arg_one=None,
        optional_tensor_arg_two=None,
        optional_tensor_arg_tree=None,
    ):
        ...
```

However, due to a bug, we would forward a dictionary of all passed
tensor shapes to build when it only accepted one argument.

This tries to clean up the logic a bit.
  • Loading branch information
mattdangerw committed Jun 22, 2023
1 parent 471f02f commit e3c90e3
Show file tree
Hide file tree
Showing 2 changed files with 122 additions and 80 deletions.
141 changes: 73 additions & 68 deletions keras_core/layers/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -778,16 +778,14 @@ def compute_output_spec(self, *args, **kwargs):
else:
# Use compute_output_shape() to return the right output spec
call_spec = CallSpec(self.call, args, kwargs)
shapes_dict = get_shapes_dict(
self.compute_output_shape, call_spec, self.__class__
shapes_dict = get_shapes_dict(call_spec)
shapes_dict = update_shapes_dict_for_target_fn(
self.compute_output_shape,
shapes_dict=shapes_dict,
call_spec=call_spec,
class_name=self.__class__.__name__,
)
if len(shapes_dict) == 1:
# Single arg: pass it positionally
input_shape = tuple(shapes_dict.values())[0]
output_shape = self.compute_output_shape(input_shape)
else:
# More than one shape: pass them by name.
output_shape = self.compute_output_shape(**shapes_dict)
output_shape = self.compute_output_shape(**shapes_dict)

if (
isinstance(output_shape, list)
Expand Down Expand Up @@ -948,47 +946,40 @@ def count_params(self):

def _maybe_build(self, call_spec):
if not self.built:
shapes_dict = get_shapes_dict(self.build, call_spec, self.__class__)
shapes_dict = get_shapes_dict(call_spec)
self._build_shapes_dict = shapes_dict
failure = False

if len(shapes_dict) == 1:
# Single arg: pass it positionally
input_shape = tuple(shapes_dict.values())[0]
with backend.name_scope(self.name):
if utils.is_default(
self.build
) and might_have_unbuilt_state(self):
status = self._build_by_run_for_single_pos_arg(
input_shape

with backend.name_scope(self.name):
if not utils.is_default(self.build):
shapes_dict = update_shapes_dict_for_target_fn(
self.build,
shapes_dict=shapes_dict,
call_spec=call_spec,
class_name=self.__class__.__name__,
)
self.build(**shapes_dict)
elif might_have_unbuilt_state(self):
if len(shapes_dict) == 1:
# Single arg: pass it positionally
success = self._build_by_run_for_single_pos_arg(
tuple(shapes_dict.values())[0]
)
if not status:
failure = True
else:
self.build(input_shape)
else:
with backend.name_scope(self.name):
if utils.is_default(self.build):
if might_have_unbuilt_state(self):
status = self._build_by_run_for_kwargs(shapes_dict)
if not status:
failure = True
else:
self.build(**shapes_dict)
if failure:
if call_spec.eager:
# Will let the actual eager call do the state-building
return
raise ValueError(
f"Layer '{self.name}' looks like it has "
"unbuilt state, but Keras is not able to "
"trace the layer `call()` in order to "
"build it automatically. You must implement "
"the `def build(self, input_shape)` method on your "
"layer. It should create all variables used by the "
"layer (e.g. by calling `layer.build()` on all its "
"children layers)."
)
success = self._build_by_run_for_kwargs(shapes_dict)
if not success:
if call_spec.eager:
# Will let the actual eager call do state-building
return
raise ValueError(
f"Layer '{self.name}' looks like it has "
"unbuilt state, but Keras is not able to "
"trace the layer `call()` in order to "
"build it automatically. You must implement "
"the `def build(self, input_shape)` method on your "
"layer. It should create all variables used by the "
"layer (e.g. by calling `layer.build()` on all its "
"children layers)."
)
self.built = True

# Check input spec again (after build, since self.input_spec
Expand Down Expand Up @@ -1216,17 +1207,16 @@ def get_arguments_dict(fn, args, kwargs):
return arg_dict


def get_shapes_dict(target_fn, call_spec, cls):
def get_shapes_dict(call_spec):
"""Convert the call() arguments dict into a dict of input shape arguments.
Example:
```
>>> get_shapes_dict(self.build, call_spec, cls)
>>> get_shapes_dict(call_spec)
{"input_a_shape": (2, 3)}
```
"""
expected_names = check_shapes_signature(target_fn, call_spec, cls)
shapes_dict = {}
for k, v in call_spec.tensor_arguments_dict.items():
if k == "mask" or k.startswith("mask_"):
Expand All @@ -1235,8 +1225,6 @@ def get_shapes_dict(target_fn, call_spec, cls):
if k == "kwargs" or k == "args":
# Do not include catch-alls in shapes dict
continue
if expected_names is not None and f"{k}_shape" not in expected_names:
continue
if k in call_spec.nested_tensor_argument_names:
shapes_dict[f"{k}_shape"] = nest.map_structure(
lambda x: backend.standardize_shape(x.shape), v
Expand All @@ -1246,22 +1234,30 @@ def get_shapes_dict(target_fn, call_spec, cls):
return shapes_dict


def check_shapes_signature(target_fn, call_spec, cls):
"""Asserts that the argument names in `target_fn` match arguments in `call`.
def update_shapes_dict_for_target_fn(
target_fn,
shapes_dict,
call_spec,
class_name,
):
"""Updates a `shapes_dict` for `build()` or `compute_output_shape()`.
We use this to check that `build()` and `compute_output_shape()` arguments
align with `call()` arguments.
This function will align a dictionary of the shapes of all tensor
passed to `call`, with the signatures of `build()` or
`compute_output_shape()`.
For instance if `build()` has the signature
`def build(self, a_shape, b_shape)` we expect `call()` to accept the
arguments `a` and `b`.
The alignment is a follows:
When there is a single argument accepted by `target_fn`, we do allow any
name and do not check the call signature.
- If `build()` or `compute_output_shape()` accept only one argument,
forward the shape of the first positional argument from call without
checking any argument names.
- If `build()` or `compute_output_shape()` accept multiple arguments,
enforce that all argument names match a call argument name, e.g.
`foo_shape` would match call argument `foo`.
Returns:
The list of arguments names expected by the `target_fn` or
`None` if any passed name is acceptable.
An updated `shapes_dict` that can be used to invoke
`target_fn(**shapes_dict)`.
"""
if utils.is_default(target_fn):
return None
Expand All @@ -1274,31 +1270,40 @@ def check_shapes_signature(target_fn, call_spec, cls):
param.KEYWORD_ONLY,
):
expected_names.append(name)

# Single arg: don't check names, pass first shape.
if len(expected_names) == 1:
return None
key = expected_names[0]
input_shape = tuple(shapes_dict.values())[0]
return {key: input_shape}

# Multiple args: check that all names line up.
kwargs = {}
for name in expected_names:
method_name = target_fn.__name__
error_preamble = (
f"For a `{method_name}()` method with more than one argument, all "
"arguments should have a `_shape` suffix and match an argument "
f"from `call()`. E.g. `{method_name}(self, foo_shape, bar_shape)` "
"would match `call(self, foo, bar)`."
)
if not name.endswith("_shape"):
raise ValueError(
f"{error_preamble} For layer '{cls.__name__}', "
f"{error_preamble} For layer '{class_name}', "
f"Received `{method_name}()` argument "
f"`{name}`, which does not end in `_shape`."
)
expected_call_arg = utils.removesuffix(name, "_shape")
if expected_call_arg not in call_spec.arguments_dict:
raise ValueError(
f"{error_preamble} For layer '{cls.__name__}', "
f"{error_preamble} For layer '{class_name}', "
f"received `{method_name}()` argument "
f"`{name}`, but `call()` does not have argument "
f"`{expected_call_arg}`."
)
return expected_names
if name in shapes_dict:
kwargs[name] = shapes_dict[name]

return kwargs


class CallContext:
Expand Down
61 changes: 49 additions & 12 deletions keras_core/layers/layer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -598,40 +598,77 @@ def call(self, inputs):
def test_build_signature_errors(self):
class NoShapeSuffix(layers.Layer):
def build(self, foo_shape, bar):
self._built = True
self.built = True

def call(self, foo, bar):
return foo + bar

class NonMatchingArgument(layers.Layer):
def build(self, foo_shape, baz_shape):
self._built = True
self.built = True

def call(self, foo, bar):
return foo + bar
return foo[:, 0] + bar[:, 0]

class MatchingArguments(layers.Layer):
def build(self, foo_shape, bar_shape):
self._built = True
def build(self, bar_shape, foo_shape):
self.foo_shape = foo_shape
self.bar_shape = bar_shape
self.built = True

def call(self, foo, bar):
return foo + bar
return foo[:, 0] + bar[:, 0]

class SubsetArguments(layers.Layer):
def build(self, baz_shape, foo_shape):
self.foo_shape = foo_shape
self.baz_shape = baz_shape
self.built = True

foo = backend.numpy.ones((4, 4))
bar = backend.numpy.ones((4, 4))
def call(self, foo, bar=None, baz=None):
return foo[:, 0] + bar[:, 0] + baz[:, 0]

class SingleArgument(layers.Layer):
def build(self, anything_whatsoever):
self.foo_shape = anything_whatsoever
self.built = True

def call(self, foo, bar):
return foo[:, 0] + bar[:, 0]

foo = backend.numpy.ones((4, 1))
bar = backend.numpy.ones((4, 2))
baz = backend.numpy.ones((4, 3))
with self.assertRaisesRegex(
ValueError,
r"argument `bar`, which does not end in `_shape`",
):
NoShapeSuffix()(foo, bar)
layer = NoShapeSuffix()
layer(foo, bar)

with self.assertRaisesRegex(
ValueError,
r"`baz_shape`, but `call\(\)` does not have argument `baz`",
):
NonMatchingArgument()(foo, bar)

MatchingArguments()(foo, bar)
layer = NonMatchingArgument()
layer(foo, bar)

# Align by name when build and call arguments match.
layer = MatchingArguments()
layer(foo, bar)
self.assertEqual(layer.foo_shape, foo.shape)
self.assertEqual(layer.bar_shape, bar.shape)

# Align by name when build supports a subset of call arguments.
layer = SubsetArguments()
layer(foo, bar, baz)
self.assertEqual(layer.foo_shape, foo.shape)
self.assertEqual(layer.baz_shape, baz.shape)

# When build has only one argument, match the first call argument.
layer = SingleArgument()
layer(foo, bar)
self.assertEqual(layer.foo_shape, foo.shape)

def test_training_arg_not_specified(self):
class NoTrainingSpecified(layers.Layer):
Expand Down

0 comments on commit e3c90e3

Please sign in to comment.