Skip to content

Commit

Permalink
fix format
Browse files Browse the repository at this point in the history
  • Loading branch information
EdenWuyifan committed Jun 30, 2023
1 parent 552bc6f commit be8c240
Showing 1 changed file with 8 additions and 8 deletions.
16 changes: 8 additions & 8 deletions alpha_automl/builtin_primitives/semisupervised_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@


class SkSelfTrainingClassifier(BasePrimitive):
sdg_params = dict(alpha=1e-5, penalty="l2", loss="log_loss")
sdg_params = dict(alpha=1e-5, penalty='l2', loss='log_loss')
model = SelfTrainingClassifier(SGDClassifier(**sdg_params), verbose=True)

def fit(self, X, y=None):
Expand All @@ -38,14 +38,14 @@ def fit(self, X, y=None):
if isinstance(X, np.ndarray):
self.pipe = Pipeline(
[
("sklearn.semi_supervised.LabelSpreading", LabelSpreading()),
('sklearn.semi_supervised.LabelSpreading', LabelSpreading()),
]
)
else:
self.pipe = Pipeline(
[
("toarray", FunctionTransformer(lambda x: x.toarray())),
("sklearn.semi_supervised.LabelSpreading", LabelSpreading()),
('toarray', FunctionTransformer(lambda x: x.toarray())),
('sklearn.semi_supervised.LabelSpreading', LabelSpreading()),
]
)
self.pipe.fit(X, y)
Expand All @@ -62,14 +62,14 @@ def fit(self, X, y=None):
if isinstance(X, np.ndarray):
self.pipe = Pipeline(
[
("sklearn.semi_supervised.LabelPropagation", LabelPropagation()),
('sklearn.semi_supervised.LabelPropagation', LabelPropagation()),
]
)
else:
self.pipe = Pipeline(
[
("toarray", FunctionTransformer(lambda x: x.toarray())),
("sklearn.semi_supervised.LabelPropagation", LabelPropagation()),
('toarray', FunctionTransformer(lambda x: x.toarray())),
('sklearn.semi_supervised.LabelPropagation', LabelPropagation()),
]
)
self.pipe.fit(X, y)
Expand Down Expand Up @@ -115,7 +115,7 @@ def fit(self, X, y):
entIdx = np.rec.fromarrays((entropies, unlabeledIx))
entIdx.sort(axis=0)

labelableIndices = entIdx["f1"][-num_instances_to_label:].reshape((-1,))
labelableIndices = entIdx['f1'][-num_instances_to_label:].reshape((-1,))

predictions = self.pipeline.predict(X[labelableIndices])

Expand Down

0 comments on commit be8c240

Please sign in to comment.