recipe-finder / app.py
alonsosilva's picture
Trust remote code for the dataset
c2318c3
raw
history blame
3.71 kB
import sqlite3
import sqlite_vec
from datasets import load_dataset
dataset = load_dataset("m3hrdadfi/recipe_nlg_lite", trust_remote_code=True)
recipe_names = dataset["train"]["name"]
from sentence_transformers import SentenceTransformer
tfm_base = SentenceTransformer("all-MiniLM-L6-v2")
X_tfm = tfm_base.encode(recipe_names)
n_feats = X_tfm.shape[1]
import polars as pl
import solara
@solara.component
def Display_Full(query,db,limit):
with db:
rows_orig = db.execute(
f"""
SELECT
rowid,
distance
FROM vec_sents
WHERE embedding MATCH ?
ORDER BY distance
LIMIT {limit}
""",
[sqlite_vec.serialize_float32(query)],
).fetchall()
df1 = pl.DataFrame({"results": [recipe_names[rowid] for rowid in [dict(row)["rowid"] for row in rows_orig]]})
with solara.Column():
solara.Markdown("## Full precision")
solara.DataFrame(df1, items_per_page=10)
@solara.component
def Display_Binary(query,db,limit):
with db:
rows_bin = db.execute(
f"""
SELECT
rowid,
distance
FROM bin_vec_sents
WHERE embedding MATCH vec_quantize_binary(?)
ORDER BY distance
LIMIT {limit}
""",
[sqlite_vec.serialize_float32(query)],
).fetchall()
df2 = pl.DataFrame({"results": [recipe_names[rowid] for rowid in [dict(row)["rowid"] for row in rows_bin]]})
with solara.Column():
solara.Markdown("## Binary quantization")
solara.DataFrame(df2, items_per_page=10)
@solara.component
def Page():
with solara.Column(margin=10):
with solara.Head():
solara.Title("Recipe finder")
solara.Markdown("# Recipe finder")
solara.Markdown("I built this tool to help me get a feeling of binary embedding quantization in [sqlite-vec](https://alexgarcia.xyz/sqlite-vec/). For any given text, it gives the top 10 results. The dataset I'm using is [m3hrdadfi/recipe_nlg_lite](https://hf.co/datasets/m3hrdadfi/recipe_nlg_lite) which consists of 6,119 recipes. Inspired by [Exploring SQLite-vec](https://www.youtube.com/watch?v=wYU66AjRIAc) by [@fishnets88](https://twitter.com/fishnets88)")
q = solara.use_reactive("I would like to have some vegetable soup")
solara.InputText("Enter a query", value=q, continuous_update=True)
query = tfm_base.encode([q.value])[0]
limit = 10
db = sqlite3.connect(":memory:")
db.enable_load_extension(True)
sqlite_vec.load(db)
db.enable_load_extension(False)
db.row_factory = sqlite3.Row
db.execute(f"create virtual table vec_sents using vec0(embedding float[{n_feats}])")
with db:
for i, item in enumerate([{"vector": x} for i, x in enumerate(X_tfm)]):
db.execute(
"INSERT INTO vec_sents(rowid, embedding) VALUES (?, ?)",
[i, sqlite_vec.serialize_float32(item["vector"])],
)
db.execute(f"create virtual table bin_vec_sents using vec0(embedding bit[{n_feats}])")
with db:
for i, item in enumerate([{"vector": x} for i, x in enumerate(X_tfm)]):
db.execute(
"INSERT INTO bin_vec_sents(rowid, embedding) VALUES (?, vec_quantize_binary(?))",
[i, sqlite_vec.serialize_float32(item["vector"])],
)
with solara.Row():
Display_Full(query,db,limit)
Display_Binary(query,db,limit)