Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH - Gets SciKeras script working #394

Open
wants to merge 57 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
57 commits
Select commit Hold shift + click to select a range
925c960
Gets SciKeras script working
lazarust Oct 10, 2023
3500592
Fixes test_metainfo
lazarust Oct 10, 2023
492a1ca
Updates changes.rst
lazarust Oct 11, 2023
4c91a07
Merge branch 'main' into enh-get-scikeras-working
lazarust Oct 11, 2023
101b90c
Merge branch 'main' into enh-get-scikeras-working
lazarust Oct 13, 2023
08f5ba7
Merge branch 'main' into enh-get-scikeras-working
lazarust Oct 14, 2023
0491a7b
Merge branch 'main' into enh-get-scikeras-working
lazarust Oct 29, 2023
f1b93fe
Update changes.rst
lazarust Oct 29, 2023
bbe6b34
Adds test
lazarust Nov 14, 2023
3b274b4
Loads dumped model in test and checks output
lazarust Nov 18, 2023
e7ab34e
Refactor test_external.py to use dumps instead of dump
lazarust Nov 21, 2023
a1a92cc
Add TensorFlow as a dependent package
lazarust Nov 23, 2023
7b0f21e
WIP Still running into a recursion error
lazarust Nov 24, 2023
3397b17
Refactor imports and update test method
lazarust Nov 24, 2023
49a16f0
WIP Fix get_state function to include module and class
lazarust Nov 29, 2023
d4469f9
Merge branch 'main' into enh-get-scikeras-working
lazarust Dec 5, 2023
3b55471
Merge branch 'main' into enh-get-scikeras-working
lazarust Dec 14, 2023
b51ca0e
Reverts changes from previous implementation
lazarust Jan 4, 2024
dd8d6e1
Merge branch 'main' into enh-get-scikeras-working
lazarust Jan 11, 2024
f68eea6
WIP Fix saving of scikeras models in zip file
lazarust Jan 12, 2024
f304dc7
Fixes lines I missed when reverting previous commits
lazarust Jan 12, 2024
c131ebd
Switch to saving as a `.keras` file
lazarust Jan 17, 2024
846e72e
Updates to use TempFile to save the model
lazarust Jan 24, 2024
7f8593a
Fixes typo in comment
lazarust Jan 24, 2024
208839b
Merge branch 'main' into enh-get-scikeras-working
lazarust Feb 12, 2024
28e9f15
Merge branch 'main' into enh-get-scikeras-working
lazarust Feb 19, 2024
d1d260e
Merge branch 'main' into enh-get-scikeras-working
lazarust Mar 31, 2024
ac11b46
Add support for saving and loading scikeras models by adding _scikera…
lazarust Apr 1, 2024
a9b9dbf
Removes comment that isn't necessary now
lazarust Apr 1, 2024
e5eb579
Adds SciKerasNode
lazarust Apr 1, 2024
0eca68a
Update Keras import to TensorFlow and fix model loading
lazarust Apr 1, 2024
30f8993
Update dependencies for TensorFlow to be included in docs
lazarust Apr 1, 2024
1b1cdff
Adds scikeras to docs dependencies
lazarust Apr 1, 2024
5033112
Updates scikeras to version 0.13
lazarust Apr 17, 2024
6a8e821
Removes default trusted types
lazarust Apr 17, 2024
f119713
Merge branch 'main' into enh-get-scikeras-working
lazarust Apr 24, 2024
ab530f3
Add importing __future__ annotations in _scikeras.py
lazarust Apr 24, 2024
0d8efca
Update TensorFlow version to 2.16.0 in _min_dependencies.py
lazarust Apr 24, 2024
07e0d5a
Update scikeras version to 0.12.0 in _min_dependencies.py
lazarust Apr 24, 2024
cc08530
Update TensorFlow version to 2.13.0 in _min_dependencies.py
lazarust Apr 24, 2024
cc2aead
Update TensorFlow version to 2.12.0 in _min_dependencies.py
lazarust Apr 24, 2024
c96a50a
Merge branch 'main' into enh-get-scikeras-working
lazarust May 3, 2024
91983e8
Moves changes to the correct version
lazarust May 5, 2024
2f3ae7a
Merge branch 'main' into enh-get-scikeras-working
lazarust May 14, 2024
3739f1f
Ignores deprecation warning from protobuf
lazarust May 25, 2024
ce11bf0
Fixes deprecation warning from matplotlib
lazarust May 25, 2024
4c47aaf
Merge branch 'main' into enh-get-scikeras-working
lazarust Jun 9, 2024
12e2108
Merge branch 'main' into enh-get-scikeras-working
lazarust Jul 2, 2024
d677476
Fixes making scikears a hard dependency
lazarust Jul 2, 2024
6d10f71
Adds test for error on untrusted types
lazarust Jul 2, 2024
e307560
Cleans up unneeded ()
lazarust Jul 2, 2024
bb82961
Merge branch 'main' into enh-get-scikeras-working
lazarust Jul 14, 2024
76341fd
Merge branch 'main' into enh-get-scikeras-working
lazarust Aug 1, 2024
fa6b208
use TF directly
adrinjalali Aug 28, 2024
83891ed
Merge remote-tracking branch 'upstream/main' into enh-get-scikeras-wo…
adrinjalali Aug 28, 2024
e9b2dd0
move changelog
adrinjalali Aug 28, 2024
2d92168
add missing file
adrinjalali Aug 28, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/changes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ v0.9
:pr:`386` by :user:`Reid Johnson <reidjohnson>`.
- :func:`skops.hub_utils.get_model_output` and :func:`skops.hub_utils.push` are
deprecated and will be removed in version 0.10. :pr:`396` by `Adrin Jalali`_.
- Fix dumping Scikeras model failing because of maximum recursion depth. :pr:`388`
by :user:`Thomas Lazarus <lazarust>`.

v0.8
----
Expand Down
1 change: 1 addition & 0 deletions skops/_min_dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
"catboost": ("1.0", "tests", None),
"fairlearn": ("0.7.0", "docs, tests", None),
"rich": ("12", "tests, rich", None),
"scikeras": ("0.4.0", "tests", None),
}


Expand Down
5 changes: 4 additions & 1 deletion skops/io/_general.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import json
import operator
import uuid
import weakref
from functools import partial
from reprlib import Repr
from types import FunctionType, MethodType
Expand Down Expand Up @@ -46,10 +47,12 @@ def dict_get_state(obj: Any, save_context: SaveContext) -> dict[str, Any]:
for key, value in obj.items():
if isinstance(value, property):
continue
if isinstance(key, weakref.ref):
key = getattr(key(), "_name")
if np.isscalar(key) and hasattr(key, "item"):
# convert numpy value to python object
key = key.item() # type: ignore
content[key] = get_state(value, save_context)
content[str(key)] = get_state(value, save_context)
res["content"] = content
res["key_types"] = key_types
return res
Expand Down
15 changes: 7 additions & 8 deletions skops/io/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,14 +157,13 @@ def get_state(value, save_context: SaveContext) -> dict[str, Any]:
# fails with `get_state`, we try with json.dumps, if that fails, we raise
# the original error alongside the json error.

# TODO: This should help with fixing recursive references.
# if id(value) in save_context.memo:
# return {
# "__module__": None,
# "__class__": None,
# "__id__": id(value),
# "__loader__": "CachedNode",
# }
if id(value) in save_context.memo:
return {
"__module__": None,
"__class__": None,
"__id__": id(value),
"__loader__": "CachedNode",
}

__id__ = save_context.memoize(obj=value)

Expand Down
54 changes: 54 additions & 0 deletions skops/io/tests/test_external.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import pytest
from sklearn.datasets import make_classification, make_regression
from sklearn.pipeline import Pipeline

from skops.io import dumps, loads, visualize
from skops.io.tests._utils import assert_method_outputs_equal, assert_params_equal
Expand Down Expand Up @@ -427,3 +428,56 @@ def test_quantile_forest(self, quantile_forest, regr_data, trusted, tree_method)
assert_method_outputs_equal(estimator, loaded, X)

visualize(dumped, trusted=trusted)


class TestSciKeras:
"""Tests for SciKerasRegressor and SciKerasClassifier"""

@pytest.fixture(autouse=True)
def capture_stdout(self):
# Mock print and rich.print so that running these tests with pytest -s
# does not spam stdout. Other, more common methods of suppressing
# printing to stdout don't seem to work, perhaps because of pytest.
with patch("builtins.print", Mock()), patch("rich.print", Mock()):
yield

@pytest.fixture(autouse=True)
def keras(self):
scikeras = pytest.importorskip("scikeras")
return scikeras

@pytest.fixture
def trusted(self):
return [
"scikeras.wrappers.KerasClassifier",
"keras.models.Sequential",
"keras.layers.core.Dense",
"keras.layers.core.Input",
]

@pytest.fixture
def test_dumping_model(self, keras, trusted):
from scikeras.wrappers import KerasClassifier

# This simplifies the basic usage tutorial from https://adriangb.com/scikeras/stable/notebooks/Basic_Usage.html

def get_clf(meta):
n_features_in_ = meta["n_features_in_"]
model = keras.models.Sequential()
model.add(keras.layers.Input(shape=(n_features_in_,)))
model.add(keras.layers.Dense(1, activation="sigmoid"))
return model

clf = KerasClassifier(model=get_clf, loss="binary_crossentropy")

pipeline = Pipeline([("classifier", clf)])

dumps(clf, "keras-test.skops")
dumps(pipeline, "keras-test.skops")

X, y = make_classification(1000, 20, n_informative=10, random_state=0)
clf.fit(X, y)
dumped = dumps(clf, "keras-test.skops")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
dumped = dumps(clf, "keras-test.skops")
dumped = dumps(clf)

2nd argument to dumps is the compression level. Honestly, I'm surprised that this didn't raise an error.


loaded = loads(dumped, trusted=trusted)
assert_method_outputs_equal(clf, loaded, X)
5 changes: 5 additions & 0 deletions skops/io/tests/test_persist.py
Original file line number Diff line number Diff line change
Expand Up @@ -605,6 +605,11 @@ def fit(self, X, y=None, **fit_params):
for key, val_expected in expected.items():
for state in states:
val_state = state[key]

# skipping all the state values that are a cached node
if val_state["__loader__"] == "CachedNode":
continue

# check presence of "content"/"file" but not exact values
assert ("content" in val_state) or ("file" in val_state)
assert val_state["__class__"] == val_expected["__class__"]
Expand Down
Loading