import pandas as pd import bm25s from langchain_core.prompts import ChatPromptTemplate, HumanMessagePromptTemplate from langchain.docstore.document import Document import torch from langchain_ollama.llms import OllamaLLM from langchain.chains import LLMChain @st.cache_data def load_data(): retriever = bm25s.BM25.load("bm25s_very_big_index", mmap=True, load_corpus = True) return retriever def load_model(): prompt = ChatPromptTemplate.from_messages([ HumanMessagePromptTemplate.from_template( f""" Extract the appropriate 8-digit HS Code base on the product description and retrieved document by thoroughly analyzing its details and utilizing a reliable and up-to-date HS Code database for accurate results. Only return the HS Code as a 8-digit number . Example: 1234567878 Context: {{context}} Description: {{description}} Answer: """ ) ]) device = "cuda" if torch.cuda.is_available() else "cpu" llm = OllamaLLM(model="gemma2", temperature=0, device=device) chain = prompt|llm return chain def process_input(sentence): docs, _ = retriever.retrieve(bm25s.tokenize(sentence), k=15) documents =[] for doc in docs[0]: documents.append(Document(doc['text'])) return documents if 'retriever' not in st.session_state: st.session_state.retriever = None if 'chain' not in st.session_state: st.session_state.chain = None if st.session_state.retriever is None: st.session_state.retriever = load_data() if st.session_state.chain is None: st.session_state.chain = load_model() sentence = st.text_input("please enter description:") if sentence !='': documents = process_input(sentence) hscode = chain.invoke({'context': documents,'description':sentence}) st.write("answer:",hscode)