import pdb import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from einops import rearrange from .layers.utils import * from .layers.transformer import SpatialTemporalBlock, CrossAttentionBlock class GestureDenoiser(nn.Module): def __init__(self, input_dim=128, latent_dim=256, ff_size=1024, num_layers=8, num_heads=4, dropout=0.1, activation="gelu", n_seed=8, flip_sin_to_cos= True, freq_shift = 0, cond_proj_dim=None, use_exp=False, seq_len=32, embed_context_multiplier=4, ): super().__init__() self.input_dim = input_dim self.latent_dim = latent_dim self.ff_size = ff_size self.num_layers = num_layers self.num_heads = num_heads self.dropout = dropout self.activation = activation self.use_exp = use_exp self.joint_num = 3 if not self.use_exp else 4 self.sequence_pos_encoder = PositionalEncoding(self.latent_dim, self.dropout) self.cross_attn_blocks = nn.ModuleList([ CrossAttentionBlock(dim=self.latent_dim*self.joint_num,num_heads=self.num_heads,mlp_ratio=self.ff_size//self.latent_dim,drop_path=self.dropout) #hidden是对应于输入x的维度,attn_heads应该是12,这里写1是为了方便调试流程 for _ in range(3)]) self.mytimmblocks = nn.ModuleList([ SpatialTemporalBlock(dim=self.latent_dim,num_heads=self.num_heads,mlp_ratio=self.ff_size//self.latent_dim,drop_path=self.dropout) #hidden是对应于输入x的维度,attn_heads应该是12,这里写1是为了方便调试流程 for _ in range(self.num_layers)]) self.embed_timestep = TimestepEmbedder(self.latent_dim, self.sequence_pos_encoder) self.n_seed = n_seed self.seq_len = seq_len self.embed_context_multiplier = embed_context_multiplier self.embed_text = nn.Linear(self.input_dim * self.joint_num * self.embed_context_multiplier, self.latent_dim) self.output_process = OutputProcess(self.input_dim, self.latent_dim) self.rel_pos = SinusoidalEmbeddings(self.latent_dim) self.input_process = InputProcess(self.input_dim , self.latent_dim) self.input_process2 = nn.Linear(self.latent_dim*2, self.latent_dim) self.time_embedding = TimestepEmbedding(self.latent_dim, self.latent_dim, self.activation, cond_proj_dim=cond_proj_dim, zero_init_cond=True) time_dim = self.latent_dim self.time_proj = Timesteps(time_dim, flip_sin_to_cos, freq_shift) if cond_proj_dim is not None: self.cond_proj = Timesteps(time_dim, flip_sin_to_cos, freq_shift) # Null condition embedding for classifier-free guidance self.null_cond_embed = nn.Parameter(torch.zeros(self.seq_len, self.latent_dim*self.joint_num), requires_grad=True) # dropout mask def prob_mask_like(self, shape, prob, device): return torch.zeros(shape, device=device).float().uniform_(0, 1) < prob def forward(self, x, timesteps, cond_time=None, seed=None, at_feat=None): """ x: [batch_size, njoints, nfeats, max_frames], denoted x_t in the paper timesteps: [batch_size] (int) seed: [batch_size, njoints, nfeats] """ if x.shape[2] == 1: x = x.squeeze(2) x = x.reshape(x.shape[0], self.joint_num, -1, x.shape[2]) bs, njoints, nfeats, nframes = x.shape # [bs, 3, 128, 32] # need to be an arrary, especially when bs is 1 # timesteps = timesteps.expand(bs).clone() time_emb = self.time_proj(timesteps) time_emb = time_emb.to(dtype=x.dtype) if cond_time is not None and self.cond_proj is not None: cond_time = cond_time.expand(bs).clone() cond_emb = self.cond_proj(cond_time) cond_emb = cond_emb.to(dtype=x.dtype) emb_t = self.time_embedding(time_emb, cond_emb) else: emb_t = self.time_embedding(time_emb) if self.n_seed != 0: embed_text = self.embed_text(seed.reshape(bs, -1)) emb_seed = embed_text xseq = self.input_process(x) # add the seed information embed_style_2 = (emb_seed + emb_t).unsqueeze(1).unsqueeze(2).expand(-1, self.joint_num, self.seq_len, -1) # (300, 256) xseq = torch.cat([embed_style_2, xseq], axis=-1) # -> [88, 300, 576] xseq = self.input_process2(xseq) # apply the positional encoding xseq = xseq.reshape(bs * self.joint_num, nframes, -1) pos_emb = self.rel_pos(xseq) xseq, _ = apply_rotary_pos_emb(xseq, xseq, pos_emb) xseq = xseq.reshape(bs, self.joint_num, nframes, -1) xseq = xseq.view(bs, self.seq_len, -1) for block in self.cross_attn_blocks: xseq = block(xseq, at_feat) xseq = xseq.view(bs, njoints, self.seq_len, -1) for block in self.mytimmblocks: xseq = block(xseq) output = xseq output = self.output_process(output) return output