import streamlit as st from langchain_core.messages import AIMessage def get_openai_token_usage(aimessage: AIMessage, model_info: dict): input_tokens = aimessage.usage_metadata["input_tokens"] output_tokens = aimessage.usage_metadata["output_tokens"] cost = ( input_tokens * 1e-6 * model_info["cost"]["pmi"] + output_tokens * 1e-6 * model_info["cost"]["pmo"] ) return { "input_tokens": input_tokens, "output_tokens": output_tokens, "cost": cost, } def get_anthropic_token_usage(aimessage: AIMessage, model_info: dict): input_tokens = aimessage.usage_metadata["input_tokens"] output_tokens = aimessage.usage_metadata["output_tokens"] cost = ( input_tokens * 1e-6 * model_info["cost"]["pmi"] + output_tokens * 1e-6 * model_info["cost"]["pmo"] ) return { "input_tokens": input_tokens, "output_tokens": output_tokens, "cost": cost, } def get_together_token_usage(aimessage: AIMessage, model_info: dict): input_tokens = aimessage.usage_metadata["input_tokens"] output_tokens = aimessage.usage_metadata["output_tokens"] cost = ( input_tokens * 1e-6 * model_info["cost"]["pmi"] + output_tokens * 1e-6 * model_info["cost"]["pmo"] ) return { "input_tokens": input_tokens, "output_tokens": output_tokens, "cost": cost, } def get_google_token_usage(aimessage: AIMessage, model_info: dict): input_tokens = aimessage.usage_metadata["input_tokens"] output_tokens = aimessage.usage_metadata["output_tokens"] cost = ( input_tokens * 1e-6 * model_info["cost"]["pmi"] + output_tokens * 1e-6 * model_info["cost"]["pmo"] ) return { "input_tokens": input_tokens, "output_tokens": output_tokens, "cost": cost, } def get_token_usage(aimessage: AIMessage, model_info: dict, provider: str): match provider: case "OpenAI": return get_openai_token_usage(aimessage, model_info) case "Anthropic": return get_anthropic_token_usage(aimessage, model_info) case "Together": return get_together_token_usage(aimessage, model_info) case "Google": return get_google_token_usage(aimessage, model_info) case _: raise ValueError() def display_api_usage( aimessage: AIMessage, model_info: dict, provider: str, tag: str | None = None ): with st.container(border=True): if tag is None: st.write("API Usage") else: st.write(f"API Usage ({tag})") token_usage = get_token_usage(aimessage, model_info, provider) col1, col2, col3 = st.columns(3) with col1: st.metric("Input Tokens", token_usage["input_tokens"]) with col2: st.metric("Output Tokens", token_usage["output_tokens"]) with col3: st.metric("Cost", f"${token_usage['cost']:.4f}") with st.expander("AIMessage Metadata"): dd = {key: val for key, val in aimessage.dict().items() if key != "content"} st.write(dd)