leedoming commited on
Commit
b4b084f
ยท
verified ยท
1 Parent(s): badc60b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +52 -23
app.py CHANGED
@@ -63,23 +63,30 @@ def process_segmentation(image):
63
 
64
  processed_items = []
65
  for segment in output:
66
- mask = segment['mask']
67
- # ๋งˆ์Šคํฌ๊ฐ€ numpy array๊ฐ€ ์•„๋‹Œ ๊ฒฝ์šฐ ๋ณ€ํ™˜
68
- if not isinstance(mask, np.ndarray):
69
- mask = np.array(mask)
 
 
 
 
 
 
 
 
70
 
71
- # ๋งˆ์Šคํฌ๊ฐ€ 2D๊ฐ€ ์•„๋‹Œ ๊ฒฝ์šฐ ์ฒซ ๋ฒˆ์งธ ์ฑ„๋„ ์‚ฌ์šฉ
72
- if len(mask.shape) > 2:
73
- mask = mask[:, :, 0]
74
 
75
- # bool ๋งˆ์Šคํฌ๋ฅผ float๋กœ ๋ณ€ํ™˜
76
- mask = mask.astype(float)
 
 
 
77
 
78
- processed_items.append({
79
- 'mask': mask,
80
- 'label': segment.get('label', 'Unknown'),
81
- 'score': segment.get('score', 0.0)
82
- })
83
 
84
  logger.info(f"Successfully processed {len(processed_items)} segments")
85
  return processed_items
@@ -380,17 +387,35 @@ def main():
380
  cols = st.columns(2)
381
  for idx, item in enumerate(st.session_state.detected_items):
382
  with cols[idx % 2]:
383
- mask = item['mask']
384
- masked_img = np.array(st.session_state.image) * np.expand_dims(mask, axis=2)
385
- st.image(masked_img.astype(np.uint8), caption=f"Detected {item['label']}")
386
- st.write(f"Item {idx + 1}: {item['label']}")
387
- st.write(f"Confidence: {item['score']*100:.1f}%")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
388
 
 
 
 
 
389
  # ์•„์ดํ…œ ์„ ํƒ
390
  selected_idx = st.selectbox(
391
  "Select item to search:",
392
- range(len(st.session_state.detected_items)),
393
- format_func=lambda i: f"{st.session_state.detected_items[i]['label']}",
394
  key='item_selector'
395
  )
396
 
@@ -410,11 +435,15 @@ def main():
410
  # ๊ฒ€์ƒ‰ ๊ฒฐ๊ณผ ์ฒ˜๋ฆฌ
411
  if search_clicked or st.session_state.get('search_clicked', False):
412
  st.session_state.search_clicked = True
413
- selected_mask = st.session_state.detected_items[selected_idx]['mask']
 
 
 
 
414
 
415
  # ๊ฒ€์ƒ‰ ๊ฒฐ๊ณผ๋ฅผ ์„ธ์…˜ ์ƒํƒœ์— ์ €์žฅ
416
  if 'search_results' not in st.session_state:
417
- similar_items = process_search(st.session_state.image, selected_mask, num_results)
418
  st.session_state.search_results = similar_items
419
 
420
  # ์ €์žฅ๋œ ๊ฒ€์ƒ‰ ๊ฒฐ๊ณผ ํ‘œ์‹œ
 
63
 
64
  processed_items = []
65
  for segment in output:
66
+ # ๊ธฐ๋ณธ๊ฐ’์„ ํฌํ•จํ•˜์—ฌ ๋”•์…”๋„ˆ๋ฆฌ ์ƒ์„ฑ
67
+ processed_segment = {
68
+ 'label': segment.get('label', 'Unknown'),
69
+ 'score': segment.get('score', 1.0), # score๊ฐ€ ์—†์œผ๋ฉด 1.0์„ ๊ธฐ๋ณธ๊ฐ’์œผ๋กœ ์‚ฌ์šฉ
70
+ 'mask': None
71
+ }
72
+
73
+ mask = segment.get('mask')
74
+ if mask is not None:
75
+ # ๋งˆ์Šคํฌ๊ฐ€ numpy array๊ฐ€ ์•„๋‹Œ ๊ฒฝ์šฐ ๋ณ€ํ™˜
76
+ if not isinstance(mask, np.ndarray):
77
+ mask = np.array(mask)
78
 
79
+ # ๋งˆ์Šคํฌ๊ฐ€ 2D๊ฐ€ ์•„๋‹Œ ๊ฒฝ์šฐ ์ฒซ ๋ฒˆ์งธ ์ฑ„๋„ ์‚ฌ์šฉ
80
+ if len(mask.shape) > 2:
81
+ mask = mask[:, :, 0]
82
 
83
+ # bool ๋งˆ์Šคํฌ๋ฅผ float๋กœ ๋ณ€ํ™˜
84
+ processed_segment['mask'] = mask.astype(float)
85
+ else:
86
+ logger.warning(f"No mask found for segment with label {processed_segment['label']}")
87
+ continue # ๋งˆ์Šคํฌ๊ฐ€ ์—†๋Š” ์„ธ๊ทธ๋จผํŠธ๋Š” ๊ฑด๋„ˆ๋œ€
88
 
89
+ processed_items.append(processed_segment)
 
 
 
 
90
 
91
  logger.info(f"Successfully processed {len(processed_items)} segments")
92
  return processed_items
 
387
  cols = st.columns(2)
388
  for idx, item in enumerate(st.session_state.detected_items):
389
  with cols[idx % 2]:
390
+ try:
391
+ if item.get('mask') is not None:
392
+ masked_img = np.array(st.session_state.image) * np.expand_dims(item['mask'], axis=2)
393
+ st.image(masked_img.astype(np.uint8), caption=f"Detected {item.get('label', 'Unknown')}")
394
+
395
+ st.write(f"Item {idx + 1}: {item.get('label', 'Unknown')}")
396
+
397
+ # score ๊ฐ’์ด ์žˆ๊ณ  ์ˆซ์ž์ธ ๊ฒฝ์šฐ์—๋งŒ ํ‘œ์‹œ
398
+ score = item.get('score')
399
+ if score is not None and isinstance(score, (int, float)):
400
+ st.write(f"Confidence: {score*100:.1f}%")
401
+ else:
402
+ st.write("Confidence: N/A")
403
+ except Exception as e:
404
+ logger.error(f"Error displaying item {idx}: {str(e)}")
405
+ st.error(f"Error displaying item {idx}")
406
+
407
+ valid_items = [i for i in range(len(st.session_state.detected_items))
408
+ if st.session_state.detected_items[i].get('mask') is not None]
409
 
410
+ if not valid_items:
411
+ st.warning("No valid items detected for search.")
412
+ return
413
+
414
  # ์•„์ดํ…œ ์„ ํƒ
415
  selected_idx = st.selectbox(
416
  "Select item to search:",
417
+ valid_items,
418
+ format_func=lambda i: f"{st.session_state.detected_items[i].get('label', 'Unknown')}",
419
  key='item_selector'
420
  )
421
 
 
435
  # ๊ฒ€์ƒ‰ ๊ฒฐ๊ณผ ์ฒ˜๋ฆฌ
436
  if search_clicked or st.session_state.get('search_clicked', False):
437
  st.session_state.search_clicked = True
438
+ selected_item = st.session_state.detected_items[selected_idx]
439
+
440
+ if selected_item.get('mask') is None:
441
+ st.error("Selected item has no valid mask for search.")
442
+ return
443
 
444
  # ๊ฒ€์ƒ‰ ๊ฒฐ๊ณผ๋ฅผ ์„ธ์…˜ ์ƒํƒœ์— ์ €์žฅ
445
  if 'search_results' not in st.session_state:
446
+ similar_items = process_search(st.session_state.image, selected_item['mask'], num_results)
447
  st.session_state.search_results = similar_items
448
 
449
  # ์ €์žฅ๋œ ๊ฒ€์ƒ‰ ๊ฒฐ๊ณผ ํ‘œ์‹œ