From 233aade18bbf0d2f8c8a9bc8de173efa95ad6ccc Mon Sep 17 00:00:00 2001 From: Oliver Holworthy Date: Mon, 6 Mar 2023 18:44:06 +0000 Subject: [PATCH 01/11] Split Dimension into ragged and fixed --- merlin/dtypes/shape.py | 78 ++++++++++++++++++++++++++++++++---------- 1 file changed, 60 insertions(+), 18 deletions(-) diff --git a/merlin/dtypes/shape.py b/merlin/dtypes/shape.py index 1222aef7d..eac096a54 100644 --- a/merlin/dtypes/shape.py +++ b/merlin/dtypes/shape.py @@ -24,8 +24,35 @@ class DefaultShapes(Enum): SCALAR = (None,) +def Dimension(size=None): + """Create a dimension from a size. + + A size can be one of: + - None : a ragged dimension of unknown size + - int : a fixed dimension of some size (-1 = unknown) + - 2-tuple : the bounds of a ragged dimension (fixed if min == max) + """ + if isinstance(size, BaseDimension): + return size + elif isinstance(size, tuple) and len(size) == 2: + if size[0] == size[1]: + return FixedDimension(size[0], size[1]) + return RaggedDimension(size[0], size[1]) + elif isinstance(size, int): + if size == -1: + return FixedDimension() + return FixedDimension(size, size) + elif size is None: + return RaggedDimension() + else: + raise ValueError( + f"Invalid dimension format: {size}. Each dimension is expected " + " to be None, a single integer, or a tuple with length 2." + ) + + @dataclass(frozen=True) -class Dimension: +class BaseDimension: """ The range of potential sizes for a single dimension of a field or column """ @@ -58,12 +85,12 @@ def is_bounded(self): return self.max is not None @property - def is_fixed(self): + def is_uniform(self): return self.is_bounded and self.min == self.max @property def is_variable(self): - return not self.is_fixed + return not self.is_uniform @property def is_unknown(self): @@ -76,6 +103,29 @@ def with_max(self, value): return replace(self, max=value) +class RaggedDimension(BaseDimension): + @property + def is_fixed(self): + return False + + @property + def size(self): + return None + + +class FixedDimension(BaseDimension): + @property + def is_fixed(self): + return True + + @property + def size(self): + if self.is_uniform: + return self.max + else: + return -1 + + @dataclass(frozen=True) class Shape: """ @@ -94,19 +144,7 @@ def __post_init__(self): if self.dims is not None: new_dims = [] for i, dim in enumerate(self.dims): - if isinstance(dim, Dimension): - new_dim = dim - elif isinstance(dim, tuple) and len(dim) == 2: - new_dim = Dimension(dim[0], dim[1]) - elif isinstance(dim, int): - new_dim = Dimension(dim, dim) - elif dim is None: - new_dim = Dimension() - else: - raise ValueError( - f"Invalid shape tuple format: {self.dims}. Each dimension is expected " - " to be None, a single integer, or a tuple with length 2." - ) + new_dim = Dimension(dim) new_dims.append(new_dim) object.__setattr__(self, "dims", tuple(new_dims)) @@ -156,9 +194,13 @@ def is_bounded(self): def is_fixed(self): return all(dim.is_fixed for dim in self.dims) + @property + def is_uniform(self): + return all(dim.is_uniform for dim in self.dims) + @property def is_variable(self): - return not self.is_fixed + return not self.is_uniform @property def is_list(self): @@ -166,7 +208,7 @@ def is_list(self): @property def is_ragged(self): - return self.is_list and any(dim.min != dim.max for dim in self.dims[1:]) + return self.is_list and any(not dim.is_fixed for dim in self.dims[1:]) @property def as_tuple(self): From 51c6ea3731a8734f4b117d093d3941215523e73a Mon Sep 17 00:00:00 2001 From: Oliver Holworthy Date: Mon, 6 Mar 2023 18:44:26 +0000 Subject: [PATCH 02/11] Coerce float dims saved to TensorFlowMetadata back to int --- merlin/dtypes/shape.py | 6 +++++- merlin/schema/io/tensorflow_metadata.py | 15 +++++++++++++-- 2 files changed, 18 insertions(+), 3 deletions(-) diff --git a/merlin/dtypes/shape.py b/merlin/dtypes/shape.py index eac096a54..c09885ed4 100644 --- a/merlin/dtypes/shape.py +++ b/merlin/dtypes/shape.py @@ -212,7 +212,11 @@ def is_ragged(self): @property def as_tuple(self): - return tuple(((dim.min, dim.max) for dim in self.dims)) if self.dims else None + return ( + tuple((dim.size if dim.is_fixed else (dim.min, dim.max) for dim in self.dims)) + if self.dims + else None + ) @property def is_unknown(self): diff --git a/merlin/schema/io/tensorflow_metadata.py b/merlin/schema/io/tensorflow_metadata.py index 012615c8b..b500f639a 100644 --- a/merlin/schema/io/tensorflow_metadata.py +++ b/merlin/schema/io/tensorflow_metadata.py @@ -252,7 +252,9 @@ def _pb_extra_metadata(column_schema): properties = { k: v for k, v in column_schema.properties.items() if k not in ("domain", "value_count") } - properties["_dims"] = list(list(dim) for dim in column_schema.shape.as_tuple or []) + properties["_dims"] = list( + list(dim) if isinstance(dim, tuple) else dim for dim in column_schema.shape.as_tuple or [] + ) properties["is_list"] = column_schema.is_list properties["is_ragged"] = column_schema.is_ragged if column_schema.dtype.element_size: @@ -377,6 +379,12 @@ def _merlin_properties(feature): } +def _coerce_int(v): + if isinstance(v, float): + return int(v) + return v + + def _merlin_dtype(feature, properties): dtype = md.unknown item_size = int(properties.get("dtype_item_size", 0)) or None @@ -396,7 +404,10 @@ def _merlin_dtype(feature, properties): dims_list = properties.pop("_dims", None) if dims_list: - dims_tuple = tuple(tuple(dim) for dim in dims_list) + dims_tuple = tuple( + tuple(_coerce_int(d) for d in dim) if isinstance(dim, list) else _coerce_int(dim) + for dim in dims_list + ) dtype = dtype.with_shape(dims_tuple) # If we found dims, avoid overwriting that shape with one inferred from counts or flags From c315fc0be9da728ad6265787ecaa800b6fd32ebc Mon Sep 17 00:00:00 2001 From: Oliver Holworthy Date: Mon, 6 Mar 2023 18:45:06 +0000 Subject: [PATCH 03/11] Support shape dim lookup with getitem --- merlin/dtypes/shape.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/merlin/dtypes/shape.py b/merlin/dtypes/shape.py index c09885ed4..a26607097 100644 --- a/merlin/dtypes/shape.py +++ b/merlin/dtypes/shape.py @@ -164,6 +164,9 @@ def __eq__(self, other): return self.dims == other.dims + def __getitem__(self, idx): + return self.dims[idx] + def __iter__(self): return self.dims From 566ff61e90d94c95f77c73f27455f18c1ca4516f Mon Sep 17 00:00:00 2001 From: Oliver Holworthy Date: Mon, 6 Mar 2023 18:45:30 +0000 Subject: [PATCH 04/11] Restore support for unknown-size fixed shapes to ColumnSchema --- merlin/schema/schema.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/merlin/schema/schema.py b/merlin/schema/schema.py index 413f20111..2830c6208 100644 --- a/merlin/schema/schema.py +++ b/merlin/schema/schema.py @@ -96,8 +96,10 @@ def __post_init__(self, dims): new_shape = dtype.shape elif value_counts: new_shape = self._shape_from_counts(Domain(**value_counts)) + elif self.is_list and self.is_ragged is False: + new_shape = Shape((-1, -1)) elif self.is_list: - new_shape = self._shape_from_flags(self.is_list) + new_shape = Shape((-1, None)) else: new_shape = Shape() @@ -112,11 +114,8 @@ def __post_init__(self, dims): properties = {**self.properties, **{"value_count": value_counts}} object.__setattr__(self, "properties", properties) - def _shape_from_flags(self, is_list): - return Shape(((0, None), (0, None))) if is_list else None - def _shape_from_counts(self, value_count): - return Shape(((0, None), (value_count.min or 0, value_count.max))) + return Shape((-1, (value_count.min or 0, value_count.max))) @property def shape(self): From f0007774d78bedb7aa819895c4aea6f613f40154 Mon Sep 17 00:00:00 2001 From: Oliver Holworthy Date: Mon, 6 Mar 2023 18:46:12 +0000 Subject: [PATCH 05/11] Add checks for float dims in Dimension post init --- merlin/dtypes/shape.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/merlin/dtypes/shape.py b/merlin/dtypes/shape.py index a26607097..d3038c67c 100644 --- a/merlin/dtypes/shape.py +++ b/merlin/dtypes/shape.py @@ -20,8 +20,8 @@ class DefaultShapes(Enum): - LIST = (None, None) - SCALAR = (None,) + LIST = (-1, None) + SCALAR = (-1,) def Dimension(size=None): @@ -32,7 +32,7 @@ def Dimension(size=None): - int : a fixed dimension of some size (-1 = unknown) - 2-tuple : the bounds of a ragged dimension (fixed if min == max) """ - if isinstance(size, BaseDimension): + if isinstance(size, (FixedDimension, RaggedDimension)): return size elif isinstance(size, tuple) and len(size) == 2: if size[0] == size[1]: @@ -64,6 +64,12 @@ def __post_init__(self): if self.min is None: raise ValueError("The minimum size of a dimension cannot be None. ") + if not isinstance(self.min, int): + raise ValueError("The minimmum size must be an integer. " f"Provided min: {self.min}") + + if self.max and not isinstance(self.max, int): + raise ValueError("The maximum size must be an integer. " f"Provided max: {self.max}") + if self.min < 0: raise ValueError( "The minimum size of a dimension must be non-negative. " f"Provided min: {self.min}" @@ -82,14 +88,17 @@ def __post_init__(self): @property def is_bounded(self): + """Is the dimension bounded in size?""" return self.max is not None @property def is_uniform(self): + """Is the dimension uniform in size?""" return self.is_bounded and self.min == self.max @property def is_variable(self): + """Can the size of the dimension vary between instances of tensors.""" return not self.is_uniform @property From 18c8d97b9fa5e675ccce89d34b0a9fa5cc8bc19f Mon Sep 17 00:00:00 2001 From: Oliver Holworthy Date: Mon, 6 Mar 2023 18:46:28 +0000 Subject: [PATCH 06/11] Update tests reflecting changes to fixed/ragged dimensions --- tests/unit/dtypes/test_shape.py | 41 +++++++++++------------- tests/unit/schema/test_column_schemas.py | 16 ++++----- 2 files changed, 26 insertions(+), 31 deletions(-) diff --git a/tests/unit/dtypes/test_shape.py b/tests/unit/dtypes/test_shape.py index 5dd9a164a..0bbeed068 100644 --- a/tests/unit/dtypes/test_shape.py +++ b/tests/unit/dtypes/test_shape.py @@ -28,7 +28,7 @@ def test_empty_dimension(): def test_min_max_val_dimension(): - dim = Dimension(2, 3) + dim = Dimension((2, 3)) assert dim.min == 2 assert dim.max == 3 @@ -36,32 +36,29 @@ def test_min_max_val_dimension(): def test_fixed_min_with_unbounded_max(): dim = Dimension(2) assert dim.min == 2 - assert dim.max is None + assert dim.max == 2 - dim = Dimension(2, None) + dim = Dimension((2, None)) assert dim.min == 2 assert dim.max is None def test_min_is_none_raises_error(): with pytest.raises(ValueError): - Dimension(None) - - with pytest.raises(ValueError): - Dimension(None, 1) + Dimension((None, 1)) def test_bounds_must_be_non_negative(): with pytest.raises(ValueError): - Dimension(-1, 2) + Dimension((-1, 2)) with pytest.raises(ValueError): - Dimension(2, -1) + Dimension((2, -1)) def test_max_less_than_min(): with pytest.raises(ValueError): - Dimension(2, 1) + Dimension((2, 1)) def test_is_bounded(): @@ -69,15 +66,15 @@ def test_is_bounded(): assert dim.is_bounded is False dim = Dimension(2) - assert dim.is_bounded is False + assert dim.is_bounded is True - dim = Dimension(2, 2) + dim = Dimension((2, 2)) assert dim.is_bounded is True - dim = Dimension(2, 4) + dim = Dimension((2, 4)) assert dim.is_bounded is True - dim = Dimension(2, None) + dim = Dimension((2, None)) assert dim.is_bounded is False @@ -86,15 +83,15 @@ def test_is_fixed(): assert dim.is_fixed is False dim = Dimension(2) - assert dim.is_fixed is False + assert dim.is_fixed is True - dim = Dimension(2, 2) + dim = Dimension((2, 2)) assert dim.is_fixed is True - dim = Dimension(2, 4) + dim = Dimension((2, 4)) assert dim.is_fixed is False - dim = Dimension(2, None) + dim = Dimension((2, None)) assert dim.is_fixed is False @@ -103,15 +100,15 @@ def test_is_variable(): assert dim.is_variable is True dim = Dimension(2) - assert dim.is_variable is True + assert dim.is_variable is False - dim = Dimension(2, 2) + dim = Dimension((2, 2)) assert dim.is_variable is False - dim = Dimension(2, 4) + dim = Dimension((2, 4)) assert dim.is_variable is True - dim = Dimension(2, None) + dim = Dimension((2, None)) assert dim.is_variable is True diff --git a/tests/unit/schema/test_column_schemas.py b/tests/unit/schema/test_column_schemas.py index daac73fdd..5b8bb4ebb 100644 --- a/tests/unit/schema/test_column_schemas.py +++ b/tests/unit/schema/test_column_schemas.py @@ -192,11 +192,9 @@ def test_list_column_attributes(): assert col3_schema.is_list assert col3_schema.is_ragged - # TODO: Re-enable this test case once we've addressed cases - # like this in downstream libraries - - # with pytest.raises(ValueError): - # ColumnSchema("col4", is_list=True, is_ragged=False) + col4_schema = ColumnSchema("col4", is_list=True, is_ragged=False) + assert col4_schema.is_list + assert col4_schema.is_ragged is False with pytest.raises(ValueError): ColumnSchema("col5", is_list=False, is_ragged=True) @@ -257,18 +255,18 @@ def test_setting_partial_value_count(value_count): ) assert col_schema.is_list assert not col_schema.is_ragged - assert col_schema.shape == Shape((None, 10)) + assert col_schema.shape == Shape((-1, 10)) assert col_schema.properties["value_count"] == {"min": 10, "max": 10} def test_setting_value_counts_updates_shape_and_flags(): - col_schema = ColumnSchema("col", dims=(None,)) + col_schema = ColumnSchema("col", dims=(-1,)) counts = {"min": 4, "max": 5} updated_schema = col_schema.with_properties({"value_count": counts}) assert updated_schema.properties["value_count"] == counts - assert updated_schema.shape == Shape((None, (4, 5))) + assert updated_schema.shape == Shape((-1, (4, 5))) assert updated_schema.is_list assert updated_schema.is_ragged @@ -287,7 +285,7 @@ def test_setting_flags_updates_shape_and_value_counts(): col_schema = ColumnSchema("col") updated_schema = col_schema.with_dtype(md.int64, is_list=True, is_ragged=True) - assert updated_schema.shape == Shape((None, None)) + assert updated_schema.shape == Shape((-1, None)) assert updated_schema.properties["value_count"] == {"min": 0, "max": None} assert updated_schema.is_list assert updated_schema.is_ragged From b3290017c9ac9303f2190682d2a8ac5cd0872efb Mon Sep 17 00:00:00 2001 From: Oliver Holworthy Date: Tue, 7 Mar 2023 19:31:22 +0000 Subject: [PATCH 07/11] If None provided for first dimension, change to fixed dim --- merlin/dtypes/shape.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/merlin/dtypes/shape.py b/merlin/dtypes/shape.py index d3038c67c..7cde2e52f 100644 --- a/merlin/dtypes/shape.py +++ b/merlin/dtypes/shape.py @@ -153,6 +153,8 @@ def __post_init__(self): if self.dims is not None: new_dims = [] for i, dim in enumerate(self.dims): + if i == 0 and dim is None: + dim = -1 new_dim = Dimension(dim) new_dims.append(new_dim) From 5494898e9d6cdea615801064a141d902ee25787e Mon Sep 17 00:00:00 2001 From: Oliver Holworthy Date: Wed, 8 Mar 2023 10:31:50 +0000 Subject: [PATCH 08/11] Update with min and with_max to handle change from ragged to fixed --- merlin/dtypes/shape.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/merlin/dtypes/shape.py b/merlin/dtypes/shape.py index 7cde2e52f..db37b643f 100644 --- a/merlin/dtypes/shape.py +++ b/merlin/dtypes/shape.py @@ -187,10 +187,16 @@ def with_dim(self, index, value): return replace(self, dims=tuple(new_dims)) def with_dim_min(self, index, value): - return self.with_dim(index, self.dims[index].with_min(value)) + new_dim = self.dims[index].with_min(value) + if new_dim.is_uniform: + new_dim = Dimension(value) + return self.with_dim(index, new_dim) def with_dim_max(self, index, value): - return self.with_dim(index, self.dims[index].with_max(value)) + new_dim = self.dims[index].with_max(value) + if new_dim.is_uniform: + new_dim = Dimension(value) + return self.with_dim(index, new_dim) @property def min(self) -> Tuple: From 9781d0c52b9d0b2c6be7a4ed98ef7730d3d0e77c Mon Sep 17 00:00:00 2001 From: Oliver Holworthy Date: Wed, 8 Mar 2023 10:53:58 +0000 Subject: [PATCH 09/11] Treat (0, None) as fixed dim if is the first dimension --- merlin/dtypes/shape.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/merlin/dtypes/shape.py b/merlin/dtypes/shape.py index db37b643f..c7fec05a7 100644 --- a/merlin/dtypes/shape.py +++ b/merlin/dtypes/shape.py @@ -153,7 +153,7 @@ def __post_init__(self): if self.dims is not None: new_dims = [] for i, dim in enumerate(self.dims): - if i == 0 and dim is None: + if i == 0 and (dim is None or dim == (0, None)): dim = -1 new_dim = Dimension(dim) new_dims.append(new_dim) From b9fbd52b46af9ab67e068da3b4209c14e08bf615 Mon Sep 17 00:00:00 2001 From: Oliver Holworthy Date: Fri, 5 May 2023 13:51:46 +0100 Subject: [PATCH 10/11] Rename FixedDimension to UniformDimension --- merlin/dtypes/shape.py | 54 ++++++++++++++++++++++++------------------ 1 file changed, 31 insertions(+), 23 deletions(-) diff --git a/merlin/dtypes/shape.py b/merlin/dtypes/shape.py index 9f2f6782c..8b382ddbd 100644 --- a/merlin/dtypes/shape.py +++ b/merlin/dtypes/shape.py @@ -32,16 +32,16 @@ def Dimension(size=None): - int : a fixed dimension of some size (-1 = unknown) - 2-tuple : the bounds of a ragged dimension (fixed if min == max) """ - if isinstance(size, (FixedDimension, RaggedDimension)): + if isinstance(size, (UniformDimension, RaggedDimension)): return size elif isinstance(size, tuple) and len(size) == 2: if size[0] == size[1]: - return FixedDimension(size[0], size[1]) + return UniformDimension(size[0], size[1]) return RaggedDimension(size[0], size[1]) elif isinstance(size, int): if size == -1: - return FixedDimension() - return FixedDimension(size, size) + return UniformDimension() + return UniformDimension(size, size) elif size is None: return RaggedDimension() else: @@ -109,14 +109,14 @@ def is_bounded(self): return self.max is not None @property - def is_uniform(self): - """Is the dimension uniform in size?""" + def is_fixed(self): + """Is the dimension fixed in size?""" return self.is_bounded and self.min == self.max @property def is_variable(self): """Can the size of the dimension vary between instances of tensors.""" - return not self.is_uniform + return not self.is_fixed @property def is_unknown(self): @@ -131,22 +131,30 @@ def with_max(self, value): class RaggedDimension(BaseDimension): @property - def is_fixed(self): + def is_uniform(self): return False + @property + def is_ragged(self): + return True + @property def size(self): return None -class FixedDimension(BaseDimension): +class UniformDimension(BaseDimension): @property - def is_fixed(self): + def is_uniform(self): return True + @property + def is_ragged(self): + return False + @property def size(self): - if self.is_uniform: + if self.is_fixed: return self.max else: return -1 @@ -170,9 +178,10 @@ def __post_init__(self): if self.dims is not None: new_dims = [] for i, dim in enumerate(self.dims): - if i == 0 and (dim is None or dim == (0, None)): - dim = -1 - new_dim = Dimension(dim) + if i == 0: + new_dim = UniformDimension(dim) + else: + new_dim = Dimension(dim) new_dims.append(new_dim) object.__setattr__(self, "dims", tuple(new_dims)) @@ -192,9 +201,6 @@ def __eq__(self, other): return self.dims == other.dims - def __getitem__(self, idx): - return self.dims[idx] - def __iter__(self): return self.dims @@ -240,7 +246,7 @@ def is_uniform(self): @property def is_variable(self): - return not self.is_uniform + return not self.is_fixed @property def is_list(self): @@ -248,14 +254,16 @@ def is_list(self): @property def is_ragged(self): - return self.is_list and any(not dim.is_fixed for dim in self.dims[1:]) + return self.is_list and any(dim.is_uniform for dim in self.dims[1:]) @property def as_tuple(self): - return ( - tuple((dim.size if dim.is_fixed else (dim.min, dim.max) for dim in self.dims)) - if self.dims - else None + if not self.dims: + return None + + return tuple( + (dim.size if dim.is_fixed or self.dim.is_uniform else (dim.min, dim.max) + for dim in self.dims) ) @property From f4904911b397925f4e4bb2b466e72f997a9c298e Mon Sep 17 00:00:00 2001 From: Oliver Holworthy Date: Fri, 5 May 2023 15:20:41 +0100 Subject: [PATCH 11/11] Handle variable sized batch dimension --- merlin/dtypes/shape.py | 16 ++++++-------- merlin/schema/io/tensorflow_metadata.py | 29 ++++++++++++++++++------- tests/unit/schema/test_schema_io.py | 2 +- 3 files changed, 29 insertions(+), 18 deletions(-) diff --git a/merlin/dtypes/shape.py b/merlin/dtypes/shape.py index 8b382ddbd..5503a3486 100644 --- a/merlin/dtypes/shape.py +++ b/merlin/dtypes/shape.py @@ -24,7 +24,7 @@ class DefaultShapes(Enum): SCALAR = (-1,) -def Dimension(size=None): +def Dimension(size=None, index=None): """Create a dimension from a size. A size can be one of: @@ -35,7 +35,7 @@ def Dimension(size=None): if isinstance(size, (UniformDimension, RaggedDimension)): return size elif isinstance(size, tuple) and len(size) == 2: - if size[0] == size[1]: + if size[0] == size[1] or index == 0: return UniformDimension(size[0], size[1]) return RaggedDimension(size[0], size[1]) elif isinstance(size, int): @@ -43,6 +43,8 @@ def Dimension(size=None): return UniformDimension() return UniformDimension(size, size) elif size is None: + if index == 0: + return UniformDimension() return RaggedDimension() else: raise ValueError( @@ -178,10 +180,7 @@ def __post_init__(self): if self.dims is not None: new_dims = [] for i, dim in enumerate(self.dims): - if i == 0: - new_dim = UniformDimension(dim) - else: - new_dim = Dimension(dim) + new_dim = Dimension(dim, index=i) new_dims.append(new_dim) object.__setattr__(self, "dims", tuple(new_dims)) @@ -254,7 +253,7 @@ def is_list(self): @property def is_ragged(self): - return self.is_list and any(dim.is_uniform for dim in self.dims[1:]) + return self.is_list and any(dim.is_ragged for dim in self.dims[1:]) @property def as_tuple(self): @@ -262,8 +261,7 @@ def as_tuple(self): return None return tuple( - (dim.size if dim.is_fixed or self.dim.is_uniform else (dim.min, dim.max) - for dim in self.dims) + dim.size if dim.is_fixed or dim.is_uniform else (dim.min, dim.max) for dim in self.dims ) @property diff --git a/merlin/schema/io/tensorflow_metadata.py b/merlin/schema/io/tensorflow_metadata.py index 43f6e42c7..3361956f4 100644 --- a/merlin/schema/io/tensorflow_metadata.py +++ b/merlin/schema/io/tensorflow_metadata.py @@ -19,6 +19,7 @@ import fsspec import merlin.dtypes as md +from merlin.dtypes.shape import RaggedDimension, UniformDimension from merlin.schema.io import proto_utils, schema_bp from merlin.schema.io.schema_bp import Feature, FeatureType, FloatDomain, IntDomain from merlin.schema.io.schema_bp import Schema as ProtoSchema @@ -273,9 +274,15 @@ def _pb_extra_metadata(column_schema): properties = { k: v for k, v in column_schema.properties.items() if k not in ("domain", "value_count") } - properties["_dims"] = list( - list(dim) if isinstance(dim, tuple) else dim for dim in column_schema.shape.as_tuple or [] + _dims = ( + list( + {"min": dim.min, "max": dim.max, "is_uniform": dim.is_uniform} + for dim in column_schema.shape.dims + ) + if column_schema.shape.dims + else [] ) + properties["_dims"] = _dims properties["is_list"] = column_schema.is_list properties["is_ragged"] = column_schema.is_ragged if column_schema.dtype.element_size: @@ -423,10 +430,18 @@ def _merlin_dtype(feature, properties): for dim in dims_list: if isinstance(dim, list): dims.append(tuple(int(d) if isinstance(d, float) else d for d in dim)) - elif dim is not None: + elif isinstance(dim, (int, float)): dims.append(int(dim)) - else: + elif dim is None: dims.append(dim) + elif isinstance(dim, dict): + _min = int(dim["min"]) if isinstance(dim["min"], float) else dim["min"] + _max = int(dim["max"]) if isinstance(dim["max"], float) else dim["max"] + if dim["is_uniform"]: + dims.append(UniformDimension(_min, _max)) + else: + dims.append(RaggedDimension(_min, _max)) + dtype = dtype.with_shape(tuple(dims)) # If we found dims, avoid overwriting that shape with one inferred from counts or flags @@ -452,10 +467,8 @@ def _merlin_column(feature): if Tags.CATEGORICAL not in tags: tags.append(Tags.CATEGORICAL) - dims = dtype.shape.as_tuple - - if dims: - return ColumnSchema(name, tags, properties, dtype, dims=dims) + if dtype.shape.dims: + return ColumnSchema(name, tags, properties, dtype, dims=dtype.shape.dims) else: return ColumnSchema(name, tags, properties, dtype, is_list=is_list, is_ragged=is_ragged) diff --git a/tests/unit/schema/test_schema_io.py b/tests/unit/schema/test_schema_io.py index 3770f4b37..618948bf6 100644 --- a/tests/unit/schema/test_schema_io.py +++ b/tests/unit/schema/test_schema_io.py @@ -205,7 +205,7 @@ def test_schema_with_shape_to_tensorflow_metadata_json(): ragged_dim = loaded_schema["col"].shape[1] assert isinstance(ragged_dim.max, int) assert isinstance(ragged_dim.min, int) - assert ragged_dim == Dimension(min=1, max=5) + assert ragged_dim == Dimension((1, 5)) def test_tensorflow_metadata_from_json():