YModel2-s0 / ymodel2.py
SnifferCaptain's picture
upload model
fc37649 verified
import math
import torch
import torch.nn as nn
from typing import Optional, Tuple, Union, List
from transformers import PreTrainedModel, GenerationMixin
from transformers.activations import ACT2FN
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.configuration_utils import PretrainedConfig
class YConfig2(PretrainedConfig):
model_type = "ynet2"
def __init__(
self,
dropout: float = 0.1,
bos_token_id: int = 1,
eos_token_id: int = 2,
hidden_act: str = 'gelu_pytorch_tanh',# silu 4.687 / gelu 4.662 / mish 4.695 / relu2 4.755 / laplace
hidden_size: int = 768,
num_layers: int = 9,
max_position_embeddings: int = 8192,
vocab_size: int = 6400,
rms_norm_eps: float = 1e-8,
rope_theta: int = 5e4,# 5e4
self_distill: bool = True,
### FFN ###
intermediate_size: int = None, # 512 * 4 (full [4] / 256) = 2048 (2 ** 17)
### attn ###
num_heads: int = 4,
head_dim: int = 64,
**kwargs
):
super().__init__(**kwargs)
self.dropout = dropout
self.bos_token_id = bos_token_id
self.eos_token_id = eos_token_id
self.hidden_act = hidden_act
self.hidden_size = hidden_size
self.num_layers = num_layers # 层数
self.max_position_embeddings = max_position_embeddings
self.vocab_size = vocab_size
self.rms_norm_eps = rms_norm_eps
self.rope_theta = rope_theta
self.self_distill = self_distill
### FFN ###
self.intermediate_size = intermediate_size # FFN中间维度
### attn ###
self.num_heads = num_heads # q头数
self.head_dim = head_dim # 头维度
def scale_lvl(self, lvl:int=0):
if lvl == 0:
# normal settings [99.312m]
self.num_layers = 16
self.hidden_size = 768
self.num_heads = 16
self.head_dim = 128
self.intermediate_size = 2048
elif lvl == -1:
self.num_layers = 8
self.hidden_size = 512 # base = 4.662 16h/64d 30
self.num_heads = 8 # 2*heads 4.578/20.84
self.head_dim = 64 # 2*dim 4.576/22.8
self.intermediate_size = 1536
else:
raise ValueError(f"Invalid level: {lvl}")
class RMSNorm(torch.nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim, dtype=torch.float32))
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
output = self._norm(x.float())
output = output * self.weight.float()
return output.type_as(x)
def precompute_freqs_cis(dim: int, end: int = int(32 * 1024), theta: float = 5e4):
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
t = torch.arange(end, device=freqs.device)
freqs = torch.outer(t, freqs).float()
freqs_cos = torch.cat([torch.cos(freqs), torch.cos(freqs)], dim=-1)
freqs_sin = torch.cat([torch.sin(freqs), torch.sin(freqs)], dim=-1)
return freqs_cos, freqs_sin
def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=0):
def rotate_half(x):
return torch.cat((-x[..., x.shape[-1] // 2:], x[..., : x.shape[-1] // 2]), dim=-1)
q_embed = (q * cos.unsqueeze(unsqueeze_dim)) + (rotate_half(q) * sin.unsqueeze(unsqueeze_dim))
k_embed = (k * cos.unsqueeze(unsqueeze_dim)) + (rotate_half(k) * sin.unsqueeze(unsqueeze_dim))
return q_embed, k_embed
def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
"""torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
b, h, l, ch = x.shape
if n_rep == 1:
return x
return (
x[:, :, None, :, :]
.expand(b, h, n_rep, l, ch)
.reshape(b, h * n_rep, l, ch)
)
class FFN(nn.Module):
def __init__(self, config: YConfig2):
super().__init__()
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size or int(2.5 * config.hidden_size)
self.gate_act = ACT2FN[config.hidden_act]
self.up = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=False)
# self.up = nn.Linear(self.hidden_size, self.intermediate_size)
# self.gate = nn.Linear(self.hidden_size, self.intermediate_size)
self.down = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x, g = self.up(x).chunk(2, dim=-1)
# x, g = self.up(x), self.gate(x)
x = self.gate_act(g) * x
x = self.down(x)
return x
class PEGA2(nn.Module):
def __init__(self, config: YConfig2):
super().__init__()
self.dropout = config.dropout # dropout rate
self.hidden_size = config.hidden_size # 输入通道大小
self.num_heads = config.num_heads # 总注意力头数
self.head_dim = config.head_dim # 每个头的维度
self.gate_act = ACT2FN[config.hidden_act]
self.delta_kv_only = False
assert self.num_heads % 2 == 0, "num_heads must be even."
# 2d opt: fused 29.5/4.693 split: 28.7/4.791
# qpe, q
self.qkv_list = [
self.num_heads // 2 * self.head_dim, # qpe
self.num_heads // 2 * self.head_dim, # qnope
self.head_dim, # kpe
self.head_dim, # kv
]
self.qkv = nn.Sequential(
nn.Linear(self.hidden_size, self.head_dim, bias=False),
nn.Linear(self.head_dim, sum(self.qkv_list), bias=False)
)
# self.z = nn.Linear(self.hidden_size, self.head_dim, bias=False)
# self.qpe = nn.Linear(self.head_dim, self.num_heads // 2 * self.head_dim, bias=False)
# self.qnope = nn.Linear(self.head_dim, self.num_heads // 2 * self.head_dim, bias=False)
# self.kpe = nn.Linear(self.head_dim, self.head_dim, bias=False)
# self.kv = nn.Linear(self.head_dim, self.head_dim, bias=False)
self.o = nn.Linear(self.head_dim // 2 * self.num_heads, self.hidden_size, bias=False)
self.rsqrt_dim = 1.0 / math.sqrt(self.head_dim)
# init 2k 4.693 --> 4.687
scale_lora = math.sqrt(
(sum(self.qkv_list) + self.head_dim) * (self.head_dim + self.head_dim) /
(2 * self.head_dim * (self.hidden_size + sum(self.qkv_list)))
)
self.qkv[1].weight.data *= scale_lora
def forward(
self,
x: torch.Tensor,
position_embeddings: Tuple[torch.Tensor, torch.Tensor],
past_key_value: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
cos, sin = position_embeddings # [L, head_dim]
b, l, _ = x.shape
# fused
qkv = self.qkv(x)
qpe, q, kpe, kv = torch.split(qkv, self.qkv_list, dim=-1)# [b, l, hd * h // 2] [b, l, hd]
# z = self.z(x)
# qpe, q, kpe, kv = (
# self.qpe(z),
# self.qnope(z),
# self.kpe(z),
# self.kv(z)
# )
# 应用 RoPE
q = q.view(b, l, self.num_heads // 2, self.head_dim).permute(0, 2, 1, 3) # [b, l, h // 2, hd]
qpe = qpe.view(b, l, self.num_heads // 2, self.head_dim).permute(0, 2, 1, 3)# [b, l, h // 2, hd]
kv = kv.unsqueeze(1) # [b, 1, l, hd]
kpe = kpe.unsqueeze(1) # [b, 1, l, hd]
qpe, kpe = apply_rotary_pos_emb(qpe, kpe, cos[:l], sin[:l])
# 拼合
q = torch.cat([qpe, q], dim=1) # [b, h, l, hd]
kv = torch.cat([kpe, kv], dim=1) # [b, 2, l, hd]
deltakv = None
if self.delta_kv_only:
# 仅返回 delta kv
deltakv = kv
# kv_cache实现
if past_key_value is not None:
kv = torch.cat([past_key_value, kv], dim=2)
past_kv = kv if use_cache else None
_, _, l_all, _ = kv.shape
dropout_p = self.dropout if self.training else 0.0
attn_mask = None
if attention_mask is not None:
attn_mask = attention_mask.view(b, 1, 1, -1).expand(b, 1, l, -1)
attn_mask = attn_mask.bool() if attention_mask is not None else None
if self.training:
o = nn.functional.scaled_dot_product_attention(
q, repeat_kv(kv, self.num_heads // 2), repeat_kv(kv[:, 1:, :, :], self.num_heads),
attn_mask=attn_mask, dropout_p=dropout_p if self.training else 0.0, is_causal=True
)
else:
o = self.sdpa_math(
q, repeat_kv(kv, self.num_heads // 2), repeat_kv(kv[:, 1:, :, :], self.num_heads),
attn_mask, 0.0
)
# o: [b, h, l, hc]
# gate 2k4b peg: 5.169 nopeg: 5.179 +gate:5.210(4.622)
ope, onope = o.permute(0, 2, 1, 3).chunk(2, dim=2) # [b, l, h // 2, hc]
# o = onope * self.gate_act(ope) # [b, l, h // 2, hc] not stable
o = ope * self.gate_act(onope) # [b, l, h // 2, hc] testing
out = o.reshape(b, l, -1)
out = self.o(out)
out = nn.functional.dropout(out, p=self.dropout, training=self.training)
return out, (deltakv if self.delta_kv_only else past_kv)
def sdpa_math(self, q:torch.Tensor, k:torch.Tensor, v:torch.Tensor, attn_mask: Optional[torch.Tensor] = None,
dropout_p: float = 0.0) -> torch.Tensor:
b, h, l, c = q.shape
scores = (q @ k.transpose(-2, -1)) * self.rsqrt_dim
casual_mask = torch.triu(
torch.full((l, l), float("-inf"), device=scores.device),
diagonal=1
).unsqueeze(0).unsqueeze(0)# [1, 1, l, l]
# 在左侧 zero pad 到 scores 的形状 [1, 1, l, l_all]
casual_mask = nn.functional.pad(casual_mask, (scores.shape[-1] - l, 0), "constant", 0.0)# [1, 1, l, l_all]
scores += casual_mask
if attn_mask is not None:
attn_mask = (1.0 - attn_mask.type_as(scores)) * -1e9
scores = scores + attn_mask
scores = nn.functional.softmax(scores.float(), dim=-1).type_as(q)
scores = nn.functional.dropout(scores, p=dropout_p, training=self.training)# [b, h, l, l]
output = scores @ v
return output
def use_delta_kv_only(self, enable:bool=True):
# 仅返回 delta kv,减少内存开销
self.delta_kv_only = enable
class Attn(nn.Module):
def __init__(self, config: YConfig2):
super().__init__()
self.dropout = config.dropout # dropout rate
self.hidden_size = config.hidden_size # 输入通道大小
self.num_heads = config.num_heads # 总注意力头数
self.head_dim = config.head_dim # 每个头的维度
self.gate_act = ACT2FN[config.hidden_act]
self.delta_kv_only = False
assert self.num_heads % 2 == 0, "num_heads must be even."
##### sparse #####
# qpe, q
self.qkv_list = [
self.num_heads * self.head_dim, # q
2 * self.head_dim, # k
2 * self.head_dim, # v
]
self.qkv = nn.Linear(self.hidden_size, sum(self.qkv_list), bias=False)
self.o = nn.Linear(self.head_dim * self.num_heads, self.hidden_size, bias=False)
def forward(
self,
x: torch.Tensor,
position_embeddings: Tuple[torch.Tensor, torch.Tensor],
past_key_value: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
cos, sin = position_embeddings # [L, head_dim]
b, l, _ = x.shape
# dense
qkv = self.qkv(x)
q, k, v = torch.split(qkv, self.qkv_list, dim=-1)# [b, l, hd * h // 2] [b, l, hd]
# qpe, q, kpe, kv = (
# self.qpe(x),
# self.qnope(x),
# self.kpe(x),
# self.kv(x)
# )
# 应用 RoPE
q = q.view(b, l, self.num_heads, self.head_dim).permute(0, 2, 1, 3) # [b, l, h // 2, hd]
k = k.view(b, l, 2, self.head_dim).permute(0, 2, 1, 3) # [b, 2, l, hd]
v = v.view(b, l, 2, self.head_dim).permute(0, 2, 1, 3) # [b, 2, l, hd]
q, k = apply_rotary_pos_emb(q, k, cos[:l], sin[:l])
deltakv = None
if self.delta_kv_only:
# 仅返回 delta kv
deltakv = None
# kv_cache实现
if past_key_value is not None:
k = torch.cat([past_key_value[0], k], dim=1)
v = torch.cat([past_key_value[1], v], dim=1)
past_kv = (k, v) if use_cache else None
_, _, l_all, _ = k.shape
dropout_p = self.dropout if self.training else 0.0
attn_mask = None
if attention_mask is not None:
attn_mask = attention_mask.view(b, 1, 1, -1).expand(b, 1, l, -1)
attn_mask = attn_mask.bool() if attention_mask is not None else None
if self.training:
o = nn.functional.scaled_dot_product_attention(
q, repeat_kv(k, self.num_heads//2), repeat_kv(v, self.num_heads//2),
attn_mask=attn_mask, dropout_p=dropout_p if self.training else 0.0, is_causal=True
)
else:
o = self.sdpa_math(
q, repeat_kv(k, self.num_heads // 2), repeat_kv(v, self.num_heads),
attn_mask, 0.0
)
# o: [b, h, l, hc]
out = o.permute(0, 2, 1, 3).reshape(b, l, -1)
out = self.o(out)
out = nn.functional.dropout(out, p=self.dropout, training=self.training)
return out, (deltakv if self.delta_kv_only else past_kv)
def sdpa_math(self, q:torch.Tensor, k:torch.Tensor, v:torch.Tensor, attn_mask: Optional[torch.Tensor] = None,
dropout_p: float = 0.0) -> torch.Tensor:
b, h, l, c = q.shape
scores = (q @ k.transpose(-2, -1)) * self.rsqrt_dim
casual_mask = torch.triu(
torch.full((l, l), float("-inf"), device=scores.device),
diagonal=1
).unsqueeze(0).unsqueeze(0)# [1, 1, l, l]
# 在左侧 zero pad 到 scores 的形状 [1, 1, l, l_all]
casual_mask = nn.functional.pad(casual_mask, (scores.shape[-1] - l, 0), "constant", 0.0)# [1, 1, l, l_all]
scores += casual_mask
if attn_mask is not None:
attn_mask = (1.0 - attn_mask.type_as(scores)) * -1e9
scores = scores + attn_mask
scores = nn.functional.softmax(scores.float(), dim=-1).type_as(q)
scores = nn.functional.dropout(scores, p=dropout_p, training=self.training)# [b, h, l, l]
output = scores @ v
return output
def use_delta_kv_only(self, enable:bool=True):
# 仅返回 delta kv,减少内存开销
self.delta_kv_only = enable
class YBlock2(nn.Module):
def __init__(self, config: YConfig2):
super().__init__()
self.attn = PEGA2(config)
self.ffn = FFN(config)
self.norm1 = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.norm2 = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(self,
x: torch.Tensor,
position_embeddings: Tuple[torch.Tensor, torch.Tensor],
past_key_value: Optional[torch.Tensor] = None, # ffn_shard * kv cache
use_cache: bool = False,
attention_mask: Optional[torch.Tensor] = None
):
# attention
residual = x
x = self.norm1(x)
attn_out, past_kv = self.attn(
x,
position_embeddings,
past_key_value=past_key_value,
attention_mask=attention_mask,
use_cache=use_cache,
)
x = residual + attn_out
# ffn
residual = x
x = self.norm2(x)
moe_out = self.ffn(x)
x = residual + moe_out
return x, past_kv
def use_delta_kv_only(self, enable:bool=True):
self.attn.use_delta_kv_only(enable)
class YModel2(nn.Module):
def __init__(self, config: YConfig2):
super().__init__()
self.vocab_size = config.vocab_size
self.num_layers = config.num_layers
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
self.dropout = config.dropout
self.use_self_distill = config.self_distill
self.layers = nn.ModuleList([
YBlock2(config) for _ in range(config.num_layers)
])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
freqs_cos, freqs_sin = precompute_freqs_cis(dim=config.head_dim,
end=config.max_position_embeddings, theta=config.rope_theta)
self.register_buffer("freqs_cos", freqs_cos, persistent=False)
self.register_buffer("freqs_sin", freqs_sin, persistent=False)
def forward(self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[List[torch.Tensor]] = None,
use_cache: bool = False,
**kwargs
):
batch_size, seq_length = input_ids.shape
past_key_values = past_key_values or [None] * self.num_layers
start_pos = past_key_values[0].shape[-2] if past_key_values[0] is not None else 0
x = self.embed_tokens(input_ids)
x = nn.functional.dropout(x, p=self.dropout, training=self.training)
position_embeddings = (
self.freqs_cos[start_pos:start_pos + seq_length],
self.freqs_sin[start_pos:start_pos + seq_length]
)
presents = []
cos_loss = None
for i, layer in enumerate(self.layers):
x0 = x
x, past_kv = layer(
x=x,
position_embeddings=position_embeddings,
past_key_value=past_key_values[i],
attention_mask=attention_mask,
use_cache=use_cache
)
if self.training and self.use_self_distill:
xd = x.detach()
# cosine loss
c_loss = 1.0 - nn.functional.cosine_similarity(x0, xd, dim=-1).mean()
cos_loss = c_loss + cos_loss if cos_loss is not None else c_loss
presents.append(past_kv)
if cos_loss is not None:
cos_loss = cos_loss / self.num_layers
x = self.norm(x)
return x, presents, cos_loss
def delta_kv_only(self, delta_kv:bool=True):
for layer in self.layers:
layer.use_delta_kv_only(delta_kv)
class YForCausalLM2(PreTrainedModel, GenerationMixin):
config_class = YConfig2
def __init__(self, config: YConfig2 = None):
self.config = config or YConfig2()
super().__init__(self.config)
self.model = YModel2(self.config)
self.lm_head = nn.Linear(self.config.hidden_size, self.config.vocab_size, bias=False)
self.model.embed_tokens.weight = self.lm_head.weight
self.OUT = CausalLMOutputWithPast()
def forward(self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
use_cache: bool = False,
logits_to_keep: Union[int, torch.Tensor] = 0,
**args):
h, past_kvs, cos_loss = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
use_cache=use_cache,
**args
)
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
logits = self.lm_head(h[:, slice_indices, :])
self.OUT.__setitem__('last_hidden_state', h)
self.OUT.__setitem__('logits', logits)
self.OUT.__setitem__('past_key_values', past_kvs)
if self.config.self_distill:
self.OUT.__setitem__('dist_loss', cos_loss)
return self.OUT
def delta_kv_only(self, delta_kv:bool=True):
self.model.delta_kv_only(delta_kv)