Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

weights_logits:4 the MultiOutSizeLinear.forward out is alwayls zero #44

Closed
splendourbell opened this issue May 22, 2024 · 4 comments · Fixed by #46
Closed

weights_logits:4 the MultiOutSizeLinear.forward out is alwayls zero #44

splendourbell opened this issue May 22, 2024 · 4 comments · Fixed by #46
Labels
bug Something isn't working

Comments

@splendourbell
Copy link

MultiOutSizeLinear.forward
self.out_features_ls is [32, 64, 128, 256, 512]
because weights_logits:4 * [8, 16, 32, 64, 256]

when out_feat_size is 8,
the "torch.eq(out_feat_size, feat_size).unsqueeze(-1) " is always is False
then the out is alwayls zero. is it right?

def forward(
    self,
    x: Float[torch.Tensor, "*batch in_feat"],
    out_feat_size: Int[torch.Tensor, "*batch"],
) -> Float[torch.Tensor, "*batch max_feat"]:
    out = 0
    for idx, feat_size in enumerate(self.out_features_ls):
        weight = self.weight[idx] * self.mask[idx]
        bias = self.bias[idx] if self.bias is not None else 0
        out = out + (
            torch.eq(out_feat_size, feat_size).unsqueeze(-1)
            * (einsum(weight, x, "out inp, ... inp -> ... out") + bias)
        )
    return out
@gorold
Copy link
Contributor

gorold commented May 24, 2024

self.out_features_ls should be [8, 16, 32, 64, 128] based on the current hyperparameters.

Not too sure what is the weights_logits that you are referring to.

out_feat_size is a tensor representing the patch size for each token. torch.eq(...) behaves as a mask, and only adds the current feat_size to out. So, out should be the prediction of each token based on the appropriate patch size with zero padding.

@splendourbell
Copy link
Author

in DistrParamProj.init function:

print(args_dim)
{'weights_logits': 4, 'components': [{'df': 1, 'loc': 1, 'scale': 1}, {'loc': 1}, {'total_count': 1, 'logits': 1}, {'loc': 1, 'scale': 1}]}

the code

else proj_layer(
                        in_features, tuple(dim * of for of in out_features), **kwargs
                    ) 

when dim is 4(weights_logits param), the tuple will be [32, 64, 128, 256, 512].
then the proj_layer(MultiOutSizeLinear).forward out always zero

class DistrParamProj(nn.Module):
def init(
self,
in_features: int,
out_features: int | tuple[int, ...] | list[int],
args_dim: PyTree[int, "T"],
domain_map: PyTree[Callable[[torch.Tensor], torch.Tensor], "T"],
proj_layer: Callable[..., nn.Module] = MultiOutSizeLinear,
**kwargs: Any,
):
super().init()
self.in_features = in_features
self.out_features = out_features
self.args_dim = args_dim
self.domain_map = domain_map
self.proj = convert_to_module(
tree_map(
lambda dim: (
proj_layer(in_features, dim * out_features, **kwargs)
if isinstance(out_features, int)
else proj_layer(
in_features, tuple(dim * of for of in out_features), **kwargs
)
),
args_dim,
)
)
self.out_size = (
out_features if isinstance(out_features, int) else max(out_features)
)

@gorold
Copy link
Contributor

gorold commented May 26, 2024

I see.. I think I get what you mean, will look into it, thanks!

@gorold
Copy link
Contributor

gorold commented May 26, 2024

Seems like this is a pretty major bug, fixing it would make predictions with patch size 8, 16 (with the current configuration) have better outputs, and improve performance for low frequency data. Thanks for catching this!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
2 participants