rai_toolbox.to_batch#

rai_toolbox.to_batch(p, param_ndim)[source]#

Returns a view of p, reshaped as shape-(N, d0, …) where (d0, …) has param_ndim entries.

See Parameters for further description

Parameters:
pTensor
param_ndim: Optional[int]

Determines the shape of the resulting parameter

  • A positive number determines the dimensionality of the tensor that the transformation will act on.

  • A negative number indicates the ‘offset’ from the dimensionality of the tensor.

  • None means that the transformation will be applied to the tensor without any broadcasting. This is equivalent to param_ndim=p.ndim

Returns:
reshaped_p: Tensor, shape-(N, d0, …)

Where - (d0, …) is of length param_ndim for param_ndim > 0 - (d0, …) is (1,) for param_ndim == 0 - (d0, …) is of length p.ndim - |param_ndim| for param_ndim < 0

Examples

>>> import torch as tr
>>> x = tr.rand((3, 5, 2))
>>> to_batch(x, param_ndim=0).shape
torch.Size([30, 1])
>>> to_batch(x, param_ndim=1).shape
torch.Size([15, 2])
>>> to_batch(x, param_ndim=2).shape
torch.Size([3, 5, 2])
>>> to_batch(x, param_ndim=3).shape
torch.Size([1, 3, 5, 2])
>>> to_batch(x, param_ndim=None).shape  # same as `param_ndim=x.ndim`
torch.Size([1, 3, 5, 2])
>>> to_batch(x, param_ndim=-1).shape
torch.Size([3, 5, 2])
>>> to_batch(x, param_ndim=-2).shape
torch.Size([15, 2])
>>> to_batch(x, param_ndim=-3).shape
torch.Size([30, 1])