tonyliu404 commited on
Commit
15be7c5
·
verified ·
1 Parent(s): 6a2d62f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +183 -2
app.py CHANGED
@@ -168,6 +168,187 @@ rag_chain = (
168
 
169
  def get_response(query):
170
  return rag_chain.invoke(query)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
171
 
172
- print("HELLO WORLD")
173
- st.title("RESSSSULTS")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
 
169
  def get_response(query):
170
  return rag_chain.invoke(query)
171
+
172
+
173
+ ##############################################
174
+ # Classifier
175
+ img_size = 224
176
+
177
+ @st.cache_resource
178
+ def loadModel():
179
+ model = load_model('efficientnet-fine-6.keras')
180
+ return model
181
+
182
+ model = loadModel()
183
+
184
+ class_names = [
185
+ "apple_pie", "baby_back_ribs", "baklava", "beef_carpaccio", "beef_tartare", "beet_salad",
186
+ "beignets", "bibimbap", "bread_pudding", "breakfast_burrito", "bruschetta", "caesar_salad",
187
+ "cannoli", "caprese_salad", "carrot_cake", "ceviche", "cheese_plate", "cheesecake", "chicken_curry",
188
+ "chicken_quesadilla", "chicken_wings", "chocolate_cake", "chocolate_mousse", "churros", "clam_chowder",
189
+ "club_sandwich", "crab_cakes", "creme_brulee", "croque_madame", "cup_cakes", "deviled_eggs", "donuts",
190
+ "dumplings", "edamame", "eggs_benedict", "escargots", "falafel", "filet_mignon", "fish_and_chips", "foie_gras",
191
+ "french_fries", "french_onion_soup", "french_toast", "fried_calamari", "fried_rice", "frozen_yogurt",
192
+ "garlic_bread", "gnocchi", "greek_salad", "grilled_cheese_sandwich", "grilled_salmon", "guacamole", "gyoza",
193
+ "hamburger", "hot_and_sour_soup", "hot_dog", "huevos_rancheros", "hummus", "ice_cream", "lasagna",
194
+ "lobster_bisque", "lobster_roll_sandwich", "macaroni_and_cheese", "macarons", "miso_soup", "mussels",
195
+ "nachos", "omelette", "onion_rings", "oysters", "pad_thai", "paella", "pancakes", "panna_cotta", "peking_duck",
196
+ "pho", "pizza", "pork_chop", "poutine", "prime_rib", "pulled_pork_sandwich", "ramen", "ravioli", "red_velvet_cake",
197
+ "risotto", "samosa", "sashimi", "scallops", "seaweed_salad", "shrimp_and_grits", "spaghetti_bolognese",
198
+ "spaghetti_carbonara", "spring_rolls", "steak", "strawberry_shortcake", "sushi", "tacos", "takoyaki", "tiramisu",
199
+ "tuna_tartare", "waffles"
200
+ ]
201
+
202
+ def classifyImage(input_image):
203
+ input_image = input_image.resize((img_size, img_size))
204
+ input_array = tf.keras.utils.img_to_array(input_image)
205
+
206
+ # Add a batch dimension
207
+ input_array = tf.expand_dims(input_array, 0) # (1, 224, 224, 3)
208
 
209
+ predictions = model.predict(input_array)[0]
210
+ print(f"Predictions: {predictions}")
211
+
212
+ # Sort predictions to get top 5
213
+ top_indices = np.argsort(predictions)[-5:][::-1]
214
+
215
+ # Prepare the top 5 predictions with their class names and percentages
216
+ top_predictions = [(class_names[i], predictions[i] * 100) for i in top_indices]
217
+ for i, (class_name, confidence) in enumerate(top_predictions, 1):
218
+ print(f"{i}. Predicted {class_name} with {confidence:.2f}% Confidence")
219
+
220
+ return top_predictions
221
+
222
+ def capitalize_after_number(input_string):
223
+ # Split the string on the first period
224
+ if ". " in input_string:
225
+ num, text = input_string.split(". ", 1)
226
+ return f"{num}. {text.capitalize()}"
227
+ return input_string
228
+ ##############################################
229
+
230
+ #for displaying RAG recipe response
231
+ def display_response(response):
232
+ """
233
+ Function to format a JSON response into Streamlit's `st.write()` format.
234
+ """
235
+ if isinstance(response, str):
236
+ # Convert JSON string to dictionary if necessary
237
+ response = json.loads(response)
238
+
239
+ st.write("### Recipe Details")
240
+ st.write(f"**Name:** {response['name'].capitalize()}")
241
+ st.write(f"**Preparation Time:** {response['minutes']} minutes")
242
+ st.write(f"**Description:** {response['description'].capitalize()}")
243
+ st.write(f"**Tags:** {', '.join(response['tags'])}")
244
+ st.write("### Ingredients")
245
+ st.write(", ".join([ingredient.capitalize() for ingredient in response['ingredients']]))
246
+ st.write(f"**Total Ingredients:** {response['n_ingredients']}")
247
+ st.write("### Nutrition Information (per serving)")
248
+ st.write(", ".join(response['formatted_nutrition']))
249
+ st.write(f"**Number of Steps:** {response['n_steps']}")
250
+ st.write("### Steps")
251
+ for step in response['formatted_steps']:
252
+ st.write(capitalize_after_number(step))
253
+
254
+ def display_dishes_in_grid(dishes, cols=3):
255
+ rows = len(dishes) // cols + int(len(dishes) % cols > 0)
256
+ for i in range(rows):
257
+ cols_data = dishes[i*cols:(i+1)*cols]
258
+ cols_list = st.columns(len(cols_data))
259
+ for col, dish in zip(cols_list, cols_data):
260
+ with col:
261
+ st.sidebar.write(dish.replace("_", " ").capitalize())
262
+ # #Streamlit
263
+
264
+ #Left sidebar title
265
+ st.sidebar.markdown(
266
+ "<h1 style='font-size:32px;'>RAG Recipe</h1>",
267
+ unsafe_allow_html=True
268
+ )
269
+
270
+ st.sidebar.write("Upload an image and/or enter a query to get started! Explore our trained dish types listed below for guidance.")
271
+
272
+ uploaded_image = st.sidebar.file_uploader("Choose an image:", type="jpg")
273
+ query = st.sidebar.text_area("Enter your query:", height=100)
274
+
275
+ # gap
276
+ st.sidebar.markdown("<br><br><br>", unsafe_allow_html=True)
277
+ selected_dish = st.sidebar.selectbox(
278
+ "Search for a dish that our model can classify:",
279
+ options=class_names,
280
+ index=0
281
+ )
282
+
283
+ # Right title
284
+ st.title("Results")
285
+ #################
286
+
287
+ # Image Classification Section
288
+ if uploaded_image and query:
289
+ # Open the image
290
+ input_image = Image.open(uploaded_image)
291
+
292
+ # Display the image
293
+ st.image(input_image, caption="Uploaded Image.", use_container_width=True)
294
+
295
+ predictions = classifyImage(input_image)
296
+ fpredictions = ""
297
+
298
+ # Show the top predictions with percentages
299
+ st.write("Top Predictions:")
300
+ for class_name, confidence in predictions:
301
+ if int(confidence) > 0.05:
302
+ fpredictions += f"{class_name}: {confidence:.2f}%,"
303
+ st.write(f"{class_name}: {confidence:.2f}%")
304
+ print(fpredictions)
305
+
306
+ # call openai to pick the best classification result based on query
307
+ openAICall = [
308
+ SystemMessage(
309
+ content = "You are a helpful assistant that identifies the best match between classified food items and a user's request based on provided classifications and keywords."
310
+ ),
311
+ HumanMessage(
312
+ content = f"""
313
+ Based on the following image classification with percentages of each food:
314
+ {fpredictions}
315
+ And the following user request:
316
+ {query}
317
+ Return to me JUST ONE of the classified images that most relates to the user request, based on the relevance of the user query.
318
+ in the format: [dish]
319
+ """
320
+ ),
321
+ ]
322
+
323
+ # Call the OpenAI API
324
+ openAIresponse = llm.invoke(openAICall)
325
+ print("AI CALL RESPONSE: ", openAIresponse.content)
326
+
327
+ # RAG the openai response and display
328
+ print("RAG INPUT", openAIresponse.content + " " + query)
329
+ RAGresponse = get_response(openAIresponse.content + " " + query)
330
+ display_response(RAGresponse)
331
+ elif uploaded_image is not None:
332
+ # Open the image
333
+ input_image = Image.open(uploaded_image)
334
+
335
+ # Display the image
336
+ st.image(input_image, caption="Uploaded Image.", use_column_width=True)
337
+
338
+ # Classify the image and display the result
339
+ predictions = classifyImage(input_image)
340
+ fpredictions = ""
341
+
342
+ # Show the top predictions with percentages
343
+ st.write("Top Predictions:")
344
+ for class_name, confidence in predictions:
345
+ if int(confidence) > 0.05:
346
+ fpredictions += f"{class_name}: {confidence:.2f}%,"
347
+ st.write(f"{class_name}: {confidence:.2f}%")
348
+ print(fpredictions)
349
+
350
+ elif query:
351
+ response = get_response(query)
352
+ display_response(response)
353
+ else:
354
+ st.write("Please input an image and/or a prompt.")