-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodeling_cheems_OTCE.py
2775 lines (2338 loc) · 127 KB
/
modeling_cheems_OTCE.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# coding=utf-8
# Copyright 2024 Jingze Shi and the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch Cheems OTCE model."""
import inspect
import math
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from transformers.activations import ACT2FN
from transformers.cache_utils import Cache, DynamicCache
from transformers.modeling_attn_mask_utils import (
_prepare_4d_causal_attention_mask,
AttentionMaskConverter,
)
from transformers.modeling_outputs import (
MoeCausalLMOutputWithPast,
MoeModelOutputWithPast,
SequenceClassifierOutputWithPast,
)
from transformers.modeling_utils import PreTrainedModel
from transformers.pytorch_utils import is_torch_greater_or_equal_than_1_13
from transformers.utils import (
is_flash_attn_greater_or_equal_2_10,
logging,
)
from transformers.utils.import_utils import (
is_flash_attn_2_available,
is_mamba_ssm_available,
is_causal_conv1d_available
)
from transformers.utils.import_utils import is_torch_fx_available
from .configuration_cheems_OTCE import CheemsOTCEConfig
if is_flash_attn_2_available():
from flash_attn import flash_attn_func, flash_attn_varlen_func
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
_flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters)
if is_torch_fx_available():
if not is_torch_greater_or_equal_than_1_13:
import torch.fx
_prepare_4d_causal_attention_mask = torch.fx.wrap(_prepare_4d_causal_attention_mask)
if is_mamba_ssm_available():
from mamba_ssm.ops.selective_scan_interface import mamba_inner_fn, selective_scan_fn
from mamba_ssm.ops.triton.selective_state_update import selective_state_update
else:
selective_state_update, selective_scan_fn, mamba_inner_fn = None, None, None
if is_causal_conv1d_available():
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
else:
causal_conv1d_fn, causal_conv1d_update = None, None
is_fast_path_available = all(
(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)
)
logger = logging.get_logger(__name__)
def load_balancing_loss_func(
gate_logits: torch.Tensor,
num_experts: torch.Tensor = None,
top_k=2,
attention_mask: Optional[torch.Tensor] = None
) -> float:
r"""
计算辅助负载平衡损失, 如Switch Transformer中所述 - 在Pytorch中实现.
有关更多详细信息, 请参见Switch Transformer (https://arxiv.org/abs/2101.03961). 该函数实现了论文中方程(4) - (6)中呈现的损失函数.
它的目的是惩罚专家之间路由太不平衡的情况.
Args:
gate_logits (Union[`torch.Tensor`, Tuple[torch.Tensor]):
`router`的logits, 应该是一个形状为[batch_size X sequence_length, num_experts]的model.config.num_hidden_layers张量的元组.
attention_mask (`torch.Tensor`, None):
在forward函数中使用的attention_mask
如果不为None, 形状为[batch_size X sequence_length].
num_experts (`int`, *optional*):
专家的数量
Returns:
辅助损失.
Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch.
See Switch Transformer (https://arxiv.org/abs/2101.03961) for more details. This function implements the loss
function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between
experts is too unbalanced.
Args:
router_logits (Union[`torch.Tensor`, Tuple[torch.Tensor]):
Logits from the `router`, should be a tuple of model.config.num_hidden_layers tensors of
shape [batch_size X sequence_length, num_experts].
attention_mask (`torch.Tensor`, None):
The attention_mask used in forward function
shape [batch_size X sequence_length] if not None.
num_experts (`int`, *optional*):
Number of experts
Returns:
The auxiliary loss.
"""
if gate_logits is None or not isinstance(gate_logits, tuple):
return 0
if isinstance(gate_logits, tuple):
compute_device = gate_logits[0].device
concatenated_gate_logits = torch.cat(
[layer_gate.to(compute_device) for layer_gate in gate_logits if layer_gate.shape[1] > 1], dim=0
)
routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1)
_, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts)
if attention_mask is not None:
# 计算路由到每个专家的tokens百分比
# Compute the percentage of tokens routed to each experts
tokens_per_expert = torch.mean(expert_mask.float(), dim=0)
# 计算路由到这些专家的平均概率
# Compute the average probability of routing to these experts
router_prob_per_expert = torch.mean(routing_weights, dim=0)
else:
batch_size, sequence_length = attention_mask.shape
num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length)
# 计算掩盖所有填充tokens为0的掩码, 其形状与expert_mask相同
# Compute the mask that masks all padding tokens to 0, with the same shape as expert_mask
expert_attention_mask = (
attention_mask[None, :, :, None, None]
.expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts))
.reshape(-1, top_k, num_experts)
.to(compute_device)
)
# 计算路由到每个专家的tokens百分比
# Compute the percentage of tokens routed to each experts
tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum(
expert_attention_mask, dim=0
)
# 计算掩盖所有填充tokens为0的掩码, 其形状与tokens_per_expert相同
# Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert
router_per_expert_attention_mask = (
attention_mask[None, :, :, None]
.expand((num_hidden_layers, batch_size, sequence_length, num_experts))
.reshape(-1, num_experts)
.to(compute_device)
)
# 计算路由到这些专家的平均概率
# Compute the average probability of routing to these experts
router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum(
router_per_expert_attention_mask, dim=0
)
overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0))
return overall_loss * num_experts
def _get_unpad_data(attention_mask):
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
max_seqlen_in_batch = seqlens_in_batch.max().item()
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
return (
indices,
cu_seqlens,
max_seqlen_in_batch,
)
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
这是torch.repeat_interleave(x, dim=1, repeats=n_rep)的等效版本. 隐藏状态从(batch, num_key_value_heads, seqlen, head_dim)变为(batch, num_attention_heads, seqlen, head_dim)
This is an equivalent version of torch.repeat_interleave(x, dim=1, repeats=n_rep). Hidden states go from (batch, num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
def rotate_half(x):
"""
旋转输入的一半隐藏维度.
Rotates half the hidden dims of the input.
"""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_QK_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
"""
将Rotary Position Embedding应用于查询和键张量.
Args:
q (`torch.Tensor`): 查询张量.
k (`torch.Tensor`): 键张量.
cos (`torch.Tensor`): 旋转嵌入的余弦部分.
sin (`torch.Tensor`): 旋转嵌入的正弦部分.
position_ids (`torch.Tensor`):
与查询和键张量对应的令牌的位置索引. 例如, 这可以用于在使用KV缓存时传递偏移的位置id.
unsqueeze_dim (`int`, *optional*, 默认为1):
'unsqueeze_dim'参数指定沿其展开cos[position_ids]和sin[position_ids]的维度, 以便它们可以正确广播到q和k的维度. 例如, 请注意cos[position_ids]和sin[position_ids]的形状为[batch_size, seq_len, head_dim]. 然后, 如果q和k的形状为[batch_size, heads, seq_len, head_dim], 那么设置unsqueeze_dim=1使cos[position_ids]和sin[position_ids]可以广播到q和k的形状. 类似地, 如果q和k的形状为[batch_size, seq_len, heads, head_dim], 则设置unsqueeze_dim=2.
Returns:
旋转使用Rotary Position Embedding的查询和键张量的元组.
Applies Rotary Position Embedding to the query and key tensors.
Args:
q (`torch.Tensor`): The query tensor.
k (`torch.Tensor`): The key tensor.
cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding.
position_ids (`torch.Tensor`):
The position indices of the tokens corresponding to the query and key tensors. For example, this can be
used to pass offsetted position ids when working with a KV-cache.
unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
Returns:
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
"""
cos = cos[position_ids].unsqueeze(unsqueeze_dim)
sin = sin[position_ids].unsqueeze(unsqueeze_dim)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
def apply_BC_rotary_pos_emb(b, c, cos, sin, position_ids):
"""
将Rotary Position Embedding应用于B和C张量.
Args:
b (`torch.Tensor`): B张量. [batch_size, seq_len, ssm_state_size]
c (`torch.Tensor`): C张量. [batch_size, seq_len, ssm_state_size]
cos (`torch.Tensor`): 旋转嵌入的余弦部分.
sin (`torch.Tensor`): 旋转嵌入的正弦部分.
position_ids (`torch.Tensor`): 令牌的位置索引.
Returns:
旋转使用Rotary Position Embedding的B和C张量.
Applies Rotary Position Embedding to the B and C tensors.
Args:
b (`torch.Tensor`): The B tensor. [batch_size, seq_len, ssm_state_size]
c (`torch.Tensor`): The C tensor. [batch_size, seq_len, ssm_state_size]
cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding.
position_ids (`torch.Tensor`): The position indices of the tokens.
Returns:
`tuple(torch.Tensor)` comprising of the B and C tensors rotated using the Rotary Position Embedding.
"""
cos = cos[position_ids]
sin = sin[position_ids]
b_embed = (b * cos) + (rotate_half(b) * sin)
c_embed = (c * cos) + (rotate_half(c) * sin)
return b_embed, c_embed
class RotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
super().__init__()
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
# 在这里构建以使`torch.jit.trace`工作
# Build here to make `torch.jit.trace` work.
self._set_cos_sin_cache(
seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
)
def _set_cos_sin_cache(self, seq_len, device, dtype):
self.max_seq_len_cached = seq_len
t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
freqs = torch.outer(t, self.inv_freq)
# 与论文不同, 但它使用不同的排列顺序以获得相同的计算
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
def forward(self, x, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
if seq_len > self.max_seq_len_cached:
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
return (
self.cos_cached[:seq_len].to(dtype=x.dtype),
self.sin_cached[:seq_len].to(dtype=x.dtype),
)
class RMSNorm(nn.Module):
def __init__(self, hidden_size, eps: float = 1e-6, elementwise_affine: bool = True, bias: bool = True):
"""
RMSNorm 是T5LayerNorm的等效
RMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
if isinstance(hidden_size, int):
hidden_size = (hidden_size,)
self.hidden_size = tuple(hidden_size)
self.eps = eps
self.elementwise_affine = elementwise_affine
if elementwise_affine:
self.weight = nn.Parameter(torch.empty(self.hidden_size))
if bias:
self.bias = nn.Parameter(torch.empty(self.hidden_size))
else:
self.register_parameter("bias", None)
else:
self.register_parameter("weight", None)
self.register_parameter("bias", None)
self.reset_parameters()
def reset_parameters(self):
if self.elementwise_affine:
nn.init.ones_(self.weight)
if self.bias is not None:
nn.init.zeros_(self.bias)
def forward(
self,
hidden_states: torch.Tensor,
) -> torch.Tensor:
input_dtype = hidden_states.dtype
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
# 权重和偏置
# weight and bias
if self.elementwise_affine:
hidden_states = (hidden_states * self.weight).to(input_dtype)
if self.bias is not None:
hidden_states = (hidden_states + self.bias).to(input_dtype)
return hidden_states
class HybridMambaAttentionDynamicCache(DynamicCache):
"""
一个动态缓存, 可以处理注意力缓存(具有seq_len维度)和mamba缓存(无论seq_len如何都具有恒定形状).
此缓存有两组张量列表: `key_cache` 和 `value_cache` 用于注意力缓存, `conv_states` 和 `ssm_states` 用于mamba缓存.
每个列表都有`num_layers`张张量. 每个张量的预期形状
对于注意力层, `key_cache` 和 `value_cache` 的形状为`(batch_size, num_heads, seq_len, head_dim)`,
而 `conv_states` 和 `ssm_states` 的形状为`(batch_size, 0)`(空张量).
对于mamba层, `key_cache` 和 `value_cache` 的形状为`(batch_size, 0)`(空张量),
而 `conv_states` 表示卷积状态, 形状为`(batch_size, d_inner, d_conv)`,
而 `ssm_states` 表示ssm状态, 形状为`(batch_size, d_inner, d_state)`.
A dynamic cache that can handle both the attention cache (which has a seq_len dimension) and the mamba cache
(which has a constant shape regardless of seq_len).
This cache has two sets of lists of tensors: `key_cache` and `value_cache` for attention cache and `conv_states`
and `ssm_states` for mamba cache. Each of these lists has `num_layers` tensors. The expected shape for each tensor
For attention layers, `key_cache` and `value_cache` have a shape of `(batch_size, num_heads, seq_len, head_dim)`,
while `conv_states` and `ssm_states` have a shape of `(batch_size, 0)` (empty tensors).
For mamba layers, `key_cache` and `value_cache` have a shape of `(batch_size, 0)` (empty tensors),
while `conv_states` represents the convolution state and has a shape of `(batch_size, d_inner, d_conv)`,
and `ssm_states` represents the ssm state and has a shape of `(batch_size, d_inner, d_state)`.
"""
def __init__(self, config: CheemsOTCEConfig, batch_size, dtype=torch.float16, device=None):
self.dtype = dtype
self.layers_block_type = config.layers_block_type
self.has_previous_state = False # only used by mamba 只有mamba使用
intermediate_size = config.mamba_expand * config.hidden_size
ssm_state_size = config.mamba_d_state
conv_kernel_size = config.mamba_d_conv
self.conv_states = []
self.ssm_states = []
for i in range(config.num_hidden_layers):
if self.layers_block_type[i] == "mamba":
self.conv_states += [
torch.zeros(batch_size, intermediate_size, conv_kernel_size, device=device, dtype=dtype)
]
self.ssm_states += [
torch.zeros(batch_size, intermediate_size, ssm_state_size, device=device, dtype=dtype)
]
else:
self.conv_states += [torch.tensor([[]] * batch_size, device=device)]
self.ssm_states += [torch.tensor([[]] * batch_size, device=device)]
self.key_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)]
self.value_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)]
def update(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
layer_idx: int,
cache_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
# 更新缓存
# Update the cache
if self.key_cache[layer_idx].shape[-1] == 0:
self.key_cache[layer_idx] = key_states
self.value_cache[layer_idx] = value_states
else:
self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=2)
self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=2)
return self.key_cache[layer_idx], self.value_cache[layer_idx]
def reorder_cache(self, beam_idx: torch.LongTensor):
"""
重新排序缓存以进行beam搜索, 给定选择的beam索引.
Reorders the cache for beam search, given the selected beam indices.
"""
for layer_idx in range(len(self.key_cache)):
device = self.key_cache[layer_idx].device
self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device))
device = self.value_cache[layer_idx].device
self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device))
device = self.conv_states[layer_idx].device
self.conv_states[layer_idx] = self.conv_states[layer_idx].index_select(0, beam_idx.to(device))
device = self.ssm_states[layer_idx].device
self.ssm_states[layer_idx] = self.ssm_states[layer_idx].index_select(0, beam_idx.to(device))
def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]:
raise NotImplementedError("HybridMambaAttentionDynamicCache does not have a legacy cache equivalent.")
@classmethod
def from_legacy_cache(cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None) -> "DynamicCache":
raise NotImplementedError("HybridMambaAttentionDynamicCache does not have a legacy cache equivalent.")
class CheemsAttention(nn.Module):
"""
Multi-headed attention 来自 'Attention Is All You Need' 论文. 修改为使用滑动窗口注意力: Longformer 和 "Generating Long Sequences with Sparse Transformers".
Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer and "Generating Long Sequences with Sparse Transformers".
"""
def __init__(self, config: CheemsOTCEConfig, layer_idx:Optional[int] = None):
super().__init__()
self.config = config
self.layer_idx = layer_idx
if layer_idx is None:
logger.warning(
f"实例化 {self.__class__.__name__} 时没有传递 `layer_idx` 不推荐, 并且在使用缓存时会导致在前向调用期间出现错误. 请确保在创建此类时提供 `layer_idx`."
f"Instantiating {self.__class__.__name__} without passing `layer_idx` is deprecated and will lead to errors during forward calls when using caches. Please make sure to provide `layer_idx` when creating this class."
)
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.hidden_size // self.num_heads
self.num_key_value_heads = config.num_key_value_heads
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.max_position_embeddings = config.max_position_embeddings
self.rope_theta = config.rope_theta
self.is_causal = True
self.attention_dropout = config.attention_dropout
if (self.head_dim * self.num_heads) != self.hidden_size:
raise ValueError(
f"hidden_size 必须能被 num_heads 整除 (得到 `hidden_size`: {self.hidden_size}"
f" 和 `num_heads`: {self.num_heads})."
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
f" and `num_heads`: {self.num_heads})."
)
self.q_proj = nn.Linear(
self.hidden_size,
self.num_heads * self.head_dim,
bias=config.hidden_bias,
)
self.k_proj = nn.Linear(
self.hidden_size,
self.num_key_value_heads * self.head_dim,
bias=config.hidden_bias,
)
self.v_proj = nn.Linear(
self.hidden_size,
self.num_key_value_heads * self.head_dim,
bias=config.hidden_bias,
)
self.o_proj = nn.Linear(
self.num_heads * self.head_dim,
self.hidden_size,
bias=config.hidden_bias,
)
self.attention_rope = config.attention_rope
if self.attention_rope:
self.QK_rotary_emb = RotaryEmbedding(
self.head_dim,
max_position_embeddings=self.max_position_embeddings,
base=self.rope_theta,
)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[HybridMambaAttentionDynamicCache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) # [bsz, num_key_value_heads, q_len, head_dim]
kv_seq_len = key_states.size[-2]
if self.attention_rope:
cos, sin = self.QK_rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_QK_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
else:
if past_key_value is not None:
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx)
# 重复k/v头部, 如果n_kv_heads < n_heads
# Repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
offset = 64
query_length = query_states.size(1)
key_length = key_states.size(1)
logn = torch.arange(offset+1, offset+key_length+1, dtype=torch.float32, device=query_states.device)[-query_length:] # [query_length]
base = torch.tensor(4096).to(query_states.device) # 训练数据的平均长度 Training data average length
logn = torch.log(logn) / torch.log(base)
logn[logn < 1.0] = 1.0
logn = logn.to(query_states.dtype).view(1, query_length, 1, 1)
query_states = query_states * logn
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
if attention_mask is not None: # no matter the length, we just slice it 不管长度如何, 我们只是切片它
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
# 将注意力上升到fp32
# Upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
attn_output = torch.matmul(attn_weights, value_states)
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
raise ValueError(
f"`attn_output`应该是大小为{(bsz, self.num_heads, q_len, self.head_dim)}, 但是是{attn_output.size()}"
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is {attn_output.size()}"
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
class CheemsFlashAttention2(CheemsAttention):
"""
cheems flash attention 模块. 此模块继承自 `CheemsAttention`, 因为模块的权重保持不变. 唯一需要更改的是在前向传递中, 它需要正确调用flash attention的公共API, 并在输入包含任何填充标记的情况下处理它们.
cheems flash attention module. This module inherits from `CheemsAttention` as the weights of the module remain the same. The only thing that needs to be changed is to correctly call the public API of flash attention in the forward pass and handle them in case the input contains any padding tokens.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# TODO: 一旦RoCm的Flash Attention升级到2.1, 就应该删除这个. flash_attn<2.1生成左上对齐的因果掩码, 而这里需要的是右下对齐, 这是flash_attn>=2.1的默认设置. 这个属性用于处理这种差异. 参考: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
# 请注意, 对于flash_attn<2.1, 使用q_seqlen != k_seqlen(除了q_seqlen == 1的情况)会产生一个错误的掩码(左上).
# TODO: Remove this once RoCm's Flash Attention is upgraded to 2.1. flash_attn<2.1 generates a top-left aligned causal mask, while we need a bottom-right aligned one here, which is the default setting for flash_attn>=2.1. This attribute is used to handle this difference. Refer to: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
# Note that for flash_attn<2.1, using q_seqlen != k_seqlen (except for q_seqlen == 1) will produce a wrong mask (top-left).
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[HybridMambaAttentionDynamicCache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
):
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
kv_seq_len = key_states.size[-2]
# 由于输入可能被填充, 绝对序列长度取决于最大位置id.
# Because the input can be padded, the absolute sequence length depends on the max position id.
if self.attention_rope:
rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1
cos, sin = self.QK_rotary_emb(value_states, seq_len=rotary_seq_len)
query_states, key_states = apply_QK_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
use_sliding_windows = (
_flash_supports_window_size
and getattr(self.config, "sliding_window", None) is not None
and kv_seq_len > self.config.sliding_window
)
if not _flash_supports_window_size:
logger.warning_once(
"当前的flash attention版本不支持滑动窗口注意力, 为了更高效的内存实现, 请确保升级flash-attn库."
"The current version of flash attention does not support sliding window attention. For more memory-efficient implementation, make sure to upgrade the flash-attn library."
)
if past_key_value is not None:
# 激活切片缓存, 只有在配置中有一个值`sliding_windows`属性时
cache_has_contents = cache_position[0] > 0
if (
getattr(self.config, "sliding_window", None) is not None
and kv_seq_len > self.config.sliding_window
and cache_has_contents
):
slicing_tokens = 1 - self.config.sliding_window
past_key = past_key_value[self.layer_idx][0]
past_value = past_key_value[self.layer_idx][1]
past_key = past_key[:, :, slicing_tokens:, :].contiguous()
past_value = past_value[:, :, slicing_tokens:, :].contiguous()
if past_key.shape[-2] != self.config.sliding_window - 1:
raise ValueError(
f"过去的键必须具有形状(`batch_size, num_heads, self.config.sliding_window-1, head_dim`), 得到{past_key.shape}"
f"Past keys must have shape (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got {past_key.shape}"
)
if attention_mask is not None:
attention_mask = attention_mask[:, slicing_tokens:]
attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1)
if self.attention_rope:
cache_kwargs = {"sin": sin, "cos": cos}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
else:
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx)
# 如果n_kv_heads < n_heads, 重复k/v头
# Repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
dropout_rate = 0.0 if not self.training else self.attention_dropout
# 在PEFT中, 通常我们为了训练稳定性的原因将层规范转换为float32, 因此输入隐藏状态会被静默地转换为float32. 因此, 我们需要将它们转换回float16, 以确保一切都按预期工作.
# In PEFT, we usually convert layer norms to float32 for stability reasons, so input hidden states are silently converted to float32. Therefore, we need to convert them back to float16 to ensure everything works as expected.
input_dtype = query_states.dtype
if input_dtype == torch.float32:
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
# 处理模型被量化的情况
elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype
else:
target_dtype = self.q_proj.weight.dtype
logger.warning_once(
f"输入隐藏状态似乎被静默地转换为float32, 这可能与您已经将嵌入或层规范层转换为float32有关. 我们将把输入转换回{target_dtype}."
f"Input hidden states seem to have been silently converted to float32, which might be related to you already converting embeddings or layer norm layers to float32. We will convert the input back to {target_dtype}."
)
query_states = query_states.to(target_dtype)
key_states = key_states.to(target_dtype)
value_states = value_states.to(target_dtype)
# 重新调整形状以符合Flash Attention的预期形状
# Reshape to fit the expected shapes for Flash Attention
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
offset = 64
query_length = query_states.size(1)
key_length = key_states.size(1)
logn = torch.arange(offset+1, offset+key_length+1, dtype=torch.float32, device=query_states.device)[-query_length:] # [query_length]
base = torch.tensor(4096).to(query_states.device)
logn = torch.log(logn) / torch.log(base)
logn[logn < 1.0] = 1.0
logn = logn.to(query_states.dtype).view(1, query_length, 1, 1)
query_states = query_states * logn
attn_output = self._flash_attention_forward(
query_states,
key_states,
value_states,
attention_mask,
q_len,
dropout=dropout_rate,
use_sliding_windows=use_sliding_windows,
)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
def _flash_attention_forward(
self,
query_states,
key_states,
value_states,
attention_mask,
query_length,
dropout=0.0,
softmax_scale=None,
use_sliding_windows=False,
):
"""
告知Flash Attention的forward方法, 如果输入隐藏状态至少包含一个填充标记, 首先取消填充输入, 然后计算注意力分数并填充最终注意力分数.
args:
query_states (`torch.Tensor`):
要传递给Flash Attention API的输入查询状态
key_states (`torch.Tensor`):
要传递给Flash Attention API的输入键状态
value_states (`torch.Tensor`):
要传递给Flash Attention API的输入值状态
attention_mask (`torch.Tensor`):
填充掩码 - 对应于大小为`(batch_size, seq_len)`的张量, 其中0表示填充标记的位置, 1表示非填充标记的位置.
dropout (`int`, *optional*):
注意力dropout
softmax_scale (`float`, *optional*):
在应用softmax之前对QK^T进行缩放. 默认为1 / sqrt(head_dim)
use_sliding_windows (`bool`, *optional*):
是否激活滑动窗口注意力.
Call the forward method of Flash Attention to first unpad the input if the input hidden states contain at least one padding token, then compute the attention scores and pad the final attention scores.
args:
query_states (`torch.Tensor`):
Input query states to pass to the Flash Attention API
key_states (`torch.Tensor`):
Input key states to pass to the Flash Attention API
value_states (`torch.Tensor`):
Input value states to pass to the Flash Attention API
attention_mask (`torch.Tensor`):
Padding mask - tensor corresponding to size `(batch_size, seq_len)` where 0 represents the position of padding tokens and 1 represents the position of non-padding tokens.
dropout (`int`, *optional*):
Attention dropout
softmax_scale (`float`, *optional*):
Scale to apply to QK^T before applying softmax. Default is 1 / sqrt(head_dim)
use_sliding_windows (`bool`, *optional*):
Whether to activate sliding window attention.
"""
if not self._flash_attn_uses_top_left_mask:
causal = self.is_causal
else:
# TODO: 一旦RoCm的Flash Attention升级到2.1, 就应该删除`query_length != 1`检查. 有关详细信息, 请参见LlamaFlashAttention2 __init__中的注释.
# TODO: Remove the `query_length != 1` check once RoCm's Flash Attention is upgraded to 2.1. For more details, refer to the comments in LlamaFlashAttention2 __init__.
causal = self.is_causal and query_length != 1
# 序列中至少包含一个填充标记
# At least one padding token in the sequence
if attention_mask is not None:
batch_size = query_states.shape[0]
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
query_states, key_states, value_states, attention_mask, query_length
)
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
if not use_sliding_windows:
attn_output_unpad = flash_attn_varlen_func(
query_states,
key_states,
value_states,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_in_batch_q,
max_seqlen_k=max_seqlen_in_batch_k,
dropout_p=dropout,
softmax_scale=softmax_scale,
causal=causal,
)
else:
attn_output_unpad = flash_attn_varlen_func(
query_states,
key_states,
value_states,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_in_batch_q,
max_seqlen_k=max_seqlen_in_batch_k,
dropout_p=dropout,
softmax_scale=softmax_scale,
causal=causal,
window_size=(self.config.sliding_window, self.config.sliding_window),
)
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
else:
if not use_sliding_windows:
attn_output = flash_attn_func(
query_states,
key_states,
value_states,
dropout,
softmax_scale=softmax_scale,
causal=causal,
)
else:
attn_output = flash_attn_func(
query_states,
key_states,
value_states,
dropout,
softmax_scale=softmax_scale,
causal=causal,
window_size=(self.config.sliding_window, self.config.sliding_window),
)
return attn_output
def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape
# 在第一次迭代中, 我们需要通过在正确的位置切片它来正确重新创建填充掩码
# In the first iteration, we need to correctly recreate the padding mask by slicing it at the right positions
if kv_seq_len != attention_mask.shape[-1]:
attention_mask_num_tokens = attention_mask.shape[-1]
attention_mask = attention_mask[:, attention_mask_num_tokens - kv_seq_len :]
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
value_layer = index_first_axis(value_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
if query_length == kv_seq_len:
query_layer = index_first_axis(
query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k
)
cu_seqlens_q = cu_seqlens_k
max_seqlen_in_batch_q = max_seqlen_in_batch_k
indices_q = indices_k
elif query_length == 1:
max_seqlen_in_batch_q = 1
cu_seqlens_q = torch.arange(
batch_size + 1, dtype=torch.int32, device=query_layer.device
) # 这里有一个memcpy, 这是非常糟糕的. This is very bad as there is a memcpy here.
indices_q = cu_seqlens_q[:-1]
query_layer = query_layer.squeeze(1)
else:
# -q_len:切片假设左填充.
# -q_len: slicing assumes left padding.
attention_mask = attention_mask[:, -query_length:]
query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
return (
query_layer,
key_layer,
value_layer,
indices_q,
(cu_seqlens_q, cu_seqlens_k),
(max_seqlen_in_batch_q, max_seqlen_in_batch_k),
)
class CheemsSdpaAttention(CheemsAttention):
"""
cheems attention 模块使用torch.nn.functional.scaled_dot_product_attention. 该模块继承自`CheemsAttention`, 因为模块的权重保持不变. 唯一的更改是在前向传递中, 以适应SDPA API.
cheems attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from `CheemsAttention` as the weights of the module remain the same. The only thing that needs to be changed is to adapt to the SDPA API in the forward pass.
"""
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if output_attentions:
# TODO: 一旦实现了这一点, 通过例如`model.config.attn_implementation = "manual"`来改进这个警告.
# TODO: Improve this warning by implementing it once, e.g. by setting `model.config.attn_implementation = "manual"`.
logger.warning_once(
"CheemsModel正在使用CheemsSdpaAttention, 但`torch.nn.functional.scaled_dot_product_attention`不支持`output_attentions=True`. 回退到手动注意力实现, 但是从Transformers版本v5.0.0开始, 将需要指定手动实现. 可以在加载模型时使用参数`attn_implementation='eager'`来删除此警告."
"CheemsModel is using CheemsSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to manual attention implementation, but specifying manual implementation will be required starting from Transformers version v5.0.0. You can remove this warning by specifying manual implementation when loading the model using the parameter `attn_implementation='eager'`."
)
return super().forward(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2]
if self.attention_rope:
cos, sin = self.QK_rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_QK_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
else:
if past_key_value is not None:
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx)
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
causal_mask = attention_mask
if attention_mask is not None:
causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]