Skip to content

Commit

Permalink
fix errors in training
Browse files Browse the repository at this point in the history
Correction in code for Text Classifier
  • Loading branch information
AdarshKumar712 committed May 1, 2021
1 parent 2977425 commit f10678d
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 27 deletions.
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ version = "0.1.1"

[deps]
BSON = "fbb218c0-5317-5bc6-957e-2ee96dd4b1f0"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
CorpusLoaders = "214a0ac2-f95b-54f7-a80b-442ed9c2c9e8"
DataDeps = "124859b0-ceae-595e-8997-d05f6a7a8dfe"
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
Expand Down
19 changes: 8 additions & 11 deletions src/ULMFiT/fine_tune_lm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,17 +24,17 @@ opts : `Vector` of optimizers used to update weights for corresponding la
NOTE: length(opts) == length(layers)
"""
function discriminative_step!(layers, ηL::Float64, l, opts::Vector)
function discriminative_step!(layers, lm::LanguageModel, gen, ηL::Float64, opts::Vector)
@assert length(opts) == length(layers)
# Gradient calculation
grads = Zygote.gradient(() -> l, get_trainable_params(layers))
grads = Zygote.gradient(() -> loss(lm, gen), get_trainable_params(layers))

# discriminative step
ηl = ηL/(2.6^(length(layers)-1))
for (layer, opt) in zip(layers, opts)
opt.eta = ηl
for ps in get_trainable_params([layer])
Flux.Optimise.update!(opt, ps, grads)
Flux.Optimise.update!(opt, ps, grads[ps])
end
ηl *= 2.6
end
Expand All @@ -55,26 +55,23 @@ function fine_tune_lm!(lm=LanguageModel(), data_loader=imdb_fine_tune_data,
epochs::Integer=1, checkpoint_itvl::Integer=5000)

opts = [ADAM(0.001, (0.7, 0.99)) for i=1:4]
gen = data_loader()
num_of_iters = take!(gen)
cut = num_of_iters * epochs * stlr_cut_frac

# Fine-Tuning loops
for epoch=1:epochs
println("\nEpoch: $epoch")
gen = data_loader()
num_of_iters = take!(gen)
cut = num_of_iters * epochs * stlr_cut_frac
T = num_of_iters-Int(floor((num_of_iters*2)/100))
set_trigger!.(T, lm.layers)
for i=1:num_of_iters

# FORWARD
l = loss(lm, gen)

# Slanted triangular learning rate step
t = i + (epoch-1)*num_of_iters
p_frac = (i < cut) ? i/cut : (1 - ((i-cut)/(cut*(1/stlr_cut_frac-1))))
ηL = stlr_η_max*((1+p_frac*(stlr_ratio-1))/stlr_ratio)

# Backprop with discriminative fine-tuning step
discriminative_step!(lm.layers[[1, 3, 5, 7]], ηL, l, opts)
discriminative_step!(lm.layers[[1, 3, 5, 7]], lm, gen, ηL, opts)

# Resets dropout masks for all the layers with DropOut or DropConnect
reset_masks!.(lm.layers)
Expand Down
11 changes: 4 additions & 7 deletions src/ULMFiT/pretrain_lm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -107,10 +107,10 @@ function loss(lm, gen)
end

# Backpropagation step while training
function backward!(layers, l, opt)
function backward!(layers, lm, gen, opt)
# Calulating gradients and weights updation
p = get_trainable_params(layers)
grads = Zygote.gradient(() -> l, p)
grads = Zygote.gradient(() -> loss(lm, gen), p)
Flux.Optimise.update!(opt, p, grads)
return
end
Expand Down Expand Up @@ -138,11 +138,8 @@ function pretrain_lm!(lm::LanguageModel=LanguageModel(), data_loader::Channel=lo
set_trigger!.(T, lm.layers) # Setting triggers for AWD_LSTM layers
for i=1:num_of_batches

# FORWARD PASS
l = loss(lm, gen)

# REVERSE PASS
backward!(lm.layers, l, opt)
backward!(lm.layers, lm, gen, opt)

# ASGD Step, works after Triggering
asgd_step!.(i, lm.layers)
Expand All @@ -158,7 +155,7 @@ end

# To save model
function save_model!(m::LanguageModel, filepath::String)
weights = cpu.(Tracker.data.(params(m)))
weights = cpu.(params(m))
BSON.@save filepath weights
end

Expand Down
36 changes: 27 additions & 9 deletions src/ULMFiT/train_text_classifier.jl
Original file line number Diff line number Diff line change
Expand Up @@ -95,11 +95,13 @@ function forward(tc::TextClassifier, gen::Channel, tracked_steps::Integer=32)
X = take!(gen)
l = length(X)
# Truncated Backprop through time
for i=1:ceil(l/now_per_pass)-1 # Tracking is swiched off inside this loop
(i == 1 && l%now_per_pass != 0) ? (last_idx = l%now_per_pass) : (last_idx = now_per_pass)
H = broadcast(x -> indices(x, classifier.vocab, "_unk_"), X[1:last_idx])
H = classifier.rnn_layers.(H)
X = X[last_idx+1:end]
Zygote.ignore() do
for i=1:ceil(l/tracked_steps)-1 # Tracking is swiched off inside this loop
(i == 1 && l%tracked_steps != 0) ? (last_idx = l%tracked_steps) : (last_idx = tracked_steps)
H = broadcast(x -> indices(x, classifier.vocab, "_unk_"), X[1:last_idx])
H = classifier.rnn_layers.(H)
X = X[last_idx+1:end]
end
end
# set the lated hidden states to original model
for (t_layer, unt_layer) in zip(tc.rnn_layers[2:end], classifier.rnn_layers[2:end])
Expand Down Expand Up @@ -130,7 +132,7 @@ Arguments:
classifier : Instance of TextClassifier
gen : 'Channel' [data loader], to give a mini-batch
tracked_words : specifies the number of time-steps for which tracking is on
tracked_steps : specifies the number of time-steps for which tracking is on
"""
function loss(classifier::TextClassifier, gen::Channel, tracked_steps::Integer=32)
H = forward(classifier, gen, tracked_steps)
Expand All @@ -140,6 +142,23 @@ function loss(classifier::TextClassifier, gen::Channel, tracked_steps::Integer=3
return l
end

function discriminative_step!(layers, classifier::TextClassifier, gen::Channel, tracked_steps::Integer, ηL::Float64, opts::Vector)
@assert length(opts) == length(layers)
# Gradient calculation
grads = Zygote.gradient(() -> loss(classifier, gen, tracked_steps = tracked_steps), get_trainable_params(layers))

# discriminative step
ηl = ηL/(2.6^(length(layers)-1))
for (layer, opt) in zip(layers, opts)
opt.eta = ηl
for ps in get_trainable_params([layer])
Flux.Optimise.update!(opt, ps, grads[ps])
end
ηl *= 2.6
end
return
end

"""
train_classifier!(classifier::TextClassifier=TextClassifier(), classes::Integer=1,
data_loader::Channel=imdb_classifier_data, hidden_layer_size::Integer=50;kw...)
Expand All @@ -151,7 +170,7 @@ function train_classifier!(classifier::TextClassifier=TextClassifier(), classes:
data_loader::Channel=imdb_classifier_data, hidden_layer_size::Integer=50;
stlr_cut_frac::Float64=0.1, stlr_ratio::Number=32, stlr_η_max::Float64=0.01,
val_loader::Channel=nothing, cross_val_batches::Union{Colon, Integer}=:,
epochs::Integer=1, checkpoint_itvl=5000)
epochs::Integer=1, checkpoint_itvl=5000, tracked_steps::Integer=32)

trainable = []
append!(trainable, [classifier.rnn_layers[[1, 3, 5, 7]]...])
Expand All @@ -166,7 +185,6 @@ function train_classifier!(classifier::TextClassifier=TextClassifier(), classes:
num_of_iters = take!(gen)
cut = num_of_iters * epochs * stlr_cut_frac
for iter=1:num_of_iters
l = loss(classifier, gen, now_per_pass = now_per_pass)

# Slanted triangular learning rates
t = iter + (epoch-1)*num_of_iters
Expand All @@ -175,7 +193,7 @@ function train_classifier!(classifier::TextClassifier=TextClassifier(), classes:

# Gradual-unfreezing Step with discriminative fine-tuning
unfreezed_layers, cur_opts = (epoch < length(trainable)) ? (trainable[end-epoch+1:end], opts[end-epoch+1:end]) : (trainable, opts)
discriminative_step!(unfreezed_layers, ηL, l, cur_opts)
discriminative_step!(unfreezed_layers, classifier, gen, tracked_steps,ηL, cur_opts)

reset_masks!.(classifier.rnn_layers) # reset all dropout masks
end
Expand Down

0 comments on commit f10678d

Please sign in to comment.