-
Notifications
You must be signed in to change notification settings - Fork 1
/
loss.py
173 lines (130 loc) · 7.78 KB
/
loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
from typing import List, NamedTuple, Optional
import torch
from torch.nn.functional import cross_entropy
# noinspection PyProtectedMember
from torch.nn.modules.loss import _Loss as Loss # The preceding comment line is to suppress a spurious warning message.
from features import Alphabet
from morpheme import Morpheme
class Dimensions(NamedTuple):
a: int
"""Number of symbols in the alphabet"""
b: int
"""Number of morphemes in a batch"""
c: int
"""Size of a symbol vector"""
m: int
"""Number of symbols in a morpheme"""
class UnbindingLoss(Loss):
"""
This criterion generalizes cross-entropy loss
over predicted symbol vectors in a tensor product representation of a morpheme
and the corresponding gold standard vectors, using cosine similarity.
"""
def __init__(
self,
alphabet: Alphabet,
device: torch.device,
weight: Optional[torch.Tensor] = None,
reduction: str = "mean"
):
super().__init__(reduction=reduction)
self.register_buffer("alpha_tensor", torch.stack([torch.tensor(symbol.vector, dtype=torch.float, device=device)
for symbol in alphabet]))
"""Tensor containing gold standard vector representations for each symbol in the alphabet"""
self.ignore_index = alphabet.index_of(alphabet.pad)
"""When calculating the loss, padding should not contribute to the loss. See torch.nn.CrossEntropyLoss"""
self.weight = weight
"""See torch.nn.CrossEntropyLoss for details"""
self.alphabet = alphabet
self.a: int = len(alphabet)
"""Number of symbols in the alphabet"""
self.c: int = len(alphabet.pad.vector)
"""Size of a symbol vector"""
def check_dimensions(self, predicted: torch.Tensor, label: torch.Tensor=None) -> Dimensions:
errors: List[str] = list()
if len(predicted.shape) != 3:
errors.append(f"Predicted tensor should have 3 dimensions but actually has {len(predicted.shape)}.")
if label is not None and len(label.shape) != 3:
errors.append(f"Label tensor should have 3 dimensions but actually has {len(label.shape)}.")
if len(errors) > 0:
raise ValueError("\n".join(errors))
if label is not None and predicted.shape[0] != label.shape[0]:
errors.append(f"Initial dimension (representing batch size)" +
f" of predicted and label tensors must be the same, but is not:" +
f" {predicted.shape[0]} != {label.shape[0]}.")
if label is not None and predicted.shape[1] != label.shape[1]:
errors.append(f"Second dimension (representing number of symbols per morpheme)" +
f" of predicted and label tensors must be the same, but is not:"
f" {predicted.shape[1]} != {label.shape[1]}.")
if predicted.shape[2] != self.c:
errors.append(f"Final dimension of predicted tensor must match" +
f" expected size of symbol vector, but does not:" +
f" {predicted.shape[2]} != {self.c}.")
if label is not None and label.shape[2] != self.c:
errors.append(f"Final dimension of label tensor must match" +
f" expected size of symbol vector, but does not:" +
f" {label.shape[2]} != {self.c}.")
if len(errors) > 0:
raise ValueError("\n".join(errors))
else:
return Dimensions(a=self.a, b=predicted.shape[0], c=self.c, m=predicted.shape[1])
def calculate_cosine_similarity(self, morpheme_tpr: torch.Tensor, dimensions: Dimensions) -> torch.Tensor:
"""
Computes the cosine similarity between predicted symbol vectors and gold standard symbol vectors.
predicted: Pytorch tensor with shape (b,m,c) representing a batch (b) of morphemes,
where each morpheme consists of a sequence of m symbols,
and each symbol is represented by a vector of length c
gold: Pytorch tensor with shape (a,c) representing the gold standard symbol vectors
for each symbol in the alphabet. The alphabet consists of a symbols.
The meaning of the above dimensions is as follows:
* a -> the number of symbols in the alphabet
* b -> batch size; the number of morphemes in the predicted tensor
* c -> the size of an individual symbol vector
* m -> the maximum number of symbols in a morpheme
This function calculates a Pytorch tensor of shape (b,m,a) representing the cosine similarity
between a predicted character at a given position in a morpheme and all symbols in the alphabet:
Let pred represent the symbol vector predicted[b][m], with length c
Let gold represent the symbol vector gold[a], with length c
Then cosine_similarity[b][m][a] = (pred • gold) / ||pred|| ||gold||
and this value represents the cosine similarity between
the predicted m^th symbol vector in the b^th batch and
the symbol vector for the a^th symbol in the alphabet.
"""
# Calculate the dot product between predicted[y][z] and self.alpha_tensor[x]
# for each batch y in range(0, b),
# each character position z in range(0, m),
# and each symbol index x in range(0, a).
dot_product = torch.einsum("bmc,ac->bma", morpheme_tpr, self.alpha_tensor)
# Calculate a tensor of shape (a),
# where gold_norm[x] is the Euclidean norm of the x^th symbol vector in the alphabet
gold_norm = torch.norm(self.alpha_tensor, p=2, dim=-1)
# Calculate a tensor of shape (b,m),
# where pred_norm[y][z] is the Euclidean norm of predicted symbol vector at position z of batch y
pred_norm = torch.norm(morpheme_tpr, p=2, dim=-1)
# Expand gold_norm to have the same shape as dot_product (b,m,a)
reshaped_gold_norm = gold_norm.unsqueeze(0).unsqueeze(0).expand(dimensions.b, dimensions.m, -1)
# Expand pred_norm to have the same shape as dot_product (b,m,a)
reshaped_pred_norm = pred_norm.unsqueeze(-1).expand(-1, -1, dimensions.a)
# cosine_similarity = (pred • gold) / (||pred|| * ||gold||)
cosine_similarity = dot_product / (reshaped_gold_norm * reshaped_pred_norm)
return cosine_similarity # Shape: (b,m,a)
def forward(self, predicted_tpr, label_tpr):
# Verify that the arguments have the expected shape, and capture the value of each named dimension
dimensions = self.check_dimensions(predicted_tpr, label_tpr)
# Calculate and reshape cosine_similarity into shape (b*m, a)
cosine_similarity = self.calculate_cosine_similarity(predicted_tpr, dimensions).view(-1, self.a)
# Get the index of the correct symbol for each morpheme position in each batch, resulting in shape (b*m)
gold_labels = self.calculate_cosine_similarity(label_tpr, dimensions).view(-1, self.a).argmax(dim=-1)
return cross_entropy(
cosine_similarity,
gold_labels,
weight=self.weight,
ignore_index=self.ignore_index,
)
def unbind(self, predicted_tpr: torch.Tensor) -> List[Morpheme]:
dimensions: Dimensions = self.check_dimensions(predicted_tpr)
cosine_similarity = self.calculate_cosine_similarity(predicted_tpr, dimensions)
predicted_labels = cosine_similarity.view(-1, dimensions.a).argmax(dim=-1).view(dimensions.b, dimensions.m)
return [Morpheme(graphemes=[self.alphabet[i] for i in predicted_labels[b].tolist()],
tpr=predicted_tpr[b].tolist())
for b in range(dimensions.b)]