Skip to content

Commit

Permalink
Fixed bug for PL=1
Browse files Browse the repository at this point in the history
Signed-off-by: Amit Raj <[email protected]>
  • Loading branch information
quic-amitraj committed Oct 10, 2024
1 parent 08ca83c commit da61970
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 3 deletions.
9 changes: 7 additions & 2 deletions QEfficient/compile/compile_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,15 @@ def create_and_dump_specializations(
"batch_size": str(batch_size),
"seq_len": str(prompt_len),
"ctx_len": str(ctx_len),
},
{"batch_size": str(batch_size), "seq_len": "1", "ctx_len": str(ctx_len)},
}
]
}

# To handle repetative input in specializations when prompt_len is 1
if prompt_len != 1:
specializations["specializations"].append(
{"batch_size": str(batch_size), "seq_len": "1", "ctx_len": str(ctx_len)}
)
# If continuous batching is enabled by proving full_batch_size we need to add FBS to the specialization file and update the batch size of decoder part to FBS
if full_batch_size is not None:
specializations["specializations"][0]["full_batch_size"] = str(full_batch_size)
Expand Down
5 changes: 4 additions & 1 deletion QEfficient/generation/text_generation_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,7 +422,10 @@ def _fetch_vocab_size(
Returns:
vocab_size: The vocabulary size fetched from the session's allowed shapes.
"""
return [x[self.session.binding_index_map["logits"]] for x in self.session.allowed_shapes][0][1][2]
if self.session.allowed_shapes:
return [x[self.session.binding_index_map["logits"]] for x in self.session.allowed_shapes][0][1][2]

return self.session.bindings[self.session.binding_index_map["logits"]].dims[2]

def _fetch_generation_len(self, generation_len, max_gen_len):
"""
Expand Down

0 comments on commit da61970

Please sign in to comment.