File size: 2,017 Bytes
57a2c61 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 |
from components.vector_db_operations import get_collection_from_vector_db
from components.vector_db_operations import retrieval
from components.english_information_extraction import english_information_extraction
from components.multi_lingual_model import MDFEND , loading_model_and_tokenizer
from components.data_loading import preparing_data , loading_data
from components.language_identification import language_identification
def run_pipeline(input_text:str):
language_dict = language_identification(input_text)
language_code = next(iter(language_dict))
if language_code == "en":
output_english = english_information_extraction(input_text)
return output_english
else:
num_results = 1
path = "/content/drive/MyDrive/general_domains/vector_database"
collection_name = "general_domains"
collection = get_collection_from_vector_db(path , collection_name)
domain , label_domain , distance = retrieval(input_text , num_results , collection )
if distance >1.45:
domain = "undetermined"
tokenizer , model = loading_model_and_tokenizer()
df = preparing_data(input_text , label_domain)
input_ids , input_masks , input_domains = loading_data(tokenizer , df )
labels = []
outputs = []
with torch.no_grad():
pred = model.forward(input_ids, input_masks , input_domains)
labels.append([])
for output in pred:
number = output.item()
label = int(1) if number >= 0.5 else int(0)
labels[-1].append(label)
outputs.append(pred)
discrimination_class = ["discriminative" if i == int(1) else "not discriminative" for i in labels[0]]
return { "domain_label" :domain ,
"domain_score":distance ,
"discrimination_label" : discrimination_class[-1],
"discrimination_score" : outputs[0][1:].item(),
}
|