fft_conv2d#

fft_conv2d(x, weight)#

Computes 2D depthwise convolution with ‘same’ padding using FFT.

Parameters:
  • x (torch.Tensor) – Input tensor of shape (batch_size, hidden_dim, x_dim_seq, y_dim_seq).

  • weight (torch.Tensor) – Weight tensor of shape (hidden_dim, x_dim_kernel, y_dim_kernel).

Returns:

Output tensor with shape (batch_size, hidden_dim, x_dim_seq, y_dim_seq).

Return type:

torch.Tensor

B, H, X_in, Y_in = x.shape
BW, h_k, K_x, K_y = weight.shape
assert H == h_k, "Input and kernel must have the same number of channels (H)."

# 1. Determine FFT size for linear convolution (same as 'valid' version)
fft_shape = (min( X_in + (K_x + 1) // 2, 2 * X_in), min( Y_in + (K_y + 1) // 2, 2 * Y_in))
input_fft = torch.fft.rfft2(x, s=fft_shape)
weight_fft = torch.fft.rfft2(weight, s=fft_shape)
# 2. Apply the Convolution Theorem
conv_fft = input_fft * weight_fft

# 3. Compute the inverse FFT to get the full convolution result
output_full = torch.fft.irfft2(conv_fft, s=fft_shape)

# 4. Crop the result to the 'same' size
# The output should have the same size as the input: (X_in, Y_in)
# To achieve this, we crop from the full convolution result,
# starting at an offset that centers the output.
crop_start_h = (K_x) // 2
crop_start_w = (K_y) // 2

output = output_full[:, :, crop_start_h : crop_start_h + X_in, crop_start_w : crop_start_w + Y_in]

Note

The kernel shape must be less than or equal to twice the input shape.