Skip to content

Commit

Permalink
Adapting Huggingfaceapi (#103)
Browse files Browse the repository at this point in the history
* switch to huggingfaceapi.jl

* rewrite config

* organize code, put implementation in separate folder

* add hgf tokenizer

* update bert text encoder

* update env

* remove old Artifacts.toml

* tokenizer warning

* new base model define with attenlib

* update env

* fix text tokenizer test

* move new attenlib model to experimental for it's breaking

* fix display and weight loading

* hgf_cola example
  • Loading branch information
chengchingwen authored Jul 31, 2022
1 parent f8c9044 commit 629132e
Show file tree
Hide file tree
Showing 31 changed files with 1,042 additions and 498 deletions.
14 changes: 13 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@ Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
DelimitedFiles = "8bb1440f-4735-579b-a4ab-409b98df4dab"
Fetch = "bb354801-46f6-40b6-9c3d-d42d7a74c775"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
FuncPipelines = "9ed96fbb-10b6-44d4-99a6-7e2a3dc8861b"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
HTTP = "cd3eb016-35fb-5094-929b-558a96fad6f3"
HuggingFaceApi = "3cc741c3-0c9d-4fbe-84fa-cdec264173de"
InternedStrings = "7d512f48-7fb1-5a58-b986-67e6dc259f01"
JSON = "682c06a0-de6a-54ab-a142-c8b1cf79cde6"
LightXML = "9c8b4983-aa76-5018-a973-4c85ecc9e179"
Expand All @@ -26,15 +28,19 @@ MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
NNlibCUDA = "a00861dc-f156-4864-bf3c-e6376f28a68d"
NeuralAttentionlib = "12afc1b8-fad6-47e1-9132-84abc478905f"
Pickle = "fbb45041-c46e-462f-888f-7c521cafbc2c"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
PrimitiveOneHot = "13d12f88-f12b-451e-9b9f-13b97e01cc85"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
SHA = "ea8e919c-243c-51af-8825-aaa63cd721ce"
Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StructWalk = "31cdf514-beb7-4750-89db-dda9d2eb8d3d"
TextEncodeBase = "f92c20c0-9f2a-4705-8116-881385faba05"
Unicode = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5"
ValSplit = "0625e100-946b-11ec-09cd-6328dd093154"
WordTokenizers = "796a5d58-b03d-544a-977e-18100b691f6e"
ZipFile = "a5390f91-8eb1-5f08-bee0-b1d1ffed6cea"

Expand All @@ -49,18 +55,24 @@ DataDeps = "0.7"
DataStructures = "0.18"
Fetch = "0.1.3"
Flux = "0.13.4"
FuncPipelines = "0.2.1"
Functors = "0.2"
HTTP = "0.9, 1"
HuggingFaceApi = "0.1"
InternedStrings = "0.7"
JSON = "0.21"
LightXML = "0.9"
MacroTools = "0.5"
NNlib = "0.8"
NNlibCUDA = "0.2"
NeuralAttentionlib = "0.1"
Pickle = "0.3"
PrimitiveOneHot = "0.1"
Requires = "1"
TextEncodeBase = "0.5"
Static = "0.7"
StructWalk = "0.2"
TextEncodeBase = "0.5.4"
ValSplit = "0.1"
WordTokenizers = "0.5.6"
ZipFile = "0.9"
julia = "1.6"
Expand Down
100 changes: 100 additions & 0 deletions example/BERT/hgf_cola/train.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
using Transformers.Basic
using Transformers.HuggingFace
using Transformers.Datasets
using Transformers.Datasets: GLUE
using Transformers.BidirectionalEncoder

using Flux
using Flux: pullback, params
import Flux.Optimise: update!
using WordTokenizers

const Epoch = 2
const Batch = 4

const cola = GLUE.CoLA()

function preprocess(batch)
global bertenc, labels
data = encode(bertenc, batch[1])
label = lookup(OneHot, labels, batch[2])
return merge(data, (label = label,))
end

const labels = Basic.Vocab([get_labels(cola)...])

const _bert_model = hgf"bert-base-uncased:forsequenceclassification"
const bertenc = hgf"bert-base-uncased:tokenizer"

const bert_model = todevice(_bert_model)

const ps = params(bert_model)
const opt = ADAM(1e-6)

function acc(p, label)
pred = Flux.onecold(p)
truth = Flux.onecold(label)
sum(pred .== truth) / length(truth)
end

function loss(model, data)
e = model.embed(data.input)
t = model.transformers(e, data.mask)

p = model.classifier.clf(
model.classifier.pooler(
t[:,1,:]
)
)

l = Basic.logcrossentropy(data.label, p)
return l, p
end

function train!()
global Batch
global Epoch
@info "start training: $(args["task"])"
for e = 1:Epoch
@info "epoch: $e"
Flux.trainmode!(bert_model)
datas = dataset(Train, cola)

i = 1
al = zero(Float64)
while (batch = get_batch(datas, Batch)) !== nothing
data = todevice(preprocess(batch))
(l, p), back = pullback(ps) do
y = bert_model(data.input.tok, data.label; token_type_ids = data.input.segment)
(y.loss, y.logits)
end
a = acc(p, data.label)
al += a
grad = back((Flux.Zygote.sensitivity(l), nothing))
i+=1
update!(opt, ps, grad)
mod1(i, 16) == 1 && @info "training" loss=l accuracy=al/i
end

test()
end
end

function test()
@info "testing"
Flux.testmode!(bert_model)
i = 1
al = zero(Float64)
datas = dataset(Dev, cola)
while (batch = get_batch(datas, Batch)) !== nothing
data = todevice(preprocess(batch))
p = bert_model(data.input.tok, data.label; token_type_ids = data.input.segment).logits
a = acc(p, data.label)
al += a
i+=1
end
al /= i
@info "testing" accuracy = al
return al
end

2 changes: 1 addition & 1 deletion example/BERT/main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ function parse_commandline()
"task"
help = "task name"
required = true
range_tester = x-> x ["cola", "mnli", "mrpc"]
range_tester = x-> x ["cola", "mnli", "mrpc", "hgf_cola"]
end

return parse_args(ARGS, s)
Expand Down
1 change: 1 addition & 0 deletions src/Transformers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -89,5 +89,6 @@ using .BidirectionalEncoder

using .HuggingFace

include("./experimental/experimental.jl")

end # module
98 changes: 65 additions & 33 deletions src/bert/textencoder.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
using ..Basic: string_getvalue, check_vocab, TextTokenizer
using FuncPipelines
using TextEncodeBase
using TextEncodeBase: trunc_and_pad, nested2batch, nestedcall
using TextEncodeBase: BaseTokenization, WrappedTokenization, Splittable,
using TextEncodeBase: trunc_and_pad, trunc_or_pad, nested2batch, nestedcall
using TextEncodeBase: BaseTokenization, WrappedTokenization, MatchTokenization, Splittable,
ParentStages, TokenStages, SentenceStage, WordStage, Batch, Sentence, getvalue, getmeta

# bert tokenizer
Expand Down Expand Up @@ -70,11 +71,11 @@ BertTextEncoder(
╰─ target[tok] := nestedcall(string_getvalue, source)
╰─ target[tok] := with_firsthead_tail([CLS], [SEP])(target.tok)
╰─ target[(tok, segment)] := segment_and_concat(target.tok)
╰─ target[trunc_tok] := trunc_and_pad(nothing, [UNK])(target.tok)
╰─ target[trunc_tok] := trunc_and_pad(5, [UNK])(target.tok)
╰─ target[trunc_len] := nestedmaxlength(target.trunc_tok)
╰─ target[mask] := getmask(target.tok, target.trunc_len)
╰─ target[tok] := nested2batch(target.trunc_tok)
╰─ target[segment] := (nested2batch ∘ trunc_and_pad(nothing, 1))(target.segment)
╰─ target[segment] := (nested2batch ∘ trunc_and_pad(5, 1))(target.segment)
╰─ target[input] := (NamedTuple{(:tok, :segment)} ∘ tuple)(target.tok, target.segment)
╰─ target := (target.input, target.mask)
)
Expand All @@ -88,6 +89,7 @@ BertTextEncoder(
├─ vocab = Vocab{String, SizedArray}(size = 28996, unk = [UNK], unki = 101),
├─ startsym = [CLS],
├─ endsym = [SEP],
├─ trunc = 5,
└─ process = Pipelines:
╰─ target[tok] := nestedcall(string_getvalue, source)
╰─ target[tok] := with_firsthead_tail([CLS], [SEP])(target.tok)
Expand All @@ -105,6 +107,7 @@ struct BertTextEncoder{T<:AbstractTokenizer,
process::P
startsym::String
endsym::String
padsym::String
trunc::Union{Nothing, Int}
end

Expand All @@ -116,47 +119,75 @@ BertTextEncoder(::typeof(bert_uncased_tokenizer), args...; kws...) =
BertTextEncoder(BertUnCasedPreTokenization(), args...; kws...)
BertTextEncoder(bt::BertTokenization, wordpiece::WordPiece, args...; kws...) =
BertTextEncoder(WordPieceTokenization(bt, wordpiece), args...; kws...)
BertTextEncoder(t::WordPieceTokenization, args...; kws...) =
BertTextEncoder(TextTokenizer(t), Vocab(t.wordpiece), args...; kws...)
BertTextEncoder(t::AbstractTokenization, vocab::AbstractVocabulary, args...; kws...) =
BertTextEncoder(TextTokenizer(t), vocab, args...; kws...)
function BertTextEncoder(t::WordPieceTokenization, args...; match_tokens = nothing, kws...)
if isnothing(match_tokens)
return BertTextEncoder(TextTokenizer(t), Vocab(t.wordpiece), args...; kws...)
else
match_tokens = match_tokens isa AbstractVector ? match_tokens : [match_tokens]
return BertTextEncoder(TextTokenizer(MatchTokenization(t, match_tokens)), Vocab(t.wordpiece), args...; kws...)
end
end
function BertTextEncoder(t::AbstractTokenization, vocab::AbstractVocabulary, args...; match_tokens = nothing, kws...)
if isnothing(match_tokens)
return BertTextEncoder(TextTokenizer(t), vocab, args...; kws...)
else
match_tokens = match_tokens isa AbstractVector ? match_tokens : [match_tokens]
return BertTextEncoder(TextTokenizer(MatchTokenization(t, match_tokens)), vocab, args...; kws...)
end
end

function BertTextEncoder(tkr::AbstractTokenizer, vocab::AbstractVocabulary, process;
startsym = "[CLS]", endsym = "[SEP]", trunc = nothing)
startsym = "[CLS]", endsym = "[SEP]", padsym = "[PAD]", trunc = nothing)
check_vocab(vocab, startsym) || @warn "startsym $startsym not in vocabulary, this might cause problem."
check_vocab(vocab, endsym) || @warn "endsym $endsym not in vocabulary, this might cause problem."
return BertTextEncoder(tkr, vocab, process, startsym, endsym, trunc)
return BertTextEncoder(tkr, vocab, process, startsym, endsym, padsym, trunc)
end

function BertTextEncoder(tkr::AbstractTokenizer, vocab::AbstractVocabulary; kws...)
function BertTextEncoder(tkr::AbstractTokenizer, vocab::AbstractVocabulary; fixedsize = false, kws...)
enc = BertTextEncoder(tkr, vocab, TextEncodeBase.process(AbstractTextEncoder); kws...)
# default processing pipelines for bert encoder
return BertTextEncoder(enc) do e
# get token and convert to string
Pipeline{:tok}(nestedcall(string_getvalue), 1) |>
# add start & end symbol
Pipeline{:tok}(with_firsthead_tail(e.startsym, e.endsym), :tok) |>
# compute segment and merge sentences
Pipeline{(:tok, :segment)}(segment_and_concat, :tok) |>
# truncate input that exceed length limit and pad them to have equal length
Pipeline{:trunc_tok}(trunc_and_pad(e.trunc, e.vocab.unk), :tok) |>
# get the truncated length
Pipeline{:trunc_len}(TextEncodeBase.nestedmaxlength, :trunc_tok) |>
# get mask with specific length
Pipeline{:mask}(getmask, (:tok, :trunc_len)) |>
# convert to dense array
Pipeline{:tok}(nested2batch, :trunc_tok) |>
# truncate & pad segment
Pipeline{:segment}(nested2batchtrunc_and_pad(e.trunc, 1), :segment) |>
# input namedtuple
Pipeline{:input}(NamedTuple{(:tok, :segment)}tuple, (:tok, :segment)) |>
# return input and mask
PipeGet{(:input, :mask)}()
bert_default_preprocess(; trunc = e.trunc, startsym = e.startsym, endsym = e.endsym, padsym = e.padsym, fixedsize)
end
end

BertTextEncoder(builder, e::BertTextEncoder) =
BertTextEncoder(e.tokenizer, e.vocab, builder(e), e.startsym, e.endsym, e.trunc)
BertTextEncoder(e.tokenizer, e.vocab, builder(e), e.startsym, e.endsym, e.padsym, e.trunc)

# preprocess

function bert_default_preprocess(; trunc = nothing, startsym = "[CLS]", endsym = "[SEP]", padsym = "[PAD]", fixedsize = false)
if fixedsize
@assert !isnothing(trunc) "`fixedsize=true` but `trunc` is not set."
truncf = trunc_or_pad
else
truncf = trunc_and_pad
end
# get token and convert to string
return Pipeline{:tok}(nestedcall(string_getvalue), 1) |>
# add start & end symbol
Pipeline{:tok}(with_firsthead_tail(startsym, endsym), :tok) |>
# compute segment and merge sentences
Pipeline{(:tok, :segment)}(segment_and_concat, :tok) |>
# truncate input that exceed length limit and pad them to have equal length
Pipeline{:trunc_tok}(truncf(trunc, padsym), :tok) |>
# get the truncated length
(fixedsize ?
Pipeline{:trunc_len}(FuncPipelines.FixRest(identity, trunc), 0) :
Pipeline{:trunc_len}(TextEncodeBase.nestedmaxlength, :trunc_tok)
) |>
# get mask with specific length
Pipeline{:mask}(getmask, (:tok, :trunc_len)) |>
# convert to dense array
Pipeline{:tok}(nested2batch, :trunc_tok) |>
# truncate & pad segment
Pipeline{:segment}(truncf(trunc, 1), :segment) |>
Pipeline{:segment}(nested2batch, :segment) |>
# input namedtuple
Pipeline{:input}(NamedTuple{(:tok, :segment)}tuple, (:tok, :segment)) |>
# return input and mask
PipeGet{(:input, :mask)}()
end

# encoder behavior

Expand Down Expand Up @@ -253,7 +284,8 @@ function Base.show(io::IO, e::BertTextEncoder)
print(io, e.tokenizer, ",\n├─ ")
print(io, "vocab = ", e.vocab, ",\n├─ ")
print(io, "startsym = ", e.startsym, ",\n├─ ")
print(io, "endsym = ", e.endsym)
print(io, "endsym = ", e.endsym, ",\n├─ ")
print(io, "padsym = ", e.padsym)
isnothing(e.trunc) || print(io, ",\n├─ trunc = ", e.trunc)
print(IOContext(io, :pipeline_display_prefix => " ╰─ "), ",\n└─ process = ", e.process, "\n)")
end
Expand Down
5 changes: 5 additions & 0 deletions src/experimental/experimental.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
module Experimental

include("./model.jl")

end
Loading

0 comments on commit 629132e

Please sign in to comment.