import gradio as gr import subprocess import os import time from transformers import AutoTokenizer, AutoModelForCausalLM import logging from starlette.middleware.sessions import SessionMiddleware # Configure logging logging.basicConfig(level=logging.INFO) # Path to the cloned repository BITNET_REPO_PATH = "/home/user/app/BitNet" SETUP_SCRIPT = os.path.join(BITNET_REPO_PATH, "setup_env.py") INFERENCE_SCRIPT = os.path.join(BITNET_REPO_PATH, "run_inference.py") # Function to set up the environment by running setup.py def setup_bitnet(model_name): try: result = subprocess.run( f"python {SETUP_SCRIPT} --hf-repo {model_name} -q i2_s", shell=True, cwd=BITNET_REPO_PATH, capture_output=True, text=True ) if result.returncode == 0: return "Setup completed successfully!" else: return f"Error in setup: {result.stderr}" except Exception as e: return str(e) # Function to run inference using the `run_inference.py` file def run_inference(model_name, input_text, num_tokens=6): try: # Call the `run_inference.py` script with the model and input model_name = model_name.split("/")[1] start_time = time.time() if input_text is None or input_text == "": return "Please provide an input text for the model" result = subprocess.run( f"python run_inference.py -m models/{model_name}/ggml-model-i2_s.gguf -p \"{input_text}\" -n {num_tokens} -temp 0", shell=True, cwd=BITNET_REPO_PATH, capture_output=True, text=True ) end_time = time.time() if result.returncode == 0: inference_time = round(end_time - start_time, 2) return result.stdout, f"Inference took {inference_time} seconds." else: return f"Error during inference: {result.stderr}", None except Exception as e: return str(e), None def run_transformers(model_name, input_text, num_tokens): # if oauth_token is None : # return "Error : To Compare please login to your HF account and make sure you have access to the used Llama models" # Load the model and tokenizer dynamically if needed (commented out for performance) # if model_name=="TinyLlama/TinyLlama-1.1B-Chat-v1.0" : print(input_text) if input_text is None or input_text == "": return "Please provide an input text for the model", None tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForCausalLM.from_pretrained(model_name) # Encode the input text input_ids = tokenizer.encode(input_text, return_tensors="pt") # Start time for inference start_time = time.time() # Generate output with the specified number of tokens output = model.generate(input_ids, max_length=len(input_ids[0]) + num_tokens, num_return_sequences=1) # Calculate inference time inference_time = time.time() - start_time # Decode the generated output generated_text = tokenizer.decode(output[0], skip_special_tokens=True) return generated_text, f"{inference_time:.2f} seconds" # Gradio Interface def interface(): with gr.Blocks(theme=gr.themes.Ocean()) as demo: # Header gr.Markdown( """

BitNet.cpp Speed Demonstration 💻

Compare the speed and performance of BitNet with popular Transformer models.

""", elem_id="header" ) # Instructions gr.Markdown( """ ### Instructions for Using the BitNet.cpp Speed Demonstration 1. **Set Up Your Project**: Begin by selecting the model you wish to use. Please note that this process may take a few minutes to complete. 2. **Select Token Count**: Choose the number of tokens you want to generate for your inference. 3. **Input Your Text**: Enter the text you wish to analyze, then compare the performance of BitNet with popular Transformer models. """, elem_id="instructions" ) # Model Selection and Setup with gr.Column(elem_id="container"): gr.Markdown("

Model Selection and Setup

") with gr.Row(): model_dropdown = gr.Dropdown( label="Select Model", choices=[ "HF1BitLLM/Llama3-8B-1.58-100B-tokens", "1bitLLM/bitnet_b1_58-3B", "1bitLLM/bitnet_b1_58-large" ], value="HF1BitLLM/Llama3-8B-1.58-100B-tokens", interactive=True ) setup_button = gr.Button("Run Setup") setup_status = gr.Textbox(label="Setup Status", interactive=False, placeholder="Setup status will appear here...") # Inference Section with gr.Column(elem_id="container"): gr.Markdown("

BitNet Inference

") with gr.Row(): num_tokens = gr.Slider( minimum=1, maximum=100, label="Number of Tokens to Generate", value=50, step=1 ) input_text = gr.Textbox( label="Input Text", placeholder="Enter your input text here...", value="Who is Zeus?" ) with gr.Row(): infer_button = gr.Button("Run Inference") result_output = gr.Textbox(label="Output", interactive=False, placeholder="Inference output will appear here...") time_output = gr.Textbox(label="Inference Time", interactive=False, placeholder="Inference time will appear here...") # Comparison with Transformers Section with gr.Column(elem_id="container"): gr.Markdown("

Compare with Transformers

") with gr.Row(): transformer_model_dropdown = gr.Dropdown( label="Select Transformers Model", choices=["TinyLlama/TinyLlama_v1.1"], value="TinyLlama/TinyLlama_v1.1", interactive=True ) input_text_tr = gr.Textbox(label="Input Text", placeholder="Enter your input text here...", value="Who is Zeus?") with gr.Row(): compare_button = gr.Button("Run Transformers Inference") transformer_result_output = gr.Textbox(label="Transformers Output", interactive=False, placeholder="Transformers output will appear here...") transformer_time_output = gr.Textbox(label="Transformers Inference Time", interactive=False, placeholder="Transformers inference time will appear here...") # Actions setup_button.click(setup_bitnet, inputs=model_dropdown, outputs=setup_status) infer_button.click(run_inference, inputs=[model_dropdown, input_text, num_tokens], outputs=[result_output, time_output]) compare_button.click(run_transformers, inputs=[transformer_model_dropdown, input_text_tr, num_tokens], outputs=[transformer_result_output, transformer_time_output]) return demo demo = interface() # # Access FastAPI app instance from Gradio # fastapi_app = demo.app # # Add SessionMiddleware to enable session management # fastapi_app.add_middleware(SessionMiddleware, secret_key="secret_key") # Use a secure, random secret key # # Launch the app demo.launch() # from fastapi import FastAPI # app = FastAPI() # # Add SessionMiddleware for sessions handling # app.add_middleware(SessionMiddleware, secret_key="secure_secret_key") # # Mount Gradio app to FastAPI at the root # app.mount("/", demo)