diff --git a/libmultilabel/common_utils.py b/libmultilabel/common_utils.py index 9ddd9ebb..80bc4b64 100644 --- a/libmultilabel/common_utils.py +++ b/libmultilabel/common_utils.py @@ -5,13 +5,6 @@ import time import numpy as np -from pytorch_lightning.plugins.environments import SLURMEnvironment, LightningEnvironment - -if os.environ.get("SLURM_PROCID") is not None: - env = SLURMEnvironment() -else: - env = LightningEnvironment() -GLOBAL_RANK = env.global_rank() class AttributeDict(dict): diff --git a/libmultilabel/nn/data_utils.py b/libmultilabel/nn/data_utils.py index bf69030a..957ba8b4 100644 --- a/libmultilabel/nn/data_utils.py +++ b/libmultilabel/nn/data_utils.py @@ -8,7 +8,6 @@ from collections import defaultdict from functools import partial from pathlib import Path -from typing import Sequence import numpy as np import pandas as pd @@ -16,20 +15,17 @@ import transformers import nltk from nltk.tokenize import RegexpTokenizer, word_tokenize -from numpy import ndarray -from scipy.sparse import csr_matrix, issparse +from scipy.sparse import issparse from sklearn.datasets import load_svmlight_file from sklearn.model_selection import train_test_split from sklearn.preprocessing import MultiLabelBinarizer, normalize -from torch import Tensor from torch.nn.utils.rnn import pad_sequence from torch.utils.data import Dataset from torchtext.vocab import build_vocab_from_iterator, pretrained_aliases, Vocab, Vectors +from lightning.pytorch.utilities.rank_zero import rank_zero_only, rank_zero_info from gensim.models import KeyedVectors from tqdm import tqdm -from ..common_utils import GLOBAL_RANK - transformers.logging.set_verbosity_error() warnings.simplefilter(action="ignore", category=FutureWarning) @@ -95,151 +91,11 @@ def __getitem__(self, index): } -class MultiLabelDataset(Dataset): - """Basic class for multi-label dataset.""" - - def __init__( - self, - x: list[list[int]], - y: csr_matrix | ndarray | None = None, - ): - """General dataset class for multi-label dataset. - - Args: - x: text. - y: labels. - """ - if y is not None: - assert len(x) == y.shape[0], "Size mismatch between x and y" - self.x = x - self.y = y - - def __getitem__(self, idx: int) -> tuple[Sequence, ndarray] | tuple[Sequence]: - x = self.x[idx] - - # train/valid/test - if self.y is not None: - if issparse(self.y): - y = self.y[idx].toarray().squeeze(0).astype(np.float32) - else: - y = self.y[idx].astype(np.float32) - return x, y - # predict - return x - - def __len__(self): - return len(self.x) - - -class PLTDataset(MultiLabelDataset): - """Dataset class for AttentionXML.""" - - def __init__( - self, - x, - y: csr_matrix | ndarray | None = None, - *, - num_nodes: int, - mapping: ndarray, - node_label: ndarray | Tensor, - node_score: ndarray | Tensor | None = None, - ): - """Dataset for FastAttentionXML. - ~ means variable length. - - Args: - x: text - y: labels - num_nodes: number of nodes at the current level. - mapping: [[0,..., 7], [8,..., 15], ...]. shape: (len(nodes), ~cluster_size). parent nodes to child nodes. - Cluster size will only vary at the last level. - node_label: [[7, 1, 128, 6], [21, 85, 64, 103], ...]. shape: (len(x), top_k). numbers are predicted nodes - from last level. - node_score: corresponding scores. shape: (len(x), top_k) - """ - super().__init__(x, y) - self.num_nodes = num_nodes - self.mapping = mapping - self.node_label = node_label - self.node_score = node_score - self.candidate_scores = None - - # candidate are positive nodes at the current level. shape: (len(x), ~cluster_size * top_k) - # look like [[0, 1, 2, 4, 5, 18, 19,...], ...] - prog = tqdm(self.node_label, leave=False, desc="Candidates") if GLOBAL_RANK == 0 else self.node_label - self.candidates = [np.concatenate(self.mapping[labels]) for labels in prog] - if self.node_score is not None: - # candidate_scores are corresponding scores for candidates and - # look like [[0.1, 0.1, 0.1, 0.4, 0.4, 0.5, 0.5,...], ...]. shape: (len(x), ~cluster_size * top_k) - # notice how scores repeat for each cluster. - self.candidate_scores = [ - np.repeat(scores, [len(i) for i in self.mapping[labels]]) - for labels, scores in zip(self.node_label, self.node_score) - ] - - # top_k * n (n <= cluster_size). number of maximum possible number candidates at the current level. - self.num_candidates = self.node_label.shape[1] * max(len(node) for node in self.mapping) - - def __getitem__(self, idx: int): - x = self.x[idx] - candidates = np.asarray(self.candidates[idx], dtype=np.int64) - - # train/valid/test - if self.y is not None: - # squeezing is necessary here because csr_matrix.toarray() always returns a 2d array - # e.g., np.ndarray([[0, 1, 2]]) - y = self.y[idx].toarray().squeeze(0).astype(np.float32) - - # train - if self.candidate_scores is None: - # randomly select nodes as candidates when less than required - if len(candidates) < self.num_candidates: - sample = np.random.randint(self.num_nodes, size=self.num_candidates - len(candidates)) - candidates = np.concatenate([candidates, sample]) - # randomly select a subset of candidates when more than required - elif len(candidates) > self.num_candidates: - # candidates = np.random.choice(candidates, self.num_candidates, replace=False) - raise ValueError("Too many candidates. Which shouldn't happen.") - return x, y, candidates - - # valid/test - else: - candidate_scores = self.candidate_scores[idx] - offset = (self.num_nodes, self.num_candidates - len(candidates)) - - # add dummy elements when less than required - if len(candidates) < self.num_candidates: - candidate_scores = np.concatenate( - [candidate_scores, [-np.inf] * (self.num_candidates - len(candidates))] - ) - candidates = np.concatenate( - [candidates, [self.num_nodes] * (self.num_candidates - len(candidates))] - ) - - candidate_scores = np.asarray(candidate_scores, dtype=np.float32) - return x, y, candidates, candidate_scores - - # predict - else: - candidate_scores = self.candidate_scores[idx] - - # add dummy elements when less than required - if len(candidates) < self.num_candidates: - candidate_scores = np.concatenate( - [candidate_scores, [-np.inf] * (self.num_candidates - len(candidates))] - ) - candidates = np.concatenate([candidates, [self.num_nodes] * (self.num_candidates - len(candidates))]) - - candidate_scores = np.asarray(candidate_scores, dtype=np.float32) - return x, candidates, candidate_scores - - -def tokenize(text: str, lowercase: bool = True, tokenizer: str = "regex") -> list[str]: +def tokenize(text: str, tokenizer: str = "regex") -> list[str]: """Tokenize text. Args: text (str): Text to tokenize. - lowercase: Whether to convert all characters to lowercase. tokenizer: The tokenizer from nltk to use. Can be one of ["regex", "punkt"] Returns: @@ -249,6 +105,7 @@ def tokenize(text: str, lowercase: bool = True, tokenizer: str = "regex") -> lis tokenizer = RegexpTokenizer(r"\w+").tokenize pattern = r"^\d+$" elif tokenizer == "punkt": + nltk.download("punkt") tokenizer = word_tokenize pattern = r"\W" elif tokenizer == "split": @@ -256,7 +113,7 @@ def tokenize(text: str, lowercase: bool = True, tokenizer: str = "regex") -> lis pattern = r"" else: raise ValueError(f"unsupported tokenizer {tokenizer}") - return [t.lower() if lowercase and t != "/SEP/" else t for t in tokenizer(text) if re.sub(pattern, "", t)] + return [t.lower() if t != "/SEP/" else t for t in tokenizer(text) if re.sub(pattern, "", t)] def generate_batch(data_batch): @@ -321,7 +178,6 @@ def _load_raw_data( is_test=False, tokenize_text=True, remove_no_label_data=False, - lowercase=True, tokenizer="regex", ) -> list[dict[str, list[str]]]: """Load and tokenize raw data in file or dataframe. @@ -335,24 +191,26 @@ def _load_raw_data( Returns: dict: [{(optional: "index": ..., ), "label": ..., "text": ...}, ...] """ - assert isinstance(data, str) or isinstance(data, pd.DataFrame), "Data must be from a file or pandas dataframe." if isinstance(data, str): - if GLOBAL_RANK == 0: - logging.info(f"Loading data from {data}.") + logging.info(f"Loading data from {data}.") data = pd.read_csv(data, sep="\t", header=None, on_bad_lines="warn", quoting=csv.QUOTE_NONE).fillna("") - data = data.astype(str) - if data.shape[1] == 2: - data.columns = ["label", "text"] - data = data.reset_index() - elif data.shape[1] == 3: - data.columns = ["index", "label", "text"] + data = data.astype(str) + elif isinstance(data, pd.DataFrame): + logging.info(f"Loading data from DataFrame.") + if data.shape[1] == 2: + data.columns = ["label", "text"] + data = data.reset_index() + elif data.shape[1] == 3: + data.columns = ["index", "label", "text"] + else: + raise ValueError(f"Expected 2 or 3 columns, got {data.shape[1]}.") else: - raise ValueError(f"Expected 2 or 3 columns, got {data.shape[1]}.") + raise ValueError("Data must be from a file or pandas dataframe.") data["label"] = data["label"].astype(str).map(lambda s: s.split()) if tokenize_text: tqdm.pandas() - data["text"] = data["text"].progress_map(lambda t: tokenize(t, lowercase=lowercase, tokenizer=tokenizer)) + data["text"] = data["text"].progress_map(lambda t: tokenize(t, tokenizer=tokenizer)) # TODO: Can we change to "list"? data = data.to_dict("records") if not is_test: @@ -379,7 +237,6 @@ def load_datasets( merge_train_val=False, tokenize_text=True, remove_no_label_data=False, - lowercase=True, random_state=42, tokenizer="regex", ) -> dict: @@ -396,7 +253,6 @@ def load_datasets( merge_train_val (bool, optional): Whether to merge the training and validation data. Defaults to False. tokenize_text (bool, optional): Whether to tokenize text. Defaults to True. - lowercase: Whether to lowercase text. Defaults to True. remove_no_label_data (bool, optional): Whether to remove training/validation instances that have no labels. Defaults to False. random_state: @@ -404,12 +260,8 @@ def load_datasets( Returns: dict: A dictionary of datasets. """ - if isinstance(training_data, str) or isinstance(test_data, str): - assert training_data or test_data, "At least one of `training_data` and `test_data` must be specified." - elif isinstance(training_data, pd.DataFrame) or isinstance(test_data, pd.DataFrame): - assert ( - not training_data.empty or not test_data.empty - ), "At least one of `training_data` and `test_data` must be specified." + if training_data is None and test_data is None: + raise ValueError("At least one of 'training_data' and 'test_data' must be specified.") datasets = {} if training_data is not None: @@ -419,23 +271,19 @@ def load_datasets( datasets["train"] = _load_raw_data( training_data, tokenize_text=tokenize_text, - lowercase=lowercase, tokenizer=tokenizer, remove_no_label_data=remove_no_label_data, ) - if GLOBAL_RANK == 0: - np.save(Path(training_data).with_suffix(".npy"), datasets["train"]) + rank_zero_only(np.save(Path(training_data).with_suffix(".npy"), datasets["train"])) if training_sparse_data is not None: - if GLOBAL_RANK == 0: - logging.info(f"Loading sparse training data from {training_sparse_data}.") + rank_zero_info(f"Loading sparse training data from {training_sparse_data}.") datasets["train_sparse_x"] = normalize(load_svmlight_file(training_sparse_data, multilabel=True)[0]) if val_data is not None: datasets["val"] = _load_raw_data( val_data, tokenize_text=tokenize_text, - lowercase=lowercase, tokenizer=tokenizer, remove_no_label_data=remove_no_label_data, ) @@ -452,12 +300,10 @@ def load_datasets( test_data, is_test=True, tokenize_text=tokenize_text, - lowercase=lowercase, tokenizer=tokenizer, remove_no_label_data=remove_no_label_data, ) - if GLOBAL_RANK == 0: - np.save(Path(test_data).with_suffix(".npy"), datasets["test"]) + rank_zero_only(np.save(Path(test_data).with_suffix(".npy"), datasets["test"])) if merge_train_val: try: @@ -473,11 +319,11 @@ def load_datasets( gc.collect() msg = " / ".join(f"{k}: {v.shape[0] if issparse(v) else len(v)}" for k, v in datasets.items()) - if GLOBAL_RANK == 0: - logging.info(f"Finish loading dataset ({msg})") + rank_zero_info(f"Finish loading dataset ({msg})") return datasets +@lambda fn: rank_zero_only(fn, defaults=(None, None)) def load_or_build_text_dict( dataset, vocab_file=None, @@ -582,8 +428,7 @@ def load_or_build_label(datasets, label_file=None, include_test_labels=False): for instance in data: classes.update(instance["label"]) classes = sorted(classes) - if GLOBAL_RANK == 0: - logging.info(f"Read {len(classes)} labels.") + rank_zero_info(f"Read {len(classes)} labels.") return classes diff --git a/libmultilabel/nn/datasets.py b/libmultilabel/nn/datasets.py new file mode 100644 index 00000000..1200fa78 --- /dev/null +++ b/libmultilabel/nn/datasets.py @@ -0,0 +1,148 @@ +from __future__ import annotations + +from typing import Sequence, Optional + +import numpy as np +from lightning.pytorch.utilities.rank_zero import rank_zero_only +from numpy import ndarray +from scipy.sparse import csr_matrix, issparse +from torch import Tensor +from torch.utils.data import Dataset +from tqdm import tqdm + + +class MultiLabelDataset(Dataset): + """Basic class for multi-label dataset.""" + + def __init__(self, x: list[list[int]], y: Optional[csr_matrix | ndarray] = None): + """General dataset class for multi-label dataset. + + Args: + x: text. + y: labels. + """ + if y is not None: + assert len(x) == y.shape[0], "Sizes mismatch between x and y" + self.x = x + self.y = y + + def __getitem__(self, idx: int) -> tuple[Sequence, ndarray] | tuple[Sequence]: + x = self.x[idx] + + # train/valid/test + if self.y is not None: + if issparse(self.y): + y = self.y[idx].toarray().squeeze(0).astype(np.float32) + else: + y = self.y[idx].astype(np.float32) + return x, y + # predict + return x + + def __len__(self): + return len(self.x) + + +class PLTDataset(MultiLabelDataset): + """Dataset class for AttentionXML.""" + + def __init__( + self, + x, + y: Optional[csr_matrix | ndarray] = None, + *, + num_nodes: int, + mapping: ndarray, + node_label: ndarray | Tensor, + node_score: Optional[ndarray | Tensor] = None, + ): + """Dataset for FastAttentionXML. + ~ means variable length. + + Args: + x: text + y: labels + num_nodes: number of nodes at the current level. + mapping: [[0,..., 7], [8,..., 15], ...]. shape: (len(nodes), ~cluster_size). parent nodes to child nodes. + Cluster size will only vary at the last level. + node_label: [[7, 1, 128, 6], [21, 85, 64, 103], ...]. shape: (len(x), top_k). numbers are predicted nodes + from last level. + node_score: corresponding scores. shape: (len(x), top_k) + """ + super().__init__(x, y) + self.num_nodes = num_nodes + self.mapping = mapping + self.node_label = node_label + self.node_score = node_score + self.candidate_scores = None + + # candidate are positive nodes at the current level. shape: (len(x), ~cluster_size * top_k) + # look like [[0, 1, 2, 4, 5, 18, 19,...], ...] + prog = rank_zero_only(tqdm(self.node_label, leave=False, desc="Candidates")) + if prog is None: + prog = self.node_label + self.candidates = [np.concatenate(self.mapping[labels]) for labels in prog] + if self.node_score is not None: + # candidate_scores are corresponding scores for candidates and + # look like [[0.1, 0.1, 0.1, 0.4, 0.4, 0.5, 0.5,...], ...]. shape: (len(x), ~cluster_size * top_k) + # notice how scores repeat for each cluster. + self.candidate_scores = [ + np.repeat(scores, [len(i) for i in self.mapping[labels]]) + for labels, scores in zip(self.node_label, self.node_score) + ] + + # top_k * n (n <= cluster_size). number of maximum possible number candidates at the current level. + self.num_candidates = self.node_label.shape[1] * max(len(node) for node in self.mapping) + + def __getitem__(self, idx: int): + x = self.x[idx] + candidates = np.asarray(self.candidates[idx], dtype=np.int64) + + # train/valid/test + if self.y is not None: + # squeezing is necessary here because csr_matrix.toarray() always returns a 2d array + # e.g., np.ndarray([[0, 1, 2]]) + y = self.y[idx].toarray().squeeze(0).astype(np.float32) + + # train + if self.candidate_scores is None: + # randomly select nodes as candidates when less than required + if len(candidates) < self.num_candidates: + sample = np.random.randint(self.num_nodes, size=self.num_candidates - len(candidates)) + candidates = np.concatenate([candidates, sample]) + # randomly select a subset of candidates when more than required + elif len(candidates) > self.num_candidates: + # candidates = np.random.choice(candidates, self.num_candidates, replace=False) + raise ValueError("Too many candidates. Which shouldn't happen.") + return x, y, candidates + + # valid/test + else: + candidate_scores = self.candidate_scores[idx] + offset = (self.num_nodes, self.num_candidates - len(candidates)) + + # add dummy elements when less than required + if len(candidates) < self.num_candidates: + candidate_scores = np.concatenate( + [candidate_scores, [-np.inf] * (self.num_candidates - len(candidates))] + ) + candidates = np.concatenate( + [candidates, [self.num_nodes] * (self.num_candidates - len(candidates))] + ) + + candidate_scores = np.asarray(candidate_scores, dtype=np.float32) + return x, y, candidates, candidate_scores + + # predict + else: + candidate_scores = self.candidate_scores[idx] + + # add dummy elements when less than required + if len(candidates) < self.num_candidates: + candidate_scores = np.concatenate( + [candidate_scores, [-np.inf] * (self.num_candidates - len(candidates))] + ) + candidates = np.concatenate([candidates, [self.num_nodes] * (self.num_candidates - len(candidates))]) + + candidate_scores = np.asarray(candidate_scores, dtype=np.float32) + return x, candidates, candidate_scores diff --git a/libmultilabel/nn/metrics.py b/libmultilabel/nn/metrics.py index 7ed7a842..1d376cea 100644 --- a/libmultilabel/nn/metrics.py +++ b/libmultilabel/nn/metrics.py @@ -1,15 +1,12 @@ from __future__ import annotations -import logging import re import numpy as np import torch import torchmetrics.classification -from torch import Tensor -from torchmetrics import Metric, MetricCollection, Precision, Recall, RetrievalPrecision -from torchmetrics.functional.retrieval.ndcg import retrieval_normalized_dcg -from torchmetrics.utilities.data import select_topk, dim_zero_cat +from torchmetrics import Metric, MetricCollection, Precision, Recall +from torchmetrics.utilities.data import select_topk AVAILABLE_METRICS_ABBR = ["p", "r", "rp", "ndcg"] @@ -50,18 +47,14 @@ def _str2metric( if metric_abbr == "p": return Precision(num_labels, average="samples", top_k=top_k) - # elif metric_abbr == "tp": - # metric = RetrievalPrecision(k=top_k) elif metric_abbr == "rp": return RPrecision(top_k=top_k) elif metric_abbr == "ndcg": return NDCG(top_k=top_k) - elif metric_abbr == "ndcgnew": - return NDCGnew(top_k=top_k) else: raise ValueError(f"Invalid metric: {metric}.") - return MetricCollection({m.lower(): _str2metric(m) for m in metrics}) + return MetricCollection({m.lower(): _str2metric(m) for m in metrics}, compute_groups=False) class Loss(Metric): @@ -99,6 +92,7 @@ class NDCG(Metric): https://scikit-learn.org/stable/modules/generated/sklearn.metrics.ndcg_score.html Please find the formal definition here: https://nlp.stanford.edu/IR-book/html/htmledition/evaluation-of-ranked-retrieval-results-1.html + Args: top_k (int): the top k relevant labels to evaluate. """ @@ -113,41 +107,37 @@ class NDCG(Metric): def __init__(self, top_k): super().__init__() self.top_k = top_k - self.add_state("ndcg", default=[], dist_reduce_fx="cat") + self.add_state("score", default=torch.tensor(0.0, dtype=torch.float64), dist_reduce_fx="sum") + self.add_state("num_sample", default=torch.tensor(0, dtype=torch.int64), dist_reduce_fx="sum") def update(self, preds, target): assert preds.shape == target.shape - # implement batch-wise calculations instead of storing results of all batches - self.ndcg += [self._metric(p, t) for p, t in zip(preds, target)] + discount = 1.0 / torch.log2(torch.arange(self.top_k, device=target.device) + 2.0) + dcg = self._dcg(preds, target, discount) + # Instances without labels will have incorrect idcg. However, their dcg will be 0. + # As a result, the ndcg will still be correct. + idcg = self._idcg(target, discount) + ndcg = dcg / idcg + self.score += ndcg.sum() + self.num_sample += preds.shape[0] def compute(self): - """Performs stacking on ndcg if neccesary""" - ndcg = torch.stack(self.ndcg) if isinstance(self.ndcg, list) else self.ndcg - return ndcg.mean() - - def _metric(self, preds, target): - return retrieval_normalized_dcg(preds, target, k=self.top_k) - - -class NDCGnew(Metric): - is_differentiable = False - higher_is_better = True - full_state_update = False - - def __init__(self, top_k: int): - super().__init__() - self.top_k = top_k - # the range of ndcg is from 0 to 1 - self.add_state("ndcg", default=torch.tensor(0, dtype=torch.float), dist_reduce_fx="sum") - self.add_state("n", default=torch.tensor(0), dist_reduce_fx="sum") + return self.score / self.num_sample - def update(self, preds: Tensor, target: Tensor): - assert preds.shape == target.shape - self.ndcg += torch.stack([retrieval_normalized_dcg(p, t, k=self.top_k) for p, t in zip(preds, target)]).sum() - self.n += len(preds) + def _dcg(self, preds, target, discount): + _, sorted_top_k_idx = torch.topk(preds, k=self.top_k) + gains = target.take_along_dim(sorted_top_k_idx, dim=1) + # best practice for batch dot product: https://discuss.pytorch.org/t/dot-product-batch-wise/9746/11 + return (gains * discount).sum(dim=1) - def compute(self): - return self.ndcg / self.n + def _idcg(self, target, discount): + """Computes IDCG@k for a 0/1 target tensor. + A 0/1 target is a special case that doesn't require sorting. + """ + cum_discount = discount.cumsum(dim=0) + idx = target.sum(dim=1) - 1 + idx = idx.clamp(min=0, max=self.top_k - 1) + return cum_discount[idx] class RPrecision(Metric): @@ -169,7 +159,6 @@ class RPrecision(Metric): def __init__(self, top_k): super().__init__() self.top_k = top_k - # self.add_state("score", default=[], dist_reduce_fx="cat") self.add_state("score", default=torch.tensor(0.0, dtype=torch.double), dist_reduce_fx="sum") self.add_state("num_sample", default=torch.tensor(0), dist_reduce_fx="sum") @@ -325,7 +314,3 @@ def tabulate_metrics(metric_dict: dict[str, float], split: str) -> str: ) msg += f"|{header}|\n|{'-----------------:|' * len(metric_dict)}\n|{values}|\n" return msg - - -def get_precision(preds, target, classes=None, top=5): - return preds.multiply(target).sum() / (top * target.shape[0]) diff --git a/libmultilabel/nn/model.py b/libmultilabel/nn/model.py index 66bca3dc..8dfd603f 100644 --- a/libmultilabel/nn/model.py +++ b/libmultilabel/nn/model.py @@ -3,6 +3,7 @@ import logging from abc import abstractmethod from collections import deque +from typing import Optional import numpy as np import pytorch_lightning as pl @@ -15,7 +16,6 @@ from ..common_utils import dump_log, argsort_top_k from ..nn.metrics import get_metrics, tabulate_metrics, list2metrics -from .networks.labelwise_attention_networks import AttentionRNN from libmultilabel.nn import networks @@ -266,19 +266,18 @@ def __init__( metrics: list[str], top_k: int, loss_fn: str = "binary_cross_entropy_with_logits", - optimizer_params: dict | None = None, - swa_epoch_start: int | None = None, - train_mlb=None, - test_mlb=None, + optimizer_params: Optional[dict] = None, + swa_epoch_start: Optional[int] = None, ): super().__init__() - self.save_hyperparameters() + self.save_hyperparameters(ignore=["embed_vecs"]) try: self.network = getattr(networks, network)(embed_vecs=embed_vecs, num_classes=num_labels, **network_config) except AttributeError as e: logging.warning(e) raise AttributeError(f"Invalid network name: {network}") + self.loss_fn = self.configure_loss_fn(loss_fn) # optimizer config @@ -292,14 +291,9 @@ def __init__( self.num_labels = num_labels self.state = {} - # mlbs are needed for testing purposes - self.train_mlb = train_mlb - self.test_mlb = test_mlb - self.metric_list = metrics - self.metrics = list2metrics(metrics, self.num_labels) - self.valid_metrics = None - self.test_metrics = None + self.valid_metrics = list2metrics(["ndcg@5"], self.num_labels) + self.test_metrics = list2metrics(metrics, self.num_labels) self._clip_grad = 5.0 self._grad_norm_queue = deque([torch.tensor(float("inf"))], maxlen=5) @@ -316,12 +310,7 @@ def configure_loss_fn(loss_fn: str) -> Module: def configure_optimizers(self) -> Optimizer: try: - if self.optimizer == "DenseSparseAdam": - from .optimizer import DenseSparseAdam - - optimizer = DenseSparseAdam(self.parameters(), **self.optimizer_params) - else: - optimizer = getattr(torch.optim, self.optimizer)(self.parameters(), **self.optimizer_params) + optimizer = getattr(torch.optim, self.optimizer)(self.parameters(), **self.optimizer_params) except AttributeError: raise AttributeError(f"Invalid optimizer name: {self.optimizer}") except TypeError: @@ -337,7 +326,6 @@ def training_step(self, batch: Tensor, batch_idx: int): x, y = batch logits = self.network(x) loss = self.loss_fn(logits, y) - # self.log("loss", loss, prog_bar=True, sync_dist=True) return loss def swa_init(self): @@ -364,35 +352,25 @@ def swap_swa_params(self): p.data, swa_state[n] = swa_state[n], p.data def on_validation_start(self): - self.valid_metrics = self.metrics.clone(prefix="valid_") self.swa_step() self.swap_swa_params() def validation_step(self, batch: Tensor, batch_idx: int): - """log metrics on epoch""" x, y = batch logits = self.network(x) - # why detach? see: https://github.com/Lightning-AI/lightning/issues/9441 - self.valid_metrics.update(torch.sigmoid(logits).detach(), y.long()) + self.valid_metrics.update(torch.sigmoid(logits), y.long()) def on_validation_epoch_end(self): - self.log_dict(self.valid_metrics.compute(), prog_bar=True, sync_dist=True) + self.log_dict(self.valid_metrics.compute(), prog_bar=True) self.valid_metrics.reset() def on_validation_end(self): self.swap_swa_params() - def on_test_start(self): - self.test_metrics = self.metrics.clone(prefix="test_") - # self.test_metrics = list2metrics(self.metric_list, len(self.test_mlb.classes_)).cuda() - def test_step(self, batch: Tensor, batch_idx: int): - """log metrics on epoch""" x, y = batch logits = self.network(x) - # preds = self.multilabel_binarize(logits.detach()) - # self.test_metrics.update(preds, y.long()) - self.test_metrics.update(torch.sigmoid(logits), y.long()) + self.test_metrics.update(torch.sigmoid(logits).detach(), y.long()) def on_test_epoch_end(self): self.log_dict(self.test_metrics.compute()) @@ -404,7 +382,15 @@ def predict_step(self, batch: Tensor, batch_idx: int, dataloader_idx: int = 0): x = batch logits = self.network(x) scores, labels = torch.topk(torch.sigmoid(logits), self.top_k) - return scores.cpu(), labels.cpu() + return scores.detach().cpu(), labels.detach().cpu() + + def forward(self, x): + return self.network(x) + + def on_save_checkpoint(self, checkpoint): + for k in list(checkpoint): + if k not in self.CHECKPOINT_KEYS: + checkpoint.pop(k) def on_after_backward(self): if self._clip_grad is not None: @@ -413,44 +399,7 @@ def on_after_backward(self): self._grad_norm_queue += [min(total_norm, max_norm * 2.0, torch.tensor(1.0))] if total_norm > max_norm * self._clip_grad: if self.trainer.is_global_zero: - logging.warning(f"Clipping gradients with total norm {total_norm} and max norm {max_norm}") - - def multilabel_binarize(self, logits: Tensor) -> Tensor: - """self-implemented MultiLabelBinarizer for AttentionXML using Tensor""" - # find the top k labels - scores, labels = torch.topk(logits, self.top_k) - # transform them back to string form with training Multi-label Binarizer - labels = self.train_mlb.classes_[labels.cpu()] - # calculate the masks where labels do not appear in testing dataset - mask = torch.tensor( - [[la not in self.test_mlb.classes_ for la in label] for label in labels], device=logits.device - ) - # fill scores with 0 where mask is True - scores.masked_fill_(mask, 0) - # get the indices of logits in new - index = torch.tensor( - [ - np.stack( - [ - np.asarray(self.test_mlb.classes_ == l).nonzero()[0][0] - if np.asarray(self.test_mlb.classes_ == l).any() - else self.test_mlb.classes_.shape[0] - for l in label - ] - ) - for label in labels - ], - device=logits.device, - ) - - # make sure preds and src use the same precision, e.g., either float16 or float32 - preds = torch.zeros( - logits.shape[0], self.test_mlb.classes_.shape[0] + 1, device=logits.device, dtype=logits.dtype - ) - preds.scatter_(dim=1, index=index, src=torch.sigmoid(scores)) - # remove dummy unknown labels - preds = preds[:, :-1] - return preds + logging.warning(f"Clipping gradients with total norm {total_norm:.4f} and max norm {max_norm:.4f}") class PLTModel(BaseModel): @@ -465,10 +414,8 @@ def __init__( top_k: int, eval_metric: str, loss_fn: str = "binary_cross_entropy_with_logits", - optimizer_params: dict | None = None, - swa_epoch_start: int | None = None, - train_mlb=None, - test_mlb=None, + optimizer_params: Optional[dict] = None, + swa_epoch_start: Optional[int] = None, ): super().__init__( network=network, @@ -481,8 +428,6 @@ def __init__( loss_fn=loss_fn, optimizer_params=optimizer_params, swa_epoch_start=swa_epoch_start, - train_mlb=train_mlb, - test_mlb=test_mlb, ) self.state["best"] = {} self.eval_metric = "_".join(["valid", eval_metric.lower()]) diff --git a/libmultilabel/nn/nn_utils.py b/libmultilabel/nn/nn_utils.py index 939f7728..2c8707f4 100644 --- a/libmultilabel/nn/nn_utils.py +++ b/libmultilabel/nn/nn_utils.py @@ -2,11 +2,11 @@ import os from pathlib import Path -import pytorch_lightning as pl +import lightning as L import torch -from pytorch_lightning.callbacks.early_stopping import EarlyStopping -from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint -from pytorch_lightning.utilities.seed import seed_everything +from lightning import seed_everything +from lightning.pytorch.callbacks.early_stopping import EarlyStopping +from lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint from ..nn import networks from ..nn.model import Model, BaseModel @@ -28,8 +28,6 @@ def init_device(use_cpu=False): # https://docs.nvidia.com/cuda/cublas/index.html os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" device = torch.device("cuda") - if GLOBAL_RANK == 0: - logging.info(f"Available GPUs: {torch.cuda.device_count()}") # Sets the internal precision of float32 matrix multiplications. # https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html torch.set_float32_matmul_precision("high") @@ -37,8 +35,7 @@ def init_device(use_cpu=False): device = torch.device("cpu") # https://github.com/pytorch/pytorch/issues/11201 torch.multiprocessing.set_sharing_strategy("file_system") - if GLOBAL_RANK == 0: - logging.info(f"Using device: {device}") + logging.info(f"Using device: {device}") return device @@ -181,7 +178,7 @@ def init_trainer( model_name=None Returns: - pl.Trainer: A torch lightning trainer. + L.Trainer: A torch lightning trainer. """ # The value of `mode` equals to 'min' only when the metric is 'Loss' @@ -233,7 +230,7 @@ def init_trainer( mode="max", ) ] - trainer = pl.Trainer( + trainer = L.Trainer( num_nodes=config.num_nodes, devices=config.devices, max_epochs=config.max_epochs, @@ -251,7 +248,7 @@ def init_trainer( # precision=16, ) else: - trainer = pl.Trainer( + trainer = L.Trainer( logger=False, num_sanity_val_steps=0, accelerator="cpu" if use_cpu else "gpu", @@ -268,23 +265,5 @@ def init_trainer( def set_seed(seed): - """Set seeds for numpy and pytorch. - # Dongli Suggestion: Rename to setup_reproducibility - # random seed if not specified - Args: - seed (int): Random seed. - """ - - if seed is not None: - if seed >= 0: - seed_everything(seed=seed, workers=True) - else: - logging.warning("the random seed should be a non-negative integer") - - -def is_global_zero(func): - def wrapper(*args, **kwargs): - if GLOBAL_RANK == 0: - return func(*args, **kwargs) - - return wrapper + """Wrapper of lightning.seed_everything""" + seed_everything(seed=seed, workers=True) diff --git a/libmultilabel/nn/optimizer.py b/libmultilabel/nn/optimizer.py deleted file mode 100644 index 61b7ece9..00000000 --- a/libmultilabel/nn/optimizer.py +++ /dev/null @@ -1,119 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -""" -Created on 2019/3/7 -@author yrh - -""" - -import math -import torch -from torch.optim.optimizer import Optimizer - - -__all__ = ["DenseSparseAdam"] - - -class DenseSparseAdam(Optimizer): - """ """ - - def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0.0): - if not 0.0 <= lr: - raise ValueError("Invalid learning rate: {}".format(lr)) - if not 0.0 <= eps: - raise ValueError("Invalid epsilon value: {}".format(eps)) - if not 0.0 <= betas[0] < 1.0: - raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) - if not 0.0 <= betas[1] < 1.0: - raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) - defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) - super(DenseSparseAdam, self).__init__(params, defaults) - - def step(self, closure=None): - """ - Performs a single optimization step. - - Parameters - ---------- - closure : ``callable``, optional. - A closure that reevaluates the model and returns the loss. - """ - loss = None - if closure is not None: - loss = closure() - - for group in self.param_groups: - for p in group["params"]: - if p.grad is None: - continue - grad = p.grad.data - - state = self.state[p] - - # State initialization - if "step" not in state: - state["step"] = 0 - if "exp_avg" not in state: - # Exponential moving average of gradient values - state["exp_avg"] = torch.zeros_like(p.data) - if "exp_avg_sq" not in state: - # Exponential moving average of squared gradient values - state["exp_avg_sq"] = torch.zeros_like(p.data) - - state["step"] += 1 - - exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] - beta1, beta2 = group["betas"] - - weight_decay = group["weight_decay"] - - if grad.is_sparse: - grad = grad.coalesce() # the update is non-linear so indices must be unique - grad_indices = grad._indices() - grad_values = grad._values() - size = grad.size() - - def make_sparse(values): - constructor = grad.new - if grad_indices.dim() == 0 or values.dim() == 0: - return constructor().resize_as_(grad) - return constructor(grad_indices, values, size) - - # Decay the first and second moment running average coefficient - # old <- b * old + (1 - b) * new - # <==> old += (1 - b) * (new - old) - old_exp_avg_values = exp_avg.sparse_mask(grad)._values() - exp_avg_update_values = grad_values.sub(old_exp_avg_values).mul_(1 - beta1) - exp_avg.add_(make_sparse(exp_avg_update_values)) - old_exp_avg_sq_values = exp_avg_sq.sparse_mask(grad)._values() - exp_avg_sq_update_values = grad_values.pow(2).sub_(old_exp_avg_sq_values).mul_(1 - beta2) - exp_avg_sq.add_(make_sparse(exp_avg_sq_update_values)) - - # Dense addition again is intended, avoiding another sparse_mask - numer = exp_avg_update_values.add_(old_exp_avg_values) - exp_avg_sq_update_values.add_(old_exp_avg_sq_values) - denom = exp_avg_sq_update_values.sqrt_().add_(group["eps"]) - del exp_avg_update_values, exp_avg_sq_update_values - - bias_correction1 = 1 - beta1 ** state["step"] - bias_correction2 = 1 - beta2 ** state["step"] - step_size = group["lr"] * math.sqrt(bias_correction2) / bias_correction1 - - p.data.add_(make_sparse(-step_size * numer.div_(denom))) - if weight_decay > 0.0: - p.data.add_(-group["lr"] * weight_decay, p.data.sparse_mask(grad)) - else: - # Decay the first and second moment running average coefficient - exp_avg.mul_(beta1).add_(1 - beta1, grad) - exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) - denom = exp_avg_sq.sqrt().add_(group["eps"]) - - bias_correction1 = 1 - beta1 ** state["step"] - bias_correction2 = 1 - beta2 ** state["step"] - step_size = group["lr"] * math.sqrt(bias_correction2) / bias_correction1 - - p.data.addcdiv_(-step_size, exp_avg, denom) - if weight_decay > 0.0: - p.data.add_(-group["lr"] * weight_decay, p.data) - - return loss diff --git a/libmultilabel/nn/plt.py b/libmultilabel/nn/plt.py index 2747691e..8865fc9c 100644 --- a/libmultilabel/nn/plt.py +++ b/libmultilabel/nn/plt.py @@ -6,15 +6,15 @@ from functools import reduce from pathlib import Path from typing import Generator -from datetime import datetime import numpy as np import torch +import torch.distributed as dist from multiprocessing import Process -import torch.distributed as dist -from pytorch_lightning import Trainer -from pytorch_lightning.callbacks import ModelCheckpoint, StochasticWeightAveraging +from lightning import Trainer +from lightning.pytorch.callbacks import ModelCheckpoint, StochasticWeightAveraging +from lightning.pytorch.utilities.rank_zero import rank_zero_only, rank_zero_info, rank_zero_warn from scipy.sparse import csr_matrix, csc_matrix from sklearn.preprocessing import normalize, MultiLabelBinarizer from torch import Tensor @@ -23,10 +23,8 @@ from tqdm import tqdm from pytorch_lightning.callbacks.early_stopping import EarlyStopping -from .data_utils import MultiLabelDataset, PLTDataset +from .datasets import MultiLabelDataset, PLTDataset from .model import PLTModel, BaseModel -from .nn_utils import is_global_zero -from ..common_utils import GLOBAL_RANK __all__ = ["PLTTrainer"] @@ -100,11 +98,10 @@ def __init__( self.max_epochs = self.config["epochs"] self.strategy = "ddp" if (self.num_nodes > 1 or self.devices > 1) and self.accelerator == "gpu" else "auto" - if GLOBAL_RANK == 0: - logger.info( - f"Accelerator: {self.accelerator}, devices: {self.devices}, num_nodes: {self.num_nodes} " - f"max_epochs: {self.max_epochs}, strategy: {self.strategy}" - ) + rank_zero_info( + f"Accelerator: {self.accelerator}, devices: {self.devices}, num_nodes: {self.num_nodes} " + f"max_epochs: {self.max_epochs}, strategy: {self.strategy}" + ) # dataloader parameters self.batch_size = self.config["batch_size"] @@ -251,8 +248,7 @@ def train_level( # model if best_model_path.exists(): # load existing best model - if trainer.is_global_zero: - logger.info(f"Best model loaded from {best_model_path}") + rank_zero_info(f"Best model loaded from {best_model_path}") model = BaseModel.load_from_checkpoint(best_model_path, top_k=self.top_k) else: # train & valid dataloaders for training @@ -283,19 +279,16 @@ def train_level( optimizer_params=self.config.get("optimizer_config"), swa_epoch_start=self.config["swa_epoch_start"][level], ) - if GLOBAL_RANK == 0: - logger.info(f"Training level-{level}. Number of labels: {num_nodes}") + rank_zero_info(f"Training level-{level}. Number of labels: {num_nodes}") trainer.fit(model, train_dataloader, valid_dataloader) # torch.cuda.empty_cache() - if GLOBAL_RANK == 0: - logger.info(f"Best last loaded from {best_model_path}") + rank_zero_info(f"Best last loaded from {best_model_path}") # FIXME: I met a bug while experimenting with ModelCheckpoint trainer.strategy.barrier() model = BaseModel.load_from_checkpoint(best_model_path) # TODO: figure out why model.optimizer = None - if GLOBAL_RANK == 0: - logger.info(f"Finish Training Level-{level}") + rank_zero_info(f"Finish Training Level-{level}") # Utilize single GPU to predict trainer = Trainer( @@ -303,12 +296,10 @@ def train_level( devices=1, accelerator=self.accelerator, ) - print(f"trainer.is_global_zero: {GLOBAL_RANK} {trainer.is_global_zero}") - if GLOBAL_RANK == 0: - logger.info( - f"Generating predictions for Level-{level + 1}. " - f"Number of possible predictions: {num_nodes}. Top k: {self.top_k}" - ) + rank_zero_info( + f"Generating predictions for Level-{level + 1}. " + f"Number of possible predictions: {num_nodes}. Top k: {self.top_k}" + ) # train & valid dataloaders for prediction (without labels) train_dataloader = DataLoader( MultiLabelDataset(train_x), @@ -335,11 +326,12 @@ def train_level( valid_node_score_pred, valid_node_y_pred = map(torch.vstack, list(zip(*valid_node_pred))) torch.cuda.empty_cache() - if GLOBAL_RANK == 0: - logger.info("Getting Candidates") + rank_zero_info("Getting Candidates") node_candidates = np.empty((len(train_x), self.top_k), dtype=np.int32) - prog = tqdm(train_node_y_pred, leave=False, desc="Parents") if GLOBAL_RANK == 0 else train_node_y_pred + prog = rank_zero_only(tqdm(train_node_y_pred, leave=False, desc="Parents")) + if prog is None: + prog = train_node_y_pred for i, ys in enumerate(prog): # true nodes/labels are positive positive = set(train_node_y.indices[train_node_y.indptr[i] : train_node_y.indptr[i + 1]]) @@ -390,11 +382,8 @@ def train_level( trainer.strategy.barrier() - print(f"{GLOBAL_RANK}: ") - if best_model_path.exists(): - if trainer.is_global_zero: - logger.info(f"Best model loaded from {best_model_path}") + rank_zero_info(f"Best model loaded from {best_model_path}") model = PLTModel.load_from_checkpoint(best_model_path, top_k=self.top_k) else: # train & valid dataloaders for training @@ -441,25 +430,21 @@ def train_level( ) # initialize current layer with weights from last layer - if GLOBAL_RANK == 0: - logger.info(f"Loading parameters of Level-{level} from Level-{level - 1}") + rank_zero_info(f"Loading parameters of Level-{level} from Level-{level - 1}") # remove the name prefix in state_dict starting with "network" model.load_from_pretrained(torch.load(self.get_best_model_path(level - 1))["state_dict"]) - if GLOBAL_RANK == 0: - logger.info( - f"Training Level-{level}, " - f"Number of nodes: {num_nodes}, " - f"Number of candidates: {train_dataloader.dataset.num_candidates}" - ) + rank_zero_info( + f"Training Level-{level}, " + f"Number of nodes: {num_nodes}, " + f"Number of candidates: {train_dataloader.dataset.num_candidates}" + ) trainer.fit(model, train_dataloader, valid_dataloader) trainer.save_checkpoint(best_model_path) # FIXME: I met a bug while experimenting with ModelCheckpoint - if GLOBAL_RANK == 0: - logger.info(f"Best model loaded from {best_model_path}") + rank_zero_info(f"Best model loaded from {best_model_path}") trainer.strategy.barrier() model = PLTModel.load_from_checkpoint(best_model_path) - if GLOBAL_RANK == 0: - logger.info(f"Finish training Level-{level}") + rank_zero_info(f"Finish training Level-{level}") # Utilize single GPU to predict trainer = Trainer( num_nodes=1, @@ -468,14 +453,12 @@ def train_level( ) # end training if it is the last level if level == self.num_levels - 1: - if GLOBAL_RANK == 0: - logger.info("Training process finished.") + rank_zero_info("Training process finished.") return - if GLOBAL_RANK == 0: - logger.info( - f"Generating predictions for Level-{level + 1}. " - f"Number of possible predictions: {num_nodes}, Top k: {self.top_k}" - ) + rank_zero_info( + f"Generating predictions for Level-{level + 1}. " + f"Number of possible predictions: {num_nodes}, Top k: {self.top_k}" + ) # train & valid dataloaders for prediction train_dataloader = DataLoader( @@ -534,11 +517,9 @@ def predict_level(self, level, test_x, num_nodes, test_y=None): pin_memory=self.pin_memory, ) - if GLOBAL_RANK == 0: - logger.info(f"Predicting Level-{level}, Top: {self.top_k}") + rank_zero_info(f"Predicting Level-{level}, Top: {self.top_k}") node_pred = trainer.predict(model, test_dataloader) - if GLOBAL_RANK == 0: - logger.info(f"node_pred: {len(node_pred)}") + rank_zero_info(f"node_pred: {len(node_pred)}") node_score_pred, node_label_pred = map(torch.vstack, list(zip(*node_pred))) return node_score_pred, node_label_pred @@ -571,11 +552,9 @@ def predict_level(self, level, test_x, num_nodes, test_y=None): trainer.test(model, test_dataloader) return - if GLOBAL_RANK == 0: - logger.info(f"Predicting Level-{level}, Top: {self.top_k}") + rank_zero_info(f"Predicting Level-{level}, Top: {self.top_k}") node_pred = trainer.predict(model, test_dataloader) - if GLOBAL_RANK == 0: - logger.info(f"node_pred: {len(node_pred)}") + rank_zero_info(f"node_pred: {len(node_pred)}") node_score_pred, node_label_pred = map(torch.vstack, list(zip(*node_pred))) return node_score_pred, node_label_pred @@ -593,10 +572,6 @@ def fit(self, datasets): # sparse training labels # TODO: remove workaround train_sparse_y_full = self.mlb.transform((i["label"] for i in train_data_full)) - cluster_process = Process( - target=build_shallow_and_wide_plt, - args=(train_sparse_x, train_sparse_y_full, self.levels, self.cluster_size, self.dir_path), - ) # TODO: remove workaround # TODO: we assume is 0. Remove this assumption or not? @@ -631,30 +606,32 @@ def fit(self, datasets): train_data = (train_x, self.mlb.transform((i["label"] for i in datasets["train"]))) valid_data = (valid_x, self.mlb.transform((i["label"] for i in datasets["val"]))) - # only do clustering on the main process - if GLOBAL_RANK == 0: - try: - cluster_process.start() - self.train_level(self.num_levels - 1, train_data, valid_data) - cluster_process.join() - finally: - # TODO: How to close process properly? - cluster_process.terminate() - cluster_process.close() - else: - self.train_level(self.num_levels - 1, train_data, valid_data) + # @rank_zero_only + # def start_cluster(): + # cluster_process = Process( + # target=build_shallow_and_wide_plt, + # args=(train_sparse_x, train_sparse_y_full, self.levels, self.cluster_size, self.dir_path), + # ) + # try: + # cluster_process.start() + # cluster_process.join() + # finally: + # # TODO: How to close process properly? + # cluster_process.terminate() + # cluster_process.close() + + # start_cluster() + build_shallow_and_wide_plt(train_sparse_x, train_sparse_y_full, self.levels, self.cluster_size, self.dir_path) + self.train_level(self.num_levels - 1, train_data, valid_data) if dist.is_initialized(): dist.destroy_process_group() torch.cuda.empty_cache() - @is_global_zero + # Here are explanations that why we want to test on a single GPU? + # https://lightning.ai/docs/pytorch/stable/common/evaluation_intermediate.html + @rank_zero_only def test(self, dataset): - # why we want to test on a single GPU? - # https://lightning.ai/docs/pytorch/stable/common/evaluation_intermediate.html - if GLOBAL_RANK != 0: - return - test_x = list( map( lambda x: torch.tensor([self.word_dict[i] for i in x], dtype=torch.int), @@ -680,6 +657,7 @@ def get_cluster_path(self, level: int) -> Path: return self.dir_path / f"Level-{len(self.levels)}-{level}.npy" +@rank_zero_only def build_shallow_and_wide_plt( sparse_x, sparse_y: np.ndarray, @@ -697,8 +675,6 @@ def build_shallow_and_wide_plt( cluster_path: """ - if GLOBAL_RANK != 0: - return logger.info(f"Number of levels: {len(levels) + 1}") logger.info(f"Internal level at depth: {levels}") logger.info(f"Cluster size: {cluster_size}") diff --git a/main.py b/main.py index 5953a0b0..803aaf3f 100644 --- a/main.py +++ b/main.py @@ -6,7 +6,8 @@ import yaml -from libmultilabel.common_utils import Timer, AttributeDict, GLOBAL_RANK +from libmultilabel.common_utils import Timer, AttributeDict +from lightning.pytorch.utilities.rank_zero import rank_zero_info from libmultilabel.logging import add_stream_handler, add_collect_handler @@ -285,8 +286,6 @@ def check_config(config): def main(): - logging.info(f"Global rank: {GLOBAL_RANK}") - # Get config config = get_config() check_config(config) @@ -295,8 +294,8 @@ def main(): log_level = logging.WARNING if config.silent else logging.INFO stream_handler = add_stream_handler(log_level) collect_handler = add_collect_handler(logging.NOTSET) - if GLOBAL_RANK == 0: - logging.info(f"Run name: {config.run_name}") + + rank_zero_info(f"Run name: {config.run_name}") if config.linear: from linear_trainer import linear_run diff --git a/requirements.txt b/requirements.txt index 538c7974..6702676f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,14 +1,13 @@ -liblinear-multicore nltk -numba -numpy pandas>1.3.0 -pytorch-lightning==1.7.7 PyYAML scikit-learn -scipy -torch>=2.0.0 -torchmetrics>=1.0 +torch>=1.13.1 +torchmetrics==0.10.3 torchtext>=0.13.0 +pytorch-lightning==1.7.7 tqdm +liblinear-multicore +numba +scipy transformers \ No newline at end of file diff --git a/torch_trainer.py b/torch_trainer.py index 80ef3098..9499d612 100644 --- a/torch_trainer.py +++ b/torch_trainer.py @@ -10,6 +10,7 @@ import torch.distributed as dist import pytorch_lightning as pl from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint +from lightning.pytorch.utilities.rank_zero import rank_zero_only, rank_zero_info, rank_zero_warn from sklearn.preprocessing import MultiLabelBinarizer from torch.nn.utils.rnn import pad_sequence from torch.utils.data import DataLoader @@ -17,11 +18,11 @@ from libmultilabel.common_utils import dump_log, is_multiclass_dataset, AttributeDict from libmultilabel.nn import data_utils -from libmultilabel.nn.data_utils import MultiLabelDataset, UNK +from libmultilabel.nn.data_utils import UNK +from libmultilabel.nn.datasets import MultiLabelDataset from libmultilabel.nn.model import Model, BaseModel -from libmultilabel.nn.nn_utils import init_device, init_model, init_trainer, set_seed, is_global_zero +from libmultilabel.nn.nn_utils import init_device, init_model, init_trainer, set_seed from libmultilabel.nn.plt import PLTTrainer -from libmultilabel.common_utils import GLOBAL_RANK class TorchTrainer: @@ -48,13 +49,11 @@ def __init__( embed_vecs=None, search_params: bool = False, save_checkpoints: bool = True, - ensemble_id: int = 0, ): self.run_name = config.run_name self.checkpoint_dir = config.checkpoint_dir self.log_path = config.log_path self.classes = classes - self.ensemble_id = ensemble_id os.makedirs(self.checkpoint_dir, exist_ok=True) # Set up seed & device @@ -62,11 +61,7 @@ def __init__( self.device = init_device(use_cpu=config.cpu) self.config = config - if config.model_name == "FastAttentionXML": - # TODO: Remove - import re - - self.model_path = re.search(r".+_.+", config.checkpoint_dir) + if self.config.model_name.lower() == "FastAttentionXML": if datasets is None: self.datasets = data_utils.load_datasets( training_data=config.training_file, @@ -75,48 +70,42 @@ def __init__( val_data=config.val_file, val_size=config.val_size, merge_train_val=config.merge_train_val, - # TODO: move to config tokenize_text=True, - lowercase=config.lowercase, tokenizer=config.get("tokenizer", "regex"), remove_no_label_data=config.remove_no_label_data, - random_state=config.get("random_state", 42), + random_state=config.get("random_state", 1270), ) if not (Path(config.test_file).parent / "word_dict.vocab").exists(): - if GLOBAL_RANK == 0: - logging.info("Calculating word dictionary and embeddings.") - word_dict, embed_vecs = data_utils.load_or_build_text_dict( - dataset=self.datasets["train"] + self.datasets["val"], - vocab_file=self.config.vocab_file, - # TODO: move to config - min_vocab_freq=self.config.min_vocab_freq, - embed_file=self.config.embed_file, - silent=self.config.silent, - normalize_embed=self.config.normalize_embed, - embed_cache_dir=self.config.embed_cache_dir, - max_tokens=self.config.get("max_tokens"), - unk_init=self.config.get("unk_init", "uniform"), - unk_init_param=self.config.get("unk_init_param", {-1, 1}), - apply_all=self.config.get("apply_all", True), - ) + rank_zero_info("Calculating word dictionary and embeddings.") + word_dict, embed_vecs = data_utils.load_or_build_text_dict( + dataset=self.datasets["train"] + self.datasets["val"], + vocab_file=self.config.vocab_file, + # TODO: move to config + min_vocab_freq=self.config.min_vocab_freq, + embed_file=self.config.embed_file, + silent=self.config.silent, + normalize_embed=self.config.normalize_embed, + embed_cache_dir=self.config.embed_cache_dir, + max_tokens=self.config.get("max_tokens"), + unk_init=self.config.get("unk_init", "uniform"), + unk_init_param=self.config.get("unk_init_param", {-1, 1}), + apply_all=self.config.get("apply_all", True), + ) + if word_dict is not None: torch.save(word_dict, Path(config.test_file).parent / "word_dict.vocab") torch.save(embed_vecs, Path(config.test_file).parent / "word_embeddings.tensor") # barrier - else: - while not (Path(config.test_file).parent / "word_dict.vocab").exists(): - time.sleep(15) + while not (Path(config.test_file).parent / "word_dict.vocab").exists(): + time.sleep(15) word_dict = torch.load(Path(config.test_file).parent / "word_dict.vocab") embed_vecs = torch.load(Path(config.test_file).parent / "word_embeddings.tensor") if not classes: - logging.warning(f"[Rank: {GLOBAL_RANK}] read labels.") classes = data_utils.load_or_build_label( self.datasets, self.config.label_file, self.config.include_test_labels ) - self.trainer = PLTTrainer( - config, classes=classes, ensemble_id=self.ensemble_id, word_dict=word_dict, embed_vecs=embed_vecs - ) + self.trainer = PLTTrainer(config, classes=classes, word_dict=word_dict, embed_vecs=embed_vecs) else: # Load pretrained tokenizer for dataset loader self.tokenizer = None @@ -132,7 +121,6 @@ def __init__( val_size=config.val_size, merge_train_val=config.merge_train_val, tokenize_text=tokenize_text, - lowercase=config.get("lowercase", True), tokenizer=config.get("tokenizer", "regex"), remove_no_label_data=config.remove_no_label_data, ) @@ -197,34 +185,35 @@ def _setup_model( else: logging.info("Initialize model from scratch.") if self.config.embed_file is not None: - # if not (Path(self.config.test_file).parent / "word_dict.vocab").exists(): - # if GLOBAL_RANK == 0: - logging.info("Calculating word dictionary and embeddings.") - logging.info("Load word dictionary ") - word_dict, embed_vecs = data_utils.load_or_build_text_dict( - dataset=self.datasets["train"] + self.datasets["val"], - vocab_file=self.config.vocab_file, - min_vocab_freq=self.config.min_vocab_freq, - embed_file=self.config.embed_file, - silent=self.config.silent, - normalize_embed=self.config.normalize_embed, - embed_cache_dir=self.config.embed_cache_dir, - max_tokens=self.config.get("max_tokens"), - unk_init=self.config.get("unk_init"), - unk_init_param=self.config.get("unk_init_param"), - apply_all=self.config.get("apply_all", False), - ) - # torch.save(word_dict, Path(self.config.test_file).parent / "word_dict.vocab") - # torch.save(embed_vecs, Path(self.config.test_file).parent / "word_embeddings.tensor") - # time.sleep(15) - # barrier - # else: - # while not (Path(self.config.test_file).parent / "word_dict.vocab").exists(): - # time.sleep(15) + if not (Path(self.config.test_file).parent / "word_dict.vocab").exists(): + rank_zero_info("Calculating word dictionary and embeddings.") + rank_zero_info("Load word dictionary") + word_dict, embed_vecs = data_utils.load_or_build_text_dict( + dataset=self.datasets["train"] + if self.config.model_name.lower() != "AttentionXML" + else self.datasets["train"] + self.datasets["val"], + vocab_file=self.config.vocab_file, + min_vocab_freq=self.config.min_vocab_freq, + embed_file=self.config.embed_file, + silent=self.config.silent, + normalize_embed=self.config.normalize_embed, + embed_cache_dir=self.config.embed_cache_dir, + max_tokens=self.config.get("max_tokens"), + unk_init=self.config.get("unk_init"), + unk_init_param=self.config.get("unk_init_param"), + apply_all=self.config.get("apply_all", False), + ) + if word_dict is not None: + torch.save(word_dict, Path(self.config.test_file).parent / "word_dict.vocab") + torch.save(embed_vecs, Path(self.config.test_file).parent / "word_embeddings.tensor") + # barrier + while not (Path(self.config.test_file).parent / "word_dict.vocab").exists(): + time.sleep(15) - # word_dict = torch.load(Path(self.config.test_file).parent / "word_dict.vocab") + word_dict = torch.load(Path(self.config.test_file).parent / "word_dict.vocab") + embed_vecs = torch.load(Path(self.config.test_file).parent / "word_embeddings.tensor") self.word_dict = word_dict - # embed_vecs = torch.load(Path(self.config.test_file).parent / "word_embeddings.tensor") + self.embed_vecs = embed_vecs if not classes: classes = data_utils.load_or_build_label( @@ -300,12 +289,12 @@ def train(self): self.trainer is not None ), "Please make sure the trainer is successfully initialized by `self._setup_trainer()`." - if self.config.model_name == "FastAttentionXML": + if self.config.model_name.lower() == "FastAttentionXML": self.trainer.fit(self.datasets) else: - if self.config.model_name == "AttentionXML": + if self.config.model_name.lower() == "AttentionXML": if ( - Path(self.config.dir_path) / f"{self.config.data_name}-{self.config.model_name}-{0}" / "Model.ckpt" + Path(self.config.dir_path) / f"{self.config.data_name}-{self.config.model_name}" / "Model.ckpt" ).exists(): return @@ -358,7 +347,7 @@ def train(self): If you want to save the best and the last model, please set `save_checkpoints` to True." ) - @is_global_zero + @rank_zero_only def test(self, split="test"): """Test model with pytorch lightning trainer. Top-k predictions are saved if `save_k_predictions` > 0.