causal_conv1d#

causal_conv1d(
x,
weight,
bias=None,
activation='identity',
channel_last=False,
)#

Depthwise causal 1D convolution with optional activation.

Each channel is convolved with its own kernel. Causal means the output at time \(t\) depends only on inputs at times \(\le t\).

\[y_{b,c,t} = \mathrm{activation}\left( \sum_{k=0}^{K-1} x_{b,c,t-k} \cdot w_{c,k} + b_c \right)\]
Parameters:
  • x (torch.Tensor) – Input tensor of shape (batch_size, dim, seq_len) if channel_last is False, otherwise (batch_size, seq_len, dim).

  • weight (torch.Tensor) – Weight tensor of shape (dim, kernel_size) if channel_last is False, otherwise (kernel_size, dim).

  • bias (torch.Tensor | None) – Optional bias tensor of shape (dim,).

  • activation (str) – Activation function to apply. Supported: "silu", "identity".

  • channel_last (bool) – Whether the channels dimension is the last dimension (NWH or NHW).

Returns:

Output tensor of shape (batch_size, dim, seq_len) if channel_last is False, otherwise (batch_size, seq_len, dim).

Return type:

torch.Tensor

Example

batch_size, dim, seq_len, kernel_size = 2, 4, 10, 3
channel_last = False
x = torch.randn(batch_size, dim, seq_len, device="cuda")
weight = torch.randn(dim, kernel_size, device="cuda")
bias = torch.randn(dim)
y = causal_conv1d(x, weight, bias, activation="silu", channel_last=channel_last)
print(y.shape)  # torch.Size([2, 4, 10])

channel_last = True
x = torch.randn(batch_size, seq_len, dim, device="cuda")
weight = torch.randn(kernel_size, dim, device="cuda")
bias = torch.randn(dim)
y = causal_conv1d(x, weight, bias, activation="silu", channel_last=channel_last)
print(y.shape)  # torch.Size([2, 10, 4])