diff --git a/daft/api_annotations.py b/daft/api_annotations.py index 0fd71c584e..85237de8cd 100644 --- a/daft/api_annotations.py +++ b/daft/api_annotations.py @@ -3,21 +3,29 @@ import functools import inspect import sys -from typing import Any, Callable, ForwardRef, Union +from typing import Any, Callable, ForwardRef, TypeVar, Union if sys.version_info < (3, 8): from typing_extensions import get_args, get_origin else: from typing import get_args, get_origin +if sys.version_info < (3, 10): + from typing_extensions import ParamSpec +else: + from typing import ParamSpec + from daft.analytics import time_df_method, time_func +T = TypeVar("T") +P = ParamSpec("P") + -def DataframePublicAPI(func: Callable[..., Any]) -> Callable[..., Any]: +def DataframePublicAPI(func: Callable[P, T]) -> Callable[P, T]: """A decorator to mark a function as part of the Daft DataFrame's public API.""" @functools.wraps(func) - def _wrap(*args, **kwargs): + def _wrap(*args: P.args, **kwargs: P.kwargs) -> T: type_check_function(func, *args, **kwargs) timed_method = time_df_method(func) return timed_method(*args, **kwargs) @@ -25,11 +33,11 @@ def _wrap(*args, **kwargs): return _wrap -def PublicAPI(func: Callable[..., Any]) -> Callable[..., Any]: +def PublicAPI(func: Callable[P, T]) -> Callable[P, T]: """A decorator to mark a function as part of the Daft public API.""" @functools.wraps(func) - def _wrap(*args, **kwargs): + def _wrap(*args: P.args, **kwargs: P.kwargs) -> T: type_check_function(func, *args, **kwargs) timed_func = time_func(func) return timed_func(*args, **kwargs)