Skip to content

Commit

Permalink
remove redundant ignore_module in to_static and move assert into eage…
Browse files Browse the repository at this point in the history
…r-only mode (PaddlePaddle#1044)
  • Loading branch information
HydrogenSulfate authored Dec 17, 2024
1 parent 69bddd9 commit d169339
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 13 deletions.
26 changes: 14 additions & 12 deletions ppsci/arch/cvit.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,10 +517,11 @@ def dot_product_attention_weights(
"""
dtype = query.dtype

assert query.ndim == key.ndim, "q, k must have same rank."
assert query.shape[:-3] == key.shape[:-3], "q, k batch dims must match."
assert query.shape[-2] == key.shape[-2], "q, k num_heads must match."
assert query.shape[-1] == key.shape[-1], "q, k depths must match."
if paddle.in_dynamic_mode():
assert query.ndim == key.ndim, "q, k must have same rank."
assert query.shape[:-3] == key.shape[:-3], "q, k batch dims must match."
assert query.shape[-2] == key.shape[-2], "q, k num_heads must match."
assert query.shape[-1] == key.shape[-1], "q, k depths must match."

# calculate attention matrix
depth = query.shape[-1]
Expand Down Expand Up @@ -567,14 +568,15 @@ def dot_product_attention(
Returns:
paddle.Tensor: Output of shape [batch..., q_length, num_heads, v_depth_per_head].
"""
assert key.ndim == query.ndim == value.ndim, "q, k, v must have same rank."
assert (
query.shape[:-3] == key.shape[:-3] == value.shape[:-3]
), "q, k, v batch dims must match."
assert (
query.shape[-2] == key.shape[-2] == value.shape[-2]
), "q, k, v num_heads must match."
assert key.shape[-3] == value.shape[-3], "k, v lengths must match."
if paddle.in_dynamic_mode():
assert key.ndim == query.ndim == value.ndim, "q, k, v must have same rank."
assert (
query.shape[:-3] == key.shape[:-3] == value.shape[:-3]
), "q, k, v batch dims must match."
assert (
query.shape[-2] == key.shape[-2] == value.shape[-2]
), "q, k, v num_heads must match."
assert key.shape[-3] == value.shape[-3], "k, v lengths must match."

# compute attention weights
attn_weights = dot_product_attention_weights(
Expand Down
1 change: 0 additions & 1 deletion ppsci/solver/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -917,7 +917,6 @@ def export(
self.model,
input_spec=input_spec,
full_graph=full_graph,
ignore_module=ignore_modules,
)

# save static graph model to disk
Expand Down

0 comments on commit d169339

Please sign in to comment.