|
from flask import Flask, render_template, request, jsonify |
|
from qdrant_client import QdrantClient |
|
from qdrant_client import models |
|
import torch.nn.functional as F |
|
import torch |
|
from torch import Tensor |
|
from transformers import AutoTokenizer, AutoModel |
|
from qdrant_client.models import Batch, PointStruct |
|
from pickle import load, dump |
|
import numpy as np |
|
import os, time, sys |
|
from datetime import datetime as dt |
|
from datetime import timedelta |
|
from datetime import timezone |
|
from faster_whisper import WhisperModel |
|
import io |
|
|
|
app = Flask(__name__) |
|
|
|
|
|
|
|
beamsize = 2 |
|
wmodel = WhisperModel("guillaumekln/faster-whisper-small", device="cpu", compute_type="int8") |
|
|
|
|
|
qdrant_api_key = os.environ.get("qdrant_api_key") |
|
qdrant_url = os.environ.get("qdrant_url") |
|
|
|
client = QdrantClient(url=qdrant_url, port=443, api_key=qdrant_api_key, prefer_grpc=False) |
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
def average_pool(last_hidden_states: Tensor, |
|
attention_mask: Tensor) -> Tensor: |
|
last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0) |
|
return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None] |
|
|
|
tokenizer = AutoTokenizer.from_pretrained('intfloat/e5-base-v2') |
|
model = AutoModel.from_pretrained('intfloat/e5-base-v2').to(device) |
|
|
|
def e5embed(query): |
|
batch_dict = tokenizer(query, max_length=512, padding=True, truncation=True, return_tensors='pt') |
|
batch_dict = {k: v.to(device) for k, v in batch_dict.items()} |
|
outputs = model(**batch_dict) |
|
embeddings = average_pool(outputs.last_hidden_state, batch_dict['attention_mask']) |
|
embeddings = F.normalize(embeddings, p=2, dim=1) |
|
embeddings = embeddings.cpu().detach().numpy().flatten().tolist() |
|
return embeddings |
|
|
|
def get_id(collection): |
|
resp = client.scroll(collection_name=collection, limit=10000, with_payload=True, with_vectors=False,) |
|
max_id = max([r.id for r in resp[0]])+1 |
|
return int(max_id) |
|
|
|
@app.route("/") |
|
def index(): |
|
return render_template("index.html") |
|
|
|
@app.route("/search", methods=["POST"]) |
|
def search(): |
|
query = request.form["query"] |
|
collection_name = request.form["collection"] |
|
topN = 200 |
|
|
|
|
|
print('QUERY: ',query) |
|
if query.strip().startswith('tilc:'): |
|
collection_name = 'tils' |
|
qvector = "context" |
|
query = query.replace('tilc:', '') |
|
elif query.strip().startswith('til:'): |
|
collection_name = 'tils' |
|
qvector = "title" |
|
query = query.replace('til:', '') |
|
else: collection_name = 'jks' |
|
|
|
timh = time.time() |
|
sq = e5embed(query) |
|
print('EMBEDDING TIME: ', time.time() - timh) |
|
|
|
timh = time.time() |
|
if collection_name == "jks": results = client.search(collection_name=collection_name, query_vector=sq, with_payload=True, limit=topN) |
|
else: results = client.search(collection_name=collection_name, query_vector=(qvector, sq), with_payload=True, limit=100) |
|
print('SEARCH TIME: ', time.time() - timh) |
|
|
|
|
|
|
|
new_results = [] |
|
if collection_name == 'jks': |
|
for r in results: |
|
if 'date' not in r.payload: r.payload['date'] = '20200101' |
|
new_results.append({"text": r.payload['text'], "date": str(int(r.payload['date'])), "id": r.id}) |
|
else: |
|
for r in results: |
|
if 'context' in r.payload and r.payload['context'] != '': |
|
if 'date' not in r.payload: r.payload['date'] = '20200101' |
|
new_results.append({"text": r.payload['title'] + '<br>Context: ' + r.payload['context'], "url": r.payload['url'], "date": r.payload['date'], "id": r.id}) |
|
else: |
|
if 'date' not in r.payload: r.payload['date'] = '20200101' |
|
new_results.append({"text": r.payload['title'], "url": r.payload['url'], "date": r.payload['date'], "id": r.id}) |
|
return jsonify(new_results) |
|
|
|
|
|
|
|
@app.route("/add_item", methods=["POST"]) |
|
def add_item(): |
|
title = request.form["title"] |
|
url = request.form["url"] |
|
if url.strip() == '': |
|
collection_name = 'jks' |
|
cid = get_id(collection_name) |
|
print('cid', cid, time.strftime("%Y%m%d")) |
|
resp = client.upsert(collection_name=collection_name, points=Batch(ids=[cid], payloads=[{'text':title, 'date': time.strftime("%Y%m%d")}],vectors=[e5embed(title)]),) |
|
else: |
|
collection_name = 'tils' |
|
cid = get_id('tils') |
|
print('cid', cid, time.strftime("%Y%m%d"), collection_name) |
|
til = {'title': title.replace('TIL that', '').replace('TIL:', '').replace('TIL ', '').strip(), 'url': url.replace('https://', '').replace('http://', ''), "date": time.strftime("%Y%m%d_%H%M")} |
|
resp = client.upsert(collection_name="tils", points=[PointStruct(id=cid, payload=til, vector={"title": e5embed(til['title']),},)]) |
|
print('Upsert response:', resp) |
|
return jsonify({"success": True, "index": collection_name}) |
|
|
|
|
|
@app.route("/delete_joke", methods=["POST"]) |
|
def delete_joke(): |
|
joke_id = request.form["id"] |
|
collection_name = request.form["collection"] |
|
print('Deleting no.', joke_id, 'from collection', collection_name) |
|
client.delete(collection_name=collection_name, points_selector=models.PointIdsList(points=[int(joke_id)],),) |
|
return jsonify({"deleted": True}) |
|
|
|
@app.route("/whisper_transcribe", methods=["POST"]) |
|
def whisper_transcribe(): |
|
if 'audio' not in request.files: return jsonify({'error': 'No file provided'}), 400 |
|
|
|
audio_file = request.files['audio'] |
|
allowed_extensions = {'mp3', 'wav', 'ogg', 'm4a'} |
|
if not (audio_file and audio_file.filename.lower().split('.')[-1] in allowed_extensions): return jsonify({'error': 'Invalid file format'}), 400 |
|
|
|
print('Transcribing audio') |
|
audio_bytes = audio_file.read() |
|
audio_file = io.BytesIO(audio_bytes) |
|
|
|
segments, info = wmodel.transcribe(audio_file, beam_size=beamsize) |
|
text = '' |
|
starttime = time.time() |
|
for segment in segments: |
|
text += segment.text |
|
print('Time to transcribe:', time.time() - starttime, 'seconds') |
|
|
|
return jsonify({'transcription': text}) |
|
|
|
|
|
if __name__ == "__main__": |
|
app.run(host="0.0.0.0", debug=True, port=7860) |