Skip to content

Commit d1473bf

Browse files
clinssenC.A.P. Linssen
andauthored
Refactor inline expressions expansion into a transformer (#1093)
* refactor inline expressions expansion into a transformer * refactor inline expressions expansion into a transformer * refactor inline expressions expansion into a transformer * refactor inline expressions expansion into a transformer * refactor inline expressions expansion into a transformer --------- Co-authored-by: C.A.P. Linssen <[email protected]>
1 parent 346fcbe commit d1473bf

File tree

8 files changed

+250
-87
lines changed

8 files changed

+250
-87
lines changed

pynestml/codegeneration/nest_code_generator.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@
6161
from pynestml.symbols.real_type_symbol import RealTypeSymbol
6262
from pynestml.symbols.unit_type_symbol import UnitTypeSymbol
6363
from pynestml.symbols.symbol import SymbolKind
64+
from pynestml.transformers.inline_expression_expansion_transformer import InlineExpressionExpansionTransformer
6465
from pynestml.utils.ast_utils import ASTUtils
6566
from pynestml.utils.logger import Logger
6667
from pynestml.utils.logger import LoggingLevel
@@ -322,8 +323,7 @@ def analyse_neuron(self, neuron: ASTModel) -> Tuple[Dict[str, ASTAssignment], Di
322323
equations_block = neuron.get_equations_blocks()[0]
323324

324325
kernel_buffers = ASTUtils.generate_kernel_buffers(neuron, equations_block)
325-
ASTUtils.make_inline_expressions_self_contained(equations_block.get_inline_expressions())
326-
ASTUtils.replace_inline_expressions_through_defining_expressions(equations_block.get_ode_equations(), equations_block.get_inline_expressions())
326+
InlineExpressionExpansionTransformer().transform(neuron)
327327
delta_factors = ASTUtils.get_delta_factors_(neuron, equations_block)
328328
ASTUtils.replace_convolve_calls_with_buffers_(neuron, equations_block)
329329

@@ -400,9 +400,7 @@ def analyse_synapse(self, synapse: ASTModel) -> Dict[str, ASTAssignment]:
400400
equations_block = synapse.get_equations_blocks()[0]
401401

402402
kernel_buffers = ASTUtils.generate_kernel_buffers(synapse, equations_block)
403-
ASTUtils.make_inline_expressions_self_contained(equations_block.get_inline_expressions())
404-
ASTUtils.replace_inline_expressions_through_defining_expressions(
405-
equations_block.get_ode_equations(), equations_block.get_inline_expressions())
403+
InlineExpressionExpansionTransformer().transform(synapse)
406404
delta_factors = ASTUtils.get_delta_factors_(synapse, equations_block)
407405
ASTUtils.replace_convolve_calls_with_buffers_(synapse, equations_block)
408406

pynestml/codegeneration/nest_compartmental_code_generator.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
from pynestml.meta_model.ast_variable import ASTVariable
5454
from pynestml.symbol_table.symbol_table import SymbolTable
5555
from pynestml.symbols.symbol import SymbolKind
56+
from pynestml.transformers.inline_expression_expansion_transformer import InlineExpressionExpansionTransformer
5657
from pynestml.utils.mechanism_processing import MechanismProcessing
5758
from pynestml.utils.channel_processing import ChannelProcessing
5859
from pynestml.utils.concentration_processing import ConcentrationProcessing
@@ -436,13 +437,9 @@ def analyse_neuron(self, neuron: ASTModel) -> List[ASTAssignment]:
436437
ASTUtils.replace_convolve_calls_with_buffers_(neuron, equations_block)
437438

438439
# substitute inline expressions with each other
439-
# such that no inline expression references another inline expression
440-
ASTUtils.make_inline_expressions_self_contained(
441-
equations_block.get_inline_expressions())
442-
443-
# dereference inline_expressions inside ode equations
444-
ASTUtils.replace_inline_expressions_through_defining_expressions(
445-
equations_block.get_ode_equations(), equations_block.get_inline_expressions())
440+
# such that no inline expression references another inline expression;
441+
# deference inline_expressions inside ode_equations
442+
InlineExpressionExpansionTransformer().transform(neuron)
446443

447444
# generate update expressions using ode toolbox
448445
# for each equation in the equation block attempt to solve analytically

pynestml/frontend/pynestml_frontend.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -132,9 +132,8 @@ def code_generator_from_target_name(target_name: str, options: Optional[Mapping[
132132
Logger.log_message(None, code, message, None, LoggingLevel.INFO)
133133
return CodeGenerator("", options)
134134

135-
# cannot reach here due to earlier assert -- silence
135+
# cannot reach here due to earlier assert -- silence static checker warnings
136136
assert "Unknown code generator requested: " + target_name
137-
# static checker warnings
138137

139138

140139
def builder_from_target_name(target_name: str, options: Optional[Mapping[str, Any]] = None) -> Tuple[Builder, Dict[str, Any]]:
Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
# -*- coding: utf-8 -*-
2+
#
3+
# inline_expression_expansion_transformer.py
4+
#
5+
# This file is part of NEST.
6+
#
7+
# Copyright (C) 2004 The NEST Initiative
8+
#
9+
# NEST is free software: you can redistribute it and/or modify
10+
# it under the terms of the GNU General Public License as published by
11+
# the Free Software Foundation, either version 2 of the License, or
12+
# (at your option) any later version.
13+
#
14+
# NEST is distributed in the hope that it will be useful,
15+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
16+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17+
# GNU General Public License for more details.
18+
#
19+
# You should have received a copy of the GNU General Public License
20+
# along with NEST. If not, see <http://www.gnu.org/licenses/>.
21+
22+
from __future__ import annotations
23+
24+
from typing import List, Optional, Mapping, Any, Union, Sequence
25+
26+
import re
27+
28+
from pynestml.frontend.frontend_configuration import FrontendConfiguration
29+
from pynestml.meta_model.ast_inline_expression import ASTInlineExpression
30+
from pynestml.meta_model.ast_node import ASTNode
31+
from pynestml.meta_model.ast_ode_equation import ASTOdeEquation
32+
from pynestml.transformers.transformer import Transformer
33+
from pynestml.utils.ast_utils import ASTUtils
34+
from pynestml.utils.logger import Logger, LoggingLevel
35+
from pynestml.utils.string_utils import removesuffix
36+
from pynestml.visitors.ast_higher_order_visitor import ASTHigherOrderVisitor
37+
from pynestml.visitors.ast_parent_visitor import ASTParentVisitor
38+
from pynestml.visitors.ast_symbol_table_visitor import ASTSymbolTableVisitor
39+
40+
41+
class InlineExpressionExpansionTransformer(Transformer):
42+
r"""
43+
Make inline expressions self contained, i.e. without any references to other inline expressions.
44+
45+
Additionally, replace variable symbols referencing inline expressions in defining expressions of ODEs with the corresponding defining expressions from the inline expressions.
46+
"""
47+
48+
_variable_matching_template = r'(\b)({})(\b)'
49+
50+
def __init__(self, options: Optional[Mapping[str, Any]] = None):
51+
super(Transformer, self).__init__(options)
52+
53+
def transform(self, models: Union[ASTNode, Sequence[ASTNode]]) -> Union[ASTNode, Sequence[ASTNode]]:
54+
single = False
55+
if isinstance(models, ASTNode):
56+
single = True
57+
models = [models]
58+
59+
for model in models:
60+
if not model.get_equations_blocks():
61+
continue
62+
63+
for equations_block in model.get_equations_blocks():
64+
self.make_inline_expressions_self_contained(equations_block.get_inline_expressions())
65+
66+
for equations_block in model.get_equations_blocks():
67+
self.replace_inline_expressions_through_defining_expressions(equations_block.get_ode_equations(), equations_block.get_inline_expressions())
68+
69+
if single:
70+
return models[0]
71+
72+
return models
73+
74+
def make_inline_expressions_self_contained(self, inline_expressions: List[ASTInlineExpression]) -> List[ASTInlineExpression]:
75+
r"""
76+
Make inline expressions self contained, i.e. without any references to other inline expressions.
77+
78+
:param inline_expressions: A sorted list with entries ASTInlineExpression.
79+
:return: A list with ASTInlineExpressions. Defining expressions don't depend on each other.
80+
"""
81+
from pynestml.utils.model_parser import ModelParser
82+
from pynestml.visitors.ast_symbol_table_visitor import ASTSymbolTableVisitor
83+
84+
for source in inline_expressions:
85+
source_position = source.get_source_position()
86+
for target in inline_expressions:
87+
matcher = re.compile(self._variable_matching_template.format(source.get_variable_name()))
88+
target_definition = str(target.get_expression())
89+
target_definition = re.sub(matcher, "(" + str(source.get_expression()) + ")", target_definition)
90+
old_parent = target.expression.parent_
91+
target.expression = ModelParser.parse_expression(target_definition)
92+
target.expression.update_scope(source.get_scope())
93+
target.expression.parent_ = old_parent
94+
target.expression.accept(ASTParentVisitor())
95+
target.expression.accept(ASTSymbolTableVisitor())
96+
97+
def log_set_source_position(node):
98+
if node.get_source_position().is_added_source_position():
99+
node.set_source_position(source_position)
100+
101+
target.expression.accept(ASTHigherOrderVisitor(visit_funcs=log_set_source_position))
102+
103+
return inline_expressions
104+
105+
@classmethod
106+
def replace_inline_expressions_through_defining_expressions(self, definitions: Sequence[ASTOdeEquation],
107+
inline_expressions: Sequence[ASTInlineExpression]) -> Sequence[ASTOdeEquation]:
108+
r"""
109+
Replace variable symbols referencing inline expressions in defining expressions of ODEs with the corresponding defining expressions from the inline expressions.
110+
111+
:param definitions: A list of ODE definitions (**updated in-place**).
112+
:param inline_expressions: A list of inline expression definitions.
113+
:return: A list of updated ODE definitions (same as the ``definitions`` parameter).
114+
"""
115+
from pynestml.utils.model_parser import ModelParser
116+
from pynestml.visitors.ast_symbol_table_visitor import ASTSymbolTableVisitor
117+
118+
for m in inline_expressions:
119+
if "mechanism" not in [e.namespace for e in m.get_decorators()]:
120+
"""
121+
exclude compartmental mechanism definitions in order to have the
122+
inline as a barrier inbetween odes that are meant to be solved independently
123+
"""
124+
source_position = m.get_source_position()
125+
for target in definitions:
126+
matcher = re.compile(self._variable_matching_template.format(m.get_variable_name()))
127+
target_definition = str(target.get_rhs())
128+
target_definition = re.sub(matcher, "(" + str(m.get_expression()) + ")", target_definition)
129+
old_parent = target.rhs.parent_
130+
target.rhs = ModelParser.parse_expression(target_definition)
131+
target.update_scope(m.get_scope())
132+
target.rhs.parent_ = old_parent
133+
target.rhs.accept(ASTParentVisitor())
134+
target.accept(ASTSymbolTableVisitor())
135+
136+
def log_set_source_position(node):
137+
if node.get_source_position().is_added_source_position():
138+
node.set_source_position(source_position)
139+
140+
target.accept(ASTHigherOrderVisitor(visit_funcs=log_set_source_position))
141+
142+
return definitions

pynestml/utils/ast_utils.py

Lines changed: 1 addition & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@
6565
from pynestml.utils.messages import Messages
6666
from pynestml.utils.string_utils import removesuffix
6767
from pynestml.visitors.ast_higher_order_visitor import ASTHigherOrderVisitor
68+
from pynestml.visitors.ast_parent_visitor import ASTParentVisitor
6869
from pynestml.visitors.ast_visitor import ASTVisitor
6970

7071

@@ -1027,8 +1028,6 @@ def has_equation_with_delay_variable(cls, equations_with_delay_vars: ASTOdeEquat
10271028
return True
10281029
return False
10291030

1030-
_variable_matching_template = r'(\b)({})(\b)'
1031-
10321031
@classmethod
10331032
def add_declarations_to_internals(cls, neuron: ASTModel, declarations: Mapping[str, str]) -> ASTModel:
10341033
"""
@@ -2080,74 +2079,6 @@ def remove_ode_definitions_from_equations_block(cls, model: ASTModel) -> None:
20802079
for decl in decl_to_remove:
20812080
equations_block.get_declarations().remove(decl)
20822081

2083-
@classmethod
2084-
def make_inline_expressions_self_contained(cls, inline_expressions: List[ASTInlineExpression]) -> List[ASTInlineExpression]:
2085-
"""
2086-
Make inline_expressions self contained, i.e. without any references to other inline_expressions.
2087-
2088-
TODO: it should be a method inside of the ASTInlineExpression
2089-
TODO: this should be done by means of a visitor
2090-
2091-
:param inline_expressions: A sorted list with entries ASTInlineExpression.
2092-
:return: A list with ASTInlineExpressions. Defining expressions don't depend on each other.
2093-
"""
2094-
from pynestml.utils.model_parser import ModelParser
2095-
from pynestml.visitors.ast_symbol_table_visitor import ASTSymbolTableVisitor
2096-
2097-
for source in inline_expressions:
2098-
source_position = source.get_source_position()
2099-
for target in inline_expressions:
2100-
matcher = re.compile(cls._variable_matching_template.format(source.get_variable_name()))
2101-
target_definition = str(target.get_expression())
2102-
target_definition = re.sub(matcher, "(" + str(source.get_expression()) + ")", target_definition)
2103-
target.expression = ModelParser.parse_expression(target_definition)
2104-
target.expression.update_scope(source.get_scope())
2105-
target.expression.accept(ASTSymbolTableVisitor())
2106-
2107-
def log_set_source_position(node):
2108-
if node.get_source_position().is_added_source_position():
2109-
node.set_source_position(source_position)
2110-
2111-
target.expression.accept(ASTHigherOrderVisitor(visit_funcs=log_set_source_position))
2112-
2113-
return inline_expressions
2114-
2115-
@classmethod
2116-
def replace_inline_expressions_through_defining_expressions(cls, definitions: Sequence[ASTOdeEquation],
2117-
inline_expressions: Sequence[ASTInlineExpression]) -> Sequence[ASTOdeEquation]:
2118-
"""
2119-
Replaces symbols from `inline_expressions` in `definitions` with corresponding defining expressions from `inline_expressions`.
2120-
2121-
:param definitions: A list of ODE definitions (**updated in-place**).
2122-
:param inline_expressions: A list of inline expression definitions.
2123-
:return: A list of updated ODE definitions (same as the ``definitions`` parameter).
2124-
"""
2125-
from pynestml.utils.model_parser import ModelParser
2126-
from pynestml.visitors.ast_symbol_table_visitor import ASTSymbolTableVisitor
2127-
2128-
for m in inline_expressions:
2129-
if "mechanism" not in [e.namespace for e in m.get_decorators()]:
2130-
"""
2131-
exclude compartmental mechanism definitions in order to have the
2132-
inline as a barrier inbetween odes that are meant to be solved independently
2133-
"""
2134-
source_position = m.get_source_position()
2135-
for target in definitions:
2136-
matcher = re.compile(cls._variable_matching_template.format(m.get_variable_name()))
2137-
target_definition = str(target.get_rhs())
2138-
target_definition = re.sub(matcher, "(" + str(m.get_expression()) + ")", target_definition)
2139-
target.rhs = ModelParser.parse_expression(target_definition)
2140-
target.update_scope(m.get_scope())
2141-
target.accept(ASTSymbolTableVisitor())
2142-
2143-
def log_set_source_position(node):
2144-
if node.get_source_position().is_added_source_position():
2145-
node.set_source_position(source_position)
2146-
2147-
target.accept(ASTHigherOrderVisitor(visit_funcs=log_set_source_position))
2148-
2149-
return definitions
2150-
21512082
@classmethod
21522083
def get_delta_factors_(cls, neuron: ASTModel, equations_block: ASTEquationsBlock) -> dict:
21532084
r"""

pynestml/visitors/ast_symbol_table_visitor.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from pynestml.meta_model.ast_model_body import ASTModelBody
2525
from pynestml.meta_model.ast_namespace_decorator import ASTNamespaceDecorator
2626
from pynestml.meta_model.ast_declaration import ASTDeclaration
27+
from pynestml.meta_model.ast_inline_expression import ASTInlineExpression
2728
from pynestml.meta_model.ast_simple_expression import ASTSimpleExpression
2829
from pynestml.meta_model.ast_stmt import ASTStmt
2930
from pynestml.meta_model.ast_variable import ASTVariable
@@ -473,11 +474,10 @@ def visit_variable(self, node: ASTVariable):
473474
node.get_vector_parameter().update_scope(node.get_scope())
474475
node.get_vector_parameter().accept(self)
475476

476-
def visit_inline_expression(self, node):
477+
def visit_inline_expression(self, node: ASTInlineExpression):
477478
"""
478-
Private method: Used to visit a single ode-function, create the corresponding symbol and update the scope.
479+
Private method: Used to visit a single inline expression, create the corresponding symbol and update the scope.
479480
:param node: a single inline expression.
480-
:type node: ASTInlineExpression
481481
"""
482482

483483
# split the decorators in the AST up into namespace decorators and other decorators
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
"""
2+
beta_function_with_inline_expression_neuron
3+
###########################################
4+
5+
Description
6+
+++++++++++
7+
8+
Used for testing processing of inline expressions.
9+
10+
11+
Copyright
12+
+++++++++
13+
14+
This file is part of NEST.
15+
16+
Copyright (C) 2004 The NEST Initiative
17+
18+
NEST is free software: you can redistribute it and/or modify
19+
it under the terms of the GNU General Public License as published by
20+
the Free Software Foundation, either version 2 of the License, or
21+
(at your option) any later version.
22+
23+
NEST is distributed in the hope that it will be useful,
24+
but WITHOUT ANY WARRANTY; without even the implied warranty of
25+
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
26+
GNU General Public License for more details.
27+
You should have received a copy of the GNU General Public License
28+
along with NEST. If not, see <http://www.gnu.org/licenses/>.
29+
"""
30+
model beta_function_with_inline_expression_neuron:
31+
32+
parameters:
33+
tau1 ms = 20 ms ## decay time
34+
tau2 ms = 10 ms ## rise time
35+
36+
state:
37+
x_ pA/ms = 0 pA/ms
38+
x pA = 0 pA
39+
40+
internals:
41+
alpha real = 42.
42+
43+
equations:
44+
x' = x_ - x / tau2
45+
x_' = - x_ / tau1
46+
47+
recordable inline z pA = x
48+
49+
input:
50+
weighted_input_spikes <- spike
51+
52+
output:
53+
spike
54+
55+
update:
56+
integrate_odes()
57+
58+
onReceive(weighted_input_spikes):
59+
x_ += alpha * (1 / tau2 - 1 / tau1) * pA * weighted_input_spikes * s

0 commit comments

Comments
 (0)