From 777daaa154f05f088d1dbf89859c052e1e7482c2 Mon Sep 17 00:00:00 2001 From: David Duvenaud Date: Sat, 27 Apr 2024 21:21:30 -0400 Subject: [PATCH] Got FFT working again with new syntax. --- lib/fft.dx | 50 ++++++++++++++++++++++++---------------------- tests/fft-tests.dx | 2 +- 2 files changed, 27 insertions(+), 25 deletions(-) diff --git a/lib/fft.dx b/lib/fft.dx index 825a2c6fc..7f6637e1b 100644 --- a/lib/fft.dx +++ b/lib/fft.dx @@ -11,7 +11,7 @@ import complex '## Helper functions -def odd_sized_palindrome(mid:a, seq:n=>a) -> ((n `Either` () `Either` n)=>a) given (a, n|Ix) = +def odd_sized_palindrome(mid:a, seq:n=>a) -> ((n `Either` () `Either` n)=>a) given (a:Type, n|Ix) = # Turns sequence 12345 into 543212345. for i. case i of @@ -33,11 +33,11 @@ def butterfly_ixs(j':halfn, pow2:Nat) -> (n, n, n, n) given (halfn|Ix, n|Ix) = # Note: with fancier index sets, this might be replacable by reshapes. j = ordinal j' k = ((idiv j pow2) * pow2 * 2) + mod j pow2 - left_write_ix = unsafe_from_ordinal k - right_write_ix = unsafe_from_ordinal (k + pow2) + left_write_ix : n = unsafe_from_ordinal k + right_write_ix : n = unsafe_from_ordinal (k + pow2) - left_read_ix = unsafe_from_ordinal j - right_read_ix = unsafe_from_ordinal (j + size halfn) + left_read_ix : n = unsafe_from_ordinal j + right_read_ix : n = unsafe_from_ordinal (j + size halfn) (left_read_ix, right_read_ix, left_write_ix, right_write_ix) def power_of_2_fft( @@ -59,8 +59,9 @@ def power_of_2_fft( log2_half_n = unsafe_nat_diff log2_n 1 # TODO: use `i` as a proof that log2_n > 0 xRef := yield_accum (AddMonoid Complex) \bufRef. for j:((Fin log2_half_n)=>(Fin 2)). # Executes in parallel. + t = (Fin log2_n) => Fin 2 (left_read_ix, right_read_ix, - left_write_ix, right_write_ix) = butterfly_ixs j ipow2 + left_write_ix, right_write_ix) : (t, t, t, t) = butterfly_ixs j ipow2 # Read one element from the last buffer, scaled. angle = dir_const * (n_to_f $ mod (ordinal j) ipow2) / n_to_f ipow2 @@ -78,7 +79,7 @@ def power_of_2_fft( def pad_to_power_of_2( log2_m:Nat, pad_val:a, xs:n=>a - ) -> ((Fin log2_m)=>(Fin 2))=>a given (a, n|Ix) = + ) -> ((Fin log2_m)=>(Fin 2))=>a given (a:Type, n|Ix) = flatsize = intpow2 log2_m padded_flat = pad_to (Fin flatsize) pad_val xs unsafe_cast_table(to=(Fin log2_m)=>(Fin 2), padded_flat) @@ -91,21 +92,22 @@ def convolve_complex( # Pad and convert to Fourier domain. min_convolve_size = (size n + size m) -| 1 log_working_size = nextpow2 min_convolve_size - u_padded = pad_to_power_of_2 log_working_size zero u - v_padded = pad_to_power_of_2 log_working_size zero v + sn = size n + u_padded = pad_to_power_of_2 log_working_size (zero::Complex) u + v_padded = pad_to_power_of_2 log_working_size (zero::Complex) v spectral_u = power_of_2_fft ForwardFT u_padded spectral_v = power_of_2_fft ForwardFT v_padded # Pointwise multiply. - spectral_conv = for i. spectral_u[i] * spectral_v[i] + spectral_conv = for i:(Fin log_working_size)=>(Fin 2). spectral_u[i] * spectral_v[i] # Convert back to primal domain and undo padding. padded_conv = power_of_2_fft InverseFT spectral_conv slice padded_conv 0 (Either n m) def convolve(u:n=>Float, v:m=>Float) -> (Either n m =>Float) given (n|Ix, m|Ix) = - u' = for i. Complex u[i] 0.0 - v' = for i. Complex v[i] 0.0 + u' = for i:n. Complex u[i] 0.0 + v' = for i:m. Complex v[i] 0.0 ans = convolve_complex u' v' for i. ans[i].re @@ -114,14 +116,14 @@ def bluestein(x: n=>Complex) -> n=>Complex given (n|Ix) = # Converts the general FFT into a convolution, # which is then solved with calls to a power-of-2 FFT. im = Complex 0.0 1.0 - wks = for i. + wks = for i:n. i_squared = n_to_f $ sq $ ordinal i exp $ (-im) * (Complex (pi * i_squared / (n_to_f (size n))) 0.0) AsList(_, tailTable) = tail wks 1 back_and_forth = odd_sized_palindrome (head wks) tailTable - xq = for i. x[i] * wks[i] - back_and_forth_conj = for i. complex_conj back_and_forth[i] + xq = for i:n. x[i] * wks[i] + back_and_forth_conj = each back_and_forth complex_conj convolution = convolve_complex xq back_and_forth_conj convslice = slice convolution (unsafe_nat_diff (size n) 1) n for i. wks[i] * convslice[i] @@ -147,19 +149,19 @@ def ifft(xs: n=>Complex) -> n=>Complex given (n|Ix) = ret = power_of_2_fft InverseFT castx unsafe_cast_table(to=n, ret) else - unscaled_fft = fft (for i. complex_conj xs[i]) + unscaled_fft = fft (each xs complex_conj) for i. (complex_conj unscaled_fft[i]) / (n_to_f (size n)) -def fft_real(x: n=>Float) -> n=>Complex given (n|Ix) = fft for i. Complex x[i] 0.0 -def ifft_real(x: n=>Float) -> n=>Complex given (n|Ix) = ifft for i. Complex x[i] 0.0 +def fft_real(x: n=>Float) -> n=>Complex given (n|Ix) = fft for i:n. Complex x[i] 0.0 +def ifft_real(x: n=>Float) -> n=>Complex given (n|Ix) = ifft for i:n. Complex x[i] 0.0 def fft2(x: n=>m=>Complex) -> n=>m=>Complex given (n|Ix, m|Ix) = - x' = for i. fft x[i] - transpose for i. fft (transpose x')[i] + x' = for i:n. fft x[i] + transpose for i:m. fft (transpose x')[i] def ifft2(x: n=>m=>Complex) -> n=>m=>Complex given (n|Ix, m|Ix) = - x' = for i. ifft x[i] - transpose for i. ifft (transpose x')[i] + x' = for i:n. ifft x[i] + transpose for i:m. ifft (transpose x')[i] -def fft2_real(x: n=>m=>Float) -> n=>m=>Complex given (n|Ix, m|Ix) = fft2 for i j. Complex x[i,j] 0.0 -def ifft2_real(x: n=>m=>Float) -> n=>m=>Complex given (n|Ix, m|Ix) = ifft2 for i j. Complex x[i,j] 0.0 +def fft2_real(x: n=>m=>Float) -> n=>m=>Complex given (n|Ix, m|Ix) = fft2 for i:n j:m. Complex x[i,j] 0.0 +def ifft2_real(x: n=>m=>Float) -> n=>m=>Complex given (n|Ix, m|Ix) = ifft2 for i:n j:m. Complex x[i,j] 0.0 diff --git a/tests/fft-tests.dx b/tests/fft-tests.dx index ec0369bb3..9d0e70a4f 100644 --- a/tests/fft-tests.dx +++ b/tests/fft-tests.dx @@ -1,7 +1,7 @@ import complex import fft -:p map nextpow2 [0, 1, 2, 3, 4, 7, 8, 9, 1023, 1024, 1025] +:p each [0, 1, 2, 3, 4, 7, 8, 9, 1023, 1024, 1025] nextpow2 > [0, 0, 1, 2, 2, 3, 3, 4, 10, 10, 11] a : (Fin 4)=>Complex = arb $ new_key 0