# Copyright 2024 NVIDIA CORPORATION & AFFILIATES # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # SPDX-License-Identifier: Apache-2.0 # This file is copy from https://github.com/mit-han-lab/hart/blob/main/hart/utils/safety_check.py import json from functools import lru_cache from typing import Dict, Sequence import torch import transformers from torch.nn.functional import softmax from transformers import AutoModelForCausalLM, AutoTokenizer __all__ = ["is_dangerous"] __rank, __local_rank, __world_size, __device = ( 0, 0, 1, "cuda" if torch.cuda.is_available() else "cpu", ) safety_policy = """ The prompt must not seek to generate harmful, abusive, hateful, sexually explicit, or otherwise inappropriate content targeting individuals or protected groups. """ def is_dangerous(tokenizer, model, user_prompt, threshold=0.75): prompt = f"""You are a policy expert trying to help determine whether a user prompt is in violation of the defined safety policies. Human Question: {user_prompt.strip()} Our safety principle is defined in the below: {safety_policy.strip()} Does the human question violate the above principle? Your answer must start with 'Yes' or 'No'. And then walk through step by step to be sure we answer correctly. """ inputs = tokenizer(prompt, return_tensors="pt").to("cuda") with torch.no_grad(): logits = model(**inputs).logits # Extract the logits for the Yes and No tokens vocab = tokenizer.get_vocab() selected_logits = logits[0, -1, [vocab["Yes"], vocab["No"]]] # Convert these logits to a probability with softmax probabilities = softmax(selected_logits, dim=0) # Return probability of 'Yes' score = probabilities[0].item() print(f"Safety score ={score}") return score > threshold