Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Cached Examples for Streamed Media #9373

Merged
merged 14 commits into from
Sep 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .changeset/green-pigs-wonder.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
---
"@gradio/audio": minor
"gradio": minor
---

feat:Fix Cached Examples for Streamed Media
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@ demo/annotatedimage_component/*.png
demo/fake_diffusion_with_gif/*.gif
demo/cancel_events/cancel_events_output_log.txt
demo/unload_event_test/output_log.txt
demo/stream_video_out/output_*.ts
demo/stream_video_out/output_*.mp4
demo/stream_audio_out/*.mp3

# Etc
.idea/*
Expand Down
Binary file removed demo/stream_video_out/output_0.mp4
Binary file not shown.
Binary file removed demo/stream_video_out/output_0.ts
Binary file not shown.
Binary file removed demo/stream_video_out/output_1.mp4
Binary file not shown.
Binary file removed demo/stream_video_out/output_1.ts
Binary file not shown.
2 changes: 1 addition & 1 deletion demo/stream_video_out/run.ipynb
Original file line number Diff line number Diff line change
@@ -1 +1 @@
{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: stream_video_out"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio opencv-python"]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["# Downloading files from the demo repo\n", "import os\n", "!wget -q https://github.com/gradio-app/gradio/raw/main/demo/stream_video_out/output_0.mp4\n", "!wget -q https://github.com/gradio-app/gradio/raw/main/demo/stream_video_out/output_0.ts\n", "!wget -q https://github.com/gradio-app/gradio/raw/main/demo/stream_video_out/output_1.mp4\n", "!wget -q https://github.com/gradio-app/gradio/raw/main/demo/stream_video_out/output_1.ts\n", "os.mkdir('video')\n", "!wget -q -O video/compliment_bot_screen_recording_3x.mp4 https://github.com/gradio-app/gradio/raw/main/demo/stream_video_out/video/compliment_bot_screen_recording_3x.mp4"]}, {"cell_type": "code", "execution_count": null, "id": "44380577570523278879349135829904343037", "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "import cv2\n", "import os\n", "from pathlib import Path\n", "import atexit\n", "\n", "current_dir = Path(__file__).resolve().parent\n", "\n", "\n", "def delete_files():\n", " for p in Path(current_dir).glob(\"*.ts\"):\n", " p.unlink()\n", " for p in Path(current_dir).glob(\"*.mp4\"):\n", " p.unlink()\n", "\n", "atexit.register(delete_files)\n", "\n", "\n", "def process_video(input_video, stream_as_mp4):\n", " cap = cv2.VideoCapture(input_video)\n", "\n", " video_codec = cv2.VideoWriter_fourcc(*\"mp4v\") if stream_as_mp4 else cv2.VideoWriter_fourcc(*\"x264\") # type: ignore\n", " fps = int(cap.get(cv2.CAP_PROP_FPS))\n", " width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))\n", " height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))\n", "\n", " iterating, frame = cap.read()\n", "\n", " n_frames = 0\n", " n_chunks = 0\n", " name = str(current_dir / f\"output_{n_chunks}{'.mp4' if stream_as_mp4 else '.ts'}\")\n", " segment_file = cv2.VideoWriter(name, video_codec, fps, (width, height)) # type: ignore\n", "\n", " while iterating:\n", "\n", " # flip frame vertically\n", " frame = cv2.flip(frame, 0)\n", " display_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)\n", " segment_file.write(display_frame)\n", " n_frames += 1\n", " if n_frames == 3 * fps:\n", " n_chunks += 1\n", " segment_file.release()\n", " n_frames = 0\n", " yield name\n", " name = str(current_dir / f\"output_{n_chunks}{'.mp4' if stream_as_mp4 else '.ts'}\")\n", " segment_file = cv2.VideoWriter(name, video_codec, fps, (width, height)) # type: ignore\n", "\n", " iterating, frame = cap.read()\n", "\n", " segment_file.release()\n", " yield name\n", "\n", "with gr.Blocks() as demo:\n", " gr.Markdown(\"# Video Streaming Out \ud83d\udcf9\")\n", " with gr.Row():\n", " with gr.Column():\n", " input_video = gr.Video(label=\"input\")\n", " checkbox = gr.Checkbox(label=\"Stream as MP4 file?\", value=False)\n", " with gr.Column():\n", " processed_frames = gr.Video(label=\"stream\", streaming=True, autoplay=True, elem_id=\"stream_video_output\")\n", " with gr.Row():\n", " process_video_btn = gr.Button(\"process video\")\n", "\n", " process_video_btn.click(process_video, [input_video, checkbox], [processed_frames])\n", "\n", " gr.Examples(\n", " [[os.path.join(os.path.abspath(''), \"video/compliment_bot_screen_recording_3x.mp4\"), False],\n", " [os.path.join(os.path.abspath(''), \"video/compliment_bot_screen_recording_3x.mp4\"), True]],\n", " [input_video, checkbox],\n", " fn=process_video,\n", " outputs=processed_frames,\n", " cache_examples=False,\n", " )\n", "\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}
{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: stream_video_out"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio opencv-python"]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["# Downloading files from the demo repo\n", "import os\n", "os.mkdir('video')\n", "!wget -q -O video/compliment_bot_screen_recording_3x.mp4 https://github.com/gradio-app/gradio/raw/main/demo/stream_video_out/video/compliment_bot_screen_recording_3x.mp4"]}, {"cell_type": "code", "execution_count": null, "id": "44380577570523278879349135829904343037", "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "import cv2\n", "import os\n", "from pathlib import Path\n", "import atexit\n", "\n", "current_dir = Path(__file__).resolve().parent\n", "\n", "\n", "def delete_files():\n", " for p in Path(current_dir).glob(\"*.ts\"):\n", " p.unlink()\n", " for p in Path(current_dir).glob(\"*.mp4\"):\n", " p.unlink()\n", "\n", "atexit.register(delete_files)\n", "\n", "\n", "def process_video(input_video, stream_as_mp4):\n", " cap = cv2.VideoCapture(input_video)\n", "\n", " video_codec = cv2.VideoWriter_fourcc(*\"mp4v\") if stream_as_mp4 else cv2.VideoWriter_fourcc(*\"x264\") # type: ignore\n", " fps = int(cap.get(cv2.CAP_PROP_FPS))\n", " width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))\n", " height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))\n", "\n", " iterating, frame = cap.read()\n", "\n", " n_frames = 0\n", " n_chunks = 0\n", " name = str(current_dir / f\"output_{n_chunks}{'.mp4' if stream_as_mp4 else '.ts'}\")\n", " segment_file = cv2.VideoWriter(name, video_codec, fps, (width, height)) # type: ignore\n", "\n", " while iterating:\n", "\n", " # flip frame vertically\n", " frame = cv2.flip(frame, 0)\n", " display_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)\n", " segment_file.write(display_frame)\n", " n_frames += 1\n", " if n_frames == 3 * fps:\n", " n_chunks += 1\n", " segment_file.release()\n", " n_frames = 0\n", " yield name\n", " name = str(current_dir / f\"output_{n_chunks}{'.mp4' if stream_as_mp4 else '.ts'}\")\n", " segment_file = cv2.VideoWriter(name, video_codec, fps, (width, height)) # type: ignore\n", "\n", " iterating, frame = cap.read()\n", "\n", " segment_file.release()\n", " yield name\n", "\n", "with gr.Blocks() as demo:\n", " gr.Markdown(\"# Video Streaming Out \ud83d\udcf9\")\n", " with gr.Row():\n", " with gr.Column():\n", " input_video = gr.Video(label=\"input\")\n", " checkbox = gr.Checkbox(label=\"Stream as MP4 file?\", value=False)\n", " with gr.Column():\n", " processed_frames = gr.Video(label=\"stream\", streaming=True, autoplay=True, elem_id=\"stream_video_output\")\n", " with gr.Row():\n", " process_video_btn = gr.Button(\"process video\")\n", "\n", " process_video_btn.click(process_video, [input_video, checkbox], [processed_frames])\n", "\n", " gr.Examples(\n", " [[os.path.join(os.path.abspath(''), \"video/compliment_bot_screen_recording_3x.mp4\"), False],\n", " [os.path.join(os.path.abspath(''), \"video/compliment_bot_screen_recording_3x.mp4\"), True]],\n", " [input_video, checkbox],\n", " fn=process_video,\n", " outputs=processed_frames,\n", " cache_examples=False,\n", " )\n", "\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}
8 changes: 7 additions & 1 deletion gradio/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -1823,7 +1823,13 @@ async def handle_streaming_outputs(
first_chunk,
)
if first_chunk:
stream_run[output_id] = MediaStream()
desired_output_format = None
if orig_name := output_data.get("orig_name"):
desired_output_format = Path(orig_name).suffix[1:]
stream_run[output_id] = MediaStream(
desired_output_format=desired_output_format
)
stream_run[output_id]

await stream_run[output_id].add_segment(binary_data)
output_data = await processing_utils.async_move_files_to_cache(
Expand Down
21 changes: 21 additions & 0 deletions gradio/components/audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,27 @@ async def stream_output(
value, duration = await self.covert_to_adts(binary_data)
return {"data": value, "duration": duration, "extension": ".aac"}, output_file

async def combine_stream(
self,
stream: list[bytes],
desired_output_format: str | None = None,
only_file=False, # noqa: ARG002
) -> FileData:
output_file = FileData(
path=processing_utils.save_bytes_to_cache(
b"".join(stream), "audio.mp3", cache_dir=self.GRADIO_CACHE
),
is_stream=False,
orig_name="audio-stream.mp3",
)
if desired_output_format and desired_output_format != "mp3":
new_path = Path(output_file.path).with_suffix(f".{desired_output_format}")
AudioSegment.from_file(output_file.path).export(
new_path, format=desired_output_format
)
output_file.path = str(new_path)
return output_file

def process_example(
self, value: tuple[int, np.ndarray] | str | Path | bytes | None
) -> str:
Expand Down
16 changes: 16 additions & 0 deletions gradio/components/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from gradio.data_classes import (
BaseModel,
DeveloperPath,
FileData,
FileDataDict,
GradioDataModel,
MediaStreamChunk,
Expand Down Expand Up @@ -383,6 +384,21 @@ async def stream_output(
) -> tuple[MediaStreamChunk | None, FileDataDict | dict]:
pass

@abc.abstractmethod
async def combine_stream(
self,
stream: list[bytes],
desired_output_format: str | None = None,
only_file=False,
) -> GradioDataModel | FileData:
"""Combine all of the stream chunks into a single file.

This is needed for downloading the stream and for caching examples.
If `only_file` is True, only the FileData corresponding to the file should be returned (needed for downloading the stream).
The desired_output_format optionally converts the combined file. Should only be used for cached examples.
"""
pass


class StreamingInput(metaclass=abc.ABCMeta):
def __init__(self, *args, **kwargs) -> None:
Expand Down
65 changes: 64 additions & 1 deletion gradio/components/video.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
if TYPE_CHECKING:
from gradio.components import Timer


if not wasm_utils.IS_WASM:
# TODO: Support ffmpeg on Wasm
from ffmpy import FFmpeg
Expand Down Expand Up @@ -485,6 +486,66 @@ async def async_convert_mp4_to_ts(mp4_file, ts_file):

return ts_file

async def combine_stream(
self,
stream: list[bytes],
desired_output_format: str | None = None, # noqa: ARG002
only_file=False,
) -> VideoData | FileData:
"""Combine video chunks into a single video file.

Do not take desired_output_format into consideration as
mp4 is a safe format for playing in browser.
"""
if wasm_utils.IS_WASM:
raise wasm_utils.WasmUnsupportedError(
"Streaming is not supported in the Wasm mode."
)

# Use an mp4 extension here so that the cached example
# is playable in the browser
output_file = tempfile.NamedTemporaryFile(
delete=False, suffix=".mp4", dir=self.GRADIO_CACHE
)

ts_files = [
processing_utils.save_bytes_to_cache(
s, "video_chunk.ts", cache_dir=self.GRADIO_CACHE
)
for s in stream
]

command = [
"ffmpeg",
"-i",
f'concat:{"|".join(ts_files)}',
"-y",
"-safe",
"0",
"-c",
"copy",
output_file.name,
]
process = await asyncio.create_subprocess_exec(
*command, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE
)

_, stderr = await process.communicate()

if process.returncode != 0:
error_message = stderr.decode().strip()
raise RuntimeError(f"FFmpeg command failed: {error_message}")
video = FileData(
path=output_file.name,
is_stream=False,
orig_name="video-stream.mp4",
)
if only_file:
return video

output = VideoData(video=video)
return output

async def stream_output(
self,
value: str | None,
Expand All @@ -495,7 +556,9 @@ async def stream_output(
"video": {
"path": output_id,
"is_stream": True,
"orig_name": "video-stream.ts",
# Need to set orig_name so that downloaded file has correct
# extension
"orig_name": "video-stream.mp4",
}
}
if value is None:
Expand Down
26 changes: 10 additions & 16 deletions gradio/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import csv
import inspect
import os
import tempfile
import warnings
from collections.abc import Callable, Iterable, Sequence
from functools import partial
Expand Down Expand Up @@ -604,29 +603,24 @@ async def merge_generated_values_into_output(
for output_index, output_component in enumerate(components):
if isinstance(output_component, StreamingOutput) and output_component.streaming:
binary_chunks = []
desired_output_format = None
for i, chunk in enumerate(generated_values):
if len(components) > 1:
chunk = chunk[output_index]
processed_chunk = output_component.postprocess(chunk)
if isinstance(processed_chunk, (GradioModel, GradioRootModel)):
processed_chunk = processed_chunk.model_dump()
binary_chunks.append(
(await output_component.stream_output(processed_chunk, "", i == 0))[
0
]
stream_chunk = await output_component.stream_output(
processed_chunk, "", i == 0
)
binary_data = b"".join([d["data"] for d in binary_chunks])
tempdir = os.environ.get("GRADIO_TEMP_DIR") or str(
Path(tempfile.gettempdir()) / "gradio"
if i == 0 and (orig_name := stream_chunk[1].get("orig_name")):
desired_output_format = Path(orig_name).suffix[1:]
if stream_chunk[0]:
binary_chunks.append(stream_chunk[0]["data"])
combined_output = await output_component.combine_stream(
binary_chunks, desired_output_format=desired_output_format
)
os.makedirs(tempdir, exist_ok=True)
temp_file = tempfile.NamedTemporaryFile(dir=tempdir, delete=False)
with open(temp_file.name, "wb") as f:
f.write(binary_data)

output[output_index] = {
"path": temp_file.name,
}
output[output_index] = combined_output.model_dump()

return output

Expand Down
4 changes: 3 additions & 1 deletion gradio/route_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -892,12 +892,14 @@ async def _handler(app: App):


class MediaStream:
def __init__(self):
def __init__(self, desired_output_format: str | None = None):
self.segments: list[MediaStreamChunk] = []
self.combined_file: str | None = None
self.ended = False
self.segment_index = 0
self.playlist = "#EXTM3U\n#EXT-X-PLAYLIST-TYPE:EVENT\n#EXT-X-TARGETDURATION:10\n#EXT-X-VERSION:4\n#EXT-X-MEDIA-SEQUENCE:0\n"
self.max_duration = 5
self.desired_output_format = desired_output_format

async def add_segment(self, data: MediaStreamChunk | None):
if not data:
Expand Down
22 changes: 13 additions & 9 deletions gradio/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -695,15 +695,19 @@ async def _(session_hash: str, run: int, component_id: int):
if not stream:
return Response(status_code=404)

byte_stream = b""
extension = ""
for segment in stream.segments:
extension = segment["extension"]
byte_stream += segment["data"]

media_type = "video/MP2T" if extension == ".ts" else "audio/aac"

return Response(content=byte_stream, media_type=media_type)
if not stream.combined_file:
stream_data = [s["data"] for s in stream.segments]
combined_file = (
await app.get_blocks()
.get_component(component_id)
.combine_stream( # type: ignore
stream_data,
only_file=True,
desired_output_format=stream.desired_output_format,
)
)
stream.combined_file = combined_file.path
return FileResponse(stream.combined_file)

@router.get("/file/{path:path}", dependencies=[Depends(login_check)])
async def file_deprecated(path: str, request: fastapi.Request):
Expand Down
2 changes: 1 addition & 1 deletion js/audio/player/AudioPlayer.svelte
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@
});
};

$: if (container !== undefined) {
$: if (!value?.is_stream && container !== undefined && container !== null) {
if (waveform !== undefined) waveform.destroy();
container.innerHTML = "";
create_waveform();
Expand Down
18 changes: 18 additions & 0 deletions test/components/test_audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,3 +180,21 @@ def test_prepost_process_to_mp3(self, gradio_temp_dir):
(48000, np.random.randint(-256, 256, (5, 3)).astype(np.int16))
).model_dump() # type: ignore
assert output["path"].endswith("mp3")

@pytest.mark.asyncio
async def test_combine_stream_audio(self, gradio_temp_dir):
x_wav = FileData(
path=processing_utils.save_base64_to_cache(
media_data.BASE64_MICROPHONE["data"], cache_dir=gradio_temp_dir
)
)
bytes_output = [Path(x_wav.path).read_bytes()] * 2
output = await gr.Audio().combine_stream(
bytes_output, desired_output_format="wav"
)
assert str(output.path).endswith("wav")

output = await gr.Audio().combine_stream(
bytes_output, desired_output_format=None
)
assert str(output.path).endswith("mp3")
Loading