-
Notifications
You must be signed in to change notification settings - Fork 75
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* switch to huggingfaceapi.jl * rewrite config * organize code, put implementation in separate folder * add hgf tokenizer * update bert text encoder * update env * remove old Artifacts.toml * tokenizer warning * new base model define with attenlib * update env * fix text tokenizer test * move new attenlib model to experimental for it's breaking * fix display and weight loading * hgf_cola example
- Loading branch information
1 parent
f8c9044
commit 629132e
Showing
31 changed files
with
1,042 additions
and
498 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
using Transformers.Basic | ||
using Transformers.HuggingFace | ||
using Transformers.Datasets | ||
using Transformers.Datasets: GLUE | ||
using Transformers.BidirectionalEncoder | ||
|
||
using Flux | ||
using Flux: pullback, params | ||
import Flux.Optimise: update! | ||
using WordTokenizers | ||
|
||
const Epoch = 2 | ||
const Batch = 4 | ||
|
||
const cola = GLUE.CoLA() | ||
|
||
function preprocess(batch) | ||
global bertenc, labels | ||
data = encode(bertenc, batch[1]) | ||
label = lookup(OneHot, labels, batch[2]) | ||
return merge(data, (label = label,)) | ||
end | ||
|
||
const labels = Basic.Vocab([get_labels(cola)...]) | ||
|
||
const _bert_model = hgf"bert-base-uncased:forsequenceclassification" | ||
const bertenc = hgf"bert-base-uncased:tokenizer" | ||
|
||
const bert_model = todevice(_bert_model) | ||
|
||
const ps = params(bert_model) | ||
const opt = ADAM(1e-6) | ||
|
||
function acc(p, label) | ||
pred = Flux.onecold(p) | ||
truth = Flux.onecold(label) | ||
sum(pred .== truth) / length(truth) | ||
end | ||
|
||
function loss(model, data) | ||
e = model.embed(data.input) | ||
t = model.transformers(e, data.mask) | ||
|
||
p = model.classifier.clf( | ||
model.classifier.pooler( | ||
t[:,1,:] | ||
) | ||
) | ||
|
||
l = Basic.logcrossentropy(data.label, p) | ||
return l, p | ||
end | ||
|
||
function train!() | ||
global Batch | ||
global Epoch | ||
@info "start training: $(args["task"])" | ||
for e = 1:Epoch | ||
@info "epoch: $e" | ||
Flux.trainmode!(bert_model) | ||
datas = dataset(Train, cola) | ||
|
||
i = 1 | ||
al = zero(Float64) | ||
while (batch = get_batch(datas, Batch)) !== nothing | ||
data = todevice(preprocess(batch)) | ||
(l, p), back = pullback(ps) do | ||
y = bert_model(data.input.tok, data.label; token_type_ids = data.input.segment) | ||
(y.loss, y.logits) | ||
end | ||
a = acc(p, data.label) | ||
al += a | ||
grad = back((Flux.Zygote.sensitivity(l), nothing)) | ||
i+=1 | ||
update!(opt, ps, grad) | ||
mod1(i, 16) == 1 && @info "training" loss=l accuracy=al/i | ||
end | ||
|
||
test() | ||
end | ||
end | ||
|
||
function test() | ||
@info "testing" | ||
Flux.testmode!(bert_model) | ||
i = 1 | ||
al = zero(Float64) | ||
datas = dataset(Dev, cola) | ||
while (batch = get_batch(datas, Batch)) !== nothing | ||
data = todevice(preprocess(batch)) | ||
p = bert_model(data.input.tok, data.label; token_type_ids = data.input.segment).logits | ||
a = acc(p, data.label) | ||
al += a | ||
i+=1 | ||
end | ||
al /= i | ||
@info "testing" accuracy = al | ||
return al | ||
end | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -89,5 +89,6 @@ using .BidirectionalEncoder | |
|
||
using .HuggingFace | ||
|
||
include("./experimental/experimental.jl") | ||
|
||
end # module |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
module Experimental | ||
|
||
include("./model.jl") | ||
|
||
end |
Oops, something went wrong.