rai_toolbox.to_batch#
- rai_toolbox.to_batch(p, param_ndim)[source]#
Returns a view of
p
, reshaped as shape-(N, d0, …) where (d0, …) hasparam_ndim
entries.See Parameters for further description
- Parameters:
- pTensor
- param_ndim: Optional[int]
Determines the shape of the resulting parameter
A positive number determines the dimensionality of the tensor that the transformation will act on.
A negative number indicates the ‘offset’ from the dimensionality of the tensor.
None
means that the transformation will be applied to the tensor without any broadcasting. This is equivalent toparam_ndim=p.ndim
- Returns:
- reshaped_p: Tensor, shape-(N, d0, …)
Where - (d0, …) is of length
param_ndim
forparam_ndim > 0
- (d0, …) is (1,) forparam_ndim == 0
- (d0, …) is of lengthp.ndim - |param_ndim|
forparam_ndim < 0
Examples
>>> import torch as tr >>> x = tr.rand((3, 5, 2))
>>> to_batch(x, param_ndim=0).shape torch.Size([30, 1])
>>> to_batch(x, param_ndim=1).shape torch.Size([15, 2])
>>> to_batch(x, param_ndim=2).shape torch.Size([3, 5, 2])
>>> to_batch(x, param_ndim=3).shape torch.Size([1, 3, 5, 2])
>>> to_batch(x, param_ndim=None).shape # same as `param_ndim=x.ndim` torch.Size([1, 3, 5, 2])
>>> to_batch(x, param_ndim=-1).shape torch.Size([3, 5, 2])
>>> to_batch(x, param_ndim=-2).shape torch.Size([15, 2])
>>> to_batch(x, param_ndim=-3).shape torch.Size([30, 1])