UncertaintAI / app.py
Corran's picture
Update app.py
d23a661 verified
import asyncio
import aiohttp
import urllib.parse
import time
import re
import json
from typing import List, Dict, Any
import os
import numpy as np
import random
import tqdm
import pandas as pd
from sentencex import segment
from setfit import SetFitModel
from bertopic import BERTopic
from bertopic.representation import KeyBERTInspired, PartOfSpeech
from sentence_transformers import SentenceTransformer
from sklearn.feature_extraction.text import CountVectorizer
from umap import UMAP
from sklearn.cluster import KMeans
from hdbscan import HDBSCAN
from huggingface_hub import InferenceClient
from bertopic.vectorizers import ClassTfidfTransformer
from pydantic import BaseModel
from sentence_transformers import CrossEncoder
from fastapi import FastAPI
from pydantic import BaseModel
from typing import List, Dict, Any
import os
from fastapi import FastAPI, BackgroundTasks
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
from typing import AsyncGenerator
import asyncio
from dataclasses import dataclass
import json
from collections import defaultdict
import uuid
# Add the necessary imports
app = FastAPI(title="Scientific Literature Topic Analyzer")
class Config:
CONCURRENCY_LIMIT = 9
HF_API_KEY = os.environ['HF_SECRET']
OPENALEX_EMAIL = "[email protected]"
class SentenceClassifier:
def __init__(self, model_path: str = "Corran/SciGenAllMiniLMSetFit"):
self.model = SetFitModel.from_pretrained(model_path)
self.mapping={'Hypothesis':[42,84,85,88],
'Aims':[68,79],
'Purpose':[33,34,35,67,76,83,86],
'Keywords':[26],
'Importance':[21,22,23,24,25,29,30],
'Background':[31,32,58,59,60,61,62,63,64,65,66,69,70,71,73,74,75,87],
'Limitations':[0,1,10,20,36,45,49],
'Method':[7,8,9,11,12,13,14,15,16,17,18,19,27,28,44,47,48,94,95,96,97,98,99],
'Uncertainty':[39,40,46,52,100],
'Result':[2,3,4,5,6,37,38,43,53,56,72,77,78,80,81,82,92,93,101,102],
'Reccomendations':[41,50,90],
'Implications':[51,89],
'Other':[54,55,57,91]}
self.id2label = {0: 'Acknowledging limitation(s) whilst stating a finding or contribution',
1: 'Advising cautious interpretation of the findings',
2: 'Commenting on the findings',
3: 'Commenting on the strengths of the current study',
4: 'Comparing the result: contradicting previous findings',
5: 'Comparing the result: supporting previous findings',
6: 'Contrasting sources with ‘however’ for emphasis',
7: 'Describing previously used methods',
8: 'Describing questionnaire design',
9: 'Describing the characteristics of the participants',
10: 'Describing the limitations of the current study',
11: 'Describing the process: adverbs of manner',
12: 'Describing the process: expressing purpose with for',
13: 'Describing the process: infinitive of purpose',
14: 'Describing the process: sequence words',
15: 'Describing the process: statistical procedures',
16: 'Describing the process: typical verbs in the passive form',
17: 'Describing the process: using + instrument',
18: 'Describing the research design and the methods used',
19: 'Describing what other writers do in their published work',
20: 'Detailing specific limitations',
21: 'Establishing the importance of the topic for the discipline',
22: 'Establishing the importance of the topic for the discipline: time frame given',
23: 'Establishing the importance of the topic for the world or society',
24: 'Establishing the importance of the topic for the world or society: time frame given',
25: 'Establising the importance of the topic as a problem to be addressed',
26: 'Explaining keywords (also refer to Defining Terms)',
27: 'Explaining the provenance of articles for review',
28: 'Explaining the provenance of the participants',
29: 'Explaining the significance of the current study',
30: 'Explaining the significance of the findings or contribution of the study',
31: 'General comments on the relevant literature',
32: 'General reference to previous research or scholarship: highlighting negative outcomes',
33: 'Giving reasons for personal interest in the research (sometimes found in the humanities, and the applied human sciences)',
34: 'Giving reasons why a particular method was adopted',
35: 'Giving reasons why a particular method was rejected',
36: 'Highlighting inadequacies or weaknesses of previous studies (also refer to Being Critical)',
37: 'Highlighting interesting or surprising results',
38: 'Highlighting significant data in a table or chart',
39: 'Identifying a controversy within the field of study',
40: 'Identifying a knowledge gap in the field of study',
41: 'Implications and/or recommendations for practice or policy',
42: 'Indicating an expected outcome',
43: 'Indicating an unexpected outcome',
44: 'Indicating criteria for selection or inclusion in the study',
45: 'Indicating methodological problems or limitations',
46: 'Indicating missing, weak, or contradictory evidence',
47: 'Indicating the methodology for the current research',
48: 'Indicating the use of an established method',
49: 'Introducing the limitations of the current study',
50: 'Making recommendations for further research work',
51: 'Noting implications of the findings',
52: 'Noting the lack of or paucity of previous research',
53: 'Offering an explanation for the findings',
54: 'Outlining the structure of a short paper',
55: 'Outlining the structure of a thesis or dissertation',
56: 'Pointing out interesting or important findings',
57: 'Previewing a chapter',
58: 'Previous research: A historic perspective',
59: 'Previous research: Approaches taken',
60: 'Previous research: What has been established or proposed',
61: 'Previous research: area investigated as the sentence object',
62: 'Previous research: area investigated as the sentence subject',
63: 'Previous research: highlighting negative outcomes',
64: 'Providing background information: reference to the literature',
65: 'Providing background information: reference to the purpose of the study',
66: 'Reference to previous research: important studies',
67: 'Referring back to the purpose of the paper or study',
68: 'Referring back to the research aims or procedures',
69: 'Referring to a single investigation in the past: investigation prominent',
70: 'Referring to a single investigation in the past: researcher prominent',
71: 'Referring to another writer’s idea(s) or position',
72: 'Referring to data in a table or chart',
73: 'Referring to important texts in the area of interest',
74: 'Referring to previous work to establish what is already known',
75: 'Referring to secondary sources',
76: 'Referring to the literature to justify a method or approach ',
77: 'Reporting positive and negative reactions',
78: 'Restating a result or one of several results',
79: 'Setting out the research questions or hypotheses',
80: 'Some ways of introducing quotations',
81: 'Stating a negative result',
82: 'Stating a positive result',
83: 'Stating purpose of the current research with reference to gaps or issues in the literature',
84: 'Stating the aims of the current research (note frequent use of past tense)',
85: 'Stating the focus, aim, or argument of a short paper',
86: 'Stating the purpose of the thesis, dissertation, or research article (note use of present tense)',
87: 'Stating what is currently known about the topic',
88: 'Suggesting general hypotheses',
89: 'Suggesting implications for what is already known',
90: 'Suggestions for future work',
91: 'Summarising the literature review',
92: 'Summarising the main research findings',
93: 'Summarising the results section',
94: 'Summarising the studies reviewed',
95: 'Surveys and interviews: Introducing excerpts from interview data',
96: 'Surveys and interviews: Reporting participants’ views',
97: 'Surveys and interviews: Reporting proportions',
98: 'Surveys and interviews: Reporting response rates',
99: 'Surveys and interviews: Reporting themes',
100: 'Synthesising sources: contrasting evidence or ideas',
101: 'Synthesising sources: supporting evidence or ideas',
102: 'Transition: moving to the next result'}
def filter_for_class(self, target_classes):
targets = []
for target_class in target_classes:
targets.extend(self.mapping[target_class])
target_st = [self.id2label[i] for i in targets]
#Temp solution for binary class
#target_st = ["Uncertainty"]
topic_sentence_indices = self.sentences["prediction"].isin(target_st)
class_sentences = self.sentences[topic_sentence_indices]
self.class_sentences = class_sentences
def _predict(self, sentences):
predictions = self.model.predict(sentences)
return predictions
def classify(self, corpus: Any, batch_size: int = 500):
sentences = pd.DataFrame(
[
{"article_id": article['id'], "sentence_index": si, "sentence": s}
for article_index, article in corpus.iterrows()
if len(article["abstract_sentences"]) > 0
for si, s in enumerate(article["abstract_sentences"])
]
)
sentences['batch'] = sentences.index // batch_size
predictions = []
for batch, group in sentences.groupby('batch'):
batch_sentences = group['sentence'].tolist()
batch_predictions = self._predict(batch_sentences)
predictions.extend(batch_predictions)
sentences['prediction'] = predictions
self.sentences = sentences.reset_index(drop=True)
class TopicAnalyzer:
def __init__(self, seed_words, sentence_model = "all-MiniLM-L6-v2"):
self.umap_model = UMAP(n_neighbors=3, n_components=7, min_dist=0.0, metric='cosine')
self.vectorizer_model = CountVectorizer(
ngram_range=(2, 5)
)
self.representation_model = KeyBERTInspired(top_n_words=10,nr_repr_docs=5)
#self.representation_model = PartOfSpeech("en_core_web_sm")
self.hdbscan_model = HDBSCAN(
min_cluster_size=4,max_cluster_size=30, metric="euclidean", prediction_data=True
)
self.ctfidf_model = ClassTfidfTransformer(
seed_words=seed_words,
seed_multiplier=2
)
#self.hdbscan_model = KMeans(n_clusters=10)
self.sentence_model = SentenceTransformer(sentence_model)
def predict_topics(self, class_sentences):
retries=0
while retries<3:
try:
topic_model = BERTopic(
nr_topics=30,
min_topic_size=2,
top_n_words=15,
umap_model = self.umap_model,
embedding_model= self.sentence_model,
vectorizer_model=self.vectorizer_model,
hdbscan_model=self.hdbscan_model,
representation_model=self.representation_model,
calculate_probabilities=True
)
topics, probs = topic_model.fit_transform(
class_sentences['sentence'].values
)
return topic_model, topics, probs
except:
retries+1
return None,None,None
class OpenAlexSearchClient:
@staticmethod
async def search(term: str, num_pages: int) -> List[str]:
term = urllib.parse.quote(term)
pages = [
f"https://api.openalex.org/works?per-page=50&page={i}&mailto={Config.OPENALEX_EMAIL}&select=id,title,publication_year,abstract_inverted_index,doi,authorships,locations&search={term}&filter=primary_location.source.type:journal"
for i in range(1, num_pages)
]
start_time = time.perf_counter()
meta, results = await OpenAlexSearchClient._download_all_pages(pages)
duration = time.perf_counter() - start_time
return meta,results
@staticmethod
async def _download_all_pages(pages: List[str]) -> List[Dict]:
semaphore = asyncio.Semaphore(Config.CONCURRENCY_LIMIT)
async with aiohttp.ClientSession() as session:
tasks = [
OpenAlexSearchClient._download_site(url, session, semaphore)
for url in pages
]
results = []
meta = []
for i in range(0, len(tasks), Config.CONCURRENCY_LIMIT):
chunk = tasks[i : i + Config.CONCURRENCY_LIMIT]
response = await asyncio.gather(*chunk, return_exceptions=True)
results.extend([article for r in response for article in r['results']])
meta.extend([article for r in response for article in r['meta']])
if i + Config.CONCURRENCY_LIMIT < len(tasks):
await asyncio.sleep(1.2)
return meta,results
@staticmethod
async def _download_site(
url: str, session: aiohttp.ClientSession, semaphore: asyncio.Semaphore
) -> Dict:
async with semaphore:
async with session.get(url) as response:
return await response.json()
class AbstractData:
def __init__(self, results):
self.abstracts = results
corpus_data = []
for article in results:
abstract = self._process_abstract(article["abstract_inverted_index"])
authorships = article["authorships"]
authors = [
{
"display_name": item['author']["display_name"],
"orcid": item['author'].get(
"orcid", None
), # Use .get() to handle cases where 'orcid' might be missing
}
for item in authorships
]
locations = article.get("locations", [{}]) # Safely get the 'locations' key or default to [{}]
if locations and isinstance(locations, list) and len(locations) > 0:
jName = locations[0].get("source", "")
jUrl = locations[0].get("landing_page_url", "")
journal = {
"display_name": jName.get("display_name") if not jName == None else None,
"url": jUrl,
}
article_data = {
"id": article['id'],
"title": article["title"],
"authors": authors,
"abstract": abstract,
"abstract_sentences": list(segment(text=abstract,language="en")),
"publication_year": article["publication_year"],
"journal": journal,
}
corpus_data.append(article_data)
data = pd.DataFrame(corpus_data)
self.data = pd.DataFrame(data)
@staticmethod
def _process_abstract(inverted_index: Dict) -> str:
if inverted_index is None:
return ""
max_list = ["" for _ in range(10000)]
for word, positions in inverted_index.items():
for position in positions:
max_list[position] = word
return " ".join(max_list)
class TopicData:
def __init__(self, class_sentences, topic_model, topics, probs, target_classes):
topic_distr, _ = topic_model.approximate_distribution(class_sentences["sentence"].values, batch_size=1000)
data = topic_model.get_document_info(class_sentences["sentence"].values).values
topics = [i[2] for i in data]
class_sentences["topic"] = topics
if isinstance(probs, list):
class_sentences["probs"] = [max(p) for p in probs]
self.topic_model = topic_model
self.class_sentences = class_sentences
self.topics = topics
self.target_classes = target_classes
def humanize_topics(self, api_key: str = Config.HF_API_KEY):
self.client = InferenceClient(api_key=api_key)
topics = self.class_sentences["topic"].values
topics = list(set(topics))
target_classes = " or ".join(self.target_classes)
map = {}
i, retries = 3,0
while i<len(topics):
topic = topics[i]
if topic[:2]=="-1":
map[topic]="ood"
i+=1
continue
topic_str = "_".join(topic.split("_")[1:])
if retries<3:
response = self._call_llm(target_classes,topic_str)
try:
struct = "{"+response.split("{")[1].split("}")[0]+"}"
result = json.loads(struct)
map[topic] = list(result.items())[0][1]
except:
retries+=1
continue
map[topic] = list(result.items())[0][1]
else:
map[topic] = topic
retries = 0
i+=1
self.class_sentences["humanized_topic"] = self.class_sentences["topic"].map(map)
def parse_for_display(self, corpus, search_term):
model = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2")
class_sentences = self.class_sentences
target_classes = " or ".join(self.target_classes)
query = f"Where do the key {target_classes} in {search_term} lie?"
humanized_topics = [str(i) for i in set(class_sentences['humanized_topic'].values)]
ranks = model.rank(query, humanized_topics)
ix = [rank['corpus_id'] for rank in ranks]
display_topics = []
for i in ix:
topic = humanized_topics[i]
if str(topic) in ["ood","nan"] or len(topic)<4:
continue
filtered_sentences = class_sentences[class_sentences['humanized_topic'] == topic]
article_ids = filtered_sentences['article_id'].values
articles = corpus.data[corpus.data['id'].isin(article_ids)].copy()
articles = articles.drop('abstract_sentences', axis=1)
articles = [a.to_dict() for _, a in articles.iterrows()]
id_to_sentence = dict(zip(filtered_sentences['article_id'], filtered_sentences['sentence']))
for article in articles:
article_id = article['id']
article['topic_sentence'] = id_to_sentence.get(article_id, None)
tr = {'topic': topic, 'articles': articles}
display_topics.append(tr)
return display_topics
def _call_llm(self,target_classes,topic):
messages = [{
"role": "user",
"content": """**disable conversational mode** **disable chain of thought reasoning**. You are simply a text preprocessor which outputs raw json. Ensure the output is valid JSON as it will be parsed
using `json.loads()` in Python.
It should be in the schema:
<output>
{
"""+target_classes+""": <Human Readable Theme/>
}"""+f"""</output>
. Summarize the one key {target_classes} theme in studies given these keywords related to the {target_classes}:{topic}"""+"""
Be concise and accurate, ensuring you intelligently include all relevent keywords within it. The uncertainty theme must be listed in note shorthanded form, and do not extrapolate. Please reply with a dict with the key and then the human readable topic notes as the item.""",
}]
completion = self.client.chat.completions.create(
model="microsoft/Phi-3.5-mini-instruct",
messages=messages,
temperature=0.4,
seed=random.randint(3, 10*10*99*34851),
max_tokens=90,
)
return completion.choices[0].message.content
def extract_topics(class_sentences, target_classes,seed_words,sentence_model="all-MiniLM-L6-v2"):
# Perform search and topic extraction
topic_analyzer = TopicAnalyzer(sentence_model=sentence_model,seed_words=seed_words)
# Extract topics
topic_model, topics, probs = topic_analyzer.predict_topics(class_sentences)
topic_data = TopicData(class_sentences, topic_model, topics, probs,target_classes)
return topic_data
from fastapi import FastAPI, BackgroundTasks
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
from typing import AsyncGenerator
import asyncio
from dataclasses import dataclass
import json
from collections import defaultdict
import uuid
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@dataclass
class ProgressState:
"""Stores progress information for a specific job"""
status: str = "pending"
current_stage: str = ""
progress: float = 0
message: str = ""
meta: dict = None
result: dict = None
# Store progress for different jobs
progress_stores = defaultdict(ProgressState)
async def update_progress(job_id: str, stage: str, progress: float, message: str, meta: Dict[str, Any], result: Any):
"""Helper function to update progress and allow for yielding control"""
progress_state = progress_stores[job_id]
progress_state.current_stage = stage
progress_state.progress = progress
progress_state.message = message
progress_state.meta = meta
progress_state.result = result
# Give control back to event loop to allow progress to be sent
await asyncio.sleep(0.1)
async def process_with_progress(job_id: str, search_term: str, target_classes: list):
"""Main processing function with progress updates"""
progress = progress_stores[job_id]
meta = {}
try:
# Initial state
await update_progress(job_id, "initializing", 0.0, "Starting search process...", meta,None)
# Search stage
await update_progress(job_id, "searching", 0.1, "Connecting to OpenAlex database...", meta,None)
search_client = OpenAlexSearchClient()
await update_progress(job_id, "searching", 0.15, "Searching for relevant papers...", meta,None)
_, results = await search_client.search(search_term, num_pages=20)
meta['num_abstracts'] = len(results)
await update_progress(job_id, "searching", 0.2, "Search complete", meta, None)
# Corpus processing stage
await update_progress(job_id, "processing_corpus", 0.3, "Processing corpus data...", meta,None)
corpus = AbstractData(results)
await update_progress(job_id, "processing_corpus", 0.4, "Corpus processing complete", meta,None)
# Classification stage
await update_progress(job_id, "classifying", 0.5, "Initializing sentence classifier...", meta,None)
sentence_classifier = SentenceClassifier()
await update_progress(job_id, "classifying", 0.55, "Classifying sentences...", meta,None)
sentence_classifier.classify(corpus.data)
await update_progress(job_id, "classifying", 0.6, "Classification complete", meta,None)
# Filtering stage
await update_progress(job_id, "filtering", 0.7, "Filtering for target classes...", meta,None)
sentence_classifier.filter_for_class(target_classes=target_classes)
class_sentences = sentence_classifier.class_sentences
meta['num_class_sentences'] = len(class_sentences)
await update_progress(job_id, "filtering", 0.8, "Filtering complete", meta,None)
# Topic extraction stage
await update_progress(job_id, "extracting_topics", 0.85, "Preparing topic extraction...", meta,None)
seed_words = search_term.split()
await update_progress(job_id, "extracting_topics", 0.9, "Extracting topics...", meta,None)
topic_data = extract_topics(
class_sentences=class_sentences,
seed_words=seed_words,
target_classes=target_classes,
sentence_model='sentence-transformers/all-MiniLM-L6-v2'
)
meta['num_topics'] = len(topic_data.topics)
await update_progress(job_id, "extracting_topics", 0.95, "Processing topics...", meta,None)
topic_data.humanize_topics()
display_topics = topic_data.parse_for_display(corpus=corpus, search_term=search_term)
# Complete
progress.status = "complete"
progress.progress = 1.0
progress.message = "Processing complete"
progress.result = display_topics
await asyncio.sleep(0.1) # Final yield to ensure completion message is sent
except Exception as e:
await update_progress(job_id, "error", 0, f"Error occurred: {str(e)}", {}, None)
progress.status = "failed"
await asyncio.sleep(0.1) # Ensure error message is sent
raise e
async def progress_generator(job_id: str) -> AsyncGenerator[str, None]:
"""Generates SSE events for progress updates"""
try:
while True:
progress = progress_stores[job_id]
# Send current progress
data = {
"status": progress.status,
"stage": progress.current_stage,
"progress": progress.progress,
"message": progress.message,
"meta": progress.meta,
"result": progress.result
}
yield f"data: {json.dumps(data)}\n\n"
# If process is complete or failed, send final message and end stream
print("complete")
if progress.status in ["complete", "failed"]:
if progress.status == "complete" and progress.result:
print("sending result")
yield f"data: {json.dumps({'status': 'complete', 'result': progress.result})}\n\n"
break
await asyncio.sleep(0.5) # Check for updates every 500ms
finally:
# Clean up progress store after streaming ends
if job_id in progress_stores:
del progress_stores[job_id]
@app.get("/get_topics")
async def get_topics(
background_tasks: BackgroundTasks,
search_term: str
):
"""Endpoint to initiate topic processing with progress tracking"""
job_id = str(uuid.uuid4())
target_classes = ["Uncertainty"]
# Start processing in background
background_tasks.add_task(
process_with_progress,
job_id=job_id,
search_term=search_term,
target_classes=target_classes
)
return {"job_id": job_id}
@app.get("/progress/{job_id}")
async def progress(job_id: str):
"""SSE endpoint to stream progress updates."""
return StreamingResponse(
progress_generator(job_id),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"Access-Control-Allow-Origin": "*",
}
)