|
import streamlit as st |
|
import pandas as pd |
|
import os |
|
from langchain.prompts import PromptTemplate |
|
from datetime import datetime |
|
import random |
|
from pathlib import Path |
|
from openai import OpenAI |
|
from dotenv import load_dotenv |
|
|
|
|
|
|
|
|
|
|
|
load_dotenv() |
|
client = OpenAI( |
|
base_url="/static-proxy?url=https%3A%2F%2Fapi-inference.huggingface.co%2Fv1%26quot%3B%3C%2Fspan%3E%2C%3C!-- HTML_TAG_END --> |
|
api_key=os.environ.get('TEXT_TOKEN') |
|
) |
|
|
|
|
|
|
|
st.markdown(""" |
|
<style> |
|
.stButton > button { |
|
width: 100%; |
|
margin-bottom: 10px; |
|
background-color: #4CAF50; |
|
color: white; |
|
border: none; |
|
padding: 10px; |
|
border-radius: 5px; |
|
} |
|
.task-button { |
|
background-color: #2196F3 !important; |
|
} |
|
.stSelectbox { |
|
margin-bottom: 20px; |
|
} |
|
.output-container { |
|
padding: 20px; |
|
border-radius: 5px; |
|
border: 1px solid #ddd; |
|
margin: 10px 0; |
|
} |
|
.status-container { |
|
padding: 10px; |
|
border-radius: 5px; |
|
margin: 10px 0; |
|
} |
|
.sidebar-info { |
|
padding: 10px; |
|
background-color: #f0f2f6; |
|
border-radius: 5px; |
|
margin: 10px 0; |
|
} |
|
</style> |
|
""", unsafe_allow_html=True) |
|
|
|
|
|
if not os.path.exists('data'): |
|
os.makedirs('data') |
|
|
|
def read_csv_with_encoding(file): |
|
encodings = ['utf-8', 'latin1', 'iso-8859-1', 'cp1252'] |
|
for encoding in encodings: |
|
try: |
|
return pd.read_csv(file, encoding=encoding) |
|
except UnicodeDecodeError: |
|
continue |
|
raise UnicodeDecodeError("Failed to read file with any supported encoding") |
|
|
|
def save_to_csv(data, filename): |
|
df = pd.DataFrame(data) |
|
df.to_csv(f'data/{filename}', index=False) |
|
return df |
|
|
|
def load_from_csv(filename): |
|
try: |
|
return pd.read_csv(f'data/{filename}') |
|
except: |
|
return pd.DataFrame() |
|
|
|
|
|
def reset_conversation(): |
|
st.session_state.conversation = [] |
|
st.session_state.messages = [] |
|
|
|
|
|
if "messages" not in st.session_state: |
|
st.session_state.messages = [] |
|
if "examples_to_classify" not in st.session_state: |
|
st.session_state.examples_to_classify = [] |
|
|
|
|
|
st.title("π€ Text Data Labeling and Generation App") |
|
|
|
|
|
with st.sidebar: |
|
st.title("βοΈ Settings") |
|
|
|
selected_model = st.selectbox( |
|
"Select Model", |
|
["meta-llama/Meta-Llama-3-8B-Instruct"], |
|
key='model_select' |
|
) |
|
|
|
temperature = st.slider( |
|
"Temperature", |
|
0.0, 1.0, 0.5, |
|
help="Controls randomness in generation" |
|
) |
|
|
|
st.button("π Reset Conversation", on_click=reset_conversation) |
|
|
|
with st.container(): |
|
st.markdown(""" |
|
<div class="sidebar-info"> |
|
<h4>Current Model: {}</h4> |
|
<p><em>Note: Generated content may be inaccurate or false.</em></p> |
|
</div> |
|
""".format(selected_model), unsafe_allow_html=True) |
|
|
|
|
|
col1, col2 = st.columns(2) |
|
|
|
with col1: |
|
if st.button("π Data Generation", key="gen_button", help="Generate new data"): |
|
st.session_state.task_choice = "Data Generation" |
|
|
|
with col2: |
|
if st.button("π·οΈ Data Labeling", key="label_button", help="Label existing data"): |
|
st.session_state.task_choice = "Data Labeling" |
|
|
|
if "task_choice" in st.session_state: |
|
if st.session_state.task_choice == "Data Generation": |
|
st.header("π Data Generation") |
|
|
|
classification_type = st.selectbox( |
|
"Classification Type", |
|
["Sentiment Analysis", "Binary Classification", "Multi-Class Classification"] |
|
) |
|
|
|
if classification_type == "Sentiment Analysis": |
|
labels = ["Positive", "Negative", "Neutral"] |
|
elif classification_type == "Binary Classification": |
|
col1, col2 = st.columns(2) |
|
with col1: |
|
label_1 = st.text_input("First class", "Positive") |
|
with col2: |
|
label_2 = st.text_input("Second class", "Negative") |
|
labels = [label_1, label_2] if label_1 and label_2 else ["Positive", "Negative"] |
|
else: |
|
num_classes = st.slider("Number of classes", 3, 10, 3) |
|
labels = [] |
|
cols = st.columns(3) |
|
for i in range(num_classes): |
|
with cols[i % 3]: |
|
label = st.text_input(f"Class {i+1}", f"Class_{i+1}") |
|
labels.append(label) |
|
|
|
domain = st.selectbox("Domain", ["Restaurant reviews", "E-commerce reviews", "Custom"]) |
|
if domain == "Custom": |
|
domain = st.text_input("Specify custom domain") |
|
|
|
col1, col2 = st.columns(2) |
|
with col1: |
|
min_words = st.number_input("Min words", 10, 90, 20) |
|
with col2: |
|
max_words = st.number_input("Max words", min_words, 90, 50) |
|
|
|
use_few_shot = st.toggle("Use few-shot examples") |
|
few_shot_examples = [] |
|
if use_few_shot: |
|
num_examples = st.slider("Number of few-shot examples", 1, 5, 1) |
|
for i in range(num_examples): |
|
with st.expander(f"Example {i+1}"): |
|
content = st.text_area(f"Content", key=f"few_shot_content_{i}") |
|
label = st.selectbox(f"Label", labels, key=f"few_shot_label_{i}") |
|
if content and label: |
|
few_shot_examples.append({"content": content, "label": label}) |
|
|
|
num_to_generate = st.number_input("Number of examples", 1, 100, 10) |
|
user_prompt = st.text_area("Additional instructions (optional)") |
|
|
|
prompt_template = PromptTemplate( |
|
input_variables=["classification_type", "domain", "num_examples", "min_words", "max_words", "labels", "user_prompt"], |
|
template=( |
|
"You are a professional {classification_type} expert tasked with generating examples for {domain}.\n" |
|
"Use the following parameters:\n" |
|
"- Generate exactly {num_examples} examples\n" |
|
"- Each example MUST be between {min_words} and {max_words} words long\n" |
|
"- Use these labels: {labels}\n" |
|
"- Generate the examples in this format: 'Example text. Label: [label]'\n" |
|
"- Do not include word counts or any additional information\n" |
|
"Additional instructions: {user_prompt}\n\n" |
|
"Generate numbered examples:" |
|
) |
|
) |
|
|
|
if st.button("π― Generate Examples"): |
|
with st.spinner("Generating examples..."): |
|
system_prompt = prompt_template.format( |
|
classification_type=classification_type, |
|
domain=domain, |
|
num_examples=num_to_generate, |
|
min_words=min_words, |
|
max_words=max_words, |
|
labels=", ".join(labels), |
|
user_prompt=user_prompt |
|
) |
|
try: |
|
stream = client.chat.completions.create( |
|
model=selected_model, |
|
messages=[{"role": "system", "content": system_prompt}], |
|
temperature=temperature, |
|
stream=True, |
|
max_tokens=3000, |
|
) |
|
response = st.write_stream(stream) |
|
st.session_state.messages.append({"role": "assistant", "content": response}) |
|
|
|
|
|
st.markdown("---") |
|
follow_up = st.radio( |
|
"What would you like to do next?", |
|
["Generate more examples", "Modify parameters and generate again", "Switch to labeling"], |
|
key="generation_follow_up" |
|
) |
|
|
|
if st.button("Continue"): |
|
if follow_up == "Generate more examples": |
|
system_prompt = prompt_template.format( |
|
classification_type=classification_type, |
|
domain=domain, |
|
num_examples=num_to_generate, |
|
min_words=min_words, |
|
max_words=max_words, |
|
labels=", ".join(labels), |
|
user_prompt=user_prompt |
|
) |
|
stream = client.chat.completions.create( |
|
model=selected_model, |
|
messages=[{"role": "system", "content": system_prompt}], |
|
temperature=temperature, |
|
stream=True, |
|
max_tokens=3000, |
|
) |
|
response = st.write_stream(stream) |
|
st.session_state.messages.append({"role": "assistant", "content": response}) |
|
elif follow_up == "Switch to labeling": |
|
st.session_state.task_choice = "Data Labeling" |
|
st.experimental_rerun() |
|
|
|
except Exception as e: |
|
st.error("An error occurred during generation.") |
|
st.error(f"Details: {e}") |
|
|
|
elif st.session_state.task_choice == "Data Labeling": |
|
st.header("π·οΈ Data Labeling") |
|
|
|
classification_type = st.selectbox( |
|
"Classification Type", |
|
["Sentiment Analysis", "Binary Classification", "Multi-Class Classification"], |
|
key="label_class_type" |
|
) |
|
|
|
if classification_type == "Sentiment Analysis": |
|
labels = ["Positive", "Negative", "Neutral"] |
|
elif classification_type == "Binary Classification": |
|
col1, col2 = st.columns(2) |
|
with col1: |
|
label_1 = st.text_input("First class", "Positive", key="label_first") |
|
with col2: |
|
label_2 = st.text_input("Second class", "Negative", key="label_second") |
|
labels = [label_1, label_2] if label_1 and label_2 else ["Positive", "Negative"] |
|
else: |
|
num_classes = st.slider("Number of classes", 3, 10, 3, key="label_num_classes") |
|
labels = [] |
|
cols = st.columns(3) |
|
for i in range(num_classes): |
|
with cols[i % 3]: |
|
label = st.text_input(f"Class {i+1}", f"Class_{i+1}", key=f"label_class_{i}") |
|
labels.append(label) |
|
|
|
use_few_shot = st.toggle("Use few-shot examples for labeling") |
|
few_shot_examples = [] |
|
if use_few_shot: |
|
num_few_shot = st.slider("Number of few-shot examples", 1, 5, 1) |
|
for i in range(num_few_shot): |
|
with st.expander(f"Few-shot Example {i+1}"): |
|
content = st.text_area(f"Content", key=f"label_few_shot_content_{i}") |
|
label = st.selectbox(f"Label", labels, key=f"label_few_shot_label_{i}") |
|
if content and label: |
|
few_shot_examples.append(f"{content}\nLabel: {label}") |
|
|
|
num_examples = st.number_input("Number of examples to classify", 1, 100, 1) |
|
|
|
examples_to_classify = [] |
|
if num_examples <= 20: |
|
for i in range(num_examples): |
|
example = st.text_area(f"Example {i+1}", key=f"example_{i}") |
|
if example: |
|
examples_to_classify.append(example) |
|
else: |
|
examples_text = st.text_area( |
|
"Enter examples (one per line)", |
|
height=300, |
|
help="Enter each example on a new line" |
|
) |
|
if examples_text: |
|
examples_to_classify = [ex.strip() for ex in examples_text.split('\n') if ex.strip()] |
|
if len(examples_to_classify) > num_examples: |
|
examples_to_classify = examples_to_classify[:num_examples] |
|
|
|
user_prompt = st.text_area("Additional instructions (optional)", key="label_instructions") |
|
|
|
few_shot_text = "\n\n".join(few_shot_examples) if few_shot_examples else "" |
|
examples_text = "\n".join([f"{i+1}. {ex}" for i, ex in enumerate(examples_to_classify)]) |
|
|
|
label_prompt_template = PromptTemplate( |
|
input_variables=["classification_type", "labels", "few_shot_examples", "examples", "user_prompt"], |
|
template=( |
|
"You are a professional {classification_type} expert. Classify the following examples using these labels: {labels}.\n" |
|
"Instructions:\n" |
|
"- Return the numbered example followed by its classification in the format: 'Example text. Label: [label]'\n" |
|
"- Do not provide any additional information or explanations\n" |
|
"{user_prompt}\n\n" |
|
"Few-shot examples:\n{few_shot_examples}\n\n" |
|
"Examples to classify:\n{examples}\n\n" |
|
"Output:\n" |
|
) |
|
) |
|
|
|
if st.button("π·οΈ Label Data"): |
|
if examples_to_classify: |
|
with st.spinner("Labeling data..."): |
|
system_prompt = label_prompt_template.format( |
|
classification_type=classification_type, |
|
labels=", ".join(labels), |
|
few_shot_examples=few_shot_text, |
|
examples=examples_text, |
|
user_prompt=user_prompt |
|
) |
|
try: |
|
stream = client.chat.completions.create( |
|
model=selected_model, |
|
messages=[{"role": "system", "content": system_prompt}], |
|
temperature=temperature, |
|
stream=True, |
|
max_tokens=3000, |
|
) |
|
response = st.write_stream(stream) |
|
st.session_state.messages.append({"role": "assistant", "content": response}) |
|
|
|
|
|
st.markdown("---") |
|
follow_up = st.radio( |
|
"What would you like to do next?", |
|
["Label more data", "Modify parameters and label again", "Switch to generation"], |
|
key="labeling_follow_up" |
|
) |
|
|
|
if st.button("Continue"): |
|
if follow_up == "Label more data": |
|
st.session_state.examples_to_classify = [] |
|
st.experimental_rerun() |
|
elif follow_up == "Switch to generation": |
|
st.session_state.task_choice = "Data Generation" |
|
st.experimental_rerun() |
|
|
|
except Exception as e: |
|
st.error("An error occurred during labeling.") |
|
st.error(f"Details: {e}") |
|
else: |
|
st.warning("Please enter at least one example to classify.") |