File size: 1,393 Bytes
71cd91e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import math
import torch


def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
    batch_size = lengths.size(0)
    max_len = max_len if max_len > 0 else lengths.max().item()
    seq_range = torch.arange(0, max_len, dtype=torch.int64, device=lengths.device)
    seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
    seq_length_expand = lengths.unsqueeze(-1)
    mask = seq_range_expand >= seq_length_expand
    return mask  # (b, t)


def make_nonpad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
    return ~make_pad_mask(lengths, max_len)


def make_block_causal_mask(
    lengths: torch.Tensor, max_len: int = 0, chunk_size: int = 4
) -> torch.Tensor:
    mask = make_nonpad_mask(lengths, max_len)  # (b, t)
    attn_mask = torch.logical_and(mask.unsqueeze(1), mask.unsqueeze(2))  # (b, t, t)

    num_blocks = math.ceil(attn_mask.shape[1] / chunk_size)
    block_mask = torch.block_diag(
        *[torch.ones(chunk_size, chunk_size) for _ in range(num_blocks)]
    )
    block_mask = block_mask[: attn_mask.shape[1], : attn_mask.shape[1]].to(
        attn_mask
    )  # (t, t)

    diag_mask = attn_mask.new_full(
        (1, attn_mask.shape[1], attn_mask.shape[2]), fill_value=True
    ).tril()  # (1, t, t)
    diag_mask = diag_mask.logical_or(block_mask)
    attn_mask = attn_mask.logical_and(diag_mask)
    return attn_mask