Skip to content

Commit

Permalink
Address reviews
Browse files Browse the repository at this point in the history
  • Loading branch information
Matt711 committed Jul 24, 2024
1 parent 7191b74 commit 48501be
Showing 1 changed file with 9 additions and 13 deletions.
22 changes: 9 additions & 13 deletions python/cudf/cudf/pandas/fast_slow_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -874,7 +874,6 @@ def _fast_slow_function_call(
func: Callable,
/,
*args,
**kwargs,
) -> Any:
"""
Call `func` with all `args` and `kwargs` converted to their
Expand All @@ -893,8 +892,8 @@ def _fast_slow_function_call(
color=_CUDF_PANDAS_NVTX_COLORS["EXECUTE_FAST"],
domain="cudf_pandas",
):
fast_args, fast_kwargs = _fast_arg(args), _fast_arg(kwargs)
result = func(*fast_args, **fast_kwargs)
fast_args = _fast_arg(args)
result = func(*fast_args)
if result is NotImplemented:
# try slow path
raise Exception()
Expand All @@ -906,12 +905,9 @@ def _fast_slow_function_call(
color=_CUDF_PANDAS_NVTX_COLORS["EXECUTE_SLOW"],
domain="cudf_pandas",
):
slow_args, slow_kwargs = (
_slow_arg(args),
_slow_arg(kwargs),
)
slow_args = (_slow_arg(args),)
with disable_module_accelerator():
slow_result = func(*slow_args, **slow_kwargs)
slow_result = func(*slow_args)
except Exception as e:
warnings.warn(
"The result from pandas could not be computed. "
Expand All @@ -936,10 +932,10 @@ def _fast_slow_function_call(
color=_CUDF_PANDAS_NVTX_COLORS["EXECUTE_SLOW"],
domain="cudf_pandas",
):
slow_args, slow_kwargs = _slow_arg(args), _slow_arg(kwargs)
slow_args = _slow_arg(args)
with disable_module_accelerator():
result = func(*slow_args, **slow_kwargs)
return _maybe_wrap_result(result, func, *args, **kwargs), fast
result = func(*slow_args)
return _maybe_wrap_result(result, func, *args), fast


def _transform_arg(
Expand Down Expand Up @@ -1054,7 +1050,7 @@ def _slow_arg(arg: Any) -> Any:
return _transform_arg(arg, "_fsproxy_slow", seen)


def _maybe_wrap_result(result: Any, func: Callable, /, *args, **kwargs) -> Any:
def _maybe_wrap_result(result: Any, func: Callable, /, *args) -> Any:
"""
Wraps "result" in a fast-slow proxy if is a "proxiable" object.
"""
Expand All @@ -1063,7 +1059,7 @@ def _maybe_wrap_result(result: Any, func: Callable, /, *args, **kwargs) -> Any:
return typ._fsproxy_wrap(result, func)
elif _is_intermediate_type(result):
typ = get_intermediate_type_map()[type(result)]
return typ._fsproxy_wrap(result, method_chain=(func, args, kwargs))
return typ._fsproxy_wrap(result, method_chain=(func, args))
elif _is_final_class(result):
return get_final_type_map()[result]
elif isinstance(result, list):
Expand Down

0 comments on commit 48501be

Please sign in to comment.