Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Split Dimension into fixed and ragged #234

Draft
wants to merge 13 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 89 additions & 20 deletions merlin/dtypes/shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,41 @@


class DefaultShapes(Enum):
LIST = (None, None)
SCALAR = (None,)
LIST = (-1, None)
SCALAR = (-1,)


def Dimension(size=None, index=None):
"""Create a dimension from a size.

A size can be one of:
- None : a ragged dimension of unknown size
- int : a fixed dimension of some size (-1 = unknown)
- 2-tuple : the bounds of a ragged dimension (fixed if min == max)
"""
if isinstance(size, (UniformDimension, RaggedDimension)):
return size
elif isinstance(size, tuple) and len(size) == 2:
if size[0] == size[1] or index == 0:
return UniformDimension(size[0], size[1])
return RaggedDimension(size[0], size[1])
elif isinstance(size, int):
if size == -1:
return UniformDimension()
return UniformDimension(size, size)
elif size is None:
if index == 0:
return UniformDimension()
return RaggedDimension()
else:
raise ValueError(
f"Invalid dimension format: {size}. Each dimension is expected "
" to be None, a single integer, or a tuple with length 2."
)


@dataclass(frozen=True)
class Dimension:
class BaseDimension:
"""
The range of potential sizes for a single dimension of a field or column
"""
Expand All @@ -37,6 +66,12 @@ def __post_init__(self):
if self.min is None:
raise ValueError("The minimum size of a dimension cannot be None. ")

if not isinstance(self.min, int):
raise ValueError("The minimmum size must be an integer. " f"Provided min: {self.min}")

if self.max and not isinstance(self.max, int):
raise ValueError("The maximum size must be an integer. " f"Provided max: {self.max}")

if self.min < 0:
raise ValueError(
"The minimum size of a dimension must be non-negative. " f"Provided min: {self.min}"
Expand Down Expand Up @@ -72,14 +107,17 @@ def __int__(self):

@property
def is_bounded(self):
"""Is the dimension bounded in size?"""
return self.max is not None

@property
def is_fixed(self):
"""Is the dimension fixed in size?"""
return self.is_bounded and self.min == self.max

@property
def is_variable(self):
"""Can the size of the dimension vary between instances of tensors."""
return not self.is_fixed

@property
Expand All @@ -93,6 +131,37 @@ def with_max(self, value):
return replace(self, max=value)


class RaggedDimension(BaseDimension):
@property
def is_uniform(self):
return False

@property
def is_ragged(self):
return True

@property
def size(self):
return None


class UniformDimension(BaseDimension):
@property
def is_uniform(self):
return True

@property
def is_ragged(self):
return False

@property
def size(self):
if self.is_fixed:
return self.max
else:
return -1


@dataclass(frozen=True)
class Shape:
"""
Expand All @@ -111,19 +180,7 @@ def __post_init__(self):
if self.dims is not None:
new_dims = []
for i, dim in enumerate(self.dims):
if isinstance(dim, Dimension):
new_dim = dim
elif isinstance(dim, tuple) and len(dim) == 2:
new_dim = Dimension(dim[0], dim[1])
elif isinstance(dim, int):
new_dim = Dimension(dim, dim)
elif dim is None:
new_dim = Dimension()
else:
raise ValueError(
f"Invalid shape tuple format: {self.dims}. Each dimension is expected "
" to be None, a single integer, or a tuple with length 2."
)
new_dim = Dimension(dim, index=i)
new_dims.append(new_dim)

object.__setattr__(self, "dims", tuple(new_dims))
Expand Down Expand Up @@ -155,10 +212,16 @@ def with_dim(self, index, value):
return replace(self, dims=tuple(new_dims))

def with_dim_min(self, index, value):
return self.with_dim(index, self.dims[index].with_min(value))
new_dim = self.dims[index].with_min(value)
if new_dim.is_uniform:
new_dim = Dimension(value)
return self.with_dim(index, new_dim)

def with_dim_max(self, index, value):
return self.with_dim(index, self.dims[index].with_max(value))
new_dim = self.dims[index].with_max(value)
if new_dim.is_uniform:
new_dim = Dimension(value)
return self.with_dim(index, new_dim)

@property
def min(self) -> Tuple:
Expand All @@ -176,6 +239,10 @@ def is_bounded(self):
def is_fixed(self):
return all(dim.is_fixed for dim in self.dims)

@property
def is_uniform(self):
return all(dim.is_uniform for dim in self.dims)

@property
def is_variable(self):
return not self.is_fixed
Expand All @@ -186,14 +253,16 @@ def is_list(self):

@property
def is_ragged(self):
return self.is_list and any(dim.min != dim.max for dim in self.dims[1:])
return self.is_list and any(dim.is_ragged for dim in self.dims[1:])

@property
def as_tuple(self):
if not self.dims:
return None

return tuple(((dim.min, dim.max) if dim.min != dim.max else dim.max for dim in self.dims))
return tuple(
dim.size if dim.is_fixed or dim.is_uniform else (dim.min, dim.max) for dim in self.dims
)

@property
def is_unknown(self):
Expand Down
29 changes: 21 additions & 8 deletions merlin/schema/io/tensorflow_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import fsspec

import merlin.dtypes as md
from merlin.dtypes.shape import RaggedDimension, UniformDimension
from merlin.schema.io import proto_utils, schema_bp
from merlin.schema.io.schema_bp import Feature, FeatureType, FloatDomain, IntDomain
from merlin.schema.io.schema_bp import Schema as ProtoSchema
Expand Down Expand Up @@ -273,9 +274,15 @@ def _pb_extra_metadata(column_schema):
properties = {
k: v for k, v in column_schema.properties.items() if k not in ("domain", "value_count")
}
properties["_dims"] = list(
list(dim) if isinstance(dim, tuple) else dim for dim in column_schema.shape.as_tuple or []
_dims = (
list(
{"min": dim.min, "max": dim.max, "is_uniform": dim.is_uniform}
for dim in column_schema.shape.dims
)
if column_schema.shape.dims
else []
)
properties["_dims"] = _dims
properties["is_list"] = column_schema.is_list
properties["is_ragged"] = column_schema.is_ragged
if column_schema.dtype.element_size:
Expand Down Expand Up @@ -423,10 +430,18 @@ def _merlin_dtype(feature, properties):
for dim in dims_list:
if isinstance(dim, list):
dims.append(tuple(int(d) if isinstance(d, float) else d for d in dim))
elif dim is not None:
elif isinstance(dim, (int, float)):
dims.append(int(dim))
else:
elif dim is None:
dims.append(dim)
elif isinstance(dim, dict):
_min = int(dim["min"]) if isinstance(dim["min"], float) else dim["min"]
_max = int(dim["max"]) if isinstance(dim["max"], float) else dim["max"]
if dim["is_uniform"]:
dims.append(UniformDimension(_min, _max))
else:
dims.append(RaggedDimension(_min, _max))

dtype = dtype.with_shape(tuple(dims))

# If we found dims, avoid overwriting that shape with one inferred from counts or flags
Expand All @@ -452,10 +467,8 @@ def _merlin_column(feature):
if Tags.CATEGORICAL not in tags:
tags.append(Tags.CATEGORICAL)

dims = dtype.shape.as_tuple

if dims:
return ColumnSchema(name, tags, properties, dtype, dims=dims)
if dtype.shape.dims:
return ColumnSchema(name, tags, properties, dtype, dims=dtype.shape.dims)
else:
return ColumnSchema(name, tags, properties, dtype, is_list=is_list, is_ragged=is_ragged)

Expand Down
9 changes: 4 additions & 5 deletions merlin/schema/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,10 @@ def __post_init__(self, dims):
new_shape = dtype.shape
elif value_counts:
new_shape = self._shape_from_counts(Domain(**value_counts))
elif self.is_list and self.is_ragged is False:
new_shape = Shape((-1, -1))
elif self.is_list:
new_shape = self._shape_from_flags(self.is_list)
new_shape = Shape((-1, None))
else:
new_shape = Shape()

Expand All @@ -115,11 +117,8 @@ def __post_init__(self, dims):

object.__setattr__(self, "properties", properties)

def _shape_from_flags(self, is_list):
return Shape(((0, None), (0, None))) if is_list else None

def _shape_from_counts(self, value_count):
return Shape(((0, None), (value_count.min or 0, value_count.max)))
return Shape((-1, (value_count.min or 0, value_count.max)))

@property
def shape(self):
Expand Down
41 changes: 19 additions & 22 deletions tests/unit/dtypes/test_shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,56 +28,53 @@ def test_empty_dimension():


def test_min_max_val_dimension():
dim = Dimension(2, 3)
dim = Dimension((2, 3))
assert dim.min == 2
assert dim.max == 3


def test_fixed_min_with_unbounded_max():
dim = Dimension(2)
assert dim.min == 2
assert dim.max is None
assert dim.max == 2

dim = Dimension(2, None)
dim = Dimension((2, None))
assert dim.min == 2
assert dim.max is None


def test_min_is_none_raises_error():
with pytest.raises(ValueError):
Dimension(None)

with pytest.raises(ValueError):
Dimension(None, 1)
Dimension((None, 1))


def test_bounds_must_be_non_negative():
with pytest.raises(ValueError):
Dimension(-1, 2)
Dimension((-1, 2))

with pytest.raises(ValueError):
Dimension(2, -1)
Dimension((2, -1))


def test_max_less_than_min():
with pytest.raises(ValueError):
Dimension(2, 1)
Dimension((2, 1))


def test_is_bounded():
dim = Dimension()
assert dim.is_bounded is False

dim = Dimension(2)
assert dim.is_bounded is False
assert dim.is_bounded is True

dim = Dimension(2, 2)
dim = Dimension((2, 2))
assert dim.is_bounded is True

dim = Dimension(2, 4)
dim = Dimension((2, 4))
assert dim.is_bounded is True

dim = Dimension(2, None)
dim = Dimension((2, None))
assert dim.is_bounded is False


Expand All @@ -86,15 +83,15 @@ def test_is_fixed():
assert dim.is_fixed is False

dim = Dimension(2)
assert dim.is_fixed is False
assert dim.is_fixed is True

dim = Dimension(2, 2)
dim = Dimension((2, 2))
assert dim.is_fixed is True

dim = Dimension(2, 4)
dim = Dimension((2, 4))
assert dim.is_fixed is False

dim = Dimension(2, None)
dim = Dimension((2, None))
assert dim.is_fixed is False


Expand All @@ -103,15 +100,15 @@ def test_is_variable():
assert dim.is_variable is True

dim = Dimension(2)
assert dim.is_variable is True
assert dim.is_variable is False

dim = Dimension(2, 2)
dim = Dimension((2, 2))
assert dim.is_variable is False

dim = Dimension(2, 4)
dim = Dimension((2, 4))
assert dim.is_variable is True

dim = Dimension(2, None)
dim = Dimension((2, None))
assert dim.is_variable is True


Expand Down
Loading