ksort commited on
Commit
02f8ed6
·
1 Parent(s): f2e397c

Update ssh

Browse files
model/matchmaker.py CHANGED
@@ -6,9 +6,19 @@ import io, os
6
  import sys
7
  sys.path.append('../')
8
  from serve.constants import SSH_SERVER, SSH_PORT, SSH_USER, SSH_PASSWORD, SSH_SKILL
 
9
 
 
 
10
 
11
- trueskill_env = TrueSkill()
 
 
 
 
 
 
 
12
 
13
  def ucb_score(trueskill_diff, t, n):
14
  exploration_term = np.sqrt((2 * np.log(t + 1e-5)) / (n + 1e-5))
@@ -25,29 +35,20 @@ def serialize_rating(rating):
25
  def deserialize_rating(rating_dict):
26
  return trueskill_env.Rating(mu=rating_dict['mu'], sigma=rating_dict['sigma'])
27
 
28
- def create_ssh_client(server, port, user, password):
29
- ssh = paramiko.SSHClient()
30
- ssh.load_system_host_keys()
31
- ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
32
- ssh.connect(server, port, user, password)
33
- return ssh
34
-
35
  def save_json_via_sftp(ratings, comparison_counts, total_comparisons):
36
- ssh = create_ssh_client(SSH_SERVER, SSH_PORT, SSH_USER, SSH_PASSWORD)
37
  data = {
38
  'ratings': [serialize_rating(r) for r in ratings],
39
  'comparison_counts': comparison_counts.tolist(),
40
  'total_comparisons': total_comparisons
41
  }
42
  json_data = json.dumps(data)
43
- sftp = ssh.open_sftp()
44
- with sftp.open(SSH_SKILL, 'w') as f:
45
  f.write(json_data)
46
 
47
  def load_json_via_sftp():
48
- ssh = create_ssh_client(SSH_SERVER, SSH_PORT, SSH_USER, SSH_PASSWORD)
49
- sftp = ssh.open_sftp()
50
- with sftp.open(SSH_SKILL, 'r') as f:
51
  data = json.load(f)
52
  ratings = [deserialize_rating(r) for r in data['ratings']]
53
  comparison_counts = np.array(data['comparison_counts'])
 
6
  import sys
7
  sys.path.append('../')
8
  from serve.constants import SSH_SERVER, SSH_PORT, SSH_USER, SSH_PASSWORD, SSH_SKILL
9
+ trueskill_env = TrueSkill()
10
 
11
+ ssh_matchmaker_client = None
12
+ sftp_matchmaker_client = None
13
 
14
+ def create_ssh_matchmaker_client(server, port, user, password):
15
+ global ssh_matchmaker_client, sftp_matchmaker_client
16
+ ssh_matchmaker_client = paramiko.SSHClient()
17
+ ssh_matchmaker_client.load_system_host_keys()
18
+ ssh_matchmaker_client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
19
+ ssh_matchmaker_client.connect(server, port, user, password)
20
+
21
+ sftp_matchmaker_client = ssh_matchmaker_client.open_sftp()
22
 
23
  def ucb_score(trueskill_diff, t, n):
24
  exploration_term = np.sqrt((2 * np.log(t + 1e-5)) / (n + 1e-5))
 
35
  def deserialize_rating(rating_dict):
36
  return trueskill_env.Rating(mu=rating_dict['mu'], sigma=rating_dict['sigma'])
37
 
 
 
 
 
 
 
 
38
  def save_json_via_sftp(ratings, comparison_counts, total_comparisons):
39
+ global sftp_matchmaker_client
40
  data = {
41
  'ratings': [serialize_rating(r) for r in ratings],
42
  'comparison_counts': comparison_counts.tolist(),
43
  'total_comparisons': total_comparisons
44
  }
45
  json_data = json.dumps(data)
46
+ with sftp_matchmaker_client.open(SSH_SKILL, 'w') as f:
 
47
  f.write(json_data)
48
 
49
  def load_json_via_sftp():
50
+ global sftp_matchmaker_client
51
+ with sftp_matchmaker_client.open(SSH_SKILL, 'r') as f:
 
52
  data = json.load(f)
53
  ratings = [deserialize_rating(r) for r in data['ratings']]
54
  comparison_counts = np.array(data['comparison_counts'])
model/model_manager.py CHANGED
@@ -59,12 +59,6 @@ class ModelManager:
59
  else executor.submit(self.generate_image_ig_api, prompt, model) for model in model_names]
60
  results = [future.result() for future in futures]
61
 
62
- # with concurrent.futures.ThreadPoolExecutor() as executor:
63
- # futures = [executor.submit(self.generate_image_ig_api, prompt, model) for model in model_names]
64
- # results = [future.result() for future in futures]
65
-
66
- # results = [self.generate_image_ig(prompt, model) for model in model_names]
67
-
68
  return results[0], results[1], results[2], results[3], \
69
  model_names[0], model_names[1], model_names[2], model_names[3]
70
 
@@ -81,27 +75,12 @@ class ModelManager:
81
 
82
  prompt = get_random_mscoco_prompt()
83
  print(prompt)
84
- # with concurrent.futures.ThreadPoolExecutor() as executor:
85
- # model_1 = model_names[0].split('_')[1]
86
- # model_2 = model_names[1].split('_')[1]
87
- # model_3 = model_names[2].split('_')[1]
88
- # model_4 = model_names[3].split('_')[1]
89
-
90
- # result_list = draw2_from_imagen_museum("t2i", model_1, model_2, model_3, model_4)
91
- # image_links = result_list[0]
92
- # prompt_list = result_list[1]
93
- # print(prompt_list[0])
94
 
95
  with concurrent.futures.ThreadPoolExecutor() as executor:
96
  futures = [executor.submit(self.generate_image_ig, prompt, model) if model.startswith("huggingface")
97
  else executor.submit(self.generate_image_ig_api, prompt, model) for model in model_names]
98
  results = [future.result() for future in futures]
99
- # with concurrent.futures.ThreadPoolExecutor() as executor:
100
- # futures = [executor.submit(self.generate_image_ig_api, prompt, model) for model in model_names]
101
- # results = [future.result() for future in futures]
102
 
103
- # results = [self.generate_image_ig_api(prompt, model) for model in model_names]
104
- # results = [future.result() for future in futures]
105
  return results[0], results[1], results[2], results[3], \
106
  model_names[0], model_names[1], model_names[2], model_names[3], prompt
107
 
 
59
  else executor.submit(self.generate_image_ig_api, prompt, model) for model in model_names]
60
  results = [future.result() for future in futures]
61
 
 
 
 
 
 
 
62
  return results[0], results[1], results[2], results[3], \
63
  model_names[0], model_names[1], model_names[2], model_names[3]
64
 
 
75
 
76
  prompt = get_random_mscoco_prompt()
77
  print(prompt)
 
 
 
 
 
 
 
 
 
 
78
 
79
  with concurrent.futures.ThreadPoolExecutor() as executor:
80
  futures = [executor.submit(self.generate_image_ig, prompt, model) if model.startswith("huggingface")
81
  else executor.submit(self.generate_image_ig_api, prompt, model) for model in model_names]
82
  results = [future.result() for future in futures]
 
 
 
83
 
 
 
84
  return results[0], results[1], results[2], results[3], \
85
  model_names[0], model_names[1], model_names[2], model_names[3], prompt
86
 
model/models/huggingface_models.py CHANGED
@@ -6,13 +6,20 @@ import torch
6
 
7
 
8
  def load_huggingface_model(model_name, model_type):
 
 
 
 
 
 
 
9
  if model_name == "SD-turbo":
10
- pipe = AutoPipelineForText2Image.from_pretrained("stabilityai/sd-turbo", torch_dtype=torch.float16, variant="fp16")
11
  elif model_name == "SDXL-turbo":
12
- pipe = AutoPipelineForText2Image.from_pretrained("stabilityai/sdxl-turbo", torch_dtype=torch.float16, variant="fp16")
13
  else:
14
  raise NotImplementedError
15
- pipe = pipe.to("cuda")
16
  return pipe
17
 
18
 
 
6
 
7
 
8
  def load_huggingface_model(model_name, model_type):
9
+ # if model_name == "SD-turbo":
10
+ # pipe = AutoPipelineForText2Image.from_pretrained("stabilityai/sd-turbo", torch_dtype=torch.float16, variant="fp16")
11
+ # elif model_name == "SDXL-turbo":
12
+ # pipe = AutoPipelineForText2Image.from_pretrained("stabilityai/sdxl-turbo", torch_dtype=torch.float16, variant="fp16")
13
+ # else:
14
+ # raise NotImplementedError
15
+ # pipe = pipe.to("cuda")
16
  if model_name == "SD-turbo":
17
+ pipe = AutoPipelineForText2Image.from_pretrained("stabilityai/sd-turbo")
18
  elif model_name == "SDXL-turbo":
19
+ pipe = AutoPipelineForText2Image.from_pretrained("stabilityai/sdxl-turbo")
20
  else:
21
  raise NotImplementedError
22
+ pipe = pipe.to("cpu")
23
  return pipe
24
 
25
 
serve/Ksort.py CHANGED
@@ -5,7 +5,7 @@ from .constants import KSORT_IMAGE_DIR
5
  from .constants import COLOR1, COLOR2, COLOR3, COLOR4
6
  from .vote_utils import save_any_image
7
  from .utils import disable_btn, enable_btn, invisible_btn
8
- from .upload import create_remote_directory, upload_image, upload_informance, upload_ssh_all
9
  import json
10
 
11
  def reset_level(Top_btn):
 
5
  from .constants import COLOR1, COLOR2, COLOR3, COLOR4
6
  from .vote_utils import save_any_image
7
  from .utils import disable_btn, enable_btn, invisible_btn
8
+ from .upload import create_remote_directory, upload_ssh_all
9
  import json
10
 
11
  def reset_level(Top_btn):
serve/gradio_web.py CHANGED
@@ -33,8 +33,11 @@ from .Ksort import (
33
  reset_vote_text,
34
  text_response_rank_igm,
35
  )
36
- from serve.upload import get_random_mscoco_prompt
 
 
37
  from functools import partial
 
38
 
39
  def build_side_by_side_ui_anony(models):
40
  notice_markdown = """
@@ -50,6 +53,9 @@ def build_side_by_side_ui_anony(models):
50
  """
51
 
52
  model_list = models.model_ig_list
 
 
 
53
 
54
  state0 = gr.State()
55
  state1 = gr.State()
 
33
  reset_vote_text,
34
  text_response_rank_igm,
35
  )
36
+ from .upload import get_random_mscoco_prompt, create_ssh_client
37
+ from .update_skill import create_ssh_skill_client
38
+ from model.matchmaker import create_ssh_matchmaker_client
39
  from functools import partial
40
+ from .constants import SSH_SERVER, SSH_PORT, SSH_USER, SSH_PASSWORD
41
 
42
  def build_side_by_side_ui_anony(models):
43
  notice_markdown = """
 
53
  """
54
 
55
  model_list = models.model_ig_list
56
+ create_ssh_client(SSH_SERVER, SSH_PORT, SSH_USER, SSH_PASSWORD)
57
+ create_ssh_skill_client(SSH_SERVER, SSH_PORT, SSH_USER, SSH_PASSWORD)
58
+ create_ssh_matchmaker_client(SSH_SERVER, SSH_PORT, SSH_USER, SSH_PASSWORD)
59
 
60
  state0 = gr.State()
61
  state1 = gr.State()
serve/update_skill.py CHANGED
@@ -9,6 +9,17 @@ trueskill_env = TrueSkill()
9
  sys.path.append('../')
10
  from model.models import IMAGE_GENERATION_MODELS
11
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
  def ucb_score(trueskill_diff, t, n):
14
  exploration_term = np.sqrt((2 * np.log(t + 1e-5)) / (n + 1e-5))
@@ -25,29 +36,21 @@ def serialize_rating(rating):
25
  def deserialize_rating(rating_dict):
26
  return trueskill_env.Rating(mu=rating_dict['mu'], sigma=rating_dict['sigma'])
27
 
28
- def create_ssh_client(server, port, user, password):
29
- ssh = paramiko.SSHClient()
30
- ssh.load_system_host_keys()
31
- ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
32
- ssh.connect(server, port, user, password)
33
- return ssh
34
 
35
  def save_json_via_sftp(ratings, comparison_counts, total_comparisons):
36
- ssh = create_ssh_client(SSH_SERVER, SSH_PORT, SSH_USER, SSH_PASSWORD)
37
  data = {
38
  'ratings': [serialize_rating(r) for r in ratings],
39
  'comparison_counts': comparison_counts.tolist(),
40
  'total_comparisons': total_comparisons
41
  }
42
  json_data = json.dumps(data)
43
- sftp = ssh.open_sftp()
44
- with sftp.open(SSH_SKILL, 'w') as f:
45
  f.write(json_data)
46
 
47
  def load_json_via_sftp():
48
- ssh = create_ssh_client(SSH_SERVER, SSH_PORT, SSH_USER, SSH_PASSWORD)
49
- sftp = ssh.open_sftp()
50
- with sftp.open(SSH_SKILL, 'r') as f:
51
  data = json.load(f)
52
  ratings = [deserialize_rating(r) for r in data['ratings']]
53
  comparison_counts = np.array(data['comparison_counts'])
 
9
  sys.path.append('../')
10
  from model.models import IMAGE_GENERATION_MODELS
11
 
12
+ ssh_skill_client = None
13
+ sftp_skill_client = None
14
+
15
+ def create_ssh_skill_client(server, port, user, password):
16
+ global ssh_skill_client, sftp_skill_client
17
+ ssh_skill_client = paramiko.SSHClient()
18
+ ssh_skill_client.load_system_host_keys()
19
+ ssh_skill_client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
20
+ ssh_skill_client.connect(server, port, user, password)
21
+
22
+ sftp_skill_client = ssh_skill_client.open_sftp()
23
 
24
  def ucb_score(trueskill_diff, t, n):
25
  exploration_term = np.sqrt((2 * np.log(t + 1e-5)) / (n + 1e-5))
 
36
  def deserialize_rating(rating_dict):
37
  return trueskill_env.Rating(mu=rating_dict['mu'], sigma=rating_dict['sigma'])
38
 
 
 
 
 
 
 
39
 
40
  def save_json_via_sftp(ratings, comparison_counts, total_comparisons):
41
+ global sftp_skill_client
42
  data = {
43
  'ratings': [serialize_rating(r) for r in ratings],
44
  'comparison_counts': comparison_counts.tolist(),
45
  'total_comparisons': total_comparisons
46
  }
47
  json_data = json.dumps(data)
48
+ with sftp_skill_client.open(SSH_SKILL, 'w') as f:
 
49
  f.write(json_data)
50
 
51
  def load_json_via_sftp():
52
+ global sftp_skill_client
53
+ with sftp_skill_client.open(SSH_SKILL, 'r') as f:
 
54
  data = json.load(f)
55
  ratings = [deserialize_rating(r) for r in data['ratings']]
56
  comparison_counts = np.array(data['comparison_counts'])
serve/upload.py CHANGED
@@ -7,81 +7,51 @@ import json
7
  import random
8
  from .constants import SSH_SERVER, SSH_PORT, SSH_USER, SSH_PASSWORD, SSH_LOG, SSH_MSCOCO
9
 
 
 
 
10
  def create_ssh_client(server, port, user, password):
11
- ssh = paramiko.SSHClient()
12
- ssh.load_system_host_keys()
13
- ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
14
- ssh.connect(server, port, user, password)
15
- return ssh
 
 
16
 
17
  def get_image_from_url(image_url):
18
  response = requests.get(image_url)
19
  response.raise_for_status() # success
20
  return Image.open(io.BytesIO(response.content))
21
 
22
- # def upload_image_via_sftp(ssh, image, remote_image_path):
23
- # if isinstance(image, str):
24
- # print("get url image")
25
- # image = get_image_from_url(image)
26
- # with ssh.open_sftp() as sftp:
27
- # with io.BytesIO() as image_byte_stream:
28
- # image.save(image_byte_stream, format='JPEG')
29
- # image_byte_stream.seek(0)
30
- # sftp.putfo(image_byte_stream, remote_image_path)
31
- # print(f"Successfully uploaded image to {remote_image_path}")
32
-
33
- # def upload_json_via_sftp(ssh, data, remote_json_path):
34
- # json_data = json.dumps(data, indent=4)
35
- # with ssh.open_sftp() as sftp:
36
- # with io.BytesIO(json_data.encode('utf-8')) as json_byte_stream:
37
- # sftp.putfo(json_byte_stream, remote_json_path)
38
- # print(f"Successfully uploaded JSON data to {remote_json_path}")
39
-
40
- def upload_image(states, output_dir):
41
- pass
42
- # ssh = create_ssh_client(SSH_SERVER, SSH_PORT, SSH_USER, SSH_PASSWORD)
43
- # for i in range(len(states)):
44
- # output_file = os.path.join(output_dir, f"{i}.jpg")
45
- # upload_image_via_sftp(ssh, states[i].output, output_file)
46
-
47
- def upload_informance(data, data_path):
48
- pass
49
- # ssh = create_ssh_client(SSH_SERVER, SSH_PORT, SSH_USER, SSH_PASSWORD)
50
- # upload_json_via_sftp(ssh, data, data_path)
51
  def get_random_mscoco_prompt():
52
- ssh = create_ssh_client(SSH_SERVER, SSH_PORT, SSH_USER, SSH_PASSWORD)
53
- sftp = ssh.open_sftp()
54
- files = sftp.listdir(SSH_MSCOCO)
55
  txt_files = [file for file in files if file.endswith('.txt')]
56
 
57
  selected_files = random.sample(txt_files, 1) # get one prompt
58
 
59
  for file in selected_files:
60
  remote_file_path = os.path.join(SSH_MSCOCO, file)
61
- with sftp.file(remote_file_path, 'r') as f:
62
  content = f.read().decode('utf-8')
63
  print(f"Content of {file}:")
64
  print("\n")
65
- sftp.close()
66
- ssh.close()
67
-
68
  return content
69
 
70
 
71
  def create_remote_directory(remote_directory):
72
- ssh = create_ssh_client(SSH_SERVER, SSH_PORT, SSH_USER, SSH_PASSWORD)
73
-
74
- stdin, stdout, stderr = ssh.exec_command(f'mkdir -p {SSH_LOG}/{remote_directory}')
75
  error = stderr.read().decode('utf-8')
76
  if error:
77
  print(f"Error: {error}")
78
  else:
79
  print(f"Directory {remote_directory} created successfully.")
80
- # ssh.close()
81
  return f'{SSH_LOG}/{remote_directory}'
82
 
83
  def upload_ssh_all(states, output_dir, data, data_path):
84
- ssh = create_ssh_client(SSH_SERVER, SSH_PORT, SSH_USER, SSH_PASSWORD)
85
  output_file_list = []
86
  image_list = []
87
  for i in range(len(states)):
@@ -89,17 +59,17 @@ def upload_ssh_all(states, output_dir, data, data_path):
89
  output_file_list.append(output_file)
90
  image_list.append(states[i].output)
91
 
92
- with ssh.open_sftp() as sftp:
93
- for i in range(len(output_file_list)):
94
- if isinstance(image_list[i], str):
95
- print("get url image")
96
- image_list[i] = get_image_from_url(image_list[i])
97
- with io.BytesIO() as image_byte_stream:
98
- image_list[i].save(image_byte_stream, format='JPEG')
99
- image_byte_stream.seek(0)
100
- sftp.putfo(image_byte_stream, output_file_list[i])
101
- print(f"Successfully uploaded image to {output_file_list[i]}")
102
- json_data = json.dumps(data, indent=4)
103
- with io.BytesIO(json_data.encode('utf-8')) as json_byte_stream:
104
- sftp.putfo(json_byte_stream, data_path)
105
- print(f"Successfully uploaded JSON data to {data_path}")
 
7
  import random
8
  from .constants import SSH_SERVER, SSH_PORT, SSH_USER, SSH_PASSWORD, SSH_LOG, SSH_MSCOCO
9
 
10
+ ssh_client = None
11
+ sftp_client = None
12
+
13
  def create_ssh_client(server, port, user, password):
14
+ global ssh_client, sftp_client
15
+ ssh_client = paramiko.SSHClient()
16
+ ssh_client.load_system_host_keys()
17
+ ssh_client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
18
+ ssh_client.connect(server, port, user, password)
19
+
20
+ sftp_client = ssh_client.open_sftp()
21
 
22
  def get_image_from_url(image_url):
23
  response = requests.get(image_url)
24
  response.raise_for_status() # success
25
  return Image.open(io.BytesIO(response.content))
26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  def get_random_mscoco_prompt():
28
+ global sftp_client
29
+ files = sftp_client.listdir(SSH_MSCOCO)
 
30
  txt_files = [file for file in files if file.endswith('.txt')]
31
 
32
  selected_files = random.sample(txt_files, 1) # get one prompt
33
 
34
  for file in selected_files:
35
  remote_file_path = os.path.join(SSH_MSCOCO, file)
36
+ with sftp_client.file(remote_file_path, 'r') as f:
37
  content = f.read().decode('utf-8')
38
  print(f"Content of {file}:")
39
  print("\n")
 
 
 
40
  return content
41
 
42
 
43
  def create_remote_directory(remote_directory):
44
+ global ssh_client
45
+ stdin, stdout, stderr = ssh_client.exec_command(f'mkdir -p {SSH_LOG}/{remote_directory}')
 
46
  error = stderr.read().decode('utf-8')
47
  if error:
48
  print(f"Error: {error}")
49
  else:
50
  print(f"Directory {remote_directory} created successfully.")
 
51
  return f'{SSH_LOG}/{remote_directory}'
52
 
53
  def upload_ssh_all(states, output_dir, data, data_path):
54
+ global sftp_client
55
  output_file_list = []
56
  image_list = []
57
  for i in range(len(states)):
 
59
  output_file_list.append(output_file)
60
  image_list.append(states[i].output)
61
 
62
+ # with sftp_client as sftp:
63
+ for i in range(len(output_file_list)):
64
+ if isinstance(image_list[i], str):
65
+ print("get url image")
66
+ image_list[i] = get_image_from_url(image_list[i])
67
+ with io.BytesIO() as image_byte_stream:
68
+ image_list[i].save(image_byte_stream, format='JPEG')
69
+ image_byte_stream.seek(0)
70
+ sftp_client.putfo(image_byte_stream, output_file_list[i])
71
+ print(f"Successfully uploaded image to {output_file_list[i]}")
72
+ json_data = json.dumps(data, indent=4)
73
+ with io.BytesIO(json_data.encode('utf-8')) as json_byte_stream:
74
+ sftp_client.putfo(json_byte_stream, data_path)
75
+ print(f"Successfully uploaded JSON data to {data_path}")