from flask import Flask, render_template, request, jsonify from qdrant_client import QdrantClient from qdrant_client import models import torch.nn.functional as F import torch from torch import Tensor from transformers import AutoTokenizer, AutoModel from qdrant_client.models import Batch, PointStruct from pickle import load, dump import numpy as np import os, time, sys from datetime import datetime as dt from datetime import timedelta from datetime import timezone from faster_whisper import WhisperModel import io app = Flask(__name__) # Faster Whisper setup # model_size = 'small' beamsize = 2 wmodel = WhisperModel("guillaumekln/faster-whisper-small", device="cpu", compute_type="int8") # Initialize Qdrant Client and other required settings qdrant_api_key = os.environ.get("qdrant_api_key") qdrant_url = os.environ.get("qdrant_url") client = QdrantClient(url=qdrant_url, port=443, api_key=qdrant_api_key, prefer_grpc=False) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def average_pool(last_hidden_states: Tensor, attention_mask: Tensor) -> Tensor: last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0) return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None] tokenizer = AutoTokenizer.from_pretrained('intfloat/e5-base-v2') model = AutoModel.from_pretrained('intfloat/e5-base-v2').to(device) def e5embed(query): batch_dict = tokenizer(query, max_length=512, padding=True, truncation=True, return_tensors='pt') batch_dict = {k: v.to(device) for k, v in batch_dict.items()} outputs = model(**batch_dict) embeddings = average_pool(outputs.last_hidden_state, batch_dict['attention_mask']) embeddings = F.normalize(embeddings, p=2, dim=1) embeddings = embeddings.cpu().detach().numpy().flatten().tolist() return embeddings def get_id(collection): resp = client.scroll(collection_name=collection, limit=10000, with_payload=True, with_vectors=False,) max_id = max([r.id for r in resp[0]])+1 return int(max_id) @app.route("/") def index(): return render_template("index.html") @app.route("/search", methods=["POST"]) def search(): query = request.form["query"] collection_name = request.form["collection"] topN = 200 # Define your topN value print('QUERY: ',query) if query.strip().startswith('tilc:'): collection_name = 'tils' qvector = "context" query = query.replace('tilc:', '') elif query.strip().startswith('til:'): collection_name = 'tils' qvector = "title" query = query.replace('til:', '') else: collection_name = 'jks' timh = time.time() sq = e5embed(query) print('EMBEDDING TIME: ', time.time() - timh) timh = time.time() if collection_name == "jks": results = client.search(collection_name=collection_name, query_vector=sq, with_payload=True, limit=topN) else: results = client.search(collection_name=collection_name, query_vector=(qvector, sq), with_payload=True, limit=100) print('SEARCH TIME: ', time.time() - timh) #print(results[0]) # try: new_results = [] if collection_name == 'jks': for r in results: if 'date' not in r.payload: r.payload['date'] = '20200101' new_results.append({"text": r.payload['text'], "date": str(int(r.payload['date'])), "id": r.id}) # Implement your Qdrant search here else: for r in results: if 'context' in r.payload and r.payload['context'] != '': if 'date' not in r.payload: r.payload['date'] = '20200101' new_results.append({"text": r.payload['title'] + '
Context: ' + r.payload['context'], "url": r.payload['url'], "date": r.payload['date'], "id": r.id}) else: if 'date' not in r.payload: r.payload['date'] = '20200101' new_results.append({"text": r.payload['title'], "url": r.payload['url'], "date": r.payload['date'], "id": r.id}) return jsonify(new_results) # except: # return jsonify([]) @app.route("/add_item", methods=["POST"]) def add_item(): title = request.form["title"] url = request.form["url"] if url.strip() == '': collection_name = 'jks' cid = get_id(collection_name) print('cid', cid, time.strftime("%Y%m%d")) resp = client.upsert(collection_name=collection_name, points=Batch(ids=[cid], payloads=[{'text':title, 'date': time.strftime("%Y%m%d")}],vectors=[e5embed(title)]),) else: collection_name = 'tils' cid = get_id('tils') print('cid', cid, time.strftime("%Y%m%d"), collection_name) til = {'title': title.replace('TIL that', '').replace('TIL:', '').replace('TIL ', '').strip(), 'url': url.replace('https://', '').replace('http://', ''), "date": time.strftime("%Y%m%d_%H%M")} resp = client.upsert(collection_name="tils", points=[PointStruct(id=cid, payload=til, vector={"title": e5embed(til['title']),},)]) print('Upsert response:', resp) return jsonify({"success": True, "index": collection_name}) @app.route("/delete_joke", methods=["POST"]) def delete_joke(): joke_id = request.form["id"] collection_name = request.form["collection"] print('Deleting no.', joke_id, 'from collection', collection_name) client.delete(collection_name=collection_name, points_selector=models.PointIdsList(points=[int(joke_id)],),) return jsonify({"deleted": True}) @app.route("/whisper_transcribe", methods=["POST"]) def whisper_transcribe(): if 'audio' not in request.files: return jsonify({'error': 'No file provided'}), 400 audio_file = request.files['audio'] allowed_extensions = {'mp3', 'wav', 'ogg', 'm4a'} if not (audio_file and audio_file.filename.lower().split('.')[-1] in allowed_extensions): return jsonify({'error': 'Invalid file format'}), 400 print('Transcribing audio') audio_bytes = audio_file.read() audio_file = io.BytesIO(audio_bytes) segments, info = wmodel.transcribe(audio_file, beam_size=beamsize) # beamsize is 2. text = '' starttime = time.time() for segment in segments: text += segment.text print('Time to transcribe:', time.time() - starttime, 'seconds') return jsonify({'transcription': text}) if __name__ == "__main__": app.run(host="0.0.0.0", debug=True, port=7860)