rearrange#

rearrange(x, bhl_to_lbh)#

Rearrange the tensor dimensions from (batch_size, hidden_dim, seq_dim) to (seq_dim, batch_size, hidden_dim) or vice versa.

Parameters:
  • x (torch.Tensor) – Input tensor of shape (batch_size, hidden_dim, seq_dim) if bhl_to_lbh is True, otherwise (seq_dim, batch_size, hidden_dim).

  • bhl_to_lbh (bool) – If True, rearrange the tensor from (batch_size, hidden_dim, seq_dim) to (seq_dim, batch_size, hidden_dim). If False, perform the inverse.

Returns:

Output tensor of shape (seq_dim, batch_size, hidden_dim) if bhl_to_lbh is True, otherwise (batch_size, hidden_dim, seq_dim).

Return type:

torch.Tensor

Example

x = torch.randn(2, 3, 4)
y = rearrange(x, bhl_to_lbh=True)
print(y.shape)  # torch.Size([4, 2, 3])
y = rearrange(x, bhl_to_lbh=False)
print(y.shape)  # torch.Size([3, 4, 2])