import torch import torch.nn.functional as F from torch.optim import Adam from torch.utils.data import DataLoader, Dataset from pymongo import MongoClient from transformers import BertTokenizer, BertModel import numpy as np # MongoDB Atlas 연결 설정 client = MongoClient( "mongodb+srv://waseoke:rookies3@cluster0.ps7gq.mongodb.net/test?retryWrites=true&w=majority" ) db = client["two_tower_model"] train_dataset = db["train_dataset"] # KoBERT 모델 및 토크나이저 로드 tokenizer = BertTokenizer.from_pretrained('monologg/kobert') model = BertModel.from_pretrained('monologg/kobert') # 상품 임베딩 함수 def embed_product_data(product): """ 상품 데이터를 KoBERT로 임베딩하는 함수. """ text = ( product.get("product_name", "") + " " + product.get("product_description", "") ) inputs = tokenizer( text, return_tensors="pt", truncation=True, padding=True, max_length=128 ) outputs = model(**inputs) embedding = ( outputs.last_hidden_state.mean(dim=1).detach().numpy().flatten() ) # 평균 풀링 return embedding # PyTorch Dataset 정의 class TripletDataset(Dataset): def __init__(self, dataset): self.dataset = dataset def __len__(self): return len(self.dataset) def __getitem__(self, idx): data = self.dataset[idx] anchor = torch.tensor(data["anchor_embedding"], dtype=torch.float32) positive = torch.tensor(data["positive_embedding"], dtype=torch.float32) negative = torch.tensor(data["negative_embedding"], dtype=torch.float32) return anchor, positive, negative # MongoDB에서 데이터셋 로드 및 임베딩 변환 def prepare_training_data(verbose=False): dataset = list(train_dataset.find()) if not dataset: raise ValueError("No training data found in MongoDB.") # Anchor, Positive, Negative 임베딩 생성 embedded_dataset = [] for idx, entry in enumerate(dataset): try: # Anchor, Positive, Negative 데이터 임베딩 anchor_embedding = embed_product_data(entry["anchor"]["product"]) positive_embedding = embed_product_data(entry["positive"]["product"]) negative_embedding = embed_product_data(entry["negative"]["product"]) # 임베딩 확인 (옵션으로 출력) if verbose: print(f"Sample {idx + 1}:") print( f"Anchor Embedding: {anchor_embedding[:5]}... (shape: {anchor_embedding.shape})" ) print( f"Positive Embedding: {positive_embedding[:5]}... (shape: {positive_embedding.shape})" ) print( f"Negative Embedding: {negative_embedding[:5]}... (shape: {negative_embedding.shape})" ) # 임베딩 결과 저장 embedded_dataset.append( { "anchor_embedding": anchor_embedding, "positive_embedding": positive_embedding, "negative_embedding": negative_embedding, } ) except Exception as e: print(f"Error embedding data at sample {idx + 1}: {e}") return TripletDataset(embedded_dataset) # 데이터셋 검증용 함수 def validate_embeddings(): """ 데이터셋 임베딩을 생성하고 각 임베딩의 일부를 출력하여 확인. """ print("Validating embeddings...") triplet_dataset = prepare_training_data(verbose=True) print(f"Total samples: {len(triplet_dataset)}") return triplet_dataset # Triplet Loss를 학습시키는 함수 def train_triplet_model( product_model, train_loader, num_epochs=10, learning_rate=0.001, margin=0.05 ): optimizer = Adam(product_model.parameters(), lr=learning_rate) for epoch in range(num_epochs): product_model.train() total_loss = 0 for anchor, positive, negative in train_loader: optimizer.zero_grad() # Forward pass anchor_vec = product_model(anchor) positive_vec = product_model(positive) negative_vec = product_model(negative) # Triplet loss 계산 positive_distance = F.pairwise_distance(anchor_vec, positive_vec) negative_distance = F.pairwise_distance(anchor_vec, negative_vec) triplet_loss = torch.clamp( positive_distance - negative_distance + margin, min=0 ).mean() # 역전파와 최적화 triplet_loss.backward() optimizer.step() total_loss += triplet_loss.item() print( f"Epoch {epoch + 1}/{num_epochs}, Loss: {total_loss / len(train_loader):.4f}" ) return product_model # 모델 학습 파이프라인 def main(): # 모델 초기화 (예시 모델) product_model = torch.nn.Sequential( torch.nn.Linear(768, 256), # 768: KoBERT 임베딩 차원 torch.nn.ReLU(), torch.nn.Linear(256, 128), ) # 데이터 준비 triplet_dataset = prepare_training_data() train_loader = DataLoader(triplet_dataset, batch_size=16, shuffle=True) # 모델 학습 trained_model = train_triplet_model(product_model, train_loader) # 학습된 모델 저장 torch.save(trained_model.state_dict(), "product_model.pth") print("Model training completed and saved.") print(validate_embeddings()) if __name__ == "__main__": main()