Skip to content

Commit

Permalink
Enable lazy module initialization from schema for any module (#1186)
Browse files Browse the repository at this point in the history
* Rename `setup_schema` to `initialize_from_schema`

* Add test for tracing with initialize_from_schema

* Infer feature schema from batch argument signature

* Add missing self

* Rename `setup_schema` in input and output blocks

---------

Co-authored-by: Marc Romeyn <[email protected]>
  • Loading branch information
oliverholworthy and marcromeyn committed Jul 11, 2023
1 parent d2113e8 commit 219bedb
Show file tree
Hide file tree
Showing 17 changed files with 213 additions and 85 deletions.
11 changes: 0 additions & 11 deletions merlin/models/torch/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@

from merlin.dataloader.torch import Loader
from merlin.io import Dataset
from merlin.models.torch import schema
from merlin.schema import Schema


@torch.jit.script
Expand Down Expand Up @@ -373,12 +371,3 @@ def sample_features(
"""

return sample_batch(data, batch_size, shuffle).features


@schema.output_schema.register_tensor(Batch)
def _(input):
output_schema = Schema()
output_schema += schema.output_schema.tensors(input.features)
output_schema += schema.output_schema.tensors(input.targets)

return output_schema
25 changes: 14 additions & 11 deletions merlin/models/torch/inputs/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,12 @@ def __init__(
self.has_combiner = self.seq_combiner is not None
self.has_module_combiner = isinstance(self.seq_combiner, nn.Module)
self.num_embeddings = 0
self.setup_schema(schema or Schema())
self.input_schema = None
if schema:
self.initialize_from_schema(schema or Schema())
self._initialized_from_schema = True

def setup_schema(self, schema: Schema):
def initialize_from_schema(self, schema: Schema):
"""
Sets up the schema for the embedding table.
Expand Down Expand Up @@ -462,17 +465,17 @@ def __init__(
self.seq_combiner = seq_combiner
self.kwargs = kwargs
if isinstance(schema, Schema):
self.setup_schema(schema)
self.initialize_from_schema(schema)
self._initialized_from_schema = True

def setup_schema(self, schema: Schema):
"""
Sets up the schema for the embedding tables.
def initialize_from_schema(self, schema: Schema):
"""Initializes the module from a schema.
Called during the schema tracing of the model.
Args:
schema (Schema): The schema to setup.
Returns:
EmbeddingTables: The updated EmbeddingTables instance with the setup schema.
Parameters
----------
schema : Schema
The schema to initialize with
"""
self.schema = schema

Expand Down
41 changes: 24 additions & 17 deletions merlin/models/torch/inputs/select.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# limitations under the License.
#

from typing import Dict, List, Optional
from typing import Dict, List, Optional, Union

import torch
from torch import nn
Expand All @@ -24,7 +24,7 @@
from merlin.schema import ColumnSchema, Schema, Tags


class SelectKeys(nn.Module, schema.Selectable):
class SelectKeys(nn.Module, schema.Selectable, schema.LazySchemaModuleMixin):
"""Filter tabular data based on a defined schema.
Example usage::
Expand All @@ -50,20 +50,24 @@ class SelectKeys(nn.Module, schema.Selectable):
List of column names in the schema.
"""

def __init__(self, schema: Optional[Schema] = None):
def __init__(self, schema: Optional[Union[Schema, ColumnSchema]] = None):
super().__init__()
self.column_names: List[str] = []
if schema:
self.setup_schema(schema)

def setup_schema(self, schema: Schema):
if isinstance(schema, ColumnSchema):
schema = Schema([schema])

if schema:
self.initialize_from_schema(schema)
self._initialized_from_schema = True
else:
schema = Schema()
self.schema = schema
self.column_names: List[str] = schema.column_names

def initialize_from_schema(self, schema: Schema):
super().initialize_from_schema(schema)
self.schema = schema
self.column_names = schema.column_names
self.input_schema = schema
self.output_schema = schema
self.column_names = schema.column_names

def select(self, selection: schema.Selection) -> "SelectKeys":
"""Select a subset of the schema based on the provided selection.
Expand Down Expand Up @@ -125,7 +129,7 @@ def __eq__(self, other) -> bool:
return set(self.column_names) == set(other.column_names)


class SelectFeatures(nn.Module):
class SelectFeatures(nn.Module, schema.LazySchemaModuleMixin):
"""Filter tabular data based on a defined schema.
It operates similarly to SelectKeys, but it uses the features from Batch.
Expand Down Expand Up @@ -154,21 +158,21 @@ def __init__(self, schema: Optional[Schema] = None):
super().__init__()
self.select_keys = SelectKeys(schema=schema)
if schema:
self.setup_schema(schema)
self.initialize_from_schema(schema)

def setup_schema(self, schema: Schema):
def initialize_from_schema(self, schema: Schema):
"""Set up the schema for the SelectFeatures.
Parameters
----------
schema : Schema
The schema to use for selection.
"""
self.select_keys.setup_schema(schema)
super().initialize_from_schema(schema)
self.select_keys.initialize_from_schema(schema)
self.embedding_names = schema.select_by_tag(Tags.EMBEDDING).column_names
self.input_schema = self.select_keys.input_schema
self.feature_schema = self.input_schema
self.output_schema = self.select_keys.output_schema
self.input_schema = Schema()
self.output_schema = schema

def select(self, selection: schema.Selection) -> "SelectFeatures":
"""Select a subset of the schema based on the provided selection.
Expand All @@ -187,6 +191,9 @@ def select(self, selection: schema.Selection) -> "SelectFeatures":

return SelectFeatures(schema)

def compute_feature_schema(self, feature_schema: Schema) -> Schema:
return feature_schema[self.select_keys.column_names]

def forward(self, inputs, batch: Batch) -> Dict[str, torch.Tensor]:
outputs = {}
selected = self.select_keys(batch.features)
Expand Down
4 changes: 2 additions & 2 deletions merlin/models/torch/inputs/tabular.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,8 @@ def __init__(
self.agg = agg
super().__init__(schema)

def setup_schema(self, schema: Schema):
super().setup_schema(schema)
def initialize_from_schema(self, schema: Schema):
super().initialize_from_schema(schema)
self.schema: Schema = self.selectable.schema
if self.init:
if isinstance(self.init, str):
Expand Down
10 changes: 6 additions & 4 deletions merlin/models/torch/outputs/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,13 @@ def __init__(
metrics=metrics or [m(task="binary") for m in self.DEFAULT_METRICS_CLS],
)
if schema:
self.setup_schema(schema)
self.initialize_from_schema(schema)
self._initialized_from_schema = True

if not self.metrics:
self.metrics = self.default_metrics()

def setup_schema(self, target: Optional[Union[ColumnSchema, Schema]]):
def initialize_from_schema(self, target: Optional[Union[ColumnSchema, Schema]]):
"""Set up the schema for the output.
Parameters
Expand Down Expand Up @@ -138,7 +139,8 @@ def __init__(
)

if schema:
self.setup_schema(schema)
self.initialize_from_schema(schema)
self._initialized_from_schema = True

@classmethod
def with_weight_tying(
Expand Down Expand Up @@ -168,7 +170,7 @@ def tie_weights(

return self

def setup_schema(self, target: Optional[Union[ColumnSchema, Schema]]):
def initialize_from_schema(self, target: Optional[Union[ColumnSchema, Schema]]):
"""Set up the schema for the output.
Parameters
Expand Down
5 changes: 3 additions & 2 deletions merlin/models/torch/outputs/regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,13 @@ def __init__(
metrics=metrics or [m() for m in self.DEFAULT_METRICS_CLS],
)
if schema:
self.setup_schema(schema)
self.initialize_from_schema(schema)
self._initialized_from_schema = True

if not self.metrics:
self.metrics = self.default_metrics()

def setup_schema(self, target: Optional[Union[ColumnSchema, Schema]]):
def initialize_from_schema(self, target: Optional[Union[ColumnSchema, Schema]]):
"""Set up the schema for the output.
Parameters
Expand Down
4 changes: 2 additions & 2 deletions merlin/models/torch/outputs/tabular.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,10 @@ def __init__(
self.init = init
super().__init__(schema, prepend_routing_module=False)

def setup_schema(self, schema: Schema):
def initialize_from_schema(self, schema: Schema):
if self.selection:
schema = select(schema, self.selection)
super().setup_schema(schema)
super().initialize_from_schema(schema)
self.schema: Schema = self.selectable.schema
if self.init:
if isinstance(self.init, str):
Expand Down
6 changes: 3 additions & 3 deletions merlin/models/torch/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,11 @@ def __init__(self, selectable: schema.Selectable, prepend_routing_module: bool =
super().__init__()
self.prepend_routing_module = prepend_routing_module
if isinstance(selectable, Schema):
self.setup_schema(selectable)
self.initialize_from_schema(selectable)
else:
self.selectable: schema.Selectable = selectable

def setup_schema(self, schema: Schema):
def initialize_from_schema(self, schema):
from merlin.models.torch.inputs.select import SelectKeys

self.selectable = SelectKeys(schema)
Expand Down Expand Up @@ -99,7 +99,7 @@ def add_route(
return self

if module is not None:
schema.setup_schema(module, routing_module.schema)
schema.initialize_from_schema(module, routing_module.schema)

if self.prepend_routing_module:
if isinstance(module, ParallelBlock):
Expand Down
Loading

0 comments on commit 219bedb

Please sign in to comment.