-
-
Notifications
You must be signed in to change notification settings - Fork 543
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enable pickling of models in "python" format #1298
Enable pickling of models in "python" format #1298
Conversation
So that `inputs_list` is defined
Codecov Report
@@ Coverage Diff @@
## issue-849-parallel-processing #1298 +/- ##
==============================================================
Coverage 98.09% 98.09%
==============================================================
Files 270 270
Lines 15176 15189 +13
==============================================================
+ Hits 14887 14900 +13
Misses 289 289
Continue to review full report at Codecov.
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, neat fix :) I know you didn't introduce this but I just noticed lots of the comments in test_scipy_solver contradict the associated code ...
# Create model | ||
model = pybamm.BaseModel() | ||
# Covert to casadi instead of python to avoid pickling of | ||
# "EvaluatorPython" objects. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
comment out of date
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, good catch
model.rhs = {var: -pybamm.InputParameter("rate") * var} | ||
model.initial_conditions = {var: 1} | ||
# No need to set parameters; can use base discretisation (no spatial | ||
# operators) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
comment out of date
@@ -258,7 +298,7 @@ def test_model_solver_multiple_inputs_discontinuity_error(self): | |||
domain = ["negative electrode", "separator", "positive electrode"] | |||
var = pybamm.Variable("var", domain=domain) | |||
model.rhs = {var: -pybamm.InputParameter("rate") * var} | |||
model.initial_conditions = {var: 1} | |||
model.initial_conditions = {var: 2 * pybamm.InputParameter("rate")} | |||
# No need to set parameters; can use base discretisation (no spatial | |||
# operators) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
comment out of date
): | ||
solver.solve(model, t_eval, inputs=inputs_list, nproc=2) | ||
|
||
def test_model_solver_multiple_inputs_initial_conditions_error(self): | ||
def test_model_solver_multiple_inputs_jax_format_error(self): | ||
# Create model | ||
model = pybamm.BaseModel() | ||
# Covert to casadi instead of python to avoid pickling of | ||
# "EvaluatorPython" objects. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
comment out of date
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks great @tlestang , happy to merge. Had a couple of points below.
@@ -507,6 +508,23 @@ def evaluate(self, t=None, y=None, y_dot=None, inputs=None, known_evals=None): | |||
else: | |||
return result | |||
|
|||
def __getstate__(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks great @tlestang, do you think we can use the same trick for the Jax evaluator?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I naively tried the same trick with the Jax evaluator but couldn't get it to work as smoothly.
Something in jax.jit
can't be pickle, but I don't know what.
AttributeError: Can't pickle local object 'jit.<locals>.f_jitted'
I tried not pickling attributes jit_evaluate
and jac_evaluate
in the same way, i.e.
def __getstate__(self):
state = self.__dict__.copy()
del state["_evaluate_jax"]
del state["_jit_evaluate"]
del state["_jac_evaluate"]
return state
but still get the AttributeError
..
@@ -589,6 +589,13 @@ def solve( | |||
for inputs in inputs_list | |||
] | |||
|
|||
# Cannot use multiprocessing with model in "jax" format |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note that Jax is designed to support parallel solution of models with different inputs, so this should be doable. But it can be left for a PR later on
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay good, would be better to rely on Jax's native behavior anyway :)
Thanks @tinosulzer and @martinjrobins for the review. I will update the comments but before we merge that PR I'd like to investigate why the content of |
This brought coverage to 100%. Now waiting to merge #1319 to fix CI on MacOS |
Description
This defines the two methods
__getstate__
and__setstate__
that can be used to control what is pickled when pickling a classinstance, as well as what happens when the object in unpickled.
In this case this is useful to avoid pickling the
_evaluate
attribute, which is a reference to a function that cannot be pickled.This attribute is restored when unpicking by re-compiling the definition of the function
evaluate
.Note that this does not make it possible to pickle model in "jax" format.
Fixes #1283 #734
Type of change
Please add a line in the relevant section of CHANGELOG.md to document the change (include PR #) - note reverse order of PR #s. If necessary, also add to the list of breaking changes.
Key checklist:
$ flake8
$ python run-tests.py --unit
$ cd docs
and then$ make clean; make html
You can run all three at once, using
$ python run-tests.py --quick
.Further checks: