You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
def pack_into_tensor(array, axis):
"""
packs a given TensorArray into a tensor along a given axis
Parameters:
----------
array: TensorArray
the tensor array to pack
axis: int
the axis to pack the array along
Returns: Tensor
the packed tensor
"""
packed_tensor = array.pack()
shape = packed_tensor.get_shape()
rank = len(shape)
dim_permutation = [axis] + range(1, axis) + [0] + range(axis + 1, rank)
correct_shape_tensor = tf.transpose(packed_tensor, dim_permutation)
return correct_shape_tensor
Imagine I want to stack a tensorArray with 6 elements of shape [3,4,5] into a tensor of shape [3,4,5,6]. Axis is 3. After array.pack(), the shape of tensor is [6,3,4,5]. dim_permutation is [3] + [1,2] + [0] + [] = [3,1,2,0]. After transpose, the output shape is [5,3,4,6], which is wrong.
Imagine I want to stack a tensorArray with 6 elements of shape [3,4,5] into a tensor of shape [3,4,5,6]. Axis is 3. After array.pack(), the shape of tensor is [6,3,4,5]. dim_permutation is [3] + [1,2] + [0] + [] = [3,1,2,0]. After transpose, the output shape is [5,3,4,6], which is wrong.
The correct formula should be
With this formula, the dim_permutation is [1,2,3] + [0] + [] = [1,2,3,0].
Similarly, the formula in unpack_into_tensorarray is also wrong. The correct code should be
Please take a look. Thanks
The text was updated successfully, but these errors were encountered: