Skip to content
This repository has been archived by the owner on Mar 16, 2024. It is now read-only.

Feature/merged default chroma #150

Merged
merged 8 commits into from
Jul 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -60,4 +60,5 @@ local_tasks/
scripts/setup.sh
setup.sh
.direnv
docs/_build
docs/_build
.chroma
2 changes: 1 addition & 1 deletion automata/agent/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def _build_initial_messages(
assert "user_input_instructions" in instruction_formatter

messages_config = load_config(
ConfigCategory.INSTRUCTION.value, self.config.instruction_version.value
ConfigCategory.INSTRUCTION.to_path(), self.config.instruction_version.to_path()
)
initial_messages = messages_config["initial_messages"]

Expand Down
4 changes: 2 additions & 2 deletions automata/cli/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def common_options(command: click.Command, *args, **kwargs) -> click.Command:
),
click.option(
"--index-file",
default="index.scip",
default="automata.scip",
help="Which index file to use for the embedding modifications.",
),
click.option(
Expand All @@ -30,7 +30,7 @@ def common_options(command: click.Command, *args, **kwargs) -> click.Command:
),
click.option(
"--doc-embedding-file",
default="symbol_doc_embedding_l3.json",
default="symbol_doc_embedding_l2.json",
help="Which embedding file to save to.",
),
]
Expand Down
4 changes: 3 additions & 1 deletion automata/cli/scripts/run_agent_config_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,9 @@ def test_yaml_compatibility(file_path) -> None:

if __name__ == "__main__":
# Find all .yaml files in the specified directory
yaml_files = glob.glob(os.path.join(get_config_fpath(), ConfigCategory.AGENT.value, "*.yaml"))
yaml_files = glob.glob(
os.path.join(get_config_fpath(), ConfigCategory.AGENT.to_path(), "*.yaml")
)

# Run validation and compatibility tests on each YAML file
for yaml_file in yaml_files:
Expand Down
15 changes: 1 addition & 14 deletions automata/cli/scripts/run_code_embedding.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
import logging
import os

from tqdm import tqdm

from automata.config.base import ConfigCategory
from automata.core.utils import get_config_fpath
from automata.llm.providers.openai import OpenAIEmbeddingProvider
from automata.memory_store.symbol_code_embedding import SymbolCodeEmbeddingHandler
from automata.singletons.dependency_factory import dependency_factory
Expand All @@ -22,20 +19,10 @@ def main(*args, **kwargs) -> str:

py_module_loader.initialize()

scip_fpath = os.path.join(
get_config_fpath(), ConfigCategory.SYMBOL.value, kwargs.get("index-file", "index.scip")
)
code_embedding_fpath = os.path.join(
get_config_fpath(),
ConfigCategory.SYMBOL.value,
kwargs.get("code-embedding-file", "symbol_code_embedding.json"),
)
embedding_provider = OpenAIEmbeddingProvider()

dependency_factory.set_overrides(
**{
"symbol_graph_scip_fpath": scip_fpath,
"code_embedding_fpath": code_embedding_fpath,
"embedding_provider": embedding_provider,
"disable_synchronization": True,
}
Expand All @@ -55,8 +42,8 @@ def main(*args, **kwargs) -> str:
for symbol in tqdm(filtered_symbols):
try:
symbol_code_embedding_handler.process_embedding(symbol)
symbol_code_embedding_handler.embedding_db.save()
except Exception as e:
logger.error(f"Failed to update embedding for {symbol.dotpath}: {e}")

symbol_code_embedding_handler.embedding_db.save()
return "Success"
31 changes: 2 additions & 29 deletions automata/cli/scripts/run_doc_embedding.py
Original file line number Diff line number Diff line change
@@ -1,60 +1,33 @@
import logging
import os

from tqdm import tqdm

from automata.config.base import ConfigCategory
from automata.context_providers.symbol_synchronization import (
SymbolProviderSynchronizationContext,
)
from automata.core.utils import get_config_fpath
from automata.llm.providers.openai import OpenAIEmbeddingProvider
from automata.memory_store.symbol_code_embedding import SymbolCodeEmbeddingHandler
from automata.memory_store.symbol_doc_embedding import SymbolDocEmbeddingHandler
from automata.singletons.dependency_factory import dependency_factory
from automata.singletons.py_module_loader import py_module_loader
from automata.symbol.graph import SymbolGraph
from automata.symbol.symbol_utils import get_rankable_symbols
from automata.symbol_embedding.vector_databases import JSONSymbolEmbeddingVectorDatabase

logger = logging.getLogger(__name__)


def initialize_providers(embedding_level, **kwargs):
py_module_loader.initialize()

scip_fpath = os.path.join(
get_config_fpath(), ConfigCategory.SYMBOL.value, kwargs.get("index-file", "index.scip")
)
code_embedding_fpath = os.path.join(
get_config_fpath(),
ConfigCategory.SYMBOL.value,
kwargs.get("code-embedding-file", "symbol_code_embedding.json"),
)
doc_embedding_fpath = os.path.join(
get_config_fpath(),
ConfigCategory.SYMBOL.value,
kwargs.get("doc-embedding-file", f"symbol_doc_embedding_l{embedding_level}.json"),
)

embedding_provider = OpenAIEmbeddingProvider()

overrides = {
"symbol_graph_scip_fpath": scip_fpath,
"code_embedding_fpath": code_embedding_fpath,
"doc_embedding_fpath": doc_embedding_fpath,
"embedding_provider": embedding_provider,
"disable_synchronization": True,
}

if embedding_level == 3:
doc_embedding_fpath_l2 = os.path.join(
get_config_fpath(),
ConfigCategory.SYMBOL.value,
kwargs.get("doc-embedding-file", "symbol_doc_embedding_l2.json"),
)
doc_embedding_db_l2 = JSONSymbolEmbeddingVectorDatabase(doc_embedding_fpath_l2)
overrides["py_retriever_doc_embedding_db"] = doc_embedding_db_l2
raise NotImplementedError("Embedding level 3 is not supported at this moment.")

dependency_factory.set_overrides(**overrides)

Expand Down Expand Up @@ -90,6 +63,6 @@ def main(*args, **kwargs) -> str:
for symbol in tqdm(filtered_symbols):
logger.info(f"Caching embedding for {symbol}")
symbol_doc_embedding_handler.process_embedding(symbol)
symbol_doc_embedding_handler.embedding_db.save()
symbol_doc_embedding_handler.embedding_db.save()

return "Success"
2 changes: 1 addition & 1 deletion automata/cli/scripts/run_doc_post_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def main(*args, **kwargs) -> str:
doc_writer = PyDocWriter(get_root_fpath())

embedding_path = os.path.join(
get_config_fpath(), ConfigCategory.SYMBOL.value, "symbol_doc_embedding_l2.json"
get_config_fpath(), ConfigCategory.SYMBOL.to_path(), "symbol_doc_embedding_l2.json"
)

embedding_db = JSONSymbolEmbeddingVectorDatabase(embedding_path)
Expand Down
40 changes: 29 additions & 11 deletions automata/config/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,30 @@
import yaml
from pydantic import BaseModel, PrivateAttr

from automata.core.utils import convert_kebab_to_snake
from automata.core.utils import convert_kebab_to_snake_case
from automata.tools.base import Tool


class ConfigCategory(Enum):
class PathEnum(Enum):

"""An abstract class for enums that represent paths"""

def to_path(self) -> str:
return convert_kebab_to_snake_case(self.value)


class EmbeddingDataCategory(PathEnum):
"""
A class to represent the different categories of configuration options
Corresponds folders in automata/configs/*
"""

CODE_EMBEDDING = "code-embedding"
DOC_EMBEDDING = "doc-embedding-l2"
INDICES = "indices"


class ConfigCategory(PathEnum):
"""
A class to represent the different categories of configuration options
Corresponds folders in automata/configs/*
Expand All @@ -19,19 +38,19 @@ class ConfigCategory(Enum):
AGENT = "agent"
PROMPT = "prompt"
SYMBOL = "symbol"
INSTRUCTION = "instruction_configs"
INSTRUCTION = "instruction-configs"


class InstructionConfigVersion(Enum):
class InstructionConfigVersion(PathEnum):
"""
InstructionConfigVersion: Enum of instruction versions.
Corresponds files in automata/configs/instruction_configs/*.yaml
"""

AGENT_INTRODUCTION = "agent_introduction"
AGENT_INTRODUCTION = "agent-introduction"


class AgentConfigName(Enum):
class AgentConfigName(PathEnum):
"""
AgentConfigName: Enum of agent config names.
Corresponds files in automata/config/agent/*.yaml
Expand All @@ -45,7 +64,7 @@ class AgentConfigName(Enum):
AUTOMATA_MAIN = "automata-main"


class LLMProvider(Enum):
class LLMProvider(PathEnum):
OPENAI = "openai"


Expand Down Expand Up @@ -84,12 +103,11 @@ def get_llm_provider() -> LLMProvider:
def _load_automata_yaml_config(cls, config_name: AgentConfigName) -> Dict:
file_dir_path = os.path.dirname(os.path.abspath(__file__))
# convert kebab to snake case to support file naming convention
config_file_name = convert_kebab_to_snake(config_name.value)
config_abs_path = os.path.join(
file_dir_path,
ConfigCategory.AGENT.value,
cls.get_llm_provider().value,
f"{config_file_name}.yaml",
ConfigCategory.AGENT.to_path(),
cls.get_llm_provider().to_path(),
f"{config_name.to_path()}.yaml",
)

if not os.path.isfile(config_abs_path):
Expand Down
7 changes: 6 additions & 1 deletion automata/config/openai_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,12 @@ def load(cls, config_name: AgentConfigName) -> "OpenAIAutomataAgentConfig":
return OpenAIAutomataAgentConfig()

loaded_yaml = cls._load_automata_yaml_config(config_name)
return OpenAIAutomataAgentConfig(**loaded_yaml)
casted_config = OpenAIAutomataAgentConfig(**loaded_yaml)
# FIXME - Why is this cast necessary to ensure correct versioning?
casted_config.instruction_version = InstructionConfigVersion(
casted_config.instruction_version
)
return casted_config

@staticmethod
def get_llm_provider() -> LLMProvider:
Expand Down
Binary file removed automata/config/symbol/index.scip
Binary file not shown.
3 changes: 0 additions & 3 deletions automata/config/symbol/symbol_code_embedding.json

This file was deleted.

3 changes: 0 additions & 3 deletions automata/config/symbol/symbol_doc_embedding_l2.json

This file was deleted.

3 changes: 0 additions & 3 deletions automata/config/symbol/symbol_doc_embedding_l3.json

This file was deleted.

2 changes: 1 addition & 1 deletion automata/core/base/database/vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ def load(self) -> None:

def save(self) -> None:
"""As Chroma is a live database, no specific save action is required."""
pass
self.client.persist()

def clear(self) -> None:
"""Clears all entries in the collection, Use with care!"""
Expand Down
7 changes: 6 additions & 1 deletion automata/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,11 @@ def get_root_fpath() -> str:
return os.path.join(get_root_py_fpath(), "..")


def get_embedding_data_fpath() -> str:
"""Get the path to the root of the Automata config directory."""
return os.path.join(get_root_fpath(), "embedding_data")


def get_config_fpath() -> str:
"""Get the path to the root of the Automata config directory."""
return os.path.join(get_root_py_fpath(), "config")
Expand Down Expand Up @@ -61,7 +66,7 @@ def format_text(format_variables: Dict[str, str], input_text: str) -> str:
return input_text


def convert_kebab_to_snake(s: str) -> str:
def convert_kebab_to_snake_case(s: str) -> str:
"""Convert a kebab-case string to snake_case."""
return s.replace("-", "_")

Expand Down
1 change: 0 additions & 1 deletion automata/experimental/search/symbol_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,6 @@ def _find_pattern_in_modules(self, pattern: str) -> Dict[str, List[int]]:
"""Finds exact line matches for a given pattern string in all modules."""
matches = {}
for module_path, module in py_module_loader.items():
print("Checking module = ", module)
if module:
if isinstance(module, RedBaron):
lines = module.dumps().splitlines()
Expand Down
11 changes: 3 additions & 8 deletions automata/memory_store/symbol_code_embedding.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import logging

from automata.core.base.database.vector import VectorDatabaseProvider
from automata.symbol.base import Symbol
from automata.symbol_embedding.builders import SymbolCodeEmbeddingBuilder
from automata.symbol_embedding.handler import SymbolEmbeddingHandler
from automata.symbol_embedding.vector_databases import JSONSymbolEmbeddingVectorDatabase

logger = logging.getLogger(__name__)

Expand All @@ -13,18 +13,14 @@ class SymbolCodeEmbeddingHandler(SymbolEmbeddingHandler):

def __init__(
self,
embedding_db: JSONSymbolEmbeddingVectorDatabase,
embedding_db: VectorDatabaseProvider,
embedding_builder: SymbolCodeEmbeddingBuilder,
) -> None:
super().__init__(embedding_db, embedding_builder)

def process_embedding(self, symbol: Symbol) -> None:
"""Process the embedding for a `Symbol` by updating if the source code has changed."""
source_code = self.embedding_builder.fetch_embedding_source_code(symbol)

if not source_code:
raise ValueError(f"Symbol {symbol} has no source code")

if self.embedding_db.contains(symbol.dotpath):
self.update_existing_embedding(source_code, symbol)
else:
Expand All @@ -37,13 +33,12 @@ def update_existing_embedding(self, source_code: str, symbol: Symbol) -> None:
of the existing embedding. If there are differences, update the embedding.
"""
existing_embedding = self.embedding_db.get(symbol.dotpath)

if existing_embedding.document != source_code:
logger.debug("Building a new embedding for %s", symbol)
self.embedding_db.discard(symbol.dotpath)
symbol_embedding = self.embedding_builder.build(source_code, symbol)
self.embedding_db.add(symbol_embedding)
elif existing_embedding.symbol != symbol:
logger.debug("Updating the embedding for %s", symbol)
self.embedding_db.discard(symbol.dotpath)
existing_embedding.symbol = symbol
self.embedding_db.add(existing_embedding)
Expand Down
10 changes: 5 additions & 5 deletions automata/memory_store/symbol_doc_embedding.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import logging

from automata.core.base.database.vector import VectorDatabaseProvider
from automata.symbol.base import Symbol, SymbolDescriptor
from automata.symbol_embedding.builders import SymbolDocEmbeddingBuilder
from automata.symbol_embedding.handler import SymbolEmbeddingHandler
from automata.symbol_embedding.vector_databases import JSONSymbolEmbeddingVectorDatabase

logger = logging.getLogger(__name__)

Expand All @@ -13,7 +13,7 @@ class SymbolDocEmbeddingHandler(SymbolEmbeddingHandler):

def __init__(
self,
embedding_db: JSONSymbolEmbeddingVectorDatabase,
embedding_db: VectorDatabaseProvider,
embedding_builder: SymbolDocEmbeddingBuilder,
) -> None:
super().__init__(embedding_db, embedding_builder)
Expand All @@ -35,10 +35,10 @@ def process_embedding(self, symbol: Symbol) -> None:
return
if symbol.symbol_kind_by_suffix() == SymbolDescriptor.PyKind.Class:
symbol_embedding = self.embedding_builder.build(source_code, symbol)
else:
if not isinstance(self.embedding_builder, SymbolDocEmbeddingBuilder):
raise ValueError("SymbolDocEmbeddingHandler requires a SymbolDocEmbeddingBuilder")
elif isinstance(self.embedding_builder, SymbolDocEmbeddingBuilder):
symbol_embedding = self.embedding_builder.build_non_class(source_code, symbol)
else:
raise ValueError("SymbolDocEmbeddingHandler requires a SymbolDocEmbeddingBuilder")

self.embedding_db.add(symbol_embedding)

Expand Down
Loading