import gradio as gr from datasets import disable_caching, load_dataset from transformer_ranker import TransformerRanker, prepare_popular_models import traceback from utils import ( DISABLED_BUTTON_VARIANT, ENABLED_BUTTON_VARIANT, CSS, HEADLINE, FOOTER, EmbeddingProgressTracker, check_dataset_exists, check_dataset_is_loaded, compute_ratio, ensure_one_lm_selected, get_dataset_info ) disable_caching() THEME = "pseudolab/huggingface-korea-theme" DEFAULT_SAMPLES = 1000 MAX_SAMPLES = 5000 LANGUAGE_MODELS = prepare_popular_models('base') + prepare_popular_models('large') # Add a tiny model for demonstration on CPU LANGUAGE_MODELS = ['prajjwal1/bert-tiny'] + list(dict.fromkeys(LANGUAGE_MODELS)) LANGUAGE_MODELS.insert(LANGUAGE_MODELS.index("bert-base-cased") + 1, "bert-base-uncased") # Preselect some small models DEFAULT_MODELS = [ "prajjwal1/bert-tiny", "google/electra-small-discriminator", "distilbert-base-cased", "sentence-transformers/all-MiniLM-L12-v2" ] with gr.Blocks(css=CSS, theme=THEME) as demo: ########## STEP 1: Load the Dataset ########## gr.Markdown(HEADLINE) gr.Markdown("## Step 1: Load a Dataset") with gr.Group(): dataset = gr.State(None) dataset_name = gr.Textbox( label="Enter the name of your dataset", placeholder="Examples: trec, ag_news, sst2, conll2003, leondz/wnut_17", max_lines=1, ) select_dataset_button = gr.Button( value="Load dataset", interactive=False, variant=DISABLED_BUTTON_VARIANT ) # Activate the "Load dataset" button if dataset was found dataset_name.change( check_dataset_exists, inputs=dataset_name, outputs=select_dataset_button ) gr.Markdown( "*The number of samples that can be used in this demo is limited to save resources. " "To run an estimate on the full dataset, check out the " "[library](https://github.com/flairNLP/transformer-ranker).*" ) ########## Step 1.1 Dataset preprocessing ########## with gr.Accordion("Dataset settings", open=False) as dataset_config: with gr.Row() as dataset_details: dataset_name_label = gr.Label("", label="Dataset Name") num_samples = gr.State(0) num_samples_label = gr.Label("", label="Number of Samples") num_samples.change( lambda x: str(x), inputs=[num_samples], outputs=[num_samples_label] ) with gr.Row(): text_column = gr.Dropdown("", label="Text Column") text_pair_column = gr.Dropdown("", label="Text Pair Column") with gr.Row(): label_column = gr.Dropdown("", label="Label Column") task_category = gr.Dropdown("", label="Task Type") with gr.Group(): downsample_ratio = gr.State(0.0) num_samples_to_use = gr.Slider( 20, MAX_SAMPLES, label="Samples to use", value=DEFAULT_SAMPLES, step=1 ) downsample_ratio_label = gr.Label("", label="Ratio of dataset to use") downsample_ratio.change( lambda x: f"{x:.1%}", inputs=[downsample_ratio], outputs=[downsample_ratio_label], ) num_samples_to_use.change( compute_ratio, inputs=[num_samples_to_use, num_samples], outputs=downsample_ratio, ) num_samples.change( compute_ratio, inputs=[num_samples_to_use, num_samples], outputs=downsample_ratio, ) # Download the dataset and show details def select_dataset(dataset_name): try: dataset = load_dataset(dataset_name, trust_remote_code=True) dataset_info = get_dataset_info(dataset) except ValueError: gr.Warning("Dataset collections are not supported. Please use a single dataset.") return ( gr.update(value="Loaded", interactive=False, variant=DISABLED_BUTTON_VARIANT), gr.Accordion(open=True), dataset_name, dataset, *dataset_info ) select_dataset_button.click( select_dataset, inputs=[dataset_name], outputs=[ select_dataset_button, dataset_config, dataset_name_label, dataset, task_category, text_column, text_pair_column, label_column, num_samples, ], scroll_to_output=True, ) ########## STEP 2 ########## gr.Markdown("## Step 2: Select a List of Language Models") with gr.Group(): model_options = [ (model_handle.split("/")[-1], model_handle) for model_handle in LANGUAGE_MODELS ] models = gr.CheckboxGroup( choices=model_options, label="Select Models", value=DEFAULT_MODELS ) ########## STEP 3: Run Language Model Ranking ########## gr.Markdown("## Step 3: Rank LMs") with gr.Group(): with gr.Accordion("Advanced settings", open=False): with gr.Row(): estimator = gr.Dropdown( choices=["hscore", "logme", "knn"], label="Transferability metric", value="hscore", ) layer_pooling_options = ["lastlayer", "layermean", "bestlayer"] layer_pooling = gr.Dropdown( choices=["lastlayer", "layermean", "bestlayer"], label="Layer pooling", value="layermean", ) submit_button = gr.Button("Run Ranking", interactive=False, variant=DISABLED_BUTTON_VARIANT) # Make button active if the dataset is loaded dataset.change( check_dataset_is_loaded, inputs=[dataset, text_column, label_column, task_category], outputs=submit_button ) label_column.change( check_dataset_is_loaded, inputs=[dataset, text_column, label_column, task_category], outputs=submit_button ) text_column.change( check_dataset_is_loaded, inputs=[dataset, text_column, label_column, task_category], outputs=submit_button ) def rank_models( dataset, downsample_ratio, selected_models, layer_pooling, estimator, text_column, text_pair_column, label_column, task_category, progress=gr.Progress(), ): if text_column == "-": raise gr.Error("Text column is not set.") if label_column == "-": raise gr.Error("Label column is not set.") if task_category == "-": raise gr.Error( "Task category is not set. The dataset must support classification or regression tasks." ) if text_pair_column == "-": text_pair_column = None progress(0.0, "Starting") with EmbeddingProgressTracker(progress=progress, model_names=selected_models) as tracker: try: ranker = TransformerRanker( dataset, dataset_downsample=downsample_ratio, text_column=text_column, text_pair_column=text_pair_column, label_column=label_column, task_category=task_category, ) results = ranker.run( models=selected_models, layer_aggregator=layer_pooling, estimator=estimator, batch_size=64, tracker=tracker, ) sorted_results = sorted( results._results.items(), key=lambda item: item[1], reverse=True ) return [ (i + 1, model, score) for i, (model, score) in enumerate(sorted_results) ] except Exception as e: gr.Error("The dataset is not supported.") gr.Markdown("## Results") ranking_results = gr.Dataframe( headers=["Rank", "Model", "Score"], datatype=["number", "str", "number"] ) submit_button.click( rank_models, inputs=[ dataset, downsample_ratio, models, layer_pooling, estimator, text_column, text_pair_column, label_column, task_category, ], outputs=ranking_results, scroll_to_output=True, ) gr.Markdown( "*The results are ranked by their transferability score, with the most suitable model listed first. " "This ranking allows focusing on the higher-ranked models for further exploration and fine-tuning.*" ) gr.Markdown(FOOTER) if __name__ == "__main__": demo.queue(default_concurrency_limit=3) demo.launch(max_threads=6)