raccoote's picture
Update app.py
e3b6619 verified
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
import torch
# Load the base model and tokenizer
model_id = "unsloth/Meta-Llama-3.1-8B" # Use the appropriate LLaMA 3.1 8b model ID
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float32) # Use torch.float32 for CPU
model.to("cpu") # Ensure the model is loaded on CPU
# Load your LoRA adapter
adapter_repo = "raccoote/angry-birds-v2" # Your repository path
adapter_weight_name = "adapter_model.safetensors" # The weight file name
# Load LoRA weights
peft_model = PeftModel.from_pretrained(model, adapter_repo, weight_name=adapter_weight_name, adapter_name="angry_birds")
# Prepare for inference
def generate_text(prompt, model, tokenizer, peft_model, max_length=50):
inputs = tokenizer(prompt, return_tensors="pt")
outputs = peft_model.generate(
**inputs,
max_length=max_length,
num_return_sequences=1,
do_sample=True, # or use `do_sample=False` for deterministic outputs
top_p=0.95, # or other sampling parameters
temperature=0.7
)
return tokenizer.decode(outputs[0], skip_special_tokens=True)
# Generate text with the loaded LoRA adapter
prompt = "large piggy on wooden tower"
generated_text = generate_text(prompt, model, tokenizer, peft_model)
print(generated_text)