K-Sort-Arena / serve /update_skill.py
ksort's picture
enable ucb
7b853d0
raw
history blame
4.15 kB
import numpy as np
import json
from trueskill import TrueSkill
import paramiko
import io, os
import sys
from serve.constants import SSH_SERVER, SSH_PORT, SSH_USER, SSH_PASSWORD, SSH_SKILL
trueskill_env = TrueSkill()
sys.path.append('../')
from model.models import IMAGE_GENERATION_MODELS
ssh_skill_client = None
sftp_skill_client = None
def create_ssh_skill_client(server, port, user, password):
global ssh_skill_client, sftp_skill_client
ssh_skill_client = paramiko.SSHClient()
ssh_skill_client.load_system_host_keys()
ssh_skill_client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
ssh_skill_client.connect(server, port, user, password)
transport = ssh_skill_client.get_transport()
transport.set_keepalive(60)
sftp_skill_client = ssh_skill_client.open_sftp()
def is_connected():
global ssh_skill_client, sftp_skill_client
if ssh_skill_client is None or sftp_skill_client is None:
return False
# 检查SSH连接是否正常
if not ssh_skill_client.get_transport().is_active():
return False
# 检查SFTP连接是否正常
try:
sftp_skill_client.listdir('.') # 尝试列出根目录
except Exception as e:
print(f"Error checking SFTP connection: {e}")
return False
return True
def ucb_score(trueskill_diff, t, n):
exploration_term = np.sqrt((2 * np.log(t + 1e-5)) / (n + 1e-5))
ucb = -trueskill_diff + 1.0 * exploration_term
return ucb
def update_trueskill(ratings, ranks):
new_ratings = trueskill_env.rate(ratings, ranks)
return new_ratings
def serialize_rating(rating):
return {'mu': rating.mu, 'sigma': rating.sigma}
def deserialize_rating(rating_dict):
return trueskill_env.Rating(mu=rating_dict['mu'], sigma=rating_dict['sigma'])
def save_json_via_sftp(ratings, comparison_counts, total_comparisons):
global sftp_skill_client
if not is_connected():
create_ssh_skill_client(SSH_SERVER, SSH_PORT, SSH_USER, SSH_PASSWORD)
data = {
'ratings': [serialize_rating(r) for r in ratings],
'comparison_counts': comparison_counts.tolist(),
'total_comparisons': total_comparisons
}
json_data = json.dumps(data)
with sftp_skill_client.open(SSH_SKILL, 'w') as f:
f.write(json_data)
def load_json_via_sftp():
global sftp_skill_client
if not is_connected():
create_ssh_skill_client(SSH_SERVER, SSH_PORT, SSH_USER, SSH_PASSWORD)
with sftp_skill_client.open(SSH_SKILL, 'r') as f:
data = json.load(f)
ratings = [deserialize_rating(r) for r in data['ratings']]
comparison_counts = np.array(data['comparison_counts'])
total_comparisons = data['total_comparisons']
return ratings, comparison_counts, total_comparisons
def update_skill(rank, model_names, k_group=4):
ratings, comparison_counts, total_comparisons = load_json_via_sftp()
# group = Model_ID.group
group = []
for model_name in model_names:
group.append(IMAGE_GENERATION_MODELS.index(model_name))
print(group)
pairwise_comparisons = [(i, j) for i in range(len(group)) for j in range(i+1, len(group))]
for player1, player2 in pairwise_comparisons:
if rank[player1] < rank[player2]:
ranks = [0, 1]
updated_ratings = update_trueskill([[ratings[group[player1]]], [ratings[group[player2]]]], ranks)
ratings[group[player1]], ratings[group[player2]] = updated_ratings[0][0], updated_ratings[1][0]
elif rank[player1] > rank[player2]:
ranks = [1, 0]
updated_ratings = update_trueskill([[ratings[group[player1]]], [ratings[group[player2]]]], ranks)
ratings[group[player1]], ratings[group[player2]] = updated_ratings[0][0], updated_ratings[1][0]
comparison_counts[group[player1], group[player2]] += 1
comparison_counts[group[player2], group[player1]] += 1
total_comparisons += 1
save_json_via_sftp(ratings, comparison_counts, total_comparisons)
from model.matchmaker import RunningPivot
if group[0] in RunningPivot.running_pivot:
RunningPivot.running_pivot.remove(group[0])