import streamlit as st import pandas as pd import bm25s from bm25s.hf import BM25HF from langchain_core.prompts import ChatPromptTemplate, HumanMessagePromptTemplate from langchain.docstore.document import Document import torch import os from huggingface_hub import login from langchain_groq import ChatGroq @st.cache_resource def load_data(): retriever = BM25HF.load_from_hub( "tien314/bm25s-version2", load_corpus=True, mmap=True) return retriever def load_model(): prompt = ChatPromptTemplate.from_messages([ HumanMessagePromptTemplate.from_template( f""" Extract the appropriate 8-digit HS Code base on the product description and retrieved document by thoroughly analyzing its details and utilizing a reliable and up-to-date HS Code database for accurate results. Only return the HS Code as a 8-digit number . Example: 1234567878 Context: {{context}} Description: {{description}} Answer: """ ) ]) #device = "cuda" if torch.cuda.is_available() else "cpu" #llm = OllamaLLM(model="gemma2", temperature=0, device=device) api_key = "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" llm = ChatGroq(model = "llama-3.1-70b-versatile", temperature = 0,api_key = api_key) chain = prompt|llm return chain def process_input(sentence): docs, _ = st.session_state.retriever.retrieve(bm25s.tokenize(sentence), k=15) documents =[] for doc in docs[0]: documents.append(Document(doc['text'])) return documents if 'retriever' not in st.session_state: st.session_state.retriever = None if 'chain' not in st.session_state: st.session_state.chain = None if st.session_state.retriever is None: st.session_state.retriever = load_data() if st.session_state.chain is None: st.session_state.chain = load_model() sentence = st.text_input("please enter description:") if sentence !='': documents = process_input(sentence) hscode = st.session_state.chain.invoke({'context': documents,'description':sentence}) st.write("answer:",hscode.content)