Florence 2 Medieval Zone Object Detection
This is Microsoft's Florence 2 model trained for 10 epochs with CATMuS Medieval Segmentation dataset with a learn rate of 1e-6
. This model would not be possible without the numerous annotators behind the various datasets available on HTR-United (See dataset for details). A special thanks to Thibault ClΓ©rice who converted the original CATMuS dataset (for HTR) to a segmentation dataset.
Model Details
- Developed by: William J.B. Mattingly
- License: CC-BY 4.0
- Finetuned from model: Florence-2-base-ft
Labels
The following table describes the labels, the ones used to train this model, the counts of those labels (multiples per image), and the definition of those labels with a link to the original documentation.
How to Get Started with the Model
Use the code below to get started with the model. All models are trained with float16.
import requests
from PIL import Image
from transformers import AutoProcessor, AutoModelForCausalLM
import os
from unittest.mock import patch
import requests
from PIL import Image
from transformers import AutoModelForCausalLM, AutoProcessor
from transformers.dynamic_module_utils import get_imports
import matplotlib.pyplot as plt
import matplotlib.patches as patches
# Mac solution => https://huggingface.co/microsoft/Florence-2-large-ft/discussions/4
def fixed_get_imports(filename: str | os.PathLike) -> list[str]:
"""Work around for https://huggingface.co/microsoft/phi-1_5/discussions/72."""
if not str(filename).endswith("/modeling_florence2.py"):
return get_imports(filename)
imports = get_imports(filename)
imports.remove("flash_attn")
return imports
with patch("transformers.dynamic_module_utils.get_imports", fixed_get_imports):
model = AutoModelForCausalLM.from_pretrained("medieval-data/florence2-medieval-bbox-zone-detection", trust_remote_code=True)
processor = AutoProcessor.from_pretrained("medieval-data/florence2-medieval-bbox-zone-detection", trust_remote_code=True)
def process_image(url):
prompt = "<OD>"
image = Image.open(requests.get(url, stream=True).raw)
inputs = processor(text=prompt, images=image, return_tensors="pt")
generated_ids = model.generate(
input_ids=inputs["input_ids"],
pixel_values=inputs["pixel_values"],
max_new_tokens=1024,
do_sample=False,
num_beams=3
)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
result = processor.post_process_generation(generated_text, task="<OD>", image_size=(image.width, image.height))
return result, image
image = "https://huggingface.co/datasets/CATMuS/medieval-segmentation/resolve/main/data/train/cambridge-corpus-christi-college-ms-111/page-002-of-003.jpg"
result, image = process_image(image)
fig, ax = plt.subplots(1, figsize=(15, 15))
ax.imshow(image)
# Add bounding boxes and labels to the plot
for bbox, label in zip(result['<OD>']['bboxes'], result['<OD>']['labels']):
x, y, width, height = bbox[0], bbox[1], bbox[2] - bbox[0], bbox[3] - bbox[1]
rect = patches.Rectangle((x, y), width, height, linewidth=2, edgecolor='r', facecolor='none')
ax.add_patch(rect)
plt.text(x, y, label, fontsize=12, bbox=dict(facecolor='yellow', alpha=0.5))
# Display the plot
plt.show()
- Downloads last month
- 15
Inference API (serverless) does not yet support model repos that contain custom code.