Skip to content

Commit

Permalink
refactor: Extract Duplicate Code into Helper Function in `ivy\func_…
Browse files Browse the repository at this point in the history
…wrapper.py` (ivy-llc#26572)

Co-authored-by: vedpatwardhan <[email protected]>
  • Loading branch information
Sai-Suraj-27 and vedpatwardhan authored Nov 30, 2023
1 parent 33e2be4 commit 99f184d
Showing 1 changed file with 21 additions and 71 deletions.
92 changes: 21 additions & 71 deletions ivy/func_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,91 +91,41 @@ def caster(dtype, intersect):
return ret_dtype


def cast_helper(arg, dtype, intersect, is_upcast=True):
step = 1 if is_upcast else -1
index = casting_modes_dict[arg]().index(dtype) + step
result = ""
while 0 <= index < len(casting_modes_dict[arg]()):
if casting_modes_dict[arg]()[index] not in intersect:
result = casting_modes_dict[arg]()[index]
break
index += step

return result


def upcaster(dtype, intersect):
# upcasting is enabled, we upcast to the highest
if "uint" in str(dtype):
index = casting_modes_dict["uint"]().index(dtype) + 1
result = ""
while index < len(casting_modes_dict["uint"]()):
if casting_modes_dict["uint"]()[index] not in intersect:
result = casting_modes_dict["uint"]()[index]
break
index += 1
return result

return cast_helper("uint", dtype, intersect, is_upcast=True)
if "int" in dtype:
index = casting_modes_dict["int"]().index(dtype) + 1
result = ""
while index < len(casting_modes_dict["int"]()):
if casting_modes_dict["int"]()[index] not in intersect:
result = casting_modes_dict["int"]()[index]
break
index += 1
return result

return cast_helper("int", dtype, intersect, is_upcast=True)
if "float" in dtype:
index = casting_modes_dict["float"]().index(dtype) + 1
result = ""
while index < len(casting_modes_dict["float"]()):
if casting_modes_dict["float"]()[index] not in intersect:
result = casting_modes_dict["float"]()[index]
break
index += 1
return result

return cast_helper("float", dtype, intersect, is_upcast=True)
if "complex" in dtype:
index = casting_modes_dict["complex"]().index(dtype) + 1
result = ""
while index < len(casting_modes_dict["complex"]()):
if casting_modes_dict["complex"]()[index] not in intersect:
result = casting_modes_dict["complex"]()[index]
break
index += 1
return result
return cast_helper("complex", dtype, intersect, is_upcast=True)


def downcaster(dtype, intersect):
# downcasting is enabled, we upcast to the highest
if "uint" in str(dtype):
index = casting_modes_dict["uint"]().index(dtype) - 1
result = ""
while index >= 0:
if casting_modes_dict["int"]()[index] not in intersect:
result = casting_modes_dict["uint"]()[index]
break
index -= 1
return result

return cast_helper("uint", dtype, intersect, is_upcast=False)
if "int" in dtype:
index = casting_modes_dict["int"]().index(dtype) - 1
result = ""
while index >= 0:
if casting_modes_dict["int"]()[index] not in intersect:
result = casting_modes_dict["int"]()[index]
break
index -= 1
return result

return cast_helper("int", dtype, intersect, is_upcast=False)
if "float" in dtype:
index = casting_modes_dict["float"]().index(dtype) - 1

result = ""
while index >= 0:
if casting_modes_dict["float"]()[index] not in intersect:
result = casting_modes_dict["float"]()[index]
break
index -= 1
return result

return cast_helper("float", dtype, intersect, is_upcast=False)
if "complex" in dtype:
index = casting_modes_dict["complex"]().index(dtype) - 1
result = ""
while index >= 0:
if casting_modes_dict["complex"]()[index] not in intersect:
result = casting_modes_dict["complex"]()[index]
break
index -= 1
return result
return cast_helper("complex", dtype, intersect, is_upcast=False)


def cross_caster(intersect):
Expand Down

0 comments on commit 99f184d

Please sign in to comment.