fft_causal_conv1d#

fft_causal_conv1d(x, weight)#

FFT Causal Conv1d performs convolution in a causal manner, using FFT routines instead of direct summation.

Note

This is more performant than causal_conv1d() for kernel sizes >= 128.

Parameters:
  • x (torch.Tensor) – Input tensor of shape (batch_size, dim, seq_len).

  • weight (torch.Tensor) – Weight tensor of shape (dim, kernel_size).

Returns:

Output tensor of shape (batch_size, dim, seq_len).

Return type:

torch.Tensor

Example

batch_size, dim, seq_len, kernel_size = 64, 128, 4096, 256
x = torch.randn(batch_size, dim, seq_len, device="cuda")
weight = torch.randn(dim, kernel_size, device="cuda")
y = fft_causal_conv1d(x, weight)
print(y.shape)  # torch.Size([64, 128, 4096])