Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

missing adjoints for rfft and irfft #246

Open
avik-pal opened this issue Nov 8, 2024 · 7 comments
Open

missing adjoints for rfft and irfft #246

avik-pal opened this issue Nov 8, 2024 · 7 comments
Assignees
Labels
bug Something isn't working

Comments

@avik-pal
Copy link
Collaborator

avik-pal commented Nov 8, 2024

I will generate an MWE in the morning. Testing out FNOs (SciML/NeuralOperators.jl#52) with #245

julia> ∇fno_compiled = @compile ∇fno(fno, ps, st, x)
error: FFT/IFFT/IRFFT take a complex tensor as input, but is given 'tensor<5x64x1024xf32>'.
LLVM ERROR: Failed to infer result type(s):
"stablehlo.fft"(...) {} : (tensor<5x64x1024xf32>) -> ( ??? )

[165838] signal 6 (-6): Aborted
in expression starting at /mnt/software/lux/NeuralOperators.jl/docs/src/tutorials/xla_compilation.md:80
unknown function (ip: 0x790d9d52a3f4)
gsignal at /usr/lib/libc.so.6 (unknown line)
abort at /usr/lib/libc.so.6 (unknown line)
_ZN4llvm18report_fatal_errorERKNS_5TwineEb.cold at /mnt/.julia/artifacts/0fb8e911ac752666a5be8cae1002f1b835e430b9/lib/libReactantExtra.so (unknown line)
_ZN4llvm18report_fatal_errorENS_9StringRefEb at /mnt/.julia/artifacts/0fb8e911ac752666a5be8cae1002f1b835e430b9/lib/libReactantExtra.so (unknown line)
_ZN4mlir6detail32reportFatalInferReturnTypesErrorERNS_14OperationStateE at /mnt/.julia/artifacts/0fb8e911ac752666a5be8cae1002f1b835e430b9/lib/libReactantExtra.so (unknown line)
_ZN4mlir9stablehlo5FftOp5buildERNS_9OpBuilderERNS_14OperationStateENS_5ValueENS0_7FftTypeEN4llvm8ArrayRefIlEE at /mnt/.julia/artifacts/0fb8e911ac752666a5be8cae1002f1b835e430b9/lib/libReactantExtra.so (unknown line)
_ZNK12_GLOBAL__N_118FftOpRevDerivative24createReverseModeAdjointEPN4mlir9OperationERNS1_9OpBuilderEPNS1_6enzyme21MGradientUtilsReverseEN4llvm11SmallVectorINS1_5ValueELj6EEE.isra.0 at /mnt/.julia/artifacts/0fb8e911ac752666a5be8cae1002f1b835e430b9/lib/libReactantExtra.so (unknown line)
_ZN4mlir6enzyme6detail41ReverseAutoDiffOpInterfaceInterfaceTraits13FallbackModelIN12_GLOBAL__N_118FftOpRevDerivativeEE24createReverseModeAdjointEPKNS2_7ConceptEPNS_9OperationERNS_9OpBuilderEPNS0_21MGradientUtilsReverseEN4llvm11SmallVectorINS_5ValueELj6EEE at /mnt/.julia/artifacts/0fb8e911ac752666a5be8cae1002f1b835e430b9/lib/libReactantExtra.so (unknown line)
_ZN4mlir6enzyme26ReverseAutoDiffOpInterface24createReverseModeAdjointERNS_9OpBuilderEPNS0_21MGradientUtilsReverseEN4llvm11SmallVectorINS_5ValueELj6EEE at /mnt/.julia/artifacts/0fb8e911ac752666a5be8cae1002f1b835e430b9/lib/libReactantExtra.so (unknown line)
_ZN4mlir6enzyme12MEnzymeLogic10visitChildEPNS_9OperationERNS_9OpBuilderEPNS0_21MGradientUtilsReverseE at /mnt/.julia/artifacts/0fb8e911ac752666a5be8cae1002f1b835e430b9/lib/libReactantExtra.so (unknown line)
_ZN4mlir6enzyme12MEnzymeLogic13visitChildrenEPNS_5BlockES3_PNS0_21MGradientUtilsReverseE at /mnt/.julia/artifacts/0fb8e911ac752666a5be8cae1002f1b835e430b9/lib/libReactantExtra.so (unknown line)
_ZN4mlir6enzyme12MEnzymeLogic13differentiateEPNS0_21MGradientUtilsReverseERNS_6RegionES5_N4llvm12function_refIFvRNS_9OpBuilderEPNS_5BlockEEEESt8functionIFSt4pairINS_5ValueESG_ENS_4TypeEEE at /mnt/.julia/artifacts/0fb8e911ac752666a5be8cae1002f1b835e430b9/lib/libReactantExtra.so (unknown line)
_ZN4mlir6enzyme12MEnzymeLogic17CreateReverseDiffENS_19FunctionOpInterfaceESt6vectorI10DIFFE_TYPESaIS4_EES6_RNS0_13MTypeAnalysisES3_IbSaIbEESA_14DerivativeModebmNS_4TypeENS0_11MFnTypeInfoESA_Pv at /mnt/.julia/artifacts/0fb8e911ac752666a5be8cae1002f1b835e430b9/lib/libReactantExtra.so (unknown line)
_ZN12_GLOBAL__N_117DifferentiatePass16lowerEnzymeCallsERN4mlir21SymbolTableCollectionENS1_19FunctionOpInterfaceE.isra.0 at /mnt/.julia/artifacts/0fb8e911ac752666a5be8cae1002f1b835e430b9/lib/libReactantExtra.so (unknown line)
_ZN12_GLOBAL__N_117DifferentiatePass14runOnOperationEv at /mnt/.julia/artifacts/0fb8e911ac752666a5be8cae1002f1b835e430b9/lib/libReactantExtra.so (unknown line)
_ZN4mlir6detail17OpToOpPassAdaptor3runEPNS_4PassEPNS_9OperationENS_15AnalysisManagerEbj at /mnt/.julia/artifacts/0fb8e911ac752666a5be8cae1002f1b835e430b9/lib/libReactantExtra.so (unknown line)
_ZN4mlir6detail17OpToOpPassAdaptor11runPipelineERNS_13OpPassManagerEPNS_9OperationENS_15AnalysisManagerEbjPNS_16PassInstrumentorEPKNS_19PassInstrumentation18PipelineParentInfoE at /mnt/.julia/artifacts/0fb8e911ac752666a5be8cae1002f1b835e430b9/lib/libReactantExtra.so (unknown line)
_ZN4mlir11PassManager9runPassesEPNS_9OperationENS_15AnalysisManagerE at /mnt/.julia/artifacts/0fb8e911ac752666a5be8cae1002f1b835e430b9/lib/libReactantExtra.so (unknown line)
_ZN4mlir11PassManager3runEPNS_9OperationE at /mnt/.julia/artifacts/0fb8e911ac752666a5be8cae1002f1b835e430b9/lib/libReactantExtra.so (unknown line)
mlirPassManagerRunOnOp at /mnt/.julia/artifacts/0fb8e911ac752666a5be8cae1002f1b835e430b9/lib/libReactantExtra.so (unknown line)
mlirPassManagerRunOnOp at /mnt/software/lux/Reactant.jl/src/mlir/libMLIR_h.jl:5853 [inlined]
run! at /mnt/software/lux/Reactant.jl/src/mlir/IR/Pass.jl:65 [inlined]
run_pass_pipeline! at /mnt/software/lux/Reactant.jl/src/Compiler.jl:251
#compile_mlir!#4 at /mnt/software/lux/Reactant.jl/src/Compiler.jl:283
compile_mlir! at /mnt/software/lux/Reactant.jl/src/Compiler.jl:265 [inlined]
#30 at /mnt/software/lux/Reactant.jl/src/Compiler.jl:720
context! at /mnt/software/lux/Reactant.jl/src/mlir/IR/Context.jl:76
unknown function (ip: 0x790d2f1841d6)
#compile_xla#29 at /mnt/software/lux/Reactant.jl/src/Compiler.jl:717
compile_xla at /mnt/software/lux/Reactant.jl/src/Compiler.jl:712 [inlined]
#compile#34 at /mnt/software/lux/Reactant.jl/src/Compiler.jl:744
compile at /mnt/software/lux/Reactant.jl/src/Compiler.jl:743
unknown function (ip: 0x790d2f18211d)
jl_apply at /cache/build/builder-demeter6-6/julialang/julia-master/src/julia.h:2157 [inlined]
do_call at /cache/build/builder-demeter6-6/julialang/julia-master/src/interpreter.c:126
eval_value at /cache/build/builder-demeter6-6/julialang/julia-master/src/interpreter.c:223
eval_stmt_value at /cache/build/builder-demeter6-6/julialang/julia-master/src/interpreter.c:174 [inlined]
eval_body at /cache/build/builder-demeter6-6/julialang/julia-master/src/interpreter.c:663
jl_interpret_toplevel_thunk at /cache/build/builder-demeter6-6/julialang/julia-master/src/interpreter.c:821
jl_toplevel_eval_flex at /cache/build/builder-demeter6-6/julialang/julia-master/src/toplevel.c:943
jl_toplevel_eval_flex at /cache/build/builder-demeter6-6/julialang/julia-master/src/toplevel.c:886
ijl_toplevel_eval_in at /cache/build/builder-demeter6-6/julialang/julia-master/src/toplevel.c:994
eval at ./boot.jl:430 [inlined]
include_string at ./loading.jl:2643
jl_apply at /cache/build/builder-demeter6-6/julialang/julia-master/src/julia.h:2157 [inlined]
jl_f__call_latest at /cache/build/builder-demeter6-6/julialang/julia-master/src/builtins.c:875
jl_apply at /cache/build/builder-demeter6-6/julialang/julia-master/src/julia.h:2157 [inlined]
do_apply at /cache/build/builder-demeter6-6/julialang/julia-master/src/builtins.c:831
#invokelatest#2 at ./essentials.jl:1055
jl_apply at /cache/build/builder-demeter6-6/julialang/julia-master/src/julia.h:2157 [inlined]
do_apply at /cache/build/builder-demeter6-6/julialang/julia-master/src/builtins.c:831
invokelatest at ./essentials.jl:1052
jl_apply at /cache/build/builder-demeter6-6/julialang/julia-master/src/julia.h:2157 [inlined]
do_apply at /cache/build/builder-demeter6-6/julialang/julia-master/src/builtins.c:831
#inlineeval#76 at /home/avikpal/.vscode/extensions/julialang.language-julia-1.127.2/scripts/packages/VSCodeServer/src/eval.jl:271
inlineeval at /home/avikpal/.vscode/extensions/julialang.language-julia-1.127.2/scripts/packages/VSCodeServer/src/eval.jl:268
#69 at /home/avikpal/.vscode/extensions/julialang.language-julia-1.127.2/scripts/packages/VSCodeServer/src/eval.jl:181
withpath at /home/avikpal/.vscode/extensions/julialang.language-julia-1.127.2/scripts/packages/VSCodeServer/src/repl.jl:276
#68 at /home/avikpal/.vscode/extensions/julialang.language-julia-1.127.2/scripts/packages/VSCodeServer/src/eval.jl:179
hideprompt at /home/avikpal/.vscode/extensions/julialang.language-julia-1.127.2/scripts/packages/VSCodeServer/src/repl.jl:38
#67 at /home/avikpal/.vscode/extensions/julialang.language-julia-1.127.2/scripts/packages/VSCodeServer/src/eval.jl:150 [inlined]
with_logstate at ./logging/logging.jl:522
with_logger at ./logging/logging.jl:632 [inlined]
#66 at /home/avikpal/.vscode/extensions/julialang.language-julia-1.127.2/scripts/packages/VSCodeServer/src/eval.jl:263
unknown function (ip: 0x790d5272df0f)
jl_apply at /cache/build/builder-demeter6-6/julialang/julia-master/src/julia.h:2157 [inlined]
jl_f__call_latest at /cache/build/builder-demeter6-6/julialang/julia-master/src/builtins.c:875
#invokelatest#2 at ./essentials.jl:1055 [inlined]
invokelatest at ./essentials.jl:1052
unknown function (ip: 0x790d62700422)
jl_apply at /cache/build/builder-demeter6-6/julialang/julia-master/src/julia.h:2157 [inlined]
do_apply at /cache/build/builder-demeter6-6/julialang/julia-master/src/builtins.c:831
#64 at /home/avikpal/.vscode/extensions/julialang.language-julia-1.127.2/scripts/packages/VSCodeServer/src/eval.jl:34
unknown function (ip: 0x790d6273a74f)
jl_apply at /cache/build/builder-demeter6-6/julialang/julia-master/src/julia.h:2157 [inlined]
start_task at /cache/build/builder-demeter6-6/julialang/julia-master/src/task.c:1202
Allocations: 255503837 (Pool: 255490289; Big: 13548); GC: 138
[1]    165838 IOT instruction (core dumped)  julia --project=docs --threads=auto

The forward pass compiles correctly, so I am guessing the adjoint isn't implemented correctly

@mofeing
Copy link
Collaborator

mofeing commented Nov 8, 2024

That the reverse pass is wrong is very probable. It was kind of hard to implement.

@wsmoses
Copy link
Member

wsmoses commented Nov 8, 2024

can you make a mwe of mlir and @mofeing do you want to give it another go?

@mofeing
Copy link
Collaborator

mofeing commented Nov 8, 2024

I can try it, but I need a MWE to know which FFT kind is being used and the result type of the ops.

@avik-pal
Copy link
Collaborator Author

avik-pal commented Nov 8, 2024

yes I will minimize this

@avik-pal
Copy link
Collaborator Author

avik-pal commented Nov 8, 2024

using FFTW, Reactant, Enzyme

function test_fft(x)
    y = rfft(x)
    z = irfft(y, size(x, 1))
    return sum(abs2, z)
end

function ∇test_fft(x)
    dx = Enzyme.make_zero(x)
    Enzyme.autodiff(Reverse, test_fft, Active, Duplicated(x, dx))
    return dx
end

x = randn(Float32, 4, 3)
x_ra = Reactant.ConcreteRArray(x)

@code_hlo test_fft(x_ra)

@code_hlo ∇test_fft(x_ra)

@avik-pal
Copy link
Collaborator Author

avik-pal commented Nov 8, 2024

julia> @code_hlo optimize=falsetest_fft(x_ra)
Module:
module {
  func.func private @abs2_broadcast_scalar(%arg0: tensor<f32>) -> (tensor<f32>, tensor<f32>) {
    %0 = stablehlo.transpose %arg0, dims = [] : (tensor<f32>) -> tensor<f32>
    %1 = stablehlo.multiply %0, %0 : tensor<f32>
    %2 = stablehlo.transpose %1, dims = [] : (tensor<f32>) -> tensor<f32>
    %3 = stablehlo.transpose %0, dims = [] : (tensor<f32>) -> tensor<f32>
    return %2, %3 : tensor<f32>, tensor<f32>
  }
  func.func private @"Const{typeof(test_fft)}(Main.test_fft)_autodiff"(%arg0: tensor<3x4xf32>) -> (tensor<f32>, tensor<3x4xf32>) {
    %0 = stablehlo.transpose %arg0, dims = [1, 0] : (tensor<3x4xf32>) -> tensor<4x3xf32>
    %1 = stablehlo.transpose %0, dims = [1, 0] : (tensor<4x3xf32>) -> tensor<3x4xf32>
    %2 = stablehlo.fft %1, type =  RFFT, length = [3, 4] : (tensor<3x4xf32>) -> tensor<3x3xcomplex<f32>>
    %3 = stablehlo.transpose %2, dims = [1, 0] : (tensor<3x3xcomplex<f32>>) -> tensor<3x3xcomplex<f32>>
    %4 = stablehlo.transpose %3, dims = [1, 0] : (tensor<3x3xcomplex<f32>>) -> tensor<3x3xcomplex<f32>>
    %5 = stablehlo.fft %4, type =  IRFFT, length = [3, 4] : (tensor<3x3xcomplex<f32>>) -> tensor<3x4xf32>
    %6 = stablehlo.transpose %5, dims = [1, 0] : (tensor<3x4xf32>) -> tensor<4x3xf32>
    %cst = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %7 = stablehlo.broadcast_in_dim %6, dims = [0, 1] : (tensor<4x3xf32>) -> tensor<4x3xf32>
    %8:2 = enzyme.batch @abs2_broadcast_scalar(%7) {batch_shape = array<i64: 4, 3>} : (tensor<4x3xf32>) -> (tensor<4x3xf32>, tensor<4x3xf32>)
    %9 = stablehlo.reduce(%8#0 init: %cst) applies stablehlo.add across dimensions = [0, 1] : (tensor<4x3xf32>, tensor<f32>) -> tensor<f32>
    %10 = stablehlo.transpose %9, dims = [] : (tensor<f32>) -> tensor<f32>
    %11 = stablehlo.transpose %0, dims = [1, 0] : (tensor<4x3xf32>) -> tensor<3x4xf32>
    return %10, %11 : tensor<f32>, tensor<3x4xf32>
  }
  func.func @main(%arg0: tensor<3x4xf32>) -> (tensor<3x4xf32>, tensor<3x4xf32>) {
    %0 = stablehlo.transpose %arg0, dims = [1, 0] : (tensor<3x4xf32>) -> tensor<4x3xf32>
    %cst = stablehlo.constant dense<0.000000e+00> : tensor<4x3xf32>
    %cst_0 = stablehlo.constant dense<1.000000e+00> : tensor<f32>
    %1 = stablehlo.transpose %0, dims = [1, 0] : (tensor<4x3xf32>) -> tensor<3x4xf32>
    %2 = stablehlo.transpose %cst_0, dims = [] : (tensor<f32>) -> tensor<f32>
    %3 = stablehlo.transpose %cst, dims = [1, 0] : (tensor<4x3xf32>) -> tensor<3x4xf32>
    %4:2 = enzyme.autodiff @"Const{typeof(test_fft)}(Main.test_fft)_autodiff"(%1, %2, %3) {activity = [#enzyme<activity enzyme_active>], ret_activity = [#enzyme<activity enzyme_activenoneed>, #enzyme<activity enzyme_active>]} : (tensor<3x4xf32>, tensor<f32>, tensor<3x4xf32>) -> (tensor<3x4xf32>, tensor<3x4xf32>)
    %5 = stablehlo.transpose %4#0, dims = [1, 0] : (tensor<3x4xf32>) -> tensor<4x3xf32>
    %6 = stablehlo.transpose %4#1, dims = [1, 0] : (tensor<3x4xf32>) -> tensor<4x3xf32>
    %7 = stablehlo.transpose %6, dims = [1, 0] : (tensor<4x3xf32>) -> tensor<3x4xf32>
    %8 = stablehlo.transpose %5, dims = [1, 0] : (tensor<4x3xf32>) -> tensor<3x4xf32>
    return %7, %8 : tensor<3x4xf32>, tensor<3x4xf32>
  }
}

@avik-pal avik-pal changed the title potentially incorrect gradients for fft? missing adjoints for rfft and irfft Nov 11, 2024
@mofeing
Copy link
Collaborator

mofeing commented Nov 13, 2024

ah sorry, I missed this MWE
I can take charge of this

@mofeing mofeing self-assigned this Nov 13, 2024
@mofeing mofeing added the bug Something isn't working label Nov 13, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants