pmbrito commited on
Commit
0f5d290
·
1 Parent(s): c9ff809

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -3
app.py CHANGED
@@ -31,12 +31,44 @@ if raw_image != 'Select image':
31
  image = np.asarray(image)
32
 
33
  with st.spinner('Loading Model...'):
34
- feature_extractor = DPTFeatureExtractor.from_pretrained("Intel/dpt-large-ade")
35
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
36
- model = DPTForSemanticSegmentation.from_pretrained("Intel/dpt-large-ade",ignore_mismatched_sizes=True,num_labels=len(id2label), id2label=id2label, label2id=label2id,reshape_last_stage=True)
37
  model = model.to(device)
38
  model.eval()
39
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
  st.success("Success")
42
 
 
31
  image = np.asarray(image)
32
 
33
  with st.spinner('Loading Model...'):
34
+ feature_extractor = DPTFeatureExtractor.from_pretrained("Intel/dpt-large-ade")
35
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
36
+ model = DPTForSemanticSegmentation.from_pretrained("Intel/dpt-large-ade",ignore_mismatched_sizes=True,num_labels=len(id2label), id2label=id2label, label2id=label2id,reshape_last_stage=True)
37
  model = model.to(device)
38
  model.eval()
39
+
40
+ with st.spinner('Preparing image...'):
41
+ # prepare the image for the model (aligned resize)
42
+ feature_extractor_inference = DPTFeatureExtractor(do_random_crop=False, do_pad=False)
43
+ pixel_values = feature_extractor_inference(image, return_tensors="pt").pixel_values.to(device)
44
+
45
+ with st.spinner('Running inference...'):
46
+ outputs = model(pixel_values=pixel_values)# logits are of shape (batch_size, num_labels, height/4, width/4)
47
+
48
+ with st.spinner('Postprocessing...'):
49
+ logits = outputs.logits.cpu()
50
+ # First, rescale logits to original image size
51
+ upsampled_logits = nn.functional.interpolate(logits,
52
+ size=image.shape[:-1], # (height, width)
53
+ mode='bilinear',
54
+ align_corners=False)
55
+ # Second, apply argmax on the class dimension
56
+ seg = upsampled_logits.argmax(dim=1)[0]
57
+ color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8) # height, width, 3\
58
+ all_labels = []
59
+ for label, color in enumerate(palette):
60
+ color_seg[seg == label, :] = color
61
+ if label in seg:
62
+ all_labels.append(id2label[label])
63
+ # Convert to BGR
64
+ color_seg = color_seg[..., ::-1]
65
+ # Show image + mask
66
+ img = np.array(image) * 0.5 + color_seg * 0.5
67
+ img = img.astype(np.uint8)
68
+ st.image(img, caption="Segmented Image")
69
+ st.header("Predicted Labels")
70
+ for idx, label in enumerate(all_labels):
71
+ st.subheader(f'{idx+1}) {label}')
72
 
73
  st.success("Success")
74