import sqlite3 import sqlite_vec from datasets import load_dataset dataset = load_dataset("m3hrdadfi/recipe_nlg_lite", trust_remote_code=True) recipe_names = dataset["train"]["name"] from sentence_transformers import SentenceTransformer tfm_base = SentenceTransformer("all-MiniLM-L6-v2") X_tfm = tfm_base.encode(recipe_names) n_feats = X_tfm.shape[1] import polars as pl import solara @solara.component def Display_Full(query,db,limit): with db: rows_orig = db.execute( f""" SELECT rowid, distance FROM vec_sents WHERE embedding MATCH ? ORDER BY distance LIMIT {limit} """, [sqlite_vec.serialize_float32(query)], ).fetchall() df1 = pl.DataFrame({"results": [recipe_names[rowid] for rowid in [dict(row)["rowid"] for row in rows_orig]]}) with solara.Column(): solara.Markdown("## Full precision") solara.DataFrame(df1, items_per_page=10) @solara.component def Display_Binary(query,db,limit): with db: rows_bin = db.execute( f""" SELECT rowid, distance FROM bin_vec_sents WHERE embedding MATCH vec_quantize_binary(?) ORDER BY distance LIMIT {limit} """, [sqlite_vec.serialize_float32(query)], ).fetchall() df2 = pl.DataFrame({"results": [recipe_names[rowid] for rowid in [dict(row)["rowid"] for row in rows_bin]]}) with solara.Column(): solara.Markdown("## Binary quantization") solara.DataFrame(df2, items_per_page=10) @solara.component def Page(): with solara.Column(margin=10): with solara.Head(): solara.Title("Recipe finder") solara.Markdown("# Recipe finder") solara.Markdown("I built this tool to help me get a feeling of binary embedding quantization in [sqlite-vec](https://alexgarcia.xyz/sqlite-vec/). For any given text, it gives the top 10 results. The dataset I'm using is [m3hrdadfi/recipe_nlg_lite](https://hf.co/datasets/m3hrdadfi/recipe_nlg_lite) which consists of 6,119 recipes. Inspired by [Exploring SQLite-vec](https://www.youtube.com/watch?v=wYU66AjRIAc) by [@fishnets88](https://twitter.com/fishnets88)") q = solara.use_reactive("I would like to have some vegetable soup") solara.InputText("Enter a query", value=q, continuous_update=True) query = tfm_base.encode([q.value])[0] limit = 10 db = sqlite3.connect(":memory:") db.enable_load_extension(True) sqlite_vec.load(db) db.enable_load_extension(False) db.row_factory = sqlite3.Row db.execute(f"create virtual table vec_sents using vec0(embedding float[{n_feats}])") with db: for i, item in enumerate([{"vector": x} for i, x in enumerate(X_tfm)]): db.execute( "INSERT INTO vec_sents(rowid, embedding) VALUES (?, ?)", [i, sqlite_vec.serialize_float32(item["vector"])], ) db.execute(f"create virtual table bin_vec_sents using vec0(embedding bit[{n_feats}])") with db: for i, item in enumerate([{"vector": x} for i, x in enumerate(X_tfm)]): db.execute( "INSERT INTO bin_vec_sents(rowid, embedding) VALUES (?, vec_quantize_binary(?))", [i, sqlite_vec.serialize_float32(item["vector"])], ) with solara.Row(): Display_Full(query,db,limit) Display_Binary(query,db,limit)