Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added conditional property to expose time scale predictions #327

Closed
wants to merge 12 commits into from
8 changes: 8 additions & 0 deletions sksurv/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from sklearn.pipeline import Pipeline, _final_estimator_has
from sklearn.utils.metaestimators import available_if

from .util import conditionalAvailableProperty


def _get_version(name):
try:
Expand Down Expand Up @@ -128,9 +130,15 @@ def predict_survival_function(self, X, **kwargs):
return self.steps[-1][-1].predict_survival_function(Xt, **kwargs)


@conditionalAvailableProperty(_final_estimator_has('_predict_risk_score'))
def _predict_risk_score(self):
return self.steps[-1][-1]._predict_risk_score


def patch_pipeline():
Pipeline.predict_survival_function = predict_survival_function
Pipeline.predict_cumulative_hazard_function = predict_cumulative_hazard_function
Pipeline._predict_risk_score = _predict_risk_score


try:
Expand Down
80 changes: 80 additions & 0 deletions sksurv/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,3 +260,83 @@ def safe_concat(objs, *args, **kwargs):
concatenated[name] = pd.Categorical(concatenated[name], **params)

return concatenated


class _ConditionalAvailableProperty:
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if you could inherit from the built-in property, because it is a class. This could allow removing the setter/getter/deleter code.

"""Implements a conditional property using the descriptor protocol based on the property decorator.

The corresponding class in scikit-learn (_AvailableIfDescriptor) only supports callables. This
class adopts the property decorator as presented by the descriptor guide in the offical documentation.

The check is defined on the getter function. Setter, Deleter also need to be defined on the exposing class
if they are supposed to be available.

See Also
--------
sklearn.utils.available_if._AvailableIfDescriptor:
The original class in scikit-learn.
"""

def __init__(self, check, fget=None, fset=None, fdel=None, doc=None):
self._check = check
self.fget = fget
self.fset = fset
self.fdel = fdel
if doc is None and fget is not None:
doc = fget.__doc__
self.__doc__ = doc
self._name = ''

def _run_check(self, obj):
attr_err = AttributeError(
f"This {repr(obj)} has no attribute {repr(self._name)}"
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of {repr(obj)}, you can use {obj!r}, and move the definition of attr_err into the if block.

)
if not self.check(obj):
raise attr_err

def __set_name__(self, owner, name):
self._name = name

def __get__(self, obj, objtype=None):
if obj is None:
return self
self._run_check(obj)
if self.fget is None:
raise AttributeError(f"property '{self._name}' has no getter")
return self.fget(obj)

def __set__(self, obj, value):
self._run_check(obj)
if self.fset is None:
raise AttributeError(f"property '{self._name}' has no setter")
self.fset(obj, value)

def __delete__(self, obj):
self._run_check(obj)
if self.fdel is None:
raise AttributeError(f"property '{self._name}' has no deleter")
self.fdel(obj)

@property
def check(self):
return self._check

def getter(self, fget):
prop = type(self)(self.check, fget, self.fset, self.fdel, self.__doc__)
prop._name = self._name
return prop

def setter(self, fset):
prop = type(self)(self.check, self.fget, fset, self.fdel, self.__doc__)
prop._name = self._name
return prop

def deleter(self, fdel):
prop = type(self)(self.check, self.fget, self.fset, fdel, self.__doc__)
prop._name = self._name
return prop


def conditionalAvailableProperty(check):
prop = _ConditionalAvailableProperty(check=check)
return prop.getter
18 changes: 18 additions & 0 deletions tests/test_aft.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import numpy as np
from numpy.testing import assert_array_almost_equal
import pytest
from sklearn.pipeline import make_pipeline

from sksurv.base import SurvivalAnalysisMixin
from sksurv.linear_model import IPCRidge
from sksurv.testing import assert_cindex_almost_equal

Expand Down Expand Up @@ -37,3 +39,19 @@ def test_predict(make_whas500):
)

assert model.score(x_test, y_test) == 0.66925817946226107

@staticmethod
def test_pipeline_score(make_whas500):
whas500 = make_whas500()
pipe = make_pipeline(IPCRidge())
pipe.fit(whas500.x[:400], whas500.y[:400])

x_test = whas500.x[400:]
y_test = whas500.y[400:]
p = pipe.predict(x_test)
assert_cindex_almost_equal(
y_test['fstat'], y_test['lenfol'], -p,
(0.66925817946226107, 2066, 1021, 0, 1),
)

assert SurvivalAnalysisMixin.score(pipe, x_test, y_test) == 0.66925817946226107
58 changes: 57 additions & 1 deletion tests/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import pytest

from sksurv.testing import FixtureParameterFactory
from sksurv.util import Surv, safe_concat
from sksurv.util import Surv, conditionalAvailableProperty, safe_concat


class ConcatCasesFactory(FixtureParameterFactory):
Expand Down Expand Up @@ -378,3 +378,59 @@ def test_from_dataframe(args, expected, expected_error):

if expected is not None:
assert_array_equal(y, expected)

def test_cond_avail_property():
class WithCondProp:
def __init__(self, val):
self.avail = False
self._prop = val

@conditionalAvailableProperty(lambda self: self.avail)
def prop(self):
return self._prop

@prop.setter
def prop(self, new):
self._prop = new

@prop.deleter
def prop(self):
self.avail = False

testval = 43
msg = "has no attribute 'prop'"

assert WithCondProp.prop is not None

test_obj = WithCondProp(testval)
with pytest.raises(AttributeError, match=msg):
_ = test_obj.prop
with pytest.raises(AttributeError, match=msg):
test_obj.prop = testval-1
with pytest.raises(AttributeError, match=msg):
del test_obj.prop
test_obj.avail = True
assert test_obj.prop == testval
test_obj.prop = testval-2
assert test_obj.prop == testval-2
del test_obj.prop
assert test_obj.avail is False
with pytest.raises(AttributeError, match=msg):
_ = test_obj.prop
with pytest.raises(AttributeError, match=msg):
test_obj.prop = testval-3
with pytest.raises(AttributeError, match=msg):
del test_obj.prop

test_obj.avail = True
WithCondProp.prop.fget = None
with pytest.raises(AttributeError, match='has no getter'):
_ = test_obj.prop
WithCondProp.prop.fset = None
with pytest.raises(AttributeError, match='has no setter'):
test_obj.prop = testval-4
WithCondProp.prop.fdel = None
with pytest.raises(AttributeError, match='has no deleter'):
del test_obj.prop