ksort commited on
Commit
e1665ec
·
1 Parent(s): 8b4d7b7

add image_cache

Browse files
.gitignore CHANGED
@@ -172,7 +172,8 @@ ksort-logs/
172
 
173
  *.mp4
174
  cache_video/
 
175
 
176
- *generate_image_cache.py
177
- *generate_video_cache.py
178
- *get_webvid_prompt.py
 
172
 
173
  *.mp4
174
  cache_video/
175
+ cache_image/
176
 
177
+ /model/models/generate_image_cache.py
178
+ /model/models/generate_video_cache.py
179
+ /get_webvid_prompt.py
README.md CHANGED
@@ -31,14 +31,4 @@ pip install -r requirements.txt
31
  ## Start Hugging Face UI
32
  ```bash
33
  python app.py
34
- ```
35
-
36
- ## Start Log server
37
- ```bash
38
- uvicorn serve.log_server:app --reload --port 22005 --host 0.0.0.0
39
- ```
40
-
41
- ## Update leaderboard
42
- ```bash
43
- cd arena_elo && bash update_leaderboard.sh
44
  ```
 
31
  ## Start Hugging Face UI
32
  ```bash
33
  python app.py
 
 
 
 
 
 
 
 
 
 
34
  ```
model/fetch_museum_results/__init__.py DELETED
@@ -1,62 +0,0 @@
1
- from .imagen_museum import TASK_DICT, DOMAIN
2
- from .imagen_museum import fetch_indexes
3
- import random
4
-
5
- ARENA_TO_IG_MUSEUM = {"LCM(v1.5/XL)":"LCM",
6
- "PlayGroundV2.5": "PlayGroundV2_5"}
7
-
8
- def draw2_from_imagen_museum(task, model_name1, model_name2, model_name3, model_name4):
9
- task_name = TASK_DICT[task]
10
- model_name1 = ARENA_TO_IG_MUSEUM[model_name1] if model_name1 in ARENA_TO_IG_MUSEUM else model_name1
11
- model_name2 = ARENA_TO_IG_MUSEUM[model_name2] if model_name2 in ARENA_TO_IG_MUSEUM else model_name2
12
- model_name3 = ARENA_TO_IG_MUSEUM[model_name3] if model_name3 in ARENA_TO_IG_MUSEUM else model_name3
13
- model_name4 = ARENA_TO_IG_MUSEUM[model_name4] if model_name4 in ARENA_TO_IG_MUSEUM else model_name4
14
-
15
- domain = DOMAIN
16
- baselink = domain + task_name
17
-
18
- matched_results = fetch_indexes(baselink)
19
- r = random.Random()
20
- uid, value = r.choice(list(matched_results.items()))
21
- image_link_1 = baselink + "/" + model_name1 + "/" + uid
22
- image_link_2 = baselink + "/" + model_name2 + "/" + uid
23
- image_link_3 = baselink + "/" + model_name3 + "/" + uid
24
- image_link_4 = baselink + "/" + model_name4 + "/" + uid
25
-
26
- if task == "t2i": # Image Gen
27
- prompt = value['prompt']
28
- return [[image_link_1, image_link_2, image_link_3, image_link_4], [prompt]]
29
- if task == "tie": # Image Edit
30
- instruction = value['instruction']
31
- input_caption = value['source_global_caption']
32
- output_caption = value['target_global_caption']
33
- source_image_link = baselink + "/" + "input" + "/" + uid
34
- return [[source_image_link, image_link_1, image_link_2, image_link_3, image_link_4], [input_caption, output_caption, instruction]]
35
- else:
36
- raise ValueError("Task not supported")
37
-
38
- def draw_from_imagen_museum(task, model_name):
39
- task_name = TASK_DICT[task]
40
- model_name = ARENA_TO_IG_MUSEUM[model_name] if model_name in ARENA_TO_IG_MUSEUM else model_name
41
-
42
- domain = DOMAIN
43
- baselink = domain + task_name
44
-
45
- matched_results = fetch_indexes(baselink)
46
- r = random.Random()
47
- uid, value = r.choice(list(matched_results.items()))
48
- model = model_name
49
- image_link = baselink + "/" + model + "/" + uid
50
- print(image_link)
51
-
52
- if task == "t2i": # Image Gen
53
- prompt = value['prompt']
54
- return [image_link, prompt]
55
- if task == "tie": # Image Edit
56
- instruction = value['instruction']
57
- input_caption = value['source_global_caption']
58
- output_caption = value['target_global_caption']
59
- source_image_link = baselink + "/" + "input" + "/" + uid
60
- return [[source_image_link, image_link], [input_caption, output_caption, instruction]]
61
- else:
62
- raise ValueError("Task not supported")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
model/fetch_museum_results/imagen_museum/__init__.py DELETED
@@ -1,129 +0,0 @@
1
- import csv
2
- import requests
3
- from io import StringIO
4
- from typing import Union, Optional, Tuple
5
- from PIL import Image
6
- import random
7
-
8
- __version__ = "0.0.1_GenAI_Arena"
9
-
10
- DOMAIN = "https://chromaica.github.io/Museum/"
11
-
12
- TASK_DICT = {
13
- "t2i": "ImagenHub_Text-Guided_IG",
14
- "tie": "ImagenHub_Text-Guided_IE",
15
- "mie": "ImagenHub_Control-Guided_IG",
16
- "cig": "ImagenHub_Control-Guided_IE",
17
- "msdig": "ImagenHub_Multi-Concept_IC",
18
- "sdig": "ImagenHub_Subject-Driven_IG",
19
- "sdie": "ImagenHub_Subject-Driven_IE"
20
- }
21
-
22
- t2i_models= [
23
- "SD",
24
- "SDXL",
25
- "OpenJourney",
26
- "DeepFloydIF",
27
- "DALLE2"
28
- ]
29
-
30
- mie_models = [
31
- "Glide",
32
- "SDInpaint",
33
- "BlendedDiffusion",
34
- "SDXLInpaint"
35
- ]
36
-
37
- tie_models = [
38
- "DiffEdit",
39
- "MagicBrush",
40
- "InstructPix2Pix",
41
- "Prompt2prompt",
42
- "Text2Live",
43
- "SDEdit",
44
- "CycleDiffusion",
45
- "Pix2PixZero"
46
- ]
47
-
48
- sdig_models = [
49
- "DreamBooth",
50
- "DreamBoothLora",
51
- "TextualInversion",
52
- "BLIPDiffusion_Gen"
53
- ]
54
-
55
- sdie_models = [
56
- "PhotoSwap",
57
- "DreamEdit",
58
- "BLIPDiffusion_Edit"
59
- ]
60
-
61
- msdig_models = [
62
- "DreamBooth",
63
- "CustomDiffusion",
64
- "TextualInversion"
65
- ]
66
-
67
- cig_models = [
68
- "ControlNet",
69
- "UniControl"
70
- ]
71
-
72
- def fetch_csv_keys(url):
73
- """
74
- Fetches a CSV file from a given URL and parses it into a list of keys,
75
- ignoring the header line.
76
- """
77
- response = requests.get(url)
78
- response.raise_for_status() # Ensure we notice bad responses
79
-
80
- # Use StringIO to turn the fetched text data into a file-like object
81
- csv_file = StringIO(response.text)
82
-
83
- # Create a CSV reader
84
- csv_reader = csv.reader(csv_file)
85
-
86
- # Skip the header
87
- next(csv_reader, None)
88
-
89
- # Return the list of keys
90
- return [row[0] for row in csv_reader if row]
91
-
92
- def fetch_json_data(url):
93
- """
94
- Fetches JSON data from a given URL.
95
- """
96
- response = requests.get(url)
97
- response.raise_for_status()
98
- return response.json()
99
-
100
- def fetch_data_and_match(csv_url, json_url):
101
- """
102
- Fetches a list of keys from a CSV and then fetches JSON data and matches the keys to the JSON.
103
- """
104
- # Fetch keys from CSV
105
- keys = fetch_csv_keys(csv_url)
106
-
107
- # Fetch JSON data
108
- json_data = fetch_json_data(json_url)
109
-
110
- # Extract relevant data using keys
111
- matched_data = {key: json_data.get(key) for key in keys if key in json_data}
112
-
113
- return matched_data
114
-
115
- def fetch_indexes(baselink):
116
- matched_results = fetch_data_and_match(baselink+"/dataset_lookup.csv", baselink+"/dataset_lookup.json")
117
- return matched_results
118
-
119
- if __name__ == "__main__":
120
- domain = "https://chromaica.github.io/Museum/"
121
- baselink = domain + "ImagenHub_Text-Guided_IE"
122
- matched_results = fetch_indexes(baselink)
123
- for uid, value in matched_results.items():
124
- print(uid)
125
- model = "CycleDiffusion"
126
- image_link = baselink + "/" + model + "/" + uid
127
- print(image_link)
128
- instruction = value['instruction']
129
- print(instruction)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
model/matchmaker.py CHANGED
@@ -81,9 +81,8 @@ def load_json_via_sftp():
81
  class RunningPivot(object):
82
  running_pivot = []
83
 
84
- not_run = [12,13,14,15,16,17,18,19,20,21,22, 25,26] #23,24,
85
 
86
- def matchmaker(num_players, k_group=4):
87
  trueskill_env = TrueSkill()
88
 
89
  ratings, comparison_counts, total_comparisons = load_json_via_sftp()
 
81
  class RunningPivot(object):
82
  running_pivot = []
83
 
 
84
 
85
+ def matchmaker(num_players, k_group=4, not_run=[]):
86
  trueskill_env = TrueSkill()
87
 
88
  ratings, comparison_counts, total_comparisons = load_json_via_sftp()
model/model_manager.py CHANGED
@@ -9,7 +9,7 @@ from PIL import Image
9
  from openai import OpenAI
10
  from .models import IMAGE_GENERATION_MODELS, IMAGE_EDITION_MODELS, VIDEO_GENERATION_MODELS, load_pipeline
11
  from .fetch_museum_results import draw_from_imagen_museum, draw2_from_imagen_museum
12
- from serve.upload import get_random_mscoco_prompt, get_random_video_prompt, get_ssh_random_video_prompt
13
  from serve.constants import SSH_CACHE_OPENSOURCE, SSH_CACHE_ADVANCE, SSH_CACHE_PIKA, SSH_CACHE_SORA, SSH_CACHE_IMAGE
14
 
15
 
@@ -73,14 +73,9 @@ class ModelManager:
73
 
74
  def generate_image_ig_parallel_anony(self, prompt, model_A, model_B, model_C, model_D):
75
  if model_A == "" and model_B == "" and model_C == "" and model_D == "":
76
- # not_run = [11, 12, 13, 14, 15, 16, 17, 18, 19]
77
- # filtered_models = [model for i, model in enumerate(self.model_ig_list) if i not in not_run]
78
- # model_names = random.sample([model for model in filtered_models], 4)
79
-
80
- # model_names = random.sample([model for model in self.model_ig_list], 4)
81
-
82
  from .matchmaker import matchmaker
83
- model_ids = matchmaker(num_players=len(self.model_ig_list))
 
84
  print(model_ids)
85
  model_names = [self.model_ig_list[i] for i in model_ids]
86
  print(model_names)
@@ -97,6 +92,26 @@ class ModelManager:
97
  return results[0], results[1], results[2], results[3], \
98
  model_names[0], model_names[1], model_names[2], model_names[3]
99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
  def generate_video_vg_parallel_anony(self, model_A, model_B, model_C, model_D):
101
  if model_A == "" and model_B == "" and model_C == "" and model_D == "":
102
  # model_names = random.sample([model for model in self.model_vg_list], 4)
@@ -123,23 +138,6 @@ class ModelManager:
123
  os.makedirs(local_dir)
124
  prompt, results = get_ssh_random_video_prompt(root_dir, local_dir, model_names)
125
  cache_dir = local_dir
126
- # cache_dir, prompt = get_random_video_prompt(root_dir)
127
- # results = []
128
- # for name in model_names:
129
- # model_source, model_name, model_type = name.split("_")
130
- # # if model_name in ["Runway-Gen3", "Pika-beta", "Pika-v1.0"]:
131
- # # file_name = cache_dir.split("/")[-1]
132
- # # video_path = os.path.join(cache_dir, f'{file_name}.mp4')
133
- # # else:
134
- # # video_path = os.path.join(cache_dir, f'{model_name}.mp4')
135
- # video_path = os.path.join(cache_dir, f'{model_name}.mp4')
136
- # print(video_path)
137
- # results.append(video_path)
138
-
139
- # with concurrent.futures.ThreadPoolExecutor() as executor:
140
- # futures = [executor.submit(self.generate_image_ig, prompt, model) if model.startswith("huggingface")
141
- # else executor.submit(self.generate_image_ig_api, prompt, model) for model in model_names]
142
- # results = [future.result() for future in futures]
143
 
144
  return results[0], results[1], results[2], results[3], \
145
  model_names[0], model_names[1], model_names[2], model_names[3], prompt, cache_dir
 
9
  from openai import OpenAI
10
  from .models import IMAGE_GENERATION_MODELS, IMAGE_EDITION_MODELS, VIDEO_GENERATION_MODELS, load_pipeline
11
  from .fetch_museum_results import draw_from_imagen_museum, draw2_from_imagen_museum
12
+ from serve.upload import get_random_mscoco_prompt, get_random_video_prompt, get_ssh_random_video_prompt, get_ssh_random_image_prompt
13
  from serve.constants import SSH_CACHE_OPENSOURCE, SSH_CACHE_ADVANCE, SSH_CACHE_PIKA, SSH_CACHE_SORA, SSH_CACHE_IMAGE
14
 
15
 
 
73
 
74
  def generate_image_ig_parallel_anony(self, prompt, model_A, model_B, model_C, model_D):
75
  if model_A == "" and model_B == "" and model_C == "" and model_D == "":
 
 
 
 
 
 
76
  from .matchmaker import matchmaker
77
+ not_run = [12,13,14,15,16,17,18,19,20,21,22, 25,26] #23,24,
78
+ model_ids = matchmaker(num_players=len(self.model_ig_list), not_run=not_run)
79
  print(model_ids)
80
  model_names = [self.model_ig_list[i] for i in model_ids]
81
  print(model_names)
 
92
  return results[0], results[1], results[2], results[3], \
93
  model_names[0], model_names[1], model_names[2], model_names[3]
94
 
95
+ def generate_image_ig_cache_anony(self, model_A, model_B, model_C, model_D):
96
+ if model_A == "" and model_B == "" and model_C == "" and model_D == "":
97
+ from .matchmaker import matchmaker
98
+ not_run = [20,21,22]
99
+ model_ids = matchmaker(num_players=len(self.model_ig_list), not_run=not_run)
100
+ print(model_ids)
101
+ model_names = [self.model_ig_list[i] for i in model_ids]
102
+ print(model_names)
103
+ else:
104
+ model_names = [model_A, model_B, model_C, model_D]
105
+
106
+ root_dir = SSH_CACHE_IMAGE
107
+ local_dir = "./cache_image"
108
+ if not os.path.exists(local_dir):
109
+ os.makedirs(local_dir)
110
+ prompt, results = get_ssh_random_image_prompt(root_dir, local_dir, model_names)
111
+
112
+ return results[0], results[1], results[2], results[3], \
113
+ model_names[0], model_names[1], model_names[2], model_names[3], prompt
114
+
115
  def generate_video_vg_parallel_anony(self, model_A, model_B, model_C, model_D):
116
  if model_A == "" and model_B == "" and model_C == "" and model_D == "":
117
  # model_names = random.sample([model for model in self.model_vg_list], 4)
 
138
  os.makedirs(local_dir)
139
  prompt, results = get_ssh_random_video_prompt(root_dir, local_dir, model_names)
140
  cache_dir = local_dir
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
 
142
  return results[0], results[1], results[2], results[3], \
143
  model_names[0], model_names[1], model_names[2], model_names[3], prompt, cache_dir
model/models/__init__.py CHANGED
@@ -1,7 +1,3 @@
1
- # from .imagenhub_models import load_imagenhub_model
2
- # from .playground_api import load_playground_model
3
- # from .fal_api_models import load_fal_model
4
- # from .videogenhub_models import load_videogenhub_model
5
  from .huggingface_models import load_huggingface_model
6
  from .replicate_api_models import load_replicate_model
7
  from .openai_api_models import load_openai_model
 
 
 
 
 
1
  from .huggingface_models import load_huggingface_model
2
  from .replicate_api_models import load_replicate_model
3
  from .openai_api_models import load_openai_model
model/models/fal_api_models.py DELETED
@@ -1,103 +0,0 @@
1
- import fal_client
2
- from PIL import Image
3
- import requests
4
- import io
5
- import os
6
- import base64
7
-
8
- FAL_MODEl_NAME_MAP = {"SDXL": "fast-sdxl", "SDXLTurbo": "fast-turbo-diffusion", "SDXLLightning": "fast-lightning-sdxl",
9
- "LCM(v1.5/XL)": "fast-lcm-diffusion", "PixArtSigma": "pixart-sigma", "StableCascade": "stable-cascade"}
10
-
11
- class FalModel():
12
- def __init__(self, model_name, model_type):
13
- self.model_name = model_name
14
- self.model_type = model_type
15
- os.environ['FAL_KEY'] = os.environ['FalAPI']
16
-
17
- def __call__(self, *args, **kwargs):
18
- def decode_data_url(data_url):
19
- # Find the start of the Base64 encoded data
20
- base64_start = data_url.find(",") + 1
21
- if base64_start == 0:
22
- raise ValueError("Invalid data URL provided")
23
-
24
- # Extract the Base64 encoded data
25
- base64_string = data_url[base64_start:]
26
-
27
- # Decode the Base64 string
28
- decoded_bytes = base64.b64decode(base64_string)
29
-
30
- return decoded_bytes
31
-
32
- if self.model_type == "text2image":
33
- assert "prompt" in kwargs, "prompt is required for text2image model"
34
- handler = fal_client.submit(
35
- f"fal-ai/{FAL_MODEl_NAME_MAP[self.model_name]}",
36
- arguments={
37
- "prompt": kwargs["prompt"]
38
- },
39
- )
40
- for event in handler.iter_events(with_logs=True):
41
- if isinstance(event, fal_client.InProgress):
42
- print('Request in progress')
43
- print(event.logs)
44
- result = handler.get()
45
- print(result)
46
- result_url = result['images'][0]['url']
47
- if self.model_name in ["SDXLTurbo", "LCM(v1.5/XL)"]:
48
- result_url = io.BytesIO(decode_data_url(result_url))
49
- result = Image.open(result_url)
50
- else:
51
- response = requests.get(result_url)
52
- result = Image.open(io.BytesIO(response.content))
53
- return result
54
- elif self.model_type == "image2image":
55
- raise NotImplementedError("image2image model is not implemented yet")
56
- # assert "image" in kwargs or "image_url" in kwargs, "image or image_url is required for image2image model"
57
- # if "image" in kwargs:
58
- # image_url = None
59
- # pass
60
- # handler = fal_client.submit(
61
- # f"fal-ai/{self.model_name}",
62
- # arguments={
63
- # "image_url": image_url
64
- # },
65
- # )
66
- #
67
- # for event in handler.iter_events():
68
- # if isinstance(event, fal_client.InProgress):
69
- # print('Request in progress')
70
- # print(event.logs)
71
- #
72
- # result = handler.get()
73
- # return result
74
- elif self.model_type == "text2video":
75
- assert "prompt" in kwargs, "prompt is required for text2video model"
76
- if self.model_name == 'AnimateDiff':
77
- fal_model_name = 'fast-animatediff/text-to-video'
78
- elif self.model_name == 'AnimateDiffTurbo':
79
- fal_model_name = 'fast-animatediff/turbo/text-to-video'
80
- else:
81
- raise NotImplementedError(f"text2video model of {self.model_name} in fal is not implemented yet")
82
- handler = fal_client.submit(
83
- f"fal-ai/{fal_model_name}",
84
- arguments={
85
- "prompt": kwargs["prompt"]
86
- },
87
- )
88
-
89
- for event in handler.iter_events(with_logs=True):
90
- if isinstance(event, fal_client.InProgress):
91
- print('Request in progress')
92
- print(event.logs)
93
-
94
- result = handler.get()
95
- print("result video: ====")
96
- print(result)
97
- result_url = result['video']['url']
98
- return result_url
99
- else:
100
- raise ValueError("model_type must be text2image or image2image")
101
-
102
- def load_fal_model(model_name, model_type):
103
- return FalModel(model_name, model_type)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
model/models/imagenhub_models.py DELETED
@@ -1,50 +0,0 @@
1
- import imagen_hub
2
-
3
- class ImagenHubModel():
4
- def __init__(self, model_name):
5
- self.model = imagen_hub.load(model_name)
6
-
7
- def __call__(self, *args, **kwargs):
8
- return self.model.infer_one_image(*args, **kwargs)
9
-
10
- class PNP(ImagenHubModel):
11
- def __init__(self):
12
- super().__init__('PNP')
13
-
14
- def __call__(self, *args, **kwargs):
15
- if "num_inversion_steps" not in kwargs:
16
- kwargs["num_inversion_steps"] = 200
17
- return super().__call__(*args, **kwargs)
18
-
19
- class Prompt2prompt(ImagenHubModel):
20
- def __init__(self):
21
- super().__init__('Prompt2prompt')
22
-
23
- def __call__(self, *args, **kwargs):
24
- if "num_inner_steps" not in kwargs:
25
- kwargs["num_inner_steps"] = 3
26
- return super().__call__(*args, **kwargs)
27
-
28
- def load_imagenhub_model(model_name, model_type=None):
29
- if model_name == 'PNP':
30
- return PNP()
31
- if model_name == 'Prompt2prompt':
32
- return Prompt2prompt()
33
- return ImagenHubModel(model_name)
34
-
35
-
36
- # for name in ['DeepFloydIF', 'PixArtAlpha', 'Kandinsky']: #, 'OpenJourney', 'LCM', 'SD' 'SDXL'
37
- # #
38
- # pipe = ImagenHubModel(name)
39
- # result = pipe(prompt='a cute dog is playing a ball')
40
- # print(result)
41
-
42
- # for name in ['SD']:
43
- # from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler
44
- # import torch
45
- # pipe = DiffusionPipeline.from_pretrained(
46
- # "stabilityai/stable-diffusion-2-base",
47
- # torch_dtype=torch.float16,
48
- # safety_checker=None,
49
- # ).to("cuda")
50
- # pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
model/models/playground_api.py DELETED
@@ -1,35 +0,0 @@
1
- import os
2
- import json
3
- import requests
4
- from PIL import Image
5
- import io
6
- import base64
7
- class PlayGround():
8
- def __init__(self, model_name, model_type=None):
9
- self.model_name = model_name
10
- self.model_type = model_type
11
- self.api_key = os.environ['PlaygroundAPI']
12
- if model_name == "PlayGroundV2":
13
- self._model_name = "Playground_v2"
14
- elif model_name == "PlayGroundV2.5":
15
- self._model_name = "Playground_v2.5"
16
-
17
-
18
- def __call__(self, prompt):
19
- headers = {
20
- 'Content-Type': 'application/json',
21
- 'Authorization': "Bearer " + self.api_key,
22
- }
23
-
24
- data = json.dumps({"prompt": prompt, "filter_model": self._model_name, "scheduler": "DPMPP_2M_K", "guidance_scale": 3})
25
-
26
- response = requests.post('https://playground.com/api/models/external/v1', headers=headers, data=data)
27
- response.raise_for_status()
28
- json_obj = response.json()
29
- image_base64 = json_obj['images'][0]
30
- img = Image.open(io.BytesIO(base64.decodebytes(bytes(image_base64, "utf-8"))))
31
-
32
- return img
33
-
34
- def load_playground_model(model_name, model_type="generation"):
35
- return PlayGround(model_name, model_type)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
serve/gradio_web.py CHANGED
@@ -16,6 +16,7 @@ from .vote_utils import (
16
  generate_igm_museum,
17
  generate_igm_annoy,
18
  generate_igm_annoy_museum,
 
19
  share_js
20
  )
21
  from .Ksort import (
@@ -68,7 +69,9 @@ def build_side_by_side_ui_anony(models):
68
  state3 = gr.State()
69
 
70
  gen_func = partial(generate_igm_annoy, models.generate_image_ig_parallel_anony)
71
- gen_func_random = partial(generate_igm_annoy_museum, models.generate_image_ig_museum_parallel_anony)
 
 
72
 
73
  gr.Markdown(notice_markdown, elem_id="notice_markdown")
74
 
@@ -216,6 +219,8 @@ def build_side_by_side_ui_anony(models):
216
  # draw_btn = gr.Button(value="🎲 Random sample", variant="primary", scale=0, elem_id="btnblue", elem_classes="send-button")
217
  send_btn = gr.Button(value="Send", variant="primary", scale=0, elem_id="btnblue")
218
  draw_btn = gr.Button(value="🎲 Random Prompt", variant="primary", scale=0, elem_id="btnblue")
 
 
219
  with gr.Row():
220
  clear_btn = gr.Button(value="🎲 New Round", interactive=False)
221
  # regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False)
@@ -238,7 +243,7 @@ def build_side_by_side_ui_anony(models):
238
  ["The bathroom with green tile and a red shower curtain.", os.path.join("./examples", "example4.jpg")]],
239
  inputs = [textbox, dummy_img_output])
240
 
241
- order_btn_list = [textbox, send_btn, draw_btn, clear_btn]
242
  vote_order_list = [leftvote_btn, left1vote_btn, rightvote_btn, right1vote_btn, tie_btn, \
243
  A1_btn, A2_btn, A3_btn, A4_btn, B1_btn, B2_btn, B3_btn, B4_btn, C1_btn, C2_btn, C3_btn, C4_btn, D1_btn, D2_btn, D3_btn, D4_btn, \
244
  vote_textbox, vote_submit_btn, vote_mode_btn]
@@ -281,6 +286,7 @@ def build_side_by_side_ui_anony(models):
281
  # Top4_btn.click(reset_level, inputs=[Top4_text], outputs=[vote_level])
282
  vote_mode = gr.Textbox(value="Rank", visible=False, interactive=False)
283
  right_vote_text = gr.Textbox(value="wrong", visible=False, interactive=False)
 
284
 
285
  textbox.submit(
286
  disable_order_buttons,
@@ -314,21 +320,22 @@ def build_side_by_side_ui_anony(models):
314
  outputs=vote_order_list
315
  )
316
 
317
- # draw_btn.click(
318
- # gen_func_random,
319
- # inputs=[state0, state1, state2, state3, model_selector_left, model_selector_left1, model_selector_right, model_selector_right1],
320
- # outputs=[state0, state1, state2, state3, generate_ig0, generate_ig1, generate_ig2, generate_ig3, chatbot_left, chatbot_left1, chatbot_right, chatbot_right1, \
321
- # textbox, model_selector_left, model_selector_left1, model_selector_right, model_selector_right1],
322
- # api_name="draw_btn_annony"
323
- # ).then(
324
- # disable_order_buttons,
325
- # inputs=None,
326
- # outputs=order_btn_list
327
- # ).then(
328
- # enable_vote_mode_buttons,
329
- # inputs=[vote_mode],
330
- # outputs=vote_order_list
331
- # )
 
332
  draw_btn.click(
333
  get_random_mscoco_prompt,
334
  inputs=None,
 
16
  generate_igm_museum,
17
  generate_igm_annoy,
18
  generate_igm_annoy_museum,
19
+ generate_igm_cache_annoy,
20
  share_js
21
  )
22
  from .Ksort import (
 
69
  state3 = gr.State()
70
 
71
  gen_func = partial(generate_igm_annoy, models.generate_image_ig_parallel_anony)
72
+ gen_cache_func = partial(generate_igm_cache_annoy, models.generate_image_ig_cache_anony)
73
+
74
+ # gen_func_random = partial(generate_igm_annoy_museum, models.generate_image_ig_museum_parallel_anony)
75
 
76
  gr.Markdown(notice_markdown, elem_id="notice_markdown")
77
 
 
219
  # draw_btn = gr.Button(value="🎲 Random sample", variant="primary", scale=0, elem_id="btnblue", elem_classes="send-button")
220
  send_btn = gr.Button(value="Send", variant="primary", scale=0, elem_id="btnblue")
221
  draw_btn = gr.Button(value="🎲 Random Prompt", variant="primary", scale=0, elem_id="btnblue")
222
+ with gr.Row():
223
+ cache_btn = gr.Button(value="🎲 Random Sample", interactive=True)
224
  with gr.Row():
225
  clear_btn = gr.Button(value="🎲 New Round", interactive=False)
226
  # regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False)
 
243
  ["The bathroom with green tile and a red shower curtain.", os.path.join("./examples", "example4.jpg")]],
244
  inputs = [textbox, dummy_img_output])
245
 
246
+ order_btn_list = [textbox, send_btn, draw_btn, cache_btn, clear_btn]
247
  vote_order_list = [leftvote_btn, left1vote_btn, rightvote_btn, right1vote_btn, tie_btn, \
248
  A1_btn, A2_btn, A3_btn, A4_btn, B1_btn, B2_btn, B3_btn, B4_btn, C1_btn, C2_btn, C3_btn, C4_btn, D1_btn, D2_btn, D3_btn, D4_btn, \
249
  vote_textbox, vote_submit_btn, vote_mode_btn]
 
286
  # Top4_btn.click(reset_level, inputs=[Top4_text], outputs=[vote_level])
287
  vote_mode = gr.Textbox(value="Rank", visible=False, interactive=False)
288
  right_vote_text = gr.Textbox(value="wrong", visible=False, interactive=False)
289
+ cache_mode = gr.Textbox(value="True", visible=False, interactive=False)
290
 
291
  textbox.submit(
292
  disable_order_buttons,
 
320
  outputs=vote_order_list
321
  )
322
 
323
+ cache_btn.click(
324
+ disable_order_buttons,
325
+ inputs=[textbox, cache_mode],
326
+ outputs=order_btn_list
327
+ ).then(
328
+ gen_cache_func,
329
+ inputs=[state0, state1, state2, state3, model_selector_left, model_selector_left1, model_selector_right, model_selector_right1],
330
+ outputs=[state0, state1, state2, state3, generate_ig0, generate_ig1, generate_ig2, generate_ig3, chatbot_left, chatbot_left1, chatbot_right, chatbot_right1, \
331
+ model_selector_left, model_selector_left1, model_selector_right, model_selector_right1, textbox],
332
+ api_name="send_btn_annony"
333
+ ).then(
334
+ enable_vote_mode_buttons,
335
+ inputs=[vote_mode, textbox],
336
+ outputs=vote_order_list
337
+ )
338
+
339
  draw_btn.click(
340
  get_random_mscoco_prompt,
341
  inputs=None,
serve/gradio_web_video.py CHANGED
@@ -310,7 +310,7 @@ def build_side_by_side_video_ui_anony(models):
310
  model_selector_left, model_selector_left1, model_selector_right, model_selector_right1],
311
  api_name="clear_btn_annony"
312
  ).then(
313
- enable_order_buttons,
314
  inputs=None,
315
  outputs=order_btn_list
316
  ).then(
 
310
  model_selector_left, model_selector_left1, model_selector_right, model_selector_right1],
311
  api_name="clear_btn_annony"
312
  ).then(
313
+ enable_video_order_buttons,
314
  inputs=None,
315
  outputs=order_btn_list
316
  ).then(
serve/upload.py CHANGED
@@ -146,8 +146,64 @@ def get_ssh_random_video_prompt(root_dir, local_dir, model_names):
146
  except Exception as e:
147
  print(f"An error occurred: {e}")
148
  raise NotImplementedError
 
 
149
  return prompt, local_path[1:]
150
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
  def create_remote_directory(remote_directory, video=False):
152
  global ssh_client
153
  if not is_connected():
 
146
  except Exception as e:
147
  print(f"An error occurred: {e}")
148
  raise NotImplementedError
149
+ sftp.close()
150
+ ssh.close()
151
  return prompt, local_path[1:]
152
 
153
+ def get_ssh_random_image_prompt(root_dir, local_dir, model_names):
154
+ def is_directory(sftp, path):
155
+ try:
156
+ return stat.S_ISDIR(sftp.stat(path).st_mode)
157
+ except IOError:
158
+ return False
159
+ ssh = paramiko.SSHClient()
160
+ ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
161
+ try:
162
+ ssh.connect(SSH_SERVER, SSH_PORT, SSH_USER, SSH_PASSWORD)
163
+ sftp = ssh.open_sftp()
164
+
165
+ remote_subdirs = sftp.listdir(root_dir)
166
+ remote_subdirs = [d for d in remote_subdirs if is_directory(sftp, os.path.join(root_dir, d))]
167
+
168
+ if not remote_subdirs:
169
+ print(f"No subdirectories found in {root_dir}")
170
+ raise NotImplementedError
171
+
172
+ chosen_subdir = random.choice(remote_subdirs)
173
+ chosen_subdir_path = os.path.join(root_dir, chosen_subdir)
174
+ print(f"Chosen subdirectory: {chosen_subdir_path}")
175
+
176
+ prompt_path = 'prompt.txt'
177
+ results = [prompt_path]
178
+ for name in model_names:
179
+ model_source, model_name, model_type = name.split("_")
180
+ image_path = f'{model_name}.jpg'
181
+ print(image_path)
182
+ results.append(image_path)
183
+
184
+ local_path = []
185
+ for tar_file in results:
186
+ remote_file_path = os.path.join(chosen_subdir_path, tar_file)
187
+ local_file_path = os.path.join(local_dir, tar_file)
188
+ sftp.get(remote_file_path, local_file_path)
189
+ local_path.append(local_file_path)
190
+ print(f"Downloaded {remote_file_path} to {local_file_path}")
191
+
192
+ if os.path.exists(local_path[0]):
193
+ str_list = []
194
+ with open(local_path[0], 'r', encoding='utf-8') as file:
195
+ for line in file:
196
+ str_list.append(line.strip())
197
+ prompt = str_list[0]
198
+ else:
199
+ raise NotImplementedError
200
+ except Exception as e:
201
+ print(f"An error occurred: {e}")
202
+ raise NotImplementedError
203
+ sftp.close()
204
+ ssh.close()
205
+ return prompt, [Image.open(path) for path in local_path[1:]]
206
+
207
  def create_remote_directory(remote_directory, video=False):
208
  global ssh_client
209
  if not is_connected():
serve/utils.py CHANGED
@@ -203,12 +203,17 @@ def disable_vote_mode_buttons():
203
 
204
 
205
  def enable_order_buttons():
206
- return tuple(gr.update(interactive=True) for _ in range(4))
207
- def disable_order_buttons(textbox, video=False):
 
 
208
  if not textbox.strip():
209
- return (gr.update(interactive=True), gr.update(interactive=True), gr.update(interactive=True), gr.update(interactive=True))
210
  else:
211
- return (gr.update(interactive=False), gr.update(interactive=False), gr.update(interactive=False), gr.update(interactive=True))
 
 
 
212
  def disable_video_order_buttons():
213
  return (gr.update(interactive=False), gr.update(interactive=False), gr.update(interactive=False), gr.update(interactive=True))
214
 
 
203
 
204
 
205
  def enable_order_buttons():
206
+ return tuple(gr.update(interactive=True) for _ in range(5))
207
+ def disable_order_buttons(textbox, cache="False"):
208
+ if cache=="True":
209
+ return (gr.update(interactive=False), gr.update(interactive=False), gr.update(interactive=False), gr.update(interactive=False), gr.update(interactive=True))
210
  if not textbox.strip():
211
+ return (gr.update(interactive=True), gr.update(interactive=True), gr.update(interactive=True), gr.update(interactive=True), gr.update(interactive=True))
212
  else:
213
+ return (gr.update(interactive=False), gr.update(interactive=False), gr.update(interactive=False), gr.update(interactive=False), gr.update(interactive=True))
214
+
215
+ def enable_video_order_buttons():
216
+ return tuple(gr.update(interactive=True) for _ in range(4))
217
  def disable_video_order_buttons():
218
  return (gr.update(interactive=False), gr.update(interactive=False), gr.update(interactive=False), gr.update(interactive=True))
219
 
serve/vote_utils.py CHANGED
@@ -871,6 +871,46 @@ def generate_igm_annoy(gen_func, state0, state1, state2, state3, text, model_nam
871
  # save_any_image(state.output, f)
872
  # save_image_file_on_log_server(output_file)
873
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
874
  def generate_vg_annoy(gen_func, state0, state1, state2, state3, model_name0, model_name1, model_name2, model_name3, request: gr.Request):
875
 
876
  if state0 is None:
 
871
  # save_any_image(state.output, f)
872
  # save_image_file_on_log_server(output_file)
873
 
874
+ def generate_igm_cache_annoy(gen_func, state0, state1, state2, state3, model_name0, model_name1, model_name2, model_name3, request: gr.Request):
875
+ if state0 is None:
876
+ state0 = ImageStateIG(model_name0)
877
+ if state1 is None:
878
+ state1 = ImageStateIG(model_name1)
879
+ if state2 is None:
880
+ state2 = ImageStateIG(model_name2)
881
+ if state3 is None:
882
+ state3 = ImageStateIG(model_name3)
883
+
884
+ ip = get_ip(request)
885
+ igm_logger.info(f"generate. ip: {ip}")
886
+ start_tstamp = time.time()
887
+ model_name0 = ""
888
+ model_name1 = ""
889
+ model_name2 = ""
890
+ model_name3 = ""
891
+
892
+ generated_image0, generated_image1, generated_image2, generated_image3, model_name0, model_name1, model_name2, model_name3, text \
893
+ = gen_func(model_name0, model_name1, model_name2, model_name3)
894
+ state0.prompt = text
895
+ state1.prompt = text
896
+ state2.prompt = text
897
+ state3.prompt = text
898
+
899
+ state0.output = generated_image0
900
+ state1.output = generated_image1
901
+ state2.output = generated_image2
902
+ state3.output = generated_image3
903
+
904
+ state0.model_name = model_name0
905
+ state1.model_name = model_name1
906
+ state2.model_name = model_name2
907
+ state3.model_name = model_name3
908
+
909
+ yield state0, state1, state2, state3, generated_image0, generated_image1, generated_image2, generated_image3, \
910
+ generated_image0, generated_image1, generated_image2, generated_image3, \
911
+ gr.Markdown(f"### Model A: {model_name0}", visible=False), gr.Markdown(f"### Model B: {model_name1}", visible=False), \
912
+ gr.Markdown(f"### Model C: {model_name2}", visible=False), gr.Markdown(f"### Model D: {model_name3}", visible=False), text
913
+
914
  def generate_vg_annoy(gen_func, state0, state1, state2, state3, model_name0, model_name1, model_name2, model_name3, request: gr.Request):
915
 
916
  if state0 is None: