Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Free CuArrays in the reverse pass #1340

Closed
wants to merge 7 commits into from
Closed

Conversation

mcabbott
Copy link
Member

@mcabbott mcabbott commented Dec 18, 2022

This adds:

  • A flag to Context to indicate that the pullback will never be called twice -- set to true for gradient, false for jacobian
  • Modifications to many rules, esp. for broadcasting, so that y=f(x) in the forward pass has finalize(y) in the reverse. This increases the largest size of Flux model which can run on a given GPU.

Applying such modifications everywhere led to many errors, some from rules like y = x .+ false which return y === x under Zygote. So they now require a separate macro @adjoint_final.

At present this modification is applied to all CR rrules. This is probably unsafe and we should revert 2524163 . Unclear how best to opt-in within ChainRules. Xref JuliaDiff/ChainRulesCore.jl#592 about the idea of a flag, but not entirely sure that's the right approach.

Explicit finalising won't work well with thunks. Which doesn't matter at all yet, but might after #966.

It also does not work with second derivatives, hence is disabled. Other uses of the context flag (like testing only_once(cfg) & then over-writing some array) probably also need to be disabled.

Needs FluxML/ZygoteRules.jl#23 so CI will fail. Locally, one failure, one failure to fail:

Global Params: Error During Test at /Users/me/.julia/dev/Zygote/test/features.jl:399
  Got exception outside of a @test
  KeyError: key :(Main.global_param) not found
  Stacktrace:
    [1] getindex(d::IdDict{Any, Any}, key::Any)
      @ Base ./iddict.jl:108
    [2] macro expansion
      @ ~/.julia/dev/Zygote/test/features.jl:404 [inlined]

Compiler: Error During Test at /Users/me/.julia/dev/Zygote/test/compiler.jl:35
 Unexpected Pass
 Expression: trace_contains(bt, :badly, "compiler.jl", 24)
 Got correct result, please change to @test if no longer broken.

@chengchingwen
Copy link
Member

Could we combine this with JuliaDiff/ChainRulesCore.jl#592 ?

@mcabbott
Copy link
Member Author

It's possible. I think that means having two distinct structs, ZygoteRuleConfig and ZygoteOnceRuleConfig or something.

At present, BTW, most of these maybe_finals seem not to be called & I'm not sure why.

@chengchingwen
Copy link
Member

It's possible. I think that means having two distinct structs, ZygoteRuleConfig and ZygoteOnceRuleConfig or something.

Or introduce another type parameter like ZygoteRuleConfig{once} where once?

At present, BTW, most of these maybe_finals seem not to be called & I'm not sure why.

Do you mean the finalize is not called, or it is called but the memory is not freed?

@mcabbott
Copy link
Member Author

mcabbott commented Dec 19, 2022

Do you mean the finalize is not called

With something like this

Zygote.maybe_final(x::CuArray) = begin CNT[]+=1; CUDA.unsafe_free!(x); nothing end

a big ResNet gradient [used to] gives me CNT[] == 3 afterwards. (Thought I had this working when I opened it...) [fixed in 9f01eff]

type parameter like ZygoteRuleConfig{once} where once

But I don't think that fits CR's mechanism; the current struct is <: RuleConfig{Union{HasReverseMode,NoForwardsMode}} and the new ones would need different supertypes.

We could also think about changing it to <: RuleConfig{Union{HasReverseMode,NoForwardsMode}, true}, in which case matching Context{..., true} would be easy.

@chengchingwen
Copy link
Member

But I don't think that fits CR's mechanism; the current struct is <: RuleConfig{Union{HasReverseMode,NoForwardsMode}} and the new ones would need different supertypes.

Couldn't it be done like struct ZygoteRuleConfig{P<:PullbackCapability} <: RuleConfig{Union{HasReverseMode,NoForwardsMode,P}}?

@mcabbott
Copy link
Member Author

Oh right, that ought to work.

Current status is that some arrays are freed too early (e.g. with Metalhead's ResNet, at addact(relu)) but it's hard to isolate. Still happens if I disable all thunks. In Zygote's tests, some failures due to too-early fill!(x, NaN) (included here as a test), perhaps related.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants