Skip to content

Commit

Permalink
fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
tsilver-bdai committed Jun 29, 2023
1 parent c6f7fce commit 9aea569
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 8 deletions.
5 changes: 4 additions & 1 deletion predicators/approaches/active_sampler_learning_approach.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,10 @@ def _learn_nsrt_sampler(self, nsrt_data: _OptionSamplerDataset,
weight_init="default")
else:
assert CFG.active_sampler_learning_model.endswith("knn")
classifier = KNeighborsClassifier(seed=CFG.seed)
n_neighbors = min(len(X_arr_classifier),
CFG.active_sampler_learning_knn_neighbors)
classifier = KNeighborsClassifier(seed=CFG.seed,
n_neighbors=n_neighbors)
classifier.fit(X_arr_classifier, y_arr_classifier)

# Save the sampler classifier for external analysis.
Expand Down
3 changes: 3 additions & 0 deletions predicators/ml_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,9 @@ def classify(self, x: Array) -> bool:

def predict_proba(self, x: Array) -> float:
probs = self._model.predict_proba([x])[0]
# Special case: only one class.
if probs.shape == (1, ):
return float(self.classify(x))
assert probs.shape == (2, ) # [P(x is class 0), P(x is class 1)]
return probs[1] # return the second element of probs

Expand Down
1 change: 1 addition & 0 deletions predicators/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,6 +520,7 @@ class GlobalSettings:
# active sampler learning parameters
active_sampler_learning_model = "myopic_classifier_mlp"
active_sampler_learning_feature_selection = "all"
active_sampler_learning_knn_neighbors = 3
active_sampler_learning_use_teacher = True
active_sampler_learning_num_samples = 100
active_sampler_learning_score_gamma = 0.5
Expand Down
18 changes: 11 additions & 7 deletions tests/approaches/test_active_sampler_learning_approach.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,22 @@
from predicators.teacher import Teacher


@pytest.mark.parametrize("model_name,right_targets,num_demo",
[("myopic_classifier_mlp", False, 0),
("myopic_classifier_mlp", True, 1),
("myopic_classifier_ensemble", False, 0),
("myopic_classifier_ensemble", False, 1),
("fitted_q", False, 0), ("fitted_q", True, 0)])
def test_active_sampler_learning_approach(model_name, right_targets, num_demo):
@pytest.mark.parametrize("model_name,right_targets,num_demo,feat_type",
[("myopic_classifier_mlp", False, 0, "all"),
("myopic_classifier_mlp", True, 1, "all"),
("myopic_classifier_ensemble", False, 0, "all"),
("myopic_classifier_ensemble", False, 1, "all"),
("fitted_q", False, 0, "all"),
("fitted_q", True, 0, "all"),
("myopic_classifier_knn", False, 0, "oracle")])
def test_active_sampler_learning_approach(model_name, right_targets, num_demo,
feat_type):
"""Test for ActiveSamplerLearningApproach class, entire pipeline."""
utils.reset_config({
"env": "bumpy_cover",
"approach": "active_sampler_learning",
"active_sampler_learning_model": model_name,
"active_sampler_learning_feature_selection": feat_type,
"timeout": 10,
"strips_learner": "oracle",
"sampler_learner": "oracle",
Expand Down

0 comments on commit 9aea569

Please sign in to comment.