linguAIcoach / src /main.py
alvaroalon2's picture
chore: first commit
18c0acd
raw
history blame
8.12 kB
import logging
import os
import streamlit as st
from langchain.schema import AIMessage, HumanMessage
from langchain_core.pydantic_v1 import SecretStr
from src.models.model_factory import EvaluationChatModelFactory, GeneratorModelFactory
from src.utils import convert_json_to_str, read_json, read_plain_text
from src.whisper_transcription import whisper_stt
logging.basicConfig(level=os.getenv("ENV", "INFO"))
start_container = st.container()
with start_container:
col1_head, col2_head = st.columns([0.15, 0.85])
col1_head.image("./docs/images/logo.png", width=80)
col2_head.title("LinguAIcoach")
st.subheader("_Your :red[AI] English Teacher_")
st.divider()
model_factory = EvaluationChatModelFactory()
gen_factory = GeneratorModelFactory()
exam_guides = read_json("./exam_guides/lessons.json")
config = read_json("./src/config.json")
input_text = None
input_voice = None
with st.sidebar as sidebar:
exam_selection = st.sidebar.selectbox("Select Exam type", list(exam_guides.keys()))
exam_selection = exam_selection if exam_selection else str(next(iter(exam_guides.keys())))
exam_info = exam_guides[exam_selection]
st.session_state.openai_api_key = SecretStr(
os.getenv(
"OPENAI_API_KEY",
st.text_input("OpenAI API Key", key="chatbot_api_key", type="password"),
)
)
st.markdown("[:red[Get your OpenAI API key]](https://platform.openai.com/account/api-keys)")
st.markdown(
"[![GitHub](https://img.shields.io/badge/github-%23121011.svg?style=for-the-badge&logo=github&logoColor=white)](https://github.com/alvaroalon2/LinguAIcoach)"
)
logging.debug(f"Selected exam: {exam_selection}")
logging.debug(f"Exam guides: {exam_guides}")
exam_content = read_plain_text(exam_info["file"])
exam_desc = exam_info["description"]
init_prompt = exam_info["init_prompt"]
level = exam_info["level"]
if "current_exam" not in st.session_state:
st.session_state["current_exam"] = ""
if st.session_state["current_exam"] != exam_selection:
logging.debug("Resetting state")
input_text = None
input_voice = None
st.session_state["current_exam"] = exam_selection
st.session_state["messages"] = [AIMessage(content=init_prompt)]
st.session_state.first_run = True
st.session_state["image_url"] = None
st.session_state["question"] = []
st.session_state["exam_type"] = exam_info["type"]
response_container = st.container()
start_container = st.container()
with start_container:
col1_start, col2_start, col3_start = st.columns([0.4, 0.2, 0.4])
voice_container = st.container()
with voice_container:
col1_voice, col2_voice, col3_voice = st.columns([0.39, 0.22, 0.39])
if not st.session_state.openai_api_key:
logging.warning("Please add your OpenAI API key to continue.")
st.info("Please add your OpenAI API key to continue.")
st.stop()
model_eval = model_factory.create_model(
model_class=st.session_state["exam_type"],
openai_api_key=st.session_state.openai_api_key,
level=level,
chat_temperature=config["OPENAI_TEMPERATURE_EVAL"],
)
generator = gen_factory.create_model(
model_class=st.session_state["exam_type"],
openai_api_key=st.session_state.openai_api_key,
exam_prompt=exam_content,
level=level,
description=exam_desc,
chat_temperature=config["OPENAI_TEMPERATURE_GEN"],
history_chat=[
AIMessage(content=f"Previous question (Don't repeat): {q.content}")
for q in st.session_state["question"][-config["N_MAX_HISTORY"] :]
],
img_size=config["IMG_GEN_SIZE"],
)
if "messages" in st.session_state:
logging.debug(f"Starting exercises for exam_type: {st.session_state['exam_type']}")
placeholder_start = col2_start.empty()
start_button = placeholder_start.button("Start exercises!", disabled=not (st.session_state.first_run))
if start_button:
logging.debug("Start button clicked, running exercise")
st.session_state.first_run = False
if st.session_state["exam_type"] == "qa":
start_response = generator.generate()
logging.info(f"Generated first question: {start_response}")
st.session_state["question"].append(AIMessage(content=start_response))
st.session_state["messages"].append(st.session_state["question"][-1])
elif st.session_state["exam_type"] == "img_desc":
st.session_state["image_url"] = generator.generate()
logging.debug(f"Generated first image URL: {st.session_state['image_url']}")
st.session_state["messages"].append(
AIMessage(
content=st.session_state["image_url"],
response_metadata={"type": "image"},
)
)
if not st.session_state.first_run:
placeholder_start.empty()
for msg in st.session_state.messages:
if isinstance(msg, HumanMessage):
if getattr(msg, "response_metadata", None) and msg.response_metadata["type"] == "image":
with response_container.chat_message("user"):
response_container.image(str(msg.content), caption="AI generated image")
else:
response_container.chat_message("user").write(msg.content)
elif getattr(msg, "response_metadata", None) and msg.response_metadata["type"] == "image":
with response_container.chat_message("assistant"):
response_container.write("Describe what you can see in the following image: ")
response_container.image(msg.content, caption="AI generated image")
else:
response_container.chat_message("assistant").write(msg.content)
placeholder_input = st.empty()
if st.session_state.first_run:
placeholder_input.empty()
col2_voice.empty()
else:
input_text = placeholder_input.chat_input(disabled=st.session_state.first_run) or None
logging.debug(f"Input text: {input_text}")
with col2_voice:
input_voice = whisper_stt(language="en", n_max_retry=config["N_MAX_RETRY"])
logging.debug(f"Input voice: {input_voice}")
input_prompt = input_text or input_voice
if input_prompt := input_text or input_voice:
logging.info(f"Processing input: {input_prompt}")
match st.session_state["exam_type"]:
case "qa":
response = model_eval.predict(input_prompt, st.session_state["question"][-1].content)
logging.info(f"QA model response: {response}")
case "img_desc":
response = model_eval.predict(input_prompt, st.session_state["image_url"])
logging.info(f"Image description model response: {response}")
response_str = convert_json_to_str(response)
with response_container:
st.session_state.messages.append(HumanMessage(content=input_prompt))
logging.debug(f"Adding user message to session state: {input_prompt}")
response_container.chat_message("user").write(input_prompt)
st.session_state.messages.append(AIMessage(content=response_str))
logging.debug(f"Adding AI message to session state: {response_str}")
response_container.chat_message("assistant").write(response_str)
if st.session_state["exam_type"] == "qa":
new_question = generator.generate()
logging.info(f"Generated new question: {new_question}")
st.session_state["question"].append(AIMessage(content=new_question))
st.session_state.messages.append(AIMessage(content=new_question))
response_container.chat_message("assistant").write(new_question)
elif st.session_state["exam_type"] == "img_desc":
new_image = generator.generate()
st.session_state["image_url"] = new_image
st.session_state.messages.append(AIMessage(content=new_image))
logging.info(f"Generated new image URL: {st.session_state['image_url']}")
response_container.chat_message("assistant").write("Describe what you can see in the following image: ")
with response_container.chat_message("assistant"):
response_container.image(st.session_state["image_url"], caption="AI generated image")