Spaces:
Running
Running
import os | |
import csv | |
import json | |
# import concurrent.futures | |
import random | |
# import gradio as gr | |
# import requests | |
import io, base64, json | |
#import spaces | |
from PIL import Image | |
# from .models import IMAGE_GENERATION_MODELS, load_pipeline | |
# from .fetch_museum_results import draw_from_imagen_museum, draw2_from_imagen_museum | |
# import time | |
# import threading | |
from . import CASE_NAMES, MODEL_INFO_CSV, DATASET_PATH, OUTPUT_PATH | |
from typing import Optional, List | |
from datetime import datetime | |
import pandas as pd | |
class Model: | |
def __init__(self, name: str, upload_date: str, description: str, parameter_count: str, creator: str, result_path: str, license: str, link: Optional[str] = None): | |
""" | |
Initializes the Model object. The upload_date string is converted to a datetime.date object. | |
:param name: Name of the model | |
:param upload_date: Upload date (string format) | |
:param description: Model description | |
:param parameter_count: Number of parameters of the model | |
:param creator: Creator of the model | |
:param result_path: Local path for saving generated results | |
:param license: License of the model | |
:param link: Link to the model (if it's open source) | |
""" | |
self.name = name | |
self.upload_date = upload_date | |
self.description = description | |
self.parameter_count = parameter_count | |
self.creator = creator | |
self.result_path = result_path | |
self.license = license | |
self.link = link | |
def __repr__(self): | |
return f"Model(name={self.name}, upload_date={self.upload_date}, description={self.description}, parameter_count={self.parameter_count}, creator={self.creator}, result_path={self.result_path}, license={self.license}, link={self.link})" | |
# def get_result(self, case_name): | |
# case_folder = os.path.join(OUTPUT_PATH, self.result_path, case_name) | |
# image_files = [f for f in os.listdir(case_folder) if f.endswith('.jpg')] | |
# # Sort the images in the order they appear (to maintain a consistent order) | |
# image_files.sort() | |
# output_images = [] | |
# for image_file in image_files: | |
# image_path = os.path.join(case_folder, image_file) | |
# image = Image.open(image_path) | |
# output_images.append(image) | |
# return output_images | |
def get_result(self, case_name): | |
# Read the CSV file | |
csv_file = os.path.join(OUTPUT_PATH, self.result_path) # result_path is the path to the CSV file | |
df = pd.read_csv(csv_file) | |
# Find all rows where the 'name' column starts with the case_name | |
matching_rows = df[df['name'].str.startswith(case_name)] | |
# Sort the rows by the 'name' column | |
sorted_matching_rows = matching_rows.sort_values(by='name') | |
# Extract the 'pc_url' column and return it as a list | |
pc_urls = sorted_matching_rows['pc_url'].tolist() | |
return pc_urls | |
class ModelManager: | |
def __init__(self): | |
# Initialize model_list as an empty list | |
self.model_list: List[Model] = [] | |
# Load model data from the provided CSV file | |
self.load_models_from_csv(MODEL_INFO_CSV) | |
def load_models_from_csv(self, csv_file: str): | |
""" | |
Loads model data from a CSV file and creates Model instances. | |
The CSV file should have the following columns: | |
name, upload_date, description, parameter_count, creator, link | |
:param csv_file: Path to the CSV file containing model information | |
""" | |
try: | |
with open(csv_file, 'r', newline='', encoding='utf-8') as file: | |
csv_reader = csv.reader(file) | |
header = next(csv_reader) # Skip the header | |
for row in csv_reader: | |
if len(row) == 8: # Ensure that all columns are present in the row | |
name, upload_date, description, parameter_count, creator, result_path, license, link = row | |
# Create Model instance and append it to model_list | |
model = Model( | |
name=name, | |
upload_date=upload_date, | |
description=description, | |
parameter_count=parameter_count, # Convert parameter count to integer | |
creator=creator, | |
result_path=result_path, | |
license=license, | |
link=link if link else None | |
) | |
self.model_list.append(model) | |
except FileNotFoundError: | |
print(f"Error: The file {csv_file} was not found.") | |
except Exception as e: | |
print(f"An error occurred while loading the CSV file: {e}") | |
def choose_case_randomly(self): | |
random_case = random.choice(CASE_NAMES) | |
case_meta_path = os.path.join("dataset", random_case, "meta.json") | |
with open(case_meta_path, 'r') as file: | |
case_info = json.load(file) | |
return random_case, case_info | |
def get_model_from_name(self, model_name: str) -> Optional[Model]: | |
""" | |
Given the model name, this function retrieves the corresponding Model object from the model list. | |
:param model_name: The name of the model to find | |
:return: The corresponding Model instance or None if not found | |
""" | |
for model in self.model_list: | |
if model.name == model_name: | |
return model | |
return None | |
def get_name_list(self): | |
name_list = [] | |
for model in self.model_list: | |
name_list.append(model.name) | |
return name_list | |
def get_model_info_md(self): | |
model_description_md = \ | |
""" | |
| name | description | creator | upload time | | |
| ---- | ---- | ---- | ---- | | |
""" | |
for model in self.model_list: | |
# Parse the upload_date to a uniform format (YYYY-MM-DD HH:MM) | |
try: | |
upload_date = datetime.strptime(model.upload_date, "%Y.%m.%d.%H.%M.%S") | |
formatted_date = upload_date.strftime("%Y-%m-%d %H:%M") # Format to 'YYYY-MM-DD HH:MM' | |
except ValueError: | |
formatted_date = model.upload_date # If parsing fails, keep the original date | |
one_model_md = f"| [{model.name}]({model.link}) | {model.description} | {model.creator} | {formatted_date} |\n" | |
model_description_md += one_model_md | |
return model_description_md | |
def get_result_of_random_case_anony(self): | |
""" | |
This function selects a random case, loads the images, reads the prompt from instruction.txt, | |
and returns the images generated by two randomly selected models. | |
""" | |
# Choose a random case | |
case_name, case_info = self.choose_case_randomly() | |
case_folder = os.path.join(DATASET_PATH, case_name) | |
# Open the images.txt file and read non-empty lines as image URLs | |
images_txt_path = os.path.join(case_folder, "images.txt") | |
input_images = [] | |
# Read all non-empty lines from the images.txt file | |
if os.path.exists(images_txt_path): | |
with open(images_txt_path, 'r') as file: | |
input_images = [line.strip() for line in file if line.strip()] | |
instruction_path = os.path.join(case_folder, "instruction.txt") | |
with open(instruction_path, 'r') as file: | |
prompt = file.read() | |
# Choose two random model | |
model_A, model_B = random.sample([model for model in self.model_list], 2) | |
output_images_A = model_A.get_result(case_name) | |
output_images_B = model_B.get_result(case_name) | |
return model_A, model_B, prompt, input_images, output_images_A, output_images_B | |
def get_result_of_random_case(self, model_name_A, model_name_B): | |
""" | |
This function allows you to specify the names of the models, and it will return their results for the chosen case. | |
""" | |
# Choose a random case | |
case_name, case_info = self.choose_case_randomly() | |
case_folder = os.path.join(DATASET_PATH, case_name) | |
# Open the images.txt file and read non-empty lines as image URLs | |
images_txt_path = os.path.join(case_folder, "images.txt") | |
input_images = [] | |
# Read all non-empty lines from the images.txt file | |
if os.path.exists(images_txt_path): | |
with open(images_txt_path, 'r') as file: | |
input_images = [line.strip() for line in file if line.strip()] | |
instruction_path = os.path.join(case_folder, "instruction.txt") | |
with open(instruction_path, 'r') as file: | |
prompt = file.read() | |
# Choose two random model | |
model_A = self.get_model_from_name(model_name_A) | |
model_B = self.get_model_from_name(model_name_B) | |
output_images_A = model_A.get_result(case_name) | |
output_images_B = model_B.get_result(case_name) | |
return model_A, model_B, prompt, input_images, output_images_A, output_images_B | |