Skip to content

Commit 4003ed8

Browse files
authored
Merge pull request gyorilab#332 from gyorilab/stratify_improvement
Reimplement stratification logic for parameter consistency
2 parents a8623e7 + a7f1875 commit 4003ed8

File tree

6 files changed

+202
-128
lines changed

6 files changed

+202
-128
lines changed

mira/dkg/api.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def get_entity(
7373
curie: str = Path(
7474
...,
7575
description="A compact URI (CURIE) for an entity in the form of ``<prefix>:<local unique identifier>``",
76-
example="ido:0000511",
76+
examples=["ido:0000511"],
7777
),
7878
):
7979
"""Get information about an entity (e.g., its name, description synonyms, alternative identifiers,
@@ -97,7 +97,7 @@ def get_entities(
9797
...,
9898
description="A comma-separated list of compact URIs (CURIEs) for an "
9999
"entity in the form of ``<prefix>:<local unique identifier>,...``",
100-
example="ido:0000511,ido:0000512",
100+
examples=["ido:0000511,ido:0000512"],
101101
),
102102
):
103103
"""
@@ -158,7 +158,7 @@ def get_transitive_closure(
158158
relation_types: List[str] = Query(
159159
...,
160160
description="A list of relation types to get a transitive closure for",
161-
example=DKG_REFINER_RELS,
161+
examples=[DKG_REFINER_RELS],
162162
),
163163
):
164164
"""Get a transitive closure of the requested type(s)"""
@@ -384,13 +384,13 @@ def is_ontological_child(
384384
)
385385
def search(
386386
request: Request,
387-
q: str = Query(..., example="infect", description="The search query"),
387+
q: str = Query(..., examples=["infect"], description="The search query"),
388388
limit: int = 25,
389389
offset: int = 0,
390390
prefixes: Optional[str] = Query(
391391
default=None,
392392
description="A comma-separated list of prefixes",
393-
examples={
393+
examples=[{
394394
"no prefix filter": {
395395
"summary": "Don't filter by prefix",
396396
"value": None,
@@ -399,7 +399,7 @@ def search(
399399
"summary": "Search for units, which have Wikidata prefixes",
400400
"value": "wikidata",
401401
},
402-
},
402+
}],
403403
),
404404
labels: Optional[str] = Query(
405405
default=None,

mira/dkg/grounding.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ def ground_get(
149149
description="The text to be grounded. Warning: grounding does not work well for "
150150
"substring matches, i.e., if searching only for 'infected'. In these "
151151
"cases, using the search API is more appropriate.",
152-
example="Infected Population",
152+
examples=["Infected Population"],
153153
),
154154
):
155155
"""Ground text with Gilda."""

mira/metamodel/ops.py

Lines changed: 117 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828

2929
def stratify(
3030
template_model: TemplateModel,
31-
*,
3231
key: str,
3332
strata: Collection[str],
3433
strata_curie_to_name: Optional[Mapping[str, str]] = None,
@@ -42,6 +41,7 @@ def stratify(
4241
params_to_preserve: Optional[Collection[str]] = None,
4342
concepts_to_stratify: Optional[Collection[str]] = None,
4443
concepts_to_preserve: Optional[Collection[str]] = None,
44+
param_renaming_uses_strata_names: Optional[bool] = False,
4545
) -> TemplateModel:
4646
"""Multiplies a model into several strata.
4747
@@ -95,23 +95,23 @@ def stratify(
9595
params_to_stratify :
9696
A list of parameters to stratify. If none given, will stratify all
9797
parameters.
98-
params_to_preserve:
98+
params_to_preserve :
9999
A list of parameters to preserve. If none given, will stratify all
100100
parameters.
101101
concepts_to_stratify :
102102
A list of concepts to stratify. If none given, will stratify all
103103
concepts.
104-
concepts_to_preserve:
104+
concepts_to_preserve :
105105
A list of concepts to preserve. If none given, will stratify all
106106
concepts.
107-
107+
param_renaming_uses_strata_names :
108+
If true, the strata names will be used in the parameter renaming.
109+
If false, the strata indices will be used. Default: False
108110
Returns
109111
-------
110112
:
111113
A stratified template model
112114
"""
113-
strata = sorted(strata)
114-
115115
if strata_name_lookup and strata_curie_to_name is None:
116116
from mira.dkg.web_client import get_entities_web, MissingBaseUrlError
117117
try:
@@ -137,8 +137,6 @@ def stratify(
137137

138138
# List of new templates
139139
templates = []
140-
# Counter to keep track of how many times a parameter has been stratified
141-
params_count = Counter()
142140

143141
# Figure out excluded concepts
144142
if concepts_to_stratify is None:
@@ -154,7 +152,10 @@ def stratify(
154152
concept_names - set(concepts_to_stratify)
155153
)
156154

155+
stratum_index_map = {stratum: i for i, stratum in enumerate(strata)}
156+
157157
keep_unstratified_parameters = set()
158+
all_param_mappings = defaultdict(set)
158159
for template in template_model.templates:
159160
# If the template doesn't have any concepts that need to be stratified
160161
# then we can just keep it as is and skip the rest of the loop
@@ -165,77 +166,100 @@ def stratify(
165166
templates.append(deepcopy(template))
166167
continue
167168

168-
# Generate a derived template for each strata
169-
for stratum in strata:
170-
new_template = template.with_context(
171-
do_rename=modify_names, exclude_concepts=exclude_concepts,
172-
curie_to_name_map=strata_curie_to_name,
173-
**{key: stratum},
174-
)
175-
rewrite_rate_law(template_model=template_model,
176-
old_template=template,
177-
new_template=new_template,
178-
params_count=params_count,
179-
params_to_stratify=params_to_stratify,
180-
params_to_preserve=params_to_preserve)
181-
# parameters = list(template_model.get_parameters_from_rate_law(template.rate_law))
182-
# if len(parameters) == 1:
183-
# new_template.set_mass_action_rate_law(parameters[0])
184-
templates.append(new_template)
185-
186-
# assume all controllers have to get stratified together
187-
# and mixing of strata doesn't occur during control
188-
controllers = template.get_controllers()
189-
if cartesian_control and controllers:
190-
remaining_strata = [s for s in strata if s != stratum]
191-
192-
# use itt.product to generate all combinations of remaining
193-
# strata for remaining controllers. for example, if there
169+
# Check if we will have any controllers in the template
170+
ncontrollers = num_controllers(template)
171+
# If we have controllers, and we want cartesian control then
172+
# we will stratify controllers separately
173+
stratify_controllers = (ncontrollers > 0) and cartesian_control
174+
175+
# Generate a derived template for each stratum
176+
for stratum, stratum_idx in stratum_index_map.items():
177+
template_strata = []
178+
new_template = deepcopy(template)
179+
# We have to make sure that we only add the stratum to the
180+
# list of template strata if we stratified any of the non-controllers
181+
# in this first for loop
182+
any_noncontrollers_stratified = False
183+
# We apply this stratum to each concept except for controllers
184+
# in case we will separately stratify those
185+
for concept in new_template.get_concepts_flat(
186+
exclude_controllers=stratify_controllers,
187+
refresh=True):
188+
if concept.name in exclude_concepts:
189+
continue
190+
concept.with_context(
191+
do_rename=modify_names,
192+
curie_to_name_map=strata_curie_to_name,
193+
inplace=True,
194+
**{key: stratum})
195+
any_noncontrollers_stratified = True
196+
197+
# If we don't stratify controllers then we are done and can just
198+
# make the new rate law, then append this new template
199+
if not stratify_controllers:
200+
# We only need to do this if we stratified any of the non-controllers
201+
if any_noncontrollers_stratified:
202+
template_strata = [stratum if
203+
param_renaming_uses_strata_names else stratum_idx]
204+
param_mappings = rewrite_rate_law(template_model=template_model,
205+
old_template=template,
206+
new_template=new_template,
207+
template_strata=template_strata,
208+
params_to_stratify=params_to_stratify,
209+
params_to_preserve=params_to_preserve)
210+
for old_param, new_param in param_mappings.items():
211+
all_param_mappings[old_param].add(new_param)
212+
templates.append(new_template)
213+
# Otherwise we are stratifying controllers separately
214+
else:
215+
# Use itt.product to generate all combinations of
216+
# strata for controllers. For example, if there
194217
# are two controllers A and B and stratification is into
195218
# old, middle, and young, then there will be the following 9:
196219
# (A_old, B_old), (A_old, B_middle), (A_old, B_young),
197220
# (A_middle, B_old), (A_middle, B_middle), (A_middle, B_young),
198221
# (A_young, B_old), (A_young, B_middle), (A_young, B_young)
199-
c_strata_tuples = itt.product(remaining_strata, repeat=len(controllers))
200-
for c_strata_tuple in c_strata_tuples:
201-
stratified_controllers = [
202-
controller.with_context(do_rename=modify_names, **{key: c_stratum})
203-
if controller.name not in exclude_concepts
204-
else controller
205-
for controller, c_stratum in zip(controllers, c_strata_tuple)
206-
]
207-
if isinstance(template, (GroupedControlledConversion, GroupedControlledProduction)):
208-
stratified_template = new_template.with_controllers(stratified_controllers)
209-
elif isinstance(template, (ControlledConversion, ControlledProduction,
210-
ControlledDegradation, ControlledReplication)):
211-
assert len(stratified_controllers) == 1
212-
stratified_template = new_template.with_controller(stratified_controllers[0])
213-
else:
214-
raise NotImplementedError
215-
# the old template is used here on purpose for easier bookkeeping
216-
rewrite_rate_law(template_model=template_model,
217-
old_template=template,
218-
new_template=stratified_template,
219-
params_count=params_count,
220-
params_to_stratify=params_to_stratify,
221-
params_to_preserve=params_to_preserve)
222+
for c_strata_tuple in itt.product(strata, repeat=ncontrollers):
223+
stratified_template = deepcopy(new_template)
224+
stratified_controllers = stratified_template.get_controllers()
225+
template_strata = [stratum if param_renaming_uses_strata_names
226+
else stratum_idx]
227+
# We now apply the stratum assigned to each controller in this particular
228+
# tuple to the controller
229+
for controller, c_stratum in zip(stratified_controllers, c_strata_tuple):
230+
controller.with_context(do_rename=modify_names, inplace=True,
231+
**{key: c_stratum})
232+
template_strata.append(c_stratum if param_renaming_uses_strata_names
233+
else stratum_index_map[c_stratum])
234+
235+
# Wew can now rewrite the rate law for this stratified template,
236+
# then append the new template
237+
param_mappings = rewrite_rate_law(template_model=template_model,
238+
old_template=template,
239+
new_template=stratified_template,
240+
template_strata=template_strata,
241+
params_to_stratify=params_to_stratify,
242+
params_to_preserve=params_to_preserve)
243+
for old_param, new_param in param_mappings.items():
244+
all_param_mappings[old_param].add(new_param)
222245
templates.append(stratified_template)
223246

224247
parameters = {}
225248
for parameter_key, parameter in template_model.parameters.items():
226-
if parameter_key not in params_count:
249+
if parameter_key not in all_param_mappings:
227250
parameters[parameter_key] = parameter
228251
continue
229252
# We need to keep the original param if it has been broken
230253
# up but not in every instance. We then also
231254
# generate the counted parameter variants
232255
elif parameter_key in keep_unstratified_parameters:
233256
parameters[parameter_key] = parameter
234-
# note that `params_count[key]` will be 1 higher than the number of uses
235-
for i in range(params_count[parameter_key]):
257+
# We otherwise generate variants of the parameter based
258+
# on the previously complied parameter mappings
259+
for stratified_param in all_param_mappings[parameter_key]:
236260
d = deepcopy(parameter)
237-
d.name = f"{parameter_key}_{i}"
238-
parameters[d.name] = d
261+
d.name = stratified_param
262+
parameters[stratified_param] = d
239263

240264
# Create new initial values for each of the strata
241265
# of the original compartments, copied from the initial
@@ -320,7 +344,7 @@ def rewrite_rate_law(
320344
template_model: TemplateModel,
321345
old_template: Template,
322346
new_template: Template,
323-
params_count: Counter,
347+
template_strata: List[int],
324348
params_to_stratify: Optional[Collection[str]] = None,
325349
params_to_preserve: Optional[Collection[str]] = None,
326350
):
@@ -337,9 +361,9 @@ def rewrite_rate_law(
337361
new_template :
338362
The new template. One of the templates created by stratification of
339363
``old_template``.
340-
params_count :
341-
A counter that keeps track of how many times a parameter has been
342-
stratified.
364+
template_strata :
365+
A list of strata indices that have been applied to the template,
366+
used for parameter naming.
343367
params_to_stratify :
344368
A list of parameters to stratify. If none given, will stratify all
345369
parameters.
@@ -351,7 +375,7 @@ def rewrite_rate_law(
351375
# to the stratified controllers in for the originals
352376
rate_law = old_template.rate_law
353377
if not rate_law:
354-
return
378+
return {}
355379

356380
# If the template has controllers/subjects that affect the rate law
357381
# and there is an overlap between these, then simple substitution
@@ -362,28 +386,7 @@ def rewrite_rate_law(
362386
old_template.get_controllers()}:
363387
has_subject_controller_overlap = True
364388

365-
# Step 1. Identify the mass action symbol and rename it with a
366-
parameters = list(template_model.get_parameters_from_rate_law(rate_law))
367-
for parameter in parameters:
368-
# If a parameter is explicitly listed as one to preserve, then
369-
# don't stratify it
370-
if params_to_preserve is not None and parameter in params_to_preserve:
371-
continue
372-
# If we have an explicit stratification list then if something isn't
373-
# in the list then don't stratify it.
374-
elif params_to_stratify is not None and parameter not in params_to_stratify:
375-
continue
376-
# Otherwise we go ahead with stratification, i.e., in cases
377-
# where nothing was said about parameter stratification or the
378-
# parameter was listed explicitly to be stratified
379-
else:
380-
rate_law = rate_law.subs(
381-
parameter,
382-
sympy.Symbol(f"{parameter}_{params_count[parameter]}")
383-
)
384-
params_count[parameter] += 1 # increment this each time to keep unique
385-
386-
# Step 2. Rename symbols based on the new concepts
389+
# Step 1. Rename controllers
387390
for old_controller, new_controller in zip(
388391
old_template.get_controllers(), new_template.get_controllers(),
389392
):
@@ -405,7 +408,7 @@ def rewrite_rate_law(
405408
sympy.Symbol(new_controller.name),
406409
)
407410

408-
# Step 3. Rename subject and object
411+
# Step 2. Rename subject and object
409412
old_cbr = old_template.get_concepts_by_role()
410413
new_cbr = new_template.get_concepts_by_role()
411414
if "subject" in old_cbr and "subject" in new_cbr:
@@ -419,7 +422,31 @@ def rewrite_rate_law(
419422
sympy.Symbol(new_template.outcome.name),
420423
)
421424

425+
# Step 3. Rename parameters by generating new parameters
426+
# named according to the strata that were applied to the
427+
# given template
428+
parameters = list(template_model.get_parameters_from_rate_law(rate_law))
429+
param_mappings = {}
430+
for parameter in parameters:
431+
# If a parameter is explicitly listed as one to preserve, then
432+
# don't stratify it
433+
if params_to_preserve is not None and parameter in params_to_preserve:
434+
continue
435+
# If we have an explicit stratification list then if something isn't
436+
# in the list then don't stratify it.
437+
elif params_to_stratify is not None and parameter not in params_to_stratify:
438+
continue
439+
# Otherwise we go ahead with stratification, i.e., in cases
440+
# where nothing was said about parameter stratification or the
441+
# parameter was listed explicitly to be stratified
442+
else:
443+
param_suffix = '_'.join([str(s) for s in template_strata])
444+
new_param = f'{parameter}_{param_suffix}'
445+
param_mappings[parameter] = new_param
446+
rate_law = rate_law.subs(parameter, sympy.Symbol(new_param))
447+
422448
new_template.rate_law = rate_law
449+
return param_mappings
423450

424451

425452
def simplify_rate_laws(template_model: TemplateModel):

0 commit comments

Comments
 (0)