wangrongsheng
commited on
Upload 81 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- LLM-Detector-V4-11w/src/api_demo.py +14 -0
- LLM-Detector-V4-11w/src/cli_demo.py +47 -0
- LLM-Detector-V4-11w/src/evaluate.py +10 -0
- LLM-Detector-V4-11w/src/export_model.py +9 -0
- LLM-Detector-V4-11w/src/llmtuner/__init__.py +10 -0
- LLM-Detector-V4-11w/src/llmtuner/api/__init__.py +1 -0
- LLM-Detector-V4-11w/src/llmtuner/api/app.py +165 -0
- LLM-Detector-V4-11w/src/llmtuner/api/protocol.py +83 -0
- LLM-Detector-V4-11w/src/llmtuner/chat/__init__.py +1 -0
- LLM-Detector-V4-11w/src/llmtuner/chat/chat_model.py +132 -0
- LLM-Detector-V4-11w/src/llmtuner/data/__init__.py +4 -0
- LLM-Detector-V4-11w/src/llmtuner/data/loader.py +148 -0
- LLM-Detector-V4-11w/src/llmtuner/data/preprocess.py +275 -0
- LLM-Detector-V4-11w/src/llmtuner/data/template.py +747 -0
- LLM-Detector-V4-11w/src/llmtuner/data/utils.py +61 -0
- LLM-Detector-V4-11w/src/llmtuner/eval/__init__.py +1 -0
- LLM-Detector-V4-11w/src/llmtuner/eval/evaluator.py +124 -0
- LLM-Detector-V4-11w/src/llmtuner/eval/template.py +86 -0
- LLM-Detector-V4-11w/src/llmtuner/extras/__init__.py +0 -0
- LLM-Detector-V4-11w/src/llmtuner/extras/callbacks.py +165 -0
- LLM-Detector-V4-11w/src/llmtuner/extras/constants.py +587 -0
- LLM-Detector-V4-11w/src/llmtuner/extras/logging.py +49 -0
- LLM-Detector-V4-11w/src/llmtuner/extras/misc.py +140 -0
- LLM-Detector-V4-11w/src/llmtuner/extras/packages.py +55 -0
- LLM-Detector-V4-11w/src/llmtuner/extras/patches/__init__.py +0 -0
- LLM-Detector-V4-11w/src/llmtuner/extras/patches/llama_patch.py +224 -0
- LLM-Detector-V4-11w/src/llmtuner/extras/ploting.py +55 -0
- LLM-Detector-V4-11w/src/llmtuner/hparams/__init__.py +5 -0
- LLM-Detector-V4-11w/src/llmtuner/hparams/data_args.py +179 -0
- LLM-Detector-V4-11w/src/llmtuner/hparams/evaluation_args.py +55 -0
- LLM-Detector-V4-11w/src/llmtuner/hparams/finetuning_args.py +196 -0
- LLM-Detector-V4-11w/src/llmtuner/hparams/generating_args.py +53 -0
- LLM-Detector-V4-11w/src/llmtuner/hparams/model_args.py +76 -0
- LLM-Detector-V4-11w/src/llmtuner/model/__init__.py +5 -0
- LLM-Detector-V4-11w/src/llmtuner/model/adapter.py +108 -0
- LLM-Detector-V4-11w/src/llmtuner/model/loader.py +235 -0
- LLM-Detector-V4-11w/src/llmtuner/model/parser.py +205 -0
- LLM-Detector-V4-11w/src/llmtuner/model/utils.py +183 -0
- LLM-Detector-V4-11w/src/llmtuner/train/__init__.py +1 -0
- LLM-Detector-V4-11w/src/llmtuner/train/dpo/__init__.py +1 -0
- LLM-Detector-V4-11w/src/llmtuner/train/dpo/collator.py +51 -0
- LLM-Detector-V4-11w/src/llmtuner/train/dpo/trainer.py +75 -0
- LLM-Detector-V4-11w/src/llmtuner/train/dpo/workflow.py +80 -0
- LLM-Detector-V4-11w/src/llmtuner/train/ppo/__init__.py +1 -0
- LLM-Detector-V4-11w/src/llmtuner/train/ppo/trainer.py +359 -0
- LLM-Detector-V4-11w/src/llmtuner/train/ppo/utils.py +35 -0
- LLM-Detector-V4-11w/src/llmtuner/train/ppo/workflow.py +100 -0
- LLM-Detector-V4-11w/src/llmtuner/train/pt/__init__.py +1 -0
- LLM-Detector-V4-11w/src/llmtuner/train/pt/workflow.py +62 -0
- LLM-Detector-V4-11w/src/llmtuner/train/rm/__init__.py +1 -0
LLM-Detector-V4-11w/src/api_demo.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import uvicorn
|
2 |
+
|
3 |
+
from llmtuner import ChatModel, create_app
|
4 |
+
|
5 |
+
|
6 |
+
def main():
|
7 |
+
chat_model = ChatModel()
|
8 |
+
app = create_app(chat_model)
|
9 |
+
print("Visit http://localhost:8000/docs for API document.")
|
10 |
+
uvicorn.run(app, host="0.0.0.0", port=8001, workers=1)
|
11 |
+
|
12 |
+
|
13 |
+
if __name__ == "__main__":
|
14 |
+
main()
|
LLM-Detector-V4-11w/src/cli_demo.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from llmtuner import ChatModel
|
2 |
+
from llmtuner.extras.misc import torch_gc
|
3 |
+
|
4 |
+
try:
|
5 |
+
import platform
|
6 |
+
if platform.system() != "Windows":
|
7 |
+
import readline
|
8 |
+
except ImportError:
|
9 |
+
print("Install `readline` for a better experience.")
|
10 |
+
|
11 |
+
|
12 |
+
def main():
|
13 |
+
chat_model = ChatModel()
|
14 |
+
history = []
|
15 |
+
print("Welcome to the CLI application, use `clear` to remove the history, use `exit` to exit the application.")
|
16 |
+
|
17 |
+
while True:
|
18 |
+
try:
|
19 |
+
query = input("\nUser: ")
|
20 |
+
except UnicodeDecodeError:
|
21 |
+
print("Detected decoding error at the inputs, please set the terminal encoding to utf-8.")
|
22 |
+
continue
|
23 |
+
except Exception:
|
24 |
+
raise
|
25 |
+
|
26 |
+
if query.strip() == "exit":
|
27 |
+
break
|
28 |
+
|
29 |
+
if query.strip() == "clear":
|
30 |
+
history = []
|
31 |
+
torch_gc()
|
32 |
+
print("History has been removed.")
|
33 |
+
continue
|
34 |
+
|
35 |
+
print("Assistant: ", end="", flush=True)
|
36 |
+
|
37 |
+
response = ""
|
38 |
+
for new_text in chat_model.stream_chat(query, history):
|
39 |
+
print(new_text, end="", flush=True)
|
40 |
+
response += new_text
|
41 |
+
print()
|
42 |
+
|
43 |
+
history = history + [(query, response)]
|
44 |
+
|
45 |
+
|
46 |
+
if __name__ == "__main__":
|
47 |
+
main()
|
LLM-Detector-V4-11w/src/evaluate.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from llmtuner import Evaluator
|
2 |
+
|
3 |
+
|
4 |
+
def main():
|
5 |
+
evaluator = Evaluator()
|
6 |
+
evaluator.eval()
|
7 |
+
|
8 |
+
|
9 |
+
if __name__ == "__main__":
|
10 |
+
main()
|
LLM-Detector-V4-11w/src/export_model.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from llmtuner import export_model
|
2 |
+
|
3 |
+
|
4 |
+
def main():
|
5 |
+
export_model()
|
6 |
+
|
7 |
+
|
8 |
+
if __name__ == "__main__":
|
9 |
+
main()
|
LLM-Detector-V4-11w/src/llmtuner/__init__.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Level: api, webui > chat, eval, train > data, model > extras, hparams
|
2 |
+
|
3 |
+
from llmtuner.api import create_app
|
4 |
+
from llmtuner.chat import ChatModel
|
5 |
+
from llmtuner.eval import Evaluator
|
6 |
+
from llmtuner.train import export_model, run_exp
|
7 |
+
from llmtuner.webui import create_ui, create_web_demo
|
8 |
+
|
9 |
+
|
10 |
+
__version__ = "0.3.2"
|
LLM-Detector-V4-11w/src/llmtuner/api/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from llmtuner.api.app import create_app
|
LLM-Detector-V4-11w/src/llmtuner/api/app.py
ADDED
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
from typing import List, Tuple
|
3 |
+
from pydantic import BaseModel
|
4 |
+
from contextlib import asynccontextmanager
|
5 |
+
|
6 |
+
from llmtuner.api.protocol import (
|
7 |
+
Role,
|
8 |
+
Finish,
|
9 |
+
ModelCard,
|
10 |
+
ModelList,
|
11 |
+
ChatMessage,
|
12 |
+
DeltaMessage,
|
13 |
+
ChatCompletionRequest,
|
14 |
+
ChatCompletionResponse,
|
15 |
+
ChatCompletionStreamResponse,
|
16 |
+
ChatCompletionResponseChoice,
|
17 |
+
ChatCompletionResponseStreamChoice,
|
18 |
+
ChatCompletionResponseUsage
|
19 |
+
)
|
20 |
+
from llmtuner.chat import ChatModel
|
21 |
+
from llmtuner.extras.misc import torch_gc
|
22 |
+
from llmtuner.extras.packages import (
|
23 |
+
is_fastapi_availble, is_starlette_available, is_uvicorn_available
|
24 |
+
)
|
25 |
+
|
26 |
+
|
27 |
+
if is_fastapi_availble():
|
28 |
+
from fastapi import FastAPI, HTTPException, status
|
29 |
+
from fastapi.middleware.cors import CORSMiddleware
|
30 |
+
|
31 |
+
|
32 |
+
if is_starlette_available():
|
33 |
+
from sse_starlette import EventSourceResponse
|
34 |
+
|
35 |
+
|
36 |
+
if is_uvicorn_available():
|
37 |
+
import uvicorn
|
38 |
+
|
39 |
+
|
40 |
+
@asynccontextmanager
|
41 |
+
async def lifespan(app: "FastAPI"): # collects GPU memory
|
42 |
+
yield
|
43 |
+
torch_gc()
|
44 |
+
|
45 |
+
|
46 |
+
def to_json(data: BaseModel) -> str:
|
47 |
+
try: # pydantic v2
|
48 |
+
return json.dumps(data.model_dump(exclude_unset=True), ensure_ascii=False)
|
49 |
+
except: # pydantic v1
|
50 |
+
return data.json(exclude_unset=True, ensure_ascii=False)
|
51 |
+
|
52 |
+
|
53 |
+
def create_app(chat_model: "ChatModel") -> "FastAPI":
|
54 |
+
app = FastAPI(lifespan=lifespan)
|
55 |
+
|
56 |
+
app.add_middleware(
|
57 |
+
CORSMiddleware,
|
58 |
+
allow_origins=["*"],
|
59 |
+
allow_credentials=True,
|
60 |
+
allow_methods=["*"],
|
61 |
+
allow_headers=["*"],
|
62 |
+
)
|
63 |
+
|
64 |
+
@app.get("/v1/models", response_model=ModelList)
|
65 |
+
async def list_models():
|
66 |
+
model_card = ModelCard(id="gpt-3.5-turbo")
|
67 |
+
return ModelList(data=[model_card])
|
68 |
+
|
69 |
+
@app.post("/v1/chat/completions", response_model=ChatCompletionResponse, status_code=status.HTTP_200_OK)
|
70 |
+
async def create_chat_completion(request: ChatCompletionRequest):
|
71 |
+
if len(request.messages) == 0 or request.messages[-1].role != Role.USER:
|
72 |
+
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid request")
|
73 |
+
|
74 |
+
query = request.messages[-1].content
|
75 |
+
prev_messages = request.messages[:-1]
|
76 |
+
if len(prev_messages) and prev_messages[0].role == Role.SYSTEM:
|
77 |
+
system = prev_messages.pop(0).content
|
78 |
+
else:
|
79 |
+
system = None
|
80 |
+
|
81 |
+
history = []
|
82 |
+
if len(prev_messages) % 2 == 0:
|
83 |
+
for i in range(0, len(prev_messages), 2):
|
84 |
+
if prev_messages[i].role == Role.USER and prev_messages[i+1].role == Role.ASSISTANT:
|
85 |
+
history.append([prev_messages[i].content, prev_messages[i+1].content])
|
86 |
+
else:
|
87 |
+
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only supports u/a/u/a/u...")
|
88 |
+
else:
|
89 |
+
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only supports u/a/u/a/u...")
|
90 |
+
|
91 |
+
if request.stream:
|
92 |
+
generate = predict(query, history, system, request)
|
93 |
+
return EventSourceResponse(generate, media_type="text/event-stream")
|
94 |
+
|
95 |
+
responses = chat_model.chat(
|
96 |
+
query, history, system,
|
97 |
+
do_sample=request.do_sample,
|
98 |
+
temperature=request.temperature,
|
99 |
+
top_p=request.top_p,
|
100 |
+
max_new_tokens=request.max_tokens,
|
101 |
+
num_return_sequences=request.n
|
102 |
+
)
|
103 |
+
|
104 |
+
prompt_length, response_length = 0, 0
|
105 |
+
choices = []
|
106 |
+
for i, response in enumerate(responses):
|
107 |
+
choices.append(ChatCompletionResponseChoice(
|
108 |
+
index=i,
|
109 |
+
message=ChatMessage(role=Role.ASSISTANT, content=response.response_text),
|
110 |
+
finish_reason=Finish.STOP if response.finish_reason == "stop" else Finish.LENGTH
|
111 |
+
))
|
112 |
+
prompt_length = response.prompt_length
|
113 |
+
response_length += response.response_length
|
114 |
+
|
115 |
+
usage = ChatCompletionResponseUsage(
|
116 |
+
prompt_tokens=prompt_length,
|
117 |
+
completion_tokens=response_length,
|
118 |
+
total_tokens=prompt_length+response_length
|
119 |
+
)
|
120 |
+
|
121 |
+
return ChatCompletionResponse(model=request.model, choices=choices, usage=usage)
|
122 |
+
|
123 |
+
async def predict(query: str, history: List[Tuple[str, str]], system: str, request: ChatCompletionRequest):
|
124 |
+
choice_data = ChatCompletionResponseStreamChoice(
|
125 |
+
index=0,
|
126 |
+
delta=DeltaMessage(role=Role.ASSISTANT),
|
127 |
+
finish_reason=None
|
128 |
+
)
|
129 |
+
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
|
130 |
+
yield to_json(chunk)
|
131 |
+
|
132 |
+
for new_text in chat_model.stream_chat(
|
133 |
+
query, history, system,
|
134 |
+
do_sample=request.do_sample,
|
135 |
+
temperature=request.temperature,
|
136 |
+
top_p=request.top_p,
|
137 |
+
max_new_tokens=request.max_tokens
|
138 |
+
):
|
139 |
+
if len(new_text) == 0:
|
140 |
+
continue
|
141 |
+
|
142 |
+
choice_data = ChatCompletionResponseStreamChoice(
|
143 |
+
index=0,
|
144 |
+
delta=DeltaMessage(content=new_text),
|
145 |
+
finish_reason=None
|
146 |
+
)
|
147 |
+
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
|
148 |
+
yield to_json(chunk)
|
149 |
+
|
150 |
+
choice_data = ChatCompletionResponseStreamChoice(
|
151 |
+
index=0,
|
152 |
+
delta=DeltaMessage(),
|
153 |
+
finish_reason=Finish.STOP
|
154 |
+
)
|
155 |
+
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
|
156 |
+
yield to_json(chunk)
|
157 |
+
yield "[DONE]"
|
158 |
+
|
159 |
+
return app
|
160 |
+
|
161 |
+
|
162 |
+
if __name__ == "__main__":
|
163 |
+
chat_model = ChatModel()
|
164 |
+
app = create_app(chat_model)
|
165 |
+
uvicorn.run(app, host="0.0.0.0", port=8000, workers=1)
|
LLM-Detector-V4-11w/src/llmtuner/api/protocol.py
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
from enum import Enum
|
3 |
+
from pydantic import BaseModel, Field
|
4 |
+
from typing import List, Optional
|
5 |
+
|
6 |
+
|
7 |
+
class Role(str, Enum):
|
8 |
+
USER = "user"
|
9 |
+
ASSISTANT = "assistant"
|
10 |
+
SYSTEM = "system"
|
11 |
+
|
12 |
+
|
13 |
+
class Finish(str, Enum):
|
14 |
+
STOP = "stop"
|
15 |
+
LENGTH = "length"
|
16 |
+
|
17 |
+
|
18 |
+
class ModelCard(BaseModel):
|
19 |
+
id: str
|
20 |
+
object: Optional[str] = "model"
|
21 |
+
created: Optional[int] = Field(default_factory=lambda: int(time.time()))
|
22 |
+
owned_by: Optional[str] = "owner"
|
23 |
+
|
24 |
+
|
25 |
+
class ModelList(BaseModel):
|
26 |
+
object: Optional[str] = "list"
|
27 |
+
data: Optional[List[ModelCard]] = []
|
28 |
+
|
29 |
+
|
30 |
+
class ChatMessage(BaseModel):
|
31 |
+
role: Role
|
32 |
+
content: str
|
33 |
+
|
34 |
+
|
35 |
+
class DeltaMessage(BaseModel):
|
36 |
+
role: Optional[Role] = None
|
37 |
+
content: Optional[str] = None
|
38 |
+
|
39 |
+
|
40 |
+
class ChatCompletionRequest(BaseModel):
|
41 |
+
model: str
|
42 |
+
messages: List[ChatMessage]
|
43 |
+
do_sample: Optional[bool] = True
|
44 |
+
temperature: Optional[float] = None
|
45 |
+
top_p: Optional[float] = None
|
46 |
+
n: Optional[int] = 1
|
47 |
+
max_tokens: Optional[int] = None
|
48 |
+
stream: Optional[bool] = False
|
49 |
+
|
50 |
+
|
51 |
+
class ChatCompletionResponseChoice(BaseModel):
|
52 |
+
index: int
|
53 |
+
message: ChatMessage
|
54 |
+
finish_reason: Finish
|
55 |
+
|
56 |
+
|
57 |
+
class ChatCompletionResponseStreamChoice(BaseModel):
|
58 |
+
index: int
|
59 |
+
delta: DeltaMessage
|
60 |
+
finish_reason: Optional[Finish] = None
|
61 |
+
|
62 |
+
|
63 |
+
class ChatCompletionResponseUsage(BaseModel):
|
64 |
+
prompt_tokens: int
|
65 |
+
completion_tokens: int
|
66 |
+
total_tokens: int
|
67 |
+
|
68 |
+
|
69 |
+
class ChatCompletionResponse(BaseModel):
|
70 |
+
id: Optional[str] = "chatcmpl-default"
|
71 |
+
object: Optional[str] = "chat.completion"
|
72 |
+
created: Optional[int] = Field(default_factory=lambda: int(time.time()))
|
73 |
+
model: str
|
74 |
+
choices: List[ChatCompletionResponseChoice]
|
75 |
+
usage: ChatCompletionResponseUsage
|
76 |
+
|
77 |
+
|
78 |
+
class ChatCompletionStreamResponse(BaseModel):
|
79 |
+
id: Optional[str] = "chatcmpl-default"
|
80 |
+
object: Optional[str] = "chat.completion.chunk"
|
81 |
+
created: Optional[int] = Field(default_factory=lambda: int(time.time()))
|
82 |
+
model: str
|
83 |
+
choices: List[ChatCompletionResponseStreamChoice]
|
LLM-Detector-V4-11w/src/llmtuner/chat/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from llmtuner.chat.chat_model import ChatModel
|
LLM-Detector-V4-11w/src/llmtuner/chat/chat_model.py
ADDED
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from dataclasses import dataclass
|
3 |
+
from typing import Any, Dict, Generator, List, Literal, Optional, Tuple
|
4 |
+
from threading import Thread
|
5 |
+
from transformers import GenerationConfig, TextIteratorStreamer
|
6 |
+
|
7 |
+
from llmtuner.data.template import get_template_and_fix_tokenizer
|
8 |
+
from llmtuner.extras.misc import get_logits_processor
|
9 |
+
from llmtuner.model import dispatch_model, get_infer_args, load_model_and_tokenizer
|
10 |
+
|
11 |
+
|
12 |
+
@dataclass
|
13 |
+
class Response:
|
14 |
+
|
15 |
+
response_text: str
|
16 |
+
response_length: int
|
17 |
+
prompt_length: int
|
18 |
+
finish_reason: Literal["stop", "length"]
|
19 |
+
|
20 |
+
|
21 |
+
class ChatModel:
|
22 |
+
|
23 |
+
def __init__(self, args: Optional[Dict[str, Any]] = None) -> None:
|
24 |
+
model_args, data_args, finetuning_args, self.generating_args = get_infer_args(args)
|
25 |
+
self.model, self.tokenizer = load_model_and_tokenizer(model_args, finetuning_args)
|
26 |
+
self.tokenizer.padding_side = "left"
|
27 |
+
self.model = dispatch_model(self.model)
|
28 |
+
self.template = get_template_and_fix_tokenizer(data_args.template, self.tokenizer)
|
29 |
+
self.system_prompt = data_args.system_prompt
|
30 |
+
|
31 |
+
def _process_args(
|
32 |
+
self,
|
33 |
+
query: str,
|
34 |
+
history: Optional[List[Tuple[str, str]]] = None,
|
35 |
+
system: Optional[str] = None,
|
36 |
+
**input_kwargs
|
37 |
+
) -> Tuple[Dict[str, Any], int]:
|
38 |
+
system = system or self.system_prompt
|
39 |
+
prompt, _ = self.template.encode_oneturn(
|
40 |
+
tokenizer=self.tokenizer, query=query, resp="", history=history, system=system
|
41 |
+
)
|
42 |
+
prompt_length = len(prompt)
|
43 |
+
input_ids = torch.tensor([prompt], device=self.model.device)
|
44 |
+
|
45 |
+
do_sample = input_kwargs.pop("do_sample", None)
|
46 |
+
temperature = input_kwargs.pop("temperature", None)
|
47 |
+
top_p = input_kwargs.pop("top_p", None)
|
48 |
+
top_k = input_kwargs.pop("top_k", None)
|
49 |
+
num_return_sequences = input_kwargs.pop("num_return_sequences", None)
|
50 |
+
repetition_penalty = input_kwargs.pop("repetition_penalty", None)
|
51 |
+
max_length = input_kwargs.pop("max_length", None)
|
52 |
+
max_new_tokens = input_kwargs.pop("max_new_tokens", None)
|
53 |
+
|
54 |
+
generating_args = self.generating_args.to_dict()
|
55 |
+
generating_args.update(dict(
|
56 |
+
do_sample=do_sample if do_sample is not None else generating_args["do_sample"],
|
57 |
+
temperature=temperature or generating_args["temperature"],
|
58 |
+
top_p=top_p or generating_args["top_p"],
|
59 |
+
top_k=top_k or generating_args["top_k"],
|
60 |
+
num_return_sequences=num_return_sequences or 1,
|
61 |
+
repetition_penalty=repetition_penalty or generating_args["repetition_penalty"],
|
62 |
+
eos_token_id=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids,
|
63 |
+
pad_token_id=self.tokenizer.pad_token_id
|
64 |
+
))
|
65 |
+
|
66 |
+
if isinstance(num_return_sequences, int) and num_return_sequences > 1:
|
67 |
+
generating_args["do_sample"] = True
|
68 |
+
|
69 |
+
if max_length:
|
70 |
+
generating_args.pop("max_new_tokens", None)
|
71 |
+
generating_args["max_length"] = max_length
|
72 |
+
|
73 |
+
if max_new_tokens:
|
74 |
+
generating_args.pop("max_length", None)
|
75 |
+
generating_args["max_new_tokens"] = max_new_tokens
|
76 |
+
|
77 |
+
gen_kwargs = dict(
|
78 |
+
inputs=input_ids,
|
79 |
+
generation_config=GenerationConfig(**generating_args),
|
80 |
+
logits_processor=get_logits_processor()
|
81 |
+
)
|
82 |
+
|
83 |
+
return gen_kwargs, prompt_length
|
84 |
+
|
85 |
+
@torch.inference_mode()
|
86 |
+
def chat(
|
87 |
+
self,
|
88 |
+
query: str,
|
89 |
+
history: Optional[List[Tuple[str, str]]] = None,
|
90 |
+
system: Optional[str] = None,
|
91 |
+
**input_kwargs
|
92 |
+
) -> List[Response]:
|
93 |
+
r"""
|
94 |
+
Args: query, history, system, **input_kwargs
|
95 |
+
|
96 |
+
Returns: [(response_text, prompt_length, response_length)] * n (default n=1)
|
97 |
+
"""
|
98 |
+
gen_kwargs, prompt_length = self._process_args(query, history, system, **input_kwargs)
|
99 |
+
generate_output = self.model.generate(**gen_kwargs)
|
100 |
+
response_ids = generate_output[:, prompt_length:]
|
101 |
+
response = self.tokenizer.batch_decode(
|
102 |
+
response_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True
|
103 |
+
)
|
104 |
+
results = []
|
105 |
+
for i in range(len(response)):
|
106 |
+
eos_index = (response_ids[i] == self.tokenizer.eos_token_id).nonzero()
|
107 |
+
response_length = (eos_index[0].item() + 1) if len(eos_index) else len(response_ids[i])
|
108 |
+
results.append(Response(
|
109 |
+
response_text=response[i],
|
110 |
+
response_length=response_length,
|
111 |
+
prompt_length=prompt_length,
|
112 |
+
finish_reason="stop" if len(eos_index) else "length"
|
113 |
+
))
|
114 |
+
|
115 |
+
return results
|
116 |
+
|
117 |
+
@torch.inference_mode()
|
118 |
+
def stream_chat(
|
119 |
+
self,
|
120 |
+
query: str,
|
121 |
+
history: Optional[List[Tuple[str, str]]] = None,
|
122 |
+
system: Optional[str] = None,
|
123 |
+
**input_kwargs
|
124 |
+
) -> Generator[str, None, None]:
|
125 |
+
gen_kwargs, _ = self._process_args(query, history, system, **input_kwargs)
|
126 |
+
streamer = TextIteratorStreamer(self.tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
|
127 |
+
gen_kwargs["streamer"] = streamer
|
128 |
+
|
129 |
+
thread = Thread(target=self.model.generate, kwargs=gen_kwargs)
|
130 |
+
thread.start()
|
131 |
+
|
132 |
+
yield from streamer
|
LLM-Detector-V4-11w/src/llmtuner/data/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from llmtuner.data.loader import get_dataset
|
2 |
+
from llmtuner.data.preprocess import preprocess_dataset
|
3 |
+
from llmtuner.data.template import get_template_and_fix_tokenizer
|
4 |
+
from llmtuner.data.utils import split_dataset
|
LLM-Detector-V4-11w/src/llmtuner/data/loader.py
ADDED
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from typing import TYPE_CHECKING, Any, Dict, List, Union
|
3 |
+
|
4 |
+
from datasets import concatenate_datasets, interleave_datasets, load_dataset
|
5 |
+
|
6 |
+
from llmtuner.data.utils import checksum, EXT2TYPE
|
7 |
+
from llmtuner.extras.logging import get_logger
|
8 |
+
|
9 |
+
if TYPE_CHECKING:
|
10 |
+
from datasets import Dataset, IterableDataset
|
11 |
+
from llmtuner.hparams import ModelArguments, DataArguments
|
12 |
+
|
13 |
+
|
14 |
+
logger = get_logger(__name__)
|
15 |
+
|
16 |
+
|
17 |
+
def get_dataset(
|
18 |
+
model_args: "ModelArguments",
|
19 |
+
data_args: "DataArguments"
|
20 |
+
) -> Union["Dataset", "IterableDataset"]:
|
21 |
+
max_samples = data_args.max_samples
|
22 |
+
all_datasets: List[Union["Dataset", "IterableDataset"]] = [] # support multiple datasets
|
23 |
+
|
24 |
+
for dataset_attr in data_args.dataset_list:
|
25 |
+
logger.info("Loading dataset {}...".format(dataset_attr))
|
26 |
+
|
27 |
+
if dataset_attr.load_from == "hf_hub":
|
28 |
+
data_path = dataset_attr.dataset_name
|
29 |
+
data_name = dataset_attr.subset
|
30 |
+
data_files = None
|
31 |
+
elif dataset_attr.load_from == "script":
|
32 |
+
data_path = os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)
|
33 |
+
data_name = dataset_attr.subset
|
34 |
+
data_files = None
|
35 |
+
elif dataset_attr.load_from == "file":
|
36 |
+
data_path, data_name = None, None
|
37 |
+
data_files: List[str] = []
|
38 |
+
if os.path.isdir(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)): # is directory
|
39 |
+
for file_name in os.listdir(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)):
|
40 |
+
data_files.append(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name, file_name))
|
41 |
+
if data_path is None:
|
42 |
+
data_path = EXT2TYPE.get(file_name.split(".")[-1], None)
|
43 |
+
else:
|
44 |
+
assert data_path == EXT2TYPE.get(file_name.split(".")[-1], None), "file types are not identical."
|
45 |
+
elif os.path.isfile(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)): # is file
|
46 |
+
data_files.append(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name))
|
47 |
+
data_path = EXT2TYPE.get(dataset_attr.dataset_name.split(".")[-1], None)
|
48 |
+
else:
|
49 |
+
raise ValueError("File not found.")
|
50 |
+
|
51 |
+
assert data_path, "File extension must be txt, csv, json or jsonl."
|
52 |
+
checksum(data_files, dataset_attr.dataset_sha1)
|
53 |
+
else:
|
54 |
+
raise NotImplementedError
|
55 |
+
|
56 |
+
dataset = load_dataset(
|
57 |
+
path=data_path,
|
58 |
+
name=data_name,
|
59 |
+
data_files=data_files,
|
60 |
+
split=data_args.split,
|
61 |
+
cache_dir=model_args.cache_dir,
|
62 |
+
token=model_args.hf_hub_token,
|
63 |
+
streaming=(data_args.streaming and (dataset_attr.load_from != "file"))
|
64 |
+
)
|
65 |
+
|
66 |
+
if data_args.streaming and (dataset_attr.load_from == "file"):
|
67 |
+
dataset = dataset.to_iterable_dataset() # TODO: add num shards parameter
|
68 |
+
|
69 |
+
if max_samples is not None: # truncate dataset
|
70 |
+
dataset = dataset.select(range(min(len(dataset), max_samples)))
|
71 |
+
|
72 |
+
def convert_format(examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
|
73 |
+
# convert dataset from sharegpt format to alpaca format
|
74 |
+
outputs = {"prompt": [], "query": [], "response": [], "history": []}
|
75 |
+
for msg_list in examples[dataset_attr.messages]:
|
76 |
+
msg_list = msg_list[:len(msg_list) // 2 * 2] # should be multiples of 2
|
77 |
+
if len(msg_list) == 0:
|
78 |
+
continue
|
79 |
+
|
80 |
+
msg_pairs = []
|
81 |
+
user_role, assistant_role = None, None
|
82 |
+
for idx in range(0, len(msg_list), 2):
|
83 |
+
if user_role is None and assistant_role is None:
|
84 |
+
user_role = msg_list[idx][dataset_attr.role]
|
85 |
+
assistant_role = msg_list[idx + 1][dataset_attr.role]
|
86 |
+
else:
|
87 |
+
if (
|
88 |
+
msg_list[idx][dataset_attr.role] != user_role
|
89 |
+
or msg_list[idx+1][dataset_attr.role] != assistant_role
|
90 |
+
):
|
91 |
+
raise ValueError("Only accepts conversation in u/a/u/a/u/a order.")
|
92 |
+
msg_pairs.append((msg_list[idx][dataset_attr.content], msg_list[idx + 1][dataset_attr.content]))
|
93 |
+
|
94 |
+
if len(msg_pairs) != 0:
|
95 |
+
outputs["prompt"].append(msg_pairs[-1][0])
|
96 |
+
outputs["query"].append("")
|
97 |
+
outputs["response"].append(msg_pairs[-1][1])
|
98 |
+
outputs["history"].append(msg_pairs[:-1])
|
99 |
+
|
100 |
+
return outputs
|
101 |
+
|
102 |
+
if dataset_attr.formatting == "sharegpt": # convert format
|
103 |
+
column_names = list(next(iter(dataset)).keys())
|
104 |
+
kwargs = {}
|
105 |
+
if not data_args.streaming:
|
106 |
+
kwargs = dict(
|
107 |
+
num_proc=data_args.preprocessing_num_workers,
|
108 |
+
load_from_cache_file=(not data_args.overwrite_cache),
|
109 |
+
desc="Converting format of dataset"
|
110 |
+
)
|
111 |
+
|
112 |
+
dataset = dataset.map(
|
113 |
+
convert_format,
|
114 |
+
batched=True,
|
115 |
+
remove_columns=column_names,
|
116 |
+
**kwargs
|
117 |
+
)
|
118 |
+
else:
|
119 |
+
for column_name in ["prompt", "query", "response", "history"]: # align dataset
|
120 |
+
if getattr(dataset_attr, column_name) and getattr(dataset_attr, column_name) != column_name:
|
121 |
+
dataset = dataset.rename_column(getattr(dataset_attr, column_name), column_name)
|
122 |
+
|
123 |
+
if dataset_attr.system_prompt: # add system prompt
|
124 |
+
system_prompt = dataset_attr.system_prompt
|
125 |
+
if data_args.streaming:
|
126 |
+
dataset = dataset.map(lambda _: {"system": system_prompt})
|
127 |
+
else:
|
128 |
+
dataset = dataset.add_column("system", [system_prompt] * len(dataset))
|
129 |
+
|
130 |
+
all_datasets.append(dataset)
|
131 |
+
|
132 |
+
if len(data_args.dataset_list) == 1:
|
133 |
+
return all_datasets[0]
|
134 |
+
elif data_args.mix_strategy == "concat":
|
135 |
+
if data_args.streaming:
|
136 |
+
logger.warning("The samples between different datasets will not be mixed in streaming mode.")
|
137 |
+
return concatenate_datasets(all_datasets)
|
138 |
+
elif data_args.mix_strategy.startswith("interleave"):
|
139 |
+
if not data_args.streaming:
|
140 |
+
logger.warning("We recommend using `mix_strategy=concat` in non-streaming mode.")
|
141 |
+
return interleave_datasets(
|
142 |
+
datasets=all_datasets,
|
143 |
+
probabilities=data_args.interleave_probs,
|
144 |
+
seed=data_args.seed,
|
145 |
+
stopping_strategy="first_exhausted" if data_args.mix_strategy.endswith("under") else "all_exhausted"
|
146 |
+
)
|
147 |
+
else:
|
148 |
+
raise ValueError("Unknown mixing strategy.")
|
LLM-Detector-V4-11w/src/llmtuner/data/preprocess.py
ADDED
@@ -0,0 +1,275 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import tiktoken
|
3 |
+
from itertools import chain
|
4 |
+
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Literal, Tuple, Union
|
5 |
+
|
6 |
+
from datasets import load_from_disk
|
7 |
+
|
8 |
+
from llmtuner.data.template import get_template_and_fix_tokenizer
|
9 |
+
from llmtuner.extras.constants import IGNORE_INDEX
|
10 |
+
from llmtuner.extras.logging import get_logger
|
11 |
+
|
12 |
+
if TYPE_CHECKING:
|
13 |
+
from datasets import Dataset, IterableDataset
|
14 |
+
from transformers import Seq2SeqTrainingArguments
|
15 |
+
from transformers.tokenization_utils import PreTrainedTokenizer
|
16 |
+
from llmtuner.hparams import DataArguments
|
17 |
+
|
18 |
+
|
19 |
+
logger = get_logger(__name__)
|
20 |
+
|
21 |
+
|
22 |
+
def construct_example(examples: Dict[str, List[Any]]) -> Generator[Any, None, None]:
|
23 |
+
for i in range(len(examples["prompt"])):
|
24 |
+
query, response = examples["prompt"][i], examples["response"][i]
|
25 |
+
query = query + "\n" + examples["query"][i] if "query" in examples and examples["query"][i] else query
|
26 |
+
history = examples["history"][i] if "history" in examples else None
|
27 |
+
system = examples["system"][i] if "system" in examples else None
|
28 |
+
yield query, response, history, system
|
29 |
+
|
30 |
+
|
31 |
+
def infer_max_len(source_len: int, target_len: int, data_args: "DataArguments") -> Tuple[int, int]:
|
32 |
+
max_target_len = int(data_args.cutoff_len * (target_len / (source_len + target_len)))
|
33 |
+
max_target_len = max(max_target_len, data_args.reserved_label_len)
|
34 |
+
max_source_len = data_args.cutoff_len - max_target_len
|
35 |
+
return max_source_len, max_target_len
|
36 |
+
|
37 |
+
|
38 |
+
def preprocess_dataset(
|
39 |
+
dataset: Union["Dataset", "IterableDataset"],
|
40 |
+
tokenizer: "PreTrainedTokenizer",
|
41 |
+
data_args: "DataArguments",
|
42 |
+
training_args: "Seq2SeqTrainingArguments",
|
43 |
+
stage: Literal["pt", "sft", "rm", "ppo"]
|
44 |
+
) -> Union["Dataset", "IterableDataset"]:
|
45 |
+
template = get_template_and_fix_tokenizer(data_args.template, tokenizer)
|
46 |
+
|
47 |
+
if data_args.train_on_prompt and template.efficient_eos:
|
48 |
+
raise ValueError("Current template does not support `train_on_prompt`.")
|
49 |
+
|
50 |
+
def preprocess_pretrain_dataset(examples: Dict[str, List[Any]]) -> Dict[str, List[List[int]]]:
|
51 |
+
# build grouped texts with format `X1 X2 X3 ...`
|
52 |
+
if isinstance(getattr(tokenizer, "tokenizer", None), tiktoken.Encoding): # for tiktoken tokenizer (Qwen)
|
53 |
+
kwargs = dict(allowed_special="all")
|
54 |
+
else:
|
55 |
+
kwargs = dict(add_special_tokens=True)
|
56 |
+
|
57 |
+
if hasattr(tokenizer, "add_eos_token"): # for LLaMA tokenizer
|
58 |
+
add_eos_token_flag = getattr(tokenizer, "add_eos_token")
|
59 |
+
setattr(tokenizer, "add_eos_token", True)
|
60 |
+
|
61 |
+
tokenized_examples = tokenizer(examples["prompt"], **kwargs)
|
62 |
+
concatenated_examples = {k: list(chain(*tokenized_examples[k])) for k in tokenized_examples.keys()}
|
63 |
+
total_length = len(concatenated_examples[list(concatenated_examples.keys())[0]])
|
64 |
+
block_size = data_args.cutoff_len
|
65 |
+
# we drop the small remainder, and if the total_length < block_size, we exclude this batch
|
66 |
+
total_length = (total_length // block_size) * block_size
|
67 |
+
# split by chunks of cutoff_len
|
68 |
+
result = {
|
69 |
+
k: [t[i: i + block_size] for i in range(0, total_length, block_size)]
|
70 |
+
for k, t in concatenated_examples.items()
|
71 |
+
}
|
72 |
+
# make sure the saved tokenizer is the same as the original one
|
73 |
+
if hasattr(tokenizer, "add_eos_token"):
|
74 |
+
setattr(tokenizer, "add_eos_token", add_eos_token_flag)
|
75 |
+
return result
|
76 |
+
|
77 |
+
def preprocess_supervised_dataset(examples: Dict[str, List[Any]]) -> Dict[str, List[List[int]]]:
|
78 |
+
# build inputs with format `<bos> X Y <eos>` and labels with format `<ignore> ... <ignore> Y <eos>`
|
79 |
+
# for multiturn examples, we only mask the prompt part in each prompt-response pair.
|
80 |
+
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
|
81 |
+
|
82 |
+
for query, response, history, system in construct_example(examples):
|
83 |
+
if not (isinstance(query, str) and isinstance(response, str) and query != "" and response != ""):
|
84 |
+
continue
|
85 |
+
|
86 |
+
input_ids, labels = [], []
|
87 |
+
for turn_idx, (source_ids, target_ids) in enumerate(template.encode_multiturn(
|
88 |
+
tokenizer, query, response, history, system
|
89 |
+
)):
|
90 |
+
source_len, target_len = len(source_ids), len(target_ids)
|
91 |
+
max_source_len, max_target_len = infer_max_len(source_len, target_len, data_args)
|
92 |
+
if source_len > max_source_len:
|
93 |
+
source_ids = source_ids[:max_source_len]
|
94 |
+
if target_len > max_target_len:
|
95 |
+
target_ids = target_ids[:max_target_len]
|
96 |
+
|
97 |
+
if data_args.train_on_prompt:
|
98 |
+
source_mask = source_ids
|
99 |
+
elif turn_idx != 0 and template.efficient_eos:
|
100 |
+
source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1)
|
101 |
+
else:
|
102 |
+
source_mask = [IGNORE_INDEX] * len(source_ids)
|
103 |
+
|
104 |
+
input_ids += source_ids + target_ids
|
105 |
+
labels += source_mask + target_ids
|
106 |
+
|
107 |
+
if template.efficient_eos:
|
108 |
+
input_ids += [tokenizer.eos_token_id]
|
109 |
+
labels += [tokenizer.eos_token_id]
|
110 |
+
|
111 |
+
if len(input_ids) > data_args.cutoff_len:
|
112 |
+
input_ids = input_ids[:data_args.cutoff_len]
|
113 |
+
labels = labels[:data_args.cutoff_len]
|
114 |
+
|
115 |
+
model_inputs["input_ids"].append(input_ids)
|
116 |
+
model_inputs["attention_mask"].append([1] * len(input_ids))
|
117 |
+
model_inputs["labels"].append(labels)
|
118 |
+
|
119 |
+
return model_inputs
|
120 |
+
|
121 |
+
def preprocess_packed_supervised_dataset(examples: Dict[str, List[Any]]) -> Dict[str, List[List[int]]]:
|
122 |
+
# build inputs with format `<bos> X1 Y1 <eos> <bos> X2 Y2 <eos>`
|
123 |
+
# and labels with format `<ignore> ... <ignore> Y1 <eos> <ignore> ... <ignore> Y2 <eos>`
|
124 |
+
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
|
125 |
+
input_ids, labels = [], []
|
126 |
+
for query, response, history, system in construct_example(examples):
|
127 |
+
if not (isinstance(query, str) and isinstance(response, str) and query != "" and response != ""):
|
128 |
+
continue
|
129 |
+
|
130 |
+
for turn_idx, (source_ids, target_ids) in enumerate(template.encode_multiturn(
|
131 |
+
tokenizer, query, response, history, system
|
132 |
+
)):
|
133 |
+
if data_args.train_on_prompt:
|
134 |
+
source_mask = source_ids
|
135 |
+
elif turn_idx != 0 and template.efficient_eos:
|
136 |
+
source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1)
|
137 |
+
else:
|
138 |
+
source_mask = [IGNORE_INDEX] * len(source_ids)
|
139 |
+
input_ids += source_ids + target_ids
|
140 |
+
labels += source_mask + target_ids
|
141 |
+
|
142 |
+
if template.efficient_eos:
|
143 |
+
input_ids += [tokenizer.eos_token_id]
|
144 |
+
labels += [tokenizer.eos_token_id]
|
145 |
+
|
146 |
+
total_length = len(input_ids)
|
147 |
+
block_size = data_args.cutoff_len
|
148 |
+
# we drop the small remainder, and if the total_length < block_size, we exclude this batch
|
149 |
+
total_length = (total_length // block_size) * block_size
|
150 |
+
# split by chunks of cutoff_len
|
151 |
+
for i in range(0, total_length, block_size):
|
152 |
+
model_inputs["input_ids"].append(input_ids[i: i + block_size])
|
153 |
+
model_inputs["attention_mask"].append([1] * block_size)
|
154 |
+
model_inputs["labels"].append(labels[i: i + block_size])
|
155 |
+
|
156 |
+
return model_inputs
|
157 |
+
|
158 |
+
def preprocess_unsupervised_dataset(examples: Dict[str, List[Any]]) -> Dict[str, List[List[int]]]:
|
159 |
+
# build inputs with format `<bos> X` and labels with format `Y <eos>`
|
160 |
+
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
|
161 |
+
|
162 |
+
for query, response, history, system in construct_example(examples):
|
163 |
+
if not (isinstance(query, str) and query != ""):
|
164 |
+
continue
|
165 |
+
|
166 |
+
input_ids, labels = template.encode_oneturn(tokenizer, query, response, history, system)
|
167 |
+
|
168 |
+
if template.efficient_eos:
|
169 |
+
labels += [tokenizer.eos_token_id]
|
170 |
+
|
171 |
+
if len(input_ids) > data_args.cutoff_len:
|
172 |
+
input_ids = input_ids[:data_args.cutoff_len]
|
173 |
+
if len(labels) > data_args.cutoff_len:
|
174 |
+
labels = labels[:data_args.cutoff_len]
|
175 |
+
|
176 |
+
model_inputs["input_ids"].append(input_ids)
|
177 |
+
model_inputs["attention_mask"].append([1] * len(input_ids))
|
178 |
+
model_inputs["labels"].append(labels)
|
179 |
+
|
180 |
+
return model_inputs
|
181 |
+
|
182 |
+
def preprocess_pairwise_dataset(examples: Dict[str, List[Any]]) -> Dict[str, List[List[int]]]:
|
183 |
+
# build input pairs with format `<bos> X`, `Y1 <eos>` and `Y2 <eos>`
|
184 |
+
model_inputs = {"prompt_ids": [], "chosen_ids": [], "rejected_ids": []}
|
185 |
+
for query, response, history, system in construct_example(examples):
|
186 |
+
if not (isinstance(query, str) and isinstance(response, list) and query != "" and len(response) > 1):
|
187 |
+
continue
|
188 |
+
|
189 |
+
prompt_ids, chosen_ids = template.encode_oneturn(tokenizer, query, response[0], history, system)
|
190 |
+
_, rejected_ids = template.encode_oneturn(tokenizer, query, response[1], history, system)
|
191 |
+
|
192 |
+
if template.efficient_eos:
|
193 |
+
chosen_ids += [tokenizer.eos_token_id]
|
194 |
+
rejected_ids += [tokenizer.eos_token_id]
|
195 |
+
|
196 |
+
source_len, target_len = len(prompt_ids), max(len(chosen_ids), len(rejected_ids))
|
197 |
+
max_source_len, max_target_len = infer_max_len(source_len, target_len, data_args)
|
198 |
+
if source_len > max_source_len:
|
199 |
+
prompt_ids = prompt_ids[:max_source_len]
|
200 |
+
if target_len > max_target_len:
|
201 |
+
chosen_ids = chosen_ids[:max_target_len]
|
202 |
+
rejected_ids = rejected_ids[:max_target_len]
|
203 |
+
|
204 |
+
model_inputs["prompt_ids"].append(prompt_ids)
|
205 |
+
model_inputs["chosen_ids"].append(chosen_ids)
|
206 |
+
model_inputs["rejected_ids"].append(rejected_ids)
|
207 |
+
|
208 |
+
return model_inputs
|
209 |
+
|
210 |
+
def print_supervised_dataset_example(example: Dict[str, List[int]]) -> None:
|
211 |
+
print("input_ids:\n{}".format(example["input_ids"]))
|
212 |
+
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
|
213 |
+
print("label_ids:\n{}".format(example["labels"]))
|
214 |
+
print("labels:\n{}".format(
|
215 |
+
tokenizer.decode(list(filter(lambda x: x != IGNORE_INDEX, example["labels"])), skip_special_tokens=False)
|
216 |
+
))
|
217 |
+
|
218 |
+
def print_pairwise_dataset_example(example: Dict[str, List[int]]) -> None:
|
219 |
+
print("prompt_ids:\n{}".format(example["prompt_ids"]))
|
220 |
+
print("prompt:\n{}".format(tokenizer.decode(example["prompt_ids"], skip_special_tokens=False)))
|
221 |
+
print("chosen_ids:\n{}".format(example["chosen_ids"]))
|
222 |
+
print("chosen:\n{}".format(tokenizer.decode(example["chosen_ids"], skip_special_tokens=False)))
|
223 |
+
print("rejected_ids:\n{}".format(example["rejected_ids"]))
|
224 |
+
print("rejected:\n{}".format(tokenizer.decode(example["rejected_ids"], skip_special_tokens=False)))
|
225 |
+
|
226 |
+
def print_unsupervised_dataset_example(example: Dict[str, List[int]]) -> None:
|
227 |
+
print("input_ids:\n{}".format(example["input_ids"]))
|
228 |
+
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
|
229 |
+
|
230 |
+
if stage == "pt":
|
231 |
+
preprocess_func = preprocess_pretrain_dataset
|
232 |
+
print_function = print_unsupervised_dataset_example
|
233 |
+
elif stage == "sft" and not training_args.predict_with_generate:
|
234 |
+
preprocess_func = preprocess_packed_supervised_dataset if data_args.sft_packing else preprocess_supervised_dataset
|
235 |
+
print_function = print_supervised_dataset_example
|
236 |
+
elif stage == "rm":
|
237 |
+
preprocess_func = preprocess_pairwise_dataset
|
238 |
+
print_function = print_pairwise_dataset_example
|
239 |
+
else:
|
240 |
+
preprocess_func = preprocess_unsupervised_dataset
|
241 |
+
print_function = print_unsupervised_dataset_example
|
242 |
+
|
243 |
+
if data_args.cache_path is not None and os.path.exists(data_args.cache_path):
|
244 |
+
logger.warning("Loading dataset from disk will ignore other data arguments.")
|
245 |
+
return load_from_disk(data_args.cache_path)
|
246 |
+
|
247 |
+
with training_args.main_process_first(desc="dataset map pre-processing"):
|
248 |
+
column_names = list(next(iter(dataset)).keys())
|
249 |
+
kwargs = {}
|
250 |
+
if not data_args.streaming:
|
251 |
+
kwargs = dict(
|
252 |
+
num_proc=data_args.preprocessing_num_workers,
|
253 |
+
load_from_cache_file=(not data_args.overwrite_cache),
|
254 |
+
desc="Running tokenizer on dataset"
|
255 |
+
)
|
256 |
+
|
257 |
+
dataset = dataset.map(
|
258 |
+
preprocess_func,
|
259 |
+
batched=True,
|
260 |
+
remove_columns=column_names,
|
261 |
+
**kwargs
|
262 |
+
)
|
263 |
+
|
264 |
+
if data_args.cache_path is not None and not os.path.exists(data_args.cache_path):
|
265 |
+
if training_args.should_save:
|
266 |
+
dataset.save_to_disk(data_args.cache_path)
|
267 |
+
raise SystemExit("Dataset saved, rerun this script with the same `--cache_path`.")
|
268 |
+
|
269 |
+
if training_args.should_log:
|
270 |
+
try:
|
271 |
+
print_function(next(iter(dataset)))
|
272 |
+
except StopIteration:
|
273 |
+
raise RuntimeError("Empty dataset!")
|
274 |
+
|
275 |
+
return dataset
|
LLM-Detector-V4-11w/src/llmtuner/data/template.py
ADDED
@@ -0,0 +1,747 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tiktoken
|
2 |
+
from dataclasses import dataclass
|
3 |
+
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
4 |
+
|
5 |
+
from llmtuner.extras.logging import get_logger
|
6 |
+
|
7 |
+
if TYPE_CHECKING:
|
8 |
+
from transformers import PreTrainedTokenizer
|
9 |
+
|
10 |
+
|
11 |
+
logger = get_logger(__name__)
|
12 |
+
|
13 |
+
|
14 |
+
@dataclass
|
15 |
+
class Template:
|
16 |
+
|
17 |
+
prefix: List[Union[str, Dict[str, str]]]
|
18 |
+
prompt: List[Union[str, Dict[str, str]]]
|
19 |
+
system: str
|
20 |
+
sep: List[Union[str, Dict[str, str]]]
|
21 |
+
stop_words: List[str]
|
22 |
+
use_history: bool
|
23 |
+
efficient_eos: bool
|
24 |
+
|
25 |
+
def encode_oneturn(
|
26 |
+
self,
|
27 |
+
tokenizer: "PreTrainedTokenizer",
|
28 |
+
query: str,
|
29 |
+
resp: str,
|
30 |
+
history: Optional[List[Tuple[str, str]]] = None,
|
31 |
+
system: Optional[str] = None
|
32 |
+
) -> Tuple[List[int], List[int]]:
|
33 |
+
r"""
|
34 |
+
Returns a single pair of token ids representing prompt and response respectively.
|
35 |
+
"""
|
36 |
+
system, history = self._format(query, resp, history, system)
|
37 |
+
encoded_pairs = self._encode(tokenizer, system, history)
|
38 |
+
prompt_ids = []
|
39 |
+
for query_ids, resp_ids in encoded_pairs[:-1]:
|
40 |
+
prompt_ids = prompt_ids + query_ids + resp_ids
|
41 |
+
prompt_ids, answer_ids = prompt_ids + encoded_pairs[-1][0], encoded_pairs[-1][1]
|
42 |
+
return prompt_ids, answer_ids
|
43 |
+
|
44 |
+
def encode_multiturn(
|
45 |
+
self,
|
46 |
+
tokenizer: "PreTrainedTokenizer",
|
47 |
+
query: str,
|
48 |
+
resp: str,
|
49 |
+
history: Optional[List[Tuple[str, str]]] = None,
|
50 |
+
system: Optional[str] = None
|
51 |
+
) -> List[Tuple[List[int], List[int]]]:
|
52 |
+
r"""
|
53 |
+
Returns multiple pairs of token ids representing prompts and responses respectively.
|
54 |
+
"""
|
55 |
+
system, history = self._format(query, resp, history, system)
|
56 |
+
encoded_pairs = self._encode(tokenizer, system, history)
|
57 |
+
return encoded_pairs
|
58 |
+
|
59 |
+
def _format(
|
60 |
+
self,
|
61 |
+
query: str,
|
62 |
+
resp: str,
|
63 |
+
history: Optional[List[Tuple[str, str]]] = None,
|
64 |
+
system: Optional[str] = None
|
65 |
+
) -> Tuple[str, List[Tuple[str, str]]]:
|
66 |
+
r"""
|
67 |
+
Aligns inputs to the standard format.
|
68 |
+
"""
|
69 |
+
system = system or self.system # use system if provided
|
70 |
+
history = history if (history and self.use_history) else []
|
71 |
+
history = history + [(query, resp)]
|
72 |
+
return system, history
|
73 |
+
|
74 |
+
def _get_special_ids(
|
75 |
+
self,
|
76 |
+
tokenizer: "PreTrainedTokenizer"
|
77 |
+
) -> Tuple[List[int], List[int]]:
|
78 |
+
if tokenizer.bos_token_id is not None and getattr(tokenizer, "add_bos_token", True):
|
79 |
+
bos_ids = [tokenizer.bos_token_id]
|
80 |
+
else: # baichuan, qwen and gpt2 models have no bos token
|
81 |
+
bos_ids = []
|
82 |
+
|
83 |
+
if tokenizer.eos_token_id is None:
|
84 |
+
raise ValueError("EOS token is required.")
|
85 |
+
|
86 |
+
if self.efficient_eos: # used in baichuan, qwen, chatglm, etc.
|
87 |
+
eos_ids = []
|
88 |
+
else:
|
89 |
+
eos_ids = [tokenizer.eos_token_id]
|
90 |
+
|
91 |
+
return bos_ids, eos_ids
|
92 |
+
|
93 |
+
def _encode(
|
94 |
+
self,
|
95 |
+
tokenizer: "PreTrainedTokenizer",
|
96 |
+
system: str,
|
97 |
+
history: List[Tuple[str, str]]
|
98 |
+
) -> List[Tuple[List[int], List[int]]]:
|
99 |
+
r"""
|
100 |
+
Encodes formatted inputs to pairs of token ids.
|
101 |
+
Turn 0: bos + prefix + sep + query resp + eos
|
102 |
+
Turn t: sep + bos + query resp + eos
|
103 |
+
"""
|
104 |
+
bos_ids, eos_ids = self._get_special_ids(tokenizer)
|
105 |
+
sep_ids = self._convert_inputs_to_ids(tokenizer, context=self.sep)
|
106 |
+
encoded_pairs = []
|
107 |
+
for turn_idx, (query, resp) in enumerate(history):
|
108 |
+
if turn_idx == 0:
|
109 |
+
prefix_ids = self._convert_inputs_to_ids(tokenizer, context=self.prefix, system=system)
|
110 |
+
if len(prefix_ids) != 0: # has prefix
|
111 |
+
prefix_ids = bos_ids + prefix_ids + sep_ids
|
112 |
+
else:
|
113 |
+
prefix_ids = bos_ids
|
114 |
+
else:
|
115 |
+
prefix_ids = sep_ids + bos_ids
|
116 |
+
|
117 |
+
query_ids = self._convert_inputs_to_ids(tokenizer, context=self.prompt, query=query, idx=str(turn_idx+1))
|
118 |
+
resp_ids = self._convert_inputs_to_ids(tokenizer, context=[resp])
|
119 |
+
encoded_pairs.append((prefix_ids + query_ids, resp_ids + eos_ids))
|
120 |
+
return encoded_pairs
|
121 |
+
|
122 |
+
def _convert_inputs_to_ids(
|
123 |
+
self,
|
124 |
+
tokenizer: "PreTrainedTokenizer",
|
125 |
+
context: List[Union[str, Dict[str, str]]],
|
126 |
+
system: Optional[str] = None,
|
127 |
+
query: Optional[str] = None,
|
128 |
+
idx: Optional[str] = None
|
129 |
+
) -> List[int]:
|
130 |
+
r"""
|
131 |
+
Converts context to token ids.
|
132 |
+
"""
|
133 |
+
if isinstance(getattr(tokenizer, "tokenizer", None), tiktoken.Encoding): # for tiktoken tokenizer (Qwen)
|
134 |
+
kwargs = dict(allowed_special="all")
|
135 |
+
else:
|
136 |
+
kwargs = dict(add_special_tokens=False)
|
137 |
+
|
138 |
+
token_ids = []
|
139 |
+
for elem in context:
|
140 |
+
if isinstance(elem, str):
|
141 |
+
elem = elem.replace("{{system}}", system, 1) if system is not None else elem
|
142 |
+
elem = elem.replace("{{query}}", query, 1) if query is not None else elem
|
143 |
+
elem = elem.replace("{{idx}}", idx, 1) if idx is not None else elem
|
144 |
+
if len(elem) != 0:
|
145 |
+
token_ids = token_ids + tokenizer.encode(elem, **kwargs)
|
146 |
+
elif isinstance(elem, dict):
|
147 |
+
token_ids = token_ids + [tokenizer.convert_tokens_to_ids(elem.get("token"))]
|
148 |
+
else:
|
149 |
+
raise ValueError("Input must be string or dict[str, str], got {}".format(type(elem)))
|
150 |
+
|
151 |
+
return token_ids
|
152 |
+
|
153 |
+
|
154 |
+
@dataclass
|
155 |
+
class Llama2Template(Template):
|
156 |
+
|
157 |
+
def _encode(
|
158 |
+
self,
|
159 |
+
tokenizer: "PreTrainedTokenizer",
|
160 |
+
system: str,
|
161 |
+
history: List[Tuple[str, str]]
|
162 |
+
) -> List[Tuple[List[int], List[int]]]:
|
163 |
+
r"""
|
164 |
+
Encodes formatted inputs to pairs of token ids.
|
165 |
+
Turn 0: bos + prefix + query resp + eos
|
166 |
+
Turn t: bos + query resp + eos
|
167 |
+
"""
|
168 |
+
bos_ids, eos_ids = self._get_special_ids(tokenizer)
|
169 |
+
encoded_pairs = []
|
170 |
+
for turn_idx, (query, resp) in enumerate(history):
|
171 |
+
if turn_idx == 0: # llama2 template has no sep_ids
|
172 |
+
query = self.prefix[0].replace("{{system}}", system) + query
|
173 |
+
query_ids = self._convert_inputs_to_ids(tokenizer, context=self.prompt, query=query)
|
174 |
+
resp_ids = self._convert_inputs_to_ids(tokenizer, context=[resp])
|
175 |
+
encoded_pairs.append((bos_ids + query_ids, resp_ids + eos_ids))
|
176 |
+
return encoded_pairs
|
177 |
+
|
178 |
+
|
179 |
+
templates: Dict[str, Template] = {}
|
180 |
+
|
181 |
+
|
182 |
+
def register_template(
|
183 |
+
name: str,
|
184 |
+
prefix: List[Union[str, Dict[str, str]]],
|
185 |
+
prompt: List[Union[str, Dict[str, str]]],
|
186 |
+
system: str,
|
187 |
+
sep: List[Union[str, Dict[str, str]]],
|
188 |
+
stop_words: Optional[List[str]] = [],
|
189 |
+
use_history: Optional[bool] = True,
|
190 |
+
efficient_eos: Optional[bool] = False
|
191 |
+
) -> None:
|
192 |
+
template_class = Llama2Template if "llama2" in name else Template
|
193 |
+
templates[name] = template_class(
|
194 |
+
prefix=prefix,
|
195 |
+
prompt=prompt,
|
196 |
+
system=system,
|
197 |
+
sep=sep,
|
198 |
+
stop_words=stop_words,
|
199 |
+
use_history=use_history,
|
200 |
+
efficient_eos=efficient_eos
|
201 |
+
)
|
202 |
+
|
203 |
+
|
204 |
+
def get_template_and_fix_tokenizer(
|
205 |
+
name: str,
|
206 |
+
tokenizer: "PreTrainedTokenizer"
|
207 |
+
) -> Template:
|
208 |
+
if tokenizer.eos_token_id is None:
|
209 |
+
tokenizer.eos_token = "<|endoftext|>"
|
210 |
+
logger.info("Add eos token: {}".format(tokenizer.eos_token))
|
211 |
+
|
212 |
+
if tokenizer.pad_token_id is None:
|
213 |
+
tokenizer.pad_token = tokenizer.eos_token
|
214 |
+
logger.info("Add pad token: {}".format(tokenizer.pad_token))
|
215 |
+
|
216 |
+
if name is None:
|
217 |
+
return None
|
218 |
+
|
219 |
+
template = templates.get(name, None)
|
220 |
+
assert template is not None, "Template {} does not exist.".format(name)
|
221 |
+
tokenizer.add_special_tokens(
|
222 |
+
dict(additional_special_tokens=template.stop_words),
|
223 |
+
replace_additional_special_tokens=False
|
224 |
+
)
|
225 |
+
return template
|
226 |
+
|
227 |
+
|
228 |
+
register_template(
|
229 |
+
name="alpaca",
|
230 |
+
prefix=[
|
231 |
+
"{{system}}"
|
232 |
+
],
|
233 |
+
prompt=[
|
234 |
+
"### Instruction:\n{{query}}\n\n### Response:\n"
|
235 |
+
],
|
236 |
+
system=(
|
237 |
+
"Below is an instruction that describes a task. "
|
238 |
+
"Write a response that appropriately completes the request."
|
239 |
+
),
|
240 |
+
sep=[
|
241 |
+
"\n\n"
|
242 |
+
]
|
243 |
+
)
|
244 |
+
|
245 |
+
|
246 |
+
register_template(
|
247 |
+
name="aquila",
|
248 |
+
prefix=[
|
249 |
+
"{{system}}"
|
250 |
+
],
|
251 |
+
prompt=[
|
252 |
+
"Human: {{query}}###Assistant:"
|
253 |
+
],
|
254 |
+
system=(
|
255 |
+
"A chat between a curious human and an artificial intelligence assistant. "
|
256 |
+
"The assistant gives helpful, detailed, and polite answers to the human's questions."
|
257 |
+
),
|
258 |
+
sep=[
|
259 |
+
"###"
|
260 |
+
],
|
261 |
+
stop_words=[
|
262 |
+
"</s>"
|
263 |
+
],
|
264 |
+
efficient_eos=True
|
265 |
+
)
|
266 |
+
|
267 |
+
|
268 |
+
register_template(
|
269 |
+
name="baichuan",
|
270 |
+
prefix=[
|
271 |
+
"{{system}}"
|
272 |
+
],
|
273 |
+
prompt=[
|
274 |
+
{"token": "<reserved_102>"}, # user token
|
275 |
+
"{{query}}",
|
276 |
+
{"token": "<reserved_103>"} # assistant token
|
277 |
+
],
|
278 |
+
system="",
|
279 |
+
sep=[],
|
280 |
+
efficient_eos=True
|
281 |
+
)
|
282 |
+
|
283 |
+
|
284 |
+
register_template(
|
285 |
+
name="baichuan2",
|
286 |
+
prefix=[
|
287 |
+
"{{system}}"
|
288 |
+
],
|
289 |
+
prompt=[
|
290 |
+
{"token": "<reserved_106>"}, # user token
|
291 |
+
"{{query}}",
|
292 |
+
{"token": "<reserved_107>"} # assistant token
|
293 |
+
],
|
294 |
+
system="",
|
295 |
+
sep=[],
|
296 |
+
efficient_eos=True
|
297 |
+
)
|
298 |
+
|
299 |
+
|
300 |
+
register_template(
|
301 |
+
name="belle",
|
302 |
+
prefix=[
|
303 |
+
"{{system}}"
|
304 |
+
],
|
305 |
+
prompt=[
|
306 |
+
"Human: {{query}}\n\nBelle: "
|
307 |
+
],
|
308 |
+
system="",
|
309 |
+
sep=[
|
310 |
+
"\n\n"
|
311 |
+
]
|
312 |
+
)
|
313 |
+
|
314 |
+
|
315 |
+
register_template(
|
316 |
+
name="bluelm",
|
317 |
+
prefix=[
|
318 |
+
"{{system}}"
|
319 |
+
],
|
320 |
+
prompt=[
|
321 |
+
{"token": "[|Human|]:"},
|
322 |
+
"{{query}}",
|
323 |
+
{"token": "[|AI|]:"}
|
324 |
+
],
|
325 |
+
system="",
|
326 |
+
sep=[]
|
327 |
+
)
|
328 |
+
|
329 |
+
|
330 |
+
register_template(
|
331 |
+
name="chatglm2",
|
332 |
+
prefix=[
|
333 |
+
{"token": "[gMASK]"},
|
334 |
+
{"token": "sop"},
|
335 |
+
"{{system}}"
|
336 |
+
],
|
337 |
+
prompt=[
|
338 |
+
"[Round {{idx}}]\n\n问:{{query}}\n\n答:"
|
339 |
+
],
|
340 |
+
system="",
|
341 |
+
sep=[
|
342 |
+
"\n\n"
|
343 |
+
],
|
344 |
+
efficient_eos=True
|
345 |
+
)
|
346 |
+
|
347 |
+
|
348 |
+
register_template(
|
349 |
+
name="chatglm3",
|
350 |
+
prefix=[
|
351 |
+
{"token": "[gMASK]"},
|
352 |
+
{"token": "sop"},
|
353 |
+
{"token": "<|system|>"},
|
354 |
+
"\n",
|
355 |
+
"{{system}}"
|
356 |
+
],
|
357 |
+
prompt=[
|
358 |
+
{"token": "<|user|>"},
|
359 |
+
"\n",
|
360 |
+
"{{query}}",
|
361 |
+
{"token": "<|assistant|>"},
|
362 |
+
"\n" # add an extra newline to avoid error in ChatGLM's process_response method
|
363 |
+
],
|
364 |
+
system=(
|
365 |
+
"You are ChatGLM3, a large language model trained by Zhipu.AI. "
|
366 |
+
"Follow the user's instructions carefully. Respond using markdown."
|
367 |
+
),
|
368 |
+
sep=[],
|
369 |
+
stop_words=[
|
370 |
+
"<|user|>",
|
371 |
+
"<|observation|>"
|
372 |
+
],
|
373 |
+
efficient_eos=True
|
374 |
+
)
|
375 |
+
|
376 |
+
|
377 |
+
register_template(
|
378 |
+
name="chatglm3_raw", # the raw template for tool tuning
|
379 |
+
prefix=[
|
380 |
+
{"token": "[gMASK]"},
|
381 |
+
{"token": "sop"},
|
382 |
+
{"token": "<|system|>"},
|
383 |
+
"\n",
|
384 |
+
"{{system}}"
|
385 |
+
],
|
386 |
+
prompt=[
|
387 |
+
{"token": "<|user|>"},
|
388 |
+
"\n",
|
389 |
+
"{{query}}",
|
390 |
+
{"token": "<|assistant|>"}
|
391 |
+
],
|
392 |
+
system=(
|
393 |
+
"You are ChatGLM3, a large language model trained by Zhipu.AI. "
|
394 |
+
"Follow the user's instructions carefully. Respond using markdown."
|
395 |
+
),
|
396 |
+
sep=[],
|
397 |
+
stop_words=[
|
398 |
+
"<|user|>",
|
399 |
+
"<|observation|>"
|
400 |
+
],
|
401 |
+
efficient_eos=True
|
402 |
+
)
|
403 |
+
|
404 |
+
|
405 |
+
register_template(
|
406 |
+
name="deepseek",
|
407 |
+
prefix=[
|
408 |
+
"{{system}}"
|
409 |
+
],
|
410 |
+
prompt=[
|
411 |
+
"User: {{query}}\n\nAssistant:"
|
412 |
+
],
|
413 |
+
system="",
|
414 |
+
sep=[]
|
415 |
+
)
|
416 |
+
|
417 |
+
|
418 |
+
register_template(
|
419 |
+
name="deepseekcoder",
|
420 |
+
prefix=[
|
421 |
+
"{{system}}"
|
422 |
+
],
|
423 |
+
prompt=[
|
424 |
+
"### Instruction:\n{{query}}\n### Response:\n"
|
425 |
+
],
|
426 |
+
system=(
|
427 |
+
"You are an AI programming assistant, utilizing the Deepseek Coder model, "
|
428 |
+
"developed by Deepseek Company, and you only answer questions related to computer science. "
|
429 |
+
"For politically sensitive questions, security and privacy issues, "
|
430 |
+
"and other non-computer science questions, you will refuse to answer\n"
|
431 |
+
),
|
432 |
+
sep=[
|
433 |
+
"\n",
|
434 |
+
{"token": "<|EOT|>"},
|
435 |
+
"\n"
|
436 |
+
],
|
437 |
+
stop_words=[
|
438 |
+
"<|EOT|>"
|
439 |
+
],
|
440 |
+
efficient_eos=True
|
441 |
+
)
|
442 |
+
|
443 |
+
|
444 |
+
register_template(
|
445 |
+
name="default",
|
446 |
+
prefix=[
|
447 |
+
"{{system}}"
|
448 |
+
],
|
449 |
+
prompt=[
|
450 |
+
"Human: {{query}}\nAssistant:"
|
451 |
+
],
|
452 |
+
system=(
|
453 |
+
"A chat between a curious user and an artificial intelligence assistant. "
|
454 |
+
"The assistant gives helpful, detailed, and polite answers to the user's questions."
|
455 |
+
),
|
456 |
+
sep=[
|
457 |
+
"\n"
|
458 |
+
]
|
459 |
+
)
|
460 |
+
|
461 |
+
|
462 |
+
register_template(
|
463 |
+
name="falcon",
|
464 |
+
prefix=[
|
465 |
+
"{{system}}"
|
466 |
+
],
|
467 |
+
prompt=[
|
468 |
+
"User: {{query}}\nFalcon:"
|
469 |
+
],
|
470 |
+
system="",
|
471 |
+
sep=[
|
472 |
+
"\n"
|
473 |
+
],
|
474 |
+
efficient_eos=True
|
475 |
+
)
|
476 |
+
|
477 |
+
|
478 |
+
register_template(
|
479 |
+
name="intern",
|
480 |
+
prefix=[
|
481 |
+
"{{system}}"
|
482 |
+
],
|
483 |
+
prompt=[
|
484 |
+
"<|User|>:{{query}}",
|
485 |
+
{"token": "<eoh>"},
|
486 |
+
"\n<|Bot|>:"
|
487 |
+
],
|
488 |
+
system="",
|
489 |
+
sep=[
|
490 |
+
{"token": "<eoa>"},
|
491 |
+
"\n"
|
492 |
+
],
|
493 |
+
stop_words=[
|
494 |
+
"<eoa>"
|
495 |
+
],
|
496 |
+
efficient_eos=True
|
497 |
+
)
|
498 |
+
|
499 |
+
|
500 |
+
register_template(
|
501 |
+
name="llama2",
|
502 |
+
prefix=[
|
503 |
+
"<<SYS>>\n{{system}}\n<</SYS>>\n\n"
|
504 |
+
],
|
505 |
+
prompt=[
|
506 |
+
"[INST] {{query}} [/INST]"
|
507 |
+
],
|
508 |
+
system=(
|
509 |
+
"You are a helpful, respectful and honest assistant. "
|
510 |
+
"Always answer as helpfully as possible, while being safe. "
|
511 |
+
"Your answers should not include any harmful, unethical, "
|
512 |
+
"racist, sexist, toxic, dangerous, or illegal content. "
|
513 |
+
"Please ensure that your responses are socially unbiased and positive in nature.\n\n"
|
514 |
+
"If a question does not make any sense, or is not factually coherent, "
|
515 |
+
"explain why instead of answering something not correct. "
|
516 |
+
"If you don't know the answer to a question, please don't share false information."
|
517 |
+
),
|
518 |
+
sep=[]
|
519 |
+
)
|
520 |
+
|
521 |
+
|
522 |
+
register_template(
|
523 |
+
name="llama2_zh",
|
524 |
+
prefix=[
|
525 |
+
"<<SYS>>\n{{system}}\n<</SYS>>\n\n"
|
526 |
+
],
|
527 |
+
prompt=[
|
528 |
+
"[INST] {{query}} [/INST]"
|
529 |
+
],
|
530 |
+
system="You are a helpful assistant. 你是一个乐于助人的助手。",
|
531 |
+
sep=[]
|
532 |
+
)
|
533 |
+
|
534 |
+
|
535 |
+
register_template(
|
536 |
+
name="mistral",
|
537 |
+
prefix=[
|
538 |
+
"{{system}}"
|
539 |
+
],
|
540 |
+
prompt=[
|
541 |
+
"[INST] {{query}} [/INST]"
|
542 |
+
],
|
543 |
+
system="",
|
544 |
+
sep=[
|
545 |
+
" "
|
546 |
+
]
|
547 |
+
)
|
548 |
+
|
549 |
+
|
550 |
+
register_template(
|
551 |
+
name="openchat",
|
552 |
+
prefix=[
|
553 |
+
"{{system}}"
|
554 |
+
],
|
555 |
+
prompt=[
|
556 |
+
"GPT4 Correct User: {{query}}",
|
557 |
+
{"token": "<|end_of_turn|>"},
|
558 |
+
"GPT4 Correct Assistant:"
|
559 |
+
],
|
560 |
+
system="",
|
561 |
+
sep=[
|
562 |
+
{"token": "<|end_of_turn|>"}
|
563 |
+
],
|
564 |
+
stop_words=[
|
565 |
+
"<|end_of_turn|>"
|
566 |
+
],
|
567 |
+
efficient_eos=True
|
568 |
+
)
|
569 |
+
|
570 |
+
|
571 |
+
register_template(
|
572 |
+
name="qwen",
|
573 |
+
prefix=[
|
574 |
+
{"token": "<|im_start|>"},
|
575 |
+
"system\n{{system}}"
|
576 |
+
],
|
577 |
+
prompt=[
|
578 |
+
{"token": "<|im_start|>"},
|
579 |
+
"user\n{{query}}",
|
580 |
+
{"token": "<|im_end|>"},
|
581 |
+
"\n",
|
582 |
+
{"token": "<|im_start|>"},
|
583 |
+
"assistant\n"
|
584 |
+
],
|
585 |
+
system="You are a helpful assistant.",
|
586 |
+
sep=[
|
587 |
+
{"token": "<|im_end|>"},
|
588 |
+
"\n"
|
589 |
+
],
|
590 |
+
stop_words=[
|
591 |
+
"<|im_end|>"
|
592 |
+
],
|
593 |
+
efficient_eos=True
|
594 |
+
)
|
595 |
+
|
596 |
+
|
597 |
+
register_template(
|
598 |
+
name="starchat",
|
599 |
+
prefix=[
|
600 |
+
{"token": "<|system|>"},
|
601 |
+
"\n{{system}}",
|
602 |
+
],
|
603 |
+
prompt=[
|
604 |
+
{"token": "<|user|>"},
|
605 |
+
"\n{{query}}",
|
606 |
+
{"token": "<|end|>"},
|
607 |
+
"\n",
|
608 |
+
{"token": "<|assistant|>"}
|
609 |
+
],
|
610 |
+
system="",
|
611 |
+
sep=[
|
612 |
+
{"token": "<|end|>"},
|
613 |
+
"\n"
|
614 |
+
],
|
615 |
+
stop_words=[
|
616 |
+
"<|end|>"
|
617 |
+
],
|
618 |
+
efficient_eos=True
|
619 |
+
)
|
620 |
+
|
621 |
+
|
622 |
+
r"""
|
623 |
+
Supports language model inference without histories.
|
624 |
+
"""
|
625 |
+
register_template(
|
626 |
+
name="vanilla",
|
627 |
+
prefix=[],
|
628 |
+
prompt=[
|
629 |
+
"{{query}}"
|
630 |
+
],
|
631 |
+
system="",
|
632 |
+
sep=[],
|
633 |
+
use_history=False
|
634 |
+
)
|
635 |
+
|
636 |
+
|
637 |
+
register_template(
|
638 |
+
name="vicuna",
|
639 |
+
prefix=[
|
640 |
+
"{{system}}"
|
641 |
+
],
|
642 |
+
prompt=[
|
643 |
+
"USER: {{query}} ASSISTANT:"
|
644 |
+
],
|
645 |
+
system=(
|
646 |
+
"A chat between a curious user and an artificial intelligence assistant. "
|
647 |
+
"The assistant gives helpful, detailed, and polite answers to the user's questions."
|
648 |
+
),
|
649 |
+
sep=[]
|
650 |
+
)
|
651 |
+
|
652 |
+
|
653 |
+
register_template(
|
654 |
+
name="xverse",
|
655 |
+
prefix=[
|
656 |
+
"{{system}}"
|
657 |
+
],
|
658 |
+
prompt=[
|
659 |
+
"Human: {{query}}\n\nAssistant: "
|
660 |
+
],
|
661 |
+
system="",
|
662 |
+
sep=[]
|
663 |
+
)
|
664 |
+
|
665 |
+
|
666 |
+
register_template(
|
667 |
+
name="yayi",
|
668 |
+
prefix=[
|
669 |
+
{"token": "<|System|>"},
|
670 |
+
":\n{{system}}"
|
671 |
+
],
|
672 |
+
prompt=[
|
673 |
+
{"token": "<|Human|>"},
|
674 |
+
":\n{{query}}\n\n",
|
675 |
+
{"token": "<|YaYi|>"},
|
676 |
+
":"
|
677 |
+
],
|
678 |
+
system=(
|
679 |
+
"You are a helpful, respectful and honest assistant named YaYi "
|
680 |
+
"developed by Beijing Wenge Technology Co.,Ltd. "
|
681 |
+
"Always answer as helpfully as possible, while being safe. "
|
682 |
+
"Your answers should not include any harmful, unethical, "
|
683 |
+
"racist, sexist, toxic, dangerous, or illegal content. "
|
684 |
+
"Please ensure that your responses are socially unbiased and positive in nature.\n\n"
|
685 |
+
"If a question does not make any sense, or is not factually coherent, "
|
686 |
+
"explain why instead of answering something not correct. "
|
687 |
+
"If you don't know the answer to a question, please don't share false information."
|
688 |
+
),
|
689 |
+
sep=[
|
690 |
+
"\n\n"
|
691 |
+
],
|
692 |
+
stop_words=[
|
693 |
+
"<|End|>"
|
694 |
+
]
|
695 |
+
)
|
696 |
+
|
697 |
+
|
698 |
+
register_template(
|
699 |
+
name="yi",
|
700 |
+
prefix=[
|
701 |
+
"{{system}}"
|
702 |
+
],
|
703 |
+
prompt=[
|
704 |
+
"<|im_start|>user\n{{query}}<|im_end|>\n<|im_start|>assistant\n"
|
705 |
+
],
|
706 |
+
system="",
|
707 |
+
sep=[
|
708 |
+
"<|im_end|>\n"
|
709 |
+
],
|
710 |
+
efficient_eos=True
|
711 |
+
)
|
712 |
+
|
713 |
+
|
714 |
+
register_template(
|
715 |
+
name="zephyr",
|
716 |
+
prefix=[
|
717 |
+
{"token": "<|system|>"},
|
718 |
+
"\n{{system}}",
|
719 |
+
{"token": "</s>"}
|
720 |
+
],
|
721 |
+
prompt=[
|
722 |
+
{"token": "<|user|>"},
|
723 |
+
"\n{{query}}",
|
724 |
+
{"token": "</s>"},
|
725 |
+
{"token": "<|assistant|>"}
|
726 |
+
],
|
727 |
+
system="You are a friendly chatbot who always responds in the style of a pirate",
|
728 |
+
sep=[]
|
729 |
+
)
|
730 |
+
|
731 |
+
|
732 |
+
register_template(
|
733 |
+
name="ziya",
|
734 |
+
prefix=[
|
735 |
+
"{{system}}"
|
736 |
+
],
|
737 |
+
prompt=[
|
738 |
+
{"token": "<human>"},
|
739 |
+
":{{query}}\n",
|
740 |
+
{"token": "<bot>"},
|
741 |
+
":"
|
742 |
+
],
|
743 |
+
system="",
|
744 |
+
sep=[
|
745 |
+
"\n"
|
746 |
+
]
|
747 |
+
)
|
LLM-Detector-V4-11w/src/llmtuner/data/utils.py
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import hashlib
|
2 |
+
from typing import TYPE_CHECKING, Dict, List, Optional, Union
|
3 |
+
|
4 |
+
from llmtuner.extras.logging import get_logger
|
5 |
+
|
6 |
+
if TYPE_CHECKING:
|
7 |
+
from datasets import Dataset, IterableDataset
|
8 |
+
from transformers import TrainingArguments
|
9 |
+
from llmtuner.hparams import DataArguments
|
10 |
+
|
11 |
+
|
12 |
+
logger = get_logger(__name__)
|
13 |
+
|
14 |
+
|
15 |
+
EXT2TYPE = {
|
16 |
+
"arrow": "arrow",
|
17 |
+
"csv": "csv",
|
18 |
+
"json": "json",
|
19 |
+
"jsonl": "json",
|
20 |
+
"parquet": "parquet",
|
21 |
+
"txt": "text"
|
22 |
+
}
|
23 |
+
|
24 |
+
|
25 |
+
def checksum(data_files: List[str], file_sha1: Optional[str] = None) -> None:
|
26 |
+
if file_sha1 is None:
|
27 |
+
logger.warning("Checksum failed: missing SHA-1 hash value in dataset_info.json.")
|
28 |
+
return
|
29 |
+
|
30 |
+
if len(data_files) != 1:
|
31 |
+
logger.warning("Checksum failed: too many files.")
|
32 |
+
return
|
33 |
+
|
34 |
+
with open(data_files[0], "rb") as f:
|
35 |
+
sha1 = hashlib.sha1(f.read()).hexdigest()
|
36 |
+
if sha1 != file_sha1:
|
37 |
+
logger.warning("Checksum failed: mismatched SHA-1 hash value at {}.".format(data_files[0]))
|
38 |
+
|
39 |
+
|
40 |
+
def split_dataset(
|
41 |
+
dataset: Union["Dataset", "IterableDataset"],
|
42 |
+
data_args: "DataArguments",
|
43 |
+
training_args: "TrainingArguments"
|
44 |
+
) -> Dict[str, "Dataset"]:
|
45 |
+
if training_args.do_train:
|
46 |
+
if data_args.val_size > 1e-6: # Split the dataset
|
47 |
+
if data_args.streaming:
|
48 |
+
val_set = dataset.take(int(data_args.val_size))
|
49 |
+
train_set = dataset.skip(int(data_args.val_size))
|
50 |
+
dataset = dataset.shuffle(buffer_size=data_args.buffer_size, seed=training_args.seed)
|
51 |
+
return {"train_dataset": train_set, "eval_dataset": val_set}
|
52 |
+
else:
|
53 |
+
val_size = int(data_args.val_size) if data_args.val_size > 1 else data_args.val_size
|
54 |
+
dataset = dataset.train_test_split(test_size=val_size, seed=training_args.seed)
|
55 |
+
return {"train_dataset": dataset["train"], "eval_dataset": dataset["test"]}
|
56 |
+
else:
|
57 |
+
if data_args.streaming:
|
58 |
+
dataset = dataset.shuffle(buffer_size=data_args.buffer_size, seed=training_args.seed)
|
59 |
+
return {"train_dataset": dataset}
|
60 |
+
else: # do_eval or do_predict
|
61 |
+
return {"eval_dataset": dataset}
|
LLM-Detector-V4-11w/src/llmtuner/eval/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from llmtuner.eval.evaluator import Evaluator
|
LLM-Detector-V4-11w/src/llmtuner/eval/evaluator.py
ADDED
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Inspired by: https://github.com/hendrycks/test/blob/master/evaluate_flan.py
|
2 |
+
|
3 |
+
import os
|
4 |
+
import json
|
5 |
+
import torch
|
6 |
+
import inspect
|
7 |
+
import tiktoken
|
8 |
+
import numpy as np
|
9 |
+
from tqdm import tqdm, trange
|
10 |
+
from typing import Any, Dict, List, Optional
|
11 |
+
|
12 |
+
from datasets import load_dataset
|
13 |
+
from transformers.utils import cached_file
|
14 |
+
|
15 |
+
from llmtuner.data.template import get_template_and_fix_tokenizer
|
16 |
+
from llmtuner.eval.template import get_eval_template
|
17 |
+
from llmtuner.extras.constants import CHOICES, SUBJECTS
|
18 |
+
from llmtuner.model import dispatch_model, get_eval_args, load_model_and_tokenizer
|
19 |
+
|
20 |
+
|
21 |
+
class Evaluator:
|
22 |
+
|
23 |
+
def __init__(self, args: Optional[Dict[str, Any]] = None) -> None:
|
24 |
+
self.model_args, self.data_args, self.eval_args, finetuning_args = get_eval_args(args)
|
25 |
+
self.model, self.tokenizer = load_model_and_tokenizer(self.model_args, finetuning_args)
|
26 |
+
self.tokenizer.padding_side = "right" # avoid overflow issue in batched inference for llama2
|
27 |
+
self.model = dispatch_model(self.model)
|
28 |
+
self.template = get_template_and_fix_tokenizer(self.data_args.template, self.tokenizer)
|
29 |
+
self.eval_template = get_eval_template(self.eval_args.lang)
|
30 |
+
self.choice_inputs = self._encode_choices()
|
31 |
+
|
32 |
+
def _encode_choices(self) -> List[int]:
|
33 |
+
if isinstance(getattr(self.tokenizer, "tokenizer", None), tiktoken.Encoding): # for tiktoken tokenizer (Qwen)
|
34 |
+
kwargs = dict(allowed_special="all")
|
35 |
+
else:
|
36 |
+
kwargs = dict(add_special_tokens=False)
|
37 |
+
|
38 |
+
return [self.tokenizer.encode(self.eval_template.prefix + ch, **kwargs)[-1] for ch in CHOICES]
|
39 |
+
|
40 |
+
@torch.inference_mode()
|
41 |
+
def batch_inference(self, batch_input: Dict[str, torch.Tensor]) -> List[str]:
|
42 |
+
logits = self.model(**batch_input).logits
|
43 |
+
lengths = torch.sum(batch_input["attention_mask"], dim=-1)
|
44 |
+
word_probs = torch.stack([logits[i, lengths[i] - 1] for i in range(len(lengths))], dim=0)
|
45 |
+
choice_probs = torch.nn.functional.softmax(word_probs[:, self.choice_inputs], dim=-1).detach()
|
46 |
+
return [chr(ord("A") + offset.item()) for offset in torch.argmax(choice_probs, dim=-1)]
|
47 |
+
|
48 |
+
def eval(self) -> None:
|
49 |
+
if "token" in inspect.signature(cached_file).parameters:
|
50 |
+
kwargs = {"token": self.model_args.hf_hub_token}
|
51 |
+
elif "use_auth_token" in inspect.signature(cached_file).parameters: # for transformers==4.31.0
|
52 |
+
kwargs = {"use_auth_token": self.model_args.hf_hub_token}
|
53 |
+
|
54 |
+
mapping = cached_file(
|
55 |
+
path_or_repo_id = os.path.join(self.eval_args.task_dir, self.eval_args.task),
|
56 |
+
filename="mapping.json",
|
57 |
+
cache_dir=self.model_args.cache_dir,
|
58 |
+
**kwargs
|
59 |
+
)
|
60 |
+
|
61 |
+
with open(mapping, "r", encoding="utf-8") as f:
|
62 |
+
categorys: Dict[str, Dict[str, str]] = json.load(f)
|
63 |
+
|
64 |
+
category_corrects = {subj: np.array([], dtype="bool") for subj in SUBJECTS}
|
65 |
+
pbar = tqdm(categorys.keys(), desc="Processing subjects", position=0)
|
66 |
+
results = {}
|
67 |
+
for subject in pbar:
|
68 |
+
dataset = load_dataset(
|
69 |
+
path=os.path.join(self.eval_args.task_dir, self.eval_args.task),
|
70 |
+
name=subject,
|
71 |
+
cache_dir=self.model_args.cache_dir,
|
72 |
+
download_mode=self.eval_args.download_mode,
|
73 |
+
token=self.model_args.hf_hub_token
|
74 |
+
)
|
75 |
+
pbar.set_postfix_str(categorys[subject]["name"])
|
76 |
+
inputs, outputs, labels = [], [], []
|
77 |
+
for i in trange(len(dataset[self.data_args.split]), desc="Formatting batches", position=1, leave=False):
|
78 |
+
support_set = dataset["train"].shuffle().select(range(min(self.eval_args.n_shot, len(dataset["train"]))))
|
79 |
+
query, resp, history = self.eval_template.format_example(
|
80 |
+
target_data=dataset[self.data_args.split][i],
|
81 |
+
support_set=support_set,
|
82 |
+
subject_name=categorys[subject]["name"],
|
83 |
+
use_history=self.template.use_history
|
84 |
+
)
|
85 |
+
input_ids, _ = self.template.encode_oneturn(
|
86 |
+
tokenizer=self.tokenizer, query=query, resp=resp, history=history
|
87 |
+
)
|
88 |
+
inputs.append({"input_ids": input_ids, "attention_mask": [1] * len(input_ids)})
|
89 |
+
labels.append(resp)
|
90 |
+
|
91 |
+
for i in trange(0, len(inputs), self.eval_args.batch_size, desc="Predicting batches", position=1, leave=False):
|
92 |
+
batch_input = self.tokenizer.pad(
|
93 |
+
inputs[i : i + self.eval_args.batch_size], return_attention_mask=True, return_tensors="pt"
|
94 |
+
).to(self.model.device)
|
95 |
+
preds = self.batch_inference(batch_input)
|
96 |
+
outputs += preds
|
97 |
+
|
98 |
+
corrects = (np.array(outputs) == np.array(labels))
|
99 |
+
category_name = categorys[subject]["category"]
|
100 |
+
category_corrects[category_name] = np.concatenate([category_corrects[category_name], corrects], axis=0)
|
101 |
+
category_corrects["Average"] = np.concatenate([category_corrects["Average"], corrects], axis=0)
|
102 |
+
results[subject] = {str(i): outputs[i] for i in range(len(outputs))}
|
103 |
+
|
104 |
+
pbar.close()
|
105 |
+
self._save_results(category_corrects, results)
|
106 |
+
|
107 |
+
def _save_results(self, category_corrects: Dict[str, np.ndarray], results: Dict[str, Dict[int, str]]) -> None:
|
108 |
+
score_info = "\n".join([
|
109 |
+
"{:>15}: {:.2f}".format(category_name, 100 * np.mean(category_correct))
|
110 |
+
for category_name, category_correct in category_corrects.items() if len(category_correct)
|
111 |
+
])
|
112 |
+
print(score_info)
|
113 |
+
if self.eval_args.save_dir is not None:
|
114 |
+
os.makedirs(self.eval_args.save_dir, exist_ok=False)
|
115 |
+
with open(os.path.join(self.eval_args.save_dir, "results.json"), "w", encoding="utf-8", newline="\n") as f:
|
116 |
+
json.dump(results, f, indent=2)
|
117 |
+
|
118 |
+
with open(os.path.join(self.eval_args.save_dir, "results.log"), "w", encoding="utf-8", newline="\n") as f:
|
119 |
+
f.write(score_info)
|
120 |
+
|
121 |
+
|
122 |
+
if __name__ == "__main__":
|
123 |
+
evaluator = Evaluator()
|
124 |
+
evaluator.eval()
|
LLM-Detector-V4-11w/src/llmtuner/eval/template.py
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
from typing import TYPE_CHECKING, Dict, List, Tuple
|
3 |
+
|
4 |
+
from llmtuner.extras.constants import CHOICES
|
5 |
+
|
6 |
+
if TYPE_CHECKING:
|
7 |
+
from datasets import Dataset
|
8 |
+
|
9 |
+
|
10 |
+
@dataclass
|
11 |
+
class EvalTemplate:
|
12 |
+
|
13 |
+
system: str
|
14 |
+
choice: str
|
15 |
+
answer: str
|
16 |
+
prefix: str
|
17 |
+
|
18 |
+
def parse_example(
|
19 |
+
self,
|
20 |
+
example: Dict[str, str]
|
21 |
+
) -> Tuple[str, str]:
|
22 |
+
candidates = [self.choice.format(choice=ch, content=example[ch]) for ch in CHOICES if ch in example]
|
23 |
+
return "".join([example["question"]] + candidates + [self.answer]), example["answer"]
|
24 |
+
|
25 |
+
def format_example(
|
26 |
+
self,
|
27 |
+
target_data: Dict[str, str],
|
28 |
+
support_set: "Dataset",
|
29 |
+
subject_name: str,
|
30 |
+
use_history: bool
|
31 |
+
) -> Tuple[str, str, List[Tuple[str, str]]]:
|
32 |
+
query, resp = self.parse_example(target_data)
|
33 |
+
history = [self.parse_example(support_set[k]) for k in range(len(support_set))]
|
34 |
+
|
35 |
+
if len(history):
|
36 |
+
temp = history.pop(0)
|
37 |
+
history.insert(0, (self.system.format(subject=subject_name) + temp[0], temp[1]))
|
38 |
+
else:
|
39 |
+
query = self.system.format(subject=subject_name) + query
|
40 |
+
|
41 |
+
if not use_history:
|
42 |
+
query = "\n\n".join(["".join(item) for item in history] + [query])
|
43 |
+
history = []
|
44 |
+
return query.strip(), resp, history
|
45 |
+
|
46 |
+
|
47 |
+
eval_templates: Dict[str, EvalTemplate] = {}
|
48 |
+
|
49 |
+
|
50 |
+
def register_eval_template(
|
51 |
+
name: str,
|
52 |
+
system: str,
|
53 |
+
choice: str,
|
54 |
+
answer: str,
|
55 |
+
prefix: str
|
56 |
+
) -> None:
|
57 |
+
eval_templates[name] = EvalTemplate(
|
58 |
+
system=system,
|
59 |
+
choice=choice,
|
60 |
+
answer=answer,
|
61 |
+
prefix=prefix
|
62 |
+
)
|
63 |
+
|
64 |
+
|
65 |
+
def get_eval_template(name: str) -> EvalTemplate:
|
66 |
+
eval_template = eval_templates.get(name, None)
|
67 |
+
assert eval_template is not None, "Template {} does not exist.".format(name)
|
68 |
+
return eval_template
|
69 |
+
|
70 |
+
|
71 |
+
register_eval_template(
|
72 |
+
name="en",
|
73 |
+
system="The following are multiple choice questions (with answers) about {subject}.\n\n",
|
74 |
+
choice="\n{choice}. {content}",
|
75 |
+
answer="\nAnswer: ",
|
76 |
+
prefix=" "
|
77 |
+
)
|
78 |
+
|
79 |
+
|
80 |
+
register_eval_template(
|
81 |
+
name="zh",
|
82 |
+
system="以下是中国关于{subject}考试的单项选择题,请选出其中的正确答案。\n\n",
|
83 |
+
choice="\n{choice}. {content}",
|
84 |
+
answer="\n答案:",
|
85 |
+
prefix="\n"
|
86 |
+
)
|
LLM-Detector-V4-11w/src/llmtuner/extras/__init__.py
ADDED
File without changes
|
LLM-Detector-V4-11w/src/llmtuner/extras/callbacks.py
ADDED
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
import time
|
4 |
+
from typing import TYPE_CHECKING
|
5 |
+
from datetime import timedelta
|
6 |
+
|
7 |
+
from transformers import TrainerCallback
|
8 |
+
from transformers.modeling_utils import custom_object_save, unwrap_model
|
9 |
+
from transformers.trainer_utils import has_length, PREFIX_CHECKPOINT_DIR
|
10 |
+
|
11 |
+
from llmtuner.extras.constants import LOG_FILE_NAME
|
12 |
+
from llmtuner.extras.logging import get_logger
|
13 |
+
|
14 |
+
if TYPE_CHECKING:
|
15 |
+
from transformers import TrainingArguments, TrainerState, TrainerControl
|
16 |
+
from trl import AutoModelForCausalLMWithValueHead
|
17 |
+
|
18 |
+
|
19 |
+
logger = get_logger(__name__)
|
20 |
+
|
21 |
+
|
22 |
+
def _save_model_with_valuehead(model: "AutoModelForCausalLMWithValueHead", output_dir: str) -> None:
|
23 |
+
model.pretrained_model.config.save_pretrained(output_dir)
|
24 |
+
if model.pretrained_model.can_generate():
|
25 |
+
model.pretrained_model.generation_config.save_pretrained(output_dir)
|
26 |
+
if getattr(model, "is_peft_model", False):
|
27 |
+
model.pretrained_model.save_pretrained(output_dir)
|
28 |
+
elif getattr(model.pretrained_model, "_auto_class", None): # must not a peft model
|
29 |
+
custom_object_save(model.pretrained_model, output_dir, config=model.pretrained_model.config)
|
30 |
+
|
31 |
+
|
32 |
+
class SavePeftModelCallback(TrainerCallback):
|
33 |
+
|
34 |
+
def on_save(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
35 |
+
r"""
|
36 |
+
Event called after a checkpoint save.
|
37 |
+
"""
|
38 |
+
if args.should_save:
|
39 |
+
_save_model_with_valuehead(
|
40 |
+
model=unwrap_model(kwargs.pop("model")),
|
41 |
+
output_dir=os.path.join(args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, state.global_step))
|
42 |
+
)
|
43 |
+
|
44 |
+
def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
45 |
+
r"""
|
46 |
+
Event called at the end of training.
|
47 |
+
"""
|
48 |
+
if args.should_save:
|
49 |
+
_save_model_with_valuehead(model=unwrap_model(kwargs.pop("model")), output_dir=args.output_dir)
|
50 |
+
|
51 |
+
|
52 |
+
class LogCallback(TrainerCallback):
|
53 |
+
|
54 |
+
def __init__(self, runner=None):
|
55 |
+
self.runner = runner
|
56 |
+
self.in_training = False
|
57 |
+
self.start_time = time.time()
|
58 |
+
self.cur_steps = 0
|
59 |
+
self.max_steps = 0
|
60 |
+
self.elapsed_time = ""
|
61 |
+
self.remaining_time = ""
|
62 |
+
|
63 |
+
def timing(self):
|
64 |
+
cur_time = time.time()
|
65 |
+
elapsed_time = cur_time - self.start_time
|
66 |
+
avg_time_per_step = elapsed_time / self.cur_steps if self.cur_steps != 0 else 0
|
67 |
+
remaining_time = (self.max_steps - self.cur_steps) * avg_time_per_step
|
68 |
+
self.elapsed_time = str(timedelta(seconds=int(elapsed_time)))
|
69 |
+
self.remaining_time = str(timedelta(seconds=int(remaining_time)))
|
70 |
+
|
71 |
+
def on_train_begin(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
72 |
+
r"""
|
73 |
+
Event called at the beginning of training.
|
74 |
+
"""
|
75 |
+
if state.is_local_process_zero:
|
76 |
+
self.in_training = True
|
77 |
+
self.start_time = time.time()
|
78 |
+
self.max_steps = state.max_steps
|
79 |
+
if os.path.exists(os.path.join(args.output_dir, LOG_FILE_NAME)) and args.overwrite_output_dir:
|
80 |
+
logger.warning("Previous log file in this folder will be deleted.")
|
81 |
+
os.remove(os.path.join(args.output_dir, LOG_FILE_NAME))
|
82 |
+
|
83 |
+
def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
84 |
+
r"""
|
85 |
+
Event called at the end of training.
|
86 |
+
"""
|
87 |
+
if state.is_local_process_zero:
|
88 |
+
self.in_training = False
|
89 |
+
self.cur_steps = 0
|
90 |
+
self.max_steps = 0
|
91 |
+
|
92 |
+
def on_substep_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
93 |
+
r"""
|
94 |
+
Event called at the end of an substep during gradient accumulation.
|
95 |
+
"""
|
96 |
+
if state.is_local_process_zero and self.runner is not None and self.runner.aborted:
|
97 |
+
control.should_epoch_stop = True
|
98 |
+
control.should_training_stop = True
|
99 |
+
|
100 |
+
def on_step_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
101 |
+
r"""
|
102 |
+
Event called at the end of a training step.
|
103 |
+
"""
|
104 |
+
if state.is_local_process_zero:
|
105 |
+
self.cur_steps = state.global_step
|
106 |
+
self.timing()
|
107 |
+
if self.runner is not None and self.runner.aborted:
|
108 |
+
control.should_epoch_stop = True
|
109 |
+
control.should_training_stop = True
|
110 |
+
|
111 |
+
def on_evaluate(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
112 |
+
r"""
|
113 |
+
Event called after an evaluation phase.
|
114 |
+
"""
|
115 |
+
if state.is_local_process_zero and not self.in_training:
|
116 |
+
self.cur_steps = 0
|
117 |
+
self.max_steps = 0
|
118 |
+
|
119 |
+
def on_predict(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", *other, **kwargs):
|
120 |
+
r"""
|
121 |
+
Event called after a successful prediction.
|
122 |
+
"""
|
123 |
+
if state.is_local_process_zero and not self.in_training:
|
124 |
+
self.cur_steps = 0
|
125 |
+
self.max_steps = 0
|
126 |
+
|
127 |
+
def on_log(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs) -> None:
|
128 |
+
r"""
|
129 |
+
Event called after logging the last logs.
|
130 |
+
"""
|
131 |
+
if not state.is_local_process_zero:
|
132 |
+
return
|
133 |
+
|
134 |
+
logs = dict(
|
135 |
+
current_steps=self.cur_steps,
|
136 |
+
total_steps=self.max_steps,
|
137 |
+
loss=state.log_history[-1].get("loss", None),
|
138 |
+
eval_loss=state.log_history[-1].get("eval_loss", None),
|
139 |
+
predict_loss=state.log_history[-1].get("predict_loss", None),
|
140 |
+
reward=state.log_history[-1].get("reward", None),
|
141 |
+
learning_rate=state.log_history[-1].get("learning_rate", None),
|
142 |
+
epoch=state.log_history[-1].get("epoch", None),
|
143 |
+
percentage=round(self.cur_steps / self.max_steps * 100, 2) if self.max_steps != 0 else 100,
|
144 |
+
elapsed_time=self.elapsed_time,
|
145 |
+
remaining_time=self.remaining_time
|
146 |
+
)
|
147 |
+
if self.runner is not None:
|
148 |
+
logger.info("{{'loss': {:.4f}, 'learning_rate': {:2.4e}, 'epoch': {:.2f}}}".format(
|
149 |
+
logs["loss"] or 0, logs["learning_rate"] or 0, logs["epoch"] or 0
|
150 |
+
))
|
151 |
+
|
152 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
153 |
+
with open(os.path.join(args.output_dir, "trainer_log.jsonl"), "a", encoding="utf-8") as f:
|
154 |
+
f.write(json.dumps(logs) + "\n")
|
155 |
+
|
156 |
+
def on_prediction_step(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
157 |
+
r"""
|
158 |
+
Event called after a prediction step.
|
159 |
+
"""
|
160 |
+
eval_dataloader = kwargs.pop("eval_dataloader", None)
|
161 |
+
if state.is_local_process_zero and has_length(eval_dataloader) and not self.in_training:
|
162 |
+
if self.max_steps == 0:
|
163 |
+
self.max_steps = len(eval_dataloader)
|
164 |
+
self.cur_steps += 1
|
165 |
+
self.timing()
|
LLM-Detector-V4-11w/src/llmtuner/extras/constants.py
ADDED
@@ -0,0 +1,587 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from enum import Enum
|
2 |
+
from collections import defaultdict, OrderedDict
|
3 |
+
from typing import Dict, Optional
|
4 |
+
|
5 |
+
|
6 |
+
CHOICES = ["A", "B", "C", "D"]
|
7 |
+
|
8 |
+
DEFAULT_MODULE = defaultdict(str)
|
9 |
+
|
10 |
+
DEFAULT_TEMPLATE = defaultdict(str)
|
11 |
+
|
12 |
+
IGNORE_INDEX = -100
|
13 |
+
|
14 |
+
LAYERNORM_NAMES = {"norm", "ln"}
|
15 |
+
|
16 |
+
LOG_FILE_NAME = "trainer_log.jsonl"
|
17 |
+
|
18 |
+
METHODS = ["full", "freeze", "lora"]
|
19 |
+
|
20 |
+
SUBJECTS = ["Average", "STEM", "Social Sciences", "Humanities", "Other"]
|
21 |
+
|
22 |
+
SUPPORTED_MODELS = OrderedDict()
|
23 |
+
|
24 |
+
TRAINING_STAGES = {
|
25 |
+
"Supervised Fine-Tuning": "sft",
|
26 |
+
"Reward Modeling": "rm",
|
27 |
+
"PPO": "ppo",
|
28 |
+
"DPO": "dpo",
|
29 |
+
"Pre-Training": "pt"
|
30 |
+
}
|
31 |
+
|
32 |
+
class DownloadSource(str, Enum):
|
33 |
+
DEFAULT = "hf"
|
34 |
+
MODELSCOPE = "ms"
|
35 |
+
|
36 |
+
|
37 |
+
def register_model_group(
|
38 |
+
models: Dict[str, Dict[DownloadSource, str]],
|
39 |
+
module: Optional[str] = None,
|
40 |
+
template: Optional[str] = None
|
41 |
+
) -> None:
|
42 |
+
prefix = None
|
43 |
+
for name, path in models.items():
|
44 |
+
if prefix is None:
|
45 |
+
prefix = name.split("-")[0]
|
46 |
+
else:
|
47 |
+
assert prefix == name.split("-")[0], "prefix should be identical."
|
48 |
+
SUPPORTED_MODELS[name] = path
|
49 |
+
if module is not None:
|
50 |
+
DEFAULT_MODULE[prefix] = module
|
51 |
+
if template is not None:
|
52 |
+
DEFAULT_TEMPLATE[prefix] = template
|
53 |
+
|
54 |
+
|
55 |
+
register_model_group(
|
56 |
+
models={
|
57 |
+
"Baichuan-7B-Base": {
|
58 |
+
DownloadSource.DEFAULT: "baichuan-inc/Baichuan-7B",
|
59 |
+
DownloadSource.MODELSCOPE: "baichuan-inc/baichuan-7B"
|
60 |
+
},
|
61 |
+
"Baichuan-13B-Base": {
|
62 |
+
DownloadSource.DEFAULT: "baichuan-inc/Baichuan-13B-Base",
|
63 |
+
DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan-13B-Base"
|
64 |
+
},
|
65 |
+
"Baichuan-13B-Chat": {
|
66 |
+
DownloadSource.DEFAULT: "baichuan-inc/Baichuan-13B-Chat",
|
67 |
+
DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan-13B-Chat"
|
68 |
+
}
|
69 |
+
},
|
70 |
+
module="W_pack",
|
71 |
+
template="baichuan"
|
72 |
+
)
|
73 |
+
|
74 |
+
|
75 |
+
register_model_group(
|
76 |
+
models={
|
77 |
+
"Baichuan2-7B-Base": {
|
78 |
+
DownloadSource.DEFAULT: "baichuan-inc/Baichuan2-7B-Base",
|
79 |
+
DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan2-7B-Base"
|
80 |
+
},
|
81 |
+
"Baichuan2-13B-Base": {
|
82 |
+
DownloadSource.DEFAULT: "baichuan-inc/Baichuan2-13B-Base",
|
83 |
+
DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan2-13B-Base"
|
84 |
+
},
|
85 |
+
"Baichuan2-7B-Chat": {
|
86 |
+
DownloadSource.DEFAULT: "baichuan-inc/Baichuan2-7B-Chat",
|
87 |
+
DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan2-7B-Chat"
|
88 |
+
},
|
89 |
+
"Baichuan2-13B-Chat": {
|
90 |
+
DownloadSource.DEFAULT: "baichuan-inc/Baichuan2-13B-Chat",
|
91 |
+
DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan2-13B-Chat"
|
92 |
+
}
|
93 |
+
},
|
94 |
+
module="W_pack",
|
95 |
+
template="baichuan2"
|
96 |
+
)
|
97 |
+
|
98 |
+
|
99 |
+
register_model_group(
|
100 |
+
models={
|
101 |
+
"BLOOM-560M": {
|
102 |
+
DownloadSource.DEFAULT: "bigscience/bloom-560m",
|
103 |
+
DownloadSource.MODELSCOPE: "AI-ModelScope/bloom-560m"
|
104 |
+
},
|
105 |
+
"BLOOM-3B": {
|
106 |
+
DownloadSource.DEFAULT: "bigscience/bloom-3b",
|
107 |
+
DownloadSource.MODELSCOPE: "AI-ModelScope/bloom-3b"
|
108 |
+
},
|
109 |
+
"BLOOM-7B1": {
|
110 |
+
DownloadSource.DEFAULT: "bigscience/bloom-7b1",
|
111 |
+
DownloadSource.MODELSCOPE: "AI-ModelScope/bloom-7b1"
|
112 |
+
}
|
113 |
+
},
|
114 |
+
module="query_key_value"
|
115 |
+
)
|
116 |
+
|
117 |
+
|
118 |
+
register_model_group(
|
119 |
+
models={
|
120 |
+
"BLOOMZ-560M": {
|
121 |
+
DownloadSource.DEFAULT: "bigscience/bloomz-560m",
|
122 |
+
DownloadSource.MODELSCOPE: "AI-ModelScope/bloomz-560m"
|
123 |
+
},
|
124 |
+
"BLOOMZ-3B": {
|
125 |
+
DownloadSource.DEFAULT: "bigscience/bloomz-3b",
|
126 |
+
DownloadSource.MODELSCOPE: "AI-ModelScope/bloomz-3b"
|
127 |
+
},
|
128 |
+
"BLOOMZ-7B1-mt": {
|
129 |
+
DownloadSource.DEFAULT: "bigscience/bloomz-7b1-mt",
|
130 |
+
DownloadSource.MODELSCOPE: "AI-ModelScope/bloomz-7b1-mt"
|
131 |
+
}
|
132 |
+
},
|
133 |
+
module="query_key_value"
|
134 |
+
)
|
135 |
+
|
136 |
+
|
137 |
+
register_model_group(
|
138 |
+
models={
|
139 |
+
"BlueLM-7B-Base": {
|
140 |
+
DownloadSource.DEFAULT: "vivo-ai/BlueLM-7B-Base",
|
141 |
+
DownloadSource.MODELSCOPE: "vivo-ai/BlueLM-7B-Base"
|
142 |
+
},
|
143 |
+
"BlueLM-7B-Chat": {
|
144 |
+
DownloadSource.DEFAULT: "vivo-ai/BlueLM-7B-Chat",
|
145 |
+
DownloadSource.MODELSCOPE: "vivo-ai/BlueLM-7B-Chat"
|
146 |
+
}
|
147 |
+
},
|
148 |
+
template="bluelm"
|
149 |
+
)
|
150 |
+
|
151 |
+
|
152 |
+
register_model_group(
|
153 |
+
models={
|
154 |
+
"ChatGLM2-6B-Chat": {
|
155 |
+
DownloadSource.DEFAULT: "THUDM/chatglm2-6b",
|
156 |
+
DownloadSource.MODELSCOPE: "ZhipuAI/chatglm2-6b"
|
157 |
+
}
|
158 |
+
},
|
159 |
+
module="query_key_value",
|
160 |
+
template="chatglm2"
|
161 |
+
)
|
162 |
+
|
163 |
+
|
164 |
+
register_model_group(
|
165 |
+
models={
|
166 |
+
"ChatGLM3-6B-Base": {
|
167 |
+
DownloadSource.DEFAULT: "THUDM/chatglm3-6b-base",
|
168 |
+
DownloadSource.MODELSCOPE: "ZhipuAI/chatglm3-6b-base"
|
169 |
+
},
|
170 |
+
"ChatGLM3-6B-Chat": {
|
171 |
+
DownloadSource.DEFAULT: "THUDM/chatglm3-6b",
|
172 |
+
DownloadSource.MODELSCOPE: "ZhipuAI/chatglm3-6b"
|
173 |
+
}
|
174 |
+
},
|
175 |
+
module="query_key_value",
|
176 |
+
template="chatglm3"
|
177 |
+
)
|
178 |
+
|
179 |
+
|
180 |
+
register_model_group(
|
181 |
+
models={
|
182 |
+
"ChineseLLaMA2-1.3B": {
|
183 |
+
DownloadSource.DEFAULT: "hfl/chinese-llama-2-1.3b",
|
184 |
+
DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-llama-2-1.3b"
|
185 |
+
},
|
186 |
+
"ChineseLLaMA2-7B": {
|
187 |
+
DownloadSource.DEFAULT: "hfl/chinese-llama-2-7b",
|
188 |
+
DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-llama-2-7b"
|
189 |
+
},
|
190 |
+
"ChineseLLaMA2-13B": {
|
191 |
+
DownloadSource.DEFAULT: "hfl/chinese-llama-2-13b",
|
192 |
+
DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-llama-2-13b"
|
193 |
+
},
|
194 |
+
"ChineseLLaMA2-1.3B-Chat": {
|
195 |
+
DownloadSource.DEFAULT: "hfl/chinese-alpaca-2-1.3b",
|
196 |
+
DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-alpaca-2-1.3b"
|
197 |
+
},
|
198 |
+
"ChineseLLaMA2-7B-Chat": {
|
199 |
+
DownloadSource.DEFAULT: "hfl/chinese-alpaca-2-7b",
|
200 |
+
DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-alpaca-2-7b"
|
201 |
+
},
|
202 |
+
"ChineseLLaMA2-13B-Chat": {
|
203 |
+
DownloadSource.DEFAULT: "hfl/chinese-alpaca-2-13b",
|
204 |
+
DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-alpaca-2-13b"
|
205 |
+
}
|
206 |
+
},
|
207 |
+
template="llama2_zh"
|
208 |
+
)
|
209 |
+
|
210 |
+
|
211 |
+
register_model_group(
|
212 |
+
models={
|
213 |
+
"DeepseekLLM-7B-Base": {
|
214 |
+
DownloadSource.DEFAULT: "deepseek-ai/deepseek-llm-7b-base",
|
215 |
+
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-llm-7b-base"
|
216 |
+
},
|
217 |
+
"DeepseekLLM-67B-Base": {
|
218 |
+
DownloadSource.DEFAULT: "deepseek-ai/deepseek-llm-67b-base",
|
219 |
+
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-llm-67b-base"
|
220 |
+
},
|
221 |
+
"DeepseekLLM-7B-Chat": {
|
222 |
+
DownloadSource.DEFAULT: "deepseek-ai/deepseek-llm-7b-chat",
|
223 |
+
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-llm-7b-chat"
|
224 |
+
},
|
225 |
+
"DeepseekLLM-67B-Chat": {
|
226 |
+
DownloadSource.DEFAULT: "deepseek-ai/deepseek-llm-67b-chat",
|
227 |
+
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-llm-67b-chat"
|
228 |
+
}
|
229 |
+
},
|
230 |
+
template="deepseek"
|
231 |
+
)
|
232 |
+
|
233 |
+
|
234 |
+
register_model_group(
|
235 |
+
models={
|
236 |
+
"DeepseekCoder-6.7B-Base": {
|
237 |
+
DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-6.7b-base",
|
238 |
+
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-coder-6.7b-base"
|
239 |
+
},
|
240 |
+
"DeepseekCoder-33B-Base": {
|
241 |
+
DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-33b-base",
|
242 |
+
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-coder-33b-base"
|
243 |
+
},
|
244 |
+
"DeepseekCoder-6.7B-Chat": {
|
245 |
+
DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-6.7b-instruct",
|
246 |
+
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-coder-6.7b-instruct"
|
247 |
+
},
|
248 |
+
"DeepseekCoder-33B-Chat": {
|
249 |
+
DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-33b-instruct",
|
250 |
+
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-coder-33b-instruct"
|
251 |
+
}
|
252 |
+
},
|
253 |
+
template="deepseekcoder"
|
254 |
+
)
|
255 |
+
|
256 |
+
|
257 |
+
register_model_group(
|
258 |
+
models={
|
259 |
+
"Falcon-7B": {
|
260 |
+
DownloadSource.DEFAULT: "tiiuae/falcon-7b",
|
261 |
+
DownloadSource.MODELSCOPE: "AI-ModelScope/falcon-7b"
|
262 |
+
},
|
263 |
+
"Falcon-40B": {
|
264 |
+
DownloadSource.DEFAULT: "tiiuae/falcon-40b",
|
265 |
+
DownloadSource.MODELSCOPE: "AI-ModelScope/falcon-40b"
|
266 |
+
},
|
267 |
+
"Falcon-180B": {
|
268 |
+
DownloadSource.DEFAULT: "tiiuae/falcon-180b",
|
269 |
+
DownloadSource.MODELSCOPE: "modelscope/falcon-180B"
|
270 |
+
},
|
271 |
+
"Falcon-7B-Chat": {
|
272 |
+
DownloadSource.DEFAULT: "tiiuae/falcon-7b-instruct",
|
273 |
+
DownloadSource.MODELSCOPE: "AI-ModelScope/falcon-7b-instruct"
|
274 |
+
},
|
275 |
+
"Falcon-40B-Chat": {
|
276 |
+
DownloadSource.DEFAULT: "tiiuae/falcon-40b-instruct",
|
277 |
+
DownloadSource.MODELSCOPE: "AI-ModelScope/falcon-40b-instruct"
|
278 |
+
},
|
279 |
+
"Falcon-180B-Chat": {
|
280 |
+
DownloadSource.DEFAULT: "tiiuae/falcon-180b-chat",
|
281 |
+
DownloadSource.MODELSCOPE: "modelscope/falcon-180B-chat"
|
282 |
+
}
|
283 |
+
},
|
284 |
+
module="query_key_value",
|
285 |
+
template="falcon"
|
286 |
+
)
|
287 |
+
|
288 |
+
|
289 |
+
register_model_group(
|
290 |
+
models={
|
291 |
+
"InternLM-7B": {
|
292 |
+
DownloadSource.DEFAULT: "internlm/internlm-7b",
|
293 |
+
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm-7b"
|
294 |
+
},
|
295 |
+
"InternLM-20B": {
|
296 |
+
DownloadSource.DEFAULT: "internlm/internlm-20b",
|
297 |
+
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm-20b"
|
298 |
+
},
|
299 |
+
"InternLM-7B-Chat": {
|
300 |
+
DownloadSource.DEFAULT: "internlm/internlm-chat-7b",
|
301 |
+
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm-chat-7b"
|
302 |
+
},
|
303 |
+
"InternLM-20B-Chat": {
|
304 |
+
DownloadSource.DEFAULT: "internlm/internlm-chat-20b",
|
305 |
+
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm-chat-20b"
|
306 |
+
}
|
307 |
+
},
|
308 |
+
template="intern"
|
309 |
+
)
|
310 |
+
|
311 |
+
|
312 |
+
register_model_group(
|
313 |
+
models={
|
314 |
+
"LingoWhale-8B": {
|
315 |
+
DownloadSource.DEFAULT: "deeplang-ai/LingoWhale-8B",
|
316 |
+
DownloadSource.MODELSCOPE: "DeepLang/LingoWhale-8B"
|
317 |
+
}
|
318 |
+
},
|
319 |
+
module="qkv_proj"
|
320 |
+
)
|
321 |
+
|
322 |
+
|
323 |
+
register_model_group(
|
324 |
+
models={
|
325 |
+
"LLaMA-7B": {
|
326 |
+
DownloadSource.DEFAULT: "huggyllama/llama-7b",
|
327 |
+
DownloadSource.MODELSCOPE: "skyline2006/llama-7b"
|
328 |
+
},
|
329 |
+
"LLaMA-13B": {
|
330 |
+
DownloadSource.DEFAULT: "huggyllama/llama-13b",
|
331 |
+
DownloadSource.MODELSCOPE: "skyline2006/llama-13b"
|
332 |
+
},
|
333 |
+
"LLaMA-30B": {
|
334 |
+
DownloadSource.DEFAULT: "huggyllama/llama-30b",
|
335 |
+
DownloadSource.MODELSCOPE: "skyline2006/llama-30b"
|
336 |
+
},
|
337 |
+
"LLaMA-65B": {
|
338 |
+
DownloadSource.DEFAULT: "huggyllama/llama-65b",
|
339 |
+
DownloadSource.MODELSCOPE: "skyline2006/llama-65b"
|
340 |
+
}
|
341 |
+
}
|
342 |
+
)
|
343 |
+
|
344 |
+
|
345 |
+
register_model_group(
|
346 |
+
models={
|
347 |
+
"LLaMA2-7B": {
|
348 |
+
DownloadSource.DEFAULT: "meta-llama/Llama-2-7b-hf",
|
349 |
+
DownloadSource.MODELSCOPE: "modelscope/Llama-2-7b-ms"
|
350 |
+
},
|
351 |
+
"LLaMA2-13B": {
|
352 |
+
DownloadSource.DEFAULT: "meta-llama/Llama-2-13b-hf",
|
353 |
+
DownloadSource.MODELSCOPE: "modelscope/Llama-2-13b-ms"
|
354 |
+
},
|
355 |
+
"LLaMA2-70B": {
|
356 |
+
DownloadSource.DEFAULT: "meta-llama/Llama-2-70b-hf",
|
357 |
+
DownloadSource.MODELSCOPE: "modelscope/Llama-2-70b-ms"
|
358 |
+
},
|
359 |
+
"LLaMA2-7B-Chat": {
|
360 |
+
DownloadSource.DEFAULT: "meta-llama/Llama-2-7b-chat-hf",
|
361 |
+
DownloadSource.MODELSCOPE: "modelscope/Llama-2-7b-chat-ms"
|
362 |
+
},
|
363 |
+
"LLaMA2-13B-Chat": {
|
364 |
+
DownloadSource.DEFAULT: "meta-llama/Llama-2-13b-chat-hf",
|
365 |
+
DownloadSource.MODELSCOPE: "modelscope/Llama-2-13b-chat-ms"
|
366 |
+
},
|
367 |
+
"LLaMA2-70B-Chat": {
|
368 |
+
DownloadSource.DEFAULT: "meta-llama/Llama-2-70b-chat-hf",
|
369 |
+
DownloadSource.MODELSCOPE: "modelscope/Llama-2-70b-chat-ms"
|
370 |
+
}
|
371 |
+
},
|
372 |
+
template="llama2"
|
373 |
+
)
|
374 |
+
|
375 |
+
|
376 |
+
register_model_group(
|
377 |
+
models={
|
378 |
+
"Mistral-7B": {
|
379 |
+
DownloadSource.DEFAULT: "mistralai/Mistral-7B-v0.1",
|
380 |
+
DownloadSource.MODELSCOPE: "AI-ModelScope/Mistral-7B-v0.1"
|
381 |
+
},
|
382 |
+
"Mistral-7B-Chat": {
|
383 |
+
DownloadSource.DEFAULT: "mistralai/Mistral-7B-Instruct-v0.1",
|
384 |
+
DownloadSource.MODELSCOPE: "AI-ModelScope/Mistral-7B-Instruct-v0.1"
|
385 |
+
}
|
386 |
+
},
|
387 |
+
template="mistral"
|
388 |
+
)
|
389 |
+
|
390 |
+
|
391 |
+
register_model_group(
|
392 |
+
models={
|
393 |
+
"OpenChat3.5-7B-Chat": {
|
394 |
+
DownloadSource.DEFAULT: "openchat/openchat_3.5",
|
395 |
+
DownloadSource.MODELSCOPE: "myxiongmodel/openchat_3.5"
|
396 |
+
}
|
397 |
+
},
|
398 |
+
template="openchat"
|
399 |
+
)
|
400 |
+
|
401 |
+
|
402 |
+
register_model_group(
|
403 |
+
models={
|
404 |
+
"Phi1.5-1.3B": {
|
405 |
+
DownloadSource.DEFAULT: "microsoft/phi-1_5",
|
406 |
+
DownloadSource.MODELSCOPE: "allspace/PHI_1-5"
|
407 |
+
}
|
408 |
+
},
|
409 |
+
module="Wqkv"
|
410 |
+
)
|
411 |
+
|
412 |
+
|
413 |
+
register_model_group(
|
414 |
+
models={
|
415 |
+
"Qwen-1.8B": {
|
416 |
+
DownloadSource.DEFAULT: "Qwen/Qwen-1_8B",
|
417 |
+
DownloadSource.MODELSCOPE: "qwen/Qwen-1_8B"
|
418 |
+
},
|
419 |
+
"Qwen-7B": {
|
420 |
+
DownloadSource.DEFAULT: "Qwen/Qwen-7B",
|
421 |
+
DownloadSource.MODELSCOPE: "qwen/Qwen-7B"
|
422 |
+
},
|
423 |
+
"Qwen-14B": {
|
424 |
+
DownloadSource.DEFAULT: "Qwen/Qwen-14B",
|
425 |
+
DownloadSource.MODELSCOPE: "qwen/Qwen-14B"
|
426 |
+
},
|
427 |
+
"Qwen-72B": {
|
428 |
+
DownloadSource.DEFAULT: "Qwen/Qwen-72B",
|
429 |
+
DownloadSource.MODELSCOPE: "qwen/Qwen-72B"
|
430 |
+
},
|
431 |
+
"Qwen-1.8B-Chat": {
|
432 |
+
DownloadSource.DEFAULT: "Qwen/Qwen-1_8B-Chat",
|
433 |
+
DownloadSource.MODELSCOPE: "qwen/Qwen-1_8B-Chat"
|
434 |
+
},
|
435 |
+
"Qwen-7B-Chat": {
|
436 |
+
DownloadSource.DEFAULT: "Qwen/Qwen-7B-Chat",
|
437 |
+
DownloadSource.MODELSCOPE: "qwen/Qwen-7B-Chat"
|
438 |
+
},
|
439 |
+
"Qwen-14B-Chat": {
|
440 |
+
DownloadSource.DEFAULT: "Qwen/Qwen-14B-Chat",
|
441 |
+
DownloadSource.MODELSCOPE: "qwen/Qwen-14B-Chat"
|
442 |
+
},
|
443 |
+
"Qwen-72B-Chat": {
|
444 |
+
DownloadSource.DEFAULT: "Qwen/Qwen-72B-Chat",
|
445 |
+
DownloadSource.MODELSCOPE: "qwen/Qwen-72B-Chat"
|
446 |
+
},
|
447 |
+
"Qwen-1.8B-int8-Chat": {
|
448 |
+
DownloadSource.DEFAULT: "Qwen/Qwen-1_8B-Chat-Int8",
|
449 |
+
DownloadSource.MODELSCOPE: "qwen/Qwen-1_8B-Chat-Int8"
|
450 |
+
},
|
451 |
+
"Qwen-1.8B-int4-Chat": {
|
452 |
+
DownloadSource.DEFAULT: "Qwen/Qwen-1_8B-Chat-Int4",
|
453 |
+
DownloadSource.MODELSCOPE: "qwen/Qwen-1_8B-Chat-Int4"
|
454 |
+
},
|
455 |
+
"Qwen-7B-int8-Chat": {
|
456 |
+
DownloadSource.DEFAULT: "Qwen/Qwen-7B-Chat-Int8",
|
457 |
+
DownloadSource.MODELSCOPE: "qwen/Qwen-7B-Chat-Int8"
|
458 |
+
},
|
459 |
+
"Qwen-7B-int4-Chat": {
|
460 |
+
DownloadSource.DEFAULT: "Qwen/Qwen-7B-Chat-Int4",
|
461 |
+
DownloadSource.MODELSCOPE: "qwen/Qwen-7B-Chat-Int4"
|
462 |
+
},
|
463 |
+
"Qwen-14B-int8-Chat": {
|
464 |
+
DownloadSource.DEFAULT: "Qwen/Qwen-14B-Chat-Int8",
|
465 |
+
DownloadSource.MODELSCOPE: "qwen/Qwen-14B-Chat-Int8"
|
466 |
+
},
|
467 |
+
"Qwen-14B-int4-Chat": {
|
468 |
+
DownloadSource.DEFAULT: "Qwen/Qwen-14B-Chat-Int4",
|
469 |
+
DownloadSource.MODELSCOPE: "qwen/Qwen-14B-Chat-Int4"
|
470 |
+
},
|
471 |
+
"Qwen-72B-int8-Chat": {
|
472 |
+
DownloadSource.DEFAULT: "Qwen/Qwen-72B-Chat-Int8",
|
473 |
+
DownloadSource.MODELSCOPE: "qwen/Qwen-72B-Chat-Int8"
|
474 |
+
},
|
475 |
+
"Qwen-72B-int4-Chat": {
|
476 |
+
DownloadSource.DEFAULT: "Qwen/Qwen-72B-Chat-Int4",
|
477 |
+
DownloadSource.MODELSCOPE: "qwen/Qwen-72B-Chat-Int4"
|
478 |
+
}
|
479 |
+
},
|
480 |
+
module="c_attn",
|
481 |
+
template="qwen"
|
482 |
+
)
|
483 |
+
|
484 |
+
|
485 |
+
register_model_group(
|
486 |
+
models={
|
487 |
+
"Skywork-13B-Base": {
|
488 |
+
DownloadSource.DEFAULT: "Skywork/Skywork-13B-base",
|
489 |
+
DownloadSource.MODELSCOPE: "skywork/Skywork-13B-base"
|
490 |
+
}
|
491 |
+
}
|
492 |
+
)
|
493 |
+
|
494 |
+
|
495 |
+
register_model_group(
|
496 |
+
models={
|
497 |
+
"Vicuna1.5-7B-Chat": {
|
498 |
+
DownloadSource.DEFAULT: "lmsys/vicuna-7b-v1.5",
|
499 |
+
DownloadSource.MODELSCOPE: "Xorbits/vicuna-7b-v1.5"
|
500 |
+
},
|
501 |
+
"Vicuna1.5-13B-Chat": {
|
502 |
+
DownloadSource.DEFAULT: "lmsys/vicuna-13b-v1.5",
|
503 |
+
DownloadSource.MODELSCOPE: "Xorbits/vicuna-13b-v1.5"
|
504 |
+
}
|
505 |
+
},
|
506 |
+
template="vicuna"
|
507 |
+
)
|
508 |
+
|
509 |
+
|
510 |
+
register_model_group(
|
511 |
+
models={
|
512 |
+
"XVERSE-7B": {
|
513 |
+
DownloadSource.DEFAULT: "xverse/XVERSE-7B",
|
514 |
+
DownloadSource.MODELSCOPE: "xverse/XVERSE-7B"
|
515 |
+
},
|
516 |
+
"XVERSE-13B": {
|
517 |
+
DownloadSource.DEFAULT: "xverse/XVERSE-13B",
|
518 |
+
DownloadSource.MODELSCOPE: "xverse/XVERSE-13B"
|
519 |
+
},
|
520 |
+
"XVERSE-65B": {
|
521 |
+
DownloadSource.DEFAULT: "xverse/XVERSE-65B",
|
522 |
+
DownloadSource.MODELSCOPE: "xverse/XVERSE-65B"
|
523 |
+
},
|
524 |
+
"XVERSE-7B-Chat": {
|
525 |
+
DownloadSource.DEFAULT: "xverse/XVERSE-7B-Chat",
|
526 |
+
DownloadSource.MODELSCOPE: "xverse/XVERSE-7B-Chat"
|
527 |
+
},
|
528 |
+
"XVERSE-13B-Chat": {
|
529 |
+
DownloadSource.DEFAULT: "xverse/XVERSE-13B-Chat",
|
530 |
+
DownloadSource.MODELSCOPE: "xverse/XVERSE-13B-Chat"
|
531 |
+
}
|
532 |
+
},
|
533 |
+
template="xverse"
|
534 |
+
)
|
535 |
+
|
536 |
+
|
537 |
+
register_model_group(
|
538 |
+
models={
|
539 |
+
"Yayi-7B": {
|
540 |
+
DownloadSource.DEFAULT: "wenge-research/yayi-7b-llama2",
|
541 |
+
DownloadSource.MODELSCOPE: "AI-ModelScope/yayi-7b-llama2"
|
542 |
+
},
|
543 |
+
"Yayi-13B": {
|
544 |
+
DownloadSource.DEFAULT: "wenge-research/yayi-13b-llama2",
|
545 |
+
DownloadSource.MODELSCOPE: "AI-ModelScope/yayi-13b-llama2"
|
546 |
+
}
|
547 |
+
},
|
548 |
+
template="yayi"
|
549 |
+
)
|
550 |
+
|
551 |
+
|
552 |
+
register_model_group(
|
553 |
+
models={
|
554 |
+
"Yi-6B": {
|
555 |
+
DownloadSource.DEFAULT: "01-ai/Yi-6B",
|
556 |
+
DownloadSource.MODELSCOPE: "01ai/Yi-6B"
|
557 |
+
},
|
558 |
+
"Yi-34B": {
|
559 |
+
DownloadSource.DEFAULT: "01-ai/Yi-34B",
|
560 |
+
DownloadSource.MODELSCOPE: "01ai/Yi-34B"
|
561 |
+
},
|
562 |
+
"Yi-34B-Chat": {
|
563 |
+
DownloadSource.DEFAULT: "01-ai/Yi-34B-Chat",
|
564 |
+
DownloadSource.MODELSCOPE: "01ai/Yi-34B-Chat"
|
565 |
+
},
|
566 |
+
"Yi-34B-int8-Chat": {
|
567 |
+
DownloadSource.DEFAULT: "01-ai/Yi-34B-Chat-8bits",
|
568 |
+
DownloadSource.MODELSCOPE: "01ai/Yi-34B-Chat-8bits"
|
569 |
+
}
|
570 |
+
},
|
571 |
+
template="yi"
|
572 |
+
)
|
573 |
+
|
574 |
+
|
575 |
+
register_model_group(
|
576 |
+
models={
|
577 |
+
"Zephyr-7B-Alpha-Chat": {
|
578 |
+
DownloadSource.DEFAULT: "HuggingFaceH4/zephyr-7b-alpha",
|
579 |
+
DownloadSource.MODELSCOPE: "AI-ModelScope/zephyr-7b-alpha"
|
580 |
+
},
|
581 |
+
"Zephyr-7B-Beta-Chat": {
|
582 |
+
DownloadSource.DEFAULT: "HuggingFaceH4/zephyr-7b-beta",
|
583 |
+
DownloadSource.MODELSCOPE: "modelscope/zephyr-7b-beta"
|
584 |
+
}
|
585 |
+
},
|
586 |
+
template="zephyr"
|
587 |
+
)
|
LLM-Detector-V4-11w/src/llmtuner/extras/logging.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import logging
|
3 |
+
|
4 |
+
|
5 |
+
class LoggerHandler(logging.Handler):
|
6 |
+
r"""
|
7 |
+
Logger handler used in Web UI.
|
8 |
+
"""
|
9 |
+
|
10 |
+
def __init__(self):
|
11 |
+
super().__init__()
|
12 |
+
self.log = ""
|
13 |
+
|
14 |
+
def reset(self):
|
15 |
+
self.log = ""
|
16 |
+
|
17 |
+
def emit(self, record):
|
18 |
+
if record.name == "httpx":
|
19 |
+
return
|
20 |
+
log_entry = self.format(record)
|
21 |
+
self.log += log_entry
|
22 |
+
self.log += "\n\n"
|
23 |
+
|
24 |
+
|
25 |
+
def get_logger(name: str) -> logging.Logger:
|
26 |
+
r"""
|
27 |
+
Gets a standard logger with a stream hander to stdout.
|
28 |
+
"""
|
29 |
+
formatter = logging.Formatter(
|
30 |
+
fmt="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
31 |
+
datefmt="%m/%d/%Y %H:%M:%S"
|
32 |
+
)
|
33 |
+
handler = logging.StreamHandler(sys.stdout)
|
34 |
+
handler.setFormatter(formatter)
|
35 |
+
|
36 |
+
logger = logging.getLogger(name)
|
37 |
+
logger.setLevel(logging.INFO)
|
38 |
+
logger.addHandler(handler)
|
39 |
+
|
40 |
+
return logger
|
41 |
+
|
42 |
+
|
43 |
+
def reset_logging() -> None:
|
44 |
+
r"""
|
45 |
+
Removes basic config of root logger. (unused in script)
|
46 |
+
"""
|
47 |
+
root = logging.getLogger()
|
48 |
+
list(map(root.removeHandler, root.handlers))
|
49 |
+
list(map(root.removeFilter, root.filters))
|
LLM-Detector-V4-11w/src/llmtuner/extras/misc.py
ADDED
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gc
|
2 |
+
import os
|
3 |
+
import sys
|
4 |
+
import torch
|
5 |
+
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple
|
6 |
+
from transformers import InfNanRemoveLogitsProcessor, LogitsProcessorList
|
7 |
+
|
8 |
+
try:
|
9 |
+
from transformers.utils import (
|
10 |
+
is_torch_bf16_cpu_available,
|
11 |
+
is_torch_bf16_gpu_available,
|
12 |
+
is_torch_cuda_available,
|
13 |
+
is_torch_npu_available
|
14 |
+
)
|
15 |
+
_is_fp16_available = is_torch_npu_available() or is_torch_cuda_available()
|
16 |
+
_is_bf16_available = is_torch_bf16_gpu_available() or is_torch_bf16_cpu_available()
|
17 |
+
except ImportError:
|
18 |
+
_is_fp16_available = torch.cuda.is_available()
|
19 |
+
try:
|
20 |
+
_is_bf16_available = torch.cuda.is_bf16_supported()
|
21 |
+
except:
|
22 |
+
_is_bf16_available = False
|
23 |
+
|
24 |
+
if TYPE_CHECKING:
|
25 |
+
from transformers import HfArgumentParser
|
26 |
+
from llmtuner.hparams import ModelArguments
|
27 |
+
|
28 |
+
|
29 |
+
class AverageMeter:
|
30 |
+
r"""
|
31 |
+
Computes and stores the average and current value.
|
32 |
+
"""
|
33 |
+
def __init__(self):
|
34 |
+
self.reset()
|
35 |
+
|
36 |
+
def reset(self):
|
37 |
+
self.val = 0
|
38 |
+
self.avg = 0
|
39 |
+
self.sum = 0
|
40 |
+
self.count = 0
|
41 |
+
|
42 |
+
def update(self, val, n=1):
|
43 |
+
self.val = val
|
44 |
+
self.sum += val * n
|
45 |
+
self.count += n
|
46 |
+
self.avg = self.sum / self.count
|
47 |
+
|
48 |
+
|
49 |
+
def count_parameters(model: torch.nn.Module) -> Tuple[int, int]:
|
50 |
+
r"""
|
51 |
+
Returns the number of trainable parameters and number of all parameters in the model.
|
52 |
+
"""
|
53 |
+
trainable_params, all_param = 0, 0
|
54 |
+
for param in model.parameters():
|
55 |
+
num_params = param.numel()
|
56 |
+
# if using DS Zero 3 and the weights are initialized empty
|
57 |
+
if num_params == 0 and hasattr(param, "ds_numel"):
|
58 |
+
num_params = param.ds_numel
|
59 |
+
|
60 |
+
# Due to the design of 4bit linear layers from bitsandbytes, multiply the number of parameters by 2
|
61 |
+
if param.__class__.__name__ == "Params4bit":
|
62 |
+
num_params = num_params * 2
|
63 |
+
|
64 |
+
all_param += num_params
|
65 |
+
if param.requires_grad:
|
66 |
+
trainable_params += num_params
|
67 |
+
|
68 |
+
return trainable_params, all_param
|
69 |
+
|
70 |
+
|
71 |
+
def get_current_device() -> str:
|
72 |
+
import accelerate
|
73 |
+
if accelerate.utils.is_xpu_available():
|
74 |
+
return "xpu:{}".format(os.environ.get("LOCAL_RANK", "0"))
|
75 |
+
elif accelerate.utils.is_npu_available() or torch.cuda.is_available():
|
76 |
+
return "cuda:{}".format(os.environ.get("LOCAL_RANK", "0"))
|
77 |
+
else:
|
78 |
+
return "cpu"
|
79 |
+
|
80 |
+
|
81 |
+
def get_logits_processor() -> "LogitsProcessorList":
|
82 |
+
r"""
|
83 |
+
Gets logits processor that removes NaN and Inf logits.
|
84 |
+
"""
|
85 |
+
logits_processor = LogitsProcessorList()
|
86 |
+
logits_processor.append(InfNanRemoveLogitsProcessor())
|
87 |
+
return logits_processor
|
88 |
+
|
89 |
+
|
90 |
+
def infer_optim_dtype(model_dtype: torch.dtype) -> torch.dtype:
|
91 |
+
r"""
|
92 |
+
Infers the optimal dtype according to the model_dtype and device compatibility.
|
93 |
+
"""
|
94 |
+
if _is_bf16_available and model_dtype == torch.bfloat16:
|
95 |
+
return torch.bfloat16
|
96 |
+
elif _is_fp16_available:
|
97 |
+
return torch.float16
|
98 |
+
else:
|
99 |
+
return torch.float32
|
100 |
+
|
101 |
+
|
102 |
+
def parse_args(parser: "HfArgumentParser", args: Optional[Dict[str, Any]] = None) -> Tuple[Any]:
|
103 |
+
if args is not None:
|
104 |
+
return parser.parse_dict(args)
|
105 |
+
elif len(sys.argv) == 2 and sys.argv[1].endswith(".yaml"):
|
106 |
+
return parser.parse_yaml_file(os.path.abspath(sys.argv[1]))
|
107 |
+
elif len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
|
108 |
+
return parser.parse_json_file(os.path.abspath(sys.argv[1]))
|
109 |
+
else:
|
110 |
+
return parser.parse_args_into_dataclasses()
|
111 |
+
|
112 |
+
|
113 |
+
def torch_gc() -> None:
|
114 |
+
r"""
|
115 |
+
Collects GPU memory.
|
116 |
+
"""
|
117 |
+
gc.collect()
|
118 |
+
if torch.cuda.is_available():
|
119 |
+
torch.cuda.empty_cache()
|
120 |
+
torch.cuda.ipc_collect()
|
121 |
+
|
122 |
+
|
123 |
+
def try_download_model_from_ms(model_args: "ModelArguments") -> None:
|
124 |
+
if not use_modelscope() or os.path.exists(model_args.model_name_or_path):
|
125 |
+
return
|
126 |
+
|
127 |
+
try:
|
128 |
+
from modelscope import snapshot_download # type: ignore
|
129 |
+
revision = "master" if model_args.model_revision == "main" else model_args.model_revision
|
130 |
+
model_args.model_name_or_path = snapshot_download(
|
131 |
+
model_args.model_name_or_path,
|
132 |
+
revision=revision,
|
133 |
+
cache_dir=model_args.cache_dir
|
134 |
+
)
|
135 |
+
except ImportError:
|
136 |
+
raise ImportError("Please install modelscope via `pip install modelscope -U`")
|
137 |
+
|
138 |
+
|
139 |
+
def use_modelscope() -> bool:
|
140 |
+
return bool(int(os.environ.get("USE_MODELSCOPE_HUB", "0")))
|
LLM-Detector-V4-11w/src/llmtuner/extras/packages.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import importlib.metadata
|
2 |
+
import importlib.util
|
3 |
+
|
4 |
+
|
5 |
+
def is_package_available(name: str) -> bool:
|
6 |
+
return importlib.util.find_spec(name) is not None
|
7 |
+
|
8 |
+
|
9 |
+
def get_package_version(name: str) -> str:
|
10 |
+
try:
|
11 |
+
return importlib.metadata.version(name)
|
12 |
+
except:
|
13 |
+
return "0.0.0"
|
14 |
+
|
15 |
+
|
16 |
+
_fastapi_available = is_package_available("fastapi")
|
17 |
+
_flash_attn2_available = is_package_available("flash_attn") and get_package_version("flash_attn").startswith("2")
|
18 |
+
_jieba_available = is_package_available("jieba")
|
19 |
+
_matplotlib_available = is_package_available("matplotlib")
|
20 |
+
_nltk_available = is_package_available("nltk")
|
21 |
+
_rouge_available = is_package_available("rouge_chinese")
|
22 |
+
_starlette_available = is_package_available("sse_starlette")
|
23 |
+
_uvicorn_available = is_package_available("uvicorn")
|
24 |
+
|
25 |
+
|
26 |
+
def is_fastapi_availble():
|
27 |
+
return _fastapi_available
|
28 |
+
|
29 |
+
|
30 |
+
def is_flash_attn2_available():
|
31 |
+
return _flash_attn2_available
|
32 |
+
|
33 |
+
|
34 |
+
def is_jieba_available():
|
35 |
+
return _jieba_available
|
36 |
+
|
37 |
+
|
38 |
+
def is_matplotlib_available():
|
39 |
+
return _matplotlib_available
|
40 |
+
|
41 |
+
|
42 |
+
def is_nltk_available():
|
43 |
+
return _nltk_available
|
44 |
+
|
45 |
+
|
46 |
+
def is_rouge_available():
|
47 |
+
return _rouge_available
|
48 |
+
|
49 |
+
|
50 |
+
def is_starlette_available():
|
51 |
+
return _starlette_available
|
52 |
+
|
53 |
+
|
54 |
+
def is_uvicorn_available():
|
55 |
+
return _uvicorn_available
|
LLM-Detector-V4-11w/src/llmtuner/extras/patches/__init__.py
ADDED
File without changes
|
LLM-Detector-V4-11w/src/llmtuner/extras/patches/llama_patch.py
ADDED
@@ -0,0 +1,224 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
from typing import Optional, Tuple
|
5 |
+
from transformers.utils import logging
|
6 |
+
from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb
|
7 |
+
|
8 |
+
try:
|
9 |
+
from transformers.models.llama.modeling_llama import repeat_kv
|
10 |
+
except ImportError:
|
11 |
+
print("Please upgrade `transformers`.")
|
12 |
+
|
13 |
+
from llmtuner.extras.packages import is_flash_attn2_available
|
14 |
+
|
15 |
+
|
16 |
+
if is_flash_attn2_available():
|
17 |
+
from flash_attn import flash_attn_func, flash_attn_varlen_func # type: ignore
|
18 |
+
from flash_attn.bert_padding import pad_input, unpad_input # type: ignore
|
19 |
+
|
20 |
+
|
21 |
+
logger = logging.get_logger(__name__)
|
22 |
+
|
23 |
+
|
24 |
+
# Modified from: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
|
25 |
+
class LlamaShiftShortAttention(LlamaAttention):
|
26 |
+
|
27 |
+
def forward(
|
28 |
+
self,
|
29 |
+
hidden_states: torch.Tensor,
|
30 |
+
attention_mask: Optional[torch.Tensor] = None,
|
31 |
+
position_ids: Optional[torch.LongTensor] = None,
|
32 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
33 |
+
output_attentions: bool = False,
|
34 |
+
use_cache: bool = False,
|
35 |
+
**kwargs
|
36 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
37 |
+
bsz, q_len, _ = hidden_states.size()
|
38 |
+
|
39 |
+
query_states = self.q_proj(hidden_states)
|
40 |
+
key_states = self.k_proj(hidden_states)
|
41 |
+
value_states = self.v_proj(hidden_states)
|
42 |
+
|
43 |
+
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
44 |
+
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
45 |
+
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
46 |
+
|
47 |
+
kv_seq_len = key_states.shape[-2]
|
48 |
+
if past_key_value is not None:
|
49 |
+
kv_seq_len += past_key_value[0].shape[-2]
|
50 |
+
|
51 |
+
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
52 |
+
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
53 |
+
|
54 |
+
if past_key_value is not None: # reuse k, v, self_attention
|
55 |
+
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
56 |
+
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
57 |
+
|
58 |
+
past_key_value = (key_states, value_states) if use_cache else None
|
59 |
+
|
60 |
+
if getattr(self, "num_key_value_groups"):
|
61 |
+
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
62 |
+
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
63 |
+
|
64 |
+
if getattr(self.config, "group_size_ratio", None) and self.training: # shift
|
65 |
+
groupsz = int(q_len * getattr(self.config, "group_size_ratio"))
|
66 |
+
assert q_len % groupsz == 0, "q_len {} should be divisible by group size {}.".format(q_len, groupsz)
|
67 |
+
num_groups = q_len // groupsz
|
68 |
+
def shift(state: torch.Tensor) -> torch.Tensor:
|
69 |
+
state = state.transpose(1, 2) # output: (bsz, seq_len, n_heads, head_dim)
|
70 |
+
state = torch.cat((
|
71 |
+
state[:, :, :self.num_heads//2], state[:, :, self.num_heads//2:].roll(-groupsz//2, dims=1)
|
72 |
+
), dim=2)
|
73 |
+
return state.reshape(bsz * num_groups, groupsz, self.num_heads, self.head_dim).transpose(1, 2)
|
74 |
+
|
75 |
+
query_states, key_states, value_states = shift(query_states), shift(key_states), shift(value_states)
|
76 |
+
if attention_mask is not None:
|
77 |
+
attention_mask = attention_mask[:, :, :groupsz, :groupsz].repeat(num_groups, 1, 1, 1)
|
78 |
+
|
79 |
+
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
80 |
+
|
81 |
+
if attention_mask is not None:
|
82 |
+
attn_weights = attn_weights + attention_mask
|
83 |
+
|
84 |
+
# upcast attention to fp32
|
85 |
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
86 |
+
attn_output = torch.matmul(attn_weights, value_states) # (bsz, :, seq_len, :) or (bsz*n_group, :, groupsz, :)
|
87 |
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
88 |
+
|
89 |
+
if getattr(self.config, "group_size_ratio", None) and self.training: # shift back
|
90 |
+
attn_output.reshape(bsz, q_len, self.num_heads, self.head_dim)
|
91 |
+
attn_output = torch.cat((
|
92 |
+
attn_output[:, :, :self.num_heads//2], attn_output[:, :, self.num_heads//2:].roll(groupsz//2, dims=1)
|
93 |
+
))
|
94 |
+
|
95 |
+
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
96 |
+
attn_output = self.o_proj(attn_output)
|
97 |
+
|
98 |
+
if not output_attentions:
|
99 |
+
attn_weights = None
|
100 |
+
|
101 |
+
return attn_output, attn_weights, past_key_value
|
102 |
+
|
103 |
+
|
104 |
+
class LlamaFlashAttention2(LlamaAttention):
|
105 |
+
|
106 |
+
def forward(
|
107 |
+
self,
|
108 |
+
hidden_states: torch.Tensor,
|
109 |
+
attention_mask: Optional[torch.Tensor] = None,
|
110 |
+
position_ids: Optional[torch.LongTensor] = None,
|
111 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
112 |
+
output_attentions: bool = False,
|
113 |
+
use_cache: bool = False,
|
114 |
+
**kwargs
|
115 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
116 |
+
# LlamaFlashAttention2 attention does not support output_attentions
|
117 |
+
output_attentions = False
|
118 |
+
|
119 |
+
bsz, q_len, _ = hidden_states.size()
|
120 |
+
|
121 |
+
query_states = self.q_proj(hidden_states)
|
122 |
+
key_states = self.k_proj(hidden_states)
|
123 |
+
value_states = self.v_proj(hidden_states)
|
124 |
+
|
125 |
+
# FlashAttention requires the input to have the shape (bsz, seq_len, n_heads, head_dim)
|
126 |
+
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
127 |
+
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
128 |
+
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
129 |
+
|
130 |
+
kv_seq_len = key_states.shape[-2]
|
131 |
+
if past_key_value is not None:
|
132 |
+
kv_seq_len += past_key_value[0].shape[-2]
|
133 |
+
|
134 |
+
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
135 |
+
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
136 |
+
|
137 |
+
if past_key_value is not None: # reuse k, v, self_attention
|
138 |
+
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
139 |
+
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
140 |
+
|
141 |
+
past_key_value = (key_states, value_states) if use_cache else None
|
142 |
+
|
143 |
+
# cast to half precision
|
144 |
+
input_dtype = query_states.dtype
|
145 |
+
if input_dtype == torch.float32:
|
146 |
+
logger.warning_once("The input hidden states seems to be silently casted in float32.")
|
147 |
+
query_states = query_states.to(self.config.torch_dtype)
|
148 |
+
key_states = key_states.to(self.config.torch_dtype)
|
149 |
+
value_states = value_states.to(self.config.torch_dtype)
|
150 |
+
|
151 |
+
if getattr(self, "num_key_value_groups", None):
|
152 |
+
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
153 |
+
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
154 |
+
|
155 |
+
query_states = query_states.transpose(1, 2) # (bsz, seq_len, n_heads, head_dim)
|
156 |
+
key_states = key_states.transpose(1, 2) # (bsz, seq_len, n_heads, head_dim)
|
157 |
+
value_states = value_states.transpose(1, 2) # (bsz, seq_len, n_heads, head_dim)
|
158 |
+
|
159 |
+
if getattr(self.config, "group_size_ratio", None) and self.training: # shift
|
160 |
+
groupsz = int(q_len * getattr(self.config, "group_size_ratio"))
|
161 |
+
assert q_len % groupsz == 0, "q_len {} should be divisible by group size {}.".format(q_len, groupsz)
|
162 |
+
num_groups = q_len // groupsz
|
163 |
+
def shift(state: torch.Tensor) -> torch.Tensor:
|
164 |
+
state = torch.cat((
|
165 |
+
state[:, :, :self.num_heads//2], state[:, :, self.num_heads//2:].roll(-groupsz//2, dims=1)
|
166 |
+
), dim=2)
|
167 |
+
return state.reshape(bsz * num_groups, groupsz, self.num_heads, self.head_dim)
|
168 |
+
|
169 |
+
query_states, key_states, value_states = shift(query_states), shift(key_states), shift(value_states)
|
170 |
+
if attention_mask is not None:
|
171 |
+
attention_mask = attention_mask.reshape(bsz * num_groups, groupsz)
|
172 |
+
|
173 |
+
if attention_mask is not None:
|
174 |
+
logger.warning_once("Padded sequences are less efficient in FlashAttention.")
|
175 |
+
# -q_len: assumes left padding when q_len != kv_len
|
176 |
+
unpadded_q, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(query_states, attention_mask[:, -q_len:])
|
177 |
+
unpadded_k, _, cu_seqlens_k, max_seqlen_k = unpad_input(key_states, attention_mask)
|
178 |
+
unpadded_v, _, _, _ = unpad_input(value_states, attention_mask)
|
179 |
+
attn_output_unpad = flash_attn_varlen_func(
|
180 |
+
unpadded_q,
|
181 |
+
unpadded_k,
|
182 |
+
unpadded_v,
|
183 |
+
cu_seqlens_q=cu_seqlens_q,
|
184 |
+
cu_seqlens_k=cu_seqlens_k,
|
185 |
+
max_seqlen_q=max_seqlen_q,
|
186 |
+
max_seqlen_k=max_seqlen_k,
|
187 |
+
dropout_p=0.0,
|
188 |
+
softmax_scale=None,
|
189 |
+
causal=True,
|
190 |
+
)
|
191 |
+
attn_output = pad_input(attn_output_unpad, indices_q, bsz, q_len)
|
192 |
+
else:
|
193 |
+
attn_output = flash_attn_func(
|
194 |
+
query_states, key_states, value_states, 0.0, softmax_scale=None, causal=True
|
195 |
+
)
|
196 |
+
|
197 |
+
if getattr(self.config, "group_size_ratio", None) and self.training: # shift back
|
198 |
+
attn_output.reshape(bsz, q_len, self.num_heads, self.head_dim)
|
199 |
+
attn_output = torch.cat((
|
200 |
+
attn_output[:, :, :self.num_heads//2], attn_output[:, :, self.num_heads//2:].roll(groupsz//2, dims=1)
|
201 |
+
))
|
202 |
+
|
203 |
+
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
|
204 |
+
attn_output = self.o_proj(attn_output)
|
205 |
+
|
206 |
+
if not output_attentions:
|
207 |
+
attn_weights = None
|
208 |
+
|
209 |
+
return attn_output, attn_weights, past_key_value
|
210 |
+
|
211 |
+
|
212 |
+
# Disable the transformation of the attention mask in LlamaModel as flash attention
|
213 |
+
# takes a boolean padding_mask. Fills in the past kv length for use in forward.
|
214 |
+
def _prepare_decoder_attention_mask(
|
215 |
+
self,
|
216 |
+
attention_mask: torch.Tensor,
|
217 |
+
input_shape: torch.Tensor,
|
218 |
+
inputs_embeds: torch.Tensor,
|
219 |
+
past_key_values_length: int
|
220 |
+
) -> torch.Tensor:
|
221 |
+
if attention_mask is not None and torch.all(attention_mask):
|
222 |
+
return None # This uses the faster call when training with full samples
|
223 |
+
|
224 |
+
return attention_mask
|
LLM-Detector-V4-11w/src/llmtuner/extras/ploting.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import math
|
3 |
+
import json
|
4 |
+
from typing import List, Optional
|
5 |
+
from transformers.trainer import TRAINER_STATE_NAME
|
6 |
+
|
7 |
+
from llmtuner.extras.logging import get_logger
|
8 |
+
from llmtuner.extras.packages import is_matplotlib_available
|
9 |
+
|
10 |
+
if is_matplotlib_available():
|
11 |
+
import matplotlib.pyplot as plt
|
12 |
+
|
13 |
+
|
14 |
+
logger = get_logger(__name__)
|
15 |
+
|
16 |
+
|
17 |
+
def smooth(scalars: List[float]) -> List[float]:
|
18 |
+
r"""
|
19 |
+
EMA implementation according to TensorBoard.
|
20 |
+
"""
|
21 |
+
last = scalars[0]
|
22 |
+
smoothed = list()
|
23 |
+
weight = 1.8 * (1 / (1 + math.exp(-0.05 * len(scalars))) - 0.5) # a sigmoid function
|
24 |
+
for next_val in scalars:
|
25 |
+
smoothed_val = last * weight + (1 - weight) * next_val
|
26 |
+
smoothed.append(smoothed_val)
|
27 |
+
last = smoothed_val
|
28 |
+
return smoothed
|
29 |
+
|
30 |
+
|
31 |
+
def plot_loss(save_dictionary: os.PathLike, keys: Optional[List[str]] = ["loss"]) -> None:
|
32 |
+
|
33 |
+
with open(os.path.join(save_dictionary, TRAINER_STATE_NAME), "r", encoding="utf-8") as f:
|
34 |
+
data = json.load(f)
|
35 |
+
|
36 |
+
for key in keys:
|
37 |
+
steps, metrics = [], []
|
38 |
+
for i in range(len(data["log_history"])):
|
39 |
+
if key in data["log_history"][i]:
|
40 |
+
steps.append(data["log_history"][i]["step"])
|
41 |
+
metrics.append(data["log_history"][i][key])
|
42 |
+
|
43 |
+
if len(metrics) == 0:
|
44 |
+
logger.warning(f"No metric {key} to plot.")
|
45 |
+
continue
|
46 |
+
|
47 |
+
plt.figure()
|
48 |
+
plt.plot(steps, metrics, alpha=0.4, label="original")
|
49 |
+
plt.plot(steps, smooth(metrics), label="smoothed")
|
50 |
+
plt.title("training {} of {}".format(key, save_dictionary))
|
51 |
+
plt.xlabel("step")
|
52 |
+
plt.ylabel(key)
|
53 |
+
plt.legend()
|
54 |
+
plt.savefig(os.path.join(save_dictionary, "training_{}.png".format(key)), format="png", dpi=100)
|
55 |
+
print("Figure saved:", os.path.join(save_dictionary, "training_{}.png".format(key)))
|
LLM-Detector-V4-11w/src/llmtuner/hparams/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .data_args import DataArguments
|
2 |
+
from .evaluation_args import EvaluationArguments
|
3 |
+
from .finetuning_args import FinetuningArguments
|
4 |
+
from .generating_args import GeneratingArguments
|
5 |
+
from .model_args import ModelArguments
|
LLM-Detector-V4-11w/src/llmtuner/hparams/data_args.py
ADDED
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
from typing import List, Literal, Optional
|
4 |
+
from dataclasses import dataclass, field
|
5 |
+
|
6 |
+
|
7 |
+
DATA_CONFIG = "dataset_info.json"
|
8 |
+
|
9 |
+
|
10 |
+
@dataclass
|
11 |
+
class DatasetAttr:
|
12 |
+
|
13 |
+
load_from: str
|
14 |
+
dataset_name: Optional[str] = None
|
15 |
+
dataset_sha1: Optional[str] = None
|
16 |
+
system_prompt: Optional[str] = None
|
17 |
+
subset: Optional[str] = None
|
18 |
+
ranking: Optional[bool] = False
|
19 |
+
formatting: Optional[Literal["alpaca", "sharegpt"]] = "alpaca"
|
20 |
+
|
21 |
+
prompt: Optional[str] = "instruction"
|
22 |
+
query: Optional[str] = "input"
|
23 |
+
response: Optional[str] = "output"
|
24 |
+
history: Optional[str] = None
|
25 |
+
messages: Optional[str] = "conversations"
|
26 |
+
role: Optional[str] = "from"
|
27 |
+
content: Optional[str] = "value"
|
28 |
+
|
29 |
+
def __repr__(self) -> str:
|
30 |
+
return self.dataset_name
|
31 |
+
|
32 |
+
|
33 |
+
@dataclass
|
34 |
+
class DataArguments:
|
35 |
+
r"""
|
36 |
+
Arguments pertaining to what data we are going to input our model for training and evaluation.
|
37 |
+
"""
|
38 |
+
template: Optional[str] = field(
|
39 |
+
default=None,
|
40 |
+
metadata={"help": "Which template to use for constructing prompts in training and inference."}
|
41 |
+
)
|
42 |
+
dataset: Optional[str] = field(
|
43 |
+
default=None,
|
44 |
+
metadata={"help": "The name of provided dataset(s) to use. Use commas to separate multiple datasets."}
|
45 |
+
)
|
46 |
+
dataset_dir: Optional[str] = field(
|
47 |
+
default="data",
|
48 |
+
metadata={"help": "Path to the folder containing the datasets."}
|
49 |
+
)
|
50 |
+
split: Optional[str] = field(
|
51 |
+
default="train",
|
52 |
+
metadata={"help": "Which dataset split to use for training and evaluation."}
|
53 |
+
)
|
54 |
+
cutoff_len: Optional[int] = field(
|
55 |
+
default=1024,
|
56 |
+
metadata={"help": "The maximum length of the model inputs after tokenization."}
|
57 |
+
)
|
58 |
+
reserved_label_len: Optional[int] = field(
|
59 |
+
default=1,
|
60 |
+
metadata={"help": "The maximum length reserved for label after tokenization."}
|
61 |
+
)
|
62 |
+
train_on_prompt: Optional[bool] = field(
|
63 |
+
default=False,
|
64 |
+
metadata={"help": "Whether to disable the mask on the prompt or not."}
|
65 |
+
)
|
66 |
+
streaming: Optional[bool] = field(
|
67 |
+
default=False,
|
68 |
+
metadata={"help": "Enable dataset streaming."}
|
69 |
+
)
|
70 |
+
buffer_size: Optional[int] = field(
|
71 |
+
default=16384,
|
72 |
+
metadata={"help": "Size of the buffer to randomly sample examples from in dataset streaming."}
|
73 |
+
)
|
74 |
+
mix_strategy: Optional[Literal["concat", "interleave_under", "interleave_over"]] = field(
|
75 |
+
default="concat",
|
76 |
+
metadata={"help": "Strategy to use in dataset mixing (concat/interleave) (undersampling/oversampling)."}
|
77 |
+
)
|
78 |
+
interleave_probs: Optional[str] = field(
|
79 |
+
default=None,
|
80 |
+
metadata={"help": "Probabilities to sample data from datasets. Use commas to separate multiple datasets."}
|
81 |
+
)
|
82 |
+
overwrite_cache: Optional[bool] = field(
|
83 |
+
default=False,
|
84 |
+
metadata={"help": "Overwrite the cached training and evaluation sets."}
|
85 |
+
)
|
86 |
+
preprocessing_num_workers: Optional[int] = field(
|
87 |
+
default=None,
|
88 |
+
metadata={"help": "The number of processes to use for the preprocessing."}
|
89 |
+
)
|
90 |
+
max_samples: Optional[int] = field(
|
91 |
+
default=None,
|
92 |
+
metadata={"help": "For debugging purposes, truncate the number of examples for each dataset."}
|
93 |
+
)
|
94 |
+
eval_num_beams: Optional[int] = field(
|
95 |
+
default=None,
|
96 |
+
metadata={"help": "Number of beams to use for evaluation. This argument will be passed to `model.generate`"}
|
97 |
+
)
|
98 |
+
ignore_pad_token_for_loss: Optional[bool] = field(
|
99 |
+
default=True,
|
100 |
+
metadata={"help": "Whether to ignore the tokens corresponding to padded labels in the loss computation or not."}
|
101 |
+
)
|
102 |
+
system_prompt: Optional[str] = field(
|
103 |
+
default=None,
|
104 |
+
metadata={"help": "System prompt to add before the user query. Use `|` to separate multiple prompts in training."}
|
105 |
+
)
|
106 |
+
val_size: Optional[float] = field(
|
107 |
+
default=0,
|
108 |
+
metadata={"help": "Size of the development set, should be an integer or a float in range `[0,1)`."}
|
109 |
+
)
|
110 |
+
sft_packing: Optional[bool] = field(
|
111 |
+
default=False,
|
112 |
+
metadata={"help": "Packing the questions and answers in the supervised fine-tuning stage."}
|
113 |
+
)
|
114 |
+
cache_path: Optional[str] = field(
|
115 |
+
default=None,
|
116 |
+
metadata={"help": "Path to save or load the preprocessed datasets."}
|
117 |
+
)
|
118 |
+
|
119 |
+
def __post_init__(self):
|
120 |
+
if self.reserved_label_len >= self.cutoff_len:
|
121 |
+
raise ValueError("`reserved_label_len` must be smaller than `cutoff_len`.")
|
122 |
+
|
123 |
+
if self.streaming and self.val_size > 1e-6 and self.val_size < 1:
|
124 |
+
raise ValueError("Streaming mode should have an integer val size.")
|
125 |
+
|
126 |
+
if self.streaming and self.max_samples is not None:
|
127 |
+
raise ValueError("`max_samples` is incompatible with `streaming`.")
|
128 |
+
|
129 |
+
if self.streaming and self.cache_path:
|
130 |
+
raise ValueError("`cache_path` is incompatible with `streaming`.")
|
131 |
+
|
132 |
+
def init_for_training(self, seed: int): # support mixing multiple datasets
|
133 |
+
self.seed = seed
|
134 |
+
dataset_names = [ds.strip() for ds in self.dataset.split(",")] if self.dataset is not None else []
|
135 |
+
try:
|
136 |
+
with open(os.path.join(self.dataset_dir, DATA_CONFIG), "r") as f:
|
137 |
+
dataset_info = json.load(f)
|
138 |
+
except Exception as err:
|
139 |
+
if self.dataset is not None:
|
140 |
+
raise ValueError("Cannot open {} due to {}.".format(os.path.join(self.dataset_dir, DATA_CONFIG), str(err)))
|
141 |
+
dataset_info = None
|
142 |
+
|
143 |
+
prompt_list = self.system_prompt.split("|") if self.system_prompt else [None]
|
144 |
+
prompt_list = prompt_list * (len(dataset_names) // len(prompt_list))
|
145 |
+
assert len(prompt_list) == len(dataset_names), "Number of system prompts should be equal to datasets or 1."
|
146 |
+
|
147 |
+
if self.interleave_probs is not None:
|
148 |
+
self.interleave_probs = [float(prob.strip()) for prob in self.interleave_probs.split(",")]
|
149 |
+
|
150 |
+
self.dataset_list: List[DatasetAttr] = []
|
151 |
+
for i, name in enumerate(dataset_names):
|
152 |
+
if name not in dataset_info:
|
153 |
+
raise ValueError("Undefined dataset {} in {}.".format(name, DATA_CONFIG))
|
154 |
+
|
155 |
+
if "hf_hub_url" in dataset_info[name]:
|
156 |
+
dataset_attr = DatasetAttr("hf_hub", dataset_name=dataset_info[name]["hf_hub_url"])
|
157 |
+
elif "script_url" in dataset_info[name]:
|
158 |
+
dataset_attr = DatasetAttr("script", dataset_name=dataset_info[name]["script_url"])
|
159 |
+
else:
|
160 |
+
dataset_attr = DatasetAttr(
|
161 |
+
"file",
|
162 |
+
dataset_name=dataset_info[name]["file_name"],
|
163 |
+
dataset_sha1=dataset_info[name].get("file_sha1", None)
|
164 |
+
)
|
165 |
+
|
166 |
+
if "columns" in dataset_info[name]:
|
167 |
+
dataset_attr.prompt = dataset_info[name]["columns"].get("prompt", None)
|
168 |
+
dataset_attr.query = dataset_info[name]["columns"].get("query", None)
|
169 |
+
dataset_attr.response = dataset_info[name]["columns"].get("response", None)
|
170 |
+
dataset_attr.history = dataset_info[name]["columns"].get("history", None)
|
171 |
+
dataset_attr.messages = dataset_info[name]["columns"].get("messages", None)
|
172 |
+
dataset_attr.role = dataset_info[name]["columns"].get("role", None)
|
173 |
+
dataset_attr.content = dataset_info[name]["columns"].get("content", None)
|
174 |
+
|
175 |
+
dataset_attr.subset = dataset_info[name].get("subset", None)
|
176 |
+
dataset_attr.ranking = dataset_info[name].get("ranking", False)
|
177 |
+
dataset_attr.formatting = dataset_info[name].get("formatting", "alpaca")
|
178 |
+
dataset_attr.system_prompt = prompt_list[i]
|
179 |
+
self.dataset_list.append(dataset_attr)
|
LLM-Detector-V4-11w/src/llmtuner/hparams/evaluation_args.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from typing import Literal, Optional
|
3 |
+
from dataclasses import dataclass, field
|
4 |
+
|
5 |
+
from datasets import DownloadMode
|
6 |
+
|
7 |
+
|
8 |
+
@dataclass
|
9 |
+
class EvaluationArguments:
|
10 |
+
r"""
|
11 |
+
Arguments pertaining to specify the evaluation parameters.
|
12 |
+
"""
|
13 |
+
task: str = field(
|
14 |
+
metadata={"help": "Name of the evaluation task."}
|
15 |
+
)
|
16 |
+
task_dir: Optional[str] = field(
|
17 |
+
default="evaluation",
|
18 |
+
metadata={"help": "Path to the folder containing the evaluation datasets."}
|
19 |
+
)
|
20 |
+
batch_size: Optional[int] = field(
|
21 |
+
default=4,
|
22 |
+
metadata={"help": "The batch size per GPU for evaluation."}
|
23 |
+
)
|
24 |
+
seed: Optional[int] = field(
|
25 |
+
default=42,
|
26 |
+
metadata={"help": "Random seed to be used with data loaders."}
|
27 |
+
)
|
28 |
+
lang: Optional[Literal["en", "zh"]] = field(
|
29 |
+
default="en",
|
30 |
+
metadata={"help": "Language used at evaluation."}
|
31 |
+
)
|
32 |
+
n_shot: Optional[int] = field(
|
33 |
+
default=5,
|
34 |
+
metadata={"help": "Number of examplars for few-shot learning."}
|
35 |
+
)
|
36 |
+
save_dir: Optional[str] = field(
|
37 |
+
default=None,
|
38 |
+
metadata={"help": "Path to save the evaluation results."}
|
39 |
+
)
|
40 |
+
download_mode: Optional[DownloadMode] = field(
|
41 |
+
default=DownloadMode.REUSE_DATASET_IF_EXISTS,
|
42 |
+
metadata={"help": "Download mode used for the evaluation datasets."}
|
43 |
+
)
|
44 |
+
|
45 |
+
def __post_init__(self):
|
46 |
+
task_available = []
|
47 |
+
for folder in os.listdir(self.task_dir):
|
48 |
+
if os.path.isdir(os.path.join(self.task_dir, folder)):
|
49 |
+
task_available.append(folder)
|
50 |
+
|
51 |
+
if self.task not in task_available:
|
52 |
+
raise ValueError("Task {} not found in {}.".format(self.task, self.task_dir))
|
53 |
+
|
54 |
+
if self.save_dir is not None and os.path.exists(self.save_dir):
|
55 |
+
raise ValueError("`save_dir` already exists, use another one.")
|
LLM-Detector-V4-11w/src/llmtuner/hparams/finetuning_args.py
ADDED
@@ -0,0 +1,196 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
from typing import Literal, Optional
|
3 |
+
from dataclasses import asdict, dataclass, field
|
4 |
+
|
5 |
+
|
6 |
+
@dataclass
|
7 |
+
class FreezeArguments:
|
8 |
+
r"""
|
9 |
+
Arguments pertaining to the freeze (partial-parameter) training.
|
10 |
+
"""
|
11 |
+
name_module_trainable: Optional[str] = field(
|
12 |
+
default="mlp",
|
13 |
+
metadata={"help": "Name of trainable modules for partial-parameter (freeze) fine-tuning. \
|
14 |
+
Use commas to separate multiple modules. \
|
15 |
+
LLaMA choices: [\"mlp\", \"self_attn\"], \
|
16 |
+
BLOOM & Falcon & ChatGLM choices: [\"mlp\", \"self_attention\"], \
|
17 |
+
Qwen choices: [\"mlp\", \"attn\"], \
|
18 |
+
Phi-1.5 choices: [\"mlp\", \"mixer\"], \
|
19 |
+
Others choices: the same as LLaMA."}
|
20 |
+
)
|
21 |
+
num_layer_trainable: Optional[int] = field(
|
22 |
+
default=3,
|
23 |
+
metadata={"help": "The number of trainable layers for partial-parameter (freeze) fine-tuning."}
|
24 |
+
)
|
25 |
+
|
26 |
+
|
27 |
+
@dataclass
|
28 |
+
class LoraArguments:
|
29 |
+
r"""
|
30 |
+
Arguments pertaining to the LoRA training.
|
31 |
+
"""
|
32 |
+
additional_target: Optional[str] = field(
|
33 |
+
default=None,
|
34 |
+
metadata={"help": "Name(s) of modules apart from LoRA layers to be set as trainable and saved in the final checkpoint."}
|
35 |
+
)
|
36 |
+
lora_alpha: Optional[float] = field(
|
37 |
+
default=None,
|
38 |
+
metadata={"help": "The scale factor for LoRA fine-tuning (default: lora_rank * 2.0)."}
|
39 |
+
)
|
40 |
+
lora_dropout: Optional[float] = field(
|
41 |
+
default=0.1,
|
42 |
+
metadata={"help": "Dropout rate for the LoRA fine-tuning."}
|
43 |
+
)
|
44 |
+
lora_rank: Optional[int] = field(
|
45 |
+
default=8,
|
46 |
+
metadata={"help": "The intrinsic dimension for LoRA fine-tuning."}
|
47 |
+
)
|
48 |
+
lora_target: Optional[str] = field(
|
49 |
+
default=None,
|
50 |
+
metadata={"help": "Name(s) of target modules to apply LoRA. Use commas to separate multiple modules. \
|
51 |
+
LLaMA choices: [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"], \
|
52 |
+
BLOOM & Falcon & ChatGLM choices: [\"query_key_value\", \"dense\", \"dense_h_to_4h\", \"dense_4h_to_h\"], \
|
53 |
+
Baichuan choices: [\"W_pack\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"], \
|
54 |
+
Qwen choices: [\"c_attn\", \"attn.c_proj\", \"w1\", \"w2\", \"mlp.c_proj\"], \
|
55 |
+
Phi-1.5 choices: [\"Wqkv\", \"out_proj\", \"fc1\", \"fc2\"], \
|
56 |
+
Others choices: the same as LLaMA."}
|
57 |
+
)
|
58 |
+
resume_lora_training: Optional[bool] = field(
|
59 |
+
default=True,
|
60 |
+
metadata={"help": "Whether to resume training from the last LoRA weights or create new weights after merging them."}
|
61 |
+
)
|
62 |
+
|
63 |
+
|
64 |
+
@dataclass
|
65 |
+
class RLHFArguments:
|
66 |
+
r"""
|
67 |
+
Arguments pertaining to the PPO and DPO training.
|
68 |
+
"""
|
69 |
+
dpo_beta: Optional[float] = field(
|
70 |
+
default=0.1,
|
71 |
+
metadata={"help": "The beta parameter for the DPO loss."}
|
72 |
+
)
|
73 |
+
ppo_buffer_size: Optional[int] = field(
|
74 |
+
default=1,
|
75 |
+
metadata={"help": "The number of mini-batches to make experience buffer in a PPO optimization step."}
|
76 |
+
)
|
77 |
+
ppo_epochs: Optional[int] = field(
|
78 |
+
default=4,
|
79 |
+
metadata={"help": "The number of epochs to perform in a PPO optimization step."}
|
80 |
+
)
|
81 |
+
ppo_logger: Optional[str] = field(
|
82 |
+
default=None,
|
83 |
+
metadata={"help": "Log with either \"wandb\" or \"tensorboard\" in PPO training."}
|
84 |
+
)
|
85 |
+
ppo_score_norm: Optional[bool] = field(
|
86 |
+
default=False,
|
87 |
+
metadata={"help": "Use score normalization in PPO training."}
|
88 |
+
)
|
89 |
+
ppo_target: Optional[float] = field(
|
90 |
+
default=6.0,
|
91 |
+
metadata={"help": "Target KL value for adaptive KL control in PPO training."}
|
92 |
+
)
|
93 |
+
ppo_whiten_rewards: Optional[bool] = field(
|
94 |
+
default=False,
|
95 |
+
metadata={"help": "Whiten the rewards before compute advantages in PPO training."}
|
96 |
+
)
|
97 |
+
ref_model: Optional[str] = field(
|
98 |
+
default=None,
|
99 |
+
metadata={"help": "Path to the reference model used for the PPO or DPO training."}
|
100 |
+
)
|
101 |
+
ref_model_checkpoint: Optional[str] = field(
|
102 |
+
default=None,
|
103 |
+
metadata={"help": "Path to the directory(s) containing the model checkpoints of the reference model."}
|
104 |
+
)
|
105 |
+
ref_model_quantization_bit: Optional[int] = field(
|
106 |
+
default=None,
|
107 |
+
metadata={"help": "The number of bits to quantize the reference model."}
|
108 |
+
)
|
109 |
+
reward_model: Optional[str] = field(
|
110 |
+
default=None,
|
111 |
+
metadata={"help": "Path to the directory containing the checkpoints of the reward model."}
|
112 |
+
)
|
113 |
+
reward_model_checkpoint: Optional[str] = field(
|
114 |
+
default=None,
|
115 |
+
metadata={"help": "Path to the directory(s) containing the model checkpoints of the reward model."}
|
116 |
+
)
|
117 |
+
reward_model_quantization_bit: Optional[int] = field(
|
118 |
+
default=None,
|
119 |
+
metadata={"help": "The number of bits to quantize the reward model."}
|
120 |
+
)
|
121 |
+
reward_model_type: Optional[Literal["lora", "full"]] = field(
|
122 |
+
default="lora",
|
123 |
+
metadata={"help": "The checkpoint type of the reward model. The lora type only supports lora training."}
|
124 |
+
)
|
125 |
+
|
126 |
+
|
127 |
+
@dataclass
|
128 |
+
class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments):
|
129 |
+
r"""
|
130 |
+
Arguments pertaining to which techniques we are going to fine-tuning with.
|
131 |
+
"""
|
132 |
+
stage: Optional[Literal["pt", "sft", "rm", "ppo", "dpo"]] = field(
|
133 |
+
default="sft",
|
134 |
+
metadata={"help": "Which stage will be performed in training."}
|
135 |
+
)
|
136 |
+
finetuning_type: Optional[Literal["lora", "freeze", "full"]] = field(
|
137 |
+
default="lora",
|
138 |
+
metadata={"help": "Which fine-tuning method to use."}
|
139 |
+
)
|
140 |
+
upcast_layernorm: Optional[bool] = field(
|
141 |
+
default=False,
|
142 |
+
metadata={"help": "Whether to upcast the layernorm weights in fp32."}
|
143 |
+
)
|
144 |
+
neft_alpha: Optional[float] = field(
|
145 |
+
default=0,
|
146 |
+
metadata={"help": "The alpha parameter to control the noise magnitude in NEFTune."}
|
147 |
+
)
|
148 |
+
export_dir: Optional[str] = field(
|
149 |
+
default=None,
|
150 |
+
metadata={"help": "Path to the directory to save the exported model."}
|
151 |
+
)
|
152 |
+
export_size: Optional[int] = field(
|
153 |
+
default=1,
|
154 |
+
metadata={"help": "The file shard size (in GB) of the exported model."}
|
155 |
+
)
|
156 |
+
plot_loss: Optional[bool] = field(
|
157 |
+
default=False,
|
158 |
+
metadata={"help": "Whether to plot the training loss after fine-tuning or not."}
|
159 |
+
)
|
160 |
+
|
161 |
+
def __post_init__(self):
|
162 |
+
def split_arg(arg):
|
163 |
+
if isinstance(arg, str):
|
164 |
+
return [item.strip() for item in arg.split(",")]
|
165 |
+
return arg
|
166 |
+
|
167 |
+
self.name_module_trainable = split_arg(self.name_module_trainable)
|
168 |
+
self.lora_alpha = self.lora_alpha or float(self.lora_rank * 2.0)
|
169 |
+
self.lora_target = split_arg(self.lora_target)
|
170 |
+
self.additional_target = split_arg(self.additional_target)
|
171 |
+
self.ref_model_checkpoint = split_arg(self.ref_model_checkpoint)
|
172 |
+
self.reward_model_checkpoint = split_arg(self.reward_model_checkpoint)
|
173 |
+
|
174 |
+
assert self.finetuning_type in ["lora", "freeze", "full"], "Invalid fine-tuning method."
|
175 |
+
assert self.ref_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization."
|
176 |
+
assert self.reward_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization."
|
177 |
+
|
178 |
+
if self.stage == "ppo" and self.reward_model is None:
|
179 |
+
raise ValueError("Reward model is necessary for PPO training.")
|
180 |
+
|
181 |
+
if self.stage == "ppo" and self.reward_model_type == "lora" and self.finetuning_type != "lora":
|
182 |
+
raise ValueError("Freeze/Full PPO training needs `reward_model_type=full`.")
|
183 |
+
|
184 |
+
def save_to_json(self, json_path: str):
|
185 |
+
r"""Saves the content of this instance in JSON format inside `json_path`."""
|
186 |
+
json_string = json.dumps(asdict(self), indent=2, sort_keys=True) + "\n"
|
187 |
+
with open(json_path, "w", encoding="utf-8") as f:
|
188 |
+
f.write(json_string)
|
189 |
+
|
190 |
+
@classmethod
|
191 |
+
def load_from_json(cls, json_path: str):
|
192 |
+
r"""Creates an instance from the content of `json_path`."""
|
193 |
+
with open(json_path, "r", encoding="utf-8") as f:
|
194 |
+
text = f.read()
|
195 |
+
|
196 |
+
return cls(**json.loads(text))
|
LLM-Detector-V4-11w/src/llmtuner/hparams/generating_args.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, Dict, Optional
|
2 |
+
from dataclasses import asdict, dataclass, field
|
3 |
+
|
4 |
+
|
5 |
+
@dataclass
|
6 |
+
class GeneratingArguments:
|
7 |
+
r"""
|
8 |
+
Arguments pertaining to specify the decoding parameters.
|
9 |
+
"""
|
10 |
+
do_sample: Optional[bool] = field(
|
11 |
+
default=True,
|
12 |
+
metadata={"help": "Whether or not to use sampling, use greedy decoding otherwise."}
|
13 |
+
)
|
14 |
+
temperature: Optional[float] = field(
|
15 |
+
default=0.95,
|
16 |
+
metadata={"help": "The value used to modulate the next token probabilities."}
|
17 |
+
)
|
18 |
+
top_p: Optional[float] = field(
|
19 |
+
default=0.7,
|
20 |
+
metadata={"help": "The smallest set of most probable tokens with probabilities that add up to top_p or higher are kept."}
|
21 |
+
)
|
22 |
+
top_k: Optional[int] = field(
|
23 |
+
default=50,
|
24 |
+
metadata={"help": "The number of highest probability vocabulary tokens to keep for top-k filtering."}
|
25 |
+
)
|
26 |
+
num_beams: Optional[int] = field(
|
27 |
+
default=1,
|
28 |
+
metadata={"help": "Number of beams for beam search. 1 means no beam search."}
|
29 |
+
)
|
30 |
+
max_length: Optional[int] = field(
|
31 |
+
default=512,
|
32 |
+
metadata={"help": "The maximum length the generated tokens can have. It can be overridden by max_new_tokens."}
|
33 |
+
)
|
34 |
+
max_new_tokens: Optional[int] = field(
|
35 |
+
default=512,
|
36 |
+
metadata={"help": "The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt."}
|
37 |
+
)
|
38 |
+
repetition_penalty: Optional[float] = field(
|
39 |
+
default=1.0,
|
40 |
+
metadata={"help": "The parameter for repetition penalty. 1.0 means no penalty."}
|
41 |
+
)
|
42 |
+
length_penalty: Optional[float] = field(
|
43 |
+
default=1.0,
|
44 |
+
metadata={"help": "Exponential penalty to the length that is used with beam-based generation."}
|
45 |
+
)
|
46 |
+
|
47 |
+
def to_dict(self) -> Dict[str, Any]:
|
48 |
+
args = asdict(self)
|
49 |
+
if args.get("max_new_tokens", -1) > 0:
|
50 |
+
args.pop("max_length", None)
|
51 |
+
else:
|
52 |
+
args.pop("max_new_tokens", None)
|
53 |
+
return args
|
LLM-Detector-V4-11w/src/llmtuner/hparams/model_args.py
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, Dict, Literal, Optional
|
2 |
+
from dataclasses import asdict, dataclass, field
|
3 |
+
|
4 |
+
|
5 |
+
@dataclass
|
6 |
+
class ModelArguments:
|
7 |
+
r"""
|
8 |
+
Arguments pertaining to which model/config/tokenizer we are going to fine-tune.
|
9 |
+
"""
|
10 |
+
model_name_or_path: str = field(
|
11 |
+
metadata={"help": "Path to pretrained model or model identifier from \
|
12 |
+
huggingface.co/models or modelscope.cn/models."}
|
13 |
+
)
|
14 |
+
cache_dir: Optional[str] = field(
|
15 |
+
default=None,
|
16 |
+
metadata={"help": "Where to store the pretrained models downloaded from huggingface.co."}
|
17 |
+
)
|
18 |
+
use_fast_tokenizer: Optional[bool] = field(
|
19 |
+
default=True,
|
20 |
+
metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}
|
21 |
+
)
|
22 |
+
split_special_tokens: Optional[bool] = field(
|
23 |
+
default=False,
|
24 |
+
metadata={"help": "Whether or not the special tokens should be split during the tokenization process."}
|
25 |
+
)
|
26 |
+
model_revision: Optional[str] = field(
|
27 |
+
default="main",
|
28 |
+
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}
|
29 |
+
)
|
30 |
+
quantization_bit: Optional[int] = field(
|
31 |
+
default=None,
|
32 |
+
metadata={"help": "The number of bits to quantize the model."}
|
33 |
+
)
|
34 |
+
quantization_type: Optional[Literal["fp4", "nf4"]] = field(
|
35 |
+
default="nf4",
|
36 |
+
metadata={"help": "Quantization data type to use in int4 training."}
|
37 |
+
)
|
38 |
+
double_quantization: Optional[bool] = field(
|
39 |
+
default=True,
|
40 |
+
metadata={"help": "Whether to use double quantization in int4 training or not."}
|
41 |
+
)
|
42 |
+
rope_scaling: Optional[Literal["linear", "dynamic"]] = field(
|
43 |
+
default=None,
|
44 |
+
metadata={"help": "Adopt scaled rotary positional embeddings."}
|
45 |
+
)
|
46 |
+
checkpoint_dir: Optional[str] = field(
|
47 |
+
default=None,
|
48 |
+
metadata={"help": "Path to the directory(s) containing the model checkpoints as well as the configurations."}
|
49 |
+
)
|
50 |
+
flash_attn: Optional[bool] = field(
|
51 |
+
default=False,
|
52 |
+
metadata={"help": "Enable FlashAttention-2 for faster training."}
|
53 |
+
)
|
54 |
+
shift_attn: Optional[bool] = field(
|
55 |
+
default=False,
|
56 |
+
metadata={"help": "Enable shift short attention (S^2-Attn) proposed by LongLoRA."}
|
57 |
+
)
|
58 |
+
hf_hub_token: Optional[str] = field(
|
59 |
+
default=None,
|
60 |
+
metadata={"help": "Auth token to log in with Hugging Face Hub."}
|
61 |
+
)
|
62 |
+
|
63 |
+
def __post_init__(self):
|
64 |
+
self.compute_dtype = None
|
65 |
+
self.model_max_length = None
|
66 |
+
|
67 |
+
if self.split_special_tokens and self.use_fast_tokenizer:
|
68 |
+
raise ValueError("`split_special_tokens` is only supported for slow tokenizers.")
|
69 |
+
|
70 |
+
if self.checkpoint_dir is not None: # support merging multiple lora weights
|
71 |
+
self.checkpoint_dir = [cd.strip() for cd in self.checkpoint_dir.split(",")]
|
72 |
+
|
73 |
+
assert self.quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization."
|
74 |
+
|
75 |
+
def to_dict(self) -> Dict[str, Any]:
|
76 |
+
return asdict(self)
|
LLM-Detector-V4-11w/src/llmtuner/model/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Level: loader > adapter > parser, utils
|
2 |
+
|
3 |
+
from llmtuner.model.loader import load_model_and_tokenizer
|
4 |
+
from llmtuner.model.parser import get_train_args, get_infer_args, get_eval_args
|
5 |
+
from llmtuner.model.utils import dispatch_model, get_modelcard_args, load_valuehead_params
|
LLM-Detector-V4-11w/src/llmtuner/model/adapter.py
ADDED
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from typing import TYPE_CHECKING
|
3 |
+
from peft import PeftModel, TaskType, LoraConfig, get_peft_model
|
4 |
+
|
5 |
+
from llmtuner.extras.logging import get_logger
|
6 |
+
from llmtuner.model.utils import find_all_linear_modules
|
7 |
+
|
8 |
+
if TYPE_CHECKING:
|
9 |
+
from transformers.modeling_utils import PreTrainedModel
|
10 |
+
from llmtuner.hparams import ModelArguments, FinetuningArguments
|
11 |
+
|
12 |
+
|
13 |
+
logger = get_logger(__name__)
|
14 |
+
|
15 |
+
|
16 |
+
def init_adapter(
|
17 |
+
model: "PreTrainedModel",
|
18 |
+
model_args: "ModelArguments",
|
19 |
+
finetuning_args: "FinetuningArguments",
|
20 |
+
is_trainable: bool
|
21 |
+
) -> "PreTrainedModel":
|
22 |
+
r"""
|
23 |
+
Initializes the adapters.
|
24 |
+
|
25 |
+
Support full-parameter, freeze and LoRA training.
|
26 |
+
|
27 |
+
Note that the trainable parameters must be cast to float32.
|
28 |
+
"""
|
29 |
+
|
30 |
+
if (not is_trainable) and model_args.checkpoint_dir is None:
|
31 |
+
logger.info("Checkpoint is not found at evaluation, load the original model.")
|
32 |
+
return model
|
33 |
+
|
34 |
+
if finetuning_args.finetuning_type == "full" and is_trainable:
|
35 |
+
logger.info("Fine-tuning method: Full")
|
36 |
+
model = model.float()
|
37 |
+
|
38 |
+
if finetuning_args.finetuning_type == "freeze" and is_trainable:
|
39 |
+
logger.info("Fine-tuning method: Freeze")
|
40 |
+
num_layers = (
|
41 |
+
getattr(model.config, "num_hidden_layers", None)
|
42 |
+
or getattr(model.config, "num_layers", None)
|
43 |
+
or getattr(model.config, "n_layer", None)
|
44 |
+
)
|
45 |
+
if not num_layers:
|
46 |
+
raise ValueError("Current model does not support freeze tuning.")
|
47 |
+
if finetuning_args.num_layer_trainable > 0: # fine-tuning the last n layers if num_layer_trainable > 0
|
48 |
+
trainable_layer_ids = [num_layers - k - 1 for k in range(finetuning_args.num_layer_trainable)]
|
49 |
+
else: # fine-tuning the first n layers if num_layer_trainable < 0
|
50 |
+
trainable_layer_ids = [k for k in range(-finetuning_args.num_layer_trainable)]
|
51 |
+
|
52 |
+
trainable_layers = []
|
53 |
+
for module_name in finetuning_args.name_module_trainable:
|
54 |
+
for idx in trainable_layer_ids:
|
55 |
+
trainable_layers.append("{:d}.{}".format(idx, module_name))
|
56 |
+
|
57 |
+
for name, param in model.named_parameters():
|
58 |
+
if not any(trainable_layer in name for trainable_layer in trainable_layers):
|
59 |
+
param.requires_grad_(False)
|
60 |
+
else:
|
61 |
+
param.data = param.data.to(torch.float32)
|
62 |
+
|
63 |
+
if finetuning_args.finetuning_type == "lora":
|
64 |
+
logger.info("Fine-tuning method: LoRA")
|
65 |
+
checkpoint_to_resume = None
|
66 |
+
|
67 |
+
if model_args.checkpoint_dir is not None:
|
68 |
+
is_mergeable = True
|
69 |
+
if getattr(model, "quantization_method", None) == "gptq":
|
70 |
+
assert len(model_args.checkpoint_dir) == 1, "GPTQ quantized model only accepts a single checkpoint."
|
71 |
+
is_mergeable = False
|
72 |
+
|
73 |
+
if (is_trainable and finetuning_args.resume_lora_training) or (not is_mergeable):
|
74 |
+
checkpoints_to_merge, checkpoint_to_resume = model_args.checkpoint_dir[:-1], model_args.checkpoint_dir[-1]
|
75 |
+
else:
|
76 |
+
checkpoints_to_merge = model_args.checkpoint_dir
|
77 |
+
|
78 |
+
for checkpoint in checkpoints_to_merge:
|
79 |
+
model = PeftModel.from_pretrained(model, checkpoint)
|
80 |
+
model = model.merge_and_unload()
|
81 |
+
|
82 |
+
if len(checkpoints_to_merge) > 0:
|
83 |
+
logger.info("Merged {} model checkpoint(s).".format(len(checkpoints_to_merge)))
|
84 |
+
|
85 |
+
if checkpoint_to_resume is not None: # resume lora training
|
86 |
+
model = PeftModel.from_pretrained(model, checkpoint_to_resume, is_trainable=is_trainable)
|
87 |
+
|
88 |
+
if is_trainable and checkpoint_to_resume is None: # create new lora weights while training
|
89 |
+
if len(finetuning_args.lora_target) == 1 and finetuning_args.lora_target[0] == "all":
|
90 |
+
target_modules = find_all_linear_modules(model, model_args.quantization_bit)
|
91 |
+
else:
|
92 |
+
target_modules = finetuning_args.lora_target
|
93 |
+
|
94 |
+
lora_config = LoraConfig(
|
95 |
+
task_type=TaskType.CAUSAL_LM,
|
96 |
+
inference_mode=False,
|
97 |
+
r=finetuning_args.lora_rank,
|
98 |
+
lora_alpha=finetuning_args.lora_alpha,
|
99 |
+
lora_dropout=finetuning_args.lora_dropout,
|
100 |
+
target_modules=target_modules,
|
101 |
+
modules_to_save=finetuning_args.additional_target
|
102 |
+
)
|
103 |
+
model = get_peft_model(model, lora_config)
|
104 |
+
|
105 |
+
if model_args.checkpoint_dir is not None:
|
106 |
+
logger.info("Loaded fine-tuned model from checkpoint(s): {}".format(",".join(model_args.checkpoint_dir)))
|
107 |
+
|
108 |
+
return model
|
LLM-Detector-V4-11w/src/llmtuner/model/loader.py
ADDED
@@ -0,0 +1,235 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
from types import MethodType
|
4 |
+
from typing import TYPE_CHECKING, Literal, Optional, Tuple
|
5 |
+
|
6 |
+
from transformers import (
|
7 |
+
AutoConfig,
|
8 |
+
AutoModelForCausalLM,
|
9 |
+
AutoTokenizer,
|
10 |
+
BitsAndBytesConfig,
|
11 |
+
PretrainedConfig,
|
12 |
+
PreTrainedModel,
|
13 |
+
PreTrainedTokenizerBase
|
14 |
+
)
|
15 |
+
from transformers.models.llama import modeling_llama as LlamaModule
|
16 |
+
from transformers.utils.versions import require_version
|
17 |
+
from trl import AutoModelForCausalLMWithValueHead
|
18 |
+
|
19 |
+
try:
|
20 |
+
from transformers.integrations import is_deepspeed_zero3_enabled
|
21 |
+
except ImportError: # https://github.com/huggingface/transformers/releases/tag/v4.33.1
|
22 |
+
from transformers.deepspeed import is_deepspeed_zero3_enabled
|
23 |
+
|
24 |
+
from llmtuner.extras.logging import get_logger
|
25 |
+
from llmtuner.extras.misc import count_parameters, get_current_device, infer_optim_dtype, try_download_model_from_ms
|
26 |
+
from llmtuner.extras.packages import is_flash_attn2_available
|
27 |
+
from llmtuner.extras.patches import llama_patch as LlamaPatches
|
28 |
+
from llmtuner.hparams import FinetuningArguments
|
29 |
+
from llmtuner.model.adapter import init_adapter
|
30 |
+
from llmtuner.model.utils import load_valuehead_params, prepare_model_for_training
|
31 |
+
|
32 |
+
if TYPE_CHECKING:
|
33 |
+
from transformers import PreTrainedTokenizer
|
34 |
+
from llmtuner.hparams import ModelArguments
|
35 |
+
|
36 |
+
|
37 |
+
logger = get_logger(__name__)
|
38 |
+
|
39 |
+
|
40 |
+
require_version("transformers>=4.31.0,<4.35.0", "To fix: pip install \"transformers>=4.31.0,<4.35.0\"")
|
41 |
+
require_version("datasets>=2.14.0", "To fix: pip install datasets>=2.14.0")
|
42 |
+
require_version("accelerate>=0.21.0", "To fix: pip install accelerate>=0.21.0")
|
43 |
+
require_version("peft>=0.6.0", "To fix: pip install peft>=0.6.0")
|
44 |
+
require_version("trl>=0.7.4", "To fix: pip install trl>=0.7.4")
|
45 |
+
|
46 |
+
|
47 |
+
def load_model_and_tokenizer(
|
48 |
+
model_args: "ModelArguments",
|
49 |
+
finetuning_args: "FinetuningArguments",
|
50 |
+
is_trainable: Optional[bool] = False,
|
51 |
+
stage: Optional[Literal["pt", "sft", "rm", "ppo"]] = "sft"
|
52 |
+
) -> Tuple[PreTrainedModel, "PreTrainedTokenizer"]:
|
53 |
+
r"""
|
54 |
+
Loads pretrained model and tokenizer.
|
55 |
+
|
56 |
+
Support both training and inference.
|
57 |
+
"""
|
58 |
+
|
59 |
+
try_download_model_from_ms(model_args)
|
60 |
+
|
61 |
+
config_kwargs = {
|
62 |
+
"trust_remote_code": True,
|
63 |
+
"cache_dir": model_args.cache_dir,
|
64 |
+
"revision": model_args.model_revision,
|
65 |
+
"token": model_args.hf_hub_token
|
66 |
+
}
|
67 |
+
|
68 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
69 |
+
model_args.model_name_or_path,
|
70 |
+
use_fast=model_args.use_fast_tokenizer,
|
71 |
+
split_special_tokens=model_args.split_special_tokens,
|
72 |
+
padding_side="right", # training with left-padded tensors in fp16 precision may cause overflow
|
73 |
+
**config_kwargs
|
74 |
+
)
|
75 |
+
|
76 |
+
if finetuning_args.finetuning_type != "lora" and model_args.checkpoint_dir is not None:
|
77 |
+
logger.info("Use `model_name_or_path` to specify the model trained with full/freeze method.")
|
78 |
+
model_to_load = model_args.checkpoint_dir[0]
|
79 |
+
else:
|
80 |
+
model_to_load = model_args.model_name_or_path
|
81 |
+
|
82 |
+
config = AutoConfig.from_pretrained(model_to_load, **config_kwargs)
|
83 |
+
|
84 |
+
# Fix tokenizer (for ChatGLM2 and ChatGLM3)
|
85 |
+
if getattr(config, "model_type", None) == "chatglm":
|
86 |
+
tokenizer._pad = MethodType(PreTrainedTokenizerBase._pad, tokenizer)
|
87 |
+
|
88 |
+
# Set model dtype
|
89 |
+
if model_args.compute_dtype is None: # priority: bf16 > fp16 > fp32
|
90 |
+
model_args.compute_dtype = infer_optim_dtype(model_dtype=getattr(config, "torch_dtype", None))
|
91 |
+
setattr(config, "torch_dtype", model_args.compute_dtype)
|
92 |
+
|
93 |
+
# Fix config (for Qwen)
|
94 |
+
if getattr(config, "model_type", None) == "qwen":
|
95 |
+
for dtype_name, dtype in [("fp16", torch.float16), ("bf16", torch.bfloat16), ("fp32", torch.float32)]:
|
96 |
+
setattr(config, dtype_name, getattr(config, "torch_dtype", None) == dtype)
|
97 |
+
|
98 |
+
# Set RoPE scaling
|
99 |
+
if model_args.rope_scaling is not None:
|
100 |
+
if not hasattr(config, "rope_scaling"):
|
101 |
+
logger.warning("Current model does not support RoPE scaling.")
|
102 |
+
else:
|
103 |
+
if is_trainable:
|
104 |
+
if model_args.rope_scaling == "dynamic":
|
105 |
+
logger.warning(
|
106 |
+
"Dynamic NTK may not work well with fine-tuning. "
|
107 |
+
"See: https://github.com/huggingface/transformers/pull/24653"
|
108 |
+
)
|
109 |
+
|
110 |
+
current_max_length = getattr(config, "max_position_embeddings", None)
|
111 |
+
if current_max_length and model_args.model_max_length > current_max_length:
|
112 |
+
scaling_factor = float(math.ceil(model_args.model_max_length / current_max_length))
|
113 |
+
else:
|
114 |
+
logger.warning("Input length is smaller than max length. Consider increase input length.")
|
115 |
+
scaling_factor = 1.0
|
116 |
+
else:
|
117 |
+
scaling_factor = 2.0
|
118 |
+
|
119 |
+
setattr(config, "rope_scaling", {"type": model_args.rope_scaling, "factor": scaling_factor})
|
120 |
+
logger.info("Using {} scaling strategy and setting scaling factor to {}".format(
|
121 |
+
model_args.rope_scaling, scaling_factor
|
122 |
+
))
|
123 |
+
|
124 |
+
# Set FlashAttention-2
|
125 |
+
if model_args.flash_attn:
|
126 |
+
if getattr(config, "model_type", None) == "llama":
|
127 |
+
if is_flash_attn2_available():
|
128 |
+
LlamaModule.LlamaAttention = LlamaPatches.LlamaFlashAttention2
|
129 |
+
LlamaModule.LlamaModel._prepare_decoder_attention_mask = LlamaPatches._prepare_decoder_attention_mask
|
130 |
+
logger.info("Using FlashAttention-2 for faster training and inference.")
|
131 |
+
else:
|
132 |
+
logger.warning("FlashAttention-2 is not installed.")
|
133 |
+
elif getattr(config, "model_type", None) in ["qwen", "Yi"]:
|
134 |
+
logger.info("Current model automatically enables FlashAttention if installed.")
|
135 |
+
else:
|
136 |
+
logger.warning("Current model does not support FlashAttention.")
|
137 |
+
elif is_trainable and model_args.shift_attn and getattr(config, "model_type", None) == "llama":
|
138 |
+
LlamaModule.LlamaAttention = LlamaPatches.LlamaShiftShortAttention
|
139 |
+
logger.warning("Using `--flash_attn` for faster training in large context length.")
|
140 |
+
|
141 |
+
# Set shift short attention (S^2-Attn)
|
142 |
+
if is_trainable and model_args.shift_attn:
|
143 |
+
if getattr(config, "model_type", None) == "llama":
|
144 |
+
setattr(config, "group_size_ratio", 0.25)
|
145 |
+
logger.info("Using shift short attention with group_size_ratio=1/4.")
|
146 |
+
else:
|
147 |
+
logger.warning("Current model does not support shift short attention.")
|
148 |
+
|
149 |
+
# Quantization configurations (using gptq or awq)
|
150 |
+
if getattr(config, "quantization_config", None):
|
151 |
+
if model_args.quantization_bit is not None: # remove bnb quantization
|
152 |
+
model_args.quantization_bit = None
|
153 |
+
config_kwargs["device_map"] = {"": get_current_device()}
|
154 |
+
quantization_config = getattr(config, "quantization_config", None)
|
155 |
+
logger.info("Loading {}-bit quantized model.".format(quantization_config.get("bits", -1)))
|
156 |
+
|
157 |
+
# Quantization configurations (using bitsandbytes library)
|
158 |
+
if model_args.quantization_bit is not None:
|
159 |
+
if is_deepspeed_zero3_enabled():
|
160 |
+
raise ValueError("DeepSpeed ZeRO-3 is incompatible with quantization.")
|
161 |
+
|
162 |
+
if model_args.quantization_bit == 8:
|
163 |
+
require_version("bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0")
|
164 |
+
config_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True)
|
165 |
+
|
166 |
+
if model_args.quantization_bit == 4:
|
167 |
+
require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0")
|
168 |
+
config_kwargs["quantization_config"] = BitsAndBytesConfig(
|
169 |
+
load_in_4bit=True,
|
170 |
+
bnb_4bit_compute_dtype=model_args.compute_dtype,
|
171 |
+
bnb_4bit_use_double_quant=model_args.double_quantization,
|
172 |
+
bnb_4bit_quant_type=model_args.quantization_type
|
173 |
+
)
|
174 |
+
|
175 |
+
config_kwargs["device_map"] = {"": get_current_device()}
|
176 |
+
logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit))
|
177 |
+
|
178 |
+
# Load pre-trained models (without valuehead)
|
179 |
+
model = AutoModelForCausalLM.from_pretrained(
|
180 |
+
model_to_load,
|
181 |
+
config=config,
|
182 |
+
torch_dtype=model_args.compute_dtype,
|
183 |
+
low_cpu_mem_usage=(not is_deepspeed_zero3_enabled()),
|
184 |
+
**config_kwargs
|
185 |
+
)
|
186 |
+
|
187 |
+
# Disable custom generate method (for Qwen and Baichuan2)
|
188 |
+
if isinstance(model, PreTrainedModel) and "GenerationMixin" not in str(model.generate.__func__):
|
189 |
+
model.generate = MethodType(PreTrainedModel.generate, model)
|
190 |
+
|
191 |
+
# Fix LM head (for ChatGLM2 and ChatGLM3)
|
192 |
+
if getattr(config, "model_type", None) == "chatglm":
|
193 |
+
setattr(model, "lm_head", model.transformer.output_layer)
|
194 |
+
setattr(model, "_keys_to_ignore_on_save", ["lm_head.weight"])
|
195 |
+
|
196 |
+
# Register auto class to save the custom code files
|
197 |
+
if isinstance(config, PretrainedConfig) and "AutoConfig" in getattr(config, "auto_map", {}):
|
198 |
+
config.__class__.register_for_auto_class()
|
199 |
+
if isinstance(model, PreTrainedModel) and "AutoModelForCausalLM" in getattr(config, "auto_map", {}):
|
200 |
+
model.__class__.register_for_auto_class()
|
201 |
+
if isinstance(tokenizer, PreTrainedTokenizerBase) and "AutoTokenizer" in tokenizer.init_kwargs.get("auto_map", {}):
|
202 |
+
tokenizer.__class__.register_for_auto_class()
|
203 |
+
|
204 |
+
# Initialize adapters
|
205 |
+
model = prepare_model_for_training(model=model, finetuning_args=finetuning_args) if is_trainable else model
|
206 |
+
model = init_adapter(model, model_args, finetuning_args, is_trainable)
|
207 |
+
model = model.train() if is_trainable else model.eval()
|
208 |
+
|
209 |
+
# Prepare model with valuehead for RLHF
|
210 |
+
if stage in ["rm", "ppo"]:
|
211 |
+
model: "AutoModelForCausalLMWithValueHead" = AutoModelForCausalLMWithValueHead.from_pretrained(model)
|
212 |
+
setattr(model, "_keys_to_ignore_on_save", [name for name, _ in model.named_parameters() if "pretrained_model" in name])
|
213 |
+
setattr(model, "tie_weights", MethodType(lambda _: None, model)) # use empty method
|
214 |
+
vhead_path = (
|
215 |
+
model_args.checkpoint_dir[-1] if model_args.checkpoint_dir is not None else model_args.model_name_or_path
|
216 |
+
)
|
217 |
+
vhead_params = load_valuehead_params(vhead_path, model_args)
|
218 |
+
if vhead_params is not None:
|
219 |
+
model.load_state_dict(vhead_params, strict=False)
|
220 |
+
logger.info("Loaded valuehead from checkpoint: {}".format(vhead_path))
|
221 |
+
|
222 |
+
# Prepare model for inference
|
223 |
+
if not is_trainable:
|
224 |
+
model.requires_grad_(False) # fix all model params
|
225 |
+
model = model.to(model_args.compute_dtype) if model_args.quantization_bit is None else model
|
226 |
+
|
227 |
+
trainable_params, all_param = count_parameters(model)
|
228 |
+
logger.info("trainable params: {:d} || all params: {:d} || trainable%: {:.4f}".format(
|
229 |
+
trainable_params, all_param, 100 * trainable_params / all_param
|
230 |
+
))
|
231 |
+
|
232 |
+
if not is_trainable:
|
233 |
+
logger.info("This IS expected that the trainable params is 0 if you are using model for inference only.")
|
234 |
+
|
235 |
+
return model, tokenizer
|
LLM-Detector-V4-11w/src/llmtuner/model/parser.py
ADDED
@@ -0,0 +1,205 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import datasets
|
4 |
+
import transformers
|
5 |
+
from typing import Any, Dict, Optional, Tuple
|
6 |
+
from transformers import HfArgumentParser, Seq2SeqTrainingArguments
|
7 |
+
from transformers.trainer_utils import get_last_checkpoint
|
8 |
+
|
9 |
+
from llmtuner.extras.logging import get_logger
|
10 |
+
from llmtuner.extras.misc import parse_args
|
11 |
+
from llmtuner.hparams import (
|
12 |
+
ModelArguments,
|
13 |
+
DataArguments,
|
14 |
+
EvaluationArguments,
|
15 |
+
FinetuningArguments,
|
16 |
+
GeneratingArguments
|
17 |
+
)
|
18 |
+
|
19 |
+
|
20 |
+
logger = get_logger(__name__)
|
21 |
+
|
22 |
+
|
23 |
+
_TRAIN_ARGS = [
|
24 |
+
ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneratingArguments
|
25 |
+
]
|
26 |
+
_TRAIN_CLS = Tuple[
|
27 |
+
ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneratingArguments
|
28 |
+
]
|
29 |
+
_INFER_ARGS = [
|
30 |
+
ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments
|
31 |
+
]
|
32 |
+
_INFER_CLS = Tuple[
|
33 |
+
ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments
|
34 |
+
]
|
35 |
+
_EVAL_ARGS = [
|
36 |
+
ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments
|
37 |
+
]
|
38 |
+
_EVAL_CLS = Tuple[
|
39 |
+
ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments
|
40 |
+
]
|
41 |
+
|
42 |
+
|
43 |
+
def _verify_model_args(model_args: "ModelArguments", finetuning_args: "FinetuningArguments") -> None:
|
44 |
+
if model_args.quantization_bit is not None and finetuning_args.finetuning_type != "lora":
|
45 |
+
raise ValueError("Quantization is only compatible with the LoRA method.")
|
46 |
+
|
47 |
+
if (
|
48 |
+
model_args.checkpoint_dir is not None
|
49 |
+
and len(model_args.checkpoint_dir) != 1
|
50 |
+
and finetuning_args.finetuning_type != "lora"
|
51 |
+
):
|
52 |
+
raise ValueError("Multiple checkpoints are only available for LoRA tuning.")
|
53 |
+
|
54 |
+
|
55 |
+
def parse_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
56 |
+
parser = HfArgumentParser(_TRAIN_ARGS)
|
57 |
+
return parse_args(parser, args)
|
58 |
+
|
59 |
+
|
60 |
+
def parse_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS:
|
61 |
+
parser = HfArgumentParser(_INFER_ARGS)
|
62 |
+
return parse_args(parser, args)
|
63 |
+
|
64 |
+
|
65 |
+
def parse_eval_args(args: Optional[Dict[str, Any]] = None) -> _EVAL_CLS:
|
66 |
+
parser = HfArgumentParser(_EVAL_ARGS)
|
67 |
+
return parse_args(parser, args)
|
68 |
+
|
69 |
+
|
70 |
+
def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
71 |
+
model_args, data_args, training_args, finetuning_args, generating_args = parse_train_args(args)
|
72 |
+
|
73 |
+
# Setup logging
|
74 |
+
if training_args.should_log:
|
75 |
+
# The default of training_args.log_level is passive, so we set log level at info here to have that default.
|
76 |
+
transformers.utils.logging.set_verbosity_info()
|
77 |
+
|
78 |
+
log_level = training_args.get_process_log_level()
|
79 |
+
datasets.utils.logging.set_verbosity(log_level)
|
80 |
+
transformers.utils.logging.set_verbosity(log_level)
|
81 |
+
transformers.utils.logging.enable_default_handler()
|
82 |
+
transformers.utils.logging.enable_explicit_format()
|
83 |
+
|
84 |
+
# Check arguments
|
85 |
+
data_args.init_for_training(training_args.seed)
|
86 |
+
|
87 |
+
if finetuning_args.stage != "pt" and data_args.template is None:
|
88 |
+
raise ValueError("Please specify which `template` to use.")
|
89 |
+
|
90 |
+
if finetuning_args.stage != "sft" and training_args.predict_with_generate:
|
91 |
+
raise ValueError("`predict_with_generate` cannot be set as True except SFT.")
|
92 |
+
|
93 |
+
if finetuning_args.stage == "sft" and training_args.do_predict and not training_args.predict_with_generate:
|
94 |
+
raise ValueError("Please enable `predict_with_generate` to save model predictions.")
|
95 |
+
|
96 |
+
if finetuning_args.stage in ["rm", "ppo"] and training_args.load_best_model_at_end:
|
97 |
+
raise ValueError("RM and PPO stages do not support `load_best_model_at_end`.")
|
98 |
+
|
99 |
+
if finetuning_args.stage == "ppo" and not training_args.do_train:
|
100 |
+
raise ValueError("PPO training does not support evaluation, use the SFT stage to evaluate models.")
|
101 |
+
|
102 |
+
if finetuning_args.stage in ["rm", "dpo"] and (not all([data_attr.ranking for data_attr in data_args.dataset_list])):
|
103 |
+
raise ValueError("Please use ranked datasets for reward modeling or DPO training.")
|
104 |
+
|
105 |
+
if finetuning_args.stage == "ppo" and model_args.shift_attn:
|
106 |
+
raise ValueError("PPO training is incompatible with S^2-Attn.")
|
107 |
+
|
108 |
+
if training_args.max_steps == -1 and data_args.streaming:
|
109 |
+
raise ValueError("Please specify `max_steps` in streaming mode.")
|
110 |
+
|
111 |
+
if training_args.do_train and training_args.predict_with_generate:
|
112 |
+
raise ValueError("`predict_with_generate` cannot be set as True while training.")
|
113 |
+
|
114 |
+
if training_args.do_train and finetuning_args.finetuning_type == "lora" and finetuning_args.lora_target is None:
|
115 |
+
raise ValueError("Please specify `lora_target` in LoRA training.")
|
116 |
+
|
117 |
+
_verify_model_args(model_args, finetuning_args)
|
118 |
+
|
119 |
+
if training_args.do_train and model_args.quantization_bit is not None and (not finetuning_args.upcast_layernorm):
|
120 |
+
logger.warning("We recommend enable `upcast_layernorm` in quantized training.")
|
121 |
+
|
122 |
+
if training_args.do_train and (not training_args.fp16) and (not training_args.bf16):
|
123 |
+
logger.warning("We recommend enable mixed precision training.")
|
124 |
+
|
125 |
+
if (not training_args.do_train) and model_args.quantization_bit is not None:
|
126 |
+
logger.warning("Evaluating model in 4/8-bit mode may cause lower scores.")
|
127 |
+
|
128 |
+
if (not training_args.do_train) and finetuning_args.stage == "dpo" and finetuning_args.ref_model is None:
|
129 |
+
logger.warning("Specify `ref_model` for computing rewards at evaluation.")
|
130 |
+
|
131 |
+
# postprocess training_args
|
132 |
+
if (
|
133 |
+
training_args.local_rank != -1
|
134 |
+
and training_args.ddp_find_unused_parameters is None
|
135 |
+
and finetuning_args.finetuning_type == "lora"
|
136 |
+
):
|
137 |
+
logger.warning("`ddp_find_unused_parameters` needs to be set as False for LoRA in DDP training.")
|
138 |
+
training_args_dict = training_args.to_dict()
|
139 |
+
training_args_dict.update(dict(ddp_find_unused_parameters=False))
|
140 |
+
training_args = Seq2SeqTrainingArguments(**training_args_dict)
|
141 |
+
|
142 |
+
if (
|
143 |
+
training_args.resume_from_checkpoint is None
|
144 |
+
and training_args.do_train
|
145 |
+
and os.path.isdir(training_args.output_dir)
|
146 |
+
and not training_args.overwrite_output_dir
|
147 |
+
):
|
148 |
+
last_checkpoint = get_last_checkpoint(training_args.output_dir)
|
149 |
+
if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
|
150 |
+
raise ValueError("Output directory already exists and is not empty. Please set `overwrite_output_dir`.")
|
151 |
+
|
152 |
+
if last_checkpoint is not None:
|
153 |
+
training_args_dict = training_args.to_dict()
|
154 |
+
training_args_dict.update(dict(resume_from_checkpoint=last_checkpoint))
|
155 |
+
training_args = Seq2SeqTrainingArguments(**training_args_dict)
|
156 |
+
logger.info("Resuming training from {}. Change `output_dir` or use `overwrite_output_dir` to avoid.".format(
|
157 |
+
training_args.resume_from_checkpoint
|
158 |
+
))
|
159 |
+
|
160 |
+
if finetuning_args.stage in ["rm", "ppo"] and training_args.resume_from_checkpoint is not None:
|
161 |
+
logger.warning("Add {} to `checkpoint_dir` to resume training from checkpoint.".format(
|
162 |
+
training_args.resume_from_checkpoint
|
163 |
+
))
|
164 |
+
|
165 |
+
# postprocess model_args
|
166 |
+
model_args.compute_dtype = (
|
167 |
+
torch.bfloat16 if training_args.bf16 else (torch.float16 if training_args.fp16 else None)
|
168 |
+
)
|
169 |
+
model_args.model_max_length = data_args.cutoff_len
|
170 |
+
|
171 |
+
# Log on each process the small summary:
|
172 |
+
logger.info("Process rank: {}, device: {}, n_gpu: {}\n distributed training: {}, compute dtype: {}".format(
|
173 |
+
training_args.local_rank, training_args.device, training_args.n_gpu,
|
174 |
+
bool(training_args.local_rank != -1), str(model_args.compute_dtype)
|
175 |
+
))
|
176 |
+
logger.info(f"Training/evaluation parameters {training_args}")
|
177 |
+
|
178 |
+
# Set seed before initializing model.
|
179 |
+
transformers.set_seed(training_args.seed)
|
180 |
+
|
181 |
+
return model_args, data_args, training_args, finetuning_args, generating_args
|
182 |
+
|
183 |
+
|
184 |
+
def get_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS:
|
185 |
+
model_args, data_args, finetuning_args, generating_args = parse_infer_args(args)
|
186 |
+
|
187 |
+
if data_args.template is None:
|
188 |
+
raise ValueError("Please specify which `template` to use.")
|
189 |
+
|
190 |
+
_verify_model_args(model_args, finetuning_args)
|
191 |
+
|
192 |
+
return model_args, data_args, finetuning_args, generating_args
|
193 |
+
|
194 |
+
|
195 |
+
def get_eval_args(args: Optional[Dict[str, Any]] = None) -> _EVAL_CLS:
|
196 |
+
model_args, data_args, eval_args, finetuning_args = parse_eval_args(args)
|
197 |
+
|
198 |
+
if data_args.template is None:
|
199 |
+
raise ValueError("Please specify which `template` to use.")
|
200 |
+
|
201 |
+
_verify_model_args(model_args, finetuning_args)
|
202 |
+
|
203 |
+
transformers.set_seed(eval_args.seed)
|
204 |
+
|
205 |
+
return model_args, data_args, eval_args, finetuning_args
|
LLM-Detector-V4-11w/src/llmtuner/model/utils.py
ADDED
@@ -0,0 +1,183 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import inspect
|
3 |
+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple
|
4 |
+
|
5 |
+
from transformers.utils import cached_file
|
6 |
+
from transformers.trainer import WEIGHTS_NAME, SAFE_WEIGHTS_NAME
|
7 |
+
|
8 |
+
from llmtuner.extras.constants import LAYERNORM_NAMES
|
9 |
+
from llmtuner.extras.logging import get_logger
|
10 |
+
from llmtuner.hparams import ModelArguments, FinetuningArguments
|
11 |
+
|
12 |
+
if TYPE_CHECKING:
|
13 |
+
from transformers.modeling_utils import PreTrainedModel
|
14 |
+
from llmtuner.hparams import DataArguments
|
15 |
+
|
16 |
+
|
17 |
+
logger = get_logger(__name__)
|
18 |
+
|
19 |
+
|
20 |
+
def dispatch_model(model: "PreTrainedModel") -> "PreTrainedModel":
|
21 |
+
r"""
|
22 |
+
Dispatches a pre-trained model to GPUs with balanced memory.
|
23 |
+
Borrowed from: https://github.com/huggingface/transformers/blob/v4.31.0/src/transformers/modeling_utils.py#L2803
|
24 |
+
"""
|
25 |
+
if getattr(model, "quantization_method", None): # already set on current device
|
26 |
+
return model
|
27 |
+
|
28 |
+
if torch.cuda.device_count() > 1:
|
29 |
+
from accelerate import dispatch_model
|
30 |
+
from accelerate.utils import infer_auto_device_map, get_balanced_memory
|
31 |
+
|
32 |
+
if model._no_split_modules is None:
|
33 |
+
raise ValueError("The model class needs to implement the `_no_split_modules` attribute.")
|
34 |
+
|
35 |
+
kwargs = {"dtype": model.dtype, "no_split_module_classes": model._no_split_modules}
|
36 |
+
max_memory = get_balanced_memory(model, **kwargs)
|
37 |
+
# Make sure tied weights are tied before creating the device map.
|
38 |
+
model.tie_weights()
|
39 |
+
device_map = infer_auto_device_map(model, max_memory=max_memory, **kwargs)
|
40 |
+
return dispatch_model(model, device_map)
|
41 |
+
else:
|
42 |
+
return model.cuda()
|
43 |
+
|
44 |
+
|
45 |
+
def find_all_linear_modules(
|
46 |
+
model: "PreTrainedModel",
|
47 |
+
quantization_bit: Optional[int] = None
|
48 |
+
) -> List[str]:
|
49 |
+
r"""
|
50 |
+
Finds all available modules to apply lora.
|
51 |
+
"""
|
52 |
+
if quantization_bit is not None:
|
53 |
+
import bitsandbytes as bnb
|
54 |
+
linear_cls = bnb.nn.Linear4bit if quantization_bit == 4 else bnb.nn.Linear8bitLt
|
55 |
+
else:
|
56 |
+
linear_cls = torch.nn.Linear
|
57 |
+
|
58 |
+
output_layer_names = ["lm_head"]
|
59 |
+
if model.config.model_type == "chatglm":
|
60 |
+
output_layer_names.append("output_layer")
|
61 |
+
|
62 |
+
module_names = set()
|
63 |
+
for name, module in model.named_modules():
|
64 |
+
if (
|
65 |
+
isinstance(module, linear_cls)
|
66 |
+
and not any([output_layer in name for output_layer in output_layer_names])
|
67 |
+
):
|
68 |
+
module_names.add(name.split(".")[-1])
|
69 |
+
|
70 |
+
logger.info("Found linear modules: {}".format(",".join(module_names)))
|
71 |
+
return list(module_names)
|
72 |
+
|
73 |
+
|
74 |
+
def get_modelcard_args(
|
75 |
+
model_args: "ModelArguments",
|
76 |
+
data_args: "DataArguments",
|
77 |
+
finetuning_args: "FinetuningArguments"
|
78 |
+
) -> Dict[str, Any]:
|
79 |
+
return {
|
80 |
+
"tasks": "text-generation",
|
81 |
+
"license": "other",
|
82 |
+
"finetuned_from": model_args.model_name_or_path,
|
83 |
+
"dataset": [dataset.strip() for dataset in data_args.dataset.split(",")],
|
84 |
+
"tags": ["llama-factory"] + (["lora"] if finetuning_args.finetuning_type == "lora" else [])
|
85 |
+
}
|
86 |
+
|
87 |
+
|
88 |
+
def load_valuehead_params(
|
89 |
+
path_or_repo_id: str,
|
90 |
+
model_args: "ModelArguments"
|
91 |
+
) -> Dict[str, torch.Tensor]:
|
92 |
+
r"""
|
93 |
+
Loads value head parameters from Hugging Face Hub or local disk.
|
94 |
+
|
95 |
+
Returns: dict with keys `v_head.summary.weight` and `v_head.summary.bias`.
|
96 |
+
"""
|
97 |
+
kwargs = {
|
98 |
+
"path_or_repo_id": path_or_repo_id,
|
99 |
+
"cache_dir": model_args.cache_dir
|
100 |
+
}
|
101 |
+
|
102 |
+
if "token" in inspect.signature(cached_file).parameters:
|
103 |
+
kwargs["token"] = model_args.hf_hub_token
|
104 |
+
elif "use_auth_token" in inspect.signature(cached_file).parameters: # for transformers==4.31.0
|
105 |
+
kwargs["use_auth_token"] = model_args.hf_hub_token
|
106 |
+
else:
|
107 |
+
logger.warning("Ignore `hf_hub_token` since matched parameter is not found.")
|
108 |
+
|
109 |
+
try:
|
110 |
+
vhead_file = cached_file(filename=WEIGHTS_NAME, **kwargs)
|
111 |
+
return torch.load(vhead_file, map_location="cpu")
|
112 |
+
except Exception as err:
|
113 |
+
logger.info("Failed to load {}: {}".format(WEIGHTS_NAME, str(err)))
|
114 |
+
|
115 |
+
try:
|
116 |
+
from safetensors import safe_open
|
117 |
+
vhead_file = cached_file(filename=SAFE_WEIGHTS_NAME, **kwargs)
|
118 |
+
with safe_open(vhead_file, framework="pt", device="cpu") as f:
|
119 |
+
return {
|
120 |
+
"v_head.summary.weight": f.get_tensor("v_head.summary.weight"),
|
121 |
+
"v_head.summary.bias": f.get_tensor("v_head.summary.bias")
|
122 |
+
}
|
123 |
+
except Exception as err:
|
124 |
+
logger.info("Failed to load {}: {}".format(SAFE_WEIGHTS_NAME, str(err)))
|
125 |
+
|
126 |
+
logger.warning("Provided path ({}) does not contain valuehead weights.".format(path_or_repo_id))
|
127 |
+
return None
|
128 |
+
|
129 |
+
|
130 |
+
def prepare_model_for_training(
|
131 |
+
model: "PreTrainedModel",
|
132 |
+
finetuning_args: "FinetuningArguments",
|
133 |
+
output_layer_name: Optional[str] = "lm_head",
|
134 |
+
use_gradient_checkpointing: Optional[bool] = True,
|
135 |
+
layernorm_names: Optional[Set[str]] = LAYERNORM_NAMES
|
136 |
+
) -> "PreTrainedModel":
|
137 |
+
r"""
|
138 |
+
Includes:
|
139 |
+
(1) cast the layernorm in fp32
|
140 |
+
(2) make output embedding layer require grads
|
141 |
+
(3) upcast the lm_head to fp32
|
142 |
+
Inspired by: https://github.com/huggingface/peft/blob/v0.2.0/src/peft/utils/other.py#L33
|
143 |
+
"""
|
144 |
+
if finetuning_args.upcast_layernorm:
|
145 |
+
for name, param in model.named_parameters():
|
146 |
+
if param.ndim == 1 and any(ln_name in name for ln_name in layernorm_names):
|
147 |
+
param.data = param.data.to(torch.float32)
|
148 |
+
logger.info("Upcasting weights in layernorm in float32.")
|
149 |
+
|
150 |
+
if finetuning_args.neft_alpha > 1e-6:
|
151 |
+
def neftune_forward_hook(module: torch.nn.Module, args: Tuple[torch.Tensor], output: torch.Tensor):
|
152 |
+
if module.training:
|
153 |
+
dims = torch.tensor(output.size(1) * output.size(2))
|
154 |
+
mag_norm = finetuning_args.neft_alpha / torch.sqrt(dims)
|
155 |
+
output = output + torch.zeros_like(output).uniform_(-mag_norm, mag_norm)
|
156 |
+
return output
|
157 |
+
|
158 |
+
model.get_input_embeddings().register_forward_hook(neftune_forward_hook)
|
159 |
+
logger.info("Using noisy embedding with alpha={:.2f}".format(finetuning_args.neft_alpha))
|
160 |
+
|
161 |
+
if use_gradient_checkpointing and getattr(model, "supports_gradient_checkpointing", False):
|
162 |
+
if hasattr(model, "enable_input_require_grads"):
|
163 |
+
model.enable_input_require_grads()
|
164 |
+
else:
|
165 |
+
def make_inputs_require_grad(module: torch.nn.Module, args: Tuple[torch.Tensor], output: torch.Tensor):
|
166 |
+
output.requires_grad_(True)
|
167 |
+
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
|
168 |
+
|
169 |
+
model.gradient_checkpointing_enable()
|
170 |
+
model.config.use_cache = False # turn off when gradient checkpointing is enabled
|
171 |
+
logger.info("Gradient checkpointing enabled.")
|
172 |
+
|
173 |
+
if finetuning_args.finetuning_type != "full" and hasattr(model, output_layer_name):
|
174 |
+
output_layer = getattr(model, output_layer_name)
|
175 |
+
if isinstance(output_layer, torch.nn.Linear):
|
176 |
+
def fp32_forward_pre_hook(module: torch.nn.Module, args: Tuple[torch.Tensor]):
|
177 |
+
return args[0].to(output_layer.weight.dtype)
|
178 |
+
def fp32_forward_post_hook(module: torch.nn.Module, args: Tuple[torch.Tensor], output: torch.Tensor):
|
179 |
+
return output.to(torch.float32)
|
180 |
+
output_layer.register_forward_pre_hook(fp32_forward_pre_hook)
|
181 |
+
output_layer.register_forward_hook(fp32_forward_post_hook)
|
182 |
+
|
183 |
+
return model
|
LLM-Detector-V4-11w/src/llmtuner/train/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from llmtuner.train.tuner import export_model, run_exp
|
LLM-Detector-V4-11w/src/llmtuner/train/dpo/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from llmtuner.train.dpo.workflow import run_dpo
|
LLM-Detector-V4-11w/src/llmtuner/train/dpo/collator.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from dataclasses import dataclass
|
3 |
+
from typing import Any, Dict, List, Sequence, Tuple
|
4 |
+
from transformers import DataCollatorForSeq2Seq
|
5 |
+
|
6 |
+
|
7 |
+
@dataclass
|
8 |
+
class DPODataCollatorWithPadding(DataCollatorForSeq2Seq):
|
9 |
+
r"""
|
10 |
+
Data collator for pairwise data.
|
11 |
+
"""
|
12 |
+
|
13 |
+
def _pad_labels(self, batch: torch.Tensor, positions: List[Tuple[int, int]]) -> torch.Tensor:
|
14 |
+
padded_labels = []
|
15 |
+
for feature, (prompt_len, answer_len) in zip(batch, positions):
|
16 |
+
if self.tokenizer.padding_side == "left":
|
17 |
+
start, end = feature.size(0) - answer_len, feature.size(0)
|
18 |
+
else:
|
19 |
+
start, end = prompt_len, prompt_len + answer_len
|
20 |
+
padded_tensor = self.label_pad_token_id * torch.ones_like(feature)
|
21 |
+
padded_tensor[start:end] = feature[start:end]
|
22 |
+
padded_labels.append(padded_tensor)
|
23 |
+
return torch.stack(padded_labels, dim=0).contiguous() # in contiguous memory
|
24 |
+
|
25 |
+
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
|
26 |
+
r"""
|
27 |
+
Pads batched data to the longest sequence in the batch.
|
28 |
+
|
29 |
+
We generate 2 * n examples where the first n examples represent chosen examples and
|
30 |
+
the last n examples represent rejected examples.
|
31 |
+
"""
|
32 |
+
concatenated_features = []
|
33 |
+
label_positions = []
|
34 |
+
for key in ("chosen_ids", "rejected_ids"):
|
35 |
+
for feature in features:
|
36 |
+
prompt_len, answer_len = len(feature["prompt_ids"]), len(feature[key])
|
37 |
+
concatenated_features.append({
|
38 |
+
"input_ids": feature["prompt_ids"] + feature[key],
|
39 |
+
"attention_mask": [1] * (prompt_len + answer_len)
|
40 |
+
})
|
41 |
+
label_positions.append((prompt_len, answer_len))
|
42 |
+
|
43 |
+
batch = self.tokenizer.pad(
|
44 |
+
concatenated_features,
|
45 |
+
padding=self.padding,
|
46 |
+
max_length=self.max_length,
|
47 |
+
pad_to_multiple_of=self.pad_to_multiple_of,
|
48 |
+
return_tensors=self.return_tensors,
|
49 |
+
)
|
50 |
+
batch["labels"] = self._pad_labels(batch["input_ids"], label_positions)
|
51 |
+
return batch
|
LLM-Detector-V4-11w/src/llmtuner/train/dpo/trainer.py
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from collections import defaultdict
|
3 |
+
from typing import TYPE_CHECKING, Dict, Literal, Optional, Tuple, Union
|
4 |
+
from transformers import BatchEncoding, Trainer
|
5 |
+
from trl import DPOTrainer
|
6 |
+
from trl.trainer.utils import disable_dropout_in_model
|
7 |
+
|
8 |
+
from llmtuner.extras.constants import IGNORE_INDEX
|
9 |
+
|
10 |
+
if TYPE_CHECKING:
|
11 |
+
from transformers import PreTrainedModel
|
12 |
+
|
13 |
+
|
14 |
+
class CustomDPOTrainer(DPOTrainer):
|
15 |
+
|
16 |
+
def __init__(
|
17 |
+
self,
|
18 |
+
beta: float,
|
19 |
+
model: Union["PreTrainedModel", torch.nn.Module],
|
20 |
+
ref_model: Optional[Union["PreTrainedModel", torch.nn.Module]] = None,
|
21 |
+
disable_dropout: Optional[bool] = True,
|
22 |
+
loss_type: Optional[Literal["sigmoid", "hinge"]] = "sigmoid",
|
23 |
+
**kwargs
|
24 |
+
):
|
25 |
+
if disable_dropout:
|
26 |
+
disable_dropout_in_model(model)
|
27 |
+
if ref_model is not None:
|
28 |
+
disable_dropout_in_model(ref_model)
|
29 |
+
|
30 |
+
self.is_encoder_decoder = model.config.is_encoder_decoder
|
31 |
+
self.ref_model = ref_model
|
32 |
+
self.use_dpo_data_collator = True # hack to avoid warning
|
33 |
+
self.generate_during_eval = False # disable at evaluation
|
34 |
+
self.label_pad_token_id = IGNORE_INDEX
|
35 |
+
self.padding_value = 0
|
36 |
+
self.beta = beta
|
37 |
+
self.loss_type = loss_type
|
38 |
+
self._stored_metrics = defaultdict(lambda: defaultdict(list))
|
39 |
+
|
40 |
+
Trainer.__init__(self, model=model, **kwargs)
|
41 |
+
if not hasattr(self, "accelerator"):
|
42 |
+
raise AttributeError("Please update `transformers`.")
|
43 |
+
|
44 |
+
if ref_model is not None:
|
45 |
+
if self.is_deepspeed_enabled:
|
46 |
+
if not (
|
47 |
+
getattr(ref_model, "is_loaded_in_8bit", False)
|
48 |
+
or getattr(ref_model, "is_loaded_in_4bit", False)
|
49 |
+
): # quantized models are already set on the correct device
|
50 |
+
self.ref_model = self._prepare_deepspeed(self.ref_model)
|
51 |
+
else:
|
52 |
+
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
|
53 |
+
|
54 |
+
def concatenated_forward(
|
55 |
+
self,
|
56 |
+
model: Optional[torch.nn.Module] = None,
|
57 |
+
batch: Optional[Dict[str, torch.Tensor]] = None
|
58 |
+
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
|
59 |
+
batch_copied = BatchEncoding({k: v.detach().clone() for k, v in batch.items()}) # avoid error
|
60 |
+
|
61 |
+
all_logits = model(
|
62 |
+
input_ids=batch_copied["input_ids"],
|
63 |
+
attention_mask=batch_copied["attention_mask"],
|
64 |
+
return_dict=True
|
65 |
+
).logits.to(torch.float32)
|
66 |
+
|
67 |
+
all_logps = self._get_batch_logps(
|
68 |
+
all_logits,
|
69 |
+
batch["labels"],
|
70 |
+
average_log_prob=False
|
71 |
+
)
|
72 |
+
batch_size = batch["input_ids"].size(0) // 2
|
73 |
+
chosen_logps, rejected_logps = all_logps.split(batch_size, dim=0)
|
74 |
+
chosen_logits, rejected_logits = all_logits.split(batch_size, dim=0)
|
75 |
+
return chosen_logps, rejected_logps, chosen_logits, rejected_logits
|
LLM-Detector-V4-11w/src/llmtuner/train/dpo/workflow.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Inspired by: https://github.com/huggingface/trl/blob/main/examples/research_projects/stack_llama_2/scripts/dpo_llama2.py
|
2 |
+
|
3 |
+
from typing import TYPE_CHECKING, Optional, List
|
4 |
+
from transformers import Seq2SeqTrainingArguments
|
5 |
+
|
6 |
+
from llmtuner.data import get_dataset, preprocess_dataset, split_dataset
|
7 |
+
from llmtuner.extras.constants import IGNORE_INDEX
|
8 |
+
from llmtuner.extras.ploting import plot_loss
|
9 |
+
from llmtuner.hparams import ModelArguments
|
10 |
+
from llmtuner.model import load_model_and_tokenizer
|
11 |
+
from llmtuner.train.dpo.collator import DPODataCollatorWithPadding
|
12 |
+
from llmtuner.train.dpo.trainer import CustomDPOTrainer
|
13 |
+
from llmtuner.train.utils import create_modelcard_and_push, create_ref_model
|
14 |
+
|
15 |
+
if TYPE_CHECKING:
|
16 |
+
from transformers import TrainerCallback
|
17 |
+
from llmtuner.hparams import DataArguments, FinetuningArguments
|
18 |
+
|
19 |
+
|
20 |
+
def run_dpo(
|
21 |
+
model_args: "ModelArguments",
|
22 |
+
data_args: "DataArguments",
|
23 |
+
training_args: "Seq2SeqTrainingArguments",
|
24 |
+
finetuning_args: "FinetuningArguments",
|
25 |
+
callbacks: Optional[List["TrainerCallback"]] = None
|
26 |
+
):
|
27 |
+
dataset = get_dataset(model_args, data_args)
|
28 |
+
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="sft")
|
29 |
+
dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="rm")
|
30 |
+
data_collator = DPODataCollatorWithPadding(
|
31 |
+
tokenizer=tokenizer,
|
32 |
+
pad_to_multiple_of=4,
|
33 |
+
label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
|
34 |
+
)
|
35 |
+
|
36 |
+
# Create reference model
|
37 |
+
if finetuning_args.ref_model is None and (not training_args.do_train): # use the model itself
|
38 |
+
ref_model = model
|
39 |
+
else:
|
40 |
+
ref_model = create_ref_model(model_args, finetuning_args, stage="dpo")
|
41 |
+
|
42 |
+
# Update arguments
|
43 |
+
training_args_dict = training_args.to_dict()
|
44 |
+
training_args_dict.update(dict(remove_unused_columns=False)) # important for pairwise dataset
|
45 |
+
training_args = Seq2SeqTrainingArguments(**training_args_dict)
|
46 |
+
|
47 |
+
# Initialize our Trainer
|
48 |
+
trainer = CustomDPOTrainer(
|
49 |
+
beta=finetuning_args.dpo_beta,
|
50 |
+
model=model,
|
51 |
+
ref_model=ref_model,
|
52 |
+
args=training_args,
|
53 |
+
tokenizer=tokenizer,
|
54 |
+
data_collator=data_collator,
|
55 |
+
callbacks=callbacks,
|
56 |
+
**split_dataset(dataset, data_args, training_args)
|
57 |
+
)
|
58 |
+
|
59 |
+
# Training
|
60 |
+
if training_args.do_train:
|
61 |
+
train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
|
62 |
+
trainer.save_model()
|
63 |
+
trainer.log_metrics("train", train_result.metrics)
|
64 |
+
trainer.save_metrics("train", train_result.metrics)
|
65 |
+
trainer.save_state()
|
66 |
+
if trainer.is_world_process_zero() and finetuning_args.plot_loss:
|
67 |
+
plot_loss(training_args.output_dir, keys=["loss", "eval_loss"])
|
68 |
+
|
69 |
+
# Evaluation
|
70 |
+
if training_args.do_eval:
|
71 |
+
metrics = trainer.evaluate(metric_key_prefix="eval")
|
72 |
+
if id(model) == id(ref_model): # unable to compute rewards without a reference model
|
73 |
+
remove_keys = [key for key in metrics.keys() if "rewards" in key]
|
74 |
+
for key in remove_keys:
|
75 |
+
metrics.pop(key)
|
76 |
+
trainer.log_metrics("eval", metrics)
|
77 |
+
trainer.save_metrics("eval", metrics)
|
78 |
+
|
79 |
+
# Create model card
|
80 |
+
create_modelcard_and_push(trainer, model_args, data_args, training_args, finetuning_args)
|
LLM-Detector-V4-11w/src/llmtuner/train/ppo/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from llmtuner.train.ppo.workflow import run_ppo
|
LLM-Detector-V4-11w/src/llmtuner/train/ppo/trainer.py
ADDED
@@ -0,0 +1,359 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import math
|
4 |
+
import torch
|
5 |
+
from tqdm import tqdm
|
6 |
+
from typing import TYPE_CHECKING, List, Optional, Tuple
|
7 |
+
|
8 |
+
from transformers import BatchEncoding, GenerationConfig, Trainer, TrainerState, TrainerControl
|
9 |
+
from transformers.utils import WEIGHTS_NAME, SAFE_WEIGHTS_NAME
|
10 |
+
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
|
11 |
+
from transformers.trainer_pt_utils import remove_dummy_checkpoint
|
12 |
+
|
13 |
+
from trl import PPOTrainer
|
14 |
+
from trl.core import PPODecorators, logprobs_from_logits
|
15 |
+
|
16 |
+
from llmtuner.extras.callbacks import LogCallback, SavePeftModelCallback
|
17 |
+
from llmtuner.extras.logging import get_logger
|
18 |
+
from llmtuner.extras.misc import AverageMeter, count_parameters, get_logits_processor
|
19 |
+
from llmtuner.train.ppo.utils import dump_layernorm, restore_layernorm, replace_model
|
20 |
+
|
21 |
+
if TYPE_CHECKING:
|
22 |
+
from transformers import Seq2SeqTrainingArguments, TrainerCallback
|
23 |
+
from trl import AutoModelForCausalLMWithValueHead
|
24 |
+
from llmtuner.hparams import ModelArguments, FinetuningArguments, GeneratingArguments
|
25 |
+
|
26 |
+
|
27 |
+
logger = get_logger(__name__)
|
28 |
+
|
29 |
+
|
30 |
+
class CustomPPOTrainer(PPOTrainer, Trainer):
|
31 |
+
r"""
|
32 |
+
Inherits PPOTrainer.
|
33 |
+
"""
|
34 |
+
|
35 |
+
def __init__(
|
36 |
+
self,
|
37 |
+
model_args: "ModelArguments",
|
38 |
+
training_args: "Seq2SeqTrainingArguments",
|
39 |
+
finetuning_args: "FinetuningArguments",
|
40 |
+
generating_args: "GeneratingArguments",
|
41 |
+
callbacks: List["TrainerCallback"],
|
42 |
+
reward_model: "AutoModelForCausalLMWithValueHead",
|
43 |
+
**kwargs
|
44 |
+
):
|
45 |
+
PPOTrainer.__init__(self, **kwargs)
|
46 |
+
|
47 |
+
self.args = training_args
|
48 |
+
self.model_args = model_args
|
49 |
+
self.finetuning_args = finetuning_args
|
50 |
+
self.reward_model = reward_model
|
51 |
+
|
52 |
+
self.generation_config = GenerationConfig(
|
53 |
+
pad_token_id=self.tokenizer.pad_token_id,
|
54 |
+
eos_token_id=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids,
|
55 |
+
**generating_args.to_dict()
|
56 |
+
)
|
57 |
+
|
58 |
+
self.state = TrainerState()
|
59 |
+
self.control = TrainerControl()
|
60 |
+
self.is_deepspeed_enabled = self.accelerator.distributed_type == "DEEPSPEED" and hasattr(
|
61 |
+
self.accelerator.state, "deepspeed_plugin"
|
62 |
+
)
|
63 |
+
self.log_callback, self.save_callback = callbacks[0], callbacks[1]
|
64 |
+
assert isinstance(self.log_callback, LogCallback) and isinstance(self.save_callback, SavePeftModelCallback)
|
65 |
+
|
66 |
+
if self.args.max_steps > 0:
|
67 |
+
logger.info("max_steps is given, it will override any value given in num_train_epochs")
|
68 |
+
|
69 |
+
if reward_model is not None:
|
70 |
+
if self.is_deepspeed_enabled:
|
71 |
+
if not (
|
72 |
+
getattr(reward_model.pretrained_model, "is_loaded_in_8bit", False)
|
73 |
+
or getattr(reward_model.pretrained_model, "is_loaded_in_4bit", False)
|
74 |
+
): # quantized models are already set on the correct device
|
75 |
+
self.reward_model = self._prepare_deepspeed(self.reward_model)
|
76 |
+
else:
|
77 |
+
self.reward_model = self.accelerator.prepare_model(self.reward_model, evaluation_mode=True)
|
78 |
+
|
79 |
+
def ppo_train(self, resume_from_checkpoint: Optional[str] = None) -> None:
|
80 |
+
r"""
|
81 |
+
Implements training loop for the PPO stage, like _inner_training_loop() in Huggingface's Trainer.
|
82 |
+
"""
|
83 |
+
if resume_from_checkpoint is not None:
|
84 |
+
raise ValueError("`resume_from_checkpoint` will be supported in the future version.")
|
85 |
+
|
86 |
+
total_train_batch_size = (
|
87 |
+
self.args.per_device_train_batch_size
|
88 |
+
* self.args.gradient_accumulation_steps
|
89 |
+
* self.finetuning_args.ppo_buffer_size
|
90 |
+
* self.args.world_size
|
91 |
+
)
|
92 |
+
if self.args.max_steps > 0:
|
93 |
+
num_examples = total_train_batch_size * self.args.max_steps
|
94 |
+
num_train_epochs = sys.maxsize
|
95 |
+
max_steps = self.args.max_steps
|
96 |
+
steps_in_epoch = self.args.max_steps
|
97 |
+
else:
|
98 |
+
len_dataloader = len(self.dataloader)
|
99 |
+
num_examples = len(self.dataset)
|
100 |
+
num_train_epochs = self.args.num_train_epochs
|
101 |
+
max_steps = math.ceil(num_train_epochs * len_dataloader)
|
102 |
+
steps_in_epoch = len_dataloader
|
103 |
+
|
104 |
+
self.state.max_steps = max_steps
|
105 |
+
self.state.num_train_epochs = num_train_epochs
|
106 |
+
self.state.is_local_process_zero = self.is_local_process_zero()
|
107 |
+
self.state.is_world_process_zero = self.is_world_process_zero()
|
108 |
+
|
109 |
+
if self.is_world_process_zero():
|
110 |
+
logger.info("***** Running training *****")
|
111 |
+
logger.info(" Num examples = {}".format(num_examples))
|
112 |
+
logger.info(" Num Epochs = {}".format(num_train_epochs))
|
113 |
+
logger.info(" Instantaneous batch size per device = {}".format(self.args.per_device_train_batch_size))
|
114 |
+
logger.info(" Total train batch size (w. parallel, buffer, distributed & accumulation) = {}".format(
|
115 |
+
total_train_batch_size
|
116 |
+
))
|
117 |
+
logger.info(" Gradient Accumulation steps = {}".format(self.args.gradient_accumulation_steps))
|
118 |
+
logger.info(" Num optimization epochs per batch = {}".format(self.finetuning_args.ppo_epochs))
|
119 |
+
logger.info(" Total training steps = {}".format(max_steps))
|
120 |
+
logger.info(" Number of trainable parameters = {}".format(count_parameters(self.model)[0]))
|
121 |
+
|
122 |
+
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
|
123 |
+
dataiter = iter(self.dataloader)
|
124 |
+
loss_meter = AverageMeter()
|
125 |
+
reward_meter = AverageMeter()
|
126 |
+
self.log_callback.on_train_begin(self.args, self.state, self.control)
|
127 |
+
|
128 |
+
for step in tqdm(range(max_steps), disable=not self.is_local_process_zero()):
|
129 |
+
try:
|
130 |
+
batch = next(dataiter)
|
131 |
+
except StopIteration:
|
132 |
+
dataiter = iter(self.dataloader)
|
133 |
+
batch = next(dataiter)
|
134 |
+
|
135 |
+
# Cast to inference mode
|
136 |
+
unwrapped_model.gradient_checkpointing_disable()
|
137 |
+
unwrapped_model.config.use_cache = True
|
138 |
+
self.model.eval()
|
139 |
+
|
140 |
+
# Get inputs
|
141 |
+
self.tokenizer.padding_side = "right" # change padding side
|
142 |
+
queries, responses, rewards = [], [], []
|
143 |
+
for idx in range(0, self.config.batch_size, self.config.mini_batch_size):
|
144 |
+
mini_batch_queries, mini_batch_responses = self.get_inputs(batch[idx:idx+self.config.mini_batch_size])
|
145 |
+
mini_batch_rewards = self.get_rewards(mini_batch_queries, mini_batch_responses, unwrapped_model)
|
146 |
+
queries.extend(mini_batch_queries)
|
147 |
+
responses.extend(mini_batch_responses)
|
148 |
+
rewards.extend(mini_batch_rewards)
|
149 |
+
|
150 |
+
# Cast to training mode
|
151 |
+
unwrapped_model.gradient_checkpointing_enable()
|
152 |
+
unwrapped_model.config.use_cache = False
|
153 |
+
self.model.train()
|
154 |
+
|
155 |
+
# Run PPO step
|
156 |
+
stats = self.step(queries, responses, rewards)
|
157 |
+
self.tokenizer.padding_side = "left" # restore padding side
|
158 |
+
loss_meter.update(float(stats["ppo/loss/total"]), n=len(rewards))
|
159 |
+
reward_meter.update(torch.stack(rewards).mean().item(), n=len(rewards))
|
160 |
+
|
161 |
+
if self.config.log_with is not None:
|
162 |
+
try:
|
163 |
+
batch["query"] = self.tokenizer.batch_decode(queries, skip_special_tokens=True)
|
164 |
+
batch["response"] = self.tokenizer.batch_decode(responses, skip_special_tokens=True)
|
165 |
+
self.log_stats(stats, batch, rewards)
|
166 |
+
except:
|
167 |
+
logger.warning("Failed to save stats due to unknown errors.")
|
168 |
+
|
169 |
+
self.state.global_step += 1
|
170 |
+
self.log_callback.on_step_end(self.args, self.state, self.control)
|
171 |
+
|
172 |
+
if self.is_local_process_zero() and (step+1) % self.args.logging_steps == 0:
|
173 |
+
logs = dict(
|
174 |
+
loss=round(loss_meter.avg, 4),
|
175 |
+
reward=round(reward_meter.avg, 4),
|
176 |
+
learning_rate=stats["ppo/learning_rate"],
|
177 |
+
epoch=round(step / steps_in_epoch, 2)
|
178 |
+
)
|
179 |
+
tqdm.write(str(logs))
|
180 |
+
logs["step"] = step
|
181 |
+
self.state.log_history.append(logs)
|
182 |
+
self.log_callback.on_log(self.args, self.state, self.control)
|
183 |
+
loss_meter.reset()
|
184 |
+
reward_meter.reset()
|
185 |
+
|
186 |
+
if (step+1) % self.args.save_steps == 0: # save checkpoint
|
187 |
+
self.save_model(os.path.join(
|
188 |
+
self.args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, self.state.global_step)
|
189 |
+
))
|
190 |
+
self.save_callback.on_save(
|
191 |
+
self.args, self.state, self.control, model=self.accelerator.unwrap_model(self.model)
|
192 |
+
)
|
193 |
+
|
194 |
+
if self.control.should_epoch_stop or self.control.should_training_stop:
|
195 |
+
break
|
196 |
+
|
197 |
+
self.log_callback.on_train_end(self.args, self.state, self.control)
|
198 |
+
self.save_callback.on_train_end(
|
199 |
+
self.args, self.state, self.control, model=self.accelerator.unwrap_model(self.model)
|
200 |
+
)
|
201 |
+
|
202 |
+
@torch.no_grad()
|
203 |
+
def get_inputs(self, batch: BatchEncoding) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
|
204 |
+
r"""
|
205 |
+
Generates model's responses given queries.
|
206 |
+
"""
|
207 |
+
if self.finetuning_args.upcast_layernorm:
|
208 |
+
layernorm_params = dump_layernorm(self.model)
|
209 |
+
|
210 |
+
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
|
211 |
+
response: torch.Tensor = unwrapped_model.generate(
|
212 |
+
generation_config=self.generation_config,
|
213 |
+
logits_processor=get_logits_processor(),
|
214 |
+
**batch
|
215 |
+
)
|
216 |
+
|
217 |
+
if self.finetuning_args.upcast_layernorm:
|
218 |
+
restore_layernorm(self.model, layernorm_params)
|
219 |
+
|
220 |
+
query, response = batch["input_ids"].detach().cpu(), response[:, batch["input_ids"].size(-1):].detach().cpu()
|
221 |
+
queries, responses = [], []
|
222 |
+
for i in range(len(query)):
|
223 |
+
query_length = (query[i] != self.tokenizer.pad_token_id).nonzero()[0].item()
|
224 |
+
response_index = (response[i] != self.tokenizer.pad_token_id).nonzero()
|
225 |
+
|
226 |
+
if len(response_index) == 0:
|
227 |
+
response_length = 1 # allow empty response
|
228 |
+
else:
|
229 |
+
response_length = response_index[-1].item() + 1
|
230 |
+
|
231 |
+
queries.append(query[i, query_length:]) # remove padding from left
|
232 |
+
responses.append(response[i, :response_length]) # remove padding from right
|
233 |
+
|
234 |
+
return queries, responses
|
235 |
+
|
236 |
+
@torch.no_grad()
|
237 |
+
def get_rewards(
|
238 |
+
self,
|
239 |
+
queries: List[torch.Tensor],
|
240 |
+
responses: List[torch.Tensor],
|
241 |
+
unwrapped_model: "AutoModelForCausalLMWithValueHead"
|
242 |
+
) -> List[torch.Tensor]:
|
243 |
+
r"""
|
244 |
+
Computes scores using given reward model.
|
245 |
+
"""
|
246 |
+
if self.reward_model is None:
|
247 |
+
replace_model(unwrapped_model, target="reward")
|
248 |
+
|
249 |
+
batch = self.prepare_model_inputs(queries, responses)
|
250 |
+
|
251 |
+
with torch.cuda.amp.autocast(dtype=self.model_args.compute_dtype): # support bf16
|
252 |
+
reward_model = self.reward_model if self.reward_model is not None else self.model
|
253 |
+
_, _, values = reward_model(**batch, output_hidden_states=True, return_dict=True)
|
254 |
+
|
255 |
+
if getattr(unwrapped_model.config, "model_type", None) == "chatglm":
|
256 |
+
values = torch.transpose(values, 0, 1)
|
257 |
+
|
258 |
+
rewards = []
|
259 |
+
for i in range(values.size(0)):
|
260 |
+
end_indexes = (batch["input_ids"][i] != self.tokenizer.pad_token_id).nonzero()
|
261 |
+
end_index = end_indexes[-1].item() if len(end_indexes) else 0
|
262 |
+
rewards.append(values[i, end_index].float().detach().cpu()) # use fp32 type
|
263 |
+
|
264 |
+
if self.reward_model is None:
|
265 |
+
replace_model(unwrapped_model, target="default")
|
266 |
+
|
267 |
+
return rewards
|
268 |
+
|
269 |
+
@PPODecorators.empty_device_cache()
|
270 |
+
def batched_forward_pass(
|
271 |
+
self,
|
272 |
+
model: "AutoModelForCausalLMWithValueHead",
|
273 |
+
queries: torch.Tensor,
|
274 |
+
responses: torch.Tensor,
|
275 |
+
model_inputs: dict,
|
276 |
+
return_logits: Optional[bool] = False,
|
277 |
+
response_masks: Optional[torch.Tensor] = None
|
278 |
+
):
|
279 |
+
r"""
|
280 |
+
Calculates model outputs in multiple batches.
|
281 |
+
|
282 |
+
Subclass and override to inject custom behavior.
|
283 |
+
"""
|
284 |
+
bs = len(queries)
|
285 |
+
fbs = self.config.mini_batch_size
|
286 |
+
all_logprobs = []
|
287 |
+
all_logits = []
|
288 |
+
all_masks = []
|
289 |
+
all_values = []
|
290 |
+
|
291 |
+
for i in range(math.ceil(bs / fbs)):
|
292 |
+
input_kwargs = {key: value[i * fbs : (i + 1) * fbs] for key, value in model_inputs.items()}
|
293 |
+
query_batch = queries[i * fbs : (i + 1) * fbs]
|
294 |
+
response_batch = responses[i * fbs : (i + 1) * fbs]
|
295 |
+
if response_masks is not None:
|
296 |
+
response_masks_batch = response_masks[i * fbs : (i + 1) * fbs]
|
297 |
+
input_ids = input_kwargs["input_ids"]
|
298 |
+
attention_mask = input_kwargs["attention_mask"]
|
299 |
+
|
300 |
+
with torch.cuda.amp.autocast(dtype=self.model_args.compute_dtype): # support bf16
|
301 |
+
logits, _, values = model(**input_kwargs)
|
302 |
+
|
303 |
+
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
|
304 |
+
if getattr(unwrapped_model.config, "model_type", None) == "chatglm":
|
305 |
+
values = torch.transpose(values, 0, 1)
|
306 |
+
|
307 |
+
logprobs = logprobs_from_logits(logits[:, :-1, :], input_ids[:, 1:])
|
308 |
+
masks = torch.zeros_like(attention_mask)
|
309 |
+
masks[:, :-1] = attention_mask[:, 1:]
|
310 |
+
|
311 |
+
for j in range(len(query_batch)):
|
312 |
+
start = len(query_batch[j]) - 1
|
313 |
+
if attention_mask[j, 0] == 0: # offset left padding
|
314 |
+
start += attention_mask[j, :].nonzero()[0].item()
|
315 |
+
end = start + len(response_batch[j])
|
316 |
+
|
317 |
+
if response_masks is not None:
|
318 |
+
response_masks_batch = torch.cat(
|
319 |
+
(torch.zeros_like(query_batch[j]), response_masks_batch[j])
|
320 |
+
)[1:]
|
321 |
+
|
322 |
+
masks[j, :start] = 0
|
323 |
+
masks[j, end:] = 0
|
324 |
+
if response_masks is not None:
|
325 |
+
masks[j, start:end] = masks[j, start:end] * response_masks_batch[j][start:end]
|
326 |
+
|
327 |
+
if return_logits:
|
328 |
+
all_logits.append(logits)
|
329 |
+
else:
|
330 |
+
del logits
|
331 |
+
|
332 |
+
all_values.append(values)
|
333 |
+
all_logprobs.append(logprobs)
|
334 |
+
all_masks.append(masks)
|
335 |
+
|
336 |
+
return (
|
337 |
+
torch.cat(all_logprobs),
|
338 |
+
torch.cat(all_logits)[:, :-1] if return_logits else None,
|
339 |
+
torch.cat(all_values)[:, :-1],
|
340 |
+
torch.cat(all_masks)[:, :-1],
|
341 |
+
)
|
342 |
+
|
343 |
+
def save_model(self, output_dir: Optional[str] = None) -> None:
|
344 |
+
r"""
|
345 |
+
Saves model checkpoint.
|
346 |
+
|
347 |
+
Subclass and override to inject custom behavior.
|
348 |
+
"""
|
349 |
+
if self.args.should_save:
|
350 |
+
try:
|
351 |
+
self._save(output_dir, state_dict=self.accelerator.get_state_dict(self.model))
|
352 |
+
except ValueError:
|
353 |
+
logger.warning(
|
354 |
+
" stage3_gather_16bit_weights_on_model_save=false. Saving the full checkpoint instead, use"
|
355 |
+
" zero_to_fp32.py to recover weights"
|
356 |
+
)
|
357 |
+
self._save(output_dir, state_dict={})
|
358 |
+
remove_dummy_checkpoint(self.args.should_save, output_dir, [WEIGHTS_NAME, SAFE_WEIGHTS_NAME])
|
359 |
+
self.model.save_checkpoint(output_dir) # wrapped model
|
LLM-Detector-V4-11w/src/llmtuner/train/ppo/utils.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from typing import TYPE_CHECKING, Dict, Literal, Optional
|
3 |
+
|
4 |
+
if TYPE_CHECKING:
|
5 |
+
from transformers import PreTrainedModel
|
6 |
+
from trl import AutoModelForCausalLMWithValueHead
|
7 |
+
|
8 |
+
|
9 |
+
def replace_model(model: "AutoModelForCausalLMWithValueHead", target: Literal["default", "reward"]) -> None:
|
10 |
+
if target == "reward": # save default head temporarily
|
11 |
+
valuehead_state_dict: Dict[str, torch.Tensor] = model.v_head.state_dict()
|
12 |
+
setattr(model, "default_head_weight", valuehead_state_dict["summary.weight"].detach().clone())
|
13 |
+
setattr(model, "default_head_bias", valuehead_state_dict["summary.bias"].detach().clone())
|
14 |
+
|
15 |
+
model.pretrained_model.set_adapter(target) # set the LoRA adapter to be active
|
16 |
+
model.v_head.load_state_dict({
|
17 |
+
"summary.weight": model.get_buffer("{}_head_weight".format(target)).detach().clone(),
|
18 |
+
"summary.bias": model.get_buffer("{}_head_bias".format(target)).detach().clone()
|
19 |
+
})
|
20 |
+
|
21 |
+
|
22 |
+
def dump_layernorm(model: "PreTrainedModel") -> Dict[str, torch.Tensor]:
|
23 |
+
layer_norm_params = {}
|
24 |
+
for name, param in model.named_parameters():
|
25 |
+
if param.data.dtype == torch.float32:
|
26 |
+
layer_norm_params[name] = param.data.detach().clone()
|
27 |
+
param.data = param.data.to(model.config.torch_dtype)
|
28 |
+
|
29 |
+
return layer_norm_params
|
30 |
+
|
31 |
+
|
32 |
+
def restore_layernorm(model: "PreTrainedModel", layernorm_params: Optional[Dict[str, torch.Tensor]] = None) -> None:
|
33 |
+
for name, param in model.named_parameters():
|
34 |
+
if name in layernorm_params:
|
35 |
+
param.data = layernorm_params[name]
|
LLM-Detector-V4-11w/src/llmtuner/train/ppo/workflow.py
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Inspired by: https://github.com/lvwerra/trl/blob/main/examples/research_projects/stack_llama/scripts/rl_training.py
|
2 |
+
|
3 |
+
import math
|
4 |
+
from trl import PPOConfig
|
5 |
+
from torch.optim import AdamW
|
6 |
+
from typing import TYPE_CHECKING, Optional, List
|
7 |
+
from transformers import DataCollatorWithPadding
|
8 |
+
from transformers.optimization import get_scheduler
|
9 |
+
|
10 |
+
from llmtuner.data import get_dataset, preprocess_dataset
|
11 |
+
from llmtuner.extras.callbacks import SavePeftModelCallback
|
12 |
+
from llmtuner.extras.ploting import plot_loss
|
13 |
+
from llmtuner.model import load_model_and_tokenizer
|
14 |
+
from llmtuner.train.utils import create_ref_model, create_reward_model
|
15 |
+
from llmtuner.train.ppo.trainer import CustomPPOTrainer
|
16 |
+
|
17 |
+
if TYPE_CHECKING:
|
18 |
+
from transformers import Seq2SeqTrainingArguments, TrainerCallback
|
19 |
+
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments
|
20 |
+
|
21 |
+
|
22 |
+
def run_ppo(
|
23 |
+
model_args: "ModelArguments",
|
24 |
+
data_args: "DataArguments",
|
25 |
+
training_args: "Seq2SeqTrainingArguments",
|
26 |
+
finetuning_args: "FinetuningArguments",
|
27 |
+
generating_args: "GeneratingArguments",
|
28 |
+
callbacks: Optional[List["TrainerCallback"]] = None
|
29 |
+
):
|
30 |
+
dataset = get_dataset(model_args, data_args)
|
31 |
+
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="ppo")
|
32 |
+
dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="ppo")
|
33 |
+
|
34 |
+
tokenizer.padding_side = "left" # use left-padding in generation while using right-padding in training
|
35 |
+
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
|
36 |
+
|
37 |
+
# Create reference model and reward model
|
38 |
+
ref_model = create_ref_model(model_args, finetuning_args, stage="ppo")
|
39 |
+
reward_model = create_reward_model(model, model_args, finetuning_args)
|
40 |
+
|
41 |
+
# Create ppo config
|
42 |
+
backward_batch_size = training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps
|
43 |
+
ppo_config = PPOConfig(
|
44 |
+
model_name=model_args.model_name_or_path,
|
45 |
+
learning_rate=training_args.learning_rate,
|
46 |
+
mini_batch_size=training_args.per_device_train_batch_size,
|
47 |
+
batch_size=backward_batch_size * finetuning_args.ppo_buffer_size,
|
48 |
+
gradient_accumulation_steps=training_args.gradient_accumulation_steps,
|
49 |
+
ppo_epochs=finetuning_args.ppo_epochs,
|
50 |
+
max_grad_norm=training_args.max_grad_norm,
|
51 |
+
seed=training_args.seed,
|
52 |
+
optimize_device_cache=True,
|
53 |
+
target=finetuning_args.ppo_target,
|
54 |
+
log_with=finetuning_args.ppo_logger,
|
55 |
+
use_score_scaling=finetuning_args.ppo_score_norm,
|
56 |
+
use_score_norm=finetuning_args.ppo_score_norm,
|
57 |
+
whiten_rewards=finetuning_args.ppo_whiten_rewards,
|
58 |
+
accelerator_kwargs={"step_scheduler_with_optimizer": False}
|
59 |
+
)
|
60 |
+
|
61 |
+
# Create optimizer and scheduler
|
62 |
+
optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=training_args.learning_rate)
|
63 |
+
if training_args.max_steps > 0:
|
64 |
+
num_training_steps = training_args.max_steps
|
65 |
+
else:
|
66 |
+
total_train_batch_size = backward_batch_size * finetuning_args.ppo_buffer_size * training_args.world_size
|
67 |
+
num_training_steps = training_args.num_train_epochs * math.ceil(len(dataset) / total_train_batch_size)
|
68 |
+
|
69 |
+
lr_scheduler = get_scheduler(
|
70 |
+
training_args.lr_scheduler_type,
|
71 |
+
optimizer=optimizer,
|
72 |
+
num_warmup_steps=training_args.get_warmup_steps(num_training_steps),
|
73 |
+
num_training_steps=num_training_steps
|
74 |
+
)
|
75 |
+
|
76 |
+
# Initialize our Trainer
|
77 |
+
ppo_trainer = CustomPPOTrainer(
|
78 |
+
model_args=model_args,
|
79 |
+
training_args=training_args,
|
80 |
+
finetuning_args=finetuning_args,
|
81 |
+
generating_args=generating_args,
|
82 |
+
callbacks=callbacks + [SavePeftModelCallback()],
|
83 |
+
reward_model=reward_model,
|
84 |
+
config=ppo_config,
|
85 |
+
model=model,
|
86 |
+
ref_model=ref_model,
|
87 |
+
tokenizer=tokenizer,
|
88 |
+
dataset=dataset,
|
89 |
+
data_collator=data_collator,
|
90 |
+
optimizer=optimizer,
|
91 |
+
lr_scheduler=lr_scheduler
|
92 |
+
)
|
93 |
+
|
94 |
+
# Training
|
95 |
+
if training_args.do_train:
|
96 |
+
ppo_trainer.ppo_train(resume_from_checkpoint=training_args.resume_from_checkpoint)
|
97 |
+
ppo_trainer.save_model()
|
98 |
+
ppo_trainer.save_state() # must be called after save_model to have a folder
|
99 |
+
if ppo_trainer.is_world_process_zero() and finetuning_args.plot_loss:
|
100 |
+
plot_loss(training_args.output_dir, keys=["loss", "reward"])
|
LLM-Detector-V4-11w/src/llmtuner/train/pt/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from llmtuner.train.pt.workflow import run_pt
|
LLM-Detector-V4-11w/src/llmtuner/train/pt/workflow.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Inspired by: https://github.com/huggingface/transformers/blob/v4.34.1/examples/pytorch/language-modeling/run_clm.py
|
2 |
+
|
3 |
+
import math
|
4 |
+
from typing import TYPE_CHECKING, Optional, List
|
5 |
+
from transformers import DataCollatorForLanguageModeling, Trainer
|
6 |
+
|
7 |
+
from llmtuner.data import get_dataset, preprocess_dataset, split_dataset
|
8 |
+
from llmtuner.extras.ploting import plot_loss
|
9 |
+
from llmtuner.model import load_model_and_tokenizer
|
10 |
+
from llmtuner.train.utils import create_modelcard_and_push
|
11 |
+
|
12 |
+
if TYPE_CHECKING:
|
13 |
+
from transformers import Seq2SeqTrainingArguments, TrainerCallback
|
14 |
+
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments
|
15 |
+
|
16 |
+
|
17 |
+
def run_pt(
|
18 |
+
model_args: "ModelArguments",
|
19 |
+
data_args: "DataArguments",
|
20 |
+
training_args: "Seq2SeqTrainingArguments",
|
21 |
+
finetuning_args: "FinetuningArguments",
|
22 |
+
callbacks: Optional[List["TrainerCallback"]] = None
|
23 |
+
):
|
24 |
+
dataset = get_dataset(model_args, data_args)
|
25 |
+
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="pt")
|
26 |
+
dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="pt")
|
27 |
+
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
|
28 |
+
|
29 |
+
# Initialize our Trainer
|
30 |
+
trainer = Trainer(
|
31 |
+
model=model,
|
32 |
+
args=training_args,
|
33 |
+
tokenizer=tokenizer,
|
34 |
+
data_collator=data_collator,
|
35 |
+
callbacks=callbacks,
|
36 |
+
**split_dataset(dataset, data_args, training_args)
|
37 |
+
)
|
38 |
+
|
39 |
+
# Training
|
40 |
+
if training_args.do_train:
|
41 |
+
train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
|
42 |
+
trainer.save_model()
|
43 |
+
trainer.log_metrics("train", train_result.metrics)
|
44 |
+
trainer.save_metrics("train", train_result.metrics)
|
45 |
+
trainer.save_state()
|
46 |
+
if trainer.is_world_process_zero() and finetuning_args.plot_loss:
|
47 |
+
plot_loss(training_args.output_dir, keys=["loss", "eval_loss"])
|
48 |
+
|
49 |
+
# Evaluation
|
50 |
+
if training_args.do_eval:
|
51 |
+
metrics = trainer.evaluate(metric_key_prefix="eval")
|
52 |
+
try:
|
53 |
+
perplexity = math.exp(metrics["eval_loss"])
|
54 |
+
except OverflowError:
|
55 |
+
perplexity = float("inf")
|
56 |
+
|
57 |
+
metrics["perplexity"] = perplexity
|
58 |
+
trainer.log_metrics("eval", metrics)
|
59 |
+
trainer.save_metrics("eval", metrics)
|
60 |
+
|
61 |
+
# Create model card
|
62 |
+
create_modelcard_and_push(trainer, model_args, data_args, training_args, finetuning_args)
|
LLM-Detector-V4-11w/src/llmtuner/train/rm/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from llmtuner.train.rm.workflow import run_rm
|