From 73b8ddd0b510dd33a32c743f761bc7440ea96901 Mon Sep 17 00:00:00 2001 From: ericliu8168 Date: Wed, 21 Aug 2024 15:39:55 +0800 Subject: [PATCH] not use momory to store nnz statistics --- libmultilabel/linear/tree.py | 19 ++++++++----------- 1 file changed, 8 insertions(+), 11 deletions(-) diff --git a/libmultilabel/linear/tree.py b/libmultilabel/linear/tree.py index 0227c911..71e285b5 100644 --- a/libmultilabel/linear/tree.py +++ b/libmultilabel/linear/tree.py @@ -149,7 +149,7 @@ def count(node): total_memory = psutil.virtual_memory().total print(f'Your system memory is: {total_memory / (1024**3):.3f} GB') - model_size = get_estimated_model_size(root, num_nodes) + model_size = get_estimated_model_size(root) print(f'The estimated tree model size is: {model_size / (1024**3):.3f} GB') if (total_memory <= model_size): @@ -208,24 +208,21 @@ def _build_tree(label_representation: sparse.csr_matrix, label_map: np.ndarray, return Node(label_map=label_map, children=children) -def get_estimated_model_size(root, num_nodes): - num_nnz_feat, num_branches = np.zeros(num_nodes), np.zeros(num_nodes) - num_nodes = 0 +def get_estimated_model_size(root): + total_num_weights = 0 + def collect_stat(node: Node): - nonlocal num_nodes - num_nnz_feat[num_nodes] = node.num_nnz_feat + nonlocal total_num_weights if node.isLeaf(): - num_branches[num_nodes] = len(node.label_map) + total_num_weights += len(node.label_map) * node.num_nnz_feat else: - num_branches[num_nodes] = len(node.children) - - num_nodes += 1 + total_num_weights += len(node.children) * node.num_nnz_feat root.dfs(collect_stat) # 16 is because when storing sparse matrices, indices (int64) require 8 bytes and floats require 8 bytes - return np.dot(num_nnz_feat, num_branches) * 16 + return total_num_weights * 16 def _train_node(y: sparse.csr_matrix, x: sparse.csr_matrix, options: str, node: Node):