utilities / app.py
Reggie's picture
Update app.py
f37b81f verified
from flask import Flask, render_template, request, jsonify
from qdrant_client import QdrantClient
from qdrant_client import models
import torch.nn.functional as F
import torch
from torch import Tensor
from transformers import AutoTokenizer, AutoModel
from qdrant_client.models import Batch, PointStruct
from pickle import load, dump
import numpy as np
import os, time, sys
from datetime import datetime as dt
from datetime import timedelta
from datetime import timezone
from faster_whisper import WhisperModel
import io
app = Flask(__name__)
# Faster Whisper setup
# model_size = 'small'
beamsize = 2
wmodel = WhisperModel("guillaumekln/faster-whisper-small", device="cpu", compute_type="int8")
# Initialize Qdrant Client and other required settings
qdrant_api_key = os.environ.get("qdrant_api_key")
qdrant_url = os.environ.get("qdrant_url")
client = QdrantClient(url=qdrant_url, port=443, api_key=qdrant_api_key, prefer_grpc=False)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def average_pool(last_hidden_states: Tensor,
attention_mask: Tensor) -> Tensor:
last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
tokenizer = AutoTokenizer.from_pretrained('intfloat/e5-base-v2')
model = AutoModel.from_pretrained('intfloat/e5-base-v2').to(device)
def e5embed(query):
batch_dict = tokenizer(query, max_length=512, padding=True, truncation=True, return_tensors='pt')
batch_dict = {k: v.to(device) for k, v in batch_dict.items()}
outputs = model(**batch_dict)
embeddings = average_pool(outputs.last_hidden_state, batch_dict['attention_mask'])
embeddings = F.normalize(embeddings, p=2, dim=1)
embeddings = embeddings.cpu().detach().numpy().flatten().tolist()
return embeddings
def get_id(collection):
resp = client.scroll(collection_name=collection, limit=10000, with_payload=True, with_vectors=False,)
max_id = max([r.id for r in resp[0]])+1
return int(max_id)
@app.route("/")
def index():
return render_template("index.html")
@app.route("/search", methods=["POST"])
def search():
query = request.form["query"]
collection_name = request.form["collection"]
topN = 200 # Define your topN value
print('QUERY: ',query)
if query.strip().startswith('tilc:'):
collection_name = 'tils'
qvector = "context"
query = query.replace('tilc:', '')
elif query.strip().startswith('til:'):
collection_name = 'tils'
qvector = "title"
query = query.replace('til:', '')
else: collection_name = 'jks'
timh = time.time()
sq = e5embed(query)
print('EMBEDDING TIME: ', time.time() - timh)
timh = time.time()
if collection_name == "jks": results = client.search(collection_name=collection_name, query_vector=sq, with_payload=True, limit=topN)
else: results = client.search(collection_name=collection_name, query_vector=(qvector, sq), with_payload=True, limit=100)
print('SEARCH TIME: ', time.time() - timh)
#print(results[0])
# try:
new_results = []
if collection_name == 'jks':
for r in results:
if 'date' not in r.payload: r.payload['date'] = '20200101'
new_results.append({"text": r.payload['text'], "date": str(int(r.payload['date'])), "id": r.id}) # Implement your Qdrant search here
else:
for r in results:
if 'context' in r.payload and r.payload['context'] != '':
if 'date' not in r.payload: r.payload['date'] = '20200101'
new_results.append({"text": r.payload['title'] + '<br>Context: ' + r.payload['context'], "url": r.payload['url'], "date": r.payload['date'], "id": r.id})
else:
if 'date' not in r.payload: r.payload['date'] = '20200101'
new_results.append({"text": r.payload['title'], "url": r.payload['url'], "date": r.payload['date'], "id": r.id})
return jsonify(new_results)
# except:
# return jsonify([])
@app.route("/add_item", methods=["POST"])
def add_item():
title = request.form["title"]
url = request.form["url"]
if url.strip() == '':
collection_name = 'jks'
cid = get_id(collection_name)
print('cid', cid, time.strftime("%Y%m%d"))
resp = client.upsert(collection_name=collection_name, points=Batch(ids=[cid], payloads=[{'text':title, 'date': time.strftime("%Y%m%d")}],vectors=[e5embed(title)]),)
else:
collection_name = 'tils'
cid = get_id('tils')
print('cid', cid, time.strftime("%Y%m%d"), collection_name)
til = {'title': title.replace('TIL that', '').replace('TIL:', '').replace('TIL ', '').strip(), 'url': url.replace('https://', '').replace('http://', ''), "date": time.strftime("%Y%m%d_%H%M")}
resp = client.upsert(collection_name="tils", points=[PointStruct(id=cid, payload=til, vector={"title": e5embed(til['title']),},)])
print('Upsert response:', resp)
return jsonify({"success": True, "index": collection_name})
@app.route("/delete_joke", methods=["POST"])
def delete_joke():
joke_id = request.form["id"]
collection_name = request.form["collection"]
print('Deleting no.', joke_id, 'from collection', collection_name)
client.delete(collection_name=collection_name, points_selector=models.PointIdsList(points=[int(joke_id)],),)
return jsonify({"deleted": True})
@app.route("/whisper_transcribe", methods=["POST"])
def whisper_transcribe():
if 'audio' not in request.files: return jsonify({'error': 'No file provided'}), 400
audio_file = request.files['audio']
allowed_extensions = {'mp3', 'wav', 'ogg', 'm4a'}
if not (audio_file and audio_file.filename.lower().split('.')[-1] in allowed_extensions): return jsonify({'error': 'Invalid file format'}), 400
print('Transcribing audio')
audio_bytes = audio_file.read()
audio_file = io.BytesIO(audio_bytes)
segments, info = wmodel.transcribe(audio_file, beam_size=beamsize) # beamsize is 2.
text = ''
starttime = time.time()
for segment in segments:
text += segment.text
print('Time to transcribe:', time.time() - starttime, 'seconds')
return jsonify({'transcription': text})
if __name__ == "__main__":
app.run(host="0.0.0.0", debug=True, port=7860)