Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add multilabel_probabilities #15

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions small_text/classifiers/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def fit(self, train_set, weights=None):
self.model.fit(train_set.x, y, **fit_kwargs)
return self

def predict(self, data_set, return_proba=False):
def predict(self, data_set, return_proba=False,sparse_proba=False):
"""
Predicts the labels for the given dataset.

Expand All @@ -122,14 +122,20 @@ def predict(self, data_set, return_proba=False):
A dataset for which the labels are to be predicted.
return_proba : bool, default=False
If `True`, also returns a probability-like class distribution.
sparse_proba: bool, default=False
Only relevant if return_proba=True and multilabel=True.
If False probabilities for each label is returned as a numpy array shape(samples,labels)
If True only probabilities of successfully predicted labels are returned in a sparce matrix.

Returns
-------
predictions : np.ndarray[np.int32] or csr_matrix[np.int32]
List of predictions if the classifier was fitted on multi-label data,
otherwise a sparse matrix of predictions.
probas : np.ndarray[np.float32]
List of probabilities (or confidence estimates) if `return_proba` is True.
probas : np.ndarray[np.float32] or csr_matrix (optional)
List of probabilities (or confidence estimates) if `return_proba` is True and binary classification or multi-class classification.
csr_matrix if multilabel classidication and sparse_proba =False
List of lists if multilabel classidication and sparse_proba = True
"""
if len(data_set) == 0:
return empty_result(self.multi_label, self.num_classes, return_prediction=True,
Expand All @@ -138,7 +144,7 @@ def predict(self, data_set, return_proba=False):
proba = self.model.predict_proba(data_set.x)

return prediction_result(proba, self.multi_label, self.num_classes, enc=None,
return_proba=return_proba)
return_proba=return_proba, sparse_proba=sparse_proba)

def predict_proba(self, data_set):
if len(data_set) == 0:
Expand Down
8 changes: 6 additions & 2 deletions small_text/integrations/pytorch/classifiers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,14 +50,18 @@ def __init__(self, multi_label=False, device=None, mini_batch_size=32):
def fit(self, train_set, validation_set=None, weights=None, **kwargs):
pass

def predict(self, data_set, return_proba=False):
def predict(self, data_set, return_proba=False, sparse_proba=True):
"""
Parameters
----------
data_set : small_text.data.Dataset
A dataset on whose instances predictions are made.
return_proba : bool
If True, additionally returns the confidence distribution over all classes.
sparse_proba: bool, default=True
Only relevant if return_proba=True and multilabel=True.
If False probabilities for each label is returned as a numpy array shape(samples,labels)
If True only probabilities of successfully predicted labels are returned in a sparce matrix.

Returns
-------
Expand All @@ -72,7 +76,7 @@ def predict(self, data_set, return_proba=False):
return_proba=return_proba)

proba = self.predict_proba(data_set)
predictions = prediction_result(proba, self.multi_label, self.num_classes, enc=self.enc_)
predictions = prediction_result(proba, self.multi_label, self.num_classes, enc=self.enc_,sparse_proba=sparse_proba)

if return_proba:
return predictions, proba
Expand Down
8 changes: 6 additions & 2 deletions small_text/integrations/pytorch/classifiers/kimcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,7 +458,7 @@ def validate(self, validation_set):

return valid_loss / len(validation_set), acc / len(validation_set)

def predict(self, data_set, return_proba=False):
def predict(self, data_set, return_proba=False, sparse_proba=True):
"""
Predicts the labels for the given dataset.

Expand All @@ -468,6 +468,10 @@ def predict(self, data_set, return_proba=False):
A dataset on whose instances predictions are made.
return_proba : bool
If True, additionally returns the confidence distribution over all classes.
sparse_proba: bool, default=True
Only relevant if return_proba=True and multilabel=True.
If False probabilities for each label is returned as a numpy array shape(samples,labels)
If True only probabilities of successfully predicted labels are returned in a sparce matrix.

Returns
-------
Expand All @@ -477,7 +481,7 @@ def predict(self, data_set, return_proba=False):
probas : np.ndarray[np.float32] (optional)
List of probabilities (or confidence estimates) if `return_proba` is True.
"""
return super().predict(data_set, return_proba=return_proba)
return super().predict(data_set, return_proba=return_proba, sparse_proba=sparse_proba)

def predict_proba(self, test_set):
if len(test_set) == 0:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -637,7 +637,7 @@ def validate(self, validation_set):

return valid_loss / len(validation_set), acc / len(validation_set)

def predict(self, data_set, return_proba=False):
def predict(self, data_set, return_proba=False, sparse_proba=True):
"""
Parameters
----------
Expand All @@ -653,8 +653,12 @@ def predict(self, data_set, return_proba=False):
otherwise a sparse matrix of predictions.
probas : np.ndarray[np.float32], optional
List of probabilities (or confidence estimates) if `return_proba` is True.
sparse_proba: bool, default=True
Only relevant if return_proba=True and multilabel=True.
If False probabilities for each label is returned as a numpy array shape(samples,labels)
If True only probabilities of successfully predicted labels are returned in a sparce matrix.
"""
return super().predict(data_set, return_proba=return_proba)
return super().predict(data_set, return_proba=return_proba, sparse_proba=sparse_proba)

def predict_proba(self, test_set):
if len(test_set) == 0:
Expand Down
8 changes: 6 additions & 2 deletions small_text/utils/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def get_splits(train_set, validation_set, weights=None, multi_label=False, valid
return result


def prediction_result(proba, multi_label, num_classes, enc=None, return_proba=False):
def prediction_result(proba, multi_label, num_classes, enc=None, return_proba=False, sparse_proba=True):
"""Helper method which returns a single- or multi-label prediction result.

Parameters
Expand All @@ -79,6 +79,10 @@ def prediction_result(proba, multi_label, num_classes, enc=None, return_proba=Fa
Also returns the probability if `True`. This is intended to be used with `multi_label=True`
where it returns a sparse matrix with only the probabilities for the predicted labels. For
the single-label case this simply returns the given `proba` input.
sparse_proba: bool, default=True
Only relevant if return_proba=True and multilabel=True.
If True probabilities for each label is returned as a numpy array shape(samples,labels)
If 'False only probabilities of successfully predicted labels are returned in a sparce matrix.

Returns
-------
Expand All @@ -97,7 +101,7 @@ def multihot_to_list(x):
predictions = [multihot_to_list(row) for row in predictions_binarized]
predictions = list_to_csr(predictions, shape=(len(predictions), num_classes))

if return_proba:
if return_proba and sparse_proba:
data = proba[predictions_binarized.astype(bool)]
proba = csr_matrix((data, predictions.indices, predictions.indptr),
shape=predictions.shape,
Expand Down
27 changes: 25 additions & 2 deletions tests/unit/small_text/utils/test_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,14 +51,14 @@ def test_prediction_result_multilabel(self):
]))
assert_csr_matrix_equal(expected, result)

def test_prediction_result_multilabel_with_proba(self):
def test_prediction_result_multilabel_with_sparse_proba(self):
proba = np.array([
[0.1, 0.2, 0.6, 0.1],
[0.25, 0.25, 0.25, 0.25],
[0.3, 0.3, 0.2, 0.2],
[0.3, 0.2, 0.5, 0.1],
])
result, proba_result = prediction_result(proba, True, proba.shape[1], return_proba=True)
result, proba_result = prediction_result(proba, True, proba.shape[1], return_proba=True, sparse_proba=True)
expected = csr_matrix(np.array([
[0, 0, 1, 0],
[0, 0, 0, 0],
Expand All @@ -74,6 +74,29 @@ def test_prediction_result_multilabel_with_proba(self):
]))
assert_csr_matrix_equal(expected_proba, proba_result)

def test_prediction_result_multilabel_with_all_proba(self):
proba = np.array([
[0.1, 0.2, 0.6, 0.1],
[0.25, 0.25, 0.25, 0.25],
[0.3, 0.3, 0.2, 0.2],
[0.3, 0.2, 0.5, 0.1],
])
result, proba_result = prediction_result(proba, True, proba.shape[1], return_proba=True, sparse_proba=False)
expected = csr_matrix(np.array([
[0, 0, 1, 0],
[0, 0, 0, 0],
[0, 0, 0, 0],
[0, 0, 0, 0],
]))
assert_csr_matrix_equal(expected, result)
expected_proba = np.array([
[0.1, 0.2, 0.6, 0.1],
[0.25, 0.25, 0.25, 0.25],
[0.3, 0.3, 0.2, 0.2],
[0.3, 0.2, 0.5, 0.1]
])
assert_array_equal(expected_proba, proba_result)

def test_prediction_result_multilabel_with_enc(self):
all_labels = [[0], [0, 1], [2, 3]]
enc = MultiLabelBinarizer()
Expand Down