Skip to content

Commit

Permalink
fix: context format
Browse files Browse the repository at this point in the history
  • Loading branch information
dilyararimovna committed Aug 12, 2022
1 parent 5b3a822 commit b70eb8e
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions services/dialogpt/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@
logger = logging.getLogger(__name__)

PRETRAINED_MODEL_NAME_OR_PATH = os.environ.get("PRETRAINED_MODEL_NAME_OR_PATH")
N_HYPOTHESES_TO_GENERATE = int(os.environ.get("N_HYPOTHESES_TO_GENERATE", 1))
CONFIG_NAME = os.environ.get("CONFIG_NAME")
logging.info(f"PRETRAINED_MODEL_NAME_OR_PATH = {PRETRAINED_MODEL_NAME_OR_PATH}")
DEFAULT_CONFIDENCE = 0.9
N_HYPOTHESES_TO_GENERATE = int(os.environ.get("N_HYPOTHESES_TO_GENERATE", 1))
ZERO_CONFIDENCE = 0.0
MAX_HISTORY_DEPTH = 3
with open(CONFIG_NAME, "r") as f:
Expand All @@ -43,7 +43,7 @@
logging.getLogger("werkzeug").setLevel("WARNING")


def generate_response(context, model, tokenizer):
def generate_responses(context, model, tokenizer):
encoded_context = []
for uttr in context[-MAX_HISTORY_DEPTH:]:
encoded_context += [tokenizer.encode(uttr + " " + tokenizer.eos_token, return_tensors="pt")]
Expand All @@ -56,7 +56,7 @@ def generate_response(context, model, tokenizer):
if torch.cuda.is_available():
chat_history_ids = chat_history_ids.cpu()

outputs = [tokenizer.decode(x[len(bot_input_ids[0]) :], skip_special_tokens=True) for x in chat_history_ids]
outputs = [tokenizer.decode(x[len(bot_input_ids[0]):], skip_special_tokens=True) for x in chat_history_ids]
return outputs


Expand All @@ -71,8 +71,8 @@ def respond():
for context in contexts:
curr_responses = []
curr_confidences = []
responses = generate_response(context, model, tokenizer)
for response in responses:
outputs = generate_responses(context, model, tokenizer)
for response in outputs:
if len(response) > 3:
# drop too short responses
curr_responses += [response]
Expand Down

0 comments on commit b70eb8e

Please sign in to comment.