Skip to content

Commit

Permalink
feat: ✨ Add input/output size increase function: step_up_size
Browse files Browse the repository at this point in the history
  • Loading branch information
rhoadesScholar committed May 6, 2024
1 parent 1c80eaf commit 561eb5c
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 22 deletions.
9 changes: 9 additions & 0 deletions src/leibnetz/leibnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,15 @@ def is_valid_input_shape(self, input_key, input_shape):
== 0
).all()

def step_up_size(self, steps: int = 1):
for n in range(steps):
target_arrays = {}
for name, metadata in self.output_shapes.items():
target_arrays[name] = tuple(
(tuple(s + 1 for s in metadata["shape"]), metadata["scale"])
)
self.compute_shapes(target_arrays, set=True)

def step_valid_shapes(self, input_key):
input_scale = self._input_shapes[input_key][1]
step_size = self.least_common_scale / input_scale
Expand Down
37 changes: 15 additions & 22 deletions src/leibnetz/nodes/additive_attention_gate_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,18 +124,18 @@ def forward(self, inputs): # TODO
g1 = self.W_g(g)
x1 = self.W_x(x)
# change it to crop
if np.all(x1.shape[-self.ndims :] != g1.shape[-self.ndims :]):
smallest_shape = np.min(
[x1.shape[-self.ndims :], g1.shape[-self.ndims :]], axis=0
)
assert len(smallest_shape) == self.ndims, (
f"Input shapes {[x1.shape,g1.shape]} have wrong dimensionality for node {self.id}, "
f"with expected inputs {self.input_keys} of dimensionality {self.ndims}"
)
if np.all(x1.shape[-self.ndims :] != smallest_shape):
x1 = self.crop(x1, smallest_shape)
if np.all(g1.shape[-self.ndims :] != smallest_shape):
g1 = self.crop(g1, smallest_shape)
# if np.all(x1.shape[-self.ndims :] != g1.shape[-self.ndims :]):
smallest_shape = np.min(
[x1.shape[-self.ndims :], g1.shape[-self.ndims :]], axis=0
)
assert len(smallest_shape) == self.ndims, (
f"Input shapes {[x1.shape,g1.shape]} have wrong dimensionality for node {self.id}, "
f"with expected inputs {self.input_keys} of dimensionality {self.ndims}"
)
# if np.all(x1.shape[-self.ndims :] != smallest_shape):
x1 = self.crop(x1, smallest_shape)
# if np.all(g1.shape[-self.ndims :] != smallest_shape):
g1 = self.crop(g1, smallest_shape)
psi = torch.nn.functional.relu(g1 + x1)
psi = self.psi(psi)
psi = torch.nn.functional.softmax(psi, dim=1)
Expand Down Expand Up @@ -188,17 +188,10 @@ def factor_crop(self, input_shape): # TODO
return (spatial_shape - target_spatial_shape) / self.scale

def crop_to_factor(self, x): # TODO
shape = x.size()
shape = shape[-self.ndims :]
shape = x.shape[-self.ndims :]
target_shape = shape - self.factor_crop(shape)
if (target_shape != shape).all():
assert all(((t > c) for t, c in zip(target_shape, self.resample_crop))), (
"Feature map with shape %s is too small to ensure "
"translation equivariance with self.least_common_scale %s and following "
"resamples %s" % (x.size(), self.least_common_scale, self.kernel_sizes)
)

return self.crop(x, target_shape.astype(int))
# if (target_shape != shape).all():
return self.crop(x, target_shape.astype(int))

return x

Expand Down

0 comments on commit 561eb5c

Please sign in to comment.