diff --git a/.github/workflows/check-black.yml b/.github/workflows/check-black.yml index 62321d9f..7a3a2aa2 100644 --- a/.github/workflows/check-black.yml +++ b/.github/workflows/check-black.yml @@ -1,6 +1,6 @@ name: Lint and Format Check -on: [push, pull_request] +on: [pull_request] jobs: lint-and-format: diff --git a/.github/workflows/check-isort.yml b/.github/workflows/check-isort.yml index 40e799c0..9114cf17 100644 --- a/.github/workflows/check-isort.yml +++ b/.github/workflows/check-isort.yml @@ -1,6 +1,6 @@ name: Check Import Sorting -on: [push, pull_request] +on: [pull_request] jobs: import-sorting: diff --git a/.github/workflows/check-mypy.yml b/.github/workflows/check-mypy.yml index dda88cc1..f7200fd3 100644 --- a/.github/workflows/check-mypy.yml +++ b/.github/workflows/check-mypy.yml @@ -1,6 +1,6 @@ name: Type Checking -on: [push, pull_request] +on: [pull_request] jobs: type-check: diff --git a/.github/workflows/test-with-codecov.yml b/.github/workflows/test-with-codecov.yml index 175aac9b..a3eb163d 100644 --- a/.github/workflows/test-with-codecov.yml +++ b/.github/workflows/test-with-codecov.yml @@ -1,5 +1,5 @@ name: Codecov Check -on: [push] +on: [pull-request] jobs: run: runs-on: ${{ matrix.os }} diff --git a/automata/core/code_indexing/python_code_retriever.py b/automata/core/code_indexing/python_code_retriever.py index 53e2aefc..00453c74 100644 --- a/automata/core/code_indexing/python_code_retriever.py +++ b/automata/core/code_indexing/python_code_retriever.py @@ -16,7 +16,7 @@ class PythonCodeRetriever: def __init__( - self, module_tree_map: Optional[LazyModuleTreeMap] = LazyModuleTreeMap.cached_default() + self, module_tree_map: LazyModuleTreeMap = LazyModuleTreeMap.cached_default() ) -> None: self.module_tree_map = module_tree_map @@ -57,7 +57,9 @@ def get_docstring(self, module_dotpath: str, object_path: Optional[str]) -> str: """ module = self.module_tree_map.get_module(module_dotpath) - return PythonCodeRetriever._get_docstring(find_syntax_tree_node(module, object_path)) + if module: + return PythonCodeRetriever._get_docstring(find_syntax_tree_node(module, object_path)) + return NO_RESULT_FOUND_STR def get_source_code_without_docstrings( self, module_dotpath: str, object_path: Optional[str] @@ -96,16 +98,14 @@ def _remove_docstrings(node: FSTNode) -> None: module = self.module_tree_map.get_module(module_dotpath) - module = ( - RedBaron(module.dumps()) if module else None - ) # create a copy because we'll remove docstrings - result = find_syntax_tree_node(module, object_path) + if module: + module_copy = RedBaron(module.dumps()) + result = find_syntax_tree_node(module_copy, object_path) - if result: - _remove_docstrings(result) - return result.dumps() - else: - return NO_RESULT_FOUND_STR + if result: + _remove_docstrings(result) + return result.dumps() + return NO_RESULT_FOUND_STR def get_parent_function_name_by_line(self, module_dotpath: str, line_number: int) -> str: """ @@ -121,19 +121,16 @@ def get_parent_function_name_by_line(self, module_dotpath: str, line_number: int """ module = self.module_tree_map.get_module(module_dotpath) - if not module: - return NO_RESULT_FOUND_STR - - node = module.at(line_number) - if node.type != "def": - node = node.parent_find("def") - if node: - if node.parent[0].type == "class": - return f"{node.parent.name}.{node.name}" - else: - return node.name - else: - return NO_RESULT_FOUND_STR + if module: + node = module.at(line_number) + if node.type != "def": + node = node.parent_find("def") + if node: + if node.parent[0].type == "class": + return f"{node.parent.name}.{node.name}" + else: + return node.name + return NO_RESULT_FOUND_STR def get_parent_function_num_code_lines( self, module_dotpath: str, line_number: int @@ -150,19 +147,17 @@ def get_parent_function_num_code_lines( """ module = self.module_tree_map.get_module(module_dotpath) - if not module: - return NO_RESULT_FOUND_STR - - node = module.at(line_number) - if node.type != "def": - node = node.parent_find("def") - if not node: - return NO_RESULT_FOUND_STR - return ( - node.absolute_bounding_box.bottom_right.line - - node.absolute_bounding_box.top_left.line - + 1 - ) + if module: + node = module.at(line_number) + if node.type != "def": + node = node.parent_find("def") + if node: + return ( + node.absolute_bounding_box.bottom_right.line + - node.absolute_bounding_box.top_left.line + + 1 + ) + return NO_RESULT_FOUND_STR def get_parent_code_by_line( self, module_dotpath: str, line_number: int, return_numbered=False @@ -181,66 +176,68 @@ def get_parent_code_by_line( """ module = self.module_tree_map.get_module(module_dotpath) - if not module: - return NO_RESULT_FOUND_STR - node = module.at(line_number) - - # retarget def or class node - if node.type not in ("def", "class") and node.parent_find( - lambda identifier: identifier in ("def", "class") - ): - node = node.parent_find(lambda identifier: identifier in ("def", "class")) - - path = node.path().to_baron_path() - pointer = module - result = [] + if module: + node = module.at(line_number) + + # retarget def or class node + if node.type not in ("def", "class") and node.parent_find( + lambda identifier: identifier in ("def", "class") + ): + node = node.parent_find(lambda identifier: identifier in ("def", "class")) + + path = node.path().to_baron_path() + pointer = module + result = [] + + for entry in path: + if isinstance(entry, int): + pointer = pointer.node_list + for x in range(entry): + start_line, start_col = ( + pointer[x].absolute_bounding_box.top_left.line, + pointer[x].absolute_bounding_box.top_left.column, + ) - for entry in path: - if isinstance(entry, int): - pointer = pointer.node_list - for x in range(entry): + if pointer[x].type == "string" and pointer[x].value.startswith('"""'): + result += self._create_line_number_tuples( + pointer[x], start_line, start_col + ) + if pointer[x].type in ("def", "class"): + docstring = PythonCodeRetriever._get_docstring(pointer[x]) + node_copy = pointer[x].copy() + node_copy.value = '"""' + docstring + '"""' + result += self._create_line_number_tuples( + node_copy, start_line, start_col + ) + pointer = pointer[entry] + else: start_line, start_col = ( - pointer[x].absolute_bounding_box.top_left.line, - pointer[x].absolute_bounding_box.top_left.column, + pointer.absolute_bounding_box.top_left.line, + pointer.absolute_bounding_box.top_left.column, ) - - if pointer[x].type == "string" and pointer[x].value.startswith('"""'): - result += self._create_line_number_tuples( - pointer[x], start_line, start_col - ) - if pointer[x].type in ("def", "class"): - docstring = PythonCodeRetriever._get_docstring(pointer[x]) - node_copy = pointer[x].copy() - node_copy.value = '"""' + docstring + '"""' - result += self._create_line_number_tuples(node_copy, start_line, start_col) - pointer = pointer[entry] - else: - start_line, start_col = ( - pointer.absolute_bounding_box.top_left.line, - pointer.absolute_bounding_box.top_left.column, - ) - node_copy = pointer.copy() - node_copy.value = "" - result += self._create_line_number_tuples(node_copy, start_line, start_col) - pointer = getattr(pointer, entry) - - start_line, start_col = ( - pointer.absolute_bounding_box.top_left.line, - pointer.absolute_bounding_box.top_left.column, - ) - result += self._create_line_number_tuples(pointer, start_line, start_col) - - prev_line = 1 - result_str = "" - for t in result: - if t[0] > prev_line + 1: - result_str += "...\n" - if return_numbered: - result_str += f"{t[0]}: {t[1]}\n" - else: - result_str += f"{t[1]}\n" - prev_line = t[0] - return result_str + node_copy = pointer.copy() + node_copy.value = "" + result += self._create_line_number_tuples(node_copy, start_line, start_col) + pointer = getattr(pointer, entry) + + start_line, start_col = ( + pointer.absolute_bounding_box.top_left.line, + pointer.absolute_bounding_box.top_left.column, + ) + result += self._create_line_number_tuples(pointer, start_line, start_col) + + prev_line = 1 + result_str = "" + for t in result: + if t[0] > prev_line + 1: + result_str += "...\n" + if return_numbered: + result_str += f"{t[0]}: {t[1]}\n" + else: + result_str += f"{t[1]}\n" + prev_line = t[0] + return result_str + return NO_RESULT_FOUND_STR def get_expression_context( self, diff --git a/automata/core/code_indexing/test/sample_modules/sample.py b/automata/core/code_indexing/test/sample_modules/sample.py new file mode 100644 index 00000000..aabfc8e7 --- /dev/null +++ b/automata/core/code_indexing/test/sample_modules/sample.py @@ -0,0 +1,39 @@ +"""This is a sample module""" +import math + + +def sample_function(name): + """This is a sample function.""" + return f"Hello, {name}! Sqrt(2) = " + str(math.sqrt(2)) + + +class Person: + """This is a sample class.""" + + def __init__(self, name): + """This is the constructor.""" + self.name = name + + def say_hello(self): + """This is a sample method.""" + return f"Hello, I am {self.name}." + + def run(self) -> str: + ... + + +def f(x) -> int: + """This is my new function""" + return x + 1 + + +class EmptyClass: + pass + + +class OuterClass: + class InnerClass: + """Inner doc strings""" + + def inner_method(self): + """Inner method doc strings""" diff --git a/automata/core/code_indexing/test/sample_modules/sample2.py b/automata/core/code_indexing/test/sample_modules/sample2.py new file mode 100644 index 00000000..ba8ecad7 --- /dev/null +++ b/automata/core/code_indexing/test/sample_modules/sample2.py @@ -0,0 +1,41 @@ +from typing import List + +from automata.core.base.tool import Tool +from automata.tools.python_tools.python_agent import PythonAgent + + +class PythonAgentToolBuilder: + """A class for building tools to interact with PythonAgent.""" + + def __init__(self, python_agent: PythonAgent): + """ + Initializes a PythonAgentToolBuilder with the given PythonAgent. + + Args: + python_agent (PythonAgent): A PythonAgent instance representing the agent to work with. + """ + self.python_agent = python_agent + + def build_tools(self) -> List: + """ + Builds a list of Tool objects for interacting with PythonAgent. + + Args: + - None + + Returns: + - tools (List[Tool]): A list of Tool objects representing PythonAgent commands. + """ + + def python_agent_python_task(): + """A sample task that utilizes PythonAgent.""" + pass + + tools = [ + Tool( + "automata-task", + python_agent_python_task, + "Execute a Python task using the PythonAgent. Provide the task description in plain English.", + ) + ] + return tools diff --git a/automata/core/search/symbol_rank/symbol_embedding_map.py b/automata/core/search/symbol_rank/symbol_embedding_map.py index 52c0db84..8078c69b 100644 --- a/automata/core/search/symbol_rank/symbol_embedding_map.py +++ b/automata/core/search/symbol_rank/symbol_embedding_map.py @@ -111,7 +111,7 @@ def update_embeddings(self, symbols_to_update: List[Symbol]): map_symbol = desc_to_full_symbol.get(symbol_desc_identifier, None) if not map_symbol: - logger.info("Adding a new symbol: %s" % symbol) + logger.debug("Adding a new symbol: %s" % symbol) symbol_embedding = self.embedding_provider.get_embedding(symbol_source) self.embedding_dict[symbol] = SymbolEmbedding( symbol=symbol, vector=symbol_embedding, source_code=symbol_source @@ -120,7 +120,7 @@ def update_embeddings(self, symbols_to_update: List[Symbol]): # If the symbol is already in the embedding map, check if the source code is the same # If not, we can update the embedding if self.embedding_dict[map_symbol].source_code != symbol_source: - logger.info("Modifying existing embedding for symbol: %s" % symbol) + logger.debug("Modifying existing embedding for symbol: %s" % symbol) symbol_embedding = self.embedding_provider.get_embedding(symbol_source) self.embedding_dict[symbol] = SymbolEmbedding( symbol=symbol, vector=symbol_embedding, source_code=symbol_source diff --git a/automata/core/search/symbol_utils.py b/automata/core/search/symbol_utils.py index 6341f423..f013048d 100644 --- a/automata/core/search/symbol_utils.py +++ b/automata/core/search/symbol_utils.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, Dict, List, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import networkx as nx import numpy as np @@ -8,7 +8,9 @@ from automata.core.search.symbol_types import Descriptor, Symbol, SymbolEmbedding -def convert_to_fst_object(symbol: Symbol, module_map: LazyModuleTreeMap) -> RedBaron: +def convert_to_fst_object( + symbol: Symbol, module_map: Optional[LazyModuleTreeMap] = None +) -> RedBaron: """ Returns the RedBaron object for the given symbol. Args: diff --git a/automata/core/tasks/automata_task_executor.py b/automata/core/tasks/automata_task_executor.py index 483e8a89..8692ef07 100644 --- a/automata/core/tasks/automata_task_executor.py +++ b/automata/core/tasks/automata_task_executor.py @@ -55,10 +55,10 @@ def execute(self, task: AutomataTask): module_map = LazyModuleTreeMap(task.path_to_root_py) retriever = PythonCodeRetriever(module_map) writer = PythonWriter(retriever) - writer.update_module( + writer.update_existing_module( module_dotpath="core.agent.automata_agent", source_code="def test123(x): return True", - write_to_disk=True, + do_write=True, ) task.result = "Test result" diff --git a/automata/tool_management/python_writer_tool_manager.py b/automata/tool_management/python_writer_tool_manager.py index 1c9ee8c4..aceaea21 100644 --- a/automata/tool_management/python_writer_tool_manager.py +++ b/automata/tool_management/python_writer_tool_manager.py @@ -35,23 +35,23 @@ def __init__( self.automata_version = ( kwargs.get("automata_version") or AgentConfigName.AUTOMATA_WRITER_PROD ) - self.model = kwargs.get("model") or "gpt-4" - self.verbose = kwargs.get("verbose") or False - self.stream = kwargs.get("stream") or True - self.temperature = kwargs.get("temperature") or 0.7 + self.model = kwargs.get("model", "gpt-4") + self.verbose = kwargs.get("verbose", False) + self.stream = kwargs.get("stream", True) + self.temperature = kwargs.get("temperature", 0.7) + self.do_write = kwargs.get("do_write", True) def build_tools(self) -> List[Tool]: """Builds a list of Tool object for interacting with PythonWriter.""" tools = [ Tool( name="python-writer-update-module", - func=lambda module_object_code_tuple: self._writer_update_module( + func=lambda module_object_code_tuple: self._update_existing_module( *module_object_code_tuple ), - description=f"Modifies the python code of a function, class, method, or module after receiving" - f" an input module path, source code, and optional class name. If the specified object or dependencies do not exist," - f" then they are created automatically. If the object already exists," - f" then the existing code is modified." + description=f"Inserts or updates the python code of a function, class, method in an existing module" + f" If a given object or its child object do not exist," + f" then they are created automatically. If the object already exists, then the existing code is modified." f" For example -" f' to implement a method "my_method" of "MyClass" in the module "my_file.py" which exists in "my_folder",' f" the correct function call follows:\n" @@ -66,22 +66,69 @@ def build_tools(self) -> List[Tool]: f"Provide the full code as input, as this tool has no context outside of passed arguments.\n", return_direct=True, ), + Tool( + name="python-writer-create-new-module", + func=lambda module_object_code_tuple: self._create_new_module( + *module_object_code_tuple + ), + description=f"Creates a new module at the given path with the given code. For example:" + f" - tool_query_1\n" + f" - tool_name\n" + f" - python-writer-create-new-module\n" + f" - tool_args\n" + f" - my_folder.my_file\n" + f' - import math\ndef my_method() -> None:\n """My Method"""\n print(math.sqrt(4))\n', + return_direct=True, + ), + Tool( + name="python-writer-delete-from-existing-module", + func=lambda module_object_code_tuple: self._delete_from_existing_module( + *module_object_code_tuple + ), + description=f"Deletes python objects and their code by name from existing module. For example:" + f" - tool_query_1\n" + f" - tool_name\n" + f" - python-writer-delete-from-existing-module\n" + f" - tool_args\n" + f" - my_folder.my_file\n" + f" - MyClass.my_method\n", + return_direct=True, + ), ] return tools - def _writer_update_module( - self, module_dotpath: str, class_name: Optional[str], code: str + def _update_existing_module( + self, + module_dotpath: str, + disambiguator: Optional[str], + code: str, ) -> str: """Writes the given code to the given module path and class name.""" try: - print("Attempting to write update to module_path = ", module_dotpath) - self.writer.update_module( - source_code=code, - do_extend=True, - module_dotpath=module_dotpath, - write_to_disk=True, - class_name=class_name, - ) + print("Attempting to write update to existing module_path = ", module_dotpath) + self.writer.update_existing_module(module_dotpath, code, disambiguator, self.do_write) return "Success" except Exception as e: return "Failed to update the module with error - " + str(e) + + def _delete_from_existing_module( + self, + module_dotpath: str, + object_dotpath: str, + ) -> str: + """Writes the given code to the given module path and class name.""" + try: + print("Attempting to reduce existing module_path = ", module_dotpath) + self.writer.delete_from_existing__module(module_dotpath, object_dotpath, self.do_write) + return "Success" + except Exception as e: + return "Failed to reduce the module with error - " + str(e) + + def _create_new_module(self, module_dotpath, code): + """Writes the given code to the given module path and class name.""" + try: + print("Attempting to write new module_path = ", module_dotpath) + self.writer.create_new_module(module_dotpath, code, self.do_write) + return "Success" + except Exception as e: + return "Failed to create the module with error - " + str(e) diff --git a/automata/tool_management/tests/test_python_writer_tool_manager.py b/automata/tool_management/tests/test_python_writer_tool_manager.py index fcf387ce..3178b31e 100644 --- a/automata/tool_management/tests/test_python_writer_tool_manager.py +++ b/automata/tool_management/tests/test_python_writer_tool_manager.py @@ -30,7 +30,7 @@ def test_init(python_writer_tool_builder): def test_build_tools(python_writer_tool_builder): tools = python_writer_tool_builder.build_tools() - assert len(tools) == 1 + assert len(tools) == 3 for tool in tools: assert isinstance(tool, Tool) @@ -41,16 +41,15 @@ def test_bootstrap_module_with_new_function(python_writer_tool_builder): absolute_path = os.sep.join(os.path.abspath(current_file).split(os.sep)[:-1]) tools = python_writer_tool_builder.build_tools() - code_writer = tools[0] + create_module_tool = tools[1] function_def = "def f(x):\n return x + 1" package = "sample_code" module = "sample3" file_py_path = f"{package}.{module}" file_abs_path = os.path.join(absolute_path, package, f"{module}.py") - code_writer.func((file_py_path, None, function_def)) + create_module_tool.func((file_py_path, function_def)) - new_sample_text = None with open(file_abs_path, "r", encoding="utf-8") as f: new_sample_text = f.read() assert new_sample_text.strip() == function_def @@ -61,14 +60,13 @@ def test_bootstrap_module_with_new_function(python_writer_tool_builder): def test_extend_module_with_new_function(python_writer_tool_builder): current_file = inspect.getframeinfo(inspect.currentframe()).filename absolute_path = os.sep.join(os.path.abspath(current_file).split(os.sep)[:-1]) - prev_text = None with open(os.path.join(absolute_path, "sample_code", "sample.py"), "r", encoding="utf-8") as f: prev_text = f.read() assert prev_text is not None, "Could not read sample.py" tools = python_writer_tool_builder.build_tools() code_writer = tools[0] - function_def = "def f(x):\n return x + 1" + function_def = "def g(x):\n return x + 1" package = "sample_code" module = "sample" @@ -77,7 +75,6 @@ def test_extend_module_with_new_function(python_writer_tool_builder): file_abs_path = os.path.join(absolute_path, file_rel_path) code_writer.func((file_py_path, None, function_def)) - new_sample_text = None with open(file_abs_path, "r", encoding="utf-8") as f: new_sample_text = f.read() assert function_def in new_sample_text diff --git a/automata/tools/python_tools/python_writer.py b/automata/tools/python_tools/python_writer.py index f3c00537..98b58e41 100644 --- a/automata/tools/python_tools/python_writer.py +++ b/automata/tools/python_tools/python_writer.py @@ -37,7 +37,7 @@ class ClassToRemove: import subprocess from typing import Optional, Union, cast -from redbaron import ClassNode, Node, NodeList, RedBaron +from redbaron import ClassNode, DefNode, Node, NodeList, RedBaron from automata.core.code_indexing.python_code_retriever import PythonCodeRetriever from automata.core.code_indexing.syntax_tree_navigation import ( @@ -74,7 +74,7 @@ class PythonWriter: class ModuleNotFound(Exception): pass - class ClassNotFound(Exception): + class ClassOrFunctionNotFound(Exception): pass class InvalidArguments(Exception): @@ -86,82 +86,87 @@ def __init__(self, python_retriever: PythonCodeRetriever): """ self.code_retriever = python_retriever - def update_module(self, source_code: str, do_extend: bool = True, **kwargs) -> None: + def create_new_module( + self, module_dotpath: str, source_code: str, do_write: bool = False + ) -> None: """ - Perform an in-place extention or reduction of a module object according to the received code. + Create a new module object from source code. Args: - source_code (str): The source_code containing the updates or deletions. - do_extend (bool): True for adding/updating, False for reducing/deleting. - module_obj (Optional[Module], keyword): The module object to be updated. - module_path (Optional[str], keyword): The path of the module to be updated. - class_name (Optional[str], keyword): The name of the class where the update should be applied, will default to module. - write_to_disk (Optional[bool], keyword): Writes the changed module to disk. - - Raises: - InvalidArguments: If both module_obj and module_dotpath are provided or none of them. + source_code (str): The source code of the module. + module_dotpath (str): The path of the module. Returns: - Module: The updated module object. + RedBaron: The module object. """ - module_obj = kwargs.get("module_obj") - module_dotpath = kwargs.get("module_dotpath") - class_name = kwargs.get("class_name") or "" - write_to_disk = kwargs.get("write_to_disk") or False - - logger.info( - "\n---Updating module---\nPath:\n%s\nClass Name:\n%s\nSource Code:\n%s\nWriting to disk:\n%s\n" - % (module_dotpath, class_name, source_code, write_to_disk) - ) + self._create_module_from_source_code(module_dotpath, source_code) + if do_write: + self._write_module_to_disk(module_dotpath) - self._validate_args(module_obj, module_dotpath, write_to_disk) - source_code = PythonWriter._clean_input_code(source_code) + def update_existing_module( + self, + module_dotpath: str, + source_code: str, + disambiguator: Optional[str] = "", + do_write: bool = False, + ) -> None: + """ + Update code or insert new code into an existing module. - if module_dotpath: - module_dotpath = cast(str, module_dotpath) + Args: + source_code (str): The source code of the part of the module that needs to be updated or insert. + module_dotpath (str): The path of the module. + disambiguator (Optional[str]): The name of the class or function scope where the update should be applied, will default to module. + do_write (bool): Write the module to disk after updating. - # christ on a bike - is_new_module = ( - not module_obj - and module_dotpath - and module_dotpath not in self.code_retriever.module_tree_map + Returns: + RedBaron: The module object. + """ + module_obj = self.code_retriever.module_tree_map.get_module(module_dotpath) + if not module_obj: + raise PythonWriter.ModuleNotFound( + f"Module not found in module dictionary: {module_dotpath}" + ) + PythonWriter._update_existing_module( + source_code, + module_dotpath, + module_obj, + disambiguator=disambiguator, ) + if do_write: + self._write_module_to_disk(module_dotpath) - is_existing_module = ( - module_obj - and self.code_retriever.module_tree_map.get_existing_module_dotpath(module_obj) - or module_dotpath in self.code_retriever.module_tree_map - ) + def delete_from_existing__module( + self, module_dotpath: str, object_dotpath: str, do_write: bool = False + ): + """ + Reduce an existing module by removing a class or function. - if is_new_module: - self._create_module_from_source_code(module_dotpath, source_code) - elif is_existing_module: - if module_obj: - module_dotpath = self.code_retriever.module_tree_map.get_existing_module_dotpath( - module_obj - ) - module_obj = self.code_retriever.module_tree_map.get_module(module_dotpath) - PythonWriter._update_module( - source_code, - module_dotpath, - module_obj, - do_extend, - class_name, - ) - else: - raise PythonWriter.InvalidArguments( - f"Module is neither new nor existing, somehow: {module_dotpath}" - ) + Args: + module_dotpath (str): The path of the module. + object_dotpath (str): The name of the class or function to remove, including the name of the scope it is in, like ClassName.function_name + do_write (bool): Write the module to disk after updating. - if write_to_disk: - self.write_module(module_dotpath) + Returns: + RedBaron: The module object. + """ + module_obj = self.code_retriever.module_tree_map.get_module(module_dotpath) + if not module_obj: + raise PythonWriter.ModuleNotFound( + f"Module not found in module dictionary: {module_dotpath}" + ) + node = find_syntax_tree_node(module_obj, object_dotpath) + if node: + PythonWriter._delete_node(node) + if do_write: + self._write_module_to_disk(module_dotpath) - def write_module(self, module_dotpath: str) -> None: + def _write_module_to_disk(self, module_dotpath: str) -> None: """ Write the modified module to a file at the specified output path. Args: - module_dotpath (str): The file path where the modified module should be written. + module_dotpath (str) """ if module_dotpath not in self.code_retriever.module_tree_map: raise PythonWriter.ModuleNotFound( @@ -171,6 +176,12 @@ def write_module(self, module_dotpath: str) -> None: module_fpath = self.code_retriever.module_tree_map.get_existing_module_fpath_by_dotpath( module_dotpath ) + + if not module_fpath: + raise PythonWriter.ModuleNotFound( + f"Module fpath found in module map for dotpath: {module_dotpath}" + ) + module_fpath = cast(str, module_fpath) with open(module_fpath, "w") as output_file: output_file.write(source_code) subprocess.run(["black", module_fpath]) @@ -188,26 +199,11 @@ def _create_module_from_source_code(self, module_dotpath: str, source_code: str) return parsed @staticmethod - def _validate_args( - module_obj: Optional[RedBaron], module_dotpath: Optional[str], write_to_disk: bool - ) -> None: - """Validate the arguments passed to the update_module method.""" - if not (module_obj or module_dotpath) or (module_obj and module_dotpath): - raise PythonWriter.InvalidArguments( - "Provide either 'module_obj' or 'module_path', not both or none." - ) - if not module_dotpath and write_to_disk: - raise PythonWriter.InvalidArguments( - "Provide 'module_path' to write the module to disk." - ) - - @staticmethod - def _update_module( + def _update_existing_module( source_code: str, module_dotpath: str, existing_module_obj: RedBaron, - do_extend: bool, - class_name: str = "", + disambiguator: Optional[str], ) -> None: """ Update a module object according to the received code. @@ -216,49 +212,43 @@ def _update_module( source_code (str): The code containing the updates. module_dotpath (str): The relative path to the module. existing_module_obj Module: The module object to be updated. - do_extend (bool): If True, add or update the code; if False, remove the code. + disambiguator (str): The name of the class or function scope to be updated, useful for nested definitions. """ new_fst = RedBaron(source_code) new_import_nodes = find_import_syntax_tree_nodes(new_fst) - PythonWriter._manage_imports(existing_module_obj, new_import_nodes, do_extend) + PythonWriter._update_imports(existing_module_obj, new_import_nodes) new_class_or_function_nodes = find_all_function_and_class_syntax_tree_nodes(new_fst) - if class_name: # splice the class - existing_class = find_syntax_tree_node(existing_module_obj, class_name) - if isinstance(existing_class, ClassNode): + if disambiguator: # splice the class + disambiguator_node = find_syntax_tree_node(existing_module_obj, disambiguator) + if isinstance(disambiguator_node, (ClassNode, DefNode)): PythonWriter._update_node_with_children( - new_class_or_function_nodes, existing_class, do_extend + new_class_or_function_nodes, + disambiguator_node, ) - - elif not do_extend: - raise PythonWriter.ClassNotFound( - f"Class {class_name} not found in module {module_dotpath}" + else: + raise PythonWriter.ClassOrFunctionNotFound( + f"Node {disambiguator} not found in module {module_dotpath}" ) - PythonWriter._update_node_with_children( - new_class_or_function_nodes, existing_module_obj, do_extend - ) + PythonWriter._update_node_with_children(new_class_or_function_nodes, existing_module_obj) @staticmethod def _update_node_with_children( class_or_function_nodes: NodeList, node_to_update: Union[ClassNode, RedBaron], - do_extend: bool, ) -> None: """Update a class object according to the received code.""" for new_node in class_or_function_nodes: child_node_name = new_node.name existing_node = find_syntax_tree_node(node_to_update, child_node_name) - if do_extend: - if existing_node: - existing_node.replace(new_node) - else: - node_to_update.append(new_node) - elif existing_node: - PythonWriter.delete_node(existing_node) + if existing_node: + existing_node.replace(new_node) + else: + node_to_update.append(new_node) @staticmethod - def delete_node(node: Node) -> None: + def _delete_node(node: Node) -> None: """Delete a node from the FST.""" parent = node.parent parent_index = node.index_on_parent @@ -310,19 +300,16 @@ def replace(match): return source_code @staticmethod - def _manage_imports( - module_obj: RedBaron, new_import_statements: NodeList, do_extend: bool - ) -> None: + def _update_imports(module_obj: RedBaron, new_import_statements: NodeList) -> None: """Manage the imports in the module.""" + first_import = module_obj.find(lambda identifier: identifier in ("import", "from_import")) + for new_import_statement in new_import_statements: existing_import_statement = find_import_syntax_tree_node_by_name( module_obj, new_import_statement.name ) - if do_extend: - if existing_import_statement: - existing_import_statement.replace(new_import_statement) + if not existing_import_statement: + if first_import: + first_import.insert_before(new_import_statement) # we will run isort later else: - module_obj.append(new_import_statement) - else: - if existing_import_statement: - PythonWriter.delete_node(existing_import_statement) + module_obj.insert(0, new_import_statement) diff --git a/automata/tools/python_tools/tests/test_python_writer.py b/automata/tools/python_tools/tests/test_python_writer.py index 8b9b752e..225d5547 100644 --- a/automata/tools/python_tools/tests/test_python_writer.py +++ b/automata/tools/python_tools/tests/test_python_writer.py @@ -180,21 +180,22 @@ def test_create_class_source_class(): def test_extend_module(python_writer): + # Arrange + # create module mock_generator = MockCodeGenerator( has_class=True, has_class_docstring=True, has_function=True, has_function_docstring=True ) source_code = mock_generator.generate_code() - module_obj = python_writer._create_module_from_source_code("sample_module_2", source_code) - mock_generator._check_module_obj(module_obj) - + python_writer.create_new_module("sample_module_2", source_code) mock_generator_2 = MockCodeGenerator( has_class=True, has_class_docstring=True, has_function=True, has_function_docstring=True ) source_code_2 = mock_generator_2.generate_code() - python_writer.update_module(source_code=source_code_2, module_obj=module_obj, do_extend=True) + python_writer.update_existing_module("sample_module_2", source_code_2) # Check module 2 is merged into module 1 + module_obj = python_writer.code_retriever.module_tree_map.get_module("sample_module_2") mock_generator._check_module_obj(module_obj) mock_generator._check_class_obj(module_obj[0]) mock_generator._check_function_obj(module_obj[1]) @@ -207,27 +208,49 @@ def test_reduce_module(python_writer): has_class=True, has_class_docstring=True, has_function=True, has_function_docstring=True ) source_code = mock_generator.generate_code() - module_obj = python_writer._create_module_from_source_code("sample_module_2", source_code) + python_writer.create_new_module("sample_module_2", source_code) + module_obj = python_writer.code_retriever.module_tree_map.get_module("sample_module_2") class_obj = module_obj.find("class") function_obj = module_obj.find_all("def")[-1] - python_writer.update_module( - source_code=class_obj.dumps(), module_obj=module_obj, do_extend=False - ) + python_writer.delete_from_existing__module("sample_module_2", class_obj.name) assert module_obj[0] == function_obj +def assert_code_lines_equal(code_1: str, code_2: str): + code_1_lines = [line for line in code_1.splitlines() if line.strip()] + code_2_lines = [line for line in code_2.splitlines() if line.strip()] + assert all(line_1 == line_2 for line_1, line_2 in zip(code_1_lines, code_2_lines)) + + def test_create_update_write_module(python_writer): mock_generator = MockCodeGenerator( has_class=True, has_class_docstring=True, has_function=True, has_function_docstring=True ) source_code = mock_generator.generate_code() - python_writer.update_module( - source_code=source_code, module_dotpath="sample_module_2", do_extend=False - ) - python_writer.write_module("sample_module_2") + python_writer.create_new_module("sample_module_write", source_code, do_write=True) root_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "sample_modules") - os.remove(os.path.join(root_dir, "sample_module_2.py")) + fpath = os.path.join(root_dir, "sample_module_write.py") + assert os.path.exists(fpath) + with open(fpath, "r") as f: + contents = f.read() + assert_code_lines_equal(source_code, contents) + + mock_generator_2 = MockCodeGenerator( + has_class=True, has_class_docstring=True, has_function=True, has_function_docstring=True + ) + source_code_2 = mock_generator_2.generate_code() + + assert source_code != source_code_2 + python_writer.update_existing_module( + source_code=source_code_2, module_dotpath="sample_module_write", do_write=True + ) + + with open(fpath, "r") as f: + contents = f.read() + assert_code_lines_equal("\n".join([source_code, source_code_2]), contents) + + os.remove(fpath) def test_create_function_with_arguments(): @@ -303,13 +326,10 @@ def test_reduce_module_remove_function(python_writer): source_code = mock_generator.generate_code() module_obj = python_writer._create_module_from_source_code("sample_module_2", source_code) - class_obj = module_obj[1] function_obj = module_obj[2] - python_writer.update_module( - source_code=function_obj.dumps(), module_obj=module_obj, do_extend=False - ) + python_writer.delete_from_existing__module("sample_module_2", function_obj.name) assert module_obj[0] == class_obj assert len(module_obj.filtered()) == 1 @@ -326,8 +346,8 @@ def {mock_generator.function_name}(): """ ) module_obj = python_writer._create_module_from_source_code("sample_module_2", source_code) - python_writer.update_module( - source_code=source_code_updated, module_obj=module_obj, do_extend=True + python_writer.update_existing_module( + source_code=source_code_updated, module_dotpath="sample_module_2" ) updated_function_obj = find_syntax_tree_node(module_obj, mock_generator.function_name) assert len(updated_function_obj) == 1 @@ -348,7 +368,7 @@ def test_write_and_retrieve_mock_code(python_writer): source_code = mock_generator.generate_code() python_writer._create_module_from_source_code("sample_module_2", source_code) - python_writer.write_module("sample_module_2") + python_writer._write_module_to_disk("sample_module_2") sample_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "sample_modules") module_map = LazyModuleTreeMap(sample_dir) diff --git a/automata/tools/tools.md b/automata/tools/tools.md index d03285eb..a641b850 100644 --- a/automata/tools/tools.md +++ b/automata/tools/tools.md @@ -39,7 +39,7 @@ updated_module = writer.update_module( ) # Write the updated module to disk -writer.write_module("output/package/module.py") +writer._write_module_to_disk("output/package/module.py") ``` ## References