Skip to content

Commit

Permalink
Merge pull request #22 from Jacob-Stevens-Haas/hotfix-21
Browse files Browse the repository at this point in the history
  • Loading branch information
Jacob-Stevens-Haas authored Mar 21, 2024
2 parents e236e17 + dc29e6e commit 971d6cf
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 24 deletions.
33 changes: 14 additions & 19 deletions src/gen_experiments/gridsearch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def _amax_to_full_inds(

if amax_inds is ...: # grab each element from arrays in list of lists of arrays
return {
np_to_primitive(el)
void_to_tuple(el)
for ar_list in amax_arrays
for arr in ar_list
for el in arr.flatten()
Expand All @@ -75,15 +75,15 @@ def _amax_to_full_inds(
for ind in amax_inds:
if ind is ...: # grab each element from arrays in list of lists of arrays
all_inds |= {
np_to_primitive(el)
void_to_tuple(el)
for ar_list in amax_arrays
for arr in ar_list
for el in arr.flatten()
}
elif isinstance(ind[0], int):
all_inds |= {np_to_primitive(cast(np.void, plot_axis_results[ind]))}
all_inds |= {void_to_tuple(cast(np.void, plot_axis_results[ind]))}
else: # ind[0] is slice(None)
all_inds |= {np_to_primitive(el) for el in plot_axis_results[ind]}
all_inds |= {void_to_tuple(el) for el in plot_axis_results[ind]}
return all_inds


Expand Down Expand Up @@ -383,10 +383,11 @@ def _marginalize_grid_views(
grid_decisions: Iterable[str],
results: Annotated[NDArray[T], "shape (n_metrics, *n_gridsearch_values)"],
max_or_min: Sequence[str],
) -> tuple[list[GridsearchResult[T]], list[GridsearchResult]]:
"""Marginalize unnecessary dimensions by taking max across axes.
) -> tuple[list[GridsearchResult[T]], list[GridsearchResult[np.void]]]:
"""Marginalize unnecessary dimensions by taking max or min across axes.
Ignores NaN values and strips the metric index from the argoptima.
Ignores NaN values
Args:
grid_decisions: list of how to treat each non-metric gridsearch
axis. An array of metrics for each "plot" grid decision
Expand All @@ -396,9 +397,8 @@ def _marginalize_grid_views(
max_or_min: either "max" or "min" for each row of results
Returns:
a list of the metric optima for each plottable grid decision, and
a list of the flattened argoptima.
a list of the flattened argoptima, with metric removed
"""
arg_dtype = np.dtype(",".join(results.ndim * "i"))
plot_param_inds = [ind for ind, val in enumerate(grid_decisions) if val == "plot"]
grid_searches = []
args_maxes = []
Expand All @@ -409,13 +409,8 @@ def _marginalize_grid_views(
[opt(result, axis=reduce_axes) for opt, result in zip(optfuns, results)]
)
sub_arrs = []
for m_ind, (result, opt) in enumerate(zip(results, max_or_min)):

def _metric_pad(tp: tuple[int, ...]) -> np.void:
return np.void((m_ind, *tp), dtype=arg_dtype)

pad_m_ind = np.vectorize(_metric_pad)
arg_max = pad_m_ind(_argopt(result, reduce_axes, opt))
for result, opt in zip(results, max_or_min):
arg_max = _argopt(result, reduce_axes, opt)
sub_arrs.append(arg_max)

args_max = np.stack(sub_arrs)
Expand Down Expand Up @@ -613,7 +608,7 @@ def find_gridpoints(
for index_of_ax, indexes_in_ax in keep_axes:
amax_arr = ser[index_of_ax][1]
amax_want = amax_arr[np.ix_(inds_of_metrics, indexes_in_ax)].flatten()
partial_match |= {np_to_primitive(el) for el in amax_want}
partial_match |= {void_to_tuple(el) for el in amax_want}
logger.info(
f"Found {len(partial_match)} gridpoints that match metric-plot_axis criteria"
)
Expand All @@ -633,7 +628,7 @@ def check_values(criteria: Any | Callable[..., bool], candidate: Any) -> bool:
logger.debug(f"Checking whether {point['pind']} matches param query")
for params_match in params_or:
if all(
check_values(value, point["params"][param])
param in point["params"] and check_values(value, point["params"][param])
for param, value in params_match.items()
):
results.append(point)
Expand Down Expand Up @@ -664,6 +659,6 @@ def _expand_ellipsis_axis(
raise TypeError("Keep_axis does not have an ellipsis or is not a 2-tuple")


def np_to_primitive(tuple_like: np.void) -> tuple[int, ...]:
def void_to_tuple(tuple_like: np.void) -> tuple[int, ...]:
"""Turn a void that represents a tuple of ints into a tuple of ints"""
return tuple(int(el) for el in cast(Iterable, tuple_like))
2 changes: 1 addition & 1 deletion src/gen_experiments/gridsearch/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ class GridLocator:
Annotated[GridsearchResult[np.void], "arg_opts"],
]
],
"len=n_grid_axes",
"len=n_plot_axes",
]

ExpResult = dict[str, Any]
Expand Down
9 changes: 5 additions & 4 deletions tests/test_gridsearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,10 +78,10 @@ def test_marginalize_grid_views():
for result, expected in zip(res_val, expected_val):
np.testing.assert_array_equal(result, expected)

ts = "i,i,i,i"
ts = "i,i,i"
expected_ind = [
np.array([[(0, 0, 0, 0), (0, 1, 1, 1)], [(1, 0, 0, 0), (1, 1, 1, 0)]], ts),
np.array([[(0, 0, 0, 0), (0, 1, 1, 1)], [(1, 1, 1, 0), (1, 0, 0, 1)]], ts),
np.array([[(0, 0, 0), (1, 1, 1)], [(0, 0, 0), (1, 1, 0)]], ts),
np.array([[(0, 0, 0), (1, 1, 1)], [(1, 1, 0), (0, 0, 1)]], ts),
]
for result, expected in zip(res_ind, expected_ind):
np.testing.assert_array_equal(result, expected)
Expand Down Expand Up @@ -187,8 +187,9 @@ def gridsearch_results():
gridsearch.GridLocator(
..., (..., ...), [{"diff_params.alpha": 0.1}, {"diff_params.alpha": 0.3}]
),
gridsearch.GridLocator(params_or=[{"diff_params.alpha": 0.1}, {"foo": 0}]),
),
ids=("exact", "object", "callable", "by_axis", "or"),
ids=("exact", "object", "callable", "by_axis", "or", "missingkey"),
)
def test_find_gridpoints(gridsearch_results, locator):
want, full_details = gridsearch_results
Expand Down

0 comments on commit 971d6cf

Please sign in to comment.