In [None]:
# ! pip install pymilvus==2.3.4
# ! pip install pyarrow==12.0.0
# !pip install -U transformers

In [None]:
from transformers import DistilBertTokenizerFast
from tensorflow.keras.models import load_model, Model
import numpy as np
import tensorflow as tf
from tqdm import tqdm
from dotenv import load_dotenv
import os
import pandas as pd
from pymilvus import connections, utility
from pymilvus import Collection, DataType, FieldSchema, CollectionSchema
import multiprocessing

In [None]:
tf.config.list_physical_devices('GPU')

In [2]:
data = pd.read_csv('labelled_newscatcher_dataset.csv', sep=";", usecols=['title', 'topic'])
json_data=pd.read_json('News_Category_Dataset_v3.json', lines=True)
data.drop_duplicates(subset=['title'], inplace=True)
json_data.drop_duplicates(subset=['headline'], inplace=True)
json_data = json_data[['headline', 'category']].copy()
json_data.rename(columns={'headline': 'title'}, inplace=True)
data.rename(columns={'topic': 'category'}, inplace=True)
data = pd.concat([data, json_data], axis=0)
data.drop_duplicates(subset=['title'], inplace=True)
data.reset_index(drop=True, inplace=True)
data.reset_index(inplace=True)
data.rename(columns={'title': 'short_description'}, inplace=True)
data

Unnamed: 0,index,category,short_description
0,0,SCIENCE,A closer look at water-splitting's solar fuel ...
1,1,SCIENCE,"An irresistible scent makes locusts swarm, stu..."
2,2,SCIENCE,Artificial intelligence warning: AI will know ...
3,3,SCIENCE,Glaciers Could Have Sculpted Mars Valleys: Study
4,4,SCIENCE,Perseid meteor shower 2020: What time and how ...
...,...,...,...
311171,311171,TECH,RIM CEO Thorsten Heins' 'Significant' Plans Fo...
311172,311172,SPORTS,Maria Sharapova Stunned By Victoria Azarenka I...
311173,311173,SPORTS,"Giants Over Patriots, Jets Over Colts Among M..."
311174,311174,SPORTS,Aldon Smith Arrested: 49ers Linebacker Busted ...


In [3]:
any(data['short_description'].duplicated())

False

In [4]:
data.to_csv('news_processed.csv', index=False)

In [5]:
model_checkpoint = "distilbert-base-uncased"
tokenizer = DistilBertTokenizerFast.from_pretrained(model_checkpoint)
interpreter = tf.lite.Interpreter(model_path="news_classification_hf_distilbert.tflite")
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()

INFO: Created TensorFlow Lite XNNPACK delegate for CPU.


In [6]:
class TextVectorizer:
    '''
    sentence transformers to extract sentence embeddings
    '''
    def vectorize(self, text, tokenizer): # need to have tokenizer as argument to prevent tokenizer error while using multiprocessing
        '''
        This code block of initializing tokenizer within the method is essential, else tokenizer will throw an error while using multiprocessing
        START
        '''
        model_checkpoint = "distilbert-base-uncased"
        tokenizer = DistilBertTokenizerFast.from_pretrained(model_checkpoint)
        '''
        END
        '''
        tokens = tokenizer(text, max_length=80, padding="max_length", truncation=True, return_tensors="tf")
        attention_mask, input_ids = tokens['attention_mask'], tokens['input_ids']
        interpreter.set_tensor(input_details[0]["index"], attention_mask)
        interpreter.set_tensor(input_details[1]["index"], input_ids)
        interpreter.invoke()
        tflite_embeds = interpreter.get_tensor(711)[0]
        return [*tflite_embeds]

In [7]:
vectorizer = TextVectorizer()

In [8]:
# getting max length of article descriptions to be used for VARCHAR while defining schema
max_desc_len = max([len(s) for s in data['short_description']])
max_desc_len

320

In [9]:
# getting max length of article categories to be used for VARCHAR while defining schema
max_cat_len = max([len(s) for s in data['category']])
max_cat_len

14

In [10]:
# # Reading milvus URI & API token from secrets.env
load_dotenv('secrets.env')
uri = os.environ.get("URI")
token = os.environ.get("TOKEN")

In [11]:
connections.connect("default", uri=uri, token=token)
print(f"Connected to DB")

Connected to DB


In [12]:
collection_name = 'news_collection_full'
check_collection = utility.has_collection(collection_name)

In [13]:
if check_collection:
    drop_result = utility.drop_collection(collection_name)
    print("Droped Existing collection")

Droped Existing collection


In [14]:
# Creating collection schema
dim = 768 # embeddings dim
article_id = FieldSchema(name="article_id", dtype=DataType.INT64, is_primary=True, description="primary id") # primary key
article_embed_field = FieldSchema(name="article_embed", dtype=DataType.FLOAT_VECTOR, dim=dim) # description embeddings
article_desc = FieldSchema(name="article_desc", dtype=DataType.VARCHAR, max_length=(max_desc_len + 50), # using max_desc_len to specify VARCHAR len 
                           is_primary=False, description="short description of the article") # short description of article
article_cat = FieldSchema(name="article_category", dtype=DataType.VARCHAR, max_length=(max_cat_len + 50), # using max_desc_len to specify VARCHAR len 
                           is_primary=False, description="category of the article") # category of article
schema = CollectionSchema(fields=[article_id, article_embed_field, article_desc, article_cat], 
                          auto_id=False, description="collection of news articles")
print(f"Creating the collection")
collection = Collection(name=collection_name, schema=schema)
print(f"Schema: {schema}")
print("Success!")

Creating the collection
Schema: {'auto_id': False, 'description': 'collection of news articles', 'fields': [{'name': 'article_id', 'description': 'primary id', 'type': <DataType.INT64: 5>, 'is_primary': True, 'auto_id': False}, {'name': 'article_embed', 'description': '', 'type': <DataType.FLOAT_VECTOR: 101>, 'params': {'dim': 768}}, {'name': 'article_desc', 'description': 'short description of the article', 'type': <DataType.VARCHAR: 21>, 'params': {'max_length': 370}}, {'name': 'article_category', 'description': 'category of the article', 'type': <DataType.VARCHAR: 21>, 'params': {'max_length': 64}}]}
Success!


In [15]:
cuts = [*range(0, len(data), 1000)]
cuts.append(len(data))
print(cuts)

[0, 1000, 2000, 3000, 4000, 5000, 6000, 7000, 8000, 9000, 10000, 11000, 12000, 13000, 14000, 15000, 16000, 17000, 18000, 19000, 20000, 21000, 22000, 23000, 24000, 25000, 26000, 27000, 28000, 29000, 30000, 31000, 32000, 33000, 34000, 35000, 36000, 37000, 38000, 39000, 40000, 41000, 42000, 43000, 44000, 45000, 46000, 47000, 48000, 49000, 50000, 51000, 52000, 53000, 54000, 55000, 56000, 57000, 58000, 59000, 60000, 61000, 62000, 63000, 64000, 65000, 66000, 67000, 68000, 69000, 70000, 71000, 72000, 73000, 74000, 75000, 76000, 77000, 78000, 79000, 80000, 81000, 82000, 83000, 84000, 85000, 86000, 87000, 88000, 89000, 90000, 91000, 92000, 93000, 94000, 95000, 96000, 97000, 98000, 99000, 100000, 101000, 102000, 103000, 104000, 105000, 106000, 107000, 108000, 109000, 110000, 111000, 112000, 113000, 114000, 115000, 116000, 117000, 118000, 119000, 120000, 121000, 122000, 123000, 124000, 125000, 126000, 127000, 128000, 129000, 130000, 131000, 132000, 133000, 134000, 135000, 136000, 137000, 138000, 

In [16]:
multiprocessing.cpu_count()

8

In [None]:
article_id = []
article_desc = []
article_embed = []
article_cat = []
pool = multiprocessing.Pool(processes=multiprocessing.cpu_count()-2)
try:
    for i in tqdm(range(len(cuts)-1)):
        df = data.iloc[cuts[i]: cuts[i+1]].copy()
        article_id = [*df['index']]
        article_desc = [*df['short_description']]
        article_cat = [*df['category']]
        results = []
        for doc in article_desc:
            f = pool.apply_async(vectorizer.vectorize, args=(doc, tokenizer)) # need to pass tokenizer as argument
            results.append(f) # appending result to results
        for f in results:
            emb = f.get(timeout=120)
            article_embed.append(emb)
        docs = [article_id, article_embed, article_desc, article_cat]
        ins_resp = collection.insert(docs)
        print(ins_resp)
        article_id = []
        article_desc = []
        article_embed = []
        article_cat = []
        if i == 0:
            index_params = {"index_type": "AUTOINDEX", "metric_type": "L2", "params": {}} 
            collection.create_index(field_name='article_embed', index_params=index_params)
            collection = Collection(name=collection_name)
            collection.load()
    pool.close()
    pool.join()
except:
    pool.close()
    pool.join()
    raise