Hector Salvador [Fisharp]
Relocation of static contents to their own files
32f7b3e
raw
history blame
8.92 kB
import sys
import os
import gradio as gr
from gradio.themes.utils import sizes
from text_generation import Client
# todo: remove and replace by the actual js file instead
from share_btn import (share_js)
from utils import (
get_file_as_string,
get_sections,
get_url_from_env_or_default_path,
preview
)
from constants import (
DEFAULT_STARCODER_API_PATH,
DEFAULT_STARCODER_BASE_API_PATH,
FIM_MIDDLE,
FIM_PREFIX,
FIM_SUFFIX,
END_OF_TEXT,
MIN_TEMPERATURE,
)
HF_TOKEN = os.environ.get("HF_TOKEN", None)
# Gracefully exit the app if the HF_TOKEN is not set,
# printing to system `errout` the error (instead of raising an exception)
# and the expected behavior
if not HF_TOKEN:
ERR_MSG = """
Please set the HF_TOKEN environment variable with your Hugging Face API token.
You can get one by signing up at https://huggingface.co/join and then visiting
https://huggingface.co/settings/tokens."""
print(ERR_MSG, file=sys.stderr)
# gr.errors.GradioError(ERR_MSG)
# gr.close_all(verbose=False)
sys.exit(1)
API_URL = get_url_from_env_or_default_path("STARCODER_API", DEFAULT_STARCODER_API_PATH)
API_URL_BASE = get_url_from_env_or_default_path("STARCODER_BASE_API", DEFAULT_STARCODER_BASE_API_PATH)
preview("StarCoder Model's URL", API_URL)
preview("StarCoderBase Model's URL", API_URL_BASE)
preview("HF Token", HF_TOKEN, ofuscate=True)
DEFAULT_PORT = 7860
FIM_INDICATOR = "<FILL_HERE>"
# Loads the whole content of the formats.md file
# and stores it into the FORMATS variable
STATIC_PATH = "static"
FORMATS = get_file_as_string("formats.md", path=STATIC_PATH)
CSS = get_file_as_string("styles.css", path=STATIC_PATH)
community_icon_svg = get_file_as_string("community_icon.svg", path=STATIC_PATH)
loading_icon_svg = get_file_as_string("loading_icon.svg", path=STATIC_PATH)
# todo: evaluate making STATIC_PATH the default path instead of the current one
README = get_file_as_string("README.md")
# Slicing the different sections from the README
readme_sections = get_sections(README, "---")
manifest, description, disclaimer = readme_sections[:3]
theme = gr.themes.Monochrome(
primary_hue="indigo",
secondary_hue="blue",
neutral_hue="slate",
radius_size=sizes.radius_sm,
font=[
gr.themes.GoogleFont("Rubik"),
"ui-sans-serif",
"system-ui",
"sans-serif",
],
text_size=sizes.text_lg,
)
HEADERS = {
"Authorization": f"Bearer {HF_TOKEN}",
}
client = Client(API_URL, headers = HEADERS)
client_base = Client(API_URL_BASE, headers = HEADERS)
def generate(prompt,
temperature = 0.9,
max_new_tokens = 256,
top_p = 0.95,
repetition_penalty = 1.0,
version = "StarCoder",
):
temperature = min(float(temperature), MIN_TEMPERATURE)
top_p = float(top_p)
generate_kwargs = dict(
temperature = temperature,
max_new_tokens = max_new_tokens,
top_p = top_p,
repetition_penalty = repetition_penalty,
do_sample = True,
seed = 42,
)
if fim_mode := FIM_INDICATOR in prompt:
try:
prefix, suffix = prompt.split(FIM_INDICATOR)
except Exception as err:
print(str(err))
raise ValueError(f"Only one {FIM_INDICATOR} allowed in prompt!") from err
prompt = f"{FIM_PREFIX}{prefix}{FIM_SUFFIX}{suffix}{FIM_MIDDLE}"
model_client = client if version == "StarCoder" else client_base
stream = model_client.generate_stream(prompt, **generate_kwargs)
output = prefix if fim_mode else prompt
for response in stream:
if response.token.text == END_OF_TEXT:
if fim_mode:
output += suffix
else:
return output
else:
output += response.token.text
# todo: log this value while in debug mode
# previous_token = response.token.text
yield output
return output
# todo: move it into the README too
examples = [
"X_train, y_train, X_test, y_test = train_test_split(X, y, test_size=0.1)\n\n# Train a logistic regression model, predict the labels on the test set and compute the accuracy score",
"// Returns every other value in the array as a new array.\nfunction everyOther(arr) {",
"def alternating(list1, list2):\n results = []\n for i in range(min(len(list1), len(list2))):\n results.append(list1[i])\n results.append(list2[i])\n if len(list1) > len(list2):\n <FILL_HERE>\n else:\n results.extend(list2[i+1:])\n return results",
]
def process_example(args):
for x in generate(args):
pass
return x
with gr.Blocks(theme=theme, analytics_enabled=False, css=CSS) as demo:
with gr.Column():
gr.Markdown(description)
with gr.Row():
with gr.Column():
instruction = gr.Textbox(
placeholder="Enter your code here",
label="Code",
elem_id="q-input",
)
submit = gr.Button("Generate", variant="primary")
output = gr.Code(elem_id="q-output", lines=30)
with gr.Row():
with gr.Column():
with gr.Accordion("Advanced settings", open=False):
with gr.Row():
column_1, column_2 = gr.Column(), gr.Column()
with column_1:
temperature = gr.Slider(
label="Temperature",
value=0.2,
minimum=0.0,
maximum=1.0,
step=0.05,
interactive=True,
info="Higher values produce more diverse outputs",
)
max_new_tokens = gr.Slider(
label="Max new tokens",
value=256,
minimum=0,
maximum=8192,
step=64,
interactive=True,
info="The maximum numbers of new tokens",
)
with column_2:
top_p = gr.Slider(
label="Top-p (nucleus sampling)",
value=0.90,
minimum=0.0,
maximum=1,
step=0.05,
interactive=True,
info="Higher values sample more low-probability tokens",
)
repetition_penalty = gr.Slider(
label="Repetition penalty",
value=1.2,
minimum=1.0,
maximum=2.0,
step=0.05,
interactive=True,
info="Penalize repeated tokens",
)
with gr.Column():
version = gr.Dropdown(
["StarCoderBase", "StarCoder"],
value="StarCoder",
label="Version",
info="",
)
gr.Markdown(disclaimer)
with gr.Group(elem_id="share-btn-container"):
community_icon = gr.HTML(community_icon_svg, visible=True)
loading_icon = gr.HTML(loading_icon_svg, visible=True)
share_button = gr.Button(
"Share to community", elem_id="share-btn", visible=True
)
gr.Examples(
examples=examples,
inputs=[instruction],
cache_examples=False,
fn=process_example,
outputs=[output],
)
gr.Markdown(FORMATS)
submit.click(
generate,
inputs=[instruction, temperature, max_new_tokens, top_p, repetition_penalty, version],
outputs=[output],
)
share_button.click(None, [], [], _js=share_js)
demo.queue(concurrency_count=16).launch(debug=True, server_port=DEFAULT_PORT)