From 4f1241b7a953983f2769e00327e3c50bf5451a2c Mon Sep 17 00:00:00 2001 From: clarkzinzow Date: Wed, 20 Sep 2023 16:22:04 -0700 Subject: [PATCH] Fix public API decorator type annotations. --- daft/api_annotations.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/daft/api_annotations.py b/daft/api_annotations.py index 0fd71c584e..85237de8cd 100644 --- a/daft/api_annotations.py +++ b/daft/api_annotations.py @@ -3,21 +3,29 @@ import functools import inspect import sys -from typing import Any, Callable, ForwardRef, Union +from typing import Any, Callable, ForwardRef, TypeVar, Union if sys.version_info < (3, 8): from typing_extensions import get_args, get_origin else: from typing import get_args, get_origin +if sys.version_info < (3, 10): + from typing_extensions import ParamSpec +else: + from typing import ParamSpec + from daft.analytics import time_df_method, time_func +T = TypeVar("T") +P = ParamSpec("P") + -def DataframePublicAPI(func: Callable[..., Any]) -> Callable[..., Any]: +def DataframePublicAPI(func: Callable[P, T]) -> Callable[P, T]: """A decorator to mark a function as part of the Daft DataFrame's public API.""" @functools.wraps(func) - def _wrap(*args, **kwargs): + def _wrap(*args: P.args, **kwargs: P.kwargs) -> T: type_check_function(func, *args, **kwargs) timed_method = time_df_method(func) return timed_method(*args, **kwargs) @@ -25,11 +33,11 @@ def _wrap(*args, **kwargs): return _wrap -def PublicAPI(func: Callable[..., Any]) -> Callable[..., Any]: +def PublicAPI(func: Callable[P, T]) -> Callable[P, T]: """A decorator to mark a function as part of the Daft public API.""" @functools.wraps(func) - def _wrap(*args, **kwargs): + def _wrap(*args: P.args, **kwargs: P.kwargs) -> T: type_check_function(func, *args, **kwargs) timed_func = time_func(func) return timed_func(*args, **kwargs)