Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimization for subpixel layer on Tensor core #4523

Closed
kice opened this issue Dec 15, 2019 · 14 comments
Closed

Optimization for subpixel layer on Tensor core #4523

kice opened this issue Dec 15, 2019 · 14 comments

Comments

@kice
Copy link
Contributor

kice commented Dec 15, 2019

I found that Depth_to_Space layer spend too much time on changing data layout (NHWC <-> NCHW) while using tensor core. It takes up to 25% of the run time to do the transpose.

Is it possible to reduce this kind of unnecessary data manipulation, like combining reshape and/or transpose into one op.

A sample network

scale = 4

conv(3, 64),
conv(64, scale**2),
subpixel(scale),
conv(64 // scale**2, 3)

Then it will do

scale = 4

nchwToNhwc(), 
conv(3, 64),
conv(64, scale**2),
nhwcToNchw(),
reshape and transpose
nchwToNhwc(), 
conv(64 // scale**2, 3)
nhwcToNchw(),

I think nchwToNhwc is done automatically by CUDA, maybe we could convert the whole to NHWC before using tensor core will be a better choice.

Or some features like this PR #4335

@jwfromm
Copy link
Contributor

jwfromm commented Dec 15, 2019

You're right that our current handling of depth2space (and space2depth) isn't optimal. We should probably have a dedicated relay op and topi implementation that directly rearranges values rather than stacking reshapes and transposes. I can try to take a stab at implementing this in the next few days.

@kice
Copy link
Contributor Author

kice commented Dec 16, 2019

I think it is okay that we use reshapes and transoises only if we could change the parameters when the data layout change. Some models, like EDSR, sandwich conv between subpixel for 4x scale, hurting the perforamce even worse.

But if the new op can support both NCHW and NHWC, it will be nice. xD

@jwfromm
Copy link
Contributor

jwfromm commented Dec 22, 2019

See PR #4566 for my implementation of native sub pixel operators.

@jwfromm
Copy link
Contributor

jwfromm commented Dec 28, 2019

@kice now that #4566 is merged, can you give this another try? Hopefully we have much better performance and this issue can be closed.

@kice
Copy link
Contributor Author

kice commented Dec 29, 2019

After testing it, I am happy to let you know, we have no significant difference at all. And I even found a bug. xD

TVM build on Win10 MSVC, CUDA 10.1, Test with RTX 2060 Super

v0.0.4
fn (%data: Tensor[(1, 3, 720, 1280), float16], %head.0.weight: Tensor[(64, 3, 3, 3), float16], %head.0.bias: Tensor[(64), float16], %body.0.body.0.weight: Tensor[(64, 64, 3, 3), float16], %body.0.body.0.bias: Tensor[(64), float16], %body.0.body.2.weight: Tensor[(64, 64, 3, 3), float16], %body.0.body.2.bias: Tensor[(64), float16], %body.1.body.0.weight: Tensor[(64, 64, 3, 3), float16], %body.1.body.0.bias: Tensor[(64), float16], %body.1.body.2.weight: Tensor[(64, 64, 3, 3), float16], %body.1.body.2.bias: Tensor[(64), float16], %body.2.weight: Tensor[(64, 64, 3, 3), float16], %body.2.bias: Tensor[(64), float16], %tail.0.0.weight: Tensor[(128, 64, 3, 3), float16], %tail.0.0.bias: Tensor[(128), float16], %tail.0.2.weight: Tensor[(64, 32, 3, 3), float16], %tail.0.2.bias: Tensor[(64), float16], %tail.1.weight: Tensor[(3, 16, 3, 3), float16], %tail.1.bias: Tensor[(3), float16]) -> Tensor[(1, 3, 2880, 5120), float16] {
  %0 = nn.conv2d(%data, %head.0.weight, padding=[1, 1], kernel_size=[3, 3]) /* ty=Tensor[(1, 64, 720, 1280), float16] */;
  %1 = nn.bias_add(%0, %head.0.bias) /* ty=Tensor[(1, 64, 720, 1280), float16] */;
  %2 = nn.conv2d(%1, %body.0.body.0.weight, padding=[1, 1], kernel_size=[3, 3]) /* ty=Tensor[(1, 64, 720, 1280), float16] */;
  %3 = nn.bias_add(%2, %body.0.body.0.bias) /* ty=Tensor[(1, 64, 720, 1280), float16] */;
  %4 = nn.leaky_relu(%3, alpha=0.1f) /* ty=Tensor[(1, 64, 720, 1280), float16] */;
  %5 = nn.conv2d(%4, %body.0.body.2.weight, padding=[1, 1], kernel_size=[3, 3]) /* ty=Tensor[(1, 64, 720, 1280), float16] */;
  %6 = nn.bias_add(%5, %body.0.body.2.bias) /* ty=Tensor[(1, 64, 720, 1280), float16] */;
  %7 = add(%6, %1) /* ty=Tensor[(1, 64, 720, 1280), float16] */;
  %8 = nn.conv2d(%7, %body.1.body.0.weight, padding=[1, 1], kernel_size=[3, 3]) /* ty=Tensor[(1, 64, 720, 1280), float16] */;
  %9 = nn.bias_add(%8, %body.1.body.0.bias) /* ty=Tensor[(1, 64, 720, 1280), float16] */;
  %10 = nn.leaky_relu(%9, alpha=0.1f) /* ty=Tensor[(1, 64, 720, 1280), float16] */;
  %11 = nn.conv2d(%10, %body.1.body.2.weight, padding=[1, 1], kernel_size=[3, 3]) /* ty=Tensor[(1, 64, 720, 1280), float16] */;
  %12 = nn.bias_add(%11, %body.1.body.2.bias) /* ty=Tensor[(1, 64, 720, 1280), float16] */;
  %13 = add(%12, %7) /* ty=Tensor[(1, 64, 720, 1280), float16] */;
  %14 = nn.conv2d(%13, %body.2.weight, padding=[1, 1], kernel_size=[3, 3]) /* ty=Tensor[(1, 64, 720, 1280), float16] */;
  %15 = nn.bias_add(%14, %body.2.bias) /* ty=Tensor[(1, 64, 720, 1280), float16] */;
  %16 = add(%15, %1) /* ty=Tensor[(1, 64, 720, 1280), float16] */;
  %17 = nn.conv2d(%16, %tail.0.0.weight, padding=[1, 1], kernel_size=[3, 3]) /* ty=Tensor[(1, 128, 720, 1280), float16] */;
  %18 = nn.bias_add(%17, %tail.0.0.bias) /* ty=Tensor[(1, 128, 720, 1280), float16] */;
  %19 = nn.depth_to_space(%18, block_size=2, mode="CRD") /* ty=Tensor[(1, 32, 1440, 2560), float16] */;
  %20 = nn.conv2d(%19, %tail.0.2.weight, padding=[1, 1], kernel_size=[3, 3]) /* ty=Tensor[(1, 64, 1440, 2560), float16] */;
  %21 = nn.bias_add(%20, %tail.0.2.bias) /* ty=Tensor[(1, 64, 1440, 2560), float16] */;
  %22 = nn.depth_to_space(%21, block_size=2, mode="CRD") /* ty=Tensor[(1, 16, 2880, 5120), float16] */;
  %23 = nn.conv2d(%22, %tail.1.weight, padding=[1, 1], kernel_size=[3, 3]) /* ty=Tensor[(1, 3, 2880, 5120), float16] */;
  nn.bias_add(%23, %tail.1.bias) /* ty=Tensor[(1, 3, 2880, 5120), float16] */
}

@kice
Copy link
Contributor Author

kice commented Dec 29, 2019

tranpose

native

Without profiling, native finished in 55.235s, transpose finished in 55.639s.
Warmup 200 runs with cpu copy back, and benchmark without copyback for 600 runs.

@jwfromm
Copy link
Contributor

jwfromm commented Dec 29, 2019

Maybe the problem is just that this op shouldnt be fused when running on tensor cores. I suspect the access pattern just fundamentally causes bad behavior in the kernel. Can you try changing its fusion pattern to OPAQUE rather than INJECTIVE (~ line 976 in tvm/relay/op/nn/_nn.py)? Also what bug did you find? I can't see anything obvious in what you posted.

@kice
Copy link
Contributor Author

kice commented Dec 29, 2019

https://github.com/apache/incubator-tvm/blob/fadea9229c078f30469879ac6ab69920a2c0cbe0/python/tvm/relay/frontend/onnx.py#L542

We should decode it to str.

I think we should add more testcases for importing models.

@kice
Copy link
Contributor Author

kice commented Dec 29, 2019

I cannot see any difference with OPAQUE fusion pattern.

My thought is that, for every ops, someone (likey the gpu driver) warps data layout transpose before and after to every op, since tensor core like NCHW.

However, I remembered that NHWC input for TVM is even slower than we do all the NCHW to NHWC transpose for tensor core, and we also need #4456 for NHWC conv to work.

@anijain2305
Copy link
Contributor

@kice The ConvertLayout pass is merged now. I think you can give a try to convert the whole graph from NCHW to NHWC, if thats the current bottleneck for you. Currently, the conversion is supported only from NHWC -> NCHW. But, we can augment the pass to do the other way round as well.

@yzhliu
Copy link
Member

yzhliu commented Jan 2, 2020

@kice could you confirm whether ConvertLayout helps to reduce the overhead?

@kice
Copy link
Contributor Author

kice commented Jan 18, 2020

ConvertLayout only support NHWC(tf) -> NCHW(tvm), but I used mxnet to make the layout conversion manually, and it turned out even slower.

I think there is something to do with NVIDIA driver, I can still see layout conversion in nvprof.

@anijain2305
Copy link
Contributor

https://docs.tvm.ai/dev/convert_layout.html

Please use the above tutorial to add support for nchw to nhwc layout.

@tqchen tqchen closed this as completed Feb 26, 2020
@tqchen
Copy link
Member

tqchen commented Feb 26, 2020

Close for now as depth2space is implemented. Please feel free to open new threads, or create discussions on the forum

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants