Skip to content

Commit

Permalink
not use momory to store nnz statistics
Browse files Browse the repository at this point in the history
  • Loading branch information
ericliu8168 committed Aug 21, 2024
1 parent 56bf1a5 commit 73b8ddd
Showing 1 changed file with 8 additions and 11 deletions.
19 changes: 8 additions & 11 deletions libmultilabel/linear/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def count(node):
total_memory = psutil.virtual_memory().total
print(f'Your system memory is: {total_memory / (1024**3):.3f} GB')

model_size = get_estimated_model_size(root, num_nodes)
model_size = get_estimated_model_size(root)
print(f'The estimated tree model size is: {model_size / (1024**3):.3f} GB')

if (total_memory <= model_size):
Expand Down Expand Up @@ -208,24 +208,21 @@ def _build_tree(label_representation: sparse.csr_matrix, label_map: np.ndarray,
return Node(label_map=label_map, children=children)


def get_estimated_model_size(root, num_nodes):
num_nnz_feat, num_branches = np.zeros(num_nodes), np.zeros(num_nodes)
num_nodes = 0
def get_estimated_model_size(root):
total_num_weights = 0

def collect_stat(node: Node):
nonlocal num_nodes
num_nnz_feat[num_nodes] = node.num_nnz_feat
nonlocal total_num_weights

if node.isLeaf():
num_branches[num_nodes] = len(node.label_map)
total_num_weights += len(node.label_map) * node.num_nnz_feat
else:
num_branches[num_nodes] = len(node.children)

num_nodes += 1
total_num_weights += len(node.children) * node.num_nnz_feat

root.dfs(collect_stat)

# 16 is because when storing sparse matrices, indices (int64) require 8 bytes and floats require 8 bytes
return np.dot(num_nnz_feat, num_branches) * 16
return total_num_weights * 16


def _train_node(y: sparse.csr_matrix, x: sparse.csr_matrix, options: str, node: Node):
Expand Down

0 comments on commit 73b8ddd

Please sign in to comment.