Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Question on the use of the Update! method and is_same_except() #212

Open
pasq-cat opened this issue Sep 24, 2024 · 11 comments
Open

Question on the use of the Update! method and is_same_except() #212

pasq-cat opened this issue Sep 24, 2024 · 11 comments

Comments

@pasq-cat
Copy link

Hi, i was trying to implement the update method for laplaceredux but I am having a problem.

this is the model

MLJBase.@mlj_model mutable struct LaplaceRegressor <: MLJFlux.MLJFluxProbabilistic
    model::Flux.Chain = nothing
    flux_loss = Flux.Losses.mse
    optimiser = Adam()
    epochs::Integer = 1000::(_ > 0)
    batch_size::Integer = 32::(_ > 0)
    subset_of_weights::Symbol = :all::(_ in (:all, :last_layer, :subnetwork))
    subnetwork_indices = nothing
    hessian_structure::Union{HessianStructure,Symbol,String} =
        :full::(_ in (:full, :diagonal))
    backend::Symbol = :GGN::(_ in (:GGN, :EmpiricalFisher))
    σ::Float64 = 1.0
    μ₀::Float64 = 0.0
    P₀::Union{AbstractMatrix,UniformScaling,Nothing} = nothing
    fit_prior_nsteps::Int = 100::(_ > 0)
end

this is the fit function that i have written

function MMI.fit(m::LaplaceRegressor, verbosity, X, y)

    #X = MLJBase.matrix(X) |> permutedims
    #y = reshape(y, 1, :)

    if Tables.istable(X)
        X = Tables.matrix(X)|>permutedims
    end

    # Reshape y if necessary
    y = reshape(y, 1, :)

    data_loader = Flux.DataLoader((X, y); batchsize=m.batch_size)
    opt_state = Flux.setup(m.optimiser, m.model)
    loss_history=[]
    push!(loss_history, m.flux_loss(m.model(X), y ))

    for epoch in 1:(m.epochs)

        loss_per_epoch= 0.0


        for (X_batch, y_batch) in data_loader
            # Forward pass: compute predictions
            y_pred = m.model(X_batch)

            # Compute loss
            loss = m.flux_loss(y_pred, y_batch)

            # Compute gradients 
            grads = gradient(m.model) do model
                # Recompute predictions inside gradient context
                y_pred = model(X_batch)
                m.flux_loss(y_pred, y_batch)
            end
            
            # Update parameters using the optimizer and computed gradients
            Flux.Optimise.update!(opt_state ,m.model , grads[1])

            # Accumulate the loss for this batch
            loss_per_epoch += sum(loss)  # Summing the batch loss
            
        end

        push!(loss_history,loss_per_epoch )

        # Print loss every 100 epochs if verbosity is 1 or more
        if verbosity >= 1 && epoch % 100 == 0
            println("Epoch $epoch: Loss: $loss_per_epoch ")
        end
    end

    la = LaplaceRedux.Laplace(
        m.model;
        likelihood=:regression,
        subset_of_weights=m.subset_of_weights,
        subnetwork_indices=m.subnetwork_indices,
        hessian_structure=m.hessian_structure,
        backend=m.backend,
        σ=m.σ,
        μ₀=m.μ₀,
        P₀=m.P₀,
    )

    # fit the Laplace model:
    LaplaceRedux.fit!(la, data_loader)
    optimize_prior!(la; verbose=false, n_steps=m.fit_prior_nsteps)

    fitresult = la
    report = (loss_history = loss_history,)
    cache = (deepcopy(m),opt_state, loss_history)
    return fitresult, cache, report
end

and now follows the incomplete update function that i was trying. I have removed the loop part since it's not important.

function MMI.update(m::LaplaceRegressor, verbosity, old_fitresult, old_cache, X, y)

println(" running MMI:update")

old_model = old_cache[1]

if Tables.istable(X)
    X = Tables.matrix(X)|>permutedims
end

# Reshape y if necessary
y = reshape(y, 1, :)


println(MMI.is_same_except(m, old_model, :epochs))



cache=()
report=()
return old_fitresult, cache, report
end



the issue is that if i try to rerun the model by changing only the number of epochs is_same_except still gives me

false

even though :epochs is listed as exception

using MLJ
flux_model = Chain(
    Dense(4, 10, relu),
    Dense(10, 10, relu),
    Dense(10, 1)
)
model = LaplaceRegressor(model=flux_model)

X, y = make_regression(100, 4; noise=0.5, sparse=0.2, outliers=0.1)
mach = machine(model, X, y) 
MLJBase.fit!(mach)



model.epochs=2000

MLJBase.fit!(mach)

so what is the correct way to implement is_same_except? thank you

@ablaom
Copy link
Member

ablaom commented Sep 25, 2024

Not sure what the problem might be. Can you provide a MWE demonstrating that is_same_except is not working as you expect. I.e, some variation of this (which is working for me):

using MLJModelInterface
import MLJModelInterface as MMI

mutable struct Classifier <: Probabilistic
    x::Int
    y::Int
end

model = Classifier(1, 2)
model2 = deepcopy(model)
model2.y = 7

@assert MMI.is_same_except(model, model2, :y)

Or, if you suspect some other problem, a more self-contained MWE would be helpful.

@pasq-cat
Copy link
Author

for example, using the this

flux_model = Chain(
    Dense(4, 10, relu),
    Dense(10, 10, relu),
    Dense(10, 1)
)
model = LaplaceRegressor(model=flux_model)
copy_model= deepcopy(model)
copy_model.epochs= 2000
MLJBase.is_same_except(model , copy_model, :epochs)

gives false but i have only changed epochs

for a simpler example that does not need LaplaceRedux consider this:

using MLJModelInterface
import MLJModelInterface as MMI

mutable struct LaplaceRegr <: MLJBase.Probabilistic
    model::Flux.Chain
    epochs::Integer 
    batch_size::Integer
    
end

model = LaplaceRegr(flux_model,1000,2)

model2= deepcopy(model)

model2.epochs = 2000

MMI.is_same_except(model, model2, :epochs)

it's due to the fact that one of the field has a flux chain in it. If i remove it i get true.

@ablaom
Copy link
Member

ablaom commented Sep 25, 2024

Thanks, this helps me see the problem:

julia> c = Flux.Chain(Dense(2,3))
julia> c == deepcopy(c)
false

Unfortunately, MLJ was not designed with this kind of behaviour in mind, for hyperparameter values. This has occurred once before and a hack was introduced, the trait MMI.deep_properties. However, this does not fix your issue, because we need equality all the way down, and the hack only goes down one level (read the doc-string if you are interested). We could try to fix the hack, but it's such a corner case and technically breaking, so I'm not a big fan. (In fact, for orthogonal reasons, the hack is no longer used anyway.)

Another possible resolution is for you to explicitly add an overloading is_same_except(model2::LaplaceRegr, model2::LaplaceRegr, ...) = ... to get the actual behaviour you want (this will automatically carry over to == because the latter is defined in terms of the former). Feel free to lift most of the code you need from here. Or, you could just manually check in update whether the only hyperparameter that has "truly changed" is :epochs. We used to do that before is_the_same_except was added.

In any case, make sure neither fit or update actually mutates any hyper-parameter value as this is definitely disallowed (except for RNG's). So if you use :model to build the learned parameters (aka fitresult), be sure to first create a deep copy first. (This is an instance of a wider limitation of Flux, due to it's conflation of hyperparameters and learned parameters, which is fixed in Lux.jl.)

@pasq-cat
Copy link
Author

couldn't something like this replace the default es_same_except function?

# Define the function is_same_except
function is_same_except(m1::M1, m2::M2, exceptions::Symbol...) where {M1<:MLJType, M2<:MLJType}
    typeof(m1) === typeof(m2) || return false
    names = propertynames(m1)
    propertynames(m2) === names || return false

    for name in names
        if !(name in exceptions)
            if !_isdefined(m1, name)
               !_isdefined(m2, name) || return false
            elseif _isdefined(m2, name)
                if name in deep_properties(M1)
                    _equal_to_depth_one(
                        getproperty(m1,name),
                        getproperty(m2, name)
                    ) || return false
                else
                    (
                        is_same_except(
                            getproperty(m1, name),
                            getproperty(m2, name)
                        ) ||
                        getproperty(m1, name) isa AbstractRNG ||
                        getproperty(m2, name) isa AbstractRNG ||
                        (getproperty(m1, name) isa Flux.Chain && getproperty(m2, name) isa Flux.Chain && _equal_flux_chain(getproperty(m1, name), getproperty(m2, name)))
                    ) || return false
                end
            else
                return false
            end
        end
    end
    return true
end

with an helper function

function _equal_flux_chain(chain1::Flux.Chain, chain2::Flux.Chain)
    if length(chain1.layers) != length(chain2.layers)
        return false
    end
    for (layer1, layer2) in zip(chain1.layers, chain2.layers)
        if typeof(layer1) != typeof(layer2)
            return false
        end
        params1 = Flux.params(layer1)
        params2 = Flux.params(layer2)
        if length(params1) != length(params2)
            return false
        end
        for (p1, p2) in zip(params1, params2)
            if !isequal(p1, p2)
                return false
            end
        end
    end
    return true
end

it should work for every MLJ model that wrap a flux model.
if not, i will use it only for LaplaceRegressor

@ablaom
Copy link
Member

ablaom commented Sep 26, 2024

Great progress.

I think your test for equality of Chains is not correct, for it will not behave as expected for nested chains, like Chain(Chain(...), ...) I. Rather, just apply Flux.params directly to the whole of chain1 and chain2, i.e., there's no need to deconstruct. You may want to add a test, as this kind of logic is a bug-magnet.

I suggest you just overload locally and we not add complexity to MLJModelInterface for this one corner case. There is probably a more generic way to handle this, maybe by fixing deep_properties, but I don't have time to look at it just now.

@pasq-cat
Copy link
Author

You may want to add a test, as this kind of logic is a bug-magnet.

indeed, i just found out that the models don't pass the test if optimiser= Adam() is included in the struct. How should I handle this case? should i always add it to the exceptions?

@ablaom
Copy link
Member

ablaom commented Sep 29, 2024

Can you please provide some more detail. I don't see any problem at my end:

julia> using Optimisers, Flux

julia> import MLJModelInterface as MMI

julia> model = NeuralNetworkClassifier();

julia> model2 = deepcopy(model);

julia> MMI.is_same_except(model, model2)
true

julia> model2.optimiser = Adam(42)
Adam(42.0, (0.9, 0.999), 1.0e-8)

julia> MMI.is_same_except(model, model2)
false

julia> model.optimiser = Adam(42)
Adam(42.0, (0.9, 0.999), 1.0e-8)

julia> MMI.is_same_except(model, model2)
true

Are you perhaps using Flux.jl optimisers instead of Optimisers.jl optimisers?

@pasq-cat
Copy link
Author

pasq-cat commented Sep 30, 2024

Are you perhaps using Flux.jl optimisers instead of Optimisers.jl optimisers?

yes i think this is the issue because

using MLJModelInterface
import MLJModelInterface as MMI

mutable struct LaplaceRegr <: MLJBase.Probabilistic
    model::Flux.Chain
    epochs::Integer 
    batch_size::Integer
    optimiser
    
end

model = LaplaceRegr(flux_model,1000,2,Flux.Adam())



model2= deepcopy(model)

model2.epochs = 2000

MMI.is_same_except(model, model2)

gives me false.

Looks like the tutorial i have read to write the training loop is outdated and now Flux prefer optimisers from the Optimisers.jl package but the documentation available online is a confusing mix of old and new rules...

@ablaom
Copy link
Member

ablaom commented Sep 30, 2024

Looks like the tutorial i have read to write the training loop is outdated and now Flux prefer optimisers from the Optimisers.jl package but the documentation available online is a confusing mix of old and new rules...

Well MLJFlux now definitely requires only Optimiser.jl optimisers. If any of the MLJ/MLJFlux docs are out-of-date in this respect, please point them out.

@pasq-cat
Copy link
Author

pasq-cat commented Oct 1, 2024

Looks like the tutorial i have read to write the training loop is outdated and now Flux prefer optimisers from the Optimisers.jl package but the documentation available online is a confusing mix of old and new rules...

Well MLJFlux now definitely requires only Optimiser.jl optimisers. If any of the MLJ/MLJFlux docs are out-of-date in this respect, please point them out.

ah but it was not the official documentation, it was i think a medium page or something like that. anyway i think i have fixed the update loop. if you don't mind i would like to keep this issue open for a bit longer, just in case i encounter another problem. in the opposite case i will close it myself. ok? thank you.

@ablaom
Copy link
Member

ablaom commented Oct 1, 2024

Happy to support your work on an MLJ interface, and thanks for your persistence.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants