fft_causal_conv1d#
- fft_causal_conv1d(x, weight)#
FFT Causal Conv1d performs convolution in a causal manner, using FFT routines instead of direct summation.
Note
This is more performant than
causal_conv1d()for kernel sizes >= 128.- Parameters:
x (torch.Tensor) – Input tensor of shape
(batch_size, dim, seq_len).weight (torch.Tensor) – Weight tensor of shape
(dim, kernel_size).
- Returns:
Output tensor of shape
(batch_size, dim, seq_len).- Return type:
Example
batch_size, dim, seq_len, kernel_size = 64, 128, 4096, 256 x = torch.randn(batch_size, dim, seq_len, device="cuda") weight = torch.randn(dim, kernel_size, device="cuda") y = fft_causal_conv1d(x, weight) print(y.shape) # torch.Size([64, 128, 4096])