Skip to content
This repository has been archived by the owner on Mar 16, 2024. It is now read-only.

Commit

Permalink
Update Embedding Workflow (#160)
Browse files Browse the repository at this point in the history
* re-factor graph approach

* rename scip

* Accept missing  in Symbol construction, e.g. allow for symbols that are directly imported

* first v of gen indices

* checkin progress

* checkin incremental progress

* cleanup

* cleanup2

* refresh embeddings

* add save to add

* add db index...

* add llama index

* Onboard langchain, llama-index, make embedding creation more efficient

* checkin

* add super simple example

* simplify demo

* migrate to automata embedding data

* test cli utils

* update workflows

* fix type error

* fix last type errors

* test

* finalize
  • Loading branch information
emrgnt-cmplxty authored Jul 10, 2023
1 parent ee324dc commit aa877dc
Show file tree
Hide file tree
Showing 63 changed files with 654 additions and 297 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ coverage*

# Ignore Local
notebooks/
automata_embedding_factory/
local_env/*
playground/*
*.sqlite3
Expand Down
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,6 @@
path = scip-python
url = https://github.com/emrgnt-cmplxty/scip-python
ignore = all # necessary to avoid issues from modified package-lock.json
[submodule "automata-embedding-data"]
path = automata-embedding-data
url = [email protected]:emrgnt-cmplxty/automata-embedding-data.git
1 change: 1 addition & 0 deletions automata-embedding-data
9 changes: 9 additions & 0 deletions automata/cli/cli_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from automata.core.utils import get_root_fpath
from automata.singletons.py_module_loader import py_module_loader


def initialize_modules(*args, **kwargs) -> None:
root_path = kwargs.get("project_root_fpath") or get_root_fpath()
project_name = kwargs.get("project_name") or "automata"
rel_py_path = kwargs.get("project_rel_py_path") or project_name
py_module_loader.initialize(root_path, rel_py_path)
16 changes: 7 additions & 9 deletions automata/cli/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,19 +19,17 @@ def common_options(command: click.Command, *args, **kwargs) -> click.Command:
help="Execute script in verbose mode?",
),
click.option(
"--index-file",
default="automata.scip",
help="Which index file to use for the embedding modifications.",
"--project_name",
default="automata",
help="The name of the project we are manipulating.",
),
click.option(
"--code-embedding-file",
default="symbol_code_embedding.json",
help="Which embedding file to save to.",
"--project_root_fpath",
help="The root path to the project.",
),
click.option(
"--doc-embedding-file",
default="symbol_doc_embedding_l2.json",
help="Which embedding file to save to.",
"--project_rel_py_path",
help="The relative py path to the project.",
),
]
for option in reversed(options):
Expand Down
2 changes: 1 addition & 1 deletion automata/cli/scripts/run_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
from automata.config import GITHUB_API_KEY, REPOSITORY_NAME
from automata.config.base import AgentConfigName
from automata.config.openai_agent import OpenAIAutomataAgentConfigBuilder
from automata.github_management.client import GitHubClient
from automata.singletons.dependency_factory import dependency_factory
from automata.singletons.github_client import GitHubClient
from automata.singletons.py_module_loader import py_module_loader
from automata.tools.factory import AgentToolFactory

Expand Down
28 changes: 22 additions & 6 deletions automata/cli/scripts/run_code_embedding.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
import logging
import os

from tqdm import tqdm

from automata.cli.cli_utils import initialize_modules
from automata.llm.providers.openai import OpenAIEmbeddingProvider
from automata.memory_store.symbol_code_embedding import SymbolCodeEmbeddingHandler
from automata.singletons.dependency_factory import dependency_factory
from automata.singletons.py_module_loader import py_module_loader
from automata.singletons.dependency_factory import DependencyFactory, dependency_factory
from automata.symbol.graph import SymbolGraph
from automata.symbol.symbol_utils import get_rankable_symbols
from automata.symbol_embedding.base import SymbolCodeEmbedding
from automata.symbol_embedding.vector_databases import (
ChromaSymbolEmbeddingVectorDatabase,
)

logger = logging.getLogger(__name__)

Expand All @@ -16,18 +21,29 @@ def main(*args, **kwargs) -> str:
"""
Update the symbol code embedding based on the specified SCIP index file.
"""
py_module_loader.initialize()
project_name = kwargs.get("project_name") or "automata"
initialize_modules(**kwargs)

symbol_graph = SymbolGraph(
os.path.join(DependencyFactory.DEFAULT_SCIP_FPATH, f"{project_name}.scip")
)

code_embedding_db = ChromaSymbolEmbeddingVectorDatabase(
project_name,
persist_directory=DependencyFactory.DEFAULT_CODE_EMBEDDING_FPATH,
factory=SymbolCodeEmbedding.from_args,
)
embedding_provider = OpenAIEmbeddingProvider()

dependency_factory.set_overrides(
**{
"symbol_graph": symbol_graph,
"code_embedding_db": code_embedding_db,
"embedding_provider": embedding_provider,
"disable_synchronization": True,
"disable_synchronization": True, # We spoof synchronization locally
}
)

symbol_graph: SymbolGraph = dependency_factory.get("symbol_graph")
symbol_code_embedding_handler: SymbolCodeEmbeddingHandler = dependency_factory.get(
"symbol_code_embedding_handler"
)
Expand All @@ -44,5 +60,5 @@ def main(*args, **kwargs) -> str:
except Exception as e:
logger.error(f"Failed to update embedding for {symbol.dotpath}: {e}")

symbol_code_embedding_handler.embedding_db.save()
symbol_code_embedding_handler.flush() # Final flush for any remaining symbols that didn't form a complete batch
return "Success"
51 changes: 38 additions & 13 deletions automata/cli/scripts/run_doc_embedding.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,60 @@
import logging
import os

from tqdm import tqdm

from automata.cli.cli_utils import initialize_modules
from automata.context_providers.symbol_synchronization import (
SymbolProviderSynchronizationContext,
)
from automata.llm.providers.openai import OpenAIEmbeddingProvider
from automata.memory_store.symbol_code_embedding import SymbolCodeEmbeddingHandler
from automata.memory_store.symbol_doc_embedding import SymbolDocEmbeddingHandler
from automata.singletons.dependency_factory import dependency_factory
from automata.singletons.py_module_loader import py_module_loader
from automata.singletons.dependency_factory import DependencyFactory, dependency_factory
from automata.symbol.graph import SymbolGraph
from automata.symbol.symbol_utils import get_rankable_symbols
from automata.symbol_embedding.base import SymbolCodeEmbedding, SymbolDocEmbedding
from automata.symbol_embedding.vector_databases import (
ChromaSymbolEmbeddingVectorDatabase,
)

logger = logging.getLogger(__name__)


def initialize_providers(embedding_level, **kwargs):
py_module_loader.initialize()
project_name = kwargs.get("project_name") or "automata"
initialize_modules(**kwargs)

symbol_graph = SymbolGraph(
os.path.join(DependencyFactory.DEFAULT_SCIP_FPATH, f"{project_name}.scip")
)
code_embedding_db = ChromaSymbolEmbeddingVectorDatabase(
project_name,
persist_directory=DependencyFactory.DEFAULT_CODE_EMBEDDING_FPATH,
factory=SymbolCodeEmbedding.from_args,
)

doc_embedding_db = ChromaSymbolEmbeddingVectorDatabase(
project_name,
persist_directory=DependencyFactory.DEFAULT_DOC_EMBEDDING_FPATH,
factory=SymbolDocEmbedding.from_args,
)

embedding_provider = OpenAIEmbeddingProvider()

overrides = {
"embedding_provider": embedding_provider,
"disable_synchronization": True,
}
dependency_factory.set_overrides(
**{
"symbol_graph": symbol_graph,
"code_embedding_db": code_embedding_db,
"doc_embedding_db": doc_embedding_db,
"embedding_provider": embedding_provider,
"disable_synchronization": True, # We synchronzie locally
}
)

if embedding_level == 3:
raise NotImplementedError("Embedding level 3 is not supported at this moment.")

dependency_factory.set_overrides(**overrides)

symbol_graph: SymbolGraph = dependency_factory.get("symbol_graph")
symbol_code_embedding_handler: SymbolCodeEmbeddingHandler = dependency_factory.get(
"symbol_code_embedding_handler"
)
Expand Down Expand Up @@ -61,8 +84,10 @@ def main(*args, **kwargs) -> str:

logger.info("Looping over filtered symbols...")
for symbol in tqdm(filtered_symbols):
logger.info(f"Caching embedding for {symbol}")
symbol_doc_embedding_handler.process_embedding(symbol)
symbol_doc_embedding_handler.embedding_db.save()
try:
logger.info(f"Caching embedding for {symbol}")
symbol_doc_embedding_handler.process_embedding(symbol)
except Exception as e:
logger.info(f"Error {e} for symbol {symbol}")

return "Success"
17 changes: 11 additions & 6 deletions automata/core/base/database/vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,7 @@ class ChromaVectorDatabase(VectorDatabaseProvider, Generic[K, V]):
def __init__(self, collection_name: str, persist_directory: Optional[str] = None):
self._setup_chroma_client(persist_directory)
self._collection = self.client.get_or_create_collection(collection_name)
self.persist_directory = persist_directory

def _setup_chroma_client(self, persist_directory: Optional[str] = None):
"""Setup the Chroma client, here we attempt to contain the Chroma dependency."""
Expand Down Expand Up @@ -299,12 +300,16 @@ def contains(self, key: K) -> bool:
return len(result["ids"]) != 0

def discard(self, key: K) -> None:
try:
self._collection.delete(ids=[key])
except RuntimeError as e:
# FIXME - It seems an error in Chroma is causing this to be raised falsely
if str(e) != "The requested to delete element is already deleted":
raise
self._collection.delete(ids=[key])
self._save()

def batch_discard(self, keys: List[K]) -> None:
self._collection.delete(ids=keys)
self._save()

def _save(self):
# TODO - Do we need to save after every action?
# I experienced some bugs when not doing so, for now
# we will conservatively save after every action.
if self.persist_directory:
self.save()
2 changes: 1 addition & 1 deletion automata/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def get_root_fpath() -> str:

def get_embedding_data_fpath() -> str:
"""Get the path to the root of the Automata config directory."""
return os.path.join(get_root_fpath(), "embedding_data")
return os.path.join(get_root_fpath(), "automata-embedding-data")


def get_config_fpath() -> str:
Expand Down
33 changes: 18 additions & 15 deletions automata/embedding/base.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
import abc
import logging
from enum import Enum
from typing import Any, Dict, Sequence
from typing import Any, Dict, List, Sequence

import astunparse
import numpy as np

from automata.core.base.database.vector import VectorDatabaseProvider
from automata.symbol.base import Symbol

logger = logging.getLogger(__name__)
Expand All @@ -25,6 +24,10 @@ class EmbeddingVectorProvider(abc.ABC):
def build_embedding_vector(self, symbol_source: str) -> np.ndarray:
pass

@abc.abstractmethod
def batch_build_embedding_vector(self, symbol_source: List[str]) -> List[np.ndarray]:
pass


class Embedding(abc.ABC):
"""Abstract base class for different types of embeddings"""
Expand Down Expand Up @@ -53,6 +56,11 @@ def build(self, source_text: str, symbol: Symbol) -> Any:
"""An abstract method to build the embedding for a symbol"""
pass

@abc.abstractmethod
def batch_build(self, source_text: List[str], symbol: List[Symbol]) -> Any:
"""An abstract method to build the embedding for a symbol"""
pass

def fetch_embedding_source_code(self, symbol: Symbol) -> str:
"""An abstract method for embedding the context is the source code itself."""
from automata.symbol.symbol_utils import ( # imported late for mocking
Expand All @@ -63,26 +71,21 @@ def fetch_embedding_source_code(self, symbol: Symbol) -> str:


class EmbeddingHandler(abc.ABC):
"""An abstract class to handle embeddings"""
"""An abstract class to handle batch embeddings."""

@abc.abstractmethod
def __init__(
self,
embedding_db: VectorDatabaseProvider,
embedding_builder: EmbeddingBuilder,
) -> None:
"""An abstract constructor for EmbeddingHandler"""
self.embedding_db = embedding_db
self.embedding_builder = embedding_builder
def get_embeddings(self, symbols: List[Symbol]) -> List[Any]:
"""An abstract method to get the embeddings entries for a list of symbols."""
pass

@abc.abstractmethod
def get_embedding(self, symbol: Symbol) -> Any:
"""An abstract method to get the embedding for a symbol"""
def process_embedding(self, symbols: Symbol) -> None:
"""An abstract method to process the embeddings for a list of symbols."""
pass

@abc.abstractmethod
def process_embedding(self, symbol: Symbol) -> None:
"""An abstract method to process the embedding for a symbol"""
def flush(self) -> None:
"""Perform any remaining updates that do not form a complete batch."""
pass


Expand Down
5 changes: 5 additions & 0 deletions automata/llm/providers/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,11 @@ def build_embedding_vector(self, source: str) -> np.ndarray:

return np.array(get_embedding(source, engine=self.engine))

def batch_build_embedding_vector(self, sources: List[str]) -> List[np.ndarray]:
from openai.embeddings_utils import get_embeddings

return [np.array(ele) for ele in get_embeddings(sources, engine=self.engine)]


class OpenAITool(Tool):
"""A class representing a tool that can be used by the OpenAI agent."""
Expand Down
Loading

0 comments on commit aa877dc

Please sign in to comment.