Skip to content

Commit

Permalink
Reduce the use of SymPy lambdify (#579)
Browse files Browse the repository at this point in the history
* evaluate SymPy parameter without lambdify

* clear linecache after use of lambdify

* comments

* formatting

* changelog, linting
  • Loading branch information
antalszava authored May 13, 2021
1 parent 127085e commit 39b0265
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 1 deletion.
6 changes: 5 additions & 1 deletion .github/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@

<h3>Bug fixes</h3>

* Fixed an unexpected behaviour that can result in increasing memory usage due
to ``sympy.lambdify`` caching too much data using ``linecache``.
[(#579)](https://github.com/XanaduAI/strawberryfields/pull/579)

<h3>Documentation</h3>

* References to the ``simulon`` simulator target have been rewritten to
Expand All @@ -25,7 +29,7 @@

This release contains contributions from (in alphabetical order):

Aaron Robertson, Jeremy Swinarton.
Aaron Robertson, Jeremy Swinarton, Antal Száva.

# Release 0.18.0 (current release)

Expand Down
13 changes: 13 additions & 0 deletions strawberryfields/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@
import collections.abc
import functools
import types
import linecache

import numpy as np
import sympy
Expand Down Expand Up @@ -191,14 +192,26 @@ def do_evaluate(p):

# using lambdify we can also substitute np.ndarrays and tf.Tensors for the atoms
atoms = list(p.atoms(MeasuredParameter, FreeParameter))

if not atoms:
# If there are not atomic values, we just convert to elementary
# Python types
return float(p) if p.is_real else complex(p)

# evaluate the atoms of the expression
vals = [k._eval_evalf(None) for k in atoms]
# use the tensorflow printer if any of the symbolic parameter values are TF objects
# (we do it like this to avoid importing tensorflow if it's not needed)
is_tf = (type(v).__module__.startswith("tensorflow") for v in vals)
printer = "tensorflow" if any(is_tf) else "numpy"

func = sympy.lambdify(atoms, p, printer)

# sympy.lambdify caches data using linecache, if called many times this
# can make up for a lot of memory used. We clear the cache here to
# avoid that.
linecache.clearcache()

if dtype is not None:
# cast the input values
if printer == "tensorflow":
Expand Down

0 comments on commit 39b0265

Please sign in to comment.