| from transformers import PretrainedConfig | |
| from typing import List | |
| class VQVAEConfig(PretrainedConfig): | |
| model_type = "VQVAE" | |
| def __init__( | |
| self, | |
| embedding_dim: int = 256, | |
| n_codes: int = 2048, | |
| n_hiddens: int = 240, | |
| n_res_layers: int = 4, | |
| downsample: List[int] = [2, 4, 4], | |
| **kwargs, | |
| ): | |
| self.embedding_dim = embedding_dim | |
| self.n_codes = n_codes | |
| self.n_hiddens = n_hiddens | |
| self.n_res_layers = n_res_layers | |
| self.downsample = downsample | |
| super().__init__(**kwargs) |