fft_conv1d#
- fft_conv1d(x, weight)#
FFT Conv1d performs non-causal convolution using FFT routines.
- Parameters:
x (torch.Tensor) – Input tensor of shape
(batch_size, dim, seq_len).weight (torch.Tensor) – Weight tensor of shape
(dim, filter_dim).
- Returns:
Output tensor of shape
(batch_size, dim, seq_len).- Return type:
Example
batch_size, dim, seq_len, filter_dim = 64, 128, 512, 1024 x = torch.randn(batch_size, dim, seq_len, device="cuda") weight = torch.randn(dim, filter_dim, device="cuda") y = fft_conv1d(x, weight) print(y.shape) # torch.Size([64, 128, 512])