Skip to content

Commit

Permalink
explain +1 in scatter_logits
Browse files Browse the repository at this point in the history
  • Loading branch information
donglihe-hub committed Apr 5, 2024
1 parent 48de2ed commit 4a2a9c4
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 14 deletions.
3 changes: 0 additions & 3 deletions docs/tutorial.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ Tutorials
* `Feature Generation and Parameter Selection for Linear Methods <./auto_examples/plot_linear_gridsearch_tutorial.html>`_
* `Parameter Selection for Neural Networks <tutorials/Parameter_Selection_for_Neural_Networks.html>`_
* `Handling Data with Many Labels <./auto_examples/plot_linear_tree_tutorial.html>`_
* `Implement Extreme Multi-Label Text Classification with AttentionXML <./auto_examples/plot_AttentionXML_tutorial.html>`_


.. toctree::
Expand All @@ -16,5 +15,3 @@ Tutorials
../auto_examples/plot_linear_gridsearch_tutorial
tutorials/Parameter_Selection_for_Neural_Networks
../auto_examples/plot_linear_tree_tutorial
../auto_examples/plot_AttentionXML_tutorial

29 changes: 18 additions & 11 deletions libmultilabel/nn/attentionxml.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def label2cluster(self, cluster_mapping, *labels) -> Generator[csr_matrix, ...]:
Given the ground-truth labels, [0, 1, 4], the resulting clusters are [0, 2].
Args:
cluster_mapping (np.ndarray): mapping from clusters to labels generated by build_label_tree.
cluster_mapping (np.ndarray): mapping from clusters generated by build_label_tree to labels .
*labels (csr_matrix): labels in CSR sparse format.
Returns:
Expand Down Expand Up @@ -169,7 +169,7 @@ def cluster2label(cluster_mapping, clusters, cluster_scores=None):
Also notice that this function deals with DENSE matrix.
Args:
cluster_mapping (np.ndarray): mapping from clusters to labels generated by build_label_tree.
cluster_mapping (np.ndarray): mapping from clusters generated by build_label_tree to labels .
clusters (np.ndarray): predicted clusters from model 0.
cluster_scores (Optional: np.ndarray): predicted scores of each cluster from model 0.
Expand Down Expand Up @@ -234,7 +234,7 @@ def fit(self, datasets):

clusters = np.load(self.get_cluster_path(), allow_pickle=True)

# each y has been mapped to the cluster indices of its parent
# map each y to the parent cluster indices
train_y_clustered, val_y_clustered = self.label2cluster(clusters, train_y, val_y)

trainer = init_trainer(
Expand Down Expand Up @@ -288,8 +288,8 @@ def fit(self, datasets):
model_0 = Model.load_from_checkpoint(best_model_path)

logger.info(
f"Predicting clusters by level-0 model. We then select {self.beam_width} clusters and "
f"extract labels from them for level 1 training."
f"Predicting clusters by level-0 model. We then select {self.beam_width} clusters for each instance and "
f"extract labels from these clusters for level 1 training."
)
# load training and validation data and predict corresponding level 0 clusters
train_dataloader = self.dataloader(PlainDataset(train_x))
Expand Down Expand Up @@ -546,11 +546,14 @@ def scatter_logits(
"""For each instance, we only have predictions on selected labels. This subroutine maps these predictions to
the whole label space. The scores of unsampled labels are set to 0."""
src = torch.sigmoid(logits.detach()) * label_scores
# During validation/testing, many fake labels might exist in a batch for the purpose of padding.
# A fake label has index len(classes) and does not belong to the real label space.
preds = torch.zeros(
labels_selected.size(0), len(self.classes) + 1, device=labels_selected.device, dtype=src.dtype
)
preds.scatter_(dim=1, index=labels_selected, src=src)
# remove dummy labels
# slicing removes fake labels whose index is exactly len(self.classes)
# afterwards, preds is restored to the real label space
preds = preds[:, :-1]
return preds

Expand Down Expand Up @@ -586,6 +589,8 @@ def predict_step(self, batch, batch_idx, dataloader_idx=0):


###################################### Dataset ######################################


class PlainDataset(Dataset):
"""Plain (compared to nn.data_utils.TextDataset) dataset class for multi-label dataset.
WHY EXISTS: The reason why this class is necessary is that it can process labels in sparse format, while TextDataset
Expand Down Expand Up @@ -634,7 +639,7 @@ class PLTDataset(PlainDataset):
x: texts.
y: labels.
num_classes: number of classes.
num_labels_selected: the number of selected labels. Pad any labels that fail to reach this number.
num_labels_selected: the number of selected labels.
labels_selected: sampled predicted labels from model_0. Shape: (len(x), predict_top_k).
label_scores: scores for each label. Shape: (len(x), predict_top_k).
"""
Expand All @@ -661,10 +666,12 @@ def __getitem__(self, idx: int):
if self.y is not None:
item["label"] = self.y[idx].toarray().squeeze(0).astype(np.int32)

# PyTorch requires inputs to be of the same shape. Pad any instances whose length is below num_labels_selected
# train
# PyTorch requires inputs to be of the same shape. Pad any instance with length below num_labels_selected by
# randomly selecting labels.
# training
if self.label_scores is None:
# add real labels when the number is below num_labels_selected
# randomly add real labels when the number is below num_labels_selected
# some labels might be selected more than once
if len(item["labels_selected"]) < self.num_labels_selected:
samples = np.random.randint(
self.num_classes,
Expand All @@ -675,7 +682,7 @@ def __getitem__(self, idx: int):
# val/test/pred
else:
item["label_scores"] = self.label_scores[idx]
# add dummy labels when the number is below num_labels_selected
# add fake labels when the number of labels is below num_labels_selected
if len(item["labels_selected"]) < self.num_labels_selected:
item["label_scores"] = np.concatenate(
[
Expand Down

0 comments on commit 4a2a9c4

Please sign in to comment.