Skip to content

Commit

Permalink
Zipformer Onnx FP16 (#1671)
Browse files Browse the repository at this point in the history
Signed-off-by: manickavela29 <[email protected]>
  • Loading branch information
manickavela29 authored Jun 27, 2024
1 parent b594a38 commit eaab2c8
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 5 deletions.
32 changes: 29 additions & 3 deletions egs/librispeech/ASR/zipformer/export-onnx-streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@
--joiner-dim 512 \
--causal True \
--chunk-size 16 \
--left-context-frames 128
--left-context-frames 128 \
--fp16 True
The --chunk-size in training is "16,32,64,-1", so we select one of them
(excluding -1) during streaming export. The same applies to `--left-context`,
Expand All @@ -73,6 +74,7 @@
import torch
import torch.nn as nn
from decoder import Decoder
from onnxconverter_common import float16
from onnxruntime.quantization import QuantType, quantize_dynamic
from scaling_converter import convert_scaled_to_non_scaled
from train import add_model_arguments, get_model, get_params
Expand Down Expand Up @@ -154,6 +156,13 @@ def get_parser():
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
)

parser.add_argument(
"--fp16",
type=str2bool,
default=False,
help="Whether to export models in fp16",
)

add_model_arguments(parser)

return parser
Expand Down Expand Up @@ -479,7 +488,6 @@ def build_inputs_outputs(tensors, i):

add_meta_data(filename=encoder_filename, meta_data=meta_data)


def export_decoder_model_onnx(
decoder_model: OnnxDecoder,
decoder_filename: str,
Expand Down Expand Up @@ -747,11 +755,29 @@ def main():
)
logging.info(f"Exported joiner to {joiner_filename}")

if(params.fp16) :
logging.info("Generate fp16 models")

encoder = onnx.load(encoder_filename)
encoder_fp16 = float16.convert_float_to_float16(encoder, keep_io_types=True)
encoder_filename_fp16 = params.exp_dir / f"encoder-{suffix}.fp16.onnx"
onnx.save(encoder_fp16,encoder_filename_fp16)

decoder = onnx.load(decoder_filename)
decoder_fp16 = float16.convert_float_to_float16(decoder, keep_io_types=True)
decoder_filename_fp16 = params.exp_dir / f"decoder-{suffix}.fp16.onnx"
onnx.save(decoder_fp16,decoder_filename_fp16)

joiner = onnx.load(joiner_filename)
joiner_fp16 = float16.convert_float_to_float16(joiner, keep_io_types=True)
joiner_filename_fp16 = params.exp_dir / f"joiner-{suffix}.fp16.onnx"
onnx.save(joiner_fp16,joiner_filename_fp16)

# Generate int8 quantization models
# See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection

logging.info("Generate int8 quantization models")

encoder_filename_int8 = params.exp_dir / f"encoder-{suffix}.int8.onnx"
quantize_dynamic(
model_input=encoder_filename,
Expand Down
30 changes: 28 additions & 2 deletions egs/librispeech/ASR/zipformer/export-onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@
--joiner-dim 512 \
--causal False \
--chunk-size "16,32,64,-1" \
--left-context-frames "64,128,256,-1"
--left-context-frames "64,128,256,-1" \
--fp16 True
It will generate the following 3 files inside $repo/exp:
- encoder-epoch-99-avg-1.onnx
Expand All @@ -70,6 +70,7 @@
import torch
import torch.nn as nn
from decoder import Decoder
from onnxconverter_common import float16
from onnxruntime.quantization import QuantType, quantize_dynamic
from scaling_converter import convert_scaled_to_non_scaled
from train import add_model_arguments, get_model, get_params
Expand Down Expand Up @@ -151,6 +152,13 @@ def get_parser():
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
)

parser.add_argument(
"--fp16",
type=str2bool,
default=False,
help="Whether to export models in fp16",
)

add_model_arguments(parser)

return parser
Expand Down Expand Up @@ -584,6 +592,24 @@ def main():
)
logging.info(f"Exported joiner to {joiner_filename}")

if(params.fp16) :
logging.info("Generate fp16 models")

encoder = onnx.load(encoder_filename)
encoder_fp16 = float16.convert_float_to_float16(encoder, keep_io_types=True)
encoder_filename_fp16 = params.exp_dir / f"encoder-{suffix}.fp16.onnx"
onnx.save(encoder_fp16,encoder_filename_fp16)

decoder = onnx.load(decoder_filename)
decoder_fp16 = float16.convert_float_to_float16(decoder, keep_io_types=True)
decoder_filename_fp16 = params.exp_dir / f"decoder-{suffix}.fp16.onnx"
onnx.save(decoder_fp16,decoder_filename_fp16)

joiner = onnx.load(joiner_filename)
joiner_fp16 = float16.convert_float_to_float16(joiner, keep_io_types=True)
joiner_filename_fp16 = params.exp_dir / f"joiner-{suffix}.fp16.onnx"
onnx.save(joiner_fp16,joiner_filename_fp16)

# Generate int8 quantization models
# See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection

Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ onnx>=1.15.0
onnxruntime>=1.16.3
onnxoptimizer
onnxsim
onnxconverter_common

# style check session:
black==22.3.0
Expand Down

0 comments on commit eaab2c8

Please sign in to comment.