|
import gradio as gr |
|
from gradio_calendar import Calendar |
|
|
|
from langchain.embeddings.huggingface import HuggingFaceEmbeddings |
|
from langchain_community.vectorstores import Chroma |
|
from langchain_core.output_parsers import StrOutputParser |
|
import torch |
|
from transformers import ( |
|
AutoModelForCausalLM, |
|
AutoTokenizer, |
|
BitsAndBytesConfig, |
|
pipeline, |
|
StoppingCriteria, StoppingCriteriaList |
|
) |
|
|
|
from langchain.prompts import PromptTemplate |
|
from langchain_community.llms import HuggingFacePipeline |
|
from langchain.chains import LLMChain |
|
from langchain_core.runnables import RunnablePassthrough, RunnableParallel |
|
|
|
|
|
|
|
model_name= 'mistralai/Mistral-7B-v0.1' |
|
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
tokenizer.pad_token = tokenizer.unk_token |
|
tokenizer.padding_side = "right" |
|
|
|
|
|
|
|
use_4bit = True |
|
|
|
|
|
bnb_4bit_compute_dtype = "float16" |
|
|
|
|
|
bnb_4bit_quant_type = "nf4" |
|
|
|
|
|
use_nested_quant = False |
|
|
|
|
|
|
|
|
|
compute_dtype = getattr(torch, bnb_4bit_compute_dtype) |
|
|
|
bnb_config = BitsAndBytesConfig( |
|
load_in_4bit=use_4bit, |
|
bnb_4bit_quant_type=bnb_4bit_quant_type, |
|
bnb_4bit_compute_dtype=compute_dtype, |
|
bnb_4bit_use_double_quant=use_nested_quant, |
|
) |
|
|
|
|
|
if compute_dtype == torch.float16 and use_4bit: |
|
major, _ = torch.cuda.get_device_capability() |
|
if major >= 8: |
|
print("=" * 80) |
|
print("Your GPU supports bfloat16: accelerate training with bf16=True") |
|
print("=" * 80) |
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
model_name, quantization_config=bnb_config) |
|
stop_list = [" \n\nAnswer:", " \n", " \n\n"] |
|
stop_token_ids = [tokenizer(x, return_tensors='pt', add_special_tokens=False)['input_ids'] for x in stop_list] |
|
stop_token_ids = [torch.LongTensor(x).to("cuda") for x in stop_token_ids] |
|
|
|
class StopOnTokens(StoppingCriteria): |
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: |
|
for stop_ids in stop_token_ids: |
|
if torch.eq(input_ids[0][-len(stop_ids[0])+1:], stop_ids[0][1:]).all(): |
|
return True |
|
return False |
|
|
|
stopping_criteria = StoppingCriteriaList([StopOnTokens()]) |
|
|
|
|
|
text_generation_pipeline = pipeline( |
|
model=model, |
|
tokenizer=tokenizer, |
|
task="text-generation", |
|
temperature=0.01, |
|
repetition_penalty=1.2, |
|
return_full_text=True, |
|
max_new_tokens=750, do_sample=True, |
|
top_k=50, top_p=0.95, |
|
stopping_criteria=stopping_criteria |
|
) |
|
mistral_llm = HuggingFacePipeline(pipeline=text_generation_pipeline) |
|
|
|
instructor_embeddings = HuggingFaceEmbeddings(model_name="intfloat/multilingual-e5-large") |
|
db3 = Chroma(persist_directory="chroma/", embedding_function=instructor_embeddings) |
|
|
|
|
|
|
|
|
|
retriever = db3.as_retriever(search_type="similarity_score_threshold", |
|
search_kwargs={"score_threshold": .5, |
|
"k": 20}) |
|
|
|
|
|
|
|
def format_docs(docs): |
|
return "\n\n".join(doc.page_content for doc in docs) |
|
|
|
|
|
template ="""" [INST] Ти асистент для надання відповідей з законодавства України. Використовуй лише вказаний нижче Context максимально точно. Описуй лише події простими словами без формальностей. Пиши чотири речення і будь максимально точним. Якщо контекст пустий - пиши "Я не маю релевантної інформації. Спробуйте ще". |
|
Context: {context} |
|
### QUESTION: |
|
{question} |
|
[/INST] |
|
""" |
|
prompt = PromptTemplate( |
|
input_variables=["context", "question"], |
|
template=template, |
|
) |
|
|
|
rag_chain_from_docs = ( |
|
RunnablePassthrough.assign(context=(lambda x: format_docs(x["context"]))) |
|
| prompt |
|
| mistral_llm |
|
| StrOutputParser() |
|
) |
|
|
|
rag_chain_with_source = RunnableParallel( |
|
{"context": retriever, "question": RunnablePassthrough()} |
|
).assign(answer=rag_chain_from_docs) |
|
|
|
|
|
|
|
def format_result(result): |
|
unique_videos = set((doc.metadata['title'], doc.metadata['act_url']) for doc in result['context']) |
|
|
|
|
|
titles_with_links = [ |
|
f"{title}: {act_url}" for title, act_url in unique_videos |
|
] |
|
|
|
|
|
titles_string = '\n'.join(titles_with_links) |
|
titles_formatted = f"Використані закони:\n{titles_string}" |
|
|
|
|
|
answer = result['answer'] |
|
response = f"{answer}\n\n{titles_formatted}" |
|
|
|
return response |
|
|
|
|
|
|
|
def generate_with_filters(message, subject_input, rubric, date_beg, date_end): |
|
if date_beg == "2010-01-01" and date_end == "2025-01-01": |
|
rag_chain_with_filters = RunnableParallel( |
|
{"context": db3.as_retriever(search_type="mmr", search_kwargs={"k": 10, |
|
"filter": {'$and': [{'subject': { |
|
'$in': subject_input}}, { |
|
'rubric': { |
|
'$in': rubric}}]}}), |
|
"question": RunnablePassthrough()} |
|
).assign(answer=rag_chain_from_docs) |
|
else: |
|
rag_chain_with_filters = RunnableParallel( |
|
{"context": db3.as_retriever(search_type="mmr", search_kwargs={"k": 10, |
|
"filter": {'$and': [{'subject': { |
|
'$in': subject_input}}, { |
|
'rubric': { |
|
'$in': rubric}},{"act_date": {"$gte": date_beg}}, {"act_date": {"$lte": date_end}}] }}), "question": RunnablePassthrough()} |
|
).assign(answer=rag_chain_from_docs) |
|
result = rag_chain_with_filters.invoke(message) |
|
return result |
|
|
|
def summarize_act(message, act_number): |
|
template = """" [INST] Ти асистент для надання відповідей з законодавства України.На вхід ти отримав один закон, підсуму його простими словами, викинь формальності та стандартні фрази. Додай усі зміни, які згадуються і цьому документі і опиши їх трьома реченнями. Якщо контекст пустий - пиши "Я не маю релевантної інформації. Спробуйте ще". |
|
Context: {context} |
|
### QUESTION: |
|
{question} |
|
[/INST] |
|
""" |
|
prompt = PromptTemplate( |
|
input_variables=["context", "question"], |
|
template=template, |
|
) |
|
|
|
rag_chain_from_docs = ( |
|
RunnablePassthrough.assign(context=(lambda x: format_docs(x["context"]))) |
|
| prompt |
|
| mistral_llm |
|
| StrOutputParser() |
|
) |
|
|
|
rag_chain_summarize = RunnableParallel( |
|
{"context": db3.as_retriever( search_kwargs={"k": 150, "filter": {'act_number': { |
|
'$eq': act_number}}}), "question": RunnablePassthrough()} |
|
).assign(answer=rag_chain_from_docs) |
|
return rag_chain_summarize.invoke("") |
|
|
|
|
|
def generate_answer(message, history, checkbox,subject_input, rubric, date_beg, date_end, act_number): |
|
result = "" |
|
if checkbox: |
|
if act_number=="": |
|
return "Будь ласка, введіть номер акту для отримання основної інформації з нього, або зніміть відповідний прапорець." |
|
result = summarize_act(message, act_number) |
|
if subject_input is None and rubric is None and date_beg == "2010-01-01" and date_end == "2025-01-01": |
|
result = rag_chain_with_source.invoke(message) |
|
|
|
else: |
|
if subject_input is None or subject_input ==[]: |
|
subject_input = ["Президент України", "Кабінет міністрів України", "Народний депутат України"] |
|
if rubric is None or rubric== []: |
|
rubric = ['Двосторонні міжнародні угоди', 'Багатосторонні міжнародні угоди', |
|
'Галузевий розвиток', 'Економічна політика', |
|
'Державне будівництво', 'Соціальна політика', 'Правова політика', |
|
'Безпека і оборона', 'Гуманітарна політика'] |
|
result = geherate_with_filters(message) |
|
|
|
result['answer'] =result['answer'].split("[/INST]")[-1].strip() |
|
formatted_results = format_result(result) |
|
return formatted_results |
|
|
|
|
|
|
|
def change_group(check_value): |
|
if check_value : |
|
return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)] |
|
else: |
|
return [gr.update(visible=True), gr.update(visible=True), gr.update(visible=True), gr.update(visible=True), |
|
gr.update(visible=False)] |
|
|
|
|
|
with gr.Blocks(theme=gr.themes.Soft()) as demo : |
|
|
|
with gr.Group() as group_components : |
|
date_beg =Calendar(type="string", label="Почакова дата пошуку", info="Натисніть на значок календаря для вибору дати", render = False, value="2010-01-01") |
|
date_end = Calendar(type="string", label="Кінцева дата пошуку", info="Натисніть на значок календаря для вибору дати", render = False, value="2025-01-01") |
|
|
|
subject_input = gr.Dropdown( |
|
["Президент України", "Кабінет міністрів України", "Народний депутат України"], multiselect=True, label="Ініціатор", info="Виберіть ініціатора законопроєкту", render=False) |
|
rubric = gr.Dropdown(['Двосторонні міжнародні угоди', 'Багатосторонні міжнародні угоди', |
|
'Галузевий розвиток', 'Економічна політика', |
|
'Державне будівництво', 'Соціальна політика', 'Правова політика', |
|
'Безпека і оборона', 'Гуманітарна політика'], multiselect=True, label='Тематика', info="Оберіть, яких галузей стосується законопроєкт", render=False) |
|
|
|
|
|
act_number = gr.Textbox(label='Номер законодавчого акту', placeholder="Наприклад: 861-20 ",visible= False, render=False) |
|
|
|
|
|
action_checkbox = gr.Checkbox(label="Хочу отримати підсумок одного документу", value=False, render=False) |
|
action_checkbox.input(fn=change_group, inputs= [action_checkbox], outputs = [subject_input, date_beg, date_end, rubric, act_number]) |
|
gr.ChatInterface( |
|
|
|
generate_answer, |
|
chatbot=gr.Chatbot(height=400, render = False), |
|
textbox = gr.Textbox(placeholder="Ввести питання", container=False, scale=7, render = False), |
|
title="Законодавчий Помічник", |
|
description="Спитай мене про будь-які регуляції в чинних законах України.", |
|
|
|
|
|
cache_examples=False, |
|
retry_btn=None, |
|
undo_btn=None, |
|
clear_btn=None, |
|
submit_btn="Спитати", |
|
stop_btn=None, |
|
additional_inputs=[ |
|
action_checkbox, |
|
subject_input, |
|
rubric, |
|
date_beg, |
|
date_end, |
|
act_number ], |
|
additional_inputs_accordion=gr.Accordion(open=False, label="Додаткові фільтри", render = False) |
|
|
|
) |
|
|
|
|
|
|
|
demo.launch(share = False) |
|
|
|
|
|
|