File size: 8,551 Bytes
16ee7cd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4683e7b
16ee7cd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9d6706f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
import gradio as gr
from transformers import AutoConfig, AutoTokenizer
from bert_graph import BertForMultipleChoice
import torch
import copy
from itertools import chain

# "Comparison of mean diurnal measurements with latanoprost and timolol showed a statistical significant (P < 0.001) difference at 3, 6, and 12 months.",
# "in patients with pigmentary glaucoma, 0.005% latanoprost taken once daily was well tolerated and more effective in reducing IOP than 0.5% timolol taken twice daily."

def preprocess_function_exp(examples, tokenizer):

        # Flatten out
        pair_list = examples
        pair_len = [len(item) for item in pair_list]

        first_sentences = []
        second_sentences = []
        for line_list in pair_list:
            for line in line_list:
                sent_item = line.strip().split('\t')
                first_sentences.append(sent_item[0].strip())
                second_sentences.append(sent_item[1].strip())
        
        # Tokenize
        tokenized_examples = tokenizer(
            first_sentences,
            second_sentences,
            max_length=512,
            padding=False,
            truncation=True,
        )
        # Un-flatten
        # tokenized_inputs = {k: [v[i : i + pair_len[0]] for i in range(0, len(v), pair_len[0])] for k, v in tokenized_examples.items()}
        tokenized_inputs = {}
        for k, v in tokenized_examples.items():
            flatten_list = []
            head_idx = 0
            tail_idx = 0
            for pair_idx in pair_len:
                tail_idx = head_idx + pair_idx
                flatten_list.append(v[head_idx: tail_idx])
                head_idx = copy.copy(tail_idx)
            tokenized_inputs[k] = flatten_list

        # tokenized_inputs["pair_len"] = pair_len
        return tokenized_inputs

def DCForMultipleChoice(features, tokenizer):

        batch_size = len(features)
        argument_len = 4
        
        flattened_features = [
            [{k: v[0][i] for k, v in features.items()} for i in range(4)]
        ]
        
        flattened_features = list(chain(*flattened_features))


        batch = tokenizer.pad(
            flattened_features,
            padding=True,
            max_length=512,
            return_tensors="pt",
        )


        batch = {k: v.view(1, argument_len, -1) for k, v in batch.items()}

        return batch

def post_process_diag(predictions):
    
    num_sentences = int(len(predictions)**0.5)
    predictions_mtx = predictions.reshape(num_sentences, num_sentences)

    for i in range(num_sentences):
        for j in range(num_sentences):
            if i == j:
                predictions_mtx[i, j] = 0

    return predictions_mtx.view(-1)

def max_vote(logits1, logits2, pred1, pred2):

    pred1 = post_process_diag(pred1)
    pred2 = post_process_diag(pred2)
    pred_res = []
    confidence_res = []
    for i in range(len(logits1)):

        soft_logits1 = torch.nn.functional.softmax(logits1[i]) # [[j] for j in range(logits1.shape[1])]
        soft_logits2 = torch.nn.functional.softmax(logits2[i])

        # two class
        # torch.topk(soft_logits1, n=2)
        values_1, _ = soft_logits1.topk(k=2)
        values_2, _ = soft_logits2.topk(k=2)
        # if (values_1[0] - values_2[0]) > (values_1[1] - values_2[1]):
        #     pred_res.append(int(pred1[i].detach().cpu().numpy()))
        # else:
        #     pred_res.append(int(pred2[i].detach().cpu().numpy()))
        if (values_1[0] - values_1[1]) >= (values_2[0] - values_2[1]):
            pred_res.append(int(pred1[i].detach().cpu().numpy()))
            confidence_res.append(float((values_1[0] - values_1[1]).detach().cpu().numpy()))
        else:
            pred_res.append(int(pred2[i].detach().cpu().numpy()))
            confidence_res.append(float((values_2[0] - values_2[1]).detach().cpu().numpy()))

    return pred_res, confidence_res

def model_infer(input_a, input_b):

    config = AutoConfig.from_pretrained('michiyasunaga/BioLinkBERT-base')
    config.win_size = 13
    config.model_mode = 'bert_mtl_1d'
    config.dataset_domain = 'absRCT'
    config.voter_branch = 'dual'
    config.destroy = False

    model = BertForMultipleChoice.from_pretrained(
                'michiyasunaga/BioLinkBERT-base',
                config=config,
            )
    p_sum = torch.load('best.pth', map_location=torch.device('cpu'))
    model.load_state_dict(p_sum)
    tokenizer = AutoTokenizer.from_pretrained('michiyasunaga/BioLinkBERT-base')


    examples = [[input_a+'\t'+input_a, input_a+'\t'+input_b, input_b+'\t'+input_a, input_b+'\t'+input_b]]
    tokenized_inputs = preprocess_function_exp(examples, tokenizer)
    tokenized_inputs = DCForMultipleChoice(tokenized_inputs, tokenizer)
    outputs = model(**tokenized_inputs)
    predictions, scores = max_vote(outputs.logits[0], outputs.logits[1], outputs.logits[0].argmax(dim=-1), outputs.logits[1].argmax(dim=-1))

    prediction_a_b = predictions[1]
    prediction_b_a = predictions[2]

    label_space = {0: 'not relates', 1: 'supports', 2: 'attack'}
    label_a_b = label_space[prediction_a_b]
    label_b_a = label_space[prediction_b_a]

    return 'Head Argument {} Tail Argument'.format(label_a_b, label_b_a)


with gr.Blocks() as demo:
    #设置输入组件
    arg_1 = gr.Textbox(label="Head Argument")
    arg_2 = gr.Textbox(label="Tail Argument")

    gr.Examples([\
    "Compared with baseline measurements, both latanoprost and timolol caused a significant (P < 0.001) reduction of IOP at each hour of diurnal curve throughout the duration of therapy.",\
        "Reduction of IOP was 6.0 +/- 4.5 and 5.9 +/- 4.6 with latanoprost and 4.8 +/- 3.0 and 4.6 +/- 3.1 with timolol after 6 and 12 months, respectively.",\
        "Comparison of mean diurnal measurements with latanoprost and timolol showed a statistical significant (P < 0.001) difference at 3, 6, and 12 months.",\
        "Mean C was found to be significantly enhanced (+30%) only in the latanoprost-treated group compared with the baseline (P = 0.017).",\
        "Mean conjunctival hyperemia was graded at 0.3 in latanoprost-treated eyes and 0.2 in timolol-treated eyes.",\
        "A remarkable change in iris color was observed in both eyes of 1 of the 18 patients treated with latanoprost and none of the 18 patients who received timolol.",\
        "In the timolol group, heart rate was significantly reduced from 72 +/- 9 at baseline to 67 +/- 10 beats per minute at 12 months.",\
        "in patients with pigmentary glaucoma, 0.005% latanoprost taken once daily was well tolerated and more effective in reducing IOP than 0.5% timolol taken twice daily.",\
        "further studies may need to confirm these data on a larger sample and to evaluate the side effect of increased iris pigmentation on long-term follow-up,",\
        ], arg_1)
    gr.Examples([\
    "Compared with baseline measurements, both latanoprost and timolol caused a significant (P < 0.001) reduction of IOP at each hour of diurnal curve throughout the duration of therapy.",\
        "Reduction of IOP was 6.0 +/- 4.5 and 5.9 +/- 4.6 with latanoprost and 4.8 +/- 3.0 and 4.6 +/- 3.1 with timolol after 6 and 12 months, respectively.",\
        "Comparison of mean diurnal measurements with latanoprost and timolol showed a statistical significant (P < 0.001) difference at 3, 6, and 12 months.",\
        "Mean C was found to be significantly enhanced (+30%) only in the latanoprost-treated group compared with the baseline (P = 0.017).",\
        "Mean conjunctival hyperemia was graded at 0.3 in latanoprost-treated eyes and 0.2 in timolol-treated eyes.",\
        "A remarkable change in iris color was observed in both eyes of 1 of the 18 patients treated with latanoprost and none of the 18 patients who received timolol.",\
        "In the timolol group, heart rate was significantly reduced from 72 +/- 9 at baseline to 67 +/- 10 beats per minute at 12 months.",\
        "in patients with pigmentary glaucoma, 0.005% latanoprost taken once daily was well tolerated and more effective in reducing IOP than 0.5% timolol taken twice daily.",\
        "further studies may need to confirm these data on a larger sample and to evaluate the side effect of increased iris pigmentation on long-term follow-up,",\
        ], arg_2)
    # 设置输出组件
    output = gr.Textbox(label="Output Box")
    #设置按钮
    greet_btn = gr.Button("Run")
    #设置按钮点击事件
    greet_btn.click(fn=model_infer, inputs=[arg_1, arg_2], outputs=output)



demo.launch(share=True)