GestureLSM / trainer /generative_trainer.py
Tharun156's picture
Upload 149 files
f7400bf verified
import os
import pprint
import random
import sys
import time
import warnings
from typing import Dict
import librosa
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import wandb
from dataloaders import data_tools
from dataloaders.data_tools import joints_list
from loguru import logger
from models.vq.model import RVQVAE
from optimizers.optim_factory import create_optimizer
from optimizers.scheduler_factory import create_scheduler
from torch.nn.parallel import DistributedDataParallel as DDP
from tqdm import tqdm
from trainer.base_trainer import BaseTrainer
from utils import (
data_transfer,
logger_tools,
metric,
other_tools,
other_tools_hf,
rotation_conversions as rc,
)
from utils.joints import hands_body_mask, lower_body_mask, upper_body_mask
def convert_15d_to_6d(motion):
"""
Convert 15D motion to 6D motion, the current motion is 15D, but the eval model is 6D
"""
bs = motion.shape[0]
motion_6d = motion.reshape(bs, -1, 55, 15)[:, :, :, 6:12]
motion_6d = motion_6d.reshape(bs, -1, 55 * 6)
return motion_6d
class CustomTrainer(BaseTrainer):
"""
Generative Trainer to support various generative models
"""
def __init__(self, cfg, args):
super().__init__(cfg, args)
self.cfg = cfg
self.args = args
self.joints = 55
self.ori_joint_list = joints_list["beat_smplx_joints"]
self.tar_joint_list_face = joints_list["beat_smplx_face"]
self.tar_joint_list_upper = joints_list["beat_smplx_upper"]
self.tar_joint_list_hands = joints_list["beat_smplx_hands"]
self.tar_joint_list_lower = joints_list["beat_smplx_lower"]
self.joint_mask_face = np.zeros(len(list(self.ori_joint_list.keys())) * 3)
self.joints = 55
for joint_name in self.tar_joint_list_face:
self.joint_mask_face[
self.ori_joint_list[joint_name][1]
- self.ori_joint_list[joint_name][0] : self.ori_joint_list[joint_name][
1
]
] = 1
self.joint_mask_upper = np.zeros(len(list(self.ori_joint_list.keys())) * 3)
for joint_name in self.tar_joint_list_upper:
self.joint_mask_upper[
self.ori_joint_list[joint_name][1]
- self.ori_joint_list[joint_name][0] : self.ori_joint_list[joint_name][
1
]
] = 1
self.joint_mask_hands = np.zeros(len(list(self.ori_joint_list.keys())) * 3)
for joint_name in self.tar_joint_list_hands:
self.joint_mask_hands[
self.ori_joint_list[joint_name][1]
- self.ori_joint_list[joint_name][0] : self.ori_joint_list[joint_name][
1
]
] = 1
self.joint_mask_lower = np.zeros(len(list(self.ori_joint_list.keys())) * 3)
for joint_name in self.tar_joint_list_lower:
self.joint_mask_lower[
self.ori_joint_list[joint_name][1]
- self.ori_joint_list[joint_name][0] : self.ori_joint_list[joint_name][
1
]
] = 1
self.tracker = other_tools.EpochTracker(
["fgd", "bc", "l1div", "predict_x0_loss", "test_clip_fgd"],
[True, True, True, True, True],
)
##### Model #####
model_module = __import__(
f"models.{cfg.model.model_name}", fromlist=["something"]
)
if self.cfg.ddp:
self.model = getattr(model_module, cfg.model.g_name)(cfg).to(self.rank)
process_group = torch.distributed.new_group()
self.model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(
self.model, process_group
)
self.model = DDP(
self.model,
device_ids=[self.rank],
output_device=self.rank,
broadcast_buffers=False,
find_unused_parameters=False,
)
else:
self.model = getattr(model_module, cfg.model.g_name)(cfg)
if self.args.mode == "train":
if self.rank == 0:
logger.info(self.model)
logger.info(f"init {self.cfg.model.g_name} success")
wandb.watch(self.model)
##### Optimizer and Scheduler #####
self.opt = create_optimizer(self.cfg.solver, self.model)
self.opt_s = create_scheduler(self.cfg.solver, self.opt)
##### VQ-VAE models #####
"""Initialize and load VQ-VAE models for different body parts."""
# Body part VQ models
self.vq_models = self._create_body_vq_models()
# Set all VQ models to eval mode
for model in self.vq_models.values():
model.eval()
self.vq_model_upper, self.vq_model_hands, self.vq_model_lower = (
self.vq_models.values()
)
##### Loss functions #####
self.reclatent_loss = nn.MSELoss()
self.vel_loss = torch.nn.L1Loss(reduction="mean")
##### Normalization #####
self.mean = np.load("./mean_std/beatx_2_330_mean.npy")
self.std = np.load("./mean_std/beatx_2_330_std.npy")
# Extract body part specific normalizations
for part in ["upper", "hands", "lower"]:
mask = globals()[f"{part}_body_mask"]
setattr(self, f"mean_{part}", torch.from_numpy(self.mean[mask]))
setattr(self, f"std_{part}", torch.from_numpy(self.std[mask]))
self.trans_mean = torch.from_numpy(
np.load("./mean_std/beatx_2_trans_mean.npy")
)
self.trans_std = torch.from_numpy(
np.load("./mean_std/beatx_2_trans_std.npy")
)
if self.args.checkpoint:
try:
ckpt_state_dict = torch.load(self.args.checkpoint, weights_only=False)[
"model_state_dict"
]
except:
ckpt_state_dict = torch.load(self.args.checkpoint, weights_only=False)[
"model_state"
]
# remove 'audioEncoder' from the state_dict due to legacy issues
ckpt_state_dict = {
k: v
for k, v in ckpt_state_dict.items()
if "modality_encoder.audio_encoder." not in k
}
self.model.load_state_dict(ckpt_state_dict, strict=False)
logger.info(f"Loaded checkpoint from {self.args.checkpoint}")
def _create_body_vq_models(self) -> Dict[str, RVQVAE]:
"""Create VQ-VAE models for body parts."""
vq_configs = {
"upper": {"dim_pose": 78},
"hands": {"dim_pose": 180},
"lower": {"dim_pose": 57},
}
vq_models = {}
for part, config in vq_configs.items():
model = self._create_rvqvae_model(config["dim_pose"], part)
vq_models[part] = model
return vq_models
def _create_rvqvae_model(self, dim_pose: int, body_part: str) -> RVQVAE:
"""Create a single RVQVAE model with specified configuration."""
vq_args = self.args
vq_args.num_quantizers = 6
vq_args.shared_codebook = False
vq_args.quantize_dropout_prob = 0.2
vq_args.quantize_dropout_cutoff_index = 0
vq_args.mu = 0.99
vq_args.beta = 1.0
model = RVQVAE(
vq_args,
input_width=dim_pose,
nb_code=1024,
code_dim=128,
output_emb_width=128,
down_t=2,
stride_t=2,
width=512,
depth=3,
dilation_growth_rate=3,
activation="relu",
norm=None,
)
# Load pretrained weights
checkpoint_path = getattr(self.cfg, f"vqvae_{body_part}_path")
model.load_state_dict(torch.load(checkpoint_path)["net"])
return model
def inverse_selection(self, filtered_t, selection_array, n):
original_shape_t = np.zeros((n, selection_array.size))
selected_indices = np.where(selection_array == 1)[0]
for i in range(n):
original_shape_t[i, selected_indices] = filtered_t[i]
return original_shape_t
def inverse_selection_tensor(self, filtered_t, selection_array, n):
selection_array = torch.from_numpy(selection_array)
original_shape_t = torch.zeros((n, 165))
selected_indices = torch.where(selection_array == 1)[0]
for i in range(n):
original_shape_t[i, selected_indices] = filtered_t[i]
return original_shape_t
def _load_data(self, dict_data):
facial_rep = dict_data["facial"]
beta = dict_data["beta"]
tar_trans = dict_data["trans"]
tar_id = dict_data["id"]
# process the pose data
tar_pose = dict_data["pose"][:, :, :165]
tar_trans_v = dict_data["trans_v"]
tar_trans = dict_data["trans"]
bs, n, j = tar_pose.shape[0], tar_pose.shape[1], self.joints
tar_pose_hands = tar_pose[:, :, 25 * 3 : 55 * 3]
tar_pose_hands = rc.axis_angle_to_matrix(tar_pose_hands.reshape(bs, n, 30, 3))
tar_pose_hands = rc.matrix_to_rotation_6d(tar_pose_hands).reshape(bs, n, 30 * 6)
tar_pose_upper = tar_pose[:, :, self.joint_mask_upper.astype(bool)]
tar_pose_upper = rc.axis_angle_to_matrix(tar_pose_upper.reshape(bs, n, 13, 3))
tar_pose_upper = rc.matrix_to_rotation_6d(tar_pose_upper).reshape(bs, n, 13 * 6)
tar_pose_leg = tar_pose[:, :, self.joint_mask_lower.astype(bool)]
tar_pose_leg = rc.axis_angle_to_matrix(tar_pose_leg.reshape(bs, n, 9, 3))
tar_pose_leg = rc.matrix_to_rotation_6d(tar_pose_leg).reshape(bs, n, 9 * 6)
tar_pose_lower = tar_pose_leg
tar_pose_upper = (tar_pose_upper - self.mean_upper) / self.std_upper
tar_pose_hands = (tar_pose_hands - self.mean_hands) / self.std_hands
tar_pose_lower = (tar_pose_lower - self.mean_lower) / self.std_lower
tar_trans_v = (tar_trans_v - self.trans_mean) / self.trans_std
tar_pose_lower = torch.cat([tar_pose_lower, tar_trans_v], dim=-1)
latent_upper_top = self.vq_model_upper.map2latent(tar_pose_upper)
latent_hands_top = self.vq_model_hands.map2latent(tar_pose_hands)
latent_lower_top = self.vq_model_lower.map2latent(tar_pose_lower)
latent_lengths = [latent_upper_top.shape[1], latent_hands_top.shape[1], latent_lower_top.shape[1]]
if len(set(latent_lengths)) != 1:
min_len = min(latent_lengths)
logger.warning(
"Latent length mismatch detected (upper=%d, hands=%d, lower=%d); truncating to %d",
latent_upper_top.shape[1],
latent_hands_top.shape[1],
latent_lower_top.shape[1],
min_len,
)
latent_upper_top = latent_upper_top[:, :min_len, :]
latent_hands_top = latent_hands_top[:, :min_len, :]
latent_lower_top = latent_lower_top[:, :min_len, :]
## TODO: Whether the latent scale is needed here?
# latent_in = torch.cat([latent_upper_top, latent_hands_top, latent_lower_top], dim=2)
latent_in = (
torch.cat([latent_upper_top, latent_hands_top, latent_lower_top], dim=2) / 5
)
word = dict_data.get("word", None)
# style feature is always None (without annotation, we never know what it is)
style_feature = None
audio_onset = None
if self.cfg.data.onset_rep:
audio_onset = dict_data["audio_onset"]
return {
"audio_onset": audio_onset,
"word": word,
"latent_in": latent_in,
"tar_id": tar_id,
"facial_rep": facial_rep,
"beta": beta,
"tar_pose": tar_pose,
"trans": tar_trans,
"style_feature": style_feature,
}
def _g_training(self, loaded_data, mode="train", epoch=0):
self.model.train()
cond_ = {"y": {}}
cond_["y"]["audio_onset"] = loaded_data["audio_onset"]
cond_["y"]["word"] = loaded_data["word"]
cond_["y"]["id"] = loaded_data["tar_id"]
cond_["y"]["seed"] = loaded_data["latent_in"][:, : self.cfg.pre_frames]
cond_["y"]["style_feature"] = loaded_data["style_feature"]
x0 = loaded_data["latent_in"]
x0 = x0.permute(0, 2, 1).unsqueeze(2)
g_loss_final = self.model.module.train_forward(cond_, x0)["loss"]
self.tracker.update_meter("predict_x0_loss", "train", g_loss_final.item())
if mode == "train":
return g_loss_final
def _g_test(self, loaded_data):
self.model.eval()
tar_beta = loaded_data["beta"]
tar_pose = loaded_data["tar_pose"]
tar_exps = loaded_data["facial_rep"]
tar_trans = loaded_data["trans"]
audio_onset = loaded_data["audio_onset"]
in_word = loaded_data["word"]
in_x0 = loaded_data["latent_in"]
in_seed = loaded_data["latent_in"]
bs, n, j = (
loaded_data["tar_pose"].shape[0],
loaded_data["tar_pose"].shape[1],
self.joints,
)
remain = n % 8
if remain != 0:
tar_pose = tar_pose[:, :-remain, :]
tar_beta = tar_beta[:, :-remain, :]
tar_exps = tar_exps[:, :-remain, :]
in_x0 = in_x0[
:, : in_x0.shape[1] - (remain // self.cfg.vqvae_squeeze_scale), :
]
in_seed = in_seed[
:, : in_x0.shape[1] - (remain // self.cfg.vqvae_squeeze_scale), :
]
in_word = in_word[:, :-remain]
n = n - remain
rec_all_upper = []
rec_all_lower = []
rec_all_hands = []
vqvae_squeeze_scale = self.cfg.vqvae_squeeze_scale
pre_frames_scaled = self.cfg.pre_frames * vqvae_squeeze_scale
roundt = (n - pre_frames_scaled) // (
self.cfg.data.pose_length - pre_frames_scaled
)
remain = (n - pre_frames_scaled) % (
self.cfg.data.pose_length - pre_frames_scaled
)
round_l = self.cfg.pose_length - pre_frames_scaled
round_audio = int(round_l / 3 * 5)
in_audio_onset_tmp = None
in_word_tmp = None
for i in range(0, roundt):
if audio_onset is not None:
in_audio_onset_tmp = audio_onset[
:,
i * (16000 // 30 * round_l) : (i + 1) * (16000 // 30 * round_l)
+ 16000 // 30 * self.cfg.pre_frames * vqvae_squeeze_scale,
]
if in_word is not None:
in_word_tmp = in_word[
:,
i * (round_l) : (i + 1) * (round_l)
+ self.cfg.pre_frames * vqvae_squeeze_scale,
]
in_id_tmp = loaded_data["tar_id"][
:, i * (round_l) : (i + 1) * (round_l) + self.cfg.pre_frames
]
in_seed_tmp = in_seed[
:,
i
* (round_l)
// vqvae_squeeze_scale : (i + 1)
* (round_l)
// vqvae_squeeze_scale
+ self.cfg.pre_frames,
]
if i == 0:
in_seed_tmp = in_seed_tmp[:, : self.cfg.pre_frames, :]
else:
in_seed_tmp = last_sample[:, -self.cfg.pre_frames :, :]
cond_ = {"y": {}}
cond_["y"]["audio_onset"] = in_audio_onset_tmp
cond_["y"]["word"] = in_word_tmp
cond_["y"]["id"] = in_id_tmp
cond_["y"]["seed"] = in_seed_tmp
cond_["y"]["style_feature"] = torch.zeros([bs, 512])
sample = self.model(cond_)["latents"]
sample = sample.squeeze(2).permute(0, 2, 1)
last_sample = sample.clone()
code_dim = self.vq_model_upper.code_dim
rec_latent_upper = sample[..., :code_dim]
rec_latent_hands = sample[..., code_dim : code_dim * 2]
rec_latent_lower = sample[..., code_dim * 2 : code_dim * 3]
if i == 0:
rec_all_upper.append(rec_latent_upper)
rec_all_hands.append(rec_latent_hands)
rec_all_lower.append(rec_latent_lower)
else:
rec_all_upper.append(rec_latent_upper[:, self.cfg.pre_frames :])
rec_all_hands.append(rec_latent_hands[:, self.cfg.pre_frames :])
rec_all_lower.append(rec_latent_lower[:, self.cfg.pre_frames :])
try:
rec_all_upper = torch.cat(rec_all_upper, dim=1) * 5
rec_all_hands = torch.cat(rec_all_hands, dim=1) * 5
rec_all_lower = torch.cat(rec_all_lower, dim=1) * 5
except RuntimeError as exc:
shape_summary = {
"upper": [tuple(t.shape) for t in rec_all_upper],
"hands": [tuple(t.shape) for t in rec_all_hands],
"lower": [tuple(t.shape) for t in rec_all_lower],
}
logger.error("Failed to concatenate latent segments: %s | shapes=%s", exc, shape_summary)
raise
rec_upper = self.vq_model_upper.latent2origin(rec_all_upper)[0]
rec_hands = self.vq_model_hands.latent2origin(rec_all_hands)[0]
rec_lower = self.vq_model_lower.latent2origin(rec_all_lower)[0]
rec_trans_v = rec_lower[..., -3:]
rec_trans_v = rec_trans_v * self.trans_std + self.trans_mean
rec_trans = torch.zeros_like(rec_trans_v)
rec_trans = torch.cumsum(rec_trans_v, dim=-2)
rec_trans[..., 1] = rec_trans_v[..., 1]
rec_lower = rec_lower[..., :-3]
rec_upper = rec_upper * self.std_upper + self.mean_upper
rec_hands = rec_hands * self.std_hands + self.mean_hands
rec_lower = rec_lower * self.std_lower + self.mean_lower
n = n - remain
tar_pose = tar_pose[:, :n, :]
tar_exps = tar_exps[:, :n, :]
tar_trans = tar_trans[:, :n, :]
tar_beta = tar_beta[:, :n, :]
if hasattr(self.cfg.model, "use_exp") and self.cfg.model.use_exp:
rec_exps = tar_exps # fallback to tar_exps since rec_face is not defined
else:
rec_exps = tar_exps
rec_trans = tar_trans
rec_pose_legs = rec_lower[:, :, :54]
bs, n = rec_pose_legs.shape[0], rec_pose_legs.shape[1]
rec_pose_upper = rec_upper.reshape(bs, n, 13, 6)
rec_pose_upper = rc.rotation_6d_to_matrix(rec_pose_upper) #
rec_pose_upper = rc.matrix_to_axis_angle(rec_pose_upper).reshape(bs * n, 13 * 3)
rec_pose_upper_recover = self.inverse_selection_tensor(
rec_pose_upper, self.joint_mask_upper, bs * n
)
rec_pose_lower = rec_pose_legs.reshape(bs, n, 9, 6)
rec_pose_lower = rc.rotation_6d_to_matrix(rec_pose_lower)
rec_pose_lower = rc.matrix_to_axis_angle(rec_pose_lower).reshape(bs * n, 9 * 3)
rec_pose_lower_recover = self.inverse_selection_tensor(
rec_pose_lower, self.joint_mask_lower, bs * n
)
rec_pose_hands = rec_hands.reshape(bs, n, 30, 6)
rec_pose_hands = rc.rotation_6d_to_matrix(rec_pose_hands)
rec_pose_hands = rc.matrix_to_axis_angle(rec_pose_hands).reshape(bs * n, 30 * 3)
rec_pose_hands_recover = self.inverse_selection_tensor(
rec_pose_hands, self.joint_mask_hands, bs * n
)
rec_pose = (
rec_pose_upper_recover + rec_pose_lower_recover + rec_pose_hands_recover
)
rec_pose[:, 66:69] = tar_pose.reshape(bs * n, 55 * 3)[:, 66:69]
rec_pose = rc.axis_angle_to_matrix(rec_pose.reshape(bs * n, j, 3))
rec_pose = rc.matrix_to_rotation_6d(rec_pose).reshape(bs, n, j * 6)
tar_pose = rc.axis_angle_to_matrix(tar_pose.reshape(bs * n, j, 3))
tar_pose = rc.matrix_to_rotation_6d(tar_pose).reshape(bs, n, j * 6)
return {
"rec_pose": rec_pose,
"rec_exps": rec_exps,
"rec_trans": rec_trans,
"tar_pose": tar_pose,
"tar_exps": tar_exps,
"tar_beta": tar_beta,
"tar_trans": tar_trans,
}
def train(self, epoch):
self.model.train()
t_start = time.time()
self.tracker.reset()
for its, batch_data in enumerate(self.train_loader):
loaded_data = self._load_data(batch_data)
t_data = time.time() - t_start
self.opt.zero_grad()
g_loss_final = 0
g_loss_final += self._g_training(loaded_data, "train", epoch)
g_loss_final.backward()
if self.cfg.solver.grad_norm != 0:
torch.nn.utils.clip_grad_norm_(
self.model.parameters(), self.cfg.solver.grad_norm
)
self.opt.step()
mem_cost = torch.cuda.memory_cached() / 1e9
lr_g = self.opt.param_groups[0]["lr"]
t_train = time.time() - t_start - t_data
t_start = time.time()
if its % self.cfg.log_period == 0:
self.train_recording(epoch, its, t_data, t_train, mem_cost, lr_g)
if self.cfg.debug:
if its == 1:
break
self.opt_s.step(epoch)
@torch.no_grad()
def _common_test_inference(
self, data_loader, epoch, mode="val", max_iterations=None, save_results=False
):
"""
Common inference logic shared by val, test, test_clip, and test_render methods.
Args:
data_loader: The data loader to iterate over
epoch: Current epoch number
mode: Mode string for logging ("val", "test", "test_clip", "test_render")
max_iterations: Maximum number of iterations (None for no limit)
save_results: Whether to save result files
Returns:
Dictionary containing computed metrics and results
"""
start_time = time.time()
total_length = 0
test_seq_list = self.test_data.selected_file
align = 0
latent_out = []
latent_ori = []
l2_all = 0
lvel = 0
results = []
# Setup save path for test mode
results_save_path = None
if save_results:
results_save_path = self.checkpoint_path + f"/{epoch}/"
if mode == "test_render":
if os.path.exists(results_save_path):
import shutil
shutil.rmtree(results_save_path)
os.makedirs(results_save_path, exist_ok=True)
self.model.eval()
self.smplx.eval()
if hasattr(self, "eval_copy"):
self.eval_copy.eval()
with torch.no_grad():
iterator = enumerate(data_loader)
if mode in ["test_clip", "test"]:
iterator = enumerate(
tqdm(data_loader, desc=f"Testing {mode}", leave=True)
)
for its, batch_data in iterator:
if max_iterations is not None and its > max_iterations:
break
loaded_data = self._load_data(batch_data)
net_out = self._g_test(loaded_data)
tar_pose = net_out["tar_pose"]
rec_pose = net_out["rec_pose"]
tar_exps = net_out["tar_exps"]
tar_beta = net_out["tar_beta"]
rec_trans = net_out["rec_trans"]
tar_trans = net_out.get("tar_trans", rec_trans)
rec_exps = net_out.get("rec_exps", tar_exps)
bs, n, j = tar_pose.shape[0], tar_pose.shape[1], self.joints
# Handle frame rate conversion
if (30 / self.cfg.data.pose_fps) != 1:
assert 30 % self.cfg.data.pose_fps == 0
n *= int(30 / self.cfg.data.pose_fps)
tar_pose = torch.nn.functional.interpolate(
tar_pose.permute(0, 2, 1),
scale_factor=30 / self.cfg.data.pose_fps,
mode="linear",
).permute(0, 2, 1)
scale_factor = (
30 / self.cfg.data.pose_fps
if mode != "test"
else 30 / self.cfg.pose_fps
)
rec_pose = torch.nn.functional.interpolate(
rec_pose.permute(0, 2, 1),
scale_factor=scale_factor,
mode="linear",
).permute(0, 2, 1)
# Calculate latent representations for evaluation
if hasattr(self, "eval_copy") and mode != "test_render":
remain = n % self.cfg.vae_test_len
latent_out.append(
self.eval_copy.map2latent(rec_pose[:, : n - remain])
.reshape(-1, self.cfg.vae_length)
.detach()
.cpu()
.numpy()
)
latent_ori.append(
self.eval_copy.map2latent(tar_pose[:, : n - remain])
.reshape(-1, self.cfg.vae_length)
.detach()
.cpu()
.numpy()
)
rec_pose = rc.rotation_6d_to_matrix(rec_pose.reshape(bs * n, j, 6))
rec_pose = rc.matrix_to_axis_angle(rec_pose).reshape(bs * n, j * 3)
tar_pose = rc.rotation_6d_to_matrix(tar_pose.reshape(bs * n, j, 6))
tar_pose = rc.matrix_to_axis_angle(tar_pose).reshape(bs * n, j * 3)
# Generate SMPLX vertices and joints
vertices_rec = self.smplx(
betas=tar_beta.reshape(bs * n, 300),
transl=rec_trans.reshape(bs * n, 3) - rec_trans.reshape(bs * n, 3),
expression=tar_exps.reshape(bs * n, 100)
- tar_exps.reshape(bs * n, 100),
jaw_pose=rec_pose[:, 66:69],
global_orient=rec_pose[:, :3],
body_pose=rec_pose[:, 3 : 21 * 3 + 3],
left_hand_pose=rec_pose[:, 25 * 3 : 40 * 3],
right_hand_pose=rec_pose[:, 40 * 3 : 55 * 3],
return_joints=True,
leye_pose=rec_pose[:, 69:72],
reye_pose=rec_pose[:, 72:75],
)
joints_rec = (
vertices_rec["joints"]
.detach()
.cpu()
.numpy()
.reshape(bs, n, 127 * 3)[0, :n, : 55 * 3]
)
# Calculate L1 diversity
if hasattr(self, "l1_calculator"):
_ = self.l1_calculator.run(joints_rec)
# Calculate alignment for single batch
if (
hasattr(self, "alignmenter")
and self.alignmenter is not None
and bs == 1
and mode != "test_render"
):
in_audio_eval, sr = librosa.load(
self.cfg.data.data_path
+ "wave16k/"
+ test_seq_list.iloc[its]["id"]
+ ".wav"
)
in_audio_eval = librosa.resample(
in_audio_eval, orig_sr=sr, target_sr=self.cfg.data.audio_sr
)
a_offset = int(
self.align_mask
* (self.cfg.data.audio_sr / self.cfg.data.pose_fps)
)
onset_bt = self.alignmenter.load_audio(
in_audio_eval[
: int(self.cfg.data.audio_sr / self.cfg.data.pose_fps * n)
],
a_offset,
len(in_audio_eval) - a_offset,
True,
)
beat_vel = self.alignmenter.load_pose(
joints_rec, self.align_mask, n - self.align_mask, 30, True
)
align += self.alignmenter.calculate_align(
onset_bt, beat_vel, 30
) * (n - 2 * self.align_mask)
# Mode-specific processing
if mode == "test" and save_results:
# Calculate facial losses for test mode
vertices_rec_face = self.smplx(
betas=tar_beta.reshape(bs * n, 300),
transl=rec_trans.reshape(bs * n, 3)
- rec_trans.reshape(bs * n, 3),
expression=rec_exps.reshape(bs * n, 100),
jaw_pose=rec_pose[:, 66:69],
global_orient=rec_pose[:, :3] - rec_pose[:, :3],
body_pose=rec_pose[:, 3 : 21 * 3 + 3]
- rec_pose[:, 3 : 21 * 3 + 3],
left_hand_pose=rec_pose[:, 25 * 3 : 40 * 3]
- rec_pose[:, 25 * 3 : 40 * 3],
right_hand_pose=rec_pose[:, 40 * 3 : 55 * 3]
- rec_pose[:, 40 * 3 : 55 * 3],
return_verts=True,
return_joints=True,
leye_pose=rec_pose[:, 69:72] - rec_pose[:, 69:72],
reye_pose=rec_pose[:, 72:75] - rec_pose[:, 72:75],
)
vertices_tar_face = self.smplx(
betas=tar_beta.reshape(bs * n, 300),
transl=tar_trans.reshape(bs * n, 3)
- tar_trans.reshape(bs * n, 3),
expression=tar_exps.reshape(bs * n, 100),
jaw_pose=tar_pose[:, 66:69],
global_orient=tar_pose[:, :3] - tar_pose[:, :3],
body_pose=tar_pose[:, 3 : 21 * 3 + 3]
- tar_pose[:, 3 : 21 * 3 + 3],
left_hand_pose=tar_pose[:, 25 * 3 : 40 * 3]
- tar_pose[:, 25 * 3 : 40 * 3],
right_hand_pose=tar_pose[:, 40 * 3 : 55 * 3]
- tar_pose[:, 40 * 3 : 55 * 3],
return_verts=True,
return_joints=True,
leye_pose=tar_pose[:, 69:72] - tar_pose[:, 69:72],
reye_pose=tar_pose[:, 72:75] - tar_pose[:, 72:75],
)
facial_rec = (
vertices_rec_face["vertices"].reshape(1, n, -1)[0, :n].cpu()
)
facial_tar = (
vertices_tar_face["vertices"].reshape(1, n, -1)[0, :n].cpu()
)
face_vel_loss = self.vel_loss(
facial_rec[1:, :] - facial_tar[:-1, :],
facial_tar[1:, :] - facial_tar[:-1, :],
)
l2 = self.reclatent_loss(facial_rec, facial_tar)
l2_all += l2.item() * n
lvel += face_vel_loss.item() * n
# Save results if needed
if save_results:
if mode == "test":
# Save NPZ files for test mode
tar_pose_np = tar_pose.detach().cpu().numpy()
rec_pose_np = rec_pose.detach().cpu().numpy()
rec_trans_np = (
rec_trans.detach().cpu().numpy().reshape(bs * n, 3)
)
rec_exp_np = (
rec_exps.detach().cpu().numpy().reshape(bs * n, 100)
)
tar_exp_np = (
tar_exps.detach().cpu().numpy().reshape(bs * n, 100)
)
tar_trans_np = (
tar_trans.detach().cpu().numpy().reshape(bs * n, 3)
)
gt_npz = np.load(
self.cfg.data.data_path
+ self.cfg.data.pose_rep
+ "/"
+ test_seq_list.iloc[its]["id"]
+ ".npz",
allow_pickle=True,
)
np.savez(
results_save_path
+ "gt_"
+ test_seq_list.iloc[its]["id"]
+ ".npz",
betas=gt_npz["betas"],
poses=tar_pose_np,
expressions=tar_exp_np,
trans=tar_trans_np,
model="smplx2020",
gender="neutral",
mocap_frame_rate=30,
)
np.savez(
results_save_path
+ "res_"
+ test_seq_list.iloc[its]["id"]
+ ".npz",
betas=gt_npz["betas"],
poses=rec_pose_np,
expressions=rec_exp_np,
trans=rec_trans_np,
model="smplx2020",
gender="neutral",
mocap_frame_rate=30,
)
elif mode == "test_render":
# Save results and render for test_render mode
audio_name = loaded_data["audio_name"][0]
rec_pose_np = rec_pose.detach().cpu().numpy()
rec_trans_np = (
rec_trans.detach().cpu().numpy().reshape(bs * n, 3)
)
rec_exp_np = (
rec_exps.detach().cpu().numpy().reshape(bs * n, 100)
)
gt_npz = np.load(
"./demo/examples/2_scott_0_1_1.npz", allow_pickle=True
)
file_name = audio_name.split("/")[-1].split(".")[0]
results_npz_file_save_path = (
results_save_path + f"result_{file_name}.npz"
)
np.savez(
results_npz_file_save_path,
betas=gt_npz["betas"],
poses=rec_pose_np,
expressions=rec_exp_np,
trans=rec_trans_np,
model="smplx2020",
gender="neutral",
mocap_frame_rate=30,
)
render_vid_path = other_tools_hf.render_one_sequence_no_gt(
results_npz_file_save_path,
results_save_path,
audio_name,
self.cfg.data_path_1 + "smplx_models/",
use_matplotlib=False,
args=self.cfg,
)
total_length += n
return {
"total_length": total_length,
"align": align,
"latent_out": latent_out,
"latent_ori": latent_ori,
"l2_all": l2_all,
"lvel": lvel,
"start_time": start_time,
}
def val(self, epoch):
self.tracker.reset()
results = self._common_test_inference(
self.test_loader, epoch, mode="val", max_iterations=15
)
total_length = results["total_length"]
align = results["align"]
latent_out = results["latent_out"]
latent_ori = results["latent_ori"]
l2_all = results["l2_all"]
lvel = results["lvel"]
start_time = results["start_time"]
logger.info(f"l2 loss: {l2_all/total_length:.10f}")
logger.info(f"lvel loss: {lvel/total_length:.10f}")
latent_out_all = np.concatenate(latent_out, axis=0)
latent_ori_all = np.concatenate(latent_ori, axis=0)
fgd = data_tools.FIDCalculator.frechet_distance(latent_out_all, latent_ori_all)
logger.info(f"fgd score: {fgd}")
self.tracker.update_meter("fgd", "val", fgd)
align_avg = align / (total_length - 2 * len(self.test_loader) * self.align_mask)
logger.info(f"align score: {align_avg}")
self.tracker.update_meter("bc", "val", align_avg)
l1div = self.l1_calculator.avg()
logger.info(f"l1div score: {l1div}")
self.tracker.update_meter("l1div", "val", l1div)
self.val_recording(epoch)
end_time = time.time() - start_time
logger.info(
f"total inference time: {int(end_time)} s for {int(total_length/self.cfg.data.pose_fps)} s motion"
)
def test_clip(self, epoch):
self.tracker.reset()
# Test on CLIP dataset
results_clip = self._common_test_inference(
self.test_clip_loader, epoch, mode="test_clip"
)
total_length_clip = results_clip["total_length"]
latent_out_clip = results_clip["latent_out"]
latent_ori_clip = results_clip["latent_ori"]
start_time = results_clip["start_time"]
latent_out_all_clip = np.concatenate(latent_out_clip, axis=0)
latent_ori_all_clip = np.concatenate(latent_ori_clip, axis=0)
fgd_clip = data_tools.FIDCalculator.frechet_distance(
latent_out_all_clip, latent_ori_all_clip
)
logger.info(f"test_clip fgd score: {fgd_clip}")
self.tracker.update_meter("test_clip_fgd", "val", fgd_clip)
current_time = time.time()
test_clip_time = current_time - start_time
logger.info(
f"total test_clip inference time: {int(test_clip_time)} s for {int(total_length_clip/self.cfg.data.pose_fps)} s motion"
)
# Test on regular test dataset for recording
results_test = self._common_test_inference(
self.test_loader, epoch, mode="test_clip"
)
total_length = results_test["total_length"]
align = results_test["align"]
latent_out = results_test["latent_out"]
latent_ori = results_test["latent_ori"]
latent_out_all = np.concatenate(latent_out, axis=0)
latent_ori_all = np.concatenate(latent_ori, axis=0)
fgd = data_tools.FIDCalculator.frechet_distance(latent_out_all, latent_ori_all)
logger.info(f"fgd score: {fgd}")
self.tracker.update_meter("fgd", "val", fgd)
align_avg = align / (total_length - 2 * len(self.test_loader) * self.align_mask)
logger.info(f"align score: {align_avg}")
self.tracker.update_meter("bc", "val", align_avg)
l1div = self.l1_calculator.avg()
logger.info(f"l1div score: {l1div}")
self.tracker.update_meter("l1div", "val", l1div)
self.val_recording(epoch)
end_time = time.time() - current_time
logger.info(
f"total inference time: {int(end_time)} s for {int(total_length/self.cfg.data.pose_fps)} s motion"
)
def test(self, epoch):
results_save_path = self.checkpoint_path + f"/{epoch}/"
os.makedirs(results_save_path, exist_ok=True)
results = self._common_test_inference(
self.test_loader, epoch, mode="test", save_results=True
)
total_length = results["total_length"]
align = results["align"]
latent_out = results["latent_out"]
latent_ori = results["latent_ori"]
l2_all = results["l2_all"]
lvel = results["lvel"]
start_time = results["start_time"]
logger.info(f"l2 loss: {l2_all/total_length:.10f}")
logger.info(f"lvel loss: {lvel/total_length:.10f}")
latent_out_all = np.concatenate(latent_out, axis=0)
latent_ori_all = np.concatenate(latent_ori, axis=0)
fgd = data_tools.FIDCalculator.frechet_distance(latent_out_all, latent_ori_all)
logger.info(f"fgd score: {fgd}")
self.test_recording("fgd", fgd, epoch)
align_avg = align / (total_length - 2 * len(self.test_loader) * self.align_mask)
logger.info(f"align score: {align_avg}")
self.test_recording("bc", align_avg, epoch)
l1div = self.l1_calculator.avg()
logger.info(f"l1div score: {l1div}")
self.test_recording("l1div", l1div, epoch)
end_time = time.time() - start_time
logger.info(
f"total inference time: {int(end_time)} s for {int(total_length/self.cfg.data.pose_fps)} s motion"
)
def test_render(self, epoch):
import platform
if platform.system() == "Linux":
os.environ["PYOPENGL_PLATFORM"] = "egl"
"""
input audio and text, output motion
do not calculate loss and metric
save video
"""
results = self._common_test_inference(
self.test_loader, epoch, mode="test_render", save_results=True
)
def load_checkpoint(self, checkpoint):
# checkpoint is already a dict, do NOT call torch.load again!
try:
ckpt_state_dict = checkpoint["model_state_dict"]
except:
ckpt_state_dict = checkpoint["model_state"]
# remove 'audioEncoder' from the state_dict due to legacy issues
ckpt_state_dict = {
k: v
for k, v in ckpt_state_dict.items()
if "modality_encoder.audio_encoder." not in k
}
self.model.load_state_dict(ckpt_state_dict, strict=False)
try:
self.opt.load_state_dict(checkpoint["optimizer_state_dict"])
except:
print("No optimizer loaded!")
if (
"scheduler_state_dict" in checkpoint
and checkpoint["scheduler_state_dict"] is not None
):
self.opt_s.load_state_dict(checkpoint["scheduler_state_dict"])
if "val_best" in checkpoint:
self.val_best = checkpoint["val_best"]
logger.info("Checkpoint loaded successfully.")