philschmid
commited on
Commit
·
e124d2a
1
Parent(s):
02124fa
Update app.py
Browse files
app.py
CHANGED
@@ -31,18 +31,18 @@ else:
|
|
31 |
# torch_dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] == 8 else torch.float16
|
32 |
# model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch_dtype, device_map="auto")
|
33 |
model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", load_in_8bit=True)
|
34 |
-
|
35 |
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
36 |
|
37 |
prompt_template = f"### Anweisung:\n{{input}}\n\n### Antwort:"
|
38 |
|
39 |
|
40 |
-
def generate(instruction, temperature, max_new_tokens, top_p, length_penalty):
|
41 |
formatted_instruction = prompt_template.format(input=instruction)
|
42 |
# COMMENT IN FOR NON STREAMING
|
43 |
# generation_config = GenerationConfig(
|
44 |
# do_sample=True,
|
45 |
# top_p=top_p,
|
|
|
46 |
# temperature=temperature,
|
47 |
# max_new_tokens=max_new_tokens,
|
48 |
# early_stopping=True,
|
@@ -71,7 +71,9 @@ def generate(instruction, temperature, max_new_tokens, top_p, length_penalty):
|
|
71 |
|
72 |
generate_kwargs = dict(
|
73 |
top_p=top_p,
|
|
|
74 |
temperature=temperature,
|
|
|
75 |
max_new_tokens=max_new_tokens,
|
76 |
early_stopping=True,
|
77 |
length_penalty=length_penalty,
|
@@ -142,7 +144,13 @@ with gr.Blocks(theme=theme) as demo:
|
|
142 |
placeholder="Hier Antwort erscheint...",
|
143 |
)
|
144 |
submit = gr.Button("Generate", variant="primary")
|
145 |
-
gr.Examples(
|
|
|
|
|
|
|
|
|
|
|
|
|
146 |
|
147 |
with gr.Column(scale=1):
|
148 |
temperature = gr.Slider(
|
|
|
31 |
# torch_dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] == 8 else torch.float16
|
32 |
# model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch_dtype, device_map="auto")
|
33 |
model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", load_in_8bit=True)
|
|
|
34 |
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
35 |
|
36 |
prompt_template = f"### Anweisung:\n{{input}}\n\n### Antwort:"
|
37 |
|
38 |
|
39 |
+
def generate(instruction, temperature=1, max_new_tokens=256, top_p=0.9, length_penalty=1.0):
|
40 |
formatted_instruction = prompt_template.format(input=instruction)
|
41 |
# COMMENT IN FOR NON STREAMING
|
42 |
# generation_config = GenerationConfig(
|
43 |
# do_sample=True,
|
44 |
# top_p=top_p,
|
45 |
+
# top_k=0,
|
46 |
# temperature=temperature,
|
47 |
# max_new_tokens=max_new_tokens,
|
48 |
# early_stopping=True,
|
|
|
71 |
|
72 |
generate_kwargs = dict(
|
73 |
top_p=top_p,
|
74 |
+
top_k=0,
|
75 |
temperature=temperature,
|
76 |
+
do_sample=True,
|
77 |
max_new_tokens=max_new_tokens,
|
78 |
early_stopping=True,
|
79 |
length_penalty=length_penalty,
|
|
|
144 |
placeholder="Hier Antwort erscheint...",
|
145 |
)
|
146 |
submit = gr.Button("Generate", variant="primary")
|
147 |
+
gr.Examples(
|
148 |
+
examples=examples,
|
149 |
+
inputs=[instruction],
|
150 |
+
# cache_examples=True,
|
151 |
+
# fn=generate,
|
152 |
+
# outputs=[output],
|
153 |
+
)
|
154 |
|
155 |
with gr.Column(scale=1):
|
156 |
temperature = gr.Slider(
|