asoria's picture
asoria HF staff
Try Snowflake/snowflake-arctic-embed-s
6db2521 verified
# Inspired by https://huggingface.co/spaces/davanstrien/dataset_column_search
import os
from functools import lru_cache
from urllib.parse import quote
import faiss
import gradio as gr
import numpy as np
import pandas as pd
from dotenv import load_dotenv
from httpx import Client
from huggingface_hub import HfApi
from huggingface_hub.utils import logging
from sentence_transformers import SentenceTransformer
from tqdm.contrib.concurrent import thread_map
load_dotenv()
HF_TOKEN = os.getenv("HF_TOKEN")
assert HF_TOKEN is not None, "You need to set HF_TOKEN in your environment variables"
BASE_DATASETS_SERVER_URL = "/static-proxy?url=https%3A%2F%2Fdatasets-server.huggingface.co%26quot%3B%3C%2Fspan%3E%3C!-- HTML_TAG_END -->
logger = logging.get_logger(__name__)
headers = {
"authorization": f"Bearer ${HF_TOKEN}",
}
client = Client(headers=headers)
api = HfApi(token=HF_TOKEN)
def get_first_config_name(dataset: str):
try:
resp = client.get(f"{BASE_DATASETS_SERVER_URL}/splits?dataset={dataset}")
data = resp.json()
return data["splits"][0]["config"][0]
except Exception as e:
logger.error(f"Failed to get splits for {dataset}: {e}")
return None
def datasets_server_valid_rows(dataset: str):
try:
resp = client.get(f"{BASE_DATASETS_SERVER_URL}/is-valid?dataset={dataset}")
return resp.json()["viewer"]
except Exception as e:
logger.error(f"Failed to get is-valid for {dataset}: {e}")
return None
def dataset_is_valid(dataset):
return dataset if datasets_server_valid_rows(dataset.id) else None
def get_first_config_and_split_name(hub_id: str):
try:
resp = client.get(f"/static-proxy?url=https%3A%2F%2Fdatasets-server.huggingface.co%2Fsplits%3Fdataset%3D%3Cspan class="hljs-subst">{hub_id}")
data = resp.json()
return data["splits"][0]["config"], data["splits"][0]["split"]
except Exception as e:
logger.error(f"Failed to get splits for {hub_id}: {e}")
return None
def get_dataset_info(hub_id: str, config: str | None = None):
if config is None:
config = get_first_config_and_split_name(hub_id)
if config is None:
return None
else:
config = config[0]
resp = client.get(f"{BASE_DATASETS_SERVER_URL}/info?dataset={hub_id}&config={config}")
resp.raise_for_status()
return resp.json()
def dataset_with_info(dataset):
try:
if info := get_dataset_info(dataset.id):
columns = info.get("dataset_info", {}).get("features", {})
if columns is not None:
return {
"dataset": dataset.id,
"column_names": ','.join(list(columns.keys())),
"text": f"{dataset.id}-{','.join(list(columns.keys()))}",
"likes": dataset.likes,
"downloads": dataset.downloads,
"created_at": dataset.created_at,
"tags": dataset.tags,
"text": f"{str(dataset.id).split('/')[-1]}-{','.join(list(columns.keys()))}",
}
except Exception as e:
logger.error(f"Failed to get info for {dataset.id}: {e}")
return None
@lru_cache(maxsize=100)
def prep_data():
datasets = list(api.list_datasets(limit=None, sort="createdAt", direction=-1))
print(f"Found {len(datasets)} datasets in the hub.")
has_server = thread_map(
dataset_is_valid,
datasets,
)
datasets_with_server = [x for x in has_server if x is not None]
print(f"Found {len(datasets_with_server)} valid datasets.")
dataset_infos = thread_map(dataset_with_info, datasets_with_server)
dataset_infos = [x for x in dataset_infos if x is not None]
print(f"Found {len(dataset_infos)} datasets with info.")
return dataset_infos
all_datasets = prep_data()
all_datasets_df = pd.DataFrame.from_dict(all_datasets)
print(all_datasets_df.head())
text = all_datasets_df['text']
encoder = SentenceTransformer("Snowflake/snowflake-arctic-embed-s")
vectors = encoder.encode(text)
vector_dimension = vectors.shape[1]
print("Start indexing")
index = faiss.IndexFlatL2(vector_dimension)
faiss.normalize_L2(vectors)
index.add(vectors)
print("Indexing done")
def render_model_hub_link(hub_id):
link = f"https://huggingface.co/datasets/{quote(hub_id)}"
return f'<a target="_blank" href="{link}" style="color: var(--link-text-color); text-decoration: underline;text-decoration-style: dotted;">{hub_id}</a>'
def search(dataset_name, k):
print(f"start search for {dataset_name}")
try:
dataset_row = all_datasets_df[all_datasets_df.dataset == dataset_name].iloc[0]
except IndexError:
return pd.DataFrame([{"error": "❌ Dataset does not exist or is not supported"}])
text = dataset_row["text"]
search_vector = encoder.encode(text)
_vector = np.array([search_vector])
faiss.normalize_L2(_vector)
distances, ann = index.search(_vector, k=k)
results = pd.DataFrame({"distances": distances[0], "ann": ann[0]})
merge = pd.merge(results, all_datasets_df, left_on="ann", right_index=True)
merge["dataset"] = merge["dataset"].apply(render_model_hub_link)
return merge.drop("text", axis=1)
with gr.Blocks() as demo:
gr.Markdown("# Search similar Datasets on Hugging Face")
gr.Markdown("This space shows similar datasets based on a name and columns. It uses https://github.com/facebookresearch/faiss for vector indexing.")
gr.Markdown("'Text' column was used for indexing. Where text is a concatenation of 'dataset_name'-'column_names'")
dataset_name = gr.Textbox("sksayril/medicine-info", label="Dataset Name")
k = gr.Slider(5, 200, 20, step=5, interactive=True, label="Top K Nearest Neighbors")
btn = gr.Button("Show similar datasets")
df = gr.DataFrame(datatype="markdown")
btn.click(search, inputs=[dataset_name, k], outputs=df)
gr.Markdown("This space was inspired by https://huggingface.co/spaces/davanstrien/dataset_column_search")
demo.launch()