diff --git a/tests/test_all.py b/tests/test_all.py index 13274fd..56da312 100644 --- a/tests/test_all.py +++ b/tests/test_all.py @@ -1,6 +1,9 @@ +import numpy as np +import pysindy as ps import pytest from gen_experiments.typing import NestedDict +from gen_experiments.utils import unionize_coeff_matrices def test_flatten_nested_dict(): @@ -19,3 +22,16 @@ def test_flatten_nested_bad_dict(): with pytest.raises(TypeError, match="Only string keys allowed"): deep = NestedDict(a={1: 1}) deep.flatten() + + +def test_unionize_coeff_matrices(): + # lib = ps.PolynomialLibrary().fit(np.array([[1, 1]])) + model = ps.SINDy(feature_names=["x", "y"]) + data = np.arange(10) + data = np.vstack((data, data)).T + model.fit(data, 0.1) + coeff_true = [{"y": -1, "zorp_x": 0.1}, {"x": 1, "zorp_y": 0.1}] + true, est, feats = unionize_coeff_matrices(model, coeff_true) + assert len(feats) == true.shape[1] + assert len(feats) == est.shape[1] + assert est.shape == true.shape