Skip to content

Commit

Permalink
Jamba: update integration tests (huggingface#32250)
Browse files Browse the repository at this point in the history
* try test updates

* a few more changes

* a few more changes

* a few more changes

* [run slow] jamba

* skip logits checks on older gpus

* [run slow] jamba

* oops

* [run slow] jamba

* Update tests/models/jamba/test_modeling_jamba.py

Co-authored-by: amyeroberts <[email protected]>

* Update tests/models/jamba/test_modeling_jamba.py

Co-authored-by: amyeroberts <[email protected]>

---------

Co-authored-by: amyeroberts <[email protected]>
  • Loading branch information
2 people authored and zucchini-nlp committed Aug 30, 2024
1 parent bcc4cd5 commit d7f3dfd
Showing 1 changed file with 79 additions and 49 deletions.
128 changes: 79 additions & 49 deletions tests/models/jamba/test_modeling_jamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -650,82 +650,112 @@ def test_new_cache_format(self, num_beams, do_sample):
class JambaModelIntegrationTest(unittest.TestCase):
model = None
tokenizer = None
# This variable is used to determine which CUDA device are we using for our runners (A10 or T4)
# Depending on the hardware we get different logits / generations
cuda_compute_capability_major_version = None

@classmethod
def setUpClass(cls):
model_id = "ai21labs/Jamba-tiny-random"
cls.model = JambaForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True)
cls.tokenizer = AutoTokenizer.from_pretrained(model_id)
if is_torch_available() and torch.cuda.is_available():
# 8 is for A100 / A10 and 7 for T4
cls.cuda_compute_capability_major_version = torch.cuda.get_device_capability()[0]

@slow
def test_simple_generate(self):
# Key 9 for MI300, Key 8 for A100/A10, and Key 7 for T4.
#
# Note: Key 9 is currently set for MI300, but may need potential future adjustments for H100s,
# considering differences in hardware processing and potential deviations in generated text.
EXPECTED_TEXTS = {
7: "<|startoftext|>Hey how are you doing on this lovely evening? Canyon rins hugaughter glamour Rutgers Singh<|reserved_797|>cw algunas",
8: "<|startoftext|>Hey how are you doing on this lovely evening? Canyon rins hugaughter glamour Rutgers Singh Hebrew llam bb",
9: "<|startoftext|>Hey how are you doing on this lovely evening? Canyon rins hugaughter glamour Rutgers Singh Hebrew llam bb",
}

self.model.to(torch_device)

input_ids = self.tokenizer("Hey how are you doing on this lovely evening?", return_tensors="pt")[
"input_ids"
].to(torch_device)
out = self.model.generate(input_ids, do_sample=False, max_new_tokens=10)
output_sentence = self.tokenizer.decode(out[0, :])
self.assertEqual(
output_sentence,
"<|startoftext|>Hey how are you doing on this lovely evening? Canyon rins hugaughter glamour Rutgers Singh Hebrew cases Cats",
)
self.assertEqual(output_sentence, EXPECTED_TEXTS[self.cuda_compute_capability_major_version])

with torch.no_grad():
logits = self.model(input_ids=input_ids).logits
# TODO: there are significant differences in the logits across major cuda versions, which shouldn't exist
if self.cuda_compute_capability_major_version == 8:
with torch.no_grad():
logits = self.model(input_ids=input_ids).logits

EXPECTED_LOGITS_NO_GRAD = torch.tensor(
[
0.0140, -0.2246, 0.0408, -0.1016, 0.0471, 0.2715, -0.1465, 0.1631,
-0.2949, -0.0297, 0.0250, -0.5586, -0.2139, -0.1426, -0.1602, 0.1309,
0.0703, 0.2236, 0.1729, -0.2285, -0.1152, -0.1177, -0.1367, 0.0289,
0.1245, 0.2363, 0.0442, 0.1094, -0.1348, -0.2295, 0.1494, -0.3945,
0.1777, -0.4570, -0.0408, 0.2412, 0.1562, -0.1943, 0.2373, -0.0593
]
, dtype=torch.float32) # fmt: skip
EXPECTED_LOGITS_NO_GRAD = torch.tensor(
[
0.0134, -0.2197, 0.0396, -0.1011, 0.0459, 0.2793, -0.1465, 0.1660,
-0.2930, -0.0278, 0.0269, -0.5586, -0.2109, -0.1426, -0.1553, 0.1279,
0.0713, 0.2246, 0.1660, -0.2314, -0.1187, -0.1162, -0.1377, 0.0292,
0.1245, 0.2275, 0.0374, 0.1089, -0.1348, -0.2305, 0.1484, -0.3906,
0.1709, -0.4590, -0.0447, 0.2422, 0.1592, -0.1855, 0.2441, -0.0562
]
, dtype=torch.float32) # fmt: skip

torch.testing.assert_close(logits[0, -1, :40].cpu(), EXPECTED_LOGITS_NO_GRAD, rtol=1e-3, atol=1e-3)
torch.testing.assert_close(logits[0, -1, :40].cpu(), EXPECTED_LOGITS_NO_GRAD, rtol=1e-3, atol=1e-3)

@slow
def test_simple_batched_generate_with_padding(self):
# Key 9 for MI300, Key 8 for A100/A10, and Key 7 for T4.
#
# Note: Key 9 is currently set for MI300, but may need potential future adjustments for H100s,
# considering differences in hardware processing and potential deviations in generated text.
EXPECTED_TEXTS = {
7: [
"<|startoftext|>Hey how are you doing on this lovely evening? Canyon rins hugaughter glamour Rutgers Singh Hebrew cases Cats",
"<|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|startoftext|>Tell me a storyptus Nets Madison El chamadamodern updximVaparsed",
],
8: [
"<|startoftext|>Hey how are you doing on this lovely evening? Canyon rins hugaughter glamour Rutgers Singh<|reserved_797|>cw algunas",
"<|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|startoftext|>Tell me a storyptus Nets Madison El chamadamodern updximVaparsed",
],
9: [
"<|startoftext|>Hey how are you doing on this lovely evening? Canyon rins hugaughter glamour Rutgers Singh<|reserved_797|>cw algunas",
"<|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|startoftext|>Tell me a storyptus Nets Madison El chamadamodern updximVaparsed",
],
}

self.model.to(torch_device)

inputs = self.tokenizer(
["Hey how are you doing on this lovely evening?", "Tell me a story"], padding=True, return_tensors="pt"
).to(torch_device)
out = self.model.generate(**inputs, do_sample=False, max_new_tokens=10)
output_sentences = self.tokenizer.batch_decode(out)
self.assertEqual(
output_sentences[0],
"<|startoftext|>Hey how are you doing on this lovely evening? Canyon rins hugaughter glamour Rutgers Singh Hebrew cases Cats",
)
self.assertEqual(
output_sentences[1],
"<|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|startoftext|>Tell me a storyptus Nets Madison El chamadamodern updximVaparsed",
)
self.assertEqual(output_sentences[0], EXPECTED_TEXTS[self.cuda_compute_capability_major_version][0])
self.assertEqual(output_sentences[1], EXPECTED_TEXTS[self.cuda_compute_capability_major_version][1])

with torch.no_grad():
logits = self.model(input_ids=inputs["input_ids"]).logits

EXPECTED_LOGITS_NO_GRAD_0 = torch.tensor(
[
0.0140, -0.2246, 0.0408, -0.1016, 0.0471, 0.2715, -0.1465, 0.1631,
-0.2949, -0.0297, 0.0250, -0.5586, -0.2139, -0.1426, -0.1602, 0.1309,
0.0703, 0.2236, 0.1729, -0.2285, -0.1152, -0.1177, -0.1367, 0.0289,
0.1245, 0.2363, 0.0442, 0.1094, -0.1348, -0.2295, 0.1494, -0.3945,
0.1777, -0.4570, -0.0408, 0.2412, 0.1562, -0.1943, 0.2373, -0.0593
]
, dtype=torch.float32) # fmt: skip

EXPECTED_LOGITS_NO_GRAD_1 = torch.tensor(
[
-0.1289, 0.2363, -0.4180, -0.0302, -0.0476, 0.0327, 0.2578, 0.0874,
0.1484, 0.2305, -0.1152, -0.1396, -0.1494, -0.1113, -0.0021, -0.2832,
0.2002, -0.2676, 0.0598, -0.1982, -0.2539, -0.1133, -0.1973, 0.2148,
0.0559, 0.1670, 0.1846, 0.1270, 0.1680, -0.1250, -0.2656, -0.2871,
0.2344, 0.2637, 0.0510, -0.1855, 0.2158, -0.1289, 0.1758, 0.0074
]
, dtype=torch.float32) # fmt: skip

torch.testing.assert_close(logits[0, -1, :40].cpu(), EXPECTED_LOGITS_NO_GRAD_0, rtol=1e-3, atol=1e-3)
torch.testing.assert_close(logits[1, -1, :40].cpu(), EXPECTED_LOGITS_NO_GRAD_1, rtol=1e-3, atol=1e-3)
# TODO: there are significant differences in the logits across major cuda versions, which shouldn't exist
if self.cuda_compute_capability_major_version == 8:
with torch.no_grad():
logits = self.model(input_ids=inputs["input_ids"]).logits

EXPECTED_LOGITS_NO_GRAD_0 = torch.tensor(
[
0.0166, -0.2227, 0.0396, -0.1035, 0.0459, 0.2754, -0.1445, 0.1641,
-0.2910, -0.0273, 0.0227, -0.5547, -0.2139, -0.1396, -0.1582, 0.1289,
0.0713, 0.2256, 0.1699, -0.2295, -0.1182, -0.1167, -0.1387, 0.0261,
0.1270, 0.2285, 0.0403, 0.1108, -0.1318, -0.2334, 0.1455, -0.3945,
0.1729, -0.4609, -0.0410, 0.2412, 0.1572, -0.1895, 0.2402, -0.0583
]
, dtype=torch.float32) # fmt: skip

EXPECTED_LOGITS_NO_GRAD_1 = torch.tensor(
[
-0.1318, 0.2354, -0.4160, -0.0325, -0.0461, 0.0342, 0.2578, 0.0874,
0.1484, 0.2266, -0.1182, -0.1396, -0.1494, -0.1089, -0.0019, -0.2852,
0.1973, -0.2676, 0.0586, -0.1992, -0.2520, -0.1147, -0.1973, 0.2129,
0.0520, 0.1699, 0.1816, 0.1289, 0.1699, -0.1216, -0.2656, -0.2891,
0.2363, 0.2656, 0.0488, -0.1875, 0.2148, -0.1250, 0.1816, 0.0077
]
, dtype=torch.float32) # fmt: skip

torch.testing.assert_close(logits[0, -1, :40].cpu(), EXPECTED_LOGITS_NO_GRAD_0, rtol=1e-3, atol=1e-3)
torch.testing.assert_close(logits[1, -1, :40].cpu(), EXPECTED_LOGITS_NO_GRAD_1, rtol=1e-3, atol=1e-3)

0 comments on commit d7f3dfd

Please sign in to comment.