Skip to content

Commit

Permalink
Add use_file_output to streaming methods (#355)
Browse files Browse the repository at this point in the history
Signed-off-by: Mattt Zmuda <[email protected]>
  • Loading branch information
mattt authored Sep 25, 2024
1 parent 4885f19 commit e53bd02
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 10 deletions.
6 changes: 4 additions & 2 deletions replicate/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,25 +190,27 @@ def stream(
self,
ref: str,
input: Optional[Dict[str, Any]] = None,
use_file_output: Optional[bool] = None,
**params: Unpack["Predictions.CreatePredictionParams"],
) -> Iterator["ServerSentEvent"]:
"""
Stream a model's output.
"""

return stream(self, ref, input, **params)
return stream(self, ref, input, use_file_output, **params)

async def async_stream(
self,
ref: str,
input: Optional[Dict[str, Any]] = None,
use_file_output: Optional[bool] = None,
**params: Unpack["Predictions.CreatePredictionParams"],
) -> AsyncIterator["ServerSentEvent"]:
"""
Stream a model's output asynchronously.
"""

return async_stream(self, ref, input, **params)
return async_stream(self, ref, input, use_file_output, **params)


# Adapted from https://github.com/encode/httpx/issues/108#issuecomment-1132753155
Expand Down
18 changes: 14 additions & 4 deletions replicate/prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,10 @@ async def async_wait(self) -> None:
await asyncio.sleep(self._client.poll_interval)
await self.async_reload()

def stream(self) -> Iterator["ServerSentEvent"]:
def stream(
self,
use_file_output: Optional[bool] = None,
) -> Iterator["ServerSentEvent"]:
"""
Stream the prediction output.
Expand All @@ -170,9 +173,14 @@ def stream(self) -> Iterator["ServerSentEvent"]:
headers["Cache-Control"] = "no-store"

with self._client._client.stream("GET", url, headers=headers) as response:
yield from EventSource(response)
yield from EventSource(
self._client, response, use_file_output=use_file_output
)

async def async_stream(self) -> AsyncIterator["ServerSentEvent"]:
async def async_stream(
self,
use_file_output: Optional[bool] = None,
) -> AsyncIterator["ServerSentEvent"]:
"""
Stream the prediction output asynchronously.
Expand All @@ -194,7 +202,9 @@ async def async_stream(self) -> AsyncIterator["ServerSentEvent"]:
async with self._client._async_client.stream(
"GET", url, headers=headers
) as response:
async for event in EventSource(response):
async for event in EventSource(
self._client, response, use_file_output=use_file_output
):
yield event

def cancel(self) -> None:
Expand Down
34 changes: 30 additions & 4 deletions replicate/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from replicate import identifier
from replicate.exceptions import ReplicateError
from replicate.helpers import transform_output

try:
from pydantic import v1 as pydantic # type: ignore
Expand Down Expand Up @@ -62,10 +63,19 @@ class EventSource:
A server-sent event source.
"""

client: "Client"
response: "httpx.Response"

def __init__(self, response: "httpx.Response") -> None:
use_file_output: bool

def __init__(
self,
client: "Client",
response: "httpx.Response",
use_file_output: Optional[bool] = None,
) -> None:
self.client = client
self.response = response
self.use_file_output = use_file_output or False
content_type, _, _ = response.headers["content-type"].partition(";")
if content_type != "text/event-stream":
raise ValueError(
Expand Down Expand Up @@ -147,6 +157,12 @@ def __iter__(self) -> Iterator[ServerSentEvent]:
if sse.event == ServerSentEvent.EventType.ERROR:
raise RuntimeError(sse.data)

if (
self.use_file_output
and sse.event == ServerSentEvent.EventType.OUTPUT
):
sse.data = transform_output(sse.data, client=self.client)

yield sse

if sse.event == ServerSentEvent.EventType.DONE:
Expand All @@ -161,6 +177,12 @@ async def __aiter__(self) -> AsyncIterator[ServerSentEvent]:
if sse.event == ServerSentEvent.EventType.ERROR:
raise RuntimeError(sse.data)

if (
self.use_file_output
and sse.event == ServerSentEvent.EventType.OUTPUT
):
sse.data = transform_output(sse.data, client=self.client)

yield sse

if sse.event == ServerSentEvent.EventType.DONE:
Expand All @@ -171,6 +193,7 @@ def stream(
client: "Client",
ref: Union["Model", "Version", "ModelVersionIdentifier", str],
input: Optional[Dict[str, Any]] = None,
use_file_output: Optional[bool] = None,
**params: Unpack["Predictions.CreatePredictionParams"],
) -> Iterator[ServerSentEvent]:
"""
Expand Down Expand Up @@ -204,13 +227,14 @@ def stream(
headers["Cache-Control"] = "no-store"

with client._client.stream("GET", url, headers=headers) as response:
yield from EventSource(response)
yield from EventSource(client, response, use_file_output=use_file_output)


async def async_stream(
client: "Client",
ref: Union["Model", "Version", "ModelVersionIdentifier", str],
input: Optional[Dict[str, Any]] = None,
use_file_output: Optional[bool] = None,
**params: Unpack["Predictions.CreatePredictionParams"],
) -> AsyncIterator[ServerSentEvent]:
"""
Expand Down Expand Up @@ -244,7 +268,9 @@ async def async_stream(
headers["Cache-Control"] = "no-store"

async with client._async_client.stream("GET", url, headers=headers) as response:
async for event in EventSource(response):
async for event in EventSource(
client, response, use_file_output=use_file_output
):
yield event


Expand Down

0 comments on commit e53bd02

Please sign in to comment.