Skip to content

Commit

Permalink
Fix SSR mode flag with mount_gradio_app and revert changes to pytes…
Browse files Browse the repository at this point in the history
…ts (#9446)

* Revert "Fix Python unit tests on `5.0-dev` branch (#9432)"

This reverts commit 278645b.

* revert changes to pytest

* add changeset

* fix

---------

Co-authored-by: gradio-pr-bot <[email protected]>
  • Loading branch information
abidlabs and gradio-pr-bot authored Sep 26, 2024
1 parent afbd8e7 commit 0c8fafb
Show file tree
Hide file tree
Showing 3 changed files with 97 additions and 48 deletions.
5 changes: 5 additions & 0 deletions .changeset/stupid-tires-stare.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"gradio": minor
---

feat:Fix SSR mode flag with `mount_gradio_app` and revert changes to pytests
2 changes: 1 addition & 1 deletion gradio/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1567,7 +1567,7 @@ def read_main():
else (
ssr_mode
if ssr_mode is not None
else bool(os.getenv("GRADIO_SSR_MODE", "False"))
else os.getenv("GRADIO_SSR_MODE", "False").lower() == "true"
)
)

Expand Down
138 changes: 91 additions & 47 deletions test/test_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import json
import os
import pickle
import socket
import tempfile
import time
from contextlib import asynccontextmanager, closing
Expand All @@ -19,7 +18,6 @@
import pytest
import requests
import starlette.routing
import uvicorn
from fastapi import FastAPI, Request
from fastapi.testclient import TestClient
from gradio_client import media_data
Expand Down Expand Up @@ -369,65 +367,97 @@ def test_get_file_created_by_app(self, test_client):
assert len(file_response_with_partial_range.text) == 11

def test_mount_gradio_app(self):
@asynccontextmanager
async def empty_lifespan(app: FastAPI):
yield
app = FastAPI()

app = FastAPI(lifespan=empty_lifespan)
demo = gr.Interface(
lambda s: f"Hello from ps, {s}!", "textbox", "textbox"
).queue()
demo1 = gr.Interface(
lambda s: f"Hello from py, {s}!", "textbox", "textbox"
).queue()

app = gr.mount_gradio_app(app, demo, path="/ps")
app = gr.mount_gradio_app(app, demo1, path="/py")

# Use context manager to trigger start up events
with TestClient(app) as client:
assert client.get("/ps").is_success
assert client.get("/py").is_success

demo1 = gr.Interface(lambda s: f"Hello 1, {s}!", "textbox", "textbox")
demo2 = gr.Interface(lambda s: f"Hello 2, {s}!", "textbox", "textbox")
demo3 = gr.Interface(
lambda s: f"Password-Protected Hello, {s}!", "textbox", "textbox"
def test_mount_gradio_app_with_app_kwargs(self):
app = FastAPI()
demo = gr.Interface(lambda s: f"You said {s}!", "textbox", "textbox").queue()
app = gr.mount_gradio_app(
app,
demo,
path="/echo",
app_kwargs={"docs_url": "/docs-custom"},
)
# Use context manager to trigger start up events
with TestClient(app) as client:
assert client.get("/echo/docs-custom").is_success

app = gr.mount_gradio_app(app, demo1, path="/demo1")
app = gr.mount_gradio_app(app, demo2, path="/demo2")
app = gr.mount_gradio_app(app, demo3, path="/demo-auth", auth=("a", "b"))
def test_mount_gradio_app_with_auth_and_params(self):
app = FastAPI()
demo = gr.Interface(lambda s: f"You said {s}!", "textbox", "textbox").queue()
app = gr.mount_gradio_app(
app,
demo,
path=f"{API_PREFIX}/echo",
auth=("a", "b"),
root_path=f"{API_PREFIX}/echo",
allowed_paths=["test/test_files/bus.png"],
)
# Use context manager to trigger start up events
with TestClient(app) as client:
assert client.get(f"{API_PREFIX}/echo/config").status_code == 401
assert demo.root_path == f"{API_PREFIX}/echo"
assert demo.allowed_paths == ["test/test_files/bus.png"]
assert demo.show_error

def get_free_port():
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("", 0)) # Bind to any free port
return s.getsockname()[1] # Get the port number
def test_mount_gradio_app_with_lifespan(self):
@asynccontextmanager
async def empty_lifespan(app: FastAPI):
yield

global port, server # noqa: PLW0603
port = None
server = None
app = FastAPI(lifespan=empty_lifespan)

def run_server():
global port, server # noqa: PLW0603
demo = gr.Interface(
lambda s: f"Hello from ps, {s}!", "textbox", "textbox"
).queue()
demo1 = gr.Interface(
lambda s: f"Hello from py, {s}!", "textbox", "textbox"
).queue()

port = get_free_port()
config = uvicorn.Config(app, host="127.0.0.1", port=port)
server = uvicorn.Server(config)
server.run()
app = gr.mount_gradio_app(app, demo, path="/ps")
app = gr.mount_gradio_app(app, demo1, path="/py")

server_thread = Thread(target=run_server, daemon=True)
server_thread.start()
# Use context manager to trigger start up events
with TestClient(app) as client:
assert client.get("/ps").is_success
assert client.get("/py").is_success

start_time = time.time()
while server is None:
time.sleep(0.01)
if time.time() - start_time > 3:
raise TimeoutError("Server did not start in time")
def test_mount_gradio_app_with_startup(self):
app = FastAPI()

base_url = f"http://127.0.0.1:{port}"
@app.on_event("startup")
async def empty_startup():
return

# Test the main routes
assert requests.get(f"{base_url}/demo1").status_code == 200
assert requests.get(f"{base_url}/demo2").status_code == 200
assert requests.get(f"{base_url}/demo-non-existent").status_code == 404
demo = gr.Interface(
lambda s: f"Hello from ps, {s}!", "textbox", "textbox"
).queue()
demo1 = gr.Interface(
lambda s: f"Hello from py, {s}!", "textbox", "textbox"
).queue()

# Test auth (TODO: Fix this)
assert (
requests.get(f"{base_url}/demo-auth").status_code
!= 200 # It should be 401, but it's 500
)
# requests.post(f"{base_url}/demo-auth/login", data={"username": "a", "password": "b"})
# assert requests.get(f"{base_url}/demo-auth").status_code == 200
app = gr.mount_gradio_app(app, demo, path="/ps")
app = gr.mount_gradio_app(app, demo1, path="/py")

server.should_exit = True # type: ignore
server_thread.join()
# Use context manager to trigger start up events
with TestClient(app) as client:
assert client.get("/ps").is_success
assert client.get("/py").is_success

def test_gradio_app_with_auth_dependency(self):
def block_anonymous(request: Request):
Expand All @@ -442,6 +472,20 @@ def block_anonymous(request: Request):
assert not client.get("/", headers={}).is_success
assert client.get("/", headers={"user": "abubakar"}).is_success

def test_mount_gradio_app_with_auth_dependency(self):
app = FastAPI()

def get_user(request: Request):
return request.headers.get("user")

demo = gr.Interface(lambda s: f"Hello from ps, {s}!", "textbox", "textbox")

app = gr.mount_gradio_app(app, demo, path="/demo", auth_dependency=get_user)

with TestClient(app) as client:
assert client.get("/demo", headers={"user": "abubakar"}).is_success
assert not client.get("/demo").is_success

def test_static_file_missing(self, test_client):
response = test_client.get(rf"{API_PREFIX}/static/not-here.js")
assert response.status_code == 404
Expand Down

0 comments on commit 0c8fafb

Please sign in to comment.