File size: 3,705 Bytes
6dee0c5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
---
base_model: ministral/Ministral-3b-instruct
library_name: peft
---

# Ministral-3b-instruct-PromptEnhancing

Ministral-3b-instruct-PromptEnhancing is a LoRA-finetuned instruction-tuned text-generation model. 

This model was releaszed alongside three other models in the 2-3b parameters range, all trained on the same dataset with the same training arguments.


## Model Details

### Model Description

This model is a LoRA fine-tune of [ministral/Ministral-3b-instruct](https://huggingface.co/ministral/Ministral-3b-instruct). 
The goal of this finetune is to provide a light-weight prompt enhancing model for stable diffusion (or other diffusers sharing the same prompting conventions) to make image generation more accessible to everyone.



- **Developed by:** [groloch](https://huggingface.co/groloch)
- **Model type:** LoRA
- **Language(s) (NLP):** English
- **License:** [apache 2.0](https://huggingface.co/datasets/choosealicense/licenses/blob/main/markdown/apache-2.0.md)
- **Finetuned from model:** [ministral/Ministral-3b-instruct](https://huggingface.co/ministral/Ministral-3b-instruct). 

### Model Sources [optional]

- **Paper:** _Coming soon_
- **Demo:** _Coming soon_

## Uses

This model should be used as a prompt-enhancing model for diffusers. To use it, the simplest is to try out at the official [demo](#) (_coming soon_).


### Direct Use

If you want to use it locally, refer to the following code snippet:
```python
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM


base_repo_id = 'ministral/Ministral-3b-instruct'
adapter_repo_id = 'groloch/Ministral-3b-instruct-PromptEnhancing'

tokenizer = AutoTokenizer.from_pretrained(base_repo_id)
model = AutoModelForCausalLM.from_pretrained(base_repo_id, torch_dtype=torch.bfloat16).to('cuda')
model.load_adapter(adapter_repo_id)

prompt_to_enhance = 'Sinister crocodile eating a jolly rabbit'

chat = [
    {'role' : 'user', 'content': prompt_to_enhance}
]

prompt = tokenizer.apply_chat_template(chat, 
                                       tokenize=False, 
                                       add_generation_prompt=True,
                                       return_tensors='pt')

encoding = tokenizer(prompt, return_tensors="pt").to('cuda')

generation_config = model.generation_config
generation_config.do_sample = True
generation_config.max_new_tokens = 96
generation_config.temperature = 0.3
generation_config.top_p = 0.7
generation_config.num_return_sequences = 1
generation_config.pad_token_id = tokenizer.eos_token_id
generation_config.eos_token_id = tokenizer.eos_token_id
generation_config.repetition_penalty = 2.0

with torch.inference_mode():
    outputs = model.generate(
        input_ids=encoding.input_ids,
        attention_mask=encoding.attention_mask,
        generation_config=generation_config
    )
    
print(tokenizer.decode(outputs[0], skip_special_tokens=True))

```

### Out-of-Scope Use

This model is meant to be used as a prompt enhancer. Inputs should be concise and not too detailed (no full prompts). 

Using this model for other purposes may yield unexpected behavior.

## Bias, Risks, and Limitations

This model was trained on a dataset partially generated by AI, which may contain bias.

This is a pretty lightweight model, so it may have significant limitations.

### Recommendations

Use high repetition penalty (> 2.0) and low temperature (< 0.4) for generation. Do not generate more than 128 tokens.

## Training Details

### Training Data

This model was trained for one epoch on [groloch/stable_diffusion_prompts_instruct](https://huggingface.co/datasets/groloch/stable_diffusion_prompts_instruct).

### Training Hyperparameters

_coming soon_

- PEFT 0.13.2