UncertaintAI / app.py
Corran's picture
Update app.py
f61dfaa verified
raw
history blame
20.2 kB
import asyncio
import aiohttp
import urllib.parse
import time
import re
import json
from typing import List, Dict, Any
import os
import numpy as np
import random
import tqdm
import pandas as pd
from sentencex import segment
from setfit import SetFitModel
from bertopic import BERTopic
from bertopic.representation import KeyBERTInspired, PartOfSpeech
from sentence_transformers import SentenceTransformer
from sklearn.feature_extraction.text import CountVectorizer
from umap import UMAP
from sklearn.cluster import KMeans
from hdbscan import HDBSCAN
from huggingface_hub import InferenceClient
from bertopic.vectorizers import ClassTfidfTransformer
from pydantic import BaseModel
from sentence_transformers import CrossEncoder
from fastapi import FastAPI
from pydantic import BaseModel
from typing import List, Dict, Any
import os
from fastapi import FastAPI, BackgroundTasks
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
from typing import AsyncGenerator
import asyncio
from dataclasses import dataclass
import json
from collections import defaultdict
import uuid
# Add the necessary imports
app = FastAPI(title="Scientific Literature Topic Analyzer")
class Config:
CONCURRENCY_LIMIT = 9
HF_API_KEY = os.environ['HF_SECRET']
OPENALEX_EMAIL = "[email protected]"
class SentenceClassifier:
def __init__(self, model_path: str = "Corran/SciGenSetfit24"):
self.model = SetFitModel.from_pretrained(model_path)
self.mapping={'Hypothesis':[42,84,85,88],
'Aims':[68,79],
'Purpose':[33,34,35,67,76,83,86],
'Keywords':[26],
'Importance':[21,22,23,24,25,29,30],
'Background':[31,32,58,59,60,61,62,63,64,65,66,69,70,71,73,74,75,87],
'Limitations':[0,1,10,20,36,45,49],
'Method':[7,8,9,11,12,13,14,15,16,17,18,19,27,28,44,47,48,94,95,96,97,98,99],
'Uncertainty':[39,40,46,52,100],
'Result':[2,3,4,5,6,37,38,43,53,56,72,77,78,80,81,82,92,93,101,102],
'Reccomendations':[41,50,90],
'Implications':[51,89],
'Other':[54,55,57,91]}
def filter_for_class(self, target_classes):
targets = []
for target_class in target_classes:
targets.extend(self.mapping[target_class])
target_st = [self.model.id2label[i] for i in targets]
#Temp solution for binary class
target_st = ["Uncertainty"]
topic_sentence_indices = self.sentences["prediction"].isin(target_st)
class_sentences = self.sentences[topic_sentence_indices]
self.class_sentences = class_sentences
def _predict(self, sentences):
predictions = self.model.predict(sentences)
return predictions
def classify(self, corpus: Any, batch_size: int = 500):
sentences = pd.DataFrame(
[
{"article_id": article['id'], "sentence_index": si, "sentence": s}
for article_index, article in corpus.iterrows()
if len(article["abstract_sentences"]) > 0
for si, s in enumerate(article["abstract_sentences"])
]
)
sentences['batch'] = sentences.index // batch_size
predictions = []
for batch, group in sentences.groupby('batch'):
batch_sentences = group['sentence'].tolist()
batch_predictions = self._predict(batch_sentences)
predictions.extend(batch_predictions)
sentences['prediction'] = predictions
self.sentences = sentences.reset_index(drop=True)
class TopicAnalyzer:
def __init__(self, seed_words, sentence_model = "all-MiniLM-L6-v2"):
self.umap_model = UMAP(n_neighbors=3, n_components=5, min_dist=0.0, metric='cosine')
self.vectorizer_model = CountVectorizer(
ngram_range=(2, 5)
)
self.representation_model = KeyBERTInspired(top_n_words=10,nr_repr_docs=5)
self.representation_model = PartOfSpeech("en_core_web_sm")
self.hdbscan_model = HDBSCAN(
min_cluster_size=3,max_cluster_size=30, metric="euclidean", prediction_data=True
)
self.ctfidf_model = ClassTfidfTransformer(
seed_words=seed_words,
seed_multiplier=2
)
#self.hdbscan_model = KMeans(n_clusters=10)
self.sentence_model = SentenceTransformer(sentence_model)
def predict_topics(self, class_sentences):
retries=0
while retries<3:
try:
topic_model = BERTopic(
nr_topics=30,
min_topic_size=2,
top_n_words=15,
umap_model = self.umap_model,
embedding_model= self.sentence_model,
vectorizer_model=self.vectorizer_model,
hdbscan_model=self.hdbscan_model,
representation_model=self.representation_model,
ctfidf_model=self.ctfidf_model,
calculate_probabilities=True
)
topics, probs = topic_model.fit_transform(
class_sentences['sentence'].values
)
return topic_model, topics, probs
except:
retries+1
return None,None,None
class OpenAlexSearchClient:
@staticmethod
async def search(term: str, num_pages: int) -> List[str]:
term = urllib.parse.quote(term)
pages = [
f"https://api.openalex.org/works?per-page=50&page={i}&mailto={Config.OPENALEX_EMAIL}&select=id,title,publication_year,abstract_inverted_index,doi,authorships,locations&search={term}&filter=primary_location.source.type:journal"
for i in range(1, num_pages)
]
start_time = time.perf_counter()
meta, results = await OpenAlexSearchClient._download_all_pages(pages)
duration = time.perf_counter() - start_time
return meta,results
@staticmethod
async def _download_all_pages(pages: List[str]) -> List[Dict]:
semaphore = asyncio.Semaphore(Config.CONCURRENCY_LIMIT)
async with aiohttp.ClientSession() as session:
tasks = [
OpenAlexSearchClient._download_site(url, session, semaphore)
for url in pages
]
results = []
meta = []
for i in range(0, len(tasks), Config.CONCURRENCY_LIMIT):
chunk = tasks[i : i + Config.CONCURRENCY_LIMIT]
response = await asyncio.gather(*chunk, return_exceptions=True)
results.extend([article for r in response for article in r['results']])
meta.extend([article for r in response for article in r['meta']])
if i + Config.CONCURRENCY_LIMIT < len(tasks):
await asyncio.sleep(1.2)
return meta,results
@staticmethod
async def _download_site(
url: str, session: aiohttp.ClientSession, semaphore: asyncio.Semaphore
) -> Dict:
async with semaphore:
async with session.get(url) as response:
return await response.json()
class AbstractData:
def __init__(self, results):
self.abstracts = results
corpus_data = []
for article in results:
abstract = self._process_abstract(article["abstract_inverted_index"])
authorships = article["authorships"]
authors = [
{
"display_name": item['author']["display_name"],
"orcid": item['author'].get(
"orcid", None
), # Use .get() to handle cases where 'orcid' might be missing
}
for item in authorships
]
locations = article.get("locations", [{}]) # Safely get the 'locations' key or default to [{}]
if locations and isinstance(locations, list) and len(locations) > 0:
jName = locations[0].get("source", "")
jUrl = locations[0].get("landing_page_url", "")
journal = {
"display_name": jName.get("display_name") if not jName == None else None,
"url": jUrl,
}
article_data = {
"id": article['id'],
"title": article["title"],
"authors": authors,
"abstract": abstract,
"abstract_sentences": list(segment(text=abstract,language="en")),
"publication_year": article["publication_year"],
"journal": journal,
}
corpus_data.append(article_data)
data = pd.DataFrame(corpus_data)
self.data = pd.DataFrame(data)
@staticmethod
def _process_abstract(inverted_index: Dict) -> str:
if inverted_index is None:
return ""
max_list = ["" for _ in range(10000)]
for word, positions in inverted_index.items():
for position in positions:
max_list[position] = word
return " ".join(max_list)
class TopicData:
def __init__(self, class_sentences, topic_model, topics, probs, target_classes):
topic_distr, _ = topic_model.approximate_distribution(class_sentences["sentence"].values, batch_size=1000)
data = topic_model.get_document_info(class_sentences["sentence"].values).values
topics = [i[2] for i in data]
class_sentences["topic"] = topics
if isinstance(probs, list):
class_sentences["probs"] = [max(p) for p in probs]
self.topic_model = topic_model
self.class_sentences = class_sentences
self.topics = topics
self.target_classes = target_classes
def humanize_topics(self, api_key: str = Config.HF_API_KEY):
self.client = InferenceClient(api_key=api_key)
topics = self.class_sentences["topic"].values
topics = list(set(topics))
target_classes = " or ".join(self.target_classes)
map = {}
i, retries = 3,0
while i<len(topics):
topic = topics[i]
if topic[:2]=="-1":
map[topic]="ood"
i+=1
continue
topic_str = "_".join(topic.split("_")[1:])
if retries<3:
response = self._call_llm(target_classes,topic_str)
try:
struct = "{"+response.split("{")[1].split("}")[0]+"}"
result = json.loads(struct)
map[topic] = list(result.items())[0][1]
except:
retries+=1
continue
map[topic] = list(result.items())[0][1]
else:
map[topic] = topic
retries = 0
i+=1
self.class_sentences["humanized_topic"] = self.class_sentences["topic"].map(map)
def parse_for_display(self, corpus, search_term):
model = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2")
class_sentences = self.class_sentences
target_classes = " or ".join(self.target_classes)
query = f"Where do the key {target_classes} in {search_term} lie?"
humanized_topics = [str(i) for i in set(class_sentences['humanized_topic'].values)]
ranks = model.rank(query, humanized_topics)
ix = [rank['corpus_id'] for rank in ranks]
display_topics = []
for i in ix:
topic = humanized_topics[i]
if str(topic) in ["ood","nan"] or len(topic)<4:
continue
filtered_sentences = class_sentences[class_sentences['humanized_topic'] == topic]
article_ids = filtered_sentences['article_id'].values
articles = corpus.data[corpus.data['id'].isin(article_ids)].copy()
articles = articles.drop('abstract_sentences', axis=1)
articles = [a.to_dict() for _, a in articles.iterrows()]
id_to_sentence = dict(zip(filtered_sentences['article_id'], filtered_sentences['sentence']))
for article in articles:
article_id = article['id']
article['topic_sentence'] = id_to_sentence.get(article_id, None)
tr = {'topic': topic, 'articles': articles}
display_topics.append(tr)
return display_topics
def _call_llm(self,target_classes,topic):
messages = [{
"role": "user",
"content": """**disable conversational mode** **disable chain of thought reasoning**. You are simply a text preprocessor which outputs raw json. Ensure the output is valid JSON as it will be parsed
using `json.loads()` in Python.
It should be in the schema:
<output>
{
"""+target_classes+""": <Human Readable Theme/>
}"""+f"""</output>
. Summarize the one key {target_classes} theme in studies given these keywords related to the {target_classes}:{topic}"""+"""
Be concise and accurate, ensuring you intelligently include all relevent keywords within it. The uncertainty theme must be listed in note shorthanded form, and do not extrapolate. Please reply with a dict with the key and then the human readable topic notes as the item.""",
}]
completion = self.client.chat.completions.create(
model="microsoft/Phi-3.5-mini-instruct",
messages=messages,
temperature=0.4,
seed=random.randint(3, 10*10*99*34851),
max_tokens=90,
)
return completion.choices[0].message.content
def extract_topics(class_sentences, target_classes,seed_words,sentence_model="all-MiniLM-L6-v2"):
# Perform search and topic extraction
topic_analyzer = TopicAnalyzer(sentence_model=sentence_model,seed_words=seed_words)
# Extract topics
topic_model, topics, probs = topic_analyzer.predict_topics(class_sentences)
topic_data = TopicData(class_sentences, topic_model, topics, probs,target_classes)
return topic_data
from fastapi import FastAPI, BackgroundTasks
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
from typing import AsyncGenerator
import asyncio
from dataclasses import dataclass
import json
from collections import defaultdict
import uuid
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@dataclass
class ProgressState:
"""Stores progress information for a specific job"""
status: str = "pending"
current_stage: str = ""
progress: float = 0
message: str = ""
result: dict = None
# Store progress for different jobs
progress_stores = defaultdict(ProgressState)
async def update_progress(job_id: str, stage: str, progress: float, message: str):
"""Helper function to update progress and allow for yielding control"""
progress_state = progress_stores[job_id]
progress_state.current_stage = stage
progress_state.progress = progress
progress_state.message = message
# Give control back to event loop to allow progress to be sent
await asyncio.sleep(0.1)
async def process_with_progress(job_id: str, search_term: str, target_classes: list):
"""Main processing function with progress updates"""
progress = progress_stores[job_id]
try:
# Initial state
await update_progress(job_id, "initializing", 0.0, "Starting search process...")
# Search stage
await update_progress(job_id, "searching", 0.1, "Connecting to OpenAlex database...")
search_client = OpenAlexSearchClient()
await update_progress(job_id, "searching", 0.15, "Searching for relevant papers...")
meta, results = await search_client.search(search_term, num_pages=20)
await update_progress(job_id, "searching", 0.2, "Search complete")
# Corpus processing stage
await update_progress(job_id, "processing_corpus", 0.3, "Processing corpus data...")
corpus = AbstractData(results)
await update_progress(job_id, "processing_corpus", 0.4, "Corpus processing complete")
# Classification stage
await update_progress(job_id, "classifying", 0.5, "Initializing sentence classifier...")
sentence_classifier = SentenceClassifier()
await update_progress(job_id, "classifying", 0.55, "Classifying sentences...")
sentence_classifier.classify(corpus.data)
await update_progress(job_id, "classifying", 0.6, "Classification complete")
# Filtering stage
await update_progress(job_id, "filtering", 0.7, "Filtering for target classes...")
sentence_classifier.filter_for_class(target_classes=target_classes)
class_sentences = sentence_classifier.class_sentences
await update_progress(job_id, "filtering", 0.8, "Filtering complete")
# Topic extraction stage
await update_progress(job_id, "extracting_topics", 0.85, "Preparing topic extraction...")
seed_words = search_term.split()
await update_progress(job_id, "extracting_topics", 0.9, "Extracting topics...")
topic_data = extract_topics(
class_sentences=class_sentences,
seed_words=seed_words,
target_classes=target_classes,
sentence_model='Corran/SciGenSetfit24Binary'
)
await update_progress(job_id, "extracting_topics", 0.95, "Processing topics...")
topic_data.humanize_topics()
display_topics = topic_data.parse_for_display(corpus=corpus, search_term=search_term)
# Complete
progress.status = "complete"
progress.progress = 1.0
progress.message = "Processing complete"
progress.result = display_topics
await asyncio.sleep(0.1) # Final yield to ensure completion message is sent
except Exception as e:
await update_progress(job_id, "error", 0, f"Error occurred: {str(e)}")
progress.status = "failed"
await asyncio.sleep(0.1) # Ensure error message is sent
raise e
async def progress_generator(job_id: str) -> AsyncGenerator[str, None]:
"""Generates SSE events for progress updates"""
try:
while True:
progress = progress_stores[job_id]
# Send current progress
data = {
"status": progress.status,
"stage": progress.current_stage,
"progress": progress.progress,
"message": progress.message
}
yield f"data: {json.dumps(data)}\n\n"
# If process is complete or failed, send final message and end stream
if progress.status in ["complete", "failed"]:
if progress.status == "complete" and progress.result:
yield f"data: {json.dumps({'status': 'complete', 'result': progress.result})}\n\n"
break
await asyncio.sleep(0.5) # Check for updates every 500ms
finally:
# Clean up progress store after streaming ends
if job_id in progress_stores:
del progress_stores[job_id]
@app.get("/get_topics")
async def get_topics(
background_tasks: BackgroundTasks,
search_term: str
):
"""Endpoint to initiate topic processing with progress tracking"""
job_id = str(uuid.uuid4())
target_classes = ["Uncertainty"]
# Start processing in background
background_tasks.add_task(
process_with_progress,
job_id=job_id,
search_term=search_term,
target_classes=target_classes
)
return {"job_id": job_id}
@app.get("/progress/{job_id}")
async def progress(job_id: str):
"""SSE endpoint to stream progress updates"""
return StreamingResponse(
progress_generator(job_id),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"Access-Control-Allow-Origin": "*",
}
)