-
Notifications
You must be signed in to change notification settings - Fork 96
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #162 from Ayushk4/CRF_patch
Conditional Random Fields
- Loading branch information
Showing
14 changed files
with
3,000 additions
and
13 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,130 @@ | ||
# Conditional Random Fields | ||
|
||
This package currently provides support for Linear Chain Conditional Random Fields. | ||
|
||
Let us first load the dependencies- | ||
|
||
using Flux | ||
using Flux: onehot, train!, Params, gradient, LSTM, Dense, reset! | ||
using TextAnalysis: CRF, viterbi_decode, crf_loss | ||
|
||
Conditional Random Field layer is essentially like a softmax that operates on the top most layer. | ||
|
||
Let us suppose the following input seqeunce to the CRF with `NUM_LABELS = 2` | ||
|
||
```julia | ||
julia> SEQUENCE_LENGTH = 2 # CRFs can handle variable length inputs sequences | ||
julia> input_seq = [rand(NUM_LABELS + 2) for i in 1:SEQUENCE_LENGTH] # NUM_LABELS + 2, where two exra features correspond to the :START and :END label. | ||
2-element Array{Array{Float64,1},1}: | ||
[0.523462, 0.455434, 0.274347, 0.755279] | ||
[0.610991, 0.315381, 0.0863632, 0.693031] | ||
|
||
``` | ||
|
||
We define our crf layer as - | ||
|
||
CRF(NUM_LABELS::Integer) | ||
|
||
```julia | ||
julia> c = CRF(NUM_LABELS) # The API internally append the START and END tags to NUM_LABELS. | ||
CRF with 4 distinct tags (including START and STOP tags). | ||
``` | ||
|
||
Now as for the initial variable in Viterbi Decode or Forward Algorithm, | ||
we define our input as | ||
|
||
```julia | ||
julia> init_α = fill(-10000, (c.n + 2, 1)) | ||
julia> init_α[c.n + 1] = 0 | ||
``` | ||
|
||
Optionally this could be shifted to GPU by `init_α = gpu(init_α)`, | ||
considering the input sequence to be CuArray in this case. | ||
To shift a CRF `c` to gpu, one can use `c = gpu(c)`. | ||
|
||
To find out the crf loss, we use the following function - | ||
|
||
crf_loss(c::CRF, input_seq, label_sequence, init_α) | ||
|
||
``` | ||
julia> label_seq1 = [onehot(1, 1:2), onehot(1, 1:2)] | ||
julia> label_seq2 = [onehot(1, 1:2), onehot(2, 1:2)] | ||
julia> label_seq3 = [onehot(2, 1:2), onehot(1, 1:2)] | ||
julia> label_seq4 = [onehot(2, 1:2), onehot(2, 1:2)] | ||
julia> crf_loss(c, input_seq, label_seq1, init_α) | ||
1.9206894963901504 (tracked) | ||
julia> crf_loss(c, input_seq, label_seq2, init_α) | ||
1.4972745472075206 (tracked) | ||
julia> crf_loss(c, input_seq, label_seq3, init_α) | ||
1.543210471592448 (tracked) | ||
julia> crf_loss(c, input_seq, label_seq4, init_α) | ||
0.876923329893466 (tracked) | ||
``` | ||
|
||
We can decode this using Viterbi Decode. | ||
|
||
viterbi_decode(c::CRF, input_seq, init_α) | ||
|
||
```julia | ||
julia> viterbi_decode(c, input_seq, init_α) # Gives the label_sequence with least loss | ||
2-element Array{Flux.OneHotVector,1}: | ||
[false, true] | ||
[false, true] | ||
|
||
``` | ||
|
||
This algorithm decodes for the label sequence with lowest loss value in polynomial time. | ||
|
||
Currently the Viterbi Decode only support cpu arrays. | ||
When working with GPU, use viterbi_decode as follows | ||
|
||
viterbi_decode(cpu(c), cpu.(input_seq), cpu(init_α)) | ||
|
||
### Working with Flux layers | ||
|
||
CRFs smoothly work over Flux layers- | ||
|
||
```julia | ||
julia> NUM_FEATURES = 20 | ||
|
||
julia> input_seq = [rand(NUM_FEATURES) for i in 1:SEQUENCE_LENGTH] | ||
2-element Array{Array{Float64,1},1}: | ||
[0.948219, 0.719964, 0.352734, 0.0677656, 0.570564, 0.187673, 0.525125, 0.787807, 0.262452, 0.472472, 0.573259, 0.643369, 0.00592054, 0.945258, 0.951466, 0.323156, 0.679573, 0.663285, 0.218595, 0.152846] | ||
[0.433295, 0.11998, 0.99615, 0.530107, 0.188887, 0.897213, 0.993726, 0.0799431, 0.953333, 0.941808, 0.982638, 0.0919345, 0.27504, 0.894169, 0.66818, 0.449537, 0.93063, 0.384957, 0.415114, 0.212203] | ||
|
||
julia> m1 = Dense(NUM_FEATURES, NUM_LABELS + 2) | ||
|
||
julia> loss1(input_seq, label_seq) = crf_loss(c, m1.(input_seq), label_seq, init_α) # loss for model m1 | ||
|
||
julia> loss1(input_seq, [onehot(1, 1:2), onehot(1, 1:2)]) | ||
4.6620379898687485 (tracked) | ||
|
||
``` | ||
|
||
|
||
Here is an example of CRF with LSTM and Dense layer - | ||
|
||
```julia | ||
julia> LSTM_SIZE = 10 | ||
|
||
julia> lstm = LSTM(NUM_FEATURES, LSTM_SIZE) | ||
|
||
julia> dense_out = Dense(LSTM_SIZE, NUM_LABELS + 2) | ||
|
||
julia> m2(x) = dense_out.(lstm.(x)) | ||
|
||
julia> loss2(input_seq, label_seq) = crf_loss(c, m2(input_seq), label_seq, init_α) # loss for model m2 | ||
|
||
julia> loss2(input_seq, [onehot(1, 1:2), onehot(1, 1:2)]) | ||
1.6501050910529504 (tracked) | ||
|
||
julia> reset!(lstm) | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
""" | ||
Linear Chain - CRF Layer. | ||
For input sequence `x`, | ||
predicts the most probable tag sequence `y`, | ||
over the set of all possible tagging sequences `Y`. | ||
In this CRF, two kinds of potentials are defined, | ||
emission and Transition. | ||
""" | ||
mutable struct CRF{S} | ||
W::S # Transition Scores | ||
n::Int # Num Labels | ||
end | ||
|
||
""" | ||
Second last index for start tag, | ||
last one for stop tag . | ||
""" | ||
function CRF(n::Integer) | ||
W = rand(Float32, n + 2, n + 2) | ||
W[:, n + 1] .= -10000 | ||
W[n + 2, :] .= -10000 | ||
|
||
return CRF(param(W), n) | ||
end | ||
|
||
@treelike CRF | ||
|
||
function Base.show(io::IO, c::CRF) | ||
print(io, "CRF with ", c.n + 2, " distinct tags (including START and STOP tags).") | ||
end | ||
|
||
function (a::CRF)(x_seq, init_α) | ||
viterbi_decode(a, x_seq, init_α) | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
""" | ||
log_sum_exp(z::Array) | ||
A stable implementation f(x) = log ∘ sum ∘ exp (x). | ||
Since exponentiation can lead to very large numbers. | ||
""" | ||
log_sum_exp(z) = log_sum_exp(z, maximum(z, dims = 1)) | ||
log_sum_exp(z, m) = log.(sum(exp.(z .- m), dims = 1)) .+ m |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
""" | ||
forward_score(c::CRF, x::Array) | ||
Compute the Normalization / partition function | ||
or the Forward Algorithm score - `Z` | ||
""" | ||
function forward_score(c::CRF, x, init_α) | ||
forward_var = log_sum_exp((c.W .+ x[1]') .+ init_α) | ||
|
||
for i in 2:length(x) | ||
forward_var = log_sum_exp((c.W .+ x[i]') .+ forward_var') | ||
end | ||
|
||
return log_sum_exp(c.W[:, c.n + 2] + forward_var')[1] | ||
end | ||
|
||
""" | ||
score_sequence(c::CRF, xs, label_seq) | ||
Calculating the score of the desired `label_seq` against sequence `xs`. | ||
Not exponentiated as required for negative log likelihood, | ||
thereby preventing operation. | ||
`label_seq`<:Array/ CuArray | ||
eltype(label_seq) = Flux.OneHotVector | ||
""" | ||
function score_sequence(c::CRF, x, label_seq) | ||
score = preds_first(c, label_seq[1]) + onecold(label_seq[1], x[1]) | ||
|
||
for i in 2:length(label_seq) | ||
score += preds_single(c, label_seq[i], label_seq[i-1]) + | ||
onecold(label_seq[i], x[i]) | ||
end | ||
|
||
return score + preds_last(c, label_seq[end]) | ||
end | ||
|
||
# REGULARIZATION TERM | ||
crf_loss(c::CRF, x, label_seq, init_α) = forward_score(c, x, init_α) - score_sequence(c, x, label_seq) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
# Decoding is done by using Viterbi Algorithm | ||
# Computes in polynomial time | ||
|
||
""" | ||
Scores for the first tag in the tagging sequence. | ||
""" | ||
function preds_first(c::CRF, y) | ||
c.W[c.n + 1, onecold(y, 1:length(y))] | ||
end | ||
|
||
""" | ||
Scores for the last tag in the tagging sequence. | ||
""" | ||
function preds_last(c::CRF, y) | ||
c.W[onecold(y, 1:length(y)), c.n + 2] | ||
end | ||
|
||
""" | ||
Scores for the tags other than the starting one. | ||
""" | ||
function preds_single(c::CRF, y, y_prev) | ||
c.W[onecold(y_prev, 1:length(y_prev)), onecold(y, 1:length(y))] | ||
end | ||
|
||
# Helper for forward pass, returns max_probs and corresponding arg_max for all the labels | ||
function forward_pass_unit(k) | ||
α_idx = [i[1] for i in argmax(k, dims=1)] | ||
α = [k[j, i] for (i,j) in enumerate(α_idx)] | ||
return α, α_idx | ||
end | ||
|
||
""" | ||
Computes the forward pass for viterbi algorithm. | ||
""" | ||
function _decode(c::CRF, x, init_vit_vars) | ||
α_idx = zeros(Int, c.n + 2, length(x)) | ||
|
||
forward_var, α_idx[:, 1] = forward_pass_unit(Tracker.data((c.W .+ x[1]') .+ init_vit_vars)) | ||
|
||
for i in 2:length(x) | ||
forward_var, α_idx[:, i] = forward_pass_unit(Tracker.data((c.W .+ x[i]') .+ forward_var')) | ||
end | ||
|
||
labels = zeros(Int, length(x)) | ||
labels[end] = argmax(forward_var + Tracker.data(c.W[:, c.n + 2])')[2] | ||
|
||
for i in reverse(2:length(x)) | ||
labels[i - 1] = α_idx[labels[i], i] | ||
end | ||
|
||
@assert α_idx[labels[1], 1] == c.n + 1 # Check for START Tag | ||
return onehotseq(labels, c.n) | ||
end | ||
|
||
onehotseq(seq, num_labels) = [onehot(i, 1:num_labels) for i in seq] | ||
|
||
""" | ||
viterbi_decode(::CRF, input_sequence) | ||
Predicts the most probable label sequence of `input_sequence`. | ||
""" | ||
function viterbi_decode(c::CRF, x_seq, init_vit_vars) | ||
length(x_seq) == 0 && throw("Input sequence is empty") | ||
return _decode(cpu(c), cpu.(x_seq), cpu(init_vit_vars)) | ||
end | ||
|
||
# function predict(c::CRF, x_seq) | ||
# viterbi_decode(c, x_seq) | ||
# end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.