Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable pickling of models in "python" format #1298

Merged

Conversation

tlestang
Copy link
Contributor

Description

This defines the two methods __getstate__ and __setstate__ that can be used to control what is pickled when pickling a class
instance, as well as what happens when the object in unpickled.

In this case this is useful to avoid pickling the _evaluate attribute, which is a reference to a function that cannot be pickled.
This attribute is restored when unpicking by re-compiling the definition of the function evaluate.

Note that this does not make it possible to pickle model in "jax" format.

Fixes #1283 #734

Type of change

Please add a line in the relevant section of CHANGELOG.md to document the change (include PR #) - note reverse order of PR #s. If necessary, also add to the list of breaking changes.

  • New feature (non-breaking change which adds functionality)
  • Optimization (back-end change that speeds up the code)
  • Bug fix (non-breaking change which fixes an issue)

Key checklist:

  • No style issues: $ flake8
  • All tests pass: $ python run-tests.py --unit
  • The documentation builds: $ cd docs and then $ make clean; make html

You can run all three at once, using $ python run-tests.py --quick.

Further checks:

  • Code is commented, particularly in hard-to-understand areas
  • Tests added that prove fix is effective or that feature works

So that `inputs_list` is defined
@codecov
Copy link

codecov bot commented Dec 14, 2020

Codecov Report

Merging #1298 (e5ae0a1) into issue-849-parallel-processing (35491ba) will increase coverage by 0.00%.
The diff coverage is 100.00%.

Impacted file tree graph

@@                      Coverage Diff                       @@
##           issue-849-parallel-processing    #1298   +/-   ##
==============================================================
  Coverage                          98.09%   98.09%           
==============================================================
  Files                                270      270           
  Lines                              15176    15189   +13     
==============================================================
+ Hits                               14887    14900   +13     
  Misses                               289      289           
Impacted Files Coverage Δ
pybamm/expression_tree/operations/evaluate.py 98.15% <100.00%> (+0.06%) ⬆️
pybamm/solvers/base_solver.py 99.07% <100.00%> (+<0.01%) ⬆️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 35491ba...e5ae0a1. Read the comment docs.

Copy link
Member

@valentinsulzer valentinsulzer left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, neat fix :) I know you didn't introduce this but I just noticed lots of the comments in test_scipy_solver contradict the associated code ...

# Create model
model = pybamm.BaseModel()
# Covert to casadi instead of python to avoid pickling of
# "EvaluatorPython" objects.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

comment out of date

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, good catch

model.rhs = {var: -pybamm.InputParameter("rate") * var}
model.initial_conditions = {var: 1}
# No need to set parameters; can use base discretisation (no spatial
# operators)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

comment out of date

@@ -258,7 +298,7 @@ def test_model_solver_multiple_inputs_discontinuity_error(self):
domain = ["negative electrode", "separator", "positive electrode"]
var = pybamm.Variable("var", domain=domain)
model.rhs = {var: -pybamm.InputParameter("rate") * var}
model.initial_conditions = {var: 1}
model.initial_conditions = {var: 2 * pybamm.InputParameter("rate")}
# No need to set parameters; can use base discretisation (no spatial
# operators)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

comment out of date

):
solver.solve(model, t_eval, inputs=inputs_list, nproc=2)

def test_model_solver_multiple_inputs_initial_conditions_error(self):
def test_model_solver_multiple_inputs_jax_format_error(self):
# Create model
model = pybamm.BaseModel()
# Covert to casadi instead of python to avoid pickling of
# "EvaluatorPython" objects.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

comment out of date

pybamm/expression_tree/operations/evaluate.py Show resolved Hide resolved
Copy link
Contributor

@martinjrobins martinjrobins left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks great @tlestang , happy to merge. Had a couple of points below.

@@ -507,6 +508,23 @@ def evaluate(self, t=None, y=None, y_dot=None, inputs=None, known_evals=None):
else:
return result

def __getstate__(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks great @tlestang, do you think we can use the same trick for the Jax evaluator?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I naively tried the same trick with the Jax evaluator but couldn't get it to work as smoothly.
Something in jax.jit can't be pickle, but I don't know what.

AttributeError: Can't pickle local object 'jit.<locals>.f_jitted'

I tried not pickling attributes jit_evaluate and jac_evaluate in the same way, i.e.

def __getstate__(self):
    state = self.__dict__.copy()
    del state["_evaluate_jax"]
    del state["_jit_evaluate"]
    del state["_jac_evaluate"]
    return state

but still get the AttributeError..

@@ -589,6 +589,13 @@ def solve(
for inputs in inputs_list
]

# Cannot use multiprocessing with model in "jax" format
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that Jax is designed to support parallel solution of models with different inputs, so this should be doable. But it can be left for a PR later on

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay good, would be better to rely on Jax's native behavior anyway :)

@tlestang
Copy link
Contributor Author

Thanks @tinosulzer and @martinjrobins for the review. I will update the comments but before we merge that PR I'd like to investigate why the content of __getstate__ isn't tested..

@tlestang
Copy link
Contributor Author

tlestang commented Jan 6, 2021

This brought coverage to 100%. Now waiting to merge #1319 to fix CI on MacOS

@tlestang tlestang merged commit b273887 into issue-849-parallel-processing Jan 6, 2021
@valentinsulzer valentinsulzer deleted the issue-1283-pickle-python-format branch September 13, 2021 15:50
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants