|
import streamlit as st |
|
import pandas as pd |
|
import bm25s |
|
from bm25s.hf import BM25HF |
|
from langchain_core.prompts import ChatPromptTemplate, HumanMessagePromptTemplate |
|
from langchain.docstore.document import Document |
|
import torch |
|
from langchain_ollama.llms import OllamaLLM |
|
from langchain.chains import LLMChain |
|
import os |
|
from huggingface_hub import login |
|
from langchain_groq import ChatGroq |
|
|
|
|
|
@st.cache_resource |
|
def load_data(): |
|
retriever = BM25HF.load_from_hub( |
|
"tien314/bm25s-version2", load_corpus=True, mmap=True) |
|
return retriever |
|
|
|
def load_model(): |
|
prompt = ChatPromptTemplate.from_messages([ |
|
HumanMessagePromptTemplate.from_template( |
|
f""" |
|
Extract the appropriate 8-digit HS Code base on the product description and retrieved document by thoroughly analyzing its details and utilizing a reliable and up-to-date HS Code database for accurate results. |
|
Only return the HS Code as a 8-digit number . |
|
Example: 1234567878 |
|
Context: {{context}} |
|
Description: {{description}} |
|
Answer: |
|
""" |
|
) |
|
]) |
|
|
|
|
|
|
|
|
|
|
|
api_key = "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" |
|
|
|
llm = ChatGroq(model = "llama-3.1-70b-versatile", temperature = 0,api_key = api_key) |
|
chain = prompt|llm |
|
return chain |
|
|
|
def process_input(sentence): |
|
docs, _ = st.session_state.retriever.retrieve(bm25s.tokenize(sentence), k=15) |
|
documents =[] |
|
for doc in docs[0]: |
|
documents.append(Document(doc['text'])) |
|
return documents |
|
|
|
if 'retriever' not in st.session_state: |
|
st.session_state.retriever = None |
|
|
|
if 'chain' not in st.session_state: |
|
st.session_state.chain = None |
|
|
|
if st.session_state.retriever is None: |
|
st.session_state.retriever = load_data() |
|
|
|
if st.session_state.chain is None: |
|
st.session_state.chain = load_model() |
|
|
|
sentence = st.text_input("please enter description:") |
|
|
|
if sentence !='': |
|
documents = process_input(sentence) |
|
hscode = st.session_state.chain.invoke({'context': documents,'description':sentence}) |
|
st.write("answer:",hscode) |
|
|