File size: 1,696 Bytes
5ee69cd
 
 
4b2efbf
 
 
 
 
 
 
 
 
5ee69cd
454ad0b
 
a1bb37b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f533ef8
 
a1bb37b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4b2efbf
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
---
base_model: unsloth/mistral-7b-bnb-4bit
library_name: peft
license: mit
datasets:
- yahma/alpaca-cleaned
language:
- en
pipeline_tag: text-generation
tags:
- physics
- conversational
---
How to use :
```python
!pip install peft accelerate bitsandbytes
from peft import PeftModel, PeftConfig
from transformers import AutoModelForCausalLM, AutoTokenizer

# Load the configuration for the fine-tuned model
model_id = "Vijayendra/QST-Mistral-7b"
config = PeftConfig.from_pretrained(model_id)

# Load the base model and the fine-tuned model
base_model = AutoModelForCausalLM.from_pretrained(config.base_model_name_or_path)
model = PeftModel.from_pretrained(base_model, model_id)

# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)

# Prepare the input for inference
prompt = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.

### Instruction:
{}

### Input:
{}

### Response:
{}"""

instruction = "Explain the significance of cyclic operators in machine learning theory."
input_text = "Provide a detailed explanation suitable for a beginner in quantum machine learning."
formatted_prompt = prompt.format(instruction, input_text, "")

# Tokenize the input
inputs = tokenizer(
    formatted_prompt,
    return_tensors="pt",
    max_length=2048,
    truncation=True
).to("cuda")

# Run inference
model.to("cuda")
outputs = model.generate(
    **inputs,
    max_new_tokens=512,
    do_sample=True,
    temperature=0.7,
    top_k=50
)

# Decode and print the output
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(response)