itda-segment / db_segmentation.py
leedoming's picture
Upload 14 files
2dba380 verified
raw
history blame
8.65 kB
import chromadb
import logging
import open_clip
import torch
from PIL import Image
import numpy as np
from transformers import pipeline
import requests
import io
from concurrent.futures import ThreadPoolExecutor
from tqdm import tqdm
import os
# ๋กœ๊น… ์„ค์ •
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler('db_creation.log'),
logging.StreamHandler()
]
)
logger = logging.getLogger(__name__)
def load_models():
"""Load CLIP and segmentation models"""
try:
logger.info("Loading models...")
# CLIP ๋ชจ๋ธ
model, _, preprocess_val = open_clip.create_model_and_transforms('hf-hub:Marqo/marqo-fashionSigLIP')
# ์„ธ๊ทธ๋ฉ˜ํ…Œ์ด์…˜ ๋ชจ๋ธ
segmenter = pipeline(model="mattmdjaga/segformer_b2_clothes")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Using device: {device}")
model.to(device)
return model, preprocess_val, segmenter, device
except Exception as e:
logger.error(f"Error loading models: {e}")
raise
def process_segmentation(image, segmenter):
"""Apply segmentation to image"""
try:
segments = segmenter(image)
if not segments:
return None
# ๊ฐ€์žฅ ํฐ ์„ธ๊ทธ๋จผํŠธ ์„ ํƒ
largest_segment = max(segments, key=lambda s: np.sum(s['mask']))
mask = np.array(largest_segment['mask'])
return mask
except Exception as e:
logger.error(f"Segmentation error: {e}")
return None
def extract_features(image, mask, model, preprocess_val, device):
"""Extract CLIP features with segmentation mask"""
try:
if mask is not None:
img_array = np.array(image)
mask = np.expand_dims(mask, axis=2)
masked_img = img_array * mask
masked_img[mask[:,:,0] == 0] = 255 # ๋ฐฐ๊ฒฝ์„ ํฐ์ƒ‰์œผ๋กœ
image = Image.fromarray(masked_img.astype(np.uint8))
image_tensor = preprocess_val(image).unsqueeze(0).to(device)
with torch.no_grad():
features = model.encode_image(image_tensor)
features /= features.norm(dim=-1, keepdim=True)
return features.cpu().numpy().flatten()
except Exception as e:
logger.error(f"Feature extraction error: {e}")
return None
def download_and_process_image(url, metadata_id, model, preprocess_val, segmenter, device):
"""Download and process single image"""
try:
response = requests.get(url, timeout=10)
if response.status_code != 200:
logger.error(f"Failed to download image {metadata_id}: HTTP {response.status_code}")
return None
image = Image.open(io.BytesIO(response.content)).convert('RGB')
# Apply segmentation
mask = process_segmentation(image, segmenter)
if mask is None:
logger.warning(f"No valid mask found for image {metadata_id}")
return None
# Extract features
features = extract_features(image, mask, model, preprocess_val, device)
if features is None:
logger.warning(f"Failed to extract features for image {metadata_id}")
return None
return features
except Exception as e:
logger.error(f"Error processing image {metadata_id}: {e}")
return None
def create_segmented_db(source_path, target_path, batch_size=100):
"""Create new segmented database from existing one"""
try:
logger.info("Loading models...")
model, preprocess_val, segmenter, device = load_models()
# Source DB ์—ฐ๊ฒฐ
source_client = chromadb.PersistentClient(path=source_path)
source_collection = source_client.get_collection(name="clothes")
# Target DB ์ƒ์„ฑ
os.makedirs(target_path, exist_ok=True)
target_client = chromadb.PersistentClient(path=target_path)
try:
target_client.delete_collection("clothes_segmented")
except:
pass
target_collection = target_client.create_collection(
name="clothes_segmented",
metadata={"description": "Clothes collection with segmentation-based features"}
)
# ์ „์ฒด ์•„์ดํ…œ ์ˆ˜ ํ™•์ธ
all_items = source_collection.get(include=['metadatas'])
total_items = len(all_items['metadatas'])
logger.info(f"Found {total_items} items in source database")
# ๋ฐฐ์น˜ ์ฒ˜๋ฆฌ๋ฅผ ์œ„ํ•œ ์ค€๋น„
successful_updates = 0
failed_updates = 0
# ThreadPoolExecutor ์„ค์ •
max_workers = min(10, os.cpu_count() or 4)
with ThreadPoolExecutor(max_workers=max_workers) as executor:
# ์ „์ฒด ๋ฐ์ดํ„ฐ๋ฅผ ๋ฐฐ์น˜๋กœ ๋‚˜๋ˆ„์–ด ์ฒ˜๋ฆฌ
for batch_start in tqdm(range(0, total_items, batch_size), desc="Processing batches"):
batch_end = min(batch_start + batch_size, total_items)
batch_items = all_items['metadatas'][batch_start:batch_end]
# ๋ฐฐ์น˜ ๋‚ด์˜ ๋ชจ๋“  ์ด๋ฏธ์ง€์— ๋Œ€ํ•œ future ์ƒ์„ฑ
futures = []
for metadata in batch_items:
if 'image_url' in metadata:
future = executor.submit(
download_and_process_image,
metadata['image_url'],
metadata.get('id', 'unknown'),
model, preprocess_val, segmenter, device
)
futures.append((metadata, future))
# ๋ฐฐ์น˜ ๊ฒฐ๊ณผ ์ฒ˜๋ฆฌ
batch_embeddings = []
batch_metadatas = []
batch_ids = []
for metadata, future in futures:
try:
features = future.result()
if features is not None:
batch_embeddings.append(features.tolist())
batch_metadatas.append(metadata)
batch_ids.append(metadata.get('id', str(hash(metadata['image_url']))))
successful_updates += 1
else:
failed_updates += 1
except Exception as e:
logger.error(f"Error processing batch item: {e}")
failed_updates += 1
continue
# ๋ฐฐ์น˜ ๋ฐ์ดํ„ฐ ์ €์žฅ
if batch_embeddings:
try:
target_collection.add(
embeddings=batch_embeddings,
metadatas=batch_metadatas,
ids=batch_ids
)
logger.info(f"Added batch of {len(batch_embeddings)} items")
except Exception as e:
logger.error(f"Error adding batch to collection: {e}")
failed_updates += len(batch_embeddings)
successful_updates -= len(batch_embeddings)
# ์ตœ์ข… ๊ฒฐ๊ณผ ์ถœ๋ ฅ
logger.info(f"Database creation completed.")
logger.info(f"Successfully processed: {successful_updates}")
logger.info(f"Failed: {failed_updates}")
logger.info(f"Total completion rate: {(successful_updates/total_items)*100:.2f}%")
return True
except Exception as e:
logger.error(f"Database creation error: {e}")
return False
if __name__ == "__main__":
# ์„ค์ •๊ฐ’
SOURCE_DB_PATH = "./clothesDB_11GmarketMusinsa" # ์›๋ณธ DB ๊ฒฝ๋กœ
TARGET_DB_PATH = "./clothesDB_11GmarketMusinsa_segmented" # ์ƒˆ๋กœ์šด DB ๊ฒฝ๋กœ
BATCH_SIZE = 50 # ํ•œ ๋ฒˆ์— ์ฒ˜๋ฆฌํ•  ์•„์ดํ…œ ์ˆ˜
# DB ์ƒ์„ฑ ์‹คํ–‰
success = create_segmented_db(SOURCE_DB_PATH, TARGET_DB_PATH, BATCH_SIZE)
if success:
logger.info("Successfully created segmented database!")
else:
logger.error("Failed to create segmented database.")