Skip to content

Commit dcab707

Browse files
authored
Fix #1 (#2)
- Multiply by a boolean to add nothing instead of an empty slice - Use dataclass.replace to properly replace fields in reshape step
1 parent c241084 commit dcab707

File tree

2 files changed

+24
-24
lines changed

2 files changed

+24
-24
lines changed

diffpass/base.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
# Stdlib imports
99
from copy import deepcopy
1010
from typing import Optional, Any, Sequence, Union
11-
from dataclasses import fields, dataclass
11+
from dataclasses import fields, dataclass, replace
1212

1313
# Progress bars
1414
from tqdm import tqdm
@@ -621,30 +621,30 @@ def fit_bootstrap(
621621
break
622622

623623
# Reshape results according to number of iterations performed
624+
reshaped_fields = {}
624625
for field_name in available_fields:
625626
results_this_field = getattr(results, field_name)
626627
n_optimized_results_this_field = (
627628
len(results_this_field)
628629
if can_optimize
629630
else field_to_length_so_far[field_name]
630631
)
632+
n_unoptimized_results_this_field = (
633+
len(results_this_field) - n_optimized_results_this_field
634+
)
631635

632636
assert not n_optimized_results_this_field % n_iters_with_optimization
633637
n_in_each_optimized_iter = (
634638
n_optimized_results_this_field // n_iters_with_optimization
635639
)
636-
setattr(
637-
results,
638-
field_name,
639-
[
640-
results_this_field[
641-
j
642-
* n_in_each_optimized_iter : (j + 1)
643-
* n_in_each_optimized_iter
644-
]
645-
for j in range(n_iters_with_optimization)
640+
reshaped_fields[field_name] = [
641+
results_this_field[
642+
j * n_in_each_optimized_iter : (j + 1) * n_in_each_optimized_iter
646643
]
647-
+ [results_this_field[n_optimized_results_this_field:]],
644+
for j in range(n_iters_with_optimization)
645+
] + [results_this_field[n_optimized_results_this_field:]] * bool(
646+
n_unoptimized_results_this_field
648647
)
648+
results = replace(results, **reshaped_fields)
649649

650650
return results

nbs/base.ipynb

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@
5252
"# Stdlib imports\n",
5353
"from copy import deepcopy\n",
5454
"from typing import Optional, Any, Sequence, Union\n",
55-
"from dataclasses import fields, dataclass\n",
55+
"from dataclasses import fields, dataclass, replace\n",
5656
"\n",
5757
"# Progress bars\n",
5858
"from tqdm import tqdm\n",
@@ -689,31 +689,31 @@
689689
" break\n",
690690
"\n",
691691
" # Reshape results according to number of iterations performed\n",
692+
" reshaped_fields = {}\n",
692693
" for field_name in available_fields:\n",
693694
" results_this_field = getattr(results, field_name)\n",
694695
" n_optimized_results_this_field = (\n",
695696
" len(results_this_field)\n",
696697
" if can_optimize\n",
697698
" else field_to_length_so_far[field_name]\n",
698699
" )\n",
700+
" n_unoptimized_results_this_field = (\n",
701+
" len(results_this_field) - n_optimized_results_this_field\n",
702+
" )\n",
699703
"\n",
700704
" assert not n_optimized_results_this_field % n_iters_with_optimization\n",
701705
" n_in_each_optimized_iter = (\n",
702706
" n_optimized_results_this_field // n_iters_with_optimization\n",
703707
" )\n",
704-
" setattr(\n",
705-
" results,\n",
706-
" field_name,\n",
707-
" [\n",
708-
" results_this_field[\n",
709-
" j\n",
710-
" * n_in_each_optimized_iter : (j + 1)\n",
711-
" * n_in_each_optimized_iter\n",
712-
" ]\n",
713-
" for j in range(n_iters_with_optimization)\n",
708+
" reshaped_fields[field_name] = [\n",
709+
" results_this_field[\n",
710+
" j * n_in_each_optimized_iter : (j + 1) * n_in_each_optimized_iter\n",
714711
" ]\n",
715-
" + [results_this_field[n_optimized_results_this_field:]],\n",
712+
" for j in range(n_iters_with_optimization)\n",
713+
" ] + [results_this_field[n_optimized_results_this_field:]] * bool(\n",
714+
" n_unoptimized_results_this_field\n",
716715
" )\n",
716+
" results = replace(results, **reshaped_fields)\n",
717717
"\n",
718718
" return results"
719719
]

0 commit comments

Comments
 (0)