Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,069 Bytes
3427608 0e99a0b 3427608 b177a48 3427608 a4b32da 02f8ed6 a4b32da 02f8ed6 745f608 3427608 745f608 a4b32da 745f608 533658a 745f608 a4b32da 745f608 a4b32da 3427608 a4b32da 3427608 a4b32da 3427608 a4b32da 3427608 02f8ed6 745f608 3427608 02f8ed6 3427608 a4b32da 3427608 02f8ed6 745f608 02f8ed6 3427608 b177a48 3427608 b177a48 3427608 7b853d0 b177a48 3427608 afa1318 7b853d0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 |
import numpy as np
import json
from trueskill import TrueSkill
import paramiko
import io, os
import sys
from serve.constants import SSH_SERVER, SSH_PORT, SSH_USER, SSH_PASSWORD, SSH_SKILL
trueskill_env = TrueSkill()
sys.path.append('../')
from model.models import IMAGE_GENERATION_MODELS
ssh_skill_client = None
sftp_skill_client = None
def create_ssh_skill_client(server, port, user, password):
global ssh_skill_client, sftp_skill_client
ssh_skill_client = paramiko.SSHClient()
ssh_skill_client.load_system_host_keys()
ssh_skill_client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
ssh_skill_client.connect(server, port, user, password)
transport = ssh_skill_client.get_transport()
transport.set_keepalive(60)
sftp_skill_client = ssh_skill_client.open_sftp()
def is_connected():
global ssh_skill_client, sftp_skill_client
if ssh_skill_client is None or sftp_skill_client is None:
return False
if not ssh_skill_client.get_transport().is_active():
return False
try:
sftp_skill_client.listdir('.')
except Exception as e:
print(f"Error checking SFTP connection: {e}")
return False
return True
def ucb_score(trueskill_diff, t, n):
exploration_term = np.sqrt((2 * np.log(t + 1e-5)) / (n + 1e-5))
ucb = -trueskill_diff + 1.0 * exploration_term
return ucb
def update_trueskill(ratings, ranks):
new_ratings = trueskill_env.rate(ratings, ranks)
return new_ratings
def serialize_rating(rating):
return {'mu': rating.mu, 'sigma': rating.sigma}
def deserialize_rating(rating_dict):
return trueskill_env.Rating(mu=rating_dict['mu'], sigma=rating_dict['sigma'])
def save_json_via_sftp(ratings, comparison_counts, total_comparisons):
global sftp_skill_client
if not is_connected():
create_ssh_skill_client(SSH_SERVER, SSH_PORT, SSH_USER, SSH_PASSWORD)
data = {
'ratings': [serialize_rating(r) for r in ratings],
'comparison_counts': comparison_counts.tolist(),
'total_comparisons': total_comparisons
}
json_data = json.dumps(data)
with sftp_skill_client.open(SSH_SKILL, 'w') as f:
f.write(json_data)
def load_json_via_sftp():
global sftp_skill_client
if not is_connected():
create_ssh_skill_client(SSH_SERVER, SSH_PORT, SSH_USER, SSH_PASSWORD)
with sftp_skill_client.open(SSH_SKILL, 'r') as f:
data = json.load(f)
ratings = [deserialize_rating(r) for r in data['ratings']]
comparison_counts = np.array(data['comparison_counts'])
total_comparisons = data['total_comparisons']
return ratings, comparison_counts, total_comparisons
def update_skill(rank, model_names, k_group=4):
ratings, comparison_counts, total_comparisons = load_json_via_sftp()
# group = Model_ID.group
group = []
for model_name in model_names:
group.append(IMAGE_GENERATION_MODELS.index(model_name))
print(group)
pairwise_comparisons = [(i, j) for i in range(len(group)) for j in range(i+1, len(group))]
for player1, player2 in pairwise_comparisons:
if rank[player1] < rank[player2]:
ranks = [0, 1]
updated_ratings = update_trueskill([[ratings[group[player1]]], [ratings[group[player2]]]], ranks)
ratings[group[player1]], ratings[group[player2]] = updated_ratings[0][0], updated_ratings[1][0]
elif rank[player1] > rank[player2]:
ranks = [1, 0]
updated_ratings = update_trueskill([[ratings[group[player1]]], [ratings[group[player2]]]], ranks)
ratings[group[player1]], ratings[group[player2]] = updated_ratings[0][0], updated_ratings[1][0]
comparison_counts[group[player1], group[player2]] += 1
comparison_counts[group[player2], group[player1]] += 1
total_comparisons += 1
save_json_via_sftp(ratings, comparison_counts, total_comparisons)
from model.matchmaker import RunningPivot
if group[0] in RunningPivot.running_pivot:
RunningPivot.running_pivot.remove(group[0]) |