Skip to content

Commit

Permalink
Replace with versions from transformers and torch
Browse files Browse the repository at this point in the history
  • Loading branch information
loadams committed Sep 13, 2023
1 parent 9cea1d1 commit fa9fe94
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 11 deletions.
6 changes: 3 additions & 3 deletions tests/unit/compression/test_compression.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
import numpy as np
from unit.megatron_model import get_gpt2_model
from deepspeed.compression.compress import init_compression
from unit.modeling import BertConfig
from unit.modelingpreln import BertEncoder as BertEncoderPreln
from transformers.models.bert.configuration_bert import BertConfig
from transformers.models.bert.modeling_bert import BertEncoder as BertEncoder
from deepspeed.compression.basic_layer import LinearLayer_Compress, ColumnParallelLinear_Compress, RowParallelLinear_Compress
from deepspeed.compression.helper import convert_conv1d_to_linear
from deepspeed.accelerator import get_accelerator
Expand Down Expand Up @@ -63,7 +63,7 @@ def create_bert_model():
biases.append(torch.nn.Parameter(torch.Tensor(hidden_size)))
biases.append(torch.nn.Parameter(torch.Tensor(hidden_size)))

return BertEncoderPreln(bert_config, weights, biases)
return BertEncoder(bert_config, weights, biases)


class Conv1D(torch.nn.Module):
Expand Down
10 changes: 6 additions & 4 deletions tests/unit/ops/accelerators/test_accelerator_backward.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,12 @@
from torch import nn
from deepspeed import DeepSpeedTransformerLayer, DeepSpeedTransformerConfig
from deepspeed.accelerator import get_accelerator
from unit.modeling import BertConfig, BertLayerNorm, BertEncoder as BertEncoderPostln
from unit.modelingpreln import BertEncoder as BertEncoderPreln
from transformers.models.bert.configuration_bert import BertConfig
from transformers.models.bert.modeling_bert import BertEncoder
from unit.common import DistributedTest, is_rocm_pytorch

BertLayerNorm = torch.nn.LayerNorm

#if not deepspeed.ops.__installed_ops__['transformer']:
#pytest.skip(
# "transformer kernels are temporarily disabled because of unexplained failures",
Expand Down Expand Up @@ -194,9 +196,9 @@ def create_models(ds_config):
biases[7].data.zero_()

if (ds_config.pre_layer_norm):
bert_encoder = BertEncoderPreln(bert_config, weights, biases)
bert_encoder = BertEncoder(bert_config, weights, biases)
else:
bert_encoder = BertEncoderPostln(bert_config, weights, biases)
bert_encoder = BertEncoder(bert_config, weights, biases)
ds_encoder = DSEncoder(ds_config, weights, biases)

if ds_config.fp16:
Expand Down
10 changes: 6 additions & 4 deletions tests/unit/ops/accelerators/test_accelerator_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,14 @@
import random
import copy
from torch import nn
from unit.modelingpreln import BertEncoder as BertEncoderPreln
from unit.modeling import BertLayerNorm, BertConfig, BertEncoder as BertEncoderPostln
from transformers.models.bert.configuration_bert import BertConfig
from transformers.models.bert.modeling_bert import BertEncoder
from deepspeed import DeepSpeedTransformerLayer, DeepSpeedTransformerConfig
from deepspeed.accelerator import get_accelerator
from unit.common import DistributedTest

BertLayerNorm = torch.nn.LayerNorm

if torch.half not in get_accelerator().supported_dtypes():
pytest.skip(f"fp16 not supported, valid dtype: {get_accelerator().supported_dtypes()}", allow_module_level=True)

Expand Down Expand Up @@ -139,9 +141,9 @@ def create_models(ds_config):
biases[7].data.zero_()

if (ds_config.pre_layer_norm):
bert_encoder = BertEncoderPreln(bert_config, weights, biases)
bert_encoder = BertEncoder(bert_config, weights, biases)
else:
bert_encoder = BertEncoderPostln(bert_config, weights, biases)
bert_encoder = BertEncoder(bert_config, weights, biases)
ds_encoder = DSEncoder(ds_config, weights, biases)

if ds_config.fp16:
Expand Down

0 comments on commit fa9fe94

Please sign in to comment.