Skip to content

Commit

Permalink
Fix #99 - Corrected Python dtypes for union types
Browse files Browse the repository at this point in the history
  • Loading branch information
naegelejd committed Nov 11, 2023
1 parent ee10587 commit 64e8edd
Show file tree
Hide file tree
Showing 8 changed files with 113 additions and 32 deletions.
8 changes: 4 additions & 4 deletions python/test_model/binary.py
Original file line number Diff line number Diff line change
Expand Up @@ -1048,10 +1048,10 @@ def _write_aliased_open_generic(self, value: AliasedOpenGeneric[AliasedString, A
def _write_aliased_closed_generic(self, value: AliasedClosedGeneric) -> None:
_MyTupleSerializer(_binary.string_serializer, _binary.EnumSerializer(_binary.int32_serializer, Fruits)).write(self._stream, value)

def _write_aliased_optional(self, value: typing.Optional[AliasedOptional]) -> None:
def _write_aliased_optional(self, value: AliasedOptional) -> None:
_binary.OptionalSerializer(_binary.int32_serializer).write(self._stream, value)

def _write_aliased_generic_optional(self, value: typing.Optional[AliasedGenericOptional[yardl.Float32]]) -> None:
def _write_aliased_generic_optional(self, value: AliasedGenericOptional[yardl.Float32]) -> None:
_binary.OptionalSerializer(_binary.float32_serializer).write(self._stream, value)

def _write_aliased_generic_union_2(self, value: AliasedGenericUnion2[AliasedString, AliasedEnum]) -> None:
Expand Down Expand Up @@ -1087,10 +1087,10 @@ def _read_aliased_open_generic(self) -> AliasedOpenGeneric[AliasedString, Aliase
def _read_aliased_closed_generic(self) -> AliasedClosedGeneric:
return _MyTupleSerializer(_binary.string_serializer, _binary.EnumSerializer(_binary.int32_serializer, Fruits)).read(self._stream)

def _read_aliased_optional(self) -> typing.Optional[AliasedOptional]:
def _read_aliased_optional(self) -> AliasedOptional:
return _binary.OptionalSerializer(_binary.int32_serializer).read(self._stream)

def _read_aliased_generic_optional(self) -> typing.Optional[AliasedGenericOptional[yardl.Float32]]:
def _read_aliased_generic_optional(self) -> AliasedGenericOptional[yardl.Float32]:
return _binary.OptionalSerializer(_binary.float32_serializer).read(self._stream)

def _read_aliased_generic_union_2(self) -> AliasedGenericUnion2[AliasedString, AliasedEnum]:
Expand Down
8 changes: 4 additions & 4 deletions python/test_model/ndjson.py
Original file line number Diff line number Diff line change
Expand Up @@ -3842,12 +3842,12 @@ def _write_aliased_closed_generic(self, value: AliasedClosedGeneric) -> None:
json_value = converter.to_json(value)
self._write_json_line({"aliasedClosedGeneric": json_value})

def _write_aliased_optional(self, value: typing.Optional[AliasedOptional]) -> None:
def _write_aliased_optional(self, value: AliasedOptional) -> None:
converter = _ndjson.OptionalConverter(_ndjson.int32_converter)
json_value = converter.to_json(value)
self._write_json_line({"aliasedOptional": json_value})

def _write_aliased_generic_optional(self, value: typing.Optional[AliasedGenericOptional[yardl.Float32]]) -> None:
def _write_aliased_generic_optional(self, value: AliasedGenericOptional[yardl.Float32]) -> None:
converter = _ndjson.OptionalConverter(_ndjson.float32_converter)
json_value = converter.to_json(value)
self._write_json_line({"aliasedGenericOptional": json_value})
Expand Down Expand Up @@ -3902,12 +3902,12 @@ def _read_aliased_closed_generic(self) -> AliasedClosedGeneric:
converter = _MyTupleConverter(_ndjson.string_converter, _ndjson.EnumConverter(Fruits, np.int32, _fruits_name_to_value_map, _fruits_value_to_name_map))
return converter.from_json(json_object)

def _read_aliased_optional(self) -> typing.Optional[AliasedOptional]:
def _read_aliased_optional(self) -> AliasedOptional:
json_object = self._read_json_line("aliasedOptional", True)
converter = _ndjson.OptionalConverter(_ndjson.int32_converter)
return converter.from_json(json_object)

def _read_aliased_generic_optional(self) -> typing.Optional[AliasedGenericOptional[yardl.Float32]]:
def _read_aliased_generic_optional(self) -> AliasedGenericOptional[yardl.Float32]:
json_object = self._read_json_line("aliasedGenericOptional", True)
converter = _ndjson.OptionalConverter(_ndjson.float32_converter)
return converter.from_json(json_object)
Expand Down
16 changes: 8 additions & 8 deletions python/test_model/protocols.py
Original file line number Diff line number Diff line change
Expand Up @@ -5240,7 +5240,7 @@ def write_aliased_closed_generic(self, value: AliasedClosedGeneric) -> None:
self._write_aliased_closed_generic(value)
self._state = 8

def write_aliased_optional(self, value: typing.Optional[AliasedOptional]) -> None:
def write_aliased_optional(self, value: AliasedOptional) -> None:
"""Ordinal 4"""

if self._state != 8:
Expand All @@ -5249,7 +5249,7 @@ def write_aliased_optional(self, value: typing.Optional[AliasedOptional]) -> Non
self._write_aliased_optional(value)
self._state = 10

def write_aliased_generic_optional(self, value: typing.Optional[AliasedGenericOptional[yardl.Float32]]) -> None:
def write_aliased_generic_optional(self, value: AliasedGenericOptional[yardl.Float32]) -> None:
"""Ordinal 5"""

if self._state != 10:
Expand Down Expand Up @@ -5311,11 +5311,11 @@ def _write_aliased_closed_generic(self, value: AliasedClosedGeneric) -> None:
raise NotImplementedError()

@abc.abstractmethod
def _write_aliased_optional(self, value: typing.Optional[AliasedOptional]) -> None:
def _write_aliased_optional(self, value: AliasedOptional) -> None:
raise NotImplementedError()

@abc.abstractmethod
def _write_aliased_generic_optional(self, value: typing.Optional[AliasedGenericOptional[yardl.Float32]]) -> None:
def _write_aliased_generic_optional(self, value: AliasedGenericOptional[yardl.Float32]) -> None:
raise NotImplementedError()

@abc.abstractmethod
Expand Down Expand Up @@ -5437,7 +5437,7 @@ def read_aliased_closed_generic(self) -> AliasedClosedGeneric:
self._state = 8
return value

def read_aliased_optional(self) -> typing.Optional[AliasedOptional]:
def read_aliased_optional(self) -> AliasedOptional:
"""Ordinal 4"""

if self._state != 8:
Expand All @@ -5447,7 +5447,7 @@ def read_aliased_optional(self) -> typing.Optional[AliasedOptional]:
self._state = 10
return value

def read_aliased_generic_optional(self) -> typing.Optional[AliasedGenericOptional[yardl.Float32]]:
def read_aliased_generic_optional(self) -> AliasedGenericOptional[yardl.Float32]:
"""Ordinal 5"""

if self._state != 10:
Expand Down Expand Up @@ -5526,11 +5526,11 @@ def _read_aliased_closed_generic(self) -> AliasedClosedGeneric:
raise NotImplementedError()

@abc.abstractmethod
def _read_aliased_optional(self) -> typing.Optional[AliasedOptional]:
def _read_aliased_optional(self) -> AliasedOptional:
raise NotImplementedError()

@abc.abstractmethod
def _read_aliased_generic_optional(self) -> typing.Optional[AliasedGenericOptional[yardl.Float32]]:
def _read_aliased_generic_optional(self) -> AliasedGenericOptional[yardl.Float32]:
raise NotImplementedError()

@abc.abstractmethod
Expand Down
19 changes: 16 additions & 3 deletions python/test_model/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1067,10 +1067,10 @@ def __repr__(self) -> str:


class RecordWithAliasedOptionalGenericField(typing.Generic[T]):
v: typing.Optional[AliasedGenericOptional[T]]
v: AliasedGenericOptional[T]

def __init__(self, *,
v: typing.Optional[AliasedGenericOptional[T]] = None,
v: AliasedGenericOptional[T] = None,
):
self.v = v

Expand Down Expand Up @@ -1996,6 +1996,8 @@ def _mk_get_dtype():
dtype_map.setdefault(RecordWithVlenCollections, np.dtype([('vector', np.dtype(np.object_)), ('array', np.dtype(np.object_))], align=True))
dtype_map.setdefault(NamedNDArray, np.dtype(np.object_))
dtype_map.setdefault(AliasedMap, lambda type_args: np.dtype(np.object_))
dtype_map.setdefault(Int32OrString, np.dtype(np.object_))
dtype_map.setdefault(TimeOrDatetime, np.dtype(np.object_))
dtype_map.setdefault(RecordWithUnions, np.dtype([('null_or_int_or_string', np.dtype(np.object_)), ('date_or_datetime', np.dtype(np.object_))], align=True))
dtype_map.setdefault(Fruits, np.dtype(np.int32))
dtype_map.setdefault(UInt64Enum, np.dtype(np.uint64))
Expand All @@ -2014,9 +2016,10 @@ def _mk_get_dtype():
dtype_map.setdefault(AliasedEnum, get_dtype(Fruits))
dtype_map.setdefault(AliasedOpenGeneric, lambda type_args: get_dtype(types.GenericAlias(AliasedTuple, (type_args[0], type_args[1],))))
dtype_map.setdefault(AliasedClosedGeneric, get_dtype(types.GenericAlias(AliasedTuple, (AliasedString, AliasedEnum,))))
dtype_map.setdefault(typing.Optional[AliasedOptional], np.dtype([('has_value', np.dtype(np.bool_)), ('value', np.dtype(np.int32))], align=True))
dtype_map.setdefault(AliasedOptional, np.dtype([('has_value', np.dtype(np.bool_)), ('value', np.dtype(np.int32))], align=True))
dtype_map.setdefault(AliasedGenericOptional, lambda type_args: np.dtype([('has_value', np.dtype(np.bool_)), ('value', get_dtype(type_args[0]))], align=True))
dtype_map.setdefault(AliasedMultiGenericOptional, lambda type_args: np.dtype(np.object_))
dtype_map.setdefault(AliasedMultiGenericOptional, lambda type_args: np.dtype(np.object_))
dtype_map.setdefault(AliasedGenericUnion2, lambda type_args: get_dtype(types.GenericAlias(GenericUnion2, (type_args[0], type_args[1],))))
dtype_map.setdefault(AliasedGenericVector, lambda type_args: np.dtype(np.object_))
dtype_map.setdefault(AliasedGenericFixedVector, lambda type_args: get_dtype(type_args[0]))
Expand All @@ -2025,6 +2028,7 @@ def _mk_get_dtype():
dtype_map.setdefault(AliasedGenericDynamicArray, lambda type_args: np.dtype(np.object_))
dtype_map.setdefault(RecordWithOptionalGenericField, lambda type_args: np.dtype([('v', np.dtype([('has_value', np.dtype(np.bool_)), ('value', get_dtype(type_args[0]))], align=True))], align=True))
dtype_map.setdefault(RecordWithAliasedOptionalGenericField, lambda type_args: np.dtype([('v', get_dtype(types.GenericAlias(AliasedGenericOptional, (type_args[0],))))], align=True))
dtype_map.setdefault(UOrV, np.dtype(np.object_))
dtype_map.setdefault(RecordWithOptionalGenericUnionField, lambda type_args: np.dtype([('v', np.dtype(np.object_))], align=True))
dtype_map.setdefault(RecordWithAliasedOptionalGenericUnionField, lambda type_args: np.dtype([('v', get_dtype(types.GenericAlias(AliasedMultiGenericOptional, (type_args[0], type_args[1],))))], align=True))
dtype_map.setdefault(RecordWithGenericVectors, lambda type_args: np.dtype([('v', np.dtype(np.object_)), ('av', get_dtype(types.GenericAlias(AliasedGenericVector, (type_args[0],))))], align=True))
Expand All @@ -2034,15 +2038,24 @@ def _mk_get_dtype():
dtype_map.setdefault(RecordContainingGenericRecords, lambda type_args: np.dtype([('g1', get_dtype(types.GenericAlias(RecordWithOptionalGenericField, (type_args[0],)))), ('g1a', get_dtype(types.GenericAlias(RecordWithAliasedOptionalGenericField, (type_args[0],)))), ('g2', get_dtype(types.GenericAlias(RecordWithOptionalGenericUnionField, (type_args[0], type_args[1],)))), ('g2a', get_dtype(types.GenericAlias(RecordWithAliasedOptionalGenericUnionField, (type_args[0], type_args[1],)))), ('g3', get_dtype(types.GenericAlias(MyTuple, (type_args[0], type_args[1],)))), ('g3a', get_dtype(types.GenericAlias(AliasedTuple, (type_args[0], type_args[1],)))), ('g4', get_dtype(types.GenericAlias(RecordWithGenericVectors, (type_args[1],)))), ('g5', get_dtype(types.GenericAlias(RecordWithGenericFixedVectors, (type_args[1],)))), ('g6', get_dtype(types.GenericAlias(RecordWithGenericArrays, (type_args[1],)))), ('g7', get_dtype(types.GenericAlias(RecordWithGenericMaps, (type_args[0], type_args[1],))))], align=True))
dtype_map.setdefault(RecordContainingNestedGenericRecords, np.dtype([('f1', get_dtype(types.GenericAlias(RecordWithOptionalGenericField, (str,)))), ('f1a', get_dtype(types.GenericAlias(RecordWithAliasedOptionalGenericField, (str,)))), ('f2', get_dtype(types.GenericAlias(RecordWithOptionalGenericUnionField, (str, yardl.Int32,)))), ('f2a', get_dtype(types.GenericAlias(RecordWithAliasedOptionalGenericUnionField, (str, yardl.Int32,)))), ('nested', get_dtype(types.GenericAlias(RecordContainingGenericRecords, (str, yardl.Int32,))))], align=True))
dtype_map.setdefault(AliasedIntOrSimpleRecord, np.dtype(np.object_))
dtype_map.setdefault(AliasedNullableIntSimpleRecord, np.dtype(np.object_))
dtype_map.setdefault(typing.Optional[AliasedNullableIntSimpleRecord], np.dtype(np.object_))
dtype_map.setdefault(T0OrT1, np.dtype(np.object_))
dtype_map.setdefault(GenericRecordWithComputedFields, lambda type_args: np.dtype([('f1', np.dtype(np.object_))], align=True))
dtype_map.setdefault(Int32OrFloat32, np.dtype(np.object_))
dtype_map.setdefault(IntOrGenericRecordWithComputedFields, np.dtype(np.object_))
dtype_map.setdefault(RecordWithComputedFields, np.dtype([('array_field', np.dtype(np.object_)), ('array_field_map_dimensions', np.dtype(np.object_)), ('dynamic_array_field', np.dtype(np.object_)), ('fixed_array_field', np.dtype(np.int32), (3, 4,)), ('int_field', np.dtype(np.int32)), ('int8_field', np.dtype(np.int8)), ('uint8_field', np.dtype(np.uint8)), ('int16_field', np.dtype(np.int16)), ('uint16_field', np.dtype(np.uint16)), ('uint32_field', np.dtype(np.uint32)), ('int64_field', np.dtype(np.int64)), ('uint64_field', np.dtype(np.uint64)), ('size_field', np.dtype(np.uint64)), ('float32_field', np.dtype(np.float32)), ('float64_field', np.dtype(np.float64)), ('complexfloat32_field', np.dtype(np.complex64)), ('complexfloat64_field', np.dtype(np.complex128)), ('string_field', np.dtype(np.object_)), ('tuple_field', get_dtype(types.GenericAlias(MyTuple, (yardl.Int32, yardl.Int32,)))), ('vector_field', np.dtype(np.object_)), ('vector_of_vectors_field', np.dtype(np.object_)), ('fixed_vector_field', np.dtype(np.int32), (3,)), ('optional_named_array', np.dtype([('has_value', np.dtype(np.bool_)), ('value', get_dtype(NamedNDArray))], align=True)), ('int_float_union', np.dtype(np.object_)), ('nullable_int_float_union', np.dtype(np.object_)), ('union_with_nested_generic_union', np.dtype(np.object_)), ('map_field', np.dtype(np.object_))], align=True))
dtype_map.setdefault(GenericUnion3, lambda type_args: np.dtype(np.object_))
dtype_map.setdefault(GenericUnion3Alternate, lambda type_args: np.dtype(np.object_))
dtype_map.setdefault(RecordNotUsedInProtocol, np.dtype([('u1', get_dtype(types.GenericAlias(GenericUnion3, (yardl.Int32, yardl.Float32, str,)))), ('u2', get_dtype(types.GenericAlias(GenericUnion3Alternate, (yardl.Int32, yardl.Float32, str,))))], align=True))
dtype_map.setdefault(ArrayWithKeywordDimensionNames, np.dtype(np.object_))
dtype_map.setdefault(EnumWithKeywordSymbols, np.dtype(np.int32))
dtype_map.setdefault(RecordWithKeywordFields, np.dtype([('int_', np.dtype(np.object_)), ('sizeof', get_dtype(ArrayWithKeywordDimensionNames)), ('if_', get_dtype(EnumWithKeywordSymbols))], align=True))
dtype_map.setdefault(AcquisitionOrImage, np.dtype(np.object_))
dtype_map.setdefault(StringOrInt32, np.dtype(np.object_))
dtype_map.setdefault(Int32OrSimpleRecord, np.dtype(np.object_))
dtype_map.setdefault(Int32OrRecordWithVlens, np.dtype(np.object_))
dtype_map.setdefault(ImageFloatOrImageDouble, np.dtype(np.object_))

return get_dtype

Expand Down
27 changes: 27 additions & 0 deletions python/tests/test_generated_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,3 +168,30 @@ def test_get_dtype():
)

assert tm.get_dtype(typing.Union[tm.Int32, tm.Float32]) == np.object_

assert tm.get_dtype(tm.Int32OrString) == np.object_

assert tm.get_dtype(tm.TimeOrDatetime) == np.object_
assert tm.get_dtype(tm.Int32OrSimpleRecord) == np.object_

assert tm.get_dtype(tm.AliasedOptional) == np.dtype(
[("has_value", "?"), ("value", np.int32)], align=True
)

assert tm.get_dtype(tm.AliasedMultiGenericOptional[str, int]) == np.object_
assert tm.get_dtype(
tm.RecordWithAliasedOptionalGenericUnionField[str, int]
) == np.dtype(
[
(
"v",
np.object_,
)
],
align=True,
)

assert tm.get_dtype(tm.AliasedNullableIntSimpleRecord) == np.object_
assert (
tm.get_dtype(typing.Optional[tm.AliasedNullableIntSimpleRecord]) == np.object_
)
3 changes: 1 addition & 2 deletions tooling/internal/python/common/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -142,8 +142,7 @@ var TypeSyntaxWriter dsl.TypeSyntaxWriter[string] = func(self dsl.TypeSyntaxWrit
}

if nt, ok := t.(*dsl.NamedType); ok {
underlyingType := dsl.GetUnderlyingType(nt.Type)
if gt, ok := underlyingType.(*dsl.GeneralizedType); ok && gt.Cases.HasNullOption() {
if gt, ok := nt.Type.(*dsl.GeneralizedType); ok && gt.Cases.HasNullOption() && !gt.Cases.IsOptional() {
typeSyntax = fmt.Sprintf("typing.Optional[%s]", typeSyntax)
}
}
Expand Down
16 changes: 7 additions & 9 deletions tooling/internal/python/static_files/_dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,22 +51,20 @@ def get_dtype_impl(
],
t: Union[type, GenericAlias],
) -> np.dtype[Any]:
# type_args = list(filter(lambda t: type(t) != TypeVar, get_args(t)))
# Check dtype map for this type first
if (res := dtype_map.get(t, None)) is not None:
if callable(res):
raise RuntimeError(f"Generic type arguments not provided for {t}")
else:
return res

origin = get_origin(t)

if origin == Union or (
sys.version_info >= (3, 10) and isinstance(t, UnionType)
):
return _get_union_dtype(get_args(t))

# If t is found in dtype_map here, t is either a Python type
# or t is a types.GenericAlias with missing type arguments
if (res := dtype_map.get(t, None)) is not None:
if callable(res):
raise RuntimeError(f"Generic type arguments not provided for {t}")
else:
return res

# Here, t is either invalid (no dtype registered)
# or t is a types.GenericAlias with type arguments specified
if origin is not None and (res := dtype_map.get(origin, None)) is not None:
Expand Down
48 changes: 46 additions & 2 deletions tooling/internal/python/types/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -948,8 +948,52 @@ func writeGetDTypeFunc(w *formatting.IndentedWriter, ns *dsl.Namespace) {
root: true,
}

for _, t := range ns.TypeDefinitions {
fmt.Fprintf(w, "dtype_map.setdefault(%s, %s)\n", common.TypeSyntaxWithoutTypeParameters(t, ns.Name), typeDefinitionDTypeExpression(t, context))
unions := make(map[string]any)

writeUnionDtypeIfNeeded := func(td dsl.Node) {
dsl.Visit(td, func(self dsl.Visitor, node dsl.Node) {
switch node := node.(type) {
case *dsl.NamedType:
if gt, ok := node.Type.(*dsl.GeneralizedType); ok {
if gt.Cases.IsUnion() {
// Special handling for dtype entry of nullable aliased unions
if gt.Cases.HasNullOption() {
// This an aliased union, where null is one of the options, e.g. X = [null, int, float]
// register X: ... instead of typing.Optional[X]: ...
// by stripping away the null option
gtClone := *gt
gtClone.Cases = gtClone.Cases[1:]
ntClone := *node
ntClone.Type = &gtClone
td := &ntClone
fmt.Fprintf(w, "dtype_map.setdefault(%s, %s)\n", common.TypeSyntaxWithoutTypeParameters(td, ns.Name), typeDefinitionDTypeExpression(td, context))

}
// Return early - we use the alias name for this union type over the yardl-generate UnionClassName
return
}
}
case *dsl.GeneralizedType:
if node.Cases.IsUnion() {
unionClassName, _ := common.UnionClassName(node)
if _, ok := unions[unionClassName]; !ok {
unions[unionClassName] = nil
fmt.Fprintf(w, "dtype_map.setdefault(%s, %s)\n", unionClassName, typeDTypeExpression(node, context))
}
}

}
self.VisitChildren(node)
})
}

for _, td := range ns.TypeDefinitions {
writeUnionDtypeIfNeeded(td)
fmt.Fprintf(w, "dtype_map.setdefault(%s, %s)\n", common.TypeSyntaxWithoutTypeParameters(td, ns.Name), typeDefinitionDTypeExpression(td, context))
}

for _, td := range ns.Protocols {
writeUnionDtypeIfNeeded(td)
}

w.WriteStringln("\nreturn get_dtype")
Expand Down

0 comments on commit 64e8edd

Please sign in to comment.