Skip to content

Commit

Permalink
rewrite for reproduction purpose only
Browse files Browse the repository at this point in the history
  • Loading branch information
donglihe-hub committed Nov 30, 2023
1 parent 8fa186e commit d0f433e
Show file tree
Hide file tree
Showing 2 changed files with 1 addition and 22 deletions.
9 changes: 0 additions & 9 deletions libmultilabel/linear/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,9 +186,6 @@ def train_thresholding(
Returns:
A model which can be used in predict_values.
"""
if not is_multilabel:
raise ValueError("thresholding method doesn't support binary/multiclass datasets.")

x, options, bias = _prepare_options(x, options)

y = y.tocsc()
Expand Down Expand Up @@ -413,9 +410,6 @@ def train_cost_sensitive(
Returns:
A model which can be used in predict_values.
"""
if not is_multilabel:
raise ValueError("cost_sensitive method doesn't support binary/multiclass datasets.")

# Follows the MATLAB implementation at https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/multilabel/
x, options, bias = _prepare_options(x, options)

Expand Down Expand Up @@ -520,9 +514,6 @@ def train_cost_sensitive_micro(
Returns:
A model which can be used in predict_values.
"""
if not is_multilabel:
raise ValueError("cost_sensitive_micro method doesn't support binary/multiclass datasets.")

# Follows the MATLAB implementation at https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/multilabel/
x, options, bias = _prepare_options(x, options)

Expand Down
14 changes: 1 addition & 13 deletions linear_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,19 +40,7 @@ def linear_test(config, model, datasets, label_mapping):

def linear_train(datasets, config):
# detect task type
is_multilabel = config.get("is_multilabel", "auto")
if is_multilabel == "auto":
is_multilabel = not is_multiclass_dataset(datasets["train"], "y")
elif not isinstance(is_multilabel, bool):
raise ValueError(
f'"is_multilabel" is expected to be either "auto", "True", or "False". But got "{is_multilabel}" instead.'
)

task_type = "multilabel" if is_multilabel else "binary/multiclass"
logging.info(
f'is_multilabel is set to "{config.get("is_multilabel", "auto")}". '
f"Model will be trained in {task_type} mode."
)
is_multilabel = not is_multiclass_dataset(datasets["train"], "y")

# train
if config.linear_technique == "tree":
Expand Down

0 comments on commit d0f433e

Please sign in to comment.