Source code for pynlo.utility.fft

# -*- coding: utf-8 -*-
"""
Aliases to fast FFT implementations. Redirects to scipy methods if not
aliased here.
"""

__all__ = ["fft", "ifft", "rfft", "irfft",
           "fftshift", "ifftshift"]


# %% Imports

import os

try: # Attempt to set backend as FFTW3
    import pyfftw.interfaces.scipy_fft as backend
    import scipy.fft as _fft
    _fft.set_global_backend(backend)

    import pyfftw
    pyfftw.interfaces.cache.enable()
    pyfftw.config.NUM_THREADS = os.cpu_count()
    print('Using FFTW FFT backend')

except ImportError: # If FFTW3 is not installed, fall back to native Scipy
    import scipy.fft as _fft
    print('Using Scipy FFT backend')


# %% Transforms
# 
# ---- FFTs
[docs] def fft(x, fsc=1.0, n=None, axis=-1, overwrite_x=False): """ Performs a 1D FFT of the input array along the given axis. Uses either FFTW3 or `scipy` as a backend. Parameters ---------- x : array_like Input array, can be complex. fsc : float, optional The forward transform scale factor. The default is 1.0. n : int, optional Length of the transformed axis of the output. If `n` is smaller than the length of the input, the input is cropped. If it is larger, the input is padded with zeros. axis : int, optional Axis over which to compute the FFT. The default is the last axis. overwrite_x : bool, optional If True, the contents of x may be overwritten during the computation. The default is False. Returns ------- complex ndarray The transformed array. """ return fsc * _fft.fft(x, n=n, axis=axis, overwrite_x=overwrite_x, norm='backward')
#
[docs] def ifft(x, fsc=1.0, n=None, axis=-1, overwrite_x=False): """ Performs a 1D IFFT of the input array along the given axis. Uses either FFTW3 or `scipy` as a backend. Parameters ---------- x : array_like Input array, can be complex. fsc : float, optional The forward transform scale factor. Internally, this function sets the reverse transform scale factor as ``1/(n*fsc)``. The default is 1.0. n : int, optional Length of the transformed axis of the output. If `n` is smaller than the length of the input, the input is cropped. If it is larger, the input is padded with zeros. axis : int, optional Axis over which to compute the inverse FFT. The default is the last axis. overwrite_x : bool, optional If True, the contents of x may be overwritten during the computation. The default is False. Returns ------- complex ndarray The transformed array. """ return 1/fsc * _fft.ifft(x, n=n, axis=axis, overwrite_x=overwrite_x, norm='backward')
# ---- Real FFTs
[docs] def rfft(x, fsc=1.0, n=None, axis=-1): """ Performs a 1D FFT of the real input array along the given axis. The output array is complex and only contains positive frequencies. Uses either FFTW3 or `scipy` as a backend. The length of the transformed axis is ``n//2 + 1``. Parameters ---------- x : array_like Input array, must be real. fsc : float, optional The forward transform scale factor. The default is 1.0. n : int, optional Number of points to use along the transformed axis of the input. If `n` is smaller than the length of the input, the input is cropped. If it is larger, the input is padded with zeros. axis : int, optional Axis over which to compute the FFT. The default is the last axis. Returns ------- complex ndarray The transformed array. """ return fsc * _fft.rfft(x, n=n, axis=axis, norm='backward')
#
[docs] def irfft(x, fsc=1.0, n=None, axis=-1): """ Performs a 1D IFFT of the input array along the given axis. The input is assumed to contain only positive frequencies, and the output is always real. Uses either FFTW3 or `scipy` as a backend. If `n` is not given the length of the transformed axis is ``2*(m-1)``, where `m` is the length of the transformed axis of the input. To get an odd number of output points, `n` must be specified. Parameters ---------- x : array_like Input array, can be complex. fsc : float, optional The forward transform scale factor. Internally, this function sets the reverse transform scale factor as ``1/(n*fsc)``. The default is 1.0. n : int, optional Length of the transformed axis of the output. For `n` output points, ``n//2+1`` input points are necessary. If the input is longer than this, it is cropped. If it is shorter than this, it is padded with zeros. axis : int, optional Axis over which to compute the inverse FFT. The default is the last axis. Returns ------- ndarray The transformed array. """ return 1/fsc * _fft.irfft(x, n=n, axis=axis, norm='backward')
# %% Passthrough to Scipy methods def __getattr__(name): # If it's not defined in this module, fall back to scipy.fft try: return getattr(_fft, name) except AttributeError: raise AttributeError(f"module {__name__!r} has no attribute {name!r}") def __dir__(): # Show methods in lookup return sorted(set(globals()) | set(dir(_fft)))