Skip to content

Commit

Permalink
Got FFT working again with new syntax.
Browse files Browse the repository at this point in the history
  • Loading branch information
duvenaud committed Apr 28, 2024
1 parent 25e2e38 commit 777daaa
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 25 deletions.
50 changes: 26 additions & 24 deletions lib/fft.dx
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import complex

'## Helper functions

def odd_sized_palindrome(mid:a, seq:n=>a) -> ((n `Either` () `Either` n)=>a) given (a, n|Ix) =
def odd_sized_palindrome(mid:a, seq:n=>a) -> ((n `Either` () `Either` n)=>a) given (a:Type, n|Ix) =
# Turns sequence 12345 into 543212345.
for i.
case i of
Expand All @@ -33,11 +33,11 @@ def butterfly_ixs(j':halfn, pow2:Nat) -> (n, n, n, n) given (halfn|Ix, n|Ix) =
# Note: with fancier index sets, this might be replacable by reshapes.
j = ordinal j'
k = ((idiv j pow2) * pow2 * 2) + mod j pow2
left_write_ix = unsafe_from_ordinal k
right_write_ix = unsafe_from_ordinal (k + pow2)
left_write_ix : n = unsafe_from_ordinal k
right_write_ix : n = unsafe_from_ordinal (k + pow2)

left_read_ix = unsafe_from_ordinal j
right_read_ix = unsafe_from_ordinal (j + size halfn)
left_read_ix : n = unsafe_from_ordinal j
right_read_ix : n = unsafe_from_ordinal (j + size halfn)
(left_read_ix, right_read_ix, left_write_ix, right_write_ix)

def power_of_2_fft(
Expand All @@ -59,8 +59,9 @@ def power_of_2_fft(
log2_half_n = unsafe_nat_diff log2_n 1 # TODO: use `i` as a proof that log2_n > 0
xRef := yield_accum (AddMonoid Complex) \bufRef.
for j:((Fin log2_half_n)=>(Fin 2)). # Executes in parallel.
t = (Fin log2_n) => Fin 2
(left_read_ix, right_read_ix,
left_write_ix, right_write_ix) = butterfly_ixs j ipow2
left_write_ix, right_write_ix) : (t, t, t, t) = butterfly_ixs j ipow2

# Read one element from the last buffer, scaled.
angle = dir_const * (n_to_f $ mod (ordinal j) ipow2) / n_to_f ipow2
Expand All @@ -78,7 +79,7 @@ def power_of_2_fft(
def pad_to_power_of_2(
log2_m:Nat,
pad_val:a, xs:n=>a
) -> ((Fin log2_m)=>(Fin 2))=>a given (a, n|Ix) =
) -> ((Fin log2_m)=>(Fin 2))=>a given (a:Type, n|Ix) =
flatsize = intpow2 log2_m
padded_flat = pad_to (Fin flatsize) pad_val xs
unsafe_cast_table(to=(Fin log2_m)=>(Fin 2), padded_flat)
Expand All @@ -91,21 +92,22 @@ def convolve_complex(
# Pad and convert to Fourier domain.
min_convolve_size = (size n + size m) -| 1
log_working_size = nextpow2 min_convolve_size
u_padded = pad_to_power_of_2 log_working_size zero u
v_padded = pad_to_power_of_2 log_working_size zero v
sn = size n
u_padded = pad_to_power_of_2 log_working_size (zero::Complex) u
v_padded = pad_to_power_of_2 log_working_size (zero::Complex) v
spectral_u = power_of_2_fft ForwardFT u_padded
spectral_v = power_of_2_fft ForwardFT v_padded

# Pointwise multiply.
spectral_conv = for i. spectral_u[i] * spectral_v[i]
spectral_conv = for i:(Fin log_working_size)=>(Fin 2). spectral_u[i] * spectral_v[i]

# Convert back to primal domain and undo padding.
padded_conv = power_of_2_fft InverseFT spectral_conv
slice padded_conv 0 (Either n m)

def convolve(u:n=>Float, v:m=>Float) -> (Either n m =>Float) given (n|Ix, m|Ix) =
u' = for i. Complex u[i] 0.0
v' = for i. Complex v[i] 0.0
u' = for i:n. Complex u[i] 0.0
v' = for i:m. Complex v[i] 0.0
ans = convolve_complex u' v'
for i. ans[i].re

Expand All @@ -114,14 +116,14 @@ def bluestein(x: n=>Complex) -> n=>Complex given (n|Ix) =
# Converts the general FFT into a convolution,
# which is then solved with calls to a power-of-2 FFT.
im = Complex 0.0 1.0
wks = for i.
wks = for i:n.
i_squared = n_to_f $ sq $ ordinal i
exp $ (-im) * (Complex (pi * i_squared / (n_to_f (size n))) 0.0)

AsList(_, tailTable) = tail wks 1
back_and_forth = odd_sized_palindrome (head wks) tailTable
xq = for i. x[i] * wks[i]
back_and_forth_conj = for i. complex_conj back_and_forth[i]
xq = for i:n. x[i] * wks[i]
back_and_forth_conj = each back_and_forth complex_conj
convolution = convolve_complex xq back_and_forth_conj
convslice = slice convolution (unsafe_nat_diff (size n) 1) n
for i. wks[i] * convslice[i]
Expand All @@ -147,19 +149,19 @@ def ifft(xs: n=>Complex) -> n=>Complex given (n|Ix) =
ret = power_of_2_fft InverseFT castx
unsafe_cast_table(to=n, ret)
else
unscaled_fft = fft (for i. complex_conj xs[i])
unscaled_fft = fft (each xs complex_conj)
for i. (complex_conj unscaled_fft[i]) / (n_to_f (size n))

def fft_real(x: n=>Float) -> n=>Complex given (n|Ix) = fft for i. Complex x[i] 0.0
def ifft_real(x: n=>Float) -> n=>Complex given (n|Ix) = ifft for i. Complex x[i] 0.0
def fft_real(x: n=>Float) -> n=>Complex given (n|Ix) = fft for i:n. Complex x[i] 0.0
def ifft_real(x: n=>Float) -> n=>Complex given (n|Ix) = ifft for i:n. Complex x[i] 0.0

def fft2(x: n=>m=>Complex) -> n=>m=>Complex given (n|Ix, m|Ix) =
x' = for i. fft x[i]
transpose for i. fft (transpose x')[i]
x' = for i:n. fft x[i]
transpose for i:m. fft (transpose x')[i]

def ifft2(x: n=>m=>Complex) -> n=>m=>Complex given (n|Ix, m|Ix) =
x' = for i. ifft x[i]
transpose for i. ifft (transpose x')[i]
x' = for i:n. ifft x[i]
transpose for i:m. ifft (transpose x')[i]

def fft2_real(x: n=>m=>Float) -> n=>m=>Complex given (n|Ix, m|Ix) = fft2 for i j. Complex x[i,j] 0.0
def ifft2_real(x: n=>m=>Float) -> n=>m=>Complex given (n|Ix, m|Ix) = ifft2 for i j. Complex x[i,j] 0.0
def fft2_real(x: n=>m=>Float) -> n=>m=>Complex given (n|Ix, m|Ix) = fft2 for i:n j:m. Complex x[i,j] 0.0
def ifft2_real(x: n=>m=>Float) -> n=>m=>Complex given (n|Ix, m|Ix) = ifft2 for i:n j:m. Complex x[i,j] 0.0
2 changes: 1 addition & 1 deletion tests/fft-tests.dx
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import complex
import fft

:p map nextpow2 [0, 1, 2, 3, 4, 7, 8, 9, 1023, 1024, 1025]
:p each [0, 1, 2, 3, 4, 7, 8, 9, 1023, 1024, 1025] nextpow2
> [0, 0, 1, 2, 2, 3, 3, 4, 10, 10, 11]

a : (Fin 4)=>Complex = arb $ new_key 0
Expand Down

0 comments on commit 777daaa

Please sign in to comment.