Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Type Hints for Python Package #7742

Merged
merged 26 commits into from
May 17, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 15 additions & 3 deletions python-package/xgboost/_typing.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,32 @@
"""Shared typing definition."""
import ctypes
import os
from typing import Optional, Any, TypeVar, Union, Sequence
from typing import Any, TypeVar, Union, Type, Sequence, Callable, List, Dict

# os.PathLike/string/numpy.array/scipy.sparse/pd.DataFrame/dt.Frame/
# cudf.DataFrame/cupy.array/dlpack
import numpy as np

DataType = Any

# xgboost accepts some other possible types in practice due to historical reason, which is
# lesser tested. For now we encourage users to pass a simple list of string.
FeatureNames = Optional[Sequence[str]]
FeatureTypes = Optional[Sequence[str]]
FeatureInfo = Sequence[str]
FeatureNames = FeatureInfo
FeatureTypes = FeatureInfo
BoosterParam = Union[List, Dict] # better be sequence

ArrayLike = Any
PathLike = Union[str, os.PathLike]
CupyT = ArrayLike # maybe need a stub for cupy arrays
NumpyOrCupy = Any
NumpyDType = Union[str, Type[np.number]]
PandasDType = Any # real type is pandas.core.dtypes.base.ExtensionDtype

FloatCompatible = Union[float, np.float32, np.float64]

# callables
FPreProcCallable = Callable

# ctypes
# c_bst_ulong corresponds to bst_ulong defined in xgboost/c_api.h
Expand Down Expand Up @@ -59,3 +70,4 @@

# template parameter
_T = TypeVar("_T")
_F = TypeVar("_F", bound=Callable[..., Any])
38 changes: 20 additions & 18 deletions python-package/xgboost/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@
import collections
import os
import pickle
from typing import Callable, List, Optional, Union, Dict, Tuple, TypeVar, cast
from typing import Sequence
from typing import Callable, List, Optional, Union, Dict, Tuple, TypeVar, cast, Sequence, Any
import numpy

from . import rabit
Expand All @@ -24,11 +23,14 @@
"EarlyStopping",
"EvaluationMonitor",
"TrainingCheckPoint",
"CallbackContainer"
]

_Score = Union[float, Tuple[float, float]]
_ScoreList = Union[List[float], List[Tuple[float, float]]]

_Model = Any # real type is Union[Booster, CVPack]; need more work


# pylint: disable=unused-argument
class TrainingCallback(ABC):
Expand All @@ -43,19 +45,19 @@ class TrainingCallback(ABC):
def __init__(self) -> None:
pass

def before_training(self, model):
def before_training(self, model: _Model) -> _Model:
'''Run before training starts.'''
return model

def after_training(self, model):
def after_training(self, model: _Model) -> _Model:
'''Run after training is finished.'''
return model

def before_iteration(self, model, epoch: int, evals_log: EvalsLog) -> bool:
def before_iteration(self, model: _Model, epoch: int, evals_log: EvalsLog) -> bool:
'''Run before each iteration. Return True when training should stop.'''
return False

def after_iteration(self, model, epoch: int, evals_log: EvalsLog) -> bool:
def after_iteration(self, model: _Model, epoch: int, evals_log: EvalsLog) -> bool:
'''Run after each iteration. Return True when training should stop.'''
return False

Expand Down Expand Up @@ -140,7 +142,7 @@ def __init__(
if self.is_cv:
self.aggregated_cv = None

def before_training(self, model):
def before_training(self, model: _Model) -> _Model:
'''Function called before training.'''
for c in self.callbacks:
model = c.before_training(model=model)
Expand All @@ -151,7 +153,7 @@ def before_training(self, model):
assert isinstance(model, Booster), msg
return model

def after_training(self, model):
def after_training(self, model: _Model) -> _Model:
'''Function called after training.'''
for c in self.callbacks:
model = c.after_training(model=model)
Expand Down Expand Up @@ -182,7 +184,7 @@ def after_training(self, model):
return model

def before_iteration(
self, model, epoch: int, dtrain: DMatrix, evals: List[Tuple[DMatrix, str]]
self, model: _Model, epoch: int, dtrain: DMatrix, evals: Optional[List[Tuple[DMatrix, str]]]
) -> bool:
'''Function called before training iteration.'''
return any(c.before_iteration(model, epoch, self.history)
Expand Down Expand Up @@ -220,7 +222,7 @@ def _update_history(

def after_iteration(
self,
model,
model: _Model,
epoch: int,
dtrain: DMatrix,
evals: Optional[List[Tuple[DMatrix, str]]],
Expand Down Expand Up @@ -276,7 +278,7 @@ def __init__(
super().__init__()

def after_iteration(
self, model, epoch: int, evals_log: TrainingCallback.EvalsLog
self, model: _Model, epoch: int, evals_log: TrainingCallback.EvalsLog
) -> bool:
model.set_param("learning_rate", self.learning_rates(epoch))
return False
Expand Down Expand Up @@ -344,12 +346,12 @@ def __init__(
self.starting_round: int = 0
super().__init__()

def before_training(self, model):
def before_training(self, model: _Model) -> _Model:
self.starting_round = model.num_boosted_rounds()
return model

def _update_rounds(
self, score: _Score, name: str, metric: str, model, epoch: int
self, score: _Score, name: str, metric: str, model: _Model, epoch: int
) -> bool:
def get_s(x: _Score) -> float:
"""get score if it's cross validation history."""
Expand Down Expand Up @@ -403,7 +405,7 @@ def minimize(new: _Score, best: _Score) -> bool:
return True
return False

def after_iteration(self, model, epoch: int,
def after_iteration(self, model: _Model, epoch: int,
evals_log: TrainingCallback.EvalsLog) -> bool:
epoch += self.starting_round # training continuation
msg = 'Must have at least 1 validation dataset for early stopping.'
Expand Down Expand Up @@ -431,7 +433,7 @@ def after_iteration(self, model, epoch: int,
score = data_log[metric_name][-1]
return self._update_rounds(score, data_name, metric_name, model, epoch)

def after_training(self, model):
def after_training(self, model: _Model) -> _Model:
try:
if self.save_best:
model = model[: int(model.attr("best_iteration")) + 1]
Expand Down Expand Up @@ -477,7 +479,7 @@ def _fmt_metric(
msg = f"\t{data + '-' + metric}:{score:.5f}"
return msg

def after_iteration(self, model, epoch: int,
def after_iteration(self, model: _Model, epoch: int,
evals_log: TrainingCallback.EvalsLog) -> bool:
if not evals_log:
return False
Expand All @@ -503,7 +505,7 @@ def after_iteration(self, model, epoch: int,
self._latest = msg
return False

def after_training(self, model):
def after_training(self, model: _Model) -> _Model:
if rabit.get_rank() == self.printer_rank and self._latest is not None:
rabit.tracker_print(self._latest)
return model
Expand Down Expand Up @@ -544,7 +546,7 @@ def __init__(
self._epoch = 0
super().__init__()

def after_iteration(self, model, epoch: int,
def after_iteration(self, model: _Model, epoch: int,
evals_log: TrainingCallback.EvalsLog) -> bool:
if self._epoch == self._iterations:
path = os.path.join(self._path, self._name + '_' + str(epoch) +
Expand Down
108 changes: 60 additions & 48 deletions python-package/xgboost/compat.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,32 @@
# coding: utf-8
# pylint: disable= invalid-name, unused-import
"""For compatibility and optional dependencies."""
from typing import Any
from typing import Any, Type, Dict, Optional, List
import sys
import types
import importlib.util
import logging
import numpy as np

from xgboost._typing import CStrPtr

assert (sys.version_info[0] == 3), 'Python 2 is no longer supported.'


def py_str(x):
def py_str(x: CStrPtr) -> str:
"""convert c string back to python string"""
return x.decode('utf-8')
return x.decode('utf-8') # type: ignore


def lazy_isinstance(instance, module, name):
def lazy_isinstance(instance: Type[object], module: str, name: str) -> bool:
"""Use string representation to identify a type."""

# Notice, we use .__class__ as opposed to type() in order
# to support object proxies such as weakref.proxy
cls = instance.__class__
module = cls.__module__ == module
name = cls.__name__ == name
return module and name
is_same_module = cls.__module__ == module
has_same_name = cls.__name__ == name
return is_same_module and has_same_name


# pandas
Expand All @@ -37,64 +39,68 @@ def lazy_isinstance(instance, module, name):
except ImportError:

MultiIndex = object
DataFrame: Any = object
DataFrame = object
Series = object
pandas_concat = None
PANDAS_INSTALLED = False

# sklearn
try:
from sklearn.base import BaseEstimator
from sklearn.base import RegressorMixin, ClassifierMixin
from sklearn.base import (
BaseEstimator as XGBModelBase,
RegressorMixin as XGBRegressorBase,
ClassifierMixin as XGBClassifierBase
)
from sklearn.preprocessing import LabelEncoder

try:
from sklearn.model_selection import KFold, StratifiedKFold
from sklearn.model_selection import (
KFold as XGBKFold,
StratifiedKFold as XGBStratifiedKFold
)
except ImportError:
from sklearn.cross_validation import KFold, StratifiedKFold
from sklearn.cross_validation import (
KFold as XGBKFold,
StratifiedKFold as XGBStratifiedKFold
)

SKLEARN_INSTALLED = True

XGBModelBase = BaseEstimator
XGBRegressorBase = RegressorMixin
XGBClassifierBase = ClassifierMixin

XGBKFold = KFold
XGBStratifiedKFold = StratifiedKFold

class XGBoostLabelEncoder(LabelEncoder):
'''Label encoder with JSON serialization methods.'''
def to_json(self):
'''Returns a JSON compatible dictionary'''
meta = {}
for k, v in self.__dict__.items():
if isinstance(v, np.ndarray):
meta[k] = v.tolist()
else:
meta[k] = v
return meta

def from_json(self, doc):
# pylint: disable=attribute-defined-outside-init
'''Load the encoder back from a JSON compatible dict.'''
meta = {}
for k, v in doc.items():
if k == 'classes_':
self.classes_ = np.array(v)
continue
meta[k] = v
self.__dict__.update(meta)
except ImportError:
SKLEARN_INSTALLED = False

# used for compatibility without sklearn
XGBModelBase = object
XGBClassifierBase = object
XGBRegressorBase = object
LabelEncoder = object

XGBKFold = None
XGBStratifiedKFold = None
XGBoostLabelEncoder = None


class XGBoostLabelEncoder(LabelEncoder):
'''Label encoder with JSON serialization methods.'''
def to_json(self) -> Dict:
'''Returns a JSON compatible dictionary'''
meta = {}
for k, v in self.__dict__.items():
if isinstance(v, np.ndarray):
meta[k] = v.tolist()
else:
meta[k] = v
return meta

def from_json(self, doc: Dict) -> None:
# pylint: disable=attribute-defined-outside-init
'''Load the encoder back from a JSON compatible dict.'''
meta = {}
for k, v in doc.items():
if k == 'classes_':
self.classes_ = np.array(v)
continue
meta[k] = v
self.__dict__.update(meta)


# dask
Expand All @@ -113,7 +119,7 @@ def from_json(self, doc):
SCIPY_INSTALLED = True
except ImportError:
scipy_sparse = False
scipy_csr: Any = object
scipy_csr = object
SCIPY_INSTALLED = False


Expand All @@ -136,15 +142,21 @@ class LazyLoader(types.ModuleType):
"""Lazily import a module, mainly to avoid pulling in large dependencies.
"""

def __init__(self, local_name, parent_module_globals, name, warning=None):
def __init__(
self,
local_name: str,
parent_module_globals: Dict,
name: str,
warning: Optional[str] = None
) -> None:
self._local_name = local_name
self._parent_module_globals = parent_module_globals
self._warning = warning
self.module = None
self.module: Optional[types.ModuleType] = None

super().__init__(name)

def _load(self):
def _load(self) -> types.ModuleType:
"""Load the module and insert it into the parent's globals."""
# Import the target module and insert it into the parent's namespace
module = importlib.import_module(self.__name__)
Expand All @@ -163,12 +175,12 @@ def _load(self):

return module

def __getattr__(self, item):
def __getattr__(self, item: str) -> Any:
if not self.module:
self.module = self._load()
return getattr(self.module, item)

def __dir__(self):
def __dir__(self) -> List[str]:
if not self.module:
self.module = self._load()
return dir(self.module)
Loading