Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ResNet18: incompatible architecture and pretrained parameters #18

Closed
andreuvall opened this issue Oct 17, 2023 · 2 comments
Closed

ResNet18: incompatible architecture and pretrained parameters #18

andreuvall opened this issue Oct 17, 2023 · 2 comments

Comments

@andreuvall
Copy link

andreuvall commented Oct 17, 2023

ResNets are transformed into Lux from Metalhead using the resnet function. transform is yielding a Chain with two Chains in it, each containing a number of layers. We can also see this if we use Lux.setup on the model.

using Metalhead
using Lux
using Random

model = transform(ResNet(18).layers);
ps, st = Lux.setup(Random.default_rng(), model);
@show keys(ps)
> keys(ps) = (:layer_1, :layer_2)

There is the option to pass pretrained = true to the resnet function. However, the pretrained parameters loaded by _initialize_model are a "flattened" named tuple of 14 layers.

using Boltz

_, ps_prime, st_prime = resnet(:resnet18; pretrained = true);
@show keys(ps_prime)
> keys(ps_prime) = (:layer_1, :layer_2, :layer_3, :layer_4, :layer_5, :layer_6, :layer_7, :layer_8, :layer_9, :layer_10, :layer_11, :layer_12, :layer_13, :layer_14)

Therefore, the model architecture and the pretrained parameters are not compatible.

x = randn(Float32, 224, 224, 3, 1);
model(x, ps, st)  # this works
model(x, ps_prime, st_prime)  # but this doesn't
@andreuvall
Copy link
Author

This other approach with preserve_ps_st = true seems to work. Assuming the code above has been executed already:

model_pp = transform(
    ResNet(18; pretrain = true).layers; 
    preserve_ps_st = true
);
> ┌ Warning: Preserving the state of `Flux.BatchNorm` is currently not supported. Ignoring the state.
> └ @ LuxFluxTransformExt ~/.julia/packages/Lux/1Iulg/ext/LuxFluxTransformExt.jl:269

ps_pp, st_pp = Lux.setup(Random.default_rng(), model_pp);
@show keys(ps_pp)
> keys(ps_pp) = (:layer_1, :layer_2)

@assert ps_pp.layer_1.layer_1.layer_1.weight == ps_prime.layer_1.weight
@assert ps_pp.layer_1.layer_1.layer_2.scale == ps_prime.layer_2.scale
model(x, ps_pp, st_pp)  # this works

Here I only checked that two of the pretrained parameter arrays are equal. I am also unsure of the effects of the state being ignored when loading the model and the pretrained parameters.

@andreuvall andreuvall changed the title ResNet18: incompatible architecture and pretrained weights ResNet18: incompatible architecture and pretrained parameters Oct 17, 2023
@avik-pal
Copy link
Member

I see that is the problem. The initial weights were imported in Lux 0.4, and since some defaults changed it led to this breakage.

Here I only checked that two of the pretrained parameter arrays are equal. I am also unsure of the effects of the state being ignored when loading the model and the pretrained parameters.

States not being preserved means that your predictions won't be correct. Specify force_preserve (https://lux.csail.mit.edu/dev/api/Lux/flux_to_lux#Lux.transform) and that should do it for now

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants