b2b_causal_conv1d#

b2b_causal_conv1d(x, weight_proj, weight_mixer, skip_bias)#

Back-to-back causal 1D convolution. Fused kernel performing projection convolution, pre-gating, mixer convolution, and post-gating. The operation is performed in a causal manner, meaning each position only attends to previous positions in the sequence. In code terms,

y_gated = b2b_causal_conv1d(x, weight_proj, weight_mixer, skip_bias)

is equivalent to,

y = conv1d_proj(x)
z = y[:,1::3, :] * y[:, 2::3, :]
y_gated = mixer(z) + mixer.skip_bias * z
y = y[:, ::3, :] * y_gated

Note

The input tensor is expected to be of shape (batch_size, 3*dim, seq_len) where dim is the number of channels in the output.

If mixer weights are used with FFT based convolution, it should be flipped along the last dimension:

weight_mixer = torch.flip(weight_mixer, [-1])
Parameters:
  • x (torch.Tensor) – Input tensor of shape (batch_size, 3*dim, seq_len).

  • weight_proj (torch.Tensor) – Projection weight tensor of shape (dim, kernel_size).

  • weight_mixer (torch.Tensor) – Mixer weight tensor of shape (dim, kernel_size).

  • skip_bias (torch.Tensor) – Skip bias tensor of shape (dim,).

Returns:

Output tensor of shape (batch_size, dim, seq_len).

Return type:

torch.Tensor

Example

batch_size, dim, seq_len, kernel_size = 2, 4, 10, 3
x = torch.randn(batch_size, 3*dim, seq_len, device="cuda")
weight_proj = torch.randn(3*dim, kernel_size, device="cuda")
weight_mixer = torch.randn(dim, kernel_size, device="cuda")
skip_bias = torch.randn(dim, device="cuda")
y_gated = b2b_causal_conv1d(x, weight_proj, weight_mixer, skip_bias)
print(y_gated.shape)  # torch.Size([2, 4, 10])