Skip to content

Commit

Permalink
test chat interface fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
abidlabs committed Sep 20, 2024
1 parent 182d2bf commit edb40f6
Showing 1 changed file with 38 additions and 46 deletions.
84 changes: 38 additions & 46 deletions test/test_chat_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,7 @@ def test_example_caching(self):
"gradio.utils.get_cache_folder", return_value=Path(tempfile.mkdtemp())
):
chatbot = gr.ChatInterface(
double,
examples=[{"text": "hello"}, {"text": "hi"}],
cache_examples=True,
double, examples=["hello", "hi"], cache_examples=True
)
prediction_hello = chatbot.examples_handler.load_from_cache(0)
prediction_hi = chatbot.examples_handler.load_from_cache(1)
Expand All @@ -102,7 +100,7 @@ async def test_example_caching_lazy(self):
):
chatbot = gr.ChatInterface(
double,
examples=[{"text": "hello"}, {"text": "hi"}],
examples=["hello", "hi"],
cache_examples=True,
cache_mode="lazy",
)
Expand All @@ -121,9 +119,7 @@ def test_example_caching_async(self):
"gradio.utils.get_cache_folder", return_value=Path(tempfile.mkdtemp())
):
chatbot = gr.ChatInterface(
async_greet,
examples=[{"text": "abubakar"}, {"text": "tom"}],
cache_examples=True,
async_greet, examples=["abubakar", "tom"], cache_examples=True
)
prediction_hello = chatbot.examples_handler.load_from_cache(0)
prediction_hi = chatbot.examples_handler.load_from_cache(1)
Expand All @@ -135,9 +131,7 @@ def test_example_caching_with_streaming(self):
"gradio.utils.get_cache_folder", return_value=Path(tempfile.mkdtemp())
):
chatbot = gr.ChatInterface(
stream,
examples=[{"text": "hello"}, {"text": "hi"}],
cache_examples=True,
stream, examples=["hello", "hi"], cache_examples=True
)
prediction_hello = chatbot.examples_handler.load_from_cache(0)
prediction_hi = chatbot.examples_handler.load_from_cache(1)
Expand All @@ -149,9 +143,7 @@ def test_example_caching_with_streaming_async(self):
"gradio.utils.get_cache_folder", return_value=Path(tempfile.mkdtemp())
):
chatbot = gr.ChatInterface(
async_stream,
examples=[{"text": "hello"}, {"text": "hi"}],
cache_examples=True,
async_stream, examples=["hello", "hi"], cache_examples=True
)
prediction_hello = chatbot.examples_handler.load_from_cache(0)
prediction_hi = chatbot.examples_handler.load_from_cache(1)
Expand Down Expand Up @@ -185,39 +177,39 @@ def test_setting_accordion_params(self, monkeypatch):
assert accordion.get_config().get("open") is True
assert accordion.get_config().get("label") == "MOAR"

# def test_example_caching_with_additional_inputs(self, monkeypatch):
# with patch(
# "gradio.utils.get_cache_folder", return_value=Path(tempfile.mkdtemp())
# ):
# chatbot = gr.ChatInterface(
# echo_system_prompt_plus_message,
# additional_inputs=["textbox", "slider"],
# examples=[["hello", "robot", 100], ["hi", "robot", 2]],
# cache_examples=True,
# )
# prediction_hello = chatbot.examples_handler.load_from_cache(0)
# prediction_hi = chatbot.examples_handler.load_from_cache(1)
# assert prediction_hello[0].root[0] == ("hello", "robot hello")
# assert prediction_hi[0].root[0] == ("hi", "ro")

# def test_example_caching_with_additional_inputs_already_rendered(self, monkeypatch):
# with patch(
# "gradio.utils.get_cache_folder", return_value=Path(tempfile.mkdtemp())
# ):
# with gr.Blocks():
# with gr.Accordion("Inputs"):
# text = gr.Textbox()
# slider = gr.Slider()
# chatbot = gr.ChatInterface(
# echo_system_prompt_plus_message,
# additional_inputs=[text, slider],
# examples=[["hello", "robot", 100], ["hi", "robot", 2]],
# cache_examples=True,
# )
# prediction_hello = chatbot.examples_handler.load_from_cache(0)
# prediction_hi = chatbot.examples_handler.load_from_cache(1)
# assert prediction_hello[0].root[0] == ("hello", "robot hello")
# assert prediction_hi[0].root[0] == ("hi", "ro")
def test_example_caching_with_additional_inputs(self, monkeypatch):
with patch(
"gradio.utils.get_cache_folder", return_value=Path(tempfile.mkdtemp())
):
chatbot = gr.ChatInterface(
echo_system_prompt_plus_message,
additional_inputs=["textbox", "slider"],
examples=[["hello", "robot", 100], ["hi", "robot", 2]],
cache_examples=True,
)
prediction_hello = chatbot.examples_handler.load_from_cache(0)
prediction_hi = chatbot.examples_handler.load_from_cache(1)
assert prediction_hello[0].root[0] == ("hello", "robot hello")
assert prediction_hi[0].root[0] == ("hi", "ro")

def test_example_caching_with_additional_inputs_already_rendered(self, monkeypatch):
with patch(
"gradio.utils.get_cache_folder", return_value=Path(tempfile.mkdtemp())
):
with gr.Blocks():
with gr.Accordion("Inputs"):
text = gr.Textbox()
slider = gr.Slider()
chatbot = gr.ChatInterface(
echo_system_prompt_plus_message,
additional_inputs=[text, slider],
examples=[["hello", "robot", 100], ["hi", "robot", 2]],
cache_examples=True,
)
prediction_hello = chatbot.examples_handler.load_from_cache(0)
prediction_hi = chatbot.examples_handler.load_from_cache(1)
assert prediction_hello[0].root[0] == ("hello", "robot hello")
assert prediction_hi[0].root[0] == ("hi", "ro")

def test_custom_chatbot_with_events(self):
with gr.Blocks() as demo:
Expand Down

0 comments on commit edb40f6

Please sign in to comment.