Skip to content

Commit

Permalink
fix question answering error when running pipeline model on cuda
Browse files Browse the repository at this point in the history
  • Loading branch information
imatiach-msft committed Aug 29, 2023
1 parent 6edbfd3 commit e4ba134
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,12 @@ def predict_qa(self, questions, start):
for q in questions:
question, context = q.split(SEP)
d = self._qa_model.tokenizer(question, context)
device = self._qa_model.device
out = self._qa_model.model.forward(
**{k: torch.tensor(d[k]).reshape(1, -1) for k in d})
**{k: torch.tensor(
d[k], device=device).reshape(1, -1) for k in d})
logits = out.start_logits if start else out.end_logits
outs.append(logits.reshape(-1).detach().numpy())
outs.append(logits.reshape(-1).detach().cpu().numpy())
return outs

def predict_qa_start(self, questions):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,8 @@ def _initialize_managers(self):
self.image_mode,
self.max_evals,
self.num_masks,
self.mask_res)
self.mask_res,
self.device)
self._error_analysis_manager = ErrorAnalysisManager(
self._wrapped_model, self.test, self._ext_test_df,
self.target_column,
Expand Down

0 comments on commit e4ba134

Please sign in to comment.