Spaces:
Running
Running
import gradio as gr | |
from transformers import T5TokenizerFast, CLIPTokenizer | |
# Load the common tokenizers once | |
t5_tokenizer = T5TokenizerFast.from_pretrained("google/t5-v1_1-xxl", legacy=False) | |
clip_tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") | |
def count_tokens(text): | |
# Get tokens and their IDs | |
t5_tokens = t5_tokenizer.encode(text, return_tensors="pt", add_special_tokens=True)[0].tolist() | |
clip_tokens = clip_tokenizer.encode(text, add_special_tokens=True) | |
# Decode individual tokens for display, explicitly setting skip_special_tokens=False | |
t5_decoded = [] | |
for token in t5_tokens: | |
decoded = t5_tokenizer.decode([token], skip_special_tokens=False) | |
if decoded.isspace(): | |
decoded = "␣" | |
elif decoded == "": | |
# Handle special tokens explicitly for T5 | |
if token == 3: | |
decoded = "▁" # Represent token ID 3 as ▁ | |
else: | |
decoded = "∅" # Default for other empty tokens | |
t5_decoded.append(decoded) | |
clip_decoded = [] | |
for token in clip_tokens: | |
decoded = clip_tokenizer.decode([token], skip_special_tokens=False) | |
if decoded.isspace(): | |
decoded = "␣" | |
elif decoded == "": | |
decoded = "∅" | |
clip_decoded.append(decoded) | |
# Create highlighted text tuples with empty labels | |
t5_highlights = [(token, "") for token in t5_decoded] | |
clip_highlights = [(token, "") for token in clip_decoded] | |
return ( | |
# T5 outputs | |
len(t5_tokens), | |
t5_highlights, | |
str(t5_tokens), | |
# CLIP outputs | |
len(clip_tokens), | |
clip_highlights, | |
str(clip_tokens) | |
) | |
# Create a Gradio interface with custom layout | |
with gr.Blocks(title="DiffusionTokenizer") as iface: | |
gr.Markdown("# DiffusionTokenizer🔢") | |
gr.Markdown("A lightning fast visualization of the tokens used in diffusion models. Use it to understand how your prompt is tokenized.") | |
with gr.Row(): | |
text_input = gr.Textbox(label="Diffusion Prompt", placeholder="Enter your prompt here...") | |
with gr.Row(): | |
# T5 Column | |
with gr.Column(): | |
gr.Markdown("### T5 Tokenizer Results") | |
t5_count = gr.Number(label="T5 Token Count") | |
t5_highlights = gr.HighlightedText(label="T5 Tokens", show_legend=True) | |
t5_ids = gr.Textbox(label="T5 Token IDs", lines=2) | |
# CLIP Column | |
with gr.Column(): | |
gr.Markdown("### CLIP Tokenizer Results") | |
clip_count = gr.Number(label="CLIP Token Count") | |
clip_highlights = gr.HighlightedText(label="CLIP Tokens", show_legend=True) | |
clip_ids = gr.Textbox(label="CLIP Token IDs", lines=2) | |
text_input.change( | |
fn=count_tokens, | |
inputs=[text_input], | |
outputs=[t5_count, t5_highlights, t5_ids, clip_count, clip_highlights, clip_ids] | |
) | |
# Launch the app | |
iface.launch(show_error=True, ssr_mode = False) |