Skip to content

Commit

Permalink
Fix #1 (#2)
Browse files Browse the repository at this point in the history
- Multiply by a boolean to add nothing instead of an empty slice
- Use dataclass.replace to properly replace fields in reshape step
  • Loading branch information
ulupo authored May 14, 2024
1 parent c241084 commit dcab707
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 24 deletions.
24 changes: 12 additions & 12 deletions diffpass/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
# Stdlib imports
from copy import deepcopy
from typing import Optional, Any, Sequence, Union
from dataclasses import fields, dataclass
from dataclasses import fields, dataclass, replace

# Progress bars
from tqdm import tqdm
Expand Down Expand Up @@ -621,30 +621,30 @@ def fit_bootstrap(
break

# Reshape results according to number of iterations performed
reshaped_fields = {}
for field_name in available_fields:
results_this_field = getattr(results, field_name)
n_optimized_results_this_field = (
len(results_this_field)
if can_optimize
else field_to_length_so_far[field_name]
)
n_unoptimized_results_this_field = (
len(results_this_field) - n_optimized_results_this_field
)

assert not n_optimized_results_this_field % n_iters_with_optimization
n_in_each_optimized_iter = (
n_optimized_results_this_field // n_iters_with_optimization
)
setattr(
results,
field_name,
[
results_this_field[
j
* n_in_each_optimized_iter : (j + 1)
* n_in_each_optimized_iter
]
for j in range(n_iters_with_optimization)
reshaped_fields[field_name] = [
results_this_field[
j * n_in_each_optimized_iter : (j + 1) * n_in_each_optimized_iter
]
+ [results_this_field[n_optimized_results_this_field:]],
for j in range(n_iters_with_optimization)
] + [results_this_field[n_optimized_results_this_field:]] * bool(
n_unoptimized_results_this_field
)
results = replace(results, **reshaped_fields)

return results
24 changes: 12 additions & 12 deletions nbs/base.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
"# Stdlib imports\n",
"from copy import deepcopy\n",
"from typing import Optional, Any, Sequence, Union\n",
"from dataclasses import fields, dataclass\n",
"from dataclasses import fields, dataclass, replace\n",
"\n",
"# Progress bars\n",
"from tqdm import tqdm\n",
Expand Down Expand Up @@ -689,31 +689,31 @@
" break\n",
"\n",
" # Reshape results according to number of iterations performed\n",
" reshaped_fields = {}\n",
" for field_name in available_fields:\n",
" results_this_field = getattr(results, field_name)\n",
" n_optimized_results_this_field = (\n",
" len(results_this_field)\n",
" if can_optimize\n",
" else field_to_length_so_far[field_name]\n",
" )\n",
" n_unoptimized_results_this_field = (\n",
" len(results_this_field) - n_optimized_results_this_field\n",
" )\n",
"\n",
" assert not n_optimized_results_this_field % n_iters_with_optimization\n",
" n_in_each_optimized_iter = (\n",
" n_optimized_results_this_field // n_iters_with_optimization\n",
" )\n",
" setattr(\n",
" results,\n",
" field_name,\n",
" [\n",
" results_this_field[\n",
" j\n",
" * n_in_each_optimized_iter : (j + 1)\n",
" * n_in_each_optimized_iter\n",
" ]\n",
" for j in range(n_iters_with_optimization)\n",
" reshaped_fields[field_name] = [\n",
" results_this_field[\n",
" j * n_in_each_optimized_iter : (j + 1) * n_in_each_optimized_iter\n",
" ]\n",
" + [results_this_field[n_optimized_results_this_field:]],\n",
" for j in range(n_iters_with_optimization)\n",
" ] + [results_this_field[n_optimized_results_this_field:]] * bool(\n",
" n_unoptimized_results_this_field\n",
" )\n",
" results = replace(results, **reshaped_fields)\n",
"\n",
" return results"
]
Expand Down

0 comments on commit dcab707

Please sign in to comment.