|
|
|
|
|
from __future__ import annotations |
|
|
|
import shutil |
|
import tempfile |
|
|
|
import gradio as gr |
|
from huggingface_hub import HfApi |
|
|
|
title = 'Model Demo Creation' |
|
description = ''' |
|
With this Space, you can create a demo Space for models that are loadable with `gradio.Interface.load` in [Model Hub](https://huggingface.co/models). |
|
The Space will be created under your account and private. |
|
You need a token with write permission (See: https://huggingface.co/settings/tokens). |
|
|
|
You can specify multiple model names by listing them separated by commas. |
|
If you specify multiple model names, the resulting Space will show all the outputs of those models side by side for the given inputs. |
|
''' |
|
article = '' |
|
examples = [ |
|
[ |
|
'resnet-50', |
|
'microsoft/resnet-50', |
|
'', |
|
'Demo for microsoft/resnet-50', |
|
'', |
|
'', |
|
], |
|
[ |
|
'compare-image-classification-models', |
|
'google/vit-base-patch16-224, microsoft/resnet-50', |
|
'', |
|
'Compare Image Classification Models', |
|
'', |
|
'', |
|
], |
|
[ |
|
'compare-text-generation-models', |
|
'EleutherAI/gpt-j-6B, EleutherAI/gpt-neo-1.3B', |
|
'', |
|
'Compare Text Generation Models', |
|
'', |
|
'', |
|
], |
|
] |
|
|
|
api = HfApi() |
|
|
|
|
|
def check_if_model_exists(model_name: str) -> bool: |
|
return any(info.modelId == model_name |
|
for info in api.list_models(search=model_name)) |
|
|
|
|
|
def check_if_model_loadable(model_name: str) -> bool: |
|
try: |
|
gr.Interface.load(model_name, src='models') |
|
except Exception: |
|
return False |
|
return True |
|
|
|
|
|
def get_model_io_types( |
|
model_name: str) -> tuple[tuple[str, ...], tuple[str, ...]]: |
|
iface = gr.Interface.load(model_name, src='models') |
|
inputs = tuple(map(str, iface.input_components)) |
|
outputs = tuple(map(str, iface.output_components)) |
|
return inputs, outputs |
|
|
|
|
|
def check_if_model_io_is_consistent(model_names: list[str]) -> bool: |
|
if len(model_names) == 1: |
|
return True |
|
|
|
inputs0, outputs0 = get_model_io_types(model_names[0]) |
|
for name in model_names[1:]: |
|
inputs, outputs = get_model_io_types(name) |
|
if inputs != inputs0 or outputs != outputs0: |
|
return False |
|
return True |
|
|
|
|
|
def save_space_info(dirname: str, filename: str, content: str) -> None: |
|
with open(f'{dirname}/{filename}', 'w') as f: |
|
f.write(content) |
|
|
|
|
|
def run(space_name: str, model_names_str: str, hf_token: str, title: str, |
|
description: str, article: str) -> str: |
|
if space_name == '': |
|
return 'Space Name must be specified.' |
|
if model_names_str == '': |
|
return 'Model Names must be specified.' |
|
if hf_token == '': |
|
return 'Hugging Face Token must be specified.' |
|
|
|
model_names = [name.strip() for name in model_names_str.split(',')] |
|
model_names_str = '\n'.join(model_names) |
|
|
|
missing_models = [ |
|
name for name in model_names if not check_if_model_exists(name) |
|
] |
|
if len(missing_models) > 0: |
|
message = 'The following models were not found: ' |
|
for model_name in missing_models: |
|
message += f'\n{model_name}' |
|
return message |
|
|
|
non_loadable_models = [ |
|
name for name in model_names if not check_if_model_loadable(name) |
|
] |
|
if len(non_loadable_models) > 0: |
|
message = 'The following models are not loadable with gradio.Interface.load: ' |
|
for model_name in non_loadable_models: |
|
message += f'\n{model_name}' |
|
return message |
|
|
|
if not check_if_model_io_is_consistent(model_names): |
|
return 'The inputs and outputs of each model must be the same.' |
|
|
|
user_name = api.whoami(token=hf_token)['name'] |
|
repo_id = f'{user_name}/{space_name}' |
|
try: |
|
space_url = api.create_repo(repo_id=repo_id, |
|
repo_type='space', |
|
private=True, |
|
token=hf_token, |
|
space_sdk='gradio') |
|
except Exception as e: |
|
return str(e) |
|
|
|
with tempfile.TemporaryDirectory() as temp_dir: |
|
shutil.copy('assets/template.py', f'{temp_dir}/app.py') |
|
save_space_info(temp_dir, 'TITLE', title) |
|
save_space_info(temp_dir, 'DESCRIPTION', description) |
|
save_space_info(temp_dir, 'ARTICLE', article) |
|
save_space_info(temp_dir, 'MODEL_NAMES', model_names_str) |
|
api.upload_folder(repo_id=repo_id, |
|
folder_path=temp_dir, |
|
path_in_repo='.', |
|
token=hf_token, |
|
repo_type='space') |
|
|
|
return f'Successfully created: {space_url}' |
|
|
|
|
|
gr.Interface( |
|
fn=run, |
|
inputs=[ |
|
gr.Textbox( |
|
label='Space Name', |
|
placeholder= |
|
'e.g. demo-resnet-50. The Space will be created under your account and private.' |
|
), |
|
gr.Textbox(label='Model Names', |
|
placeholder='e.g. microsoft/resnet-50'), |
|
gr.Textbox( |
|
label='Hugging Face Token', |
|
placeholder= |
|
'This should be a token with write permission. See: https://huggingface.co/settings/tokens' |
|
), |
|
gr.Textbox(label='Title (Optional)'), |
|
gr.Textbox(label='Description (Optional)'), |
|
gr.Textbox(label='Article (Optional)'), |
|
], |
|
outputs=gr.Textbox(label='Output'), |
|
title=title, |
|
description=description, |
|
article=article, |
|
examples=examples, |
|
cache_examples=False, |
|
).launch(enable_queue=True, share=False) |
|
|