Skip to content

Commit

Permalink
Merge pull request #109 from BlackSamorez/sharded_tests
Browse files Browse the repository at this point in the history
Testing interfaces (soon to be refactored)
  • Loading branch information
Andrei Panferov authored Jul 25, 2023
2 parents 31329ad + b915ad8 commit aa3dd94
Showing 1 changed file with 14 additions and 0 deletions.
14 changes: 14 additions & 0 deletions tests/test_legacy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import torch
from transformers import BertModel, PreTrainedModel

from tensor_parallel import Sharded, tensor_parallel


def test_legacy_factory_and_sharded():
model = BertModel.from_pretrained("bert-base-uncased")

tp_model = tensor_parallel(model, sharded=False)
assert isinstance(tp_model, PreTrainedModel)
tp_model.wrapped_model = Sharded(tp_model.wrapped_model)

tp_model(torch.zeros(1, 8, dtype=int))

0 comments on commit aa3dd94

Please sign in to comment.