diff --git a/merlin/models/torch/__init__.py b/merlin/models/torch/__init__.py index 025c8ba0dc..603d6e2892 100644 --- a/merlin/models/torch/__init__.py +++ b/merlin/models/torch/__init__.py @@ -31,6 +31,11 @@ from merlin.models.torch.router import RouterBlock from merlin.models.torch.transforms.agg import Concat, Stack +input_schema = schema.input_schema +output_schema = schema.output_schema +target_schema = schema.target_schema +feature_schema = schema.feature_schema + __all__ = [ "Batch", "BinaryOutput", @@ -55,6 +60,10 @@ "Concat", "Stack", "schema", + "input_schema", + "output_schema", + "feature_schema", + "target_schema", "DLRMBlock", "DLRMModel", ] diff --git a/merlin/models/torch/batch.py b/merlin/models/torch/batch.py index 72d813c00c..14e21ec5af 100644 --- a/merlin/models/torch/batch.py +++ b/merlin/models/torch/batch.py @@ -375,10 +375,10 @@ def sample_features( return sample_batch(data, batch_size, shuffle).features -@schema.output.register_tensor(Batch) +@schema.output_schema.register_tensor(Batch) def _(input): output_schema = Schema() - output_schema += schema.output.tensors(input.features) - output_schema += schema.output.tensors(input.targets) + output_schema += schema.output_schema.tensors(input.features) + output_schema += schema.output_schema.tensors(input.targets) return output_schema diff --git a/merlin/models/torch/block.py b/merlin/models/torch/block.py index 42dede5b9b..cf5bea6f29 100644 --- a/merlin/models/torch/block.py +++ b/merlin/models/torch/block.py @@ -588,31 +588,31 @@ def set_pre(module: nn.Module, pre: BlockContainer): return set_pre(module[0], pre) -@schema.input.register(BlockContainer) +@schema.input_schema.register(BlockContainer) def _(module: BlockContainer, input: Schema): - return schema.input(module[0], input) if module else input + return schema.input_schema(module[0], input) if module else input -@schema.input.register(ParallelBlock) +@schema.input_schema.register(ParallelBlock) def _(module: ParallelBlock, input: Schema): if module.pre: - return schema.input(module.pre) + return schema.input_schema(module.pre) out_schema = Schema() for branch in module.branches.values(): - out_schema += schema.input(branch, input) + out_schema += schema.input_schema(branch, input) return out_schema -@schema.output.register(ParallelBlock) +@schema.output_schema.register(ParallelBlock) def _(module: ParallelBlock, input: Schema): if module.post: - return schema.output(module.post, input) + return schema.output_schema(module.post, input) output = Schema() for name, branch in module.branches.items(): - branch_schema = schema.output(branch, input) + branch_schema = schema.output_schema(branch, input) if len(branch_schema) == 1 and branch_schema.first.name == "output": branch_schema = Schema([branch_schema.first.with_name(name)]) @@ -622,9 +622,9 @@ def _(module: ParallelBlock, input: Schema): return output -@schema.output.register(BlockContainer) +@schema.output_schema.register(BlockContainer) def _(module: BlockContainer, input: Schema): - return schema.output(module[-1], input) if module else input + return schema.output_schema(module[-1], input) if module else input BlockT = TypeVar("BlockT", bound=BlockContainer) @@ -720,13 +720,13 @@ def _extract_block(main, selection, route, name=None): if isinstance(main, ParallelBlock): return _extract_parallel(main, selection, route=route, name=name) - main_schema = schema.input(main) - route_schema = schema.input(route) + main_schema = schema.input_schema(main) + route_schema = schema.input_schema(route) if main_schema == route_schema: from merlin.models.torch.inputs.select import SelectFeatures - out_schema = schema.output(main, main_schema) + out_schema = schema.output_schema(main, main_schema) if len(out_schema) == 1 and out_schema.first.name == "output": out_schema = Schema([out_schema.first.with_name(name)]) diff --git a/merlin/models/torch/blocks/mlp.py b/merlin/models/torch/blocks/mlp.py index e7e9c1334d..8038dc89f7 100644 --- a/merlin/models/torch/blocks/mlp.py +++ b/merlin/models/torch/blocks/mlp.py @@ -4,7 +4,7 @@ from torch import nn from merlin.models.torch.block import Block -from merlin.models.torch.schema import Schema, output +from merlin.models.torch.schema import Schema, output_schema from merlin.models.torch.transforms.agg import Concat, MaybeAgg @@ -84,8 +84,8 @@ def __init__( super().__init__(*modules) -@output.register(nn.LazyLinear) -@output.register(nn.Linear) -@output.register(MLPBlock) -def _output_schema_block(module: nn.LazyLinear, input: Schema): - return output.tensors(torch.ones((1, module.out_features), dtype=float)) +@output_schema.register(nn.LazyLinear) +@output_schema.register(nn.Linear) +@output_schema.register(MLPBlock) +def _output_schema_block(module: nn.LazyLinear, inputs: Schema): + return output_schema.tensors(torch.ones((1, module.out_features), dtype=float)) diff --git a/merlin/models/torch/inputs/select.py b/merlin/models/torch/inputs/select.py index 6456e807a6..a2c04261af 100644 --- a/merlin/models/torch/inputs/select.py +++ b/merlin/models/torch/inputs/select.py @@ -201,8 +201,8 @@ def forward(self, inputs, batch: Batch) -> Dict[str, torch.Tensor]: @schema.extract.register(SelectKeys) def _(main, selection, route, name=None): - main_schema = schema.input(main) - route_schema = schema.input(route) + main_schema = schema.input_schema(main) + route_schema = schema.input_schema(route) diff = main_schema.excluding_by_name(route_schema.column_names) diff --git a/merlin/models/torch/schema.py b/merlin/models/torch/schema.py index d140eff670..f76e99f91d 100644 --- a/merlin/models/torch/schema.py +++ b/merlin/models/torch/schema.py @@ -97,7 +97,7 @@ def __call__(self, module: nn.Module, inputs: Optional[Schema] = None) -> Schema return super().__call__(module, inputs) except NotImplementedError: raise ValueError( - f"Could not get output schema of {module} " "please call mm.trace_schema first." + f"Could not get output schema of {module} " "please call `mm.schema.trace` first." ) def trace( @@ -127,7 +127,7 @@ def _func(module: nn.Module, input: Schema) -> Schema: def __call__(self, module: nn.Module, inputs: Optional[Schema] = None) -> Schema: try: - _inputs = input(module) + _inputs = input_schema(module) inputs = _inputs except ValueError: pass @@ -156,7 +156,7 @@ def __call__(self, module: nn.Module, inputs: Optional[Schema] = None) -> Schema return super().__call__(module, inputs) except NotImplementedError: raise ValueError( - f"Could not get output schema of {module} " "please call mm.trace_schema first." + f"Could not get output schema of {module} " "please call `mm.schema.trace` first." ) def trace( @@ -165,7 +165,7 @@ def trace( inputs: Union[torch.Tensor, Dict[str, torch.Tensor], Schema], outputs: Union[torch.Tensor, Dict[str, torch.Tensor], Schema], ) -> Schema: - _input_schema = input.get_schema(inputs) + _input_schema = input_schema.get_schema(inputs) _output_schema = self.get_schema(outputs) try: @@ -207,8 +207,8 @@ def extract(self, module: nn.Module, selection: Selection, route: nn.Module, nam return fn(module, selection, route, name=name) -input = _InputSchemaDispatch("input_schema") -output = _OutputSchemaDispatch("output_schema") +input_schema = _InputSchemaDispatch("input_schema") +output_schema = _OutputSchemaDispatch("output_schema") select = _SelectDispatch("selection") extract = _ExtractDispatch("extract") @@ -240,13 +240,13 @@ def _hook(mod: nn.Module, inputs: Tuple[torch.Tensor], outputs: torch.Tensor): mod.__input_schemas = () mod.__output_schemas = () - _input_schema = input.trace(mod, inputs[0]) + _input_schema = input_schema.trace(mod, inputs[0]) if _input_schema not in mod.__input_schemas: mod.__input_schemas += (_input_schema,) - mod.__output_schemas += (output.trace(mod, _input_schema, outputs),) + mod.__output_schemas += (output_schema.trace(mod, _input_schema, outputs),) def add_hook(m): - custom_modules = list(output.dispatcher.registry.keys()) + custom_modules = list(output_schema.dispatcher.registry.keys()) if m and isinstance(m, tuple(custom_modules[1:])): return @@ -261,7 +261,7 @@ def add_hook(m): return module_out -def features(module: nn.Module) -> Schema: +def feature_schema(module: nn.Module) -> Schema: """Extract the feature schema from a PyTorch Module. This function operates by applying the `get_feature_schema` method @@ -293,7 +293,7 @@ def get_feature_schema(module): return feature_schema -def targets(module: nn.Module) -> Schema: +def target_schema(module: nn.Module) -> Schema: """ Extract the target schema from a PyTorch Module. @@ -484,7 +484,7 @@ def select(self, selection: Selection) -> "Selectable": raise NotImplementedError() -@output.register_tensor(torch.Tensor) +@output_schema.register_tensor(torch.Tensor) def _tensor_to_schema(input, name="output"): kwargs = dict(dims=input.shape[1:], dtype=input.dtype) @@ -494,13 +494,13 @@ def _tensor_to_schema(input, name="output"): return Schema([ColumnSchema(name, **kwargs)]) -@input.register_tensor(torch.Tensor) +@input_schema.register_tensor(torch.Tensor) def _(input): return _tensor_to_schema(input, "input") -@input.register_tensor(Dict[str, torch.Tensor]) -@output.register_tensor(Dict[str, torch.Tensor]) +@input_schema.register_tensor(Dict[str, torch.Tensor]) +@output_schema.register_tensor(Dict[str, torch.Tensor]) def _(input): output = Schema() for k, v in sorted(input.items()): @@ -509,23 +509,27 @@ def _(input): return output -@input.register_tensor(Tuple[torch.Tensor]) -@output.register_tensor(Tuple[torch.Tensor]) -@input.register_tensor(Tuple[torch.Tensor, torch.Tensor]) -@output.register_tensor(Tuple[torch.Tensor, torch.Tensor]) -@input.register_tensor(Tuple[torch.Tensor, torch.Tensor, torch.Tensor]) -@output.register_tensor(Tuple[torch.Tensor, torch.Tensor, torch.Tensor]) -@input.register_tensor(Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]) -@output.register_tensor(Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]) -@input.register_tensor(Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]) -@output.register_tensor(Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]) -@input.register_tensor( +@input_schema.register_tensor(Tuple[torch.Tensor]) +@output_schema.register_tensor(Tuple[torch.Tensor]) +@input_schema.register_tensor(Tuple[torch.Tensor, torch.Tensor]) +@output_schema.register_tensor(Tuple[torch.Tensor, torch.Tensor]) +@input_schema.register_tensor(Tuple[torch.Tensor, torch.Tensor, torch.Tensor]) +@output_schema.register_tensor(Tuple[torch.Tensor, torch.Tensor, torch.Tensor]) +@input_schema.register_tensor(Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]) +@output_schema.register_tensor(Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]) +@input_schema.register_tensor( + Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor] +) +@output_schema.register_tensor( + Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor] +) +@input_schema.register_tensor( Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor] ) -@output.register_tensor( +@output_schema.register_tensor( Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor] ) -@input.register_tensor( +@input_schema.register_tensor( Tuple[ torch.Tensor, torch.Tensor, @@ -536,7 +540,7 @@ def _(input): torch.Tensor, ] ) -@output.register_tensor( +@output_schema.register_tensor( Tuple[ torch.Tensor, torch.Tensor, @@ -547,7 +551,7 @@ def _(input): torch.Tensor, ] ) -@input.register_tensor( +@input_schema.register_tensor( Tuple[ torch.Tensor, torch.Tensor, @@ -559,7 +563,7 @@ def _(input): torch.Tensor, ] ) -@output.register_tensor( +@output_schema.register_tensor( Tuple[ torch.Tensor, torch.Tensor, @@ -571,7 +575,7 @@ def _(input): torch.Tensor, ] ) -@input.register_tensor( +@input_schema.register_tensor( Tuple[ torch.Tensor, torch.Tensor, @@ -584,7 +588,7 @@ def _(input): torch.Tensor, ] ) -@output.register_tensor( +@output_schema.register_tensor( Tuple[ torch.Tensor, torch.Tensor, @@ -597,7 +601,7 @@ def _(input): torch.Tensor, ] ) -@input.register_tensor( +@input_schema.register_tensor( Tuple[ torch.Tensor, torch.Tensor, @@ -611,7 +615,7 @@ def _(input): torch.Tensor, ] ) -@output.register_tensor( +@output_schema.register_tensor( Tuple[ torch.Tensor, torch.Tensor, diff --git a/tests/unit/torch/inputs/test_select.py b/tests/unit/torch/inputs/test_select.py index 37c15c784f..7dcac2114e 100644 --- a/tests/unit/torch/inputs/test_select.py +++ b/tests/unit/torch/inputs/test_select.py @@ -72,8 +72,8 @@ def test_forward(self): outputs = mm.schema.trace(block, self.batch.features["session_id"], batch=self.batch) assert len(outputs) == 5 - assert mm.schema.input(block).column_names == ["input"] - assert mm.schema.features(block).column_names == [ + assert mm.input_schema(block).column_names == ["input"] + assert mm.feature_schema(block).column_names == [ "user_id", "country", "user_age", diff --git a/tests/unit/torch/inputs/test_tabular.py b/tests/unit/torch/inputs/test_tabular.py index a3ee2fb0a3..e81fe44ce4 100644 --- a/tests/unit/torch/inputs/test_tabular.py +++ b/tests/unit/torch/inputs/test_tabular.py @@ -68,27 +68,27 @@ def test_extract_route_two_tower(self): "item_recency", "item_genres", } - assert set(mm.schema.input(towers).column_names) == input_cols - assert mm.schema.output(towers).column_names == ["user", "item"] + assert set(mm.input_schema(towers).column_names) == input_cols + assert mm.output_schema(towers).column_names == ["user", "item"] categorical = towers.select(Tags.CATEGORICAL) outputs = module_utils.module_test(towers, self.batch) assert mm.schema.extract(towers, Tags.CATEGORICAL)[1] == categorical - assert set(mm.schema.input(towers).column_names) == input_cols - assert mm.schema.output(towers).column_names == ["user", "item"] + assert set(mm.input_schema(towers).column_names) == input_cols + assert mm.output_schema(towers).column_names == ["user", "item"] outputs = towers(self.batch.features) assert outputs["user"].shape == (10, 10) assert outputs["item"].shape == (10, 10) new_inputs, route = mm.schema.extract(towers, Tags.USER) - assert mm.schema.output(new_inputs).column_names == ["user", "item"] + assert mm.output_schema(new_inputs).column_names == ["user", "item"] assert "user" in new_inputs.branches assert new_inputs.branches["user"][0].select_keys.column_names == ["user"] assert "user" in route.branches - assert mm.schema.output(route).select_by_tag(Tags.EMBEDDING).column_names == ["user"] + assert mm.output_schema(route).select_by_tag(Tags.EMBEDDING).column_names == ["user"] def test_extract_route_embeddings(self): input_block = mm.TabularInputBlock(self.schema, init="defaults", agg="concat") @@ -97,7 +97,7 @@ def test_extract_route_embeddings(self): assert outputs.shape == (10, 107) no_embs, emb_route = mm.schema.extract(input_block, Tags.CATEGORICAL) - output_schema = mm.schema.output(emb_route) + output_schema = mm.output_schema(emb_route) assert len(output_schema.select_by_tag(Tags.USER)) == 3 assert len(output_schema.select_by_tag(Tags.ITEM)) == 3 diff --git a/tests/unit/torch/models/test_base.py b/tests/unit/torch/models/test_base.py index 2ee931989d..c67c65f2f8 100644 --- a/tests/unit/torch/models/test_base.py +++ b/tests/unit/torch/models/test_base.py @@ -197,7 +197,7 @@ def test_output_schema(self): "b": torch.tensor([[5.0, 6.0], [7.0, 8.0]]), } outputs = mm.schema.trace(model, inputs) - schema = mm.schema.output(model) + schema = mm.output_schema(model) for name in outputs: assert name in schema.column_names assert schema[name].dtype.name == str(outputs[name].dtype).split(".")[-1] @@ -205,7 +205,7 @@ def test_output_schema(self): def test_no_output_schema(self): model = mm.Model(PlusOne()) with pytest.raises(ValueError, match="Could not get output schema of PlusOne()"): - mm.schema.output(model) + mm.output_schema(model) def test_train_classification_with_lightning_trainer(self, music_streaming_data, batch_size=16): schema = music_streaming_data.schema.select_by_name( diff --git a/tests/unit/torch/test_block.py b/tests/unit/torch/test_block.py index ea36aaa412..02eb342f61 100644 --- a/tests/unit/torch/test_block.py +++ b/tests/unit/torch/test_block.py @@ -57,7 +57,7 @@ def test_identity(self): outputs = module_utils.module_test(block, inputs, batch=Batch(inputs)) assert torch.equal(inputs, outputs) - assert mm.schema.output(block) == mm.schema.output.tensors(inputs) + assert mm.output_schema(block) == mm.output_schema.tensors(inputs) def test_insertion(self): block = Block() @@ -158,7 +158,7 @@ def test_schema_tracking(self): inputs = torch.randn(1, 3) outputs = mm.schema.trace(pb, inputs) - schema = mm.schema.output(pb) + schema = mm.output_schema(pb) for name in outputs: assert name in schema.column_names @@ -258,9 +258,9 @@ def test_set_pre(self): def test_input_schema_pre(self): pb = ParallelBlock({"a": PlusOne(), "b": PlusOne()}) outputs = mm.schema.trace(pb, torch.randn(1, 3)) - input_schema = mm.schema.input(pb) + input_schema = mm.input_schema(pb) assert len(input_schema) == 1 - assert len(mm.schema.output(pb)) == 2 + assert len(mm.output_schema(pb)) == 2 assert len(outputs) == 2 pb2 = ParallelBlock({"a": PlusOne(), "b": PlusOne()}) @@ -270,8 +270,8 @@ def test_input_schema_pre(self): assert get_pre(pb2)[0] == pb pb2.append(pb) - assert input_schema == mm.schema.input(pb2) - assert mm.schema.output(pb2) == mm.schema.output(pb) + assert input_schema == mm.input_schema(pb2) + assert mm.output_schema(pb2) == mm.output_schema(pb) def test_leaf(self): block = ParallelBlock({"a": PlusOne()}) diff --git a/tests/unit/torch/test_router.py b/tests/unit/torch/test_router.py index 89f9292f6c..76459ea8db 100644 --- a/tests/unit/torch/test_router.py +++ b/tests/unit/torch/test_router.py @@ -162,4 +162,4 @@ def test_nested(self): outputs = module_utils.module_test(nested, self.batch.features) assert list(outputs.keys()) == ["user_age"] - assert "user_age" in mm.schema.output(nested).column_names + assert "user_age" in mm.output_schema(nested).column_names diff --git a/tests/unit/torch/test_schema.py b/tests/unit/torch/test_schema.py index d6919c2617..78ca1811ec 100644 --- a/tests/unit/torch/test_schema.py +++ b/tests/unit/torch/test_schema.py @@ -19,11 +19,11 @@ from merlin.models.torch.schema import ( Selectable, - features, + feature_schema, select, select_schema, selection_name, - targets, + target_schema, ) from merlin.schema import ColumnSchema, Schema, Tags @@ -122,8 +122,8 @@ def test_features(self): schema = Schema([ColumnSchema("a"), ColumnSchema("b")]) module = MockModule(feature_schema=schema) - assert features(module) == schema - assert targets(module) == Schema() + assert feature_schema(module) == schema + assert target_schema(module) == Schema() class TestTargets: @@ -131,5 +131,5 @@ def test_targets(self): schema = Schema([ColumnSchema("a"), ColumnSchema("b")]) module = MockModule(target_schema=schema) - assert targets(module) == schema - assert features(module) == Schema() + assert target_schema(module) == schema + assert feature_schema(module) == Schema()