Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ResNet spending much time in CuArrays GC #149

Closed
KristofferC opened this issue Feb 5, 2019 · 25 comments
Closed

ResNet spending much time in CuArrays GC #149

KristofferC opened this issue Feb 5, 2019 · 25 comments
Labels
cuda array Stuff about CuArray. performance How fast can we go?

Comments

@KristofferC
Copy link
Contributor

I was profiling why a resnet model (https://github.com/KristofferC/resnet) was running extremely slow on Flux.

Sprinkling some sections using https://github.com/KristofferC/TimerOutputs.jl and training the model a little bit I got:

(edit: the timings below are stale due to changes in CuArrays, see https://github.com/JuliaGPU/CuArrays.jl/issues/273#issuecomment-461943376 for an update)

 ───────────────────────────────────────────────────────────────────────
                                Time                   Allocations
                        ──────────────────────   ───────────────────────
    Tot / % measured:        82.6s / 62.9%           9.09GiB / 3.32%

 Section        ncalls     time   %tot     avg     alloc   %tot      avg
 ───────────────────────────────────────────────────────────────────────
 gc true           468    47.8s  92.1%   102ms   4.88MiB  1.58%  10.7KiB
 crossentropy       32    1.91s  3.69%  59.8ms    159MiB  51.5%  4.98MiB
 conv            1.70k    1.61s  3.10%   948μs   85.6MiB  27.7%  51.7KiB
 dense              32    488ms  0.94%  15.2ms   59.1MiB  19.1%  1.85MiB
 gc false           32   81.0ms  0.16%  2.53ms    645KiB  0.20%  20.2KiB
 ───────────────────────────────────────────────────────────────────────

The gc true section refers to only this line:

https://github.com/JuliaGPU/CuArrays.jl/blob/61e25a2d239da77a5e8f3dc9746f9f62cd9e1380/src/memory.jl#L256

It seems this line is being called too often compared to how expensive a gc(true) call is.

@maleadt
Copy link
Member

maleadt commented Feb 5, 2019

Not much we can do at that point though, since the GPU is OOM, ie. JuliaGPU/CuArrays.jl#270. Querying memory and occasionally calling gc(false) might viable too.

@KristofferC
Copy link
Contributor Author

KristofferC commented Feb 5, 2019

I am already running a very small batch size on the resnet model. Since the GC pressure from the cuda arrays is low, we will always end up in this situation (GPU memory being full) sooner or later?

@maleadt
Copy link
Member

maleadt commented Feb 5, 2019

Since the GC pressure from the cuda arrays is low, we will always end up in this situation (GPU memory being full) sooner or later?

Yes, and extending the main GC to keep track of additional memory pressure and/or separate object pools doesn't seem like it'll be happening. Maybe we could maintain our own pressure metric and occasionally call out to the GC during alloc, but that'd still be costly since it would be processing Julia objects too. So I'm hoping some careful memory management a la JuliaGPU/CuArrays.jl#270 will suffice.

@KristofferC
Copy link
Contributor Author

While it is possible to do manual memory management in some cases, how about temporaries like in a + b + c + d? In order to keep the convenience of having a GC, it feels like some support from the runtime will be needed?

@MikeInnes
Copy link
Contributor

In case it's useful I threw this together: FluxML/Flux.jl#598. For a ResNet there's a risk that we're just holding on to too much memory even with the GC and making it thrash.

@KristofferC
Copy link
Contributor Author

KristofferC commented Feb 5, 2019

I don't think it is the parameters of the model that is taking up space but all the temporaries that we allocate until we hit oom and then gc(true) joins the party.

julia> Base.summarysize(MODEL) / (1024^2)
195.42853546142578

@maleadt
Copy link
Member

maleadt commented Feb 7, 2019

@KristofferC You got the code for this model? I'm having a look at some mem alloc perf improvements.

@KristofferC
Copy link
Contributor Author

Hopefully https://github.com/KristofferC/resnet should work. Just message otherwise.

@maleadt
Copy link
Member

maleadt commented Feb 7, 2019

I'm getting an image loading problem with https://github.com/KristofferC/resnet:

Error encountered while loading "tiny-imagenet-200/train/n03804744/images/n03804744_393.JPEG".
Fatal error:
<hang>

Image in question changes upon every run. How many memory does your model need for the initial run? EDIT: also fails on my 6GB Titan, so doesn't seem like an OOM. The Fatal error: comes from the FileIO package, FWIW, without any additional error stack-trace though.

To see which allocations are being an issue (ie. where we need to early-free) you could run with CUARRAYS_TRACE_POOL=true and move that code upwards to report on buffers when encountering an initial OOM. Or, even better, take the diff of pool_traces before and after gc(true) and report on the ones that got freed.

@KristofferC
Copy link
Contributor Author

The Fatal error: comes from the FileIO package, FWIW, without any additional error stack-trace though.

Seems like you just cant load images using FileIO then (I could really gripe about FileIO but that is another topic).

Sorry, but I am not sure how to debug where the allocations are coming from. Just enable CUARRAYS_TRACE_POOL for certain blocks and see how much is allocated in that block?

@maleadt
Copy link
Member

maleadt commented Feb 8, 2019

Sorry, but I am not sure how to debug where the allocations are coming from. Just enable CUARRAYS_TRACE_POOL for certain blocks and see how much is allocated in that block?

I'm having a look at this right now (in the context of denizyuret/Knet.jl#417 but that shouldn't matter).

@maleadt
Copy link
Member

maleadt commented Feb 8, 2019

@KristofferC Could you have a try with JuliaGPU/CuArrays.jl#277?

@KristofferC
Copy link
Contributor Author

WIll try!

@KristofferC
Copy link
Contributor Author

Ok so now I get

───────────────────────────────────────────────────────────────────────
                                Time                   Allocations
                        ──────────────────────   ───────────────────────
    Tot / % measured:        67.0s / 16.8%           9.16GiB / 3.30%

 Section        ncalls     time   %tot     avg     alloc   %tot      avg
 ───────────────────────────────────────────────────────────────────────
 conv            1.70k    6.59s  58.6%  3.89ms   85.6MiB  27.6%  51.7KiB
   gc true           7    806ms  7.17%   115ms    886KiB  0.28%   127KiB
 gc true            24    2.82s  25.0%   117ms   2.88MiB  0.93%   123KiB
 crossentropy       32    1.31s  11.7%  41.0ms    161MiB  52.0%  5.03MiB
 dense              32    508ms  4.52%  15.9ms   59.9MiB  19.3%  1.87MiB
 gc false            7   25.3ms  0.23%  3.62ms    284KiB  0.09%  40.6KiB
 ───────────────────────────────────────────────────────────────────────

So most of the time is now being spent outside my tracking regions and I will need to update it to include more regions. Note however:

  • Total time went from 82 - 67 seconds.
  • The conv call itself seems to be significantly slower (maybe just moving where the deallocation happens?)

Previously, the graph of GPU usage looked like:

image

Now it looks like:

image

@KristofferC
Copy link
Contributor Author

For reference, in Pytorch with the same model it looks like:

pytorch

(and the fans actually start spinning on my GPU :P)

@KristofferC
Copy link
Contributor Author

It's weird, it is like I am hitting different "modes" on the GPU. Now I had a quite fast run:

 ───────────────────────────────────────────────────────────────────────
                                Time                   Allocations
                        ──────────────────────   ───────────────────────
    Tot / % measured:        52.9s / 19.8%           9.15GiB / 3.27%

 Section        ncalls     time   %tot     avg     alloc   %tot      avg
 ───────────────────────────────────────────────────────────────────────
 conv            1.70k    6.58s  62.8%  3.88ms   84.8MiB  27.7%  51.2KiB
   gc true          14    1.49s  14.2%   106ms   1.11MiB  0.36%  81.5KiB
 crossentropy       32    1.70s  16.2%  53.2ms    160MiB  52.2%  5.00MiB
 gc true            16    1.69s  16.1%   105ms   1.66MiB  0.54%   106KiB
 dense              32    517ms  4.93%  16.2ms   59.9MiB  19.6%  1.87MiB
 ───────────────────────────────────────────────────────────────────────

@maleadt
Copy link
Member

maleadt commented Feb 8, 2019

I'm seeing something similar... denizyuret/Knet.jl#417 (comment)
Really unsure what's causing this.

EDIT: found some logic bugs in the manager though. Fixing those, although that shouldn't cause the nondeterminism we're seeing.

@maleadt
Copy link
Member

maleadt commented Feb 8, 2019

With those fixes I'm not seeing the changing behavior again.

But anyhow, I'm pretty sure I got rid of the costly gc(true), and the allocator is in a much better shape now, so I'm going to go ahead and close this issue.

If you have any more timing results or insights as to where your resnet model might run into CuArrays problems, just open a new issue. I won't have any time to look into profiling it myself though.

@maleadt maleadt closed this as completed Feb 8, 2019
@maleadt
Copy link
Member

maleadt commented Feb 8, 2019

Ah, found the FileIO issue: needed ImageMagick.jl. Strange how that error didn't go through when running under your benchmark script (interrupting the hang showed a stacktrace into a yield).

@maleadt
Copy link
Member

maleadt commented Feb 8, 2019

julia> CuArrays.@time main(16, 5)
Epoch: 1, Batch 1 / 5, 0.67 Images / sec 
Epoch: 1, Batch 2 / 5, 4.17 Images / sec 
Epoch: 1, Batch 3 / 5, 4.10 Images / sec 
Epoch: 1, Batch 4 / 5, 4.41 Images / sec 
Epoch: 1, Batch 5 / 5, 4.31 Images / sec 
 38.837985 seconds (66.33 M CPU allocations: 3.727 GiB, 22.26% gc time) (6.27 k GPU allocations: 39.669 GiB, 45.79% gc time of which 50.69% spent allocating)

GC time definitely still is relevant.

@maleadt maleadt reopened this Feb 8, 2019
@maleadt maleadt changed the title gc(true) can be prohibitly slow to call ResNet spending much time in CuArrays GC Feb 8, 2019
@maleadt
Copy link
Member

maleadt commented Feb 8, 2019

julia> TimerOutputs.reset_timer!(CuArrays.time_alloc); main(16, 1); CuArrays.time_alloc
Epoch: 1, Batch 1 / 1, 3.06 Images / sec 
 ───────────────────────────────────────────────────────────────────────────────────
                                            Time                   Allocations      
                                    ──────────────────────   ───────────────────────
          Tot / % measured:              5.25s / 54.1%            389MiB / 2.28%    

 Section                    ncalls     time   %tot     avg     alloc   %tot      avg
 ───────────────────────────────────────────────────────────────────────────────────
 pooled alloc                1.25k    2.84s   100%  2.26ms   8.86MiB  100%   7.24KiB
   step 4: reclaim unused       63    1.28s  45.3%  20.4ms   28.2KiB  0.31%        -
     reclaim                    63    299ms  10.5%  4.75ms         -  0.00%        -
     scan                       63   84.3μs  0.00%  1.34μs   20.7KiB  0.23%        -
   step 5: gc(true)              3    513ms  18.1%   171ms    310KiB  3.42%   103KiB
   step 3: gc(false)            93    485ms  17.1%  5.22ms    162KiB  1.78%  1.74KiB
   step 2: try alloc           656    474ms  16.7%   723μs   38.9KiB  0.43%        -
   step 1: check pool        1.25k    167μs  0.01%   133ns         -  0.00%        -
 scan                            1   1.60μs  0.00%  1.60μs         -  0.00%        -
 reclaim                         1   1.45μs  0.00%  1.45μs         -  0.00%        -
 ───────────────────────────────────────────────────────────────────────────────────

Spending way too much time in there. There's no good option though, the reclaim is slow because it calls cudaMemFree, but the only alternative is to gc(true) which is even slower.

The first thing to try is to sprinkle some unsafe_free! calls around. There's code (currently surrounding gc(true)) to print allocations that would get freed if the GC would do its job, so those would be good candidates for an early free.

An alternative approach could be to split larger blocks in order to fullfill allocation requests, but that will require yet another rework of the allocator.

@KristofferC
Copy link
Contributor Author

You've done so much on the CuArrays.jl side, maybe it fair we do some on the Flux side ;).

@maleadt
Copy link
Member

maleadt commented Feb 9, 2019

Yeah, seeing how JuliaGPU/CuArrays.jl#279 (comment) shows that the allocator performs really well on the CuArrays test suite (which also allocates a ton), I'd say that adapting the higher-up layers will benefit us more.

I tried returning larger buffers but that increases memory pressure too much to be beneficial on the ResNet model. I'll leave this issue open until I have the time to develop some better tools for tracing which outstanding objects are hurting the allocator.

@maleadt
Copy link
Member

maleadt commented Sep 10, 2019

Took a while, but finally having another look at this.

Working with a super simple CuArrays allocator, doing a straight CUDA malloc and when that fails calling into the GC and trying again, similarly directly freeing memory, the resnet model by @KristofferC only works if I allow 6GB of GPU memory to be allocated. This is without pooling, so there's no additional allocations or memory set aside or kept alive. That's pretty bad, right, seeing https://github.com/JuliaGPU/CuArrays.jl/issues/273#issuecomment-460661339 and how PyTorch consumes about 4GB (but with pooling memory, so it's really less).
Thoughts about that memory usage, @MikeInnes?

@maleadt maleadt transferred this issue from JuliaGPU/CuArrays.jl May 27, 2020
@maleadt maleadt added cuda array Stuff about CuArray. performance How fast can we go? labels May 27, 2020
@maleadt
Copy link
Member

maleadt commented Mar 2, 2021

Let's close this, see #137 (comment). Please open new issues with updated MWEs if the issue persists.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cuda array Stuff about CuArray. performance How fast can we go?
Projects
None yet
Development

No branches or pull requests

3 participants