IDEA-Bench-Arena / model /model_manager.py
JasiLiang's picture
initial commit
62d106f verified
raw
history blame
9.16 kB
import os
import csv
import json
# import concurrent.futures
import random
# import gradio as gr
# import requests
import io, base64, json
#import spaces
from PIL import Image
# from .models import IMAGE_GENERATION_MODELS, load_pipeline
# from .fetch_museum_results import draw_from_imagen_museum, draw2_from_imagen_museum
# import time
# import threading
from . import CASE_NAMES, MODEL_INFO_CSV, DATASET_PATH, OUTPUT_PATH
from typing import Optional, List
from datetime import datetime
import pandas as pd
class Model:
def __init__(self, name: str, upload_date: str, description: str, parameter_count: str, creator: str, result_path: str, license: str, link: Optional[str] = None):
"""
Initializes the Model object. The upload_date string is converted to a datetime.date object.
:param name: Name of the model
:param upload_date: Upload date (string format)
:param description: Model description
:param parameter_count: Number of parameters of the model
:param creator: Creator of the model
:param result_path: Local path for saving generated results
:param license: License of the model
:param link: Link to the model (if it's open source)
"""
self.name = name
self.upload_date = upload_date
self.description = description
self.parameter_count = parameter_count
self.creator = creator
self.result_path = result_path
self.license = license
self.link = link
def __repr__(self):
return f"Model(name={self.name}, upload_date={self.upload_date}, description={self.description}, parameter_count={self.parameter_count}, creator={self.creator}, result_path={self.result_path}, license={self.license}, link={self.link})"
# def get_result(self, case_name):
# case_folder = os.path.join(OUTPUT_PATH, self.result_path, case_name)
# image_files = [f for f in os.listdir(case_folder) if f.endswith('.jpg')]
# # Sort the images in the order they appear (to maintain a consistent order)
# image_files.sort()
# output_images = []
# for image_file in image_files:
# image_path = os.path.join(case_folder, image_file)
# image = Image.open(image_path)
# output_images.append(image)
# return output_images
def get_result(self, case_name):
# Read the CSV file
csv_file = os.path.join(OUTPUT_PATH, self.result_path) # result_path is the path to the CSV file
df = pd.read_csv(csv_file)
# Find all rows where the 'name' column starts with the case_name
matching_rows = df[df['name'].str.startswith(case_name)]
# Sort the rows by the 'name' column
sorted_matching_rows = matching_rows.sort_values(by='name')
# Extract the 'pc_url' column and return it as a list
pc_urls = sorted_matching_rows['pc_url'].tolist()
return pc_urls
class ModelManager:
def __init__(self):
# Initialize model_list as an empty list
self.model_list: List[Model] = []
# Load model data from the provided CSV file
self.load_models_from_csv(MODEL_INFO_CSV)
def load_models_from_csv(self, csv_file: str):
"""
Loads model data from a CSV file and creates Model instances.
The CSV file should have the following columns:
name, upload_date, description, parameter_count, creator, link
:param csv_file: Path to the CSV file containing model information
"""
try:
with open(csv_file, 'r', newline='', encoding='utf-8') as file:
csv_reader = csv.reader(file)
header = next(csv_reader) # Skip the header
for row in csv_reader:
if len(row) == 8: # Ensure that all columns are present in the row
name, upload_date, description, parameter_count, creator, result_path, license, link = row
# Create Model instance and append it to model_list
model = Model(
name=name,
upload_date=upload_date,
description=description,
parameter_count=parameter_count, # Convert parameter count to integer
creator=creator,
result_path=result_path,
license=license,
link=link if link else None
)
self.model_list.append(model)
except FileNotFoundError:
print(f"Error: The file {csv_file} was not found.")
except Exception as e:
print(f"An error occurred while loading the CSV file: {e}")
def choose_case_randomly(self):
random_case = random.choice(CASE_NAMES)
case_meta_path = os.path.join("dataset", random_case, "meta.json")
with open(case_meta_path, 'r') as file:
case_info = json.load(file)
return random_case, case_info
def get_model_from_name(self, model_name: str) -> Optional[Model]:
"""
Given the model name, this function retrieves the corresponding Model object from the model list.
:param model_name: The name of the model to find
:return: The corresponding Model instance or None if not found
"""
for model in self.model_list:
if model.name == model_name:
return model
return None
def get_name_list(self):
name_list = []
for model in self.model_list:
name_list.append(model.name)
return name_list
def get_model_info_md(self):
model_description_md = \
"""
| name | description | creator | upload time |
| ---- | ---- | ---- | ---- |
"""
for model in self.model_list:
# Parse the upload_date to a uniform format (YYYY-MM-DD HH:MM)
try:
upload_date = datetime.strptime(model.upload_date, "%Y.%m.%d.%H.%M.%S")
formatted_date = upload_date.strftime("%Y-%m-%d %H:%M") # Format to 'YYYY-MM-DD HH:MM'
except ValueError:
formatted_date = model.upload_date # If parsing fails, keep the original date
one_model_md = f"| [{model.name}]({model.link}) | {model.description} | {model.creator} | {formatted_date} |\n"
model_description_md += one_model_md
return model_description_md
def get_result_of_random_case_anony(self):
"""
This function selects a random case, loads the images, reads the prompt from instruction.txt,
and returns the images generated by two randomly selected models.
"""
# Choose a random case
case_name, case_info = self.choose_case_randomly()
case_folder = os.path.join(DATASET_PATH, case_name)
# Open the images.txt file and read non-empty lines as image URLs
images_txt_path = os.path.join(case_folder, "images.txt")
input_images = []
# Read all non-empty lines from the images.txt file
if os.path.exists(images_txt_path):
with open(images_txt_path, 'r') as file:
input_images = [line.strip() for line in file if line.strip()]
instruction_path = os.path.join(case_folder, "instruction.txt")
with open(instruction_path, 'r') as file:
prompt = file.read()
# Choose two random model
model_A, model_B = random.sample([model for model in self.model_list], 2)
output_images_A = model_A.get_result(case_name)
output_images_B = model_B.get_result(case_name)
return model_A, model_B, prompt, input_images, output_images_A, output_images_B
def get_result_of_random_case(self, model_name_A, model_name_B):
"""
This function allows you to specify the names of the models, and it will return their results for the chosen case.
"""
# Choose a random case
case_name, case_info = self.choose_case_randomly()
case_folder = os.path.join(DATASET_PATH, case_name)
# Open the images.txt file and read non-empty lines as image URLs
images_txt_path = os.path.join(case_folder, "images.txt")
input_images = []
# Read all non-empty lines from the images.txt file
if os.path.exists(images_txt_path):
with open(images_txt_path, 'r') as file:
input_images = [line.strip() for line in file if line.strip()]
instruction_path = os.path.join(case_folder, "instruction.txt")
with open(instruction_path, 'r') as file:
prompt = file.read()
# Choose two random model
model_A = self.get_model_from_name(model_name_A)
model_B = self.get_model_from_name(model_name_B)
output_images_A = model_A.get_result(case_name)
output_images_B = model_B.get_result(case_name)
return model_A, model_B, prompt, input_images, output_images_A, output_images_B