nishantguvvada commited on
Commit
19c08c2
·
1 Parent(s): dead93e

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +159 -0
app.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import tensorflow as tf
3
+ import cv2
4
+ import numpy as np
5
+ from PIL import Image, ImageOps
6
+ import imageio.v3 as iio
7
+ import time
8
+ from textwrap import wrap
9
+
10
+ import matplotlib.pylab as plt
11
+ import numpy as np
12
+ import tensorflow as tf
13
+ import tensorflow_datasets as tfds
14
+ import tensorflow_hub as hub
15
+ from tensorflow.keras import Input
16
+ from tensorflow.keras.layers import (
17
+ GRU,
18
+ Add,
19
+ AdditiveAttention,
20
+ Attention,
21
+ Concatenate,
22
+ Dense,
23
+ Embedding,
24
+ LayerNormalization,
25
+ Reshape,
26
+ StringLookup,
27
+ TextVectorization,
28
+ )
29
+
30
+ @st.cache_resource()
31
+ def load_image_model():
32
+ image_model=tf.keras.models.load_model('./image_caption_model.h5')
33
+ return image_model
34
+
35
+ @st.cache_resource()
36
+ def load_decoder_model():
37
+ decoder_model=tf.keras.models.load_model('./decoder_pred_model.h5')
38
+ return decoder_model
39
+
40
+ @st.cache_resource()
41
+ def load_encoder_model():
42
+ encoder=tf.keras.models.load_model('./encoder_model.h5')
43
+ return encoder
44
+
45
+ st.title(":blue[Nishant Guvvada's] :red[AI Journey] Image Caption Generation")
46
+ image = Image.open('./title.jpg')
47
+ st.image(image)
48
+ st.write("""
49
+ # Multi-Modal Machine Learning
50
+ """
51
+ )
52
+
53
+ file = st.file_uploader("Upload any image and the model will try to provide a caption to it!", type= ['png', 'jpg'])
54
+
55
+ MAX_CAPTION_LEN = 64
56
+ MINIMUM_SENTENCE_LENGTH = 5
57
+ IMG_HEIGHT = 299
58
+ IMG_WIDTH = 299
59
+ IMG_CHANNELS = 3
60
+ ATTENTION_DIM = 512 # size of dense layer in Attention
61
+ VOCAB_SIZE = 20000
62
+
63
+
64
+ # We will override the default standardization of TextVectorization to preserve
65
+ # "<>" characters, so we preserve the tokens for the <start> and <end>.
66
+ def standardize(inputs):
67
+ inputs = tf.strings.lower(inputs)
68
+ return tf.strings.regex_replace(
69
+ inputs, r"[!\"#$%&\(\)\*\+.,-/:;=?@\[\\\]^_`{|}~]?", ""
70
+ )
71
+
72
+
73
+ # Choose the most frequent words from the vocabulary & remove punctuation etc.
74
+ tokenizer = TextVectorization(
75
+ max_tokens=VOCAB_SIZE,
76
+ standardize=standardize,
77
+ output_sequence_length=MAX_CAPTION_LEN,
78
+ )
79
+
80
+ # Lookup table: Word -> Index
81
+ word_to_index = StringLookup(
82
+ mask_token="", vocabulary=tokenizer.get_vocabulary()
83
+ )
84
+
85
+ # Lookup table: Index -> Word
86
+ index_to_word = StringLookup(
87
+ mask_token="", vocabulary=tokenizer.get_vocabulary(), invert=True
88
+ )
89
+
90
+
91
+ ## Probabilistic prediction using the trained model
92
+ def predict_caption(file):
93
+ gru_state = tf.zeros((1, ATTENTION_DIM))
94
+
95
+ img = tf.image.decode_jpeg(tf.io.read_file(filename), channels=IMG_CHANNELS)
96
+ img = tf.image.resize(img, (IMG_HEIGHT, IMG_WIDTH))
97
+ img = img / 255
98
+
99
+ encoder = load_encoder_model()
100
+ features = encoder(tf.expand_dims(img, axis=0))
101
+ dec_input = tf.expand_dims([word_to_index("<start>")], 1)
102
+ result = []
103
+ decoder_pred_model = load_decoder_model()
104
+ for i in range(MAX_CAPTION_LEN):
105
+ predictions, gru_state = decoder_pred_model(
106
+ [dec_input, gru_state, features]
107
+ )
108
+
109
+ # draws from log distribution given by predictions
110
+ top_probs, top_idxs = tf.math.top_k(
111
+ input=predictions[0][0], k=10, sorted=False
112
+ )
113
+ chosen_id = tf.random.categorical([top_probs], 1)[0].numpy()
114
+ predicted_id = top_idxs.numpy()[chosen_id][0]
115
+
116
+ result.append(tokenizer.get_vocabulary()[predicted_id])
117
+
118
+ if predicted_id == word_to_index("<end>"):
119
+ return img, result
120
+
121
+ dec_input = tf.expand_dims([predicted_id], 1)
122
+
123
+ return img, result
124
+
125
+
126
+ filename = "../sample_images/surf.jpeg" # you can also try surf.jpeg
127
+
128
+ for i in range(5):
129
+ image, caption = predict_caption(filename)
130
+ print(" ".join(caption[:-1]) + ".")
131
+
132
+ img = tf.image.decode_jpeg(tf.io.read_file(filename), channels=IMG_CHANNELS)
133
+ plt.imshow(img)
134
+ plt.axis("off")
135
+
136
+
137
+ filename = np.array(Image.open(file).convert('RGB'))
138
+
139
+ def model_prediction(path):
140
+ resize = tf.image.resize(path, (256,256))
141
+ with st.spinner('Model is being loaded..'):
142
+ model=load_image_model()
143
+ yhat = model.predict(np.expand_dims(resize/255, 0))
144
+ return yhat
145
+
146
+ def on_click():
147
+ if file is None:
148
+ st.text("Please upload an image file")
149
+ else:
150
+ image = Image.open(file)
151
+ st.image(image, use_column_width=True)
152
+ image = image.convert('RGB')
153
+ predictions = model_prediction(np.array(image))
154
+ if (predictions>0.5):
155
+ st.write("""# Prediction : Implant is loose""")
156
+ else:
157
+ st.write("""# Prediction : Implant is in control""")
158
+
159
+ st.button('Predict', on_click=on_click)