waseoke commited on
Commit
45b2ac8
·
verified ·
1 Parent(s): 2dfec3c

Update train_model.py

Browse files
Files changed (1) hide show
  1. train_model.py +74 -24
train_model.py CHANGED
@@ -7,25 +7,35 @@ from transformers import BertTokenizer, BertModel
7
  import numpy as np
8
 
9
  # MongoDB Atlas 연결 설정
10
- client = MongoClient("mongodb+srv://waseoke:[email protected]/test?retryWrites=true&w=majority")
 
 
11
  db = client["two_tower_model"]
12
  train_dataset = db["train_dataset"]
13
 
14
- # BERT 모델 및 토크나이저 로드 (예: klue/bert-base)
15
- tokenizer = BertTokenizer.from_pretrained("klue/bert-base")
16
- bert_model = BertModel.from_pretrained("klue/bert-base")
 
17
 
18
  # 상품 임베딩 함수
19
  def embed_product_data(product):
20
  """
21
- 상품 데이터를 임베딩하는 함수.
22
  """
23
- text = product.get("product_name", "") + " " + product.get("product_description", "")
24
- inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=128)
25
- outputs = bert_model(**inputs)
26
- embedding = outputs.last_hidden_state.mean(dim=1).detach().numpy().flatten() # 평균 풀링
 
 
 
 
 
 
27
  return embedding
28
 
 
29
  # PyTorch Dataset 정의
30
  class TripletDataset(Dataset):
31
  def __init__(self, dataset):
@@ -41,31 +51,64 @@ class TripletDataset(Dataset):
41
  negative = torch.tensor(data["negative_embedding"], dtype=torch.float32)
42
  return anchor, positive, negative
43
 
 
44
  # MongoDB에서 데이터셋 로드 및 임베딩 변환
45
- def prepare_training_data():
46
- dataset = list(train_dataset.find()) # MongoDB에서 데이터를 가져옵니다.
47
  if not dataset:
48
  raise ValueError("No training data found in MongoDB.")
49
 
50
  # Anchor, Positive, Negative 임베딩 생성
51
  embedded_dataset = []
52
- for entry in dataset:
53
  try:
 
54
  anchor_embedding = embed_product_data(entry["anchor"]["product"])
55
  positive_embedding = embed_product_data(entry["positive"]["product"])
56
  negative_embedding = embed_product_data(entry["negative"]["product"])
57
- embedded_dataset.append({
58
- "anchor_embedding": anchor_embedding,
59
- "positive_embedding": positive_embedding,
60
- "negative_embedding": negative_embedding,
61
- })
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  except Exception as e:
63
- print(f"Error embedding data: {e}")
64
-
65
  return TripletDataset(embedded_dataset)
66
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  # Triplet Loss를 학습시키는 함수
68
- def train_triplet_model(product_model, train_loader, num_epochs=10, learning_rate=0.001, margin=1.0):
 
 
69
  optimizer = Adam(product_model.parameters(), lr=learning_rate)
70
 
71
  for epoch in range(num_epochs):
@@ -83,7 +126,9 @@ def train_triplet_model(product_model, train_loader, num_epochs=10, learning_rat
83
  # Triplet loss 계산
84
  positive_distance = F.pairwise_distance(anchor_vec, positive_vec)
85
  negative_distance = F.pairwise_distance(anchor_vec, negative_vec)
86
- triplet_loss = torch.clamp(positive_distance - negative_distance + margin, min=0).mean()
 
 
87
 
88
  # 역전파와 최적화
89
  triplet_loss.backward()
@@ -91,17 +136,20 @@ def train_triplet_model(product_model, train_loader, num_epochs=10, learning_rat
91
 
92
  total_loss += triplet_loss.item()
93
 
94
- print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {total_loss / len(train_loader):.4f}")
 
 
95
 
96
  return product_model
97
 
 
98
  # 모델 학습 파이프라인
99
  def main():
100
  # 모델 초기화 (예시 모델)
101
  product_model = torch.nn.Sequential(
102
- torch.nn.Linear(768, 256), # 768: BERT 임베딩 차원
103
  torch.nn.ReLU(),
104
- torch.nn.Linear(256, 128)
105
  )
106
 
107
  # 데이터 준비
@@ -114,6 +162,8 @@ def main():
114
  # 학습된 모델 저장
115
  torch.save(trained_model.state_dict(), "product_model.pth")
116
  print("Model training completed and saved.")
 
 
117
 
118
  if __name__ == "__main__":
119
  main()
 
7
  import numpy as np
8
 
9
  # MongoDB Atlas 연결 설정
10
+ client = MongoClient(
11
+ "mongodb+srv://waseoke:[email protected]/test?retryWrites=true&w=majority"
12
+ )
13
  db = client["two_tower_model"]
14
  train_dataset = db["train_dataset"]
15
 
16
+ # KoBERT 모델 및 토크나이저 로드
17
+ tokenizer = BertTokenizer.from_pretrained('monologg/kobert')
18
+ model = BertModel.from_pretrained('monologg/kobert')
19
+
20
 
21
  # 상품 임베딩 함수
22
  def embed_product_data(product):
23
  """
24
+ 상품 데이터를 KoBERT로 임베딩하는 함수.
25
  """
26
+ text = (
27
+ product.get("product_name", "") + " " + product.get("product_description", "")
28
+ )
29
+ inputs = tokenizer(
30
+ text, return_tensors="pt", truncation=True, padding=True, max_length=128
31
+ )
32
+ outputs = model(**inputs)
33
+ embedding = (
34
+ outputs.last_hidden_state.mean(dim=1).detach().numpy().flatten()
35
+ ) # 평균 풀링
36
  return embedding
37
 
38
+
39
  # PyTorch Dataset 정의
40
  class TripletDataset(Dataset):
41
  def __init__(self, dataset):
 
51
  negative = torch.tensor(data["negative_embedding"], dtype=torch.float32)
52
  return anchor, positive, negative
53
 
54
+
55
  # MongoDB에서 데이터셋 로드 및 임베딩 변환
56
+ def prepare_training_data(verbose=False):
57
+ dataset = list(train_dataset.find())
58
  if not dataset:
59
  raise ValueError("No training data found in MongoDB.")
60
 
61
  # Anchor, Positive, Negative 임베딩 생성
62
  embedded_dataset = []
63
+ for idx, entry in enumerate(dataset):
64
  try:
65
+ # Anchor, Positive, Negative 데이터 임베딩
66
  anchor_embedding = embed_product_data(entry["anchor"]["product"])
67
  positive_embedding = embed_product_data(entry["positive"]["product"])
68
  negative_embedding = embed_product_data(entry["negative"]["product"])
69
+
70
+ # 임베딩 확인 (옵션으로 출력)
71
+ if verbose:
72
+ print(f"Sample {idx + 1}:")
73
+ print(
74
+ f"Anchor Embedding: {anchor_embedding[:5]}... (shape: {anchor_embedding.shape})"
75
+ )
76
+ print(
77
+ f"Positive Embedding: {positive_embedding[:5]}... (shape: {positive_embedding.shape})"
78
+ )
79
+ print(
80
+ f"Negative Embedding: {negative_embedding[:5]}... (shape: {negative_embedding.shape})"
81
+ )
82
+
83
+ # 임베딩 결과 저장
84
+ embedded_dataset.append(
85
+ {
86
+ "anchor_embedding": anchor_embedding,
87
+ "positive_embedding": positive_embedding,
88
+ "negative_embedding": negative_embedding,
89
+ }
90
+ )
91
  except Exception as e:
92
+ print(f"Error embedding data at sample {idx + 1}: {e}")
93
+
94
  return TripletDataset(embedded_dataset)
95
 
96
+
97
+ # 데이터셋 검증용 함수
98
+ def validate_embeddings():
99
+ """
100
+ 데이터셋 임베딩을 생성하고 각 임베딩의 일부를 출력하여 확인.
101
+ """
102
+ print("Validating embeddings...")
103
+ triplet_dataset = prepare_training_data(verbose=True)
104
+ print(f"Total samples: {len(triplet_dataset)}")
105
+ return triplet_dataset
106
+
107
+
108
  # Triplet Loss를 학습시키는 함수
109
+ def train_triplet_model(
110
+ product_model, train_loader, num_epochs=10, learning_rate=0.001, margin=0.05
111
+ ):
112
  optimizer = Adam(product_model.parameters(), lr=learning_rate)
113
 
114
  for epoch in range(num_epochs):
 
126
  # Triplet loss 계산
127
  positive_distance = F.pairwise_distance(anchor_vec, positive_vec)
128
  negative_distance = F.pairwise_distance(anchor_vec, negative_vec)
129
+ triplet_loss = torch.clamp(
130
+ positive_distance - negative_distance + margin, min=0
131
+ ).mean()
132
 
133
  # 역전파와 최적화
134
  triplet_loss.backward()
 
136
 
137
  total_loss += triplet_loss.item()
138
 
139
+ print(
140
+ f"Epoch {epoch + 1}/{num_epochs}, Loss: {total_loss / len(train_loader):.4f}"
141
+ )
142
 
143
  return product_model
144
 
145
+
146
  # 모델 학습 파이프라인
147
  def main():
148
  # 모델 초기화 (예시 모델)
149
  product_model = torch.nn.Sequential(
150
+ torch.nn.Linear(768, 256), # 768: KoBERT 임베딩 차원
151
  torch.nn.ReLU(),
152
+ torch.nn.Linear(256, 128),
153
  )
154
 
155
  # 데이터 준비
 
162
  # 학습된 모델 저장
163
  torch.save(trained_model.state_dict(), "product_model.pth")
164
  print("Model training completed and saved.")
165
+ print(validate_embeddings())
166
+
167
 
168
  if __name__ == "__main__":
169
  main()