Spaces:
Runtime error
Runtime error
Commit
·
6621d73
1
Parent(s):
d0819c0
Adding filter for ids
Browse files- src/utilities.py +28 -0
src/utilities.py
CHANGED
@@ -2,6 +2,7 @@ import os
|
|
2 |
|
3 |
import numpy as np
|
4 |
import pandas as pd
|
|
|
5 |
from datasets import Dataset, DownloadMode, load_dataset
|
6 |
from gradio_client import Client
|
7 |
|
@@ -12,6 +13,7 @@ USERNAME = os.environ["USERNAME"]
|
|
12 |
OG_DATASET = f"{USERNAME}/dataset-creator-reddit-{SUBREDDIT}"
|
13 |
PROCESSED_DATASET = os.environ['PROCESSED_DATASET']
|
14 |
embeddings_space = f"{USERNAME}/nomic-embeddings"
|
|
|
15 |
|
16 |
logger = setup_logger(__name__)
|
17 |
|
@@ -36,6 +38,9 @@ def merge_and_update_datasets(dataset, original_dataset):
|
|
36 |
odf = original_dataset['train'].to_pandas()
|
37 |
df = dataset['train'].to_pandas()
|
38 |
|
|
|
|
|
|
|
39 |
# Step 1: Merge df onto odf
|
40 |
# We'll bring in 'content' and 'embedding' from df to compare and possibly update 'embedding'
|
41 |
merged_df = pd.merge(odf, df[['id', 'content', 'embedding']], on='id', how='left', suffixes=('_odf', ''))
|
@@ -60,6 +65,29 @@ def merge_and_update_datasets(dataset, original_dataset):
|
|
60 |
return dataset, updated_row_count
|
61 |
|
62 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
63 |
def update_embeddings(content, client):
|
64 |
embedding = client.predict('search_document: ' + content, api_name="/embed")
|
65 |
return np.array(embedding)
|
|
|
2 |
|
3 |
import numpy as np
|
4 |
import pandas as pd
|
5 |
+
import requests
|
6 |
from datasets import Dataset, DownloadMode, load_dataset
|
7 |
from gradio_client import Client
|
8 |
|
|
|
13 |
OG_DATASET = f"{USERNAME}/dataset-creator-reddit-{SUBREDDIT}"
|
14 |
PROCESSED_DATASET = os.environ['PROCESSED_DATASET']
|
15 |
embeddings_space = f"{USERNAME}/nomic-embeddings"
|
16 |
+
FILTER_IDS_URL = "https://huggingface.co/spaces/reddit-tools-HF/dataset-creator-reddit-bestofredditorupdates/raw/main/filter_ids.json"
|
17 |
|
18 |
logger = setup_logger(__name__)
|
19 |
|
|
|
38 |
odf = original_dataset['train'].to_pandas()
|
39 |
df = dataset['train'].to_pandas()
|
40 |
|
41 |
+
# Filter ODF in-case we missed any
|
42 |
+
odf = remove_filtered_rows(odf, FILTER_IDS_URL)
|
43 |
+
|
44 |
# Step 1: Merge df onto odf
|
45 |
# We'll bring in 'content' and 'embedding' from df to compare and possibly update 'embedding'
|
46 |
merged_df = pd.merge(odf, df[['id', 'content', 'embedding']], on='id', how='left', suffixes=('_odf', ''))
|
|
|
65 |
return dataset, updated_row_count
|
66 |
|
67 |
|
68 |
+
def remove_filtered_rows(df: pd.DataFrame, url: str) -> pd.DataFrame:
|
69 |
+
"""
|
70 |
+
Removes rows from the DataFrame where the 'id' is present in the JSON file at the given URL.
|
71 |
+
|
72 |
+
:param df: Input DataFrame to be filtered.
|
73 |
+
:param url: URL to the JSON file containing the filter IDs.
|
74 |
+
:return: DataFrame with rows containing IDs present in the JSON file removed.
|
75 |
+
"""
|
76 |
+
|
77 |
+
# Load filter IDs from JSON file at the URL
|
78 |
+
response = requests.get(url)
|
79 |
+
filter_ids = response.json()
|
80 |
+
|
81 |
+
logger.info(f"Loaded {len(filter_ids)} IDs from {url}")
|
82 |
+
|
83 |
+
# Remove the rows with IDs present in filter_ids
|
84 |
+
filtered_df = df[~df['id'].astype(str).isin(filter_ids)]
|
85 |
+
|
86 |
+
logger.info(f"Filtered {len(df) - len(filtered_df)} rows from the DataFrame")
|
87 |
+
|
88 |
+
return filtered_df
|
89 |
+
|
90 |
+
|
91 |
def update_embeddings(content, client):
|
92 |
embedding = client.predict('search_document: ' + content, api_name="/embed")
|
93 |
return np.array(embedding)
|