Skip to content

Commit

Permalink
refactor: update parameter sweep api
Browse files Browse the repository at this point in the history
  • Loading branch information
mortonne committed Jun 10, 2022
1 parent 408ad14 commit ab7b02c
Showing 1 changed file with 49 additions and 8 deletions.
57 changes: 49 additions & 8 deletions src/cymr/fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -814,19 +814,60 @@ def parameter_recovery(
return results

def parameter_sweep(
self, data, param_def, param_names, param_sweeps, patterns=None, n_rep=1
self,
data,
group_param,
subj_param,
sweep_param,
param_def,
patterns=None,
n_rep=1,
):
"""Simulate data with varying parameters."""
index = pd.MultiIndex.from_product(param_sweeps, names=param_names)
"""
Simulate data with varying parameters.
Parameters
----------
data : pandas.DataFrame
Data to guide simulation. Must include a 'subject' column.
May include dummy recall events if there is a dynamic
recall parameter.
group_param : dict of (str: float)
Values of parameters that apply to all subjects.
subj_param : dict of (str: dict of (str: float))
Parameters that vary by subject, indexed by subject.
sweep_param : dict of (str: numpy.ndarray)
Parameter values to sweep over.
param_def : cymr.parameters.Parameters, optional
Parameter definitions.
patterns : dict of (str: dict of (str: numpy.array)), optional
Patterns to use in the model.
n_rep : int
Number of times to repeat the simulation for each subject.
Returns
-------
results : pandas.DataFrame
Simulated data for each combination of sweep parameters.
"""
index = pd.MultiIndex.from_product(
sweep_param.values(), names=sweep_param.keys()
)
df_list = []
for idx in index:
param = param_def.fixed.copy()
param.update(dict(zip(param_names, idx)))
param = group_param.copy()
param.update(dict(zip(sweep_param.keys(), idx)))
sim = self.generate(
data, param, None, param_def, patterns=patterns, n_rep=n_rep
data, param, subj_param, param_def, patterns=patterns, n_rep=n_rep
)
df_list.append(sim)
results = pd.concat(df_list, axis=0, keys=index)
results = results.droplevel(len(param_sweeps))
results.index.rename(param_names, inplace=True)
results = results.droplevel(len(sweep_param.keys()))
results.index.rename(sweep_param.keys(), inplace=True)
return results

0 comments on commit ab7b02c

Please sign in to comment.