Skip to content

Commit

Permalink
Merge pull request #82 from JuliaTrustworthyAI/81-fix-interoperabilit…
Browse files Browse the repository at this point in the history
…y-with-counterfactualexplanationsjl

yup, this has done it
  • Loading branch information
pat-alt authored Mar 21, 2024
2 parents cb2af89 + e880e0c commit 3e8115c
Show file tree
Hide file tree
Showing 7 changed files with 531 additions and 28 deletions.
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
name = "LaplaceRedux"
uuid = "c52c1a26-f7c5-402b-80be-ba1e638ad478"
authors = ["Patrick Altmeyer"]
version = "0.1.5"
version = "0.1.6"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
ComputationalResources = "ed09eef8-17a6-5b46-8889-db040fac31e3"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
Expand All @@ -21,6 +22,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
Aqua = "0.8"
ChainRulesCore = "1.23.0"
Compat = "4.7.0"
ComputationalResources = "0.3.2"
Flux = "0.12, 0.13, 0.14"
Expand Down
2 changes: 1 addition & 1 deletion src/baselaplace.jl
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ end

# Posterior predictions:
"""
predict(la::BaseLaplace, X::AbstractArray; link_approx=:probit)
predict(la::BaseLaplace, X::AbstractArray; link_approx=:probit, predict_proba::Bool=true)
Computes predictions from Bayesian neural network.
Expand Down
16 changes: 14 additions & 2 deletions src/curvature/utils.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
using ChainRulesCore

"""
jacobians(curvature::CurvatureInterface, X::AbstractArray; batched::Bool=false)
Expand Down Expand Up @@ -25,7 +27,12 @@ function jacobians_unbatched(curvature::CurvatureInterface, X::AbstractArray)
= vec(ŷ)
# Jacobian:
# Differentiate f with regards to the model parameters
𝐉 = jacobian(() -> nn(X), Flux.params(nn))
J = []
ChainRulesCore.ignore_derivatives() do
𝐉 = jacobian(() -> nn(X), Flux.params(nn))
push!(J, 𝐉)
end
𝐉 = J[1]
# Concatenate Jacobians for the selected parameters, to produce a matrix (K, P), where P is the total number of parameter scalars.
𝐉 = reduce(hcat, [𝐉[θ] for θ in curvature.params])
if curvature.subset_of_weights == :subnetwork
Expand All @@ -46,7 +53,12 @@ function jacobians_batched(curvature::CurvatureInterface, X::AbstractArray)
batch_size = size(X)[end]
out_size = outdim(nn)
# Jacobian:
grads = jacobian(() -> nn(X), Flux.params(nn))
grads = []
ChainRulesCore.ignore_derivatives() do
g = jacobian(() -> nn(X), Flux.params(nn))
push!(grads, g)
end
grads = grads[1]
grads_joint = reduce(hcat, [grads[θ] for θ in curvature.params])
views = [
@view grads_joint[batch_start:(batch_start + out_size - 1), :] for
Expand Down
Loading

2 comments on commit 3e8115c

@pat-alt
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator register

Release notes:

  • Fixes interoperability with CE.jl

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/103358

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.1.6 -m "<description of version>" 3e8115cb8914be04621f971c8ff6bcfb36f1262e
git push origin v0.1.6

Please sign in to comment.