Russian
English
zjkarina commited on
Commit
22abae5
·
verified ·
1 Parent(s): d9473b4

Create README.md

Browse files
Files changed (1) hide show
  1. README.md +77 -0
README.md ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ```python
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
3
+
4
+ with ('generation_config.json').open('w') as fp:
5
+ json.dump({
6
+ "pad_token_id": 0,
7
+ "bos_token_id": 1,
8
+ "eos_token_id": 2,
9
+ "temperature": 0.3,
10
+ "top_p": 0.9,
11
+ "top_k": 50,
12
+ "do_sample": True,
13
+ "max_new_tokens": 1536,
14
+ "repetition_penalty": 1.1,
15
+ "no_repeat_ngram_size": 15,
16
+ }, fp, indent=4)
17
+
18
+ MODEL_NAME = "Vikhrmodels/Vikhr_instruct"
19
+ TEMPLATE = "<s>{role}\n{content}</s>\n"
20
+ SYSTEM_PROMPT = "Ты – полезный помощник по имени Вихрь. Ты разговариваешь с людьми и помогаешь им."
21
+
22
+ model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
23
+ model.to('cuda')
24
+ model.eval()
25
+
26
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=False)
27
+ generation_config = GenerationConfig.from_pretrained("generation_config.json")
28
+
29
+ class Conversation:
30
+ def __init__(
31
+ self,
32
+ message_template=DEFAULT_MESSAGE_TEMPLATE,
33
+ system_prompt=DEFAULT_SYSTEM_PROMPT,
34
+ ):
35
+ self.message_template = message_template
36
+ self.messages = [{
37
+ "role": "system",
38
+ "content": system_prompt
39
+ }]
40
+
41
+ def add_user_message(self, message):
42
+ self.messages.append({
43
+ "role": "user",
44
+ "content": message
45
+ })
46
+
47
+ def get_prompt(self, tokenizer):
48
+ final_text = ""
49
+ for message in self.messages:
50
+ message_text = self.message_template.format(**message)
51
+ final_text += message_text
52
+ final_text += 'bot'
53
+ return final_text.strip()
54
+
55
+
56
+ def generate(model, tokenizer, prompt, generation_config):
57
+ data = tokenizer(prompt, return_tensors="pt")
58
+ data = {k: v.to(model.device) for k, v in data.items()}
59
+ output_ids = model.generate(
60
+ **data,
61
+ generation_config=generation_config
62
+ )[0]
63
+ output_ids = output_ids[len(data["input_ids"][0]):]
64
+ output = tokenizer.decode(output_ids, skip_special_tokens=True)
65
+ return output.strip()
66
+
67
+ inputs = ["Как тебя зовут?", "Кто такой Колмогоров?"]
68
+
69
+ for inp in inputs:
70
+ conversation = Conversation()
71
+ conversation.add_user_message(inp)
72
+ prompt = conversation.get_prompt(tokenizer)
73
+
74
+ output = generate(model, tokenizer, prompt, generation_config)
75
+ print(inp)
76
+ print(output)
77
+ ```