From 2fbdece9adab24b3c93b54b937f59b9482b178a8 Mon Sep 17 00:00:00 2001 From: punithsekar Date: Wed, 10 Jul 2024 07:52:37 +0000 Subject: [PATCH] #10270: yolov4 implementation --- .../yolov4/reference/downsample1.py | 73 ++++ .../yolov4/reference/downsample2.py | 65 ++++ .../yolov4/reference/downsample3.py | 56 +++ .../yolov4/reference/downsample4.py | 58 +++ .../yolov4/reference/downsample5.py | 60 +++ models/experimental/yolov4/reference/head.py | 146 +++++++ models/experimental/yolov4/reference/neck.py | 206 ++++++++++ .../experimental/yolov4/reference/resblock.py | 28 ++ .../experimental/yolov4/reference/yolov4.py | 38 ++ models/experimental/yolov4/ttnn/common.py | 100 +++++ .../experimental/yolov4/ttnn/downsample1.py | 98 +++++ .../experimental/yolov4/ttnn/downsample2.py | 104 +++++ .../experimental/yolov4/ttnn/downsample3.py | 165 ++++++++ .../experimental/yolov4/ttnn/downsample4.py | 165 ++++++++ .../experimental/yolov4/ttnn/downsample5.py | 127 ++++++ models/experimental/yolov4/ttnn/head.py | 281 ++++++++++++++ models/experimental/yolov4/ttnn/neck.py | 362 ++++++++++++++++++ models/experimental/yolov4/ttnn/yolov4.py | 113 ++++++ 18 files changed, 2245 insertions(+) create mode 100644 models/experimental/yolov4/reference/downsample1.py create mode 100644 models/experimental/yolov4/reference/downsample2.py create mode 100644 models/experimental/yolov4/reference/downsample3.py create mode 100644 models/experimental/yolov4/reference/downsample4.py create mode 100644 models/experimental/yolov4/reference/downsample5.py create mode 100644 models/experimental/yolov4/reference/head.py create mode 100644 models/experimental/yolov4/reference/neck.py create mode 100644 models/experimental/yolov4/reference/resblock.py create mode 100644 models/experimental/yolov4/reference/yolov4.py create mode 100644 models/experimental/yolov4/ttnn/common.py create mode 100644 models/experimental/yolov4/ttnn/downsample1.py create mode 100644 models/experimental/yolov4/ttnn/downsample2.py create mode 100644 models/experimental/yolov4/ttnn/downsample3.py create mode 100644 models/experimental/yolov4/ttnn/downsample4.py create mode 100644 models/experimental/yolov4/ttnn/downsample5.py create mode 100644 models/experimental/yolov4/ttnn/head.py create mode 100644 models/experimental/yolov4/ttnn/neck.py create mode 100644 models/experimental/yolov4/ttnn/yolov4.py diff --git a/models/experimental/yolov4/reference/downsample1.py b/models/experimental/yolov4/reference/downsample1.py new file mode 100644 index 00000000000..2addb26824d --- /dev/null +++ b/models/experimental/yolov4/reference/downsample1.py @@ -0,0 +1,73 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + + +import torch +import torch.nn as nn + + +class DownSample1(nn.Module): + def __init__(self): + super().__init__() + self.c1 = nn.Conv2d(3, 32, 3, 1, 1, bias=False) + self.b1 = nn.BatchNorm2d(32) + self.relu = nn.ReLU(inplace=True) + + self.c2 = nn.Conv2d(32, 64, 3, 2, 1, bias=False) + self.b2 = nn.BatchNorm2d(64) + + self.c3 = nn.Conv2d(64, 64, 1, 1, 0, bias=False) + self.b3 = nn.BatchNorm2d(64) + + self.c4 = nn.Conv2d(64, 64, 1, 1, 0, bias=False) + self.b4 = nn.BatchNorm2d(64) + + self.c5 = nn.Conv2d(64, 32, 1, 1, 0, bias=False) + self.b5 = nn.BatchNorm2d(32) + + self.c6 = nn.Conv2d(32, 64, 3, 1, 1, bias=False) + self.b6 = nn.BatchNorm2d(64) + + self.c7 = nn.Conv2d(64, 64, 1, 1, 0, bias=False) + self.b7 = nn.BatchNorm2d(64) + + self.c8 = nn.Conv2d(128, 64, 1, 1, 0, bias=False) + self.b8 = nn.BatchNorm2d(64) + + def forward(self, input: torch.Tensor): + x1 = self.c1(input) + x1_b = self.b1(x1) + x1_m = self.relu(x1_b) + + x2 = self.c2(x1_m) + x2_b = self.b2(x2) + x2_m = self.relu(x2_b) + + x3 = self.c3(x2_m) + x3_b = self.b3(x3) + x3_m = self.relu(x3_b) + + x4 = self.c4(x2_m) + x4_b = self.b4(x4) + x4_m = self.relu(x4_b) + + x5 = self.c5(x4_m) + x5_b = self.b5(x5) + x5_m = self.relu(x5_b) + + x6 = self.c6(x5_m) + x6_b = self.b6(x6) + x6_m = self.relu(x6_b) + x6_m = x6_m + x4_m + + x7 = self.c7(x6_m) + x7_b = self.b7(x7) + x7_m = self.relu(x7_b) + x7_m = torch.cat([x7_m, x3_m], dim=1) + + x8 = self.c8(x7_m) + x8_b = self.b8(x8) + x8_m = self.relu(x8_b) + + return x8_m diff --git a/models/experimental/yolov4/reference/downsample2.py b/models/experimental/yolov4/reference/downsample2.py new file mode 100644 index 00000000000..162cd80d7ca --- /dev/null +++ b/models/experimental/yolov4/reference/downsample2.py @@ -0,0 +1,65 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + + +import torch +import torch.nn as nn +from models.experimental.yolov4.reference.resblock import ResBlock + + +class Mish(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + x = x * (torch.tanh(torch.nn.functional.softplus(x))) + return x + + +class DownSample2(nn.Module): + def __init__(self): + super().__init__() + self.c1 = nn.Conv2d(64, 128, 3, 2, 1, bias=False) + self.b1 = nn.BatchNorm2d(128) + self.relu = Mish() + + self.c2 = nn.Conv2d(128, 64, 1, 1, 0, bias=False) + self.b2 = nn.BatchNorm2d(64) + + self.c3 = nn.Conv2d(128, 64, 1, 1, 0, bias=False) + self.b3 = nn.BatchNorm2d(64) + + self.res = ResBlock(ch=64, nblocks=2) + + self.c4 = nn.Conv2d(64, 64, 1, 1, 0, bias=False) + self.b4 = nn.BatchNorm2d(64) + + self.c5 = nn.Conv2d(128, 128, 1, 1, 0, bias=False) + self.b5 = nn.BatchNorm2d(128) + + def forward(self, input: torch.Tensor): + x1 = self.c1(input) + x1_b = self.b1(x1) + x1_m = self.relu(x1_b) + + x2 = self.c2(x1_m) + x2_b = self.b2(x2) + x2_m = self.relu(x2_b) + + x3 = self.c3(x1_m) + x3_b = self.b3(x3) + x3_m = self.relu(x3_b) + + r1 = self.res(x3_m) + + x4 = self.c4(r1) + x4_b = self.b4(x4) + x4_m = self.relu(x4_b) + + x4_m = torch.cat([x4_m, x2_m], dim=1) + + x5 = self.c5(x4_m) + x5_b = self.b5(x5) + x5_m = self.relu(x5_b) + return x5_m diff --git a/models/experimental/yolov4/reference/downsample3.py b/models/experimental/yolov4/reference/downsample3.py new file mode 100644 index 00000000000..6c67e9eaf1c --- /dev/null +++ b/models/experimental/yolov4/reference/downsample3.py @@ -0,0 +1,56 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + + +import torch +import torch.nn as nn +from models.experimental.yolov4.reference.resblock import ResBlock + + +class DownSample3(nn.Module): + def __init__(self): + super().__init__() + self.c1 = nn.Conv2d(128, 256, 3, 2, 1, bias=False) + self.b1 = nn.BatchNorm2d(256) + self.relu = nn.ReLU(inplace=True) + + self.c2 = nn.Conv2d(256, 128, 1, 1, bias=False) + self.b2 = nn.BatchNorm2d(128) + + self.c3 = nn.Conv2d(256, 128, 1, 1, bias=False) + self.b3 = nn.BatchNorm2d(128) + + self.res = ResBlock(128, 8) + + self.c4 = nn.Conv2d(128, 128, 1, 1, bias=False) + self.b4 = nn.BatchNorm2d(128) + + self.c5 = nn.Conv2d(256, 256, 1, 1, bias=False) + self.b5 = nn.BatchNorm2d(256) + + def forward(self, input: torch.Tensor): + x1 = self.c1(input) + x1_b = self.b1(x1) + x1_m = self.relu(x1_b) + + x2 = self.c2(x1_m) + x2_b = self.b2(x2) + x2_m = self.relu(x2_b) + + x3 = self.c3(x1_m) + x3_b = self.b3(x3) + x3_m = self.relu(x3_b) + + r1 = self.res(x3_m) + + x4 = self.c4(r1) + x4_b = self.b4(x4) + x4_m = self.relu(x4_b) + + x4_m = torch.cat([x4_m, x2_m], dim=1) + + x5 = self.c5(x4_m) + x5_b = self.b5(x5) + x5_m = self.relu(x5_b) + return x5_m diff --git a/models/experimental/yolov4/reference/downsample4.py b/models/experimental/yolov4/reference/downsample4.py new file mode 100644 index 00000000000..efbbaa342af --- /dev/null +++ b/models/experimental/yolov4/reference/downsample4.py @@ -0,0 +1,58 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + + +import torch +import torch.nn as nn +from models.experimental.yolov4.reference.resblock import ResBlock + + +class DownSample4(nn.Module): + def __init__(self): + super().__init__() + self.c1 = nn.Conv2d(256, 512, 3, 2, 1, bias=False) + self.b1 = nn.BatchNorm2d(512) + self.relu = nn.ReLU(inplace=True) + + self.c2 = nn.Conv2d(512, 256, 1, 1, 0, bias=False) + self.b2 = nn.BatchNorm2d(256) + + self.c3 = nn.Conv2d(512, 256, 1, 1, 0, bias=False) + self.b3 = nn.BatchNorm2d(256) + + self.res = ResBlock(256, 8) + + self.c4 = nn.Conv2d(256, 256, 1, 1, 0, bias=False) + self.b4 = nn.BatchNorm2d(256) + + self.c5 = nn.Conv2d(512, 512, 1, 1, 0, bias=False) + self.b5 = nn.BatchNorm2d(512) + + def forward(self, input: torch.Tensor): + x1 = self.c1(input) + x1_b = self.b1(x1) + x1_m = self.relu(x1_b) + + x2 = self.c2(x1_m) + x2_b = self.b2(x2) + x2_m = self.relu(x2_b) + + x3 = self.c3(x1_m) + x3_b = self.b3(x3) + x3_m = self.relu(x3_b) + + # resblock + r = self.res(x3_m) + + x4 = self.c4(r) + x4_b = self.b4(x4) + x4_m = self.relu(x4_b) + + x4_m = torch.cat([x4_m, x2_m], dim=1) + + x5 = self.c5(x4_m) + x5_b = self.b5(x5) + x5_m = self.relu(x5_b) + + return x5_m diff --git a/models/experimental/yolov4/reference/downsample5.py b/models/experimental/yolov4/reference/downsample5.py new file mode 100644 index 00000000000..0476cfd26d0 --- /dev/null +++ b/models/experimental/yolov4/reference/downsample5.py @@ -0,0 +1,60 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + + +import torch +import torch.nn as nn +from models.experimental.yolov4.reference.resblock import ResBlock + + +class DownSample5(nn.Module): + def __init__(self): + super().__init__() + self.c1 = nn.Conv2d(512, 1024, 3, 2, 1, bias=False) + self.b1 = nn.BatchNorm2d(1024) + self.relu = nn.ReLU(inplace=True) + + self.c2 = nn.Conv2d(1024, 512, 1, 1, bias=False) + self.b2 = nn.BatchNorm2d(512) + + self.c3 = nn.Conv2d(1024, 512, 1, 1, bias=False) + self.b3 = nn.BatchNorm2d(512) + + self.res = ResBlock(512, 4) + + self.c4 = nn.Conv2d(512, 512, 1, 1, bias=False) + self.b4 = nn.BatchNorm2d(512) + self.relu = nn.ReLU(inplace=True) + + self.c5 = nn.Conv2d(1024, 1024, 1, 1, bias=False) + self.b5 = nn.BatchNorm2d(1024) + self.relu = nn.ReLU(inplace=True) + + def forward(self, input: torch.Tensor): + x1 = self.c1(input) + x1_b = self.b1(x1) + x1_m = self.relu(x1_b) + + x2 = self.c2(x1_m) + x2_b = self.b2(x2) + x2_m = self.relu(x2_b) + + x3 = self.c3(x1_m) + x3_b = self.b3(x3) + x3_m = self.relu(x3_b) + + # resblock + r = self.res(x3_m) + + x4 = self.c4(r) + x4_b = self.b4(x4) + x4_m = self.relu(x4_b) + + x4_m = torch.cat([x4_m, x2_m], dim=1) + + x5 = self.c5(x4_m) + x5_b = self.b5(x5) + x5_m = self.relu(x5_b) + + return x5_m diff --git a/models/experimental/yolov4/reference/head.py b/models/experimental/yolov4/reference/head.py new file mode 100644 index 00000000000..0ade9c92f6b --- /dev/null +++ b/models/experimental/yolov4/reference/head.py @@ -0,0 +1,146 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import torch +import torch.nn as nn + + +class Head(nn.Module): + def __init__(self): + super().__init__() + # left side of graph + # in_chan, out_chan, kernel, stride, + output_ch = 255 + + self.c1 = nn.Conv2d(128, 256, 3, 1, 1, bias=False) + self.b1 = nn.BatchNorm2d(256) + self.relu = nn.LeakyReLU(0.1, inplace=True) + + self.c2 = nn.Conv2d(256, output_ch, 1, 1, 0, bias=True) + + # R -4 + self.c3 = nn.Conv2d(128, 256, 3, 2, 1, bias=False) + self.b3 = nn.BatchNorm2d(256) + + # R -1 -16 + self.c4 = nn.Conv2d(512, 256, 1, 1, 0, bias=False) + self.b4 = nn.BatchNorm2d(256) + + self.c5 = nn.Conv2d(256, 512, 3, 1, 1, bias=False) + self.b5 = nn.BatchNorm2d(512) + + self.c6 = nn.Conv2d(512, 256, 1, 1, 0, bias=False) + self.b6 = nn.BatchNorm2d(256) + + self.c7 = nn.Conv2d(256, 512, 3, 1, 1, bias=False) + self.b7 = nn.BatchNorm2d(512) + + self.c8 = nn.Conv2d(512, 256, 1, 1, 0, bias=False) + self.b8 = nn.BatchNorm2d(256) + + self.c9 = nn.Conv2d(256, 512, 3, 1, 1, bias=False) + self.b9 = nn.BatchNorm2d(512) + + self.c10 = nn.Conv2d(512, output_ch, 1, 1, 0, bias=True) + + # R -4 + self.c11 = nn.Conv2d(256, 512, 3, 2, 1, bias=False) + self.b11 = nn.BatchNorm2d(512) + + self.c12 = nn.Conv2d(1024, 512, 1, 1, 0, bias=False) + self.b12 = nn.BatchNorm2d(512) + + self.c13 = nn.Conv2d(512, 1024, 3, 1, 1, bias=False) + self.b13 = nn.BatchNorm2d(1024) + + self.c14 = nn.Conv2d(1024, 512, 1, 1, 0, bias=False) + self.b14 = nn.BatchNorm2d(512) + + self.c15 = nn.Conv2d(512, 1024, 3, 1, 1, bias=False) + self.b15 = nn.BatchNorm2d(1024) + + self.c16 = nn.Conv2d(1024, 512, 1, 1, 0, bias=False) + self.b16 = nn.BatchNorm2d(512) + + self.c17 = nn.Conv2d(512, 1024, 3, 1, 1, bias=False) + self.b17 = nn.BatchNorm2d(1024) + + self.c18 = nn.Conv2d(1024, output_ch, 1, 1, 0, bias=True) + + def forward(self, inputs): + x1 = self.c1(inputs[0]) + x1 = self.b1(x1) + x1 = self.relu(x1) + + x2 = self.c2(x1) + + x3 = self.c3(inputs[0]) + x3 = self.b3(x3) + x3 = self.relu(x3) + + # R -1 -16 + outfromNeck1 = inputs[2] + x3 = torch.cat([x3, outfromNeck1], dim=1) + + x4 = self.c4(x3) + x4 = self.b4(x4) + x4 = self.relu(x4) + + x5 = self.c5(x4) + x5 = self.b5(x5) + x5 = self.relu(x5) + + x6 = self.c6(x5) + x6 = self.b6(x6) + x6 = self.relu(x6) + + x7 = self.c7(x6) + x7 = self.b7(x7) + x7 = self.relu(x7) + + x8 = self.c8(x7) + x8 = self.b8(x8) + x8 = self.relu(x8) + + x9 = self.c9(x8) + x9 = self.b9(x9) + x9 = self.relu(x9) + + x10 = self.c10(x9) + + # R -4 + x11 = self.c11(x8) + x11 = self.b11(x11) + x11 = self.relu(x11) + + # R -1 -37 + outfromNeck2 = inputs[1] + x11 = torch.cat([x11, outfromNeck2], dim=1) + + x12 = self.c12(x11) + x12 = self.b12(x12) + x12 = self.relu(x12) + + x13 = self.c13(x12) + x13 = self.b13(x13) + x13 = self.relu(x13) + + x14 = self.c14(x13) + x14 = self.b14(x14) + x14 = self.relu(x14) + + x15 = self.c15(x14) + x15 = self.b15(x15) + x15 = self.relu(x15) + + x16 = self.c16(x15) + x16 = self.b16(x16) + x16 = self.relu(x16) + + x17 = self.c17(x16) + x17 = self.b17(x17) + x17 = self.relu(x17) + + x18 = self.c18(x17) + return x2, x10, x18 diff --git a/models/experimental/yolov4/reference/neck.py b/models/experimental/yolov4/reference/neck.py new file mode 100644 index 00000000000..4ffbfb1dcef --- /dev/null +++ b/models/experimental/yolov4/reference/neck.py @@ -0,0 +1,206 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + + +import torch +import torch.nn as nn + + +class Neck(nn.Module): + def __init__(self): + super().__init__() + self.c1 = nn.Conv2d(1024, 512, 1, 1, 0, bias=False) + self.b1 = nn.BatchNorm2d(512) + self.relu = nn.LeakyReLU(0.1, inplace=True) + + self.c2 = nn.Conv2d(512, 1024, 3, 1, 1, bias=False) + self.b2 = nn.BatchNorm2d(1024) + + self.c3 = nn.Conv2d(1024, 512, 1, 1, 0, bias=False) + self.b3 = nn.BatchNorm2d(512) + + # 3 maxpools + self.p1 = nn.MaxPool2d(kernel_size=5, stride=1, padding=2, dilation=1, ceil_mode=False) + self.p2 = nn.MaxPool2d(kernel_size=9, stride=1, padding=4, dilation=1, ceil_mode=False) + self.p3 = nn.MaxPool2d(kernel_size=13, stride=1, padding=6, dilation=1, ceil_mode=False) + #### + + self.c4 = nn.Conv2d(2048, 512, 1, 1, 0, bias=False) + self.b4 = nn.BatchNorm2d(512) + + self.c5 = nn.Conv2d(512, 1024, 3, 1, 1, bias=False) + self.b5 = nn.BatchNorm2d(1024) + + self.c6 = nn.Conv2d(1024, 512, 1, 1, 0, bias=False) + self.b6 = nn.BatchNorm2d(512) + + self.c7 = nn.Conv2d(512, 256, 1, 1, 0, bias=False) + self.b7 = nn.BatchNorm2d(256) + + # 2 upsample2d + self.u = nn.Upsample(scale_factor=(2, 2), mode="nearest") + + self.c7_2 = nn.Conv2d(512, 256, 1, 1, 0, bias=False) + self.b7_2 = nn.BatchNorm2d(256) + + self.c7_3 = nn.Conv2d(512, 256, 1, 1, 0, bias=False) + self.b7_3 = nn.BatchNorm2d(256) + + self.c8 = nn.Conv2d(256, 512, 3, 1, 1, bias=False) + self.b8 = nn.BatchNorm2d(512) + + self.c7_4 = nn.Conv2d(512, 256, 1, 1, 0, bias=False) + self.b7_4 = nn.BatchNorm2d(256) + + self.c8_2 = nn.Conv2d(256, 512, 3, 1, 1, bias=False) + self.b8_2 = nn.BatchNorm2d(512) + + self.c7_5 = nn.Conv2d(512, 256, 1, 1, 0, bias=False) + self.b7_5 = nn.BatchNorm2d(256) + + self.c9 = nn.Conv2d(256, 128, 1, 1, 0, bias=False) + self.b9 = nn.BatchNorm2d(128) + + self.c9_2 = nn.Conv2d(256, 128, 1, 1, 0, bias=False) + self.b9_2 = nn.BatchNorm2d(128) + self.c9_3 = nn.Conv2d(256, 128, 1, 1, 0, bias=False) + self.b9_3 = nn.BatchNorm2d(128) + + self.c10 = nn.Conv2d(128, 256, 3, 1, 1, bias=False) + self.b10 = nn.BatchNorm2d(256) + + self.c9_4 = nn.Conv2d(256, 128, 1, 1, 0, bias=False) + self.b9_4 = nn.BatchNorm2d(128) + self.c10_2 = nn.Conv2d(128, 256, 3, 1, 1, bias=False) + self.b10_2 = nn.BatchNorm2d(256) + self.c9_5 = nn.Conv2d(256, 128, 1, 1, 0, bias=False) + self.b9_5 = nn.BatchNorm2d(128) + + def forward(self, inputs): + # 3 CBN blocks + x1 = self.c1(inputs[0]) + x1_b = self.b1(x1) + x1_m = self.relu(x1_b) + + x2 = self.c2(x1_m) + x2_b = self.b2(x2) + x2_m = self.relu(x2_b) + + x3 = self.c3(x2_m) + x3_b = self.b3(x3) + x3_m = self.relu(x3_b) + + # maxpools + x4 = self.p1(x3_m) + x5 = self.p2(x3_m) + x6 = self.p3(x3_m) + + # concat the outputs of maxpool and x3_m + conc1 = torch.cat([x6, x5, x4, x3_m], dim=1) + + # 4 back2back CBRs + # CBR4-1 + x7 = self.c4(conc1) + x7_b = self.b4(x7) + x7_m = self.relu(x7_b) + + # CBR4-2 + x8 = self.c5(x7_m) + x8_b = self.b5(x8) + x8_m = self.relu(x8_b) + + # CBR4-3 + x9 = self.c6(x8_m) + x9_b = self.b6(x9) + x9_m = self.relu(x9_b) + + # CBR4-4 + x10 = self.c7(x9_m) + x10_b = self.b7(x10) + x10_m = self.relu(x10_b) + + # upsample + u1 = self.u(x10_m) + + # Next CBR block to be concatinated with output of u1 + # gets the output of downsample4 module which is dimensions: [1, 512, 20, 20] - make a random tensor with that shape for the purpose of running the neck unit test stand-alone + outDownSample4 = inputs[1] + # CBR block for conc2 + x11 = self.c7_2(outDownSample4) + x11_b = self.b7_2(x11) + x11_m = self.relu(x11_b) + + # concat CBR output with output from u1 + conc2 = torch.cat([x11_m, u1], dim=1) + + # 6 back2back CBRs + # CBR6_1 + x12 = self.c7_3(conc2) + x12_b = self.b7_3(x12) + x12_m = self.relu(x12_b) + + # CBR6_2 + x13 = self.c8(x12_m) + x13_b = self.b8(x13) + x13_m = self.relu(x13_b) + + # CBR6_3 + x14 = self.c7_4(x13_m) + x14_b = self.b7_4(x14) + x14_m = self.relu(x14_b) + + # CBR6_4 + x15 = self.c8_2(x14_m) + x15_b = self.b8_2(x15) + x15_m = self.relu(x15_b) + + # CBR6_5 + x16 = self.c7_5(x15_m) + x16_b = self.b7_5(x16) + x16_m = self.relu(x16_b) + + # CBR6_6 + x17 = self.c9(x16_m) + x17_b = self.b9(x17) + x17_m = self.relu(x17_b) + + # upsample + u2 = self.u(x17_m) + + # CBR block for conc3 + outDownSample3 = inputs[2] + x18 = self.c9_2(outDownSample3) + x18_b = self.b9_2(x18) + x18_m = self.relu(x18_b) + + # concat CBR output with output from u2 + conc3 = torch.cat([x18_m, u2], dim=1) + + # 5 CBR blocks + # CBR5_1 + x19 = self.c9_3(conc3) + x19_b = self.b9_3(x19) + x19_m = self.relu(x19_b) + + # CBR5_2 + x20 = self.c10(x19_m) + x20_b = self.b10(x20) + x20_m = self.relu(x20_b) + + # CBR5_3 + x21 = self.c9_4(x20_m) + x21_b = self.b9_4(x21) + x21_m = self.relu(x21_b) + + # CBR5_4 + x22 = self.c10_2(x21_m) + x22_b = self.b10_2(x22) + x22_m = self.relu(x22_b) + + # CBR5_5 + x23 = self.c9_5(x22_m) + x23_b = self.b9_5(x23) + x23_m = self.relu(x23_b) + # return [x4, x4, x4] + return x23_m, x9_m, x16_m diff --git a/models/experimental/yolov4/reference/resblock.py b/models/experimental/yolov4/reference/resblock.py new file mode 100644 index 00000000000..934347260de --- /dev/null +++ b/models/experimental/yolov4/reference/resblock.py @@ -0,0 +1,28 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import torch.nn as nn + + +class ResBlock(nn.Module): + def __init__(self, ch, nblocks=1, shortcut=True): + super().__init__() + self.shortcut = shortcut + self.module_list = nn.ModuleList() + for i in range(nblocks): + conv1 = nn.Conv2d(ch, ch, 1, 1, 0, bias=False) + bn1 = nn.BatchNorm2d(ch) + relu = nn.ReLU(inplace=True) + conv2 = nn.Conv2d(ch, ch, 3, 1, 1, bias=False) + bn2 = nn.BatchNorm2d(ch) + resblock_one = nn.ModuleList([conv1, bn1, relu, conv2, bn2, relu]) + self.module_list.append(resblock_one) + + def forward(self, x): + for module in self.module_list: + h = x + for res in module: + h = res(h) + x = x + h if self.shortcut else h + return x diff --git a/models/experimental/yolov4/reference/yolov4.py b/models/experimental/yolov4/reference/yolov4.py new file mode 100644 index 00000000000..bced124efc8 --- /dev/null +++ b/models/experimental/yolov4/reference/yolov4.py @@ -0,0 +1,38 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + + +from models.experimental.yolov4.reference.downsample1 import DownSample1 +from models.experimental.yolov4.reference.downsample2 import DownSample2 +from models.experimental.yolov4.reference.downsample3 import DownSample3 +from models.experimental.yolov4.reference.downsample4 import DownSample4 +from models.experimental.yolov4.reference.downsample5 import DownSample5 +from models.experimental.yolov4.reference.neck import Neck +from models.experimental.yolov4.reference.head import Head + +import torch +import torch.nn as nn + + +class Yolov4(nn.Module): + def __init__(self): + super(Yolov4, self).__init__() + self.downsample1 = DownSample1() + self.downsample2 = DownSample2() + self.downsample3 = DownSample3() + self.downsample4 = DownSample4() + self.downsample5 = DownSample5() + self.neck = Neck() + self.head = Head() + + def forward(self, input: torch.Tensor): + d1 = self.downsample1(input) + d2 = self.downsample2(d1) + d3 = self.downsample3(d2) + d4 = self.downsample4(d3) + d5 = self.downsample5(d4) + x20, x13, x6 = self.neck([d5, d4, d3]) + x4, x5, x6 = self.head([x20, x13, x6]) + + return x4, x5, x6 diff --git a/models/experimental/yolov4/ttnn/common.py b/models/experimental/yolov4/ttnn/common.py new file mode 100644 index 00000000000..ef366cb1915 --- /dev/null +++ b/models/experimental/yolov4/ttnn/common.py @@ -0,0 +1,100 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import torch +import ttnn + + +def fold_bn_to_conv_weights_bias(model, path): + bn_weight = model[path + ".conv.1.weight"].unsqueeze(1).unsqueeze(1).unsqueeze(1) + bn_running_var = model[path + ".conv.1.running_var"].unsqueeze(1).unsqueeze(1).unsqueeze(1) + + weight = model[path + ".conv.0.weight"] + weight = (weight / torch.sqrt(bn_running_var)) * bn_weight + + bn_running_mean = model[path + ".conv.1.running_mean"].unsqueeze(1).unsqueeze(1).unsqueeze(1) + bn_bias = model[path + ".conv.1.bias"].unsqueeze(1).unsqueeze(1).unsqueeze(1) + + bias = -(bn_weight) * (bn_running_mean / torch.sqrt(bn_running_var)) + bn_bias + + bias = bias.reshape(1, 1, 1, -1) + return ( + ttnn.from_torch( + weight, + ), + ttnn.from_torch(bias), + ) + + +class Conv: + def __init__( + self, + model, + path, + input_params, + conv_params, + *, + act_block_h=None, + reshard=False, + deallocate=True, + height_sharding=True, + activation="relu", + fused_op=True, + ) -> None: + if fused_op: + self.weights, self.bias = fold_bn_to_conv_weights_bias(model, path) + else: + weight = model[path + ".conv.0.weight"] + bias = model[path + ".conv.0.bias"] + self.weights = ttnn.from_torch(weight) + bias = bias.reshape(1, 1, 1, -1) + self.bias = ttnn.from_torch(bias) + self.input_params = input_params + self.kernel_size = (self.weights.shape[2], self.weights.shape[3]) + self.conv_params = conv_params + self.out_channels = self.weights.shape[0] + self.act_block_h = act_block_h + self.reshard = reshard + self.height_sharding = height_sharding + self.deallocate = deallocate + self.activation = activation + + def __str__(self) -> str: + return f"Conv: {self.weights.shape} {self.bias.shape} {self.kernel_size}" + + def __call__(self, device, input_tensor): + conv_config = ttnn.Conv2dConfig( + dtype=ttnn.bfloat16, + weights_dtype=ttnn.bfloat8_b, + math_fidelity=ttnn.MathFidelity.LoFi, + activation=self.activation, + height_sharding=self.height_sharding, + math_approx_mode_enabled=True, + fp32_dest_acc_enabled=False, + packer_l1_accum_enabled=False, + input_channels_alignment=16 if self.input_params[3] < 16 else 32, + transpose_shards=False, + reshard_if_not_optimal=self.reshard, + deallocate_activation=self.deallocate, + reallocate_halo_output=False, + ) + if self.act_block_h is not None: + conv_config.act_block_h_override = self.act_block_h + + [output_tensor, _out_height, _out_width, self.weights, self.bias] = ttnn.conv2d( + input_tensor=input_tensor, + weight_tensor=self.weights, + bias_tensor=self.bias, + in_channels=self.input_params[3], + out_channels=self.out_channels, + device=device, + kernel_size=self.kernel_size, + stride=(self.conv_params[0], self.conv_params[1]), + padding=(self.conv_params[2], self.conv_params[3]), + batch_size=self.input_params[0], + input_height=self.input_params[1], + input_width=self.input_params[2], + conv_config=conv_config, + ) + return output_tensor diff --git a/models/experimental/yolov4/ttnn/downsample1.py b/models/experimental/yolov4/ttnn/downsample1.py new file mode 100644 index 00000000000..f4ff679600f --- /dev/null +++ b/models/experimental/yolov4/ttnn/downsample1.py @@ -0,0 +1,98 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import torch +import ttnn +from models.experimental.yolov4.ttnn.common import Conv +from models.experimental.yolov4.reference.downsample1 import DownSample1 +from tests.ttnn.utils_for_testing import assert_with_pcc +import pytest +import time + + +class Down1: + def __init__(self, model) -> None: + if type(model) is str: + torch_model = torch.load(model) + else: + torch_model = model.torch_model + self.torch_model = torch_model + self.conv1 = Conv(torch_model, "down1.conv1", [1, 320, 320, 3], (1, 1, 1, 1), act_block_h=128) + self.conv2 = Conv(torch_model, "down1.conv2", [1, 320, 320, 32], (2, 2, 1, 1), reshard=True) + self.conv3 = Conv(torch_model, "down1.conv3", [1, 160, 160, 64], (1, 1, 0, 0), deallocate=False) + self.conv4 = Conv(torch_model, "down1.conv4", [1, 160, 160, 64], (1, 1, 0, 0), reshard=True) + self.conv5 = Conv(torch_model, "down1.conv5", [1, 160, 160, 64], (1, 1, 0, 0), deallocate=False) + self.conv6 = Conv(torch_model, "down1.conv6", [1, 160, 160, 32], (1, 1, 1, 1)) + self.conv7 = Conv(torch_model, "down1.conv7", [1, 160, 160, 64], (1, 1, 0, 0)) + self.conv8 = Conv(torch_model, "down1.conv8", [1, 160, 160, 128], (1, 1, 0, 0)) + self.convs = [self.conv1, self.conv2, self.conv3, self.conv4, self.conv5, self.conv6, self.conv7, self.conv8] + + def __call__(self, device, input_tensor): + output_tensor = self.conv1(device, input_tensor) + output_tensor_split = self.conv2(device, output_tensor) + + output_tensor_left = self.conv3(device, output_tensor_split) + + res_block_split = self.conv4(device, output_tensor_split) + output_tensor = self.conv5(device, res_block_split) + output_tensor = self.conv6(device, output_tensor) + output_tensor = res_block_split + output_tensor + + ttnn.deallocate(res_block_split) + output_tensor = self.conv7(device, output_tensor) + + output_tensor = ttnn.experimental.tensor.sharded_to_interleaved(output_tensor, ttnn.L1_MEMORY_CONFIG) + output_tensor_left = ttnn.experimental.tensor.sharded_to_interleaved(output_tensor_left, ttnn.L1_MEMORY_CONFIG) + output_tensor = ttnn.concat([output_tensor, output_tensor_left], dim=3, memory_config=ttnn.L1_MEMORY_CONFIG) + ttnn.deallocate(output_tensor_left) + + output_tensor = self.conv8(device, output_tensor) + return output_tensor + + def __str__(self) -> str: + this_str = "" + index = 1 + for conv in self.convs: + this_str += str(index) + " " + str(conv) + this_str += " \n" + index += 1 + return this_str + + +@pytest.mark.parametrize("device_params", [{"l1_small_size": 16384}], indirect=True) +def test_down1(device, use_program_cache): + ttnn_model = Down1("tests/ttnn/integration_tests/yolov4/yolov4.pth") + + torch_input = torch.randn((1, 320, 320, 3), dtype=torch.bfloat16) + ttnn_input = ttnn.from_torch(torch_input, dtype=ttnn.bfloat16) + torch_input = torch_input.permute(0, 3, 1, 2).float() + torch_model = DownSample1() + + for layer in torch_model.children(): + print(layer) + + new_state_dict = {} + ds_state_dict = {k: v for k, v in ttnn_model.torch_model.items() if (k.startswith("down1."))} + + keys = [name for name, parameter in torch_model.state_dict().items()] + values = [parameter for name, parameter in ds_state_dict.items()] + print(keys) + for i in range(len(keys)): + new_state_dict[keys[i]] = values[i] + + torch_model.load_state_dict(new_state_dict) + torch_model.eval() + + result_ttnn = ttnn_model(device, ttnn_input) + + start_time = time.time() + for x in range(100): + result_ttnn = ttnn_model(device, ttnn_input) + print(f"Time taken: {time.time() - start_time}") + result = ttnn.to_torch(result_ttnn) + + ref = torch_model(torch_input) + ref = ref.permute(0, 2, 3, 1) + result = result.reshape(1, 160, 160, 64) + assert_with_pcc(result, ref, 0.99) diff --git a/models/experimental/yolov4/ttnn/downsample2.py b/models/experimental/yolov4/ttnn/downsample2.py new file mode 100644 index 00000000000..37929fa621d --- /dev/null +++ b/models/experimental/yolov4/ttnn/downsample2.py @@ -0,0 +1,104 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import torch +import ttnn +from models.experimental.yolov4.ttnn.common import Conv +from models.experimental.yolov4.reference.downsample2 import DownSample2 +from tests.ttnn.utils_for_testing import assert_with_pcc +import pytest +import time + + +class Down2: + def __init__(self, model) -> None: + if type(model) is str: + torch_model = torch.load(model) + else: + torch_model = model.torch_model + self.torch_model = torch_model + self.conv1 = Conv(torch_model, "down2.conv1", [1, 160, 160, 64], (2, 2, 1, 1), reshard=True) + self.conv2 = Conv(torch_model, "down2.conv2", [1, 80, 80, 128], (1, 1, 0, 0), reshard=True, deallocate=False) + self.conv3 = Conv(torch_model, "down2.conv3", [1, 80, 80, 128], (1, 1, 0, 0)) + self.conv4 = Conv(torch_model, "down2.conv4", [1, 80, 80, 64], (1, 1, 0, 0), reshard=True, deallocate=False) + + self.res1_conv1 = Conv( + torch_model, "down2.resblock.module_list.0.0", [1, 80, 80, 64], (1, 1, 0, 0), deallocate=False + ) + self.res1_conv2 = Conv(torch_model, "down2.resblock.module_list.0.1", [1, 80, 80, 64], (1, 1, 1, 1)) + self.res2_conv1 = Conv( + torch_model, "down2.resblock.module_list.1.0", [1, 80, 80, 64], (1, 1, 0, 0), deallocate=False + ) + self.res2_conv2 = Conv(torch_model, "down2.resblock.module_list.1.1", [1, 80, 80, 64], (1, 1, 1, 1)) + + self.conv5 = Conv(torch_model, "down2.conv5", [1, 80, 80, 128], (1, 1, 0, 0)) + + def __call__(self, device, input_tensor): + output_tensor_split = self.conv1(device, input_tensor) + output_tensor_left = self.conv2(device, output_tensor_split) + + res1_split = self.conv3(device, output_tensor_split) + + output_tensor = self.res1_conv1(device, res1_split) + output_tensor = self.res1_conv2(device, output_tensor) + res2_split = res1_split + output_tensor + ttnn.deallocate(res1_split) + + output_tensor = self.res2_conv1(device, res2_split) + output_tensor = self.res2_conv2(device, output_tensor) + output_tensor = res2_split + output_tensor + + ttnn.deallocate(res2_split) + + output_tensor = self.conv4(device, output_tensor) + + output_tensor = ttnn.experimental.tensor.sharded_to_interleaved(output_tensor, ttnn.L1_MEMORY_CONFIG) + output_tensor_left = ttnn.experimental.tensor.sharded_to_interleaved(output_tensor_left, ttnn.L1_MEMORY_CONFIG) + output_tensor = ttnn.concat([output_tensor, output_tensor_left], dim=3, memory_config=ttnn.L1_MEMORY_CONFIG) + ttnn.deallocate(output_tensor_left) + + output_tensor = self.conv5(device, output_tensor) + return output_tensor + + def __str__(self) -> str: + this_str = "" + index = 1 + for conv in self.convs: + this_str += str(index) + " " + str(conv) + this_str += " \n" + index += 1 + return this_str + + +@pytest.mark.parametrize("device_params", [{"l1_small_size": 16384}], indirect=True) +def test_down2(device): + ttnn_model = Down2("tests/ttnn/integration_tests/yolov4/yolov4.pth") + + torch_input = torch.randn((1, 160, 160, 64), dtype=torch.bfloat16) + ttnn_input = ttnn.from_torch(torch_input, dtype=ttnn.bfloat16) + torch_input = torch_input.permute(0, 3, 1, 2).float() + torch_model = DownSample2() + + new_state_dict = {} + ds_state_dict = {k: v for k, v in ttnn_model.torch_model.items() if (k.startswith("down2."))} + + keys = [name for name, parameter in torch_model.state_dict().items()] + values = [parameter for name, parameter in ds_state_dict.items()] + for i in range(len(keys)): + new_state_dict[keys[i]] = values[i] + + torch_model.load_state_dict(new_state_dict) + torch_model.eval() + + result_ttnn = ttnn_model(device, ttnn_input) + + start_time = time.time() + for x in range(2): + result_ttnn = ttnn_model(device, ttnn_input) + print(f"Time taken: {time.time() - start_time}") + result = ttnn.to_torch(result_ttnn) + ref = torch_model(torch_input) + ref = ref.permute(0, 2, 3, 1) + result = result.reshape(ref.shape) + assert_with_pcc(result, ref, 0.97) diff --git a/models/experimental/yolov4/ttnn/downsample3.py b/models/experimental/yolov4/ttnn/downsample3.py new file mode 100644 index 00000000000..4c5a8ef4be0 --- /dev/null +++ b/models/experimental/yolov4/ttnn/downsample3.py @@ -0,0 +1,165 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import torch +import ttnn +from models.experimental.yolov4.ttnn.common import Conv +from models.experimental.yolov4.reference.downsample3 import DownSample3 +from tests.ttnn.utils_for_testing import assert_with_pcc +import pytest +import time + + +class Down3: + def __init__(self, model) -> None: + if type(model) is str: + torch_model = torch.load(model) + else: + torch_model = model.torch_model + self.torch_model = torch_model + self.conv1 = Conv(torch_model, "down3.conv1", [1, 80, 80, 128], (2, 2, 1, 1), reshard=True) + self.conv2 = Conv(torch_model, "down3.conv2", [1, 40, 40, 256], (1, 1, 0, 0), reshard=True, deallocate=False) + self.conv3 = Conv(torch_model, "down3.conv3", [1, 40, 40, 256], (1, 1, 0, 0)) + + self.res1_conv1 = Conv( + torch_model, "down3.resblock.module_list.0.0", [1, 40, 40, 128], (1, 1, 0, 0), deallocate=False + ) + self.res1_conv2 = Conv(torch_model, "down3.resblock.module_list.0.1", [1, 40, 40, 128], (1, 1, 1, 1)) + self.res2_conv1 = Conv( + torch_model, "down3.resblock.module_list.1.0", [1, 40, 40, 128], (1, 1, 0, 0), deallocate=False + ) + self.res2_conv2 = Conv(torch_model, "down3.resblock.module_list.1.1", [1, 40, 40, 128], (1, 1, 1, 1)) + self.res3_conv1 = Conv( + torch_model, "down3.resblock.module_list.2.0", [1, 40, 40, 128], (1, 1, 0, 0), deallocate=False + ) + self.res3_conv2 = Conv(torch_model, "down3.resblock.module_list.2.1", [1, 40, 40, 128], (1, 1, 1, 1)) + self.res4_conv1 = Conv( + torch_model, "down3.resblock.module_list.3.0", [1, 40, 40, 128], (1, 1, 0, 0), deallocate=False + ) + self.res4_conv2 = Conv(torch_model, "down3.resblock.module_list.3.1", [1, 40, 40, 128], (1, 1, 1, 1)) + self.res5_conv1 = Conv( + torch_model, "down3.resblock.module_list.4.0", [1, 40, 40, 128], (1, 1, 0, 0), deallocate=False + ) + self.res5_conv2 = Conv(torch_model, "down3.resblock.module_list.4.1", [1, 40, 40, 128], (1, 1, 1, 1)) + self.res6_conv1 = Conv( + torch_model, "down3.resblock.module_list.5.0", [1, 40, 40, 128], (1, 1, 0, 0), deallocate=False + ) + self.res6_conv2 = Conv(torch_model, "down3.resblock.module_list.5.1", [1, 40, 40, 128], (1, 1, 1, 1)) + self.res7_conv1 = Conv( + torch_model, "down3.resblock.module_list.6.0", [1, 40, 40, 128], (1, 1, 0, 0), deallocate=False + ) + self.res7_conv2 = Conv(torch_model, "down3.resblock.module_list.6.1", [1, 40, 40, 128], (1, 1, 1, 1)) + self.res8_conv1 = Conv( + torch_model, "down3.resblock.module_list.7.0", [1, 40, 40, 128], (1, 1, 0, 0), deallocate=False + ) + self.res8_conv2 = Conv(torch_model, "down3.resblock.module_list.7.1", [1, 40, 40, 128], (1, 1, 1, 1)) + + self.conv4 = Conv(torch_model, "down3.conv4", [1, 40, 40, 128], (1, 1, 0, 0), reshard=True, deallocate=False) + + self.conv5 = Conv(torch_model, "down3.conv5", [1, 40, 40, 256], (1, 1, 0, 0)) + + def __call__(self, device, input_tensor): + output_tensor_split = self.conv1(device, input_tensor) + output_tensor_left = self.conv2(device, output_tensor_split) + + res1_split = self.conv3(device, output_tensor_split) + + output_tensor = self.res1_conv1(device, res1_split) + output_tensor = self.res1_conv2(device, output_tensor) + res2_split = res1_split + output_tensor + ttnn.deallocate(res1_split) + + output_tensor = self.res2_conv1(device, res2_split) + output_tensor = self.res2_conv2(device, output_tensor) + res3_split = res2_split + output_tensor + + ttnn.deallocate(res2_split) + + output_tensor = self.res3_conv1(device, res3_split) + output_tensor = self.res3_conv2(device, output_tensor) + res4_split = res3_split + output_tensor + + ttnn.deallocate(res3_split) + + output_tensor = self.res4_conv1(device, res4_split) + output_tensor = self.res4_conv2(device, output_tensor) + res5_split = res4_split + output_tensor + + ttnn.deallocate(res4_split) + + output_tensor = self.res5_conv1(device, res5_split) + output_tensor = self.res5_conv2(device, output_tensor) + res6_split = res5_split + output_tensor + + ttnn.deallocate(res5_split) + + output_tensor = self.res6_conv1(device, res6_split) + output_tensor = self.res6_conv2(device, output_tensor) + res7_split = res6_split + output_tensor + + ttnn.deallocate(res6_split) + + output_tensor = self.res7_conv1(device, res7_split) + output_tensor = self.res7_conv2(device, output_tensor) + res8_split = res7_split + output_tensor + + ttnn.deallocate(res7_split) + + output_tensor = self.res8_conv1(device, res8_split) + output_tensor = self.res8_conv2(device, output_tensor) + output_tensor = res8_split + output_tensor + + ttnn.deallocate(res8_split) + + output_tensor = self.conv4(device, output_tensor) + + output_tensor = ttnn.experimental.tensor.sharded_to_interleaved(output_tensor, ttnn.L1_MEMORY_CONFIG) + output_tensor_left = ttnn.experimental.tensor.sharded_to_interleaved(output_tensor_left, ttnn.L1_MEMORY_CONFIG) + output_tensor = ttnn.concat([output_tensor, output_tensor_left], dim=3, memory_config=ttnn.L1_MEMORY_CONFIG) + ttnn.deallocate(output_tensor_left) + + output_tensor = self.conv5(device, output_tensor) + return output_tensor + + def __str__(self) -> str: + this_str = "" + index = 1 + for conv in self.convs: + this_str += str(index) + " " + str(conv) + this_str += " \n" + index += 1 + return this_str + + +@pytest.mark.parametrize("device_params", [{"l1_small_size": 16384}], indirect=True) +def test_down3(device): + ttnn_model = Down3("tests/ttnn/integration_tests/yolov4/yolov4.pth") + + torch_input = torch.randn((1, 80, 80, 128), dtype=torch.bfloat16) + ttnn_input = ttnn.from_torch(torch_input, dtype=ttnn.bfloat16) + torch_input = torch_input.permute(0, 3, 1, 2).float() + torch_model = DownSample3() + + new_state_dict = {} + ds_state_dict = {k: v for k, v in ttnn_model.torch_model.items() if (k.startswith("down3."))} + + keys = [name for name, parameter in torch_model.state_dict().items()] + values = [parameter for name, parameter in ds_state_dict.items()] + for i in range(len(keys)): + new_state_dict[keys[i]] = values[i] + + torch_model.load_state_dict(new_state_dict) + torch_model.eval() + + result_ttnn = ttnn_model(device, ttnn_input) + + start_time = time.time() + for x in range(2): + result_ttnn = ttnn_model(device, ttnn_input) + print(f"Time taken: {time.time() - start_time}") + result = ttnn.to_torch(result_ttnn) + ref = torch_model(torch_input) + ref = ref.permute(0, 2, 3, 1) + result = result.reshape(ref.shape) + assert_with_pcc(result, ref, 0.99) diff --git a/models/experimental/yolov4/ttnn/downsample4.py b/models/experimental/yolov4/ttnn/downsample4.py new file mode 100644 index 00000000000..ec3ce43d794 --- /dev/null +++ b/models/experimental/yolov4/ttnn/downsample4.py @@ -0,0 +1,165 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import torch +import ttnn +from models.experimental.yolov4.ttnn.common import Conv +from models.experimental.yolov4.reference.downsample4 import DownSample4 +from tests.ttnn.utils_for_testing import assert_with_pcc +import pytest +import time + + +class Down4: + def __init__(self, model) -> None: + if type(model) is str: + torch_model = torch.load(model) + else: + torch_model = model.torch_model + self.torch_model = torch_model + self.conv1 = Conv(torch_model, "down4.conv1", [1, 40, 40, 256], (2, 2, 1, 1), reshard=True) + self.conv2 = Conv(torch_model, "down4.conv2", [1, 20, 20, 512], (1, 1, 0, 0), reshard=True, deallocate=False) + self.conv3 = Conv(torch_model, "down4.conv3", [1, 20, 20, 512], (1, 1, 0, 0)) + + self.res1_conv1 = Conv( + torch_model, "down4.resblock.module_list.0.0", [1, 20, 20, 256], (1, 1, 0, 0), deallocate=False + ) + self.res1_conv2 = Conv(torch_model, "down4.resblock.module_list.0.1", [1, 20, 20, 256], (1, 1, 1, 1)) + self.res2_conv1 = Conv( + torch_model, "down4.resblock.module_list.1.0", [1, 20, 20, 256], (1, 1, 0, 0), deallocate=False + ) + self.res2_conv2 = Conv(torch_model, "down4.resblock.module_list.1.1", [1, 20, 20, 256], (1, 1, 1, 1)) + self.res3_conv1 = Conv( + torch_model, "down4.resblock.module_list.2.0", [1, 20, 20, 256], (1, 1, 0, 0), deallocate=False + ) + self.res3_conv2 = Conv(torch_model, "down4.resblock.module_list.2.1", [1, 20, 20, 256], (1, 1, 1, 1)) + self.res4_conv1 = Conv( + torch_model, "down4.resblock.module_list.3.0", [1, 20, 20, 256], (1, 1, 0, 0), deallocate=False + ) + self.res4_conv2 = Conv(torch_model, "down4.resblock.module_list.3.1", [1, 20, 20, 256], (1, 1, 1, 1)) + self.res5_conv1 = Conv( + torch_model, "down4.resblock.module_list.4.0", [1, 20, 20, 256], (1, 1, 0, 0), deallocate=False + ) + self.res5_conv2 = Conv(torch_model, "down4.resblock.module_list.4.1", [1, 20, 20, 256], (1, 1, 1, 1)) + self.res6_conv1 = Conv( + torch_model, "down4.resblock.module_list.5.0", [1, 20, 20, 256], (1, 1, 0, 0), deallocate=False + ) + self.res6_conv2 = Conv(torch_model, "down4.resblock.module_list.5.1", [1, 20, 20, 256], (1, 1, 1, 1)) + self.res7_conv1 = Conv( + torch_model, "down4.resblock.module_list.6.0", [1, 20, 20, 256], (1, 1, 0, 0), deallocate=False + ) + self.res7_conv2 = Conv(torch_model, "down4.resblock.module_list.6.1", [1, 20, 20, 256], (1, 1, 1, 1)) + self.res8_conv1 = Conv( + torch_model, "down4.resblock.module_list.7.0", [1, 20, 20, 256], (1, 1, 0, 0), deallocate=False + ) + self.res8_conv2 = Conv(torch_model, "down4.resblock.module_list.7.1", [1, 20, 20, 256], (1, 1, 1, 1)) + + self.conv4 = Conv(torch_model, "down4.conv4", [1, 20, 20, 256], (1, 1, 0, 0), reshard=True, deallocate=False) + + self.conv5 = Conv(torch_model, "down4.conv5", [1, 20, 20, 512], (1, 1, 0, 0)) + + def __call__(self, device, input_tensor): + output_tensor_split = self.conv1(device, input_tensor) + output_tensor_left = self.conv2(device, output_tensor_split) + + res1_split = self.conv3(device, output_tensor_split) + + output_tensor = self.res1_conv1(device, res1_split) + output_tensor = self.res1_conv2(device, output_tensor) + res2_split = res1_split + output_tensor + ttnn.deallocate(res1_split) + + output_tensor = self.res2_conv1(device, res2_split) + output_tensor = self.res2_conv2(device, output_tensor) + res3_split = res2_split + output_tensor + + ttnn.deallocate(res2_split) + + output_tensor = self.res3_conv1(device, res3_split) + output_tensor = self.res3_conv2(device, output_tensor) + res4_split = res3_split + output_tensor + + ttnn.deallocate(res3_split) + + output_tensor = self.res4_conv1(device, res4_split) + output_tensor = self.res4_conv2(device, output_tensor) + res5_split = res4_split + output_tensor + + ttnn.deallocate(res4_split) + + output_tensor = self.res5_conv1(device, res5_split) + output_tensor = self.res5_conv2(device, output_tensor) + res6_split = res5_split + output_tensor + + ttnn.deallocate(res5_split) + + output_tensor = self.res6_conv1(device, res6_split) + output_tensor = self.res6_conv2(device, output_tensor) + res7_split = res6_split + output_tensor + + ttnn.deallocate(res6_split) + + output_tensor = self.res7_conv1(device, res7_split) + output_tensor = self.res7_conv2(device, output_tensor) + res8_split = res7_split + output_tensor + + ttnn.deallocate(res7_split) + + output_tensor = self.res8_conv1(device, res8_split) + output_tensor = self.res8_conv2(device, output_tensor) + output_tensor = res8_split + output_tensor + + ttnn.deallocate(res8_split) + + output_tensor = self.conv4(device, output_tensor) + + output_tensor = ttnn.experimental.tensor.sharded_to_interleaved(output_tensor, ttnn.L1_MEMORY_CONFIG) + output_tensor_left = ttnn.experimental.tensor.sharded_to_interleaved(output_tensor_left, ttnn.L1_MEMORY_CONFIG) + output_tensor = ttnn.concat([output_tensor, output_tensor_left], dim=3, memory_config=ttnn.L1_MEMORY_CONFIG) + ttnn.deallocate(output_tensor_left) + + output_tensor = self.conv5(device, output_tensor) + return output_tensor + + def __str__(self) -> str: + this_str = "" + index = 1 + for conv in self.convs: + this_str += str(index) + " " + str(conv) + this_str += " \n" + index += 1 + return this_str + + +@pytest.mark.parametrize("device_params", [{"l1_small_size": 16384}], indirect=True) +def test_down4(device): + ttnn_model = Down4("tests/ttnn/integration_tests/yolov4/yolov4.pth") + + torch_input = torch.randn((1, 40, 40, 256), dtype=torch.bfloat16) + ttnn_input = ttnn.from_torch(torch_input, dtype=ttnn.bfloat16) + torch_input = torch_input.permute(0, 3, 1, 2).float() + torch_model = DownSample4() + + new_state_dict = {} + ds_state_dict = {k: v for k, v in ttnn_model.torch_model.items() if (k.startswith("down4."))} + + keys = [name for name, parameter in torch_model.state_dict().items()] + values = [parameter for name, parameter in ds_state_dict.items()] + for i in range(len(keys)): + new_state_dict[keys[i]] = values[i] + + torch_model.load_state_dict(new_state_dict) + torch_model.eval() + + result_ttnn = ttnn_model(device, ttnn_input) + + start_time = time.time() + for x in range(2): + result_ttnn = ttnn_model(device, ttnn_input) + print(f"Time taken: {time.time() - start_time}") + result = ttnn.to_torch(result_ttnn) + ref = torch_model(torch_input) + ref = ref.permute(0, 2, 3, 1) + result = result.reshape(ref.shape) + assert_with_pcc(result, ref, 0.99) diff --git a/models/experimental/yolov4/ttnn/downsample5.py b/models/experimental/yolov4/ttnn/downsample5.py new file mode 100644 index 00000000000..4ec836be58b --- /dev/null +++ b/models/experimental/yolov4/ttnn/downsample5.py @@ -0,0 +1,127 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import torch +import ttnn +from models.experimental.yolov4.ttnn.common import Conv +from models.experimental.yolov4.reference.downsample5 import DownSample5 +from tests.ttnn.utils_for_testing import assert_with_pcc +import pytest +import time + + +class Down5: + def __init__(self, model) -> None: + if type(model) is str: + torch_model = torch.load(model) + else: + torch_model = model.torch_model + self.torch_model = torch_model + self.conv1 = Conv( + torch_model, "down5.conv1", [1, 20, 20, 512], (2, 2, 1, 1), reshard=True, height_sharding=False + ) + self.conv2 = Conv(torch_model, "down5.conv2", [1, 10, 10, 1024], (1, 1, 0, 0), reshard=True, deallocate=False) + self.conv3 = Conv(torch_model, "down5.conv3", [1, 10, 10, 1024], (1, 1, 0, 0)) + + self.res1_conv1 = Conv( + torch_model, "down5.resblock.module_list.0.0", [1, 10, 10, 512], (1, 1, 0, 0), deallocate=False + ) + self.res1_conv2 = Conv(torch_model, "down5.resblock.module_list.0.1", [1, 10, 10, 512], (1, 1, 1, 1)) + self.res2_conv1 = Conv( + torch_model, "down5.resblock.module_list.1.0", [1, 10, 10, 512], (1, 1, 0, 0), deallocate=False + ) + self.res2_conv2 = Conv(torch_model, "down5.resblock.module_list.1.1", [1, 10, 10, 512], (1, 1, 1, 1)) + self.res3_conv1 = Conv( + torch_model, "down5.resblock.module_list.2.0", [1, 10, 10, 512], (1, 1, 0, 0), deallocate=False + ) + self.res3_conv2 = Conv(torch_model, "down5.resblock.module_list.2.1", [1, 10, 10, 512], (1, 1, 1, 1)) + self.res4_conv1 = Conv( + torch_model, "down5.resblock.module_list.3.0", [1, 10, 10, 512], (1, 1, 0, 0), deallocate=False + ) + self.res4_conv2 = Conv(torch_model, "down5.resblock.module_list.3.1", [1, 10, 10, 512], (1, 1, 1, 1)) + + self.conv4 = Conv(torch_model, "down5.conv4", [1, 10, 10, 512], (1, 1, 0, 0), reshard=True, deallocate=False) + + self.conv5 = Conv(torch_model, "down5.conv5", [1, 10, 10, 1024], (1, 1, 0, 0), height_sharding=False) + + def __call__(self, device, input_tensor): + output_tensor_split = self.conv1(device, input_tensor) + output_tensor_left = self.conv2(device, output_tensor_split) + + res1_split = self.conv3(device, output_tensor_split) + + output_tensor = self.res1_conv1(device, res1_split) + output_tensor = self.res1_conv2(device, output_tensor) + res2_split = res1_split + output_tensor + ttnn.deallocate(res1_split) + + output_tensor = self.res2_conv1(device, res2_split) + output_tensor = self.res2_conv2(device, output_tensor) + res3_split = res2_split + output_tensor + + ttnn.deallocate(res2_split) + + output_tensor = self.res3_conv1(device, res3_split) + output_tensor = self.res3_conv2(device, output_tensor) + res4_split = res3_split + output_tensor + + ttnn.deallocate(res3_split) + + output_tensor = self.res4_conv1(device, res4_split) + output_tensor = self.res4_conv2(device, output_tensor) + output_tensor = res4_split + output_tensor + + ttnn.deallocate(res4_split) + + output_tensor = self.conv4(device, output_tensor) + + output_tensor = ttnn.experimental.tensor.sharded_to_interleaved(output_tensor, ttnn.L1_MEMORY_CONFIG) + output_tensor_left = ttnn.experimental.tensor.sharded_to_interleaved(output_tensor_left, ttnn.L1_MEMORY_CONFIG) + output_tensor = ttnn.concat([output_tensor, output_tensor_left], dim=3, memory_config=ttnn.L1_MEMORY_CONFIG) + ttnn.deallocate(output_tensor_left) + + output_tensor = self.conv5(device, output_tensor) + return output_tensor + + def __str__(self) -> str: + this_str = "" + index = 1 + for conv in self.convs: + this_str += str(index) + " " + str(conv) + this_str += " \n" + index += 1 + return this_str + + +@pytest.mark.parametrize("device_params", [{"l1_small_size": 16384}], indirect=True) +def test_down5(device): + ttnn_model = Down5("tests/ttnn/integration_tests/yolov4/yolov4.pth") + + torch_input = torch.randn((1, 20, 20, 512), dtype=torch.bfloat16) + ttnn_input = ttnn.from_torch(torch_input, dtype=ttnn.bfloat16) + torch_input = torch_input.permute(0, 3, 1, 2).float() + torch_model = DownSample5() + + new_state_dict = {} + ds_state_dict = {k: v for k, v in ttnn_model.torch_model.items() if (k.startswith("down5."))} + + keys = [name for name, parameter in torch_model.state_dict().items()] + values = [parameter for name, parameter in ds_state_dict.items()] + for i in range(len(keys)): + new_state_dict[keys[i]] = values[i] + + torch_model.load_state_dict(new_state_dict) + torch_model.eval() + + result_ttnn = ttnn_model(device, ttnn_input) + + start_time = time.time() + for x in range(2): + result_ttnn = ttnn_model(device, ttnn_input) + print(f"Time taken: {time.time() - start_time}") + result = ttnn.to_torch(result_ttnn) + ref = torch_model(torch_input) + ref = ref.permute(0, 2, 3, 1) + result = result.reshape(ref.shape) + assert_with_pcc(result, ref, 0.99) diff --git a/models/experimental/yolov4/ttnn/head.py b/models/experimental/yolov4/ttnn/head.py new file mode 100644 index 00000000000..8a0ca42d391 --- /dev/null +++ b/models/experimental/yolov4/ttnn/head.py @@ -0,0 +1,281 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import torch +import ttnn +from models.experimental.yolov4.ttnn.common import Conv +from models.experimental.yolov4.reference.head import Head +from tests.ttnn.utils_for_testing import assert_with_pcc +import pytest +import time + + +class TtHead: + def __init__(self, model) -> None: + if type(model) is str: + torch_model = torch.load(model) + else: + torch_model = model.torch_model + self.torch_model = torch_model + self.conv1 = Conv( + torch_model, "head.conv1", [1, 40, 40, 128], (1, 1, 1, 1), reshard=True, deallocate=False, activation="" + ) + self.conv2 = Conv( + torch_model, "head.conv2", [1, 40, 40, 256], (1, 1, 0, 0), reshard=True, fused_op=False, activation="" + ) + self.conv3 = Conv( + torch_model, "head.conv3", [1, 40, 40, 128], (2, 2, 1, 1), reshard=True, deallocate=False, activation="" + ) + self.conv4 = Conv( + torch_model, + "head.conv4", + [1, 20, 20, 512], + (1, 1, 0, 0), + reshard=True, + height_sharding=False, + activation="", + ) + self.conv5 = Conv(torch_model, "head.conv5", [1, 20, 20, 256], (1, 1, 1, 1), reshard=True, activation="") + self.conv6 = Conv( + torch_model, + "head.conv6", + [1, 20, 20, 512], + (1, 1, 0, 0), + reshard=True, + height_sharding=False, + activation="", + ) + self.conv7 = Conv(torch_model, "head.conv7", [1, 20, 20, 256], (1, 1, 1, 1), reshard=True, activation="") + self.conv8 = Conv( + torch_model, + "head.conv8", + [1, 20, 20, 512], + (1, 1, 0, 0), + reshard=True, + height_sharding=False, + activation="", + ) + self.conv9 = Conv( + torch_model, "head.conv9", [1, 20, 20, 256], (1, 1, 1, 1), reshard=True, deallocate=False, activation="" + ) + self.conv10 = Conv( + torch_model, + "head.conv10", + [1, 20, 20, 512], + (1, 1, 0, 0), + reshard=True, + height_sharding=False, + fused_op=False, + activation="", + ) + self.conv11 = Conv(torch_model, "head.conv11", [1, 20, 20, 256], (2, 2, 1, 1), reshard=True, activation="") + self.conv12 = Conv( + torch_model, + "head.conv12", + [1, 10, 10, 1024], + (1, 1, 0, 0), + reshard=True, + height_sharding=False, + activation="", + ) + self.conv13 = Conv( + torch_model, + "head.conv13", + [1, 10, 10, 512], + (1, 1, 1, 1), + reshard=True, + height_sharding=False, + activation="", + ) + self.conv14 = Conv( + torch_model, + "head.conv14", + [1, 10, 10, 1024], + (1, 1, 0, 0), + reshard=True, + height_sharding=False, + activation="", + ) + self.conv15 = Conv( + torch_model, + "head.conv15", + [1, 10, 10, 512], + (1, 1, 1, 1), + reshard=True, + height_sharding=False, + activation="", + ) + self.conv16 = Conv( + torch_model, + "head.conv16", + [1, 10, 10, 1024], + (1, 1, 0, 0), + reshard=True, + height_sharding=False, + activation="", + ) + self.conv17 = Conv( + torch_model, + "head.conv17", + [1, 10, 10, 512], + (1, 1, 1, 1), + reshard=True, + height_sharding=False, + activation="", + ) + self.conv18 = Conv( + torch_model, + "head.conv18", + [1, 10, 10, 1024], + (1, 1, 0, 0), + reshard=True, + fused_op=False, + activation="", + height_sharding=False, + ) + + def __call__(self, device, input_tensor, model): + output_tensor = self.conv1(device, input_tensor[0]) + output_tensor = ttnn.leaky_relu(output_tensor, slope=0.1) + + output_tensor_left_1 = self.conv2(device, output_tensor) + + output_tensor = self.conv3(device, input_tensor[0]) + output_tensor = ttnn.leaky_relu(output_tensor, slope=0.1) + outfrom_Neck1 = input_tensor[2] + + output_tensor = ttnn.experimental.tensor.sharded_to_interleaved(output_tensor, ttnn.L1_MEMORY_CONFIG) + if ( + outfrom_Neck1.memory_config().is_sharded() + ): # This is used because test of head sub_module passes interleaved tensor + outfrom_Neck1 = ttnn.experimental.tensor.sharded_to_interleaved(outfrom_Neck1, ttnn.L1_MEMORY_CONFIG) + + output_tensor = ttnn.concat([output_tensor, outfrom_Neck1], dim=3, memory_config=ttnn.L1_MEMORY_CONFIG) + ttnn.deallocate(outfrom_Neck1) + + output_tensor = self.conv4(device, output_tensor) + output_tensor = ttnn.leaky_relu(output_tensor, slope=0.1) + + output_tensor = self.conv5(device, output_tensor) + output_tensor = ttnn.leaky_relu(output_tensor, slope=0.1) + + output_tensor = self.conv6(device, output_tensor) + output_tensor = ttnn.leaky_relu(output_tensor, slope=0.1) + + output_tensor = self.conv7(device, output_tensor) + output_tensor = ttnn.leaky_relu(output_tensor, slope=0.1) + + output_tensor = self.conv8(device, output_tensor) + output_tensor_split = ttnn.leaky_relu(output_tensor, slope=0.1) + + output_tensor = self.conv9(device, output_tensor_split) + output_tensor = ttnn.leaky_relu(output_tensor, slope=0.1) + + output_tensor_left_2 = self.conv10(device, output_tensor) + + output_tensor = self.conv11(device, output_tensor_split) + output_tensor = ttnn.leaky_relu(output_tensor, slope=0.1) + + outfromNeck2 = input_tensor[1] + output_tensor = ttnn.experimental.tensor.sharded_to_interleaved(output_tensor, ttnn.L1_MEMORY_CONFIG) + if ( + outfromNeck2.memory_config().is_sharded() + ): # This is used because test of head sub_module passes interleaved tensor + outfromNeck2 = ttnn.experimental.tensor.sharded_to_interleaved(outfromNeck2, ttnn.L1_MEMORY_CONFIG) + output_tensor = ttnn.concat([output_tensor, outfromNeck2], dim=3, memory_config=ttnn.L1_MEMORY_CONFIG) + ttnn.deallocate(outfromNeck2) + + output_tensor = self.conv12(device, output_tensor) + output_tensor = ttnn.leaky_relu(output_tensor, slope=0.1) + + output_tensor = self.conv13(device, output_tensor) + output_tensor = ttnn.leaky_relu(output_tensor, slope=0.1) + + output_tensor = self.conv14(device, output_tensor) + output_tensor = ttnn.leaky_relu(output_tensor, slope=0.1) + + output_tensor = self.conv15(device, output_tensor) + output_tensor = ttnn.leaky_relu(output_tensor, slope=0.1) + + output_tensor = self.conv16(device, output_tensor) + output_tensor = ttnn.leaky_relu(output_tensor, slope=0.1) + + output_tensor = self.conv17(device, output_tensor) + output_tensor = ttnn.leaky_relu(output_tensor, slope=0.1) + + output_tensor_left_3 = self.conv18(device, output_tensor) + + return output_tensor_left_1, output_tensor_left_2, output_tensor_left_3 + + +@pytest.mark.parametrize("device_params", [{"l1_small_size": 16384}], indirect=True) +def test_head(device, reset_seeds): + ttnn_model = TtHead("tests/ttnn/integration_tests/yolov4/yolov4.pth") + + torch_input_tensor1 = torch.randn(1, 40, 40, 128, dtype=torch.bfloat16) + torch_input_tensor2 = torch.randn(1, 10, 10, 512, dtype=torch.bfloat16) + torch_input_tensor3 = torch.randn(1, 20, 20, 256, dtype=torch.bfloat16) + + ttnn_input_tensor1 = ttnn.from_torch(torch_input_tensor1, dtype=ttnn.bfloat16) + ttnn_input_tensor1 = ttnn.reshape(ttnn_input_tensor1, (1, 1, 1600, 128)) + ttnn_input_tensor1 = ttnn.to_layout(ttnn_input_tensor1, layout=ttnn.TILE_LAYOUT) + ttnn_input_tensor1 = ttnn.to_device(ttnn_input_tensor1, device=device) + + ttnn_input_tensor2 = ttnn.from_torch(torch_input_tensor2, dtype=ttnn.bfloat16) + ttnn_input_tensor2 = ttnn.reshape(ttnn_input_tensor2, (1, 1, 100, 512)) + ttnn_input_tensor2 = ttnn.to_layout(ttnn_input_tensor2, layout=ttnn.TILE_LAYOUT) + ttnn_input_tensor2 = ttnn.to_device(ttnn_input_tensor2, device=device) + + ttnn_input_tensor3 = ttnn.from_torch(torch_input_tensor3, dtype=ttnn.bfloat16) + ttnn_input_tensor3 = ttnn.reshape(ttnn_input_tensor3, (1, 1, 400, 256)) + ttnn_input_tensor3 = ttnn.to_layout(ttnn_input_tensor3, layout=ttnn.TILE_LAYOUT) + ttnn_input_tensor3 = ttnn.to_device(ttnn_input_tensor3, device=device) + + ttnn_input_tensor = [ttnn_input_tensor1, ttnn_input_tensor2, ttnn_input_tensor3] + torch_input_tensor1 = torch_input_tensor1.permute(0, 3, 1, 2).float() + torch_input_tensor2 = torch_input_tensor2.permute(0, 3, 1, 2).float() + torch_input_tensor3 = torch_input_tensor3.permute(0, 3, 1, 2).float() + torch_input_tensor = [torch_input_tensor1, torch_input_tensor2, torch_input_tensor3] + + torch_model = Head() + + new_state_dict = {} + ds_state_dict = {k: v for k, v in ttnn_model.torch_model.items() if (k.startswith("head."))} + + keys = [name for name, parameter in torch_model.state_dict().items()] + values = [parameter for name, parameter in ds_state_dict.items()] + for i in range(len(keys)): + new_state_dict[keys[i]] = values[i] + + torch_model.load_state_dict(new_state_dict) + torch_model.eval() + + result_ttnn = ttnn_model(device, ttnn_input_tensor, torch_model) + # start_time = time.time() + # for x in range(1): + # result_ttnn = ttnn_model(device, ttnn_input_tensor) + # print(f"Time taken: {time.time() - start_time}") + + result_1 = ttnn.to_torch(result_ttnn[0]) + result_2 = ttnn.to_torch(result_ttnn[1]) + result_3 = ttnn.to_torch(result_ttnn[2]) + ref1, ref2, ref3 = torch_model(torch_input_tensor) + + result_1 = result_1.reshape(1, ref1.shape[2], ref1.shape[3], 256) + result_1 = result_1.permute(0, 3, 1, 2) + + result_2 = result_2.reshape(1, ref2.shape[2], ref2.shape[3], 256) + result_2 = result_2.permute(0, 3, 1, 2) + + result_3 = result_3.reshape(1, ref3.shape[2], ref3.shape[3], 256) + result_3 = result_3.permute(0, 3, 1, 2) + + # Output is sliced because ttnn.conv returns 256 channels instead of 255. + result_1 = result_1[:, :255, :, :] + result_2 = result_2[:, :255, :, :] + result_3 = result_3[:, :255, :, :] + + assert_with_pcc(result_1, ref1, 0.99) + assert_with_pcc(result_2, ref2, 0.99) + assert_with_pcc(result_3, ref3, 0.99) diff --git a/models/experimental/yolov4/ttnn/neck.py b/models/experimental/yolov4/ttnn/neck.py new file mode 100644 index 00000000000..a88ba66edf8 --- /dev/null +++ b/models/experimental/yolov4/ttnn/neck.py @@ -0,0 +1,362 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import torch +import ttnn +from models.experimental.yolov4.ttnn.common import Conv +from models.experimental.yolov4.reference.neck import Neck +from tests.ttnn.utils_for_testing import assert_with_pcc +import pytest +import time +from tt_lib.fallback_ops import fallback_ops + + +class TtNeck: + def __init__(self, model) -> None: + if type(model) is str: + torch_model = torch.load(model) + else: + torch_model = model.torch_model + self.torch_model = torch_model + self.conv1 = Conv( + torch_model, + "neek.conv1", + [1, 10, 10, 1024], + (1, 1, 0, 0), + height_sharding=False, + reshard=True, + activation="", + ) + self.conv2 = Conv( + torch_model, + "neek.conv2", + [1, 10, 10, 512], + (1, 1, 1, 1), + height_sharding=False, + reshard=True, + activation="", + ) + self.conv3 = Conv( + torch_model, + "neek.conv3", + [1, 10, 10, 1024], + (1, 1, 0, 0), + height_sharding=False, + reshard=True, + activation="", + ) + + self.p1 = fallback_ops.MaxPool2d(kernel_size=5, stride=1, padding=2, dilation=1, ceil_mode=False) + # ttnn.MaxPool2d( + # kernel_size=(5, 5), + # stride=(1, 1), + # padding=(2, 2), + # dilation=(1, 1), + # dtype=ttnn.bfloat16, + # device=self.device, + # batch_size=self.batch_size, + # input_height=10, + # input_width=10, + # reader_patterns_cache=self.max_pool_reader_patterns_cache, + # deallocate_activation=True, + # parallel_config_override={}, + # channels=512 + # ) + + self.p2 = fallback_ops.MaxPool2d(kernel_size=9, stride=1, padding=4, dilation=1, ceil_mode=False) + # ttnn.MaxPool2d( + # kernel_size=(9, 9), + # stride=(1, 1), + # padding=(4, 4), + # dilation=(1, 1), + # dtype=ttnn.bfloat16, + # device=self.device, + # batch_size=self.batch_size, + # input_height=10, + # input_width=10, + # reader_patterns_cache=self.max_pool_reader_patterns_cache, + # deallocate_activation=True, + # parallel_config_override={}, + # channels=512 + # ) + + self.p3 = fallback_ops.MaxPool2d(kernel_size=13, stride=1, padding=6, dilation=1, ceil_mode=False) + # ttnn.MaxPool2d( + # kernel_size=(13, 13), + # stride=(1, 1), + # padding=(6, 6), + # dilation=(1, 1), + # dtype=ttnn.bfloat16, + # device=self.device, + # batch_size=self.batch_size, + # input_height=10, + # input_width=10, + # reader_patterns_cache=self.max_pool_reader_patterns_cache, + # deallocate_activation=True, + # parallel_config_override={}, + # channels=512 + # ) + + self.conv4 = Conv( + torch_model, + "neek.conv4", + [1, 10, 10, 2048], + (1, 1, 0, 0), + height_sharding=False, + reshard=True, + activation="", + ) + self.conv5 = Conv( + torch_model, + "neek.conv5", + [1, 10, 10, 512], + (1, 1, 1, 1), + height_sharding=False, + reshard=True, + activation="", + ) + self.conv6 = Conv( + torch_model, + "neek.conv6", + [1, 10, 10, 1024], + (1, 1, 0, 0), + height_sharding=False, + reshard=True, + activation="", + ) + self.conv7 = Conv( + torch_model, + "neek.conv7", + [1, 10, 10, 512], + (1, 1, 0, 0), + height_sharding=False, + reshard=True, + deallocate=False, + activation="", + ) + self.conv7_2 = Conv( + torch_model, + "neek.conv8", + [1, 20, 20, 512], + (1, 1, 0, 0), + height_sharding=False, + reshard=True, + activation="", + ) + self.conv7_3 = Conv( + torch_model, + "neek.conv9", + [1, 20, 20, 512], + (1, 1, 0, 0), + height_sharding=False, + reshard=True, + activation="", + ) + self.conv8 = Conv(torch_model, "neek.conv10", [1, 20, 20, 256], (1, 1, 1, 1), reshard=True, activation="") + self.conv7_4 = Conv( + torch_model, + "neek.conv11", + [1, 20, 20, 512], + (1, 1, 0, 0), + height_sharding=False, + reshard=True, + activation="", + ) + self.conv8_2 = Conv(torch_model, "neek.conv12", [1, 20, 20, 256], (1, 1, 1, 1), reshard=True, activation="") + self.conv7_5 = Conv( + torch_model, + "neek.conv13", + [1, 20, 20, 512], + (1, 1, 0, 0), + height_sharding=False, + reshard=True, + activation="", + ) + + self.conv9 = Conv( + torch_model, "neek.conv14", [1, 20, 20, 256], (1, 1, 0, 0), reshard=True, deallocate=False, activation="" + ) + self.conv9_2 = Conv(torch_model, "neek.conv15", [1, 40, 40, 256], (1, 1, 0, 0), reshard=True, activation="") + self.conv9_3 = Conv(torch_model, "neek.conv16", [1, 40, 40, 256], (1, 1, 0, 0), reshard=True, activation="") + self.conv10 = Conv(torch_model, "neek.conv17", [1, 40, 40, 128], (1, 1, 1, 1), reshard=True, activation="") + + self.conv9_4 = Conv(torch_model, "neek.conv18", [1, 40, 40, 256], (1, 1, 0, 0), reshard=True, activation="") + self.conv10_2 = Conv(torch_model, "neek.conv19", [1, 40, 40, 128], (1, 1, 1, 1), reshard=True, activation="") + self.conv9_5 = Conv(torch_model, "neek.conv20", [1, 40, 40, 256], (1, 1, 0, 0), reshard=True, activation="") + + def __call__(self, device, input_tensor): + output_tensor = self.conv1(device, input_tensor[0]) + output_tensor = ttnn.leaky_relu(output_tensor, slope=0.1) + + output_tensor = self.conv2(device, output_tensor) + output_tensor = ttnn.leaky_relu(output_tensor, slope=0.1) + + output_tensor = self.conv3(device, output_tensor) + output_tensor = ttnn.leaky_relu(output_tensor, slope=0.1) + + output_tensor_conv3 = ttnn.experimental.tensor.sharded_to_interleaved(output_tensor, ttnn.L1_MEMORY_CONFIG) + + # Once issue #7746 is resolved we will use ttnn.MaxPool instead of fallback.MaxPool + output_tensor_conv3 = ttnn.to_layout(output_tensor_conv3, layout=ttnn.ROW_MAJOR_LAYOUT) + output_tensor_conv3 = ttnn.reshape( + output_tensor_conv3, (1, 10, 10, 512) + ) # hard coded the shape as in future we will be using ttnn.MaxPool + output_tensor_conv3 = ttnn.permute(output_tensor_conv3, (0, 3, 1, 2)) + + pool_1 = self.p1(output_tensor_conv3) + pool_2 = self.p2(output_tensor_conv3) + pool_3 = self.p3(output_tensor_conv3) + + pool_1 = ttnn.permute(pool_1, (0, 2, 3, 1)) + pool_1 = ttnn.reshape(pool_1, (1, 1, pool_1.shape[1] * pool_1.shape[2], pool_1.shape[3])) + pool_2 = ttnn.permute(pool_2, (0, 2, 3, 1)) + pool_2 = ttnn.reshape(pool_2, (1, 1, pool_2.shape[1] * pool_2.shape[2], pool_2.shape[3])) + pool_3 = ttnn.permute(pool_3, (0, 2, 3, 1)) + pool_3 = ttnn.reshape(pool_3, (1, 1, pool_3.shape[1] * pool_3.shape[2], pool_3.shape[3])) + pool_1 = ttnn.to_layout(pool_1, layout=ttnn.TILE_LAYOUT) + pool_2 = ttnn.to_layout(pool_2, layout=ttnn.TILE_LAYOUT) + pool_3 = ttnn.to_layout(pool_3, layout=ttnn.TILE_LAYOUT) + + output_tensor = ttnn.experimental.tensor.sharded_to_interleaved(output_tensor, ttnn.L1_MEMORY_CONFIG) + output_tensor = ttnn.concat([pool_3, pool_2, pool_1, output_tensor], dim=3, memory_config=ttnn.L1_MEMORY_CONFIG) + ttnn.deallocate(pool_3) + ttnn.deallocate(pool_2) + ttnn.deallocate(pool_1) + + output_tensor = self.conv4(device, output_tensor) + output_tensor = ttnn.leaky_relu(output_tensor, slope=0.1) + + output_tensor = self.conv5(device, output_tensor) + output_tensor = ttnn.leaky_relu(output_tensor, slope=0.1) + + output_tensor = self.conv6(device, output_tensor) + output_tensor_left_1 = ttnn.leaky_relu(output_tensor, slope=0.1) + + output_tensor = self.conv7(device, output_tensor_left_1) + output_tensor = ttnn.leaky_relu(output_tensor, slope=0.1) + + output_tensor = ttnn.experimental.tensor.sharded_to_interleaved(output_tensor, ttnn.L1_MEMORY_CONFIG) + output_tensor = ttnn.to_layout(output_tensor, ttnn.ROW_MAJOR_LAYOUT) + output_tensor_upsample_1 = ttnn.upsample(output_tensor, (1, 4, 1), memory_config=ttnn.L1_MEMORY_CONFIG) + output_tensor_upsample_1 = ttnn.to_layout(output_tensor_upsample_1, layout=ttnn.TILE_LAYOUT) + + outDowSample5 = input_tensor[1] + output_tensor = self.conv7_2(device, outDowSample5) + output_tensor = ttnn.leaky_relu(output_tensor, slope=0.1) + + output_tensor = ttnn.experimental.tensor.sharded_to_interleaved(output_tensor, ttnn.L1_MEMORY_CONFIG) + + output_tensor = ttnn.concat( + [output_tensor, output_tensor_upsample_1], dim=3, memory_config=ttnn.L1_MEMORY_CONFIG + ) + ttnn.deallocate(output_tensor_upsample_1) + + output_tensor = self.conv7_3(device, output_tensor) + output_tensor = ttnn.leaky_relu(output_tensor, slope=0.1) + + output_tensor = self.conv8(device, output_tensor) + output_tensor = ttnn.leaky_relu(output_tensor, slope=0.1) + + output_tensor = self.conv7_4(device, output_tensor) + output_tensor = ttnn.leaky_relu(output_tensor, slope=0.1) + + output_tensor = self.conv8_2(device, output_tensor) + output_tensor = ttnn.leaky_relu(output_tensor, slope=0.1) + + output_tensor = self.conv7_5(device, output_tensor) + output_tensor_left_2 = ttnn.leaky_relu(output_tensor, slope=0.1) + + output_tensor = self.conv9(device, output_tensor_left_2) + output_tensor = ttnn.leaky_relu(output_tensor, slope=0.1) + + output_tensor = ttnn.experimental.tensor.sharded_to_interleaved(output_tensor, ttnn.L1_MEMORY_CONFIG) + output_tensor = ttnn.to_layout(output_tensor, ttnn.ROW_MAJOR_LAYOUT) + output_tensor_upsample_2 = ttnn.upsample(output_tensor, (1, 4, 1), memory_config=ttnn.L1_MEMORY_CONFIG) + output_tensor_upsample_2 = ttnn.to_layout(output_tensor_upsample_2, ttnn.TILE_LAYOUT) + + outDowSample3 = input_tensor[2] + + output_tensor = self.conv9_2(device, outDowSample3) + output_tensor = ttnn.leaky_relu(output_tensor, slope=0.1) + + output_tensor = ttnn.experimental.tensor.sharded_to_interleaved(output_tensor, ttnn.L1_MEMORY_CONFIG) + output_tensor = ttnn.concat( + [output_tensor, output_tensor_upsample_2], dim=3, memory_config=ttnn.L1_MEMORY_CONFIG + ) + ttnn.deallocate(output_tensor_upsample_2) + + output_tensor = self.conv9_3(device, output_tensor) + output_tensor = ttnn.leaky_relu(output_tensor, slope=0.1) + + output_tensor = self.conv10(device, output_tensor) + output_tensor = ttnn.leaky_relu(output_tensor, slope=0.1) + + output_tensor = self.conv9_4(device, output_tensor) + output_tensor = ttnn.leaky_relu(output_tensor, slope=0.1) + + output_tensor = self.conv10_2(device, output_tensor) + output_tensor = ttnn.leaky_relu(output_tensor, slope=0.1) + + output_tensor = self.conv9_5(device, output_tensor) + output_tensor = ttnn.leaky_relu(output_tensor, slope=0.1) + + return output_tensor, output_tensor_left_1, output_tensor_left_2 + + +@pytest.mark.parametrize("device_params", [{"l1_small_size": 16384}], indirect=True) +def test_neck(device, reset_seeds): + ttnn_model = TtNeck("tests/ttnn/integration_tests/yolov4/yolov4.pth") + + torch_input_tensor1 = torch.randn(1, 10, 10, 1024, dtype=torch.bfloat16) + torch_input_tensor2 = torch.randn(1, 20, 20, 512, dtype=torch.bfloat16) + torch_input_tensor3 = torch.randn(1, 40, 40, 256, dtype=torch.bfloat16) + ttnn_input_tensor1 = ttnn.from_torch(torch_input_tensor1, dtype=ttnn.bfloat16) + ttnn_input_tensor1 = ttnn.reshape(ttnn_input_tensor1, (1, 1, 100, 1024)) + ttnn_input_tensor1 = ttnn.to_layout(ttnn_input_tensor1, layout=ttnn.TILE_LAYOUT) + ttnn_input_tensor1 = ttnn.to_device(ttnn_input_tensor1, device=device) + ttnn_input_tensor2 = ttnn.from_torch(torch_input_tensor2, dtype=ttnn.bfloat16) + ttnn_input_tensor2 = ttnn.reshape(ttnn_input_tensor2, (1, 1, 400, 512)) + ttnn_input_tensor2 = ttnn.to_layout(ttnn_input_tensor2, layout=ttnn.TILE_LAYOUT) + ttnn_input_tensor2 = ttnn.to_device(ttnn_input_tensor2, device=device) + ttnn_input_tensor3 = ttnn.from_torch(torch_input_tensor3, dtype=ttnn.bfloat16) + ttnn_input_tensor3 = ttnn.reshape(ttnn_input_tensor3, (1, 1, 1600, 256)) + ttnn_input_tensor3 = ttnn.to_layout(ttnn_input_tensor3, layout=ttnn.TILE_LAYOUT) + ttnn_input_tensor3 = ttnn.to_device(ttnn_input_tensor3, device=device) + ttnn_input_tensor = [ttnn_input_tensor1, ttnn_input_tensor2, ttnn_input_tensor3] + torch_input_tensor1 = torch_input_tensor1.permute(0, 3, 1, 2).float() + torch_input_tensor2 = torch_input_tensor2.permute(0, 3, 1, 2).float() + torch_input_tensor3 = torch_input_tensor3.permute(0, 3, 1, 2).float() + torch_input_tensor = [torch_input_tensor1, torch_input_tensor2, torch_input_tensor3] + torch_model = Neck() + + new_state_dict = {} + ds_state_dict = {k: v for k, v in ttnn_model.torch_model.items() if (k.startswith("neek."))} + + keys = [name for name, parameter in torch_model.state_dict().items()] + values = [parameter for name, parameter in ds_state_dict.items()] + for i in range(len(keys)): + new_state_dict[keys[i]] = values[i] + + torch_model.load_state_dict(new_state_dict) + torch_model.eval() + + result_ttnn = ttnn_model(device, ttnn_input_tensor) + start_time = time.time() + for x in range(2): + result_ttnn = ttnn_model(device, ttnn_input_tensor) + print(f"Time taken: {time.time() - start_time}") + + result_1 = ttnn.to_torch(result_ttnn[0]) + result_2 = ttnn.to_torch(result_ttnn[1]) + result_3 = ttnn.to_torch(result_ttnn[2]) + ref1, ref2, ref3 = torch_model(torch_input_tensor) + ref1 = ref1.permute(0, 2, 3, 1) + ref2 = ref2.permute(0, 2, 3, 1) + ref3 = ref3.permute(0, 2, 3, 1) + result1 = result_1.reshape(ref1.shape) + result2 = result_2.reshape(ref2.shape) + result3 = result_3.reshape(ref3.shape) + assert_with_pcc(result1, ref1, 0.94) # PCC = 0.94 + assert_with_pcc(result2, ref2, 0.99) # PCC = 0.99 + assert_with_pcc(result3, ref3, 0.96) # PCC = 0.96 diff --git a/models/experimental/yolov4/ttnn/yolov4.py b/models/experimental/yolov4/ttnn/yolov4.py new file mode 100644 index 00000000000..0dc5a159342 --- /dev/null +++ b/models/experimental/yolov4/ttnn/yolov4.py @@ -0,0 +1,113 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any +import ttnn +import torch +import pytest + +ttnn.enable_fast_runtime_mode = False +ttnn.enable_logging = True +ttnn.report_name = "yolo_fail" +ttnn.enable_graph_report = False +ttnn.enable_detailed_buffer_report = True +ttnn.enable_detailed_tensor_report = True +ttnn.enable_comparison_mode = False + +from models.experimental.yolov4.reference.yolov4 import Yolov4 +from tests.ttnn.utils_for_testing import assert_with_pcc +from models.experimental.yolov4.ttnn.downsample1 import Down1 +from models.experimental.yolov4.ttnn.downsample2 import Down2 +from models.experimental.yolov4.ttnn.downsample3 import Down3 +from models.experimental.yolov4.ttnn.downsample4 import Down4 +from models.experimental.yolov4.ttnn.downsample5 import Down5 +from models.experimental.yolov4.ttnn.neck import TtNeck +from models.experimental.yolov4.ttnn.head import TtHead + + +class TtYOLOv4: + def __init__(self, path) -> None: + self.torch_model = torch.load(path) + self.torch_keys = self.torch_model.keys() + self.down1 = Down1(self) + self.down2 = Down2(self) + self.down3 = Down3(self) + self.down4 = Down4(self) + self.down5 = Down5(self) + + self.neck = TtNeck(self) + self.head = TtHead(self) + + self.downs = [] # [self.down1] + + def __call__(self, device, input_tensor, model): + d1 = self.down1(device, input_tensor) + d2 = self.down2(device, d1) + ttnn.deallocate(d1) + d3 = self.down3(device, d2) + ttnn.deallocate(d2) + d4 = self.down4(device, d3) + d5 = self.down5(device, d4) + x20, x13, x6 = self.neck(device, [d5, d4, d3]) + x4, x5, x6 = self.head(device, [x20, x13, x6], model.head) + + return x4, x5, x6 + + def __str__(self) -> str: + this_str = "" + for down in self.downs: + this_str += str(down) + this_str += " \n" + return this_str + + +@pytest.mark.parametrize("device_params", [{"l1_small_size": 16384}], indirect=True) +def test_yolov4(device, reset_seeds): + ttnn_model = TtYOLOv4("tests/ttnn/integration_tests/yolov4/yolov4.pth") + print(ttnn_model) + + torch_input = torch.randn((1, 320, 320, 3), dtype=torch.bfloat16) + ttnn_input = ttnn.from_torch(torch_input, dtype=ttnn.bfloat16) + torch_input = torch_input.permute(0, 3, 1, 2).float() + torch_model = Yolov4() + + for layer in torch_model.children(): + print(layer) + + new_state_dict = {} + ds_state_dict = {k: v for k, v in ttnn_model.torch_model.items()} + + keys = [name for name, parameter in torch_model.state_dict().items()] + values = [parameter for name, parameter in ds_state_dict.items()] + print(keys) + for i in range(len(keys)): + new_state_dict[keys[i]] = values[i] + + torch_model.load_state_dict(new_state_dict) + torch_model.eval() + + result_1, result_2, result_3 = ttnn_model(device, ttnn_input, torch_model) + result_1 = ttnn.to_torch(result_1) + result_2 = ttnn.to_torch(result_2) + result_3 = ttnn.to_torch(result_3) + + ref1, ref2, ref3 = torch_model(torch_input) + + result_1 = result_1.reshape(1, ref1.shape[2], ref1.shape[3], 256) + result_1 = result_1.permute(0, 3, 1, 2) + + result_2 = result_2.reshape(1, ref2.shape[2], ref2.shape[3], 256) + result_2 = result_2.permute(0, 3, 1, 2) + + result_3 = result_3.reshape(1, ref3.shape[2], ref3.shape[3], 256) + result_3 = result_3.permute(0, 3, 1, 2) + + # Output is sliced because ttnn.conv returns 256 channels instead of 255. + result_1 = result_1[:, :255, :, :] + result_2 = result_2[:, :255, :, :] + result_3 = result_3[:, :255, :, :] + + assert_with_pcc(result_1, ref1, 0.95) # PCC = 0.95 + assert_with_pcc(result_2, ref2, 0.96) # PCC = 0.96 + assert_with_pcc(result_3, ref3, 0.98) # PCC = 0.98