diff --git a/fastchat/serve/api_provider.py b/fastchat/serve/api_provider.py index 9fa852bca..1e319f0a2 100644 --- a/fastchat/serve/api_provider.py +++ b/fastchat/serve/api_provider.py @@ -99,17 +99,19 @@ def openai_api_stream_iter( ): import openai - is_azure = False + api_key = api_key or os.environ["OPENAI_API_KEY"] + if "azure" in model_name: - is_azure = True - openai.api_type = "azure" - openai.api_version = "2023-07-01-preview" + client = openai.AzureOpenAI( + api_version="2023-07-01-preview", + azure_endpoint=api_base or "https://api.openai.com/v1", + api_key=api_key, + ) else: - openai.api_type = "open_ai" - openai.api_version = None + client = openai.OpenAI( + base_url=api_base or "https://api.openai.com/v1", api_key=api_key + ) - openai.api_base = api_base or "https://api.openai.com/v1" - openai.api_key = api_key or os.environ["OPENAI_API_KEY"] if model_name == "gpt-4-turbo": model_name = "gpt-4-1106-preview" @@ -123,26 +125,17 @@ def openai_api_stream_iter( } logger.info(f"==== request ====\n{gen_params}") - if is_azure: - res = openai.ChatCompletion.create( - engine=model_name, - messages=messages, - temperature=temperature, - max_tokens=max_new_tokens, - stream=True, - ) - else: - res = openai.ChatCompletion.create( - model=model_name, - messages=messages, - temperature=temperature, - max_tokens=max_new_tokens, - stream=True, - ) + res = client.chat.completions.create( + model=model_name, + messages=messages, + temperature=temperature, + max_tokens=max_new_tokens, + stream=True, + ) text = "" for chunk in res: - if len(chunk["choices"]) > 0: - text += chunk["choices"][0]["delta"].get("content", "") + if len(chunk.choices) > 0: + text += chunk.choices[0].delta.content or "" data = { "text": text, "error_code": 0, diff --git a/fastchat/serve/gradio_block_arena_vision.py b/fastchat/serve/gradio_block_arena_vision.py index 60cad07c6..e9a7e7b17 100644 --- a/fastchat/serve/gradio_block_arena_vision.py +++ b/fastchat/serve/gradio_block_arena_vision.py @@ -41,7 +41,6 @@ def build_single_vision_language_model_ui(models, add_promotion_links=False): notice_markdown = f""" # 🏔️ Chat with Open Large Vision-Language Models {promotion} -## 🤖 Choose any model to chat """ state = gr.State() diff --git a/fastchat/utils.py b/fastchat/utils.py index 13a6333f0..5f14abce8 100644 --- a/fastchat/utils.py +++ b/fastchat/utils.py @@ -154,10 +154,7 @@ def oai_moderation(text): """ import openai - openai.api_base = "https://api.openai.com/v1" - openai.api_key = os.environ["OPENAI_API_KEY"] - openai.api_type = "open_ai" - openai.api_version = None + client = openai.OpenAI(api_key=os.environ["OPENAI_API_KEY"]) threshold_dict = { "sexual": 0.2, @@ -165,13 +162,13 @@ def oai_moderation(text): MAX_RETRY = 3 for _ in range(MAX_RETRY): try: - res = openai.Moderation.create(input=text) - flagged = res["results"][0]["flagged"] + res = client.moderations.create(input=text) + flagged = res.results[0].flagged for category, threshold in threshold_dict.items(): - if res["results"][0]["category_scores"][category] > threshold: + if getattr(res.results[0].category_scores, category) > threshold: flagged = True break - except (openai.error.OpenAIError, KeyError, IndexError) as e: + except (openai.OpenAIError, KeyError, IndexError) as e: # flag true to be conservative flagged = True print(f"MODERATION ERROR: {e}\nInput: {text}")