File size: 1,081 Bytes
8afa9a1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 |
import torch
from transformers import AutoModel, AutoTokenizer
from detectron2.config import get_cfg
from detectron2.engine import DefaultPredictor
import os
class MinerUModelLoader:
@staticmethod
def load_models(base_path):
models = {}
# Layout模型加载
cfg = get_cfg()
cfg.merge_from_file(os.path.join(base_path, "models/Layout/config.json"))
cfg.MODEL.WEIGHTS = os.path.join(base_path, "models/Layout/model_final.pth")
models["layout"] = DefaultPredictor(cfg)
# 公式检测模型
models["formula_detector"] = torch.load(os.path.join(base_path, "models/MFD/weights.pt"))
# 公式识别模型
models["formula_recognizer"] = AutoModel.from_pretrained(
os.path.join(base_path, "models/MFR/UniMERNet")
)
# 表格识别模型
models["table_recognizer"] = AutoModel.from_pretrained(
os.path.join(base_path, "models/TabRec/StructEqTable")
)
return models |