breezedeus commited on
Commit
5315596
·
1 Parent(s): a98529e
Files changed (5) hide show
  1. README.md +1 -1
  2. app.py +244 -0
  3. hf_config.yaml +20 -0
  4. local_config.yaml +20 -0
  5. requirements.txt +8 -0
README.md CHANGED
@@ -1,6 +1,6 @@
1
  ---
2
  title: Coin CLIP Retrieval
3
- emoji: 💻
4
  colorFrom: yellow
5
  colorTo: purple
6
  sdk: gradio
 
1
  ---
2
  title: Coin CLIP Retrieval
3
+ emoji: 🪙
4
  colorFrom: yellow
5
  colorTo: purple
6
  sdk: gradio
app.py ADDED
@@ -0,0 +1,244 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding: utf-8
2
+ # Copyright (C) 2023, [Breezedeus](https://github.com/breezedeus).
3
+ # Licensed to the Apache Software Foundation (ASF) under one
4
+ # or more contributor license agreements. See the NOTICE file
5
+ # distributed with this work for additional information
6
+ # regarding copyright ownership. The ASF licenses this file
7
+ # to you under the Apache License, Version 2.0 (the
8
+ # "License"); you may not use this file except in compliance
9
+ # with the License. You may obtain a copy of the License at
10
+ #
11
+ # http://www.apache.org/licenses/LICENSE-2.0
12
+ #
13
+ # Unless required by applicable law or agreed to in writing,
14
+ # software distributed under the License is distributed on an
15
+ # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16
+ # KIND, either express or implied. See the License for the
17
+ # specific language governing permissions and limitations
18
+ # under the License.
19
+
20
+ import os
21
+ import sys
22
+ import logging
23
+ from typing import List
24
+
25
+ import yaml
26
+
27
+ import gradio as gr
28
+ from PIL import Image
29
+ import numpy as np
30
+ from datasets import load_dataset
31
+ import chromadb
32
+ from chromadb import Settings
33
+
34
+ from coin_clip.utils import resize_img
35
+ from coin_clip.chroma_embedding import ChromaEmbeddingFunction
36
+ from coin_clip.detect import Detector
37
+
38
+
39
+ logging.basicConfig(level=logging.INFO)
40
+ logger = logging.getLogger(__name__)
41
+ env = os.environ.get('COIN_ENV', 'local')
42
+ if env == 'hf':
43
+ config_fp = 'hf_config.yaml'
44
+ else:
45
+ config_fp = 'local_config.yaml'
46
+ logger.info(f'Use config file: {config_fp}')
47
+
48
+ total_config = yaml.safe_load(open(config_fp))
49
+ DETECTOR = Detector(
50
+ model_name=total_config['detector']['model_name'],
51
+ device=total_config['detector']['device'],
52
+ )
53
+ # USE_REMOVE_BG = total_config['use_remove_bg']
54
+ RESIZED_TO_BEFORE_DETECT = total_config['detector'].get('resized_to', 300)
55
+
56
+
57
+ def prepare_chromadb():
58
+ if env == 'local':
59
+ return
60
+ from huggingface_hub import snapshot_download
61
+ snapshot_download(repo_type='model', repo_id='breezedeus/usa-coins-chromadb', local_dir='./')
62
+
63
+
64
+ def load_dataset(data_path):
65
+ logger.info('Load dataset from %s', data_path)
66
+
67
+ if env == 'hf':
68
+ dataset = load_dataset(data_path, split='train')
69
+ else:
70
+ dataset = load_dataset("imagefolder", data_dir=data_path, split='train')
71
+ return dataset
72
+
73
+
74
+ def detect(images):
75
+ outs = []
76
+ for idx, img in enumerate(images):
77
+ img = resize_img(img, RESIZED_TO_BEFORE_DETECT)
78
+ out = DETECTOR.detect(np.array(img))
79
+ if not out:
80
+ out = {'position': None, 'scores': 0.0}
81
+ else:
82
+ out = out[0]
83
+ out.pop('label')
84
+ out['position'] = out.pop('box')
85
+ out['from_image_idx'] = idx
86
+ outs.append(out)
87
+
88
+ box_images = []
89
+ for out, img in zip(outs, images):
90
+ if out['position'] is None:
91
+ box_images.append(None)
92
+ else:
93
+ # box 比例值转化为绝对位置值
94
+ w, h = img.size
95
+ box = out['position']
96
+ box = (int(box[0] * w), int(box[1] * h), int(box[2] * w), int(box[3] * h))
97
+ box_images.append(img.crop(box))
98
+
99
+ return outs, box_images
100
+
101
+
102
+ def load_chroma_db(db_dir, collection_name, model_name, device='cpu'):
103
+ logger.info('Load chroma db from %s', db_dir)
104
+ client = chromadb.PersistentClient(
105
+ path=db_dir, settings=Settings(anonymized_telemetry=False)
106
+ )
107
+
108
+ embedding_function = ChromaEmbeddingFunction(model_name, device)
109
+ collection = client.get_collection(
110
+ name=collection_name,
111
+ embedding_function=embedding_function,
112
+ )
113
+ return collection
114
+
115
+
116
+ def retrieve(query_image: Image.Image, collection, top_k=20) -> List[Image.Image]:
117
+ query_image = np.array(query_image)
118
+ retrieved = collection.query(
119
+ query_images=[query_image], include=['metadatas', 'distances'], n_results=top_k,
120
+ )
121
+ logger.info('retrieved ids: %s', retrieved['ids'][0])
122
+ logger.info('retrieved distances: %s', retrieved['distances'][0])
123
+ return [ds_dict[id]['image'] for id in retrieved['ids'][0]]
124
+
125
+
126
+ dataset = load_dataset(**total_config['dataset'])
127
+ ds_dict = {_d['id']: _d for _d in dataset}
128
+
129
+ prepare_chromadb()
130
+ cc_collection = load_chroma_db(**total_config['coin_clip_db'])
131
+ clip_collection = load_chroma_db(**total_config['clip_db'])
132
+
133
+
134
+ def search(image_file: Image.Image):
135
+ images = [image_file.convert('RGB')]
136
+ detected_outs, box_images = detect(images)
137
+ box_images = [img for img in box_images if img is not None]
138
+ if len(box_images) == 0:
139
+ return [
140
+ gr.update(visible=False),
141
+ gr.update(visible=True),
142
+ gr.update(visible=False),
143
+ gr.update(visible=False),
144
+ ]
145
+
146
+ box_image = box_images[0]
147
+ # breakpoint()
148
+ cc_results = retrieve(box_image, cc_collection, top_k=30)
149
+ clip_results = retrieve(box_image, clip_collection, top_k=30)
150
+ return [
151
+ gr.update(value=box_image, visible=True),
152
+ gr.update(visible=False),
153
+ gr.update(value=cc_results, visible=True),
154
+ gr.update(value=clip_results, visible=True),
155
+ ]
156
+
157
+
158
+ def main():
159
+ title = 'USA Coin Retrieval by'
160
+ desc = (
161
+ '<p style="text-align: center">Coin-CLIP: '
162
+ '<a href="https://huggingface.co/breezedeus/coin-clip-vit-base-patch32" target="_blank">Model</a>, '
163
+ '<a href="https://github.com/breezedeus/coin-clip" target="_blank">Github</a>; '
164
+ 'Author: <a href="https://www.breezedeus.com" target="_blank">Breezedeus</a> , '
165
+ '<a href="https://github.com/breezedeus" target="_blank">Github</a> </p>'
166
+ )
167
+ examples = [
168
+ 'examples/c2.jpeg',
169
+ 'examples/c20.jpg',
170
+ 'examples/c21.jpg',
171
+ 'examples/c22.png',
172
+ 'examples/c1.jpg',
173
+ 'examples/c11.jpg',
174
+ 'examples/c3.png',
175
+ 'examples/c4.jpg',
176
+ 'examples/c5.jpeg',
177
+ 'examples/c6.jpeg',
178
+ 'examples/c7.jpg',
179
+ 'examples/c8.jpeg',
180
+ ]
181
+
182
+ with gr.Blocks() as demo:
183
+ gr.Markdown(
184
+ f'<h1 style="text-align: center; margin-bottom: 1rem;">{title} <a href="https://github.com/breezedeus/coin-clip" target="_blank">Coin-CLIP</a></h1>'
185
+ )
186
+ gr.Markdown(desc)
187
+ with gr.Row(equal_height=False):
188
+ with gr.Column(variant='compact', scale=1):
189
+ gr.Markdown('### Image within a coin')
190
+ image_file = gr.Image(
191
+ label='Coin Image to Search',
192
+ type="pil",
193
+ image_mode='RGB',
194
+ height=400,
195
+ )
196
+ sub_btn = gr.Button("Submit", variant="primary")
197
+ with gr.Column(variant='compact', scale=1):
198
+ gr.Markdown('### Detected Coin')
199
+ detected_image = gr.Image(
200
+ label='Detected Coin',
201
+ type="pil",
202
+ interactive=False,
203
+ image_mode='RGB',
204
+ height=400,
205
+ )
206
+ no_detect_warn = gr.Markdown(
207
+ '**⚠️ Warning**: No coins detected in image', visible=False
208
+ )
209
+
210
+ with gr.Row(equal_height=False):
211
+ with gr.Column(variant='compact', scale=1):
212
+ gr.Markdown('### Results from Coin-CLIP')
213
+ cc_results = gr.Gallery(
214
+ label='Coin-CLIP Results', columns=3, height=2200, show_share_button=True, visible=False
215
+ )
216
+
217
+ with gr.Column(variant='compact', scale=1):
218
+ gr.Markdown('### Results from CLIP')
219
+ coin_results = gr.Gallery(
220
+ label='CLIP Results', columns=3, height=2200, show_share_button=True, visible=False
221
+ )
222
+
223
+ sub_btn.click(
224
+ search,
225
+ inputs=[image_file,],
226
+ outputs=[detected_image, no_detect_warn, cc_results, coin_results],
227
+ )
228
+
229
+ gr.Examples(
230
+ label='Examples',
231
+ examples=examples,
232
+ inputs=image_file,
233
+ outputs=[detected_image, no_detect_warn, cc_results, coin_results],
234
+ fn=search,
235
+ examples_per_page=12,
236
+ cache_examples=True,
237
+ )
238
+
239
+ demo.queue(max_size=20)
240
+ demo.launch()
241
+
242
+
243
+ if __name__ == '__main__':
244
+ main()
hf_config.yaml ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dataset:
2
+ data_path: "breezedeus/usa-coins"
3
+
4
+ detector:
5
+ model_name: "google/owlvit-base-patch32"
6
+ device: "cpu"
7
+ resized_to: 300
8
+
9
+ coin_clip_db:
10
+ db_dir: "data/coin_clip_chroma.db"
11
+ model_name: "breezedeus/coin-clip-vit-base-patch32"
12
+ collection_name: "coin_clip_collection"
13
+ device: "cpu"
14
+
15
+ clip_db:
16
+ db_dir: "data/clip_chroma.db"
17
+ model_name: "openai/clip-vit-base-patch32"
18
+ collection_name: "clip_collection"
19
+ device: "cpu"
20
+ device: "cpu"
local_config.yaml ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dataset:
2
+ data_path: "data/coin_usa"
3
+
4
+ detector:
5
+ model_name: "google/owlvit-base-patch32"
6
+ device: "cpu"
7
+ resized_to: 300
8
+
9
+ coin_clip_db:
10
+ db_dir: "data/coin_clip_chroma.db"
11
+ model_name: "../coin-clip-vit-base-patch32"
12
+ collection_name: "coin_clip_collection"
13
+ device: "cpu"
14
+
15
+ clip_db:
16
+ db_dir: "data/clip_chroma.db"
17
+ model_name: "openai/clip-vit-base-patch32"
18
+ collection_name: "clip_collection"
19
+ device: "cpu"
20
+ device: "cpu"
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ --extra-index-url https://pypi.org/simple
2
+
3
+ coin_clip==0.1
4
+ huggingface_hub
5
+ matplotlib
6
+ chromadb
7
+ datasets
8
+ numpy