Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Continuous normalizing flows #46

Closed
ChrisRackauckas opened this issue Apr 18, 2019 · 14 comments
Closed

Continuous normalizing flows #46

ChrisRackauckas opened this issue Apr 18, 2019 · 14 comments
Labels

Comments

@ChrisRackauckas
Copy link
Member

using DifferentialEquations
using Distributions
using Flux, DiffEqFlux, ForwardDiff
using Flux.Tracker


function f(z, p)
  α, β = p
  tanh.(α.*z .+ β)
end

u0 = [0.0, 0.0]
tspan = (0.0, 10.0)
function cnf(du,u,p,t)
  z, logpz = u
  α, β = p
  du[1] = f(z, p)
  du[2] = -sum(ForwardDiff.jacobian((z)->f(z, p), [z]))
end
prob = ODEProblem(cnf,u0,tspan,nothing)

p = param([0.0, 0.0]) # Initial Parameter Vector
params = Params([p])

function predict_adjoint(x)
    diffeq_adjoint(p,prob,Tsit5(),u0=[x,0.0],
                   saveat=0.0:0.1:10.0,
                   sensealg=DiffEqFlux.SensitivityAlg(quad=false,
                                backsolve=true,autojacvec=false))
end

function loss_adjoint(xs)
    pz = Normal(0.0, 1.0)
    preds = [predict_adjoint(x)[:,end] for x in xs]
    z = [pred[1] for pred in preds] # TODO better slicing
    delta_logp = [pred[2] for pred in preds]

    logpz = logpdf.(pz, z)
    logpx = logpz - delta_logp
    loss = -mean(logpx)
end


opt = ADAM(0.1)

raw_data = [[rand(Normal(2.0, 0.1)) for i in 1:100]]
data = Iterators.repeated(raw_data, 100);

Flux.train!(loss_adjoint, params, data, opt)



# check whether it looks standard normal
using Plots

preds = [predict_adjoint(r)[:,end] for r in raw_data[1]];

histogram([p[1].data for p in preds])
@ChrisRackauckas
Copy link
Member Author

using DifferentialEquations
using Distributions
using Flux, DiffEqFlux, ForwardDiff
using Flux.Tracker

# Neural Network
function f(z, p)
  α, β = p
  tanh.(α.*z .+ β)
end

tspan = (0.0, 10.0)
function cnf(du,u,p,t)
  z = @view u[1:end-1]
  du[1:end-1] = f(z, p)
  du[end] = -sum(ForwardDiff.jacobian((z)->f(z, p), z))
end
prob = ODEProblem(cnf,nothing,tspan,nothing)

p = param([0.0, 0.0]) # Initial Parameter Vector
params = Params([p])

function predict_adjoint(x)
    diffeq_adjoint(p,prob,Tsit5(),u0=[x,false],
                   saveat=0.0:0.1:10.0,
                   sensealg=DiffEqFlux.SensitivityAlg(quad=false,
                                backsolve=true,autojacvec=false))
end

function loss_adjoint(xs)
    pz = Normal(0.0, 1.0)
    preds = [predict_adjoint(x)[:,end] for x in xs]
    z = [pred[1] for pred in preds] # TODO better slicing
    delta_logp = [pred[2] for pred in preds]

    logpz = logpdf.(pz, z)
    logpx = logpz - delta_logp
    loss = -mean(logpx)
end


opt = ADAM(0.1)

raw_data = [[rand(Normal(2.0, 0.1)) for i in 1:100]]
data = Iterators.repeated(raw_data, 100);

Flux.train!(loss_adjoint, params, data, opt)

# check whether it looks standard normal
using Plots

preds = [predict_adjoint(r)[:,end] for r in raw_data[1]];

histogram([p[1].data for p in preds])

That's a version that should be suitable for library use. Needs to be made for batching and get some testing with Flux models though.

@pevnak
Copy link

pevnak commented Apr 18, 2019

I am getting this error

julia> Flux.train!(loss_adjoint, params, data, opt)
ERROR: UndefVarError: uf not defined
Stacktrace:
 [1] #ODEAdjointProblem#18(::Array{Float64,1}, ::CallbackSet{Tuple{},Tuple{}}, ::LinearAlgebra.UniformScaling{Bool}, ::Function, ::ODESolution{Float64,2,Array{Array{Float64,1},1},Nothing,Nothing,Array{Float64,1},Array{Array{Array{Float64,1},1},1},ODEProblem{Array{Float64,1},Tuple{Float64,Float64},true,Array{Float64,1},ODEFunction{true,typeof(cnf),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Nothing,DiffEqBase.StandardODEProblem},Tsit5,OrdinaryDiffEq.InterpolationData{ODEFunction{true,typeof(cnf),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Array{Array{Float64,1},1},Array{Float64,1},Array{Array{Array{Float64,1},1},1},OrdinaryDiffEq.Tsit5Cache{Array{Float64,1},Array{Float64,1},Array{Float64,1},OrdinaryDiffEq.Tsit5ConstantCache{Float64,Float64}}}}, ::getfield(DiffEqFlux, Symbol("#df#27")){Bool}, ::Array{Float64,1}, ::Nothing, ::DiffEqSensitivity.SensitivityAlg{0,true,Val{:central}}) at /Users/tpevny/.julia/packages/DiffEqSensitivity/DI6VG/src/adjoint_sensitivity.jl:177
 [2] #ODEAdjointProblem at ./none:0 [inlined]

My status of DiffFluxEq is

(v1.1) pkg> st DiffEqFlux
    Status `~/.julia/environments/v1.1/Project.toml`
  [79e6a3ab] Adapt v0.4.2
  [aae7a2af] DiffEqFlux v0.4.0+ #master (https://github.com/JuliaDiffEq/DiffEqFlux.jl.git)
  [587475ba] Flux v0.8.2+ [`~/.julia/dev/Flux`]
  [f6369f11] ForwardDiff v0.10.3
  [10745b16] Statistics

@ChrisRackauckas
Copy link
Member Author

You need OrdinaryDiffEq, DiffEqFlux, and DiffEqSensitivity master. If anyone could help generate the Project.toml files I will register

@jessebett
Copy link
Contributor

If I replace that "neural network" with a Flux model the above code fails.

# Neural Network
nn = Dense(1,1,tanh)
pp = destructure(nn)
function f(z, p)
    m = restructure(nn,p)
    return m(z)
end

@ChrisRackauckas
Copy link
Member Author

using OrdinaryDiffEq
using Distributions
using Flux, DiffEqFlux, ForwardDiff, Tracker

# Neural Network
nn = Dense(1,1,tanh)
p = DiffEqFlux.destructure(nn)
function f(z, p)
    m = DiffEqFlux.restructure(nn,p)
    return m(z)
end

tspan = Float32.((0.0, 10.0))
function cnf(du,u,p,t)
  z = @view u[1:end-1]
  du[1:end-1] = f(z, p)
  du[end] = -sum(Tracker.jacobian((z)->f(z, p), z))
end
prob = ODEProblem(cnf,nothing,tspan,nothing)

p = param(Float32[0.0, 0.0]) # Initial Parameter Vector
params = Params([p])

function predict_adjoint(x)
    diffeq_adjoint(p,prob,Tsit5(),u0=[x;false],
                   saveat=0.0:0.1:10.0,
                   sensealg=DiffEqFlux.SensitivityAlg(quad=false,
                                backsolve=true,autojacvec=true))
end

function loss_adjoint(xs)
    pz = Normal(0.0, 1.0)
    preds = [predict_adjoint(x)[:,end] for x in xs]
    z = [pred[1] for pred in preds] # TODO better slicing
    delta_logp = [pred[2] for pred in preds]

    logpz = logpdf.(pz, z)
    logpx = logpz - delta_logp
    loss = -mean(logpx)
end

opt = ADAM(0.1)

raw_data = [Float32[rand(Normal(2.0, 0.1)) for i in 1:100]]
data = Iterators.repeated(raw_data, 100);

Flux.train!(loss_adjoint, params, data, opt)

# check whether it looks standard normal
using Plots

preds = [predict_adjoint(r)[:,end] for r in raw_data[1]];

is probably the closest we've gotten, and just needs to fix the nesting of Tracker: finding out why the gradient is undefined.

@ChrisRackauckas
Copy link
Member Author

This is a nice simplifying example. It works until the backpass, where in that case DiffEqSensitivity's autojacvec uses Flux to do the vjps. In that case, u is a TrackedArray and then it fails.

using OrdinaryDiffEq
using Distributions
using Flux, DiffEqFlux, ForwardDiff, Tracker

# Neural Network
nn = Dense(1,1,tanh)
p = Tracker.data(DiffEqFlux.destructure(nn))
DiffEqFlux.restructure(nn,p)([1.0])
tspan = Float32.((0.0, 10.0))
function cnf(du,u,p,t)
  z = @view u[1:end-1]
  m = DiffEqFlux.restructure(nn,p)
  du[1:end-1] = m(z)
  @show z
  du[end] = -sum(Tracker.jacobian((z)->log.(z), z))
end
prob = ODEProblem(cnf,nothing,tspan,nothing)

p = param(Float32[0.0, 0.0]) # Initial Parameter Vector
params = Params([p])

function predict_adjoint(x)
    diffeq_adjoint(p,prob,Tsit5(),u0=[x;false],
                   saveat=0.0:0.1:10.0,
                   sensealg=DiffEqFlux.SensitivityAlg(quad=false,
                                backsolve=true,autojacvec=true))
end

function loss_adjoint(xs)
    pz = Normal(0.0, 1.0)
    preds = [predict_adjoint(x)[:,end] for x in xs]
    z = [pred[1] for pred in preds] # TODO better slicing
    delta_logp = [pred[2] for pred in preds]

    logpz = logpdf.(pz, z)
    logpx = logpz - delta_logp
    loss = -mean(logpx)
end


Tracker.jacobian((z)->log.(3 .* z.+z.^2), [10.0])


opt = ADAM(0.1)

raw_data = [Float32[rand(Normal(2.0, 0.1)) for i in 1:100]]
data = Iterators.repeated(raw_data, 100);

Flux.train!(loss_adjoint, params, data, opt)

The problem is that Tracker.jacobian can't nest. How to work around this @MikeInnes ?

@ChrisRackauckas
Copy link
Member Author

ChrisRackauckas commented Apr 29, 2019

using OrdinaryDiffEq
using Distributions
using Flux, DiffEqFlux, ForwardDiff, Tracker

# Neural Network
nn = Chain(Dense(1,1,tanh))
p = DiffEqFlux.destructure(nn)
tspan = Float32.((0.0, 10.0))
function cnf(u,p,t)
  z = @view u[1:end-1]
  m = DiffEqFlux.restructure(nn,p)
  jac = -sum(Tracker.jacobian((z)->log.(z), z))
  if u isa TrackedArray
      res = Tracker.collect([m(z);jac])
  else
      res = Tracker.data([m(z);jac])
  end
  res
end

prob = ODEProblem(cnf,nothing,tspan,nothing)
params = Params([p])

function predict_adjoint(x)
    diffeq_adjoint(p,prob,Tsit5(),u0=[x;false],
                   saveat=0.0:0.1:10.0,
                   sensealg=DiffEqFlux.SensitivityAlg(quad=false,
                                backsolve=true,autojacvec=true))
end

function loss_adjoint(xs)
    pz = Normal(0.0, 1.0)
    preds = [predict_adjoint(x)[:,end] for x in xs]
    z = [pred[1] for pred in preds] # TODO better slicing
    delta_logp = [pred[2] for pred in preds]

    logpz = logpdf.(pz, z)
    logpx = logpz - delta_logp
    loss = -mean(logpx)
end

opt = ADAM(0.1)

raw_data = [Float32[rand(Normal(2.0, 0.1)) for i in 1:100]]
data = Iterators.repeated(raw_data, 1);

loss_adjoint(raw_data[1])

Flux.train!(loss_adjoint, params, data, opt)
iszero(Tracker.grad(nn[1].W))

works with FluxML/Tracker.jl#24

@aussetg
Copy link
Contributor

aussetg commented Aug 30, 2019

using OrdinaryDiffEq
using Distributions
using Flux, DiffEqFlux, Tracker
using LinearAlgebra: tr
using Plots
using Flux: @epochs, throttle
using Tracker: forward

# Neural Network
nn = Chain(Dense(1,1,swish), Dense(1,1,identity))

# We track the parameters.
p = Flux.data(DiffEqFlux.destructure(nn))
params = param(p)
ps = Flux.params(params)

tspan = Float32.((0.0, 10.0))

# We define tr(J) to support batching. 
# But it's possible to use tr(Tracker.jacobian(m, z)), it works perfectly.
function divergence(f, x::AbstractArray)
  y::AbstractArray, back = forward(f, x)
  D, N = size(x) 
  (i) = [i == j for j = 1:D]
  reduce(+, transpose([back((i))[1][i, :] for i = 1:D]))
end

# Dynamics of the CNF
function cnf_dudt_(u::TrackedArray,p,t)
    z = @view u[1:end-1, :]
    m = DiffEqFlux.restructure(nn, p)
    jac = -divergence(m, z)
    Tracker.collect([m(z);jac])
end

function cnf_dudt_(u::AbstractArray,p,t)
    z = @view u[1:end-1, :]
    m = DiffEqFlux.restructure(nn, p)
    jac = -divergence(m, z)
    Tracker.data([m(z);jac])
end

function predict_adjoint(x)
    diffeq_adjoint(params,prob,Tsit5(),u0=vcat(x, zeros(Float32, (1, size(x, 2)))),
                   saveat=0.0:0.1:10.0,
                   sensealg=DiffEqFlux.SensitivityAlg(quad=false,
                                backsolve=false,autojacvec=true))
end

# We want to be able to sample according to x = f^(-1)(z)
# We don't need the dynamics of log P here.
function f(u, p, t)
  m = DiffEqFlux.restructure(nn,p)
  return Tracker.data(m(u))
end

prob = ODEProblem(cnf_dudt_,nothing,tspan,nothing)

function invsample(x::AbstractArray)
    # Remember that to train with respect to the NLL we actually went in the x -> u direction
    # We want to go in the u -> x direction so we just solve the ODE backward in time.
    invprob = ODEProblem(f,x,(10.0, 0.), params)
    solve(invprob, Tsit5(),save_everystep=false)[2]
end

function invsample(x::Real)
    invsample([x])[1]
end

opt = ADAM(0.1)

model = Normal(5., 0.1)

raw_data = [Float32.(rand(model, (1, 100)))]
data = Iterators.repeated(raw_data, 10)

function loss_adjoint(xs)
    pz = Normal(0.0, 1.0)
    preds = predict_adjoint(xs)[:, :,end]
    z = preds[1, :]
    delta_logp = preds[2, :]

    logpz = logpdf.(pz, z)
    logpx = logpz - delta_logp
    loss = -mean(logpx)
end

cb = function()
    # You can schedule the learning rate if you want.
    opt.eta *= 0.95
    
    pz = Normal(0.0, 1.0)
    
    preds = Tracker.data(predict_adjoint(raw_data[1])[:, :,end])
    zs = preds[1, :]
    delta_logp = preds[2, :]
    
    logpz = logpdf.(pz, zs)
    logpx = logpz - delta_logp
    loss = -mean(logpx)
    
    perm = sortperm(raw_data[1][1, :])
    pl = plot(raw_data[1][1, perm], Tracker.data.(exp.(logpx[perm])), xlims=(4, 6), ylims=(0, 5), title="Loss = $(loss)", label="Learned density")
    plot!(t -> pdf(model, t), label="Real density")
    samples = invsample(Float32.(rand(pz, (1, 500))))[1, :]
    histogram!(samples, normalize=:pdf, alpha=.3, fillalpha=.3, label="Model samples")
    display(pl)
end

# The very first invocation of predict_adjoint is very slow because of the JIT overhead. 
# I don't think this is normal.
# Just be patient.
cb()

@epochs 100 Flux.train!(loss_adjoint, ps, data, opt; cb = throttle(cb, 100))

This one works for me.
plot

@ChrisRackauckas
Copy link
Member Author

@jessebett so how should we "libraryitize" CNF? Clearly the layer should be given a nice function, but what about the loss function? Is that just specific to normal distributions, should we generalize it? Is this common or does it change depending on application?

@ChrisRackauckas
Copy link
Member Author

I tried to make it GPU-compatible:

using OrdinaryDiffEq
using Distributions
using Flux, DiffEqFlux, Tracker
using LinearAlgebra: tr
using Plots
using Flux: @epochs, throttle
using Tracker: forward
using Adapt
#using CuArrays
using LinearAlgebra

# Neural Network
nn = Chain(Dense(1,1,swish), Dense(1,1,identity)) #|> gpu

# We track the parameters.
p = Flux.data(DiffEqFlux.destructure(nn))
params = param(p)
ps = Flux.params(params)

tspan = Float32.((0.0, 10.0))

function divergence(f, x::AbstractArray)
  y::AbstractArray, back = forward(f, x)
  D, N = size(x)
  T = DiffEqFlux.gpu_or_cpu(x)
  (i) = adapt(T,[i == j for j = 1:D])
  tmp = [back((i))[1][i, :] for i = 1:D]
  adapt(T,reduce(+, transpose(tmp)))
end

Tracker.@grad function divergence(f, x::TrackedArray)
  y::AbstractArray, back = forward(f, x)
  D, N = size(x)
  T = DiffEqFlux.gpu_or_cpu(x)
  (i) = T([i == j for j = 1:D])
  out = reduce(+, transpose([back((i))[1][i, :] for i = 1:D]))
  out, Δ -> begin
    nothing, back(Δ)
  end
end

# Dynamics of the CNF
function cnf_dudt_(u::TrackedArray,p,t)
    z = u[1:end-1, :]
    m = DiffEqFlux.restructure(nn, p)
    jac = -divergence(m, z)
    Tracker.collect([m(z);jac])
end

function cnf_dudt_(u::AbstractArray,p,t)
    z = u[1:end-1, :]
    m = DiffEqFlux.restructure(nn, p)
    jac = -divergence(m, z)
    Tracker.data([m(z);jac])
end

function predict_adjoint(x)
    diffeq_adjoint(params,prob,Tsit5(),u0=vcat(x, zeros(Float32, (1, size(x, 2)))),
                   saveat=0.0:0.1:10.0,
                   sensealg=DiffEqFlux.SensitivityAlg(quad=false,
                                backsolve=false,autojacvec=true))
end

# We want to be able to sample according to x = f^(-1)(z)
# We don't need the dynamics of log P here.
function f(u, p, t)
  m = DiffEqFlux.restructure(nn,p)
  return Tracker.data(m(u))
end

prob = ODEProblem(cnf_dudt_,nothing,tspan,nothing)

function invsample(x::AbstractArray)
    # Remember that to train with respect to the NLL we actually went in the x -> u direction
    # We want to go in the u -> x direction so we just solve the ODE backward in time.
    invprob = ODEProblem(f,x,(10.0, 0.), params)
    solve(invprob, Tsit5(),save_everystep=false)[2]
end

function invsample(x::Real)
    invsample([x])[1]
end

opt = ADAM(0.1)

model = Normal(5., 0.1)

raw_data = [Float32.(rand(model, (1, 100)))] #.|> gpu
data = Iterators.repeated(raw_data, 10)

function loss_adjoint(xs)
    preds = predict_adjoint(xs)[:, :,end]
    z = preds[1, :]
    delta_logp = preds[2, :]

    μ = 0.0
    σ = 1.0
    pz = Normal(μ, σ)

    #logpz = logpdf.(pz, z)
    logpz = -((((z .- μ) ./ σ ).^2 .+ log(2π))./2 .- log(σ))
    logpx = logpz - delta_logp
    loss = -mean(logpx)
end

cb = function()
    # You can schedule the learning rate if you want.
    opt.eta *= 0.95

    preds = Tracker.data(predict_adjoint(raw_data[1])[:, :,end])
    zs = preds[1, :]
    delta_logp = preds[2, :]

    μ = 0.0
    σ = 1.0
    pz = Normal(μ, σ)

    #logpz = logpdf.(pz, zs)
    logpz = -((((zs .- μ) ./ σ ).^2 .+ log(2π))./2 .- log(σ))
    logpx = Array(logpz - delta_logp)
    loss = -mean(logpx)

    _raw_data = Array.(raw_data)
    perm = sortperm(_raw_data[1][1, :])
    pl = plot(_raw_data[1][1, perm], Tracker.data.(exp.(logpx[perm])), xlims=(4, 6), ylims=(0, 5), title="Loss = $(loss)", label="Learned density")
    plot!(t -> pdf(model, t), label="Real density")
    gendata = Float32.(rand(pz, (1, 500))) #|> gpu
    samples = Array(invsample(gendata)[1, :])
    histogram!(pl,samples, normalize=:pdf, alpha=.3, fillalpha=.3, label="Model samples")
    display(pl)
end

CuArrays.allowscalar(false)
# The very first invocation of predict_adjoint is very slow because of the JIT overhead.
# I don't think this is normal.
# Just be patient.
cb()

@epochs 100 Flux.train!(loss_adjoint, ps, data, opt; cb = throttle(cb, 100))

The issue is that

function divergence(f, x::AbstractArray)
  y::AbstractArray, back = forward(f, x)
  D, N = size(x)
  T = DiffEqFlux.gpu_or_cpu(x)
  (i) = adapt(T,[i == j for j = 1:D])
  tmp = [back((i))[1][i, :] for i = 1:D]
  adapt(T,reduce(+, transpose(tmp)))
end

the adapt calls for some reason break the gradient, and then

Tracker.@grad function divergence(f, x::TrackedArray)
  y::AbstractArray, back = forward(f, x)
  D, N = size(x)
  T = DiffEqFlux.gpu_or_cpu(x)
  (i) = T([i == j for j = 1:D])
  out = reduce(+, transpose([back((i))[1][i, :] for i = 1:D]))
  out, Δ -> begin
    nothing, back(Δ)
  end
end

is the wrong adjoint. The adjoint of the divergence is the gradient, so I just took a stab at it but @MikeInnes might know how to fix this.

@ChrisRackauckas ChrisRackauckas mentioned this issue Sep 1, 2019
@aussetg
Copy link
Contributor

aussetg commented Sep 2, 2019

@jessebett so how should we "libraryitize" CNF? Clearly the layer should be given a nice function, but what about the loss function? Is that just specific to normal distributions, should we generalize it? Is this common or does it change depending on application?

I think the loss should be left to the user as it's really problem dependent, on the toy problem I'm playing with the loglikelihood isn't even computed that way. The base distribution doesn't have to be a normal distribution even if I'm pretty sure 99% people use a normal distribution for normalizing flows or VAEs, but if we think of it as a prior then I'm sure other more sensible and problem-specific distributions are useful.

I started writing a NormalizingFlow library that would implement the API of Distributions.jl ( because we are just fitting distributions by MLE after all, we even have access to the PDF, CDF and can sample. We can implement all the methods. ) but I stopped when you said the API would change with Zygote. But it would be useful to play with NF + Turing :)

Otherwise the only thing I would change in the API would be the ability to pass hyper-parameters to the ODE / predict_adjoint and all related functions. Right now we can only pass an AbstractArray p that will then be tracked. We would need to be able to pass a tracked p1 and untracked p2. Here p2 would just be the structure of the neural net. Having the net in the global scope is just dirty. It breaks for all kinds of reasons when playing in the repl too.

So I'd love if ODEs took parameters of the type dudt(u, p::Tuple(AbstractArray, Any), t) or dudt(u,p::AbstractArray,t; hyperparams=Nothing)

Implementing CNFs that way, as Distributions.jl objects would enable the CNFs to be mixed in other things like in Turing for example where they have Bijectors.jl

@jessebett
Copy link
Contributor

@ChrisRackauckas @aussetg is right re loss function and requiring base distribution to be normal.

However, to libraryize I think a convenience implementation could definitely be made in the normal distribution case, and then generalized for other distributions. A CNF (and FFJORD?) library implementation only needs the dynamics function m to have jacobian(m) working and composable (with higher order AD). In the case of FFJORD, m only needs gradient to work. Then a library function would take the usual dynamics given by dudt(u) = m(u) and instead give something like

function dudt(u) 
 z,deltapz = u  #unpack state
 z = m(z)  #original dynamics
 deltapz = trace(jacobian(m,z)) #or hutchinson estimate in ffjord
 return  [z;deltapz]

So as @aussetg says, the loss should definitely be left to the user. The u0 for the above dynamics takes an initial sample from the base distribution and its log-likelihood under that base distribution, and then transforms both. Ideally, a library version of all this should not have to worry about where these samples and log-probabilities come from, so they could be Distributions.jl. There's work to be done in both directions to make Distributions.jl more composable with our stuff and the other way around. I like the idea of supplying a Distributions-API that allows the user to sample from/evaluate likelihoods under a flow-defined distribution just like any other. However, it would also be nice if Distributions worked for us internally.

e.g. logpz = -((((z .- μ) ./ σ ).^2 .+ log(2π))./2 .- log(σ)) is hand-coded to compute logpdf(Normal(μ,σ),z) because the calling that from Distributions breaks autodiff in a few ways. @willtebbutt has ideas/work on improving this.

@ChrisRackauckas
Copy link
Member Author

ChrisRackauckas commented Mar 6, 2020

using OrdinaryDiffEq
using Distributions, FiniteDiff
using Flux, DiffEqFlux, DistributionsAD

# Hack to fix Zygote AD of Normal
Base.Irrational{:log2π}(x::Int64) = Base.Irrational{:log2π}()

function f(z, p)
  α, β = p
  tanh.(α.*z .+ β)
end

u0 = Float32[0.0, 0.0]
tspan = (0f0, 10f0)
function cnf(u,p,t)
  z, logpz = u
  α, β = p
  [f(z, p),-sum(FiniteDiff.finite_difference_jacobian((z)->f(z, p), [z]))]
end
prob = ODEProblem{false}(cnf,u0,tspan,nothing)

θinit = Float32[0.0, 0.0] # Initial Parameter Vector

function predict_adjoint(x,θ)
    concrete_solve(prob,Tsit5(),[x,0f0],θ,
                   saveat=0f0:0.1f0:10f0)
end

function loss_adjoint(θ,xs)
    pz = Normal(0.0, 1.0)
    preds = [predict_adjoint(x,θ)[:,end] for x in xs]
    z = [pred[1] for pred in preds] # TODO better slicing
    delta_logp = [pred[2] for pred in preds]

    logpz = logpdf.(pz, z)
    logpx = logpz - delta_logp
    loss = -mean(logpx)
end

function cb(θ,l)
    @show l
    false
end

opt = ADAM(0.1)
raw_data = Iterators.cycle(([Float32[rand(Normal(2.0, 0.1)) for i in 1:100]],))
DiffEqFlux.sciml_train(loss_adjoint, θinit, opt, raw_data, cb=cb, maxiters=100)

# check whether it looks standard normal
using Plots

preds = [predict_adjoint(r)[:,end] for r in raw_data[1]];

histogram([p[1].data for p in preds])

is a working finite difference version, and the following uses Tracker over Zygote to nest the reverse mode:

using OrdinaryDiffEq, Zygote
using Distributions, FiniteDiff, DiffEqSensitivity
using Flux, DiffEqFlux, DistributionsAD

function jacobian(f, x::AbstractVector)
  y::AbstractVector, back = Zygote.pullback(f, x)
  (i) = [i == j for j = 1:length(y)]
  vcat([transpose(back((i))[1]) for i = 1:length(y)]...)
end

# Hack to fix Zygote AD of Normal
Base.Irrational{:log2π}(x::Int64) = Base.Irrational{:log2π}()

function f(z, p)
  α, β = p
  tanh.(α.*z .+ β)
end

u0 = Float32[0.0, 0.0]
tspan = (0f0, 10f0)
function cnf(u,p,t)
  z, logpz = u
  α, β = p
  [f(z, p),-sum(jacobian((z)->f(z, p), [z]))]
end
prob = ODEProblem{false}(cnf,u0,tspan,nothing)

θinit = Float32[0.0, 0.0] # Initial Parameter Vector

function predict_adjoint(x,θ)
    concrete_solve(prob,Tsit5(),[x,0f0],θ,
                   saveat=0f0:0.1f0:10f0,sensealg=InterpolatingAdjoint(
                   autojacvec=DiffEqSensitivity.TrackerVJP()))
end

function loss_adjoint(θ,xs)
    pz = Normal(0.0, 1.0)
    preds = [predict_adjoint(x,θ)[:,end] for x in xs]
    z = [pred[1] for pred in preds] # TODO better slicing
    delta_logp = [pred[2] for pred in preds]

    logpz = logpdf.(pz, z)
    logpx = logpz - delta_logp
    loss = -mean(logpx)
end

function cb(θ,l)
    @show l
    false
end

opt = ADAM(0.01)
raw_data = Iterators.cycle(([Float32[rand(Normal(2.0, 0.1)) for i in 1:100]],))
DiffEqFlux.sciml_train(loss_adjoint, θinit, opt, raw_data, cb=cb, maxiters=100)

# check whether it looks standard normal
using Plots

preds = [predict_adjoint(r)[:,end] for r in raw_data[1]];

histogram([p[1].data for p in preds])

Now it's time to library-itize it.

@jessebett
Copy link
Contributor

jessebett commented Mar 12, 2020

@abhigupta768 here is a summary:

So @ChrisRackauckas wrote this after a few iterations but I’ll just quickly give an overview as far as I can tell.

f is a very simple Dense neural network layer. Should be replaced by something like what Chris is calling FastDense.

function cnf(u,p,t)
  z, logpz = u
  α, β = p
  [f(z, p),-sum(FiniteDiff.finite_difference_jacobian((z)->f(z, p), [z]))]
end

This CNF is the dynamics we describe in the Neural ODE paper. I can get the equation number for you, but I highly recommend you read both the variational inference normalizing flows and the CNF section of Neural ODE paper to understand what this is doing.

The TL;DR: the complete state u represents both the original state z which evolves according to dz/dt = f(z,p)
AND the log-probability of that state z at time t. Which evolves according to the -trace(J).

The idea is we start with some x that we can't evaluate the loglikelihood logpx of easily. So instead, we will put it through these cnf dynamics, and the result hopefully be a distribution we can easily evaluate the likelihood under, e.g. a standard gaussian.
So we put x in, integrate it, and get z out. We also get delta_logp which is the change in log likelihood during the integration, from that second term in the dynamics.
Now that we have a z we can evaluate its likelihood under that standard gaussian, giving logpz.
To get logpx we just undo the change in log-likelihood from integration: logpx = logpz - delta_logp.

The loss we are trying to optimize is to maximize the likelihood of x, or minimize the negative log likelihood, loss = -mean(logpx)

function loss_adjoint(θ,xs)
    pz = Normal(0.0, 1.0)
    preds = [predict_adjoint(x,θ)[:,end] for x in xs]
    z = [pred[1] for pred in preds] # TODO better slicing
    delta_logp = [pred[2] for pred in preds]

    logpz = logpdf.(pz, z)
    logpx = logpz - delta_logp
    loss = -mean(logpx)
end

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

4 participants