import sys import os import logging as log from typing import Generator import gradio as gr from gradio.themes.utils import sizes from text_generation import Client from src.request import StarCoderRequest, StarCoderRequestConfig from src.utils import ( get_file_as_string, get_sections, get_url_from_env_or_default_path, preview ) from constants import ( FIM_MIDDLE, FIM_PREFIX, FIM_SUFFIX, END_OF_TEXT, MIN_TEMPERATURE, ) from settings import ( DEFAULT_PORT, DEFAULT_STARCODER_API_PATH, DEFAULT_STARCODER_BASE_API_PATH, ) HF_TOKEN = os.environ.get("HF_TOKEN", None) # Gracefully exit the app if the HF_TOKEN is not set, # printing to system `errout` the error (instead of raising an exception) # and the expected behavior if not HF_TOKEN: ERR_MSG = """ Please set the HF_TOKEN environment variable with your Hugging Face API token. You can get one by signing up at https://huggingface.co/join and then visiting https://huggingface.co/settings/tokens.""" print(ERR_MSG, file=sys.stderr) # gr.errors.GradioError(ERR_MSG) # gr.close_all(verbose=False) sys.exit(1) API_URL_STAR = get_url_from_env_or_default_path("STARCODER_API", DEFAULT_STARCODER_API_PATH) API_URL_BASE = get_url_from_env_or_default_path("STARCODER_BASE_API", DEFAULT_STARCODER_BASE_API_PATH) preview("StarCoder Model URL", API_URL_STAR) preview("StarCoderBase Model URL", API_URL_BASE) preview("HF Token", HF_TOKEN, ofuscate=True) _styles = get_file_as_string("styles.css") _script = get_file_as_string("community-btn.js") _sharing_icon_svg = get_file_as_string("community-icon.svg") _loading_icon_svg = get_file_as_string("loading-icon.svg") # Loads the whole content of the ./README.md file # slicing/unpacking its different sections into their proper variables readme_file_content = get_file_as_string("README.md", path='./') ( manifest, description, disclaimer, formats, ) = get_sections(readme_file_content, "---", up_to=4) theme = gr.themes.Monochrome( primary_hue="indigo", secondary_hue="blue", neutral_hue="slate", radius_size=sizes.radius_sm, font=[ gr.themes.GoogleFont("IBM Plex Sans", [400, 600]), "ui-sans-serif", "system-ui", "sans-serif", ], text_size=sizes.text_lg, ) HEADERS = { "Authorization": f"Bearer {HF_TOKEN}", } client_star = Client(API_URL_STAR, headers=HEADERS) client_base = Client(API_URL_BASE, headers=HEADERS) def get_tokens_collector(request: StarCoderRequest) -> Generator[str, None, None]: model_client = client_star if request.settings.version == "StarCoder" else client_base stream = model_client.generate_stream(request.prompt, **request.settings.kwargs()) for response in stream: # print(response.token.id, response.token.text) # if token.text != END_OF_TEXT: if response.token.id != 0: yield response.token.text def get_tokens_accumulator(request: StarCoderRequest) -> Generator[str, None, None]: # start with the prefix (if in fim_mode) output = request.prefix if request.fim_mode else request.prompt for token in get_tokens_collector(request=request): output += token yield output # after the last token, append the suffix (if in fim_mode) if request.fim_mode: output += request.suffix yield output # Append an extra line at the end yield output + '\n' def get_tokens_linker(request: StarCoderRequest) -> str: return "".join(list(get_tokens_collector(request))) def generate( prompt: str, temperature = 0.9, max_new_tokens = 256, top_p = 0.95, repetition_penalty = 1.0, version = "StarCoder", ) -> Generator[str, None, None]: request = StarCoderRequest( prompt=prompt, settings=StarCoderRequestConfig( version=version, temperature=temperature, max_new_tokens=max_new_tokens, top_p=top_p, repetition_penalty=repetition_penalty, ) ) yield from get_tokens_accumulator(request) def process_example( prompt: str, temperature = 0.9, max_new_tokens = 256, top_p = 0.95, repetition_penalty = 1.0, version = "StarCoder", ) -> Generator[str, None, None]: request = StarCoderRequest( prompt=prompt, settings=StarCoderRequestConfig( version=version, temperature=temperature, max_new_tokens=max_new_tokens, top_p=top_p, repetition_penalty=repetition_penalty, ) ) yield from get_tokens_linker(request) # todo: move it into the README too examples = [ "X_train, y_train, X_test, y_test = train_test_split(X, y, test_size=0.1)\n\n# Train a logistic regression model, predict the labels on the test set and compute the accuracy score", "// Returns every other value in the array as a new array.\nfunction everyOther(arr) {", "def alternating(list1, list2):\n results = []\n for i in range(min(len(list1), len(list2))):\n results.append(list1[i])\n results.append(list2[i])\n if len(list1) > len(list2):\n \n else:\n results.extend(list2[i+1:])\n return results", ] with gr.Blocks(theme=theme, analytics_enabled=False, css=_styles) as demo: with gr.Column(): gr.Markdown(description) with gr.Row(): with gr.Column(): instruction = gr.Textbox( placeholder="Enter your code here", label="Code", elem_id="q-input", ) submit = gr.Button("Generate", variant="primary") output = gr.Code(elem_id="q-output", lines=30) with gr.Row(): with gr.Column(): with gr.Accordion("Advanced settings", open=False): with gr.Row(): column_1, column_2 = gr.Column(), gr.Column() with column_1: temperature = gr.Slider( label="Temperature", value=0.2, minimum=0.0, maximum=1.0, step=0.05, interactive=True, info="Higher values produce more diverse outputs", ) max_new_tokens = gr.Slider( label="Max new tokens", value=256, minimum=0, maximum=8192, step=64, interactive=True, info="The maximum numbers of new tokens", ) with column_2: top_p = gr.Slider( label="Top-p (nucleus sampling)", value=0.90, minimum=0.0, maximum=1, step=0.05, interactive=True, info="Higher values sample more low-probability tokens", ) repetition_penalty = gr.Slider( label="Repetition penalty", value=1.2, minimum=1.0, maximum=2.0, step=0.05, interactive=True, info="Penalize repeated tokens", ) with gr.Column(): version = gr.Dropdown( ["StarCoderBase", "StarCoder"], value="StarCoder", label="Version", info="", ) gr.Markdown(disclaimer) with gr.Group(elem_id="share-btn-container"): community_icon = gr.HTML(_sharing_icon_svg, visible=True) loading_icon = gr.HTML(_loading_icon_svg, visible=True) share_button = gr.Button( "Share to community", elem_id="share-btn", visible=True ) gr.Examples( examples=examples, inputs=[instruction], cache_examples=False, fn=process_example, outputs=[output], ) gr.Markdown(formats) submit.click( generate, inputs=[instruction, temperature, max_new_tokens, top_p, repetition_penalty, version], outputs=[output], # preprocess=False, max_batch_size=8, show_progress=True ) share_button.click(None, [], [], _js=_script) demo.queue(concurrency_count=16).launch(debug=True, server_port=DEFAULT_PORT)