Skip to content
This repository has been archived by the owner on Mar 16, 2024. It is now read-only.

Commit

Permalink
Feature/merged default chroma (#150)
Browse files Browse the repository at this point in the history
* Change default db to chroma, remove old db setup

* fix test error

* steps towards working code embedding

* Adding embedding dbs for now

* refresh code embeddings

* update code embedding db

* update embedding db

* finalize db update
  • Loading branch information
emrgnt-cmplxty committed Jul 8, 2023
1 parent 08b1aa7 commit 8c0282d
Show file tree
Hide file tree
Showing 43 changed files with 192 additions and 203 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -60,4 +60,5 @@ local_tasks/
scripts/setup.sh
setup.sh
.direnv
docs/_build
docs/_build
.chroma
2 changes: 1 addition & 1 deletion automata/agent/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def _build_initial_messages(
assert "user_input_instructions" in instruction_formatter

messages_config = load_config(
ConfigCategory.INSTRUCTION.value, self.config.instruction_version.value
ConfigCategory.INSTRUCTION.to_path(), self.config.instruction_version.to_path()
)
initial_messages = messages_config["initial_messages"]

Expand Down
4 changes: 2 additions & 2 deletions automata/cli/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def common_options(command: click.Command, *args, **kwargs) -> click.Command:
),
click.option(
"--index-file",
default="index.scip",
default="automata.scip",
help="Which index file to use for the embedding modifications.",
),
click.option(
Expand All @@ -30,7 +30,7 @@ def common_options(command: click.Command, *args, **kwargs) -> click.Command:
),
click.option(
"--doc-embedding-file",
default="symbol_doc_embedding_l3.json",
default="symbol_doc_embedding_l2.json",
help="Which embedding file to save to.",
),
]
Expand Down
4 changes: 3 additions & 1 deletion automata/cli/scripts/run_agent_config_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,9 @@ def test_yaml_compatibility(file_path) -> None:

if __name__ == "__main__":
# Find all .yaml files in the specified directory
yaml_files = glob.glob(os.path.join(get_config_fpath(), ConfigCategory.AGENT.value, "*.yaml"))
yaml_files = glob.glob(
os.path.join(get_config_fpath(), ConfigCategory.AGENT.to_path(), "*.yaml")
)

# Run validation and compatibility tests on each YAML file
for yaml_file in yaml_files:
Expand Down
15 changes: 1 addition & 14 deletions automata/cli/scripts/run_code_embedding.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
import logging
import os

from tqdm import tqdm

from automata.config.base import ConfigCategory
from automata.core.utils import get_config_fpath
from automata.llm.providers.openai import OpenAIEmbeddingProvider
from automata.memory_store.symbol_code_embedding import SymbolCodeEmbeddingHandler
from automata.singletons.dependency_factory import dependency_factory
Expand All @@ -22,20 +19,10 @@ def main(*args, **kwargs) -> str:

py_module_loader.initialize()

scip_fpath = os.path.join(
get_config_fpath(), ConfigCategory.SYMBOL.value, kwargs.get("index-file", "index.scip")
)
code_embedding_fpath = os.path.join(
get_config_fpath(),
ConfigCategory.SYMBOL.value,
kwargs.get("code-embedding-file", "symbol_code_embedding.json"),
)
embedding_provider = OpenAIEmbeddingProvider()

dependency_factory.set_overrides(
**{
"symbol_graph_scip_fpath": scip_fpath,
"code_embedding_fpath": code_embedding_fpath,
"embedding_provider": embedding_provider,
"disable_synchronization": True,
}
Expand All @@ -55,8 +42,8 @@ def main(*args, **kwargs) -> str:
for symbol in tqdm(filtered_symbols):
try:
symbol_code_embedding_handler.process_embedding(symbol)
symbol_code_embedding_handler.embedding_db.save()
except Exception as e:
logger.error(f"Failed to update embedding for {symbol.dotpath}: {e}")

symbol_code_embedding_handler.embedding_db.save()
return "Success"
31 changes: 2 additions & 29 deletions automata/cli/scripts/run_doc_embedding.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,17 @@
import logging
import os

from tqdm import tqdm

from automata.config.base import ConfigCategory
from automata.context_providers.symbol_synchronization import (
SymbolProviderSynchronizationContext,
)
from automata.core.utils import get_config_fpath
from automata.llm.providers.openai import OpenAIEmbeddingProvider
from automata.memory_store.symbol_code_embedding import SymbolCodeEmbeddingHandler
from automata.memory_store.symbol_doc_embedding import SymbolDocEmbeddingHandler
from automata.singletons.dependency_factory import dependency_factory
from automata.singletons.py_module_loader import py_module_loader, pyast_module_loader
from automata.symbol.graph import SymbolGraph
from automata.symbol.symbol_utils import get_rankable_symbols
from automata.symbol_embedding.vector_databases import JSONSymbolEmbeddingVectorDatabase

logger = logging.getLogger(__name__)

Expand All @@ -24,38 +20,15 @@ def initialize_providers(embedding_level, **kwargs):
py_module_loader.initialize()
pyast_module_loader.initialize()

scip_fpath = os.path.join(
get_config_fpath(), ConfigCategory.SYMBOL.value, kwargs.get("index-file", "index.scip")
)
code_embedding_fpath = os.path.join(
get_config_fpath(),
ConfigCategory.SYMBOL.value,
kwargs.get("code-embedding-file", "symbol_code_embedding.json"),
)
doc_embedding_fpath = os.path.join(
get_config_fpath(),
ConfigCategory.SYMBOL.value,
kwargs.get("doc-embedding-file", f"symbol_doc_embedding_l{embedding_level}.json"),
)

embedding_provider = OpenAIEmbeddingProvider()

overrides = {
"symbol_graph_scip_fpath": scip_fpath,
"code_embedding_fpath": code_embedding_fpath,
"doc_embedding_fpath": doc_embedding_fpath,
"embedding_provider": embedding_provider,
"disable_synchronization": True,
}

if embedding_level == 3:
doc_embedding_fpath_l2 = os.path.join(
get_config_fpath(),
ConfigCategory.SYMBOL.value,
kwargs.get("doc-embedding-file", "symbol_doc_embedding_l2.json"),
)
doc_embedding_db_l2 = JSONSymbolEmbeddingVectorDatabase(doc_embedding_fpath_l2)
overrides["py_retriever_doc_embedding_db"] = doc_embedding_db_l2
raise NotImplementedError("Embedding level 3 is not supported at this moment.")

dependency_factory.set_overrides(**overrides)

Expand Down Expand Up @@ -91,6 +64,6 @@ def main(*args, **kwargs) -> str:
for symbol in tqdm(filtered_symbols):
logger.info(f"Caching embedding for {symbol}")
symbol_doc_embedding_handler.process_embedding(symbol)
symbol_doc_embedding_handler.embedding_db.save()
symbol_doc_embedding_handler.embedding_db.save()

return "Success"
2 changes: 1 addition & 1 deletion automata/cli/scripts/run_doc_post_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def main(*args, **kwargs) -> str:
doc_writer = PyDocWriter(get_root_fpath())

embedding_path = os.path.join(
get_config_fpath(), ConfigCategory.SYMBOL.value, "symbol_doc_embedding_l2.json"
get_config_fpath(), ConfigCategory.SYMBOL.to_path(), "symbol_doc_embedding_l2.json"
)

embedding_db = JSONSymbolEmbeddingVectorDatabase(embedding_path)
Expand Down
40 changes: 29 additions & 11 deletions automata/config/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,30 @@
import yaml
from pydantic import BaseModel, PrivateAttr

from automata.core.utils import convert_kebab_to_snake
from automata.core.utils import convert_kebab_to_snake_case
from automata.tools.base import Tool


class ConfigCategory(Enum):
class PathEnum(Enum):

"""An abstract class for enums that represent paths"""

def to_path(self) -> str:
return convert_kebab_to_snake_case(self.value)


class EmbeddingDataCategory(PathEnum):
"""
A class to represent the different categories of configuration options
Corresponds folders in automata/configs/*
"""

CODE_EMBEDDING = "code-embedding"
DOC_EMBEDDING = "doc-embedding-l2"
INDICES = "indices"


class ConfigCategory(PathEnum):
"""
A class to represent the different categories of configuration options
Corresponds folders in automata/configs/*
Expand All @@ -19,19 +38,19 @@ class ConfigCategory(Enum):
AGENT = "agent"
PROMPT = "prompt"
SYMBOL = "symbol"
INSTRUCTION = "instruction_configs"
INSTRUCTION = "instruction-configs"


class InstructionConfigVersion(Enum):
class InstructionConfigVersion(PathEnum):
"""
InstructionConfigVersion: Enum of instruction versions.
Corresponds files in automata/configs/instruction_configs/*.yaml
"""

AGENT_INTRODUCTION = "agent_introduction"
AGENT_INTRODUCTION = "agent-introduction"


class AgentConfigName(Enum):
class AgentConfigName(PathEnum):
"""
AgentConfigName: Enum of agent config names.
Corresponds files in automata/config/agent/*.yaml
Expand All @@ -45,7 +64,7 @@ class AgentConfigName(Enum):
AUTOMATA_MAIN = "automata-main"


class LLMProvider(Enum):
class LLMProvider(PathEnum):
OPENAI = "openai"


Expand Down Expand Up @@ -84,12 +103,11 @@ def get_llm_provider() -> LLMProvider:
def _load_automata_yaml_config(cls, config_name: AgentConfigName) -> Dict:
file_dir_path = os.path.dirname(os.path.abspath(__file__))
# convert kebab to snake case to support file naming convention
config_file_name = convert_kebab_to_snake(config_name.value)
config_abs_path = os.path.join(
file_dir_path,
ConfigCategory.AGENT.value,
cls.get_llm_provider().value,
f"{config_file_name}.yaml",
ConfigCategory.AGENT.to_path(),
cls.get_llm_provider().to_path(),
f"{config_name.to_path()}.yaml",
)

if not os.path.isfile(config_abs_path):
Expand Down
7 changes: 6 additions & 1 deletion automata/config/openai_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,12 @@ def load(cls, config_name: AgentConfigName) -> "OpenAIAutomataAgentConfig":
return OpenAIAutomataAgentConfig()

loaded_yaml = cls._load_automata_yaml_config(config_name)
return OpenAIAutomataAgentConfig(**loaded_yaml)
casted_config = OpenAIAutomataAgentConfig(**loaded_yaml)
# FIXME - Why is this cast necessary to ensure correct versioning?
casted_config.instruction_version = InstructionConfigVersion(
casted_config.instruction_version
)
return casted_config

@staticmethod
def get_llm_provider() -> LLMProvider:
Expand Down
Binary file removed automata/config/symbol/index.scip
Binary file not shown.
3 changes: 0 additions & 3 deletions automata/config/symbol/symbol_code_embedding.json

This file was deleted.

3 changes: 0 additions & 3 deletions automata/config/symbol/symbol_doc_embedding_l2.json

This file was deleted.

3 changes: 0 additions & 3 deletions automata/config/symbol/symbol_doc_embedding_l3.json

This file was deleted.

2 changes: 1 addition & 1 deletion automata/core/base/database/vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ def load(self) -> None:

def save(self) -> None:
"""As Chroma is a live database, no specific save action is required."""
pass
self.client.persist()

def clear(self) -> None:
"""Clears all entries in the collection, Use with care!"""
Expand Down
7 changes: 6 additions & 1 deletion automata/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,11 @@ def get_root_fpath() -> str:
return os.path.join(get_root_py_fpath(), "..")


def get_embedding_data_fpath() -> str:
"""Get the path to the root of the Automata config directory."""
return os.path.join(get_root_fpath(), "embedding_data")


def get_config_fpath() -> str:
"""Get the path to the root of the Automata config directory."""
return os.path.join(get_root_py_fpath(), "config")
Expand Down Expand Up @@ -61,7 +66,7 @@ def format_text(format_variables: Dict[str, str], input_text: str) -> str:
return input_text


def convert_kebab_to_snake(s: str) -> str:
def convert_kebab_to_snake_case(s: str) -> str:
"""Convert a kebab-case string to snake_case."""
return s.replace("-", "_")

Expand Down
11 changes: 3 additions & 8 deletions automata/memory_store/symbol_code_embedding.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import logging

from automata.core.base.database.vector import VectorDatabaseProvider
from automata.symbol.base import Symbol
from automata.symbol_embedding.builders import SymbolCodeEmbeddingBuilder
from automata.symbol_embedding.handler import SymbolEmbeddingHandler
from automata.symbol_embedding.vector_databases import JSONSymbolEmbeddingVectorDatabase

logger = logging.getLogger(__name__)

Expand All @@ -13,18 +13,14 @@ class SymbolCodeEmbeddingHandler(SymbolEmbeddingHandler):

def __init__(
self,
embedding_db: JSONSymbolEmbeddingVectorDatabase,
embedding_db: VectorDatabaseProvider,
embedding_builder: SymbolCodeEmbeddingBuilder,
) -> None:
super().__init__(embedding_db, embedding_builder)

def process_embedding(self, symbol: Symbol) -> None:
"""Process the embedding for a `Symbol` by updating if the source code has changed."""
source_code = self.embedding_builder.fetch_embedding_source_code(symbol)

if not source_code:
raise ValueError(f"Symbol {symbol} has no source code")

if self.embedding_db.contains(symbol.dotpath):
self.update_existing_embedding(source_code, symbol)
else:
Expand All @@ -37,13 +33,12 @@ def update_existing_embedding(self, source_code: str, symbol: Symbol) -> None:
of the existing embedding. If there are differences, update the embedding.
"""
existing_embedding = self.embedding_db.get(symbol.dotpath)

if existing_embedding.document != source_code:
logger.debug("Building a new embedding for %s", symbol)
self.embedding_db.discard(symbol.dotpath)
symbol_embedding = self.embedding_builder.build(source_code, symbol)
self.embedding_db.add(symbol_embedding)
elif existing_embedding.symbol != symbol:
logger.debug("Updating the embedding for %s", symbol)
self.embedding_db.discard(symbol.dotpath)
existing_embedding.symbol = symbol
self.embedding_db.add(existing_embedding)
Expand Down
10 changes: 5 additions & 5 deletions automata/memory_store/symbol_doc_embedding.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import logging

from automata.core.base.database.vector import VectorDatabaseProvider
from automata.symbol.base import Symbol, SymbolDescriptor
from automata.symbol_embedding.builders import SymbolDocEmbeddingBuilder
from automata.symbol_embedding.handler import SymbolEmbeddingHandler
from automata.symbol_embedding.vector_databases import JSONSymbolEmbeddingVectorDatabase

logger = logging.getLogger(__name__)

Expand All @@ -13,7 +13,7 @@ class SymbolDocEmbeddingHandler(SymbolEmbeddingHandler):

def __init__(
self,
embedding_db: JSONSymbolEmbeddingVectorDatabase,
embedding_db: VectorDatabaseProvider,
embedding_builder: SymbolDocEmbeddingBuilder,
) -> None:
super().__init__(embedding_db, embedding_builder)
Expand All @@ -35,10 +35,10 @@ def process_embedding(self, symbol: Symbol) -> None:
return
if symbol.symbol_kind_by_suffix() == SymbolDescriptor.PyKind.Class:
symbol_embedding = self.embedding_builder.build(source_code, symbol)
else:
if not isinstance(self.embedding_builder, SymbolDocEmbeddingBuilder):
raise ValueError("SymbolDocEmbeddingHandler requires a SymbolDocEmbeddingBuilder")
elif isinstance(self.embedding_builder, SymbolDocEmbeddingBuilder):
symbol_embedding = self.embedding_builder.build_non_class(source_code, symbol)
else:
raise ValueError("SymbolDocEmbeddingHandler requires a SymbolDocEmbeddingBuilder")

self.embedding_db.add(symbol_embedding)

Expand Down
Loading

0 comments on commit 8c0282d

Please sign in to comment.