Skip to content

Commit

Permalink
Merge branch 'main' into add_pytorch_DLRM_example
Browse files Browse the repository at this point in the history
  • Loading branch information
radekosmulski committed Jul 5, 2023
2 parents 489c705 + eddc2ab commit 7dfa115
Show file tree
Hide file tree
Showing 12 changed files with 96 additions and 83 deletions.
9 changes: 9 additions & 0 deletions merlin/models/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,11 @@
from merlin.models.torch.router import RouterBlock
from merlin.models.torch.transforms.agg import Concat, Stack

input_schema = schema.input_schema
output_schema = schema.output_schema
target_schema = schema.target_schema
feature_schema = schema.feature_schema

__all__ = [
"Batch",
"BinaryOutput",
Expand All @@ -55,6 +60,10 @@
"Concat",
"Stack",
"schema",
"input_schema",
"output_schema",
"feature_schema",
"target_schema",
"DLRMBlock",
"DLRMModel",
]
6 changes: 3 additions & 3 deletions merlin/models/torch/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,10 +375,10 @@ def sample_features(
return sample_batch(data, batch_size, shuffle).features


@schema.output.register_tensor(Batch)
@schema.output_schema.register_tensor(Batch)
def _(input):
output_schema = Schema()
output_schema += schema.output.tensors(input.features)
output_schema += schema.output.tensors(input.targets)
output_schema += schema.output_schema.tensors(input.features)
output_schema += schema.output_schema.tensors(input.targets)

return output_schema
26 changes: 13 additions & 13 deletions merlin/models/torch/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -588,31 +588,31 @@ def set_pre(module: nn.Module, pre: BlockContainer):
return set_pre(module[0], pre)


@schema.input.register(BlockContainer)
@schema.input_schema.register(BlockContainer)
def _(module: BlockContainer, input: Schema):
return schema.input(module[0], input) if module else input
return schema.input_schema(module[0], input) if module else input


@schema.input.register(ParallelBlock)
@schema.input_schema.register(ParallelBlock)
def _(module: ParallelBlock, input: Schema):
if module.pre:
return schema.input(module.pre)
return schema.input_schema(module.pre)

out_schema = Schema()
for branch in module.branches.values():
out_schema += schema.input(branch, input)
out_schema += schema.input_schema(branch, input)

return out_schema


@schema.output.register(ParallelBlock)
@schema.output_schema.register(ParallelBlock)
def _(module: ParallelBlock, input: Schema):
if module.post:
return schema.output(module.post, input)
return schema.output_schema(module.post, input)

output = Schema()
for name, branch in module.branches.items():
branch_schema = schema.output(branch, input)
branch_schema = schema.output_schema(branch, input)

if len(branch_schema) == 1 and branch_schema.first.name == "output":
branch_schema = Schema([branch_schema.first.with_name(name)])
Expand All @@ -622,9 +622,9 @@ def _(module: ParallelBlock, input: Schema):
return output


@schema.output.register(BlockContainer)
@schema.output_schema.register(BlockContainer)
def _(module: BlockContainer, input: Schema):
return schema.output(module[-1], input) if module else input
return schema.output_schema(module[-1], input) if module else input


BlockT = TypeVar("BlockT", bound=BlockContainer)
Expand Down Expand Up @@ -720,13 +720,13 @@ def _extract_block(main, selection, route, name=None):
if isinstance(main, ParallelBlock):
return _extract_parallel(main, selection, route=route, name=name)

main_schema = schema.input(main)
route_schema = schema.input(route)
main_schema = schema.input_schema(main)
route_schema = schema.input_schema(route)

if main_schema == route_schema:
from merlin.models.torch.inputs.select import SelectFeatures

out_schema = schema.output(main, main_schema)
out_schema = schema.output_schema(main, main_schema)
if len(out_schema) == 1 and out_schema.first.name == "output":
out_schema = Schema([out_schema.first.with_name(name)])

Expand Down
12 changes: 6 additions & 6 deletions merlin/models/torch/blocks/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from torch import nn

from merlin.models.torch.block import Block
from merlin.models.torch.schema import Schema, output
from merlin.models.torch.schema import Schema, output_schema
from merlin.models.torch.transforms.agg import Concat, MaybeAgg


Expand Down Expand Up @@ -84,8 +84,8 @@ def __init__(
super().__init__(*modules)


@output.register(nn.LazyLinear)
@output.register(nn.Linear)
@output.register(MLPBlock)
def _output_schema_block(module: nn.LazyLinear, input: Schema):
return output.tensors(torch.ones((1, module.out_features), dtype=float))
@output_schema.register(nn.LazyLinear)
@output_schema.register(nn.Linear)
@output_schema.register(MLPBlock)
def _output_schema_block(module: nn.LazyLinear, inputs: Schema):
return output_schema.tensors(torch.ones((1, module.out_features), dtype=float))
4 changes: 2 additions & 2 deletions merlin/models/torch/inputs/select.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,8 +201,8 @@ def forward(self, inputs, batch: Batch) -> Dict[str, torch.Tensor]:

@schema.extract.register(SelectKeys)
def _(main, selection, route, name=None):
main_schema = schema.input(main)
route_schema = schema.input(route)
main_schema = schema.input_schema(main)
route_schema = schema.input_schema(route)

diff = main_schema.excluding_by_name(route_schema.column_names)

Expand Down
74 changes: 39 additions & 35 deletions merlin/models/torch/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def __call__(self, module: nn.Module, inputs: Optional[Schema] = None) -> Schema
return super().__call__(module, inputs)
except NotImplementedError:
raise ValueError(
f"Could not get output schema of {module} " "please call mm.trace_schema first."
f"Could not get output schema of {module} " "please call `mm.schema.trace` first."
)

def trace(
Expand Down Expand Up @@ -127,7 +127,7 @@ def _func(module: nn.Module, input: Schema) -> Schema:

def __call__(self, module: nn.Module, inputs: Optional[Schema] = None) -> Schema:
try:
_inputs = input(module)
_inputs = input_schema(module)
inputs = _inputs
except ValueError:
pass
Expand Down Expand Up @@ -156,7 +156,7 @@ def __call__(self, module: nn.Module, inputs: Optional[Schema] = None) -> Schema
return super().__call__(module, inputs)
except NotImplementedError:
raise ValueError(
f"Could not get output schema of {module} " "please call mm.trace_schema first."
f"Could not get output schema of {module} " "please call `mm.schema.trace` first."
)

def trace(
Expand All @@ -165,7 +165,7 @@ def trace(
inputs: Union[torch.Tensor, Dict[str, torch.Tensor], Schema],
outputs: Union[torch.Tensor, Dict[str, torch.Tensor], Schema],
) -> Schema:
_input_schema = input.get_schema(inputs)
_input_schema = input_schema.get_schema(inputs)
_output_schema = self.get_schema(outputs)

try:
Expand Down Expand Up @@ -207,8 +207,8 @@ def extract(self, module: nn.Module, selection: Selection, route: nn.Module, nam
return fn(module, selection, route, name=name)


input = _InputSchemaDispatch("input_schema")
output = _OutputSchemaDispatch("output_schema")
input_schema = _InputSchemaDispatch("input_schema")
output_schema = _OutputSchemaDispatch("output_schema")
select = _SelectDispatch("selection")
extract = _ExtractDispatch("extract")

Expand Down Expand Up @@ -240,13 +240,13 @@ def _hook(mod: nn.Module, inputs: Tuple[torch.Tensor], outputs: torch.Tensor):
mod.__input_schemas = ()
mod.__output_schemas = ()

_input_schema = input.trace(mod, inputs[0])
_input_schema = input_schema.trace(mod, inputs[0])
if _input_schema not in mod.__input_schemas:
mod.__input_schemas += (_input_schema,)
mod.__output_schemas += (output.trace(mod, _input_schema, outputs),)
mod.__output_schemas += (output_schema.trace(mod, _input_schema, outputs),)

def add_hook(m):
custom_modules = list(output.dispatcher.registry.keys())
custom_modules = list(output_schema.dispatcher.registry.keys())
if m and isinstance(m, tuple(custom_modules[1:])):
return

Expand All @@ -261,7 +261,7 @@ def add_hook(m):
return module_out


def features(module: nn.Module) -> Schema:
def feature_schema(module: nn.Module) -> Schema:
"""Extract the feature schema from a PyTorch Module.
This function operates by applying the `get_feature_schema` method
Expand Down Expand Up @@ -293,7 +293,7 @@ def get_feature_schema(module):
return feature_schema


def targets(module: nn.Module) -> Schema:
def target_schema(module: nn.Module) -> Schema:
"""
Extract the target schema from a PyTorch Module.
Expand Down Expand Up @@ -484,7 +484,7 @@ def select(self, selection: Selection) -> "Selectable":
raise NotImplementedError()


@output.register_tensor(torch.Tensor)
@output_schema.register_tensor(torch.Tensor)
def _tensor_to_schema(input, name="output"):
kwargs = dict(dims=input.shape[1:], dtype=input.dtype)

Expand All @@ -494,13 +494,13 @@ def _tensor_to_schema(input, name="output"):
return Schema([ColumnSchema(name, **kwargs)])


@input.register_tensor(torch.Tensor)
@input_schema.register_tensor(torch.Tensor)
def _(input):
return _tensor_to_schema(input, "input")


@input.register_tensor(Dict[str, torch.Tensor])
@output.register_tensor(Dict[str, torch.Tensor])
@input_schema.register_tensor(Dict[str, torch.Tensor])
@output_schema.register_tensor(Dict[str, torch.Tensor])
def _(input):
output = Schema()
for k, v in sorted(input.items()):
Expand All @@ -509,23 +509,27 @@ def _(input):
return output


@input.register_tensor(Tuple[torch.Tensor])
@output.register_tensor(Tuple[torch.Tensor])
@input.register_tensor(Tuple[torch.Tensor, torch.Tensor])
@output.register_tensor(Tuple[torch.Tensor, torch.Tensor])
@input.register_tensor(Tuple[torch.Tensor, torch.Tensor, torch.Tensor])
@output.register_tensor(Tuple[torch.Tensor, torch.Tensor, torch.Tensor])
@input.register_tensor(Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor])
@output.register_tensor(Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor])
@input.register_tensor(Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor])
@output.register_tensor(Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor])
@input.register_tensor(
@input_schema.register_tensor(Tuple[torch.Tensor])
@output_schema.register_tensor(Tuple[torch.Tensor])
@input_schema.register_tensor(Tuple[torch.Tensor, torch.Tensor])
@output_schema.register_tensor(Tuple[torch.Tensor, torch.Tensor])
@input_schema.register_tensor(Tuple[torch.Tensor, torch.Tensor, torch.Tensor])
@output_schema.register_tensor(Tuple[torch.Tensor, torch.Tensor, torch.Tensor])
@input_schema.register_tensor(Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor])
@output_schema.register_tensor(Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor])
@input_schema.register_tensor(
Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]
)
@output_schema.register_tensor(
Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]
)
@input_schema.register_tensor(
Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]
)
@output.register_tensor(
@output_schema.register_tensor(
Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]
)
@input.register_tensor(
@input_schema.register_tensor(
Tuple[
torch.Tensor,
torch.Tensor,
Expand All @@ -536,7 +540,7 @@ def _(input):
torch.Tensor,
]
)
@output.register_tensor(
@output_schema.register_tensor(
Tuple[
torch.Tensor,
torch.Tensor,
Expand All @@ -547,7 +551,7 @@ def _(input):
torch.Tensor,
]
)
@input.register_tensor(
@input_schema.register_tensor(
Tuple[
torch.Tensor,
torch.Tensor,
Expand All @@ -559,7 +563,7 @@ def _(input):
torch.Tensor,
]
)
@output.register_tensor(
@output_schema.register_tensor(
Tuple[
torch.Tensor,
torch.Tensor,
Expand All @@ -571,7 +575,7 @@ def _(input):
torch.Tensor,
]
)
@input.register_tensor(
@input_schema.register_tensor(
Tuple[
torch.Tensor,
torch.Tensor,
Expand All @@ -584,7 +588,7 @@ def _(input):
torch.Tensor,
]
)
@output.register_tensor(
@output_schema.register_tensor(
Tuple[
torch.Tensor,
torch.Tensor,
Expand All @@ -597,7 +601,7 @@ def _(input):
torch.Tensor,
]
)
@input.register_tensor(
@input_schema.register_tensor(
Tuple[
torch.Tensor,
torch.Tensor,
Expand All @@ -611,7 +615,7 @@ def _(input):
torch.Tensor,
]
)
@output.register_tensor(
@output_schema.register_tensor(
Tuple[
torch.Tensor,
torch.Tensor,
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/torch/inputs/test_select.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,8 @@ def test_forward(self):

outputs = mm.schema.trace(block, self.batch.features["session_id"], batch=self.batch)
assert len(outputs) == 5
assert mm.schema.input(block).column_names == ["input"]
assert mm.schema.features(block).column_names == [
assert mm.input_schema(block).column_names == ["input"]
assert mm.feature_schema(block).column_names == [
"user_id",
"country",
"user_age",
Expand Down
Loading

0 comments on commit 7dfa115

Please sign in to comment.