Spaces:
Sleeping
Sleeping
import gradio as gr | |
from datasets import load_dataset, Dataset | |
from llama_index.core import PromptTemplate | |
from llama_index.core.prompts import ChatMessage | |
from llama_index.llms.openai import OpenAI | |
from pydantic import BaseModel, Field | |
import asyncio | |
import numpy as np | |
import pandas as pd | |
from chromadb import Client | |
from chromadb.utils.embedding_functions import SentenceTransformerEmbeddingFunction | |
import structlog | |
logger = structlog.get_logger() | |
logger.info('Loading embedding model') | |
embed_fn = SentenceTransformerEmbeddingFunction('BAAI/bge-small-en-v1.5') | |
def load_train_data_and_vectorstore(): | |
logger.info("Loading dataset") | |
ds = load_dataset('SetFit/amazon_reviews_multi_en') | |
train_samples_per_class = 50 | |
eval_test_samples_per_class = 10 | |
train = Dataset.from_pandas(ds['train'].to_pandas().groupby('label').sample(train_samples_per_class, random_state=1234).reset_index(drop=True)) | |
reviews = Client().create_collection( | |
name='reviews', | |
embedding_function=embed_fn, | |
get_or_create=True | |
) | |
logger.info("Adding documents to vector store") | |
reviews.add(documents=train['text'], metadatas=[{'rating': x} for x in train['label']], ids=train['id']) | |
return train, reviews | |
train, reviews = load_train_data_and_vectorstore() | |
class Rating(BaseModel): | |
rating: int = Field(..., description="Rating of the review", enum=[0, 1, 2, 3, 4]) | |
llm = OpenAI(model="gpt-4o-mini") | |
structured_llm = llm.as_structured_llm(Rating) | |
prompt_tmpl_str = """\ | |
The review text is below. | |
--------------------- | |
{review} | |
--------------------- | |
Given the review text and not prior knowledge, \ | |
please attempt to predict the score of the review. | |
Query: What is the rating of this review? | |
Answer: \ | |
""" | |
prompt_tmpl = PromptTemplate( | |
prompt_tmpl_str, | |
) | |
async def zero_shot_predict(text): | |
messages = [ | |
ChatMessage.from_str(prompt_tmpl.format(review=text)) | |
] | |
response = await structured_llm.achat(messages) | |
return response.raw.rating | |
rng = np.random.Generator(np.random.PCG64(1234)) | |
def random_few_shot_examples_fn(**kwargs): | |
if n_samples:=kwargs.get('n_samples'): | |
random_examples = train.shuffle(generator=rng)[:n_samples] | |
else: | |
random_examples = train.shuffle(generator=rng)[:5] | |
result_strs = [] | |
for text, rating in zip(random_examples['text'], random_examples['label']): | |
result_strs.append(f"Text: {text}\nRating: {rating}") | |
return "\n\n".join(result_strs) | |
few_shot_prompt_tmpl_str = """\ | |
The review text is below. | |
--------------------- | |
{review} | |
--------------------- | |
Given the review text and not prior knowledge, \ | |
please attempt to predict the review score of the context. \ | |
Here are several examples of reviews and their ratings: | |
{random_few_shot_examples} | |
Query: What is the rating of this review? | |
Answer: \ | |
""" | |
few_shot_prompt_tmpl = PromptTemplate( | |
few_shot_prompt_tmpl_str, | |
function_mappings={"random_few_shot_examples": random_few_shot_examples_fn}, | |
) | |
async def random_few_shot_predict(text, n_examples=5): | |
tasks = [] | |
for _ in range(3): | |
messages = [ | |
ChatMessage.from_str(few_shot_prompt_tmpl.format(review=text, n_samples=n_examples)) | |
] | |
tasks.append(structured_llm.achat(messages, temperature=0.9)) | |
results = await asyncio.gather(*tasks) | |
ratings = [r.raw.rating for r in results] | |
# print(ratings) | |
return pd.Series(ratings).mode()[0] | |
def dynamic_few_shot_examples_fn(**kwargs): | |
n_examples = kwargs.get('n_examples', 5) | |
retrievals = reviews.query( | |
query_texts=[kwargs['review']], | |
n_results=n_examples | |
) | |
result_strs = [] | |
documents = retrievals['documents'][0] | |
metadatas = retrievals['metadatas'][0] | |
for document, metadata in zip(documents, metadatas): | |
result_strs.append(f"Text: {document}\nRating: {metadata.get('rating')}") | |
return "\n\n".join(result_strs) | |
dynamic_few_shot_prompt_tmpl_str = """\ | |
The review text is below. | |
--------------------- | |
{review} | |
--------------------- | |
Given the review text and not prior knowledge, \ | |
please attempt to predict the review score of the context. \ | |
Here are several examples of reviews and their ratings: | |
{dynamic_few_shot_examples} | |
Query: What is the rating of this review? | |
Answer: \ | |
""" | |
dynamic_few_shot_prompt_tmpl = PromptTemplate( | |
dynamic_few_shot_prompt_tmpl_str, | |
function_mappings={"dynamic_few_shot_examples": dynamic_few_shot_examples_fn}, | |
) | |
async def dynamic_few_shot_predict(text, n_examples=5): | |
messages = [ | |
ChatMessage.from_str(dynamic_few_shot_prompt_tmpl.format(review=text, n_examples=n_examples)) | |
] | |
response = await structured_llm.achat(messages) | |
return response.raw.rating | |
def classify(review, num_examples, api_key): | |
llm = OpenAI(model="gpt-4o-mini", api_key=api_key).as_structured_llm(Rating) | |
zero_shot = asyncio.run(zero_shot_predict(review)) | |
random_few_shot = asyncio.run(random_few_shot_predict(review, num_examples)) | |
dynamic_few_shot = asyncio.run(dynamic_few_shot_predict(review, num_examples)) | |
return zero_shot, random_few_shot, dynamic_few_shot | |
with gr.Blocks() as demo: | |
with gr.Row(): | |
with gr.Column(): | |
api_key = gr.Textbox(label='Openai API Key') | |
n_examples = gr.Slider(minimum=1, maximum=10, value=5, step=1, label='Number of examples to retrieve', interactive=True) | |
review = gr.Textbox(label='Review', interactive=True) | |
submit = gr.Button(value='Submit') | |
with gr.Column(): | |
zero_shot_label = gr.Textbox(label='Zero shot', interactive=False) | |
random_few_shot_label = gr.Textbox(label='Random few shot', interactive=False) | |
dynamic_few_shot_label = gr.Textbox(label='Dynamic few shot', interactive=False) | |
submit.click(classify, [review, n_examples], [zero_shot_label, random_few_shot_label, dynamic_few_shot_label]) | |
demo.queue().launch() |