File size: 3,705 Bytes
6dee0c5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 |
---
base_model: ministral/Ministral-3b-instruct
library_name: peft
---
# Ministral-3b-instruct-PromptEnhancing
Ministral-3b-instruct-PromptEnhancing is a LoRA-finetuned instruction-tuned text-generation model.
This model was releaszed alongside three other models in the 2-3b parameters range, all trained on the same dataset with the same training arguments.
## Model Details
### Model Description
This model is a LoRA fine-tune of [ministral/Ministral-3b-instruct](https://huggingface.co/ministral/Ministral-3b-instruct).
The goal of this finetune is to provide a light-weight prompt enhancing model for stable diffusion (or other diffusers sharing the same prompting conventions) to make image generation more accessible to everyone.
- **Developed by:** [groloch](https://huggingface.co/groloch)
- **Model type:** LoRA
- **Language(s) (NLP):** English
- **License:** [apache 2.0](https://huggingface.co/datasets/choosealicense/licenses/blob/main/markdown/apache-2.0.md)
- **Finetuned from model:** [ministral/Ministral-3b-instruct](https://huggingface.co/ministral/Ministral-3b-instruct).
### Model Sources [optional]
- **Paper:** _Coming soon_
- **Demo:** _Coming soon_
## Uses
This model should be used as a prompt-enhancing model for diffusers. To use it, the simplest is to try out at the official [demo](#) (_coming soon_).
### Direct Use
If you want to use it locally, refer to the following code snippet:
```python
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
base_repo_id = 'ministral/Ministral-3b-instruct'
adapter_repo_id = 'groloch/Ministral-3b-instruct-PromptEnhancing'
tokenizer = AutoTokenizer.from_pretrained(base_repo_id)
model = AutoModelForCausalLM.from_pretrained(base_repo_id, torch_dtype=torch.bfloat16).to('cuda')
model.load_adapter(adapter_repo_id)
prompt_to_enhance = 'Sinister crocodile eating a jolly rabbit'
chat = [
{'role' : 'user', 'content': prompt_to_enhance}
]
prompt = tokenizer.apply_chat_template(chat,
tokenize=False,
add_generation_prompt=True,
return_tensors='pt')
encoding = tokenizer(prompt, return_tensors="pt").to('cuda')
generation_config = model.generation_config
generation_config.do_sample = True
generation_config.max_new_tokens = 96
generation_config.temperature = 0.3
generation_config.top_p = 0.7
generation_config.num_return_sequences = 1
generation_config.pad_token_id = tokenizer.eos_token_id
generation_config.eos_token_id = tokenizer.eos_token_id
generation_config.repetition_penalty = 2.0
with torch.inference_mode():
outputs = model.generate(
input_ids=encoding.input_ids,
attention_mask=encoding.attention_mask,
generation_config=generation_config
)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
```
### Out-of-Scope Use
This model is meant to be used as a prompt enhancer. Inputs should be concise and not too detailed (no full prompts).
Using this model for other purposes may yield unexpected behavior.
## Bias, Risks, and Limitations
This model was trained on a dataset partially generated by AI, which may contain bias.
This is a pretty lightweight model, so it may have significant limitations.
### Recommendations
Use high repetition penalty (> 2.0) and low temperature (< 0.4) for generation. Do not generate more than 128 tokens.
## Training Details
### Training Data
This model was trained for one epoch on [groloch/stable_diffusion_prompts_instruct](https://huggingface.co/datasets/groloch/stable_diffusion_prompts_instruct).
### Training Hyperparameters
_coming soon_
- PEFT 0.13.2 |