File size: 1,081 Bytes
8afa9a1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import torch
from transformers import AutoModel, AutoTokenizer
from detectron2.config import get_cfg
from detectron2.engine import DefaultPredictor
import os

class MinerUModelLoader:
    @staticmethod
    def load_models(base_path):
        models = {}
        
        # Layout模型加载
        cfg = get_cfg()
        cfg.merge_from_file(os.path.join(base_path, "models/Layout/config.json"))
        cfg.MODEL.WEIGHTS = os.path.join(base_path, "models/Layout/model_final.pth")
        models["layout"] = DefaultPredictor(cfg)
        
        # 公式检测模型
        models["formula_detector"] = torch.load(os.path.join(base_path, "models/MFD/weights.pt"))
        
        # 公式识别模型
        models["formula_recognizer"] = AutoModel.from_pretrained(
            os.path.join(base_path, "models/MFR/UniMERNet")
        )
        
        # 表格识别模型
        models["table_recognizer"] = AutoModel.from_pretrained(
            os.path.join(base_path, "models/TabRec/StructEqTable")
        )
        
        return models