Spaces:
Runtime error
Runtime error
File size: 4,652 Bytes
3701fee |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
import datetime
import string
import nltk
nltk.download('stopwords')
from nltk.corpus import stopwords
stop_words = stopwords.words('english')
import arxiv
import gradio as gr
import re
from datetime import datetime
import json
def rag_cleaner(inp):
rank = inp['rank']
title = inp['document_metadata']['title']
content = inp['content']
date = inp['document_metadata']['_time']
return f"{rank}. <b> {title} </b> \n Date : {date} \n Abstract: {content}"
def get_prompt_text(question, context, formatted = True, llm_model_picked = 'mistralai/Mistral-7B-Instruct-v0.2'):
if formatted:
sys_instruction = f"Context:\n {context} \n Given the following scientific paper abstracts, take a deep breath and lets think step by step to answer the question. Cite the titles of your sources when answering, do not cite links or dates."
message = f"Question: {question}"
if 'mistralai' in llm_model_picked:
return f"<s>" + f"[INST] {sys_instruction}" + f" {message}[/INST]"
elif 'gemma' in llm_model_picked:
return f"<bos><start_of_turn>user\n{sys_instruction}" + f" {message}<end_of_turn>\n"
return f"Context:\n {context} \n Given the following info, take a deep breath and lets think step by step to answer the question: {question}. Cite the titles of your sources when answering.\n\n"
def get_references(question, retriever, k = retrieve_results):
rag_out = retriever.search(query=question, k=k)
return rag_out
def get_rag(message):
return get_references(message, RAG)
def SaveResponseAndRead(result):
documentHTML5='''
<!DOCTYPE html>
<html>
<head>
<title>Read It Aloud</title>
<script type="text/javascript">
function readAloud() {
const text = document.getElementById("textArea").value;
const speech = new SpeechSynthesisUtterance(text);
window.speechSynthesis.speak(speech);
}
</script>
</head>
<body>
<h1>🔊 Read It Aloud</h1>
<textarea id="textArea" rows="10" cols="80">
'''
documentHTML5 = documentHTML5 + result
documentHTML5 = documentHTML5 + '''
</textarea>
<br>
<button onclick="readAloud()">🔊 Read Aloud</button>
</body>
</html>
'''
gr.HTML(documentHTML5)
def get_md_text_abstract(rag_answer, source = ['Arxiv Search', 'Semantic Search'][1], return_prompt_formatting = False):
if 'Semantic Search' in source:
title = rag_answer['document_metadata']['title'].replace('\n','')
#score = round(rag_answer['score'], 2)
date = rag_answer['document_metadata']['_time']
paper_abs = rag_answer['content']
authors = rag_answer['document_metadata']['authors'].replace('\n','')
doc_id = rag_answer['document_id']
paper_link = f'''https://arxiv.org/abs/{doc_id}'''
download_link = f'''https://arxiv.org/pdf/{doc_id}'''
elif 'Arxiv' in source:
title = rag_answer.title
date = rag_answer.updated.strftime('%d %b %Y')
paper_abs = rag_answer.summary.replace('\n',' ') + '\n'
authors = ', '.join([author.name for author in rag_answer.authors])
paper_link = rag_answer.links[0].href
download_link = rag_answer.links[1].href
else:
raise Exception
paper_title = f'''### {date} | [{title}]({paper_link}) | [⬇️]({download_link})\n'''
authors_formatted = f'*{authors}*' + ' \n\n'
md_text_formatted = paper_title + authors_formatted + paper_abs + '\n---------------\n'+ '\n'
if return_prompt_formatting:
prompt_formatted = f"<b> {title} </b> \n Abstract: {paper_abs}"
return md_text_formatted, prompt_formatted
return md_text_formatted
def remove_punctuation(text):
punct_str = string.punctuation
punct_str = punct_str.replace("'", "")
return text.translate(str.maketrans("", "", punct_str))
def remove_stopwords(text):
text = ' '.join(word for word in text.split(' ') if word not in stop_words)
return text
def search_cleaner(text):
new_text = text.lower()
new_text = remove_stopwords(new_text)
new_text = remove_punctuation(new_text)
return new_text
q = '(cat:cs.CV OR cat:cs.LG OR cat:cs.CL OR cat:cs.AI OR cat:cs.NE OR cat:cs.RO)'
def get_arxiv_live_search(query, client, max_results = 10):
clean_text = search_cleaner(query)
search = arxiv.Search(
query = clean_text + " AND "+q,
max_results = max_results,
sort_by = arxiv.SortCriterion.Relevance
)
results = client.results(search)
all_results = list(results)
return all_results |