fft_conv2d#
- fft_conv2d(x, weight)#
Computes 2D depthwise convolution with ‘same’ padding using FFT.
- Parameters:
x (torch.Tensor) – Input tensor of shape
(batch_size, hidden_dim, x_dim_seq, y_dim_seq).weight (torch.Tensor) – Weight tensor of shape
(hidden_dim, x_dim_kernel, y_dim_kernel).
- Returns:
Output tensor with shape
(batch_size, hidden_dim, x_dim_seq, y_dim_seq).- Return type:
B, H, X_in, Y_in = x.shape BW, h_k, K_x, K_y = weight.shape assert H == h_k, "Input and kernel must have the same number of channels (H)." # 1. Determine FFT size for linear convolution (same as 'valid' version) fft_shape = (min( X_in + (K_x + 1) // 2, 2 * X_in), min( Y_in + (K_y + 1) // 2, 2 * Y_in)) input_fft = torch.fft.rfft2(x, s=fft_shape) weight_fft = torch.fft.rfft2(weight, s=fft_shape) # 2. Apply the Convolution Theorem conv_fft = input_fft * weight_fft # 3. Compute the inverse FFT to get the full convolution result output_full = torch.fft.irfft2(conv_fft, s=fft_shape) # 4. Crop the result to the 'same' size # The output should have the same size as the input: (X_in, Y_in) # To achieve this, we crop from the full convolution result, # starting at an offset that centers the output. crop_start_h = (K_x) // 2 crop_start_w = (K_y) // 2 output = output_full[:, :, crop_start_h : crop_start_h + X_in, crop_start_w : crop_start_w + Y_in]
Note
The kernel shape must be less than or equal to twice the input shape.