Shen Feiyu
init at 250916
71cd91e
import math
import torch
def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
batch_size = lengths.size(0)
max_len = max_len if max_len > 0 else lengths.max().item()
seq_range = torch.arange(0, max_len, dtype=torch.int64, device=lengths.device)
seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
seq_length_expand = lengths.unsqueeze(-1)
mask = seq_range_expand >= seq_length_expand
return mask # (b, t)
def make_nonpad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
return ~make_pad_mask(lengths, max_len)
def make_block_causal_mask(
lengths: torch.Tensor, max_len: int = 0, chunk_size: int = 4
) -> torch.Tensor:
mask = make_nonpad_mask(lengths, max_len) # (b, t)
attn_mask = torch.logical_and(mask.unsqueeze(1), mask.unsqueeze(2)) # (b, t, t)
num_blocks = math.ceil(attn_mask.shape[1] / chunk_size)
block_mask = torch.block_diag(
*[torch.ones(chunk_size, chunk_size) for _ in range(num_blocks)]
)
block_mask = block_mask[: attn_mask.shape[1], : attn_mask.shape[1]].to(
attn_mask
) # (t, t)
diag_mask = attn_mask.new_full(
(1, attn_mask.shape[1], attn_mask.shape[2]), fill_value=True
).tril() # (1, t, t)
diag_mask = diag_mask.logical_or(block_mask)
attn_mask = attn_mask.logical_and(diag_mask)
return attn_mask