Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Training Neural SDEs with Mutating Arrays #785

Closed
Alexio-Phytides opened this issue Dec 19, 2022 · 1 comment
Closed

Training Neural SDEs with Mutating Arrays #785

Alexio-Phytides opened this issue Dec 19, 2022 · 1 comment

Comments

@Alexio-Phytides
Copy link

Alexio-Phytides commented Dec 19, 2022

I am writing a training function in a similar form to the method of moments one in the neural SDE tutorial in the DiffEqFlux.jl documentation, that requires a loop as my loss is the difference between call prices which I need to calculate for various strikes using a loop. However, this results in a mutating array, how do I get around this?

function predict_neuralsde(p, u = S₀) 
  return Array(Pre_NSDE(u, p)) # Returns an array of stock prices/SDE solutions at each time point. 
end

function loss_neuralsde(p; n = 10000)
u = repeat(reshape(S₀, :, 1), 1, n)
stock = predict_neuralsde(p,u)[1,:,end]

for i = 1:length(K)
NSDE_call[i] = exp(-r*tspan[2]).*mean(maximum([terminal_values .- K[i] zeros(size(terminal_values))], dims = 2))
end 

NSDE_call
    
loss = sum(abs2, NSDE_call - BS_Price')  
 return loss , NSDE_call  
end  

opt = ADAM(0.025)

adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x,p) -> loss_neuralsde(x, n=10000), adtype)
optprob = Optimization.OptimizationProblem(optf, Pre_NSDE.p)
result1 = Optimization.solve(optprob, opt, callback = callback, maxiters = 200)  # 200 iterations of training.

The error message is as follows:

Mutating arrays is not supported -- called setindex!(Matrix{Float64}, ...)
This error occurs when you ask Zygote to differentiate operations that change
the elements of arrays in place (e.g. setting values with x .= ...)
@ChrisRackauckas
Copy link
Member

Quadruplicate of https://discourse.julialang.org/t/training-with-mutating-arrays/91876/2, answered there. Please don't quadruple post.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants