Skip to content

Commit

Permalink
Update ULMFiT model
Browse files Browse the repository at this point in the history
  • Loading branch information
AdarshKumar712 committed Apr 30, 2021
1 parent 42a0e06 commit 2977425
Show file tree
Hide file tree
Showing 10 changed files with 91 additions and 81 deletions.
3 changes: 3 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,15 @@ version = "0.1.1"

[deps]
BSON = "fbb218c0-5317-5bc6-957e-2ee96dd4b1f0"
CorpusLoaders = "214a0ac2-f95b-54f7-a80b-442ed9c2c9e8"
DataDeps = "124859b0-ceae-595e-8997-d05f6a7a8dfe"
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
DelimitedFiles = "8bb1440f-4735-579b-a4ab-409b98df4dab"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
JSON = "682c06a0-de6a-54ab-a142-c8b1cf79cde6"
Languages = "8ef0a80b-9436-5d2c-a485-80b904378c43"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
TextAnalysis = "a2db99b7-8b79-58f8-94bf-bbc811eef33d"
Expand Down
33 changes: 17 additions & 16 deletions src/TextModels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,27 +39,28 @@ module TextModels


# ULMFiT
#module ULMFiT
# using ..TextAnalysis
# using DataDeps
# using Flux
# using Tracker
# using BSON
# include("ULMFiT/utils.jl")
# include("ULMFiT/datadeps.jl")
# include("ULMFiT/data_loaders.jl")
# include("ULMFiT/custom_layers.jl")
# include("ULMFiT/pretrain_lm.jl")
# include("ULMFiT/fine_tune_lm.jl")
# include("ULMFiT/train_text_classifier.jl")
#end
#export ULMFiT
module ULMFiT
using TextAnalysis
using DataDeps
using Flux
using Zygote
using BSON
using CorpusLoaders
include("ULMFiT/utils.jl")
include("ULMFiT/datadeps.jl")
include("ULMFiT/data_loaders.jl")
include("ULMFiT/custom_layers.jl")
include("ULMFiT/pretrain_lm.jl")
include("ULMFiT/fine_tune_lm.jl")
include("ULMFiT/train_text_classifier.jl")
end
export ULMFiT

function __init__()
pos_tagger_datadep_register()
ner_datadep_register()
pos_datadep_register()
#ULMFiT.ulmfit_datadep_register()
ULMFiT.ulmfit_datadep_register()

global sentiment_model = artifact"sentiment_model"
end
Expand Down
53 changes: 30 additions & 23 deletions src/ULMFiT/custom_layers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ This file contains the custom layers defined for this model:
PooledDense
"""

import Flux: gate, _testmode!, _dropout_kernel
import Flux: gate, testmode!, _dropout_kernel

reset_masks!(entity) = nothing
reset_probability!(entity) = nothing
Expand Down Expand Up @@ -44,12 +44,12 @@ Moreover this also follows the Vartional DropOut citeria, that is,
the drop mask is remains same for a whole training pass.
This is done by saving the masks in 'maskWi' and 'maskWh' fields
"""
mutable struct WeightDroppedLSTMCell{A, V, M}
mutable struct WeightDroppedLSTMCell{A, V, S, M}
Wi::A
Wh::A
b::V
h::V
c::V
h::S
c::S
p::Float64
maskWi::M
maskWh::M
Expand All @@ -60,17 +60,17 @@ function WeightDroppedLSTMCell(in::Integer, out::Integer, p::Float64=0.0;
init = Flux.glorot_uniform)
@assert 0 p 1
cell = WeightDroppedLSTMCell(
param(init(out*4, in)),
param(init(out*4, out)),
param(init(out*4)),
param(zeros(Float32, out)),
param(zeros(Float32, out)),
init(out*4, in),
init(out*4, out),
init(out*4),
reshape(zeros(Float32, out),out, 1),
reshape(zeros(Float32, out), out, 1),
p,
drop_mask((out*4, in), p),
drop_mask((out*4, out), p),
true
)
cell.b.data[gate(out, 2)] .= 1
cell.b[gate(out, 2)] .= 1
return cell
end

Expand All @@ -88,9 +88,12 @@ function (m::WeightDroppedLSTMCell)((h, c), x)
return (h′, c), h′
end

Flux.@treelike WeightDroppedLSTMCell
Flux.@functor WeightDroppedLSTMCell

_testmode!(m::WeightDroppedLSTMCell, test) = (m.active = !test)
Flux.trainable(m::WeightDroppedLSTMCell) = (m.Wi, m.Wh, m.b, m.h, m.c)

testmode!(m::WeightDroppedLSTMCell, mode=true) =
(m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode; m)

"""
WeightDroppedLSTM(in::Integer, out::Integer, p::Float64=0.0)
Expand All @@ -106,7 +109,7 @@ julia> wd = WeightDroppedLSTM(4, 5, 0.3);
function WeightDroppedLSTM(a...; kw...)
cell = WeightDroppedLSTMCell(a...;kw...)
hidden = (cell.h, cell.c)
return Flux.Recur(cell, hidden, hidden)
return Flux.Recur(cell, hidden)
end

"""
Expand Down Expand Up @@ -155,7 +158,9 @@ end

AWD_LSTM(in::Integer, out::Integer, p::Float64=0.0; kw...) = AWD_LSTM(WeightDroppedLSTM(in, out, p; kw...), -1, [])

Flux.@treelike AWD_LSTM
Flux.@functor AWD_LSTM

Flux.trainable(m::AWD_LSTM) = (m.layer,)

(m::AWD_LSTM)(in) = m.layer(in)

Expand Down Expand Up @@ -184,12 +189,12 @@ function asgd_step!(iter::Integer, layer::AWD_LSTM)
p = get_trainable_params([layer])
avg_fact = 1/max(iter - layer.T + 1, 1)
if avg_fact != 1
layer.accum = layer.accum .+ Tracker.data.(p)
layer.accum = layer.accum .+ p
for (ps, accum) in zip(p, layer.accum)
Tracker.data(ps) .= avg_fact*accum
ps .= avg_fact*accum
end
else
layer.accum = deepcopy(Tracker.data.(p)) # Accumulator for ASGD
layer.accum = deepcopy(p) # Accumulator for ASGD
end
end
return
Expand Down Expand Up @@ -230,7 +235,8 @@ function (vd::VarDrop)(x)
return (x .* vd.mask)
end

_testmode!(vd::VarDrop, test) = (vd.active = !test)
testmode!(m::VarDrop, mode=true) =
(m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode; m)

# method for reseting mask of VarDrop
reset_masks!(vd::VarDrop) = (vd.reset = true)
Expand Down Expand Up @@ -270,7 +276,7 @@ end
function DroppedEmbeddings(in::Integer, embed_size::Integer, p::Float64=0.0;
init = Flux.glorot_uniform)
de = DroppedEmbeddings{AbstractArray, typeof(p)}(
param(init(in, embed_size)),
init(in, embed_size),
p,
drop_mask((in,), p),
true
Expand All @@ -283,9 +289,10 @@ function (de::DroppedEmbeddings)(x::AbstractArray, tying::Bool=false)
return tying ? dropped * x : transpose(dropped[x, :])
end

Flux.@treelike DroppedEmbeddings
Flux.@functor DroppedEmbeddings (emb,)

_testmode!(de::DroppedEmbeddings, test) = (de.active = !test)
testmode!(m::DroppedEmbeddings, mode=true) =
(m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode; m)

function reset_masks!(de::DroppedEmbeddings)
de.mask = drop_mask(de.mask, de.p)
Expand Down Expand Up @@ -324,10 +331,10 @@ PooledDense(W, b) = PooledDense(W, b, identity)

function PooledDense(hidden_sz::Integer, out::Integer, σ = identity;
initW = Flux.glorot_uniform, initb = (dims...) -> zeros(Float32, dims...))
return PooledDense(param(initW(out, hidden_sz*3)), param(initb(out)), σ)
return PooledDense(initW(out, hidden_sz*3), initb(out), σ)
end

Flux.@treelike PooledDense
Flux.@functor PooledDense

function (a::PooledDense)(x)
W, b, σ = a.W, a.b, a.σ
Expand Down
32 changes: 16 additions & 16 deletions src/ULMFiT/data_loaders.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,29 +27,29 @@ function imdb_preprocess(doc::AbstractDocument)
length(word) == 1 && return [word]
return split(word, symbol)
end
text = text(doc)
remove_corrupt_utf8!(text)
remove_case!(text)
prepare!(text, strip_html_tags)
tokens = tokens(text)
text_ = doc
remove_corrupt_utf8!(text_)
remove_case!(text_)
prepare!(text_, strip_html_tags)
tokens_ = tokens(text_)
for symbol in [',', '.', '-', '/', "'s"]
tokens = split_word.(tokens, symbol)
tokens_ = split_word.(tokens_, symbol)
temp = []
for token in tokens
for token_ in tokens_
try
append!(temp, put(token, symbol))
append!(temp, put(token_, symbol))
catch
append!(temp, token)
append!(temp, token_)
end
end
tokens = temp
tokens_ = temp
end
deleteat!(tokens, findall(x -> isequal(x, "")||isequal(x, " "), tokens))
return tokens
deleteat!(tokens_, findall(x -> isequal(x, "")||isequal(x, " "), tokens_))
return tokens_
end

# Loads WikiText-103 corpus and output a Channel to give a mini-batch at each call
function load_wikitext_103(batchsize::Integer, bptt::Integer; type = "train")
function load_wikitext_103(batchsize::Integer=16, bptt::Integer=70; type = "train")
corpuspath = joinpath(datadep"WikiText-103", "wiki.$(type).tokens")
corpus = read(open(corpuspath, "r"), String)
corpus = tokenize(corpus)
Expand All @@ -58,13 +58,13 @@ end

# IMDB Data loaders for Sentiment Analysis specifically
# IMDB data loader for fine-tuning Language Model
function imdb_fine_tune_data(batchsize::Integer, bptt::Integer, num_examples::Integer=50000)
function imdb_fine_tune_data(batchsize::Integer=16, bptt::Integer=70, num_examples::Integer=50000)
imdb_dataset = IMDB("train_unsup")
dataset = []
for path in imdb_dataset.filepaths #extract data from the files in directory and put into channel
for path in imdb_dataset.filepaths[1:num_examples] #extract data from the files in directory and put into channel
open(path) do fileio
cur_text = read(fileio, String)
append!(dataset, imdb_preprocess(cur_text))
append!(dataset, imdb_preprocess(StringDocument(cur_text)))
end #open
end #for
return Channel(x -> generator(x, dataset; batchsize=batchsize, bptt=bptt))
Expand Down
15 changes: 7 additions & 8 deletions src/ULMFiT/fine_tune_lm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,14 @@ NOTE: length(opts) == length(layers)
function discriminative_step!(layers, ηL::Float64, l, opts::Vector)
@assert length(opts) == length(layers)
# Gradient calculation
grads = Tracker.gradient(() -> l, get_trainable_params(layers))
grads = Zygote.gradient(() -> l, get_trainable_params(layers))

# discriminative step
ηl = ηL/(2.6^(length(layers)-1))
for (layer, opt) in zip(layers, opts)
opt.eta = ηl
for ps in get_trainable_params([layer])
Tracker.update!(opt, ps, grads[ps])
Flux.Optimise.update!(opt, ps, grads)
end
ηl *= 2.6
end
Expand All @@ -50,18 +50,17 @@ This function contains main training loops for fine-tuning the language model.
To use this funciton, an instance of LanguageModel and a data loader is needed.
Read the docs for more info about arguments
"""
function fine_tune_lm!(lm::LanguageModel, data_loader::Channel=imdb_fine_tune_data,
stlr_cut_frac::Float64=0.1, stlr_ratio::Float32=32, stlr_η_max::Float64=4e-3;
function fine_tune_lm!(lm=LanguageModel(), data_loader=imdb_fine_tune_data,
stlr_cut_frac::Float64=0.1, stlr_ratio::Float32=Float32(32), stlr_η_max::Float64=4e-3;
epochs::Integer=1, checkpoint_itvl::Integer=5000)

opts = [ADAM(0.001, (0.7, 0.99)) for i=1:4]
gen = data_loader()
num_of_iters = take!(gen)
cut = num_of_iters * epochs * stlr_cut_frac

# Fine-Tuning loops
for epoch=1:epochs
println("\nEpoch: $epoch")
gen = data_loader()
num_of_iters = take!(gen)
T = num_of_iters-Int(floor((num_of_iters*2)/100))
set_trigger!.(T, lm.layers)
for i=1:num_of_iters
Expand Down Expand Up @@ -121,7 +120,7 @@ julia> insert!(vocab, 2, "_pad_")
function set_vocab!(lm::LanguageModel, vocab::Vector)
idxs = indices(vocab, lm.vocab)
lm.vocab = vocab
lm.layers[1].emb = param(Tracker.data(lm.layers[1].emb)[idxs, :])
lm.layers[1].emb = param(lm.layers[1].emb[idxs, :])
lm.layers[1].mask = gpu(drop_mask((length(vocab),), lm.layers[1].p))
return
end
10 changes: 5 additions & 5 deletions src/ULMFiT/pretrain_lm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ function LanguageModel(load_pretrained::Bool=false, vocabpath::String=joinpath(@
return lm
end

Flux.@treelike LanguageModel
Flux.@functor LanguageModel

"""
test_lm(lm::LanguageModel, data_gen, num_of_iters::Integer; unknown_token::String="_unk_")
Expand All @@ -63,7 +63,7 @@ It returns loss, accuracy, precsion, recall and F1 score.
julia> test_lm(lm, data_gen, 200, "<unk")
"""
function test_lm(lm::LanguageModel, data_gen, num_of_iters::Integer; unknown_token::String="_unk_")
model_layers = mapleaves(Tracker.data, lm.layers)
model_layers = lm.layers
testmode!(model_layers)
loss = 0
len = length(vocab)
Expand Down Expand Up @@ -110,8 +110,8 @@ end
function backward!(layers, l, opt)
# Calulating gradients and weights updation
p = get_trainable_params(layers)
grads = Tracker.gradient(() -> l, p)
Tracker.update!(opt, p, grads)
grads = Zygote.gradient(() -> l, p)
Flux.Optimise.update!(opt, p, grads)
return
end

Expand Down Expand Up @@ -182,7 +182,7 @@ SAMPLING...
"""
function sample(starting_text::AbstractDocument, lm::LanguageModel)
testmode!(lm.layers)
model_layers = mapleaves(Tracker.data, lm.layers)
model_layers = lm.layers
tokens = tokens(starting_text)
word_indices = map(x -> indices([x], lm.vocab, "_unk_"), tokens)
h = (model_layers.(word_indices))[end]
Expand Down
4 changes: 2 additions & 2 deletions src/ULMFiT/sentiment.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,12 @@ function BinSentimentClassifier()
)
)
Flux.loadparams!(sc, weights)
sc = mapleaves(Tracker.data, sc)
sc = sc
Flux.testmode!(sc)
return sc
end

Flux.@treelike BinSentimentClassifier
Flux.@functor BinSentimentClassifier

function (sc::BinSentimentClassifier)(x::TokenDocument)
remove_case!(x)
Expand Down
Loading

0 comments on commit 2977425

Please sign in to comment.