Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

axis permutation wrong in utility.pack_into_tensor and utility.unpack_into_tensorarray #8

Open
JeffOwOSun opened this issue Mar 5, 2017 · 0 comments

Comments

@JeffOwOSun
Copy link

def pack_into_tensor(array, axis):
    """
    packs a given TensorArray into a tensor along a given axis
    Parameters:
    ----------
    array: TensorArray
        the tensor array to pack
    axis: int
        the axis to pack the array along
    Returns: Tensor
        the packed tensor
    """

    packed_tensor = array.pack()
    shape = packed_tensor.get_shape()
    rank = len(shape)

    dim_permutation = [axis] + range(1, axis) + [0]  + range(axis + 1, rank)
    correct_shape_tensor = tf.transpose(packed_tensor, dim_permutation)

    return correct_shape_tensor

Imagine I want to stack a tensorArray with 6 elements of shape [3,4,5] into a tensor of shape [3,4,5,6]. Axis is 3. After array.pack(), the shape of tensor is [6,3,4,5]. dim_permutation is [3] + [1,2] + [0] + [] = [3,1,2,0]. After transpose, the output shape is [5,3,4,6], which is wrong.

The correct formula should be

dim_permutation = range(1, axis+1) + [0] + range(axis+1, rank)

With this formula, the dim_permutation is [1,2,3] + [0] + [] = [1,2,3,0].

Similarly, the formula in unpack_into_tensorarray is also wrong. The correct code should be

dim_permutation = [axis] + range(0, axis) + range(axis+1, rank)

Please take a look. Thanks

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant