-
Notifications
You must be signed in to change notification settings - Fork 62
Open
Labels
Description
When using the HBR model, I run into errors when trying to use transfer or extend with only one site in the original normative model and one site in the transfer/extend dataset. The same error can be reproduced in the 06_transfer_extend.ipynb file in the examples section, when selecting only 1 site.
The current tutorial examples assume 2+ sites, so it is unclear how batch effects should be specified when working with only a single source site and a single transfer site.
The error message is listed below, but it is unclear whether:
- The batch effects are incorrectly specified,
- The intercept structure expected by HBR requires ≥2 sites,
- Or whether single-site transfer is currently unsupported.
Error
RuntimeError Traceback (most recent call last)
Cell In[9], line 11
1 small_model = NormativeModel(
2 template_regression_model=template_hbr,
3 savemodel=True,
(...)
9 outscaler="standardize",
10 )
---> 11 small_model.fit_predict(transfer_train, transfer_test)
12 plot_centiles_advanced(
13 small_model,
14 scatter_data=transfer_train,
15 scatter_kwargs_advanced={},
16 plt_kwargs={}
17 )
File ~/.conda/envs/char2/lib/python3.12/site-packages/pcntoolkit/normative_model.py:168, in NormativeModel.fit_predict(self, fit_data, predict_data)
164 def fit_predict(self, fit_data: NormData, predict_data: NormData) -> NormData:
165 """
166 Combines model.fit and model.predict in a single operation.
167 """
--> 168 self.fit(fit_data)
169 self.predict(predict_data)
170 if self.savemodel: # Make sure model is saved
File ~/.conda/envs/char2/lib/python3.12/site-packages/pcntoolkit/normative_model.py:134, in NormativeModel.fit(self, data)
132 resp_fit_data = data.sel({"response_vars": responsevar})
133 X, be, be_maps, Y, _ = self.extract_data(resp_fit_data)
--> 134 self[responsevar].fit(X, be, be_maps, Y)
135 self.is_fitted = True
136 self.postprocess(data)
File ~/.conda/envs/char2/lib/python3.12/site-packages/pcntoolkit/regression_model/hbr.py:107, in HBR.fit(self, X, be, be_maps, Y)
105 self.pymc_model: pm.Model = self.likelihood.compile(X, be, self.be_maps, Y)
106 with self.pymc_model:
--> 107 self.idata = pm.sample(
108 self.draws,
109 tune=self.tune,
110 cores=self.cores,
111 chains=self.chains,
112 nuts_sampler=self.nuts_sampler, # type: ignore
113 init=self.init,
114 progressbar=self.progressbar,
115 )
116 self.is_fitted = True
File ~/.conda/envs/char2/lib/python3.12/site-packages/pymc/sampling/mcmc.py:809, in sample(draws, tune, chains, cores, random_seed, progressbar, progressbar_theme, step, var_names, nuts_sampler, initvals, init, jitter_max_retries, n_init, trace, discard_tuned_samples, compute_convergence_checks, keep_warning_stat, return_inferencedata, idata_kwargs, nuts_sampler_kwargs, callback, mp_ctx, blas_cores, model, compile_kwargs, **kwargs)
804 raise ValueError(
805 "Model can not be sampled with NUTS alone. It either has discrete variables or a non-differentiable log-probability."
806 )
808 with joined_blas_limiter():
--> 809 return _sample_external_nuts(
810 sampler=nuts_sampler,
811 draws=draws,
812 tune=tune,
813 chains=chains,
814 target_accept=kwargs.pop("nuts", {}).get("target_accept", 0.8),
815 random_seed=random_seed,
816 initvals=initvals,
817 model=model,
818 var_names=var_names,
819 progressbar=progress_bool,
820 idata_kwargs=idata_kwargs,
821 compute_convergence_checks=compute_convergence_checks,
822 nuts_sampler_kwargs=nuts_sampler_kwargs,
823 **kwargs,
824 )
826 if exclusive_nuts and not provided_steps:
827 # Special path for NUTS initialization
828 if "nuts" in kwargs:
File ~/.conda/envs/char2/lib/python3.12/site-packages/pymc/sampling/mcmc.py:349, in _sample_external_nuts(sampler, draws, tune, chains, target_accept, random_seed, initvals, model, var_names, progressbar, idata_kwargs, compute_convergence_checks, nuts_sampler_kwargs, **kwargs)
344 compiled_model = nutpie.compile_pymc_model(
345 model,
346 **compile_kwargs,
347 )
348 t_start = time.time()
--> 349 idata = nutpie.sample(
350 compiled_model,
351 draws=draws,
352 tune=tune,
353 chains=chains,
354 target_accept=target_accept,
355 seed=_get_seeds_per_chain(random_seed, 1)[0],
356 progress_bar=progressbar,
357 **nuts_sampler_kwargs,
358 )
359 t_sample = time.time() - t_start
360 # Temporary work-around. Revert once https://github.com/pymc-devs/nutpie/issues/74 is fixed
361 # gather observed and constant data as nutpie.sample() has no access to the PyMC model
File ~/.conda/envs/char2/lib/python3.12/site-packages/nutpie/sample.py:636, in sample(compiled_model, draws, tune, chains, cores, seed, save_warmup, progress_bar, low_rank_modified_mass_matrix, init_mean, return_raw_trace, blocking, progress_template, progress_style, progress_rate, **kwargs)
633 return sampler
635 try:
--> 636 result = sampler.wait()
637 except KeyboardInterrupt:
638 result = sampler.abort()
File ~/.conda/envs/char2/lib/python3.12/site-packages/nutpie/sample.py:388, in _BackgroundSampler.wait(self, timeout)
378 def wait(self, *, timeout=None):
379 """Wait until sampling is finished and return the trace.
380
381 KeyboardInterrupt will lead to interrupt the waiting.
(...)
386 This resumes the sampler in case it had been paused.
387 """
--> 388 self._sampler.wait(timeout)
389 results = self._sampler.extract_results()
390 return self._extract(results)
RuntimeError: Could not create arrow struct
Caused by:
Invalid argument error: Incorrect array length for StructArray field "normalized_site_offset_intercept_mu_zerosum__", expected 2000 got 0
Reactions are currently unavailable