Spaces:
Running
Running
import datetime | |
import oss2 | |
import time | |
import json | |
import uuid | |
import gradio as gr | |
import regex as re | |
from .utils import * | |
from .log_utils import build_logger, get_oss_bucket | |
from .constants import OSS_FILE_PREFIX | |
igm_logger = build_logger("gradio_web_server_image_generation_multi", "gr_web_image_generation_multi.log") # igm = image generation multi, loggers for side-by-side and battle | |
bucket = get_oss_bucket() | |
# Function to append content to a file in OSS | |
def append_to_oss_file(bucket, file_name, content): | |
""" | |
Append content to a file in OSS. | |
:param bucket: oss2.Bucket instance | |
:param file_name: OSS file path (including prefix and file name) | |
:param content: Content to be appended | |
""" | |
# Check if the file already exists | |
try: | |
result = bucket.get_object(file_name) | |
existing_content = result.read().decode('utf-8') | |
except oss2.exceptions.NoSuchKey: | |
existing_content = "" # If the file doesn't exist, initialize with empty content | |
# Append new content to the existing content | |
updated_content = existing_content + content + "\n" | |
# Upload the updated content back to OSS | |
bucket.put_object(file_name, updated_content) | |
# Modify the vote_last_response_igm function to write content to OSS | |
def vote_last_response_igm(states, vote_type, anony, request: gr.Request): | |
file_name = get_conv_log_filename() # Get the log file name | |
print(file_name) | |
oss_file_path = OSS_FILE_PREFIX + file_name # Construct the OSS file path | |
data = { | |
"tstamp": round(time.time(), 4), | |
"type": vote_type, | |
"models": [x.name for x in states], | |
"states": [{} for x in states], | |
"anony": anony, | |
"ip": get_ip(request), | |
} | |
json_data = json.dumps(data) # Convert data to JSON format | |
try: | |
# Call the append_to_oss_file function to write data to OSS | |
append_to_oss_file(bucket, oss_file_path, json_data) | |
igm_logger.info(f"Successfully wrote to OSS: {oss_file_path}") | |
except Exception as e: | |
igm_logger.error(f"Failed to write to OSS: {e}") | |
## Image Generation Multi (IGM) Side-by-Side and Battle | |
def leftvote_last_response_igm( | |
state0, state1, request: gr.Request | |
): | |
igm_logger.info(f"leftvote (named). ip: {get_ip(request)}") | |
vote_last_response_igm( | |
[state0, state1], "leftvote", False, request | |
) | |
return (disable_btn,) * 3 + ( | |
gr.Markdown(f"### ⬆ Model A: {state0.name}", visible=True), | |
gr.Markdown(f"### ⬇ Model B: {state1.name}", visible=True) | |
) | |
def rightvote_last_response_igm( | |
state0, state1, request: gr.Request | |
): | |
igm_logger.info(f"rightvote (named). ip: {get_ip(request)}") | |
vote_last_response_igm( | |
[state0, state1], "rightvote", False, request | |
) | |
return (disable_btn,) * 3 + ( | |
gr.Markdown(f"### ⬇ Model B: {state0.name}", visible=True), | |
gr.Markdown(f"### ⬆ Model B: {state1.name}", visible=True) | |
) | |
def bothbad_vote_last_response_igm( | |
state0, state1, request: gr.Request | |
): | |
igm_logger.info(f"bothbad_vote (named). ip: {get_ip(request)}") | |
vote_last_response_igm( | |
[state0, state1], "bothbad_vote", False, request | |
) | |
return (disable_btn,) * 3 + ( | |
gr.Markdown(f"### ⬇ Model A: {state0.name}", visible=True), | |
gr.Markdown(f"### ⬇ Model B: {state1.name}", visible=True) | |
) | |
def leftvote_last_response_igm_anony( | |
state0, state1, request: gr.Request | |
): | |
igm_logger.info(f"leftvote (named). ip: {get_ip(request)}") | |
vote_last_response_igm( | |
[state0, state1], "leftvote", True, request | |
) | |
return (disable_btn,) * 3 + ( | |
gr.Markdown(f"### ⬆ Model A: {state0.name}", visible=True), | |
gr.Markdown(f"### ⬇ Model B: {state1.name}", visible=True) | |
) | |
def rightvote_last_response_igm_anony( | |
state0, state1, request: gr.Request | |
): | |
igm_logger.info(f"rightvote (named). ip: {get_ip(request)}") | |
vote_last_response_igm( | |
[state0, state1], "rightvote", True, request | |
) | |
return (disable_btn,) * 3 + ( | |
gr.Markdown(f"### ⬇ Model B: {state0.name}", visible=True), | |
gr.Markdown(f"### ⬆ Model B: {state1.name}", visible=True) | |
) | |
def bothbad_vote_last_response_igm_anony( | |
state0, state1, request: gr.Request | |
): | |
igm_logger.info(f"bothbad_vote (named). ip: {get_ip(request)}") | |
vote_last_response_igm( | |
[state0, state1], "bothbad_vote", True, request | |
) | |
return (disable_btn,) * 3 + ( | |
gr.Markdown(f"### ⬇ Model A: {state0.name}", visible=True), | |
gr.Markdown(f"### ⬇ Model B: {state1.name}", visible=True) | |
) | |
share_js = """ | |
function (a, b, c, d) { | |
const captureElement = document.querySelector('#share-region-named'); | |
html2canvas(captureElement) | |
.then(canvas => { | |
canvas.style.display = 'none' | |
document.body.appendChild(canvas) | |
return canvas | |
}) | |
.then(canvas => { | |
const image = canvas.toDataURL('image/png') | |
const a = document.createElement('a') | |
a.setAttribute('download', 'chatbot-arena.png') | |
a.setAttribute('href', image) | |
a.click() | |
canvas.remove() | |
}); | |
return [a, b, c, d]; | |
} | |
""" | |
def share_click_igm(state0, state1, model_selector0, model_selector1, request: gr.Request): | |
igm_logger.info(f"share (anony). ip: {get_ip(request)}") | |
if state0 is not None and state1 is not None: | |
vote_last_response_igm( | |
[state0, state1], "share", [model_selector0, model_selector1], request | |
) | |
## All Generation Gradio Interface | |
class ImageStateIG: | |
def __init__(self, model_name): | |
self.conv_id = uuid.uuid4().hex | |
self.model_name = model_name | |
self.prompt = None | |
self.output = None | |
def dict(self): | |
base = { | |
"conv_id": self.conv_id, | |
"model_name": self.model_name, | |
"prompt": self.prompt | |
} | |
return base | |
class ImageStateIE: | |
def __init__(self, model_name): | |
self.conv_id = uuid.uuid4().hex | |
self.model_name = model_name | |
self.source_prompt = None | |
self.target_prompt = None | |
self.instruct_prompt = None | |
self.source_image = None | |
self.output = None | |
def dict(self): | |
base = { | |
"conv_id": self.conv_id, | |
"model_name": self.model_name, | |
"source_prompt": self.source_prompt, | |
"target_prompt": self.target_prompt, | |
"instruct_prompt": self.instruct_prompt | |
} | |
return base | |
class VideoStateVG: | |
def __init__(self, model_name): | |
self.conv_id = uuid.uuid4().hex | |
self.model_name = model_name | |
self.prompt = None | |
self.output = None | |
def dict(self): | |
base = { | |
"conv_id": self.conv_id, | |
"model_name": self.model_name, | |
"prompt": self.prompt | |
} | |
return base | |