prgrmc commited on
Commit
f13e44b
·
1 Parent(s): e7bea79

update system_prompt, add prompt guard for user prompt and safety check for LLM generated content

Browse files
Files changed (2) hide show
  1. README.md +6 -1
  2. helper.py +441 -210
README.md CHANGED
@@ -107,7 +107,12 @@ HUGGINGFACE_API_KEY=your_api_key_here
107
 
108
  ## Usage
109
  ```bash
110
- # Start the game
 
 
 
 
 
111
  python main.py
112
 
113
  # Access via web browser
 
107
 
108
  ## Usage
109
  ```bash
110
+ # Run the game locally using gpu-compute branch
111
+ git checkout gpu-compute
112
+ python main.py
113
+
114
+ # Start the game using deployed main branch
115
+ git checkout main
116
  python main.py
117
 
118
  # Access via web browser
helper.py CHANGED
@@ -4,8 +4,10 @@ from dotenv import load_dotenv, find_dotenv
4
  import json
5
  import gradio as gr
6
  import torch # first import torch then transformers
7
-
 
8
  from huggingface_hub import InferenceClient
 
9
  from transformers import pipeline
10
  from transformers import AutoTokenizer, AutoModelForCausalLM
11
  import logging
@@ -71,70 +73,186 @@ MODEL_CONFIG = {
71
  "dtype": torch.float32, # Use float32 for CPU
72
  "max_length": 256,
73
  "device": "cuda" if torch.cuda.is_available() else "cpu",
 
74
  },
75
  }
76
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
 
78
- def initialize_model_pipeline(model_name, force_cpu=False):
79
- """Initialize pipeline with memory management"""
 
80
  try:
81
- if force_cpu:
82
- device = -1
83
- else:
84
- device = MODEL_CONFIG["main_model"]["device"]
85
-
86
- api_key = get_huggingface_api_key()
87
-
88
- # Use 8-bit quantization for memory efficiency
89
- model = AutoModelForCausalLM.from_pretrained(
90
- model_name,
91
- load_in_8bit=False,
92
- torch_dtype=MODEL_CONFIG["main_model"]["dtype"],
93
- use_cache=True,
94
- device_map="auto",
95
- low_cpu_mem_usage=True,
96
- trust_remote_code=True,
97
- token=api_key, # Add token here
98
  )
 
 
 
 
99
 
100
- model.config.use_cache = True
101
 
102
- tokenizer = AutoTokenizer.from_pretrained(model_name, token=api_key)
 
 
 
 
 
 
 
 
 
103
 
104
- # Initialize pipeline
105
- logger.info(f"Initializing pipeline with device: {device}")
106
- generator = pipeline(
107
- "text-generation",
108
- model=model,
109
- tokenizer=tokenizer,
110
- # device=device,
111
- # temperature=0.7,
112
- model_kwargs={"low_cpu_mem_usage": True},
113
- )
114
 
115
- logger.info("Model Pipeline initialized successfully")
116
- return generator, tokenizer
117
 
118
- except ImportError as e:
119
- logger.error(f"Missing required package: {str(e)}")
120
- raise
121
  except Exception as e:
122
- logger.error(f"Failed to initialize pipeline: {str(e)}")
123
- raise
124
 
125
 
126
- def initialize_inference_client():
127
- """Initialize HuggingFace Inference Client"""
128
  try:
129
- inference_key = get_huggingface_inference_key()
 
 
 
 
 
 
130
 
131
- client = InferenceClient(api_key=inference_key)
132
- logger.info("Inference Client initialized successfully")
133
- return client
 
 
 
 
 
134
  except Exception as e:
135
- logger.error(f"Failed to initialize Inference Client: {e}")
136
- raise
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
 
139
  # # Initialize model pipeline
140
  # try:
@@ -161,17 +279,30 @@ def initialize_inference_client():
161
  # raise
162
 
163
 
 
 
 
 
 
 
 
 
 
 
 
 
 
164
  def load_world(filename):
165
  with open(filename, "r") as f:
166
  return json.load(f)
167
 
168
 
169
  # Define system_prompt and model
170
- system_prompt = """You are an AI Game Master. Write ONE response describing what the player sees/experiences.
171
  CRITICAL Rules:
172
  - Write EXACTLY 3 sentences maximum
173
  - Use daily English language
174
- - Start with "You see", "You hear", or "You feel"
175
  - Don't use 'Elara' or 'she/he', only use 'you'
176
  - Use only second person ("you")
177
  - Never include dialogue after the response
@@ -183,7 +314,7 @@ CRITICAL Rules:
183
  - Never include 'What would you like to do?' or similar prompts
184
  - Always finish with one real response
185
  - Never use 'Your turn' or or anything like conversation starting prompts
186
- - Always end the response with a period"""
187
 
188
 
189
  def get_game_state(inventory: Dict = None) -> Dict[str, Any]:
@@ -435,39 +566,59 @@ New Quest: {next_quest['title']}
435
 
436
 
437
  def parse_items_from_story(text: str) -> Dict[str, int]:
438
- """Extract item changes from story text"""
439
  items = {}
440
 
441
- # Common item keywords and patterns
442
- gold_pattern = r"(\d+)\s*gold"
443
- items_pattern = (
444
- r"(?:receive|find|given|hand|containing)\s+(?:a|an|the)?\s*(\d+)?\s*([\w\s]+)"
445
- )
 
 
 
 
 
 
446
 
447
- # Find gold amounts
448
- gold_matches = re.findall(gold_pattern, text.lower())
449
- if gold_matches:
450
- items["gold"] = sum(int(x) for x in gold_matches)
451
 
452
- # Find other items
453
- item_matches = re.findall(items_pattern, text.lower())
454
- for count, item in item_matches:
455
- count = int(count) if count else 1
456
- item = item.strip()
457
- if item in items:
458
- items[item] += count
459
- else:
460
- items[item] = count
 
 
 
 
 
 
 
 
 
 
 
 
461
 
462
- return items
 
 
463
 
464
 
465
- def update_game_inventory(game_state: Dict, story_text: str) -> str:
466
- """Update inventory based on story and return update message"""
467
  try:
468
  items = parse_items_from_story(story_text)
469
  update_msg = ""
470
 
 
471
  for item, count in items.items():
472
  if item in game_state["inventory"]:
473
  game_state["inventory"][item] += count
@@ -475,10 +626,15 @@ def update_game_inventory(game_state: Dict, story_text: str) -> str:
475
  game_state["inventory"][item] = count
476
  update_msg += f"\nReceived: {count} {item}"
477
 
478
- return update_msg
 
 
 
 
 
479
  except Exception as e:
480
  logger.error(f"Error updating inventory: {e}")
481
- return ""
482
 
483
 
484
  def extract_response_after_action(full_text: str, action: str) -> str:
@@ -530,8 +686,10 @@ def run_action(message: str, history: list, game_state: Dict) -> str:
530
 
531
  {game_state['start']}
532
 
533
- Currently in {game_state['town_name']}, in the kingdom of {game_state['kingdom']}.
534
- {game_state['town']}
 
 
535
 
536
 
537
  Current Quest: {initial_quest['title']}
@@ -545,6 +703,11 @@ What would you like to do?"""
545
  logger.error(f"Invalid game state type: {type(game_state)}")
546
  return "Error: Invalid game state"
547
 
 
 
 
 
 
548
  # logger.info(f"Processing action with game state: {game_state}")
549
  logger.info(f"Processing action with game state")
550
 
@@ -589,7 +752,8 @@ Inventory: {json.dumps(game_state['inventory'])}"""
589
 
590
  # Add history in correct alternating format
591
  if history:
592
- for h in history[-3:]: # Last 3 exchanges
 
593
  if isinstance(h, tuple):
594
  messages.append({"role": "user", "content": h[0]})
595
  messages.append({"role": "assistant", "content": h[1]})
@@ -664,6 +828,11 @@ Inventory: {json.dumps(game_state['inventory'])}"""
664
  if not response:
665
  return "You look around carefully."
666
 
 
 
 
 
 
667
  # # Perform safety check before returning
668
  # safe = is_safe(response)
669
  # print(f"\nSafety Check Result: {'SAFE' if safe else 'UNSAFE'}")
@@ -690,8 +859,8 @@ Inventory: {json.dumps(game_state['inventory'])}"""
690
  if quest_completed:
691
  response += quest_message
692
 
693
- # Check for item updates
694
- inventory_update = update_game_inventory(game_state, response)
695
  if inventory_update:
696
  response += inventory_update
697
 
@@ -733,7 +902,7 @@ def chat_response(message: str, chat_history: list, current_state: dict) -> tupl
733
  """Process chat input and return response with updates"""
734
  try:
735
  if not message.strip():
736
- return chat_history, current_state, "", ""
737
 
738
  # Get AI response
739
  output = run_action(message, chat_history, current_state)
@@ -745,12 +914,17 @@ def chat_response(message: str, chat_history: list, current_state: dict) -> tupl
745
  # Update status displays
746
  status_text, quest_text = update_game_status(current_state)
747
 
 
 
 
 
 
748
  # Return tuple includes empty string to clear input
749
- return chat_history, current_state, status_text, quest_text
750
 
751
  except Exception as e:
752
  logger.error(f"Error in chat response: {e}")
753
- return chat_history, current_state, "", ""
754
 
755
 
756
  def start_game(main_loop, game_state, share=False):
@@ -888,22 +1062,26 @@ def start_game(main_loop, game_state, share=False):
888
 
889
  def submit_action(message, history, state):
890
  # Process response
891
- new_history, new_state, status_text, quest_text = chat_response(
892
- message, history, state
893
  )
 
 
 
 
894
  # Clear input
895
- return "", new_history, new_state, status_text, quest_text
896
 
897
  submit_btn.click(
898
  submit_action,
899
  inputs=[txt, chatbot, state],
900
- outputs=[txt, chatbot, state, status, quest_display],
901
  )
902
 
903
  txt.submit(
904
  submit_action,
905
  inputs=[txt, chatbot, state],
906
- outputs=[txt, chatbot, state, status, quest_display],
907
  )
908
 
909
  demo.launch(share=share)
@@ -939,159 +1117,212 @@ Should not
939
  }
940
 
941
 
942
- def init_safety_model(model_name, force_cpu=False):
943
- """Initialize safety checking model with optimized memory usage"""
944
  try:
945
- if force_cpu:
946
- device = -1
947
- else:
948
- device = MODEL_CONFIG["safety_model"]["device"]
949
-
950
- # model_id = "meta-llama/Llama-Guard-3-8B"
951
- # model_id = "meta-llama/Llama-Guard-3-1B"
952
-
953
- api_key = get_huggingface_api_key()
954
-
955
- safety_model = AutoModelForCausalLM.from_pretrained(
956
- model_name,
957
- token=api_key,
958
- torch_dtype=MODEL_CONFIG["safety_model"]["dtype"],
959
- use_cache=True,
960
- device_map="auto",
961
- )
962
- safety_model.config.use_cache = True
963
 
964
- safety_tokenizer = AutoTokenizer.from_pretrained(model_name, token=api_key)
965
- # Set pad token explicitly
966
- safety_tokenizer.pad_token = safety_tokenizer.eos_token
967
 
968
- logger.info(f"Safety model initialized successfully on {device}")
969
- return safety_model, safety_tokenizer
 
 
970
 
971
- except Exception as e:
972
- logger.error(f"Failed to initialize safety model: {e}")
973
- raise
 
 
 
 
 
974
 
 
 
 
 
 
 
 
975
 
976
- # Initialize safety model pipeline
977
- try:
978
- safety_model_name = MODEL_CONFIG["safety_model"]["name"]
979
 
980
- api_key = get_huggingface_api_key()
981
 
982
- # Initialize the pipeline with memory management
983
- safety_model, safety_tokenizer = init_safety_model(safety_model_name)
 
 
 
 
 
 
 
984
 
985
- except Exception as e:
986
- logger.error(f"Failed to initialize model: {str(e)}")
987
- # Fallback to CPU if GPU initialization fails
988
- try:
989
- logger.info("Attempting CPU fallback...")
990
- safety_model, safety_tokenizer = init_safety_model(
991
- safety_model_name, force_cpu=True
992
- )
993
  except Exception as e:
994
- logger.error(f"CPU fallback failed: {str(e)}")
995
- raise
996
 
997
 
998
- @lru_cache(maxsize=1000)
999
- def get_safety_response(prompt: str) -> str:
1000
- """Cached safety check responses"""
1001
- with torch.no_grad():
1002
- inputs = safety_tokenizer(
1003
- prompt,
1004
- return_tensors="pt",
1005
- padding=True,
1006
- truncation=True,
1007
- max_length=MODEL_CONFIG["safety_model"]["max_length"],
1008
- ).to(safety_model.device)
1009
-
1010
- output = safety_model.generate(
1011
- **inputs,
1012
- max_new_tokens=5,
1013
- temperature=0.1, # Lower temperature for more consistent output
1014
- pad_token_id=safety_tokenizer.pad_token_id,
1015
- eos_token_id=safety_tokenizer.eos_token_id,
1016
- )
1017
-
1018
- return safety_tokenizer.decode(output[0], skip_special_tokens=True)
1019
 
 
 
1020
 
1021
- def is_safe(message: str) -> bool:
1022
- """Check if message content is safe"""
1023
- try:
1024
- logger.info("Performing safety check...")
1025
- # Build the prompt with embedded values
1026
- # prompt = f"""[INST] Task: Check if this content is safe according to our safety policy with the below categories.
1027
 
1028
- # <BEGIN UNSAFE CONTENT CATEGORIES>
1029
- # {everyone_content_policy['policy']}
1030
- # <END UNSAFE CONTENT CATEGORIES>
 
 
 
 
 
1031
 
1032
- # <CONTENT TO CHECK>
1033
- # {message}
1034
- # <END CONTENT TO CHECK>
1035
 
1036
- # Provide your safety assessment for user in the above CONTENT TO CHECK section:
1037
- # - Reply ONLY with single word 'safe' or 'unsafe' [/INST]
1038
- # """
1039
 
1040
- prompt = f"""[INST]Check if safe:
1041
- {message}
1042
- Reply 'safe' or 'unsafe'[/INST]
1043
- """
1044
 
1045
- # Generate safety check on cuda
1046
 
1047
- # with torch.no_grad():
1048
- # inputs = safety_tokenizer(
1049
- # prompt,
1050
- # return_tensors="pt",
1051
- # padding=True,
1052
- # truncation=True,
1053
- # )
1054
 
1055
- # # Move inputs to correct device
1056
- # inputs = {k: v.to(device) for k, v in inputs.items()}
1057
 
1058
- # output = safety_model.generate(
1059
- # **inputs,
1060
- # max_new_tokens=10,
1061
- # temperature=0.1, # Lower temperature for more consistent output
1062
- # pad_token_id=safety_tokenizer.pad_token_id, # Use configured pad token
1063
- # eos_token_id=safety_tokenizer.eos_token_id,
1064
- # do_sample=False,
1065
- # )
1066
 
1067
- # result = safety_tokenizer.decode(output[0], skip_special_tokens=True)
1068
- result = get_safety_response(prompt)
1069
- print(f"Raw safety check result: {result}")
 
 
 
 
 
 
 
 
1070
 
1071
- # # Extract response after prompt
1072
- # if "[/INST]" in result:
1073
- # result = result.split("[/INST]")[-1]
1074
 
1075
- # # Clean response
1076
- # result = result.lower().strip()
1077
- # print(f"Cleaned safety check result: {result}")
1078
- # words = [word for word in result.split() if word in ["safe", "unsafe"]]
 
 
 
 
 
 
 
1079
 
1080
- # # Take first valid response word
1081
- # is_safe = words[0] == "safe" if words else False
 
 
 
 
 
1082
 
1083
- # print("Final Safety check result:", is_safe)
1084
 
1085
- is_safe = "safe" in result.lower().split()
1086
 
1087
- logger.info(
1088
- f"Safety check completed - Result: {'SAFE' if is_safe else 'UNSAFE'}"
1089
- )
1090
- return is_safe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1091
 
1092
- except Exception as e:
1093
- logger.error(f"Safety check failed: {e}")
1094
- return False
1095
 
1096
 
1097
  def detect_inventory_changes(game_state, output):
 
4
  import json
5
  import gradio as gr
6
  import torch # first import torch then transformers
7
+ from torch.nn.functional import softmax
8
+ from transformers import AutoModelForSequenceClassification
9
  from huggingface_hub import InferenceClient
10
+
11
  from transformers import pipeline
12
  from transformers import AutoTokenizer, AutoModelForCausalLM
13
  import logging
 
73
  "dtype": torch.float32, # Use float32 for CPU
74
  "max_length": 256,
75
  "device": "cuda" if torch.cuda.is_available() else "cpu",
76
+ "max_tokens": 500,
77
  },
78
  }
79
 
80
+ PROMPT_GUARD_CONFIG = {
81
+ "model_id": "meta-llama/Prompt-Guard-86M",
82
+ "temperature": 1.0,
83
+ "jailbreak_threshold": 0.5,
84
+ "injection_threshold": 0.9,
85
+ "device": "cpu",
86
+ "safe_commands": [
87
+ "look around",
88
+ "investigate",
89
+ "explore",
90
+ "search",
91
+ "examine",
92
+ "take",
93
+ "use",
94
+ "go",
95
+ "walk",
96
+ "continue",
97
+ "help",
98
+ "inventory",
99
+ "quest",
100
+ "status",
101
+ "map",
102
+ "talk",
103
+ "fight",
104
+ "run",
105
+ "hide",
106
+ ],
107
+ "max_length": 512,
108
+ }
109
 
110
+
111
+ def initialize_prompt_guard():
112
+ """Initialize Prompt Guard model"""
113
  try:
114
+ tokenizer = AutoTokenizer.from_pretrained(PROMPT_GUARD_CONFIG["model_id"])
115
+ model = AutoModelForSequenceClassification.from_pretrained(
116
+ PROMPT_GUARD_CONFIG["model_id"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
  )
118
+ return model, tokenizer
119
+ except Exception as e:
120
+ logger.error(f"Failed to initialize Prompt Guard: {e}")
121
+ raise
122
 
 
123
 
124
+ def get_class_probabilities(text: str, guard_model, guard_tokenizer) -> torch.Tensor:
125
+ """Evaluate model probabilities with temperature scaling"""
126
+ try:
127
+ inputs = guard_tokenizer(
128
+ text,
129
+ return_tensors="pt",
130
+ padding=True,
131
+ truncation=True,
132
+ max_length=PROMPT_GUARD_CONFIG["max_length"],
133
+ ).to(PROMPT_GUARD_CONFIG["device"])
134
 
135
+ with torch.no_grad():
136
+ logits = guard_model(**inputs).logits
 
 
 
 
 
 
 
 
137
 
138
+ scaled_logits = logits / PROMPT_GUARD_CONFIG["temperature"]
139
+ return softmax(scaled_logits, dim=-1)
140
 
 
 
 
141
  except Exception as e:
142
+ logger.error(f"Error getting class probabilities: {e}")
143
+ return None
144
 
145
 
146
+ def get_jailbreak_score(text: str, guard_model, guard_tokenizer) -> float:
147
+ """Get jailbreak probability score"""
148
  try:
149
+ probabilities = get_class_probabilities(text, guard_model, guard_tokenizer)
150
+ if probabilities is None:
151
+ return 1.0 # Fail safe
152
+ return probabilities[0, 2].item()
153
+ except Exception as e:
154
+ logger.error(f"Error getting jailbreak score: {e}")
155
+ return 1.0
156
 
157
+
158
+ def get_injection_score(text: str, guard_model, guard_tokenizer) -> float:
159
+ """Get injection probability score"""
160
+ try:
161
+ probabilities = get_class_probabilities(text, guard_model, guard_tokenizer)
162
+ if probabilities is None:
163
+ return 1.0 # Fail safe
164
+ return (probabilities[0, 1] + probabilities[0, 2]).item()
165
  except Exception as e:
166
+ logger.error(f"Error getting injection score: {e}")
167
+ return 1.0
168
+
169
+
170
+ # Initialize safety model pipeline
171
+ try:
172
+ # Initialize Prompt Guard
173
+ guard_model, guard_tokenizer = initialize_prompt_guard()
174
+
175
+ except Exception as e:
176
+ logger.error(f"Failed to initialize model: {str(e)}")
177
+
178
+
179
+ def is_prompt_safe(message: str) -> bool:
180
+ """Enhanced safety check with Prompt Guard"""
181
+ try:
182
+ # Allow safe game commands
183
+ if any(cmd in message.lower() for cmd in PROMPT_GUARD_CONFIG["safe_commands"]):
184
+ logger.info("Message matched safe command pattern")
185
+ return True
186
 
187
+ # Get safety scores
188
+ jailbreak_score = get_jailbreak_score(message, guard_model, guard_tokenizer)
189
+ injection_score = get_injection_score(message, guard_model, guard_tokenizer)
190
+
191
+ logger.info(
192
+ f"Safety scores - Jailbreak: {jailbreak_score}, Injection: {injection_score}"
193
+ )
194
+
195
+ # Check against thresholds
196
+ is_safe = (
197
+ jailbreak_score
198
+ < PROMPT_GUARD_CONFIG["jailbreak_threshold"]
199
+ # and injection_score < PROMPT_GUARD_CONFIG["injection_threshold"] # Disable for now because injection is too strict and current prompt guard model seems malfunctioning for now.
200
+ )
201
+
202
+ logger.info(f"Final safety result: {is_safe}")
203
+ return is_safe
204
+
205
+ except Exception as e:
206
+ logger.error(f"Safety check failed: {e}")
207
+ return False
208
+
209
+
210
+ # def initialize_model_pipeline(model_name, force_cpu=False):
211
+ # """Initialize pipeline with memory management"""
212
+ # try:
213
+ # if force_cpu:
214
+ # device = -1
215
+ # else:
216
+ # device = MODEL_CONFIG["main_model"]["device"]
217
+
218
+ # api_key = get_huggingface_api_key()
219
+
220
+ # # Use 8-bit quantization for memory efficiency
221
+ # model = AutoModelForCausalLM.from_pretrained(
222
+ # model_name,
223
+ # load_in_8bit=False,
224
+ # torch_dtype=MODEL_CONFIG["main_model"]["dtype"],
225
+ # use_cache=True,
226
+ # device_map="auto",
227
+ # low_cpu_mem_usage=True,
228
+ # trust_remote_code=True,
229
+ # token=api_key, # Add token here
230
+ # )
231
+
232
+ # model.config.use_cache = True
233
+
234
+ # tokenizer = AutoTokenizer.from_pretrained(model_name, token=api_key)
235
+
236
+ # # Initialize pipeline
237
+ # logger.info(f"Initializing pipeline with device: {device}")
238
+ # generator = pipeline(
239
+ # "text-generation",
240
+ # model=model,
241
+ # tokenizer=tokenizer,
242
+ # # device=device,
243
+ # # temperature=0.7,
244
+ # model_kwargs={"low_cpu_mem_usage": True},
245
+ # )
246
+
247
+ # logger.info("Model Pipeline initialized successfully")
248
+ # return generator, tokenizer
249
+
250
+ # except ImportError as e:
251
+ # logger.error(f"Missing required package: {str(e)}")
252
+ # raise
253
+ # except Exception as e:
254
+ # logger.error(f"Failed to initialize pipeline: {str(e)}")
255
+ # raise
256
 
257
  # # Initialize model pipeline
258
  # try:
 
279
  # raise
280
 
281
 
282
+ def initialize_inference_client():
283
+ """Initialize HuggingFace Inference Client"""
284
+ try:
285
+ inference_key = get_huggingface_inference_key()
286
+
287
+ client = InferenceClient(api_key=inference_key)
288
+ logger.info("Inference Client initialized successfully")
289
+ return client
290
+ except Exception as e:
291
+ logger.error(f"Failed to initialize Inference Client: {e}")
292
+ raise
293
+
294
+
295
  def load_world(filename):
296
  with open(filename, "r") as f:
297
  return json.load(f)
298
 
299
 
300
  # Define system_prompt and model
301
+ system_prompt = """You are an AI Game master. Your job is to write what happens next in a player's adventure game.
302
  CRITICAL Rules:
303
  - Write EXACTLY 3 sentences maximum
304
  - Use daily English language
305
+ - Start with "You "
306
  - Don't use 'Elara' or 'she/he', only use 'you'
307
  - Use only second person ("you")
308
  - Never include dialogue after the response
 
314
  - Never include 'What would you like to do?' or similar prompts
315
  - Always finish with one real response
316
  - Never use 'Your turn' or or anything like conversation starting prompts
317
+ - Always end the response with a period(.)"""
318
 
319
 
320
  def get_game_state(inventory: Dict = None) -> Dict[str, Any]:
 
566
 
567
 
568
  def parse_items_from_story(text: str) -> Dict[str, int]:
569
+ """Extract item changes from story text with improved pattern matching"""
570
  items = {}
571
 
572
+ # Skip parsing if text starts with common narrative phrases
573
+ skip_patterns = [
574
+ "you see",
575
+ "you find yourself",
576
+ "you are",
577
+ "you stand",
578
+ "you hear",
579
+ "you feel",
580
+ ]
581
+ if any(text.lower().startswith(pattern) for pattern in skip_patterns):
582
+ return items
583
 
584
+ # Common item keywords and patterns
585
+ gold_pattern = r"(\d+)\s*gold(?:\s+coins?)?"
586
+ items_pattern = r"(?:receive|find|given|obtain|pick up|grab)\s+(?:a|an|the)?\s*(\d+)?\s*([\w\s]+?)"
 
587
 
588
+ try:
589
+ # Find gold amounts
590
+ gold_matches = re.findall(gold_pattern, text.lower())
591
+ if gold_matches:
592
+ items["gold"] = sum(int(x) for x in gold_matches)
593
+
594
+ # Find other items
595
+ item_matches = re.findall(items_pattern, text.lower())
596
+ for count, item in item_matches:
597
+ # Validate item name
598
+ item = item.strip()
599
+ if len(item) > 2 and not any( # Minimum length check
600
+ skip in item for skip in ["yourself", "you", "door", "wall", "floor"]
601
+ ): # Skip common words
602
+ count = int(count) if count else 1
603
+ if item in items:
604
+ items[item] += count
605
+ else:
606
+ items[item] = count
607
+
608
+ return items
609
 
610
+ except Exception as e:
611
+ logger.error(f"Error parsing items from story: {e}")
612
+ return {}
613
 
614
 
615
+ def update_game_inventory(game_state: Dict, story_text: str) -> Tuple[str, list]:
616
+ """Update inventory and return message and updated inventory data"""
617
  try:
618
  items = parse_items_from_story(story_text)
619
  update_msg = ""
620
 
621
+ # Update inventory
622
  for item, count in items.items():
623
  if item in game_state["inventory"]:
624
  game_state["inventory"][item] += count
 
626
  game_state["inventory"][item] = count
627
  update_msg += f"\nReceived: {count} {item}"
628
 
629
+ # Create updated inventory data for display
630
+ inventory_data = [
631
+ [item, count] for item, count in game_state["inventory"].items()
632
+ ]
633
+
634
+ return update_msg, inventory_data
635
  except Exception as e:
636
  logger.error(f"Error updating inventory: {e}")
637
+ return "", []
638
 
639
 
640
  def extract_response_after_action(full_text: str, action: str) -> str:
 
686
 
687
  {game_state['start']}
688
 
689
+ You are currently in {game_state['town_name']}, {game_state['town']}.
690
+
691
+ {game_state['town_name']} is a city in {game_state['kingdom']}.
692
+
693
 
694
 
695
  Current Quest: {initial_quest['title']}
 
703
  logger.error(f"Invalid game state type: {type(game_state)}")
704
  return "Error: Invalid game state"
705
 
706
+ # Safety check with Prompt Guard
707
+ if not is_prompt_safe(message):
708
+ logger.warning("Unsafe content detected in user prompt")
709
+ return "I cannot process that request for safety reasons."
710
+
711
  # logger.info(f"Processing action with game state: {game_state}")
712
  logger.info(f"Processing action with game state")
713
 
 
752
 
753
  # Add history in correct alternating format
754
  if history:
755
+ # for h in history[-3:]: # Last 3 exchanges
756
+ for h in history:
757
  if isinstance(h, tuple):
758
  messages.append({"role": "user", "content": h[0]})
759
  messages.append({"role": "assistant", "content": h[1]})
 
828
  if not response:
829
  return "You look around carefully."
830
 
831
+ # Safety check the responce using inference API
832
+ if not is_safe(response):
833
+ logger.warning("Unsafe content detected - blocking response")
834
+ return "This response was blocked for safety reasons."
835
+
836
  # # Perform safety check before returning
837
  # safe = is_safe(response)
838
  # print(f"\nSafety Check Result: {'SAFE' if safe else 'UNSAFE'}")
 
859
  if quest_completed:
860
  response += quest_message
861
 
862
+ # Check for item-inventory updates
863
+ inventory_update, inventory_data = update_game_inventory(game_state, response)
864
  if inventory_update:
865
  response += inventory_update
866
 
 
902
  """Process chat input and return response with updates"""
903
  try:
904
  if not message.strip():
905
+ return chat_history, current_state, "", "", [] # Add empty inventory data
906
 
907
  # Get AI response
908
  output = run_action(message, chat_history, current_state)
 
914
  # Update status displays
915
  status_text, quest_text = update_game_status(current_state)
916
 
917
+ # Get inventory updates
918
+ update_msg, inventory_data = update_game_inventory(current_state, output)
919
+ if update_msg:
920
+ output += update_msg
921
+
922
  # Return tuple includes empty string to clear input
923
+ return chat_history, current_state, status_text, quest_text, inventory_data
924
 
925
  except Exception as e:
926
  logger.error(f"Error in chat response: {e}")
927
+ return chat_history, current_state, "", "", []
928
 
929
 
930
  def start_game(main_loop, game_state, share=False):
 
1062
 
1063
  def submit_action(message, history, state):
1064
  # Process response
1065
+ new_history, new_state, status_text, quest_text, inventory_data = (
1066
+ chat_response(message, history, state)
1067
  )
1068
+
1069
+ # Update inventory display
1070
+ inventory.value = inventory_data
1071
+
1072
  # Clear input
1073
+ return "", new_history, new_state, status_text, quest_text, inventory
1074
 
1075
  submit_btn.click(
1076
  submit_action,
1077
  inputs=[txt, chatbot, state],
1078
+ outputs=[txt, chatbot, state, status, quest_display, inventory],
1079
  )
1080
 
1081
  txt.submit(
1082
  submit_action,
1083
  inputs=[txt, chatbot, state],
1084
+ outputs=[txt, chatbot, state, status, quest_display, inventory],
1085
  )
1086
 
1087
  demo.launch(share=share)
 
1117
  }
1118
 
1119
 
1120
+ def initialize_safety_client():
1121
+ """Initialize HuggingFace Inference Client"""
1122
  try:
1123
+ inference_key = get_huggingface_inference_key()
1124
+ # api_key = get_huggingface_api_key()
1125
+ return InferenceClient(api_key=inference_key)
1126
+ except Exception as e:
1127
+ logger.error(f"Failed to initialize safety client: {e}")
1128
+ raise
 
 
 
 
 
 
 
 
 
 
 
 
1129
 
 
 
 
1130
 
1131
+ def is_safe(message: str) -> bool:
1132
+ """Check content safety using Inference API"""
1133
+ try:
1134
+ client = initialize_safety_client()
1135
 
1136
+ messages = [
1137
+ {"role": "user", "content": f"Check if this content is safe:\n{message}"},
1138
+ {
1139
+ "role": "assistant",
1140
+ "content": f"I will check if the content is safe based on this content policy:\n{everyone_content_policy['policy']}",
1141
+ },
1142
+ {"role": "user", "content": "Is it safe or unsafe?"},
1143
+ ]
1144
 
1145
+ try:
1146
+ completion = client.chat.completions.create(
1147
+ model=MODEL_CONFIG["safety_model"]["name"],
1148
+ messages=messages,
1149
+ max_tokens=MODEL_CONFIG["safety_model"]["max_tokens"],
1150
+ temperature=0.1,
1151
+ )
1152
 
1153
+ response = completion.choices[0].message.content.lower()
1154
+ logger.info(f"Safety check response: {response}")
 
1155
 
1156
+ is_safe = "safe" in response and "unsafe" not in response
1157
 
1158
+ logger.info(f"Safety check result: {'SAFE' if is_safe else 'UNSAFE'}")
1159
+ return is_safe
1160
+
1161
+ except Exception as api_error:
1162
+ logger.error(f"API error: {api_error}")
1163
+ # Fallback to allow common game commands
1164
+ return any(
1165
+ cmd in message.lower() for cmd in PROMPT_GUARD_CONFIG["safe_commands"]
1166
+ )
1167
 
 
 
 
 
 
 
 
 
1168
  except Exception as e:
1169
+ logger.error(f"Safety check failed: {e}")
1170
+ return False
1171
 
1172
 
1173
+ # def init_safety_model(model_name, force_cpu=False):
1174
+ # """Initialize safety checking model with optimized memory usage"""
1175
+ # try:
1176
+ # if force_cpu:
1177
+ # device = -1
1178
+ # else:
1179
+ # device = MODEL_CONFIG["safety_model"]["device"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1180
 
1181
+ # # model_id = "meta-llama/Llama-Guard-3-8B"
1182
+ # # model_id = "meta-llama/Llama-Guard-3-1B"
1183
 
1184
+ # api_key = get_huggingface_api_key()
 
 
 
 
 
1185
 
1186
+ # safety_model = AutoModelForCausalLM.from_pretrained(
1187
+ # model_name,
1188
+ # token=api_key,
1189
+ # torch_dtype=MODEL_CONFIG["safety_model"]["dtype"],
1190
+ # use_cache=True,
1191
+ # device_map="auto",
1192
+ # )
1193
+ # safety_model.config.use_cache = True
1194
 
1195
+ # safety_tokenizer = AutoTokenizer.from_pretrained(model_name, token=api_key)
1196
+ # # Set pad token explicitly
1197
+ # safety_tokenizer.pad_token = safety_tokenizer.eos_token
1198
 
1199
+ # logger.info(f"Safety model initialized successfully on {device}")
1200
+ # return safety_model, safety_tokenizer
 
1201
 
1202
+ # except Exception as e:
1203
+ # logger.error(f"Failed to initialize safety model: {e}")
1204
+ # raise
 
1205
 
 
1206
 
1207
+ # # Initialize safety model pipeline
1208
+ # try:
1209
+ # safety_model_name = MODEL_CONFIG["safety_model"]["name"]
 
 
 
 
1210
 
1211
+ # api_key = get_huggingface_api_key()
 
1212
 
1213
+ # # Initialize the pipeline with memory management
1214
+ # safety_model, safety_tokenizer = init_safety_model(safety_model_name)
 
 
 
 
 
 
1215
 
1216
+ # except Exception as e:
1217
+ # logger.error(f"Failed to initialize model: {str(e)}")
1218
+ # # Fallback to CPU if GPU initialization fails
1219
+ # try:
1220
+ # logger.info("Attempting CPU fallback...")
1221
+ # safety_model, safety_tokenizer = init_safety_model(
1222
+ # safety_model_name, force_cpu=True
1223
+ # )
1224
+ # except Exception as e:
1225
+ # logger.error(f"CPU fallback failed: {str(e)}")
1226
+ # raise
1227
 
 
 
 
1228
 
1229
+ # @lru_cache(maxsize=1000)
1230
+ # def get_safety_response(prompt: str) -> str:
1231
+ # """Cached safety check responses"""
1232
+ # with torch.no_grad():
1233
+ # inputs = safety_tokenizer(
1234
+ # prompt,
1235
+ # return_tensors="pt",
1236
+ # padding=True,
1237
+ # truncation=True,
1238
+ # max_length=MODEL_CONFIG["safety_model"]["max_length"],
1239
+ # ).to(safety_model.device)
1240
 
1241
+ # output = safety_model.generate(
1242
+ # **inputs,
1243
+ # max_new_tokens=5,
1244
+ # temperature=0.1, # Lower temperature for more consistent output
1245
+ # pad_token_id=safety_tokenizer.pad_token_id,
1246
+ # eos_token_id=safety_tokenizer.eos_token_id,
1247
+ # )
1248
 
1249
+ # return safety_tokenizer.decode(output[0], skip_special_tokens=True)
1250
 
 
1251
 
1252
+ # def is_safe(message: str) -> bool:
1253
+ # """Check if message content is safe"""
1254
+ # try:
1255
+ # logger.info("Performing safety check...")
1256
+ # # Build the prompt with embedded values
1257
+ # # prompt = f"""[INST] Task: Check if this content is safe according to our safety policy with the below categories.
1258
+
1259
+ # # <BEGIN UNSAFE CONTENT CATEGORIES>
1260
+ # # {everyone_content_policy['policy']}
1261
+ # # <END UNSAFE CONTENT CATEGORIES>
1262
+
1263
+ # # <CONTENT TO CHECK>
1264
+ # # {message}
1265
+ # # <END CONTENT TO CHECK>
1266
+
1267
+ # # Provide your safety assessment for user in the above CONTENT TO CHECK section:
1268
+ # # - Reply ONLY with single word 'safe' or 'unsafe' [/INST]
1269
+ # # """
1270
+
1271
+ # prompt = f"""[INST]Check if safe:
1272
+ # {message}
1273
+ # Reply 'safe' or 'unsafe'[/INST]
1274
+ # """
1275
+
1276
+ # # Generate safety check on cuda
1277
+
1278
+ # # with torch.no_grad():
1279
+ # # inputs = safety_tokenizer(
1280
+ # # prompt,
1281
+ # # return_tensors="pt",
1282
+ # # padding=True,
1283
+ # # truncation=True,
1284
+ # # )
1285
+
1286
+ # # # Move inputs to correct device
1287
+ # # inputs = {k: v.to(device) for k, v in inputs.items()}
1288
+
1289
+ # # output = safety_model.generate(
1290
+ # # **inputs,
1291
+ # # max_new_tokens=10,
1292
+ # # temperature=0.1, # Lower temperature for more consistent output
1293
+ # # pad_token_id=safety_tokenizer.pad_token_id, # Use configured pad token
1294
+ # # eos_token_id=safety_tokenizer.eos_token_id,
1295
+ # # do_sample=False,
1296
+ # # )
1297
+
1298
+ # # result = safety_tokenizer.decode(output[0], skip_special_tokens=True)
1299
+ # result = get_safety_response(prompt)
1300
+ # print(f"Raw safety check result: {result}")
1301
+
1302
+ # # # Extract response after prompt
1303
+ # # if "[/INST]" in result:
1304
+ # # result = result.split("[/INST]")[-1]
1305
+
1306
+ # # # Clean response
1307
+ # # result = result.lower().strip()
1308
+ # # print(f"Cleaned safety check result: {result}")
1309
+ # # words = [word for word in result.split() if word in ["safe", "unsafe"]]
1310
+
1311
+ # # # Take first valid response word
1312
+ # # is_safe = words[0] == "safe" if words else False
1313
+
1314
+ # # print("Final Safety check result:", is_safe)
1315
+
1316
+ # is_safe = "safe" in result.lower().split()
1317
+
1318
+ # logger.info(
1319
+ # f"Safety check completed - Result: {'SAFE' if is_safe else 'UNSAFE'}"
1320
+ # )
1321
+ # return is_safe
1322
 
1323
+ # except Exception as e:
1324
+ # logger.error(f"Safety check failed: {e}")
1325
+ # return False
1326
 
1327
 
1328
  def detect_inventory_changes(game_state, output):