diff --git a/Project.toml b/Project.toml index 474d9ab..b0738c6 100644 --- a/Project.toml +++ b/Project.toml @@ -22,6 +22,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] Aqua = "0.8" +ChainRulesCore = "1.23.0" Compat = "4.7.0" ComputationalResources = "0.3.2" Flux = "0.12, 0.13, 0.14" diff --git a/src/curvature/utils.jl b/src/curvature/utils.jl index 16edab4..f71ffea 100644 --- a/src/curvature/utils.jl +++ b/src/curvature/utils.jl @@ -27,9 +27,12 @@ function jacobians_unbatched(curvature::CurvatureInterface, X::AbstractArray) ŷ = vec(ŷ) # Jacobian: # Differentiate f with regards to the model parameters + J = [] ChainRulesCore.ignore_derivatives() do 𝐉 = jacobian(() -> nn(X), Flux.params(nn)) + push!(J, 𝐉) end + 𝐉 = J[1] # Concatenate Jacobians for the selected parameters, to produce a matrix (K, P), where P is the total number of parameter scalars. 𝐉 = reduce(hcat, [𝐉[θ] for θ in curvature.params]) if curvature.subset_of_weights == :subnetwork @@ -50,9 +53,12 @@ function jacobians_batched(curvature::CurvatureInterface, X::AbstractArray) batch_size = size(X)[end] out_size = outdim(nn) # Jacobian: + grads = [] ChainRulesCore.ignore_derivatives() do - grads = jacobian(() -> nn(X), Flux.params(nn)) + g = jacobian(() -> nn(X), Flux.params(nn)) + push!(grads, g) end + grads = grads[1] grads_joint = reduce(hcat, [grads[θ] for θ in curvature.params]) views = [ @view grads_joint[batch_start:(batch_start + out_size - 1), :] for