Skip to content

Commit

Permalink
Merge pull request #9 from kengz/1.1.2
Browse files Browse the repository at this point in the history
set encoding as nn.Parameter
  • Loading branch information
kengz authored Dec 31, 2021
2 parents 1f60d2f + fcbf0da commit cac3dbd
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 2 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def run_tests(self):

setup(
name='torcharc',
version='1.1.1',
version='1.1.2',
description='Build PyTorch networks by specifying architectures.',
long_description='https://github.com/kengz/torcharc',
keywords='torcharc',
Expand Down
2 changes: 1 addition & 1 deletion torcharc/module/perceiver_io/preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def build_pos_encoding(self, pos: torch.Tensor, max_reso: list = None) -> torch.
spatial_encoding = torch.cat(encodings, dim=-1) # shape (x, y,... d*(2*num_freq_bands+1))
# flatten spatial dimensions into 1D
pos_encoding = rearrange(spatial_encoding, '... c -> (...) c')
return pos_encoding
return nn.Parameter(pos_encoding)

def get_pos_encoding_dim(self) -> int:
return len(self.spatial_shape) * (2 * self.num_freq_bands + int(self.cat_pos))
Expand Down

0 comments on commit cac3dbd

Please sign in to comment.