Spaces:
Running
Running
update system_prompt, add prompt guard for user prompt and safety check for LLM generated content
Browse files
README.md
CHANGED
@@ -107,7 +107,12 @@ HUGGINGFACE_API_KEY=your_api_key_here
|
|
107 |
|
108 |
## Usage
|
109 |
```bash
|
110 |
-
#
|
|
|
|
|
|
|
|
|
|
|
111 |
python main.py
|
112 |
|
113 |
# Access via web browser
|
|
|
107 |
|
108 |
## Usage
|
109 |
```bash
|
110 |
+
# Run the game locally using gpu-compute branch
|
111 |
+
git checkout gpu-compute
|
112 |
+
python main.py
|
113 |
+
|
114 |
+
# Start the game using deployed main branch
|
115 |
+
git checkout main
|
116 |
python main.py
|
117 |
|
118 |
# Access via web browser
|
helper.py
CHANGED
@@ -4,8 +4,10 @@ from dotenv import load_dotenv, find_dotenv
|
|
4 |
import json
|
5 |
import gradio as gr
|
6 |
import torch # first import torch then transformers
|
7 |
-
|
|
|
8 |
from huggingface_hub import InferenceClient
|
|
|
9 |
from transformers import pipeline
|
10 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
11 |
import logging
|
@@ -71,70 +73,186 @@ MODEL_CONFIG = {
|
|
71 |
"dtype": torch.float32, # Use float32 for CPU
|
72 |
"max_length": 256,
|
73 |
"device": "cuda" if torch.cuda.is_available() else "cpu",
|
|
|
74 |
},
|
75 |
}
|
76 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
77 |
|
78 |
-
|
79 |
-
|
|
|
80 |
try:
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
device = MODEL_CONFIG["main_model"]["device"]
|
85 |
-
|
86 |
-
api_key = get_huggingface_api_key()
|
87 |
-
|
88 |
-
# Use 8-bit quantization for memory efficiency
|
89 |
-
model = AutoModelForCausalLM.from_pretrained(
|
90 |
-
model_name,
|
91 |
-
load_in_8bit=False,
|
92 |
-
torch_dtype=MODEL_CONFIG["main_model"]["dtype"],
|
93 |
-
use_cache=True,
|
94 |
-
device_map="auto",
|
95 |
-
low_cpu_mem_usage=True,
|
96 |
-
trust_remote_code=True,
|
97 |
-
token=api_key, # Add token here
|
98 |
)
|
|
|
|
|
|
|
|
|
99 |
|
100 |
-
model.config.use_cache = True
|
101 |
|
102 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
103 |
|
104 |
-
|
105 |
-
|
106 |
-
generator = pipeline(
|
107 |
-
"text-generation",
|
108 |
-
model=model,
|
109 |
-
tokenizer=tokenizer,
|
110 |
-
# device=device,
|
111 |
-
# temperature=0.7,
|
112 |
-
model_kwargs={"low_cpu_mem_usage": True},
|
113 |
-
)
|
114 |
|
115 |
-
|
116 |
-
return
|
117 |
|
118 |
-
except ImportError as e:
|
119 |
-
logger.error(f"Missing required package: {str(e)}")
|
120 |
-
raise
|
121 |
except Exception as e:
|
122 |
-
logger.error(f"
|
123 |
-
|
124 |
|
125 |
|
126 |
-
def
|
127 |
-
"""
|
128 |
try:
|
129 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
130 |
|
131 |
-
|
132 |
-
|
133 |
-
|
|
|
|
|
|
|
|
|
|
|
134 |
except Exception as e:
|
135 |
-
logger.error(f"
|
136 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
137 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
138 |
|
139 |
# # Initialize model pipeline
|
140 |
# try:
|
@@ -161,17 +279,30 @@ def initialize_inference_client():
|
|
161 |
# raise
|
162 |
|
163 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
164 |
def load_world(filename):
|
165 |
with open(filename, "r") as f:
|
166 |
return json.load(f)
|
167 |
|
168 |
|
169 |
# Define system_prompt and model
|
170 |
-
system_prompt = """You are an AI Game
|
171 |
CRITICAL Rules:
|
172 |
- Write EXACTLY 3 sentences maximum
|
173 |
- Use daily English language
|
174 |
-
- Start with "You
|
175 |
- Don't use 'Elara' or 'she/he', only use 'you'
|
176 |
- Use only second person ("you")
|
177 |
- Never include dialogue after the response
|
@@ -183,7 +314,7 @@ CRITICAL Rules:
|
|
183 |
- Never include 'What would you like to do?' or similar prompts
|
184 |
- Always finish with one real response
|
185 |
- Never use 'Your turn' or or anything like conversation starting prompts
|
186 |
-
- Always end the response with a period"""
|
187 |
|
188 |
|
189 |
def get_game_state(inventory: Dict = None) -> Dict[str, Any]:
|
@@ -435,39 +566,59 @@ New Quest: {next_quest['title']}
|
|
435 |
|
436 |
|
437 |
def parse_items_from_story(text: str) -> Dict[str, int]:
|
438 |
-
"""Extract item changes from story text"""
|
439 |
items = {}
|
440 |
|
441 |
-
#
|
442 |
-
|
443 |
-
|
444 |
-
|
445 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
446 |
|
447 |
-
#
|
448 |
-
|
449 |
-
|
450 |
-
items["gold"] = sum(int(x) for x in gold_matches)
|
451 |
|
452 |
-
|
453 |
-
|
454 |
-
|
455 |
-
|
456 |
-
|
457 |
-
|
458 |
-
|
459 |
-
|
460 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
461 |
|
462 |
-
|
|
|
|
|
463 |
|
464 |
|
465 |
-
def update_game_inventory(game_state: Dict, story_text: str) -> str:
|
466 |
-
"""Update inventory
|
467 |
try:
|
468 |
items = parse_items_from_story(story_text)
|
469 |
update_msg = ""
|
470 |
|
|
|
471 |
for item, count in items.items():
|
472 |
if item in game_state["inventory"]:
|
473 |
game_state["inventory"][item] += count
|
@@ -475,10 +626,15 @@ def update_game_inventory(game_state: Dict, story_text: str) -> str:
|
|
475 |
game_state["inventory"][item] = count
|
476 |
update_msg += f"\nReceived: {count} {item}"
|
477 |
|
478 |
-
|
|
|
|
|
|
|
|
|
|
|
479 |
except Exception as e:
|
480 |
logger.error(f"Error updating inventory: {e}")
|
481 |
-
return ""
|
482 |
|
483 |
|
484 |
def extract_response_after_action(full_text: str, action: str) -> str:
|
@@ -530,8 +686,10 @@ def run_action(message: str, history: list, game_state: Dict) -> str:
|
|
530 |
|
531 |
{game_state['start']}
|
532 |
|
533 |
-
|
534 |
-
|
|
|
|
|
535 |
|
536 |
|
537 |
Current Quest: {initial_quest['title']}
|
@@ -545,6 +703,11 @@ What would you like to do?"""
|
|
545 |
logger.error(f"Invalid game state type: {type(game_state)}")
|
546 |
return "Error: Invalid game state"
|
547 |
|
|
|
|
|
|
|
|
|
|
|
548 |
# logger.info(f"Processing action with game state: {game_state}")
|
549 |
logger.info(f"Processing action with game state")
|
550 |
|
@@ -589,7 +752,8 @@ Inventory: {json.dumps(game_state['inventory'])}"""
|
|
589 |
|
590 |
# Add history in correct alternating format
|
591 |
if history:
|
592 |
-
for h in history[-3:]: # Last 3 exchanges
|
|
|
593 |
if isinstance(h, tuple):
|
594 |
messages.append({"role": "user", "content": h[0]})
|
595 |
messages.append({"role": "assistant", "content": h[1]})
|
@@ -664,6 +828,11 @@ Inventory: {json.dumps(game_state['inventory'])}"""
|
|
664 |
if not response:
|
665 |
return "You look around carefully."
|
666 |
|
|
|
|
|
|
|
|
|
|
|
667 |
# # Perform safety check before returning
|
668 |
# safe = is_safe(response)
|
669 |
# print(f"\nSafety Check Result: {'SAFE' if safe else 'UNSAFE'}")
|
@@ -690,8 +859,8 @@ Inventory: {json.dumps(game_state['inventory'])}"""
|
|
690 |
if quest_completed:
|
691 |
response += quest_message
|
692 |
|
693 |
-
# Check for item updates
|
694 |
-
inventory_update = update_game_inventory(game_state, response)
|
695 |
if inventory_update:
|
696 |
response += inventory_update
|
697 |
|
@@ -733,7 +902,7 @@ def chat_response(message: str, chat_history: list, current_state: dict) -> tupl
|
|
733 |
"""Process chat input and return response with updates"""
|
734 |
try:
|
735 |
if not message.strip():
|
736 |
-
return chat_history, current_state, "", ""
|
737 |
|
738 |
# Get AI response
|
739 |
output = run_action(message, chat_history, current_state)
|
@@ -745,12 +914,17 @@ def chat_response(message: str, chat_history: list, current_state: dict) -> tupl
|
|
745 |
# Update status displays
|
746 |
status_text, quest_text = update_game_status(current_state)
|
747 |
|
|
|
|
|
|
|
|
|
|
|
748 |
# Return tuple includes empty string to clear input
|
749 |
-
return chat_history, current_state, status_text, quest_text
|
750 |
|
751 |
except Exception as e:
|
752 |
logger.error(f"Error in chat response: {e}")
|
753 |
-
return chat_history, current_state, "", ""
|
754 |
|
755 |
|
756 |
def start_game(main_loop, game_state, share=False):
|
@@ -888,22 +1062,26 @@ def start_game(main_loop, game_state, share=False):
|
|
888 |
|
889 |
def submit_action(message, history, state):
|
890 |
# Process response
|
891 |
-
new_history, new_state, status_text, quest_text =
|
892 |
-
message, history, state
|
893 |
)
|
|
|
|
|
|
|
|
|
894 |
# Clear input
|
895 |
-
return "", new_history, new_state, status_text, quest_text
|
896 |
|
897 |
submit_btn.click(
|
898 |
submit_action,
|
899 |
inputs=[txt, chatbot, state],
|
900 |
-
outputs=[txt, chatbot, state, status, quest_display],
|
901 |
)
|
902 |
|
903 |
txt.submit(
|
904 |
submit_action,
|
905 |
inputs=[txt, chatbot, state],
|
906 |
-
outputs=[txt, chatbot, state, status, quest_display],
|
907 |
)
|
908 |
|
909 |
demo.launch(share=share)
|
@@ -939,159 +1117,212 @@ Should not
|
|
939 |
}
|
940 |
|
941 |
|
942 |
-
def
|
943 |
-
"""Initialize
|
944 |
try:
|
945 |
-
|
946 |
-
|
947 |
-
|
948 |
-
|
949 |
-
|
950 |
-
|
951 |
-
# model_id = "meta-llama/Llama-Guard-3-1B"
|
952 |
-
|
953 |
-
api_key = get_huggingface_api_key()
|
954 |
-
|
955 |
-
safety_model = AutoModelForCausalLM.from_pretrained(
|
956 |
-
model_name,
|
957 |
-
token=api_key,
|
958 |
-
torch_dtype=MODEL_CONFIG["safety_model"]["dtype"],
|
959 |
-
use_cache=True,
|
960 |
-
device_map="auto",
|
961 |
-
)
|
962 |
-
safety_model.config.use_cache = True
|
963 |
|
964 |
-
safety_tokenizer = AutoTokenizer.from_pretrained(model_name, token=api_key)
|
965 |
-
# Set pad token explicitly
|
966 |
-
safety_tokenizer.pad_token = safety_tokenizer.eos_token
|
967 |
|
968 |
-
|
969 |
-
|
|
|
|
|
970 |
|
971 |
-
|
972 |
-
|
973 |
-
|
|
|
|
|
|
|
|
|
|
|
974 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
975 |
|
976 |
-
|
977 |
-
|
978 |
-
safety_model_name = MODEL_CONFIG["safety_model"]["name"]
|
979 |
|
980 |
-
|
981 |
|
982 |
-
|
983 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
984 |
|
985 |
-
except Exception as e:
|
986 |
-
logger.error(f"Failed to initialize model: {str(e)}")
|
987 |
-
# Fallback to CPU if GPU initialization fails
|
988 |
-
try:
|
989 |
-
logger.info("Attempting CPU fallback...")
|
990 |
-
safety_model, safety_tokenizer = init_safety_model(
|
991 |
-
safety_model_name, force_cpu=True
|
992 |
-
)
|
993 |
except Exception as e:
|
994 |
-
logger.error(f"
|
995 |
-
|
996 |
|
997 |
|
998 |
-
|
999 |
-
|
1000 |
-
|
1001 |
-
|
1002 |
-
|
1003 |
-
|
1004 |
-
|
1005 |
-
padding=True,
|
1006 |
-
truncation=True,
|
1007 |
-
max_length=MODEL_CONFIG["safety_model"]["max_length"],
|
1008 |
-
).to(safety_model.device)
|
1009 |
-
|
1010 |
-
output = safety_model.generate(
|
1011 |
-
**inputs,
|
1012 |
-
max_new_tokens=5,
|
1013 |
-
temperature=0.1, # Lower temperature for more consistent output
|
1014 |
-
pad_token_id=safety_tokenizer.pad_token_id,
|
1015 |
-
eos_token_id=safety_tokenizer.eos_token_id,
|
1016 |
-
)
|
1017 |
-
|
1018 |
-
return safety_tokenizer.decode(output[0], skip_special_tokens=True)
|
1019 |
|
|
|
|
|
1020 |
|
1021 |
-
|
1022 |
-
"""Check if message content is safe"""
|
1023 |
-
try:
|
1024 |
-
logger.info("Performing safety check...")
|
1025 |
-
# Build the prompt with embedded values
|
1026 |
-
# prompt = f"""[INST] Task: Check if this content is safe according to our safety policy with the below categories.
|
1027 |
|
1028 |
-
|
1029 |
-
|
1030 |
-
|
|
|
|
|
|
|
|
|
|
|
1031 |
|
1032 |
-
|
1033 |
-
|
1034 |
-
|
1035 |
|
1036 |
-
|
1037 |
-
|
1038 |
-
# """
|
1039 |
|
1040 |
-
|
1041 |
-
{
|
1042 |
-
|
1043 |
-
"""
|
1044 |
|
1045 |
-
# Generate safety check on cuda
|
1046 |
|
1047 |
-
|
1048 |
-
|
1049 |
-
|
1050 |
-
# return_tensors="pt",
|
1051 |
-
# padding=True,
|
1052 |
-
# truncation=True,
|
1053 |
-
# )
|
1054 |
|
1055 |
-
|
1056 |
-
# inputs = {k: v.to(device) for k, v in inputs.items()}
|
1057 |
|
1058 |
-
|
1059 |
-
|
1060 |
-
# max_new_tokens=10,
|
1061 |
-
# temperature=0.1, # Lower temperature for more consistent output
|
1062 |
-
# pad_token_id=safety_tokenizer.pad_token_id, # Use configured pad token
|
1063 |
-
# eos_token_id=safety_tokenizer.eos_token_id,
|
1064 |
-
# do_sample=False,
|
1065 |
-
# )
|
1066 |
|
1067 |
-
|
1068 |
-
|
1069 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1070 |
|
1071 |
-
# # Extract response after prompt
|
1072 |
-
# if "[/INST]" in result:
|
1073 |
-
# result = result.split("[/INST]")[-1]
|
1074 |
|
1075 |
-
|
1076 |
-
|
1077 |
-
|
1078 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1079 |
|
1080 |
-
|
1081 |
-
|
|
|
|
|
|
|
|
|
|
|
1082 |
|
1083 |
-
|
1084 |
|
1085 |
-
is_safe = "safe" in result.lower().split()
|
1086 |
|
1087 |
-
|
1088 |
-
|
1089 |
-
|
1090 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1091 |
|
1092 |
-
|
1093 |
-
|
1094 |
-
|
1095 |
|
1096 |
|
1097 |
def detect_inventory_changes(game_state, output):
|
|
|
4 |
import json
|
5 |
import gradio as gr
|
6 |
import torch # first import torch then transformers
|
7 |
+
from torch.nn.functional import softmax
|
8 |
+
from transformers import AutoModelForSequenceClassification
|
9 |
from huggingface_hub import InferenceClient
|
10 |
+
|
11 |
from transformers import pipeline
|
12 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
13 |
import logging
|
|
|
73 |
"dtype": torch.float32, # Use float32 for CPU
|
74 |
"max_length": 256,
|
75 |
"device": "cuda" if torch.cuda.is_available() else "cpu",
|
76 |
+
"max_tokens": 500,
|
77 |
},
|
78 |
}
|
79 |
|
80 |
+
PROMPT_GUARD_CONFIG = {
|
81 |
+
"model_id": "meta-llama/Prompt-Guard-86M",
|
82 |
+
"temperature": 1.0,
|
83 |
+
"jailbreak_threshold": 0.5,
|
84 |
+
"injection_threshold": 0.9,
|
85 |
+
"device": "cpu",
|
86 |
+
"safe_commands": [
|
87 |
+
"look around",
|
88 |
+
"investigate",
|
89 |
+
"explore",
|
90 |
+
"search",
|
91 |
+
"examine",
|
92 |
+
"take",
|
93 |
+
"use",
|
94 |
+
"go",
|
95 |
+
"walk",
|
96 |
+
"continue",
|
97 |
+
"help",
|
98 |
+
"inventory",
|
99 |
+
"quest",
|
100 |
+
"status",
|
101 |
+
"map",
|
102 |
+
"talk",
|
103 |
+
"fight",
|
104 |
+
"run",
|
105 |
+
"hide",
|
106 |
+
],
|
107 |
+
"max_length": 512,
|
108 |
+
}
|
109 |
|
110 |
+
|
111 |
+
def initialize_prompt_guard():
|
112 |
+
"""Initialize Prompt Guard model"""
|
113 |
try:
|
114 |
+
tokenizer = AutoTokenizer.from_pretrained(PROMPT_GUARD_CONFIG["model_id"])
|
115 |
+
model = AutoModelForSequenceClassification.from_pretrained(
|
116 |
+
PROMPT_GUARD_CONFIG["model_id"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
117 |
)
|
118 |
+
return model, tokenizer
|
119 |
+
except Exception as e:
|
120 |
+
logger.error(f"Failed to initialize Prompt Guard: {e}")
|
121 |
+
raise
|
122 |
|
|
|
123 |
|
124 |
+
def get_class_probabilities(text: str, guard_model, guard_tokenizer) -> torch.Tensor:
|
125 |
+
"""Evaluate model probabilities with temperature scaling"""
|
126 |
+
try:
|
127 |
+
inputs = guard_tokenizer(
|
128 |
+
text,
|
129 |
+
return_tensors="pt",
|
130 |
+
padding=True,
|
131 |
+
truncation=True,
|
132 |
+
max_length=PROMPT_GUARD_CONFIG["max_length"],
|
133 |
+
).to(PROMPT_GUARD_CONFIG["device"])
|
134 |
|
135 |
+
with torch.no_grad():
|
136 |
+
logits = guard_model(**inputs).logits
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
137 |
|
138 |
+
scaled_logits = logits / PROMPT_GUARD_CONFIG["temperature"]
|
139 |
+
return softmax(scaled_logits, dim=-1)
|
140 |
|
|
|
|
|
|
|
141 |
except Exception as e:
|
142 |
+
logger.error(f"Error getting class probabilities: {e}")
|
143 |
+
return None
|
144 |
|
145 |
|
146 |
+
def get_jailbreak_score(text: str, guard_model, guard_tokenizer) -> float:
|
147 |
+
"""Get jailbreak probability score"""
|
148 |
try:
|
149 |
+
probabilities = get_class_probabilities(text, guard_model, guard_tokenizer)
|
150 |
+
if probabilities is None:
|
151 |
+
return 1.0 # Fail safe
|
152 |
+
return probabilities[0, 2].item()
|
153 |
+
except Exception as e:
|
154 |
+
logger.error(f"Error getting jailbreak score: {e}")
|
155 |
+
return 1.0
|
156 |
|
157 |
+
|
158 |
+
def get_injection_score(text: str, guard_model, guard_tokenizer) -> float:
|
159 |
+
"""Get injection probability score"""
|
160 |
+
try:
|
161 |
+
probabilities = get_class_probabilities(text, guard_model, guard_tokenizer)
|
162 |
+
if probabilities is None:
|
163 |
+
return 1.0 # Fail safe
|
164 |
+
return (probabilities[0, 1] + probabilities[0, 2]).item()
|
165 |
except Exception as e:
|
166 |
+
logger.error(f"Error getting injection score: {e}")
|
167 |
+
return 1.0
|
168 |
+
|
169 |
+
|
170 |
+
# Initialize safety model pipeline
|
171 |
+
try:
|
172 |
+
# Initialize Prompt Guard
|
173 |
+
guard_model, guard_tokenizer = initialize_prompt_guard()
|
174 |
+
|
175 |
+
except Exception as e:
|
176 |
+
logger.error(f"Failed to initialize model: {str(e)}")
|
177 |
+
|
178 |
+
|
179 |
+
def is_prompt_safe(message: str) -> bool:
|
180 |
+
"""Enhanced safety check with Prompt Guard"""
|
181 |
+
try:
|
182 |
+
# Allow safe game commands
|
183 |
+
if any(cmd in message.lower() for cmd in PROMPT_GUARD_CONFIG["safe_commands"]):
|
184 |
+
logger.info("Message matched safe command pattern")
|
185 |
+
return True
|
186 |
|
187 |
+
# Get safety scores
|
188 |
+
jailbreak_score = get_jailbreak_score(message, guard_model, guard_tokenizer)
|
189 |
+
injection_score = get_injection_score(message, guard_model, guard_tokenizer)
|
190 |
+
|
191 |
+
logger.info(
|
192 |
+
f"Safety scores - Jailbreak: {jailbreak_score}, Injection: {injection_score}"
|
193 |
+
)
|
194 |
+
|
195 |
+
# Check against thresholds
|
196 |
+
is_safe = (
|
197 |
+
jailbreak_score
|
198 |
+
< PROMPT_GUARD_CONFIG["jailbreak_threshold"]
|
199 |
+
# and injection_score < PROMPT_GUARD_CONFIG["injection_threshold"] # Disable for now because injection is too strict and current prompt guard model seems malfunctioning for now.
|
200 |
+
)
|
201 |
+
|
202 |
+
logger.info(f"Final safety result: {is_safe}")
|
203 |
+
return is_safe
|
204 |
+
|
205 |
+
except Exception as e:
|
206 |
+
logger.error(f"Safety check failed: {e}")
|
207 |
+
return False
|
208 |
+
|
209 |
+
|
210 |
+
# def initialize_model_pipeline(model_name, force_cpu=False):
|
211 |
+
# """Initialize pipeline with memory management"""
|
212 |
+
# try:
|
213 |
+
# if force_cpu:
|
214 |
+
# device = -1
|
215 |
+
# else:
|
216 |
+
# device = MODEL_CONFIG["main_model"]["device"]
|
217 |
+
|
218 |
+
# api_key = get_huggingface_api_key()
|
219 |
+
|
220 |
+
# # Use 8-bit quantization for memory efficiency
|
221 |
+
# model = AutoModelForCausalLM.from_pretrained(
|
222 |
+
# model_name,
|
223 |
+
# load_in_8bit=False,
|
224 |
+
# torch_dtype=MODEL_CONFIG["main_model"]["dtype"],
|
225 |
+
# use_cache=True,
|
226 |
+
# device_map="auto",
|
227 |
+
# low_cpu_mem_usage=True,
|
228 |
+
# trust_remote_code=True,
|
229 |
+
# token=api_key, # Add token here
|
230 |
+
# )
|
231 |
+
|
232 |
+
# model.config.use_cache = True
|
233 |
+
|
234 |
+
# tokenizer = AutoTokenizer.from_pretrained(model_name, token=api_key)
|
235 |
+
|
236 |
+
# # Initialize pipeline
|
237 |
+
# logger.info(f"Initializing pipeline with device: {device}")
|
238 |
+
# generator = pipeline(
|
239 |
+
# "text-generation",
|
240 |
+
# model=model,
|
241 |
+
# tokenizer=tokenizer,
|
242 |
+
# # device=device,
|
243 |
+
# # temperature=0.7,
|
244 |
+
# model_kwargs={"low_cpu_mem_usage": True},
|
245 |
+
# )
|
246 |
+
|
247 |
+
# logger.info("Model Pipeline initialized successfully")
|
248 |
+
# return generator, tokenizer
|
249 |
+
|
250 |
+
# except ImportError as e:
|
251 |
+
# logger.error(f"Missing required package: {str(e)}")
|
252 |
+
# raise
|
253 |
+
# except Exception as e:
|
254 |
+
# logger.error(f"Failed to initialize pipeline: {str(e)}")
|
255 |
+
# raise
|
256 |
|
257 |
# # Initialize model pipeline
|
258 |
# try:
|
|
|
279 |
# raise
|
280 |
|
281 |
|
282 |
+
def initialize_inference_client():
|
283 |
+
"""Initialize HuggingFace Inference Client"""
|
284 |
+
try:
|
285 |
+
inference_key = get_huggingface_inference_key()
|
286 |
+
|
287 |
+
client = InferenceClient(api_key=inference_key)
|
288 |
+
logger.info("Inference Client initialized successfully")
|
289 |
+
return client
|
290 |
+
except Exception as e:
|
291 |
+
logger.error(f"Failed to initialize Inference Client: {e}")
|
292 |
+
raise
|
293 |
+
|
294 |
+
|
295 |
def load_world(filename):
|
296 |
with open(filename, "r") as f:
|
297 |
return json.load(f)
|
298 |
|
299 |
|
300 |
# Define system_prompt and model
|
301 |
+
system_prompt = """You are an AI Game master. Your job is to write what happens next in a player's adventure game.
|
302 |
CRITICAL Rules:
|
303 |
- Write EXACTLY 3 sentences maximum
|
304 |
- Use daily English language
|
305 |
+
- Start with "You "
|
306 |
- Don't use 'Elara' or 'she/he', only use 'you'
|
307 |
- Use only second person ("you")
|
308 |
- Never include dialogue after the response
|
|
|
314 |
- Never include 'What would you like to do?' or similar prompts
|
315 |
- Always finish with one real response
|
316 |
- Never use 'Your turn' or or anything like conversation starting prompts
|
317 |
+
- Always end the response with a period(.)"""
|
318 |
|
319 |
|
320 |
def get_game_state(inventory: Dict = None) -> Dict[str, Any]:
|
|
|
566 |
|
567 |
|
568 |
def parse_items_from_story(text: str) -> Dict[str, int]:
|
569 |
+
"""Extract item changes from story text with improved pattern matching"""
|
570 |
items = {}
|
571 |
|
572 |
+
# Skip parsing if text starts with common narrative phrases
|
573 |
+
skip_patterns = [
|
574 |
+
"you see",
|
575 |
+
"you find yourself",
|
576 |
+
"you are",
|
577 |
+
"you stand",
|
578 |
+
"you hear",
|
579 |
+
"you feel",
|
580 |
+
]
|
581 |
+
if any(text.lower().startswith(pattern) for pattern in skip_patterns):
|
582 |
+
return items
|
583 |
|
584 |
+
# Common item keywords and patterns
|
585 |
+
gold_pattern = r"(\d+)\s*gold(?:\s+coins?)?"
|
586 |
+
items_pattern = r"(?:receive|find|given|obtain|pick up|grab)\s+(?:a|an|the)?\s*(\d+)?\s*([\w\s]+?)"
|
|
|
587 |
|
588 |
+
try:
|
589 |
+
# Find gold amounts
|
590 |
+
gold_matches = re.findall(gold_pattern, text.lower())
|
591 |
+
if gold_matches:
|
592 |
+
items["gold"] = sum(int(x) for x in gold_matches)
|
593 |
+
|
594 |
+
# Find other items
|
595 |
+
item_matches = re.findall(items_pattern, text.lower())
|
596 |
+
for count, item in item_matches:
|
597 |
+
# Validate item name
|
598 |
+
item = item.strip()
|
599 |
+
if len(item) > 2 and not any( # Minimum length check
|
600 |
+
skip in item for skip in ["yourself", "you", "door", "wall", "floor"]
|
601 |
+
): # Skip common words
|
602 |
+
count = int(count) if count else 1
|
603 |
+
if item in items:
|
604 |
+
items[item] += count
|
605 |
+
else:
|
606 |
+
items[item] = count
|
607 |
+
|
608 |
+
return items
|
609 |
|
610 |
+
except Exception as e:
|
611 |
+
logger.error(f"Error parsing items from story: {e}")
|
612 |
+
return {}
|
613 |
|
614 |
|
615 |
+
def update_game_inventory(game_state: Dict, story_text: str) -> Tuple[str, list]:
|
616 |
+
"""Update inventory and return message and updated inventory data"""
|
617 |
try:
|
618 |
items = parse_items_from_story(story_text)
|
619 |
update_msg = ""
|
620 |
|
621 |
+
# Update inventory
|
622 |
for item, count in items.items():
|
623 |
if item in game_state["inventory"]:
|
624 |
game_state["inventory"][item] += count
|
|
|
626 |
game_state["inventory"][item] = count
|
627 |
update_msg += f"\nReceived: {count} {item}"
|
628 |
|
629 |
+
# Create updated inventory data for display
|
630 |
+
inventory_data = [
|
631 |
+
[item, count] for item, count in game_state["inventory"].items()
|
632 |
+
]
|
633 |
+
|
634 |
+
return update_msg, inventory_data
|
635 |
except Exception as e:
|
636 |
logger.error(f"Error updating inventory: {e}")
|
637 |
+
return "", []
|
638 |
|
639 |
|
640 |
def extract_response_after_action(full_text: str, action: str) -> str:
|
|
|
686 |
|
687 |
{game_state['start']}
|
688 |
|
689 |
+
You are currently in {game_state['town_name']}, {game_state['town']}.
|
690 |
+
|
691 |
+
{game_state['town_name']} is a city in {game_state['kingdom']}.
|
692 |
+
|
693 |
|
694 |
|
695 |
Current Quest: {initial_quest['title']}
|
|
|
703 |
logger.error(f"Invalid game state type: {type(game_state)}")
|
704 |
return "Error: Invalid game state"
|
705 |
|
706 |
+
# Safety check with Prompt Guard
|
707 |
+
if not is_prompt_safe(message):
|
708 |
+
logger.warning("Unsafe content detected in user prompt")
|
709 |
+
return "I cannot process that request for safety reasons."
|
710 |
+
|
711 |
# logger.info(f"Processing action with game state: {game_state}")
|
712 |
logger.info(f"Processing action with game state")
|
713 |
|
|
|
752 |
|
753 |
# Add history in correct alternating format
|
754 |
if history:
|
755 |
+
# for h in history[-3:]: # Last 3 exchanges
|
756 |
+
for h in history:
|
757 |
if isinstance(h, tuple):
|
758 |
messages.append({"role": "user", "content": h[0]})
|
759 |
messages.append({"role": "assistant", "content": h[1]})
|
|
|
828 |
if not response:
|
829 |
return "You look around carefully."
|
830 |
|
831 |
+
# Safety check the responce using inference API
|
832 |
+
if not is_safe(response):
|
833 |
+
logger.warning("Unsafe content detected - blocking response")
|
834 |
+
return "This response was blocked for safety reasons."
|
835 |
+
|
836 |
# # Perform safety check before returning
|
837 |
# safe = is_safe(response)
|
838 |
# print(f"\nSafety Check Result: {'SAFE' if safe else 'UNSAFE'}")
|
|
|
859 |
if quest_completed:
|
860 |
response += quest_message
|
861 |
|
862 |
+
# Check for item-inventory updates
|
863 |
+
inventory_update, inventory_data = update_game_inventory(game_state, response)
|
864 |
if inventory_update:
|
865 |
response += inventory_update
|
866 |
|
|
|
902 |
"""Process chat input and return response with updates"""
|
903 |
try:
|
904 |
if not message.strip():
|
905 |
+
return chat_history, current_state, "", "", [] # Add empty inventory data
|
906 |
|
907 |
# Get AI response
|
908 |
output = run_action(message, chat_history, current_state)
|
|
|
914 |
# Update status displays
|
915 |
status_text, quest_text = update_game_status(current_state)
|
916 |
|
917 |
+
# Get inventory updates
|
918 |
+
update_msg, inventory_data = update_game_inventory(current_state, output)
|
919 |
+
if update_msg:
|
920 |
+
output += update_msg
|
921 |
+
|
922 |
# Return tuple includes empty string to clear input
|
923 |
+
return chat_history, current_state, status_text, quest_text, inventory_data
|
924 |
|
925 |
except Exception as e:
|
926 |
logger.error(f"Error in chat response: {e}")
|
927 |
+
return chat_history, current_state, "", "", []
|
928 |
|
929 |
|
930 |
def start_game(main_loop, game_state, share=False):
|
|
|
1062 |
|
1063 |
def submit_action(message, history, state):
|
1064 |
# Process response
|
1065 |
+
new_history, new_state, status_text, quest_text, inventory_data = (
|
1066 |
+
chat_response(message, history, state)
|
1067 |
)
|
1068 |
+
|
1069 |
+
# Update inventory display
|
1070 |
+
inventory.value = inventory_data
|
1071 |
+
|
1072 |
# Clear input
|
1073 |
+
return "", new_history, new_state, status_text, quest_text, inventory
|
1074 |
|
1075 |
submit_btn.click(
|
1076 |
submit_action,
|
1077 |
inputs=[txt, chatbot, state],
|
1078 |
+
outputs=[txt, chatbot, state, status, quest_display, inventory],
|
1079 |
)
|
1080 |
|
1081 |
txt.submit(
|
1082 |
submit_action,
|
1083 |
inputs=[txt, chatbot, state],
|
1084 |
+
outputs=[txt, chatbot, state, status, quest_display, inventory],
|
1085 |
)
|
1086 |
|
1087 |
demo.launch(share=share)
|
|
|
1117 |
}
|
1118 |
|
1119 |
|
1120 |
+
def initialize_safety_client():
|
1121 |
+
"""Initialize HuggingFace Inference Client"""
|
1122 |
try:
|
1123 |
+
inference_key = get_huggingface_inference_key()
|
1124 |
+
# api_key = get_huggingface_api_key()
|
1125 |
+
return InferenceClient(api_key=inference_key)
|
1126 |
+
except Exception as e:
|
1127 |
+
logger.error(f"Failed to initialize safety client: {e}")
|
1128 |
+
raise
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1129 |
|
|
|
|
|
|
|
1130 |
|
1131 |
+
def is_safe(message: str) -> bool:
|
1132 |
+
"""Check content safety using Inference API"""
|
1133 |
+
try:
|
1134 |
+
client = initialize_safety_client()
|
1135 |
|
1136 |
+
messages = [
|
1137 |
+
{"role": "user", "content": f"Check if this content is safe:\n{message}"},
|
1138 |
+
{
|
1139 |
+
"role": "assistant",
|
1140 |
+
"content": f"I will check if the content is safe based on this content policy:\n{everyone_content_policy['policy']}",
|
1141 |
+
},
|
1142 |
+
{"role": "user", "content": "Is it safe or unsafe?"},
|
1143 |
+
]
|
1144 |
|
1145 |
+
try:
|
1146 |
+
completion = client.chat.completions.create(
|
1147 |
+
model=MODEL_CONFIG["safety_model"]["name"],
|
1148 |
+
messages=messages,
|
1149 |
+
max_tokens=MODEL_CONFIG["safety_model"]["max_tokens"],
|
1150 |
+
temperature=0.1,
|
1151 |
+
)
|
1152 |
|
1153 |
+
response = completion.choices[0].message.content.lower()
|
1154 |
+
logger.info(f"Safety check response: {response}")
|
|
|
1155 |
|
1156 |
+
is_safe = "safe" in response and "unsafe" not in response
|
1157 |
|
1158 |
+
logger.info(f"Safety check result: {'SAFE' if is_safe else 'UNSAFE'}")
|
1159 |
+
return is_safe
|
1160 |
+
|
1161 |
+
except Exception as api_error:
|
1162 |
+
logger.error(f"API error: {api_error}")
|
1163 |
+
# Fallback to allow common game commands
|
1164 |
+
return any(
|
1165 |
+
cmd in message.lower() for cmd in PROMPT_GUARD_CONFIG["safe_commands"]
|
1166 |
+
)
|
1167 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1168 |
except Exception as e:
|
1169 |
+
logger.error(f"Safety check failed: {e}")
|
1170 |
+
return False
|
1171 |
|
1172 |
|
1173 |
+
# def init_safety_model(model_name, force_cpu=False):
|
1174 |
+
# """Initialize safety checking model with optimized memory usage"""
|
1175 |
+
# try:
|
1176 |
+
# if force_cpu:
|
1177 |
+
# device = -1
|
1178 |
+
# else:
|
1179 |
+
# device = MODEL_CONFIG["safety_model"]["device"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1180 |
|
1181 |
+
# # model_id = "meta-llama/Llama-Guard-3-8B"
|
1182 |
+
# # model_id = "meta-llama/Llama-Guard-3-1B"
|
1183 |
|
1184 |
+
# api_key = get_huggingface_api_key()
|
|
|
|
|
|
|
|
|
|
|
1185 |
|
1186 |
+
# safety_model = AutoModelForCausalLM.from_pretrained(
|
1187 |
+
# model_name,
|
1188 |
+
# token=api_key,
|
1189 |
+
# torch_dtype=MODEL_CONFIG["safety_model"]["dtype"],
|
1190 |
+
# use_cache=True,
|
1191 |
+
# device_map="auto",
|
1192 |
+
# )
|
1193 |
+
# safety_model.config.use_cache = True
|
1194 |
|
1195 |
+
# safety_tokenizer = AutoTokenizer.from_pretrained(model_name, token=api_key)
|
1196 |
+
# # Set pad token explicitly
|
1197 |
+
# safety_tokenizer.pad_token = safety_tokenizer.eos_token
|
1198 |
|
1199 |
+
# logger.info(f"Safety model initialized successfully on {device}")
|
1200 |
+
# return safety_model, safety_tokenizer
|
|
|
1201 |
|
1202 |
+
# except Exception as e:
|
1203 |
+
# logger.error(f"Failed to initialize safety model: {e}")
|
1204 |
+
# raise
|
|
|
1205 |
|
|
|
1206 |
|
1207 |
+
# # Initialize safety model pipeline
|
1208 |
+
# try:
|
1209 |
+
# safety_model_name = MODEL_CONFIG["safety_model"]["name"]
|
|
|
|
|
|
|
|
|
1210 |
|
1211 |
+
# api_key = get_huggingface_api_key()
|
|
|
1212 |
|
1213 |
+
# # Initialize the pipeline with memory management
|
1214 |
+
# safety_model, safety_tokenizer = init_safety_model(safety_model_name)
|
|
|
|
|
|
|
|
|
|
|
|
|
1215 |
|
1216 |
+
# except Exception as e:
|
1217 |
+
# logger.error(f"Failed to initialize model: {str(e)}")
|
1218 |
+
# # Fallback to CPU if GPU initialization fails
|
1219 |
+
# try:
|
1220 |
+
# logger.info("Attempting CPU fallback...")
|
1221 |
+
# safety_model, safety_tokenizer = init_safety_model(
|
1222 |
+
# safety_model_name, force_cpu=True
|
1223 |
+
# )
|
1224 |
+
# except Exception as e:
|
1225 |
+
# logger.error(f"CPU fallback failed: {str(e)}")
|
1226 |
+
# raise
|
1227 |
|
|
|
|
|
|
|
1228 |
|
1229 |
+
# @lru_cache(maxsize=1000)
|
1230 |
+
# def get_safety_response(prompt: str) -> str:
|
1231 |
+
# """Cached safety check responses"""
|
1232 |
+
# with torch.no_grad():
|
1233 |
+
# inputs = safety_tokenizer(
|
1234 |
+
# prompt,
|
1235 |
+
# return_tensors="pt",
|
1236 |
+
# padding=True,
|
1237 |
+
# truncation=True,
|
1238 |
+
# max_length=MODEL_CONFIG["safety_model"]["max_length"],
|
1239 |
+
# ).to(safety_model.device)
|
1240 |
|
1241 |
+
# output = safety_model.generate(
|
1242 |
+
# **inputs,
|
1243 |
+
# max_new_tokens=5,
|
1244 |
+
# temperature=0.1, # Lower temperature for more consistent output
|
1245 |
+
# pad_token_id=safety_tokenizer.pad_token_id,
|
1246 |
+
# eos_token_id=safety_tokenizer.eos_token_id,
|
1247 |
+
# )
|
1248 |
|
1249 |
+
# return safety_tokenizer.decode(output[0], skip_special_tokens=True)
|
1250 |
|
|
|
1251 |
|
1252 |
+
# def is_safe(message: str) -> bool:
|
1253 |
+
# """Check if message content is safe"""
|
1254 |
+
# try:
|
1255 |
+
# logger.info("Performing safety check...")
|
1256 |
+
# # Build the prompt with embedded values
|
1257 |
+
# # prompt = f"""[INST] Task: Check if this content is safe according to our safety policy with the below categories.
|
1258 |
+
|
1259 |
+
# # <BEGIN UNSAFE CONTENT CATEGORIES>
|
1260 |
+
# # {everyone_content_policy['policy']}
|
1261 |
+
# # <END UNSAFE CONTENT CATEGORIES>
|
1262 |
+
|
1263 |
+
# # <CONTENT TO CHECK>
|
1264 |
+
# # {message}
|
1265 |
+
# # <END CONTENT TO CHECK>
|
1266 |
+
|
1267 |
+
# # Provide your safety assessment for user in the above CONTENT TO CHECK section:
|
1268 |
+
# # - Reply ONLY with single word 'safe' or 'unsafe' [/INST]
|
1269 |
+
# # """
|
1270 |
+
|
1271 |
+
# prompt = f"""[INST]Check if safe:
|
1272 |
+
# {message}
|
1273 |
+
# Reply 'safe' or 'unsafe'[/INST]
|
1274 |
+
# """
|
1275 |
+
|
1276 |
+
# # Generate safety check on cuda
|
1277 |
+
|
1278 |
+
# # with torch.no_grad():
|
1279 |
+
# # inputs = safety_tokenizer(
|
1280 |
+
# # prompt,
|
1281 |
+
# # return_tensors="pt",
|
1282 |
+
# # padding=True,
|
1283 |
+
# # truncation=True,
|
1284 |
+
# # )
|
1285 |
+
|
1286 |
+
# # # Move inputs to correct device
|
1287 |
+
# # inputs = {k: v.to(device) for k, v in inputs.items()}
|
1288 |
+
|
1289 |
+
# # output = safety_model.generate(
|
1290 |
+
# # **inputs,
|
1291 |
+
# # max_new_tokens=10,
|
1292 |
+
# # temperature=0.1, # Lower temperature for more consistent output
|
1293 |
+
# # pad_token_id=safety_tokenizer.pad_token_id, # Use configured pad token
|
1294 |
+
# # eos_token_id=safety_tokenizer.eos_token_id,
|
1295 |
+
# # do_sample=False,
|
1296 |
+
# # )
|
1297 |
+
|
1298 |
+
# # result = safety_tokenizer.decode(output[0], skip_special_tokens=True)
|
1299 |
+
# result = get_safety_response(prompt)
|
1300 |
+
# print(f"Raw safety check result: {result}")
|
1301 |
+
|
1302 |
+
# # # Extract response after prompt
|
1303 |
+
# # if "[/INST]" in result:
|
1304 |
+
# # result = result.split("[/INST]")[-1]
|
1305 |
+
|
1306 |
+
# # # Clean response
|
1307 |
+
# # result = result.lower().strip()
|
1308 |
+
# # print(f"Cleaned safety check result: {result}")
|
1309 |
+
# # words = [word for word in result.split() if word in ["safe", "unsafe"]]
|
1310 |
+
|
1311 |
+
# # # Take first valid response word
|
1312 |
+
# # is_safe = words[0] == "safe" if words else False
|
1313 |
+
|
1314 |
+
# # print("Final Safety check result:", is_safe)
|
1315 |
+
|
1316 |
+
# is_safe = "safe" in result.lower().split()
|
1317 |
+
|
1318 |
+
# logger.info(
|
1319 |
+
# f"Safety check completed - Result: {'SAFE' if is_safe else 'UNSAFE'}"
|
1320 |
+
# )
|
1321 |
+
# return is_safe
|
1322 |
|
1323 |
+
# except Exception as e:
|
1324 |
+
# logger.error(f"Safety check failed: {e}")
|
1325 |
+
# return False
|
1326 |
|
1327 |
|
1328 |
def detect_inventory_changes(game_state, output):
|