From 08f2e696160303702b2df64193c888223e3e4363 Mon Sep 17 00:00:00 2001 From: Oliver Holworthy <1216955+oliverholworthy@users.noreply.github.com> Date: Wed, 4 Oct 2023 13:26:12 +0100 Subject: [PATCH] Handle ColumnSchema `target` in serialization of `SequenceTransform` --- merlin/models/tf/transforms/sequence.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/merlin/models/tf/transforms/sequence.py b/merlin/models/tf/transforms/sequence.py index 312019cd92..b1090f6605 100644 --- a/merlin/models/tf/transforms/sequence.py +++ b/merlin/models/tf/transforms/sequence.py @@ -184,7 +184,10 @@ def compute_output_shape(self, input_shape): def get_config(self): """Returns the config of the layer as a Python dictionary.""" config = super().get_config() - config["target"] = self.target + target = self.target + if isinstance(target, ColumnSchema): + target = schema_utils.schema_to_tensorflow_metadata_json(Schema([target])) + config["target"] = target return config @@ -193,6 +196,10 @@ def from_config(cls, config): """Creates layer from its config. Returning the instance.""" config = tf_utils.maybe_deserialize_keras_objects(config, ["pre", "post", "aggregation"]) config["schema"] = schema_utils.tensorflow_metadata_json_to_schema(config["schema"]) + if config["target"].startswith("{"): # we have a schema + config["target"] = [ + col for col in schema_utils.tensorflow_metadata_json_to_schema(config["target"]) + ][0] schema = config.pop("schema") target = config.pop("target") return cls(schema, target, **config)