raccoote commited on
Commit
e3b6619
·
verified ·
1 Parent(s): 910ab21

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -26
app.py CHANGED
@@ -1,35 +1,35 @@
1
- import gradio as gr
2
- from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
3
  from peft import PeftModel
 
4
 
5
- # Define the base model and configuration
6
- base_model_name = "raccoote/angry-birds-v2"
 
 
 
7
 
8
- # Load the tokenizer
9
- tokenizer = AutoTokenizer.from_pretrained(base_model_name)
 
10
 
11
- # Load the model with 8-bit precision
12
- quantization_config = BitsAndBytesConfig(load_in_8bit=True)
13
 
14
- base_model = AutoModelForCausalLM.from_pretrained(
15
- base_model_name,
16
- quantization_config=quantization_config,
17
- device_map="auto" # This will ensure the model is distributed to available hardware
18
- )
19
-
20
- # Load the LoRA adapter from the repository
21
- adapter_model = PeftModel.from_pretrained(base_model, base_model_name)
22
-
23
- def generate_text(prompt):
24
  inputs = tokenizer(prompt, return_tensors="pt")
25
- outputs = adapter_model.generate(**inputs, max_new_tokens=50)
 
 
 
 
 
 
 
26
  return tokenizer.decode(outputs[0], skip_special_tokens=True)
27
 
28
- # Create the Gradio interface
29
- iface = gr.Interface(fn=generate_text,
30
- inputs="text",
31
- outputs="text",
32
- title="LLaMA 3.1 with LoRA Adapters",
33
- description="Enter a prompt and get the model's output.")
34
 
35
- iface.launch()
 
1
+ from transformers import AutoModelForCausalLM, AutoTokenizer
 
2
  from peft import PeftModel
3
+ import torch
4
 
5
+ # Load the base model and tokenizer
6
+ model_id = "unsloth/Meta-Llama-3.1-8B" # Use the appropriate LLaMA 3.1 8b model ID
7
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
8
+ model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float32) # Use torch.float32 for CPU
9
+ model.to("cpu") # Ensure the model is loaded on CPU
10
 
11
+ # Load your LoRA adapter
12
+ adapter_repo = "raccoote/angry-birds-v2" # Your repository path
13
+ adapter_weight_name = "adapter_model.safetensors" # The weight file name
14
 
15
+ # Load LoRA weights
16
+ peft_model = PeftModel.from_pretrained(model, adapter_repo, weight_name=adapter_weight_name, adapter_name="angry_birds")
17
 
18
+ # Prepare for inference
19
+ def generate_text(prompt, model, tokenizer, peft_model, max_length=50):
 
 
 
 
 
 
 
 
20
  inputs = tokenizer(prompt, return_tensors="pt")
21
+ outputs = peft_model.generate(
22
+ **inputs,
23
+ max_length=max_length,
24
+ num_return_sequences=1,
25
+ do_sample=True, # or use `do_sample=False` for deterministic outputs
26
+ top_p=0.95, # or other sampling parameters
27
+ temperature=0.7
28
+ )
29
  return tokenizer.decode(outputs[0], skip_special_tokens=True)
30
 
31
+ # Generate text with the loaded LoRA adapter
32
+ prompt = "large piggy on wooden tower"
33
+ generated_text = generate_text(prompt, model, tokenizer, peft_model)
 
 
 
34
 
35
+ print(generated_text)