Spaces:
Runtime error
Runtime error
import gradio as gr | |
import torch | |
from safetensors.torch import save_model | |
import requests | |
import os | |
def convert_ckpt_to_safetensors(input_path, output_path): | |
# Load the .ckpt file | |
# ⚠️ SECURITY WARNING: | |
# Loading untrusted .ckpt files with torch.load() can execute arbitrary code. | |
# Only load files from trusted sources. | |
obj = torch.load(input_path, map_location='cpu') | |
# Determine if obj is a state dict or a model object | |
if isinstance(obj, dict): | |
# Check for nested 'state_dict' or 'model' keys | |
if 'state_dict' in obj: | |
state_dict = obj['state_dict'] | |
elif 'model' in obj: | |
state_dict = obj['model'] | |
else: | |
# Assume obj is the state dict | |
state_dict = obj | |
elif hasattr(obj, 'state_dict'): | |
# If obj is a model object | |
state_dict = obj.state_dict() | |
else: | |
return "Unsupported checkpoint format." | |
# Save the state dictionary, including shared tensors and LM head | |
try: | |
save_model(state_dict, output_path) | |
except Exception as e: | |
return f"An error occurred during saving: {e}" | |
return "Success" | |
def process(url, uploaded_file): | |
if url: | |
# Download the .ckpt file | |
local_filename = 'model.ckpt' | |
try: | |
with requests.get(url, stream=True) as r: | |
r.raise_for_status() | |
with open(local_filename, 'wb') as f: | |
for chunk in r.iter_content(chunk_size=8192): | |
f.write(chunk) | |
except Exception as e: | |
return f"<p style='color:red;'>Failed to download file: {e}</p>" | |
elif uploaded_file is not None: | |
# Save uploaded file | |
local_filename = 'uploaded_model.ckpt' | |
try: | |
with open(local_filename, 'wb') as f: | |
f.write(uploaded_file.read()) | |
except Exception as e: | |
return f"<p style='color:red;'>Failed to save uploaded file: {e}</p>" | |
else: | |
return "<p style='color:red;'>Please provide a URL or upload a .ckpt file.</p>" | |
output_filename = local_filename.replace('.ckpt', '.safetensors') | |
# Convert the .ckpt to .safetensors | |
try: | |
result = convert_ckpt_to_safetensors(local_filename, output_filename) | |
if result != "Success": | |
# Clean up the input file | |
os.remove(local_filename) | |
return f"<p style='color:red;'>An error occurred during conversion: {result}</p>" | |
except Exception as e: | |
# Clean up the input file | |
os.remove(local_filename) | |
return f"<p style='color:red;'>An exception occurred: {e}</p>" | |
# Clean up the input file | |
os.remove(local_filename) | |
# Provide a download link for the output file | |
return gr.File.update(value=output_filename, visible=True) | |
iface = gr.Interface( | |
fn=process, | |
inputs=[ | |
gr.Textbox(label="URL of .ckpt file", placeholder="Enter the URL here"), | |
gr.File(label="Or upload a .ckpt file", file_types=['.ckpt']) | |
], | |
outputs=gr.File(label="Converted .safetensors file"), | |
title="CKPT to SafeTensors Converter", | |
description=""" | |
Convert .ckpt files to .safetensors format. Provide a URL or upload your .ckpt file. | |
**Security Warning:** Loading .ckpt files can execute arbitrary code. Only use files from trusted sources. | |
""" | |
) | |
iface.launch() |