kenken999's picture
fda
0f43f8a
raw
history blame
3.37 kB
import psycopg2
from sentence_transformers import SentenceTransformer
class ProductDatabase:
def __init__(self, database_url):
self.database_url = database_url
self.conn = None
self.model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
def connect(self):
self.conn = psycopg2.connect(self.database_url)
def close(self):
if self.conn:
self.conn.close()
def setup_vector_extension_and_column(self):
with self.conn.cursor() as cursor:
# pgvector拡張機能のインストール
cursor.execute("CREATE EXTENSION IF NOT EXISTS vector;")
# ベクトルカラムの追加
cursor.execute("ALTER TABLE products ADD COLUMN IF NOT EXISTS vector_col vector(384);")
self.conn.commit()
def get_embedding(self, text):
embedding = self.model.encode(text)
return embedding
def insert_vector(self, product_id, text):
vector = self.get_embedding(text).tolist() # ndarray をリストに変換
with self.conn.cursor() as cursor:
cursor.execute("UPDATE products SET vector_col = %s WHERE id = %s", (vector, product_id))
self.conn.commit()
def search_similar_vectors(self, query_text, top_k=5):
query_vector = self.get_embedding(query_text).tolist() # ndarray をリストに変換
with self.conn.cursor() as cursor:
cursor.execute("""
SELECT id, vector_col <=> %s::vector AS distance
FROM products
ORDER BY distance
LIMIT %s;
""", (query_vector, top_k))
results = cursor.fetchall()
return results
def main():
# データベース接続情報
DATABASE_URL = os.getenv("postgre_url")
# ProductDatabaseクラスのインスタンスを作成
db = ProductDatabase(DATABASE_URL)
# データベースに接続
db.connect()
try:
# pgvector拡張機能のインストールとカラムの追加
db.setup_vector_extension_and_column()
print("Vector extension installed and column added successfully.")
# サンプルデータの挿入
sample_text = """検査にはどのぐらい時間かかりますか?⇒当日に分かります。
法人取引やってますか?⇒大丈夫ですよ。成約時に必要な書類の説明
LINEで金粉送って、査定はできますか?⇒できますが、今お話した内容と同様で、検査が必要な旨を返すだけなので、金粉ではなく、他のお品物でLINE査定くださいと。
分かりました、またどうするか検討して連絡しますと"""
sample_product_id = 1 # 実際の製品IDを使用
db.insert_vector(sample_product_id, sample_text)
db.insert_vector(2, sample_text)
print(f"Vector inserted for product ID {sample_product_id}.")
# ベクトル検索
query_text = "今お話した内容と同様で"
results = db.search_similar_vectors(query_text)
print("Search results:")
for result in results:
print(result)
finally:
# 接続を閉じる
db.close()
if __name__ == "__main__":
main()