1- # ffts
1+ module AbstractFFTsChainRulesCoreExt
2+
3+ using AbstractFFTs
4+ import ChainRulesCore
5+
26function ChainRulesCore. frule ((_, Δx, _), :: typeof (fft), x:: AbstractArray , dims)
37 y = fft (x, dims)
48 Δy = fft (Δx, dims)
@@ -46,7 +50,7 @@ function ChainRulesCore.frule((_, Δx, _), ::typeof(ifft), x::AbstractArray, dim
4650end
4751function ChainRulesCore. rrule (:: typeof (ifft), x:: AbstractArray , dims)
4852 y = ifft (x, dims)
49- invN = normalization (y, dims)
53+ invN = AbstractFFTs . normalization (y, dims)
5054 project_x = ChainRulesCore. ProjectTo (x)
5155 function ifft_pullback (ȳ)
5256 x̄ = project_x (invN .* fft (ChainRulesCore. unthunk (ȳ), dims))
@@ -66,7 +70,7 @@ function ChainRulesCore.rrule(::typeof(irfft), x::AbstractArray, d::Int, dims)
6670 # compute scaling factors
6771 halfdim = first (dims)
6872 n = size (x, halfdim)
69- invN = normalization (y, dims)
73+ invN = AbstractFFTs . normalization (y, dims)
7074 twoinvN = 2 * invN
7175 scale = reshape (
7276 [i == 1 || (i == n && 2 * (i - 1 ) == d) ? invN : twoinvN for i in 1 : n],
@@ -150,3 +154,5 @@ function ChainRulesCore.rrule(::typeof(ifftshift), x::AbstractArray, dims)
150154 end
151155 return y, ifftshift_pullback
152156end
157+
158+ end # module
0 commit comments