import logging import os import streamlit as st from langchain.schema import AIMessage, HumanMessage from langchain_core.pydantic_v1 import SecretStr from src.models.model_factory import EvaluationChatModelFactory, GeneratorModelFactory from src.utils import convert_json_to_str, read_json, read_plain_text from src.whisper_transcription import whisper_stt logging.basicConfig(level=os.getenv("ENV", "INFO")) start_container = st.container() with start_container: col1_head, col2_head = st.columns([0.15, 0.85]) col1_head.image("./docs/images/logo.png", width=80) col2_head.title("LinguAIcoach") st.subheader("_Your :red[AI] English Teacher_") st.divider() model_factory = EvaluationChatModelFactory() gen_factory = GeneratorModelFactory() exam_guides = read_json("./exam_guides/lessons.json") config = read_json("./src/config.json") input_text = None input_voice = None with st.sidebar as sidebar: exam_selection = st.sidebar.selectbox("Select Exam type", list(exam_guides.keys())) exam_selection = exam_selection if exam_selection else str(next(iter(exam_guides.keys()))) exam_info = exam_guides[exam_selection] st.session_state.openai_api_key = SecretStr( os.getenv( "OPENAI_API_KEY", st.text_input("OpenAI API Key", key="chatbot_api_key", type="password"), ) ) st.markdown("[:red[Get your OpenAI API key]](https://platform.openai.com/account/api-keys)") st.markdown( "[![GitHub](https://img.shields.io/badge/github-%23121011.svg?style=for-the-badge&logo=github&logoColor=white)](https://github.com/alvaroalon2/LinguAIcoach)" ) logging.debug(f"Selected exam: {exam_selection}") logging.debug(f"Exam guides: {exam_guides}") exam_content = read_plain_text(exam_info["file"]) exam_desc = exam_info["description"] init_prompt = exam_info["init_prompt"] level = exam_info["level"] if "current_exam" not in st.session_state: st.session_state["current_exam"] = "" if st.session_state["current_exam"] != exam_selection: logging.debug("Resetting state") input_text = None input_voice = None st.session_state["current_exam"] = exam_selection st.session_state["messages"] = [AIMessage(content=init_prompt)] st.session_state.first_run = True st.session_state["image_url"] = None st.session_state["question"] = [] st.session_state["exam_type"] = exam_info["type"] response_container = st.container() start_container = st.container() with start_container: col1_start, col2_start, col3_start = st.columns([0.4, 0.2, 0.4]) voice_container = st.container() with voice_container: col1_voice, col2_voice, col3_voice = st.columns([0.39, 0.22, 0.39]) if not st.session_state.openai_api_key: logging.warning("Please add your OpenAI API key to continue.") st.info("Please add your OpenAI API key to continue.") st.stop() model_eval = model_factory.create_model( model_class=st.session_state["exam_type"], openai_api_key=st.session_state.openai_api_key, level=level, chat_temperature=config["OPENAI_TEMPERATURE_EVAL"], ) generator = gen_factory.create_model( model_class=st.session_state["exam_type"], openai_api_key=st.session_state.openai_api_key, exam_prompt=exam_content, level=level, description=exam_desc, chat_temperature=config["OPENAI_TEMPERATURE_GEN"], history_chat=[ AIMessage(content=f"Previous question (Don't repeat): {q.content}") for q in st.session_state["question"][-config["N_MAX_HISTORY"] :] ], img_size=config["IMG_GEN_SIZE"], ) if "messages" in st.session_state: logging.debug(f"Starting exercises for exam_type: {st.session_state['exam_type']}") placeholder_start = col2_start.empty() start_button = placeholder_start.button("Start exercises!", disabled=not (st.session_state.first_run)) if start_button: logging.debug("Start button clicked, running exercise") st.session_state.first_run = False if st.session_state["exam_type"] == "qa": start_response = generator.generate() logging.info(f"Generated first question: {start_response}") st.session_state["question"].append(AIMessage(content=start_response)) st.session_state["messages"].append(st.session_state["question"][-1]) elif st.session_state["exam_type"] == "img_desc": st.session_state["image_url"] = generator.generate() logging.debug(f"Generated first image URL: {st.session_state['image_url']}") st.session_state["messages"].append( AIMessage( content=st.session_state["image_url"], response_metadata={"type": "image"}, ) ) if not st.session_state.first_run: placeholder_start.empty() for msg in st.session_state.messages: if isinstance(msg, HumanMessage): if getattr(msg, "response_metadata", None) and msg.response_metadata["type"] == "image": with response_container.chat_message("user"): response_container.image(str(msg.content), caption="AI generated image") else: response_container.chat_message("user").write(msg.content) elif getattr(msg, "response_metadata", None) and msg.response_metadata["type"] == "image": with response_container.chat_message("assistant"): response_container.write("Describe what you can see in the following image: ") response_container.image(msg.content, caption="AI generated image") else: response_container.chat_message("assistant").write(msg.content) placeholder_input = st.empty() if st.session_state.first_run: placeholder_input.empty() col2_voice.empty() else: input_text = placeholder_input.chat_input(disabled=st.session_state.first_run) or None logging.debug(f"Input text: {input_text}") with col2_voice: input_voice = whisper_stt(language="en", n_max_retry=config["N_MAX_RETRY"]) logging.debug(f"Input voice: {input_voice}") input_prompt = input_text or input_voice if input_prompt := input_text or input_voice: logging.info(f"Processing input: {input_prompt}") match st.session_state["exam_type"]: case "qa": response = model_eval.predict(input_prompt, st.session_state["question"][-1].content) logging.info(f"QA model response: {response}") case "img_desc": response = model_eval.predict(input_prompt, st.session_state["image_url"]) logging.info(f"Image description model response: {response}") response_str = convert_json_to_str(response) with response_container: st.session_state.messages.append(HumanMessage(content=input_prompt)) logging.debug(f"Adding user message to session state: {input_prompt}") response_container.chat_message("user").write(input_prompt) st.session_state.messages.append(AIMessage(content=response_str)) logging.debug(f"Adding AI message to session state: {response_str}") response_container.chat_message("assistant").write(response_str) if st.session_state["exam_type"] == "qa": new_question = generator.generate() logging.info(f"Generated new question: {new_question}") st.session_state["question"].append(AIMessage(content=new_question)) st.session_state.messages.append(AIMessage(content=new_question)) response_container.chat_message("assistant").write(new_question) elif st.session_state["exam_type"] == "img_desc": new_image = generator.generate() st.session_state["image_url"] = new_image st.session_state.messages.append(AIMessage(content=new_image)) logging.info(f"Generated new image URL: {st.session_state['image_url']}") response_container.chat_message("assistant").write("Describe what you can see in the following image: ") with response_container.chat_message("assistant"): response_container.image(st.session_state["image_url"], caption="AI generated image")