b2b_causal_conv1d#
- b2b_causal_conv1d(x, weight_proj, weight_mixer, skip_bias)#
Back-to-back causal 1D convolution. Fused kernel performing projection convolution, pre-gating, mixer convolution, and post-gating. The operation is performed in a causal manner, meaning each position only attends to previous positions in the sequence. In code terms,
y_gated = b2b_causal_conv1d(x, weight_proj, weight_mixer, skip_bias)
is equivalent to,
y = conv1d_proj(x) z = y[:,1::3, :] * y[:, 2::3, :] y_gated = mixer(z) + mixer.skip_bias * z y = y[:, ::3, :] * y_gated
Note
The input tensor is expected to be of shape
(batch_size, 3*dim, seq_len)wheredimis the number of channels in the output.If mixer weights are used with FFT based convolution, it should be flipped along the last dimension:
weight_mixer = torch.flip(weight_mixer, [-1])
- Parameters:
x (torch.Tensor) – Input tensor of shape
(batch_size, 3*dim, seq_len).weight_proj (torch.Tensor) – Projection weight tensor of shape
(dim, kernel_size).weight_mixer (torch.Tensor) – Mixer weight tensor of shape
(dim, kernel_size).skip_bias (torch.Tensor) – Skip bias tensor of shape
(dim,).
- Returns:
Output tensor of shape
(batch_size, dim, seq_len).- Return type:
Example
batch_size, dim, seq_len, kernel_size = 2, 4, 10, 3 x = torch.randn(batch_size, 3*dim, seq_len, device="cuda") weight_proj = torch.randn(3*dim, kernel_size, device="cuda") weight_mixer = torch.randn(dim, kernel_size, device="cuda") skip_bias = torch.randn(dim, device="cuda") y_gated = b2b_causal_conv1d(x, weight_proj, weight_mixer, skip_bias) print(y_gated.shape) # torch.Size([2, 4, 10])