-
Notifications
You must be signed in to change notification settings - Fork 5
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
missing adjoints for rfft
and irfft
#246
Labels
bug
Something isn't working
Comments
That the reverse pass is wrong is very probable. It was kind of hard to implement. |
can you make a mwe of mlir and @mofeing do you want to give it another go? |
I can try it, but I need a MWE to know which FFT kind is being used and the result type of the ops. |
yes I will minimize this |
using FFTW, Reactant, Enzyme
function test_fft(x)
y = rfft(x)
z = irfft(y, size(x, 1))
return sum(abs2, z)
end
function ∇test_fft(x)
dx = Enzyme.make_zero(x)
Enzyme.autodiff(Reverse, test_fft, Active, Duplicated(x, dx))
return dx
end
x = randn(Float32, 4, 3)
x_ra = Reactant.ConcreteRArray(x)
@code_hlo test_fft(x_ra)
@code_hlo ∇test_fft(x_ra) |
julia> @code_hlo optimize=false ∇test_fft(x_ra)
Module:
module {
func.func private @abs2_broadcast_scalar(%arg0: tensor<f32>) -> (tensor<f32>, tensor<f32>) {
%0 = stablehlo.transpose %arg0, dims = [] : (tensor<f32>) -> tensor<f32>
%1 = stablehlo.multiply %0, %0 : tensor<f32>
%2 = stablehlo.transpose %1, dims = [] : (tensor<f32>) -> tensor<f32>
%3 = stablehlo.transpose %0, dims = [] : (tensor<f32>) -> tensor<f32>
return %2, %3 : tensor<f32>, tensor<f32>
}
func.func private @"Const{typeof(test_fft)}(Main.test_fft)_autodiff"(%arg0: tensor<3x4xf32>) -> (tensor<f32>, tensor<3x4xf32>) {
%0 = stablehlo.transpose %arg0, dims = [1, 0] : (tensor<3x4xf32>) -> tensor<4x3xf32>
%1 = stablehlo.transpose %0, dims = [1, 0] : (tensor<4x3xf32>) -> tensor<3x4xf32>
%2 = stablehlo.fft %1, type = RFFT, length = [3, 4] : (tensor<3x4xf32>) -> tensor<3x3xcomplex<f32>>
%3 = stablehlo.transpose %2, dims = [1, 0] : (tensor<3x3xcomplex<f32>>) -> tensor<3x3xcomplex<f32>>
%4 = stablehlo.transpose %3, dims = [1, 0] : (tensor<3x3xcomplex<f32>>) -> tensor<3x3xcomplex<f32>>
%5 = stablehlo.fft %4, type = IRFFT, length = [3, 4] : (tensor<3x3xcomplex<f32>>) -> tensor<3x4xf32>
%6 = stablehlo.transpose %5, dims = [1, 0] : (tensor<3x4xf32>) -> tensor<4x3xf32>
%cst = stablehlo.constant dense<0.000000e+00> : tensor<f32>
%7 = stablehlo.broadcast_in_dim %6, dims = [0, 1] : (tensor<4x3xf32>) -> tensor<4x3xf32>
%8:2 = enzyme.batch @abs2_broadcast_scalar(%7) {batch_shape = array<i64: 4, 3>} : (tensor<4x3xf32>) -> (tensor<4x3xf32>, tensor<4x3xf32>)
%9 = stablehlo.reduce(%8#0 init: %cst) applies stablehlo.add across dimensions = [0, 1] : (tensor<4x3xf32>, tensor<f32>) -> tensor<f32>
%10 = stablehlo.transpose %9, dims = [] : (tensor<f32>) -> tensor<f32>
%11 = stablehlo.transpose %0, dims = [1, 0] : (tensor<4x3xf32>) -> tensor<3x4xf32>
return %10, %11 : tensor<f32>, tensor<3x4xf32>
}
func.func @main(%arg0: tensor<3x4xf32>) -> (tensor<3x4xf32>, tensor<3x4xf32>) {
%0 = stablehlo.transpose %arg0, dims = [1, 0] : (tensor<3x4xf32>) -> tensor<4x3xf32>
%cst = stablehlo.constant dense<0.000000e+00> : tensor<4x3xf32>
%cst_0 = stablehlo.constant dense<1.000000e+00> : tensor<f32>
%1 = stablehlo.transpose %0, dims = [1, 0] : (tensor<4x3xf32>) -> tensor<3x4xf32>
%2 = stablehlo.transpose %cst_0, dims = [] : (tensor<f32>) -> tensor<f32>
%3 = stablehlo.transpose %cst, dims = [1, 0] : (tensor<4x3xf32>) -> tensor<3x4xf32>
%4:2 = enzyme.autodiff @"Const{typeof(test_fft)}(Main.test_fft)_autodiff"(%1, %2, %3) {activity = [#enzyme<activity enzyme_active>], ret_activity = [#enzyme<activity enzyme_activenoneed>, #enzyme<activity enzyme_active>]} : (tensor<3x4xf32>, tensor<f32>, tensor<3x4xf32>) -> (tensor<3x4xf32>, tensor<3x4xf32>)
%5 = stablehlo.transpose %4#0, dims = [1, 0] : (tensor<3x4xf32>) -> tensor<4x3xf32>
%6 = stablehlo.transpose %4#1, dims = [1, 0] : (tensor<3x4xf32>) -> tensor<4x3xf32>
%7 = stablehlo.transpose %6, dims = [1, 0] : (tensor<4x3xf32>) -> tensor<3x4xf32>
%8 = stablehlo.transpose %5, dims = [1, 0] : (tensor<4x3xf32>) -> tensor<3x4xf32>
return %7, %8 : tensor<3x4xf32>, tensor<3x4xf32>
}
} |
avik-pal
changed the title
potentially incorrect gradients for
missing adjoints for Nov 11, 2024
fft
?rfft
and irfft
ah sorry, I missed this MWE |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
I will generate an MWE in the morning. Testing out FNOs (SciML/NeuralOperators.jl#52) with #245
The forward pass compiles correctly, so I am guessing the adjoint isn't implemented correctly
The text was updated successfully, but these errors were encountered: