Concerning alternate priors/posteriors #82
nnethercott
started this conversation in
General
Replies: 1 comment
-
Hi @nnethercott, from pythae.models import VAE, VAEConfig
from pythae.models.base.base_utils import ModelOutput
from pythae.data.datasets import BaseDataset
from pydantic.dataclasses import dataclass
import torch
### You can define the config of your model as follows
@dataclass
class MyModelConfig(VAEConfig):
n: int = 2
m: int = 3
def __post_init__(self):
super().__post_init__()
self.latent_dim = self.n + self.m
### Then, you can create your model
class MyVAEModel(VAE):
def __init__(self, model_config: MyModelConfig, encoder=None, decoder=None):
super().__init__(model_config, encoder, decoder)
self.model_name = "MyVAEModel"
self.n = model_config.n
self.m = model_config.m
self.prior = torch.distributions.Normal(0, 1)
def forward(self, inputs: BaseDataset, **kwargs):
x = inputs["data"]
encoder_output = self.encoder(x)
mu, log_var = encoder_output.embedding, encoder_output.log_covariance
std = torch.exp(0.5 * log_var)
z, _ = self._sample_gauss(mu, std)
mu_i = z[:self.n]
mu_j = z[self.n]
#### DO WHAT YOU WANT WITH Z_i ####
#### DO WHAT YOU WANT WITH Z_j ####
recon_x = self.decoder(z)["reconstruction"]
#### DEFINE YOUR OWN LOSS ####
loss = ((recon_x.reshape(x.shape[0], -1) - x.reshape(x.shape[0], -1) ** 2)).sum(dim=-1)
return ModelOutput(
loss=loss.mean(),
recon_x=recon_x,
z=z,
) You can now add anything you want within the I hope this helps. Do not hesitate you have any questions :) Best, Clément |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
I was wondering what the support was like for implementing different priors on the latent space rather than the traditional$z_{i}\overset{iid}{\sim} N(0,1)$ prior, as well as for parametric forms of the encoder posterior. For instance is the implementation of the spike and slab prior on the latents (or mixture models in general) a quick modification to the pre-existing code?
I was hoping to have direct control over imposing the prior at the resolution of each individual latent since I was hoping to try something funky, but was wanted to make use of minimal code through pythae to accomplish this. Essentially what i'm looking for is to have the flexibility of doing something like: for$z\in R^{n+m}$ , $z_{i} \sim \mathcal{N}(0,1)$ $i=1,2,...,n$ , and $z_{j}\sim f_{j}(\cdot)$ for $j=n+1,...,n+m$ where the $f_{j}$ 's don't necessarily admit a closed form in the computation of the KL divergence with the encoder posterior (and hence where MC estimation would be needed).
Beta Was this translation helpful? Give feedback.
All reactions