import gradio as gr import PIL.Image import transformers from transformers import PaliGemmaForConditionalGeneration, PaliGemmaProcessor import torch import os import string import functools import re import flax.linen as nn import jax import jax.numpy as jnp import numpy as np hf_token = os.getenv("HF_TOKEN") model_id = "google/paligemma-3b-mix-448" COLORS = ['#4285f4', '#db4437', '#f4b400', '#0f9d58', '#e48ef1'] device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = PaliGemmaForConditionalGeneration.from_pretrained(model_id, use_auth_token=hf_token).eval().to(device) processor = PaliGemmaProcessor.from_pretrained(model_id) ###### Transformers Inference def infer( image: PIL.Image.Image, text: str, max_new_tokens: int ) -> str: inputs = processor(text=text, images=image, return_tensors="pt").to(device) with torch.inference_mode(): generated_ids = model.generate( **inputs, max_new_tokens=max_new_tokens, do_sample=False ) result = processor.batch_decode(generated_ids, skip_special_tokens=True) return result[0][len(text):] ##### Parse segmentation output tokens into masks ##### Also returns bounding boxes with their labels def parse_segmentation(input_image, input_text): out = infer(input_image, input_text, max_new_tokens=100) objs = extract_objs(out.lstrip("\n"), input_image.size[0], input_image.size[1], unique_labels=True) labels = set(obj.get('name') for obj in objs if obj.get('name')) color_map = {l: COLORS[i % len(COLORS)] for i, l in enumerate(labels)} highlighted_text = [(obj['content'], obj.get('name')) for obj in objs] annotated_img = ( input_image, [ ( obj['mask'] if obj.get('mask') is not None else obj['xyxy'], obj['name'] or '', ) for obj in objs if 'mask' in obj or 'xyxy' in obj ], ) has_annotations = bool(annotated_img[1]) return annotated_img ######## Demo INTRO_TEXT = """## PaliGemma demo\n\n | [Github](https://github.com/google-research/big_vision/blob/main/big_vision/configs/proj/paligemma/README.md) | [Blogpost](https://huggingface.co/blog/paligemma) |\n\n PaliGemma is an open vision-language model by Google, inspired by [PaLI-3](https://arxiv.org/abs/2310.09199) and built with open components such as the [SigLIP](https://arxiv.org/abs/2303.15343) vision model and the [Gemma](https://arxiv.org/abs/2403.08295) language model. PaliGemma is designed as a versatile model for transfer to a wide range of vision-language tasks such as image and short video caption, visual question answering, text reading, object detection and object segmentation. \n\n This space includes models fine-tuned on a mix of downstream tasks, **inferred via 🤗 transformers**. See the [Blogpost](https://huggingface.co/blog/paligemma) and [README]((https://github.com/google-research/big_vision/blob/main/big_vision/configs/proj/paligemma/README.md)) for detailed information how to use and fine-tune PaliGemma models. \n\n **This is an experimental research model.** Make sure to add appropriate guardrails when using the model for applications. """ with gr.Blocks(css="style.css") as demo: gr.Markdown(INTRO_TEXT) with gr.Tab("Conversation"): with gr.Column(): image = gr.Image(type="pil") text_input = gr.Text(label="Input Text") text_output = gr.Text(label="Text Output") chat_btn = gr.Button() tokens = gr.Slider( label="Max New Tokens", info="Set to larger for longer generation.", minimum=10, maximum=100, value=20, step=10, ) chat_inputs = [ image, text_input, tokens ] chat_outputs = [ text_output ] chat_btn.click( fn=infer, inputs=chat_inputs, outputs=chat_outputs, ) examples = [["./bee.jpg", "What is on the flower?"], ["./examples/billard1.jpg", "How many red balls are there?"], ["./examples/bowie.jpg", "Who is this?"], ["./examples/emu.jpg", "What animal is this?"], ["./howto.jpg", "What does this image show?"], ["./examples/password.jpg", "What is the password?"], ["./examples/ulges.jpg", "Who is the author of this book?"]] gr.Markdown("Example images are licensed CC0 by [akolesnikoff@](https://github.com/akolesnikoff), [mbosnjak@](https://github.com/mbosnjak), [maximneumann@](https://github.com/maximneumann) and [merve](https://huggingface.co/merve).") gr.Examples( examples=examples, inputs=chat_inputs, ) with gr.Tab("Segment/Detect"): image = gr.Image(type="pil") seg_input = gr.Text(label="Entities to Segment/Detect") seg_btn = gr.Button("Submit") annotated_image = gr.AnnotatedImage(label="Output") examples = [["./cats.png", "segment cats"], ["./bee.jpg", "detect bee"], ["./examples/barsik.jpg", "segment cat"], ["./bird.jpg", "segment bird ; bird ; plant"]] gr.Markdown("Example images are licensed CC0 by [akolesnikoff@](https://github.com/akolesnikoff), [mbosnjak@](https://github.com/mbosnjak), [maximneumann@](https://github.com/maximneumann) and [merve](https://huggingface.co/merve).") gr.Examples( examples=examples, inputs=[image, seg_input], ) seg_inputs = [ image, seg_input ] seg_outputs = [ annotated_image ] seg_btn.click( fn=parse_segmentation, inputs=seg_inputs, outputs=seg_outputs, ) ### Postprocessing Utils for Segmentation Tokens ### Segmentation tokens are passed to another VAE which decodes them to a mask _MODEL_PATH = 'vae-oid.npz' _SEGMENT_DETECT_RE = re.compile( r'(.*?)' + r'' * 4 + r'\s*' + '(?:%s)?' % (r'' * 16) + r'\s*([^;<>]+)? ?(?:; )?', ) def _get_params(checkpoint): """Converts PyTorch checkpoint to Flax params.""" def transp(kernel): return np.transpose(kernel, (2, 3, 1, 0)) def conv(name): return { 'bias': checkpoint[name + '.bias'], 'kernel': transp(checkpoint[name + '.weight']), } def resblock(name): return { 'Conv_0': conv(name + '.0'), 'Conv_1': conv(name + '.2'), 'Conv_2': conv(name + '.4'), } return { '_embeddings': checkpoint['_vq_vae._embedding'], 'Conv_0': conv('decoder.0'), 'ResBlock_0': resblock('decoder.2.net'), 'ResBlock_1': resblock('decoder.3.net'), 'ConvTranspose_0': conv('decoder.4'), 'ConvTranspose_1': conv('decoder.6'), 'ConvTranspose_2': conv('decoder.8'), 'ConvTranspose_3': conv('decoder.10'), 'Conv_1': conv('decoder.12'), } def _quantized_values_from_codebook_indices(codebook_indices, embeddings): batch_size, num_tokens = codebook_indices.shape assert num_tokens == 16, codebook_indices.shape unused_num_embeddings, embedding_dim = embeddings.shape encodings = jnp.take(embeddings, codebook_indices.reshape((-1)), axis=0) encodings = encodings.reshape((batch_size, 4, 4, embedding_dim)) return encodings @functools.cache def _get_reconstruct_masks(): """Reconstructs masks from codebook indices. Returns: A function that expects indices shaped `[B, 16]` of dtype int32, each ranging from 0 to 127 (inclusive), and that returns a decoded masks sized `[B, 64, 64, 1]`, of dtype float32, in range [-1, 1]. """ class ResBlock(nn.Module): features: int @nn.compact def __call__(self, x): original_x = x x = nn.Conv(features=self.features, kernel_size=(3, 3), padding=1)(x) x = nn.relu(x) x = nn.Conv(features=self.features, kernel_size=(3, 3), padding=1)(x) x = nn.relu(x) x = nn.Conv(features=self.features, kernel_size=(1, 1), padding=0)(x) return x + original_x class Decoder(nn.Module): """Upscales quantized vectors to mask.""" @nn.compact def __call__(self, x): num_res_blocks = 2 dim = 128 num_upsample_layers = 4 x = nn.Conv(features=dim, kernel_size=(1, 1), padding=0)(x) x = nn.relu(x) for _ in range(num_res_blocks): x = ResBlock(features=dim)(x) for _ in range(num_upsample_layers): x = nn.ConvTranspose( features=dim, kernel_size=(4, 4), strides=(2, 2), padding=2, transpose_kernel=True, )(x) x = nn.relu(x) dim //= 2 x = nn.Conv(features=1, kernel_size=(1, 1), padding=0)(x) return x def reconstruct_masks(codebook_indices): quantized = _quantized_values_from_codebook_indices( codebook_indices, params['_embeddings'] ) return Decoder().apply({'params': params}, quantized) with open(_MODEL_PATH, 'rb') as f: params = _get_params(dict(np.load(f))) return jax.jit(reconstruct_masks, backend='cpu') def extract_objs(text, width, height, unique_labels=False): """Returns objs for a string with "" and "" tokens.""" objs = [] seen = set() while text: m = _SEGMENT_DETECT_RE.match(text) if not m: break print("m", m) gs = list(m.groups()) before = gs.pop(0) name = gs.pop() y1, x1, y2, x2 = [int(x) / 1024 for x in gs[:4]] y1, x1, y2, x2 = map(round, (y1*height, x1*width, y2*height, x2*width)) seg_indices = gs[4:20] if seg_indices[0] is None: mask = None else: seg_indices = np.array([int(x) for x in seg_indices], dtype=np.int32) m64, = _get_reconstruct_masks()(seg_indices[None])[..., 0] m64 = np.clip(np.array(m64) * 0.5 + 0.5, 0, 1) m64 = PIL.Image.fromarray((m64 * 255).astype('uint8')) mask = np.zeros([height, width]) if y2 > y1 and x2 > x1: mask[y1:y2, x1:x2] = np.array(m64.resize([x2 - x1, y2 - y1])) / 255.0 content = m.group() if before: objs.append(dict(content=before)) content = content[len(before):] while unique_labels and name in seen: name = (name or '') + "'" seen.add(name) objs.append(dict( content=content, xyxy=(x1, y1, x2, y2), mask=mask, name=name)) text = text[len(before) + len(content):] if text: objs.append(dict(content=text)) return objs ######### if __name__ == "__main__": demo.queue(max_size=10).launch(debug=True)