merve HF staff commited on
Commit
749d0fc
1 Parent(s): d688799

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +140 -0
app.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from huggingface_hub import from_pretrained_keras
2
+ import numpy as np
3
+ import pandas as pd
4
+ import tensorflow as tf
5
+ import tensorflow_hub as hub
6
+ import tensorflow_text as text
7
+ from tensorflow import keras
8
+ import gradio as gr
9
+
10
+
11
+ def make_bert_preprocessing_model(sentence_features, seq_length=128):
12
+ """Returns Model mapping string features to BERT inputs.
13
+
14
+ Args:
15
+ sentence_features: A list with the names of string-valued features.
16
+ seq_length: An integer that defines the sequence length of BERT inputs.
17
+
18
+ Returns:
19
+ A Keras Model that can be called on a list or dict of string Tensors
20
+ (with the order or names, resp., given by sentence_features) and
21
+ returns a dict of tensors for input to BERT.
22
+ """
23
+
24
+ input_segments = [
25
+ tf.keras.layers.Input(shape=(), dtype=tf.string, name=ft)
26
+ for ft in sentence_features
27
+ ]
28
+
29
+ # tokenize the text to word pieces
30
+ bert_preprocess = hub.load(bert_preprocess_path)
31
+ tokenizer = hub.KerasLayer(bert_preprocess.tokenize,
32
+ name="tokenizer")
33
+
34
+ segments = [tokenizer(s) for s in input_segments]
35
+
36
+ truncated_segments = segments
37
+
38
+ packer = hub.KerasLayer(bert_preprocess.bert_pack_inputs,
39
+ arguments=dict(seq_length=seq_length),
40
+ name="packer")
41
+ model_inputs = packer(truncated_segments)
42
+ return keras.Model(input_segments, model_inputs)
43
+
44
+
45
+ def preprocess_image(image_path, resize):
46
+ extension = tf.strings.split(image_path)[-1]
47
+
48
+ image = tf.io.read_file(image_path)
49
+ if extension == b"jpg":
50
+ image = tf.image.decode_jpeg(image, 3)
51
+ else:
52
+ image = tf.image.decode_png(image, 3)
53
+
54
+ image = tf.image.resize(image, resize)
55
+ return image
56
+
57
+ def preprocess_text(text_1, text_2):
58
+
59
+ text_1 = tf.convert_to_tensor([text_1])
60
+ text_2 = tf.convert_to_tensor([text_2])
61
+
62
+ output = bert_preprocess_model([text_1, text_2])
63
+
64
+ output = {feature: tf.squeeze(output[feature]) for feature in bert_input_features}
65
+
66
+ return output
67
+
68
+ def preprocess_text_and_image(sample, resize):
69
+
70
+ image_1 = preprocess_image(sample['image_1_path'], resize)
71
+ image_2 = preprocess_image(sample['image_2_path'], resize)
72
+
73
+ text = preprocess_text(sample['text_1'], sample['text_2'])
74
+
75
+ return {"image_1": image_1, "image_2": image_2, "text": text}
76
+
77
+
78
+ def classify_info(image_1, text_1, image_2, text_2):
79
+
80
+ sample = dict()
81
+ sample['image_1_path'] = image_1
82
+ sample['image_2_path'] = image_2
83
+ sample['text_1'] = text_1
84
+ sample['text_2'] = text_2
85
+
86
+ dataframe = pd.DataFrame(sample, index=[0])
87
+
88
+ ds = tf.data.Dataset.from_tensor_slices((dict(dataframe), [0]))
89
+ ds = ds.map(lambda x, y: (preprocess_text_and_image(x, resize), y)).cache()
90
+ batch_size = 1
91
+ auto = tf.data.AUTOTUNE
92
+ ds = ds.batch(batch_size).prefetch(auto)
93
+ output = model.predict(ds)
94
+
95
+ outputs = dict()
96
+
97
+ outputs[labels[0]] = float(output[0][0])
98
+ outputs[labels[1]] = float(output[0][1])
99
+ outputs[labels[2]] = float(output[0][2])
100
+ #label = np.argmax(output)
101
+ return outputs #labels[label]
102
+
103
+
104
+ model = from_pretrained_keras("keras-io/multimodal-entailment")
105
+ resize = (128, 128)
106
+ bert_input_features = ["input_word_ids", "input_type_ids", "input_mask"]
107
+ bert_model_path = ("https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-2_H-256_A-4/1")
108
+ bert_preprocess_path = "https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3"
109
+ bert_preprocess_model = make_bert_preprocessing_model(['text_1', 'text_2'])
110
+
111
+ labels = {0: "Contradictory", 1: "Implies", 2: "No Entailment"}
112
+
113
+ block = gr.Blocks()
114
+
115
+ examples = [['examples/image_1.png', '#IndiaFightsCorona:\n\nNearly 4.5 million beneficiaries vaccinated against #COVID19 in 19 days.\n\nIndia is the fastest country to cross landmark of vaccinating 4 million beneficiaries in merely 18 days.\n\n#StaySafe #IndiaWillWin #Unite2FightCorona https://t.co/beGDQfd06S', 'examples/image_2.jpg', '#IndiaFightsCorona:\n\nIndia has become the fastest nation to reach 4 million #COVID19 vaccinations ; it took only 18 days to administer the first 4 million #vaccines\n\n:@MoHFW_INDIA Secretary\n\n#StaySafe #IndiaWillWin #Unite2FightCorona https://t.co/9GENQlqtn3']]
116
+
117
+
118
+ with block:
119
+ gr.Markdown("Multimodal Entailment")
120
+ with gr.Tab("Hypothesis"):
121
+ with gr.Row():
122
+ gr.Markdown("Upload hypothesis image:")
123
+ image_1 = gr.inputs.Image(type="filepath")
124
+ text_1 = gr.inputs.Textbox(lines=5)
125
+
126
+ with gr.Tab("Premise"):
127
+ with gr.Row():
128
+ gr.Markdown("Upload premise image:")
129
+ image_2 = gr.inputs.Image(type="filepath")
130
+ text_2 = gr.inputs.Textbox(lines=5)
131
+
132
+ xray_results = gr.outputs.JSON()
133
+ xray_run = gr.Button("Run")
134
+ xray_run.click(xray_model, inputs=[disease, xray_scan], outputs=xray_results)
135
+
136
+ run = gr.Button("Run")
137
+ label = gr.outputs.Label()
138
+ run.click(model, inputs=[image_1, text_1, image_2, text_2], outputs=label)
139
+
140
+ block.launch()