-
-
Notifications
You must be signed in to change notification settings - Fork 157
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Example/Implementation of Neural Controlled Differential Equations #408
Comments
Yeah, that could be a nice model to implement. Let me know if you need any help optimizing it. |
@ChrisRackauckas I was able to get a version of this working although it is very slow. If there are any optimizations that stand out please let me know. Thanks! Here is the code for it:
|
Hey, here's an updated version with comments on what was done and timings. That little training step improved by about 6.6x: using Statistics
using DataInterpolations
using Flux
using OrdinaryDiffEq
using DiffEqSensitivity
using Zygote
using ProgressBars
using Random
T = Float32
bs = 512
X = [rand(T, 10, 50) for _ in 1:bs*10]
function create_spline(i)
x = X[i]
t = x[end, :]
t = (t .- minimum(t)) ./ (maximum(t) - minimum(t))
spline = QuadraticInterpolation(x, t)
end
splines = [create_spline(i) for i in tqdm(1:length(X))]
rand_inds = randperm(length(X))
i_sz = size(X[1], 1)
h_sz = 16
use_gpu = true
batches = [[splines[rand_inds[(i-1)*bs+1:i*bs]]] for i in tqdm(1:length(X)÷bs)]
data_ = Iterators.cycle(batches)
function call_and_cat(splines, t)
vals = Zygote.ignore() do
vals = reduce(hcat,[spline(t) for spline in splines])
end
vals |> (use_gpu ? gpu : cpu)
end
function derivative(A::QuadraticInterpolation, t::Number)
idx = findfirst(x -> x >= t, A.t) - 1
idx == 0 ? idx += 1 : nothing
if idx == length(A.t) - 1
i₀ = idx - 1; i₁ = idx; i₂ = i₁ + 1;
else
i₀ = idx; i₁ = i₀ + 1; i₂ = i₁ + 1;
end
dl₀ = (2t - A.t[i₁] - A.t[i₂]) / ((A.t[i₀] - A.t[i₁]) * (A.t[i₀] - A.t[i₂]))
dl₁ = (2t - A.t[i₀] - A.t[i₂]) / ((A.t[i₁] - A.t[i₀]) * (A.t[i₁] - A.t[i₂]))
dl₂ = (2t - A.t[i₀] - A.t[i₁]) / ((A.t[i₂] - A.t[i₀]) * (A.t[i₂] - A.t[i₁]))
@views @. A.u[:, i₀] * dl₀ + A.u[:, i₁] * dl₁ + A.u[:, i₂] * dl₂
end
function derivative_call_and_cat(splines, t)
vals = Zygote.ignore() do
reduce(hcat,[derivative(spline, t) for spline in splines]) |> (use_gpu ? gpu : cpu)
end
end
cde = Chain(
Dense(h_sz, h_sz, relu),
Dense(h_sz, h_sz*i_sz, tanh),
) |> (use_gpu ? gpu : cpu)
h_to_out = Dense(h_sz, 2) |> (use_gpu ? gpu : cpu)
initial = Dense(i_sz, h_sz) |> (use_gpu ? gpu : cpu)
cde_p, cde_re = Flux.destructure(cde)
initial_p, initial_re = Flux.destructure(initial)
h_to_out_p, h_to_out_re = Flux.destructure(h_to_out)
basic_tgrad(u,p,t) = zero(u)
function predict_func(p, BX)
By = call_and_cat(BX, 1)
x0 = call_and_cat(BX, 0)
i = 1
j = (i-1)+length(initial_p)
h0 = initial_re(p[i:j])(x0)
function dhdt(h,p,t)
x = derivative_call_and_cat(BX, t)
bs = size(h, 2)
a = reshape(cde_re(p)(h), (i_sz, h_sz, bs))
b = reshape(x, (1, i_sz, bs))
dh = batched_mul(b,a)[1,:,:]
end
i = j+1
j = (i-1)+length(cde_p)
tspan = (0.0f0, 0.8f0)
ff = ODEFunction{false}(dhdt,tgrad=basic_tgrad)
prob = ODEProblem{false}(ff,h0,tspan,p[i:j])
sense = InterpolatingAdjoint(autojacvec=ZygoteVJP())
solver = Tsit5()
sol = solve(prob,solver,u0=h0,saveat=tspan[end], save_start=false, sensealg=sense)
#@show sol.destats
i = j+1
j = (i-1)+length(h_to_out_p)
y_hat = h_to_out_re(p[i:j])(sol[end])
y_hat, By[1:2, :]
end
function loss_func(p, BX)
y_hat, y = predict_func(p, BX)
mean(sum(sqrt.((y .- y_hat).^2), dims=1))
end
p = vcat(initial_p, cde_p, h_to_out_p)
callback = function (p, l)
display(l)
return false
end
using DiffEqFlux
Zygote.gradient((p)->loss_func(p, first(data_)...),p)
@time Zygote.gradient((p)->loss_func(p, first(data_)...),p)
@time result_neuralode = DiffEqFlux.sciml_train(loss_func, p, ADAM(0.05),
data_,
cb = callback,
maxiters = 10)
# Start
# 13.178288 seconds (40.04 M allocations: 3.362 GiB, 5.01% gc time)
# Reduce(hcat)
# 6.273153 seconds (21.84 M allocations: 2.443 GiB, 7.86% gc time)
# @views @.
# 3.315527 seconds (5.05 M allocations: 495.695 MiB, 3.81% gc time)
# gpu in do
# 2.652512 seconds (5.11 M allocations: 466.430 MiB, 2.64% gc time)
# Training time before:
# 199.442675 seconds (218.19 M allocations: 23.603 GiB, 66.46% gc time)
# Training time after:
# 30.587359 seconds (58.69 M allocations: 5.210 GiB, 3.74% gc time) The rate limiting step here is that the spline data is on the CPU while your computations are on the GPU, so the most costly portion now is simply moving the spline output to the GPU. That's like 90% of the cost or something ridiculous now, so you'd have to tackle that problem and I was only giving myself 30 minutes to play with this. One way you could do this would be to make your quadratic spline asynchronously pre-cache some of the next time points onto the GPU while other computations are taking place. Or, even cooler, train a neural network to mimic the spline but be all on the GPU, and then use that in place of the spline. But shipping that much data every step is going to dominate the computation so it's gotta be dealt with somehow. What's the baseline you want to beat here? Do you have that code around to time? |
Awesome, thanks for reviewing the code and giving those optimizations! That's quite a speed up. I will try to figure out how to avoid moving the data from CPU to GPU every step. Btw the author of the paper / repo above recently released this new repo for this type of model: So probably we would want to show it outperforming some examples in that repo. The baseline I was comparing against was based off of the code in https://github.com/patrick-kidger/NeuralCDE/blob/master/example/example.py I'll do a speed comparison against that and the optimized code you posted. |
Here is the code I am using as the efficiency baseline: On my system the baseline code take about ~6 seconds to run whereas the optimized version you posted takes about ~18 seconds. I think once we fix moving the data to GPU at each step it should be a lot faster though.
|
I don't see any shuttling to GPUs there: are the spline coefficients on the GPU in that implementation? If they are, that would make a massive difference. Also, doing this as 1 spline instead of 5000 splines probably makes a decent difference. |
This it the file where the spline code is defined: It does seem like the spline operations are being done as a batch on the GPU
|
I wrapped the PyTorch spline functions using PyCall and CUDA.jl and I was able to get a speed up of ~5.5x over the torch version, it was 92.7 seconds for the 10 batches of torch version and ~17 seconds for the 10 batches of the DiffEqFlux version. Btw thanks for the awesome library! Here's what the calls looked like:
|
Awesome. So yeah, it would really be nice to get that directly implemented in Julia as a library function for people who want to use this method. It's the rate-limiting step. |
Just been pointed at this. Three quick comments:
Anyway, I'm not Julia-proficient but let me know if I can help out over here. |
I think his latest version (the version that was timed) is using your spline code IIUC |
In the Julia code I used for the benchmark I was doing a call out to your spline code so that I could do a consistent comparison. I did have some errors in the python code I posted above that I fixed in my script that I used for the benchmarking. Here is the updated script:
|
I wouldn't set the default device in the way that you're doing. In particular this doesn't perform the usual CPU-to-GPU copy you'll usually see when training a model. You're using I suggest using https://github.com/patrick-kidger/torchcde/blob/master/example/example.py as a reference point. |
Thanks for the clarification about the grid_points, I'll fix that. I am using fastai for creating the dataloaders, it automatically puts the batches on the gpu. I'll post a minimal runnable script shortly that takes out the fastai code. |
Here are updated Python and Julia scripts based on the example you linked to: EDIT: the loss for the Julia version isn't decreasing like the Python version, so I may have a bug with the model. I will try to fix that. Python: Time to run the training part at the bottom: ~58 seconds
Julia code using the spline from Python
I suspect the gap between the Julia and Python code may widen in Julia's favor with larger batch sizes, hidden sizes, etc. but I still think doing something with DiffEqGPU will probably be the way to get a major speed up. I'll be trying some of the ideas Chris mentioned in the other thread to see if I can get a speed up with that. |
I wouldn't get your hopes up there.
That's for a different use case, i.e. ensembles of small ODEs.
On the contrary, it would probably shrink as the rate limiting step is sooner or later going to be the cost of the GPU kernels, which if both are calling into CuBLAS then it'll be the same. So if they are taking the same number of steps (which they likely aren't due to some stabilizing tricks, but those are like 50% gains), then you'd expect the cost to be the same. It's when the kernels aren't fully saturated that the extra codegen and fusion matters. |
Okay, thanks for keeping me from going down that route. I was mainly thinking that if each trajectory has it's own solver they should be required to do many less function calls, whereas (if I'm not mistaken) the batch version has to do a stop whenever any of the trajectories needs to stop, so there would be many more function calls done for the current batch version. Is there a good way to reduce the number of function calls? |
Oh I read you wrong. If you goal is to make use of like 4 GPUs by running 4 trajectories at a time on different GPUs, yeah DiffEqGPU isn't the right tool but EnsembleDistributed where each Julia process has a different GPU will do this. You could also try to pack multiple onto the same GPU, but sooner or later you'll get memory limited.
That's a great research question. In the general context, that's just developing a "better" differential equation solver which can be hard work given how much they've been optimized, but there's still some tricks no one has done and we will have some new methods coming out soonish. But in the context of training a CDE, there's some other tricks one can employ. For example, you don't necessarily need to fit the ODE solves themselves: you can regularize to find solutions that are fast to solve, you drop accuracy and only increase accuracy after a decending a bit, etc. |
Some discrepancies:
@ChrisRackauckas: You mention that the number of solver steps can be reduced via some stabilising tricks - I'm curious what you're referring to specifically? |
@patrick-kidger @ChrisRackauckas I created a non-gpu, non-batch version that performs fairly well. I wasn't able to figure out the issue with the GPU version, I believe that there is some non-trivial issue with the gradient calculation. Below is the script for training on the data from your example. It probably could easily incorporate multiprocessing to speed it up. It has a use_linear variable for choosing whether to use linear interpolation or natural cubic interpolation. The linear interpolation is quite slow since it has the additional tstops, but the cubic interpolation version is quite fast. It takes ~83 seconds to run on 10 epochs of the data of your example. The ODE solves get faster as the loss goes down, so the time taken is dependent on the starting parameters and the order of the examples seen, so there is some variance in the time. I benchmarked against your example with the changes that you mentioned and found that on average the code took ~60 seconds to run. The script for that is included at the bottom. EDIT: I was accidentally using h_sz=16 instead of 8 for the Julia version, and I forgot to include the compilation warmup for the sciml_train, updated the script with those values.
|
What's the latest status on this project? Seems useful |
I don't think anyone has picked it up. In terms of differentiable interpolations, DataInterpolations.jl got some nice stable differentiability overloads, so this should be easy pickings but someone needs to package it all up. |
@ChrisRackauckas I refactored the CPU version and changed it to only use Julia code. I created a simple .md example file for it. Could you point me to a guide for how to open a pull request for it? |
https://www.youtube.com/watch?v=QVmU29rCjaA is a tutorial for all of that kind of stuff. |
The paper provides a good method for encoding data with an ODE. This would be useful for being the encoder of an encoder/decoder architecture as an alternative to using an RNN encoder.
https://arxiv.org/abs/2005.08926
https://github.com/patrick-kidger/NeuralCDE
I have played around with the example in the code repo, but it is very slow and could probably be significantly faster if written with DiffEqFlux.
The text was updated successfully, but these errors were encountered: