takarajordan's picture
spelling
a6ac1a4 verified
import gradio as gr
from transformers import T5TokenizerFast, CLIPTokenizer
# Load the common tokenizers once
t5_tokenizer = T5TokenizerFast.from_pretrained("google/t5-v1_1-xxl", legacy=False)
clip_tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
def count_tokens(text):
# Get tokens and their IDs
t5_tokens = t5_tokenizer.encode(text, return_tensors="pt", add_special_tokens=True)[0].tolist()
clip_tokens = clip_tokenizer.encode(text, add_special_tokens=True)
# Decode individual tokens for display, explicitly setting skip_special_tokens=False
t5_decoded = []
for token in t5_tokens:
decoded = t5_tokenizer.decode([token], skip_special_tokens=False)
if decoded.isspace():
decoded = "␣"
elif decoded == "":
# Handle special tokens explicitly for T5
if token == 3:
decoded = "▁" # Represent token ID 3 as ▁
else:
decoded = "∅" # Default for other empty tokens
t5_decoded.append(decoded)
clip_decoded = []
for token in clip_tokens:
decoded = clip_tokenizer.decode([token], skip_special_tokens=False)
if decoded.isspace():
decoded = "␣"
elif decoded == "":
decoded = "∅"
clip_decoded.append(decoded)
# Create highlighted text tuples with empty labels
t5_highlights = [(token, "") for token in t5_decoded]
clip_highlights = [(token, "") for token in clip_decoded]
return (
# T5 outputs
len(t5_tokens),
t5_highlights,
str(t5_tokens),
# CLIP outputs
len(clip_tokens),
clip_highlights,
str(clip_tokens)
)
# Create a Gradio interface with custom layout
with gr.Blocks(title="DiffusionTokenizer") as iface:
gr.Markdown("# DiffusionTokenizer🔢")
gr.Markdown("A lightning fast visualization of the tokens used in diffusion models. Use it to understand how your prompt is tokenized.")
with gr.Row():
text_input = gr.Textbox(label="Diffusion Prompt", placeholder="Enter your prompt here...")
with gr.Row():
# T5 Column
with gr.Column():
gr.Markdown("### T5 Tokenizer Results")
t5_count = gr.Number(label="T5 Token Count")
t5_highlights = gr.HighlightedText(label="T5 Tokens", show_legend=True)
t5_ids = gr.Textbox(label="T5 Token IDs", lines=2)
# CLIP Column
with gr.Column():
gr.Markdown("### CLIP Tokenizer Results")
clip_count = gr.Number(label="CLIP Token Count")
clip_highlights = gr.HighlightedText(label="CLIP Tokens", show_legend=True)
clip_ids = gr.Textbox(label="CLIP Token IDs", lines=2)
text_input.change(
fn=count_tokens,
inputs=[text_input],
outputs=[t5_count, t5_highlights, t5_ids, clip_count, clip_highlights, clip_ids]
)
# Launch the app
iface.launch(show_error=True, ssr_mode = False)