alikayh's picture
Create train.py
ea4688b verified
from datasets import load_dataset
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
import torch
from PIL import Image
import requests
# Load dataset
dataset = load_dataset("nielsr/funsd")
# Load pre-trained model and processor
processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-printed")
model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-printed")
# Preprocess the dataset
def preprocess_images(examples):
images = [Image.open(img).convert("RGB") for img in examples['image_path']]
pixel_values = processor(images=images, return_tensors="pt").pixel_values
return {"pixel_values": pixel_values}
encoded_dataset = dataset.map(preprocess_images, batched=True)
# Preprocess the labels
max_length = 64
def preprocess_labels(examples):
labels = processor.tokenizer(examples['words'], is_split_into_words=True, padding="max_length", max_length=max_length, truncation=True)
return labels
encoded_dataset = encoded_dataset.map(preprocess_labels, batched=True)
# Prepare for training
model.config.decoder_start_token_id = processor.tokenizer.cls_token_id
model.config.pad_token_id = processor.tokenizer.pad_token_id
# Define training arguments
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments
training_args = Seq2SeqTrainingArguments(
output_dir="./trocr-finetuned-funsd",
per_device_train_batch_size=8,
per_device_eval_batch_size=8,
learning_rate=5e-5,
num_train_epochs=3,
weight_decay=0.01,
logging_dir="./trocr-finetuned-funsd/logs",
logging_steps=10,
evaluation_strategy="epoch",
save_strategy="epoch",
)
trainer = Seq2SeqTrainer(
model=model,
args=training_args,
train_dataset=encoded_dataset["train"],
eval_dataset=encoded_dataset["test"],
tokenizer=processor.tokenizer,
)
# Train the model
trainer.train()