|
"""Demo UI to show different levels of LLM security.""" |
|
|
|
import re |
|
|
|
import pandas as pd |
|
from llm_guard.input_scanners import PromptInjection |
|
import streamlit as st |
|
|
|
import config |
|
import utils |
|
import llm |
|
from card import card |
|
|
|
|
|
hint_color = "rgba(225, 166, 28, 0.1)" |
|
info_color = "rgba(54, 225, 28, 0.1)" |
|
|
|
|
|
st.set_page_config( |
|
page_title="Secret Agent Guardrail Challenge", |
|
layout="wide", |
|
initial_sidebar_state="expanded", |
|
) |
|
|
|
st.logo("images/ML6_logo.png") |
|
st.title("π΅οΈ Secret Agent Guardrail Challenge") |
|
st.info( |
|
"You are a secret agent meeting your informant in a bar. Convince him to give you his secret! But be prepared, with every new level the informant will be more cautious.", |
|
icon="π", |
|
) |
|
|
|
|
|
level_tabs = st.tabs([f"Level {i}" for i in range(len(config.LEVELS))]) |
|
|
|
|
|
def init_session_state(state_level: str, default_value: any): |
|
if state_level not in st.session_state: |
|
st.session_state[state_level] = default_value |
|
|
|
|
|
for idx, level in enumerate(config.LEVELS): |
|
secret = config.SECRETS[idx] |
|
|
|
|
|
init_session_state(f"solved_{level}", False) |
|
init_session_state(f"prompt_try_count_{level}", 0) |
|
init_session_state(f"secret_guess_count_{level}", 0) |
|
init_session_state(f"intermediate_output_holder_{level}", None) |
|
init_session_state(f"show_benefits_drawbacks_{level}", False) |
|
|
|
|
|
for i in range(4): |
|
init_session_state(f"opened_hint_{level}_{i}", False) |
|
|
|
with level_tabs[idx]: |
|
header_col1, header_col2 = st.columns(2, gap="medium") |
|
header_col1.subheader(f"{config.LEVEL_EMOJIS[idx]} Level {idx}") |
|
header_col2.subheader(f"Need help ...") |
|
|
|
col1, col2 = st.columns(2, gap="medium") |
|
|
|
with col1: |
|
with st.container(height=600, border=False): |
|
with st.container(border=True): |
|
txt = st.text_area( |
|
"Provide your prompt here:", |
|
key=f"txt_{level}", |
|
label_visibility="visible", |
|
height=200, |
|
placeholder="Your prompt", |
|
max_chars=config.MAX_INPUT_CHARS, |
|
) |
|
btn_submit_prompt = st.button( |
|
"Send prompt", key=f"submit_prompt_{level}" |
|
) |
|
output = None |
|
|
|
if len(txt) > 0 and btn_submit_prompt: |
|
st.session_state[f"prompt_try_count_{level}"] += 1 |
|
with st.container(border=True): |
|
st.write("Response:") |
|
|
|
|
|
if level == "llm_judge_input": |
|
invalid, output_raw = llm.run_judge( |
|
level, {"user_input": txt} |
|
) |
|
|
|
st.session_state[ |
|
f"intermediate_output_holder_{level}" |
|
] = (invalid, output_raw) |
|
if invalid: |
|
st.error( |
|
"Malicious user input detected. Try again." |
|
) |
|
else: |
|
st.write_stream( |
|
llm.stream_request(level, secret, txt) |
|
) |
|
|
|
elif level == "basic_check_secret_in_output": |
|
output = "".join(llm.stream_request(level, secret, txt)) |
|
invalid = secret.lower() in output.lower() |
|
st.session_state[ |
|
f"intermediate_output_holder_{level}" |
|
] = invalid |
|
if invalid: |
|
st.error( |
|
"Secret detected in user input. Try again." |
|
) |
|
else: |
|
st.write(output) |
|
elif level == "advanced_check_secret_in_output": |
|
output = "".join(llm.stream_request(level, secret, txt)) |
|
invalid = utils.is_subsequence(output, secret) |
|
st.session_state[ |
|
f"intermediate_output_holder_{level}" |
|
] = invalid |
|
if invalid: |
|
st.error( |
|
"Secret detected in user input. Try again." |
|
) |
|
else: |
|
st.write(output) |
|
|
|
elif level == "llm_judge_output": |
|
output = "".join(llm.stream_request(level, secret, txt)) |
|
invalid, output_raw = llm.run_judge( |
|
level, {"secret": secret, "message": output} |
|
) |
|
st.session_state[ |
|
f"intermediate_output_holder_{level}" |
|
] = (invalid, output_raw) |
|
if invalid: |
|
st.error( |
|
"Secret detected in user input. Try again." |
|
) |
|
else: |
|
st.write(output) |
|
elif level == "chain_of_thought": |
|
output = "".join(llm.stream_request(level, secret, txt)) |
|
|
|
new_output = re.findall( |
|
r"(?:<ANSWER>)([^;]*)(?:<\/ANSWER>)", output |
|
)[0] |
|
st.write(new_output) |
|
st.session_state[ |
|
f"intermediate_output_holder_{level}" |
|
] = output |
|
elif level == "guard_framework": |
|
|
|
is_valid, risk_score = utils.is_malicious(txt) |
|
st.session_state[ |
|
f"intermediate_output_holder_{level}" |
|
] = (is_valid, risk_score) |
|
if not is_valid: |
|
st.error( |
|
"Malicious user input detected. Try again." |
|
) |
|
else: |
|
st.write_stream( |
|
llm.stream_request(level, secret, txt) |
|
) |
|
elif level == "preflight_prompt": |
|
valid, output_raw = llm.run_judge( |
|
level, {"user_input": txt}, expected_output="dog" |
|
) |
|
st.session_state[ |
|
f"intermediate_output_holder_{level}" |
|
] = (valid, output_raw) |
|
|
|
if valid: |
|
st.write_stream( |
|
llm.stream_request(level, secret, txt) |
|
) |
|
else: |
|
st.error( |
|
"Malicious user input detected. Try again." |
|
) |
|
else: |
|
st.write_stream(llm.stream_request(level, secret, txt)) |
|
|
|
with st.container(border=True): |
|
secret_guess = st.text_input( |
|
"What is the secret?", |
|
key=f"guess_{level}", |
|
placeholder="Your guess", |
|
) |
|
btn_submit_guess = st.button( |
|
"Submit guess", key=f"submit_guess_{level}" |
|
) |
|
|
|
if btn_submit_guess: |
|
st.session_state[f"secret_guess_count_{level}"] += 1 |
|
if secret_guess.lower() == secret.lower(): |
|
st.success("You found the secret!") |
|
st.session_state[f"solved_{level}"] = True |
|
else: |
|
st.error("Wrong guess. Try again.") |
|
|
|
with col2: |
|
with st.container(border=True, height=600): |
|
st.info( |
|
"There are three levels of hints and a full explanation available to you. But be careful, if you open them before solving the secret, it will show up in your record.", |
|
icon="βΉοΈ", |
|
) |
|
|
|
hint_1_cont = card(color=hint_color) |
|
hint1 = hint_1_cont.toggle( |
|
"Show hint 1 - **Basic description of security strategy**", |
|
key=f"hint1_checkbox_{level}", |
|
) |
|
if hint1: |
|
|
|
st.session_state[f"opened_hint_{level}_0"] = ( |
|
True |
|
if st.session_state[f"opened_hint_{level}_0"] |
|
else not st.session_state[f"solved_{level}"] |
|
) |
|
|
|
hint_1_cont.write(config.LEVEL_DESCRIPTIONS[level]["hint1"]) |
|
|
|
hint_2_cont = card(color=hint_color) |
|
hint2 = hint_2_cont.toggle( |
|
"Show hint 2 - **Backend code execution**", |
|
key=f"hint2_checkbox_{level}", |
|
) |
|
if hint2: |
|
st.session_state[f"opened_hint_{level}_1"] = ( |
|
True |
|
if st.session_state[f"opened_hint_{level}_1"] |
|
else not st.session_state[f"solved_{level}"] |
|
) |
|
|
|
user_input_holder = txt if len(txt) > 0 else None |
|
|
|
prompts = llm.get_full_prompt( |
|
level, {"user_input": user_input_holder} |
|
) |
|
|
|
def show_base_prompt(): |
|
|
|
for key, val in prompts.items(): |
|
desc = key.replace("_", " ").capitalize() |
|
hint_2_cont.write(f"*{desc}:*") |
|
hint_2_cont.code(val, language=None) |
|
|
|
if level == "llm_judge_input": |
|
special_prompt = llm.get_full_prompt( |
|
llm.secondary_llm_call[level], |
|
{"user_input": user_input_holder}, |
|
) |
|
|
|
hint_2_cont.write( |
|
"*Step 1:* A **LLM judge** reviews the user input and determines if it is malicious or not." |
|
) |
|
hint_2_cont.write("**LLM judge prompt:**") |
|
for key, val in special_prompt.items(): |
|
hint_2_cont.code(val, language=None) |
|
hint_2_cont.write("The response of the LLM judge:") |
|
intermediate_output = st.session_state[ |
|
f"intermediate_output_holder_{level}" |
|
] |
|
if intermediate_output is None: |
|
hint_2_cont.warning("Please submit a prompt first.") |
|
|
|
else: |
|
invalid, output_raw = intermediate_output |
|
hint_2_cont.code(output_raw, language=None) |
|
hint_2_cont.write( |
|
f"The prompt was determined as **{'malicious' if invalid else 'not malicious'}** and therefor step 2 is executed." |
|
) |
|
|
|
hint_2_cont.write( |
|
"*Step 2:* If the user input is not classified as malicious, the prompt containing the actual secret is executed and the response is shown." |
|
) |
|
hint_2_cont.write("**Actual prompt:**") |
|
show_base_prompt() |
|
elif level == "basic_check_secret_in_output": |
|
hint_2_cont.write("*Step 1:* The following prompt is executed:") |
|
show_base_prompt() |
|
hint_2_cont.write( |
|
"*Step 2:* In the response text of the LLM call, the secret is searched for with a simple python expression . If it contains the secret, the output will not be shown." |
|
) |
|
intermediate_output = st.session_state[ |
|
f"intermediate_output_holder_{level}" |
|
] |
|
hint_2_cont.write("The code execution looks like this:") |
|
if intermediate_output is not None: |
|
hint_2_cont.code( |
|
f"secret.lower() in output.lower() = {intermediate_output}" |
|
) |
|
hint_2_cont.write( |
|
f"And the output **{'was not' if intermediate_output else 'was'}** shown to the user." |
|
) |
|
else: |
|
hint_2_cont.warning("Please submit a prompt first.") |
|
elif level == "advanced_check_secret_in_output": |
|
hint_2_cont.write("*Step 1:* The following prompt is executed:") |
|
show_base_prompt() |
|
hint_2_cont.write( |
|
"*Step 2:* In the response text of the LLM call, the secret is searched for with a python function `is_subsequence` which looks for substrings. If it contains the secret, the output will not be shown." |
|
) |
|
with hint_2_cont: |
|
utils.is_subsequence |
|
hint_2_cont.write("The code execution looks like this:") |
|
intermediate_output = st.session_state[ |
|
f"intermediate_output_holder_{level}" |
|
] |
|
if intermediate_output is not None: |
|
hint_2_cont.code( |
|
f"utils.is_subsequence(output, secret) = {intermediate_output}" |
|
) |
|
hint_2_cont.write( |
|
f"And the output **{'was not' if intermediate_output else 'was'}** shown to the user." |
|
) |
|
else: |
|
hint_2_cont.warning("Please submit a prompt first.") |
|
elif level == "llm_judge_output": |
|
hint_2_cont.write("*Step 1:* The following prompt is executed:") |
|
show_base_prompt() |
|
hint_2_cont.write( |
|
"*Step 2:* The response of the LLM call is checked by a **LLM judge**. The judge checks if the secret is hidden in the response." |
|
) |
|
special_prompt = llm.get_full_prompt( |
|
llm.secondary_llm_call[level], |
|
{"message": output}, |
|
) |
|
for key, val in special_prompt.items(): |
|
hint_2_cont.code(val, language=None) |
|
hint_2_cont.write("The response of the LLM judge:") |
|
intermediate_output = st.session_state[ |
|
f"intermediate_output_holder_{level}" |
|
] |
|
if intermediate_output is None: |
|
hint_2_cont.warning("Please submit a prompt first.") |
|
else: |
|
invalid, output_raw = intermediate_output |
|
hint_2_cont.code(output_raw, language=None) |
|
hint_2_cont.write( |
|
f"The LLM-judge **{'did' if invalid else 'did not'}** find the secret in the answer." |
|
) |
|
elif level == "chain_of_thought": |
|
hint_2_cont.write( |
|
"*Step 1:* The following prompt with Chain-of-thought reasoning is executed. But only the finale answer is displayed to the user:" |
|
) |
|
show_base_prompt() |
|
hint_2_cont.write( |
|
"The full model output, including the reasoning:" |
|
) |
|
intermediate_output = st.session_state[ |
|
f"intermediate_output_holder_{level}" |
|
] |
|
if intermediate_output is None: |
|
hint_2_cont.warning("Please submit a prompt first.") |
|
else: |
|
hint_2_cont.code(intermediate_output, language=None) |
|
elif level == "guard_framework": |
|
hint_2_cont.write( |
|
"*Step 1:* The user input is reviewed with the pre-build framework `LLM Guard` to check for prompt injections. It uses a [Huggingface model](https://huggingface.co/protectai/deberta-v3-base-prompt-injection-v2) specialized in detecting prompt injections." |
|
) |
|
with hint_2_cont: |
|
PromptInjection |
|
hint_2_cont.write("The output of the guard looks like this:") |
|
intermediate_output = st.session_state[ |
|
f"intermediate_output_holder_{level}" |
|
] |
|
if intermediate_output is None: |
|
hint_2_cont.warning("Please submit a prompt first.") |
|
else: |
|
is_valid, risk_score = intermediate_output |
|
hint_2_cont.code( |
|
f""" |
|
prompt is valid: {is_valid} |
|
Prompt has a risk score of: {risk_score}""", |
|
language=None, |
|
) |
|
hint_2_cont.write( |
|
f"The Huggingface model **{'did not' if is_valid else 'did'}** predict a prompt injection." |
|
) |
|
|
|
hint_2_cont.write( |
|
"*Step 2:* If the user input is valid, the following prompt is executed and the response is shown to the user:" |
|
) |
|
show_base_prompt() |
|
elif level == "preflight_prompt": |
|
hint_2_cont.write( |
|
"*Step 1:* The following pre-flight prompt is executed to see if the user input changes the expected output:" |
|
) |
|
special_prompt = llm.get_full_prompt( |
|
llm.secondary_llm_call[level], |
|
{"user_input": user_input_holder}, |
|
) |
|
|
|
hint_2_cont.code(special_prompt["user_prompt"], language=None) |
|
hint_2_cont.write("The output of the pre-flight prompt is:") |
|
|
|
intermediate_output = st.session_state[ |
|
f"intermediate_output_holder_{level}" |
|
] |
|
if intermediate_output is None: |
|
hint_2_cont.warning("Please submit a prompt first.") |
|
else: |
|
is_valid, output_raw = intermediate_output |
|
hint_2_cont.code(output_raw, language=None) |
|
hint_2_cont.write( |
|
f"The output of the pre-flight prompt **{'was' if is_valid else 'was not'}** as expected." |
|
) |
|
hint_2_cont.write( |
|
"*Step 2:* If the output of the pre-flight prompt is as expected, the following prompt is executed and the response is shown to the user:" |
|
) |
|
show_base_prompt() |
|
else: |
|
hint_2_cont.write( |
|
"*Step 1:* The following prompt is executed and the full response is shown to the user:" |
|
) |
|
show_base_prompt() |
|
|
|
hint_3_cont = card(color=hint_color) |
|
|
|
hint3 = hint_3_cont.toggle( |
|
"Show hint 3 - **Prompt solution example**", |
|
key=f"hint3_checkbox_{level}", |
|
) |
|
if hint3: |
|
st.session_state[f"opened_hint_{level}_2"] = ( |
|
True |
|
if st.session_state[f"opened_hint_{level}_2"] |
|
else not st.session_state[f"solved_{level}"] |
|
) |
|
|
|
hint_3_cont.code( |
|
config.LEVEL_DESCRIPTIONS[level]["hint3"], |
|
language=None, |
|
) |
|
hint_3_cont.info("*May not always work") |
|
|
|
info_cont = card(color=info_color) |
|
|
|
info_toggle = info_cont.toggle( |
|
"Show info - **Explanation and real-life usage**", |
|
key=f"info_checkbox_{level}", |
|
) |
|
if info_toggle: |
|
st.session_state[f"opened_hint_{level}_3"] = ( |
|
True |
|
if st.session_state[f"opened_hint_{level}_3"] |
|
else not st.session_state[f"solved_{level}"] |
|
) |
|
|
|
info_cont.write("### " + config.LEVEL_DESCRIPTIONS[level]["name"]) |
|
info_cont.write("##### Explanation") |
|
info_cont.write(config.LEVEL_DESCRIPTIONS[level]["explanation"]) |
|
info_cont.write("##### Real-life usage") |
|
info_cont.write(config.LEVEL_DESCRIPTIONS[level]["real_life"]) |
|
df = pd.DataFrame( |
|
{ |
|
"Benefits": [config.LEVEL_DESCRIPTIONS[level]["benefits"]], |
|
"Drawbacks": [ |
|
config.LEVEL_DESCRIPTIONS[level]["drawbacks"] |
|
], |
|
}, |
|
) |
|
info_cont.markdown( |
|
df.style.hide(axis="index").to_html(), unsafe_allow_html=True |
|
) |
|
|
|
|
|
def build_hint_status(level: str): |
|
hint_status = "" |
|
for i in range(4): |
|
if st.session_state[f"opened_hint_{level}_{i}"]: |
|
hint_status += f"β {i+1}<br>" |
|
return hint_status |
|
|
|
|
|
with st.expander("π Record", expanded=True): |
|
show_mitigation_toggle = st.toggle( |
|
"[SPOILER] Show all mitigation techniques with their benefits and drawbacks", |
|
key=f"show_mitigation", |
|
) |
|
if show_mitigation_toggle: |
|
st.warning("All mitigation techniques are shown.", icon="π¨") |
|
|
|
|
|
table_data = [] |
|
for idx, level in enumerate(config.LEVELS): |
|
if show_mitigation_toggle: |
|
|
|
st.session_state[f"opened_hint_{level}_3"] = ( |
|
True |
|
if st.session_state[f"opened_hint_{level}_3"] |
|
else not st.session_state[f"solved_{level}"] |
|
) |
|
|
|
table_data.append( |
|
[ |
|
idx, |
|
config.LEVEL_EMOJIS[idx], |
|
st.session_state[f"prompt_try_count_{level}"], |
|
st.session_state[f"secret_guess_count_{level}"], |
|
build_hint_status(level), |
|
"β
" if st.session_state[f"solved_{level}"] else "β", |
|
config.SECRETS[idx] if st.session_state[f"solved_{level}"] else "...", |
|
( |
|
"<b>" + config.LEVEL_DESCRIPTIONS[level]["name"] + "</b>" |
|
if st.session_state[f"opened_hint_{level}_0"] |
|
or st.session_state[f"opened_hint_{level}_1"] |
|
or st.session_state[f"opened_hint_{level}_2"] |
|
or st.session_state[f"opened_hint_{level}_3"] |
|
or show_mitigation_toggle |
|
else "..." |
|
), |
|
( |
|
config.LEVEL_DESCRIPTIONS[level]["benefits"] |
|
if st.session_state[f"opened_hint_{level}_3"] |
|
or show_mitigation_toggle |
|
else "..." |
|
), |
|
( |
|
config.LEVEL_DESCRIPTIONS[level]["drawbacks"] |
|
if st.session_state[f"opened_hint_{level}_3"] |
|
or show_mitigation_toggle |
|
else "..." |
|
), |
|
] |
|
) |
|
|
|
|
|
st.markdown( |
|
pd.DataFrame( |
|
table_data, |
|
columns=[ |
|
"lvl", |
|
"emoji", |
|
"Prompt tries", |
|
"Secret guesses", |
|
"Hint used", |
|
"Solved", |
|
"Secret", |
|
"Mitigation", |
|
"Benefits", |
|
"Drawbacks", |
|
], |
|
|
|
) |
|
.style.hide(axis="index") |
|
.to_html(), |
|
unsafe_allow_html=True, |
|
|
|
) |
|
|
|
|
|
|
|
|
|
|