law_bot / app.py
Darka001's picture
Update app.py
de880c5 verified
raw
history blame
13.4 kB
import gradio as gr
from gradio_calendar import Calendar
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
from langchain_community.vectorstores import Chroma
from langchain_core.output_parsers import StrOutputParser
import torch
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
BitsAndBytesConfig,
pipeline,
StoppingCriteria, StoppingCriteriaList
)
from langchain.prompts import PromptTemplate
from langchain_community.llms import HuggingFacePipeline
from langchain.chains import LLMChain
from langchain_core.runnables import RunnablePassthrough, RunnableParallel
model_name= 'mistralai/Mistral-7B-v0.1'
#model_name='SherlockAssistant/Mistral-7B-Instruct-Ukrainian'
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.unk_token
tokenizer.padding_side = "right"
# Activate 4-bit precision base model loading
use_4bit = True
# Compute dtype for 4-bit base models
bnb_4bit_compute_dtype = "float16"
# Quantization type (fp4 or nf4)
bnb_4bit_quant_type = "nf4"
# Activate nested quantization for 4-bit base models (double quantization)
use_nested_quant = False
#################################################################
# Set up quantization config
#################################################################
compute_dtype = getattr(torch, bnb_4bit_compute_dtype)
bnb_config = BitsAndBytesConfig(
load_in_4bit=use_4bit,
bnb_4bit_quant_type=bnb_4bit_quant_type,
bnb_4bit_compute_dtype=compute_dtype,
bnb_4bit_use_double_quant=use_nested_quant,
)
# Check GPU compatibility with bfloat16
if compute_dtype == torch.float16 and use_4bit:
major, _ = torch.cuda.get_device_capability()
if major >= 8:
print("=" * 80)
print("Your GPU supports bfloat16: accelerate training with bf16=True")
print("=" * 80)
model = AutoModelForCausalLM.from_pretrained(
model_name, quantization_config=bnb_config)
stop_list = [" \n\nAnswer:", " \n", " \n\n"]
stop_token_ids = [tokenizer(x, return_tensors='pt', add_special_tokens=False)['input_ids'] for x in stop_list]
stop_token_ids = [torch.LongTensor(x).to("cuda") for x in stop_token_ids]
class StopOnTokens(StoppingCriteria):
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
for stop_ids in stop_token_ids:
if torch.eq(input_ids[0][-len(stop_ids[0])+1:], stop_ids[0][1:]).all():
return True
return False
stopping_criteria = StoppingCriteriaList([StopOnTokens()])
text_generation_pipeline = pipeline(
model=model,
tokenizer=tokenizer,
task="text-generation",
temperature=0.01,
repetition_penalty=1.2,
return_full_text=True,
max_new_tokens=750, do_sample=True,
top_k=50, top_p=0.95,
stopping_criteria=stopping_criteria
)
mistral_llm = HuggingFacePipeline(pipeline=text_generation_pipeline)
# # # load chroma from disk
instructor_embeddings = HuggingFaceEmbeddings(model_name="intfloat/multilingual-e5-large")
db3 = Chroma(persist_directory="chroma/", embedding_function=instructor_embeddings)
retriever = db3.as_retriever(search_type="similarity_score_threshold",
search_kwargs={"score_threshold": .5,
"k": 20})
#retriever = db3.as_retriever(search_kwargs={"k":15})
# Get pre-written rag prompt
def format_docs(docs):
return "\n\n".join(doc.page_content for doc in docs)
template ="""" [INST] Ти асистент для надання відповідей з законодавства України. Використовуй лише вказаний нижче Context максимально точно. Описуй лише події простими словами без формальностей. Пиши чотири речення і будь максимально точним. Якщо контекст пустий - пиши "Я не маю релевантної інформації. Спробуйте ще".
Context: {context}
### QUESTION:
{question}
[/INST]
"""
prompt = PromptTemplate(
input_variables=["context", "question"],
template=template,
)
rag_chain_from_docs = (
RunnablePassthrough.assign(context=(lambda x: format_docs(x["context"])))
| prompt
| mistral_llm
| StrOutputParser()
)
rag_chain_with_source = RunnableParallel(
{"context": retriever, "question": RunnablePassthrough()}
).assign(answer=rag_chain_from_docs)
def format_result(result):
unique_videos = set((doc.metadata['title'], doc.metadata['act_url']) for doc in result['context'])
# Create a plain text string where each title is followed by its URL
titles_with_links = [
f"{title}: {act_url}" for title, act_url in unique_videos
]
# Join these entries with line breaks to form a clear list
titles_string = '\n'.join(titles_with_links)
titles_formatted = f"Використані закони:\n{titles_string}"
# Combine the answer from the result with the formatted list of video links
answer = result['answer']
response = f"{answer}\n\n{titles_formatted}"
return response
def generate_with_filters(message, subject_input, rubric, date_beg, date_end):
if date_beg == "2010-01-01" and date_end == "2025-01-01":
rag_chain_with_filters = RunnableParallel(
{"context": db3.as_retriever(search_type="mmr", search_kwargs={"k": 10,
"filter": {'$and': [{'subject': {
'$in': subject_input}}, {
'rubric': {
'$in': rubric}}]}}),
"question": RunnablePassthrough()}
).assign(answer=rag_chain_from_docs)
else:
rag_chain_with_filters = RunnableParallel(
{"context": db3.as_retriever(search_type="mmr", search_kwargs={"k": 10,
"filter": {'$and': [{'subject': {
'$in': subject_input}}, {
'rubric': {
'$in': rubric}},{"act_date": {"$gte": date_beg}}, {"act_date": {"$lte": date_end}}] }}), "question": RunnablePassthrough()}
).assign(answer=rag_chain_from_docs)
result = rag_chain_with_filters.invoke(message)
return result
def summarize_act(message, act_number):
template = """" [INST] Ти асистент для надання відповідей з законодавства України.На вхід ти отримав один закон, підсуму його простими словами, викинь формальності та стандартні фрази. Додай усі зміни, які згадуються і цьому документі і опиши їх трьома реченнями. Якщо контекст пустий - пиши "Я не маю релевантної інформації. Спробуйте ще".
Context: {context}
### QUESTION:
{question}
[/INST]
"""
prompt = PromptTemplate(
input_variables=["context", "question"],
template=template,
)
rag_chain_from_docs = (
RunnablePassthrough.assign(context=(lambda x: format_docs(x["context"])))
| prompt
| mistral_llm
| StrOutputParser()
)
rag_chain_summarize = RunnableParallel(
{"context": db3.as_retriever( search_kwargs={"k": 150, "filter": {'act_number': {
'$eq': act_number}}}), "question": RunnablePassthrough()}
).assign(answer=rag_chain_from_docs)
return rag_chain_summarize.invoke("")
def generate_answer(message, history, checkbox,subject_input, rubric, date_beg, date_end, act_number):
result = ""
if checkbox:
if act_number=="":
return "Будь ласка, введіть номер акту для отримання основної інформації з нього, або зніміть відповідний прапорець."
result = summarize_act(message, act_number)
if subject_input is None and rubric is None and date_beg == "2010-01-01" and date_end == "2025-01-01":
result = rag_chain_with_source.invoke(message)
else:
if subject_input is None or subject_input ==[]:
subject_input = ["Президент України", "Кабінет міністрів України", "Народний депутат України"]
if rubric is None or rubric== []:
rubric = ['Двосторонні міжнародні угоди', 'Багатосторонні міжнародні угоди',
'Галузевий розвиток', 'Економічна політика',
'Державне будівництво', 'Соціальна політика', 'Правова політика',
'Безпека і оборона', 'Гуманітарна політика']
result = geherate_with_filters(message)
result['answer'] =result['answer'].split("[/INST]")[-1].strip()
formatted_results = format_result(result)
return formatted_results
def change_group(check_value):
if check_value :
return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)]
else:
return [gr.update(visible=True), gr.update(visible=True), gr.update(visible=True), gr.update(visible=True),
gr.update(visible=False)]
with gr.Blocks(theme=gr.themes.Soft()) as demo :
with gr.Group() as group_components :
date_beg =Calendar(type="string", label="Почакова дата пошуку", info="Натисніть на значок календаря для вибору дати", render = False, value="2010-01-01")
date_end = Calendar(type="string", label="Кінцева дата пошуку", info="Натисніть на значок календаря для вибору дати", render = False, value="2025-01-01")
subject_input = gr.Dropdown(
["Президент України", "Кабінет міністрів України", "Народний депутат України"], multiselect=True, label="Ініціатор", info="Виберіть ініціатора законопроєкту", render=False)
rubric = gr.Dropdown(['Двосторонні міжнародні угоди', 'Багатосторонні міжнародні угоди',
'Галузевий розвиток', 'Економічна політика',
'Державне будівництво', 'Соціальна політика', 'Правова політика',
'Безпека і оборона', 'Гуманітарна політика'], multiselect=True, label='Тематика', info="Оберіть, яких галузей стосується законопроєкт", render=False)
act_number = gr.Textbox(label='Номер законодавчого акту', placeholder="Наприклад: 861-20 ",visible= False, render=False)
action_checkbox = gr.Checkbox(label="Хочу отримати підсумок одного документу", value=False, render=False)
action_checkbox.input(fn=change_group, inputs= [action_checkbox], outputs = [subject_input, date_beg, date_end, rubric, act_number])
gr.ChatInterface(
generate_answer,
chatbot=gr.Chatbot(height=400, render = False),
textbox = gr.Textbox(placeholder="Ввести питання", container=False, scale=7, render = False),
title="Законодавчий Помічник",
description="Спитай мене про будь-які регуляції в чинних законах України.",
# # examples=["мобілізація", "земельна реформа", "екологія"],
cache_examples=False,
retry_btn=None,
undo_btn=None,
clear_btn=None,
submit_btn="Спитати",
stop_btn=None,
additional_inputs=[
action_checkbox,
subject_input,
rubric,
date_beg,
date_end,
act_number ],
additional_inputs_accordion=gr.Accordion(open=False, label="Додаткові фільтри", render = False)
)
demo.launch(share = False)