Skip to content

Commit

Permalink
Fixed tests
Browse files Browse the repository at this point in the history
  • Loading branch information
tgaddair committed Nov 1, 2024
1 parent 35c7de2 commit 295829f
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 1 deletion.
1 change: 1 addition & 0 deletions server/tests/adapters/test_medusa.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def test_batched_medusa_weights(default_causal_lm: CausalLM):

meta = AdapterBatchMetadata(
adapter_indices=torch.tensor([0, 0, 1, 1, 0, 0, 1, 1], dtype=torch.int64),
adapter_list=[0, 1, 0, 1],
adapter_set={0, 1},
adapter_segments=torch.tensor([0, 2, 4, 6, 8], dtype=torch.int64),
segment_indices=[0, 1, 0, 1],
Expand Down
5 changes: 4 additions & 1 deletion server/tests/utils/test_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ def test_batched_lora_weights(lora_ranks: List[int]):

meta = AdapterBatchMetadata(
adapter_indices=torch.tensor([0, 0, 1, 1, 0, 0, 1, 1], dtype=torch.int64),
adapter_list=[0, 1, 0, 1],
adapter_set={0, 1},
adapter_segments=torch.tensor([0, 2, 4, 6, 8], dtype=torch.int64),
segment_indices=[0, 1, 0, 1],
Expand Down Expand Up @@ -149,6 +150,7 @@ def test_batched_lora_weights_decode(

meta = AdapterBatchMetadata(
adapter_indices=torch.tensor(adapter_indices, dtype=torch.int64),
adapter_list=adapter_indices,
adapter_set=set(adapter_indices),
adapter_segments=torch.tensor(segments, dtype=torch.int64),
segment_indices=segment_indices,
Expand Down Expand Up @@ -193,7 +195,8 @@ def test_batched_lora_weights_no_segments():

meta = AdapterBatchMetadata(
adapter_indices=torch.tensor([0, 0, 0, 0], dtype=torch.int64),
adapter_set={0, 1},
adapter_list=[0],
adapter_set={0},
adapter_segments=torch.tensor([0, 4], dtype=torch.int64),
segment_indices=[0],
)
Expand Down

0 comments on commit 295829f

Please sign in to comment.