Skip to content

Commit

Permalink
fix arg missing issue
Browse files Browse the repository at this point in the history
  • Loading branch information
donglihe-hub committed Nov 29, 2023
1 parent 8326176 commit aecb787
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions libmultilabel/linear/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,14 +203,14 @@ def _train_node(y: sparse.csr_matrix, x: sparse.csr_matrix, options: str, node:
node (Node): Node to be trained.
"""
if node.isLeaf():
node.model = linear.train_1vsrest(y[:, node.label_map], x, options, False)
node.model = linear.train_1vsrest(y[:, node.label_map], x, True, options, False)
else:
# meta_y[i, j] is 1 if the ith instance is relevant to the jth child.
# getnnz returns an ndarray of shape number of instances.
# This must be reshaped into number of instances * 1 to be interpreted as a column.
meta_y = [y[:, child.label_map].getnnz(axis=1)[:, np.newaxis] > 0 for child in node.children]
meta_y = sparse.csr_matrix(np.hstack(meta_y))
node.model = linear.train_1vsrest(meta_y, x, options, False)
node.model = linear.train_1vsrest(meta_y, x, True, options, False)

node.model.weights = sparse.csc_matrix(node.model.weights)

Expand Down

0 comments on commit aecb787

Please sign in to comment.