import torch import torch.nn as nn import torch.nn.functional as F from einops import rearrange from torch.nn.utils.parametrizations import weight_norm def WNConv1d(*args, **kwargs): return weight_norm(nn.Conv1d(*args, **kwargs)) def WNConvTranspose1d(*args, **kwargs): return weight_norm(nn.ConvTranspose1d(*args, **kwargs)) class VectorQuantize(nn.Module): def __init__( self, input_dim: int, codebook_size: int, codebook_dim: int, ): super().__init__() self.input_dim = input_dim self.codebook_size = codebook_size self.codebook_dim = codebook_dim self.in_project = ( WNConv1d( self.input_dim, self.codebook_dim, kernel_size=1 ) # (B, D, T) -> (B, D', T) if self.input_dim != self.codebook_dim else nn.Identity() ) self.out_project = ( WNConv1d( self.codebook_dim, self.input_dim, kernel_size=1 ) # (B, D', T) -> (B, D, T) if self.input_dim != self.codebook_dim else nn.Identity() ) # Initialize codebook and EMA buffers self.register_buffer( "codebook", torch.zeros(codebook_size, codebook_dim).float() ) # (codebook_size, D'), ensure fp32 # Place holder, not used in inference self.register_buffer("inited", torch.tensor([True], dtype=torch.bool)) # (1) self.register_buffer( "cluster_size", torch.zeros(codebook_size).float() ) # (codebook_size), ensure fp32 self.register_buffer( "embed_avg", self.codebook.clone().float() ) # (codebook_size, D'), ensure fp32 def decode_code(self, embed_id): # embed_id: (B, T) embed = ( F.embedding(embed_id, self.codebook).transpose(1, 2).float() ) # (B, D', T), ensure fp32 return embed def encode_code(self, z: torch.Tensor): # z: (B, D, T) # logging.info(f"{self.cluster_size = }, {self.codebook = }, {self.embed_avg = }, {self.inited = }") z = z.float() # Ensure fp32 z_e = self.in_project(z).float() # (B, D', T), ensure fp32 # Rearrange for quantization encodings = rearrange(z_e, "b d t -> (b t) d").float() # (B*T, D'), ensure fp32 # Quantization dist = ( encodings.pow(2).sum(1, keepdim=True) # (B*T, 1) - 2 * encodings @ self.codebook.float().t() # (B*T, codebook_size) + self.codebook.float().pow(2).sum(1, keepdim=True).t() ) # (1, codebook_size) # dist: (B*T, codebook_size) indices = (-dist).max(1)[1] # (B*T) indices = rearrange(indices, "(b t) -> b t", b=z.size(0)) # (B, T) # Get quantized vectors z_q = self.decode_code(indices).float() # (B, D', T), ensure fp32 # Straight-through estimator z_q = z_e + (z_q - z_e).detach() # (B, D', T) z_q = self.out_project(z_q).float() # (B, D, T), ensure fp32 # z_q: (B, D, T), commit_loss: (B), indices: (B, T), z: (B, D', T) return z_q, indices class ResidualVQ(nn.Module): def __init__( self, input_dim: int = 768, # Input dimension, unrelated to RVQ rvq_dim=None, # RVQ dimension. If different from input_dim/output_dim, will add input_dim->rvq_dim/rvq_dim->output_dim projection output_dim: int = None, # Output dimension, unrelated to RVQ num_quantizers: int = 8, codebook_size: int = 1024, codebook_dim: int = 256, # Dimension of each codebook. If different from rvq_dim, will add rvq_dim->codebook_dim and codebook_dim->rvq_dim projections ): super().__init__() self.input_dim = input_dim self.num_quantizers = num_quantizers self.codebook_size = codebook_size self.codebook_dim = codebook_dim self.rvq_dim = rvq_dim self.input_proj = ( WNConv1d(input_dim, rvq_dim, kernel_size=1) if input_dim != rvq_dim else nn.Identity() ) self.output_proj = ( WNConv1d(rvq_dim, output_dim, kernel_size=1) if rvq_dim != output_dim else nn.Identity() ) self.quantizers = nn.ModuleList( [ VectorQuantize( input_dim=rvq_dim, codebook_size=self.codebook_size, codebook_dim=codebook_dim, ) for i in range(num_quantizers) ] ) def encode_codes(self, z: torch.Tensor): z = self.input_proj(z) residual = z.clone().float() # (B, D, T), ensure fp32 all_indices = [] # Quantize to tokens for i, quantizer in enumerate(self.quantizers): # (B, D, T), (B), scalar, (B, T), (B, D', T), ensure fp32 z_q_i, indices_i = quantizer.encode_code(residual) residual = residual - z_q_i all_indices.append(indices_i) # (B, T) all_indices = torch.stack(all_indices) # (N, B, T) return all_indices def decode_codes(self, codes): # codes: (nq, B, T) """Decode codes from multiple quantizers to embeddings. Args: codes: Tensor of shape (nq, B, T) containing code indices for each quantizer. Returns: emb: Tensor of shape (B, D, T) representing the decoded embeddings. """ nq, B, T = codes.shape device = codes.device emb = torch.zeros( B, self.rvq_dim, T, device=device, dtype=torch.float32 ) # (B, D, T) for i, quantizer in enumerate(self.quantizers[:nq]): code_i = codes[i] # (B, T) quantized_i = quantizer.decode_code(code_i) # (B, D', T) emb += quantizer.out_project(quantized_i) # Accumulate quantized embeddings emb = self.output_proj(emb) # (B, D, T), apply output projection return emb # (B, D, T)