File size: 3,285 Bytes
72e2472
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
#from dotenv import find_dotenv, load_dotenv
#_ = load_dotenv(find_dotenv())

import solara

# Clean up all the directories used in this notebook
import shutil

shutil.rmtree("./data", ignore_errors=True)

import polars as pl

df = pl.read_csv(
    "https://drive.google.com/uc?export=download&id=1uD3h7xYxr9EoZ0Ggoh99JtQXa3AxtxyU"
)

import string

df = df.with_columns(
    pl.Series("Album", [string.capwords(album) for album in df["Album"]])
)
df = df.with_columns(pl.Series("Song", [string.capwords(song) for song in df["Song"]]))
df = df.with_columns(pl.col("Lyrics").fill_null("None"))

df = df.with_columns(
    text=pl.lit("# ")
    + pl.col("Album")
    + pl.lit(": ")
    + pl.col("Song")
    + pl.lit("\n\n")
    + pl.col("Lyrics")
    # text = pl.col("Lyrics")
)

import lancedb

db = lancedb.connect("data/")

from lancedb.embeddings import get_registry

embeddings = (
    get_registry()
    .get("sentence-transformers")
    .create(name="TaylorAI/gte-tiny", device="cpu")
)

from lancedb.pydantic import LanceModel, Vector


class Songs(LanceModel):
    Song: str
    Lyrics: str
    Album: str
    Artist: str
    text: str = embeddings.SourceField()
    vector: Vector(embeddings.ndims()) = embeddings.VectorField()

table = db.create_table("Songs", schema=Songs)
table.add(data=df)

import os
from typing import Optional

from langchain_community.chat_models import ChatOpenAI

class ChatOpenRouter(ChatOpenAI):
    openai_api_base: str
    openai_api_key: str
    model_name: str

    def __init__(
        self,
        model_name: str,
        openai_api_key: Optional[str] = None,
        openai_api_base: str = "https://openrouter.ai/api/v1",
        **kwargs,
    ):
        openai_api_key = os.getenv("OPENROUTER_API_KEY")
        super().__init__(
            openai_api_base=openai_api_base,
            openai_api_key=openai_api_key,
            model_name=model_name,
            **kwargs,
        )

llm_openrouter = ChatOpenRouter(model_name="meta-llama/llama-3.1-405b-instruct")

def get_relevant_texts(query, table):
    results = (
        table.search(query)
             .limit(5)
             .to_polars()
    )
    return " ".join([results["text"][i] + "\n\n---\n\n" for i in range(5)])

def generate_prompt(query, table):
    return (
        "Answer the question based only on the following context:\n\n"
        + get_relevant_texts(query, table)
        + "\n\nQuestion: "
        + query
    )

def generate_response(query, table):
    prompt = generate_prompt(query, table)
    response = llm_openrouter.invoke(input=prompt)
    return response.content

query = solara.reactive("Which song is about a boy who is having nightmares?")
@solara.component
def Page():
    with solara.Column(margin=10):
        solara.Markdown("# Metallica Song Finder Bot")
        solara.InputText("Enter some query:", query, continuous_update=False)
        if query.value != "":
            df_results = table.search(query.value).limit(5).to_polars()
            df_results = df_results.select(['Song', 'Album', '_distance', 'Lyrics', 'Artist'])
            solara.Markdown("## Answer:")
            solara.Markdown(generate_response(query.value, table))
            solara.Markdown("## Context:")
            solara.DataFrame(df_results, items_per_page=5)