Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change irim block - add invertible UNET #57

Open
wants to merge 11 commits into
base: master
Choose a base branch
from

Conversation

rafaelorozco
Copy link
Collaborator

-Changing irim block to generate multiple RBs with different dilations and different hidden channels. This is the proper welling implementation.
-This should break some examples. If we are okay with this new block I will go and change all examples to run properly. It is as easy as changing NetworkIRIM(n_in, n_hidden ....)->NetworkIRIM(n_in, [n_hidden], [4];) thus defining a single unet layer with conv dilation 4 which is the current IRIM implementation.

-Add new network invertible unet this is basically a single loop unrolled iteration of irim. Name comes from welling code.
-Directly takes in a precomputed gradient.
-meant for inverse problems where your operator is too expensive to use online.

-Doesnt have aggressive memory savings such as inplace conv1x1 yet but should work well with moderately sized 3D. Will be testing this.

@codecov
Copy link

codecov bot commented Apr 25, 2022

Codecov Report

Base: 88.11% // Head: 87.94% // Decreases project coverage by -0.16% ⚠️

Coverage data is based on head (9fb0572) compared to base (304c778).
Patch coverage: 91.74% of modified lines in pull request are covered.

❗ Current head 9fb0572 differs from pull request most recent head 34a67dd. Consider uploading reports for the commit 34a67dd to get more accurate results

Additional details and impacted files
@@            Coverage Diff             @@
##           master      #57      +/-   ##
==========================================
- Coverage   88.11%   87.94%   -0.17%     
==========================================
  Files          31       32       +1     
  Lines        2330     2390      +60     
==========================================
+ Hits         2053     2102      +49     
- Misses        277      288      +11     
Impacted Files Coverage Δ
src/InvertibleNetworks.jl 60.00% <ø> (ø)
src/layers/invertible_layer_conv1x1.jl 89.37% <ø> (-0.63%) ⬇️
src/networks/invertible_network_irim.jl 91.78% <50.00%> (ø)
src/layers/invertible_layer_irim.jl 88.46% <91.42%> (-9.69%) ⬇️
src/networks/invertible_network_unet.jl 93.75% <93.75%> (ø)
src/layers/layer_residual_block.jl 98.75% <100.00%> (+0.06%) ⬆️

Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here.

☔ View full report at Codecov.
📢 Do you have feedback about the report comment? Let us know in this issue.

Copy link
Collaborator

@alisiahkoohi alisiahkoohi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be good to include all the necessary modifications due to this change in one commit.

C::Conv1x1
RB::Union{ResidualBlock, FluxBlock}
C::AbstractArray{Conv1x1, 1}
RB::AbstractArray{ResidualBlock, 1}
end
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are we removing FluxBlock from allowed types?

@@ -75,10 +75,18 @@ end
# Constructors

# Constructor
function ResidualBlock(n_in, n_hidden; k1=3, k2=3, p1=1, p2=1, s1=1, s2=1, fan=false, ndims=2)
function ResidualBlock(n_in, n_hidden; d=nothing, k1=3, k2=3, p1=1, p2=1, s1=1, s2=1, fan=false, ndims=2)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not clear what d is. Maybe add a docstring.


# Check if downsampling factor d is defined
if !isnothing(d)
k1 = d
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why k1 = s1 = d? Please add a reference for this choice.

@@ -66,7 +66,7 @@ end
@Flux.functor NetworkLoop

# 2D Constructor
function NetworkLoop(n_in, n_hidden, maxiter, Ψ; k1=4, k2=3, p1=0, p2=1, s1=4, s2=1, type="additive", ndims=2)
function NetworkLoop(n_in, n_hidden, maxiter, Ψ; n_hiddens=nothing, ds=nothing, k1=4, k2=3, p1=0, p2=1, s1=4, s2=1, type="additive", ndims=2)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar variable names n_hidden and and n_hiddens. Maybe add a docstring explaining these different inputs and think of a more clear variable name.

@@ -0,0 +1,133 @@
# Invertible network layer from Putzky and Welling (2019): https://arxiv.org/abs/1911.10914
# Author: Philipp Witte, [email protected]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are these information correct?

@@ -64,6 +64,7 @@ include("layers/invertible_layer_hint.jl")
# Invertible network architectures
include("networks/invertible_network_hint_multiscale.jl")
include("networks/invertible_network_irim.jl") # i-RIM: Putzky and Welling (2019)
include("networks/invertible_network_unet.jl") # single loop i-RIM: Putzky and Welling (2019)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

C = Conv1x1(n_in)
RB = ResidualBlock(n_in÷2, n_hidden; k1=k1, k2=k2, p1=p1, p2=p2, s1=s1, s2=s2, ndims=ndims)
if length(n_hiddens) != length(ds)
throw("Number of downsampling factors in ds must be the same defined hidden channels in n_hidden")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

n_hiddens and ds must have equal length but the provided lengths are .... and ....

else
ΔX, Δθ_C2 = L.C.inverse((ΔX_, X_); set_grad=set_grad)[1:2]
# Initialize layer parameters
!set_grad && (p1 = Array{Parameter, 1}(undef, 0))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need to initialize?

end

set_grad ? (return ΔX, X) : (return ΔX, cat(Δθ_C1+Δθ_C2, Δθ_RB; dims=1), X)
set_grad ? (return ΔY, Y) : (ΔY, cat(p1, p2; dims=1), Y)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be good to keep the naming convention the same, i.e., this should return ΔX.

end

@test isapprox(err3[end] / (err3[1]/2^(maxiter-1)), 1f0; atol=1f0)
@test isapprox(err4[end] / (err4[1]/4^(maxiter-1)), 1f0; atol=1f0)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing jacobian tests.

add very simple invertible unet. No gradient as irim, just an input for clear comparison with traditional unets.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants