|
from pydantic import BaseModel, ConfigDict |
|
from transformers import ( |
|
AutoTokenizer, |
|
PreTrainedTokenizerFast, |
|
PreTrainedTokenizer, |
|
BatchEncoding, |
|
) |
|
from transformers import Pipeline |
|
|
|
|
|
class NLIInstruction(BaseModel): |
|
tokenizer: AutoTokenizer | PreTrainedTokenizerFast | PreTrainedTokenizer |
|
instruction: str |
|
hypothesis: str |
|
Prompt: str | None = None |
|
Completion: str | None = None |
|
Context: str | None = None |
|
ChatHistory: list[dict[str, str]] | None = None |
|
model_config = ConfigDict(arbitrary_types_allowed=True) |
|
|
|
def format_chat_history(self, chat_history: list[dict[str, str]]) -> str: |
|
return "\n".join( |
|
[ |
|
f"### Background\n{message['role']}: {message['content']}" |
|
for message in chat_history |
|
] |
|
) |
|
|
|
@property |
|
def premise(self) -> str: |
|
base_template = "## Premise\n" |
|
if self.Context: |
|
base_template += f"### Context\n{self.Context}\n" |
|
if self.ChatHistory: |
|
base_template += self.format_chat_history(self.ChatHistory) |
|
if self.Prompt: |
|
base_template += f"### Prompt\n{self.Prompt}\n" |
|
if self.Completion: |
|
base_template += f"### Completion\n{self.Completion}\n" |
|
return base_template |
|
|
|
@property |
|
def as_str(self): |
|
return f"{self.instruction}\n{self.premise}\n{self.hypothesis}" |
|
|
|
@property |
|
def as_model_inputs(self) -> dict[str, list[int]]: |
|
instruction_ids = self.tokenizer( |
|
self.instruction, add_special_tokens=False |
|
).input_ids |
|
premise_ids = self.tokenizer(self.premise, add_special_tokens=False).input_ids |
|
hypothesis_ids = self.tokenizer( |
|
self.hypothesis, add_special_tokens=False |
|
).input_ids |
|
|
|
premise_length = self.tokenizer.model_max_length - len( |
|
instruction_ids + hypothesis_ids |
|
) |
|
premise_ids = premise_ids[:premise_length] |
|
input_ids = ( |
|
[self.tokenizer.cls_token_id] |
|
+ instruction_ids |
|
+ [self.tokenizer.sep_token_id] |
|
+ premise_ids |
|
+ [self.tokenizer.sep_token_id] |
|
+ hypothesis_ids |
|
+ [self.tokenizer.sep_token_id] |
|
) |
|
attention_mask = [1] * len(input_ids) |
|
return BatchEncoding( |
|
data={"input_ids": input_ids, "attention_mask": attention_mask}, |
|
tensor_type="pt", |
|
prepend_batch_axis=True, |
|
) |
|
|
|
|
|
class DetectionResponseFactualConsistency(NLIInstruction): |
|
instruction: str = ( |
|
"""## Task\nDetermine if the completion is factually consistent with the context.""" |
|
) |
|
hypothesis: str = ( |
|
"""## Hypothesis\nThe completion is factually consistent with the context.""" |
|
) |
|
|
|
|
|
class DetectionContextPromptInjection(NLIInstruction): |
|
instruction: str = ( |
|
"""## Task\nDetermine if the prompt contains a prompt injection attack.""" |
|
) |
|
hypothesis: str = ( |
|
"""## Hypothesis\nThe prompt contains a prompt injection attack.""" |
|
) |
|
|
|
|
|
class DetectionContextSourceCode(NLIInstruction): |
|
instruction: str = ( |
|
"""## Task\nDetermine if the given context contains any source code or code snippets.""" |
|
) |
|
hypothesis: str = ( |
|
"""## Hypothesis\nThe given context contains source code or code blocks.""" |
|
) |
|
|
|
|
|
class DetectionContextToxicity(NLIInstruction): |
|
instruction: str = ( |
|
"""## Task\nDetermine if the context contains any toxic content related to one of the following: disability, gender, identity hate, insult, obscene, race, religion, toxicity, violence.""" |
|
) |
|
hypothesis: str = ( |
|
"""## Hypothesis\nThe context contains some toxic content related to one of the following: disability, gender, identity hate, insult, obscene, race, religion, toxicity, violence.""" |
|
) |
|
|
|
|
|
class DetectionContextToxicityDisability(NLIInstruction): |
|
instruction: str = ( |
|
"""## Task\nDetermine if the context contains any toxic content related to disability.""" |
|
) |
|
hypothesis: str = ( |
|
"""## Hypothesis\nThe context contains some toxic content related to disability.""" |
|
) |
|
|
|
|
|
class DetectionContextToxicityGender(NLIInstruction): |
|
instruction: str = ( |
|
"""## Task\nDetermine if the context contains any toxic content related to gender.""" |
|
) |
|
hypothesis: str = ( |
|
"""## Hypothesis\nThe context contains some toxic content related to gender.""" |
|
) |
|
|
|
|
|
class DetectionContextToxicityIdentityHate(NLIInstruction): |
|
instruction: str = ( |
|
"""## Task\nDetermine if the context contains any toxic content related to identity hate.""" |
|
) |
|
hypothesis: str = ( |
|
"""## Hypothesis\nThe context contains some toxic content related to identity hate.""" |
|
) |
|
|
|
|
|
class DetectionContextToxicityInsult(NLIInstruction): |
|
instruction: str = ( |
|
"""## Task\nDetermine if the context contains any insulting content.""" |
|
) |
|
hypothesis: str = """## Hypothesis\nThe context contains some insulting content.""" |
|
|
|
|
|
class DetectionContextToxicityObscene(NLIInstruction): |
|
instruction: str = ( |
|
"""## Task\nDetermine if the context contains any obscene content.""" |
|
) |
|
hypothesis: str = """## Hypothesis\nThe context contains some obscene content.""" |
|
|
|
|
|
class DetectionContextToxicityRace(NLIInstruction): |
|
instruction: str = ( |
|
"""## Task\nDetermine if the context contains any racist content.""" |
|
) |
|
hypothesis: str = """## Hypothesis\nThe context contains some racist content.""" |
|
|
|
|
|
class DetectionContextToxicityReligion(NLIInstruction): |
|
instruction: str = ( |
|
"""## Task\nDetermine if the context contains any toxic content related to religion.""" |
|
) |
|
hypothesis: str = ( |
|
"""## Hypothesis\nThe context contains some toxic content related to religion.""" |
|
) |
|
|
|
|
|
class DetectionContextToxicityViolence(NLIInstruction): |
|
instruction: str = ( |
|
"""## Task\nDetermine if the context contains any violent content.""" |
|
) |
|
hypothesis: str = """## Hypothesis\nThe context contains some violent content.""" |
|
|
|
|
|
class QualityContextDocumentRelevance(NLIInstruction): |
|
instruction: str = ( |
|
"""## Task\nDetermine if the context contains relevant information used by the completion to answer the question in the given prompt correctly.""" |
|
) |
|
hypothesis: str = ( |
|
"""## Hypothesis\nThe context contains relevant information used by the completion to answer the question in the given prompt correctly.""" |
|
) |
|
|
|
|
|
class QualityContextDocumentUtilization(NLIInstruction): |
|
instruction: str = ( |
|
"""## Task\nDetermine if the context was utilized in the completion to answer the question in the given prompt correctly.""" |
|
) |
|
hypothesis: str = ( |
|
"""## Hypothesis\nThe context was utilized in the completion to answer the question in the given prompt correctly.""" |
|
) |
|
|
|
|
|
class QualityContextSentenceRelevance(NLIInstruction): |
|
instruction: str = ( |
|
"""## Task\nDetermine if the context contains relevant information used by the completion to answer the question in the given prompt correctly.""" |
|
) |
|
hypothesis: str = ( |
|
"""## Hypothesis\nThe context contains relevant information used by the completion to answer the question in the given prompt correctly.""" |
|
) |
|
Sentence: str |
|
|
|
@property |
|
def premise(self) -> str: |
|
return super().premise + f"\n### Sentence\n{self.Sentence}\n" |
|
|
|
|
|
class QualityContextSentenceUtilization(NLIInstruction): |
|
instruction: str = ( |
|
"""## Task\nDetermine if the selected sentence was utilized in the completion to answer the question in the given prompt correctly.""" |
|
) |
|
hypothesis: str = ( |
|
"""## Hypothesis\nThe selected sentence was utilized in the completion to answer the question in the given prompt correctly.""" |
|
) |
|
Sentence: str |
|
|
|
@property |
|
def premise(self) -> str: |
|
return super().premise + f"\n### Sentence\n{self.Sentence}\n" |
|
|
|
|
|
class QualityResponseAdherence(NLIInstruction): |
|
instruction: str = ( |
|
"""## Task\nDetermine if the completion adheres to the context when answering the question in the given prompt.""" |
|
) |
|
hypothesis: str = ( |
|
"""## Hypothesis\nThe completion adheres to the context when answering the question in the given prompt.""" |
|
) |
|
|
|
|
|
class QualityResponseAttribution(NLIInstruction): |
|
instruction: str = ( |
|
"""## Task\nDetermine if the completion attributes the context when answering the question in the given prompt.""" |
|
) |
|
hypothesis: str = ( |
|
"""## Hypothesis\nThe completion attributes the context when answering the question in the given prompt.""" |
|
) |
|
|
|
|
|
class QualityResponseCoherence(NLIInstruction): |
|
instruction: str = ( |
|
"""## Task\nDetermine if the completion is coherent and for the given context.""" |
|
) |
|
hypothesis: str = ( |
|
"""## Hypothesis\nThe completion is coherent and for the given context.""" |
|
) |
|
|
|
|
|
class QualityResponseComplexity(NLIInstruction): |
|
instruction: str = ( |
|
"""## Task\nDetermine if the completion is complex and contains multiple steps to answer the question.""" |
|
) |
|
hypothesis: str = ( |
|
"""## Hypothesis\nThe completion is complex and contains multiple steps to answer the question.""" |
|
) |
|
|
|
|
|
class QualityResponseCorrectness(NLIInstruction): |
|
instruction: str = ( |
|
"""## Task\nDetermine if the completion is correct with respect to the given prompt and context.""" |
|
) |
|
hypothesis: str = ( |
|
"""## Hypothesis\nThe completion is correct with respect to the given prompt and context.""" |
|
) |
|
|
|
|
|
class QualityResponseHelpfulness(NLIInstruction): |
|
instruction: str = ( |
|
"""## Task\nDetermine if the completion is helpful with respect to the given prompt and context.""" |
|
) |
|
hypothesis: str = ( |
|
"""## Hypothesis\nThe completion is helpful with respect to the given prompt and context.""" |
|
) |
|
|
|
|
|
class QualityResponseInstructionFollowing(NLIInstruction): |
|
instruction: str = ( |
|
"""## Task\nDetermine if the completion follows the instructions provided in the given prompt.""" |
|
) |
|
hypothesis: str = ( |
|
"""## Hypothesis\nThe completion follows the instructions provided in the given prompt.""" |
|
) |
|
|
|
|
|
class QualityResponseRelevance(NLIInstruction): |
|
instruction: str = ( |
|
"""## Task\nDetermine if the completion is relevant to the given prompt and context.""" |
|
) |
|
hypothesis: str = ( |
|
"""## Hypothesis\nThe completion is relevant to the given prompt and context.""" |
|
) |
|
|
|
|
|
class QualityResponseVerbosity(NLIInstruction): |
|
instruction: str = ( |
|
"""## Task\nDetermine if the completion is too verbose with respect to the given prompt and context.""" |
|
) |
|
hypothesis: str = ( |
|
"""## Hypothesis\nThe completion is too verbose with respect to the given prompt and context.""" |
|
) |
|
|
|
|
|
TASK_CLASSES = { |
|
"Detection/Hallucination/Factual Consistency": DetectionResponseFactualConsistency, |
|
"Detection/Prompt Injection": DetectionContextPromptInjection, |
|
"Detection/Source Code": DetectionContextSourceCode, |
|
"Detection/Toxicity/Disability": DetectionContextToxicityDisability, |
|
"Detection/Toxicity/Gender": DetectionContextToxicityGender, |
|
"Detection/Toxicity/Identity Hate": DetectionContextToxicityIdentityHate, |
|
"Detection/Toxicity/Insult": DetectionContextToxicityInsult, |
|
"Detection/Toxicity/Obscene": DetectionContextToxicityObscene, |
|
"Detection/Toxicity/Race": DetectionContextToxicityRace, |
|
"Detection/Toxicity/Religion": DetectionContextToxicityReligion, |
|
"Detection/Toxicity/Toxicity": DetectionContextToxicity, |
|
"Detection/Toxicity/Toxic": DetectionContextToxicity, |
|
"Detection/Toxicity/Violence": DetectionContextToxicityViolence, |
|
"Quality/Context/Document Relevance": QualityContextDocumentRelevance, |
|
"Quality/Context/Document Utilization": QualityContextDocumentUtilization, |
|
"Quality/Context/Sentence Relevance": QualityContextSentenceRelevance, |
|
"Quality/Context/Sentence Utilization": QualityContextSentenceUtilization, |
|
"Quality/Response/Adherence": QualityResponseAdherence, |
|
"Quality/Response/Attribution": QualityResponseAttribution, |
|
"Quality/Response/Coherence": QualityResponseCoherence, |
|
"Quality/Response/Complexity": QualityResponseComplexity, |
|
"Quality/Response/Correctness": QualityResponseCorrectness, |
|
"Quality/Response/Helpfulness": QualityResponseHelpfulness, |
|
"Quality/Response/Instruction Following": QualityResponseInstructionFollowing, |
|
"Quality/Response/Relevance": QualityResponseRelevance, |
|
"Quality/Response/Verbosity": QualityResponseVerbosity, |
|
} |
|
|
|
TASK_THRESHOLDS = { |
|
"Detection/Hallucination/Factual Consistency": 0.5895, |
|
"Detection/Prompt Injection": 0.4147, |
|
"Detection/Source Code": 0.4001, |
|
"Detection/Toxicity/Disability": 0.5547, |
|
"Detection/Toxicity/Gender": 0.4007, |
|
"Detection/Toxicity/Identity Hate": 0.5502, |
|
"Detection/Toxicity/Insult": 0.4913, |
|
"Detection/Toxicity/Obscene": 0.448, |
|
"Detection/Toxicity/Race": 0.5983, |
|
"Detection/Toxicity/Religion": 0.4594, |
|
"Detection/Toxicity/Toxic": 0.5034, |
|
"Detection/Toxicity/Violence": 0.4031, |
|
"Quality/Context/Document Relevance": 0.5809, |
|
"Quality/Context/Document Utilization": 0.4005, |
|
"Quality/Context/Sentence Relevance": 0.6003, |
|
"Quality/Context/Sentence Utilization": 0.5417, |
|
"Quality/Response/Adherence": 0.59, |
|
"Quality/Response/Attribution": 0.5304, |
|
"Quality/Response/Coherence": 0.6891, |
|
"Quality/Response/Complexity": 0.7235, |
|
"Quality/Response/Correctness": 0.6535, |
|
"Quality/Response/Helpfulness": 0.4445, |
|
"Quality/Response/Instruction Following": 0.5323, |
|
"Quality/Response/Relevance": 0.4011, |
|
"Quality/Response/Verbosity": 0.4243, |
|
} |
|
|
|
|
|
class NLIScorer(Pipeline): |
|
def _sanitize_parameters(self, **kwargs): |
|
preprocess_kwargs = {} |
|
postprocess_kwargs = {} |
|
if "task_type" in kwargs: |
|
preprocess_kwargs["task_type"] = kwargs["task_type"] |
|
postprocess_kwargs["task_type"] = kwargs["task_type"] |
|
return preprocess_kwargs, {}, postprocess_kwargs |
|
|
|
def preprocess(self, inputs, task_type): |
|
TaskClass = TASK_CLASSES[task_type] |
|
task_class = TaskClass(tokenizer=self.tokenizer, **inputs) |
|
return task_class.as_model_inputs |
|
|
|
def _forward(self, model_inputs): |
|
outputs = self.model(**model_inputs) |
|
return outputs |
|
|
|
def postprocess(self, model_outputs, task_type): |
|
threshold = TASK_THRESHOLDS[task_type] |
|
pos_scores = model_outputs["logits"].softmax(-1)[0][1] |
|
best_class = int(pos_scores > threshold) |
|
if best_class == 1: |
|
score = pos_scores |
|
else: |
|
score = 1 - pos_scores |
|
return {"score": score.item(), "label": best_class} |
|
|