Spaces:
Sleeping
Sleeping
import sqlite3 | |
import sqlite_vec | |
from datasets import load_dataset | |
dataset = load_dataset("m3hrdadfi/recipe_nlg_lite", trust_remote_code=True) | |
recipe_names = dataset["train"]["name"] | |
from sentence_transformers import SentenceTransformer | |
tfm_base = SentenceTransformer("all-MiniLM-L6-v2") | |
X_tfm = tfm_base.encode(recipe_names) | |
n_feats = X_tfm.shape[1] | |
import polars as pl | |
import solara | |
def Display_Full(query,db,limit): | |
with db: | |
rows_orig = db.execute( | |
f""" | |
SELECT | |
rowid, | |
distance | |
FROM vec_sents | |
WHERE embedding MATCH ? | |
ORDER BY distance | |
LIMIT {limit} | |
""", | |
[sqlite_vec.serialize_float32(query)], | |
).fetchall() | |
df1 = pl.DataFrame({"results": [recipe_names[rowid] for rowid in [dict(row)["rowid"] for row in rows_orig]]}) | |
with solara.Column(): | |
solara.Markdown("## Full precision") | |
solara.DataFrame(df1, items_per_page=10) | |
def Display_Binary(query,db,limit): | |
with db: | |
rows_bin = db.execute( | |
f""" | |
SELECT | |
rowid, | |
distance | |
FROM bin_vec_sents | |
WHERE embedding MATCH vec_quantize_binary(?) | |
ORDER BY distance | |
LIMIT {limit} | |
""", | |
[sqlite_vec.serialize_float32(query)], | |
).fetchall() | |
df2 = pl.DataFrame({"results": [recipe_names[rowid] for rowid in [dict(row)["rowid"] for row in rows_bin]]}) | |
with solara.Column(): | |
solara.Markdown("## Binary quantization") | |
solara.DataFrame(df2, items_per_page=10) | |
def Page(): | |
with solara.Column(margin=10): | |
with solara.Head(): | |
solara.Title("Recipe finder") | |
solara.Markdown("# Recipe finder") | |
solara.Markdown("I built this tool to help me get a feeling of binary embedding quantization in [sqlite-vec](https://alexgarcia.xyz/sqlite-vec/). For any given text, it gives the top 10 results. The dataset I'm using is [m3hrdadfi/recipe_nlg_lite](https://hf.co/datasets/m3hrdadfi/recipe_nlg_lite) which consists of 6,119 recipes. Inspired by [Exploring SQLite-vec](https://www.youtube.com/watch?v=wYU66AjRIAc) by [@fishnets88](https://twitter.com/fishnets88)") | |
q = solara.use_reactive("I would like to have some vegetable soup") | |
solara.InputText("Enter a query", value=q, continuous_update=True) | |
query = tfm_base.encode([q.value])[0] | |
limit = 10 | |
db = sqlite3.connect(":memory:") | |
db.enable_load_extension(True) | |
sqlite_vec.load(db) | |
db.enable_load_extension(False) | |
db.row_factory = sqlite3.Row | |
db.execute(f"create virtual table vec_sents using vec0(embedding float[{n_feats}])") | |
with db: | |
for i, item in enumerate([{"vector": x} for i, x in enumerate(X_tfm)]): | |
db.execute( | |
"INSERT INTO vec_sents(rowid, embedding) VALUES (?, ?)", | |
[i, sqlite_vec.serialize_float32(item["vector"])], | |
) | |
db.execute(f"create virtual table bin_vec_sents using vec0(embedding bit[{n_feats}])") | |
with db: | |
for i, item in enumerate([{"vector": x} for i, x in enumerate(X_tfm)]): | |
db.execute( | |
"INSERT INTO bin_vec_sents(rowid, embedding) VALUES (?, vec_quantize_binary(?))", | |
[i, sqlite_vec.serialize_float32(item["vector"])], | |
) | |
with solara.Row(): | |
Display_Full(query,db,limit) | |
Display_Binary(query,db,limit) | |