Spaces:
Sleeping
Sleeping
import chromadb | |
import logging | |
import open_clip | |
import torch | |
from PIL import Image | |
import numpy as np | |
from transformers import pipeline | |
import requests | |
import io | |
from concurrent.futures import ThreadPoolExecutor | |
from tqdm import tqdm | |
import os | |
# ๋ก๊น ์ค์ | |
logging.basicConfig( | |
level=logging.INFO, | |
format='%(asctime)s - %(levelname)s - %(message)s', | |
handlers=[ | |
logging.FileHandler('db_creation.log'), | |
logging.StreamHandler() | |
] | |
) | |
logger = logging.getLogger(__name__) | |
def load_models(): | |
"""Load CLIP and segmentation models""" | |
try: | |
logger.info("Loading models...") | |
# CLIP ๋ชจ๋ธ | |
model, _, preprocess_val = open_clip.create_model_and_transforms('hf-hub:Marqo/marqo-fashionSigLIP') | |
# ์ธ๊ทธ๋ฉํ ์ด์ ๋ชจ๋ธ | |
segmenter = pipeline(model="mattmdjaga/segformer_b2_clothes") | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
logger.info(f"Using device: {device}") | |
model.to(device) | |
return model, preprocess_val, segmenter, device | |
except Exception as e: | |
logger.error(f"Error loading models: {e}") | |
raise | |
def process_segmentation(image, segmenter): | |
"""Apply segmentation to image""" | |
try: | |
segments = segmenter(image) | |
if not segments: | |
return None | |
# ๊ฐ์ฅ ํฐ ์ธ๊ทธ๋จผํธ ์ ํ | |
largest_segment = max(segments, key=lambda s: np.sum(s['mask'])) | |
mask = np.array(largest_segment['mask']) | |
return mask | |
except Exception as e: | |
logger.error(f"Segmentation error: {e}") | |
return None | |
def extract_features(image, mask, model, preprocess_val, device): | |
"""Extract CLIP features with segmentation mask""" | |
try: | |
if mask is not None: | |
img_array = np.array(image) | |
mask = np.expand_dims(mask, axis=2) | |
masked_img = img_array * mask | |
masked_img[mask[:,:,0] == 0] = 255 # ๋ฐฐ๊ฒฝ์ ํฐ์์ผ๋ก | |
image = Image.fromarray(masked_img.astype(np.uint8)) | |
image_tensor = preprocess_val(image).unsqueeze(0).to(device) | |
with torch.no_grad(): | |
features = model.encode_image(image_tensor) | |
features /= features.norm(dim=-1, keepdim=True) | |
return features.cpu().numpy().flatten() | |
except Exception as e: | |
logger.error(f"Feature extraction error: {e}") | |
return None | |
def download_and_process_image(url, metadata_id, model, preprocess_val, segmenter, device): | |
"""Download and process single image""" | |
try: | |
response = requests.get(url, timeout=10) | |
if response.status_code != 200: | |
logger.error(f"Failed to download image {metadata_id}: HTTP {response.status_code}") | |
return None | |
image = Image.open(io.BytesIO(response.content)).convert('RGB') | |
# Apply segmentation | |
mask = process_segmentation(image, segmenter) | |
if mask is None: | |
logger.warning(f"No valid mask found for image {metadata_id}") | |
return None | |
# Extract features | |
features = extract_features(image, mask, model, preprocess_val, device) | |
if features is None: | |
logger.warning(f"Failed to extract features for image {metadata_id}") | |
return None | |
return features | |
except Exception as e: | |
logger.error(f"Error processing image {metadata_id}: {e}") | |
return None | |
def create_segmented_db(source_path, target_path, batch_size=100): | |
"""Create new segmented database from existing one""" | |
try: | |
logger.info("Loading models...") | |
model, preprocess_val, segmenter, device = load_models() | |
# Source DB ์ฐ๊ฒฐ | |
source_client = chromadb.PersistentClient(path=source_path) | |
source_collection = source_client.get_collection(name="clothes") | |
# Target DB ์์ฑ | |
os.makedirs(target_path, exist_ok=True) | |
target_client = chromadb.PersistentClient(path=target_path) | |
try: | |
target_client.delete_collection("clothes_segmented") | |
except: | |
pass | |
target_collection = target_client.create_collection( | |
name="clothes_segmented", | |
metadata={"description": "Clothes collection with segmentation-based features"} | |
) | |
# ์ ์ฒด ์์ดํ ์ ํ์ธ | |
all_items = source_collection.get(include=['metadatas']) | |
total_items = len(all_items['metadatas']) | |
logger.info(f"Found {total_items} items in source database") | |
# ๋ฐฐ์น ์ฒ๋ฆฌ๋ฅผ ์ํ ์ค๋น | |
successful_updates = 0 | |
failed_updates = 0 | |
# ThreadPoolExecutor ์ค์ | |
max_workers = min(10, os.cpu_count() or 4) | |
with ThreadPoolExecutor(max_workers=max_workers) as executor: | |
# ์ ์ฒด ๋ฐ์ดํฐ๋ฅผ ๋ฐฐ์น๋ก ๋๋์ด ์ฒ๋ฆฌ | |
for batch_start in tqdm(range(0, total_items, batch_size), desc="Processing batches"): | |
batch_end = min(batch_start + batch_size, total_items) | |
batch_items = all_items['metadatas'][batch_start:batch_end] | |
# ๋ฐฐ์น ๋ด์ ๋ชจ๋ ์ด๋ฏธ์ง์ ๋ํ future ์์ฑ | |
futures = [] | |
for metadata in batch_items: | |
if 'image_url' in metadata: | |
future = executor.submit( | |
download_and_process_image, | |
metadata['image_url'], | |
metadata.get('id', 'unknown'), | |
model, preprocess_val, segmenter, device | |
) | |
futures.append((metadata, future)) | |
# ๋ฐฐ์น ๊ฒฐ๊ณผ ์ฒ๋ฆฌ | |
batch_embeddings = [] | |
batch_metadatas = [] | |
batch_ids = [] | |
for metadata, future in futures: | |
try: | |
features = future.result() | |
if features is not None: | |
batch_embeddings.append(features.tolist()) | |
batch_metadatas.append(metadata) | |
batch_ids.append(metadata.get('id', str(hash(metadata['image_url'])))) | |
successful_updates += 1 | |
else: | |
failed_updates += 1 | |
except Exception as e: | |
logger.error(f"Error processing batch item: {e}") | |
failed_updates += 1 | |
continue | |
# ๋ฐฐ์น ๋ฐ์ดํฐ ์ ์ฅ | |
if batch_embeddings: | |
try: | |
target_collection.add( | |
embeddings=batch_embeddings, | |
metadatas=batch_metadatas, | |
ids=batch_ids | |
) | |
logger.info(f"Added batch of {len(batch_embeddings)} items") | |
except Exception as e: | |
logger.error(f"Error adding batch to collection: {e}") | |
failed_updates += len(batch_embeddings) | |
successful_updates -= len(batch_embeddings) | |
# ์ต์ข ๊ฒฐ๊ณผ ์ถ๋ ฅ | |
logger.info(f"Database creation completed.") | |
logger.info(f"Successfully processed: {successful_updates}") | |
logger.info(f"Failed: {failed_updates}") | |
logger.info(f"Total completion rate: {(successful_updates/total_items)*100:.2f}%") | |
return True | |
except Exception as e: | |
logger.error(f"Database creation error: {e}") | |
return False | |
if __name__ == "__main__": | |
# ์ค์ ๊ฐ | |
SOURCE_DB_PATH = "./clothesDB_11GmarketMusinsa" # ์๋ณธ DB ๊ฒฝ๋ก | |
TARGET_DB_PATH = "./clothesDB_11GmarketMusinsa_segmented" # ์๋ก์ด DB ๊ฒฝ๋ก | |
BATCH_SIZE = 50 # ํ ๋ฒ์ ์ฒ๋ฆฌํ ์์ดํ ์ | |
# DB ์์ฑ ์คํ | |
success = create_segmented_db(SOURCE_DB_PATH, TARGET_DB_PATH, BATCH_SIZE) | |
if success: | |
logger.info("Successfully created segmented database!") | |
else: | |
logger.error("Failed to create segmented database.") |