from PIL import Image from huggingface_hub import hf_hub_download from ultralytics import YOLO from models.tools.draw import add_bboxes class YoloModel: def __init__(self, repo_name: str, file_name: str): weight_file = YoloModel.download_weight_file(repo_name, file_name) self.model = YOLO(weight_file) @staticmethod def download_weight_file(repo_name: str, file_name: str): return hf_hub_download(repo_name, file_name) def detect(self, im): return self.model(source=im) def preview_detect(self, filename, confidence): image = Image.open(filename) results = self.model(source=image) res_img = image for result in results: res_img = add_bboxes(res_img, result, confidence) return res_img def test(): model = YoloModel("SHOU-ISD/fire-and-smoke", "yolov8n.pt") im = Image.open("./tests/fire1.jpg") results = model.model(source=im) for result in results: im = add_bboxes(im, result, confidence=0.1) print(result.boxes) def argument_parser(): """ Argument Parser :return: args """ import argparse parser = argparse.ArgumentParser(description='Help for YoloModel') parser.add_argument('--test', '-t', action='store_true', help='Run test') # list of repo_name&file_name parser.add_argument('--weight_files', '-w', nargs='+', help='List of weight files') return parser.parse_args() def pre_cache_weight_files(weight_files: list[str]): """ Pre-cache weight files :return: None """ for weight_file in weight_files: weight_file = weight_file.split(":") YoloModel.download_weight_file(weight_file[0], weight_file[1]) if __name__ == '__main__': args = argument_parser() if args.test: test() else: pre_cache_weight_files(args.weight_files)