Skip to content

Commit

Permalink
add attention mask
Browse files Browse the repository at this point in the history
  • Loading branch information
amva13 committed Oct 27, 2024
1 parent 05f9e9f commit 744912a
Showing 1 changed file with 11 additions and 4 deletions.
15 changes: 11 additions & 4 deletions tdc/test/test_model_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,10 +132,17 @@ def testGeneformerTokenizer(self):
# input_tensor = torch.squeeze(input_tensor)
out = []
try:
ctr = 0 # stop after some passes to avoid failure
for batch in input_tensor:
out.append(model(batch))
# build an attention mask
attention_mask = torch.tensor([[x[0]!=0, x[1]!=0] for x in batch])
out.append(model(batch, attention_mask=attention_mask))
if ctr == 2:
break
ctr += 1
except Exception as e:
raise Exception("tensor shape is", input_tensor.shape, "exception was:", e, "\n cells was\n", cells)
# raise Exception("tensor shape is", input_tensor.shape, "exception was:", e, "\n cells was\n", cells)
raise Exception(e)

# input_tensor = torch.tensor(cells)
# input_tensor_squeezed = torch.squeeze(input_tensor)
Expand All @@ -148,8 +155,8 @@ def testGeneformerTokenizer(self):
# except Exception as e:
# raise Exception("tensor shape is", input_tensor.shape, "exception was: {}".format(e), "input_tensor_squeezed is\n", input_tensor, "\n\ninput_tensor normal is: {}".format(input_tensor))
assert out, "FAILURE: Geneformer output is false-like. Value = {}".format(out)
assert len(out) == input_tensor.shape[0], "FAILURE: Geneformer output and input tensor input don't have the same length. {} vs {}".format(len(out), input_tensor.shape[0])
assert len(out) == len(cells), "FAILURE: Geneformer output and tokenized cells don't have the same length. {} vs {}".format(len(out), len(cells))
assert len(out[0]) == input_tensor.shape[1], "FAILURE: Geneformer output and input tensor input don't have the same length. {} vs {}".format(len(out[0]), input_tensor.shape[1])
assert len(out[0][0]) == input_tensor.shape[2], "FAILURE: Geneformer output and tokenized cells don't have the same length. {} vs {}".format(len(out[0][0]), input_tensor.shape[2])

def tearDown(self):
try:
Expand Down

0 comments on commit 744912a

Please sign in to comment.