fft_conv1d#

fft_conv1d(x, weight)#

FFT Conv1d performs non-causal convolution using FFT routines.

Parameters:
  • x (torch.Tensor) – Input tensor of shape (batch_size, dim, seq_len).

  • weight (torch.Tensor) – Weight tensor of shape (dim, filter_dim).

Returns:

Output tensor of shape (batch_size, dim, seq_len).

Return type:

torch.Tensor

Example

batch_size, dim, seq_len, filter_dim = 64, 128, 512, 1024
x = torch.randn(batch_size, dim, seq_len, device="cuda")
weight = torch.randn(dim, filter_dim, device="cuda")
y = fft_conv1d(x, weight)
print(y.shape)  # torch.Size([64, 128, 512])