Spaces:
Runtime error
Runtime error
| import math | |
| import torch | |
| def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor: | |
| batch_size = lengths.size(0) | |
| max_len = max_len if max_len > 0 else lengths.max().item() | |
| seq_range = torch.arange(0, max_len, dtype=torch.int64, device=lengths.device) | |
| seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len) | |
| seq_length_expand = lengths.unsqueeze(-1) | |
| mask = seq_range_expand >= seq_length_expand | |
| return mask # (b, t) | |
| def make_nonpad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor: | |
| return ~make_pad_mask(lengths, max_len) | |
| def make_block_causal_mask( | |
| lengths: torch.Tensor, max_len: int = 0, chunk_size: int = 4 | |
| ) -> torch.Tensor: | |
| mask = make_nonpad_mask(lengths, max_len) # (b, t) | |
| attn_mask = torch.logical_and(mask.unsqueeze(1), mask.unsqueeze(2)) # (b, t, t) | |
| num_blocks = math.ceil(attn_mask.shape[1] / chunk_size) | |
| block_mask = torch.block_diag( | |
| *[torch.ones(chunk_size, chunk_size) for _ in range(num_blocks)] | |
| ) | |
| block_mask = block_mask[: attn_mask.shape[1], : attn_mask.shape[1]].to( | |
| attn_mask | |
| ) # (t, t) | |
| diag_mask = attn_mask.new_full( | |
| (1, attn_mask.shape[1], attn_mask.shape[2]), fill_value=True | |
| ).tril() # (1, t, t) | |
| diag_mask = diag_mask.logical_or(block_mask) | |
| attn_mask = attn_mask.logical_and(diag_mask) | |
| return attn_mask | |