Skip to content

Commit

Permalink
Fix experiment bugs (#1717)
Browse files Browse the repository at this point in the history
* setup to get results we missed

* oops

* revert yaml

* fix tests

* should improve the success rate of our method!
  • Loading branch information
NishanthJKumar authored Oct 9, 2024
1 parent 2f49d79 commit 1c8d8b2
Show file tree
Hide file tree
Showing 5 changed files with 18 additions and 6 deletions.
10 changes: 6 additions & 4 deletions predicators/approaches/grammar_search_invention_approach.py
Original file line number Diff line number Diff line change
Expand Up @@ -885,6 +885,7 @@ class _ForallPredicateGrammarWrapper(_PredicateGrammar):
base_grammar: _PredicateGrammar

def enumerate(self) -> Iterator[Tuple[Predicate, float]]:
forall_penalty = CFG.grammar_search_forall_penalty
for (predicate, cost) in self.base_grammar.enumerate():
yield (predicate, cost)
if predicate.arity == 0:
Expand All @@ -894,14 +895,15 @@ def enumerate(self) -> Iterator[Tuple[Predicate, float]]:
forall_predicate = Predicate(str(forall_classifier), [],
forall_classifier)
assert forall_predicate.arity == 0
yield (forall_predicate, cost + 1) # add arity + 1 to cost
yield (forall_predicate, cost + forall_penalty)
# Generate NOT-Forall(x)
notforall_classifier = _NegationClassifier(forall_predicate)
notforall_predicate = Predicate(str(notforall_classifier),
forall_predicate.types,
notforall_classifier)
assert notforall_predicate.arity == 0
yield (notforall_predicate, cost + 1) # add arity + 1 to cost
yield (notforall_predicate, cost + forall_penalty)

# Generate UFFs
if predicate.arity >= 2:
for idx in range(predicate.arity):
Expand All @@ -911,14 +913,14 @@ def enumerate(self) -> Iterator[Tuple[Predicate, float]]:
[predicate.types[idx]],
uff_classifier)
assert uff_predicate.arity == 1
yield (uff_predicate, cost + 2) # add arity + 1 to cost
yield (uff_predicate, cost + forall_penalty + 1)
# Negated UFF
notuff_classifier = _NegationClassifier(uff_predicate)
notuff_predicate = Predicate(str(notuff_classifier),
uff_predicate.types,
notuff_classifier)
assert notuff_predicate.arity == 1
yield (notuff_predicate, cost + 2) # add arity + 1 to cost
yield (notuff_predicate, cost + forall_penalty + 1)


################################################################################
Expand Down
1 change: 1 addition & 0 deletions predicators/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -660,6 +660,7 @@ class GlobalSettings:
grammar_search_grammar_use_diff_features = False
grammar_search_grammar_use_euclidean_dist = False
grammar_search_use_handcoded_debug_grammar = False
grammar_search_forall_penalty = 1
grammar_search_pred_selection_approach = "score_optimization"
grammar_search_pred_clusterer = "oracle"
grammar_search_true_pos_weight = 10
Expand Down
3 changes: 2 additions & 1 deletion predicators/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2485,7 +2485,8 @@ def query_vlm_for_atom_vals(
all_vlm_responses = vlm_output_str.strip().split("\n")
# NOTE: this assumption is likely too brittle; if this is breaking, feel
# free to remove/adjust this and change the below parsing loop accordingly!
assert len(atom_queries_list) == len(all_vlm_responses)
if len(atom_queries_list) != len(all_vlm_responses):
return set()
for i, (atom_query, curr_vlm_output_line) in enumerate(
zip(atom_queries_list, all_vlm_responses)):
assert atom_query + ":" in curr_vlm_output_line
Expand Down
4 changes: 3 additions & 1 deletion scripts/configs/pred_invention_vlm.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,9 @@ ENVS:
grammar_search_vlm_atom_proposal_use_debug: False
allow_exclude_goal_predicates: True
grammar_search_prune_redundant_preds: True
grammar_search_predicate_cost_upper_bound: 7
grammar_search_predicate_cost_upper_bound: 13
grammar_search_pred_complexity_weight: 10
grammar_search_forall_penalty: 5
allow_state_allclose_comparison_despite_simulator_state: True
grammar_search_max_predicates: 200
grammar_search_parallelize_vlm_labeling: True
Expand Down
6 changes: 6 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1133,6 +1133,12 @@ def _classifier2(state, objects):
vlm_atoms_set = utils.abstract(vlm_state, [vlm_pred], _DummyVLM())
assert len(vlm_atoms_set) == 1
assert "IsFishy" in str(vlm_atoms_set)
# Now, teset the case where the VLM response is wrong/bad.
vlm_pred2 = VLMPredicate("IsSnakey", [], lambda s, o: NotImplementedError,
lambda o: "is_snakey")
vlm_atoms_set = utils.abstract(vlm_state, [vlm_pred, vlm_pred2],
_DummyVLM())
assert len(vlm_atoms_set) == 0


def test_create_new_variables():
Expand Down

0 comments on commit 1c8d8b2

Please sign in to comment.