Skip to content

Commit

Permalink
refactor attentionxml
Browse files Browse the repository at this point in the history
  • Loading branch information
donglihe-hub committed Jan 10, 2024
1 parent 40d3033 commit 8b54fa3
Show file tree
Hide file tree
Showing 11 changed files with 361 additions and 622 deletions.
7 changes: 0 additions & 7 deletions libmultilabel/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,6 @@
import time

import numpy as np
from pytorch_lightning.plugins.environments import SLURMEnvironment, LightningEnvironment

if os.environ.get("SLURM_PROCID") is not None:
env = SLURMEnvironment()
else:
env = LightningEnvironment()
GLOBAL_RANK = env.global_rank()


class AttributeDict(dict):
Expand Down
207 changes: 26 additions & 181 deletions libmultilabel/nn/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,28 +8,24 @@
from collections import defaultdict
from functools import partial
from pathlib import Path
from typing import Sequence

import numpy as np
import pandas as pd
import torch
import transformers
import nltk
from nltk.tokenize import RegexpTokenizer, word_tokenize
from numpy import ndarray
from scipy.sparse import csr_matrix, issparse
from scipy.sparse import issparse
from sklearn.datasets import load_svmlight_file
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MultiLabelBinarizer, normalize
from torch import Tensor
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset
from torchtext.vocab import build_vocab_from_iterator, pretrained_aliases, Vocab, Vectors
from lightning.pytorch.utilities.rank_zero import rank_zero_only, rank_zero_info
from gensim.models import KeyedVectors
from tqdm import tqdm

from ..common_utils import GLOBAL_RANK

transformers.logging.set_verbosity_error()
warnings.simplefilter(action="ignore", category=FutureWarning)

Expand Down Expand Up @@ -95,151 +91,11 @@ def __getitem__(self, index):
}


class MultiLabelDataset(Dataset):
"""Basic class for multi-label dataset."""

def __init__(
self,
x: list[list[int]],
y: csr_matrix | ndarray | None = None,
):
"""General dataset class for multi-label dataset.
Args:
x: text.
y: labels.
"""
if y is not None:
assert len(x) == y.shape[0], "Size mismatch between x and y"
self.x = x
self.y = y

def __getitem__(self, idx: int) -> tuple[Sequence, ndarray] | tuple[Sequence]:
x = self.x[idx]

# train/valid/test
if self.y is not None:
if issparse(self.y):
y = self.y[idx].toarray().squeeze(0).astype(np.float32)
else:
y = self.y[idx].astype(np.float32)
return x, y
# predict
return x

def __len__(self):
return len(self.x)


class PLTDataset(MultiLabelDataset):
"""Dataset class for AttentionXML."""

def __init__(
self,
x,
y: csr_matrix | ndarray | None = None,
*,
num_nodes: int,
mapping: ndarray,
node_label: ndarray | Tensor,
node_score: ndarray | Tensor | None = None,
):
"""Dataset for FastAttentionXML.
~ means variable length.
Args:
x: text
y: labels
num_nodes: number of nodes at the current level.
mapping: [[0,..., 7], [8,..., 15], ...]. shape: (len(nodes), ~cluster_size). parent nodes to child nodes.
Cluster size will only vary at the last level.
node_label: [[7, 1, 128, 6], [21, 85, 64, 103], ...]. shape: (len(x), top_k). numbers are predicted nodes
from last level.
node_score: corresponding scores. shape: (len(x), top_k)
"""
super().__init__(x, y)
self.num_nodes = num_nodes
self.mapping = mapping
self.node_label = node_label
self.node_score = node_score
self.candidate_scores = None

# candidate are positive nodes at the current level. shape: (len(x), ~cluster_size * top_k)
# look like [[0, 1, 2, 4, 5, 18, 19,...], ...]
prog = tqdm(self.node_label, leave=False, desc="Candidates") if GLOBAL_RANK == 0 else self.node_label
self.candidates = [np.concatenate(self.mapping[labels]) for labels in prog]
if self.node_score is not None:
# candidate_scores are corresponding scores for candidates and
# look like [[0.1, 0.1, 0.1, 0.4, 0.4, 0.5, 0.5,...], ...]. shape: (len(x), ~cluster_size * top_k)
# notice how scores repeat for each cluster.
self.candidate_scores = [
np.repeat(scores, [len(i) for i in self.mapping[labels]])
for labels, scores in zip(self.node_label, self.node_score)
]

# top_k * n (n <= cluster_size). number of maximum possible number candidates at the current level.
self.num_candidates = self.node_label.shape[1] * max(len(node) for node in self.mapping)

def __getitem__(self, idx: int):
x = self.x[idx]
candidates = np.asarray(self.candidates[idx], dtype=np.int64)

# train/valid/test
if self.y is not None:
# squeezing is necessary here because csr_matrix.toarray() always returns a 2d array
# e.g., np.ndarray([[0, 1, 2]])
y = self.y[idx].toarray().squeeze(0).astype(np.float32)

# train
if self.candidate_scores is None:
# randomly select nodes as candidates when less than required
if len(candidates) < self.num_candidates:
sample = np.random.randint(self.num_nodes, size=self.num_candidates - len(candidates))
candidates = np.concatenate([candidates, sample])
# randomly select a subset of candidates when more than required
elif len(candidates) > self.num_candidates:
# candidates = np.random.choice(candidates, self.num_candidates, replace=False)
raise ValueError("Too many candidates. Which shouldn't happen.")
return x, y, candidates

# valid/test
else:
candidate_scores = self.candidate_scores[idx]
offset = (self.num_nodes, self.num_candidates - len(candidates))

# add dummy elements when less than required
if len(candidates) < self.num_candidates:
candidate_scores = np.concatenate(
[candidate_scores, [-np.inf] * (self.num_candidates - len(candidates))]
)
candidates = np.concatenate(
[candidates, [self.num_nodes] * (self.num_candidates - len(candidates))]
)

candidate_scores = np.asarray(candidate_scores, dtype=np.float32)
return x, y, candidates, candidate_scores

# predict
else:
candidate_scores = self.candidate_scores[idx]

# add dummy elements when less than required
if len(candidates) < self.num_candidates:
candidate_scores = np.concatenate(
[candidate_scores, [-np.inf] * (self.num_candidates - len(candidates))]
)
candidates = np.concatenate([candidates, [self.num_nodes] * (self.num_candidates - len(candidates))])

candidate_scores = np.asarray(candidate_scores, dtype=np.float32)
return x, candidates, candidate_scores


def tokenize(text: str, lowercase: bool = True, tokenizer: str = "regex") -> list[str]:
def tokenize(text: str, tokenizer: str = "regex") -> list[str]:
"""Tokenize text.
Args:
text (str): Text to tokenize.
lowercase: Whether to convert all characters to lowercase.
tokenizer: The tokenizer from nltk to use. Can be one of ["regex", "punkt"]
Returns:
Expand All @@ -249,14 +105,15 @@ def tokenize(text: str, lowercase: bool = True, tokenizer: str = "regex") -> lis
tokenizer = RegexpTokenizer(r"\w+").tokenize
pattern = r"^\d+$"
elif tokenizer == "punkt":
nltk.download("punkt")
tokenizer = word_tokenize
pattern = r"\W"
elif tokenizer == "split":
tokenizer = lambda x: x.split()
pattern = r""
else:
raise ValueError(f"unsupported tokenizer {tokenizer}")
return [t.lower() if lowercase and t != "/SEP/" else t for t in tokenizer(text) if re.sub(pattern, "", t)]
return [t.lower() if t != "/SEP/" else t for t in tokenizer(text) if re.sub(pattern, "", t)]


def generate_batch(data_batch):
Expand Down Expand Up @@ -321,7 +178,6 @@ def _load_raw_data(
is_test=False,
tokenize_text=True,
remove_no_label_data=False,
lowercase=True,
tokenizer="regex",
) -> list[dict[str, list[str]]]:
"""Load and tokenize raw data in file or dataframe.
Expand All @@ -335,24 +191,26 @@ def _load_raw_data(
Returns:
dict: [{(optional: "index": ..., ), "label": ..., "text": ...}, ...]
"""
assert isinstance(data, str) or isinstance(data, pd.DataFrame), "Data must be from a file or pandas dataframe."
if isinstance(data, str):
if GLOBAL_RANK == 0:
logging.info(f"Loading data from {data}.")
logging.info(f"Loading data from {data}.")
data = pd.read_csv(data, sep="\t", header=None, on_bad_lines="warn", quoting=csv.QUOTE_NONE).fillna("")
data = data.astype(str)
if data.shape[1] == 2:
data.columns = ["label", "text"]
data = data.reset_index()
elif data.shape[1] == 3:
data.columns = ["index", "label", "text"]
data = data.astype(str)
elif isinstance(data, pd.DataFrame):
logging.info(f"Loading data from DataFrame.")
if data.shape[1] == 2:
data.columns = ["label", "text"]
data = data.reset_index()
elif data.shape[1] == 3:
data.columns = ["index", "label", "text"]
else:
raise ValueError(f"Expected 2 or 3 columns, got {data.shape[1]}.")
else:
raise ValueError(f"Expected 2 or 3 columns, got {data.shape[1]}.")
raise ValueError("Data must be from a file or pandas dataframe.")

data["label"] = data["label"].astype(str).map(lambda s: s.split())
if tokenize_text:
tqdm.pandas()
data["text"] = data["text"].progress_map(lambda t: tokenize(t, lowercase=lowercase, tokenizer=tokenizer))
data["text"] = data["text"].progress_map(lambda t: tokenize(t, tokenizer=tokenizer))
# TODO: Can we change to "list"?
data = data.to_dict("records")
if not is_test:
Expand All @@ -379,7 +237,6 @@ def load_datasets(
merge_train_val=False,
tokenize_text=True,
remove_no_label_data=False,
lowercase=True,
random_state=42,
tokenizer="regex",
) -> dict:
Expand All @@ -396,20 +253,15 @@ def load_datasets(
merge_train_val (bool, optional): Whether to merge the training and validation data.
Defaults to False.
tokenize_text (bool, optional): Whether to tokenize text. Defaults to True.
lowercase: Whether to lowercase text. Defaults to True.
remove_no_label_data (bool, optional): Whether to remove training/validation instances that have no labels.
Defaults to False.
random_state:
Returns:
dict: A dictionary of datasets.
"""
if isinstance(training_data, str) or isinstance(test_data, str):
assert training_data or test_data, "At least one of `training_data` and `test_data` must be specified."
elif isinstance(training_data, pd.DataFrame) or isinstance(test_data, pd.DataFrame):
assert (
not training_data.empty or not test_data.empty
), "At least one of `training_data` and `test_data` must be specified."
if training_data is None and test_data is None:
raise ValueError("At least one of 'training_data' and 'test_data' must be specified.")

datasets = {}
if training_data is not None:
Expand All @@ -419,23 +271,19 @@ def load_datasets(
datasets["train"] = _load_raw_data(
training_data,
tokenize_text=tokenize_text,
lowercase=lowercase,
tokenizer=tokenizer,
remove_no_label_data=remove_no_label_data,
)
if GLOBAL_RANK == 0:
np.save(Path(training_data).with_suffix(".npy"), datasets["train"])
rank_zero_only(np.save(Path(training_data).with_suffix(".npy"), datasets["train"]))

if training_sparse_data is not None:
if GLOBAL_RANK == 0:
logging.info(f"Loading sparse training data from {training_sparse_data}.")
rank_zero_info(f"Loading sparse training data from {training_sparse_data}.")
datasets["train_sparse_x"] = normalize(load_svmlight_file(training_sparse_data, multilabel=True)[0])

if val_data is not None:
datasets["val"] = _load_raw_data(
val_data,
tokenize_text=tokenize_text,
lowercase=lowercase,
tokenizer=tokenizer,
remove_no_label_data=remove_no_label_data,
)
Expand All @@ -452,12 +300,10 @@ def load_datasets(
test_data,
is_test=True,
tokenize_text=tokenize_text,
lowercase=lowercase,
tokenizer=tokenizer,
remove_no_label_data=remove_no_label_data,
)
if GLOBAL_RANK == 0:
np.save(Path(test_data).with_suffix(".npy"), datasets["test"])
rank_zero_only(np.save(Path(test_data).with_suffix(".npy"), datasets["test"]))

if merge_train_val:
try:
Expand All @@ -473,11 +319,11 @@ def load_datasets(
gc.collect()

msg = " / ".join(f"{k}: {v.shape[0] if issparse(v) else len(v)}" for k, v in datasets.items())
if GLOBAL_RANK == 0:
logging.info(f"Finish loading dataset ({msg})")
rank_zero_info(f"Finish loading dataset ({msg})")
return datasets


@lambda fn: rank_zero_only(fn, defaults=(None, None))
def load_or_build_text_dict(
dataset,
vocab_file=None,
Expand Down Expand Up @@ -582,8 +428,7 @@ def load_or_build_label(datasets, label_file=None, include_test_labels=False):
for instance in data:
classes.update(instance["label"])
classes = sorted(classes)
if GLOBAL_RANK == 0:
logging.info(f"Read {len(classes)} labels.")
rank_zero_info(f"Read {len(classes)} labels.")

return classes

Expand Down
Loading

0 comments on commit 8b54fa3

Please sign in to comment.