Skip to content

Commit

Permalink
[Misc] Lint code and fix code smells (#1871)
Browse files Browse the repository at this point in the history
  • Loading branch information
deshraj committed Sep 17, 2024
1 parent 0a78cb9 commit 55c54be
Show file tree
Hide file tree
Showing 57 changed files with 1,165 additions and 1,344 deletions.
6 changes: 3 additions & 3 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,15 @@ install_all:

# Format code with ruff
format:
poetry run ruff check . --fix $(RUFF_OPTIONS)
poetry run ruff format mem0/

# Sort imports with isort
sort:
poetry run isort . $(ISORT_OPTIONS)
poetry run isort mem0/

# Lint code with ruff
lint:
poetry run ruff .
poetry run ruff check mem0/

docs:
cd docs && mintlify dev
Expand Down
20 changes: 7 additions & 13 deletions cookbooks/add_memory_using_qdrant_cloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,27 +7,21 @@

# Loading OpenAI API Key
load_dotenv()
OPENAI_API_KEY = os.environ.get('OPENAI_API_KEY')
OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY")
USER_ID = "test"
quadrant_host="xx.gcp.cloud.qdrant.io"
quadrant_host = "xx.gcp.cloud.qdrant.io"

# creating the config attributes
collection_name="memory" # this is the collection I created in QDRANT cloud
api_key=os.environ.get("QDRANT_API_KEY") # Getting the QDRANT api KEY
host=quadrant_host
port=6333 #Default port for QDRANT cloud
collection_name = "memory" # this is the collection I created in QDRANT cloud
api_key = os.environ.get("QDRANT_API_KEY") # Getting the QDRANT api KEY
host = quadrant_host
port = 6333 # Default port for QDRANT cloud

# Creating the config dict
config = {
"vector_store": {
"provider": "qdrant",
"config": {
"collection_name": collection_name,
"host": host,
"port": port,
"path": None,
"api_key":api_key
}
"config": {"collection_name": collection_name, "host": host, "port": port, "path": None, "api_key": api_key},
}
}

Expand Down
356 changes: 178 additions & 178 deletions cookbooks/mem0-multion.ipynb

Large diffs are not rendered by default.

590 changes: 290 additions & 300 deletions cookbooks/multion_travel_agent.ipynb

Large diffs are not rendered by default.

22 changes: 10 additions & 12 deletions mem0/client/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,11 @@
from mem0.memory.telemetry import capture_client_event

logger = logging.getLogger(__name__)
warnings.filterwarnings('always', category=DeprecationWarning, message="The 'session_id' parameter is deprecated. User 'run_id' instead.")
warnings.filterwarnings(
"always",
category=DeprecationWarning,
message="The 'session_id' parameter is deprecated. User 'run_id' instead.",
)

# Setup user config
setup_config()
Expand Down Expand Up @@ -82,14 +86,10 @@ def _validate_api_key(self):
response = self.client.get("/v1/memories/", params={"user_id": "test"})
response.raise_for_status()
except httpx.HTTPStatusError:
raise ValueError(
"Invalid API Key. Please get a valid API Key from https://app.mem0.ai"
)
raise ValueError("Invalid API Key. Please get a valid API Key from https://app.mem0.ai")

@api_error_handler
def add(
self, messages: Union[str, List[Dict[str, str]]], **kwargs
) -> Dict[str, Any]:
def add(self, messages: Union[str, List[Dict[str, str]]], **kwargs) -> Dict[str, Any]:
"""Add a new memory.
Args:
Expand Down Expand Up @@ -253,9 +253,7 @@ def delete_users(self) -> Dict[str, str]:
"""Delete all users, agents, or sessions."""
entities = self.users()
for entity in entities["results"]:
response = self.client.delete(
f"/v1/entities/{entity['type']}/{entity['id']}/"
)
response = self.client.delete(f"/v1/entities/{entity['type']}/{entity['id']}/")
response.raise_for_status()

capture_client_event("client.delete_users", self)
Expand Down Expand Up @@ -312,7 +310,7 @@ def _prepare_payload(
"The 'session_id' parameter is deprecated and will be removed in version 0.1.20. "
"Use 'run_id' instead.",
DeprecationWarning,
stacklevel=2
stacklevel=2,
)
kwargs["run_id"] = kwargs.pop("session_id")

Expand All @@ -335,7 +333,7 @@ def _prepare_params(self, kwargs: Dict[str, Any]) -> Dict[str, Any]:
"The 'session_id' parameter is deprecated and will be removed in version 0.1.20. "
"Use 'run_id' instead.",
DeprecationWarning,
stacklevel=2
stacklevel=2,
)
kwargs["run_id"] = kwargs.pop("session_id")

Expand Down
29 changes: 12 additions & 17 deletions mem0/configs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,10 @@ class MemoryItem(BaseModel):
) # TODO After prompt changes from platform, update this
hash: Optional[str] = Field(None, description="The hash of the memory")
# The metadata value can be anything and not just string. Fix it
metadata: Optional[Dict[str, Any]] = Field(
None, description="Additional metadata for the text data"
)
score: Optional[float] = Field(
None, description="The score associated with the text data"
)
created_at: Optional[str] = Field(
None, description="The timestamp when the memory was created"
)
updated_at: Optional[str] = Field(
None, description="The timestamp when the memory was updated"
)
metadata: Optional[Dict[str, Any]] = Field(None, description="Additional metadata for the text data")
score: Optional[float] = Field(None, description="The score associated with the text data")
created_at: Optional[str] = Field(None, description="The timestamp when the memory was created")
updated_at: Optional[str] = Field(None, description="The timestamp when the memory was updated")


class MemoryConfig(BaseModel):
Expand Down Expand Up @@ -60,7 +52,7 @@ class MemoryConfig(BaseModel):
description="Custom prompt for the memory",
default=None,
)


class AzureConfig(BaseModel):
"""
Expand All @@ -73,7 +65,10 @@ class AzureConfig(BaseModel):
api_version (str): The version of the Azure API being used.
"""

api_key: str = Field(description="The API key used for authenticating with the Azure service.", default=None)
azure_deployment : str = Field(description="The name of the Azure deployment.", default=None)
azure_endpoint : str = Field(description="The endpoint URL for the Azure service.", default=None)
api_version : str = Field(description="The version of the Azure API being used.", default=None)
api_key: str = Field(
description="The API key used for authenticating with the Azure service.",
default=None,
)
azure_deployment: str = Field(description="The name of the Azure deployment.", default=None)
azure_endpoint: str = Field(description="The endpoint URL for the Azure service.", default=None)
api_version: str = Field(description="The version of the Azure API being used.", default=None)
2 changes: 1 addition & 1 deletion mem0/configs/embeddings/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,6 @@ def __init__(

# Huggingface specific
self.model_kwargs = model_kwargs or {}

# AzureOpenAI specific
self.azure_kwargs = AzureConfig(**azure_kwargs) or {}
1 change: 1 addition & 0 deletions mem0/configs/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
If you do not find anything relevant facts, user memories, and preferences in the below conversation, you can return an empty list corresponding to the "facts" key.
"""


def get_update_memory_messages(retrieved_old_memory_dict, response_content):
return f"""You are a smart memory manager which controls the memory of a system.
You can perform four operations: (1) add into the memory, (2) update the memory, (3) delete from the memory, and (4) no change.
Expand Down
4 changes: 1 addition & 3 deletions mem0/configs/vector_stores/chroma.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,7 @@ class ChromaDbConfig(BaseModel):
Client: ClassVar[type] = Client

collection_name: str = Field("mem0", description="Default name for the collection")
client: Optional[Client] = Field(
None, description="Existing ChromaDB client instance"
)
client: Optional[Client] = Field(None, description="Existing ChromaDB client instance")
path: Optional[str] = Field(None, description="Path to the database directory")
host: Optional[str] = Field(None, description="Database connection remote host")
port: Optional[int] = Field(None, description="Database connection remote port")
Expand Down
22 changes: 12 additions & 10 deletions mem0/configs/vector_stores/milvus.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,24 @@
from enum import Enum
from typing import Dict, Any
from pydantic import BaseModel, model_validator, Field
from typing import Any, Dict

from pydantic import BaseModel, Field, model_validator


class MetricType(str, Enum):
"""
Metric Constant for milvus/ zilliz server.
"""

def __str__(self) -> str:
return str(self.value)

L2 = "L2"
IP = "IP"
COSINE = "COSINE"
HAMMING = "HAMMING"
JACCARD = "JACCARD"
IP = "IP"
COSINE = "COSINE"
HAMMING = "HAMMING"
JACCARD = "JACCARD"


class MilvusDBConfig(BaseModel):
url: str = Field("http://localhost:19530", description="Full URL for Milvus/Zilliz server")
token: str = Field(None, description="Token for Zilliz server / local setup defaults to None.")
Expand All @@ -38,4 +40,4 @@ def validate_extra_fields(cls, values: Dict[str, Any]) -> Dict[str, Any]:

model_config = {
"arbitrary_types_allowed": True,
}
}
5 changes: 1 addition & 4 deletions mem0/configs/vector_stores/pgvector.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,9 @@


class PGVectorConfig(BaseModel):

dbname: str = Field("postgres", description="Default name for the database")
collection_name: str = Field("mem0", description="Default name for the collection")
embedding_model_dims: Optional[int] = Field(
1536, description="Dimensions of the embedding model"
)
embedding_model_dims: Optional[int] = Field(1536, description="Dimensions of the embedding model")
user: Optional[str] = Field(None, description="Database user")
password: Optional[str] = Field(None, description="Database password")
host: Optional[str] = Field(None, description="Database host. Default is localhost")
Expand Down
16 changes: 4 additions & 12 deletions mem0/configs/vector_stores/qdrant.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,11 @@ class QdrantConfig(BaseModel):
QdrantClient: ClassVar[type] = QdrantClient

collection_name: str = Field("mem0", description="Name of the collection")
embedding_model_dims: Optional[int] = Field(
1536, description="Dimensions of the embedding model"
)
client: Optional[QdrantClient] = Field(
None, description="Existing Qdrant client instance"
)
embedding_model_dims: Optional[int] = Field(1536, description="Dimensions of the embedding model")
client: Optional[QdrantClient] = Field(None, description="Existing Qdrant client instance")
host: Optional[str] = Field(None, description="Host address for Qdrant server")
port: Optional[int] = Field(None, description="Port for Qdrant server")
path: Optional[str] = Field(
"/tmp/qdrant", description="Path for local Qdrant database"
)
path: Optional[str] = Field("/tmp/qdrant", description="Path for local Qdrant database")
url: Optional[str] = Field(None, description="Full URL for Qdrant server")
api_key: Optional[str] = Field(None, description="API key for Qdrant server")
on_disk: Optional[bool] = Field(False, description="Enables persistent storage")
Expand All @@ -35,9 +29,7 @@ def check_host_port_or_path(cls, values: Dict[str, Any]) -> Dict[str, Any]:
values.get("api_key"),
)
if not path and not (host and port) and not (url and api_key):
raise ValueError(
"Either 'host' and 'port' or 'url' and 'api_key' or 'path' must be provided."
)
raise ValueError("Either 'host' and 'port' or 'url' and 'api_key' or 'path' must be provided.")
return values

@model_validator(mode="before")
Expand Down
14 changes: 5 additions & 9 deletions mem0/embeddings/azure_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,14 @@ def __init__(self, config: Optional[BaseEmbedderConfig] = None):
azure_deployment = self.config.azure_kwargs.azure_deployment or os.getenv("EMBEDDING_AZURE_DEPLOYMENT")
azure_endpoint = self.config.azure_kwargs.azure_endpoint or os.getenv("EMBEDDING_AZURE_ENDPOINT")
api_version = self.config.azure_kwargs.api_version or os.getenv("EMBEDDING_AZURE_API_VERSION")

self.client = AzureOpenAI(
azure_deployment=azure_deployment,
azure_deployment=azure_deployment,
azure_endpoint=azure_endpoint,
api_version=api_version,
api_key=api_key,
http_client=self.config.http_client
)
http_client=self.config.http_client,
)

def embed(self, text):
"""
Expand All @@ -35,8 +35,4 @@ def embed(self, text):
list: The embedding vector.
"""
text = text.replace("\n", " ")
return (
self.client.embeddings.create(input=[text], model=self.config.model)
.data[0]
.embedding
)
return self.client.embeddings.create(input=[text], model=self.config.model).data[0].embedding
4 changes: 1 addition & 3 deletions mem0/embeddings/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,7 @@ class EmbedderConfig(BaseModel):
description="Provider of the embedding model (e.g., 'ollama', 'openai')",
default="openai",
)
config: Optional[dict] = Field(
description="Configuration for the specific embedding model", default={}
)
config: Optional[dict] = Field(description="Configuration for the specific embedding model", default={})

@field_validator("config")
def validate_config(cls, v, values):
Expand Down
2 changes: 1 addition & 1 deletion mem0/embeddings/ollama.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from ollama import Client
except ImportError:
user_input = input("The 'ollama' library is required. Install it now? [y/N]: ")
if user_input.lower() == 'y':
if user_input.lower() == "y":
try:
subprocess.check_call([sys.executable, "-m", "pip", "install", "ollama"])
from ollama import Client
Expand Down
6 changes: 1 addition & 5 deletions mem0/embeddings/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,4 @@ def embed(self, text):
list: The embedding vector.
"""
text = text.replace("\n", " ")
return (
self.client.embeddings.create(input=[text], model=self.config.model)
.data[0]
.embedding
)
return self.client.embeddings.create(input=[text], model=self.config.model).data[0].embedding
5 changes: 3 additions & 2 deletions mem0/embeddings/vertexai.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from mem0.configs.embeddings.base import BaseEmbedderConfig
from mem0.embeddings.base import EmbeddingBase


class VertexAI(EmbeddingBase):
def __init__(self, config: Optional[BaseEmbedderConfig] = None):
super().__init__(config)
Expand Down Expand Up @@ -34,6 +35,6 @@ def embed(self, text):
Returns:
list: The embedding vector.
"""
embeddings = self.model.get_embeddings(texts=[text], output_dimensionality= self.config.embedding_dims)
embeddings = self.model.get_embeddings(texts=[text], output_dimensionality=self.config.embedding_dims)

return embeddings[0].values
23 changes: 5 additions & 18 deletions mem0/graphs/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,28 +18,16 @@ def check_host_port_or_path(cls, values):
values.get("password"),
)
if not url or not username or not password:
raise ValueError(
"Please provide 'url', 'username' and 'password'."
)
raise ValueError("Please provide 'url', 'username' and 'password'.")
return values


class GraphStoreConfig(BaseModel):
provider: str = Field(
description="Provider of the data store (e.g., 'neo4j')",
default="neo4j"
)
config: Neo4jConfig = Field(
description="Configuration for the specific data store",
default=None
)
llm: Optional[LlmConfig] = Field(
description="LLM configuration for querying the graph store",
default=None
)
provider: str = Field(description="Provider of the data store (e.g., 'neo4j')", default="neo4j")
config: Neo4jConfig = Field(description="Configuration for the specific data store", default=None)
llm: Optional[LlmConfig] = Field(description="LLM configuration for querying the graph store", default=None)
custom_prompt: Optional[str] = Field(
description="Custom prompt to fetch entities from the given text",
default=None
description="Custom prompt to fetch entities from the given text", default=None
)

@field_validator("config")
Expand All @@ -49,4 +37,3 @@ def validate_config(cls, v, values):
return Neo4jConfig(**v.model_dump())
else:
raise ValueError(f"Unsupported graph store provider: {provider}")

Loading

0 comments on commit 55c54be

Please sign in to comment.