File size: 2,017 Bytes
57a2c61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
from components.vector_db_operations import get_collection_from_vector_db
from components.vector_db_operations import retrieval
from components.english_information_extraction import english_information_extraction
from components.multi_lingual_model import MDFEND , loading_model_and_tokenizer
from components.data_loading import preparing_data , loading_data
from components.language_identification import language_identification



def run_pipeline(input_text:str):

    language_dict = language_identification(input_text)
    language_code = next(iter(language_dict))

    if language_code == "en":

      output_english = english_information_extraction(input_text)

      return output_english

    else:


        num_results = 1
        path = "/content/drive/MyDrive/general_domains/vector_database"
        collection_name = "general_domains"


        collection = get_collection_from_vector_db(path , collection_name)

        domain , label_domain , distance  = retrieval(input_text  , num_results , collection )

        if distance >1.45:
          domain = "undetermined"

        tokenizer , model = loading_model_and_tokenizer()

        df = preparing_data(input_text , label_domain)

        input_ids , input_masks , input_domains = loading_data(tokenizer , df )

        labels = []
        outputs = []
        with torch.no_grad():

            pred = model.forward(input_ids, input_masks , input_domains)
            labels.append([])

            for output in pred:
                number = output.item()
                label = int(1) if number >= 0.5 else int(0)
                labels[-1].append(label)
            outputs.append(pred)

        discrimination_class = ["discriminative" if i == int(1) else "not discriminative" for i in labels[0]]


        return { "domain_label" :domain ,
                "domain_score":distance ,
                 "discrimination_label" : discrimination_class[-1],
                 "discrimination_score" : outputs[0][1:].item(),
        }