hscode8 / app.py
tien314's picture
Update app.py
9f2a90c verified
raw
history blame
2.22 kB
import streamlit as st
import pandas as pd
import bm25s
from bm25s.hf import BM25HF
from langchain_core.prompts import ChatPromptTemplate, HumanMessagePromptTemplate
from langchain.docstore.document import Document
import torch
from langchain_ollama.llms import OllamaLLM
from langchain.chains import LLMChain
import os
from huggingface_hub import login
from langchain_groq import ChatGroq
@st.cache_resource
def load_data():
retriever = BM25HF.load_from_hub(
"tien314/bm25s-version2", load_corpus=True, mmap=True)
return retriever
def load_model():
prompt = ChatPromptTemplate.from_messages([
HumanMessagePromptTemplate.from_template(
f"""
Extract the appropriate 8-digit HS Code base on the product description and retrieved document by thoroughly analyzing its details and utilizing a reliable and up-to-date HS Code database for accurate results.
Only return the HS Code as a 8-digit number .
Example: 1234567878
Context: {{context}}
Description: {{description}}
Answer:
"""
)
])
#device = "cuda" if torch.cuda.is_available() else "cpu"
#llm = OllamaLLM(model="gemma2", temperature=0, device=device)
api_key = "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"
llm = ChatGroq(model = "llama-3.1-70b-versatile", temperature = 0,api_key = api_key)
chain = prompt|llm
return chain
def process_input(sentence):
docs, _ = st.session_state.retriever.retrieve(bm25s.tokenize(sentence), k=15)
documents =[]
for doc in docs[0]:
documents.append(Document(doc['text']))
return documents
if 'retriever' not in st.session_state:
st.session_state.retriever = None
if 'chain' not in st.session_state:
st.session_state.chain = None
if st.session_state.retriever is None:
st.session_state.retriever = load_data()
if st.session_state.chain is None:
st.session_state.chain = load_model()
sentence = st.text_input("please enter description:")
if sentence !='':
documents = process_input(sentence)
hscode = st.session_state.chain.invoke({'context': documents,'description':sentence})
st.write("answer:",hscode)