Skip to content

HBR transfer/extend fails when using only 1 source site and 1 transfer site (intercept format error) #368

@CharFraza

Description

@CharFraza

When using the HBR model, I run into errors when trying to use transfer or extend with only one site in the original normative model and one site in the transfer/extend dataset. The same error can be reproduced in the 06_transfer_extend.ipynb file in the examples section, when selecting only 1 site.

The current tutorial examples assume 2+ sites, so it is unclear how batch effects should be specified when working with only a single source site and a single transfer site.

The error message is listed below, but it is unclear whether:

  1. The batch effects are incorrectly specified,
  2. The intercept structure expected by HBR requires ≥2 sites,
  3. Or whether single-site transfer is currently unsupported.
Error
RuntimeError                              Traceback (most recent call last)
Cell In[9], line 11
      1 small_model = NormativeModel(
      2     template_regression_model=template_hbr,
      3     savemodel=True,
   (...)
      9     outscaler="standardize",
     10 )
---> 11 small_model.fit_predict(transfer_train, transfer_test)
     12 plot_centiles_advanced(
     13     small_model,
     14     scatter_data=transfer_train,
     15     scatter_kwargs_advanced={},
     16     plt_kwargs={}
     17 )

File ~/.conda/envs/char2/lib/python3.12/site-packages/pcntoolkit/normative_model.py:168, in NormativeModel.fit_predict(self, fit_data, predict_data)
    164 def fit_predict(self, fit_data: NormData, predict_data: NormData) -> NormData:
    165     """
    166     Combines model.fit and model.predict in a single operation.
    167     """
--> 168     self.fit(fit_data)
    169     self.predict(predict_data)
    170     if self.savemodel:  # Make sure model is saved

File ~/.conda/envs/char2/lib/python3.12/site-packages/pcntoolkit/normative_model.py:134, in NormativeModel.fit(self, data)
    132     resp_fit_data = data.sel({"response_vars": responsevar})
    133     X, be, be_maps, Y, _ = self.extract_data(resp_fit_data)
--> 134     self[responsevar].fit(X, be, be_maps, Y)
    135 self.is_fitted = True
    136 self.postprocess(data)

File ~/.conda/envs/char2/lib/python3.12/site-packages/pcntoolkit/regression_model/hbr.py:107, in HBR.fit(self, X, be, be_maps, Y)
    105 self.pymc_model: pm.Model = self.likelihood.compile(X, be, self.be_maps, Y)
    106 with self.pymc_model:
--> 107     self.idata = pm.sample(
    108         self.draws,
    109         tune=self.tune,
    110         cores=self.cores,
    111         chains=self.chains,
    112         nuts_sampler=self.nuts_sampler,  # type: ignore
    113         init=self.init,
    114         progressbar=self.progressbar,
    115     )
    116 self.is_fitted = True

File ~/.conda/envs/char2/lib/python3.12/site-packages/pymc/sampling/mcmc.py:809, in sample(draws, tune, chains, cores, random_seed, progressbar, progressbar_theme, step, var_names, nuts_sampler, initvals, init, jitter_max_retries, n_init, trace, discard_tuned_samples, compute_convergence_checks, keep_warning_stat, return_inferencedata, idata_kwargs, nuts_sampler_kwargs, callback, mp_ctx, blas_cores, model, compile_kwargs, **kwargs)
    804         raise ValueError(
    805             "Model can not be sampled with NUTS alone. It either has discrete variables or a non-differentiable log-probability."
    806         )
    808     with joined_blas_limiter():
--> 809         return _sample_external_nuts(
    810             sampler=nuts_sampler,
    811             draws=draws,
    812             tune=tune,
    813             chains=chains,
    814             target_accept=kwargs.pop("nuts", {}).get("target_accept", 0.8),
    815             random_seed=random_seed,
    816             initvals=initvals,
    817             model=model,
    818             var_names=var_names,
    819             progressbar=progress_bool,
    820             idata_kwargs=idata_kwargs,
    821             compute_convergence_checks=compute_convergence_checks,
    822             nuts_sampler_kwargs=nuts_sampler_kwargs,
    823             **kwargs,
    824         )
    826 if exclusive_nuts and not provided_steps:
    827     # Special path for NUTS initialization
    828     if "nuts" in kwargs:

File ~/.conda/envs/char2/lib/python3.12/site-packages/pymc/sampling/mcmc.py:349, in _sample_external_nuts(sampler, draws, tune, chains, target_accept, random_seed, initvals, model, var_names, progressbar, idata_kwargs, compute_convergence_checks, nuts_sampler_kwargs, **kwargs)
    344 compiled_model = nutpie.compile_pymc_model(
    345     model,
    346     **compile_kwargs,
    347 )
    348 t_start = time.time()
--> 349 idata = nutpie.sample(
    350     compiled_model,
    351     draws=draws,
    352     tune=tune,
    353     chains=chains,
    354     target_accept=target_accept,
    355     seed=_get_seeds_per_chain(random_seed, 1)[0],
    356     progress_bar=progressbar,
    357     **nuts_sampler_kwargs,
    358 )
    359 t_sample = time.time() - t_start
    360 # Temporary work-around. Revert once https://github.com/pymc-devs/nutpie/issues/74 is fixed
    361 # gather observed and constant data as nutpie.sample() has no access to the PyMC model

File ~/.conda/envs/char2/lib/python3.12/site-packages/nutpie/sample.py:636, in sample(compiled_model, draws, tune, chains, cores, seed, save_warmup, progress_bar, low_rank_modified_mass_matrix, init_mean, return_raw_trace, blocking, progress_template, progress_style, progress_rate, **kwargs)
    633     return sampler
    635 try:
--> 636     result = sampler.wait()
    637 except KeyboardInterrupt:
    638     result = sampler.abort()

File ~/.conda/envs/char2/lib/python3.12/site-packages/nutpie/sample.py:388, in _BackgroundSampler.wait(self, timeout)
    378 def wait(self, *, timeout=None):
    379     """Wait until sampling is finished and return the trace.
    380 
    381     KeyboardInterrupt will lead to interrupt the waiting.
   (...)
    386     This resumes the sampler in case it had been paused.
    387     """
--> 388     self._sampler.wait(timeout)
    389     results = self._sampler.extract_results()
    390     return self._extract(results)

RuntimeError: Could not create arrow struct

Caused by:
    Invalid argument error: Incorrect array length for StructArray field "normalized_site_offset_intercept_mu_zerosum__", expected 2000 got 0

Metadata

Metadata

Assignees

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions