Skip to content

Commit

Permalink
update readme to switch from partially broken instructions for how ot…
Browse files Browse the repository at this point in the history
… use numba to partially broken instructions for how to use de.jit
  • Loading branch information
Lilith Hafner authored and Lilith Hafner committed Oct 10, 2023
1 parent c4ff28a commit fad6b6f
Showing 1 changed file with 21 additions and 33 deletions.
54 changes: 21 additions & 33 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,7 @@ interpreter then run:
>>> diffeqpy.install()
```

and you're good! In addition, to improve the performance of your code it is
recommended that you use Numba to JIT compile your derivative functions. To
install Numba, use:

```
pip install numba
```
and you're good!

## General Flow

Expand Down Expand Up @@ -150,32 +144,30 @@ sol = de.solve(prob,de.Vern9(),saveat=0.1,abstol=1e-10,reltol=1e-10)
The set of algorithms for ODEs is described
[at the ODE solvers page](http://diffeq.sciml.ai/dev/solvers/ode_solve).

### Compilation with Numba and Julia
### Compilation with `de.jit` and Julia

When solving a differential equation, it's pertinent that your derivative
function `f` is fast since it occurs in the inner loop of the solver. We can
utilize Numba to JIT compile our derivative functions to improve the efficiency
of the solver:
convert the entire ode problem to symbolic form, optimize that symbolic form,
and emit efficient native code to simulate it using `de.jit` to improve the
efficiency of the solver at the expense of added setup time:

```py
import numba
numba_f = numba.jit(f)

prob = de.ODEProblem(numba_f, u0, tspan)
sol = de.solve(prob) # ERROR
fast_prob = de.jit(prob)
sol = de.solve(fast_prob)
```

Additionally, you can directly define the functions in Julia. This will allow
for more specialization and could be helpful to increase the efficiency over
the Numba version for repeat or long calls. This is done via `seval`:
Additionally, you can directly define the functions in Julia. This will also
allow for specialization and could be helpful to increase the efficiency for
repeat or long calls. This is done via `seval`:

```py
jul_f = de.seval("(u,p,t)->-u") # Define the anonymous function in Julia
prob = de.ODEProblem(jul_f, u0, tspan)
sol = de.solve(prob)
```

#### Note that when using Numba, one must avoid Python lists and pass state and parameters as NumPy arrays!
#### Note that when using `de.jit`, certain undocumented restrictions apply!!

### Systems of ODEs: Lorenz Equations

Expand Down Expand Up @@ -228,12 +220,12 @@ def f(du,u,p,t):
du[1] = x * (rho - z) - y
du[2] = x * y - beta * z

numba_f = numba.jit(f)
u0 = [1.0,0.0,0.0]
tspan = (0., 100.)
p = [10.0,28.0,2.66]
prob = de.ODEProblem(numba_f, u0, tspan, p)
sol = de.solve(prob)
prob = de.ODEProblem(f, u0, tspan, p)
jit_prob = de.jit(prob)
sol = de.solve(jit_prob)
```

or using a Julia function:
Expand Down Expand Up @@ -299,12 +291,10 @@ def g(du,u,p,t):
du[1] = 0.3*u[1]
du[2] = 0.3*u[2]

numba_f = numba.jit(f)
numba_g = numba.jit(g)
u0 = [1.0,0.0,0.0]
tspan = (0., 100.)
p = [10.0,28.0,2.66]
prob = de.SDEProblem(numba_f, numba_g, u0, tspan, p)
prob = de.jit(de.SDEProblem(f, g, u0, tspan, p))
sol = de.solve(prob)

# Now let's draw a phase plot
Expand Down Expand Up @@ -351,10 +341,9 @@ u0 = [1.0,0.0,0.0]
tspan = (0.0,100.0)
p = [10.0,28.0,2.66]
nrp = numpy.zeros((3,2))
numba_f = numba.jit(f)
numba_g = numba.jit(g)
prob = de.SDEProblem(numba_f,numba_g,u0,tspan,p,noise_rate_prototype=nrp)
sol = de.solve(prob,saveat=0.005)
prob = de.SDEProblem(f,g,u0,tspan,p,noise_rate_prototype=nrp)
jit_prob = de.jit(prob)
sol = de.solve(jit_prob,saveat=0.005)

# Now let's draw a phase plot

Expand Down Expand Up @@ -409,9 +398,9 @@ def f(resid,du,u,p,t):
resid[1] = + 0.04*u[0] - 3e7*u[1]**2 - 1e4*u[1]*u[2] - du[1]
resid[2] = u[0] + u[1] + u[2] - 1.0

numba_f = numba.jit(f)
prob = de.DAEProblem(numba_f,du0,u0,tspan,differential_vars=differential_vars)
sol = de.solve(prob) # ERROR
prob = de.DAEProblem(f,du0,u0,tspan,differential_vars=differential_vars)
jit_prob = de.jit(prob) # Error: no method matching matching modelingtoolkitize(::SciMLBase.DAEProblem{...})
sol = de.solve(jit_prob)
```

## Delay Differential Equations
Expand Down Expand Up @@ -476,7 +465,6 @@ Unit tests can be run by [`tox`](http://tox.readthedocs.io).

```sh
tox
tox -e py3-numba # test with Numba
```

### Troubleshooting
Expand Down

0 comments on commit fad6b6f

Please sign in to comment.