Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bug: ConcretizationTypeError when trying to use prob_model.predictive() #101

Closed
PaulScemama opened this issue Jul 18, 2023 · 5 comments
Closed
Labels
bug Something isn't working

Comments

@PaulScemama
Copy link
Contributor

PaulScemama commented Jul 18, 2023

Bug Report

Hi! I've trained a prob_model and created checkpoints. I then run prob_model.load_state and attempt to produce predictions on the test set. However, I'm getting the following error:

...
  pspec=PartitionSpec('processes',)
] b
    from line [/home/pscemama/bayesian-conformal-sets/.venv/lib/python3.10/site-packages/orbax/checkpoint/utils.py:63](https://vscode-remote+ssh-002dremote-002brapidstart.vscode-resource.vscode-cdn.net/home/pscemama/bayesian-conformal-sets/.venv/lib/python3.10/site-packages/orbax/checkpoint/utils.py:63) (sync_global_devices)

See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError

The only thing I've done that is not standard is use my own custom model, which is here:

from typing import Any
import flax.linen as nn
import jax.numpy as jnp
import jax

act = jax.nn.swish


class AlexNet(nn.Module):
    output_dim: int
    dtype: Any = jnp.float32
    """
    An AlexNet model for Cifar10.
    """

    def setup(self):
        self.hidden_layers = AlexNetHiddenLayers(dtype=self.dtype)
        self.last_layer = AlexNetLastLayer(output_dim=self.output_dim, dtype=self.dtype)

    def __call__(self, x: jnp.ndarray, train: bool = True) -> jnp.ndarray:
        x = self.hidden_layers(x, train)
        x = self.last_layer(x, train)
        return x


class AlexNetHiddenLayers(nn.Module):
    dtype: Any = jnp.float32
    """
    Hidden Convolutional layers of AlexNet model
    """

    @nn.compact
    def __call__(self, x: jnp.ndarray, train: bool = True):
        # [32, 32, 3]
        x = nn.Conv(features=64, kernel_size=(3,))(x)
        # [32, 32, 64]
        x = act(x)
        x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2))
        # [16, 16, 64]

        x = nn.Conv(features=128, kernel_size=(3,))(x)
        # [16, 16, 128]
        x = act(x)
        x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2))
        # [8, 8, 128]

        x = nn.Conv(features=256, kernel_size=(2,))(x)
        # [8, 8, 256]
        x = act(x)

        x = nn.Conv(features=128, kernel_size=(2,))(x)
        # [8, 8, 128]
        x = act(x)

        x = nn.Conv(features=64, kernel_size=(2,))(x)
        # [8, 8, 64]
        x = act(x)

        x = x.reshape((x.shape[0], -1))
        return x


class AlexNetLastLayer(nn.Module):
    output_dim: int
    dtype: Any = jnp.float32

    @nn.compact
    def __call__(self, x: jnp.ndarray, train: bool = True):
        x = nn.Dense(features=256, dtype=self.dtype)(x)
        x = act(x)
        x = nn.Dense(features=256, dtype=self.dtype)(x)
        x = act(x)
        x = nn.Dense(features=self.output_dim, dtype=self.dtype)(x)
        return x

Steps to reproduce:

# // Model
prob_model = ProbClassifier(
    model=AlexNet(output_dim=10), 
    posterior_approximator=LaplacePosteriorApproximator(),
    prior=IsotropicGaussianPrior(log_var=jnp.log(PRIOR_VAR))
)
prob_model.load_state("../sgd_checkpoints/checkpoint_11532/")
test_log_probs = prob_model.predictive.log_prob(data_loader=test_loader)
# RAISES ERROR

Other information:

The data is coming from a torch dataloader, and converted with .from_torch_dataloader(). Let me know if you need more information on the actual data.

My hunch is that maybe I'm doing something wrong here. Any guidance is appreciated :)

@PaulScemama PaulScemama added the bug Something isn't working label Jul 18, 2023
@gianlucadetommaso
Copy link
Contributor

Hi Paul,
could you provide a reproducible example? The error you get is not at Fortuna's level, so I don't really know what's going on 😄

@PaulScemama
Copy link
Contributor Author

@gianlucadetommaso it looks like it was my mistake! I passed in the directory "/checkpoint_11532" to load_state instead of the file "/checkpoint_15532/checkpoint/". It might be useful to catch such an error (e.g. check if the input is a directory or file) because with orbax you pass in the directory.

@gianlucadetommaso
Copy link
Contributor

Alright! I'm anyway refactoring the checkpointing to work with Orbax instead. This is part of #96 which will also enable model sharding.

@PaulScemama
Copy link
Contributor Author

Looking forward to it! @gianlucadetommaso. As I've been using fortuna more, I've had some thoughts as well for possible feature enhancement / pull requests. What would be the best place to discuss such things?

@gianlucadetommaso
Copy link
Contributor

If you want to discuss something at high-level, I would open a Discussion. If you have a more concrete bug/feature request in mind, I'd suggest to open an issue. If you find some small problems and want directly to propose a quick fix, feel free to open a PR.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants