Skip to content

Commit

Permalink
is_multilabel is renamed multiclass
Browse files Browse the repository at this point in the history
  • Loading branch information
donglihe-hub committed Dec 1, 2023
1 parent 173ee1b commit abfb587
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 23 deletions.
34 changes: 17 additions & 17 deletions libmultilabel/linear/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,13 @@ def __init__(
weights: np.matrix,
bias: float,
thresholds: float | np.ndarray,
is_multilabel: bool,
multiclass: bool,
):
self.name = name
self.weights = weights
self.bias = bias
self.thresholds = thresholds
self.is_multilabel = is_multilabel
self.multiclass = multiclass

def predict_values(self, x: sparse.csr_matrix) -> np.ndarray:
"""Calculates the decision values associated with x.
Expand Down Expand Up @@ -72,7 +72,7 @@ def predict_values(self, x: sparse.csr_matrix) -> np.ndarray:
def train_1vsrest(
y: sparse.csr_matrix,
x: sparse.csr_matrix,
is_multilabel: bool,
multiclass: bool,
options: str = "",
verbose: bool = True,
) -> FlatModel:
Expand All @@ -81,7 +81,7 @@ def train_1vsrest(
Args:
y (sparse.csr_matrix): A 0/1 matrix with dimensions number of instances * number of classes.
x (sparse.csr_matrix): A matrix with dimensions number of instances * number of features.
is_multilabel (bool): A flag indicating if the dataset is multilabel.
multiclass (bool): A flag indicating if the dataset is multiclass.
options (str, optional): The option string passed to liblinear. Defaults to ''.
verbose (bool, optional): Output extra progress information. Defaults to True.
Expand All @@ -107,7 +107,7 @@ def train_1vsrest(
weights=np.asmatrix(weights),
bias=bias,
thresholds=0,
is_multilabel=is_multilabel,
multiclass=multiclass,
)


Expand Down Expand Up @@ -162,7 +162,7 @@ def _prepare_options(x: sparse.csr_matrix, options: str) -> tuple[sparse.csr_mat
def train_thresholding(
y: sparse.csr_matrix,
x: sparse.csr_matrix,
is_multilabel: bool,
multiclass: bool,
options: str = "",
verbose: bool = True,
) -> FlatModel:
Expand All @@ -179,7 +179,7 @@ def train_thresholding(
Args:
y (sparse.csr_matrix): A 0/1 matrix with dimensions number of instances * number of classes.
x (sparse.csr_matrix): A matrix with dimensions number of instances * number of features.
is_multilabel (bool): A flag indicating if the dataset is multilabel.
multiclass (bool): A flag indicating if the dataset is multiclass.
options (str, optional): The option string passed to liblinear. Defaults to ''.
verbose (bool, optional): Output extra progress information. Defaults to True.
Expand Down Expand Up @@ -214,7 +214,7 @@ def train_thresholding(
weights=np.asmatrix(weights),
bias=bias,
thresholds=thresholds,
is_multilabel=is_multilabel,
multiclass=multiclass,
)


Expand Down Expand Up @@ -389,7 +389,7 @@ def _fmeasure(y_true: np.ndarray, y_pred: np.ndarray) -> float:
def train_cost_sensitive(
y: sparse.csr_matrix,
x: sparse.csr_matrix,
is_multilabel: bool,
multiclass: bool,
options: str = "",
verbose: bool = True,
) -> FlatModel:
Expand All @@ -403,7 +403,7 @@ def train_cost_sensitive(
Args:
y (sparse.csr_matrix): A 0/1 matrix with dimensions number of instances * number of classes.
x (sparse.csr_matrix): A matrix with dimensions number of instances * number of features.
is_multilabel (bool): A flag indicating if the dataset is multilabel.
multiclass (bool): A flag indicating if the dataset is multiclass.
options (str, optional): The option string passed to liblinear. Defaults to ''.
verbose (bool, optional): Output extra progress information. Defaults to True.
Expand All @@ -430,7 +430,7 @@ def train_cost_sensitive(
weights=np.asmatrix(weights),
bias=bias,
thresholds=0,
is_multilabel=is_multilabel,
multiclass=multiclass,
)


Expand Down Expand Up @@ -493,7 +493,7 @@ def _cross_validate(y: np.ndarray, x: sparse.csr_matrix, options: str, perm: np.
def train_cost_sensitive_micro(
y: sparse.csr_matrix,
x: sparse.csr_matrix,
is_multilabel: bool,
multiclass: bool,
options: str = "",
verbose: bool = True,
) -> FlatModel:
Expand All @@ -507,7 +507,7 @@ def train_cost_sensitive_micro(
Args:
y (sparse.csr_matrix): A 0/1 matrix with dimensions number of instances * number of classes.
x (sparse.csr_matrix): A matrix with dimensions number of instances * number of features.
is_multilabel (bool): A flag indicating if the dataset is multilabel.
multiclass (bool): A flag indicating if the dataset is multiclass.
options (str, optional): The option string passed to liblinear. Defaults to ''.
verbose (bool, optional): Output extra progress information. Defaults to True.
Expand Down Expand Up @@ -557,14 +557,14 @@ def train_cost_sensitive_micro(
weights=np.asmatrix(weights),
bias=bias,
thresholds=0,
is_multilabel=is_multilabel,
multiclass=multiclass,
)


def train_binary_and_multiclass(
y: sparse.csr_matrix,
x: sparse.csr_matrix,
is_multilabel: bool,
multiclass: bool,
options: str = "",
verbose: bool = True,
) -> FlatModel:
Expand All @@ -573,7 +573,7 @@ def train_binary_and_multiclass(
Args:
y (sparse.csr_matrix): A 0/1 matrix with dimensions number of instances * number of classes.
x (sparse.csr_matrix): A matrix with dimensions number of instances * number of features.
is_multilabel (bool): A flag indicating if the dataset is multilabel.
multiclass (bool): A flag indicating if the dataset is multiclass.
options (str, optional): The option string passed to liblinear. Defaults to ''.
verbose (bool, optional): Output extra progress information. Defaults to True.
Expand Down Expand Up @@ -614,7 +614,7 @@ def train_binary_and_multiclass(
weights=np.asmatrix(weights),
bias=bias,
thresholds=thresholds,
is_multilabel=is_multilabel,
multiclass=multiclass,
)


Expand Down
10 changes: 4 additions & 6 deletions linear_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,7 @@


def linear_test(config, model, datasets, label_mapping):
metrics = linear.get_metrics(
config.monitor_metrics, datasets["test"]["y"].shape[1], multiclass=not model.is_multilabel
)
metrics = linear.get_metrics(config.monitor_metrics, datasets["test"]["y"].shape[1], multiclass=model.multiclass)
num_instance = datasets["test"]["x"].shape[0]
k = config.save_k_predictions
if k > 0:
Expand All @@ -38,11 +36,11 @@ def linear_test(config, model, datasets, label_mapping):

def linear_train(datasets, config):
# detect task type
is_multilabel = not is_multiclass_dataset(datasets["train"], "y")
multiclass = is_multiclass_dataset(datasets["train"], "y")

# train
if config.linear_technique == "tree":
if not is_multilabel:
if multiclass:
raise ValueError("Tree model should only be used with multilabel datasets.")

model = LINEAR_TECHNIQUES[config.linear_technique](
Expand All @@ -56,7 +54,7 @@ def linear_train(datasets, config):
model = LINEAR_TECHNIQUES[config.linear_technique](
datasets["train"]["y"],
datasets["train"]["x"],
is_multilabel=is_multilabel,
multiclass=multiclass,
options=config.liblinear_options,
)
return model
Expand Down

0 comments on commit abfb587

Please sign in to comment.