sanjayw's picture
Duplicate from Fisharp/starcoder-playground
c9c9be5
import sys
import os
import logging as log
from typing import Generator
import gradio as gr
from gradio.themes.utils import sizes
from text_generation import Client
from src.request import StarCoderRequest, StarCoderRequestConfig
from src.utils import (
get_file_as_string,
get_sections,
get_url_from_env_or_default_path,
preview
)
from constants import (
FIM_MIDDLE,
FIM_PREFIX,
FIM_SUFFIX,
END_OF_TEXT,
MIN_TEMPERATURE,
)
from settings import (
DEFAULT_PORT,
DEFAULT_STARCODER_API_PATH,
DEFAULT_STARCODER_BASE_API_PATH,
)
HF_TOKEN = os.environ.get("HF_TOKEN", None)
# Gracefully exit the app if the HF_TOKEN is not set,
# printing to system `errout` the error (instead of raising an exception)
# and the expected behavior
if not HF_TOKEN:
ERR_MSG = """
Please set the HF_TOKEN environment variable with your Hugging Face API token.
You can get one by signing up at https://huggingface.co/join and then visiting
https://huggingface.co/settings/tokens."""
print(ERR_MSG, file=sys.stderr)
# gr.errors.GradioError(ERR_MSG)
# gr.close_all(verbose=False)
sys.exit(1)
API_URL_STAR = get_url_from_env_or_default_path("STARCODER_API", DEFAULT_STARCODER_API_PATH)
API_URL_BASE = get_url_from_env_or_default_path("STARCODER_BASE_API", DEFAULT_STARCODER_BASE_API_PATH)
preview("StarCoder Model URL", API_URL_STAR)
preview("StarCoderBase Model URL", API_URL_BASE)
preview("HF Token", HF_TOKEN, ofuscate=True)
_styles = get_file_as_string("styles.css")
_script = get_file_as_string("community-btn.js")
_sharing_icon_svg = get_file_as_string("community-icon.svg")
_loading_icon_svg = get_file_as_string("loading-icon.svg")
# Loads the whole content of the ./README.md file
# slicing/unpacking its different sections into their proper variables
readme_file_content = get_file_as_string("README.md", path='./')
(
manifest,
description,
disclaimer,
formats,
) = get_sections(readme_file_content, "---", up_to=4)
theme = gr.themes.Monochrome(
primary_hue="indigo",
secondary_hue="blue",
neutral_hue="slate",
radius_size=sizes.radius_sm,
font=[
gr.themes.GoogleFont("IBM Plex Sans", [400, 600]),
"ui-sans-serif",
"system-ui",
"sans-serif",
],
text_size=sizes.text_lg,
)
HEADERS = {
"Authorization": f"Bearer {HF_TOKEN}",
}
client_star = Client(API_URL_STAR, headers=HEADERS)
client_base = Client(API_URL_BASE, headers=HEADERS)
def get_tokens_collector(request: StarCoderRequest) -> Generator[str, None, None]:
model_client = client_star if request.settings.version == "StarCoder" else client_base
stream = model_client.generate_stream(request.prompt, **request.settings.kwargs())
for response in stream:
# print(response.token.id, response.token.text)
# if token.text != END_OF_TEXT:
if response.token.id != 0:
yield response.token.text
def get_tokens_accumulator(request: StarCoderRequest) -> Generator[str, None, None]:
# start with the prefix (if in fim_mode)
output = request.prefix if request.fim_mode else request.prompt
for token in get_tokens_collector(request=request):
output += token
yield output
# after the last token, append the suffix (if in fim_mode)
if request.fim_mode:
output += request.suffix
yield output
# Append an extra line at the end
yield output + '\n'
def get_tokens_linker(request: StarCoderRequest) -> str:
return "".join(list(get_tokens_collector(request)))
def generate(
prompt: str,
temperature = 0.9,
max_new_tokens = 256,
top_p = 0.95,
repetition_penalty = 1.0,
version = "StarCoder",
) -> Generator[str, None, None]:
request = StarCoderRequest(
prompt=prompt,
settings=StarCoderRequestConfig(
version=version,
temperature=temperature,
max_new_tokens=max_new_tokens,
top_p=top_p,
repetition_penalty=repetition_penalty,
)
)
yield from get_tokens_accumulator(request)
def process_example(
prompt: str,
temperature = 0.9,
max_new_tokens = 256,
top_p = 0.95,
repetition_penalty = 1.0,
version = "StarCoder",
) -> Generator[str, None, None]:
request = StarCoderRequest(
prompt=prompt,
settings=StarCoderRequestConfig(
version=version,
temperature=temperature,
max_new_tokens=max_new_tokens,
top_p=top_p,
repetition_penalty=repetition_penalty,
)
)
yield from get_tokens_linker(request)
# todo: move it into the README too
examples = [
"X_train, y_train, X_test, y_test = train_test_split(X, y, test_size=0.1)\n\n# Train a logistic regression model, predict the labels on the test set and compute the accuracy score",
"// Returns every other value in the array as a new array.\nfunction everyOther(arr) {",
"def alternating(list1, list2):\n results = []\n for i in range(min(len(list1), len(list2))):\n results.append(list1[i])\n results.append(list2[i])\n if len(list1) > len(list2):\n <FILL_HERE>\n else:\n results.extend(list2[i+1:])\n return results",
]
with gr.Blocks(theme=theme, analytics_enabled=False, css=_styles) as demo:
with gr.Column():
gr.Markdown(description)
with gr.Row():
with gr.Column():
instruction = gr.Textbox(
placeholder="Enter your code here",
label="Code",
elem_id="q-input",
)
submit = gr.Button("Generate", variant="primary")
output = gr.Code(elem_id="q-output", lines=30)
with gr.Row():
with gr.Column():
with gr.Accordion("Advanced settings", open=False):
with gr.Row():
column_1, column_2 = gr.Column(), gr.Column()
with column_1:
temperature = gr.Slider(
label="Temperature",
value=0.2,
minimum=0.0,
maximum=1.0,
step=0.05,
interactive=True,
info="Higher values produce more diverse outputs",
)
max_new_tokens = gr.Slider(
label="Max new tokens",
value=256,
minimum=0,
maximum=8192,
step=64,
interactive=True,
info="The maximum numbers of new tokens",
)
with column_2:
top_p = gr.Slider(
label="Top-p (nucleus sampling)",
value=0.90,
minimum=0.0,
maximum=1,
step=0.05,
interactive=True,
info="Higher values sample more low-probability tokens",
)
repetition_penalty = gr.Slider(
label="Repetition penalty",
value=1.2,
minimum=1.0,
maximum=2.0,
step=0.05,
interactive=True,
info="Penalize repeated tokens",
)
with gr.Column():
version = gr.Dropdown(
["StarCoderBase", "StarCoder"],
value="StarCoder",
label="Version",
info="",
)
gr.Markdown(disclaimer)
with gr.Group(elem_id="share-btn-container"):
community_icon = gr.HTML(_sharing_icon_svg, visible=True)
loading_icon = gr.HTML(_loading_icon_svg, visible=True)
share_button = gr.Button(
"Share to community", elem_id="share-btn", visible=True
)
gr.Examples(
examples=examples,
inputs=[instruction],
cache_examples=False,
fn=process_example,
outputs=[output],
)
gr.Markdown(formats)
submit.click(
generate,
inputs=[instruction, temperature, max_new_tokens, top_p, repetition_penalty, version],
outputs=[output],
# preprocess=False,
max_batch_size=8,
show_progress=True
)
share_button.click(None, [], [], _js=_script)
demo.queue(concurrency_count=16).launch(debug=True, server_port=DEFAULT_PORT)