Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add CodeInterpreter service #51

Merged
merged 4 commits into from
Feb 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,13 @@ API_BASE_URL=https://rag.superagent.sh
OPENAI_API_KEY=
COHERE_API_KEY=


# Optional for vm sandboxes on E2B Cloud
E2B_API_KEY=

# Public code interpreter setup in E2B Cloud
E2B_SANDBOX_NAME="super-rag"

# Optional for walkthrough
PINECONE_API_KEY=
PINECONE_HOST=
Expand Down
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,4 @@ lint_diff: PYTHON_FILES=$(shell git diff --name-only --diff-filter=d master | gr
lint lint_diff:
poetry run black $(PYTHON_FILES) --check
poetry run ruff .
poetry run vulture . --exclude=venv
poetry run vulture . --exclude=venv
152 changes: 141 additions & 11 deletions poetry.lock

Large diffs are not rendered by default.

3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ pypdf = "^4.0.1"
docx2txt = "^0.8"
python-dotenv = "^1.0.1"
unstructured = {version = "^0.12.4", extras = ["all-docs"], python = "<3.12"}
e2b = "^0.14.4"

[tool.poetry.group.dev.dependencies]
termcolor = "^2.4.0"
Expand All @@ -43,6 +44,7 @@ build-backend = "poetry.core.masonry.api"
exclude = [
"*/test_*.py",
"*/.venv/*.py",
"*/sandboxes/*",
]
ignore_decorators = ["@app.route", "@require_*"]
ignore_names = ["visit_*", "do_*"]
Expand All @@ -57,4 +59,5 @@ exclude = [
"*/docs/*.py",
"*/test_*.py",
"*/.venv/*.py",
"*/sandboxes/*",
]
18 changes: 18 additions & 0 deletions sandboxes/e2b/super-rag/e2b.Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
FROM python:3.11.6

RUN DEBIAN_FRONTEND=noninteractive apt-get update && apt-get install -y \
build-essential curl git util-linux

ENV PIP_DEFAULT_TIMEOUT=100 \
PIP_DISABLE_PIP_VERSION_CHECK=1 \
PIP_NO_CACHE_DIR=1

WORKDIR /code

COPY ./requirements.txt requirements.txt
RUN pip install -r requirements.txt

RUN mkdir -p /home/user/artifacts

RUN echo "export MPLBACKEND=module://e2b_matplotlib_backend" >>~/.bashrc
COPY e2b_matplotlib_backend.py /usr/local/lib/python3.11/site-packages/e2b_matplotlib_backend.py
14 changes: 14 additions & 0 deletions sandboxes/e2b/super-rag/e2b.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# This is a config for E2B sandbox template.
# You can use 'template_id' (1db0ceubobh88yem7h90) or 'template_name (super-rag) from this config to spawn a sandbox:

# Python SDK
# from e2b import Sandbox
# sandbox = Sandbox(template='super-rag')

# JS SDK
# import { Sandbox } from 'e2b'
# const sandbox = await Sandbox.create({ template: 'super-rag' })

dockerfile = "e2b.Dockerfile"
template_name = "super-rag"
template_id = "1db0ceubobh88yem7h90"
22 changes: 22 additions & 0 deletions sandboxes/e2b/super-rag/e2b_matplotlib_backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# flake8: noqa
from time import strftime

from matplotlib.backend_bases import FigureManagerBase, Gcf
from matplotlib.backends.backend_agg import FigureCanvasAgg

dateformat = "%Y%m%d-%H%M%S"
FigureCanvas = FigureCanvasAgg


class FigureManager(FigureManagerBase):
def show(self):
self.canvas.figure.savefig(
f"/home/user/artifacts/figure_{strftime(dateformat)}.png"
)


def show(*_args, **_kwargs):
for _, figmanager in enumerate(Gcf.get_all_fig_managers()):
figmanager.canvas.figure.savefig(
f"/home/user/artifacts/figure_{strftime(dateformat)}.png"
)
29 changes: 29 additions & 0 deletions sandboxes/e2b/super-rag/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
pandas
opencv-python
imageio
scikit-learn
spacy
bokeh
pytest
aiohttp
python-docx
nltk
textblob
beautifulsoup4
seaborn
plotly
tornado
matplotlib
xarray
librosa
gensim
soundfile
pytz
requests
scikit-image
xlrd
scipy
numpy
openpyxl
joblib
urllib3
111 changes: 111 additions & 0 deletions service/code_interpreter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
import asyncio
import logging
import time
from typing import List

from e2b import Sandbox

logging.getLogger("e2b").setLevel(logging.INFO)


class CodeInterpreterService:
timeout = 3 * 60 # 3 minutes

@staticmethod
def _get_file_path(file_url: str):
"""Get the file path in the sandbox for a given file_url."""
return "/code/" + str(hash(file_url))

async def _upload_file(self, file_url: str):
"""Upload a file to the sandbox."""
process = await asyncio.to_thread(
self.sandbox.process.start_and_wait,
f"wget -O {self._get_file_path(file_url)} {file_url}",
)

if process.exit_code != 0:
raise Exception(
f"Error downloading file {file_url} to sandbox {self.sandbox.id}"
)

def _ensure_sandbox(
self,
session_id: str | None,
):
"""
Ensure we have a sandbox for the given session_id exists. If not, create a new one.
If no session_id is given, create a new sandbox that will
be deleted after exiting the object context.
"""
if not session_id:
return Sandbox("super-rag")

sandboxes = Sandbox.list()
for s in sandboxes:
if not s.metadata:
continue
if s.metadata.get("session_id") == session_id:
return Sandbox.reconnect(s.sandbox_id)

return Sandbox(metadata={"session_id": session_id}, template="super-rag")

def __init__(
self,
session_id: str | None,
file_urls: List[str],
):
self.session_id = session_id
self.file_urls = file_urls
self._is_initialized = False

self.sandbox = self._ensure_sandbox(session_id)

async def __aenter__(self):
if not self._is_initialized:
self._is_initialized = True
for file_url in self.file_urls:
await self._upload_file(file_url)

return self

async def __aexit__(self, _exc_type, _exc_value, _traceback):
if self.session_id:
self.sandbox.keep_alive(self.timeout)
self.sandbox.close()

def get_files_code(self):
"""
Get the code to read the files in the sandbox.
This can be used for instructing the LLM how to access the loaded files.
"""

# TODO: Add support for xslx, json
files_code = "\n".join(
f'df{i} = pd.read_csv("{self._get_file_path(url)}") # {url}'
for i, url in enumerate(self.file_urls)
)

return f"""
import pandas as pd

{files_code}

"""

async def run_python(self, code: str):
files_code = self.get_files_code()

templated_code = f"""
{files_code}
{code}
"""

epoch_time = time.time()
codefile_path = f"/tmp/main-{epoch_time}.py"
self.sandbox.filesystem.write(codefile_path, templated_code)
process = await asyncio.to_thread(
self.sandbox.process.start_and_wait,
f"python {codefile_path}",
)

return process
23 changes: 18 additions & 5 deletions service/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from models.document import BaseDocumentChunk
from models.query import RequestPayload
from service.code_interpreter import CodeInterpreterService
from service.embedding import get_encoder
from utils.logger import logger
from utils.summarise import SUMMARY_SUFFIX
Expand Down Expand Up @@ -55,9 +56,21 @@ async def query(payload: RequestPayload) -> list[BaseDocumentChunk]:
)
return await get_documents(vector_service=vector_service, payload=payload)

vector_service: BaseVectorDatabase = get_vector_service(
index_name=payload.index_name,
credentials=payload.vector_database,
encoder=encoder,
)
# vector_service: BaseVectorDatabase = get_vector_service(
# index_name=payload.index_name,
# credentials=payload.vector_database,
# encoder=encoder,
# )

async with CodeInterpreterService(
session_id=payload.session_id,
file_urls=[
"https://raw.githubusercontent.com/datasciencedojo/datasets/master/titanic.csv"
],
) as service:
code = "df0.info()"
output = await service.run_python(code=code)
print(output.stderr)
print(output.stdout)

return await get_documents(vector_service=vector_service, payload=payload)