Skip to content

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
tsilver-bdai committed Jun 29, 2023
1 parent c534660 commit f6dd4c8
Showing 1 changed file with 42 additions and 44 deletions.
86 changes: 42 additions & 44 deletions predicators/approaches/active_sampler_learning_approach.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@
from predicators.approaches.online_nsrt_learning_approach import \
OnlineNSRTLearningApproach
from predicators.explorers import BaseExplorer, create_explorer
from predicators.ml_models import BinaryClassifierEnsemble, \
MLPBinaryClassifier, MLPRegressor
from predicators.ml_models import BinaryClassifier, BinaryClassifierEnsemble, \
KNeighborsClassifier, MLPBinaryClassifier, MLPRegressor
from predicators.settings import CFG
from predicators.structs import NSRT, Array, GroundAtom, LowLevelTrajectory, \
NSRTSampler, Object, ParameterizedOption, Predicate, Segment, State, \
Expand Down Expand Up @@ -247,34 +247,34 @@ class _ClassifierWrappedSamplerLearner(_WrappedSamplerLearner):

def _learn_nsrt_sampler(self, nsrt_data: _OptionSamplerDataset,
nsrt: NSRT) -> Tuple[NSRTSampler, NSRTSampler]:
X_classifier: List[List[Array]] = []
X_classifier: List[Array] = []
y_classifier: List[int] = []
for state, option, _, label in nsrt_data:
objects = option.objects
params = option.params
# input is state features and option parameters
X_classifier.append([np.array(1.0)]) # start with bias term
for obj in objects:
X_classifier[-1].extend(state[obj])
X_classifier[-1].extend(params)
assert not CFG.sampler_learning_use_goals
x_arr = _construct_sampler_input(state, objects, params)
X_classifier.append(x_arr)
y_classifier.append(label)
X_arr_classifier = np.array(X_classifier)
# output is binary signal
y_arr_classifier = np.array(y_classifier)
classifier = MLPBinaryClassifier(
seed=CFG.seed,
balance_data=CFG.mlp_classifier_balance_data,
max_train_iters=CFG.sampler_mlp_classifier_max_itr,
learning_rate=CFG.learning_rate,
weight_decay=CFG.weight_decay,
use_torch_gpu=CFG.use_torch_gpu,
train_print_every=CFG.pytorch_train_print_every,
n_iter_no_change=CFG.mlp_classifier_n_iter_no_change,
hid_sizes=CFG.mlp_classifier_hid_sizes,
n_reinitialize_tries=CFG.
sampler_mlp_classifier_n_reinitialize_tries,
weight_init="default")
if CFG.active_sampler_learning_model.endswith("mlp"):
classifier: BinaryClassifier = MLPBinaryClassifier(
seed=CFG.seed,
balance_data=CFG.mlp_classifier_balance_data,
max_train_iters=CFG.sampler_mlp_classifier_max_itr,
learning_rate=CFG.learning_rate,
weight_decay=CFG.weight_decay,
use_torch_gpu=CFG.use_torch_gpu,
train_print_every=CFG.pytorch_train_print_every,
n_iter_no_change=CFG.mlp_classifier_n_iter_no_change,
hid_sizes=CFG.mlp_classifier_hid_sizes,
n_reinitialize_tries=CFG.
sampler_mlp_classifier_n_reinitialize_tries,
weight_init="default")
else:
assert CFG.active_samplre_learning_model.endswith("knn")
classifier = KNeighborsClassifier(seed=CFG.seed)
classifier.fit(X_arr_classifier, y_arr_classifier)

# Save the sampler classifier for external analysis.
Expand All @@ -300,17 +300,13 @@ class _ClassifierEnsembleWrappedSamplerLearner(_WrappedSamplerLearner):

def _learn_nsrt_sampler(self, nsrt_data: _OptionSamplerDataset,
nsrt: NSRT) -> Tuple[NSRTSampler, NSRTSampler]:
X_classifier: List[List[Array]] = []
X_classifier: List[Array] = []
y_classifier: List[int] = []
for state, option, _, label in nsrt_data:
objects = option.objects
params = option.params
# input is state features and option parameters
X_classifier.append([np.array(1.0)]) # start with bias term
for obj in objects:
X_classifier[-1].extend(state[obj])
X_classifier[-1].extend(params)
assert not CFG.sampler_learning_use_goals
x_arr = _construct_sampler_input(state, objects, params)
X_classifier.append(x_arr)
y_classifier.append(label)
X_arr_classifier = np.array(X_classifier)
# output is binary signal
Expand Down Expand Up @@ -446,17 +442,13 @@ def _sample_options_from_state(self,
return sampled_options

def _fit_regressor(self, nsrt_data: _OptionSamplerDataset) -> MLPRegressor:
X_regressor: List[List[Array]] = []
X_regressor: List[Array] = []
y_regressor: List[Array] = []
for state, option, _, target in nsrt_data:
objects = option.objects
params = option.params
# input is state features and option parameters
X_regressor.append([np.array(1.0)]) # start with bias term
for obj in objects:
X_regressor[-1].extend(state[obj])
X_regressor[-1].extend(params)
assert not CFG.sampler_learning_use_goals
x_arr = _construct_sampler_input(state, objects, params)
X_regressor.append(x_arr)
y_regressor.append(np.array([target]))
X_arr_regressor = np.array(X_regressor)
y_arr_regressor = np.array(y_regressor)
Expand All @@ -476,6 +468,16 @@ def _fit_regressor(self, nsrt_data: _OptionSamplerDataset) -> MLPRegressor:


# Helper functions.
def _construct_sampler_input(state: State, objects: Sequence[Object],
params: Array) -> Array:
sampler_input_lst = [1.0] # start with bias term
for obj in objects:
sampler_input_lst.extend(state[obj])
sampler_input_lst.extend(params)
assert not CFG.sampler_learning_use_goals
return np.array(sampler_input_lst)


def _wrap_sampler(
base_sampler: NSRTSampler,
score_fn: _ScoreFn,
Expand Down Expand Up @@ -503,19 +505,15 @@ def _vector_score_fn_to_score_fn(vector_fn: Callable[[Array], float],

def _score_fn(state: State, objects: Sequence[Object],
param_lst: List[Array]) -> List[float]:
x_lst: List[Any] = [1.0] # start with bias term
sub = dict(zip(nsrt.parameters, objects))
for var in nsrt.parameters:
x_lst.extend(state[sub[var]])
assert not CFG.sampler_learning_use_goals
x = np.array(x_lst)
scores = [vector_fn(np.r_[x, p]) for p in param_lst]
X = np.array(
[_construct_sampler_input(state, objects, p) for p in param_lst])
scores = [vector_fn(X) for p in param_lst]
return scores

return _score_fn


def _classifier_to_score_fn(classifier: MLPBinaryClassifier,
def _classifier_to_score_fn(classifier: BinaryClassifier,
nsrt: NSRT) -> _ScoreFn:
return _vector_score_fn_to_score_fn(classifier.predict_proba, nsrt)

Expand Down

0 comments on commit f6dd4c8

Please sign in to comment.