Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Conditional Random Fields #162

Merged
merged 40 commits into from
Jul 29, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
4a728cf
Create CRF layer
Ayushk4 Jun 8, 2019
20eee89
Switch over to Flux functions
Ayushk4 Jun 10, 2019
56f1868
Add show function for CRF
Ayushk4 Jun 10, 2019
a4d6eae
Add decoding functions for CRF
Ayushk4 Jun 12, 2019
736c1d2
Fix show function for CRF and add functions
Ayushk4 Jun 14, 2019
16c5a21
Add forward pass and backward pass
Ayushk4 Jun 18, 2019
af391ac
Fix viterbi
Ayushk4 Jun 19, 2019
a0a0ea2
Fix CRF and logsumexp
Ayushk4 Jun 19, 2019
48aa08d
Fix bugs in predict.jl
Ayushk4 Jun 19, 2019
c491dc2
Viterbi decode working
Ayushk4 Jun 19, 2019
6f5c580
Merge branch 'master' of https://github.com/JuliaText/TextAnalysis.jl…
Ayushk4 Jun 19, 2019
7853a6c
Fix minor bug in viterbi
Ayushk4 Jun 19, 2019
5fff857
Export CRF
Ayushk4 Jun 22, 2019
c105d0c
Adding tests
Ayushk4 Jun 22, 2019
c4511c9
Add crf loss function
Ayushk4 Jun 23, 2019
3fcb172
Basic tests for CRF
Ayushk4 Jun 23, 2019
d35dc9d
Temp change
Ayushk4 Jun 23, 2019
9ffeb97
Get CRF running
Ayushk4 Jun 24, 2019
a21d0d7
Minor changes to CRF structure
Ayushk4 Jun 24, 2019
84efe6b
Minor code fix
Ayushk4 Jun 24, 2019
fecfe1f
Add more tests
Ayushk4 Jun 24, 2019
b303b9c
Merge branch 'master' of https://github.com/JuliaText/TextAnalysis.jl…
Ayushk4 Jun 28, 2019
5f98acb
minor typo
Ayushk4 Jul 11, 2019
fac2698
Faster loss for CRF
Ayushk4 Jul 14, 2019
0c06fac
More tests for the loss function
Ayushk4 Jul 14, 2019
29e2faa
Stabler implementation for forward algorithm
Ayushk4 Jul 14, 2019
6f6106c
fix loss function
Ayushk4 Jul 14, 2019
6dbcb8c
CRF GPU support
Ayushk4 Jul 16, 2019
1fefd06
Re Structure CRF
Ayushk4 Jul 18, 2019
3e83427
Docstring for stable log_sum_exp
Ayushk4 Jul 18, 2019
d3f92ef
Update loss and forward algo as per new CRF struct
Ayushk4 Jul 18, 2019
5a38fb9
Fix viterbi algo as per CRF struct
Ayushk4 Jul 19, 2019
a652308
Edit TextAnalysis.jl
Ayushk4 Jul 19, 2019
30bb596
Tests passing for CRF
Ayushk4 Jul 19, 2019
0501001
Docs for CRF
Ayushk4 Jul 19, 2019
a18d97f
Missed out a comma in documentation
Ayushk4 Jul 26, 2019
f3b3a49
Default to Float32 and add @treelike
Ayushk4 Jul 28, 2019
0286a23
Minor changed is docs and tests
Ayushk4 Jul 28, 2019
cb16418
Fix docstrings warning in travis
Ayushk4 Jul 28, 2019
2467ae2
Minor fixes in tests
Ayushk4 Jul 28, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ makedocs(
"Features" => "features.md",
"Semantic Analysis" => "semantic.md",
"Classifier" => "classify.md",
"Extended Example" => "example.md"
"Extended Example" => "example.md",
"Conditional Random Fields" => "crf.md"
],
)

Expand Down
130 changes: 130 additions & 0 deletions docs/src/crf.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
# Conditional Random Fields

This package currently provides support for Linear Chain Conditional Random Fields.

Let us first load the dependencies-

using Flux
using Flux: onehot, train!, Params, gradient, LSTM, Dense, reset!
using TextAnalysis: CRF, viterbi_decode, crf_loss

Conditional Random Field layer is essentially like a softmax that operates on the top most layer.

Let us suppose the following input seqeunce to the CRF with `NUM_LABELS = 2`

```julia
julia> SEQUENCE_LENGTH = 2 # CRFs can handle variable length inputs sequences
julia> input_seq = [rand(NUM_LABELS + 2) for i in 1:SEQUENCE_LENGTH] # NUM_LABELS + 2, where two exra features correspond to the :START and :END label.
2-element Array{Array{Float64,1},1}:
[0.523462, 0.455434, 0.274347, 0.755279]
[0.610991, 0.315381, 0.0863632, 0.693031]

```

We define our crf layer as -

CRF(NUM_LABELS::Integer)

```julia
julia> c = CRF(NUM_LABELS) # The API internally append the START and END tags to NUM_LABELS.
CRF with 4 distinct tags (including START and STOP tags).
```

Now as for the initial variable in Viterbi Decode or Forward Algorithm,
we define our input as

```julia
julia> init_α = fill(-10000, (c.n + 2, 1))
julia> init_α[c.n + 1] = 0
```

Optionally this could be shifted to GPU by `init_α = gpu(init_α)`,
considering the input sequence to be CuArray in this case.
To shift a CRF `c` to gpu, one can use `c = gpu(c)`.

To find out the crf loss, we use the following function -

crf_loss(c::CRF, input_seq, label_sequence, init_α)

```
julia> label_seq1 = [onehot(1, 1:2), onehot(1, 1:2)]

julia> label_seq2 = [onehot(1, 1:2), onehot(2, 1:2)]

julia> label_seq3 = [onehot(2, 1:2), onehot(1, 1:2)]

julia> label_seq4 = [onehot(2, 1:2), onehot(2, 1:2)]

julia> crf_loss(c, input_seq, label_seq1, init_α)
1.9206894963901504 (tracked)

julia> crf_loss(c, input_seq, label_seq2, init_α)
1.4972745472075206 (tracked)

julia> crf_loss(c, input_seq, label_seq3, init_α)
1.543210471592448 (tracked)

julia> crf_loss(c, input_seq, label_seq4, init_α)
0.876923329893466 (tracked)

```

We can decode this using Viterbi Decode.

viterbi_decode(c::CRF, input_seq, init_α)

```julia
julia> viterbi_decode(c, input_seq, init_α) # Gives the label_sequence with least loss
2-element Array{Flux.OneHotVector,1}:
[false, true]
[false, true]

```

This algorithm decodes for the label sequence with lowest loss value in polynomial time.

Currently the Viterbi Decode only support cpu arrays.
When working with GPU, use viterbi_decode as follows

viterbi_decode(cpu(c), cpu.(input_seq), cpu(init_α))

### Working with Flux layers

CRFs smoothly work over Flux layers-

```julia
julia> NUM_FEATURES = 20

julia> input_seq = [rand(NUM_FEATURES) for i in 1:SEQUENCE_LENGTH]
2-element Array{Array{Float64,1},1}:
[0.948219, 0.719964, 0.352734, 0.0677656, 0.570564, 0.187673, 0.525125, 0.787807, 0.262452, 0.472472, 0.573259, 0.643369, 0.00592054, 0.945258, 0.951466, 0.323156, 0.679573, 0.663285, 0.218595, 0.152846]
[0.433295, 0.11998, 0.99615, 0.530107, 0.188887, 0.897213, 0.993726, 0.0799431, 0.953333, 0.941808, 0.982638, 0.0919345, 0.27504, 0.894169, 0.66818, 0.449537, 0.93063, 0.384957, 0.415114, 0.212203]

julia> m1 = Dense(NUM_FEATURES, NUM_LABELS + 2)

julia> loss1(input_seq, label_seq) = crf_loss(c, m1.(input_seq), label_seq, init_α) # loss for model m1

julia> loss1(input_seq, [onehot(1, 1:2), onehot(1, 1:2)])
4.6620379898687485 (tracked)

```


Here is an example of CRF with LSTM and Dense layer -

```julia
julia> LSTM_SIZE = 10

julia> lstm = LSTM(NUM_FEATURES, LSTM_SIZE)

julia> dense_out = Dense(LSTM_SIZE, NUM_LABELS + 2)

julia> m2(x) = dense_out.(lstm.(x))

julia> loss2(input_seq, label_seq) = crf_loss(c, m2(input_seq), label_seq, init_α) # loss for model m2

julia> loss2(input_seq, [onehot(1, 1:2), onehot(1, 1:2)])
1.6501050910529504 (tracked)

julia> reset!(lstm)
```
36 changes: 36 additions & 0 deletions src/CRF/crf.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
"""
Linear Chain - CRF Layer.

For input sequence `x`,
predicts the most probable tag sequence `y`,
over the set of all possible tagging sequences `Y`.

In this CRF, two kinds of potentials are defined,
emission and Transition.
"""
mutable struct CRF{S}
W::S # Transition Scores
n::Int # Num Labels
end

"""
Second last index for start tag,
last one for stop tag .
"""
function CRF(n::Integer)
W = rand(Float32, n + 2, n + 2)
W[:, n + 1] .= -10000
W[n + 2, :] .= -10000

return CRF(param(W), n)
end

@treelike CRF

function Base.show(io::IO, c::CRF)
print(io, "CRF with ", c.n + 2, " distinct tags (including START and STOP tags).")
end

function (a::CRF)(x_seq, init_α)
viterbi_decode(a, x_seq, init_α)
end
8 changes: 8 additions & 0 deletions src/CRF/crf_utils.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
"""
log_sum_exp(z::Array)

A stable implementation f(x) = log ∘ sum ∘ exp (x).
Since exponentiation can lead to very large numbers.
"""
log_sum_exp(z) = log_sum_exp(z, maximum(z, dims = 1))
log_sum_exp(z, m) = log.(sum(exp.(z .- m), dims = 1)) .+ m
39 changes: 39 additions & 0 deletions src/CRF/loss.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
"""
forward_score(c::CRF, x::Array)

Compute the Normalization / partition function
or the Forward Algorithm score - `Z`
"""
function forward_score(c::CRF, x, init_α)
forward_var = log_sum_exp((c.W .+ x[1]') .+ init_α)

for i in 2:length(x)
forward_var = log_sum_exp((c.W .+ x[i]') .+ forward_var')
end

return log_sum_exp(c.W[:, c.n + 2] + forward_var')[1]
end

"""
score_sequence(c::CRF, xs, label_seq)

Calculating the score of the desired `label_seq` against sequence `xs`.
Not exponentiated as required for negative log likelihood,
thereby preventing operation.

`label_seq`<:Array/ CuArray
eltype(label_seq) = Flux.OneHotVector
"""
function score_sequence(c::CRF, x, label_seq)
score = preds_first(c, label_seq[1]) + onecold(label_seq[1], x[1])

for i in 2:length(label_seq)
score += preds_single(c, label_seq[i], label_seq[i-1]) +
onecold(label_seq[i], x[i])
end

return score + preds_last(c, label_seq[end])
end

# REGULARIZATION TERM
crf_loss(c::CRF, x, label_seq, init_α) = forward_score(c, x, init_α) - score_sequence(c, x, label_seq)
69 changes: 69 additions & 0 deletions src/CRF/predict.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# Decoding is done by using Viterbi Algorithm
# Computes in polynomial time

"""
Scores for the first tag in the tagging sequence.
"""
function preds_first(c::CRF, y)
c.W[c.n + 1, onecold(y, 1:length(y))]
end

"""
Scores for the last tag in the tagging sequence.
"""
function preds_last(c::CRF, y)
c.W[onecold(y, 1:length(y)), c.n + 2]
end

"""
Scores for the tags other than the starting one.
"""
function preds_single(c::CRF, y, y_prev)
c.W[onecold(y_prev, 1:length(y_prev)), onecold(y, 1:length(y))]
end

# Helper for forward pass, returns max_probs and corresponding arg_max for all the labels
function forward_pass_unit(k)
α_idx = [i[1] for i in argmax(k, dims=1)]
α = [k[j, i] for (i,j) in enumerate(α_idx)]
return α, α_idx
end

"""
Computes the forward pass for viterbi algorithm.
"""
function _decode(c::CRF, x, init_vit_vars)
α_idx = zeros(Int, c.n + 2, length(x))

forward_var, α_idx[:, 1] = forward_pass_unit(Tracker.data((c.W .+ x[1]') .+ init_vit_vars))

for i in 2:length(x)
forward_var, α_idx[:, i] = forward_pass_unit(Tracker.data((c.W .+ x[i]') .+ forward_var'))
end

labels = zeros(Int, length(x))
labels[end] = argmax(forward_var + Tracker.data(c.W[:, c.n + 2])')[2]

for i in reverse(2:length(x))
labels[i - 1] = α_idx[labels[i], i]
end

@assert α_idx[labels[1], 1] == c.n + 1 # Check for START Tag
return onehotseq(labels, c.n)
end

onehotseq(seq, num_labels) = [onehot(i, 1:num_labels) for i in seq]

"""
viterbi_decode(::CRF, input_sequence)

Predicts the most probable label sequence of `input_sequence`.
"""
function viterbi_decode(c::CRF, x_seq, init_vit_vars)
length(x_seq) == 0 && throw("Input sequence is empty")
return _decode(cpu(c), cpu.(x_seq), cpu(init_vit_vars))
end

# function predict(c::CRF, x_seq)
# viterbi_decode(c, x_seq)
# end
12 changes: 12 additions & 0 deletions src/TextAnalysis.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ module TextAnalysis
using DataFrames
using WordTokenizers

using Flux
using Flux: identity, onehot, onecold, @treelike

import DataFrames.DataFrame
import Base.depwarn

Expand Down Expand Up @@ -55,6 +58,8 @@ module TextAnalysis
export rouge_l_summary, rouge_l_sentence, rouge_n
export PerceptronTagger, fit!, predict

export CRF, viterbi_decode, crf_loss

include("tokenizer.jl")
include("ngramizer.jl")
include("document.jl")
Expand Down Expand Up @@ -83,4 +88,11 @@ module TextAnalysis
include("utils.jl")
include("rouge.jl")
include("averagePerceptronTagger.jl")

# CRF
include("CRF/crf.jl")
include("CRF/predict.jl")
include("CRF/crf_utils.jl")
include("CRF/loss.jl")

end
2 changes: 0 additions & 2 deletions src/averagePerceptronTagger.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ AVERAGE PERCEPTRON MODEL

This struct contains the actual Average Perceptron Model
"""

mutable struct AveragePerceptron
classes :: Set
weights :: Dict
Expand Down Expand Up @@ -129,7 +128,6 @@ tagger = PerceptronTagger(true)
To predict tag:
predict(tagger, ["today", "is"])
"""

mutable struct PerceptronTagger
model :: AveragePerceptron
tagdict :: Dict
Expand Down
4 changes: 2 additions & 2 deletions src/document.jl
Original file line number Diff line number Diff line change
Expand Up @@ -171,9 +171,9 @@ A NGramDocument{AbstractString}
* Title: Untitled Document
* Author: Unknown Author
* Timestamp: Unknown Time
* Snippet: ***SAMPLE TEXT NOT AVAILABLE***```
* Snippet: ***SAMPLE TEXT NOT AVAILABLE***
```
"""

function NGramDocument(txt::AbstractString, dm::DocumentMetadata, n::Integer...=1)
NGramDocument(ngramize(dm.language, tokenize(dm.language, String(txt)), n...), (length(n) == 1) ? Int(first(n)) : Int[n...], dm)
end
Expand Down
1 change: 0 additions & 1 deletion src/hash.jl
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@ julia> index_hash("b", h)
7
```
"""

function index_hash(s::AbstractString, h::TextHashFunction)
return Int(rem(h.hash_function(s), h.cardinality)) + 1
end
7 changes: 0 additions & 7 deletions src/sentiment.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
using JSON
using BSON

Flux = nothing # Will be filled once we actually use sentiment analysis

function pad_sequences(l, maxlen=500)
if length(l) <= maxlen
res = zeros(maxlen - length(l))
Expand Down Expand Up @@ -64,10 +62,6 @@ struct SentimentModel
words

function SentimentModel()
# Only load Flux once it is actually needed
global Flux
Flux = Base.require(TextAnalysis, :Flux)

new(read_weights(), read_word_ids())
end
end
Expand Down Expand Up @@ -98,7 +92,6 @@ Predict sentiment of the input doc in range 0 to 1, 0 being least sentiment scor
- doc = Input Document for calculating document (`AbstractDocument` type)
- handle_unknown = A function for handling unknown words. Should return an array (default x->tuple())
"""

function(m::SentimentAnalyzer)(d::AbstractDocument, handle_unknown = x->tuple())
m.model(handle_unknown, tokens(d))
end
Loading