tonyliu404 commited on
Commit
8b020ec
·
verified ·
1 Parent(s): 54c1dec

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +173 -0
app.py CHANGED
@@ -183,3 +183,176 @@ def loadModel():
183
  return model
184
 
185
  model = loadModel()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
183
  return model
184
 
185
  model = loadModel()
186
+
187
+
188
+ class_names = [
189
+ "apple_pie", "baby_back_ribs", "baklava", "beef_carpaccio", "beef_tartare", "beet_salad",
190
+ "beignets", "bibimbap", "bread_pudding", "breakfast_burrito", "bruschetta", "caesar_salad",
191
+ "cannoli", "caprese_salad", "carrot_cake", "ceviche", "cheese_plate", "cheesecake", "chicken_curry",
192
+ "chicken_quesadilla", "chicken_wings", "chocolate_cake", "chocolate_mousse", "churros", "clam_chowder",
193
+ "club_sandwich", "crab_cakes", "creme_brulee", "croque_madame", "cup_cakes", "deviled_eggs", "donuts",
194
+ "dumplings", "edamame", "eggs_benedict", "escargots", "falafel", "filet_mignon", "fish_and_chips", "foie_gras",
195
+ "french_fries", "french_onion_soup", "french_toast", "fried_calamari", "fried_rice", "frozen_yogurt",
196
+ "garlic_bread", "gnocchi", "greek_salad", "grilled_cheese_sandwich", "grilled_salmon", "guacamole", "gyoza",
197
+ "hamburger", "hot_and_sour_soup", "hot_dog", "huevos_rancheros", "hummus", "ice_cream", "lasagna",
198
+ "lobster_bisque", "lobster_roll_sandwich", "macaroni_and_cheese", "macarons", "miso_soup", "mussels",
199
+ "nachos", "omelette", "onion_rings", "oysters", "pad_thai", "paella", "pancakes", "panna_cotta", "peking_duck",
200
+ "pho", "pizza", "pork_chop", "poutine", "prime_rib", "pulled_pork_sandwich", "ramen", "ravioli", "red_velvet_cake",
201
+ "risotto", "samosa", "sashimi", "scallops", "seaweed_salad", "shrimp_and_grits", "spaghetti_bolognese",
202
+ "spaghetti_carbonara", "spring_rolls", "steak", "strawberry_shortcake", "sushi", "tacos", "takoyaki", "tiramisu",
203
+ "tuna_tartare", "waffles"
204
+ ]
205
+
206
+ def classifyImage(input_image):
207
+ input_image = input_image.resize((img_size, img_size))
208
+ input_array = tf.keras.utils.img_to_array(input_image)
209
+
210
+ # Add a batch dimension
211
+ input_array = tf.expand_dims(input_array, 0) # (1, 224, 224, 3)
212
+
213
+ predictions = model.predict(input_array)[0]
214
+ print(f"Predictions: {predictions}")
215
+
216
+ # Sort predictions to get top 5
217
+ top_indices = np.argsort(predictions)[-5:][::-1]
218
+
219
+ # Prepare the top 5 predictions with their class names and percentages
220
+ top_predictions = [(class_names[i], predictions[i] * 100) for i in top_indices]
221
+ for i, (class_name, confidence) in enumerate(top_predictions, 1):
222
+ print(f"{i}. Predicted {class_name} with {confidence:.2f}% Confidence")
223
+
224
+ return top_predictions
225
+
226
+ def capitalize_after_number(input_string):
227
+ # Split the string on the first period
228
+ if ". " in input_string:
229
+ num, text = input_string.split(". ", 1)
230
+ return f"{num}. {text.capitalize()}"
231
+ return input_string
232
+ ##############################################
233
+
234
+ #for displaying RAG recipe response
235
+ def display_response(response):
236
+ """
237
+ Function to format a JSON response into Streamlit's `st.write()` format.
238
+ """
239
+ if isinstance(response, str):
240
+ # Convert JSON string to dictionary if necessary
241
+ response = json.loads(response)
242
+
243
+ st.write("### Recipe Details")
244
+ st.write(f"**Name:** {response['name'].capitalize()}")
245
+ st.write(f"**Preparation Time:** {response['minutes']} minutes")
246
+ st.write(f"**Description:** {response['description'].capitalize()}")
247
+ st.write(f"**Tags:** {', '.join(response['tags'])}")
248
+ st.write("### Ingredients")
249
+ st.write(", ".join([ingredient.capitalize() for ingredient in response['ingredients']]))
250
+ st.write(f"**Total Ingredients:** {response['n_ingredients']}")
251
+ st.write("### Nutrition Information (per serving)")
252
+ st.write(", ".join(response['formatted_nutrition']))
253
+ st.write(f"**Number of Steps:** {response['n_steps']}")
254
+ st.write("### Steps")
255
+ for step in response['formatted_steps']:
256
+ st.write(capitalize_after_number(step))
257
+
258
+ def display_dishes_in_grid(dishes, cols=3):
259
+ rows = len(dishes) // cols + int(len(dishes) % cols > 0)
260
+ for i in range(rows):
261
+ cols_data = dishes[i*cols:(i+1)*cols]
262
+ cols_list = st.columns(len(cols_data))
263
+ for col, dish in zip(cols_list, cols_data):
264
+ with col:
265
+ st.sidebar.write(dish.replace("_", " ").capitalize())
266
+ # #Streamlit
267
+
268
+ #Left sidebar title
269
+ st.sidebar.markdown(
270
+ "<h1 style='font-size:32px;'>RAG Recipe</h1>",
271
+ unsafe_allow_html=True
272
+ )
273
+
274
+ st.sidebar.write("Upload an image and/or enter a query to get started! Explore our trained dish types listed below for guidance.")
275
+
276
+ uploaded_image = st.sidebar.file_uploader("Choose an image:", type="jpg")
277
+ query = st.sidebar.text_area("Enter your query:", height=100)
278
+
279
+ # gap
280
+ st.sidebar.markdown("<br><br><br>", unsafe_allow_html=True)
281
+ selected_dish = st.sidebar.selectbox(
282
+ "Search for a dish that our model can classify:",
283
+ options=class_names,
284
+ index=0
285
+ )
286
+
287
+ # Right title
288
+ st.title("Results")
289
+ #################
290
+
291
+ # Image Classification Section
292
+ if uploaded_image and query:
293
+ # Open the image
294
+ input_image = Image.open(uploaded_image)
295
+
296
+ # Display the image
297
+ st.image(input_image, caption="Uploaded Image.", use_container_width=True)
298
+
299
+ predictions = classifyImage(input_image)
300
+ fpredictions = ""
301
+
302
+ # Show the top predictions with percentages
303
+ st.write("Top Predictions:")
304
+ for class_name, confidence in predictions:
305
+ if int(confidence) > 0.05:
306
+ fpredictions += f"{class_name}: {confidence:.2f}%,"
307
+ st.write(f"{class_name}: {confidence:.2f}%")
308
+ print(fpredictions)
309
+
310
+ # call openai to pick the best classification result based on query
311
+ openAICall = [
312
+ SystemMessage(
313
+ content = "You are a helpful assistant that identifies the best match between classified food items and a user's request based on provided classifications and keywords."
314
+ ),
315
+ HumanMessage(
316
+ content = f"""
317
+ Based on the following image classification with percentages of each food:
318
+ {fpredictions}
319
+ And the following user request:
320
+ {query}
321
+ Return to me JUST ONE of the classified images that most relates to the user request, based on the relevance of the user query.
322
+ in the format: [dish]
323
+ """
324
+ ),
325
+ ]
326
+
327
+ # Call the OpenAI API
328
+ openAIresponse = llm.invoke(openAICall)
329
+ print("AI CALL RESPONSE: ", openAIresponse.content)
330
+
331
+ # RAG the openai response and display
332
+ print("RAG INPUT", openAIresponse.content + " " + query)
333
+ RAGresponse = get_response(openAIresponse.content + " " + query)
334
+ display_response(RAGresponse)
335
+ elif uploaded_image is not None:
336
+ # Open the image
337
+ input_image = Image.open(uploaded_image)
338
+
339
+ # Display the image
340
+ st.image(input_image, caption="Uploaded Image.", use_column_width=True)
341
+
342
+ # Classify the image and display the result
343
+ predictions = classifyImage(input_image)
344
+ fpredictions = ""
345
+
346
+ # Show the top predictions with percentages
347
+ st.write("Top Predictions:")
348
+ for class_name, confidence in predictions:
349
+ if int(confidence) > 0.05:
350
+ fpredictions += f"{class_name}: {confidence:.2f}%,"
351
+ st.write(f"{class_name}: {confidence:.2f}%")
352
+ print(fpredictions)
353
+
354
+ elif query:
355
+ response = get_response(query)
356
+ display_response(response)
357
+ else:
358
+ st.write("Please input an image and/or a prompt.")