From ace82bc7711fc37daa02cc45722b2a1db524218d Mon Sep 17 00:00:00 2001 From: kkaris Date: Wed, 19 Jul 2023 10:47:50 -0700 Subject: [PATCH 1/7] Add endpoint for deactivating templates in models --- mira/dkg/model.py | 133 +++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 131 insertions(+), 2 deletions(-) diff --git a/mira/dkg/model.py b/mira/dkg/model.py index 9758373a2..931a929db 100644 --- a/mira/dkg/model.py +++ b/mira/dkg/model.py @@ -19,14 +19,14 @@ Request, ) from fastapi.responses import FileResponse -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, root_validator, validator from mira.examples.sir import sir_bilayer, sir, sir_parameterized_init from mira.metamodel import ( NaturalConversion, Template, ControlledConversion, stratify, Concept, ModelComparisonGraphdata, TemplateModelDelta, TemplateModel, Parameter, simplify_rate_laws, aggregate_parameters, - counts_to_dimensionless + counts_to_dimensionless, deactivate_templates ) from mira.modeling import Model from mira.modeling.askenet.petrinet import AskeNetPetriNetModel, ModelSpecification @@ -296,6 +296,135 @@ def model_stratification( return template_model +# template deactivation +class DeactivationQuery(BaseModel): + model: Dict[str, Any] = Field( + ..., + description="The model to deactivate transitions in", + example=askenet_petrinet_json_units_values + ) + parameters: Optional[List[str]] = Field( + None, + description="Deactivates transitions that have a parameter from the " + "provided list in their rate law", + example=["beta"] + ) + transitions: Optional[List[List[str]]] = Field( + None, + description="Deactivates transitions that have a source-target " + "pair from the provided list", + example=[["infected_population_old", "infected_population_young"]] + ) + and_or: Literal["and", "or"] = Field( + "and", + description="If both transitions and parameters are provided, " + "whether to deactivate transitions that match both " + "or either of the provided conditions. If only one " + "of transitions or parameters is provided, this " + "parameter is has no effect.", + example="and" + ) + + @validator('transitions') + def check_transitions(cls, v): + # This enforces that the transitions are a list of lists of length 2 + # (since we can't use tuples for JSON (or can we?)) + if v is not None: + for transition in v: + if len(transition) != 2: + raise ValueError( + "Each transition must be a list of length 2" + ) + return v + + @root_validator(skip_on_failure=True) + def check_a_or_b(cls, values): + if ( + values.get("parameters") is None or + values.get("parameters") == [] + ) and ( + values.get("transitions") is None or + not any(values.get("transitions", [])) + ): + raise ValueError( + 'At least one of "parameters" or "transitions" is required' + ) + return values + + +@model_blueprint.post( + "/deactivate_transitions", + response_model=ModelSpecification, + tags=["modeling"], +) +def deactivate_transitions( + query: DeactivationQuery = Body( + ..., + examples={ + "parameters": { + "model": askenet_petrinet_json_units_values, + "parameters": ["beta"], + }, + "transitions": { + # Todo: Fix example model to include transitions with the + # same source and target as the example below + "model": askenet_petrinet_json_units_values, + "transitions": [["infected_population_old", + "infected_population_young"]], + }, + }, + ) +): + """Deactivate transitions in a model""" + amr_json = query.model + tm = template_model_from_askenet_json(amr_json) + + # Create callables for deactivating transitions + if query.parameters: + def deactivate_parameter(t: Template) -> bool: + """Deactivate transitions that have a parameter in the query""" + if t.rate_law is None: + return False + for symb in t.rate_law.atoms(): + if str(symb) in set(query.parameters): + return True + else: + deactivate_parameter = None + + if query.transitions is not None: + def deactivate_transition(t: Template) -> bool: + """Deactivate template if it is a transition-like template and it + matches the source-target pair""" + if hasattr(t, "subject") and hasattr(t, "outcome"): + for subject, outcome in query.transitions: + if t.subject.name == subject and t.outcome.name == outcome: + return True + return False + else: + deactivate_transition = None + + def meta_deactivate(t: Template) -> bool: + if deactivate_parameter is not None and \ + deactivate_transition is not None: + if query.and_or == "and": + return deactivate_parameter(t) and deactivate_transition(t) + else: + return deactivate_parameter(t) or deactivate_transition(t) + elif deactivate_parameter is None: + return deactivate_transition(t) + elif deactivate_transition is None: + return deactivate_parameter(t) + else: + raise ValueError( + "Need to provide either or both of parameters or transitions" + ) + + tm_deactivated = deactivate_templates(template_model=tm, + condition=meta_deactivate) + + return AskeNetPetriNetModel(Model(tm_deactivated)).to_pydantic() + + @model_blueprint.post( "/counts_to_dimensionless_mira", response_model=TemplateModel, From 8626585e9677dd4e9a639277f0048b15b9cca2c1 Mon Sep 17 00:00:00 2001 From: kkaris Date: Wed, 19 Jul 2023 10:58:36 -0700 Subject: [PATCH 2/7] Remove unused variable --- mira/examples/sir.py | 1 - 1 file changed, 1 deletion(-) diff --git a/mira/examples/sir.py b/mira/examples/sir.py index 0fff0cad0..44f52f49a 100644 --- a/mira/examples/sir.py +++ b/mira/examples/sir.py @@ -152,7 +152,6 @@ sir_parameterized_init.parameters['beta'].units = \ Unit(expression=1 / (sympy.Symbol('person') * sympy.Symbol('day'))) -old_beta = sir_parameterized_init.parameters['beta'].value for initial in sir_parameterized_init.initials.values(): initial.concept.units = Unit(expression=sympy.Symbol('person')) From abe8842e6a231e366811a7d328f12a8a28955e28 Mon Sep 17 00:00:00 2001 From: kkaris Date: Wed, 19 Jul 2023 11:20:03 -0700 Subject: [PATCH 3/7] Correctly call deactivate_templates --- mira/dkg/model.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/mira/dkg/model.py b/mira/dkg/model.py index 931a929db..ba6437927 100644 --- a/mira/dkg/model.py +++ b/mira/dkg/model.py @@ -419,10 +419,9 @@ def meta_deactivate(t: Template) -> bool: "Need to provide either or both of parameters or transitions" ) - tm_deactivated = deactivate_templates(template_model=tm, - condition=meta_deactivate) + deactivate_templates(template_model=tm, condition=meta_deactivate) - return AskeNetPetriNetModel(Model(tm_deactivated)).to_pydantic() + return AskeNetPetriNetModel(Model(tm)).to_pydantic() @model_blueprint.post( From d05772c01d984b2fe9e34df7521f40c873ad8691 Mon Sep 17 00:00:00 2001 From: kkaris Date: Wed, 19 Jul 2023 11:20:24 -0700 Subject: [PATCH 4/7] Add test for deactivation endpoint --- tests/test_model_api.py | 85 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 85 insertions(+) diff --git a/tests/test_model_api.py b/tests/test_model_api.py index 14fb83ff6..308777198 100644 --- a/tests/test_model_api.py +++ b/tests/test_model_api.py @@ -694,3 +694,88 @@ def test_reconstruct_ode_semantics_endpoint(self): assert len(flux_span_tm.parameters) == 11 assert all(t.rate_law for t in flux_span_tm.templates) + def test_deactivation_endpoint(self): + # Deliberately create a stratifiction that will lead to nonsense + # transitions, i.e. a transitions between age groups + age_strata = stratify(sir_parameterized_init, + key='age', + strata=['y', 'o'], + cartesian_control=True) + + # Assert that there are old to young transitions + transition_list = [] + for template in age_strata.templates: + if hasattr(template, 'subject') and hasattr(template, 'outcome'): + subj, outc = template.subject.name, template.outcome.name + if subj.endswith('_o') and outc.endswith('_y') or \ + subj.endswith('_y') and outc.endswith('_o'): + transition_list.append((subj, outc)) + assert len(transition_list), "No old to young transitions found" + + amr_sir = AskeNetPetriNetModel(Model(age_strata)).to_json() + + # Test the endpoint itself + # Should fail with 422 because of missing transitions or parameters + response = self.client.post( + "/api/deactivate_transitions", + json={"model": amr_sir} + ) + self.assertEqual(422, response.status_code) + + # Should fail with 422 because of empty transition list + response = self.client.post( + "/api/deactivate_transitions", + json={"model": amr_sir, "transitions": [[]]} + ) + self.assertEqual(422, response.status_code) + + # Should fail with 422 because of transitions are triples + response = self.client.post( + "/api/deactivate_transitions", + json={"model": amr_sir, "transitions": [['a', 'b', 'c']]} + ) + self.assertEqual(422, response.status_code) + + # Should fail with 422 because of empty parameters list + response = self.client.post( + "/api/deactivate_transitions", + json={"model": amr_sir, "parameters": []} + ) + self.assertEqual(422, response.status_code) + + # Actual Test + response = self.client.post( + "/api/deactivate_transitions", + json={"model": amr_sir, "transitions": transition_list} + ) + self.assertEqual(200, response.status_code) + + # Check that the transitions are deactivated + amr_sir_deactivated = response.json() + tm_deactivated = template_model_from_askenet_json(amr_sir_deactivated) + for template in tm_deactivated.templates: + if hasattr(template, 'subject') and hasattr(template, 'outcome'): + subj, outc = template.subject.name, template.outcome.name + if (subj, outc) in transition_list: + assert template.rate_law.args[0] == \ + sympy.core.numbers.Zero(), \ + template.rate_law + + # Test using parameter names for deactivation + deactivate_key = list(age_strata.parameters.keys())[0] + response = self.client.post( + "/api/deactivate_transitions", + json={"model": amr_sir, "parameters": [deactivate_key]} + ) + self.assertEqual(200, response.status_code) + amr_sir_deactivated_params = response.json() + tm_deactivated_params = template_model_from_askenet_json( + amr_sir_deactivated_params) + for template in tm_deactivated_params.templates: + for symb in template.rate_law.atoms(): + if str(symb) == deactivate_key: + assert ( + template.rate_law.rate_law.args[0] == + sympy.core.numbers.Zero(), + template.rate_law + ) From d8cfe226a9bff34bd16dbd3ddf75fdd73ec75705 Mon Sep 17 00:00:00 2001 From: kkaris Date: Wed, 19 Jul 2023 12:10:43 -0700 Subject: [PATCH 5/7] Update examples in api --- mira/dkg/model.py | 45 +++++++++++++++++++++++++++++++++++---------- 1 file changed, 35 insertions(+), 10 deletions(-) diff --git a/mira/dkg/model.py b/mira/dkg/model.py index ba6437927..3e107d09b 100644 --- a/mira/dkg/model.py +++ b/mira/dkg/model.py @@ -6,7 +6,7 @@ import uuid from pathlib import Path from textwrap import dedent -from typing import Any, Dict, List, Literal, Optional, Set, Type, Union +from typing import Any, Dict, List, Literal, Optional, Set, Type, Union, Tuple import pystow from fastapi import ( @@ -87,12 +87,20 @@ ] ) +# Used as example in the deactivation endpoint +age_strata = stratify(sir_parameterized_init, + key='age', + strata=['young', 'old'], + cartesian_control=True) + #: PetriNetModel json example petrinet_json = PetriNetModel(Model(sir)).to_pydantic() askenet_petrinet_json = AskeNetPetriNetModel(Model(sir)).to_pydantic() askenet_petrinet_json_units_values = AskeNetPetriNetModel( Model(sir_parameterized_init) ).to_pydantic() +askenet_petrinet_json_deactivate = AskeNetPetriNetModel(Model( + age_strata)).to_pydantic() @model_blueprint.post( @@ -313,7 +321,10 @@ class DeactivationQuery(BaseModel): None, description="Deactivates transitions that have a source-target " "pair from the provided list", - example=[["infected_population_old", "infected_population_young"]] + example=[ + ["infected_population_old", "infected_population_young"], + ["infected_population_young", "infected_population_old"] + ] ) and_or: Literal["and", "or"] = Field( "and", @@ -361,16 +372,24 @@ def deactivate_transitions( query: DeactivationQuery = Body( ..., examples={ - "parameters": { + "With parameters": { "model": askenet_petrinet_json_units_values, "parameters": ["beta"], }, - "transitions": { - # Todo: Fix example model to include transitions with the - # same source and target as the example below - "model": askenet_petrinet_json_units_values, - "transitions": [["infected_population_old", - "infected_population_young"]], + "With transitions": { + "model": askenet_petrinet_json_deactivate, + "transitions": list( + [t.subject.name, t.outcome.name] + for t in age_strata.templates + if hasattr(t, "subject") and hasattr(t, "outcome") and + ( + t.subject.name.endswith('_young') and + t.outcome.name.endswith('_old') + or + t.subject.name.endswith('_old') and + t.outcome.name.endswith('_young') + ) + ), }, }, ) @@ -382,7 +401,8 @@ def deactivate_transitions( # Create callables for deactivating transitions if query.parameters: def deactivate_parameter(t: Template) -> bool: - """Deactivate transitions that have a parameter in the query""" + """Deactivate templates that have the given parameter(s) in + their rate law""" if t.rate_law is None: return False for symb in t.rate_law.atoms(): @@ -854,3 +874,8 @@ def reproduce_ode_semantics_endpoint( tm = reproduce_ode_semantics(query.model) am = AskeNetPetriNetModel(Model(tm)) return am.to_pydantic() + + +from fastapi import FastAPI +app = FastAPI() +app.include_router(model_blueprint) From 8a3f83d312424bae1a7683ed820adf9267f9432c Mon Sep 17 00:00:00 2001 From: kkaris Date: Wed, 19 Jul 2023 12:12:27 -0700 Subject: [PATCH 6/7] Update test --- tests/test_model_api.py | 34 +++++++++++++++++----------------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/tests/test_model_api.py b/tests/test_model_api.py index 308777198..5957486cd 100644 --- a/tests/test_model_api.py +++ b/tests/test_model_api.py @@ -702,16 +702,6 @@ def test_deactivation_endpoint(self): strata=['y', 'o'], cartesian_control=True) - # Assert that there are old to young transitions - transition_list = [] - for template in age_strata.templates: - if hasattr(template, 'subject') and hasattr(template, 'outcome'): - subj, outc = template.subject.name, template.outcome.name - if subj.endswith('_o') and outc.endswith('_y') or \ - subj.endswith('_y') and outc.endswith('_o'): - transition_list.append((subj, outc)) - assert len(transition_list), "No old to young transitions found" - amr_sir = AskeNetPetriNetModel(Model(age_strata)).to_json() # Test the endpoint itself @@ -744,6 +734,15 @@ def test_deactivation_endpoint(self): self.assertEqual(422, response.status_code) # Actual Test + # Assert that there are old to young transitions + transition_list = [] + for template in age_strata.templates: + if hasattr(template, 'subject') and hasattr(template, 'outcome'): + subj, outc = template.subject.name, template.outcome.name + if subj.endswith('_o') and outc.endswith('_y') or \ + subj.endswith('_y') and outc.endswith('_o'): + transition_list.append((subj, outc)) + assert len(transition_list), "No old to young transitions found" response = self.client.post( "/api/deactivate_transitions", json={"model": amr_sir, "transitions": transition_list} @@ -772,10 +771,11 @@ def test_deactivation_endpoint(self): tm_deactivated_params = template_model_from_askenet_json( amr_sir_deactivated_params) for template in tm_deactivated_params.templates: - for symb in template.rate_law.atoms(): - if str(symb) == deactivate_key: - assert ( - template.rate_law.rate_law.args[0] == - sympy.core.numbers.Zero(), - template.rate_law - ) + # All rate laws must either be zero or not contain the deactivated + # parameter + if template.rate_law and not template.rate_law.is_zero: + for symb in template.rate_law.atoms(): + assert str(symb) != deactivate_key + else: + assert (template.rate_law.args[0] == sympy.core.numbers.Zero(), + template.rate_law) From 435ea85d48899e663adb2b8f38fc7c9cb3152b12 Mon Sep 17 00:00:00 2001 From: kkaris Date: Tue, 25 Jul 2023 13:11:05 -0700 Subject: [PATCH 7/7] Remove fastapi app creation used for local testing --- mira/dkg/model.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/mira/dkg/model.py b/mira/dkg/model.py index 3e107d09b..976d59902 100644 --- a/mira/dkg/model.py +++ b/mira/dkg/model.py @@ -874,8 +874,3 @@ def reproduce_ode_semantics_endpoint( tm = reproduce_ode_semantics(query.model) am = AskeNetPetriNetModel(Model(tm)) return am.to_pydantic() - - -from fastapi import FastAPI -app = FastAPI() -app.include_router(model_blueprint)