import gradio as gr from datasets import load_dataset, Dataset from llama_index.core import PromptTemplate from llama_index.core.prompts import ChatMessage from llama_index.llms.openai import OpenAI from pydantic import BaseModel, Field import asyncio import numpy as np import pandas as pd from chromadb import Client from chromadb.utils.embedding_functions import SentenceTransformerEmbeddingFunction import structlog logger = structlog.get_logger() logger.info('Loading embedding model') embed_fn = SentenceTransformerEmbeddingFunction('BAAI/bge-small-en-v1.5') def load_train_data_and_vectorstore(): logger.info("Loading dataset") ds = load_dataset('SetFit/amazon_reviews_multi_en') train_samples_per_class = 50 eval_test_samples_per_class = 10 train = Dataset.from_pandas(ds['train'].to_pandas().groupby('label').sample(train_samples_per_class, random_state=1234).reset_index(drop=True)) reviews = Client().create_collection( name='reviews', embedding_function=embed_fn, get_or_create=True ) logger.info("Adding documents to vector store") reviews.add(documents=train['text'], metadatas=[{'rating': x} for x in train['label']], ids=train['id']) return train, reviews train, reviews = load_train_data_and_vectorstore() class Rating(BaseModel): rating: int = Field(..., description="Rating of the review", enum=[0, 1, 2, 3, 4]) llm = OpenAI(model="gpt-4o-mini") structured_llm = llm.as_structured_llm(Rating) prompt_tmpl_str = """\ The review text is below. --------------------- {review} --------------------- Given the review text and not prior knowledge, \ please attempt to predict the score of the review. Query: What is the rating of this review? Answer: \ """ prompt_tmpl = PromptTemplate( prompt_tmpl_str, ) async def zero_shot_predict(text): messages = [ ChatMessage.from_str(prompt_tmpl.format(review=text)) ] response = await structured_llm.achat(messages) return response.raw.rating rng = np.random.Generator(np.random.PCG64(1234)) def random_few_shot_examples_fn(**kwargs): if n_samples:=kwargs.get('n_samples'): random_examples = train.shuffle(generator=rng)[:n_samples] else: random_examples = train.shuffle(generator=rng)[:5] result_strs = [] for text, rating in zip(random_examples['text'], random_examples['label']): result_strs.append(f"Text: {text}\nRating: {rating}") return "\n\n".join(result_strs) few_shot_prompt_tmpl_str = """\ The review text is below. --------------------- {review} --------------------- Given the review text and not prior knowledge, \ please attempt to predict the review score of the context. \ Here are several examples of reviews and their ratings: {random_few_shot_examples} Query: What is the rating of this review? Answer: \ """ few_shot_prompt_tmpl = PromptTemplate( few_shot_prompt_tmpl_str, function_mappings={"random_few_shot_examples": random_few_shot_examples_fn}, ) async def random_few_shot_predict(text, n_examples=5): tasks = [] for _ in range(3): messages = [ ChatMessage.from_str(few_shot_prompt_tmpl.format(review=text, n_samples=n_examples)) ] tasks.append(structured_llm.achat(messages, temperature=0.9)) results = await asyncio.gather(*tasks) ratings = [r.raw.rating for r in results] # print(ratings) return pd.Series(ratings).mode()[0] def dynamic_few_shot_examples_fn(**kwargs): n_examples = kwargs.get('n_examples', 5) retrievals = reviews.query( query_texts=[kwargs['review']], n_results=n_examples ) result_strs = [] documents = retrievals['documents'][0] metadatas = retrievals['metadatas'][0] for document, metadata in zip(documents, metadatas): result_strs.append(f"Text: {document}\nRating: {metadata.get('rating')}") return "\n\n".join(result_strs) dynamic_few_shot_prompt_tmpl_str = """\ The review text is below. --------------------- {review} --------------------- Given the review text and not prior knowledge, \ please attempt to predict the review score of the context. \ Here are several examples of reviews and their ratings: {dynamic_few_shot_examples} Query: What is the rating of this review? Answer: \ """ dynamic_few_shot_prompt_tmpl = PromptTemplate( dynamic_few_shot_prompt_tmpl_str, function_mappings={"dynamic_few_shot_examples": dynamic_few_shot_examples_fn}, ) async def dynamic_few_shot_predict(text, n_examples=5): messages = [ ChatMessage.from_str(dynamic_few_shot_prompt_tmpl.format(review=text, n_examples=n_examples)) ] response = await structured_llm.achat(messages) return response.raw.rating def classify(review, num_examples, api_key): llm = OpenAI(model="gpt-4o-mini", api_key=api_key).as_structured_llm(Rating) zero_shot = asyncio.run(zero_shot_predict(review)) random_few_shot = asyncio.run(random_few_shot_predict(review, num_examples)) dynamic_few_shot = asyncio.run(dynamic_few_shot_predict(review, num_examples)) return zero_shot, random_few_shot, dynamic_few_shot with gr.Blocks() as demo: with gr.Row(): with gr.Column(): api_key = gr.Textbox(label='Openai API Key') n_examples = gr.Slider(minimum=1, maximum=10, value=5, step=1, label='Number of examples to retrieve', interactive=True) review = gr.Textbox(label='Review', interactive=True) submit = gr.Button(value='Submit') with gr.Column(): zero_shot_label = gr.Textbox(label='Zero shot', interactive=False) random_few_shot_label = gr.Textbox(label='Random few shot', interactive=False) dynamic_few_shot_label = gr.Textbox(label='Dynamic few shot', interactive=False) submit.click(classify, [review, n_examples], [zero_shot_label, random_few_shot_label, dynamic_few_shot_label]) demo.queue().launch()