arxiv-RAG / helper.py
jharrison27's picture
Upload 2 files
3dbe475 verified
raw
history blame
4.61 kB
import sys
import gradio as gr
from huggingface_hub import InferenceClient
import datetime
import string
import nltk
nltk.download('stopwords')
from nltk.corpus import stopwords
stop_words = stopwords.words('english')
import arxiv
def rag_cleaner(inp):
rank = inp['rank']
title = inp['document_metadata']['title']
content = inp['content']
date = inp['document_metadata']['_time']
return f"{rank}. <b> {title} </b> \n Date : {date} \n Abstract: {content}"
def get_prompt_text(question, context, formatted=True, llm_model_picked='mistralai/Mistral-7B-Instruct-v0.2'):
sys_instruction = f"Context:\n {context} \n Given the following scientific paper abstracts, take a deep breath and lets think step by step to answer the question. Cite the titles of your sources when answering, do not cite links or dates."
message = f"Question: {question}"
if formatted:
if 'mistralai' in llm_model_picked:
return f"<s>[INST] {sys_instruction} {message}[/INST]"
elif 'gemma' in llm_model_picked:
return f"<bos><start_of_turn>user\n{sys_instruction} {message}<end_of_turn>\n"
return f"Context:\n {context} \n Given the following info, take a deep breath and lets think step by step to answer the question: {question}. Cite the titles of your sources when answering.\n\n"
def get_references(question, retriever, k):
rag_out = retriever.search(query=question, k=k)
return rag_out
def get_rag(message, RAG, RETRIEVE_RESULTS):
return get_references(message, RAG, k=RETRIEVE_RESULTS)
def SaveResponseAndRead(result):
documentHTML5 = '''
<!DOCTYPE html>
<html>
<head>
<title>Read It Aloud</title>
<script type="text/javascript">
function readAloud() {
const text = document.getElementById("textArea").value;
const speech = new SpeechSynthesisUtterance(text);
window.speechSynthesis.speak(speech);
}
</script>
</head>
<body>
<h1>🔊 Read It Aloud</h1>
<textarea id="textArea" rows="10" cols="80">
'''
documentHTML5 += result
documentHTML5 += '''
</textarea>
<br>
<button onclick="readAloud()">🔊 Read Aloud</button>
</body>
</html>
'''
return gr.HTML(documentHTML5)
def get_md_text_abstract(rag_answer, source = ['Arxiv Search', 'Semantic Search'][1], return_prompt_formatting = False):
if 'Semantic Search' in source:
title = rag_answer['document_metadata']['title'].replace('\n','')
#score = round(rag_answer['score'], 2)
date = rag_answer['document_metadata']['_time']
paper_abs = rag_answer['content']
authors = rag_answer['document_metadata']['authors'].replace('\n','')
doc_id = rag_answer['document_id']
paper_link = f'''https://arxiv.org/abs/{doc_id}'''
download_link = f'''https://arxiv.org/pdf/{doc_id}'''
elif 'Arxiv' in source:
title = rag_answer.title
date = rag_answer.updated.strftime('%d %b %Y')
paper_abs = rag_answer.summary.replace('\n',' ') + '\n'
authors = ', '.join([author.name for author in rag_answer.authors])
paper_link = rag_answer.links[0].href
download_link = rag_answer.links[1].href
else:
raise Exception
paper_title = f'''### {date} | [{title}]({paper_link}) | [⬇️]({download_link})\n'''
authors_formatted = f'*{authors}*' + ' \n\n'
md_text_formatted = paper_title + authors_formatted + paper_abs + '\n---------------\n'+ '\n'
if return_prompt_formatting:
prompt_formatted = f"<b> {title} </b> \n Abstract: {paper_abs}"
return md_text_formatted, prompt_formatted
return md_text_formatted
def remove_punctuation(text):
punct_str = string.punctuation
punct_str = punct_str.replace("'", "")
return text.translate(str.maketrans("", "", punct_str))
def remove_stopwords(text):
text = ' '.join(word for word in text.split(' ') if word not in stop_words)
return text
def search_cleaner(text):
new_text = text.lower()
new_text = remove_stopwords(new_text)
new_text = remove_punctuation(new_text)
return new_text
q = '(cat:cs.CV OR cat:cs.LG OR cat:cs.CL OR cat:cs.AI OR cat:cs.NE OR cat:cs.RO)'
def get_arxiv_live_search(query, client, max_results = 10):
clean_text = search_cleaner(query)
search = arxiv.Search(
query = clean_text + " AND "+q,
max_results = max_results,
sort_by = arxiv.SortCriterion.Relevance
)
results = client.results(search)
all_results = list(results)
return all_results