Skip to content

Commit

Permalink
hmm
Browse files Browse the repository at this point in the history
  • Loading branch information
pat-alt committed Mar 21, 2024
1 parent 18070a9 commit e880e0c
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 1 deletion.
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
Aqua = "0.8"
ChainRulesCore = "1.23.0"
Compat = "4.7.0"
ComputationalResources = "0.3.2"
Flux = "0.12, 0.13, 0.14"
Expand Down
8 changes: 7 additions & 1 deletion src/curvature/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,12 @@ function jacobians_unbatched(curvature::CurvatureInterface, X::AbstractArray)
= vec(ŷ)
# Jacobian:
# Differentiate f with regards to the model parameters
J = []
ChainRulesCore.ignore_derivatives() do
𝐉 = jacobian(() -> nn(X), Flux.params(nn))
push!(J, 𝐉)
end
𝐉 = J[1]
# Concatenate Jacobians for the selected parameters, to produce a matrix (K, P), where P is the total number of parameter scalars.
𝐉 = reduce(hcat, [𝐉[θ] for θ in curvature.params])
if curvature.subset_of_weights == :subnetwork
Expand All @@ -50,9 +53,12 @@ function jacobians_batched(curvature::CurvatureInterface, X::AbstractArray)
batch_size = size(X)[end]
out_size = outdim(nn)
# Jacobian:
grads = []
ChainRulesCore.ignore_derivatives() do
grads = jacobian(() -> nn(X), Flux.params(nn))
g = jacobian(() -> nn(X), Flux.params(nn))
push!(grads, g)
end
grads = grads[1]
grads_joint = reduce(hcat, [grads[θ] for θ in curvature.params])
views = [
@view grads_joint[batch_start:(batch_start + out_size - 1), :] for
Expand Down

0 comments on commit e880e0c

Please sign in to comment.