Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
tsilver-bdai committed Jun 29, 2023
1 parent f6dd4c8 commit 606d326
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 10 deletions.
53 changes: 43 additions & 10 deletions predicators/approaches/active_sampler_learning_approach.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,8 @@ def _learn_nsrt_sampler(self, nsrt_data: _OptionSamplerDataset,
for state, option, _, label in nsrt_data:
objects = option.objects
params = option.params
x_arr = _construct_sampler_input(state, objects, params)
x_arr = _construct_sampler_input(state, objects, params,
option.parent)
X_classifier.append(x_arr)
y_classifier.append(label)
X_arr_classifier = np.array(X_classifier)
Expand Down Expand Up @@ -305,7 +306,8 @@ def _learn_nsrt_sampler(self, nsrt_data: _OptionSamplerDataset,
for state, option, _, label in nsrt_data:
objects = option.objects
params = option.params
x_arr = _construct_sampler_input(state, objects, params)
x_arr = _construct_sampler_input(state, objects, params,
option.parent)
X_classifier.append(x_arr)
y_classifier.append(label)
X_arr_classifier = np.array(X_classifier)
Expand Down Expand Up @@ -447,7 +449,8 @@ def _fit_regressor(self, nsrt_data: _OptionSamplerDataset) -> MLPRegressor:
for state, option, _, target in nsrt_data:
objects = option.objects
params = option.params
x_arr = _construct_sampler_input(state, objects, params)
x_arr = _construct_sampler_input(state, objects, params,
option.parent)
X_regressor.append(x_arr)
y_regressor.append(np.array([target]))
X_arr_regressor = np.array(X_regressor)
Expand All @@ -469,12 +472,40 @@ def _fit_regressor(self, nsrt_data: _OptionSamplerDataset) -> MLPRegressor:

# Helper functions.
def _construct_sampler_input(state: State, objects: Sequence[Object],
params: Array) -> Array:
sampler_input_lst = [1.0] # start with bias term
for obj in objects:
sampler_input_lst.extend(state[obj])
sampler_input_lst.extend(params)
params: Array,
param_option: ParameterizedOption) -> Array:

assert not CFG.sampler_learning_use_goals
sampler_input_lst = [1.0] # start with bias term
if CFG.active_sampler_learning_feature_selection == "all":
for obj in objects:
sampler_input_lst.extend(state[obj])
sampler_input_lst.extend(params)

else:
assert CFG.active_sampler_learning_feature_selection == "oracle"
assert CFG.env == "bumpy_cover"
if param_option.name == "Pick":
# In this case, the x-data should be
# [block_bumpy, relative_pick_loc]
assert len(objects) == 1
block = objects[0]
block_pos = state[block][3]
block_bumpy = state[block][5]
sampler_input_lst.append(block_bumpy)
assert len(params) == 1
sampler_input_lst.append(params[0] - block_pos)
else:
assert param_option.name == "Place"
assert len(objects) == 2
block, target = objects
target_pos = state[target][3]
grasp = state[block][4]
target_width = state[target][2]
sampler_input_lst.extend([grasp, target_width])
assert len(params) == 1
sampler_input_lst.append(params[0] - target_pos[0])

return np.array(sampler_input_lst)


Expand Down Expand Up @@ -505,8 +536,10 @@ def _vector_score_fn_to_score_fn(vector_fn: Callable[[Array], float],

def _score_fn(state: State, objects: Sequence[Object],
param_lst: List[Array]) -> List[float]:
X = np.array(
[_construct_sampler_input(state, objects, p) for p in param_lst])
X = np.array([
_construct_sampler_input(state, objects, p, nsrt.option)
for p in param_lst
])
scores = [vector_fn(X) for p in param_lst]
return scores

Expand Down
1 change: 1 addition & 0 deletions predicators/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,6 +519,7 @@ class GlobalSettings:

# active sampler learning parameters
active_sampler_learning_model = "myopic_classifier_mlp"
active_sampler_learning_feature_selection = "all"
active_sampler_learning_use_teacher = True
active_sampler_learning_num_samples = 100
active_sampler_learning_score_gamma = 0.5
Expand Down

0 comments on commit 606d326

Please sign in to comment.