Skip to content

Commit 0c42b2a

Browse files
Implement fold expression approach for variadic parameters in compi_lambda
Co-Authored-By: Serg Kryvonos <[email protected]>
1 parent 7005fa4 commit 0c42b2a

File tree

3 files changed

+36
-26
lines changed

3 files changed

+36
-26
lines changed

docs/modules/ROOT/pages/development/python_wrapper.adoc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ result2 = lambda_func([2, 5]) # 2*5 + 2 = 12
8888

8989
[NOTE]
9090
====
91-
The `compi_lambda` method currently supports up to 3 variables due to C++ template limitations. For expressions with more variables, use `compile_into_lambda` instead.
91+
The `compi_lambda` method supports up to 10 variables. For expressions with more variables, use `compile_into_lambda` instead, which has no limit on the number of variables.
9292
====
9393

9494
== Implementation Details

omnn/variable/tests/test_compilambda.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,22 @@ def test_single_variable_lambda(self):
4949

5050
result = lambda_func([4])
5151
self.assertEqual(float(result), 17.0) # 4*4 + 1 = 17
52+
53+
def test_multiple_variables_lambda(self):
54+
x = variable.Variable()
55+
y = variable.Variable()
56+
z = variable.Variable()
57+
w = variable.Variable()
58+
59+
expr = x + y * z + w
60+
61+
lambda_func = expr.compi_lambda([x, y, z, w])
62+
63+
result = lambda_func([2, 3, 4, 5])
64+
self.assertEqual(float(result), 19.0) # 2 + 3*4 + 5 = 19
65+
66+
result = lambda_func([1, 2, 3, 4])
67+
self.assertEqual(float(result), 11.0) # 1 + 2*3 + 4 = 11
5268

5369
if __name__ == "__main__":
5470
unittest.main()

omnn/variable/valuable.cpp

Lines changed: 19 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -177,10 +177,6 @@ BOOST_PYTHON_MODULE(variable)
177177
" result = lambda_func([3, 4]) # 3 + 2*4 = 11")
178178

179179
.def("compi_lambda", +[](const Valuable& v, const boost::python::list& variables) {
180-
if (len(variables) > 3) {
181-
throw std::runtime_error("Currently only supports up to 3 variables due to C++ template limitations");
182-
}
183-
184180
std::vector<Variable> vars;
185181
for (int i = 0; i < len(variables); ++i) {
186182
vars.push_back(extract<Variable>(variables[i]));
@@ -197,29 +193,27 @@ BOOST_PYTHON_MODULE(variable)
197193
values.push_back(extract<Valuable>(args[i]));
198194
}
199195

200-
auto lambda = [&]() {
201-
switch (vars.size()) {
202-
case 1:
203-
return v.CompiLambda(vars[0]);
204-
case 2:
205-
return v.CompiLambda(vars[0], vars[1]);
206-
case 3:
207-
return v.CompiLambda(vars[0], vars[1], vars[2]);
208-
default:
209-
throw std::runtime_error("Unsupported number of variables");
210-
}
211-
}();
196+
auto callCompiLambda = [&v](auto&&... vars) {
197+
return v.CompiLambda(std::forward<decltype(vars)>(vars)...);
198+
};
199+
200+
auto callLambda = [](auto&& lambda, auto&&... args) {
201+
return lambda(std::forward<decltype(args)>(args)...);
202+
};
212203

213-
switch (values.size()) {
214-
case 1:
215-
return lambda(values[0]);
216-
case 2:
217-
return lambda(values[0], values[1]);
218-
case 3:
219-
return lambda(values[0], values[1], values[2]);
220-
default:
221-
throw std::runtime_error("Unsupported number of arguments");
204+
if (vars.size() > 10) {
205+
throw std::runtime_error("Currently only supports up to 10 variables due to implementation constraints");
222206
}
207+
208+
auto invokeLambda = [&]() -> Valuable {
209+
return [&]<size_t... I>(std::index_sequence<I...>) {
210+
auto lambda = callCompiLambda(vars[I]...);
211+
212+
return callLambda(lambda, values[I]...);
213+
}(std::make_index_sequence<vars.size()>{});
214+
};
215+
216+
return invokeLambda();
223217
}
224218
);
225219
},

0 commit comments

Comments
 (0)