diff --git a/libmultilabel/linear/tree.py b/libmultilabel/linear/tree.py index 0c428fec..4436fb09 100644 --- a/libmultilabel/linear/tree.py +++ b/libmultilabel/linear/tree.py @@ -7,6 +7,7 @@ import sklearn.cluster import sklearn.preprocessing from tqdm import tqdm +import psutil from . import linear @@ -135,13 +136,25 @@ def train_tree( root = _build_tree(label_representation, np.arange(y.shape[1]), 0, K, dmax) num_nodes = 0 + label_feature_used = (y.T * (x != 0)).tocsr() def count(node): nonlocal num_nodes num_nodes += 1 + node.num_nnz_feat = np.count_nonzero(label_feature_used[node.label_map,:].sum(axis=0)) root.dfs(count) + # Calculate the total memory (excluding swap) on the local machine + total_memory = psutil.virtual_memory().total + print(f'{total_memory / (1024**3):.3f} GB') + + model_size = get_estimated_model_size(root, num_nodes) + print(f'*** model_size: {model_size / (1024**3):.3f} GB') + + if (total_memory <= model_size): + raise MemoryError(f'Not enough memory to train the model. model_size: {model_size / (1024**3):.3f} GB') + pbar = tqdm(total=num_nodes, disable=not verbose) def visit(node): @@ -195,6 +208,26 @@ def _build_tree(label_representation: sparse.csr_matrix, label_map: np.ndarray, return Node(label_map=label_map, children=children) +def get_estimated_model_size(root, num_nodes): + num_nnz_feat, num_branches = np.zeros(num_nodes), np.zeros(num_nodes) + num_nodes = 0 + def collect_stat(node: Node): + nonlocal num_nodes + num_nnz_feat[num_nodes] = node.num_nnz_feat + + if node.isLeaf(): + num_branches[num_nodes] = len(node.label_map) + else: + num_branches[num_nodes] = len(node.children) + + num_nodes += 1 + + root.dfs(collect_stat) + + # 16 is because when storing sparse matrices, indices (int64) require 8 bytes and floats require 8 bytes + return np.dot(num_nnz_feat, num_branches) * 16 + + def _train_node(y: sparse.csr_matrix, x: sparse.csr_matrix, options: str, node: Node): """If node is internal, computes the metalabels representing each child and trains on the metalabels. Otherwise, train on y. diff --git a/requirements.txt b/requirements.txt index ed6bd6ca..c1285e72 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,3 +5,4 @@ PyYAML scikit-learn scipy tqdm +psutil