Davidzhangyuanhan commited on
Commit
68f7ba1
·
1 Parent(s): d22eaa3

Add application file

Browse files
Files changed (2) hide show
  1. app.py +25 -0
  2. requirements.txt +6 -0
app.py CHANGED
@@ -2,6 +2,7 @@ import argparse
2
  import requests
3
  import gradio as gr
4
  import numpy as np
 
5
  import torch
6
  import torch.nn as nn
7
  from PIL import Image
@@ -47,6 +48,30 @@ eval_transforms = build_transforms(224)
47
  '''
48
  borrow code from here: https://github.com/jacobgil/pytorch-grad-cam/blob/master/pytorch_grad_cam/utils/image.py
49
  '''
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
  def recognize_image(image, texts):
52
  img_t = eval_transforms(image)
 
2
  import requests
3
  import gradio as gr
4
  import numpy as np
5
+ import cv2
6
  import torch
7
  import torch.nn as nn
8
  from PIL import Image
 
48
  '''
49
  borrow code from here: https://github.com/jacobgil/pytorch-grad-cam/blob/master/pytorch_grad_cam/utils/image.py
50
  '''
51
+ def show_cam_on_image(img: np.ndarray,
52
+ mask: np.ndarray,
53
+ use_rgb: bool = False,
54
+ colormap: int = cv2.COLORMAP_JET) -> np.ndarray:
55
+ """ This function overlays the cam mask on the image as an heatmap.
56
+ By default the heatmap is in BGR format.
57
+ :param img: The base image in RGB or BGR format.
58
+ :param mask: The cam mask.
59
+ :param use_rgb: Whether to use an RGB or BGR heatmap, this should be set to True if 'img' is in RGB format.
60
+ :param colormap: The OpenCV colormap to be used.
61
+ :returns: The default image with the cam overlay.
62
+ """
63
+ heatmap = cv2.applyColorMap(np.uint8(255 * mask), colormap)
64
+ if use_rgb:
65
+ heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
66
+ heatmap = np.float32(heatmap) / 255
67
+
68
+ if np.max(img) > 1:
69
+ raise Exception(
70
+ "The input image should np.float32 in the range [0, 1]")
71
+
72
+ cam = 0.7*heatmap + 0.3*img
73
+ # cam = cam / np.max(cam)
74
+ return np.uint8(255 * cam)
75
 
76
  def recognize_image(image, texts):
77
  img_t = eval_transforms(image)
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ torch==1.10.1
2
+ torchvision==0.11.2
3
+ opencv-python-headless==4.5.3.56
4
+ timm==0.4.12
5
+ numpy
6
+