Skip to content

Commit

Permalink
Merge pull request #127 from ganesh-k13/enh_py_version
Browse files Browse the repository at this point in the history
ENH: Added Python version to `vetiver_pin_write`
  • Loading branch information
isabelizimm authored Jan 24, 2023
2 parents 84d03ff + 9325d35 commit 7a2eaf8
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 1 deletion.
5 changes: 4 additions & 1 deletion vetiver/meta.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import sys
from dataclasses import dataclass, asdict, field
from typing import Mapping

Expand All @@ -10,6 +11,7 @@ class VetiverMeta:
version: "str | None" = None
url: "str | None" = None
required_pkgs: "list | None" = field(default_factory=list)
python_version: "tuple | None" = None

def to_dict(self) -> Mapping:
data = asdict(self)
Expand All @@ -25,9 +27,10 @@ def from_dict(cls, metadata, pip_name=None) -> "VetiverMeta":
version = metadata.get("version", None)
url = metadata.get("url", None)
required_pkgs = metadata.get("required_pkgs", [])
python_version = tuple(metadata.get("python_version", sys.version_info))

if pip_name:
if not list(filter(lambda x: pip_name in x, required_pkgs)):
required_pkgs = required_pkgs + [f"{pip_name}"]

return cls(user, version, url, required_pkgs)
return cls(user, version, url, required_pkgs, python_version)
1 change: 1 addition & 0 deletions vetiver/pin_read_write.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ def vetiver_pin_write(board, model: VetiverModel, versioned: bool = True):
"vetiver_meta": {
"required_pkgs": model.metadata.required_pkgs,
"prototype": None if not model.prototype else model.prototype().json(),
"python_version": list(model.metadata.python_version),
},
},
versioned=versioned,
Expand Down
38 changes: 38 additions & 0 deletions vetiver/tests/test_build_vetiver_model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import sklearn
import sys

import vetiver as vt
from vetiver.meta import VetiverMeta
Expand Down Expand Up @@ -113,6 +114,7 @@ def test_vetiver_model_use_ptype():
version=None,
url=None,
required_pkgs=["scikit-learn"],
python_version=tuple(sys.version_info),
)


Expand All @@ -137,5 +139,41 @@ def test_vetiver_model_from_pin():
assert v2.metadata.user == {"test": 123}
assert v2.metadata.version is not None
assert v2.metadata.required_pkgs == ["scikit-learn"]
assert v2.metadata.python_version == tuple(sys.version_info)

board.pin_delete("model")


def test_vetiver_model_from_pin_user_metadata():
"""
Test if standard keys as part of :dataclass:`VetiverMeta` are picked
"""
custom_meta = {
"test": 123,
"required_pkgs": ["foo", "bar"],
"python_version": [3, 10, 6, "final", 0],
}
loaded_pkgs = custom_meta["required_pkgs"] + ["scikit-learn"]

v = vt.VetiverModel(
model=model,
prototype_data=X_df,
model_name="model",
versioned=None,
description=None,
metadata=custom_meta,
)

board = pins.board_temp(allow_pickle_read=True)
vt.vetiver_pin_write(board=board, model=v)
v2 = vt.VetiverModel.from_pin(board, "model")

assert isinstance(v2, vt.VetiverModel)
assert isinstance(v2.model, sklearn.base.BaseEstimator)
assert isinstance(v2.prototype.construct(), pydantic.BaseModel)
assert v2.metadata.user == custom_meta
assert v2.metadata.version is not None
assert v2.metadata.required_pkgs == loaded_pkgs
assert v2.metadata.python_version == tuple(custom_meta["python_version"])

board.pin_delete("model")
3 changes: 3 additions & 0 deletions vetiver/vetiver_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ def from_pin(cls, board, name: str, version: str = None):
if "vetiver_meta" in meta.user:
get_prototype = meta.user.get("vetiver_meta").get("prototype", None)
required_pkgs = meta.user.get("vetiver_meta").get("required_pkgs", None)
python_version = meta.user.get("vetiver_meta").get("python_version", None)
meta.user.pop("vetiver_meta")
else:
# ptype = meta.user.get("ptype", None)
Expand All @@ -113,6 +114,7 @@ def from_pin(cls, board, name: str, version: str = None):
# get_prototype = None

required_pkgs = meta.user.get("required_pkgs")
python_version = meta.user.get("python_version")

return cls(
model=model,
Expand All @@ -123,6 +125,7 @@ def from_pin(cls, board, name: str, version: str = None):
"version": meta.version.version,
"url": meta.local.get("url"), # None all the time, besides Connect,
"required_pkgs": required_pkgs,
"python_version": python_version,
},
prototype_data=json.loads(get_prototype) if get_prototype else None,
versioned=True,
Expand Down

0 comments on commit 7a2eaf8

Please sign in to comment.