Thouph's picture
Create app.py
5ba6e49 verified
raw
history blame
1.81 kB
import json
import random
random.seed(1234)
import torch
from transformers import Qwen2ForSequenceClassification, AutoTokenizer
import gradio as gr
from datetime import datetime
torch.set_grad_enabled(False)
model = Qwen2ForSequenceClassification.from_pretrained("Thouph/danbooru-to-e621-qwen2.5-0.5b", num_labels = 9086, device_map="cpu")
model.eval()
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B")
with open("tags_9083.json", "r") as file:
allowed_tags = json.load(file)
allowed_tags = sorted(allowed_tags)
allowed_tags.append("explicit")
allowed_tags.append("questionable")
allowed_tags.append("safe")
def create_tags(prompt, threshold):
inputs = tokenizer(
prompt,
padding="do_not_pad",
max_length=512,
truncation=True,
return_tensors="pt",
)
output = model(**inputs).logits
output = torch.nn.functional.sigmoid(output)
indices = torch.where(output > threshold)
values = output[indices]
indices = indices[1]
values = values.squeeze()
temp = []
tag_score = dict()
for i in range(indices.size(0)):
temp.append([allowed_tags[indices[i]], values[i].item()])
tag_score[allowed_tags[indices[i]]] = values[i].item()
temp = [t[0] for t in temp]
text_no_impl = " ".join(temp)
current_datetime = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
print(f"{current_datetime}: finished.")
return text_no_impl, tag_score
demo = gr.Interface(
create_tags,
inputs=[
gr.TextArea(label="Prompt",),
gr.Slider(minimum=0.00, maximum=1.00, step=0.01, value=0.40, label="Threshold")
],
outputs=[
gr.Textbox(label="Tag String"),
gr.Label(label="Tag Predictions", num_top_classes=200),
],
allow_flagging="never",
)
demo.launch()