rearrange#
- rearrange(x, bhl_to_lbh)#
Rearrange the tensor dimensions from
(batch_size, hidden_dim, seq_dim)to(seq_dim, batch_size, hidden_dim)or vice versa.- Parameters:
x (torch.Tensor) – Input tensor of shape
(batch_size, hidden_dim, seq_dim)ifbhl_to_lbhis True, otherwise(seq_dim, batch_size, hidden_dim).bhl_to_lbh (bool) – If True, rearrange the tensor from
(batch_size, hidden_dim, seq_dim)to(seq_dim, batch_size, hidden_dim). If False, perform the inverse.
- Returns:
Output tensor of shape
(seq_dim, batch_size, hidden_dim)ifbhl_to_lbhis True, otherwise(batch_size, hidden_dim, seq_dim).- Return type:
Example
x = torch.randn(2, 3, 4) y = rearrange(x, bhl_to_lbh=True) print(y.shape) # torch.Size([4, 2, 3]) y = rearrange(x, bhl_to_lbh=False) print(y.shape) # torch.Size([3, 4, 2])