Skip to content

Commit

Permalink
Merge pull request emrgnt-cmplxty#160 from maks-ivanov/feature/al-23-…
Browse files Browse the repository at this point in the history
…refactor-update_module-of-python-writer

Feature/al 23 refactor update module of python writer
  • Loading branch information
maks-ivanov authored May 31, 2023
2 parents 45082ef + 075d177 commit 1c6f847
Show file tree
Hide file tree
Showing 15 changed files with 390 additions and 260 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/check-black.yml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name: Lint and Format Check

on: [push, pull_request]
on: [pull_request]

jobs:
lint-and-format:
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/check-isort.yml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name: Check Import Sorting

on: [push, pull_request]
on: [pull_request]

jobs:
import-sorting:
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/check-mypy.yml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name: Type Checking

on: [push, pull_request]
on: [pull_request]

jobs:
type-check:
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/test-with-codecov.yml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
name: Codecov Check
on: [push]
on: [pull-request]
jobs:
run:
runs-on: ${{ matrix.os }}
Expand Down
185 changes: 91 additions & 94 deletions automata/core/code_indexing/python_code_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

class PythonCodeRetriever:
def __init__(
self, module_tree_map: Optional[LazyModuleTreeMap] = LazyModuleTreeMap.cached_default()
self, module_tree_map: LazyModuleTreeMap = LazyModuleTreeMap.cached_default()
) -> None:
self.module_tree_map = module_tree_map

Expand Down Expand Up @@ -57,7 +57,9 @@ def get_docstring(self, module_dotpath: str, object_path: Optional[str]) -> str:
"""

module = self.module_tree_map.get_module(module_dotpath)
return PythonCodeRetriever._get_docstring(find_syntax_tree_node(module, object_path))
if module:
return PythonCodeRetriever._get_docstring(find_syntax_tree_node(module, object_path))
return NO_RESULT_FOUND_STR

def get_source_code_without_docstrings(
self, module_dotpath: str, object_path: Optional[str]
Expand Down Expand Up @@ -96,16 +98,14 @@ def _remove_docstrings(node: FSTNode) -> None:

module = self.module_tree_map.get_module(module_dotpath)

module = (
RedBaron(module.dumps()) if module else None
) # create a copy because we'll remove docstrings
result = find_syntax_tree_node(module, object_path)
if module:
module_copy = RedBaron(module.dumps())
result = find_syntax_tree_node(module_copy, object_path)

if result:
_remove_docstrings(result)
return result.dumps()
else:
return NO_RESULT_FOUND_STR
if result:
_remove_docstrings(result)
return result.dumps()
return NO_RESULT_FOUND_STR

def get_parent_function_name_by_line(self, module_dotpath: str, line_number: int) -> str:
"""
Expand All @@ -121,19 +121,16 @@ def get_parent_function_name_by_line(self, module_dotpath: str, line_number: int
"""

module = self.module_tree_map.get_module(module_dotpath)
if not module:
return NO_RESULT_FOUND_STR

node = module.at(line_number)
if node.type != "def":
node = node.parent_find("def")
if node:
if node.parent[0].type == "class":
return f"{node.parent.name}.{node.name}"
else:
return node.name
else:
return NO_RESULT_FOUND_STR
if module:
node = module.at(line_number)
if node.type != "def":
node = node.parent_find("def")
if node:
if node.parent[0].type == "class":
return f"{node.parent.name}.{node.name}"
else:
return node.name
return NO_RESULT_FOUND_STR

def get_parent_function_num_code_lines(
self, module_dotpath: str, line_number: int
Expand All @@ -150,19 +147,17 @@ def get_parent_function_num_code_lines(
"""
module = self.module_tree_map.get_module(module_dotpath)
if not module:
return NO_RESULT_FOUND_STR

node = module.at(line_number)
if node.type != "def":
node = node.parent_find("def")
if not node:
return NO_RESULT_FOUND_STR
return (
node.absolute_bounding_box.bottom_right.line
- node.absolute_bounding_box.top_left.line
+ 1
)
if module:
node = module.at(line_number)
if node.type != "def":
node = node.parent_find("def")
if node:
return (
node.absolute_bounding_box.bottom_right.line
- node.absolute_bounding_box.top_left.line
+ 1
)
return NO_RESULT_FOUND_STR

def get_parent_code_by_line(
self, module_dotpath: str, line_number: int, return_numbered=False
Expand All @@ -181,66 +176,68 @@ def get_parent_code_by_line(
"""

module = self.module_tree_map.get_module(module_dotpath)
if not module:
return NO_RESULT_FOUND_STR
node = module.at(line_number)

# retarget def or class node
if node.type not in ("def", "class") and node.parent_find(
lambda identifier: identifier in ("def", "class")
):
node = node.parent_find(lambda identifier: identifier in ("def", "class"))

path = node.path().to_baron_path()
pointer = module
result = []
if module:
node = module.at(line_number)

# retarget def or class node
if node.type not in ("def", "class") and node.parent_find(
lambda identifier: identifier in ("def", "class")
):
node = node.parent_find(lambda identifier: identifier in ("def", "class"))

path = node.path().to_baron_path()
pointer = module
result = []

for entry in path:
if isinstance(entry, int):
pointer = pointer.node_list
for x in range(entry):
start_line, start_col = (
pointer[x].absolute_bounding_box.top_left.line,
pointer[x].absolute_bounding_box.top_left.column,
)

for entry in path:
if isinstance(entry, int):
pointer = pointer.node_list
for x in range(entry):
if pointer[x].type == "string" and pointer[x].value.startswith('"""'):
result += self._create_line_number_tuples(
pointer[x], start_line, start_col
)
if pointer[x].type in ("def", "class"):
docstring = PythonCodeRetriever._get_docstring(pointer[x])
node_copy = pointer[x].copy()
node_copy.value = '"""' + docstring + '"""'
result += self._create_line_number_tuples(
node_copy, start_line, start_col
)
pointer = pointer[entry]
else:
start_line, start_col = (
pointer[x].absolute_bounding_box.top_left.line,
pointer[x].absolute_bounding_box.top_left.column,
pointer.absolute_bounding_box.top_left.line,
pointer.absolute_bounding_box.top_left.column,
)

if pointer[x].type == "string" and pointer[x].value.startswith('"""'):
result += self._create_line_number_tuples(
pointer[x], start_line, start_col
)
if pointer[x].type in ("def", "class"):
docstring = PythonCodeRetriever._get_docstring(pointer[x])
node_copy = pointer[x].copy()
node_copy.value = '"""' + docstring + '"""'
result += self._create_line_number_tuples(node_copy, start_line, start_col)
pointer = pointer[entry]
else:
start_line, start_col = (
pointer.absolute_bounding_box.top_left.line,
pointer.absolute_bounding_box.top_left.column,
)
node_copy = pointer.copy()
node_copy.value = ""
result += self._create_line_number_tuples(node_copy, start_line, start_col)
pointer = getattr(pointer, entry)

start_line, start_col = (
pointer.absolute_bounding_box.top_left.line,
pointer.absolute_bounding_box.top_left.column,
)
result += self._create_line_number_tuples(pointer, start_line, start_col)

prev_line = 1
result_str = ""
for t in result:
if t[0] > prev_line + 1:
result_str += "...\n"
if return_numbered:
result_str += f"{t[0]}: {t[1]}\n"
else:
result_str += f"{t[1]}\n"
prev_line = t[0]
return result_str
node_copy = pointer.copy()
node_copy.value = ""
result += self._create_line_number_tuples(node_copy, start_line, start_col)
pointer = getattr(pointer, entry)

start_line, start_col = (
pointer.absolute_bounding_box.top_left.line,
pointer.absolute_bounding_box.top_left.column,
)
result += self._create_line_number_tuples(pointer, start_line, start_col)

prev_line = 1
result_str = ""
for t in result:
if t[0] > prev_line + 1:
result_str += "...\n"
if return_numbered:
result_str += f"{t[0]}: {t[1]}\n"
else:
result_str += f"{t[1]}\n"
prev_line = t[0]
return result_str
return NO_RESULT_FOUND_STR

def get_expression_context(
self,
Expand Down
39 changes: 39 additions & 0 deletions automata/core/code_indexing/test/sample_modules/sample.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
"""This is a sample module"""
import math


def sample_function(name):
"""This is a sample function."""
return f"Hello, {name}! Sqrt(2) = " + str(math.sqrt(2))


class Person:
"""This is a sample class."""

def __init__(self, name):
"""This is the constructor."""
self.name = name

def say_hello(self):
"""This is a sample method."""
return f"Hello, I am {self.name}."

def run(self) -> str:
...


def f(x) -> int:
"""This is my new function"""
return x + 1


class EmptyClass:
pass


class OuterClass:
class InnerClass:
"""Inner doc strings"""

def inner_method(self):
"""Inner method doc strings"""
41 changes: 41 additions & 0 deletions automata/core/code_indexing/test/sample_modules/sample2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from typing import List

from automata.core.base.tool import Tool
from automata.tools.python_tools.python_agent import PythonAgent


class PythonAgentToolBuilder:
"""A class for building tools to interact with PythonAgent."""

def __init__(self, python_agent: PythonAgent):
"""
Initializes a PythonAgentToolBuilder with the given PythonAgent.
Args:
python_agent (PythonAgent): A PythonAgent instance representing the agent to work with.
"""
self.python_agent = python_agent

def build_tools(self) -> List:
"""
Builds a list of Tool objects for interacting with PythonAgent.
Args:
- None
Returns:
- tools (List[Tool]): A list of Tool objects representing PythonAgent commands.
"""

def python_agent_python_task():
"""A sample task that utilizes PythonAgent."""
pass

tools = [
Tool(
"automata-task",
python_agent_python_task,
"Execute a Python task using the PythonAgent. Provide the task description in plain English.",
)
]
return tools
4 changes: 2 additions & 2 deletions automata/core/search/symbol_rank/symbol_embedding_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def update_embeddings(self, symbols_to_update: List[Symbol]):
map_symbol = desc_to_full_symbol.get(symbol_desc_identifier, None)

if not map_symbol:
logger.info("Adding a new symbol: %s" % symbol)
logger.debug("Adding a new symbol: %s" % symbol)
symbol_embedding = self.embedding_provider.get_embedding(symbol_source)
self.embedding_dict[symbol] = SymbolEmbedding(
symbol=symbol, vector=symbol_embedding, source_code=symbol_source
Expand All @@ -120,7 +120,7 @@ def update_embeddings(self, symbols_to_update: List[Symbol]):
# If the symbol is already in the embedding map, check if the source code is the same
# If not, we can update the embedding
if self.embedding_dict[map_symbol].source_code != symbol_source:
logger.info("Modifying existing embedding for symbol: %s" % symbol)
logger.debug("Modifying existing embedding for symbol: %s" % symbol)
symbol_embedding = self.embedding_provider.get_embedding(symbol_source)
self.embedding_dict[symbol] = SymbolEmbedding(
symbol=symbol, vector=symbol_embedding, source_code=symbol_source
Expand Down
6 changes: 4 additions & 2 deletions automata/core/search/symbol_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Callable, Dict, List, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import networkx as nx
import numpy as np
Expand All @@ -8,7 +8,9 @@
from automata.core.search.symbol_types import Descriptor, Symbol, SymbolEmbedding


def convert_to_fst_object(symbol: Symbol, module_map: LazyModuleTreeMap) -> RedBaron:
def convert_to_fst_object(
symbol: Symbol, module_map: Optional[LazyModuleTreeMap] = None
) -> RedBaron:
"""
Returns the RedBaron object for the given symbol.
Args:
Expand Down
4 changes: 2 additions & 2 deletions automata/core/tasks/automata_task_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,10 @@ def execute(self, task: AutomataTask):
module_map = LazyModuleTreeMap(task.path_to_root_py)
retriever = PythonCodeRetriever(module_map)
writer = PythonWriter(retriever)
writer.update_module(
writer.update_existing_module(
module_dotpath="core.agent.automata_agent",
source_code="def test123(x): return True",
write_to_disk=True,
do_write=True,
)
task.result = "Test result"

Expand Down
Loading

0 comments on commit 1c6f847

Please sign in to comment.