From 1c8d8b2e820849d0f549522c7342164d835caad0 Mon Sep 17 00:00:00 2001 From: Nishanth Kumar Date: Wed, 9 Oct 2024 15:11:45 -0400 Subject: [PATCH] Fix experiment bugs (#1717) * setup to get results we missed * oops * revert yaml * fix tests * should improve the success rate of our method! --- .../approaches/grammar_search_invention_approach.py | 10 ++++++---- predicators/settings.py | 1 + predicators/utils.py | 3 ++- scripts/configs/pred_invention_vlm.yaml | 4 +++- tests/test_utils.py | 6 ++++++ 5 files changed, 18 insertions(+), 6 deletions(-) diff --git a/predicators/approaches/grammar_search_invention_approach.py b/predicators/approaches/grammar_search_invention_approach.py index 681ae4f667..fbf33efdf4 100644 --- a/predicators/approaches/grammar_search_invention_approach.py +++ b/predicators/approaches/grammar_search_invention_approach.py @@ -885,6 +885,7 @@ class _ForallPredicateGrammarWrapper(_PredicateGrammar): base_grammar: _PredicateGrammar def enumerate(self) -> Iterator[Tuple[Predicate, float]]: + forall_penalty = CFG.grammar_search_forall_penalty for (predicate, cost) in self.base_grammar.enumerate(): yield (predicate, cost) if predicate.arity == 0: @@ -894,14 +895,15 @@ def enumerate(self) -> Iterator[Tuple[Predicate, float]]: forall_predicate = Predicate(str(forall_classifier), [], forall_classifier) assert forall_predicate.arity == 0 - yield (forall_predicate, cost + 1) # add arity + 1 to cost + yield (forall_predicate, cost + forall_penalty) # Generate NOT-Forall(x) notforall_classifier = _NegationClassifier(forall_predicate) notforall_predicate = Predicate(str(notforall_classifier), forall_predicate.types, notforall_classifier) assert notforall_predicate.arity == 0 - yield (notforall_predicate, cost + 1) # add arity + 1 to cost + yield (notforall_predicate, cost + forall_penalty) + # Generate UFFs if predicate.arity >= 2: for idx in range(predicate.arity): @@ -911,14 +913,14 @@ def enumerate(self) -> Iterator[Tuple[Predicate, float]]: [predicate.types[idx]], uff_classifier) assert uff_predicate.arity == 1 - yield (uff_predicate, cost + 2) # add arity + 1 to cost + yield (uff_predicate, cost + forall_penalty + 1) # Negated UFF notuff_classifier = _NegationClassifier(uff_predicate) notuff_predicate = Predicate(str(notuff_classifier), uff_predicate.types, notuff_classifier) assert notuff_predicate.arity == 1 - yield (notuff_predicate, cost + 2) # add arity + 1 to cost + yield (notuff_predicate, cost + forall_penalty + 1) ################################################################################ diff --git a/predicators/settings.py b/predicators/settings.py index 6f1ffe1fcb..075c50ac83 100644 --- a/predicators/settings.py +++ b/predicators/settings.py @@ -660,6 +660,7 @@ class GlobalSettings: grammar_search_grammar_use_diff_features = False grammar_search_grammar_use_euclidean_dist = False grammar_search_use_handcoded_debug_grammar = False + grammar_search_forall_penalty = 1 grammar_search_pred_selection_approach = "score_optimization" grammar_search_pred_clusterer = "oracle" grammar_search_true_pos_weight = 10 diff --git a/predicators/utils.py b/predicators/utils.py index cedc41d238..8694a13a74 100644 --- a/predicators/utils.py +++ b/predicators/utils.py @@ -2485,7 +2485,8 @@ def query_vlm_for_atom_vals( all_vlm_responses = vlm_output_str.strip().split("\n") # NOTE: this assumption is likely too brittle; if this is breaking, feel # free to remove/adjust this and change the below parsing loop accordingly! - assert len(atom_queries_list) == len(all_vlm_responses) + if len(atom_queries_list) != len(all_vlm_responses): + return set() for i, (atom_query, curr_vlm_output_line) in enumerate( zip(atom_queries_list, all_vlm_responses)): assert atom_query + ":" in curr_vlm_output_line diff --git a/scripts/configs/pred_invention_vlm.yaml b/scripts/configs/pred_invention_vlm.yaml index c9d9167ecd..ecc5d9aa87 100644 --- a/scripts/configs/pred_invention_vlm.yaml +++ b/scripts/configs/pred_invention_vlm.yaml @@ -54,7 +54,9 @@ ENVS: grammar_search_vlm_atom_proposal_use_debug: False allow_exclude_goal_predicates: True grammar_search_prune_redundant_preds: True - grammar_search_predicate_cost_upper_bound: 7 + grammar_search_predicate_cost_upper_bound: 13 + grammar_search_pred_complexity_weight: 10 + grammar_search_forall_penalty: 5 allow_state_allclose_comparison_despite_simulator_state: True grammar_search_max_predicates: 200 grammar_search_parallelize_vlm_labeling: True diff --git a/tests/test_utils.py b/tests/test_utils.py index f5dd4743b8..760539d5bb 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1133,6 +1133,12 @@ def _classifier2(state, objects): vlm_atoms_set = utils.abstract(vlm_state, [vlm_pred], _DummyVLM()) assert len(vlm_atoms_set) == 1 assert "IsFishy" in str(vlm_atoms_set) + # Now, teset the case where the VLM response is wrong/bad. + vlm_pred2 = VLMPredicate("IsSnakey", [], lambda s, o: NotImplementedError, + lambda o: "is_snakey") + vlm_atoms_set = utils.abstract(vlm_state, [vlm_pred, vlm_pred2], + _DummyVLM()) + assert len(vlm_atoms_set) == 0 def test_create_new_variables():