diff --git a/README.md b/README.md index 7b4d7da..9486e5c 100644 --- a/README.md +++ b/README.md @@ -7,8 +7,15 @@ The model is taken from the P2D model outlined in the [paper](http://web.mit.edu ## Getting started +### pip method (may not work on recent linux versions) + * `pip install numpy scipy` (or get these from your OS) * `pip install jax[cpu]` * `pip install scikit-umfpack` -Then run `examples/main_reorder.py`. +### conda method (tested on linux) + +* install conda according to your system : +* use the environment.yml to create a suitable conda environment : `conda env create -f environment.yml` + +Then run `examples/main_reorder.py` or `examples/reorder_test.py` diff --git a/decoupled/p2d_main_fast_fn.py b/decoupled/p2d_main_fast_fn.py index 4387c23..903b47c 100644 --- a/decoupled/p2d_main_fast_fn.py +++ b/decoupled/p2d_main_fast_fn.py @@ -23,7 +23,7 @@ video_name = 'video.avi' from utils.precompute_c import precompute from model.p2d_param import get_battery_sections - +from functools import partial @@ -71,7 +71,7 @@ def form_c2_n_jit(temp, j, T): val=vmap(fn,(0,None,0),1)(j,temp,Deff_vec) return val - @jax.partial(jax.jit, static_argnums=(2, 3,)) + @partial(jax.jit, static_argnums=(2, 3,)) def combine_c(cII, cI_vec, M,N): return np.reshape(cII, [M * (N + 2)], order="F") + cI_vec diff --git a/environment.yml b/environment.yml new file mode 100644 index 0000000..fd432f7 --- /dev/null +++ b/environment.yml @@ -0,0 +1,16 @@ +name: p2d_fast_solver-env +channels: + - defaults + - conda-forge +dependencies: + - python=3.10 + - jax=0.3.0 + - scikit-umfpack=0.3.3 + - gcc=12.1.0 + - suitesparse=5.10.1 + - numpy=1.22.3 + - matplotlib=3.7.1 + - numba=0.57.1 +## Optionals +# - nose=1.3.7 +# - jupyter=1.0.0 diff --git a/examples/reorder_test.py b/examples/reorder_test.py index 92ed446..ab590a4 100644 --- a/examples/reorder_test.py +++ b/examples/reorder_test.py @@ -753,7 +753,7 @@ def compute_der(U, Uold, cs_pe1, cs_ne1): diff0 = abs(Joutput - Jab) diff0_matrix = np.zeros_like(Joutput) -diff0_matrix[Jab.nonzero()] = diff0[Jab.nonzero()]/abs(Jab[Jab.nonzero()]) +diff0_matrix.at[Jab.nonzero()].set(diff0[Jab.nonzero()]/abs(Jab[Jab.nonzero()])) plt.figure() plt.imshow(diff0_matrix); plt.colorbar()