From 244b85f97619e48ca419b95b2b823c7768d5312d Mon Sep 17 00:00:00 2001 From: "Dilyara Zharikova (Baymurzina)" Date: Mon, 7 Aug 2023 11:28:48 +0300 Subject: [PATCH] fix: anthropic model params (#547) --- .../anthropic_generative_config.json | 3 +++ services/anthropic_api_lm/server.py | 16 ++++++++-------- 2 files changed, 11 insertions(+), 8 deletions(-) create mode 100644 common/generative_configs/anthropic_generative_config.json diff --git a/common/generative_configs/anthropic_generative_config.json b/common/generative_configs/anthropic_generative_config.json new file mode 100644 index 0000000000..e0858bd0ea --- /dev/null +++ b/common/generative_configs/anthropic_generative_config.json @@ -0,0 +1,3 @@ +{ + "max_tokens_to_sample": 256 +} \ No newline at end of file diff --git a/services/anthropic_api_lm/server.py b/services/anthropic_api_lm/server.py index 8b430d0fbc..c701f6fddd 100644 --- a/services/anthropic_api_lm/server.py +++ b/services/anthropic_api_lm/server.py @@ -19,13 +19,13 @@ PRETRAINED_MODEL_NAME_OR_PATH = os.environ.get("PRETRAINED_MODEL_NAME_OR_PATH") logger.info(f"PRETRAINED_MODEL_NAME_OR_PATH = {PRETRAINED_MODEL_NAME_OR_PATH}") -NAMING = ["Assistant", "Human"] +NAMING = [anthropic.AI_PROMPT, anthropic.HUMAN_PROMPT] app = Flask(__name__) logging.getLogger("werkzeug").setLevel("WARNING") DEFAULT_CONFIGS = { - "claude-1": json.load(open("common/generative_configs/empty_generative_config.json", "r")), - "claude-instant-1": json.load(open("common/generative_configs/empty_generative_config.json", "r")), + "claude-1": json.load(open("common/generative_configs/anthropic_generative_config.json", "r")), + "claude-instant-1": json.load(open("common/generative_configs/anthropic_generative_config.json", "r")), } @@ -33,15 +33,15 @@ def generate_responses(context, anthropic_api_key, prompt, generation_params, co assert anthropic_api_key, logger.error("Error: Anthropic API key is not specified in env") outputs = [] - dialog_context = "" + dialog_context = f"{anthropic.HUMAN_PROMPT} " if prompt: - dialog_context += prompt + "\n" + dialog_context += prompt s = len(context) % 2 - context = [f"{NAMING[(s + uttr_id) % 2]}: {uttr}" for uttr_id, uttr in enumerate(context)] + context = [f"{NAMING[(s + uttr_id) % 2]} {uttr}" for uttr_id, uttr in enumerate(context)] if continue_last_uttr: - dialog_context += "\n".join(context) + dialog_context += "".join(context) else: - dialog_context += "\n".join(context) + f"\n{NAMING[0]}:" + dialog_context += "".join(context) + f"{NAMING[0]}" logger.info(f"context inside generate_responses seen as: {dialog_context}") client = anthropic.Client(api_key=anthropic_api_key)