liusx14
commited on
Commit
·
a50a5cd
1
Parent(s):
ee119e1
upload
Browse files- README.md +135 -2
- config.json +48 -0
- configuration_telechat.py +210 -0
- generation_config.json +14 -0
- modeling_telechat.py +1105 -0
- pytorch_model.bin.index.json +458 -0
- special_tokens_map.json +30 -0
- tokenization_telechat.py +403 -0
- tokenizer.model +3 -0
- tokenizer_config.json +54 -0
README.md
CHANGED
@@ -1,3 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
---
|
2 |
-
|
3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<div align="center">
|
2 |
+
<h1>
|
3 |
+
星辰大模型(52B)
|
4 |
+
</h1>
|
5 |
+
</div>
|
6 |
+
|
7 |
+
|
8 |
+
|
9 |
+
|
10 |
+
# 目录
|
11 |
+
- [模型介绍](#模型介绍)
|
12 |
+
- [数据开源](#数据开源)
|
13 |
+
- [效果评测](#效果评测)
|
14 |
+
- [模型推理和部署](#模型推理和部署)
|
15 |
+
- [模型微调](#模型微调)
|
16 |
+
- [声明、协议、引用](#声明协议引用)
|
17 |
+
|
18 |
+
# 最新动态
|
19 |
+
- 5.16 开源52B版本chat模型
|
20 |
+
|
21 |
+
# 模型介绍
|
22 |
+
### 星辰大模型(52B)
|
23 |
+
- 星辰大模型(52B)是一款开源多语言大模型,其模型基座使用更优数据配比,采用课程学习方式,训练了2万亿tokens的中英文高质量数据。
|
24 |
+
- 我们开源了使用星辰语义大模型52B基座微调的对话模型,以及基于Deepspeed的微调代码和huggingface推理代码。
|
25 |
+
- 星辰大模型(52B)在模型评测中取得了领先的效果,在榜单评测上超过LLaMA-2-70B-Chat,与Qwen-72B-chat可比;通用对话性能已经超过GPT-3.5-Turbo。
|
26 |
+
|
27 |
+
### 模型结构
|
28 |
+
|
29 |
+
我们采用标准的 `Decoder-only` 结构设计了 **TeleChat** 模型,并在模型维度做了如下的一些改进:
|
30 |
+
|
31 |
+
- **位置编码**:我们使用 [Rotary Embedding](https://arxiv.org/pdf/2104.09864.pdf) 的位置编码方法,该方法将相对位置信息依赖集成到 self-attention 中,并且具有较好的位置外推性。Rotary Embedding还可以较好地与Flash-Attention v2 配合使用,将模型的训练速度提升约20%。
|
32 |
+
- **激活函数**:我们使用 [SwiGLU](https://arxiv.org/pdf/2002.05202.pdf) 激活函数来替代GELU激活函数。
|
33 |
+
- **层标准化**: 基于 [RMSNorm](https://arxiv.org/abs/1910.07467) 的 Pre-Normalization。
|
34 |
+
- **词嵌入层与输出层解耦**:我们将星辰52B的词嵌入层和输出lm head层参数分开,有助于增强训练稳定性和收敛性。
|
35 |
+
|
36 |
+
|
37 |
+
| | layer_num | hidden_size | ffn_hidden_size | head_num | tie_word_embeddings |
|
38 |
+
| ------- | --------- | ----------- | --------------- | -------- | ------------------- |
|
39 |
+
| 星辰52B | 64 | 8192 | 21824 | 64 | 否 |
|
40 |
+
|
41 |
---
|
42 |
+
|
43 |
+
我们开源的星辰52B模型:
|
44 |
+
- 支持deepspeed微调,开源了基于deepspeed的训练代码,支持Zero并行显存优化,同时集成了FlashAttention2
|
45 |
+
- 多轮能力支持。开源了多轮数据构建方式,针对多轮模型训练集成了针对多轮的mask loss训练方式,更好的聚焦多轮答案,提升问答效果。
|
46 |
+
|
47 |
+
|
48 |
+
# 效果评测
|
49 |
+
星辰52B模型相比同规模模型在评测效果方面也有较好的表现,我们的评测集涵盖了包括MMLU、AGIEval、CMMLU、 GSM8K、MATH、HumanEval 等数据集,评测能力包括了自然语言理解、知识、数学计算和推理、代码生成等
|
50 |
+
|
51 |
+
## 评测集介绍
|
52 |
+
|
53 |
+
### 通用能力
|
54 |
+
|
55 |
+
- MMLU 数据集是一个全面的英文评测数据集,涵盖了 57 个学科,包括人文学科、社会科学、自然科学、初等数学、美国历史、计算机科学、法律等等。
|
56 |
+
- CMMLU 数据集同样是一个全面的中文评估测试集,涵盖了从基础学科到高级专业水平的67个主题。
|
57 |
+
- AGIEval 数据集是一个专门为评估基础模型在难度较高的标准化考试(如大学入学考试、法学院入学考试、数学竞赛和律师资格考试)的语境中而设计的基准测试,包括中文试题和英文试题。
|
58 |
+
|
59 |
+
### 推理和代码能力
|
60 |
+
|
61 |
+
- GSM8K 数据集包含了8.5K高质量的小学数学题,能够评估语言模型在数学推理能力上的表现,我们利用[官方](https://github.com/openai/grade-school-math)的评测方案在test集上进行了4-shot测试。
|
62 |
+
|
63 |
+
- MATH 数据集包含了12.5K具有挑战性的高中数学竞赛题,难度较大,对语言模型的推理能力要求较高,基于[官方](https://github.com/hendrycks/math)的评测方案,我们在test集上进行了4-shot测试。
|
64 |
+
|
65 |
+
- HumanEval 数据集是一个由openai提供的代码能力测试数据集,它由 164 个编程问题组成,要求根据给定的问题和代码模板,生成正确的代码片段,我们利用[官方](https://github.com/openai/human-eval)评测方案在test集上进行了zero-shot测试。
|
66 |
+
|
67 |
+
|
68 |
+
|
69 |
+
## 评测结果如下
|
70 |
+
|
71 |
+
| Model | MMLU | CMMLU | AGIEval | GSM8K | MATH | HumanEval | BBH | HellaSwag |
|
72 |
+
| :--------------- | :------: | :-------: | :-------: | :------: | :------: | :-------: | :------: | :-------: |
|
73 |
+
| | 5-shot | 5-shot | zero-shot | 4-shot | 4-shot | zero-shot | 3-shot | zero-shot |
|
74 |
+
| LLaMA-2-70B-Chat | 63.8 | 43.3 | 37.9 | 59.3 | 10.4 | 32.3 | 60.8 | 80.6 |
|
75 |
+
| Qwen-72B-chat | 74 | 81.4 | 58.5 | 67.4 | 31.8 | 49.4 | 68 | 84.7 |
|
76 |
+
| 星辰7B-chat | 60.5 | 64.3 | 46.8 | 36.7 | 10.3 | 20.1 | 19.5 | 36.7 |
|
77 |
+
| 星辰12B-chat | 73.3 | 74.2 | 51.7 | 57.2 | 16.0 | 22.0 | 52.2 | 71.5 |
|
78 |
+
| **星辰52B-chat** | **76.6** | **73.79** | **61.1** | **63.5** | **13.5** | **36.6** | **60.3** | **86.3** |
|
79 |
+
|
80 |
+
说明:榜单均基于[OpenCompass](https://github.com/open-compass/OpenCompass/)平台提供的评测方法进行评估,而对于对比模型,我们同时参考了官方汇报结果和OpenCompass结果。
|
81 |
+
|
82 |
+
### 对话能力评测
|
83 |
+
|
84 |
+
为了评价模型的对话能力,研发团队建立了包含2500+单轮、多轮对话交互的内部评测系统,涵盖闲聊问答、专业知识、翻译、逻辑思维、长文写作、幻觉测试、安全测试、角色扮演、任务执行、数学能力等多个维度,并使用Judge模型基于详细的评价指标文档进行自动打分。在当前评测数据上,星辰52B模型的综合平均得分为83.8,高于GPT-3.5-Turbo的82.3。这一结果表明,星辰52B模型能较好地支持下游任务应用。
|
85 |
+
|
86 |
+
|
87 |
+
|
88 |
+
# 模型推理和部署
|
89 |
+
### 模型推理
|
90 |
+
当前模型支持fp16精度推理,适配4卡40G A100进行推理。具体推理操作请参考`infer.py`文件,该文件中有单轮和多轮的推理示例。推理结果示例见`infer_result.txt`。
|
91 |
+
|
92 |
+
**模型推理方法示范**
|
93 |
+
```python
|
94 |
+
import os
|
95 |
+
import torch
|
96 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
97 |
+
from transformers import GenerationConfig
|
98 |
+
PATH = "/path/to/TeleChat-52B-chat"
|
99 |
+
tokenizer = AutoTokenizer.from_pretrained(PATH, use_fast=False, trust_remote_code=True)
|
100 |
+
model = AutoModelForCausalLM.from_pretrained(PATH,torch_dtype=torch.bfloat16,device_map='auto',trust_remote_code=True)
|
101 |
+
question = "你作为一名气候保护协会的会员,你准备写一篇全球气候变化的新闻报告,要求体现出全球气候变化以前与现在情况的对比,字数要求1000字。"
|
102 |
+
generate_config = GenerationConfig.from_pretrained(PATH)
|
103 |
+
answer = model.chat(tokenizer,question, history_input_list = [], history_output_list = [],generation_config = generate_config)
|
104 |
+
print("machine:",answer)
|
105 |
+
```
|
106 |
+
|
107 |
+
# 声明、协议、引用
|
108 |
+
### 声明
|
109 |
+
我们在此声明,不要使用TeleChat模型及其衍生模型进行任何危害国家社会安全或违法的活动。同时,我们也要求使用者不要将TeleChat模型用于没有安全审查和备案的互联网服务。我们希望所有使用者遵守上述原则,确保科技发展在合法合规的环境下进行。
|
110 |
+
|
111 |
+
我们已经尽我们所能,来确保模型训练过程中使用的数据的合规性。然而,尽管我们已经做出了巨大的努力,但由于模型和数据的复杂性,仍有可能存在一些无法预见的问题。因此,如果由于使用TeleChat开源模型而导致的任何问题,包括但不限于数据安全问题、公共舆论风险,或模型被误导、滥用、传播或不当利用所带来的任何风险和问题,我们将不承担任何责任。
|
112 |
+
|
113 |
+
### 协议
|
114 |
+
社区使用 TeleChat 模型需要遵循《[TeleChat模型社区许可协议](./TeleChat模型社区许可协议.pdf)》。TeleChat模型支持商业用途,如果您计划将 TeleChat 模型或其衍生品用于商业目的,您需要通过以下联系邮箱 [email protected],提交《TeleChat模型社区许可协议》要求的申请材料。审核通过后,将特此授予您一个非排他性、全球性、不可转让、不可再许可、可撤销的商用版权许可。
|
115 |
+
|
116 |
+
### 引用
|
117 |
+
如需引用我们的工作,请使用如下 reference:
|
118 |
+
```
|
119 |
+
@misc{wang2024telechat,
|
120 |
+
title={TeleChat Technical Report},
|
121 |
+
author={Zihan Wang and Xinzhang Liu and Shixuan Liu and Yitong Yao and Yuyao Huang and Zhongjiang He and Xuelong Li and Yongxiang Li and Zhonghao Che and Zhaoxi Zhang and Yan Wang and Xin Wang and Luwen Pu and Huihan Xu and Ruiyu Fang and Yu Zhao and Jie Zhang and Xiaomeng Huang and Zhilong Lu and Jiaxin Peng and Wenjun Zheng and Shiquan Wang and Bingkai Yang and Xuewei he and Zhuoru Jiang and Qiyi Xie and Yanhan Zhang and Zhongqiu Li and Lingling Shi and Weiwei Fu and Yin Zhang and Zilu Huang and Sishi Xiong and Yuxiang Zhang and Chao Wang and Shuangyong Song},
|
122 |
+
year={2024},
|
123 |
+
eprint={2401.03804},
|
124 |
+
archivePrefix={arXiv},
|
125 |
+
primaryClass={cs.CL}
|
126 |
+
}
|
127 |
+
|
128 |
+
@misc{li2024teleflm,
|
129 |
+
title={Tele-FLM Technical Report},
|
130 |
+
author={Xiang Li and Yiqun Yao and Xin Jiang and Xuezhi Fang and Chao Wang and Xinzhang Liu and Zihan Wang and Yu Zhao and Xin Wang and Yuyao Huang and Shuangyong Song and Yongxiang Li and Zheng Zhang and Bo Zhao and Aixin Sun and Yequan Wang and Zhongjiang He and Zhongyuan Wang and Xuelong Li and Tiejun Huang},
|
131 |
+
year={2024},
|
132 |
+
eprint={2404.16645},
|
133 |
+
archivePrefix={arXiv},
|
134 |
+
primaryClass={cs.CL}
|
135 |
+
}
|
136 |
+
```
|
config.json
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"activation_function": "silu",
|
3 |
+
"add_bias_linear": false,
|
4 |
+
"attn_pdrop": 0.0,
|
5 |
+
"auto_map": {
|
6 |
+
"AutoConfig": "configuration_telechat.TELECHATConfig",
|
7 |
+
"AutoModel": "modeling_telechat.TELECHAT",
|
8 |
+
"AutoModelForCausalLM": "modeling_telechat.TELECHAT"
|
9 |
+
},
|
10 |
+
"bos_token_id": 1,
|
11 |
+
"embd_pdrop": 0.0,
|
12 |
+
"enable_flash_attn": true,
|
13 |
+
"eos_token_id": 2,
|
14 |
+
"initializer_range": 0.02,
|
15 |
+
"input_mult": 1.0,
|
16 |
+
"layer_norm_epsilon": 1e-05,
|
17 |
+
"model_type": "telechat",
|
18 |
+
"mup_base_width": 256,
|
19 |
+
"mup_scale_factor": 32.0,
|
20 |
+
"n_embd": 8192,
|
21 |
+
"n_head": 64,
|
22 |
+
"n_inner": 21824,
|
23 |
+
"n_layer": 64,
|
24 |
+
"n_positions": 4096,
|
25 |
+
"output_mult": 1.0,
|
26 |
+
"pad_token_id": 3,
|
27 |
+
"relative_encoding": "rotary",
|
28 |
+
"reorder_and_upcast_attn": true,
|
29 |
+
"resid_pdrop": 0.0,
|
30 |
+
"rotary_theta": 10000,
|
31 |
+
"rotary_use_xpos": false,
|
32 |
+
"rotary_xpos_scale_base": 512,
|
33 |
+
"scale_attn_by_inverse_layer_idx": true,
|
34 |
+
"scale_attn_weights": true,
|
35 |
+
"summary_activation": null,
|
36 |
+
"summary_first_dropout": 0.1,
|
37 |
+
"summary_proj_to_labels": true,
|
38 |
+
"summary_type": "cls_index",
|
39 |
+
"summary_use_proj": true,
|
40 |
+
"tie_word_embeddings": false,
|
41 |
+
"tokenizer_class": "TELECHATTokenizer",
|
42 |
+
"transformers_version": "4.34.1",
|
43 |
+
"unk_token_id": 0,
|
44 |
+
"use_RMSNorm": true,
|
45 |
+
"use_cache": true,
|
46 |
+
"use_mup": true,
|
47 |
+
"vocab_size": 80896
|
48 |
+
}
|
configuration_telechat.py
ADDED
@@ -0,0 +1,210 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
|
3 |
+
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
""" TELECHAT configuration"""
|
17 |
+
|
18 |
+
from transformers.configuration_utils import PretrainedConfig
|
19 |
+
from transformers.utils import logging
|
20 |
+
|
21 |
+
|
22 |
+
|
23 |
+
logger = logging.get_logger(__name__)
|
24 |
+
|
25 |
+
TELECHAT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
26 |
+
}
|
27 |
+
|
28 |
+
|
29 |
+
class TELECHATConfig(PretrainedConfig):
|
30 |
+
"""
|
31 |
+
xxxxxx
|
32 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
33 |
+
documentation from [`PretrainedConfig`] for more information.
|
34 |
+
Args:
|
35 |
+
vocab_size (`int`, *optional*, defaults to 50257):
|
36 |
+
Vocabulary size of the GPT-2 model. Defines the number of different tokens that can be represented by the
|
37 |
+
`inputs_ids` passed when calling [`GPT2Model`] or [`TFGPT2Model`].
|
38 |
+
n_positions (`int`, *optional*, defaults to 1024):
|
39 |
+
The maximum sequence length that this model might ever be used with. Typically set this to something large
|
40 |
+
just in case (e.g., 512 or 1024 or 2048).
|
41 |
+
n_embd (`int`, *optional*, defaults to 768):
|
42 |
+
Dimensionality of the embeddings and hidden states.
|
43 |
+
n_layer (`int`, *optional*, defaults to 12):
|
44 |
+
Number of hidden layers in the Transformer encoder.
|
45 |
+
n_head (`int`, *optional*, defaults to 12):
|
46 |
+
Number of attention heads for each attention layer in the Transformer encoder.
|
47 |
+
n_inner (`int`, *optional*, defaults to None):
|
48 |
+
Dimensionality of the inner feed-forward layers. `None` will set it to 4 times n_embd
|
49 |
+
activation_function (`str`, *optional*, defaults to `"gelu"`):
|
50 |
+
Activation function, to be selected in the list `["relu", "silu", "gelu", "tanh", "gelu_new"]`.
|
51 |
+
resid_pdrop (`float`, *optional*, defaults to 0.1):
|
52 |
+
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
|
53 |
+
embd_pdrop (`int`, *optional*, defaults to 0.1):
|
54 |
+
The dropout ratio for the embeddings.
|
55 |
+
attn_pdrop (`float`, *optional*, defaults to 0.1):
|
56 |
+
The dropout ratio for the attention.
|
57 |
+
layer_norm_epsilon (`float`, *optional*, defaults to 1e-5):
|
58 |
+
The epsilon to use in the layer normalization layers.
|
59 |
+
initializer_range (`float`, *optional*, defaults to 0.02):
|
60 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
61 |
+
summary_type (`string`, *optional*, defaults to `"cls_index"`):
|
62 |
+
Argument used when doing sequence summary, used in the models [`GPT2DoubleHeadsModel`] and
|
63 |
+
[`TFGPT2DoubleHeadsModel`].
|
64 |
+
Has to be one of the following options:
|
65 |
+
- `"last"`: Take the last token hidden state (like XLNet).
|
66 |
+
- `"first"`: Take the first token hidden state (like BERT).
|
67 |
+
- `"mean"`: Take the mean of all tokens hidden states.
|
68 |
+
- `"cls_index"`: Supply a Tensor of classification token position (like GPT/GPT-2).
|
69 |
+
- `"attn"`: Not implemented now, use multi-head attention.
|
70 |
+
summary_use_proj (`bool`, *optional*, defaults to `True`):
|
71 |
+
Argument used when doing sequence summary, used in the models [`GPT2DoubleHeadsModel`] and
|
72 |
+
[`TFGPT2DoubleHeadsModel`].
|
73 |
+
Whether or not to add a projection after the vector extraction.
|
74 |
+
summary_activation (`str`, *optional*):
|
75 |
+
Argument used when doing sequence summary. Used in for the multiple choice head in
|
76 |
+
[`GPT2DoubleHeadsModel`].
|
77 |
+
Pass `"tanh"` for a tanh activation to the output, any other value will result in no activation.
|
78 |
+
summary_proj_to_labels (`bool`, *optional*, defaults to `True`):
|
79 |
+
Argument used when doing sequence summary, used in the models [`GPT2DoubleHeadsModel`] and
|
80 |
+
[`TFGPT2DoubleHeadsModel`].
|
81 |
+
Whether the projection outputs should have `config.num_labels` or `config.hidden_size` classes.
|
82 |
+
summary_first_dropout (`float`, *optional*, defaults to 0.1):
|
83 |
+
Argument used when doing sequence summary, used in the models [`GPT2DoubleHeadsModel`] and
|
84 |
+
[`TFGPT2DoubleHeadsModel`].
|
85 |
+
The dropout ratio to be used after the projection and activation.
|
86 |
+
scale_attn_weights (`bool`, *optional*, defaults to `True`):
|
87 |
+
Scale attention weights by dividing by sqrt(hidden_size)..
|
88 |
+
use_cache (`bool`, *optional*, defaults to `True`):
|
89 |
+
Whether or not the model should return the last key/values attentions (not used by all models).
|
90 |
+
scale_attn_by_inverse_layer_idx (`bool`, *optional*, defaults to `False`):
|
91 |
+
Whether to additionally scale attention weights by `1 / layer_idx + 1`.
|
92 |
+
reorder_and_upcast_attn (`bool`, *optional*, defaults to `False`):
|
93 |
+
Whether to scale keys (K) prior to computing attention (dot-product) and upcast attention
|
94 |
+
dot-product/softmax to float() when training with mixed precision.
|
95 |
+
Example:
|
96 |
+
```python
|
97 |
+
>>> from transformers import GPT2Config, GPT2Model
|
98 |
+
>>> # Initializing a GPT2 configuration
|
99 |
+
>>> configuration = GPT2Config()
|
100 |
+
>>> # Initializing a model (with random weights) from the configuration
|
101 |
+
>>> model = GPT2Model(configuration)
|
102 |
+
>>> # Accessing the model configuration
|
103 |
+
>>> configuration = model.config
|
104 |
+
```"""
|
105 |
+
|
106 |
+
model_type = "telechat"
|
107 |
+
keys_to_ignore_at_inference = ["past_key_values"]
|
108 |
+
attribute_map = {
|
109 |
+
"hidden_size": "n_embd",
|
110 |
+
"max_position_embeddings": "n_positions",
|
111 |
+
"num_attention_heads": "n_head",
|
112 |
+
"num_hidden_layers": "n_layer",
|
113 |
+
}
|
114 |
+
|
115 |
+
def __init__(
|
116 |
+
self,
|
117 |
+
vocab_size=80000,
|
118 |
+
n_positions=1024,
|
119 |
+
n_embd=768,
|
120 |
+
n_layer=12,
|
121 |
+
n_head=12,
|
122 |
+
n_inner=None,
|
123 |
+
activation_function="gelu_new",
|
124 |
+
resid_pdrop=0.1,
|
125 |
+
embd_pdrop=0.1,
|
126 |
+
attn_pdrop=0.1,
|
127 |
+
layer_norm_epsilon=1e-5,
|
128 |
+
initializer_range=0.02,
|
129 |
+
summary_type="cls_index",
|
130 |
+
summary_use_proj=True,
|
131 |
+
summary_activation=None,
|
132 |
+
summary_proj_to_labels=True,
|
133 |
+
summary_first_dropout=0.1,
|
134 |
+
scale_attn_weights=True,
|
135 |
+
use_cache=True,
|
136 |
+
bos_token_id=None,
|
137 |
+
eos_token_id=None,
|
138 |
+
sep_token_id=None,
|
139 |
+
pad_token_id=None,
|
140 |
+
unk_token_id=None,
|
141 |
+
scale_attn_by_inverse_layer_idx=False,
|
142 |
+
reorder_and_upcast_attn=False,
|
143 |
+
relative_encoding=None,
|
144 |
+
rotary_theta=10000,
|
145 |
+
rotary_use_xpos=True,
|
146 |
+
rotary_xpos_scale_base=512,
|
147 |
+
use_mup=False,
|
148 |
+
mup_scale_factor=1.0,
|
149 |
+
output_mult=1.0,
|
150 |
+
input_mult=1.0,
|
151 |
+
mup_base_width=256,
|
152 |
+
enable_flash_attn=True,
|
153 |
+
use_RMSNorm=False,
|
154 |
+
add_bias_linear=True,
|
155 |
+
**kwargs,
|
156 |
+
):
|
157 |
+
self.vocab_size = vocab_size
|
158 |
+
self.n_positions = n_positions
|
159 |
+
self.n_embd = n_embd
|
160 |
+
self.n_layer = n_layer
|
161 |
+
self.n_head = n_head
|
162 |
+
self.n_inner = n_inner
|
163 |
+
self.activation_function = activation_function
|
164 |
+
self.resid_pdrop = resid_pdrop
|
165 |
+
self.embd_pdrop = embd_pdrop
|
166 |
+
self.attn_pdrop = attn_pdrop
|
167 |
+
self.layer_norm_epsilon = layer_norm_epsilon
|
168 |
+
self.initializer_range = initializer_range
|
169 |
+
self.summary_type = summary_type
|
170 |
+
self.summary_use_proj = summary_use_proj
|
171 |
+
self.summary_activation = summary_activation
|
172 |
+
self.summary_first_dropout = summary_first_dropout
|
173 |
+
self.summary_proj_to_labels = summary_proj_to_labels
|
174 |
+
self.scale_attn_weights = scale_attn_weights
|
175 |
+
self.use_cache = use_cache
|
176 |
+
self.scale_attn_by_inverse_layer_idx = scale_attn_by_inverse_layer_idx
|
177 |
+
self.reorder_and_upcast_attn = reorder_and_upcast_attn
|
178 |
+
self.relative_encoding = relative_encoding
|
179 |
+
self.use_RMSNorm = use_RMSNorm
|
180 |
+
self.add_bias_linear = add_bias_linear
|
181 |
+
|
182 |
+
# for rotary
|
183 |
+
self.rotary_theta = rotary_theta
|
184 |
+
self.rotary_use_xpos = rotary_use_xpos
|
185 |
+
self.rotary_xpos_scale_base = rotary_xpos_scale_base
|
186 |
+
|
187 |
+
# for mup
|
188 |
+
self.use_mup = use_mup
|
189 |
+
self.mup_scale_factor = mup_scale_factor
|
190 |
+
self.output_mult = output_mult
|
191 |
+
self.input_mult = input_mult
|
192 |
+
self.mup_base_width = mup_base_width
|
193 |
+
|
194 |
+
self.bos_token_id = bos_token_id
|
195 |
+
self.eos_token_id = eos_token_id
|
196 |
+
self.unk_token_id = unk_token_id
|
197 |
+
self.sep_token_id = sep_token_id
|
198 |
+
self.pad_token_id = pad_token_id
|
199 |
+
|
200 |
+
self.enable_flash_attn = enable_flash_attn
|
201 |
+
|
202 |
+
self.architectures = ["TELECHAT"]
|
203 |
+
self.auto_map = {
|
204 |
+
"AutoConfig": "configuration_telechat.TELECHATConfig",
|
205 |
+
"AutoModel": "modeling_telechat.TELECHAT",
|
206 |
+
"AutoModelForCausalLM": "modeling_telechat.TELECHAT"
|
207 |
+
}
|
208 |
+
|
209 |
+
|
210 |
+
super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, sep_token_id = sep_token_id, pad_token_id = pad_token_id, **kwargs)
|
generation_config.json
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"max_length": 2048,
|
3 |
+
"do_sample": false,
|
4 |
+
"use_cache": true,
|
5 |
+
"temperature": 0.3,
|
6 |
+
"top_k": 5,
|
7 |
+
"top_p": 0.85,
|
8 |
+
"repetition_penalty": 1.02,
|
9 |
+
"pad_token_id": 3,
|
10 |
+
"bos_token_id": 1,
|
11 |
+
"eos_token_id": 2,
|
12 |
+
"user_token_id": 20,
|
13 |
+
"bot_token_id": 21
|
14 |
+
}
|
modeling_telechat.py
ADDED
@@ -0,0 +1,1105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
|
3 |
+
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
4 |
+
#
|
5 |
+
# This code is based on OpenAI's GPT-2 library. It has been modified from its
|
6 |
+
# original forms to accommodate architectural differences compared to GPT-2.
|
7 |
+
#
|
8 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
9 |
+
# you may not use this file except in compliance with the License.
|
10 |
+
# You may obtain a copy of the License at
|
11 |
+
#
|
12 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
13 |
+
#
|
14 |
+
# Unless required by applicable law or agreed to in writing, software
|
15 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
16 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
17 |
+
# See the License for the specific language governing permissions and
|
18 |
+
# limitations under the License.
|
19 |
+
"""PyTorch TELECHAT model."""
|
20 |
+
|
21 |
+
from typing import Optional, Tuple, Union
|
22 |
+
|
23 |
+
import math
|
24 |
+
import torch
|
25 |
+
from einops import rearrange
|
26 |
+
from torch import einsum, nn
|
27 |
+
from torch.cuda.amp import autocast
|
28 |
+
import torch.nn.functional as F
|
29 |
+
from transformers.activations import ACT2FN
|
30 |
+
from transformers.modeling_outputs import (
|
31 |
+
BaseModelOutputWithPastAndCrossAttentions,
|
32 |
+
CausalLMOutputWithCrossAttentions,
|
33 |
+
SequenceClassifierOutputWithPast,
|
34 |
+
)
|
35 |
+
from transformers.modeling_utils import PreTrainedModel
|
36 |
+
from transformers.pytorch_utils import find_pruneable_heads_and_indices, prune_conv1d_layer
|
37 |
+
from transformers.utils import logging
|
38 |
+
from transformers.utils.model_parallel_utils import assert_device_map, get_device_map
|
39 |
+
try:
|
40 |
+
from flash_attn.flash_attn_interface import flash_attn_unpadded_func # flashattn1
|
41 |
+
print("# FLASH ATTENTION 1 DETECTED #")
|
42 |
+
except ImportError:
|
43 |
+
try:
|
44 |
+
from flash_attn.flash_attn_interface import flash_attn_varlen_func as flash_attn_unpadded_func # flashattn2
|
45 |
+
print("# FLASH ATTENTION 2 DETECTED #")
|
46 |
+
except ImportError:
|
47 |
+
print("# NO FLASH ATTENTION DETECTED #")
|
48 |
+
flash_attn_unpadded_func = None
|
49 |
+
from .configuration_telechat import TELECHATConfig
|
50 |
+
|
51 |
+
|
52 |
+
def debug_print_tensor(t, name, title='', show_dim=10):
|
53 |
+
# return
|
54 |
+
prefix = f'{title} -> '
|
55 |
+
if isinstance(t, torch.Tensor):
|
56 |
+
if len(t.shape) == 1:
|
57 |
+
output = f"{name}[{t.shape}]: {t[:show_dim]}"
|
58 |
+
elif len(t.shape) == 2:
|
59 |
+
output = f"{name}[{t.shape}]: {t[-1, :show_dim]}"
|
60 |
+
elif len(t.shape) == 3:
|
61 |
+
output = f" {name}[{t.shape}]: {t[-1, -1, :show_dim]}"
|
62 |
+
elif len(t.shape) == 4:
|
63 |
+
output = f"{name}[{t.shape}]: {t[-1, -1, -1, :show_dim]}"
|
64 |
+
else:
|
65 |
+
output = f"{name}[{t.shape}]"
|
66 |
+
elif isinstance(t, list):
|
67 |
+
output = f"{name} [{len(t)}]: {t[:show_dim]}"
|
68 |
+
else:
|
69 |
+
output = f"{name} 未知类型: {type(t)}"
|
70 |
+
print(prefix + output)
|
71 |
+
|
72 |
+
|
73 |
+
|
74 |
+
class Conv1D(nn.Module):
|
75 |
+
|
76 |
+
def __init__(self, nf, nx, bias=True):
|
77 |
+
super().__init__()
|
78 |
+
self.nf = nf
|
79 |
+
self.weight = nn.Parameter(torch.empty(nx, nf))
|
80 |
+
self.bias = None
|
81 |
+
if bias:
|
82 |
+
self.bias = nn.Parameter(torch.zeros(nf))
|
83 |
+
nn.init.normal_(self.weight, std=0.02)
|
84 |
+
|
85 |
+
def forward(self, x):
|
86 |
+
if self.bias is not None:
|
87 |
+
return torch.matmul(x, self.weight) + self.bias
|
88 |
+
else:
|
89 |
+
return torch.matmul(x, self.weight)
|
90 |
+
|
91 |
+
|
92 |
+
|
93 |
+
class RMSNorm(nn.Module):
|
94 |
+
def __init__(self, hidden_size, eps=1e-5):
|
95 |
+
super().__init__()
|
96 |
+
self.weight = nn.Parameter(torch.ones(hidden_size))
|
97 |
+
self.eps = eps
|
98 |
+
|
99 |
+
def forward(self, hidden_states):
|
100 |
+
input_dtype = hidden_states.dtype
|
101 |
+
hidden_states = hidden_states.to(torch.float32)
|
102 |
+
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
103 |
+
hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
|
104 |
+
return self.weight * hidden_states.to(input_dtype)
|
105 |
+
|
106 |
+
|
107 |
+
logger = logging.get_logger(__name__)
|
108 |
+
|
109 |
+
|
110 |
+
def exists(v):
|
111 |
+
return v is not None
|
112 |
+
|
113 |
+
|
114 |
+
class RotaryEmbedding(nn.Module):
|
115 |
+
def __init__(self, dim, use_xpos=False, xpos_scale_base=512, theta=10000):
|
116 |
+
super().__init__()
|
117 |
+
inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
|
118 |
+
self.register_buffer('inv_freq', inv_freq)
|
119 |
+
self.cache = dict()
|
120 |
+
self.cache_scale = dict()
|
121 |
+
self.use_xpos = use_xpos
|
122 |
+
if not use_xpos:
|
123 |
+
self.register_buffer('scale', None)
|
124 |
+
return
|
125 |
+
scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim)
|
126 |
+
self.register_buffer('scale', scale)
|
127 |
+
self.scale_base = xpos_scale_base
|
128 |
+
|
129 |
+
def forward(self, seq, cache_key=None):
|
130 |
+
|
131 |
+
if cache_key is not None and cache_key in self.cache:
|
132 |
+
return self.cache[cache_key]
|
133 |
+
|
134 |
+
inv_freq = self.inv_freq.to(device=seq.device)
|
135 |
+
freqs = einsum('i , j -> i j', seq, inv_freq)
|
136 |
+
# first part even vector components, second part odd vector components,
|
137 |
+
# 2 * dim in dimension size
|
138 |
+
scale = torch.cat((freqs, freqs), dim=-1)
|
139 |
+
if exists(cache_key):
|
140 |
+
self.cache[cache_key] = scale
|
141 |
+
return scale
|
142 |
+
|
143 |
+
def rotate_queries_and_keys(self, q, k, seq_dim=-2):
|
144 |
+
"""
|
145 |
+
use this only when xpos is activated.
|
146 |
+
"""
|
147 |
+
assert self.use_xpos and q.device == k.device
|
148 |
+
device, seq_len_k, seq_len_q = k.device, k.shape[seq_dim], q.shape[seq_dim]
|
149 |
+
pos_seq_k = torch.arange(seq_len_k, device=device, dtype=torch.float32)
|
150 |
+
pos_seq_q = torch.arange(seq_len_k - seq_len_q, seq_len_k, device=device, dtype=torch.float32)
|
151 |
+
freqs_k = self.forward(pos_seq_k, cache_key=f"{0}:{seq_len_k}")
|
152 |
+
freqs_q = self.forward(pos_seq_q, cache_key=f"{seq_len_k - seq_len_q}:{seq_len_k}")
|
153 |
+
scale_k = self.get_scale(pos_seq_k)
|
154 |
+
scale_q = self.get_scale(pos_seq_q, offset=seq_len_k - seq_len_q) # 这里的offset是Q相对于K的offset
|
155 |
+
rotated_q = apply_rotary_emb(freqs_q, q, scale=scale_q)
|
156 |
+
rotated_k = apply_rotary_emb(freqs_k, k, scale=scale_k ** -1)
|
157 |
+
return rotated_q, rotated_k
|
158 |
+
|
159 |
+
def rotate_queries_or_keys(self, t, seq_dim=-2, offset=0):
|
160 |
+
"""
|
161 |
+
use this only when xpos is NOT activated.
|
162 |
+
"""
|
163 |
+
# t's shape e.g. -> (batchsize, headnum, seqlen, dimofhead)
|
164 |
+
assert not self.use_xpos, 'you must use `.rotate_queries_and_keys` method instead and pass in both queries and keys, for length extrapolatable rotary embeddings'
|
165 |
+
device, seq_len = t.device, t.shape[seq_dim]
|
166 |
+
pos_seq_t = torch.arange(offset, offset + seq_len, device=device, dtype=torch.float32)
|
167 |
+
freqs = self.forward(pos_seq_t, cache_key=f"{offset}:{offset+seq_len}")
|
168 |
+
# freqs seqlen x dim
|
169 |
+
return apply_rotary_emb(freqs, t)
|
170 |
+
|
171 |
+
def get_scale(self, t, cache_key=None, offset=0, ):
|
172 |
+
assert self.use_xpos, 'This function is only useful for xpos.'
|
173 |
+
if exists(cache_key) and cache_key in self.cache_scale:
|
174 |
+
return self.cache_scale[cache_key]
|
175 |
+
if callable(t):
|
176 |
+
t = t()
|
177 |
+
length = len(t)
|
178 |
+
min_pos = -(length + offset) // 2
|
179 |
+
max_pos = length + offset + min_pos
|
180 |
+
power = torch.arange(min_pos, max_pos, 1).to(device=self.scale.device) / self.scale_base
|
181 |
+
scale = self.scale ** rearrange(power, 'n -> n 1')
|
182 |
+
scale = scale[-length:, :]
|
183 |
+
scale = torch.cat((scale, scale), dim=-1)
|
184 |
+
if exists(cache_key):
|
185 |
+
self.cache_scale[cache_key] = scale
|
186 |
+
return scale
|
187 |
+
|
188 |
+
|
189 |
+
def rotate_half(x):
|
190 |
+
"""
|
191 |
+
change sign so the last dimension becomes [-odd, +even]
|
192 |
+
"""
|
193 |
+
x1, x2 = torch.chunk(x, 2, dim=-1)
|
194 |
+
return torch.cat((-x2, x1), dim=-1)
|
195 |
+
|
196 |
+
|
197 |
+
def apply_rotary_emb(freqs, t, start_index=0, scale=1.):
|
198 |
+
"""
|
199 |
+
freq: seqlen x dim
|
200 |
+
t: [batchsize * headnum , seqlen , dim (dim_of_head actually)]
|
201 |
+
"""
|
202 |
+
dtype_t = t.dtype
|
203 |
+
freqs = freqs.to(device=t.device)
|
204 |
+
if isinstance(scale, torch.Tensor):
|
205 |
+
scale = scale.to(device=t.device)
|
206 |
+
rot_dim = freqs.shape[-1]
|
207 |
+
end_index = start_index + rot_dim
|
208 |
+
t_left, t, t_right = t[..., :start_index], t[..., start_index:end_index], t[..., end_index:]
|
209 |
+
t = (t * freqs.cos() + rotate_half(t) * freqs.sin()) * scale
|
210 |
+
rotated = torch.cat((t_left, t, t_right), dim=-1)
|
211 |
+
rotated = rotated.to(dtype=dtype_t)
|
212 |
+
return rotated
|
213 |
+
|
214 |
+
|
215 |
+
class TELECHATAttention(nn.Module):
|
216 |
+
def __init__(self, config, layer_idx=None):
|
217 |
+
super().__init__()
|
218 |
+
|
219 |
+
max_positions = config.max_position_embeddings
|
220 |
+
self.register_buffer(
|
221 |
+
"bias",
|
222 |
+
torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view(
|
223 |
+
1, 1, max_positions, max_positions
|
224 |
+
),
|
225 |
+
)
|
226 |
+
self.register_buffer("masked_bias", torch.tensor(-1e4))
|
227 |
+
|
228 |
+
self.embed_dim = config.hidden_size
|
229 |
+
self.num_heads = config.num_attention_heads
|
230 |
+
self.head_dim = self.embed_dim // self.num_heads
|
231 |
+
self.split_size = self.embed_dim
|
232 |
+
if self.head_dim * self.num_heads != self.embed_dim:
|
233 |
+
raise ValueError(
|
234 |
+
f"`embed_dim` must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
|
235 |
+
f" {self.num_heads})."
|
236 |
+
)
|
237 |
+
|
238 |
+
self.scale_attn_weights = config.scale_attn_weights
|
239 |
+
|
240 |
+
# Layer-wise attention scaling, reordering, and upcasting
|
241 |
+
self.scale_attn_by_inverse_layer_idx = config.scale_attn_by_inverse_layer_idx
|
242 |
+
# for alignment with megatron-lm in softmax scale
|
243 |
+
self.layer_idx = max(1, layer_idx)
|
244 |
+
self.reorder_and_upcast_attn = config.reorder_and_upcast_attn
|
245 |
+
|
246 |
+
self.relative_encoding = config.relative_encoding
|
247 |
+
self.rotary_use_xpos = config.rotary_use_xpos
|
248 |
+
|
249 |
+
self.use_mup = config.use_mup
|
250 |
+
|
251 |
+
self.c_attn = Conv1D(3 * self.embed_dim, self.embed_dim, bias=config.add_bias_linear)
|
252 |
+
self.c_proj = Conv1D(self.embed_dim, self.embed_dim, bias=config.add_bias_linear)
|
253 |
+
|
254 |
+
self.attn_dropout = nn.Dropout(config.attn_pdrop)
|
255 |
+
self.resid_dropout = nn.Dropout(config.resid_pdrop)
|
256 |
+
|
257 |
+
self.pruned_heads = set()
|
258 |
+
|
259 |
+
self.use_flash_attn = False
|
260 |
+
|
261 |
+
|
262 |
+
|
263 |
+
def set_max_positions(self, max_positions, device='cuda'):
|
264 |
+
self.max_positions = max_positions
|
265 |
+
self.register_buffer(
|
266 |
+
"bias",
|
267 |
+
torch.tril(torch.ones((self.max_positions, self.max_positions), dtype=torch.bool)).view(
|
268 |
+
1, 1, self.max_positions, self.max_positions
|
269 |
+
).to(device=device)
|
270 |
+
)
|
271 |
+
|
272 |
+
def prune_heads(self, heads):
|
273 |
+
if len(heads) == 0:
|
274 |
+
return
|
275 |
+
heads, index = find_pruneable_heads_and_indices(heads, self.num_heads, self.head_dim, self.pruned_heads)
|
276 |
+
index_attn = torch.cat([index, index + self.split_size, index + (2 * self.split_size)])
|
277 |
+
|
278 |
+
# Prune conv1d layers
|
279 |
+
self.c_attn = prune_conv1d_layer(self.c_attn, index_attn, dim=1)
|
280 |
+
self.c_proj = prune_conv1d_layer(self.c_proj, index, dim=0)
|
281 |
+
|
282 |
+
# Update hyper params
|
283 |
+
self.split_size = (self.split_size // self.num_heads) * (self.num_heads - len(heads))
|
284 |
+
self.num_heads = self.num_heads - len(heads)
|
285 |
+
self.pruned_heads = self.pruned_heads.union(heads)
|
286 |
+
|
287 |
+
def _attn(self, query, key, value, attention_mask=None, head_mask=None):
|
288 |
+
# (batch, head, seq_length, head_features)
|
289 |
+
# batch_size, head_num, k_seq_len(q_seq_len), head_features
|
290 |
+
batch_size, head_num, k_seq_len, head_features = key.shape
|
291 |
+
_, _, q_seq_len, _ = query.shape
|
292 |
+
|
293 |
+
if self.use_flash_attn:
|
294 |
+
# print("*")
|
295 |
+
# attn_output = torch.nn.functional._scaled_dot_product_attention(query, key, value, is_causal=True)
|
296 |
+
# attn_weights = None
|
297 |
+
# return attn_output, attn_weights
|
298 |
+
|
299 |
+
batch_size, seqlen_q = query.shape[0], query.shape[2]
|
300 |
+
seqlen_k = key.shape[2]
|
301 |
+
|
302 |
+
query, key, value = [rearrange(x, 'b h s ... -> (b s) h ...') for x in [query, key, value]]
|
303 |
+
cu_seqlens_q = torch.arange(0, (batch_size + 1) * seqlen_q, step=seqlen_q, dtype=torch.int32,
|
304 |
+
device=query.device)
|
305 |
+
is_causal = seqlen_q == seqlen_k
|
306 |
+
cu_seqlens_k = torch.arange(0, (batch_size + 1) * seqlen_k, step=seqlen_k, dtype=torch.int32,
|
307 |
+
device=query.device)
|
308 |
+
dropout_p = 0
|
309 |
+
|
310 |
+
softmax_scale = 1/torch.full([], (value.size(-1) ** 0.5), dtype=value.dtype, device=value.device) if self.scale_attn_weights else 1
|
311 |
+
attn_output = flash_attn_unpadded_func(
|
312 |
+
query, key, value, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen_k,
|
313 |
+
dropout_p,
|
314 |
+
softmax_scale=softmax_scale, causal=is_causal
|
315 |
+
)
|
316 |
+
attn_output = rearrange(attn_output, '(b s) h ... -> b h s ...', b=batch_size)
|
317 |
+
attn_weights = None
|
318 |
+
return attn_output, attn_weights
|
319 |
+
|
320 |
+
attn_weights = torch.matmul(query, key.transpose(-1, -2))
|
321 |
+
|
322 |
+
if self.scale_attn_weights:
|
323 |
+
if self.use_mup:
|
324 |
+
attn_weights = attn_weights / torch.full(
|
325 |
+
[], value.size(-1) / (value.size(-1) ** 0.5), dtype=attn_weights.dtype,
|
326 |
+
device=attn_weights.device
|
327 |
+
)
|
328 |
+
else:
|
329 |
+
attn_weights = attn_weights / torch.full(
|
330 |
+
[], value.size(-1) ** 0.5, dtype=attn_weights.dtype, device=attn_weights.device
|
331 |
+
)
|
332 |
+
|
333 |
+
if not self.is_cross_attention:
|
334 |
+
# if only "normal" attention layer implements causal mask
|
335 |
+
query_length, key_length = query.size(-2), key.size(-2)
|
336 |
+
causal_mask = self.bias[:, :, key_length - query_length: key_length, :key_length]
|
337 |
+
mask_value = torch.finfo(attn_weights.dtype).min
|
338 |
+
# Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
|
339 |
+
# Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
|
340 |
+
mask_value = torch.full([], mask_value, dtype=attn_weights.dtype).to(attn_weights.device)
|
341 |
+
attn_weights = torch.where(causal_mask, attn_weights.to(attn_weights.dtype), mask_value)
|
342 |
+
|
343 |
+
if attention_mask is not None:
|
344 |
+
# Apply the attention mask
|
345 |
+
attn_weights = attn_weights + attention_mask
|
346 |
+
|
347 |
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
348 |
+
|
349 |
+
# Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op otherwise
|
350 |
+
attn_weights = attn_weights.type(value.dtype)
|
351 |
+
attn_weights = self.attn_dropout(attn_weights)
|
352 |
+
|
353 |
+
# Mask heads if we want to
|
354 |
+
if head_mask is not None:
|
355 |
+
attn_weights = attn_weights * head_mask
|
356 |
+
|
357 |
+
attn_output = torch.matmul(attn_weights, value)
|
358 |
+
|
359 |
+
return attn_output, attn_weights
|
360 |
+
|
361 |
+
def _upcast_and_reordered_attn(self, query, key, value, attention_mask=None, head_mask=None):
|
362 |
+
# Use `torch.baddbmm` (a bit more efficient w/ alpha param for scaling -- from Megatron-LM)
|
363 |
+
bsz, num_heads, q_seq_len, dk = query.size()
|
364 |
+
_, _, k_seq_len, _ = key.size()
|
365 |
+
|
366 |
+
# Preallocate attn_weights for `baddbmm`
|
367 |
+
attn_weights = torch.empty(bsz * num_heads, q_seq_len, k_seq_len, dtype=query.dtype, device=query.device)
|
368 |
+
|
369 |
+
# Compute Scale Factor
|
370 |
+
scale_factor = 1.0
|
371 |
+
if self.scale_attn_weights:
|
372 |
+
scale_factor /= float(value.size(-1)) ** 0.5
|
373 |
+
|
374 |
+
if self.scale_attn_by_inverse_layer_idx:
|
375 |
+
scale_factor /= float(self.layer_idx)
|
376 |
+
# Upcast (turn off autocast) and reorder (Scale K by 1 / root(dk))
|
377 |
+
with autocast(enabled=False):
|
378 |
+
q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(-1, dk, k_seq_len)
|
379 |
+
attn_weights = torch.baddbmm(attn_weights, q, k, beta=0, alpha=scale_factor)
|
380 |
+
attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len)
|
381 |
+
|
382 |
+
if not self.is_cross_attention:
|
383 |
+
attn_weights = attn_weights.float()
|
384 |
+
if self.scale_attn_by_inverse_layer_idx:
|
385 |
+
attn_weights *= self.layer_idx
|
386 |
+
# if only "normal" attention layer implements causal mask
|
387 |
+
query_length, key_length = query.size(-2), key.size(-2)
|
388 |
+
causal_mask = self.bias[:, :, key_length - query_length: key_length, :key_length]
|
389 |
+
mask_value = -10000.0 # align with megatron-lm
|
390 |
+
# Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
|
391 |
+
# Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
|
392 |
+
mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(attn_weights.device)
|
393 |
+
attn_weights = torch.where(causal_mask, attn_weights, mask_value)
|
394 |
+
|
395 |
+
if attention_mask is not None:
|
396 |
+
# Apply the attention mask
|
397 |
+
attn_weights = attn_weights + attention_mask
|
398 |
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
399 |
+
# Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op if otherwise
|
400 |
+
if attn_weights.dtype != torch.float32:
|
401 |
+
raise RuntimeError("Error with upcasting, attn_weights does not have dtype torch.float32")
|
402 |
+
attn_weights = attn_weights.type(value.dtype)
|
403 |
+
attn_weights = self.attn_dropout(attn_weights)
|
404 |
+
|
405 |
+
# Mask heads if we want to
|
406 |
+
if head_mask is not None:
|
407 |
+
attn_weights = attn_weights * head_mask
|
408 |
+
attn_output = torch.matmul(attn_weights, value)
|
409 |
+
return attn_output, attn_weights
|
410 |
+
|
411 |
+
def _split_heads(self, tensor, num_heads, attn_head_size):
|
412 |
+
"""
|
413 |
+
Splits hidden_size dim into attn_head_size and num_heads
|
414 |
+
"""
|
415 |
+
new_shape = tensor.size()[:-1] + (num_heads, attn_head_size)
|
416 |
+
tensor = tensor.view(new_shape)
|
417 |
+
return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features)
|
418 |
+
|
419 |
+
def _merge_heads(self, tensor, num_heads, attn_head_size):
|
420 |
+
"""
|
421 |
+
Merges attn_head_size dim and num_attn_heads dim into hidden_size
|
422 |
+
"""
|
423 |
+
tensor = tensor.permute(0, 2, 1, 3).contiguous()
|
424 |
+
new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,)
|
425 |
+
return tensor.view(new_shape)
|
426 |
+
|
427 |
+
def forward(
|
428 |
+
self,
|
429 |
+
hidden_states: Optional[Tuple[torch.FloatTensor]],
|
430 |
+
layer_past: Optional[Tuple[torch.Tensor]] = None,
|
431 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
432 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
433 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
434 |
+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
435 |
+
rotary_embedding: Optional[RotaryEmbedding] = None,
|
436 |
+
use_cache: Optional[bool] = False,
|
437 |
+
output_attentions: Optional[bool] = False,
|
438 |
+
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
|
439 |
+
if encoder_hidden_states is not None:
|
440 |
+
if not hasattr(self, "q_attn"):
|
441 |
+
raise ValueError(
|
442 |
+
"If class is used as cross attention, the weights `q_attn` have to be defined. "
|
443 |
+
"Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`."
|
444 |
+
)
|
445 |
+
|
446 |
+
query = self.q_attn(hidden_states)
|
447 |
+
key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2)
|
448 |
+
attention_mask = encoder_attention_mask
|
449 |
+
else:
|
450 |
+
query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)
|
451 |
+
|
452 |
+
query = self._split_heads(query, self.num_heads, self.head_dim)
|
453 |
+
key = self._split_heads(key, self.num_heads, self.head_dim)
|
454 |
+
value = self._split_heads(value, self.num_heads, self.head_dim)
|
455 |
+
|
456 |
+
if layer_past is not None:
|
457 |
+
past_key, past_value = layer_past
|
458 |
+
key = torch.cat((past_key, key), dim=-2)
|
459 |
+
value = torch.cat((past_value, value), dim=-2)
|
460 |
+
|
461 |
+
if use_cache is True:
|
462 |
+
present = (key, value)
|
463 |
+
else:
|
464 |
+
present = None
|
465 |
+
|
466 |
+
batch_size, head_num, k_seq_len, head_features = key.shape
|
467 |
+
_, _, q_seq_len, _ = query.shape
|
468 |
+
query_offset = k_seq_len - q_seq_len
|
469 |
+
if rotary_embedding is not None:
|
470 |
+
query = query.contiguous().view(batch_size * head_num, q_seq_len, head_features)
|
471 |
+
key = key.contiguous().view(batch_size * head_num, k_seq_len, head_features)
|
472 |
+
|
473 |
+
# batch_size * head_num, k_seq_len(q_seq_len), head_features
|
474 |
+
if self.rotary_use_xpos:
|
475 |
+
# query: [batch_size * head_num, seqlen, hn]
|
476 |
+
query, key = rotary_embedding.rotate_queries_and_keys(query, key)
|
477 |
+
else:
|
478 |
+
query = rotary_embedding.rotate_queries_or_keys(query, offset=query_offset)
|
479 |
+
key = rotary_embedding.rotate_queries_or_keys(key)
|
480 |
+
# batch_size * head_num, k_seq_len(q_seq_len), head_features
|
481 |
+
query = query.view(batch_size, head_num, q_seq_len, head_features)
|
482 |
+
key = key.view(batch_size, head_num, k_seq_len, head_features)
|
483 |
+
|
484 |
+
if self.reorder_and_upcast_attn and not self.use_flash_attn:
|
485 |
+
attn_output, attn_weights = self._upcast_and_reordered_attn(query, key, value, attention_mask, head_mask)
|
486 |
+
else:
|
487 |
+
attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
|
488 |
+
attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
|
489 |
+
attn_output = self.c_proj(attn_output)
|
490 |
+
attn_output = self.resid_dropout(attn_output)
|
491 |
+
outputs = (attn_output, present)
|
492 |
+
if output_attentions:
|
493 |
+
outputs += (attn_weights,)
|
494 |
+
|
495 |
+
return outputs
|
496 |
+
|
497 |
+
|
498 |
+
class TELECHATMLP(nn.Module):
|
499 |
+
def __init__(self, intermediate_size, config):
|
500 |
+
super().__init__()
|
501 |
+
embed_dim = config.hidden_size
|
502 |
+
if config.activation_function=='silu':
|
503 |
+
up_intermediate_size = 2 * intermediate_size
|
504 |
+
else:
|
505 |
+
up_intermediate_size = intermediate_size
|
506 |
+
self.c_fc = Conv1D(up_intermediate_size, embed_dim, bias=config.add_bias_linear)
|
507 |
+
self.c_proj = Conv1D(embed_dim, intermediate_size, bias=config.add_bias_linear)
|
508 |
+
if config.activation_function=='silu':
|
509 |
+
def swiglu(x):
|
510 |
+
x = torch.chunk(x, 2, dim=-1)
|
511 |
+
return F.silu(x[0]) * x[1]
|
512 |
+
self.act = swiglu
|
513 |
+
else:
|
514 |
+
self.act = ACT2FN[config.activation_function]
|
515 |
+
self.dropout = nn.Dropout(config.resid_pdrop)
|
516 |
+
|
517 |
+
def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.FloatTensor:
|
518 |
+
hidden_states = self.c_fc(hidden_states)
|
519 |
+
# print(f'activation func: {self.act}')
|
520 |
+
# print(f'before act: hidden_states {hidden_states.shape}')
|
521 |
+
hidden_states = self.act(hidden_states)
|
522 |
+
# print(f'after act: hidden_states {hidden_states.shape}')
|
523 |
+
hidden_states = self.c_proj(hidden_states)
|
524 |
+
hidden_states = self.dropout(hidden_states)
|
525 |
+
return hidden_states
|
526 |
+
|
527 |
+
|
528 |
+
class TELECHATBlock(nn.Module):
|
529 |
+
def __init__(self, config, layer_idx=None):
|
530 |
+
super().__init__()
|
531 |
+
LayerNorm = nn.LayerNorm if not config.use_RMSNorm else RMSNorm
|
532 |
+
hidden_size = config.hidden_size
|
533 |
+
inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size
|
534 |
+
self.layer_idx = layer_idx
|
535 |
+
self.ln_1 = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
536 |
+
self.attn = TELECHATAttention(config, layer_idx=layer_idx)
|
537 |
+
self.ln_2 = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
538 |
+
self.mlp = TELECHATMLP(inner_dim, config)
|
539 |
+
|
540 |
+
def forward(
|
541 |
+
self,
|
542 |
+
hidden_states: Optional[Tuple[torch.FloatTensor]],
|
543 |
+
layer_past: Optional[Tuple[torch.Tensor]] = None,
|
544 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
545 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
546 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
547 |
+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
548 |
+
rotary_embedding: Optional[RotaryEmbedding] = None,
|
549 |
+
use_cache: Optional[bool] = False,
|
550 |
+
output_attentions: Optional[bool] = False,
|
551 |
+
) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]:
|
552 |
+
residual = hidden_states
|
553 |
+
hidden_states = self.ln_1(hidden_states)
|
554 |
+
# debug_print_tensor(hidden_states, 'after ln_1')
|
555 |
+
attn_outputs = self.attn(
|
556 |
+
hidden_states,
|
557 |
+
layer_past=layer_past,
|
558 |
+
attention_mask=attention_mask,
|
559 |
+
head_mask=head_mask,
|
560 |
+
rotary_embedding=rotary_embedding,
|
561 |
+
use_cache=use_cache,
|
562 |
+
output_attentions=output_attentions
|
563 |
+
)
|
564 |
+
attn_output = attn_outputs[0] # output_attn: a, present, (attentions)
|
565 |
+
outputs = attn_outputs[1:]
|
566 |
+
# residual connection
|
567 |
+
hidden_states = attn_output + residual
|
568 |
+
|
569 |
+
residual = hidden_states
|
570 |
+
hidden_states = self.ln_2(hidden_states)
|
571 |
+
feed_forward_hidden_states = self.mlp(hidden_states)
|
572 |
+
# residual connection
|
573 |
+
hidden_states = residual + feed_forward_hidden_states
|
574 |
+
if use_cache:
|
575 |
+
outputs = (hidden_states,) + outputs
|
576 |
+
else:
|
577 |
+
outputs = (hidden_states,) + outputs[1:]
|
578 |
+
# debug_print_tensor(hidden_states, 'block output')
|
579 |
+
|
580 |
+
return outputs
|
581 |
+
|
582 |
+
|
583 |
+
class TELECHATPretrainedModel(PreTrainedModel):
|
584 |
+
"""
|
585 |
+
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
586 |
+
models.
|
587 |
+
"""
|
588 |
+
|
589 |
+
config_class = TELECHATConfig
|
590 |
+
load_tf_weights = None
|
591 |
+
base_model_prefix = "transformer"
|
592 |
+
is_parallelizable = True
|
593 |
+
supports_gradient_checkpointing = True
|
594 |
+
_no_split_modules = ["TELECHATBlock"]
|
595 |
+
|
596 |
+
def __init__(self, *inputs, **kwargs):
|
597 |
+
super().__init__(*inputs, **kwargs)
|
598 |
+
|
599 |
+
def _init_weights(self, module):
|
600 |
+
"""Initialize the weights."""
|
601 |
+
if isinstance(module, (nn.Linear, Conv1D)):
|
602 |
+
# Slightly different from the TF version which uses truncated_normal for initialization
|
603 |
+
# cf https://github.com/pytorch/pytorch/pull/5617
|
604 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
605 |
+
if module.bias is not None:
|
606 |
+
module.bias.data.zero_()
|
607 |
+
elif isinstance(module, nn.Embedding):
|
608 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
609 |
+
if module.padding_idx is not None:
|
610 |
+
module.weight.data[module.padding_idx].zero_()
|
611 |
+
elif isinstance(module, nn.LayerNorm) or isinstance(module, RMSNorm):
|
612 |
+
module.bias.data.zero_()
|
613 |
+
module.weight.data.fill_(1.0)
|
614 |
+
|
615 |
+
# Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
|
616 |
+
# > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
|
617 |
+
# > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
|
618 |
+
# > -- GPT-2 :: https://openai.com/blog/better-language-models/
|
619 |
+
#
|
620 |
+
# Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
|
621 |
+
for name, p in module.named_parameters():
|
622 |
+
if name == "c_proj.weight":
|
623 |
+
# Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
|
624 |
+
p.data.normal_(mean=0.0, std=(self.config.initializer_range / math.sqrt(2 * self.config.n_layer)))
|
625 |
+
|
626 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
627 |
+
if isinstance(module, TELECHATTransformer):
|
628 |
+
module.gradient_checkpointing = value
|
629 |
+
|
630 |
+
|
631 |
+
class TELECHATTransformer(TELECHATPretrainedModel):
|
632 |
+
_keys_to_ignore_on_load_missing = ["attn.masked_bias"]
|
633 |
+
|
634 |
+
def __init__(self, config):
|
635 |
+
super().__init__(config)
|
636 |
+
|
637 |
+
self.embed_dim = config.hidden_size
|
638 |
+
|
639 |
+
self.relative_encoding = config.relative_encoding
|
640 |
+
self.wte = nn.Embedding(config.vocab_size, self.embed_dim)
|
641 |
+
|
642 |
+
self.use_mup = config.use_mup
|
643 |
+
if self.use_mup:
|
644 |
+
self.input_mult = config.input_mult
|
645 |
+
|
646 |
+
if self.relative_encoding is None:
|
647 |
+
self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
|
648 |
+
elif self.relative_encoding == 'rotary':
|
649 |
+
pe_dim = config.n_embd // config.n_head
|
650 |
+
self.wpe = RotaryEmbedding(pe_dim,
|
651 |
+
use_xpos=config.rotary_use_xpos,
|
652 |
+
xpos_scale_base=config.rotary_xpos_scale_base,
|
653 |
+
theta=config.rotary_theta
|
654 |
+
)
|
655 |
+
|
656 |
+
else:
|
657 |
+
raise RuntimeError(
|
658 |
+
f'Unknown relative positional encoding type: `relative_encoding`={self.relative_encoding}')
|
659 |
+
self.drop = nn.Dropout(config.embd_pdrop)
|
660 |
+
self.h = nn.ModuleList([TELECHATBlock(config, layer_idx=i + 1) for i in range(config.num_hidden_layers)])
|
661 |
+
LayerNorm = nn.LayerNorm if not config.use_RMSNorm else RMSNorm
|
662 |
+
self.ln_f = LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
|
663 |
+
|
664 |
+
# Model parallel
|
665 |
+
self.model_parallel = False
|
666 |
+
self.device_map = None
|
667 |
+
self.gradient_checkpointing = False
|
668 |
+
|
669 |
+
# Initialize weights and apply final processing
|
670 |
+
self.post_init()
|
671 |
+
|
672 |
+
# @add_start_docstrings(PARALLELIZE_DOCSTRING)
|
673 |
+
def parallelize(self, device_map=None):
|
674 |
+
# Check validity of device_map
|
675 |
+
self.device_map = (
|
676 |
+
get_device_map(len(self.h), range(torch.cuda.device_count())) if device_map is None else device_map
|
677 |
+
)
|
678 |
+
assert_device_map(self.device_map, len(self.h))
|
679 |
+
self.model_parallel = True
|
680 |
+
self.first_device = "cpu" if "cpu" in self.device_map.keys() else "cuda:" + str(min(self.device_map.keys()))
|
681 |
+
self.last_device = "cuda:" + str(max(self.device_map.keys()))
|
682 |
+
self.wte = self.wte.to(self.first_device)
|
683 |
+
self.wpe = self.wpe.to(self.first_device)
|
684 |
+
# Load onto devices
|
685 |
+
for k, v in self.device_map.items():
|
686 |
+
for block in v:
|
687 |
+
cuda_device = "cuda:" + str(k)
|
688 |
+
self.h[block] = self.h[block].to(cuda_device)
|
689 |
+
# ln_f to last
|
690 |
+
self.ln_f = self.ln_f.to(self.last_device)
|
691 |
+
|
692 |
+
def deparallelize(self):
|
693 |
+
self.model_parallel = False
|
694 |
+
self.device_map = None
|
695 |
+
self.first_device = "cpu"
|
696 |
+
self.last_device = "cpu"
|
697 |
+
self.wte = self.wte.to("cpu")
|
698 |
+
self.wpe = self.wpe.to("cpu")
|
699 |
+
for index in range(len(self.h)):
|
700 |
+
self.h[index] = self.h[index].to("cpu")
|
701 |
+
self.ln_f = self.ln_f.to("cpu")
|
702 |
+
torch.cuda.empty_cache()
|
703 |
+
|
704 |
+
def get_input_embeddings(self):
|
705 |
+
return self.wte
|
706 |
+
|
707 |
+
def set_input_embeddings(self, new_embeddings):
|
708 |
+
self.wte = new_embeddings
|
709 |
+
|
710 |
+
def _prune_heads(self, heads_to_prune):
|
711 |
+
"""
|
712 |
+
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
|
713 |
+
"""
|
714 |
+
for layer, heads in heads_to_prune.items():
|
715 |
+
self.h[layer].attn.prune_heads(heads)
|
716 |
+
|
717 |
+
def forward(
|
718 |
+
self,
|
719 |
+
input_ids: Optional[torch.LongTensor] = None,
|
720 |
+
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
721 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
722 |
+
token_type_ids: Optional[torch.LongTensor] = None,
|
723 |
+
position_ids: Optional[torch.LongTensor] = None,
|
724 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
725 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
726 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
727 |
+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
728 |
+
use_cache: Optional[bool] = None,
|
729 |
+
output_attentions: Optional[bool] = None,
|
730 |
+
output_hidden_states: Optional[bool] = None,
|
731 |
+
return_dict: Optional[bool] = None,
|
732 |
+
) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
|
733 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
734 |
+
output_hidden_states = (
|
735 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
736 |
+
)
|
737 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
738 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
739 |
+
|
740 |
+
if input_ids is not None and inputs_embeds is not None:
|
741 |
+
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
742 |
+
elif input_ids is not None:
|
743 |
+
input_shape = input_ids.size()
|
744 |
+
input_ids = input_ids.view(-1, input_shape[-1])
|
745 |
+
batch_size = input_ids.shape[0]
|
746 |
+
elif inputs_embeds is not None:
|
747 |
+
input_shape = inputs_embeds.size()[:-1]
|
748 |
+
batch_size = inputs_embeds.shape[0]
|
749 |
+
else:
|
750 |
+
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
751 |
+
|
752 |
+
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
753 |
+
|
754 |
+
if token_type_ids is not None:
|
755 |
+
token_type_ids = token_type_ids.view(-1, input_shape[-1])
|
756 |
+
if position_ids is not None:
|
757 |
+
position_ids = position_ids.view(-1, input_shape[-1])
|
758 |
+
|
759 |
+
if past_key_values is None:
|
760 |
+
past_length = 0
|
761 |
+
past_key_values = tuple([None] * len(self.h))
|
762 |
+
else:
|
763 |
+
past_length = past_key_values[0][0].size(-2)
|
764 |
+
if position_ids is None:
|
765 |
+
position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
|
766 |
+
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
|
767 |
+
|
768 |
+
# GPT2Attention mask.
|
769 |
+
if attention_mask is not None:
|
770 |
+
if batch_size <= 0:
|
771 |
+
raise ValueError("batch_size has to be defined and > 0")
|
772 |
+
attention_mask = attention_mask.view(batch_size, -1)
|
773 |
+
# We create a 3D attention mask from a 2D tensor mask.
|
774 |
+
# Sizes are [batch_size, 1, 1, to_seq_length]
|
775 |
+
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
|
776 |
+
# this attention mask is more simple than the triangular masking of causal attention
|
777 |
+
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
|
778 |
+
attention_mask = attention_mask[:, None, None, :]
|
779 |
+
|
780 |
+
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
781 |
+
# masked positions, this operation will create a tensor which is 0.0 for
|
782 |
+
# positions we want to attend and the dtype's smallest value for masked positions.
|
783 |
+
# Since we are adding it to the raw scores before the softmax, this is
|
784 |
+
# effectively the same as removing these entirely.
|
785 |
+
attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
|
786 |
+
attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
|
787 |
+
|
788 |
+
# If a 2D or 3D attention mask is provided for the cross-attention
|
789 |
+
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
790 |
+
# if self.config.add_cross_attention and encoder_hidden_states is not None:
|
791 |
+
# encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
|
792 |
+
# encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
|
793 |
+
# if encoder_attention_mask is None:
|
794 |
+
# encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
|
795 |
+
# encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask)
|
796 |
+
# else:
|
797 |
+
# encoder_attention_mask = None
|
798 |
+
encoder_attention_mask = None
|
799 |
+
|
800 |
+
# Prepare head mask if needed
|
801 |
+
# 1.0 in head_mask indicate we keep the head
|
802 |
+
# attention_probs has shape bsz x n_heads x N x N
|
803 |
+
# head_mask has shape n_layer x batch x n_heads x N x N
|
804 |
+
head_mask = self.get_head_mask(head_mask, self.config.n_layer)
|
805 |
+
|
806 |
+
if inputs_embeds is None:
|
807 |
+
inputs_embeds = self.wte(input_ids)
|
808 |
+
|
809 |
+
# Mup
|
810 |
+
if self.use_mup:
|
811 |
+
inputs_embeds = inputs_embeds * self.input_mult
|
812 |
+
if self.relative_encoding is None:
|
813 |
+
position_embeds = self.wpe(position_ids)
|
814 |
+
hidden_states = inputs_embeds + position_embeds
|
815 |
+
elif self.relative_encoding == 'rotary':
|
816 |
+
hidden_states = inputs_embeds
|
817 |
+
if token_type_ids is not None:
|
818 |
+
token_type_embeds = self.wte(token_type_ids)
|
819 |
+
hidden_states = hidden_states + token_type_embeds
|
820 |
+
hidden_states = self.drop(hidden_states)
|
821 |
+
|
822 |
+
output_shape = input_shape + (hidden_states.size(-1),)
|
823 |
+
|
824 |
+
presents = () if use_cache else None
|
825 |
+
all_self_attentions = () if output_attentions else None
|
826 |
+
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
|
827 |
+
all_hidden_states = () if output_hidden_states else None
|
828 |
+
# debug_print_tensor(hidden_states, 'after embedding')
|
829 |
+
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
|
830 |
+
|
831 |
+
# Model parallel
|
832 |
+
if self.model_parallel:
|
833 |
+
torch.cuda.set_device(hidden_states.device)
|
834 |
+
# Ensure layer_past is on same device as hidden_states (might not be correct)
|
835 |
+
if layer_past is not None:
|
836 |
+
layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past)
|
837 |
+
# Ensure that attention_mask is always on the same device as hidden_states
|
838 |
+
if attention_mask is not None:
|
839 |
+
attention_mask = attention_mask.to(hidden_states.device)
|
840 |
+
if isinstance(head_mask, torch.Tensor):
|
841 |
+
head_mask = head_mask.to(hidden_states.device)
|
842 |
+
if output_hidden_states:
|
843 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
844 |
+
|
845 |
+
if self.gradient_checkpointing and self.training:
|
846 |
+
|
847 |
+
if use_cache:
|
848 |
+
logger.warning(
|
849 |
+
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
850 |
+
)
|
851 |
+
use_cache = False
|
852 |
+
|
853 |
+
def create_custom_forward(module):
|
854 |
+
def custom_forward(*inputs):
|
855 |
+
# None for past_key_value
|
856 |
+
return module(*inputs, use_cache, output_attentions)
|
857 |
+
|
858 |
+
return custom_forward
|
859 |
+
|
860 |
+
outputs = torch.utils.checkpoint.checkpoint(
|
861 |
+
create_custom_forward(block),
|
862 |
+
hidden_states,
|
863 |
+
None,
|
864 |
+
attention_mask,
|
865 |
+
head_mask[i],
|
866 |
+
encoder_hidden_states,
|
867 |
+
encoder_attention_mask,
|
868 |
+
)
|
869 |
+
else:
|
870 |
+
outputs = block(
|
871 |
+
hidden_states,
|
872 |
+
layer_past=layer_past,
|
873 |
+
attention_mask=attention_mask,
|
874 |
+
head_mask=head_mask[i],
|
875 |
+
encoder_hidden_states=encoder_hidden_states,
|
876 |
+
encoder_attention_mask=encoder_attention_mask,
|
877 |
+
rotary_embedding=self.wpe if self.relative_encoding == 'rotary' else None,
|
878 |
+
use_cache=use_cache,
|
879 |
+
output_attentions=output_attentions
|
880 |
+
)
|
881 |
+
|
882 |
+
hidden_states = outputs[0]
|
883 |
+
if use_cache is True:
|
884 |
+
presents = presents + (outputs[1],)
|
885 |
+
|
886 |
+
if output_attentions:
|
887 |
+
all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
|
888 |
+
# if self.config.add_cross_attention:
|
889 |
+
# all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],)
|
890 |
+
|
891 |
+
# Model Parallel: If it's the last layer for that device, put things on the next device
|
892 |
+
if self.model_parallel:
|
893 |
+
for k, v in self.device_map.items():
|
894 |
+
if i == v[-1] and "cuda:" + str(k) != self.last_device:
|
895 |
+
hidden_states = hidden_states.to("cuda:" + str(k + 1))
|
896 |
+
|
897 |
+
hidden_states = self.ln_f(hidden_states)
|
898 |
+
|
899 |
+
hidden_states = hidden_states.view(output_shape)
|
900 |
+
# Add last hidden state
|
901 |
+
if output_hidden_states:
|
902 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
903 |
+
|
904 |
+
if not return_dict:
|
905 |
+
return tuple(
|
906 |
+
v
|
907 |
+
for v in [hidden_states, presents, all_hidden_states, all_self_attentions, all_cross_attentions]
|
908 |
+
if v is not None
|
909 |
+
)
|
910 |
+
|
911 |
+
return BaseModelOutputWithPastAndCrossAttentions(
|
912 |
+
last_hidden_state=hidden_states,
|
913 |
+
past_key_values=presents,
|
914 |
+
hidden_states=all_hidden_states,
|
915 |
+
attentions=all_self_attentions,
|
916 |
+
cross_attentions=all_cross_attentions,
|
917 |
+
)
|
918 |
+
|
919 |
+
|
920 |
+
class TELECHAT(TELECHATPretrainedModel):
|
921 |
+
_keys_to_ignore_on_load_missing = [r"attn.masked_bias", r"attn.bias", r"lm_head.weight"]
|
922 |
+
|
923 |
+
def __init__(self, config):
|
924 |
+
super().__init__(config)
|
925 |
+
self.transformer = TELECHATTransformer(config)
|
926 |
+
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
927 |
+
self.use_mup = config.use_mup
|
928 |
+
if self.use_mup:
|
929 |
+
self.mup_scale_factor = config.mup_scale_factor
|
930 |
+
self.output_mult = config.output_mult / self.mup_scale_factor
|
931 |
+
|
932 |
+
# 初始化时先根据config里的开关决定是否开启flashattn, 用户可以通过修改config或者model.enable_flash_attn修改flashattn的开关
|
933 |
+
self.enable_flash_attn(config.enable_flash_attn)
|
934 |
+
|
935 |
+
# Model parallel
|
936 |
+
self.model_parallel = False
|
937 |
+
self.device_map = None
|
938 |
+
|
939 |
+
# Initialize weights and apply final processing
|
940 |
+
self.post_init()
|
941 |
+
def enable_flash_attn(self, enabled: bool):
|
942 |
+
for block in self.transformer.h:
|
943 |
+
block.attn.use_flash_attn = enabled
|
944 |
+
print(f"TELECHAT flash attention {'enabled' if enabled else 'disabled'}")
|
945 |
+
# torch.backends.cuda.enable_flash_sdp(enabled)
|
946 |
+
def set_max_positions(self, max_positions):
|
947 |
+
for layer in self.transformer.h:
|
948 |
+
device = layer.ln_1.weight.device
|
949 |
+
layer.attn.set_max_positions(max_positions, device=device)
|
950 |
+
|
951 |
+
def parallelize(self, device_map=None):
|
952 |
+
self.device_map = (
|
953 |
+
get_device_map(len(self.transformer.h), range(torch.cuda.device_count()))
|
954 |
+
if device_map is None
|
955 |
+
else device_map
|
956 |
+
)
|
957 |
+
assert_device_map(self.device_map, len(self.transformer.h))
|
958 |
+
self.transformer.parallelize(self.device_map)
|
959 |
+
self.lm_head = self.lm_head.to(self.transformer.first_device)
|
960 |
+
self.model_parallel = True
|
961 |
+
|
962 |
+
def deparallelize(self):
|
963 |
+
self.transformer.deparallelize()
|
964 |
+
self.transformer = self.transformer.to("cpu")
|
965 |
+
self.lm_head = self.lm_head.to("cpu")
|
966 |
+
self.model_parallel = False
|
967 |
+
torch.cuda.empty_cache()
|
968 |
+
|
969 |
+
def get_output_embeddings(self):
|
970 |
+
return self.lm_head
|
971 |
+
|
972 |
+
def set_output_embeddings(self, new_embeddings):
|
973 |
+
self.lm_head = new_embeddings
|
974 |
+
|
975 |
+
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs):
|
976 |
+
token_type_ids = kwargs.get("token_type_ids", None)
|
977 |
+
# only last token for inputs_ids if past is defined in kwargs
|
978 |
+
if past_key_values:
|
979 |
+
input_ids = input_ids[:, -1].unsqueeze(-1)
|
980 |
+
if token_type_ids is not None:
|
981 |
+
token_type_ids = token_type_ids[:, -1].unsqueeze(-1)
|
982 |
+
|
983 |
+
attention_mask = kwargs.get("attention_mask", None)
|
984 |
+
position_ids = kwargs.get("position_ids", None)
|
985 |
+
|
986 |
+
if attention_mask is not None and position_ids is None:
|
987 |
+
# create position_ids on the fly for batch generation
|
988 |
+
position_ids = attention_mask.long().cumsum(-1) - 1
|
989 |
+
position_ids.masked_fill_(attention_mask == 0, 1)
|
990 |
+
if past_key_values:
|
991 |
+
position_ids = position_ids[:, -1].unsqueeze(-1)
|
992 |
+
else:
|
993 |
+
position_ids = None
|
994 |
+
return {
|
995 |
+
"input_ids": input_ids,
|
996 |
+
"past_key_values": past_key_values,
|
997 |
+
"use_cache": kwargs.get("use_cache"),
|
998 |
+
"position_ids": position_ids,
|
999 |
+
"attention_mask": attention_mask,
|
1000 |
+
"token_type_ids": token_type_ids,
|
1001 |
+
}
|
1002 |
+
|
1003 |
+
def forward(
|
1004 |
+
self,
|
1005 |
+
input_ids: Optional[torch.LongTensor] = None,
|
1006 |
+
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
1007 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
1008 |
+
token_type_ids: Optional[torch.LongTensor] = None,
|
1009 |
+
position_ids: Optional[torch.LongTensor] = None,
|
1010 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
1011 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
1012 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
1013 |
+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
1014 |
+
labels: Optional[torch.LongTensor] = None,
|
1015 |
+
use_cache: Optional[bool] = None,
|
1016 |
+
output_attentions: Optional[bool] = None,
|
1017 |
+
output_hidden_states: Optional[bool] = None,
|
1018 |
+
return_dict: Optional[bool] = None,
|
1019 |
+
) -> Union[Tuple, CausalLMOutputWithCrossAttentions, SequenceClassifierOutputWithPast]:
|
1020 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1021 |
+
|
1022 |
+
transformer_outputs = self.transformer(
|
1023 |
+
input_ids,
|
1024 |
+
past_key_values=past_key_values,
|
1025 |
+
attention_mask=attention_mask,
|
1026 |
+
token_type_ids=token_type_ids,
|
1027 |
+
position_ids=position_ids,
|
1028 |
+
head_mask=head_mask,
|
1029 |
+
inputs_embeds=inputs_embeds,
|
1030 |
+
encoder_hidden_states=encoder_hidden_states,
|
1031 |
+
encoder_attention_mask=encoder_attention_mask,
|
1032 |
+
use_cache=use_cache,
|
1033 |
+
output_attentions=output_attentions,
|
1034 |
+
output_hidden_states=output_hidden_states,
|
1035 |
+
return_dict=return_dict
|
1036 |
+
)
|
1037 |
+
hidden_states = transformer_outputs[0]
|
1038 |
+
|
1039 |
+
# Set device for model parallelism
|
1040 |
+
if self.model_parallel:
|
1041 |
+
torch.cuda.set_device(self.transformer.first_device)
|
1042 |
+
hidden_states = hidden_states.to(self.lm_head.weight.device)
|
1043 |
+
|
1044 |
+
lm_logits = self.lm_head(hidden_states)
|
1045 |
+
# Mup
|
1046 |
+
if self.use_mup:
|
1047 |
+
lm_logits = lm_logits * self.output_mult
|
1048 |
+
|
1049 |
+
loss = None
|
1050 |
+
if labels is not None:
|
1051 |
+
# Shift so that tokens < n predict n
|
1052 |
+
shift_logits = lm_logits[..., :-1, :].contiguous()
|
1053 |
+
shift_labels = labels[..., 1:].contiguous()
|
1054 |
+
# Flatten the tokens
|
1055 |
+
loss_fct = nn.CrossEntropyLoss()
|
1056 |
+
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
|
1057 |
+
|
1058 |
+
if not return_dict:
|
1059 |
+
output = (lm_logits,) + transformer_outputs[1:]
|
1060 |
+
return ((loss,) + output) if loss is not None else output
|
1061 |
+
|
1062 |
+
return CausalLMOutputWithCrossAttentions(
|
1063 |
+
loss=loss,
|
1064 |
+
logits=lm_logits,
|
1065 |
+
past_key_values=transformer_outputs.past_key_values,
|
1066 |
+
hidden_states=transformer_outputs.hidden_states,
|
1067 |
+
attentions=transformer_outputs.attentions,
|
1068 |
+
cross_attentions=transformer_outputs.cross_attentions,
|
1069 |
+
)
|
1070 |
+
|
1071 |
+
def chat(self,tokenizer, question, history_input_list, history_output_list,generation_config):
|
1072 |
+
'''
|
1073 |
+
:param question: 当前问题
|
1074 |
+
:param history_input_list: 历史问题列表, list of strings
|
1075 |
+
:param history_output_list: 历史回答列表, list of string
|
1076 |
+
:return: response
|
1077 |
+
'''
|
1078 |
+
|
1079 |
+
inputs = ""
|
1080 |
+
assert len(history_output_list) == len(history_output_list)
|
1081 |
+
for i in range(len(history_input_list)):
|
1082 |
+
inputs += "<_user>" + history_input_list[i] + "<_bot>" + history_output_list[i] + "<_end>"
|
1083 |
+
inputs += "<_user>" + question + "<_bot>"
|
1084 |
+
print("input:", inputs)
|
1085 |
+
input_ids = tokenizer.encode(inputs,
|
1086 |
+
return_tensors="pt"
|
1087 |
+
)
|
1088 |
+
if len(input_ids[0]) >= 2000:
|
1089 |
+
input_ids = input_ids[:, -2000:]
|
1090 |
+
input_ids = input_ids.to(0)
|
1091 |
+
output = self.generate(input_ids,generation_config)
|
1092 |
+
response = tokenizer.decode(output[0].cpu().numpy().tolist()).split('<_bot>')[-1].split('</s>')[0]
|
1093 |
+
return response
|
1094 |
+
|
1095 |
+
@staticmethod
|
1096 |
+
def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) -> Tuple[Tuple[torch.Tensor]]:
|
1097 |
+
"""
|
1098 |
+
This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
|
1099 |
+
[`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
|
1100 |
+
beam_idx at every generation step.
|
1101 |
+
"""
|
1102 |
+
return tuple(
|
1103 |
+
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
|
1104 |
+
for layer_past in past
|
1105 |
+
)
|
pytorch_model.bin.index.json
ADDED
@@ -0,0 +1,458 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"metadata": {
|
3 |
+
"total_size": 105665020032
|
4 |
+
},
|
5 |
+
"weight_map": {
|
6 |
+
"lm_head.weight": "pytorch_model-00011-of-00011.bin",
|
7 |
+
"transformer.h.0.attn.c_attn.weight": "pytorch_model-00001-of-00011.bin",
|
8 |
+
"transformer.h.0.attn.c_proj.weight": "pytorch_model-00001-of-00011.bin",
|
9 |
+
"transformer.h.0.attn.masked_bias": "pytorch_model-00001-of-00011.bin",
|
10 |
+
"transformer.h.0.ln_1.weight": "pytorch_model-00001-of-00011.bin",
|
11 |
+
"transformer.h.0.ln_2.weight": "pytorch_model-00001-of-00011.bin",
|
12 |
+
"transformer.h.0.mlp.c_fc.weight": "pytorch_model-00001-of-00011.bin",
|
13 |
+
"transformer.h.0.mlp.c_proj.weight": "pytorch_model-00001-of-00011.bin",
|
14 |
+
"transformer.h.1.attn.c_attn.weight": "pytorch_model-00001-of-00011.bin",
|
15 |
+
"transformer.h.1.attn.c_proj.weight": "pytorch_model-00001-of-00011.bin",
|
16 |
+
"transformer.h.1.attn.masked_bias": "pytorch_model-00001-of-00011.bin",
|
17 |
+
"transformer.h.1.ln_1.weight": "pytorch_model-00001-of-00011.bin",
|
18 |
+
"transformer.h.1.ln_2.weight": "pytorch_model-00001-of-00011.bin",
|
19 |
+
"transformer.h.1.mlp.c_fc.weight": "pytorch_model-00001-of-00011.bin",
|
20 |
+
"transformer.h.1.mlp.c_proj.weight": "pytorch_model-00001-of-00011.bin",
|
21 |
+
"transformer.h.10.attn.c_attn.weight": "pytorch_model-00002-of-00011.bin",
|
22 |
+
"transformer.h.10.attn.c_proj.weight": "pytorch_model-00002-of-00011.bin",
|
23 |
+
"transformer.h.10.attn.masked_bias": "pytorch_model-00002-of-00011.bin",
|
24 |
+
"transformer.h.10.ln_1.weight": "pytorch_model-00002-of-00011.bin",
|
25 |
+
"transformer.h.10.ln_2.weight": "pytorch_model-00002-of-00011.bin",
|
26 |
+
"transformer.h.10.mlp.c_fc.weight": "pytorch_model-00002-of-00011.bin",
|
27 |
+
"transformer.h.10.mlp.c_proj.weight": "pytorch_model-00002-of-00011.bin",
|
28 |
+
"transformer.h.11.attn.c_attn.weight": "pytorch_model-00002-of-00011.bin",
|
29 |
+
"transformer.h.11.attn.c_proj.weight": "pytorch_model-00002-of-00011.bin",
|
30 |
+
"transformer.h.11.attn.masked_bias": "pytorch_model-00002-of-00011.bin",
|
31 |
+
"transformer.h.11.ln_1.weight": "pytorch_model-00002-of-00011.bin",
|
32 |
+
"transformer.h.11.ln_2.weight": "pytorch_model-00002-of-00011.bin",
|
33 |
+
"transformer.h.11.mlp.c_fc.weight": "pytorch_model-00003-of-00011.bin",
|
34 |
+
"transformer.h.11.mlp.c_proj.weight": "pytorch_model-00003-of-00011.bin",
|
35 |
+
"transformer.h.12.attn.c_attn.weight": "pytorch_model-00003-of-00011.bin",
|
36 |
+
"transformer.h.12.attn.c_proj.weight": "pytorch_model-00003-of-00011.bin",
|
37 |
+
"transformer.h.12.attn.masked_bias": "pytorch_model-00003-of-00011.bin",
|
38 |
+
"transformer.h.12.ln_1.weight": "pytorch_model-00003-of-00011.bin",
|
39 |
+
"transformer.h.12.ln_2.weight": "pytorch_model-00003-of-00011.bin",
|
40 |
+
"transformer.h.12.mlp.c_fc.weight": "pytorch_model-00003-of-00011.bin",
|
41 |
+
"transformer.h.12.mlp.c_proj.weight": "pytorch_model-00003-of-00011.bin",
|
42 |
+
"transformer.h.13.attn.c_attn.weight": "pytorch_model-00003-of-00011.bin",
|
43 |
+
"transformer.h.13.attn.c_proj.weight": "pytorch_model-00003-of-00011.bin",
|
44 |
+
"transformer.h.13.attn.masked_bias": "pytorch_model-00003-of-00011.bin",
|
45 |
+
"transformer.h.13.ln_1.weight": "pytorch_model-00003-of-00011.bin",
|
46 |
+
"transformer.h.13.ln_2.weight": "pytorch_model-00003-of-00011.bin",
|
47 |
+
"transformer.h.13.mlp.c_fc.weight": "pytorch_model-00003-of-00011.bin",
|
48 |
+
"transformer.h.13.mlp.c_proj.weight": "pytorch_model-00003-of-00011.bin",
|
49 |
+
"transformer.h.14.attn.c_attn.weight": "pytorch_model-00003-of-00011.bin",
|
50 |
+
"transformer.h.14.attn.c_proj.weight": "pytorch_model-00003-of-00011.bin",
|
51 |
+
"transformer.h.14.attn.masked_bias": "pytorch_model-00003-of-00011.bin",
|
52 |
+
"transformer.h.14.ln_1.weight": "pytorch_model-00003-of-00011.bin",
|
53 |
+
"transformer.h.14.ln_2.weight": "pytorch_model-00003-of-00011.bin",
|
54 |
+
"transformer.h.14.mlp.c_fc.weight": "pytorch_model-00003-of-00011.bin",
|
55 |
+
"transformer.h.14.mlp.c_proj.weight": "pytorch_model-00003-of-00011.bin",
|
56 |
+
"transformer.h.15.attn.c_attn.weight": "pytorch_model-00003-of-00011.bin",
|
57 |
+
"transformer.h.15.attn.c_proj.weight": "pytorch_model-00003-of-00011.bin",
|
58 |
+
"transformer.h.15.attn.masked_bias": "pytorch_model-00003-of-00011.bin",
|
59 |
+
"transformer.h.15.ln_1.weight": "pytorch_model-00003-of-00011.bin",
|
60 |
+
"transformer.h.15.ln_2.weight": "pytorch_model-00003-of-00011.bin",
|
61 |
+
"transformer.h.15.mlp.c_fc.weight": "pytorch_model-00003-of-00011.bin",
|
62 |
+
"transformer.h.15.mlp.c_proj.weight": "pytorch_model-00003-of-00011.bin",
|
63 |
+
"transformer.h.16.attn.c_attn.weight": "pytorch_model-00003-of-00011.bin",
|
64 |
+
"transformer.h.16.attn.c_proj.weight": "pytorch_model-00003-of-00011.bin",
|
65 |
+
"transformer.h.16.attn.masked_bias": "pytorch_model-00003-of-00011.bin",
|
66 |
+
"transformer.h.16.ln_1.weight": "pytorch_model-00003-of-00011.bin",
|
67 |
+
"transformer.h.16.ln_2.weight": "pytorch_model-00003-of-00011.bin",
|
68 |
+
"transformer.h.16.mlp.c_fc.weight": "pytorch_model-00003-of-00011.bin",
|
69 |
+
"transformer.h.16.mlp.c_proj.weight": "pytorch_model-00003-of-00011.bin",
|
70 |
+
"transformer.h.17.attn.c_attn.weight": "pytorch_model-00003-of-00011.bin",
|
71 |
+
"transformer.h.17.attn.c_proj.weight": "pytorch_model-00003-of-00011.bin",
|
72 |
+
"transformer.h.17.attn.masked_bias": "pytorch_model-00003-of-00011.bin",
|
73 |
+
"transformer.h.17.ln_1.weight": "pytorch_model-00003-of-00011.bin",
|
74 |
+
"transformer.h.17.ln_2.weight": "pytorch_model-00003-of-00011.bin",
|
75 |
+
"transformer.h.17.mlp.c_fc.weight": "pytorch_model-00004-of-00011.bin",
|
76 |
+
"transformer.h.17.mlp.c_proj.weight": "pytorch_model-00004-of-00011.bin",
|
77 |
+
"transformer.h.18.attn.c_attn.weight": "pytorch_model-00004-of-00011.bin",
|
78 |
+
"transformer.h.18.attn.c_proj.weight": "pytorch_model-00004-of-00011.bin",
|
79 |
+
"transformer.h.18.attn.masked_bias": "pytorch_model-00004-of-00011.bin",
|
80 |
+
"transformer.h.18.ln_1.weight": "pytorch_model-00004-of-00011.bin",
|
81 |
+
"transformer.h.18.ln_2.weight": "pytorch_model-00004-of-00011.bin",
|
82 |
+
"transformer.h.18.mlp.c_fc.weight": "pytorch_model-00004-of-00011.bin",
|
83 |
+
"transformer.h.18.mlp.c_proj.weight": "pytorch_model-00004-of-00011.bin",
|
84 |
+
"transformer.h.19.attn.c_attn.weight": "pytorch_model-00004-of-00011.bin",
|
85 |
+
"transformer.h.19.attn.c_proj.weight": "pytorch_model-00004-of-00011.bin",
|
86 |
+
"transformer.h.19.attn.masked_bias": "pytorch_model-00004-of-00011.bin",
|
87 |
+
"transformer.h.19.ln_1.weight": "pytorch_model-00004-of-00011.bin",
|
88 |
+
"transformer.h.19.ln_2.weight": "pytorch_model-00004-of-00011.bin",
|
89 |
+
"transformer.h.19.mlp.c_fc.weight": "pytorch_model-00004-of-00011.bin",
|
90 |
+
"transformer.h.19.mlp.c_proj.weight": "pytorch_model-00004-of-00011.bin",
|
91 |
+
"transformer.h.2.attn.c_attn.weight": "pytorch_model-00001-of-00011.bin",
|
92 |
+
"transformer.h.2.attn.c_proj.weight": "pytorch_model-00001-of-00011.bin",
|
93 |
+
"transformer.h.2.attn.masked_bias": "pytorch_model-00001-of-00011.bin",
|
94 |
+
"transformer.h.2.ln_1.weight": "pytorch_model-00001-of-00011.bin",
|
95 |
+
"transformer.h.2.ln_2.weight": "pytorch_model-00001-of-00011.bin",
|
96 |
+
"transformer.h.2.mlp.c_fc.weight": "pytorch_model-00001-of-00011.bin",
|
97 |
+
"transformer.h.2.mlp.c_proj.weight": "pytorch_model-00001-of-00011.bin",
|
98 |
+
"transformer.h.20.attn.c_attn.weight": "pytorch_model-00004-of-00011.bin",
|
99 |
+
"transformer.h.20.attn.c_proj.weight": "pytorch_model-00004-of-00011.bin",
|
100 |
+
"transformer.h.20.attn.masked_bias": "pytorch_model-00004-of-00011.bin",
|
101 |
+
"transformer.h.20.ln_1.weight": "pytorch_model-00004-of-00011.bin",
|
102 |
+
"transformer.h.20.ln_2.weight": "pytorch_model-00004-of-00011.bin",
|
103 |
+
"transformer.h.20.mlp.c_fc.weight": "pytorch_model-00004-of-00011.bin",
|
104 |
+
"transformer.h.20.mlp.c_proj.weight": "pytorch_model-00004-of-00011.bin",
|
105 |
+
"transformer.h.21.attn.c_attn.weight": "pytorch_model-00004-of-00011.bin",
|
106 |
+
"transformer.h.21.attn.c_proj.weight": "pytorch_model-00004-of-00011.bin",
|
107 |
+
"transformer.h.21.attn.masked_bias": "pytorch_model-00004-of-00011.bin",
|
108 |
+
"transformer.h.21.ln_1.weight": "pytorch_model-00004-of-00011.bin",
|
109 |
+
"transformer.h.21.ln_2.weight": "pytorch_model-00004-of-00011.bin",
|
110 |
+
"transformer.h.21.mlp.c_fc.weight": "pytorch_model-00004-of-00011.bin",
|
111 |
+
"transformer.h.21.mlp.c_proj.weight": "pytorch_model-00004-of-00011.bin",
|
112 |
+
"transformer.h.22.attn.c_attn.weight": "pytorch_model-00004-of-00011.bin",
|
113 |
+
"transformer.h.22.attn.c_proj.weight": "pytorch_model-00004-of-00011.bin",
|
114 |
+
"transformer.h.22.attn.masked_bias": "pytorch_model-00004-of-00011.bin",
|
115 |
+
"transformer.h.22.ln_1.weight": "pytorch_model-00004-of-00011.bin",
|
116 |
+
"transformer.h.22.ln_2.weight": "pytorch_model-00004-of-00011.bin",
|
117 |
+
"transformer.h.22.mlp.c_fc.weight": "pytorch_model-00004-of-00011.bin",
|
118 |
+
"transformer.h.22.mlp.c_proj.weight": "pytorch_model-00004-of-00011.bin",
|
119 |
+
"transformer.h.23.attn.c_attn.weight": "pytorch_model-00004-of-00011.bin",
|
120 |
+
"transformer.h.23.attn.c_proj.weight": "pytorch_model-00004-of-00011.bin",
|
121 |
+
"transformer.h.23.attn.masked_bias": "pytorch_model-00004-of-00011.bin",
|
122 |
+
"transformer.h.23.ln_1.weight": "pytorch_model-00004-of-00011.bin",
|
123 |
+
"transformer.h.23.ln_2.weight": "pytorch_model-00004-of-00011.bin",
|
124 |
+
"transformer.h.23.mlp.c_fc.weight": "pytorch_model-00005-of-00011.bin",
|
125 |
+
"transformer.h.23.mlp.c_proj.weight": "pytorch_model-00005-of-00011.bin",
|
126 |
+
"transformer.h.24.attn.c_attn.weight": "pytorch_model-00005-of-00011.bin",
|
127 |
+
"transformer.h.24.attn.c_proj.weight": "pytorch_model-00005-of-00011.bin",
|
128 |
+
"transformer.h.24.attn.masked_bias": "pytorch_model-00005-of-00011.bin",
|
129 |
+
"transformer.h.24.ln_1.weight": "pytorch_model-00005-of-00011.bin",
|
130 |
+
"transformer.h.24.ln_2.weight": "pytorch_model-00005-of-00011.bin",
|
131 |
+
"transformer.h.24.mlp.c_fc.weight": "pytorch_model-00005-of-00011.bin",
|
132 |
+
"transformer.h.24.mlp.c_proj.weight": "pytorch_model-00005-of-00011.bin",
|
133 |
+
"transformer.h.25.attn.c_attn.weight": "pytorch_model-00005-of-00011.bin",
|
134 |
+
"transformer.h.25.attn.c_proj.weight": "pytorch_model-00005-of-00011.bin",
|
135 |
+
"transformer.h.25.attn.masked_bias": "pytorch_model-00005-of-00011.bin",
|
136 |
+
"transformer.h.25.ln_1.weight": "pytorch_model-00005-of-00011.bin",
|
137 |
+
"transformer.h.25.ln_2.weight": "pytorch_model-00005-of-00011.bin",
|
138 |
+
"transformer.h.25.mlp.c_fc.weight": "pytorch_model-00005-of-00011.bin",
|
139 |
+
"transformer.h.25.mlp.c_proj.weight": "pytorch_model-00005-of-00011.bin",
|
140 |
+
"transformer.h.26.attn.c_attn.weight": "pytorch_model-00005-of-00011.bin",
|
141 |
+
"transformer.h.26.attn.c_proj.weight": "pytorch_model-00005-of-00011.bin",
|
142 |
+
"transformer.h.26.attn.masked_bias": "pytorch_model-00005-of-00011.bin",
|
143 |
+
"transformer.h.26.ln_1.weight": "pytorch_model-00005-of-00011.bin",
|
144 |
+
"transformer.h.26.ln_2.weight": "pytorch_model-00005-of-00011.bin",
|
145 |
+
"transformer.h.26.mlp.c_fc.weight": "pytorch_model-00005-of-00011.bin",
|
146 |
+
"transformer.h.26.mlp.c_proj.weight": "pytorch_model-00005-of-00011.bin",
|
147 |
+
"transformer.h.27.attn.c_attn.weight": "pytorch_model-00005-of-00011.bin",
|
148 |
+
"transformer.h.27.attn.c_proj.weight": "pytorch_model-00005-of-00011.bin",
|
149 |
+
"transformer.h.27.attn.masked_bias": "pytorch_model-00005-of-00011.bin",
|
150 |
+
"transformer.h.27.ln_1.weight": "pytorch_model-00005-of-00011.bin",
|
151 |
+
"transformer.h.27.ln_2.weight": "pytorch_model-00005-of-00011.bin",
|
152 |
+
"transformer.h.27.mlp.c_fc.weight": "pytorch_model-00005-of-00011.bin",
|
153 |
+
"transformer.h.27.mlp.c_proj.weight": "pytorch_model-00005-of-00011.bin",
|
154 |
+
"transformer.h.28.attn.c_attn.weight": "pytorch_model-00005-of-00011.bin",
|
155 |
+
"transformer.h.28.attn.c_proj.weight": "pytorch_model-00005-of-00011.bin",
|
156 |
+
"transformer.h.28.attn.masked_bias": "pytorch_model-00005-of-00011.bin",
|
157 |
+
"transformer.h.28.ln_1.weight": "pytorch_model-00005-of-00011.bin",
|
158 |
+
"transformer.h.28.ln_2.weight": "pytorch_model-00005-of-00011.bin",
|
159 |
+
"transformer.h.28.mlp.c_fc.weight": "pytorch_model-00005-of-00011.bin",
|
160 |
+
"transformer.h.28.mlp.c_proj.weight": "pytorch_model-00005-of-00011.bin",
|
161 |
+
"transformer.h.29.attn.c_attn.weight": "pytorch_model-00005-of-00011.bin",
|
162 |
+
"transformer.h.29.attn.c_proj.weight": "pytorch_model-00005-of-00011.bin",
|
163 |
+
"transformer.h.29.attn.masked_bias": "pytorch_model-00005-of-00011.bin",
|
164 |
+
"transformer.h.29.ln_1.weight": "pytorch_model-00005-of-00011.bin",
|
165 |
+
"transformer.h.29.ln_2.weight": "pytorch_model-00005-of-00011.bin",
|
166 |
+
"transformer.h.29.mlp.c_fc.weight": "pytorch_model-00006-of-00011.bin",
|
167 |
+
"transformer.h.29.mlp.c_proj.weight": "pytorch_model-00006-of-00011.bin",
|
168 |
+
"transformer.h.3.attn.c_attn.weight": "pytorch_model-00001-of-00011.bin",
|
169 |
+
"transformer.h.3.attn.c_proj.weight": "pytorch_model-00001-of-00011.bin",
|
170 |
+
"transformer.h.3.attn.masked_bias": "pytorch_model-00001-of-00011.bin",
|
171 |
+
"transformer.h.3.ln_1.weight": "pytorch_model-00001-of-00011.bin",
|
172 |
+
"transformer.h.3.ln_2.weight": "pytorch_model-00001-of-00011.bin",
|
173 |
+
"transformer.h.3.mlp.c_fc.weight": "pytorch_model-00001-of-00011.bin",
|
174 |
+
"transformer.h.3.mlp.c_proj.weight": "pytorch_model-00001-of-00011.bin",
|
175 |
+
"transformer.h.30.attn.c_attn.weight": "pytorch_model-00006-of-00011.bin",
|
176 |
+
"transformer.h.30.attn.c_proj.weight": "pytorch_model-00006-of-00011.bin",
|
177 |
+
"transformer.h.30.attn.masked_bias": "pytorch_model-00006-of-00011.bin",
|
178 |
+
"transformer.h.30.ln_1.weight": "pytorch_model-00006-of-00011.bin",
|
179 |
+
"transformer.h.30.ln_2.weight": "pytorch_model-00006-of-00011.bin",
|
180 |
+
"transformer.h.30.mlp.c_fc.weight": "pytorch_model-00006-of-00011.bin",
|
181 |
+
"transformer.h.30.mlp.c_proj.weight": "pytorch_model-00006-of-00011.bin",
|
182 |
+
"transformer.h.31.attn.c_attn.weight": "pytorch_model-00006-of-00011.bin",
|
183 |
+
"transformer.h.31.attn.c_proj.weight": "pytorch_model-00006-of-00011.bin",
|
184 |
+
"transformer.h.31.attn.masked_bias": "pytorch_model-00006-of-00011.bin",
|
185 |
+
"transformer.h.31.ln_1.weight": "pytorch_model-00006-of-00011.bin",
|
186 |
+
"transformer.h.31.ln_2.weight": "pytorch_model-00006-of-00011.bin",
|
187 |
+
"transformer.h.31.mlp.c_fc.weight": "pytorch_model-00006-of-00011.bin",
|
188 |
+
"transformer.h.31.mlp.c_proj.weight": "pytorch_model-00006-of-00011.bin",
|
189 |
+
"transformer.h.32.attn.c_attn.weight": "pytorch_model-00006-of-00011.bin",
|
190 |
+
"transformer.h.32.attn.c_proj.weight": "pytorch_model-00006-of-00011.bin",
|
191 |
+
"transformer.h.32.attn.masked_bias": "pytorch_model-00006-of-00011.bin",
|
192 |
+
"transformer.h.32.ln_1.weight": "pytorch_model-00006-of-00011.bin",
|
193 |
+
"transformer.h.32.ln_2.weight": "pytorch_model-00006-of-00011.bin",
|
194 |
+
"transformer.h.32.mlp.c_fc.weight": "pytorch_model-00006-of-00011.bin",
|
195 |
+
"transformer.h.32.mlp.c_proj.weight": "pytorch_model-00006-of-00011.bin",
|
196 |
+
"transformer.h.33.attn.c_attn.weight": "pytorch_model-00006-of-00011.bin",
|
197 |
+
"transformer.h.33.attn.c_proj.weight": "pytorch_model-00006-of-00011.bin",
|
198 |
+
"transformer.h.33.attn.masked_bias": "pytorch_model-00006-of-00011.bin",
|
199 |
+
"transformer.h.33.ln_1.weight": "pytorch_model-00006-of-00011.bin",
|
200 |
+
"transformer.h.33.ln_2.weight": "pytorch_model-00006-of-00011.bin",
|
201 |
+
"transformer.h.33.mlp.c_fc.weight": "pytorch_model-00006-of-00011.bin",
|
202 |
+
"transformer.h.33.mlp.c_proj.weight": "pytorch_model-00006-of-00011.bin",
|
203 |
+
"transformer.h.34.attn.c_attn.weight": "pytorch_model-00006-of-00011.bin",
|
204 |
+
"transformer.h.34.attn.c_proj.weight": "pytorch_model-00006-of-00011.bin",
|
205 |
+
"transformer.h.34.attn.masked_bias": "pytorch_model-00006-of-00011.bin",
|
206 |
+
"transformer.h.34.ln_1.weight": "pytorch_model-00006-of-00011.bin",
|
207 |
+
"transformer.h.34.ln_2.weight": "pytorch_model-00006-of-00011.bin",
|
208 |
+
"transformer.h.34.mlp.c_fc.weight": "pytorch_model-00006-of-00011.bin",
|
209 |
+
"transformer.h.34.mlp.c_proj.weight": "pytorch_model-00006-of-00011.bin",
|
210 |
+
"transformer.h.35.attn.c_attn.weight": "pytorch_model-00006-of-00011.bin",
|
211 |
+
"transformer.h.35.attn.c_proj.weight": "pytorch_model-00006-of-00011.bin",
|
212 |
+
"transformer.h.35.attn.masked_bias": "pytorch_model-00006-of-00011.bin",
|
213 |
+
"transformer.h.35.ln_1.weight": "pytorch_model-00006-of-00011.bin",
|
214 |
+
"transformer.h.35.ln_2.weight": "pytorch_model-00006-of-00011.bin",
|
215 |
+
"transformer.h.35.mlp.c_fc.weight": "pytorch_model-00007-of-00011.bin",
|
216 |
+
"transformer.h.35.mlp.c_proj.weight": "pytorch_model-00007-of-00011.bin",
|
217 |
+
"transformer.h.36.attn.c_attn.weight": "pytorch_model-00007-of-00011.bin",
|
218 |
+
"transformer.h.36.attn.c_proj.weight": "pytorch_model-00007-of-00011.bin",
|
219 |
+
"transformer.h.36.attn.masked_bias": "pytorch_model-00007-of-00011.bin",
|
220 |
+
"transformer.h.36.ln_1.weight": "pytorch_model-00007-of-00011.bin",
|
221 |
+
"transformer.h.36.ln_2.weight": "pytorch_model-00007-of-00011.bin",
|
222 |
+
"transformer.h.36.mlp.c_fc.weight": "pytorch_model-00007-of-00011.bin",
|
223 |
+
"transformer.h.36.mlp.c_proj.weight": "pytorch_model-00007-of-00011.bin",
|
224 |
+
"transformer.h.37.attn.c_attn.weight": "pytorch_model-00007-of-00011.bin",
|
225 |
+
"transformer.h.37.attn.c_proj.weight": "pytorch_model-00007-of-00011.bin",
|
226 |
+
"transformer.h.37.attn.masked_bias": "pytorch_model-00007-of-00011.bin",
|
227 |
+
"transformer.h.37.ln_1.weight": "pytorch_model-00007-of-00011.bin",
|
228 |
+
"transformer.h.37.ln_2.weight": "pytorch_model-00007-of-00011.bin",
|
229 |
+
"transformer.h.37.mlp.c_fc.weight": "pytorch_model-00007-of-00011.bin",
|
230 |
+
"transformer.h.37.mlp.c_proj.weight": "pytorch_model-00007-of-00011.bin",
|
231 |
+
"transformer.h.38.attn.c_attn.weight": "pytorch_model-00007-of-00011.bin",
|
232 |
+
"transformer.h.38.attn.c_proj.weight": "pytorch_model-00007-of-00011.bin",
|
233 |
+
"transformer.h.38.attn.masked_bias": "pytorch_model-00007-of-00011.bin",
|
234 |
+
"transformer.h.38.ln_1.weight": "pytorch_model-00007-of-00011.bin",
|
235 |
+
"transformer.h.38.ln_2.weight": "pytorch_model-00007-of-00011.bin",
|
236 |
+
"transformer.h.38.mlp.c_fc.weight": "pytorch_model-00007-of-00011.bin",
|
237 |
+
"transformer.h.38.mlp.c_proj.weight": "pytorch_model-00007-of-00011.bin",
|
238 |
+
"transformer.h.39.attn.c_attn.weight": "pytorch_model-00007-of-00011.bin",
|
239 |
+
"transformer.h.39.attn.c_proj.weight": "pytorch_model-00007-of-00011.bin",
|
240 |
+
"transformer.h.39.attn.masked_bias": "pytorch_model-00007-of-00011.bin",
|
241 |
+
"transformer.h.39.ln_1.weight": "pytorch_model-00007-of-00011.bin",
|
242 |
+
"transformer.h.39.ln_2.weight": "pytorch_model-00007-of-00011.bin",
|
243 |
+
"transformer.h.39.mlp.c_fc.weight": "pytorch_model-00007-of-00011.bin",
|
244 |
+
"transformer.h.39.mlp.c_proj.weight": "pytorch_model-00007-of-00011.bin",
|
245 |
+
"transformer.h.4.attn.c_attn.weight": "pytorch_model-00001-of-00011.bin",
|
246 |
+
"transformer.h.4.attn.c_proj.weight": "pytorch_model-00001-of-00011.bin",
|
247 |
+
"transformer.h.4.attn.masked_bias": "pytorch_model-00001-of-00011.bin",
|
248 |
+
"transformer.h.4.ln_1.weight": "pytorch_model-00001-of-00011.bin",
|
249 |
+
"transformer.h.4.ln_2.weight": "pytorch_model-00001-of-00011.bin",
|
250 |
+
"transformer.h.4.mlp.c_fc.weight": "pytorch_model-00001-of-00011.bin",
|
251 |
+
"transformer.h.4.mlp.c_proj.weight": "pytorch_model-00001-of-00011.bin",
|
252 |
+
"transformer.h.40.attn.c_attn.weight": "pytorch_model-00007-of-00011.bin",
|
253 |
+
"transformer.h.40.attn.c_proj.weight": "pytorch_model-00007-of-00011.bin",
|
254 |
+
"transformer.h.40.attn.masked_bias": "pytorch_model-00007-of-00011.bin",
|
255 |
+
"transformer.h.40.ln_1.weight": "pytorch_model-00007-of-00011.bin",
|
256 |
+
"transformer.h.40.ln_2.weight": "pytorch_model-00007-of-00011.bin",
|
257 |
+
"transformer.h.40.mlp.c_fc.weight": "pytorch_model-00007-of-00011.bin",
|
258 |
+
"transformer.h.40.mlp.c_proj.weight": "pytorch_model-00007-of-00011.bin",
|
259 |
+
"transformer.h.41.attn.c_attn.weight": "pytorch_model-00007-of-00011.bin",
|
260 |
+
"transformer.h.41.attn.c_proj.weight": "pytorch_model-00007-of-00011.bin",
|
261 |
+
"transformer.h.41.attn.masked_bias": "pytorch_model-00007-of-00011.bin",
|
262 |
+
"transformer.h.41.ln_1.weight": "pytorch_model-00007-of-00011.bin",
|
263 |
+
"transformer.h.41.ln_2.weight": "pytorch_model-00007-of-00011.bin",
|
264 |
+
"transformer.h.41.mlp.c_fc.weight": "pytorch_model-00008-of-00011.bin",
|
265 |
+
"transformer.h.41.mlp.c_proj.weight": "pytorch_model-00008-of-00011.bin",
|
266 |
+
"transformer.h.42.attn.c_attn.weight": "pytorch_model-00008-of-00011.bin",
|
267 |
+
"transformer.h.42.attn.c_proj.weight": "pytorch_model-00008-of-00011.bin",
|
268 |
+
"transformer.h.42.attn.masked_bias": "pytorch_model-00008-of-00011.bin",
|
269 |
+
"transformer.h.42.ln_1.weight": "pytorch_model-00008-of-00011.bin",
|
270 |
+
"transformer.h.42.ln_2.weight": "pytorch_model-00008-of-00011.bin",
|
271 |
+
"transformer.h.42.mlp.c_fc.weight": "pytorch_model-00008-of-00011.bin",
|
272 |
+
"transformer.h.42.mlp.c_proj.weight": "pytorch_model-00008-of-00011.bin",
|
273 |
+
"transformer.h.43.attn.c_attn.weight": "pytorch_model-00008-of-00011.bin",
|
274 |
+
"transformer.h.43.attn.c_proj.weight": "pytorch_model-00008-of-00011.bin",
|
275 |
+
"transformer.h.43.attn.masked_bias": "pytorch_model-00008-of-00011.bin",
|
276 |
+
"transformer.h.43.ln_1.weight": "pytorch_model-00008-of-00011.bin",
|
277 |
+
"transformer.h.43.ln_2.weight": "pytorch_model-00008-of-00011.bin",
|
278 |
+
"transformer.h.43.mlp.c_fc.weight": "pytorch_model-00008-of-00011.bin",
|
279 |
+
"transformer.h.43.mlp.c_proj.weight": "pytorch_model-00008-of-00011.bin",
|
280 |
+
"transformer.h.44.attn.c_attn.weight": "pytorch_model-00008-of-00011.bin",
|
281 |
+
"transformer.h.44.attn.c_proj.weight": "pytorch_model-00008-of-00011.bin",
|
282 |
+
"transformer.h.44.attn.masked_bias": "pytorch_model-00008-of-00011.bin",
|
283 |
+
"transformer.h.44.ln_1.weight": "pytorch_model-00008-of-00011.bin",
|
284 |
+
"transformer.h.44.ln_2.weight": "pytorch_model-00008-of-00011.bin",
|
285 |
+
"transformer.h.44.mlp.c_fc.weight": "pytorch_model-00008-of-00011.bin",
|
286 |
+
"transformer.h.44.mlp.c_proj.weight": "pytorch_model-00008-of-00011.bin",
|
287 |
+
"transformer.h.45.attn.c_attn.weight": "pytorch_model-00008-of-00011.bin",
|
288 |
+
"transformer.h.45.attn.c_proj.weight": "pytorch_model-00008-of-00011.bin",
|
289 |
+
"transformer.h.45.attn.masked_bias": "pytorch_model-00008-of-00011.bin",
|
290 |
+
"transformer.h.45.ln_1.weight": "pytorch_model-00008-of-00011.bin",
|
291 |
+
"transformer.h.45.ln_2.weight": "pytorch_model-00008-of-00011.bin",
|
292 |
+
"transformer.h.45.mlp.c_fc.weight": "pytorch_model-00008-of-00011.bin",
|
293 |
+
"transformer.h.45.mlp.c_proj.weight": "pytorch_model-00008-of-00011.bin",
|
294 |
+
"transformer.h.46.attn.c_attn.weight": "pytorch_model-00008-of-00011.bin",
|
295 |
+
"transformer.h.46.attn.c_proj.weight": "pytorch_model-00008-of-00011.bin",
|
296 |
+
"transformer.h.46.attn.masked_bias": "pytorch_model-00008-of-00011.bin",
|
297 |
+
"transformer.h.46.ln_1.weight": "pytorch_model-00008-of-00011.bin",
|
298 |
+
"transformer.h.46.ln_2.weight": "pytorch_model-00008-of-00011.bin",
|
299 |
+
"transformer.h.46.mlp.c_fc.weight": "pytorch_model-00008-of-00011.bin",
|
300 |
+
"transformer.h.46.mlp.c_proj.weight": "pytorch_model-00008-of-00011.bin",
|
301 |
+
"transformer.h.47.attn.c_attn.weight": "pytorch_model-00008-of-00011.bin",
|
302 |
+
"transformer.h.47.attn.c_proj.weight": "pytorch_model-00008-of-00011.bin",
|
303 |
+
"transformer.h.47.attn.masked_bias": "pytorch_model-00008-of-00011.bin",
|
304 |
+
"transformer.h.47.ln_1.weight": "pytorch_model-00008-of-00011.bin",
|
305 |
+
"transformer.h.47.ln_2.weight": "pytorch_model-00008-of-00011.bin",
|
306 |
+
"transformer.h.47.mlp.c_fc.weight": "pytorch_model-00009-of-00011.bin",
|
307 |
+
"transformer.h.47.mlp.c_proj.weight": "pytorch_model-00009-of-00011.bin",
|
308 |
+
"transformer.h.48.attn.c_attn.weight": "pytorch_model-00009-of-00011.bin",
|
309 |
+
"transformer.h.48.attn.c_proj.weight": "pytorch_model-00009-of-00011.bin",
|
310 |
+
"transformer.h.48.attn.masked_bias": "pytorch_model-00009-of-00011.bin",
|
311 |
+
"transformer.h.48.ln_1.weight": "pytorch_model-00009-of-00011.bin",
|
312 |
+
"transformer.h.48.ln_2.weight": "pytorch_model-00009-of-00011.bin",
|
313 |
+
"transformer.h.48.mlp.c_fc.weight": "pytorch_model-00009-of-00011.bin",
|
314 |
+
"transformer.h.48.mlp.c_proj.weight": "pytorch_model-00009-of-00011.bin",
|
315 |
+
"transformer.h.49.attn.c_attn.weight": "pytorch_model-00009-of-00011.bin",
|
316 |
+
"transformer.h.49.attn.c_proj.weight": "pytorch_model-00009-of-00011.bin",
|
317 |
+
"transformer.h.49.attn.masked_bias": "pytorch_model-00009-of-00011.bin",
|
318 |
+
"transformer.h.49.ln_1.weight": "pytorch_model-00009-of-00011.bin",
|
319 |
+
"transformer.h.49.ln_2.weight": "pytorch_model-00009-of-00011.bin",
|
320 |
+
"transformer.h.49.mlp.c_fc.weight": "pytorch_model-00009-of-00011.bin",
|
321 |
+
"transformer.h.49.mlp.c_proj.weight": "pytorch_model-00009-of-00011.bin",
|
322 |
+
"transformer.h.5.attn.c_attn.weight": "pytorch_model-00001-of-00011.bin",
|
323 |
+
"transformer.h.5.attn.c_proj.weight": "pytorch_model-00001-of-00011.bin",
|
324 |
+
"transformer.h.5.attn.masked_bias": "pytorch_model-00001-of-00011.bin",
|
325 |
+
"transformer.h.5.ln_1.weight": "pytorch_model-00001-of-00011.bin",
|
326 |
+
"transformer.h.5.ln_2.weight": "pytorch_model-00001-of-00011.bin",
|
327 |
+
"transformer.h.5.mlp.c_fc.weight": "pytorch_model-00002-of-00011.bin",
|
328 |
+
"transformer.h.5.mlp.c_proj.weight": "pytorch_model-00002-of-00011.bin",
|
329 |
+
"transformer.h.50.attn.c_attn.weight": "pytorch_model-00009-of-00011.bin",
|
330 |
+
"transformer.h.50.attn.c_proj.weight": "pytorch_model-00009-of-00011.bin",
|
331 |
+
"transformer.h.50.attn.masked_bias": "pytorch_model-00009-of-00011.bin",
|
332 |
+
"transformer.h.50.ln_1.weight": "pytorch_model-00009-of-00011.bin",
|
333 |
+
"transformer.h.50.ln_2.weight": "pytorch_model-00009-of-00011.bin",
|
334 |
+
"transformer.h.50.mlp.c_fc.weight": "pytorch_model-00009-of-00011.bin",
|
335 |
+
"transformer.h.50.mlp.c_proj.weight": "pytorch_model-00009-of-00011.bin",
|
336 |
+
"transformer.h.51.attn.c_attn.weight": "pytorch_model-00009-of-00011.bin",
|
337 |
+
"transformer.h.51.attn.c_proj.weight": "pytorch_model-00009-of-00011.bin",
|
338 |
+
"transformer.h.51.attn.masked_bias": "pytorch_model-00009-of-00011.bin",
|
339 |
+
"transformer.h.51.ln_1.weight": "pytorch_model-00009-of-00011.bin",
|
340 |
+
"transformer.h.51.ln_2.weight": "pytorch_model-00009-of-00011.bin",
|
341 |
+
"transformer.h.51.mlp.c_fc.weight": "pytorch_model-00009-of-00011.bin",
|
342 |
+
"transformer.h.51.mlp.c_proj.weight": "pytorch_model-00009-of-00011.bin",
|
343 |
+
"transformer.h.52.attn.c_attn.weight": "pytorch_model-00009-of-00011.bin",
|
344 |
+
"transformer.h.52.attn.c_proj.weight": "pytorch_model-00009-of-00011.bin",
|
345 |
+
"transformer.h.52.attn.masked_bias": "pytorch_model-00009-of-00011.bin",
|
346 |
+
"transformer.h.52.ln_1.weight": "pytorch_model-00009-of-00011.bin",
|
347 |
+
"transformer.h.52.ln_2.weight": "pytorch_model-00009-of-00011.bin",
|
348 |
+
"transformer.h.52.mlp.c_fc.weight": "pytorch_model-00009-of-00011.bin",
|
349 |
+
"transformer.h.52.mlp.c_proj.weight": "pytorch_model-00009-of-00011.bin",
|
350 |
+
"transformer.h.53.attn.c_attn.weight": "pytorch_model-00009-of-00011.bin",
|
351 |
+
"transformer.h.53.attn.c_proj.weight": "pytorch_model-00009-of-00011.bin",
|
352 |
+
"transformer.h.53.attn.masked_bias": "pytorch_model-00009-of-00011.bin",
|
353 |
+
"transformer.h.53.ln_1.weight": "pytorch_model-00009-of-00011.bin",
|
354 |
+
"transformer.h.53.ln_2.weight": "pytorch_model-00009-of-00011.bin",
|
355 |
+
"transformer.h.53.mlp.c_fc.weight": "pytorch_model-00010-of-00011.bin",
|
356 |
+
"transformer.h.53.mlp.c_proj.weight": "pytorch_model-00010-of-00011.bin",
|
357 |
+
"transformer.h.54.attn.c_attn.weight": "pytorch_model-00010-of-00011.bin",
|
358 |
+
"transformer.h.54.attn.c_proj.weight": "pytorch_model-00010-of-00011.bin",
|
359 |
+
"transformer.h.54.attn.masked_bias": "pytorch_model-00010-of-00011.bin",
|
360 |
+
"transformer.h.54.ln_1.weight": "pytorch_model-00010-of-00011.bin",
|
361 |
+
"transformer.h.54.ln_2.weight": "pytorch_model-00010-of-00011.bin",
|
362 |
+
"transformer.h.54.mlp.c_fc.weight": "pytorch_model-00010-of-00011.bin",
|
363 |
+
"transformer.h.54.mlp.c_proj.weight": "pytorch_model-00010-of-00011.bin",
|
364 |
+
"transformer.h.55.attn.c_attn.weight": "pytorch_model-00010-of-00011.bin",
|
365 |
+
"transformer.h.55.attn.c_proj.weight": "pytorch_model-00010-of-00011.bin",
|
366 |
+
"transformer.h.55.attn.masked_bias": "pytorch_model-00010-of-00011.bin",
|
367 |
+
"transformer.h.55.ln_1.weight": "pytorch_model-00010-of-00011.bin",
|
368 |
+
"transformer.h.55.ln_2.weight": "pytorch_model-00010-of-00011.bin",
|
369 |
+
"transformer.h.55.mlp.c_fc.weight": "pytorch_model-00010-of-00011.bin",
|
370 |
+
"transformer.h.55.mlp.c_proj.weight": "pytorch_model-00010-of-00011.bin",
|
371 |
+
"transformer.h.56.attn.c_attn.weight": "pytorch_model-00010-of-00011.bin",
|
372 |
+
"transformer.h.56.attn.c_proj.weight": "pytorch_model-00010-of-00011.bin",
|
373 |
+
"transformer.h.56.attn.masked_bias": "pytorch_model-00010-of-00011.bin",
|
374 |
+
"transformer.h.56.ln_1.weight": "pytorch_model-00010-of-00011.bin",
|
375 |
+
"transformer.h.56.ln_2.weight": "pytorch_model-00010-of-00011.bin",
|
376 |
+
"transformer.h.56.mlp.c_fc.weight": "pytorch_model-00010-of-00011.bin",
|
377 |
+
"transformer.h.56.mlp.c_proj.weight": "pytorch_model-00010-of-00011.bin",
|
378 |
+
"transformer.h.57.attn.c_attn.weight": "pytorch_model-00010-of-00011.bin",
|
379 |
+
"transformer.h.57.attn.c_proj.weight": "pytorch_model-00010-of-00011.bin",
|
380 |
+
"transformer.h.57.attn.masked_bias": "pytorch_model-00010-of-00011.bin",
|
381 |
+
"transformer.h.57.ln_1.weight": "pytorch_model-00010-of-00011.bin",
|
382 |
+
"transformer.h.57.ln_2.weight": "pytorch_model-00010-of-00011.bin",
|
383 |
+
"transformer.h.57.mlp.c_fc.weight": "pytorch_model-00010-of-00011.bin",
|
384 |
+
"transformer.h.57.mlp.c_proj.weight": "pytorch_model-00010-of-00011.bin",
|
385 |
+
"transformer.h.58.attn.c_attn.weight": "pytorch_model-00010-of-00011.bin",
|
386 |
+
"transformer.h.58.attn.c_proj.weight": "pytorch_model-00010-of-00011.bin",
|
387 |
+
"transformer.h.58.attn.masked_bias": "pytorch_model-00010-of-00011.bin",
|
388 |
+
"transformer.h.58.ln_1.weight": "pytorch_model-00010-of-00011.bin",
|
389 |
+
"transformer.h.58.ln_2.weight": "pytorch_model-00010-of-00011.bin",
|
390 |
+
"transformer.h.58.mlp.c_fc.weight": "pytorch_model-00010-of-00011.bin",
|
391 |
+
"transformer.h.58.mlp.c_proj.weight": "pytorch_model-00010-of-00011.bin",
|
392 |
+
"transformer.h.59.attn.c_attn.weight": "pytorch_model-00010-of-00011.bin",
|
393 |
+
"transformer.h.59.attn.c_proj.weight": "pytorch_model-00010-of-00011.bin",
|
394 |
+
"transformer.h.59.attn.masked_bias": "pytorch_model-00010-of-00011.bin",
|
395 |
+
"transformer.h.59.ln_1.weight": "pytorch_model-00010-of-00011.bin",
|
396 |
+
"transformer.h.59.ln_2.weight": "pytorch_model-00010-of-00011.bin",
|
397 |
+
"transformer.h.59.mlp.c_fc.weight": "pytorch_model-00011-of-00011.bin",
|
398 |
+
"transformer.h.59.mlp.c_proj.weight": "pytorch_model-00011-of-00011.bin",
|
399 |
+
"transformer.h.6.attn.c_attn.weight": "pytorch_model-00002-of-00011.bin",
|
400 |
+
"transformer.h.6.attn.c_proj.weight": "pytorch_model-00002-of-00011.bin",
|
401 |
+
"transformer.h.6.attn.masked_bias": "pytorch_model-00002-of-00011.bin",
|
402 |
+
"transformer.h.6.ln_1.weight": "pytorch_model-00002-of-00011.bin",
|
403 |
+
"transformer.h.6.ln_2.weight": "pytorch_model-00002-of-00011.bin",
|
404 |
+
"transformer.h.6.mlp.c_fc.weight": "pytorch_model-00002-of-00011.bin",
|
405 |
+
"transformer.h.6.mlp.c_proj.weight": "pytorch_model-00002-of-00011.bin",
|
406 |
+
"transformer.h.60.attn.c_attn.weight": "pytorch_model-00011-of-00011.bin",
|
407 |
+
"transformer.h.60.attn.c_proj.weight": "pytorch_model-00011-of-00011.bin",
|
408 |
+
"transformer.h.60.attn.masked_bias": "pytorch_model-00011-of-00011.bin",
|
409 |
+
"transformer.h.60.ln_1.weight": "pytorch_model-00011-of-00011.bin",
|
410 |
+
"transformer.h.60.ln_2.weight": "pytorch_model-00011-of-00011.bin",
|
411 |
+
"transformer.h.60.mlp.c_fc.weight": "pytorch_model-00011-of-00011.bin",
|
412 |
+
"transformer.h.60.mlp.c_proj.weight": "pytorch_model-00011-of-00011.bin",
|
413 |
+
"transformer.h.61.attn.c_attn.weight": "pytorch_model-00011-of-00011.bin",
|
414 |
+
"transformer.h.61.attn.c_proj.weight": "pytorch_model-00011-of-00011.bin",
|
415 |
+
"transformer.h.61.attn.masked_bias": "pytorch_model-00011-of-00011.bin",
|
416 |
+
"transformer.h.61.ln_1.weight": "pytorch_model-00011-of-00011.bin",
|
417 |
+
"transformer.h.61.ln_2.weight": "pytorch_model-00011-of-00011.bin",
|
418 |
+
"transformer.h.61.mlp.c_fc.weight": "pytorch_model-00011-of-00011.bin",
|
419 |
+
"transformer.h.61.mlp.c_proj.weight": "pytorch_model-00011-of-00011.bin",
|
420 |
+
"transformer.h.62.attn.c_attn.weight": "pytorch_model-00011-of-00011.bin",
|
421 |
+
"transformer.h.62.attn.c_proj.weight": "pytorch_model-00011-of-00011.bin",
|
422 |
+
"transformer.h.62.attn.masked_bias": "pytorch_model-00011-of-00011.bin",
|
423 |
+
"transformer.h.62.ln_1.weight": "pytorch_model-00011-of-00011.bin",
|
424 |
+
"transformer.h.62.ln_2.weight": "pytorch_model-00011-of-00011.bin",
|
425 |
+
"transformer.h.62.mlp.c_fc.weight": "pytorch_model-00011-of-00011.bin",
|
426 |
+
"transformer.h.62.mlp.c_proj.weight": "pytorch_model-00011-of-00011.bin",
|
427 |
+
"transformer.h.63.attn.c_attn.weight": "pytorch_model-00011-of-00011.bin",
|
428 |
+
"transformer.h.63.attn.c_proj.weight": "pytorch_model-00011-of-00011.bin",
|
429 |
+
"transformer.h.63.attn.masked_bias": "pytorch_model-00011-of-00011.bin",
|
430 |
+
"transformer.h.63.ln_1.weight": "pytorch_model-00011-of-00011.bin",
|
431 |
+
"transformer.h.63.ln_2.weight": "pytorch_model-00011-of-00011.bin",
|
432 |
+
"transformer.h.63.mlp.c_fc.weight": "pytorch_model-00011-of-00011.bin",
|
433 |
+
"transformer.h.63.mlp.c_proj.weight": "pytorch_model-00011-of-00011.bin",
|
434 |
+
"transformer.h.7.attn.c_attn.weight": "pytorch_model-00002-of-00011.bin",
|
435 |
+
"transformer.h.7.attn.c_proj.weight": "pytorch_model-00002-of-00011.bin",
|
436 |
+
"transformer.h.7.attn.masked_bias": "pytorch_model-00002-of-00011.bin",
|
437 |
+
"transformer.h.7.ln_1.weight": "pytorch_model-00002-of-00011.bin",
|
438 |
+
"transformer.h.7.ln_2.weight": "pytorch_model-00002-of-00011.bin",
|
439 |
+
"transformer.h.7.mlp.c_fc.weight": "pytorch_model-00002-of-00011.bin",
|
440 |
+
"transformer.h.7.mlp.c_proj.weight": "pytorch_model-00002-of-00011.bin",
|
441 |
+
"transformer.h.8.attn.c_attn.weight": "pytorch_model-00002-of-00011.bin",
|
442 |
+
"transformer.h.8.attn.c_proj.weight": "pytorch_model-00002-of-00011.bin",
|
443 |
+
"transformer.h.8.attn.masked_bias": "pytorch_model-00002-of-00011.bin",
|
444 |
+
"transformer.h.8.ln_1.weight": "pytorch_model-00002-of-00011.bin",
|
445 |
+
"transformer.h.8.ln_2.weight": "pytorch_model-00002-of-00011.bin",
|
446 |
+
"transformer.h.8.mlp.c_fc.weight": "pytorch_model-00002-of-00011.bin",
|
447 |
+
"transformer.h.8.mlp.c_proj.weight": "pytorch_model-00002-of-00011.bin",
|
448 |
+
"transformer.h.9.attn.c_attn.weight": "pytorch_model-00002-of-00011.bin",
|
449 |
+
"transformer.h.9.attn.c_proj.weight": "pytorch_model-00002-of-00011.bin",
|
450 |
+
"transformer.h.9.attn.masked_bias": "pytorch_model-00002-of-00011.bin",
|
451 |
+
"transformer.h.9.ln_1.weight": "pytorch_model-00002-of-00011.bin",
|
452 |
+
"transformer.h.9.ln_2.weight": "pytorch_model-00002-of-00011.bin",
|
453 |
+
"transformer.h.9.mlp.c_fc.weight": "pytorch_model-00002-of-00011.bin",
|
454 |
+
"transformer.h.9.mlp.c_proj.weight": "pytorch_model-00002-of-00011.bin",
|
455 |
+
"transformer.ln_f.weight": "pytorch_model-00011-of-00011.bin",
|
456 |
+
"transformer.wte.weight": "pytorch_model-00001-of-00011.bin"
|
457 |
+
}
|
458 |
+
}
|
special_tokens_map.json
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"bos_token": {
|
3 |
+
"content": "<s>",
|
4 |
+
"lstrip": false,
|
5 |
+
"normalized": true,
|
6 |
+
"rstrip": false,
|
7 |
+
"single_word": false
|
8 |
+
},
|
9 |
+
"eos_token": {
|
10 |
+
"content": "</s>",
|
11 |
+
"lstrip": false,
|
12 |
+
"normalized": true,
|
13 |
+
"rstrip": false,
|
14 |
+
"single_word": false
|
15 |
+
},
|
16 |
+
"pad_token": {
|
17 |
+
"content": "<pad>",
|
18 |
+
"lstrip": false,
|
19 |
+
"normalized": true,
|
20 |
+
"rstrip": false,
|
21 |
+
"single_word": false
|
22 |
+
},
|
23 |
+
"unk_token": {
|
24 |
+
"content": "<unk>",
|
25 |
+
"lstrip": false,
|
26 |
+
"normalized": true,
|
27 |
+
"rstrip": false,
|
28 |
+
"single_word": false
|
29 |
+
}
|
30 |
+
}
|
tokenization_telechat.py
ADDED
@@ -0,0 +1,403 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
3 |
+
#
|
4 |
+
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
5 |
+
# and OPT implementations in this library. It has been modified from its
|
6 |
+
# original forms to accommodate minor architectural differences compared
|
7 |
+
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
8 |
+
#
|
9 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
10 |
+
# you may not use this file except in compliance with the License.
|
11 |
+
# You may obtain a copy of the License at
|
12 |
+
#
|
13 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
14 |
+
#
|
15 |
+
# Unless required by applicable law or agreed to in writing, software
|
16 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
17 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
18 |
+
# See the License for the specific language governing permissions and
|
19 |
+
# limitations under the License.
|
20 |
+
|
21 |
+
"""Tokenization classes for TELECHAT."""
|
22 |
+
import os
|
23 |
+
from shutil import copyfile
|
24 |
+
from typing import Any, Dict, List, Optional, Tuple
|
25 |
+
|
26 |
+
import sentencepiece as spm
|
27 |
+
import re
|
28 |
+
from transformers.convert_slow_tokenizer import import_protobuf
|
29 |
+
from transformers import AddedToken, PreTrainedTokenizer
|
30 |
+
from transformers.utils import logging
|
31 |
+
from transformers.tokenization_utils_base import TextInput
|
32 |
+
|
33 |
+
logger = logging.get_logger(__name__)
|
34 |
+
|
35 |
+
VOCAB_FILES_NAMES = {"vocab_file": "tokenizer.model"}
|
36 |
+
|
37 |
+
PRETRAINED_VOCAB_FILES_MAP = {
|
38 |
+
"vocab_file": {},
|
39 |
+
"tokenizer_file": {},
|
40 |
+
}
|
41 |
+
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
|
42 |
+
"telechat-tokenizer": 8192,
|
43 |
+
}
|
44 |
+
SPIECE_UNDERLINE = "▁"
|
45 |
+
|
46 |
+
|
47 |
+
class TELECHATTokenizer(PreTrainedTokenizer):
|
48 |
+
"""
|
49 |
+
Construct a TELECHAT tokenizer. Based on byte-level Byte-Pair-Encoding. The default padding token is unset as there is
|
50 |
+
no padding token in the original model.
|
51 |
+
|
52 |
+
Args:
|
53 |
+
vocab_file (`str`):
|
54 |
+
Path to the vocabulary file.
|
55 |
+
unk_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<unk>"`):
|
56 |
+
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
|
57 |
+
token instead.
|
58 |
+
bos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<s>"`):
|
59 |
+
The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token.
|
60 |
+
eos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"</s>"`):
|
61 |
+
The end of sequence token.
|
62 |
+
pad_token (`str` or `tokenizers.AddedToken`, *optional*):
|
63 |
+
A special token used to make arrays of tokens the same size for batching purpose. Will then be ignored by
|
64 |
+
attention mechanisms or loss computation.
|
65 |
+
sp_model_kwargs (`Dict[str, Any]`, `Optional`, *optional*):
|
66 |
+
Will be passed to the `SentencePieceProcessor.__init__()` method. The [Python wrapper for
|
67 |
+
SentencePiece](https://github.com/google/sentencepiece/tree/master/python) can be used, among other things,
|
68 |
+
to set:
|
69 |
+
|
70 |
+
- `enable_sampling`: Enable subword regularization.
|
71 |
+
- `nbest_size`: Sampling parameters for unigram. Invalid for BPE-Dropout.
|
72 |
+
|
73 |
+
- `nbest_size = {0,1}`: No sampling is performed.
|
74 |
+
- `nbest_size > 1`: samples from the nbest_size results.
|
75 |
+
- `nbest_size < 0`: assuming that nbest_size is infinite and samples from the all hypothesis (lattice)
|
76 |
+
using forward-filtering-and-backward-sampling algorithm.
|
77 |
+
|
78 |
+
- `alpha`: Smoothing parameter for unigram sampling, and dropout probability of merge operations for
|
79 |
+
BPE-dropout.
|
80 |
+
|
81 |
+
add_bos_token (`bool`, *optional*, defaults to `True`):
|
82 |
+
Whether or not to add an `bos_token` at the start of sequences.
|
83 |
+
add_eos_token (`bool`, *optional*, defaults to `False`):
|
84 |
+
Whether or not to add an `eos_token` at the end of sequences.
|
85 |
+
clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`):
|
86 |
+
Whether or not to cleanup spaces after decoding, cleanup consists in removing potential artifacts like
|
87 |
+
extra spaces.
|
88 |
+
spaces_between_special_tokens (`bool`, *optional*, defaults to `False`):
|
89 |
+
Whether or not to add spaces between special tokens.
|
90 |
+
|
91 |
+
"""
|
92 |
+
|
93 |
+
vocab_files_names = VOCAB_FILES_NAMES
|
94 |
+
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
|
95 |
+
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
|
96 |
+
model_input_names = ["input_ids", "attention_mask"]
|
97 |
+
|
98 |
+
def __init__(
|
99 |
+
self,
|
100 |
+
vocab_file,
|
101 |
+
bos_token="<s>",
|
102 |
+
eos_token="</s>",
|
103 |
+
unk_token="<unk>",
|
104 |
+
pad_token=None,
|
105 |
+
sp_model_kwargs: Optional[Dict[str, Any]] = None,
|
106 |
+
add_bos_token=False,
|
107 |
+
add_eos_token=False,
|
108 |
+
clean_up_tokenization_spaces=False,
|
109 |
+
spaces_between_special_tokens=False,
|
110 |
+
**kwargs,
|
111 |
+
):
|
112 |
+
self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
|
113 |
+
bos_token = AddedToken(bos_token, normalized=False, special=True) if isinstance(bos_token, str) else bos_token
|
114 |
+
eos_token = AddedToken(eos_token, normalized=False, special=True) if isinstance(eos_token, str) else eos_token
|
115 |
+
pad_token = AddedToken(pad_token, normalized=False, special=True) if isinstance(pad_token, str) else pad_token
|
116 |
+
self.vocab_file = vocab_file
|
117 |
+
self.add_bos_token = add_bos_token
|
118 |
+
self.add_eos_token = add_eos_token
|
119 |
+
self.sp_model = self.get_spm_processor(kwargs.pop("from_slow", False))
|
120 |
+
super().__init__(
|
121 |
+
bos_token=bos_token,
|
122 |
+
eos_token=eos_token,
|
123 |
+
unk_token=unk_token,
|
124 |
+
pad_token=pad_token,
|
125 |
+
add_bos_token=add_bos_token,
|
126 |
+
add_eos_token=add_eos_token,
|
127 |
+
sp_model_kwargs=self.sp_model_kwargs,
|
128 |
+
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
|
129 |
+
spaces_between_special_tokens=spaces_between_special_tokens,
|
130 |
+
**kwargs,
|
131 |
+
)
|
132 |
+
|
133 |
+
@property
|
134 |
+
def unk_token_length(self):
|
135 |
+
return len(self.sp_model.encode(str(self.unk_token)))
|
136 |
+
|
137 |
+
# Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.get_spm_processor
|
138 |
+
def get_spm_processor(self, from_slow=False):
|
139 |
+
tokenizer = spm.SentencePieceProcessor(**self.sp_model_kwargs)
|
140 |
+
with open(self.vocab_file, "rb") as f:
|
141 |
+
sp_model = f.read()
|
142 |
+
model_pb2 = import_protobuf(f"The new behaviour of {self.__class__.__name__} (with `self.legacy = False`)")
|
143 |
+
model = model_pb2.ModelProto.FromString(sp_model)
|
144 |
+
normalizer_spec = model_pb2.NormalizerSpec()
|
145 |
+
normalizer_spec.add_dummy_prefix = True
|
146 |
+
model.normalizer_spec.MergeFrom(normalizer_spec)
|
147 |
+
sp_model = model.SerializeToString()
|
148 |
+
tokenizer.LoadFromSerializedProto(sp_model)
|
149 |
+
return tokenizer
|
150 |
+
|
151 |
+
def __getstate__(self):
|
152 |
+
state = self.__dict__.copy()
|
153 |
+
state["sp_model"] = None
|
154 |
+
state["sp_model_proto"] = self.sp_model.serialized_model_proto()
|
155 |
+
return state
|
156 |
+
|
157 |
+
def __setstate__(self, d):
|
158 |
+
self.__dict__ = d
|
159 |
+
self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
|
160 |
+
self.sp_model.LoadFromSerializedProto(self.sp_model_proto)
|
161 |
+
|
162 |
+
@property
|
163 |
+
def vocab_size(self):
|
164 |
+
"""Returns vocab size"""
|
165 |
+
return self.sp_model.get_piece_size()
|
166 |
+
|
167 |
+
def get_vocab(self):
|
168 |
+
"""Returns vocab as a dict"""
|
169 |
+
vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
|
170 |
+
vocab.update(self.added_tokens_encoder)
|
171 |
+
return vocab
|
172 |
+
|
173 |
+
def tokenize(self, text: TextInput, **kwargs) -> List[str]:
|
174 |
+
"""
|
175 |
+
Converts a string in a sequence of tokens, using the tokenizer.
|
176 |
+
|
177 |
+
Split in words for word-based vocabulary or sub-words for sub-word-based vocabularies
|
178 |
+
(BPE/SentencePieces/WordPieces). Takes care of added tokens.
|
179 |
+
|
180 |
+
Args:
|
181 |
+
text (`str`):
|
182 |
+
The sequence to be encoded.
|
183 |
+
**kwargs (additional keyword arguments):
|
184 |
+
Passed along to the model-specific `prepare_for_tokenization` preprocessing method.
|
185 |
+
|
186 |
+
Returns:
|
187 |
+
`List[str]`: The list of tokens.
|
188 |
+
"""
|
189 |
+
split_special_tokens = kwargs.pop("split_special_tokens", self.split_special_tokens)
|
190 |
+
remove_dummy_prefix = kwargs.pop("remove_dummy_prefix", False)
|
191 |
+
|
192 |
+
text, kwargs = self.prepare_for_tokenization(text, **kwargs)
|
193 |
+
|
194 |
+
if kwargs:
|
195 |
+
logger.warning(f"Keyword arguments {kwargs} not recognized.")
|
196 |
+
|
197 |
+
if hasattr(self, "do_lower_case") and self.do_lower_case:
|
198 |
+
# convert non-special tokens to lowercase. Might be super slow as well?
|
199 |
+
escaped_special_toks = [re.escape(s_tok) for s_tok in (self.all_special_tokens)]
|
200 |
+
escaped_special_toks += [
|
201 |
+
re.escape(s_tok.content)
|
202 |
+
for s_tok in (self._added_tokens_decoder.values())
|
203 |
+
if not s_tok.special and s_tok.normalized
|
204 |
+
]
|
205 |
+
pattern = r"(" + r"|".join(escaped_special_toks) + r")|" + r"(.+?)"
|
206 |
+
text = re.sub(pattern, lambda m: m.groups()[0] or m.groups()[1].lower(), text)
|
207 |
+
|
208 |
+
if split_special_tokens:
|
209 |
+
no_split_token = []
|
210 |
+
tokens = [text]
|
211 |
+
else:
|
212 |
+
no_split_token = self._added_tokens_encoder.keys() # don't split on any of the added tokens
|
213 |
+
# "This is something<special_token_1> else"
|
214 |
+
tokens = self.tokens_trie.split(text)
|
215 |
+
|
216 |
+
# ["This is something", "<special_token_1>", " else"]
|
217 |
+
for i, token in enumerate(tokens):
|
218 |
+
if token in no_split_token:
|
219 |
+
tok_extended = self._added_tokens_decoder.get(self._added_tokens_encoder[token], None)
|
220 |
+
left = tokens[i - 1] if i > 0 else None
|
221 |
+
right = tokens[i + 1] if i < len(tokens) - 1 else None
|
222 |
+
if isinstance(tok_extended, AddedToken):
|
223 |
+
if tok_extended.rstrip and right:
|
224 |
+
# A bit counter-intuitive but we strip the left of the string
|
225 |
+
# since tok_extended.rstrip means the special token is eating all white spaces on its right
|
226 |
+
tokens[i + 1] = right.lstrip()
|
227 |
+
# Strip white spaces on the left
|
228 |
+
if tok_extended.lstrip and left:
|
229 |
+
tokens[i - 1] = left.rstrip() # Opposite here
|
230 |
+
if tok_extended.single_word and left and left[-1] != " ":
|
231 |
+
tokens[i - 1] += token
|
232 |
+
tokens[i] = ""
|
233 |
+
elif tok_extended.single_word and right and right[0] != " ":
|
234 |
+
tokens[i + 1] = token + tokens[i + 1]
|
235 |
+
tokens[i] = ""
|
236 |
+
else:
|
237 |
+
raise ValueError(
|
238 |
+
f"{tok_extended} cannot be tokenized because it was not properly added"
|
239 |
+
f" to the tokenizer. This means that it is not an `AddedToken` but a {type(tok_extended)}"
|
240 |
+
)
|
241 |
+
# ["This is something", "<special_token_1>", "else"]
|
242 |
+
tokenized_text = []
|
243 |
+
for token in tokens:
|
244 |
+
# Need to skip eventual empty (fully stripped) tokens
|
245 |
+
if not token:
|
246 |
+
continue
|
247 |
+
if token in no_split_token:
|
248 |
+
tokenized_text.append(token)
|
249 |
+
else:
|
250 |
+
tokenized_text.extend(self._tokenize(token, remove_dummy_prefix=remove_dummy_prefix))
|
251 |
+
# ["This", " is", " something", "<special_token_1>", "else"]
|
252 |
+
return tokenized_text
|
253 |
+
|
254 |
+
def _tokenize(self, text, **kwargs):
|
255 |
+
"""
|
256 |
+
Returns a tokenized string.
|
257 |
+
|
258 |
+
We add a option to remove dummpy prefix during tokenization instead of changing the default behaviour of the sentencepiece tokenizer.
|
259 |
+
This is useful when there're two tokenized sentences to be merged into one as the last one will have an extra dummy prefix which results in a
|
260 |
+
inconsistant pattern.
|
261 |
+
"""
|
262 |
+
tokens = self.sp_model.encode(text, out_type=str)
|
263 |
+
if text.startswith((SPIECE_UNDERLINE, " ")):
|
264 |
+
return tokens
|
265 |
+
if len(tokens) > 0 and kwargs.get("remove_dummy_prefix") is True:
|
266 |
+
tokens[0] = tokens[0].replace(SPIECE_UNDERLINE, "", 1)
|
267 |
+
return tokens
|
268 |
+
|
269 |
+
def _convert_token_to_id(self, token):
|
270 |
+
"""Converts a token (str) in an id using the vocab."""
|
271 |
+
return self.sp_model.piece_to_id(token)
|
272 |
+
|
273 |
+
def _convert_id_to_token(self, index):
|
274 |
+
"""Converts an index (integer) in a token (str) using the vocab."""
|
275 |
+
token = self.sp_model.IdToPiece(index)
|
276 |
+
return token
|
277 |
+
|
278 |
+
def convert_tokens_to_string(self, tokens):
|
279 |
+
"""Converts a sequence of tokens (string) in a single string."""
|
280 |
+
current_sub_tokens = []
|
281 |
+
out_string = ""
|
282 |
+
# prev_is_special = False
|
283 |
+
for i, token in enumerate(tokens):
|
284 |
+
# make sure that special tokens are not decoded using sentencepiece model
|
285 |
+
if token in self.all_special_tokens:
|
286 |
+
# if not prev_is_special and i != 0 and self.legacy:
|
287 |
+
# out_string += " "
|
288 |
+
out_string += self.sp_model.decode(current_sub_tokens) + token
|
289 |
+
# prev_is_special = True
|
290 |
+
current_sub_tokens = []
|
291 |
+
else:
|
292 |
+
current_sub_tokens.append(token)
|
293 |
+
# prev_is_special = False
|
294 |
+
out_string += self.sp_model.decode(current_sub_tokens)
|
295 |
+
return out_string
|
296 |
+
|
297 |
+
def save_vocabulary(self, save_directory, filename_prefix: Optional[str] = None) -> Tuple[str]:
|
298 |
+
"""
|
299 |
+
Save the vocabulary and special tokens file to a directory.
|
300 |
+
|
301 |
+
Args:
|
302 |
+
save_directory (`str`):
|
303 |
+
The directory in which to save the vocabulary.
|
304 |
+
|
305 |
+
Returns:
|
306 |
+
`Tuple(str)`: Paths to the files saved.
|
307 |
+
"""
|
308 |
+
if not os.path.isdir(save_directory):
|
309 |
+
logger.error(f"Vocabulary path ({save_directory}) should be a directory")
|
310 |
+
return
|
311 |
+
out_vocab_file = os.path.join(
|
312 |
+
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
|
313 |
+
)
|
314 |
+
|
315 |
+
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
|
316 |
+
copyfile(self.vocab_file, out_vocab_file)
|
317 |
+
elif not os.path.isfile(self.vocab_file):
|
318 |
+
with open(out_vocab_file, "wb") as fi:
|
319 |
+
content_spiece_model = self.sp_model.serialized_model_proto()
|
320 |
+
fi.write(content_spiece_model)
|
321 |
+
|
322 |
+
return (out_vocab_file,)
|
323 |
+
|
324 |
+
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
|
325 |
+
bos_token_id = [self.bos_token_id] if self.add_bos_token else []
|
326 |
+
eos_token_id = [self.eos_token_id] if self.add_eos_token else []
|
327 |
+
|
328 |
+
output = bos_token_id + token_ids_0 + eos_token_id
|
329 |
+
|
330 |
+
if token_ids_1 is not None:
|
331 |
+
output = output + bos_token_id + token_ids_1 + eos_token_id
|
332 |
+
|
333 |
+
return output
|
334 |
+
|
335 |
+
def get_special_tokens_mask(
|
336 |
+
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
|
337 |
+
) -> List[int]:
|
338 |
+
"""
|
339 |
+
Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
|
340 |
+
special tokens using the tokenizer `prepare_for_model` method.
|
341 |
+
|
342 |
+
Args:
|
343 |
+
token_ids_0 (`List[int]`):
|
344 |
+
List of IDs.
|
345 |
+
token_ids_1 (`List[int]`, *optional*):
|
346 |
+
Optional second list of IDs for sequence pairs.
|
347 |
+
already_has_special_tokens (`bool`, *optional*, defaults to `False`):
|
348 |
+
Whether or not the token list is already formatted with special tokens for the model.
|
349 |
+
|
350 |
+
Returns:
|
351 |
+
`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
|
352 |
+
"""
|
353 |
+
if already_has_special_tokens:
|
354 |
+
return super().get_special_tokens_mask(
|
355 |
+
token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
|
356 |
+
)
|
357 |
+
|
358 |
+
bos_token_id = [1] if self.add_bos_token else []
|
359 |
+
eos_token_id = [1] if self.add_eos_token else []
|
360 |
+
|
361 |
+
if token_ids_1 is None:
|
362 |
+
return bos_token_id + ([0] * len(token_ids_0)) + eos_token_id
|
363 |
+
return (
|
364 |
+
bos_token_id
|
365 |
+
+ ([0] * len(token_ids_0))
|
366 |
+
+ eos_token_id
|
367 |
+
+ bos_token_id
|
368 |
+
+ ([0] * len(token_ids_1))
|
369 |
+
+ eos_token_id
|
370 |
+
)
|
371 |
+
|
372 |
+
def create_token_type_ids_from_sequences(
|
373 |
+
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
374 |
+
) -> List[int]:
|
375 |
+
"""
|
376 |
+
Creates a mask from the two sequences passed to be used in a sequence-pair classification task. An ALBERT
|
377 |
+
sequence pair mask has the following format:
|
378 |
+
|
379 |
+
```
|
380 |
+
0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
|
381 |
+
| first sequence | second sequence |
|
382 |
+
```
|
383 |
+
|
384 |
+
if token_ids_1 is None, only returns the first portion of the mask (0s).
|
385 |
+
|
386 |
+
Args:
|
387 |
+
token_ids_0 (`List[int]`):
|
388 |
+
List of ids.
|
389 |
+
token_ids_1 (`List[int]`, *optional*):
|
390 |
+
Optional second list of IDs for sequence pairs.
|
391 |
+
|
392 |
+
Returns:
|
393 |
+
`List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s).
|
394 |
+
"""
|
395 |
+
bos_token_id = [self.bos_token_id] if self.add_bos_token else []
|
396 |
+
eos_token_id = [self.eos_token_id] if self.add_eos_token else []
|
397 |
+
|
398 |
+
output = [0] * len(bos_token_id + token_ids_0 + eos_token_id)
|
399 |
+
|
400 |
+
if token_ids_1 is not None:
|
401 |
+
output += [1] * len(bos_token_id + token_ids_1 + eos_token_id)
|
402 |
+
|
403 |
+
return output
|
tokenizer.model
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:1e2bf2c2d38bab8a4d7107e36073be27be40a625b2f4e57f5a0609bdb70deed8
|
3 |
+
size 1159468
|
tokenizer_config.json
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"add_bos_token": false,
|
3 |
+
"add_eos_token": false,
|
4 |
+
"added_tokens_decoder": {
|
5 |
+
"0": {
|
6 |
+
"content": "<unk>",
|
7 |
+
"lstrip": false,
|
8 |
+
"normalized": true,
|
9 |
+
"rstrip": false,
|
10 |
+
"single_word": false,
|
11 |
+
"special": true
|
12 |
+
},
|
13 |
+
"1": {
|
14 |
+
"content": "<s>",
|
15 |
+
"lstrip": false,
|
16 |
+
"normalized": true,
|
17 |
+
"rstrip": false,
|
18 |
+
"single_word": false,
|
19 |
+
"special": true
|
20 |
+
},
|
21 |
+
"2": {
|
22 |
+
"content": "</s>",
|
23 |
+
"lstrip": false,
|
24 |
+
"normalized": true,
|
25 |
+
"rstrip": false,
|
26 |
+
"single_word": false,
|
27 |
+
"special": true
|
28 |
+
},
|
29 |
+
"3": {
|
30 |
+
"content": "<pad>",
|
31 |
+
"lstrip": false,
|
32 |
+
"normalized": true,
|
33 |
+
"rstrip": false,
|
34 |
+
"single_word": false,
|
35 |
+
"special": true
|
36 |
+
}
|
37 |
+
},
|
38 |
+
"auto_map": {
|
39 |
+
"AutoTokenizer": [
|
40 |
+
"tokenization_telechat.TELECHATTokenizer",
|
41 |
+
null
|
42 |
+
]
|
43 |
+
},
|
44 |
+
"bos_token": "<s>",
|
45 |
+
"clean_up_tokenization_spaces": false,
|
46 |
+
"eos_token": "</s>",
|
47 |
+
"model_max_length": 8192,
|
48 |
+
"pad_token": "<pad>",
|
49 |
+
"sp_model_kwargs": {},
|
50 |
+
"spaces_between_special_tokens": false,
|
51 |
+
"tokenizer_class": "TELECHATTokenizer",
|
52 |
+
"unk_token": "<unk>",
|
53 |
+
"use_fast": false
|
54 |
+
}
|