Skip to content

Commit

Permalink
Casting modes (ivy-llc#26681)
Browse files Browse the repository at this point in the history
added support for casting modes to give the original expected output type, also modified promote_types to accept None
  • Loading branch information
RickSanchezStoic authored Oct 6, 2023
1 parent 3dd98b9 commit 2afd872
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 6 deletions.
49 changes: 43 additions & 6 deletions ivy/func_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,10 +188,10 @@ def cross_caster(intersect):
valid_float = sorted(ivy.valid_float_dtypes)
valid_int = sorted(ivy.valid_int_dtypes)
intersect = sorted(intersect)
if intersect == valid_int:
if set(valid_int).issubset(intersect):
# make dtype equal to default float
dtype = ivy.default_float_dtype()
elif intersect == valid_float:
elif set(valid_float).issubset(intersect):
# make dtype equal to default int
dtype = ivy.default_int_dtype()

Expand Down Expand Up @@ -1160,9 +1160,13 @@ def _wrap_function(
return to_wrap


def casting_modes_ops(fn):
def casting_modes_ops(fn, ret_dtype_target=None):
@functools.wraps(fn)
def method(*args, **kwargs):
# Get the function signature
signature = inspect.signature(fn)
# Extract argument names
arg_names = [param.name for param in signature.parameters.values()]
# we first check if it has unsupported/supported dtypes uniquely added to it
intersect = set(ivy.function_unsupported_dtypes(fn)).difference(
set(ivy.invalid_dtypes)
Expand All @@ -1179,7 +1183,10 @@ def method(*args, **kwargs):
# no unsupported dtype specified
return fn(*args, **kwargs)

# specifies which dtype to cast the output to
to_cast = None
if "dtype" in kwargs and kwargs["dtype"] is not None:
to_cast = kwargs["dtype"]
dtype = caster(kwargs["dtype"], intersect)
if dtype:
kwargs["dtype"] = ivy.as_native_dtype(dtype)
Expand All @@ -1194,7 +1201,36 @@ def mini_helper(x):

args = ivy.nested_map(mini_helper, args, include_derived=True)
kwargs = ivy.nested_map(mini_helper, kwargs)
return fn(*args, **kwargs)

if not to_cast and ret_dtype_target:
for arg in ret_dtype_target:
if arg:
to_cast, arg_mod = ivy.promote_types_of_inputs(
to_cast,
(
args[arg_names.index(arg)]
if arg not in kwargs
else kwargs[arg]
),
)
if arg not in kwargs:
args[arg_names.index(arg)] = (
arg_mod
if not ivy.is_array(args[arg_names.index(arg)])
else args[arg_names.index(arg)]
)
else:
kwargs[arg] = (
arg_mod
if not ivy.is_array(args[arg_names.index(arg)])
else kwargs[arg]
)

return (
ivy.astype(fn(*args, **kwargs), ivy.to_native(to_cast))
if to_cast
else fn(*args, **kwargs)
)

return method

Expand Down Expand Up @@ -1284,7 +1320,7 @@ def _dtype_device_wrapper_creator(attrib, t):
A wrapper function for the attribute.
"""

def _wrapper_outer(version_dict, version, exclusive=True):
def _wrapper_outer(version_dict, version, exclusive=True, ret_dtype_target=None):
def _wrapped(func):
val = _versioned_attribute_factory(
lambda: _dtype_from_version(version_dict, version), t
Expand Down Expand Up @@ -1335,7 +1371,8 @@ def _wrapped(func):
if "frontends" in func.__module__:
# it's a frontend func, no casting modes for this
return func
return casting_modes_ops(func)

return casting_modes_ops(func, ret_dtype_target=ret_dtype_target)

return _wrapped

Expand Down
3 changes: 3 additions & 0 deletions ivy/functional/ivy/data_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -2109,6 +2109,9 @@ def promote_types(
ret
The type that both input types promote to
"""
# in case either is of none type
if not (type1 and type2):
return type1 if type1 else type2
query = [ivy.as_ivy_dtype(type1), ivy.as_ivy_dtype(type2)]
query = tuple(query)
if query not in ivy.promotion_table:
Expand Down

0 comments on commit 2afd872

Please sign in to comment.