MohamedRashad commited on
Commit
00ec273
·
1 Parent(s): 5925564

Refactor model ID handling in app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -33
app.py CHANGED
@@ -20,47 +20,47 @@ def load_model_a(model_id):
20
  global tokenizer_a, model_a
21
  tokenizer_a = AutoTokenizer.from_pretrained(model_id)
22
  print(f"model A: {tokenizer_a.eos_token}")
23
- try:
24
- model_a = AutoModelForCausalLM.from_pretrained(
25
- model_id,
26
- torch_dtype=torch.bfloat16,
27
- device_map="auto",
28
- attn_implementation="flash_attention_2",
29
- trust_remote_code=True,
30
- ).eval()
31
- except:
32
- print(f"Using default attention implementation in {model_id}")
33
- model_a = AutoModelForCausalLM.from_pretrained(
34
- model_id,
35
- torch_dtype=torch.bfloat16,
36
- device_map="auto",
37
- trust_remote_code=True,
38
- ).eval()
39
  return gr.update(label=model_id)
40
 
41
  def load_model_b(model_id):
42
  global tokenizer_b, model_b
43
  tokenizer_b = AutoTokenizer.from_pretrained(model_id)
44
  print(f"model B: {tokenizer_b.eos_token}")
45
- try:
46
- model_b = AutoModelForCausalLM.from_pretrained(
47
- model_id,
48
- torch_dtype=torch.bfloat16,
49
- device_map="auto",
50
- attn_implementation="flash_attention_2",
51
- trust_remote_code=True,
52
- ).eval()
53
- except:
54
- print(f"Using default attention implementation in {model_id}")
55
- model_b = AutoModelForCausalLM.from_pretrained(
56
- model_id,
57
- torch_dtype=torch.bfloat16,
58
- device_map="auto",
59
- trust_remote_code=True,
60
- ).eval()
61
  return gr.update(label=model_id)
62
 
63
- @spaces.GPU(duration=120)
64
  def generate_both(system_prompt, input_text, chatbot_a, chatbot_b, max_new_tokens=2048, temperature=0.2, top_p=0.9, repetition_penalty=1.1):
65
  text_streamer_a = TextIteratorStreamer(tokenizer_a, skip_prompt=True)
66
  text_streamer_b = TextIteratorStreamer(tokenizer_b, skip_prompt=True)
 
20
  global tokenizer_a, model_a
21
  tokenizer_a = AutoTokenizer.from_pretrained(model_id)
22
  print(f"model A: {tokenizer_a.eos_token}")
23
+ model_a = AutoModelForCausalLM.from_pretrained(
24
+ model_id,
25
+ torch_dtype=torch.bfloat16,
26
+ device_map="auto",
27
+ attn_implementation="flash_attention_2",
28
+ trust_remote_code=True,
29
+ ).eval()
30
+ # try:
31
+ # except:
32
+ # print(f"Using default attention implementation in {model_id}")
33
+ # model_a = AutoModelForCausalLM.from_pretrained(
34
+ # model_id,
35
+ # torch_dtype=torch.bfloat16,
36
+ # device_map="auto",
37
+ # trust_remote_code=True,
38
+ # ).eval()
39
  return gr.update(label=model_id)
40
 
41
  def load_model_b(model_id):
42
  global tokenizer_b, model_b
43
  tokenizer_b = AutoTokenizer.from_pretrained(model_id)
44
  print(f"model B: {tokenizer_b.eos_token}")
45
+ model_b = AutoModelForCausalLM.from_pretrained(
46
+ model_id,
47
+ torch_dtype=torch.bfloat16,
48
+ device_map="auto",
49
+ attn_implementation="flash_attention_2",
50
+ trust_remote_code=True,
51
+ ).eval()
52
+ # try:
53
+ # except:
54
+ # print(f"Using default attention implementation in {model_id}")
55
+ # model_b = AutoModelForCausalLM.from_pretrained(
56
+ # model_id,
57
+ # torch_dtype=torch.bfloat16,
58
+ # device_map="auto",
59
+ # trust_remote_code=True,
60
+ # ).eval()
61
  return gr.update(label=model_id)
62
 
63
+ @spaces.GPU()
64
  def generate_both(system_prompt, input_text, chatbot_a, chatbot_b, max_new_tokens=2048, temperature=0.2, top_p=0.9, repetition_penalty=1.1):
65
  text_streamer_a = TextIteratorStreamer(tokenizer_a, skip_prompt=True)
66
  text_streamer_b = TextIteratorStreamer(tokenizer_b, skip_prompt=True)