sudokara commited on
Commit
67158a0
·
1 Parent(s): 6d2998d

switched to pipeline

Browse files
Files changed (1) hide show
  1. app.py +27 -11
app.py CHANGED
@@ -7,19 +7,35 @@ import transformers
7
  # Load the model and tokenizer
8
  model_name = "Artigenz/Artigenz-Coder-DS-6.7B"
9
  tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
10
- model = transformers.AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
- # Function to generate responses from the model
13
  def generate_response(input_text):
14
- inputs = tokenizer(input_text, return_tensors="pt")
15
- input_ids = inputs["input_ids"].to(model.device)
16
- attention_mask = inputs["attention_mask"].to(model.device)
17
-
18
- # Generate the output
19
- outputs = model.generate(input_ids=input_ids, attention_mask=attention_mask, max_new_tokens=512)
20
- response = tokenizer.decode(outputs[0], skip_special_tokens=True)
21
-
22
- return response
23
 
24
  # Define the Gradio interface
25
  iface = gr.Interface(
 
7
  # Load the model and tokenizer
8
  model_name = "Artigenz/Artigenz-Coder-DS-6.7B"
9
  tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
10
+ model = transformers.AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", torch_dtype="auto")
11
+ max_new_tokens:int=1024
12
+ do_sample:bool=True
13
+ num_beams:int=1
14
+ temperature:float=0.5
15
+ top_p:float=0.95
16
+ top_k:float=40
17
+ repetition_penalty:float=1.1
18
+ pipe = transformers.pipeline(
19
+ "text-generation",
20
+ model=model,
21
+ tokenizer=tokenizer,
22
+ max_new_tokens=max_new_tokens,
23
+ do_sample=do_sample,
24
+ num_beams=num_beams,
25
+ temperature=temperature,
26
+ top_p=top_p,
27
+ top_k=top_k,
28
+ repetition_penalty=repetition_penalty,
29
+ )
30
 
 
31
  def generate_response(input_text):
32
+ messages = [
33
+ {
34
+ "role": "system", "content": "You are a helpful coding chatbot. You will answer the user's questions to the best of your ability.",
35
+ "role": "user", "content": input_text,
36
+ },
37
+ ]
38
+ return pipe(messages)[0]['generated_text'][-1]['content'].replace("\\n", "\n")
 
 
39
 
40
  # Define the Gradio interface
41
  iface = gr.Interface(