Skip to content

Commit

Permalink
[FIX] Try correcting einsum equation
Browse files Browse the repository at this point in the history
  • Loading branch information
f-dangel committed Apr 28, 2023
1 parent e3886cd commit 6cd4466
Showing 1 changed file with 7 additions and 8 deletions.
15 changes: 7 additions & 8 deletions einconv/diag_ggn.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def _conv_diag_ggn_einsum_equation(N: int) -> str:
result_str = ""

# requires 6 + 5 * N letters
letters = get_letters(6 + 5 * N)
letters = get_letters(5 + 5 * N)

# class dimension
class_letter = letters.pop()
Expand All @@ -42,13 +42,12 @@ def _conv_diag_ggn_einsum_equation(N: int) -> str:
sqrt_ggn2_str += batch_letter

# group dimension
group1_letter = letters.pop()
input1_str += group1_letter
sqrt_ggn1_str += group1_letter

group2_letter = letters.pop()
input2_str += group2_letter
sqrt_ggn2_str += group2_letter
group_letter = letters.pop()
input1_str += group_letter
sqrt_ggn1_str += group_letter
input2_str += group_letter
sqrt_ggn2_str += group_letter
result_str += group_letter

# output channel dimension
out_channel_letter = letters.pop()
Expand Down

0 comments on commit 6cd4466

Please sign in to comment.