Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Label get param #306

Open
wants to merge 77 commits into
base: development
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
77 commits
Select commit Hold shift + click to select a range
a7170e7
rename params
BalzaniEdoardo Feb 1, 2025
1ba2d73
rename params
BalzaniEdoardo Feb 2, 2025
5770d19
added setter for label, added list all labels
BalzaniEdoardo Feb 2, 2025
feca6a4
fix autogenerated
BalzaniEdoardo Feb 3, 2025
0818cba
added label to composite basis
BalzaniEdoardo Feb 3, 2025
842cfc5
fix label logic
BalzaniEdoardo Feb 3, 2025
abfc2d3
added getitem
BalzaniEdoardo Feb 3, 2025
04bd322
added getitem to transformer
BalzaniEdoardo Feb 3, 2025
cf28db0
set_params return self
BalzaniEdoardo Feb 3, 2025
66fb783
improved regex
BalzaniEdoardo Feb 4, 2025
df21547
fix setter
BalzaniEdoardo Feb 4, 2025
d17a1c3
split method of basis in two
BalzaniEdoardo Feb 4, 2025
d30f0be
linted
BalzaniEdoardo Feb 4, 2025
4b39526
improved var names
BalzaniEdoardo Feb 4, 2025
cdc1b46
simplified func
BalzaniEdoardo Feb 4, 2025
ffd2890
fixed behavior
BalzaniEdoardo Feb 4, 2025
c50d4e3
use f string
BalzaniEdoardo Feb 4, 2025
5dfc988
improved exception
BalzaniEdoardo Feb 4, 2025
c73706a
renamed method
BalzaniEdoardo Feb 4, 2025
ace344d
edited tutorials
BalzaniEdoardo Feb 4, 2025
6cfc6a1
added new note on composition?
BalzaniEdoardo Feb 4, 2025
c9d2f6a
added a card
BalzaniEdoardo Feb 4, 2025
f02a920
added exception
BalzaniEdoardo Feb 4, 2025
cd1af4c
Merge branch 'main' into label_get_param
BalzaniEdoardo Feb 5, 2025
64e2280
restore original label if error is raised
BalzaniEdoardo Feb 5, 2025
1097896
use generator to getitem
BalzaniEdoardo Feb 5, 2025
59ed6c2
simplified search
BalzaniEdoardo Feb 5, 2025
02e55f5
improved generate keys
BalzaniEdoardo Feb 5, 2025
1e14d2d
refined map
BalzaniEdoardo Feb 5, 2025
655d912
improved docstrig
BalzaniEdoardo Feb 5, 2025
b4ed972
do not rely on attr
BalzaniEdoardo Feb 5, 2025
69f242e
bugfix
BalzaniEdoardo Feb 5, 2025
05531a7
initialize new
BalzaniEdoardo Feb 5, 2025
aa4f612
sort properties
BalzaniEdoardo Feb 5, 2025
7d9e5e4
added test on dynamic label update
BalzaniEdoardo Feb 5, 2025
264a076
linted and added test getitem
BalzaniEdoardo Feb 5, 2025
96ae246
improved test getitem
BalzaniEdoardo Feb 5, 2025
02b6d0f
fixed shared tests
BalzaniEdoardo Feb 5, 2025
bdb7583
fix additive basis lab test
BalzaniEdoardo Feb 5, 2025
73c3d5c
fix test labels
BalzaniEdoardo Feb 6, 2025
c887ca9
fast editing basis func
BalzaniEdoardo Feb 7, 2025
7d16bde
fixed small issues with basis listing
BalzaniEdoardo Feb 7, 2025
d4bdfa4
fix composition checks
BalzaniEdoardo Feb 7, 2025
f00a883
use autogen in getitem of composite if needed.
BalzaniEdoardo Feb 7, 2025
d3274a5
added a test for invalid labels
BalzaniEdoardo Feb 7, 2025
fec3f49
list module runtime
BalzaniEdoardo Feb 8, 2025
83e2504
added sys module lookup
BalzaniEdoardo Feb 8, 2025
9b65424
fix docs and moved list of public bases
BalzaniEdoardo Feb 10, 2025
4c4797d
fix doctests
BalzaniEdoardo Feb 10, 2025
a304bc6
linted
BalzaniEdoardo Feb 10, 2025
1c9e04b
depr warning mpl
BalzaniEdoardo Feb 10, 2025
0694082
improved note
BalzaniEdoardo Feb 10, 2025
34c3b06
improved note
BalzaniEdoardo Feb 10, 2025
e554a8d
improved decorator logic
BalzaniEdoardo Feb 10, 2025
ee71424
improve check logic
BalzaniEdoardo Feb 10, 2025
71be463
removed unused imports
BalzaniEdoardo Feb 10, 2025
21b953c
better naming
BalzaniEdoardo Feb 10, 2025
ef3ab00
fixed tests
BalzaniEdoardo Feb 10, 2025
4f92d65
deepcopied bases
BalzaniEdoardo Feb 10, 2025
37c495f
added test for nested invalid set_params
BalzaniEdoardo Feb 10, 2025
01e76b7
fixed test transf basis
BalzaniEdoardo Feb 10, 2025
d3bd7f7
removed unused import
BalzaniEdoardo Feb 10, 2025
beca52e
improved comment
BalzaniEdoardo Feb 11, 2025
41c9d07
added check for class name
BalzaniEdoardo Feb 11, 2025
ba2e99d
added few tests
BalzaniEdoardo Feb 11, 2025
1c50b0e
remove additional seen_labels
BalzaniEdoardo Feb 11, 2025
a65a944
fix test attr exception
BalzaniEdoardo Feb 11, 2025
ec27a24
simplified code
BalzaniEdoardo Feb 11, 2025
db730cd
simplify key map and new param dict compute
BalzaniEdoardo Feb 11, 2025
fd7b17e
additional code simplification
BalzaniEdoardo Feb 11, 2025
4a8b3c8
additional code simplification
BalzaniEdoardo Feb 11, 2025
67f7548
swap output arg
BalzaniEdoardo Feb 11, 2025
662c355
add dir to init
BalzaniEdoardo Feb 11, 2025
6583242
removed transformer basis
BalzaniEdoardo Feb 11, 2025
4bb3447
linted and added tests
BalzaniEdoardo Feb 11, 2025
e26b01b
fix flake8
BalzaniEdoardo Feb 11, 2025
791e669
moved basis setter logic to dedicated func
BalzaniEdoardo Feb 11, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 12 additions & 10 deletions docs/background/basis/plot_01_1D_basis_function.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ jupytext:
format_version: 0.13
jupytext_version: 1.16.4
kernelspec:
display_name: Python 3
display_name: Python 3 (ipykernel)
language: python
name: python3
---
Expand Down Expand Up @@ -42,7 +42,6 @@ warnings.filterwarnings(
from nemos._documentation_utils._myst_nb_glue import glue_two_step_convolve
glue_two_step_convolve()
```

(simple_basis_function)=
Expand Down Expand Up @@ -71,7 +70,8 @@ order = 4
n_basis = 10
# Define the 1D basis function object
bspline = nmo.basis.BSplineEval(n_basis_funcs=n_basis, order=order)
bspline = nmo.basis.BSplineEval(n_basis_funcs=n_basis, order=order, label="bspline")
bspline
```

We provide the convenience method `evaluate_on_grid` for evaluating the basis on an equi-spaced grid of points that makes it easier to plot and visualize all basis elements.
Expand All @@ -85,7 +85,6 @@ plt.plot(x, y, lw=2)
plt.title("B-Spline Basis")
```


## Computing Features
All bases in the `nemos.basis` module perform a transformation of one or more time series into a set of features. This operation is always carried out by the method [`compute_features`](nemos.basis._basis.Basis.compute_features).
We can group the bases into two categories depending on the type of transformation that [`compute_features`](nemos.basis._basis.Basis.compute_features) applies:
Expand All @@ -97,8 +96,8 @@ We can group the bases into two categories depending on the type of transformati
Let's see how these two categories operate:

```{code-cell} ipython3
eval_mode = nmo.basis.BSplineEval(n_basis_funcs=n_basis)
conv_mode = nmo.basis.BSplineConv(n_basis_funcs=n_basis, window_size=100)
eval_mode = nmo.basis.BSplineEval(n_basis_funcs=n_basis, label="eval")
conv_mode = nmo.basis.BSplineConv(n_basis_funcs=n_basis, window_size=100, label="conv")
# define an input
angles = np.linspace(0, np.pi*4, 201)
Expand Down Expand Up @@ -148,7 +147,7 @@ For inputs with more than one dimension, `compute_features` assumes the first ax
For Eval bases, `compute_features` evaluates the basis and outputs a 2D feature matrix.

```{code-cell} ipython3
basis = nmo.basis.RaisedCosineLinearEval(n_basis_funcs=5)
basis = nmo.basis.RaisedCosineLinearEval(n_basis_funcs=5, label="multidim")
# generate a 3D array
inp = np.random.randn(50, 3, 2)
out = basis.compute_features(inp)
Expand All @@ -158,6 +157,12 @@ out.shape
For each of the $3 \times 2 = 6$ inputs, `n_basis_funcs = 5` features are computed. These are concatenated on the second axis of the feature matrix, for a total of
$3 \times 2 \times 5 = 30$ outputs.

This concatenation can be undone by the `split_by_feature` method of basis, which creates a dictionary with keys the labels of the basis and values a reshaped array.

```{code-cell} ipython3
basis.split_by_feature(out, axis=1)["multidim"].shape
```

#### Conv Basis

For Conv bases, `compute_features` convolves each input with `n_basis_funcs` kernels and outputs a 2D feature matrix.
Expand Down Expand Up @@ -214,7 +219,6 @@ You can specify a range for the support of your basis by setting the `bounds`
parameter at initialization of Eval bases.
Evaluating the basis at any sample outside the bounds will result in a NaN.


```{code-cell} ipython3
bspline_range = nmo.basis.BSplineEval(n_basis_funcs=n_basis, order=order, bounds=(0.2, 0.8))
Expand All @@ -226,7 +230,6 @@ print(np.round(bspline_range.compute_features([0.5, 0.1]), 3))
Let's compare the default behavior of basis (estimating the range from the samples) with
the fixed range basis.


```{code-cell} ipython3
samples = np.linspace(0, 1, 200)
fig, axs = plt.subplots(2,1, sharex=True)
Expand All @@ -237,4 +240,3 @@ axs[1].plot(samples, bspline_range.compute_features(samples), color="tomato")
axs[1].set_title("bounds=[0.2, 0.8]")
plt.tight_layout()
```

36 changes: 19 additions & 17 deletions docs/background/basis/plot_02_ND_basis_function.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ jupytext:
format_version: 0.13
jupytext_version: 1.16.4
kernelspec:
display_name: Python 3
display_name: Python 3 (ipykernel)
language: python
name: python3
---
Expand Down Expand Up @@ -123,24 +123,23 @@ $$

Here, we simply add two basis objects, `a_basis` and `b_basis`, together to define the additive basis.


```{code-cell} ipython3
import matplotlib.pyplot as plt
import numpy as np
import nemos as nmo
# Define 1D basis objects
a_basis = nmo.basis.MSplineEval(n_basis_funcs=15, order=3)
b_basis = nmo.basis.RaisedCosineLogEval(n_basis_funcs=14)
a_basis = nmo.basis.MSplineEval(n_basis_funcs=15, order=3, label="a")
b_basis = nmo.basis.RaisedCosineLogEval(n_basis_funcs=14, label="b")
# Define the 2D additive basis object
additive_basis = a_basis + b_basis
additive_basis
```

Evaluating the additive basis will require two inputs, one for each coordinate.
The total number of elements of the additive basis will be the sum of the elements of the 1D basis.


```{code-cell} ipython3
# Define a trajectory with 1000 time-points representing the recorded trajectory of the animal
T = 1000
Expand All @@ -154,14 +153,14 @@ eval_basis = additive_basis.compute_features(x_coord, y_coord)
print(f"Sum of two 1D splines with {eval_basis.shape[1]} "
f"basis element and {eval_basis.shape[0]} samples:\n"
f"\t- a_basis had {a_basis.n_basis_funcs} elements\n\t- b_basis had {b_basis.n_basis_funcs} elements.")
f"\t- a_basis had {additive_basis['a'].n_basis_funcs} elements.\n"
f"\t- b_basis had {additive_basis['b'].n_basis_funcs} elements.")
```

(plotting-2d-additive-basis-elements)=
#### Plotting 2D Additive Basis Elements
Let's select and plot a basis element from each of the basis we added.


```{code-cell} ipython3
basis_a_element = 5
basis_b_element = 1
Expand All @@ -184,21 +183,18 @@ We can visualize how these elements are extended in 2D by evaluating the additiv
on a grid of points that spans its domain and plotting the result.
We use the `evaluate_on_grid` method for this.


```{code-cell} ipython3
X, Y, Z = additive_basis.evaluate_on_grid(200, 200)
```

We can select the indices of the 2D additive basis that corresponds to the 1D original elements.


```{code-cell} ipython3
basis_elem_idx = [basis_a_element, a_basis.n_basis_funcs + basis_b_element]
```

Finally, we can plot the 2D counterparts.


```{code-cell} ipython3
_, axs = plt.subplots(1, 2, subplot_kw={'aspect': 1})
Expand All @@ -218,6 +214,19 @@ plt.tight_layout()
plt.show()
```

If we don't want to do the index algebra ourself, we can use the `split_by_feature` method to split `Z` for each additive element of the basis.

```{code-cell} ipython3
Z_split = additive_basis.split_by_feature(Z, axis=-1)
print(Z_split["a"].shape, Z_split["b"].shape)
```

And then index directly the splitted array.

```{code-cell} ipython3
element_a, element_b = Z_split["a"][basis_a_element], Z_split["b"][basis_b_element]
```

### Multiplicative Basis Object

If the aim is to capture interactions between the coordinates, the response function can be modeled as the external
Expand All @@ -230,7 +239,6 @@ $$
In this model, we define the 2D basis function as the product of two 1D basis objects.
This allows the response to capture non-linear and interaction effects between the x and y coordinates.


```{code-cell} ipython3
# 2D basis function as the product of the two 1D basis objects
prod_basis = a_basis * b_basis
Expand All @@ -239,7 +247,6 @@ prod_basis = a_basis * b_basis
Again evaluating the basis will require 2 inputs.
The number of elements of the product basis will be the product of the elements of the two 1D bases.


```{code-cell} ipython3
# Evaluate the product basis at the x and y coordinates
eval_basis = prod_basis.compute_features(x_coord, y_coord)
Expand All @@ -255,9 +262,7 @@ print(f"Product of two 1D splines with {eval_basis.shape[1]} "
Plotting works in the same way as before. To demonstrate that, we select a few pairs of 1D basis elements,
and we visualize the corresponding product.


```{code-cell} ipython3
X, Y, Z = prod_basis.evaluate_on_grid(200, 200)
# basis element pairs
Expand Down Expand Up @@ -299,7 +304,6 @@ A practical example would be characterizing the responses to position
in a linear maze and the LFP phase angle.
:::


N-Dimensional Basis
-------------------
Sometimes it may be useful to model even higher dimensional interactions, for example between the heding direction of
Expand Down Expand Up @@ -331,7 +335,6 @@ print(f"Product of three 1D splines results in {prod_basis_3.n_basis_funcs} "

The evaluation of the product of 3 basis is a 4 dimensional tensor; we can visualize slices of it.


```{code-cell} ipython3
X, Y, W, Z = prod_basis_3.evaluate_on_grid(30, 30, 30)
Expand Down Expand Up @@ -369,7 +372,6 @@ full domain of the basis.
Here we demonstrate a shortcut syntax for multiplying bases of the same class.
This is achieved using the power operator with an integer exponent.


```{code-cell} ipython3
# First, let's define a basis `power_basis` that is equivalent to `prod_basis_3`,
# but we use the power syntax this time:
Expand Down
6 changes: 2 additions & 4 deletions docs/background/plot_00_conceptual_intro.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ jupytext:
format_version: 0.13
jupytext_version: 1.16.4
kernelspec:
display_name: Python 3
display_name: Python 3 (ipykernel)
language: python
name: python3
---
Expand Down Expand Up @@ -39,6 +39,7 @@ warnings.filterwarnings(
category=RuntimeWarning,
)
```

(glm_intro_background)=
# Generalized Linear Models: An Introduction

Expand Down Expand Up @@ -154,7 +155,6 @@ to simplify things, we will look at three simple LNP neuron models as
described above, working through each step of the transform. First, we will
plot the linear transformation of the input x:


```{code-cell} ipython3
weights = np.asarray([.5, 4, -4])
intercepts = np.asarray([.5, -3, -2])
Expand All @@ -180,7 +180,6 @@ have to be non-negative! That's what the nonlinearity handles: making sure our
firing rate is always positive. We can visualize this second stage of the LNP model
by adding the `plot_nonlinear` keyword to our `lnp_schematic()` plotting function:


```{code-cell} ipython3
fig = doc_plots.lnp_schematic(input_feature, weights, intercepts,
plot_nonlinear=True)
Expand All @@ -207,7 +206,6 @@ positive, though note that the y-values have changed drastically.
Now we're ready to look at the third step of the LNP model, and see what
the generated spikes spikes look like!


```{code-cell} ipython3
# mkdocs_gallery_thumbnail_number = 3
fig = doc_plots.lnp_schematic(input_feature, weights, intercepts,
Expand Down
7 changes: 1 addition & 6 deletions docs/background/plot_03_1D_convolution.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ jupytext:
format_version: 0.13
jupytext_version: 1.16.4
kernelspec:
display_name: Python 3
display_name: Python 3 (ipykernel)
language: python
name: python3
---
Expand Down Expand Up @@ -46,7 +46,6 @@ warnings.filterwarnings(
## Generate synthetic data
Generate some simulated spike counts.


```{code-cell} ipython3
import matplotlib.patches as patches
import matplotlib.pylab as plt
Expand Down Expand Up @@ -109,7 +108,6 @@ and anti-causal effects one should use the acausal filters.
Below we provide the function [`create_convolutional_predictor`](nemos.convolve.create_convolutional_predictor) that runs the convolution in "valid" mode and pads the convolution output
for the different filter types.


```{code-cell} ipython3
# pad according to the causal direction of the filter, after squeeze,
# the dimension is (n_filters, n_samples)
Expand All @@ -126,7 +124,6 @@ spk_acausal_conv = nmo.convolve.create_convolutional_predictor(

Plot the results


```{code-cell} ipython3
# NaN padded area
rect_causal = patches.Rectangle((0, -2.5), ws, 5, alpha=0.3, color='grey')
Expand Down Expand Up @@ -161,7 +158,6 @@ plt.vlines(np.arange(spk.shape[0]), 0, shift_spk, color='k')
plt.plot(np.arange(spk.shape[0]), spk_acausal_conv)
plt.ylabel('acausal')
plt.tight_layout()
```

```{code-cell} ipython3
Expand Down Expand Up @@ -195,7 +191,6 @@ convolution.
All the parameters of [`create_convolutional_predictor`](nemos.convolve.create_convolutional_predictor) can be passed to the object directly at initialization.
Let's see how we can get the same results through [`Basis`](nemos.basis._basis.Basis).


```{code-cell} ipython3
# define basis with different predictor causality
causal_basis = nmo.basis.RaisedCosineLinearConv(
Expand Down
20 changes: 19 additions & 1 deletion docs/how_to_guide/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,9 @@ plot_05_transformer_basis.md
:::{grid-item-card}

<figure>
<a href="plot_06_sklearn_pipeline_cv_demo.html">
<img src="../_static/thumbnails/how_to_guide/plot_06_sklearn_pipeline_cv_demo.svg" style="height: 100px", alt="PyTrees."/>
</a>
</figure>

```{toctree}
Expand All @@ -85,7 +87,9 @@ plot_06_sklearn_pipeline_cv_demo.md
:::{grid-item-card}

<figure>
<a href="plot_07_glm_pytree.html">
<img src="../_static/thumbnails/how_to_guide/plot_07_glm_pytree.svg" style="height: 100px", alt="PyTrees."/>
</a>
</figure>

```{toctree}
Expand All @@ -99,10 +103,11 @@ plot_07_glm_pytree.md
:::{grid-item-card}

```{eval-rst}

.. plot:: scripts/glm_predictors.py plot_categorical_var_design_matrix
:show-source-link: False
:height: 100px

```

```{toctree}
Expand Down Expand Up @@ -133,7 +138,9 @@ custom_predictors.md
:::{grid-item-card}

<figure>
<a href="raw_history_feature.html">
<img src="../_static/glm_population_scheme.svg" style="height: 100px", alt="Coupled GLM."/>
</a>
</figure>

```{toctree}
Expand All @@ -144,4 +151,15 @@ raw_history_feature.md

:::

:::{grid-item-card}


```{toctree}
:maxdepth: 2

handling_composite_bases.md
```

:::

::::
Loading