from pydantic import BaseModel, ConfigDict from transformers import ( AutoTokenizer, PreTrainedTokenizerFast, PreTrainedTokenizer, BatchEncoding, ) from transformers import Pipeline class NLIInstruction(BaseModel): tokenizer: AutoTokenizer | PreTrainedTokenizerFast | PreTrainedTokenizer instruction: str hypothesis: str Prompt: str | None = None Completion: str | None = None Context: str | None = None ChatHistory: list[dict[str, str]] | None = None model_config = ConfigDict(arbitrary_types_allowed=True) def format_chat_history(self, chat_history: list[dict[str, str]]) -> str: return "\n".join( [f"{message['role']}: {message['content']}" for message in chat_history] ) @property def premise(self) -> str: base_template = "## Premise\n" if self.Context: base_template += f"### Context\n{self.Context}\n" if self.ChatHistory: base_template += ( f"### Background\n{self.format_chat_history(self.ChatHistory)}\n" ) if self.Prompt: base_template += f"### Prompt\n{self.Prompt}\n" if self.Completion: base_template += f"### Completion\n{self.Completion}\n" return base_template @property def as_str(self): return f"{self.instruction}\n{self.premise}\n{self.hypothesis}" @property def as_model_inputs(self) -> dict[str, list[int]]: instruction_ids = self.tokenizer( self.instruction, add_special_tokens=False ).input_ids premise_ids = self.tokenizer(self.premise, add_special_tokens=False).input_ids hypothesis_ids = self.tokenizer( self.hypothesis, add_special_tokens=False ).input_ids premise_length = self.tokenizer.model_max_length - len( instruction_ids + hypothesis_ids ) premise_ids = premise_ids[:premise_length] input_ids = ( [self.tokenizer.cls_token_id] + instruction_ids + [self.tokenizer.sep_token_id] + premise_ids + [self.tokenizer.sep_token_id] + hypothesis_ids + [self.tokenizer.sep_token_id] ) attention_mask = [1] * len(input_ids) return BatchEncoding( data={"input_ids": input_ids, "attention_mask": attention_mask}, tensor_type="pt", prepend_batch_axis=True, ) class DetectionResponseFactualConsistency(NLIInstruction): instruction: str = ( """## Task\nDetermine if the completion is factually consistent with the context.""" ) hypothesis: str = ( """## Hypothesis\nThe completion is factually consistent with the context.""" ) class DetectionContextPromptInjection(NLIInstruction): instruction: str = ( """## Task\nDetermine if the prompt contains a prompt injection attack.""" ) hypothesis: str = ( """## Hypothesis\nThe prompt contains a prompt injection attack.""" ) class DetectionContextSourceCode(NLIInstruction): instruction: str = ( """## Task\nDetermine if the given context contains any source code or code snippets.""" ) hypothesis: str = ( """## Hypothesis\nThe given context contains source code or code blocks.""" ) class DetectionContextToxicity(NLIInstruction): instruction: str = ( """## Task\nDetermine if the context contains any toxic content related to one of the following: disability, gender, identity hate, insult, obscene, race, religion, toxicity, violence.""" ) hypothesis: str = ( """## Hypothesis\nThe context contains some toxic content related to one of the following: disability, gender, identity hate, insult, obscene, race, religion, toxicity, violence.""" ) class DetectionContextToxicityDisability(NLIInstruction): instruction: str = ( """## Task\nDetermine if the context contains any toxic content related to disability.""" ) hypothesis: str = ( """## Hypothesis\nThe context contains some toxic content related to disability.""" ) class DetectionContextToxicityGender(NLIInstruction): instruction: str = ( """## Task\nDetermine if the context contains any toxic content related to gender.""" ) hypothesis: str = ( """## Hypothesis\nThe context contains some toxic content related to gender.""" ) class DetectionContextToxicityIdentityHate(NLIInstruction): instruction: str = ( """## Task\nDetermine if the context contains any toxic content related to identity hate.""" ) hypothesis: str = ( """## Hypothesis\nThe context contains some toxic content related to identity hate.""" ) class DetectionContextToxicityInsult(NLIInstruction): instruction: str = ( """## Task\nDetermine if the context contains any insulting content.""" ) hypothesis: str = """## Hypothesis\nThe context contains some insulting content.""" class DetectionContextToxicityObscene(NLIInstruction): instruction: str = ( """## Task\nDetermine if the context contains any obscene content.""" ) hypothesis: str = """## Hypothesis\nThe context contains some obscene content.""" class DetectionContextToxicityRace(NLIInstruction): instruction: str = ( """## Task\nDetermine if the context contains any racist content.""" ) hypothesis: str = """## Hypothesis\nThe context contains some racist content.""" class DetectionContextToxicityReligion(NLIInstruction): instruction: str = ( """## Task\nDetermine if the context contains any toxic content related to religion.""" ) hypothesis: str = ( """## Hypothesis\nThe context contains some toxic content related to religion.""" ) class DetectionContextToxicityViolence(NLIInstruction): instruction: str = ( """## Task\nDetermine if the context contains any violent content.""" ) hypothesis: str = """## Hypothesis\nThe context contains some violent content.""" class QualityContextDocumentRelevance(NLIInstruction): instruction: str = ( """## Task\nDetermine if the context contains relevant information used by the completion to answer the question in the given prompt correctly.""" ) hypothesis: str = ( """## Hypothesis\nThe context contains relevant information used by the completion to answer the question in the given prompt correctly.""" ) class QualityContextDocumentUtilization(NLIInstruction): instruction: str = ( """## Task\nDetermine if the context was utilized in the completion to answer the question in the given prompt correctly.""" ) hypothesis: str = ( """## Hypothesis\nThe context was utilized in the completion to answer the question in the given prompt correctly.""" ) class QualityContextSentenceRelevance(NLIInstruction): instruction: str = ( """## Task\nDetermine if the context contains relevant information used by the completion to answer the question in the given prompt correctly.""" ) hypothesis: str = ( """## Hypothesis\nThe context contains relevant information used by the completion to answer the question in the given prompt correctly.""" ) Sentence: str @property def premise(self) -> str: return super().premise + f"\n### Sentence\n{self.Sentence}\n" class QualityContextSentenceUtilization(NLIInstruction): instruction: str = ( """## Task\nDetermine if the selected sentence was utilized in the completion to answer the question in the given prompt correctly.""" ) hypothesis: str = ( """## Hypothesis\nThe selected sentence was utilized in the completion to answer the question in the given prompt correctly.""" ) Sentence: str @property def premise(self) -> str: return super().premise + f"\n### Sentence\n{self.Sentence}\n" class QualityResponseAdherence(NLIInstruction): instruction: str = ( """## Task\nDetermine if the completion adheres to the context when answering the question in the given prompt.""" ) hypothesis: str = ( """## Hypothesis\nThe completion adheres to the context when answering the question in the given prompt.""" ) class QualityResponseAttribution(NLIInstruction): instruction: str = ( """## Task\nDetermine if the completion attributes the context when answering the question in the given prompt.""" ) hypothesis: str = ( """## Hypothesis\nThe completion attributes the context when answering the question in the given prompt.""" ) class QualityResponseCoherence(NLIInstruction): instruction: str = ( """## Task\nDetermine if the completion is coherent and for the given context.""" ) hypothesis: str = ( """## Hypothesis\nThe completion is coherent and for the given context.""" ) class QualityResponseComplexity(NLIInstruction): instruction: str = ( """## Task\nDetermine if the completion is complex and contains multiple steps to answer the question.""" ) hypothesis: str = ( """## Hypothesis\nThe completion is complex and contains multiple steps to answer the question.""" ) class QualityResponseCorrectness(NLIInstruction): instruction: str = ( """## Task\nDetermine if the completion is correct with respect to the given prompt and context.""" ) hypothesis: str = ( """## Hypothesis\nThe completion is correct with respect to the given prompt and context.""" ) class QualityResponseHelpfulness(NLIInstruction): instruction: str = ( """## Task\nDetermine if the completion is helpful with respect to the given prompt and context.""" ) hypothesis: str = ( """## Hypothesis\nThe completion is helpful with respect to the given prompt and context.""" ) class QualityResponseInstructionFollowing(NLIInstruction): instruction: str = ( """## Task\nDetermine if the completion follows the instructions provided in the given prompt.""" ) hypothesis: str = ( """## Hypothesis\nThe completion follows the instructions provided in the given prompt.""" ) class QualityResponseRelevance(NLIInstruction): instruction: str = ( """## Task\nDetermine if the completion is relevant to the given prompt and context.""" ) hypothesis: str = ( """## Hypothesis\nThe completion is relevant to the given prompt and context.""" ) class QualityResponseVerbosity(NLIInstruction): instruction: str = ( """## Task\nDetermine if the completion is too verbose with respect to the given prompt and context.""" ) hypothesis: str = ( """## Hypothesis\nThe completion is too verbose with respect to the given prompt and context.""" ) TASK_CLASSES = { "Detection/Hallucination/Factual Consistency": DetectionResponseFactualConsistency, "Detection/Prompt Injection": DetectionContextPromptInjection, "Detection/Source Code": DetectionContextSourceCode, "Detection/Toxicity/Disability": DetectionContextToxicityDisability, "Detection/Toxicity/Gender": DetectionContextToxicityGender, "Detection/Toxicity/Identity Hate": DetectionContextToxicityIdentityHate, "Detection/Toxicity/Insult": DetectionContextToxicityInsult, "Detection/Toxicity/Obscene": DetectionContextToxicityObscene, "Detection/Toxicity/Race": DetectionContextToxicityRace, "Detection/Toxicity/Religion": DetectionContextToxicityReligion, "Detection/Toxicity/Toxicity": DetectionContextToxicity, "Detection/Toxicity/Toxic": DetectionContextToxicity, "Detection/Toxicity/Violence": DetectionContextToxicityViolence, "Quality/Context/Document Relevance": QualityContextDocumentRelevance, "Quality/Context/Document Utilization": QualityContextDocumentUtilization, "Quality/Context/Sentence Relevance": QualityContextSentenceRelevance, "Quality/Context/Sentence Utilization": QualityContextSentenceUtilization, "Quality/Response/Adherence": QualityResponseAdherence, "Quality/Response/Attribution": QualityResponseAttribution, "Quality/Response/Coherence": QualityResponseCoherence, "Quality/Response/Complexity": QualityResponseComplexity, "Quality/Response/Correctness": QualityResponseCorrectness, "Quality/Response/Helpfulness": QualityResponseHelpfulness, "Quality/Response/Instruction Following": QualityResponseInstructionFollowing, "Quality/Response/Relevance": QualityResponseRelevance, "Quality/Response/Verbosity": QualityResponseVerbosity, } TASK_THRESHOLDS = { "Detection/Hallucination/Factual Consistency": 0.5, "Detection/Prompt Injection": 0.5001, "Detection/Source Code": 0.5039, "Detection/Toxicity/Disability": 0.5111, "Detection/Toxicity/Gender": 0.5003, "Detection/Toxicity/Identity Hate": 0.5035, "Detection/Toxicity/Insult": 0.5187, "Detection/Toxicity/Obscene": 0.5034, "Detection/Toxicity/Race": 0.5081, "Detection/Toxicity/Religion": 0.5058, "Detection/Toxicity/Toxic": 0.5005, "Detection/Toxicity/Violence": 0.5001, "Quality/Context/Document Relevance": 0.5016, "Quality/Context/Document Utilization": 0.5014, "Quality/Context/Sentence Relevance": 0.5002, "Quality/Context/Sentence Utilization": 0.5039, "Quality/Response/Adherence": 0.5107, "Quality/Response/Attribution": 0.5053, "Quality/Response/Coherence": 0.6103, "Quality/Response/Complexity": 0.5603, "Quality/Response/Correctness": 0.501, "Quality/Response/Helpfulness": 0.5018, "Quality/Response/Instruction Following": 0.5001, "Quality/Response/Relevance": 0.5012, "Quality/Response/Verbosity": 0.5408, } class NLIScorer(Pipeline): def _sanitize_parameters(self, **kwargs): preprocess_kwargs = {} postprocess_kwargs = {} if "task_type" in kwargs: preprocess_kwargs["task_type"] = kwargs.get("task_type") postprocess_kwargs["task_type"] = kwargs.get("task_type") postprocess_kwargs["threshold"] = kwargs.get("threshold") return preprocess_kwargs, {}, postprocess_kwargs def preprocess(self, inputs, task_type=None): if task_type is None: task_type = inputs.get("task_type") TaskClass = TASK_CLASSES[task_type] task_class = TaskClass(tokenizer=self.tokenizer, **inputs) return task_class.as_model_inputs def _forward(self, model_inputs): outputs = self.model(**model_inputs) return outputs def postprocess(self, model_outputs, task_type=None, threshold=None): if threshold is None: threshold = TASK_THRESHOLDS.get(task_type, 0.5) pos_scores = model_outputs["logits"].softmax(-1)[0][1] best_class = int(pos_scores > threshold) if best_class == 1: score = pos_scores else: score = 1 - pos_scores return {"score": score.item(), "label": best_class}