-
-
Notifications
You must be signed in to change notification settings - Fork 157
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Continuous normalizing flows #46
Comments
using DifferentialEquations
using Distributions
using Flux, DiffEqFlux, ForwardDiff
using Flux.Tracker
# Neural Network
function f(z, p)
α, β = p
tanh.(α.*z .+ β)
end
tspan = (0.0, 10.0)
function cnf(du,u,p,t)
z = @view u[1:end-1]
du[1:end-1] = f(z, p)
du[end] = -sum(ForwardDiff.jacobian((z)->f(z, p), z))
end
prob = ODEProblem(cnf,nothing,tspan,nothing)
p = param([0.0, 0.0]) # Initial Parameter Vector
params = Params([p])
function predict_adjoint(x)
diffeq_adjoint(p,prob,Tsit5(),u0=[x,false],
saveat=0.0:0.1:10.0,
sensealg=DiffEqFlux.SensitivityAlg(quad=false,
backsolve=true,autojacvec=false))
end
function loss_adjoint(xs)
pz = Normal(0.0, 1.0)
preds = [predict_adjoint(x)[:,end] for x in xs]
z = [pred[1] for pred in preds] # TODO better slicing
delta_logp = [pred[2] for pred in preds]
logpz = logpdf.(pz, z)
logpx = logpz - delta_logp
loss = -mean(logpx)
end
opt = ADAM(0.1)
raw_data = [[rand(Normal(2.0, 0.1)) for i in 1:100]]
data = Iterators.repeated(raw_data, 100);
Flux.train!(loss_adjoint, params, data, opt)
# check whether it looks standard normal
using Plots
preds = [predict_adjoint(r)[:,end] for r in raw_data[1]];
histogram([p[1].data for p in preds]) That's a version that should be suitable for library use. Needs to be made for batching and get some testing with Flux models though. |
I am getting this error
My status of DiffFluxEq is
|
You need OrdinaryDiffEq, DiffEqFlux, and DiffEqSensitivity master. If anyone could help generate the Project.toml files I will register |
If I replace that "neural network" with a Flux model the above code fails.
|
using OrdinaryDiffEq
using Distributions
using Flux, DiffEqFlux, ForwardDiff, Tracker
# Neural Network
nn = Dense(1,1,tanh)
p = DiffEqFlux.destructure(nn)
function f(z, p)
m = DiffEqFlux.restructure(nn,p)
return m(z)
end
tspan = Float32.((0.0, 10.0))
function cnf(du,u,p,t)
z = @view u[1:end-1]
du[1:end-1] = f(z, p)
du[end] = -sum(Tracker.jacobian((z)->f(z, p), z))
end
prob = ODEProblem(cnf,nothing,tspan,nothing)
p = param(Float32[0.0, 0.0]) # Initial Parameter Vector
params = Params([p])
function predict_adjoint(x)
diffeq_adjoint(p,prob,Tsit5(),u0=[x;false],
saveat=0.0:0.1:10.0,
sensealg=DiffEqFlux.SensitivityAlg(quad=false,
backsolve=true,autojacvec=true))
end
function loss_adjoint(xs)
pz = Normal(0.0, 1.0)
preds = [predict_adjoint(x)[:,end] for x in xs]
z = [pred[1] for pred in preds] # TODO better slicing
delta_logp = [pred[2] for pred in preds]
logpz = logpdf.(pz, z)
logpx = logpz - delta_logp
loss = -mean(logpx)
end
opt = ADAM(0.1)
raw_data = [Float32[rand(Normal(2.0, 0.1)) for i in 1:100]]
data = Iterators.repeated(raw_data, 100);
Flux.train!(loss_adjoint, params, data, opt)
# check whether it looks standard normal
using Plots
preds = [predict_adjoint(r)[:,end] for r in raw_data[1]]; is probably the closest we've gotten, and just needs to fix the nesting of Tracker: finding out why the gradient is undefined. |
This is a nice simplifying example. It works until the backpass, where in that case DiffEqSensitivity's autojacvec uses Flux to do the vjps. In that case, using OrdinaryDiffEq
using Distributions
using Flux, DiffEqFlux, ForwardDiff, Tracker
# Neural Network
nn = Dense(1,1,tanh)
p = Tracker.data(DiffEqFlux.destructure(nn))
DiffEqFlux.restructure(nn,p)([1.0])
tspan = Float32.((0.0, 10.0))
function cnf(du,u,p,t)
z = @view u[1:end-1]
m = DiffEqFlux.restructure(nn,p)
du[1:end-1] = m(z)
@show z
du[end] = -sum(Tracker.jacobian((z)->log.(z), z))
end
prob = ODEProblem(cnf,nothing,tspan,nothing)
p = param(Float32[0.0, 0.0]) # Initial Parameter Vector
params = Params([p])
function predict_adjoint(x)
diffeq_adjoint(p,prob,Tsit5(),u0=[x;false],
saveat=0.0:0.1:10.0,
sensealg=DiffEqFlux.SensitivityAlg(quad=false,
backsolve=true,autojacvec=true))
end
function loss_adjoint(xs)
pz = Normal(0.0, 1.0)
preds = [predict_adjoint(x)[:,end] for x in xs]
z = [pred[1] for pred in preds] # TODO better slicing
delta_logp = [pred[2] for pred in preds]
logpz = logpdf.(pz, z)
logpx = logpz - delta_logp
loss = -mean(logpx)
end
Tracker.jacobian((z)->log.(3 .* z.+z.^2), [10.0])
opt = ADAM(0.1)
raw_data = [Float32[rand(Normal(2.0, 0.1)) for i in 1:100]]
data = Iterators.repeated(raw_data, 100);
Flux.train!(loss_adjoint, params, data, opt) The problem is that Tracker.jacobian can't nest. How to work around this @MikeInnes ? |
using OrdinaryDiffEq
using Distributions
using Flux, DiffEqFlux, ForwardDiff, Tracker
# Neural Network
nn = Chain(Dense(1,1,tanh))
p = DiffEqFlux.destructure(nn)
tspan = Float32.((0.0, 10.0))
function cnf(u,p,t)
z = @view u[1:end-1]
m = DiffEqFlux.restructure(nn,p)
jac = -sum(Tracker.jacobian((z)->log.(z), z))
if u isa TrackedArray
res = Tracker.collect([m(z);jac])
else
res = Tracker.data([m(z);jac])
end
res
end
prob = ODEProblem(cnf,nothing,tspan,nothing)
params = Params([p])
function predict_adjoint(x)
diffeq_adjoint(p,prob,Tsit5(),u0=[x;false],
saveat=0.0:0.1:10.0,
sensealg=DiffEqFlux.SensitivityAlg(quad=false,
backsolve=true,autojacvec=true))
end
function loss_adjoint(xs)
pz = Normal(0.0, 1.0)
preds = [predict_adjoint(x)[:,end] for x in xs]
z = [pred[1] for pred in preds] # TODO better slicing
delta_logp = [pred[2] for pred in preds]
logpz = logpdf.(pz, z)
logpx = logpz - delta_logp
loss = -mean(logpx)
end
opt = ADAM(0.1)
raw_data = [Float32[rand(Normal(2.0, 0.1)) for i in 1:100]]
data = Iterators.repeated(raw_data, 1);
loss_adjoint(raw_data[1])
Flux.train!(loss_adjoint, params, data, opt)
iszero(Tracker.grad(nn[1].W)) works with FluxML/Tracker.jl#24 |
using OrdinaryDiffEq
using Distributions
using Flux, DiffEqFlux, Tracker
using LinearAlgebra: tr
using Plots
using Flux: @epochs, throttle
using Tracker: forward
# Neural Network
nn = Chain(Dense(1,1,swish), Dense(1,1,identity))
# We track the parameters.
p = Flux.data(DiffEqFlux.destructure(nn))
params = param(p)
ps = Flux.params(params)
tspan = Float32.((0.0, 10.0))
# We define tr(J) to support batching.
# But it's possible to use tr(Tracker.jacobian(m, z)), it works perfectly.
function divergence(f, x::AbstractArray)
y::AbstractArray, back = forward(f, x)
D, N = size(x)
ȳ(i) = [i == j for j = 1:D]
reduce(+, transpose([back(ȳ(i))[1][i, :] for i = 1:D]))
end
# Dynamics of the CNF
function cnf_dudt_(u::TrackedArray,p,t)
z = @view u[1:end-1, :]
m = DiffEqFlux.restructure(nn, p)
jac = -divergence(m, z)
Tracker.collect([m(z);jac])
end
function cnf_dudt_(u::AbstractArray,p,t)
z = @view u[1:end-1, :]
m = DiffEqFlux.restructure(nn, p)
jac = -divergence(m, z)
Tracker.data([m(z);jac])
end
function predict_adjoint(x)
diffeq_adjoint(params,prob,Tsit5(),u0=vcat(x, zeros(Float32, (1, size(x, 2)))),
saveat=0.0:0.1:10.0,
sensealg=DiffEqFlux.SensitivityAlg(quad=false,
backsolve=false,autojacvec=true))
end
# We want to be able to sample according to x = f^(-1)(z)
# We don't need the dynamics of log P here.
function f(u, p, t)
m = DiffEqFlux.restructure(nn,p)
return Tracker.data(m(u))
end
prob = ODEProblem(cnf_dudt_,nothing,tspan,nothing)
function invsample(x::AbstractArray)
# Remember that to train with respect to the NLL we actually went in the x -> u direction
# We want to go in the u -> x direction so we just solve the ODE backward in time.
invprob = ODEProblem(f,x,(10.0, 0.), params)
solve(invprob, Tsit5(),save_everystep=false)[2]
end
function invsample(x::Real)
invsample([x])[1]
end
opt = ADAM(0.1)
model = Normal(5., 0.1)
raw_data = [Float32.(rand(model, (1, 100)))]
data = Iterators.repeated(raw_data, 10)
function loss_adjoint(xs)
pz = Normal(0.0, 1.0)
preds = predict_adjoint(xs)[:, :,end]
z = preds[1, :]
delta_logp = preds[2, :]
logpz = logpdf.(pz, z)
logpx = logpz - delta_logp
loss = -mean(logpx)
end
cb = function()
# You can schedule the learning rate if you want.
opt.eta *= 0.95
pz = Normal(0.0, 1.0)
preds = Tracker.data(predict_adjoint(raw_data[1])[:, :,end])
zs = preds[1, :]
delta_logp = preds[2, :]
logpz = logpdf.(pz, zs)
logpx = logpz - delta_logp
loss = -mean(logpx)
perm = sortperm(raw_data[1][1, :])
pl = plot(raw_data[1][1, perm], Tracker.data.(exp.(logpx[perm])), xlims=(4, 6), ylims=(0, 5), title="Loss = $(loss)", label="Learned density")
plot!(t -> pdf(model, t), label="Real density")
samples = invsample(Float32.(rand(pz, (1, 500))))[1, :]
histogram!(samples, normalize=:pdf, alpha=.3, fillalpha=.3, label="Model samples")
display(pl)
end
# The very first invocation of predict_adjoint is very slow because of the JIT overhead.
# I don't think this is normal.
# Just be patient.
cb()
@epochs 100 Flux.train!(loss_adjoint, ps, data, opt; cb = throttle(cb, 100)) |
@jessebett so how should we "libraryitize" CNF? Clearly the layer should be given a nice function, but what about the loss function? Is that just specific to normal distributions, should we generalize it? Is this common or does it change depending on application? |
I tried to make it GPU-compatible: using OrdinaryDiffEq
using Distributions
using Flux, DiffEqFlux, Tracker
using LinearAlgebra: tr
using Plots
using Flux: @epochs, throttle
using Tracker: forward
using Adapt
#using CuArrays
using LinearAlgebra
# Neural Network
nn = Chain(Dense(1,1,swish), Dense(1,1,identity)) #|> gpu
# We track the parameters.
p = Flux.data(DiffEqFlux.destructure(nn))
params = param(p)
ps = Flux.params(params)
tspan = Float32.((0.0, 10.0))
function divergence(f, x::AbstractArray)
y::AbstractArray, back = forward(f, x)
D, N = size(x)
T = DiffEqFlux.gpu_or_cpu(x)
ȳ(i) = adapt(T,[i == j for j = 1:D])
tmp = [back(ȳ(i))[1][i, :] for i = 1:D]
adapt(T,reduce(+, transpose(tmp)))
end
Tracker.@grad function divergence(f, x::TrackedArray)
y::AbstractArray, back = forward(f, x)
D, N = size(x)
T = DiffEqFlux.gpu_or_cpu(x)
ȳ(i) = T([i == j for j = 1:D])
out = reduce(+, transpose([back(ȳ(i))[1][i, :] for i = 1:D]))
out, Δ -> begin
nothing, back(Δ)
end
end
# Dynamics of the CNF
function cnf_dudt_(u::TrackedArray,p,t)
z = u[1:end-1, :]
m = DiffEqFlux.restructure(nn, p)
jac = -divergence(m, z)
Tracker.collect([m(z);jac])
end
function cnf_dudt_(u::AbstractArray,p,t)
z = u[1:end-1, :]
m = DiffEqFlux.restructure(nn, p)
jac = -divergence(m, z)
Tracker.data([m(z);jac])
end
function predict_adjoint(x)
diffeq_adjoint(params,prob,Tsit5(),u0=vcat(x, zeros(Float32, (1, size(x, 2)))),
saveat=0.0:0.1:10.0,
sensealg=DiffEqFlux.SensitivityAlg(quad=false,
backsolve=false,autojacvec=true))
end
# We want to be able to sample according to x = f^(-1)(z)
# We don't need the dynamics of log P here.
function f(u, p, t)
m = DiffEqFlux.restructure(nn,p)
return Tracker.data(m(u))
end
prob = ODEProblem(cnf_dudt_,nothing,tspan,nothing)
function invsample(x::AbstractArray)
# Remember that to train with respect to the NLL we actually went in the x -> u direction
# We want to go in the u -> x direction so we just solve the ODE backward in time.
invprob = ODEProblem(f,x,(10.0, 0.), params)
solve(invprob, Tsit5(),save_everystep=false)[2]
end
function invsample(x::Real)
invsample([x])[1]
end
opt = ADAM(0.1)
model = Normal(5., 0.1)
raw_data = [Float32.(rand(model, (1, 100)))] #.|> gpu
data = Iterators.repeated(raw_data, 10)
function loss_adjoint(xs)
preds = predict_adjoint(xs)[:, :,end]
z = preds[1, :]
delta_logp = preds[2, :]
μ = 0.0
σ = 1.0
pz = Normal(μ, σ)
#logpz = logpdf.(pz, z)
logpz = -((((z .- μ) ./ σ ).^2 .+ log(2π))./2 .- log(σ))
logpx = logpz - delta_logp
loss = -mean(logpx)
end
cb = function()
# You can schedule the learning rate if you want.
opt.eta *= 0.95
preds = Tracker.data(predict_adjoint(raw_data[1])[:, :,end])
zs = preds[1, :]
delta_logp = preds[2, :]
μ = 0.0
σ = 1.0
pz = Normal(μ, σ)
#logpz = logpdf.(pz, zs)
logpz = -((((zs .- μ) ./ σ ).^2 .+ log(2π))./2 .- log(σ))
logpx = Array(logpz - delta_logp)
loss = -mean(logpx)
_raw_data = Array.(raw_data)
perm = sortperm(_raw_data[1][1, :])
pl = plot(_raw_data[1][1, perm], Tracker.data.(exp.(logpx[perm])), xlims=(4, 6), ylims=(0, 5), title="Loss = $(loss)", label="Learned density")
plot!(t -> pdf(model, t), label="Real density")
gendata = Float32.(rand(pz, (1, 500))) #|> gpu
samples = Array(invsample(gendata)[1, :])
histogram!(pl,samples, normalize=:pdf, alpha=.3, fillalpha=.3, label="Model samples")
display(pl)
end
CuArrays.allowscalar(false)
# The very first invocation of predict_adjoint is very slow because of the JIT overhead.
# I don't think this is normal.
# Just be patient.
cb()
@epochs 100 Flux.train!(loss_adjoint, ps, data, opt; cb = throttle(cb, 100)) The issue is that function divergence(f, x::AbstractArray)
y::AbstractArray, back = forward(f, x)
D, N = size(x)
T = DiffEqFlux.gpu_or_cpu(x)
ȳ(i) = adapt(T,[i == j for j = 1:D])
tmp = [back(ȳ(i))[1][i, :] for i = 1:D]
adapt(T,reduce(+, transpose(tmp)))
end the Tracker.@grad function divergence(f, x::TrackedArray)
y::AbstractArray, back = forward(f, x)
D, N = size(x)
T = DiffEqFlux.gpu_or_cpu(x)
ȳ(i) = T([i == j for j = 1:D])
out = reduce(+, transpose([back(ȳ(i))[1][i, :] for i = 1:D]))
out, Δ -> begin
nothing, back(Δ)
end
end
is the wrong adjoint. The adjoint of the divergence is the gradient, so I just took a stab at it but @MikeInnes might know how to fix this. |
I think the loss should be left to the user as it's really problem dependent, on the toy problem I'm playing with the loglikelihood isn't even computed that way. The base distribution doesn't have to be a normal distribution even if I'm pretty sure 99% people use a normal distribution for normalizing flows or VAEs, but if we think of it as a prior then I'm sure other more sensible and problem-specific distributions are useful. I started writing a NormalizingFlow library that would implement the API of Distributions.jl ( because we are just fitting distributions by MLE after all, we even have access to the PDF, CDF and can sample. We can implement all the methods. ) but I stopped when you said the API would change with Zygote. But it would be useful to play with NF + Turing :) Otherwise the only thing I would change in the API would be the ability to pass hyper-parameters to the ODE / predict_adjoint and all related functions. Right now we can only pass an AbstractArray p that will then be tracked. We would need to be able to pass a tracked p1 and untracked p2. Here p2 would just be the structure of the neural net. Having the net in the global scope is just dirty. It breaks for all kinds of reasons when playing in the repl too. So I'd love if ODEs took parameters of the type dudt(u, p::Tuple(AbstractArray, Any), t) or dudt(u,p::AbstractArray,t; hyperparams=Nothing) Implementing CNFs that way, as Distributions.jl objects would enable the CNFs to be mixed in other things like in Turing for example where they have Bijectors.jl |
@ChrisRackauckas @aussetg is right re loss function and requiring base distribution to be normal. However, to libraryize I think a convenience implementation could definitely be made in the normal distribution case, and then generalized for other distributions. A CNF (and FFJORD?) library implementation only needs the dynamics function
So as @aussetg says, the loss should definitely be left to the user. The e.g. |
using OrdinaryDiffEq
using Distributions, FiniteDiff
using Flux, DiffEqFlux, DistributionsAD
# Hack to fix Zygote AD of Normal
Base.Irrational{:log2π}(x::Int64) = Base.Irrational{:log2π}()
function f(z, p)
α, β = p
tanh.(α.*z .+ β)
end
u0 = Float32[0.0, 0.0]
tspan = (0f0, 10f0)
function cnf(u,p,t)
z, logpz = u
α, β = p
[f(z, p),-sum(FiniteDiff.finite_difference_jacobian((z)->f(z, p), [z]))]
end
prob = ODEProblem{false}(cnf,u0,tspan,nothing)
θinit = Float32[0.0, 0.0] # Initial Parameter Vector
function predict_adjoint(x,θ)
concrete_solve(prob,Tsit5(),[x,0f0],θ,
saveat=0f0:0.1f0:10f0)
end
function loss_adjoint(θ,xs)
pz = Normal(0.0, 1.0)
preds = [predict_adjoint(x,θ)[:,end] for x in xs]
z = [pred[1] for pred in preds] # TODO better slicing
delta_logp = [pred[2] for pred in preds]
logpz = logpdf.(pz, z)
logpx = logpz - delta_logp
loss = -mean(logpx)
end
function cb(θ,l)
@show l
false
end
opt = ADAM(0.1)
raw_data = Iterators.cycle(([Float32[rand(Normal(2.0, 0.1)) for i in 1:100]],))
DiffEqFlux.sciml_train(loss_adjoint, θinit, opt, raw_data, cb=cb, maxiters=100)
# check whether it looks standard normal
using Plots
preds = [predict_adjoint(r)[:,end] for r in raw_data[1]];
histogram([p[1].data for p in preds]) is a working finite difference version, and the following uses Tracker over Zygote to nest the reverse mode: using OrdinaryDiffEq, Zygote
using Distributions, FiniteDiff, DiffEqSensitivity
using Flux, DiffEqFlux, DistributionsAD
function jacobian(f, x::AbstractVector)
y::AbstractVector, back = Zygote.pullback(f, x)
ȳ(i) = [i == j for j = 1:length(y)]
vcat([transpose(back(ȳ(i))[1]) for i = 1:length(y)]...)
end
# Hack to fix Zygote AD of Normal
Base.Irrational{:log2π}(x::Int64) = Base.Irrational{:log2π}()
function f(z, p)
α, β = p
tanh.(α.*z .+ β)
end
u0 = Float32[0.0, 0.0]
tspan = (0f0, 10f0)
function cnf(u,p,t)
z, logpz = u
α, β = p
[f(z, p),-sum(jacobian((z)->f(z, p), [z]))]
end
prob = ODEProblem{false}(cnf,u0,tspan,nothing)
θinit = Float32[0.0, 0.0] # Initial Parameter Vector
function predict_adjoint(x,θ)
concrete_solve(prob,Tsit5(),[x,0f0],θ,
saveat=0f0:0.1f0:10f0,sensealg=InterpolatingAdjoint(
autojacvec=DiffEqSensitivity.TrackerVJP()))
end
function loss_adjoint(θ,xs)
pz = Normal(0.0, 1.0)
preds = [predict_adjoint(x,θ)[:,end] for x in xs]
z = [pred[1] for pred in preds] # TODO better slicing
delta_logp = [pred[2] for pred in preds]
logpz = logpdf.(pz, z)
logpx = logpz - delta_logp
loss = -mean(logpx)
end
function cb(θ,l)
@show l
false
end
opt = ADAM(0.01)
raw_data = Iterators.cycle(([Float32[rand(Normal(2.0, 0.1)) for i in 1:100]],))
DiffEqFlux.sciml_train(loss_adjoint, θinit, opt, raw_data, cb=cb, maxiters=100)
# check whether it looks standard normal
using Plots
preds = [predict_adjoint(r)[:,end] for r in raw_data[1]];
histogram([p[1].data for p in preds]) Now it's time to library-itize it. |
@abhigupta768 here is a summary: So @ChrisRackauckas wrote this after a few iterations but I’ll just quickly give an overview as far as I can tell.
This CNF is the dynamics we describe in the Neural ODE paper. I can get the equation number for you, but I highly recommend you read both the variational inference normalizing flows and the CNF section of Neural ODE paper to understand what this is doing. The TL;DR: the complete state The idea is we start with some The loss we are trying to optimize is to maximize the likelihood of
|
The text was updated successfully, but these errors were encountered: