Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create edgedb.ai package #489

Merged
merged 5 commits into from
May 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions edgedb/ai/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
#
# This source file is part of the EdgeDB open source project.
#
# Copyright 2024-present MagicStack Inc. and the EdgeDB authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from .types import AIOptions, ChatParticipantRole, Prompt, QueryContext
from .core import create_ai, EdgeDBAI
from .core import create_async_ai, AsyncEdgeDBAI

__all__ = [
"AIOptions",
"ChatParticipantRole",
"Prompt",
"QueryContext",
"create_ai",
"EdgeDBAI",
"create_async_ai",
"AsyncEdgeDBAI",
]
174 changes: 174 additions & 0 deletions edgedb/ai/core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
#
# This source file is part of the EdgeDB open source project.
#
# Copyright 2024-present MagicStack Inc. and the EdgeDB authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from __future__ import annotations
import typing

import edgedb
import httpx
import httpx_sse

from . import types


def create_ai(client: edgedb.Client, **kwargs) -> EdgeDBAI:
client.ensure_connected()
return EdgeDBAI(client, types.AIOptions(**kwargs))


async def create_async_ai(
client: edgedb.AsyncIOClient, **kwargs
) -> AsyncEdgeDBAI:
await client.ensure_connected()
return AsyncEdgeDBAI(client, types.AIOptions(**kwargs))


class BaseEdgeDBAI:
options: types.AIOptions
context: types.QueryContext
client_cls = NotImplemented

def __init__(
self,
client: typing.Union[edgedb.Client, edgedb.AsyncIOClient],
options: types.AIOptions,
**kwargs,
):
pool = client._impl
host, port = pool._working_addr
params = pool._working_params
proto = "http" if params.tls_security == "insecure" else "https"
branch = params.branch
self.options = options
self.context = types.QueryContext(**kwargs)
args = dict(
base_url=f"{proto}://{host}:{port}/branch/{branch}/ext/ai",
verify=params.ssl_ctx,
)
if params.password is not None:
args["auth"] = (params.user, params.password)
elif params.secret_key is not None:
args["headers"] = {"Authorization": f"Bearer {params.secret_key}"}
self._init_client(**args)

def _init_client(self, **kwargs):
raise NotImplementedError

def with_config(self, **kwargs) -> typing.Self:
cls = type(self)
rv = cls.__new__(cls)
rv.options = self.options.derive(kwargs)
rv.context = self.context
rv.client = self.client
return rv

def with_context(self, **kwargs) -> typing.Self:
cls = type(self)
rv = cls.__new__(cls)
rv.options = self.options
rv.context = self.context.derive(kwargs)
rv.client = self.client
return rv


class EdgeDBAI(BaseEdgeDBAI):
client: httpx.Client

def _init_client(self, **kwargs):
self.client = httpx.Client(**kwargs)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need a wrapper for the /embeddings endpoint too.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll add it in a new PR 🙏

def query_rag(
self, message: str, context: typing.Optional[types.QueryContext] = None
) -> str:
if context is None:
context = self.context
resp = self.client.post(
**types.RAGRequest(
model=self.options.model,
prompt=self.options.prompt,
context=context,
query=message,
stream=False,
).to_httpx_request()
)
resp.raise_for_status()
return resp.json()["response"]

def stream_rag(
self, message: str, context: typing.Optional[types.QueryContext] = None
):
if context is None:
context = self.context
with httpx_sse.connect_sse(
self.client,
"post",
**types.RAGRequest(
model=self.options.model,
prompt=self.options.prompt,
context=context,
query=message,
stream=True,
).to_httpx_request(),
) as event_source:
event_source.response.raise_for_status()
for sse in event_source.iter_sse():
yield sse.data


class AsyncEdgeDBAI(BaseEdgeDBAI):
client: httpx.AsyncClient

def _init_client(self, **kwargs):
self.client = httpx.AsyncClient(**kwargs)

async def query_rag(
self, message: str, context: typing.Optional[types.QueryContext] = None
) -> str:
if context is None:
context = self.context
resp = await self.client.post(
**types.RAGRequest(
model=self.options.model,
prompt=self.options.prompt,
context=context,
query=message,
stream=False,
).to_httpx_request()
)
resp.raise_for_status()
return resp.json()["response"]

async def stream_rag(
self, message: str, context: typing.Optional[types.QueryContext] = None
):
if context is None:
context = self.context
async with httpx_sse.aconnect_sse(
self.client,
"post",
**types.RAGRequest(
model=self.options.model,
prompt=self.options.prompt,
context=context,
query=message,
stream=True,
).to_httpx_request(),
) as event_source:
event_source.response.raise_for_status()
async for sse in event_source.aiter_sse():
yield sse.data
81 changes: 81 additions & 0 deletions edgedb/ai/types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
#
# This source file is part of the EdgeDB open source project.
#
# Copyright 2024-present MagicStack Inc. and the EdgeDB authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import typing

import dataclasses as dc
import enum


class ChatParticipantRole(enum.Enum):
SYSTEM = "system"
USER = "user"
ASSISTANT = "assistant"
TOOL = "tool"


class Custom(typing.TypedDict):
role: ChatParticipantRole
content: str


class Prompt:
name: typing.Optional[str]
id: typing.Optional[str]
custom: typing.Optional[typing.List[Custom]]


@dc.dataclass
class AIOptions:
model: str
prompt: typing.Optional[Prompt] = None

def derive(self, kwargs):
return AIOptions(**{**dc.asdict(self), **kwargs})


@dc.dataclass
class QueryContext:
query: str = ""
variables: typing.Optional[typing.Dict[str, typing.Any]] = None
globals: typing.Optional[typing.Dict[str, typing.Any]] = None
max_object_count: typing.Optional[int] = None

def derive(self, kwargs):
return QueryContext(**{**dc.asdict(self), **kwargs})


@dc.dataclass
class RAGRequest:
model: str
prompt: typing.Optional[Prompt]
context: QueryContext
query: str
stream: typing.Optional[bool]

def to_httpx_request(self) -> typing.Dict[str, typing.Any]:
return dict(
url="/rag",
headers={
"Content-Type": "application/json",
"Accept": (
"text/event-stream" if self.stream else "application/json"
),
},
json=dc.asdict(self),
)
6 changes: 6 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,13 @@
'sphinx_rtd_theme~=1.0.0',
]

AI_DEPENDENCIES = [
'httpx~=0.27.0',
'httpx-sse~=0.4.0',
]

EXTRA_DEPENDENCIES = {
'ai': AI_DEPENDENCIES,
'docs': DOC_DEPENDENCIES,
'test': TEST_DEPENDENCIES,
# Dependencies required to develop edgedb.
Expand Down
Loading