diff --git a/automata/fa/dfa.py b/automata/fa/dfa.py index 28366d8c..9e763fb7 100644 --- a/automata/fa/dfa.py +++ b/automata/fa/dfa.py @@ -499,28 +499,46 @@ def _minify( If the input DFA is partial, then the result is also a partial DFA """ - # First, assemble backmap and equivalence class data structure - eq_classes = PartitionRefinement(reachable_states) - refinement = eq_classes.refine(reachable_final_states) - - final_states_id = ( - refinement[0][0] if refinement else next(iter(eq_classes.get_set_ids())) - ) + reachable_states = set(reachable_states) # Per input-symbol backmap (tgt -> origin states) transition_back_map: Dict[str, Dict[DFAStateT, List[DFAStateT]]] = { - symbol: {end_state: list() for end_state in reachable_states} + symbol: {end_state: [] for end_state in reachable_states} for symbol in input_symbols } + trap_state = None + for start_state, path in transitions.items(): if start_state in reachable_states: - for symbol, end_state in path.items(): - symbol_dict = transition_back_map[symbol] - # If statement here needed to ignore certain transitions - # when minifying a partial DFA. - if end_state in symbol_dict: - symbol_dict[end_state].append(start_state) + for symbol in input_symbols: + end_state = path.get(symbol) + if end_state is not None: + symbol_dict = transition_back_map[symbol] + # If statement here needed to ignore certain transitions + # for non-reachable states + if end_state in symbol_dict: + symbol_dict[end_state].append(start_state) + else: + # Add trap state if needed + if trap_state is None: + trap_state = next( + x for x in count(-1, -1) if x not in reachable_states + ) + for trap_symbol in input_symbols: + transition_back_map[trap_symbol][trap_state] = [] + + reachable_states.add(trap_state) + + transition_back_map[symbol][trap_state].append(start_state) + + # Set up equivalence class data structure + eq_classes = PartitionRefinement(reachable_states) + refinement = eq_classes.refine(reachable_final_states) + + final_states_id = ( + refinement[0][0] if refinement else next(iter(eq_classes.get_set_ids())) + ) origin_dicts = tuple(transition_back_map.values()) processing = {final_states_id} @@ -558,7 +576,12 @@ def _minify( ) # need a backmap to prevent constant calls to index - back_map = {state: name for name, eq in eq_class_name_pairs for state in eq} + back_map = { + state: name + for name, eq in eq_class_name_pairs + for state in eq + if trap_state not in eq + } new_input_symbols = input_symbols new_states = frozenset(back_map.values()) @@ -567,12 +590,17 @@ def _minify( new_transitions = {} for name, eq in eq_class_name_pairs: + # For trap state, can just leave out + if trap_state in eq: + continue + eq_class_rep = next(iter(eq)) + inner_transition_dict_old = transitions[eq_class_rep] new_transitions[name] = { letter: back_map[inner_transition_dict_old[letter]] for letter in inner_transition_dict_old.keys() - if inner_transition_dict_old[letter] in reachable_states + if inner_transition_dict_old[letter] in back_map.keys() } allow_partial = any( diff --git a/tests/test_dfa.py b/tests/test_dfa.py index 4af7dab6..2d871d9c 100644 --- a/tests/test_dfa.py +++ b/tests/test_dfa.py @@ -1280,6 +1280,40 @@ def test_minify_partial_dfa(self) -> None: self.assertEqual(len(minified_partial_dfa.states), 4) self.assertEqual(minified_partial_dfa, partial_dfa_extra_state) + def test_minify_partial_dfa_correctness(self) -> None: + """ + Test correctness of minifying partial DFAs. + Test added because of issues raised here: + https://github.com/caleb531/automata/issues/182 + """ + + input_symbols = {"a", "b", "c"} + dfa = DFA.from_finite_language( + language={"ab", "abcb"}, input_symbols=input_symbols, as_partial=True + ) + + self.assertEqual(dfa.minify(), dfa) + + dfa2 = DFA.from_finite_language( + language={"ab", "abba", "cbab"}, + input_symbols=input_symbols, + as_partial=True, + ) + + self.assertEqual(dfa2.minify(), dfa2) + + self.assertEqual(dfa.union(dfa2, minify=False), dfa.union(dfa2, minify=True)) + self.assertEqual( + dfa.intersection(dfa2, minify=False), dfa.intersection(dfa2, minify=True) + ) + self.assertEqual( + dfa.symmetric_difference(dfa2, minify=False), + dfa.symmetric_difference(dfa2, minify=True), + ) + self.assertEqual( + dfa.difference(dfa2, minify=False), dfa.difference(dfa2, minify=True) + ) + def test_init_nfa_simple(self) -> None: """Should convert to a DFA a simple NFA.""" nfa = NFA(