Skip to content
This repository has been archived by the owner on Mar 16, 2024. It is now read-only.

refactor docs #106

Merged
merged 1 commit into from
Jul 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ Examples of these classes are:
Code example for creating an instance of 'SymbolCodeEmbedding':
```python
import numpy as np
from automata.core.symbol_embedding.base import SymbolCodeEmbedding
from automata.core.base.symbol_embedding import SymbolCodeEmbedding
from automata.core.symbol.parser import parse_symbol

symbol_str = 'scip-python python automata 75482692a6fe30c72db516201a6f47d9fb4af065 `automata.core.agent.agent_enums`/ActionIndicator#'
Expand All @@ -157,7 +157,7 @@ embedding = SymbolCodeEmbedding(symbol=symbol, source_code=source_code, vector=v

Code example for creating an instance of 'SymbolDocEmbedding':
```python
from automata.core.symbol.base import SymbolDocEmbedding
from automata.core.base.symbol_embedding import SymbolDocEmbedding
from automata.core.symbol.parser import parse_symbol
import numpy as np

Expand Down
6 changes: 3 additions & 3 deletions automata/cli/scripts/run_code_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@
from tqdm import tqdm

from automata.config.base import ConfigCategory
from automata.core.base.database.vector import JSONEmbeddingVectorDatabase
from automata.core.base.symbol_embedding import JSONSymbolEmbeddingVectorDatabase
from automata.core.coding.py.module_loader import py_module_loader
from automata.core.llm.providers.openai import OpenAIEmbeddingProvider
from automata.core.memory_store.symbol_code_embedding import SymbolCodeEmbeddingHandler
from automata.core.symbol.graph import SymbolGraph
from automata.core.symbol.symbol_utils import get_rankable_symbols
from automata.core.symbol_embedding.embedding_builders import SymbolCodeEmbeddingBuilder
from automata.core.symbol_embedding.builders import SymbolCodeEmbeddingBuilder
from automata.core.utils import get_config_fpath

logger = logging.getLogger(__name__)
Expand All @@ -37,7 +37,7 @@ def main(*args, **kwargs) -> str:
all_defined_symbols = symbol_graph.get_all_available_symbols()
filtered_symbols = sorted(get_rankable_symbols(all_defined_symbols), key=lambda x: x.dotpath)

embedding_db = JSONEmbeddingVectorDatabase(embedding_path)
embedding_db = JSONSymbolEmbeddingVectorDatabase(embedding_path)
embedding_provider = OpenAIEmbeddingProvider()
embedding_builder = SymbolCodeEmbeddingBuilder(embedding_provider)
embedding_handler = SymbolCodeEmbeddingHandler(embedding_db, embedding_builder)
Expand Down
10 changes: 5 additions & 5 deletions automata/cli/scripts/run_doc_embedding_l2.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
from tqdm import tqdm

from automata.config.base import ConfigCategory
from automata.core.base.database.vector import JSONEmbeddingVectorDatabase
from automata.core.base.symbol import SymbolDescriptor
from automata.core.base.symbol_embedding import JSONSymbolEmbeddingVectorDatabase
from automata.core.coding.py.module_loader import py_module_loader
from automata.core.context.py.retriever import (
PyContextRetriever,
Expand All @@ -16,12 +17,11 @@
)
from automata.core.memory_store.symbol_code_embedding import SymbolCodeEmbeddingHandler
from automata.core.memory_store.symbol_doc_embedding import SymbolDocEmbeddingHandler
from automata.core.symbol.base import SymbolDescriptor
from automata.core.symbol.graph import SymbolGraph
from automata.core.symbol.search.rank import SymbolRankConfig
from automata.core.symbol.search.symbol_search import SymbolSearch
from automata.core.symbol.symbol_utils import get_rankable_symbols
from automata.core.symbol_embedding.embedding_builders import (
from automata.core.symbol_embedding.builders import (
SymbolCodeEmbeddingBuilder,
SymbolDocEmbeddingBuilder,
)
Expand All @@ -46,7 +46,7 @@ def main(*args, **kwargs) -> str:
code_embedding_fpath = os.path.join(
get_config_fpath(), ConfigCategory.SYMBOL.value, "symbol_code_embedding.json"
)
code_embedding_db = JSONEmbeddingVectorDatabase(code_embedding_fpath)
code_embedding_db = JSONSymbolEmbeddingVectorDatabase(code_embedding_fpath)
embedding_provider = OpenAIEmbeddingProvider()
embedding_builder = SymbolCodeEmbeddingBuilder(embedding_provider)
code_embedding_handler = SymbolCodeEmbeddingHandler(code_embedding_db, embedding_builder)
Expand All @@ -57,7 +57,7 @@ def main(*args, **kwargs) -> str:
ConfigCategory.SYMBOL.value,
kwargs.get("symbol_doc_embedding_l2_fpath", "symbol_doc_embedding_l2.json"),
)
embedding_db_l2 = JSONEmbeddingVectorDatabase(embedding_path_l2)
embedding_db_l2 = JSONSymbolEmbeddingVectorDatabase(embedding_path_l2)

symbol_graph = SymbolGraph(scip_path)

Expand Down
12 changes: 6 additions & 6 deletions automata/cli/scripts/run_doc_embedding_l3.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
from tqdm import tqdm

from automata.config.base import ConfigCategory
from automata.core.base.database.vector import JSONEmbeddingVectorDatabase
from automata.core.base.symbol import SymbolDescriptor
from automata.core.base.symbol_embedding import JSONSymbolEmbeddingVectorDatabase
from automata.core.coding.py.module_loader import py_module_loader
from automata.core.context.py.retriever import (
PyContextRetriever,
Expand All @@ -16,12 +17,11 @@
)
from automata.core.memory_store.symbol_code_embedding import SymbolCodeEmbeddingHandler
from automata.core.memory_store.symbol_doc_embedding import SymbolDocEmbeddingHandler
from automata.core.symbol.base import SymbolDescriptor
from automata.core.symbol.graph import SymbolGraph
from automata.core.symbol.search.rank import SymbolRankConfig
from automata.core.symbol.search.symbol_search import SymbolSearch
from automata.core.symbol.symbol_utils import get_rankable_symbols
from automata.core.symbol_embedding.embedding_builders import (
from automata.core.symbol_embedding.builders import (
SymbolCodeEmbeddingBuilder,
SymbolDocEmbeddingBuilder,
)
Expand All @@ -45,7 +45,7 @@ def main(*args, **kwargs) -> str:
code_embedding_fpath = os.path.join(
get_config_fpath(), ConfigCategory.SYMBOL.value, "symbol_code_embedding.json"
)
code_embedding_db = JSONEmbeddingVectorDatabase(code_embedding_fpath)
code_embedding_db = JSONSymbolEmbeddingVectorDatabase(code_embedding_fpath)
embedding_provider = OpenAIEmbeddingProvider()
embedding_builder = SymbolCodeEmbeddingBuilder(embedding_provider)
code_embedding_handler = SymbolCodeEmbeddingHandler(code_embedding_db, embedding_builder)
Expand All @@ -56,7 +56,7 @@ def main(*args, **kwargs) -> str:
ConfigCategory.SYMBOL.value,
kwargs.get("symbol_doc_embedding_l2_fpath", "symbol_doc_embedding_l2.json"),
)
embedding_db_l2 = JSONEmbeddingVectorDatabase(embedding_path_l2)
embedding_db_l2 = JSONSymbolEmbeddingVectorDatabase(embedding_path_l2)

embedding_path_l3 = os.path.join(
get_config_fpath(),
Expand All @@ -66,7 +66,7 @@ def main(*args, **kwargs) -> str:

symbol_graph = SymbolGraph(scip_path)

embedding_db_l3 = JSONEmbeddingVectorDatabase(embedding_path_l3)
embedding_db_l3 = JSONSymbolEmbeddingVectorDatabase(embedding_path_l3)

symbol_code_similarity = SymbolSimilarityCalculator(code_embedding_handler, embedding_provider)

Expand Down
4 changes: 2 additions & 2 deletions automata/cli/scripts/run_doc_post_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import os

from automata.config.base import ConfigCategory
from automata.core.base.database.vector import JSONEmbeddingVectorDatabase
from automata.core.base.symbol_embedding import JSONSymbolEmbeddingVectorDatabase
from automata.core.coding.py.writer import PyDocWriter
from automata.core.utils import get_config_fpath, get_root_py_fpath

Expand All @@ -19,7 +19,7 @@ def main(*args, **kwargs) -> str:
get_config_fpath(), ConfigCategory.SYMBOL.value, "symbol_doc_embedding_l2.json"
)

embedding_db = JSONEmbeddingVectorDatabase(embedding_path)
embedding_db = JSONSymbolEmbeddingVectorDatabase(embedding_path)

symbols = [embedding.symbol for embedding in embedding_db.get_all_entries()]

Expand Down
Binary file modified automata/config/symbol/index.scip
Binary file not shown.
4 changes: 2 additions & 2 deletions automata/config/symbol/symbol_code_embedding.json
Git LFS file not shown
4 changes: 2 additions & 2 deletions automata/config/symbol/symbol_doc_embedding_l2.json
Git LFS file not shown
4 changes: 2 additions & 2 deletions automata/config/symbol/symbol_doc_embedding_l3.json
Git LFS file not shown
12 changes: 6 additions & 6 deletions automata/core/agent/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def __init__(self, instructions: str, config: AutomataOpenAIAgentConfig) -> None
super().__init__(instructions)
self.config = config
self.iteration_count = 0
self.agent_conversations = OpenAIConversation()
self.agent_conversation_database = OpenAIConversation()
self.completed = False
self._setup()

Expand Down Expand Up @@ -133,7 +133,7 @@ def run(self) -> str:
except AgentStopIteration:
break

last_message = self.agent_conversations.get_latest_message()
last_message = self.agent_conversation_database.get_latest_message()
if self.iteration_count >= self.config.max_iterations:
raise AgentMaxIterError("The agent exceeded the maximum number of iterations.")
if not self.completed or not isinstance(last_message, OpenAIChatMessage):
Expand All @@ -148,7 +148,7 @@ def set_database_provider(self, provider: LLMConversationDatabaseProvider) -> No
if self.database_provider:
raise AgentDatabaseError("The database provider has already been set.")
self.database_provider = provider
self.agent_conversations.register_observer(provider)
self.agent_conversation_database.register_observer(provider)

def _build_initial_messages(
self, instruction_formatter: Dict[str, str]
Expand Down Expand Up @@ -217,21 +217,21 @@ def _setup(self) -> None:
AgentError: If the agent fails to initialize.
"""
logger.debug(f"Setting up agent with tools = {self.config.tools}")
self.agent_conversations.add_message(
self.agent_conversation_database.add_message(
OpenAIChatMessage(role="system", content=self.config.system_instruction)
)
for message in list(
self._build_initial_messages({"user_input_instructions": self.instructions})
):
logger.debug(f"Adding the following initial mesasge to the conversation {message}")
self.agent_conversations.add_message(message)
self.agent_conversation_database.add_message(message)
logging.debug(f"\n{('-' * 120)}")

self.chat_provider = OpenAIChatCompletionProvider(
model=self.config.model,
temperature=self.config.temperature,
stream=self.config.stream,
conversation=self.agent_conversations,
conversation=self.agent_conversation_database,
functions=self.functions,
)
self._initialized = True
Expand Down
8 changes: 4 additions & 4 deletions automata/core/agent/tool/tool_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
from automata.config.base import ConfigCategory, LLMProvider
from automata.core.agent.error import AgentGeneralError, UnknownToolError
from automata.core.base.agent import AgentToolProviders
from automata.core.base.database.vector import JSONEmbeddingVectorDatabase
from automata.core.base.singleton import Singleton
from automata.core.base.symbol_embedding import JSONSymbolEmbeddingVectorDatabase
from automata.core.base.tool import Tool
from automata.core.coding.py.reader import PyReader
from automata.core.coding.py.writer import PyWriter
Expand All @@ -24,7 +24,7 @@
from automata.core.symbol.graph import SymbolGraph
from automata.core.symbol.search.rank import SymbolRank, SymbolRankConfig
from automata.core.symbol.search.symbol_search import SymbolSearch
from automata.core.symbol_embedding.embedding_builders import SymbolDocEmbeddingBuilder
from automata.core.symbol_embedding.builders import SymbolDocEmbeddingBuilder
from automata.core.symbol_embedding.similarity import SymbolSimilarityCalculator
from automata.core.utils import get_config_fpath

Expand Down Expand Up @@ -154,7 +154,7 @@ def create_symbol_code_similarity(self) -> SymbolSimilarityCalculator:
code_embedding_fpath = self.overrides.get(
"code_embedding_fpath", DependencyFactory.DEFAULT_CODE_EMBEDDING_FPATH
)
code_embedding_db = JSONEmbeddingVectorDatabase(code_embedding_fpath)
code_embedding_db = JSONSymbolEmbeddingVectorDatabase(code_embedding_fpath)

embedding_provider = self.overrides.get("embedding_provider", OpenAIEmbeddingProvider())
code_embedding_handler = SymbolCodeEmbeddingHandler(code_embedding_db, embedding_provider)
Expand All @@ -170,7 +170,7 @@ def create_symbol_doc_similarity(self) -> SymbolSimilarityCalculator:
doc_embedding_fpath = self.overrides.get(
"doc_embedding_fpath", DependencyFactory.DEFAULT_DOC_EMBEDDING_FPATH
)
doc_embedding_db = JSONEmbeddingVectorDatabase(doc_embedding_fpath)
doc_embedding_db = JSONSymbolEmbeddingVectorDatabase(doc_embedding_fpath)

embedding_provider = self.overrides.get("embedding_provider", OpenAIEmbeddingProvider())
symbol_search = self.get("symbol_search")
Expand Down
76 changes: 0 additions & 76 deletions automata/core/base/database/vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@

import jsonpickle

from automata.core.symbol_embedding.base import SymbolEmbedding

logger = logging.getLogger(__name__)

T = TypeVar("T")
Expand Down Expand Up @@ -125,77 +123,3 @@ def get(self, key: K) -> T:
def clear(self):
self.data = []
self.index = {}


class JSONEmbeddingVectorDatabase(JSONVectorDatabase):
"""Concrete class to provide a vector database that saves into a JSON file."""

def __init__(self, file_path: str):
super().__init__(file_path)

def entry_to_key(self, entry: SymbolEmbedding) -> str:
"""Method to generate a hashable key from an entry of type T."""
return entry.symbol.dotpath

def get_all_entries(self) -> List[SymbolEmbedding]:
return sorted(self.data, key=lambda x: self.entry_to_key(x))


# class JSONEmbeddingVectorDatabase(VectorDatabaseProvider):
# """Concrete class to provide a vector database that saves into a JSON file."""

# def __init__(self, file_path: str):
# self.file_path = file_path
# self.data: List[SymbolEmbedding] = []
# self.index: Dict[str, int] = {}
# self.load()

# def save(self):
# """Saves the vector database to the JSON file."""
# with open(self.file_path, "w") as file:
# encoded_data = jsonpickle.encode(self.data)
# file.write(encoded_data)

# def load(self):
# """Loads the vector database from the JSON file."""
# try:
# with open(self.file_path, "r") as file:
# self.data = jsonpickle.decode(file.read())
# # We index on the dotpath of the symbol, which is unique and indepenent of commit hash
# self.index = {embedding.symbol.dotpath: i for i, embedding in enumerate(self.data)}
# except FileNotFoundError:
# logger.info(f"Creating new vector embedding db at {self.file_path}")

# def add(self, embedding: SymbolEmbedding):
# self.data.append(embedding)
# self.index[embedding.symbol.dotpath] = len(self.data) - 1

# def update_database(self, embedding: SymbolEmbedding):
# if embedding.symbol not in self.index:
# raise KeyError(f"Symbol {embedding.symbol} not in database")
# self.data[self.index[embedding.symbol.dotpath]] = embedding

# def discard(self, symbol: Symbol):
# if symbol.dotpath not in self.index:
# raise KeyError(f"Symbol {symbol} not in database")
# index = self.index[symbol.dotpath]
# del self.data[index]
# del self.index[symbol.dotpath]
# # Recalculate indices after deletion
# self.index = {embedding.symbol.dotpath: i for i, embedding in enumerate(self.data)}

# def contains(self, symbol: Symbol) -> bool:
# return symbol.dotpath in self.index

# def get(self, symbol: Symbol) -> SymbolEmbedding:
# if symbol.dotpath not in self.index:
# raise KeyError(f"Symbol {symbol} not in database")
# return self.data[self.index[symbol.dotpath]]

# def clear(self):
# self.data = []
# self.index = {}

# def get_all_entries(self) -> List[Symbol]:
# symbol_list = [embedding.symbol for embedding in self.data]
# return sorted(symbol_list, key=lambda x: str(x.dotpath))
51 changes: 51 additions & 0 deletions automata/core/base/embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import abc
import logging
from enum import Enum
from typing import Dict

import numpy as np

from automata.core.base.symbol import Symbol

logger = logging.getLogger(__name__)


class EmbeddingNormType(Enum):
L1 = "l1"
L2 = "l2"
SOFTMAX = "softmax"


class EmbeddingProvider(abc.ABC):
"""A class to provide embeddings for symbols"""

@abc.abstractmethod
def build_embedding_array(self, symbol_source: str) -> np.ndarray:
pass


class EmbeddingSimilarityCalculator(abc.ABC):
@abc.abstractmethod
def calculate_query_similarity_dict(self, query_text: str) -> Dict[Symbol, float]:
"""An abstract method to get the similarity between a query and all symbols"""
pass

@abc.abstractmethod
def _calculate_embedding_similarity(self, embedding_array: np.ndarray) -> np.ndarray:
"""An abstract method to calculate the similarity between the embedding array target embeddings."""
pass

@staticmethod
def _normalize_embeddings(
embeddings_array: np.ndarray, norm_type: EmbeddingNormType
) -> np.ndarray:
if norm_type == EmbeddingNormType.L1:
norm = np.sum(np.abs(embeddings_array), axis=1, keepdims=True)
return embeddings_array / norm
elif norm_type == EmbeddingNormType.L2:
return embeddings_array / np.linalg.norm(embeddings_array, axis=1, keepdims=True)
elif norm_type == EmbeddingNormType.SOFTMAX:
e_x = np.exp(embeddings_array - np.max(embeddings_array, axis=1, keepdims=True))
return e_x / np.sum(e_x, axis=1, keepdims=True)
else:
raise ValueError(f"Invalid normalization type {norm_type}")
Loading