param-bharat's picture
Upload NLIScorer
b1ae5f6 verified
raw
history blame
14.9 kB
from pydantic import BaseModel, ConfigDict
from transformers import (
AutoTokenizer,
PreTrainedTokenizerFast,
PreTrainedTokenizer,
BatchEncoding,
)
from transformers import Pipeline
class NLIInstruction(BaseModel):
tokenizer: AutoTokenizer | PreTrainedTokenizerFast | PreTrainedTokenizer
instruction: str
hypothesis: str
Prompt: str | None = None
Completion: str | None = None
Context: str | None = None
ChatHistory: list[dict[str, str]] | None = None
model_config = ConfigDict(arbitrary_types_allowed=True)
def format_chat_history(self, chat_history: list[dict[str, str]]) -> str:
return "\n".join(
[
f"### Background\n{message['role']}: {message['content']}"
for message in chat_history
]
)
@property
def premise(self) -> str:
base_template = "## Premise\n"
if self.Context:
base_template += f"### Context\n{self.Context}\n"
if self.ChatHistory:
base_template += self.format_chat_history(self.ChatHistory)
if self.Prompt:
base_template += f"### Prompt\n{self.Prompt}\n"
if self.Completion:
base_template += f"### Completion\n{self.Completion}\n"
return base_template
@property
def as_str(self):
return f"{self.instruction}\n{self.premise}\n{self.hypothesis}"
@property
def as_model_inputs(self) -> dict[str, list[int]]:
instruction_ids = self.tokenizer(
self.instruction, add_special_tokens=False
).input_ids
premise_ids = self.tokenizer(self.premise, add_special_tokens=False).input_ids
hypothesis_ids = self.tokenizer(
self.hypothesis, add_special_tokens=False
).input_ids
premise_length = self.tokenizer.model_max_length - len(
instruction_ids + hypothesis_ids
)
premise_ids = premise_ids[:premise_length]
input_ids = (
[self.tokenizer.cls_token_id]
+ instruction_ids
+ [self.tokenizer.sep_token_id]
+ premise_ids
+ [self.tokenizer.sep_token_id]
+ hypothesis_ids
+ [self.tokenizer.sep_token_id]
)
attention_mask = [1] * len(input_ids)
return BatchEncoding(
data={"input_ids": input_ids, "attention_mask": attention_mask},
tensor_type="pt",
prepend_batch_axis=True,
)
class DetectionResponseFactualConsistency(NLIInstruction):
instruction: str = (
"""## Task\nDetermine if the completion is factually consistent with the context."""
)
hypothesis: str = (
"""## Hypothesis\nThe completion is factually consistent with the context."""
)
class DetectionContextPromptInjection(NLIInstruction):
instruction: str = (
"""## Task\nDetermine if the prompt contains a prompt injection attack."""
)
hypothesis: str = (
"""## Hypothesis\nThe prompt contains a prompt injection attack."""
)
class DetectionContextSourceCode(NLIInstruction):
instruction: str = (
"""## Task\nDetermine if the given context contains any source code or code snippets."""
)
hypothesis: str = (
"""## Hypothesis\nThe given context contains source code or code blocks."""
)
class DetectionContextToxicity(NLIInstruction):
instruction: str = (
"""## Task\nDetermine if the context contains any toxic content related to one of the following: disability, gender, identity hate, insult, obscene, race, religion, toxicity, violence."""
)
hypothesis: str = (
"""## Hypothesis\nThe context contains some toxic content related to one of the following: disability, gender, identity hate, insult, obscene, race, religion, toxicity, violence."""
)
class DetectionContextToxicityDisability(NLIInstruction):
instruction: str = (
"""## Task\nDetermine if the context contains any toxic content related to disability."""
)
hypothesis: str = (
"""## Hypothesis\nThe context contains some toxic content related to disability."""
)
class DetectionContextToxicityGender(NLIInstruction):
instruction: str = (
"""## Task\nDetermine if the context contains any toxic content related to gender."""
)
hypothesis: str = (
"""## Hypothesis\nThe context contains some toxic content related to gender."""
)
class DetectionContextToxicityIdentityHate(NLIInstruction):
instruction: str = (
"""## Task\nDetermine if the context contains any toxic content related to identity hate."""
)
hypothesis: str = (
"""## Hypothesis\nThe context contains some toxic content related to identity hate."""
)
class DetectionContextToxicityInsult(NLIInstruction):
instruction: str = (
"""## Task\nDetermine if the context contains any insulting content."""
)
hypothesis: str = """## Hypothesis\nThe context contains some insulting content."""
class DetectionContextToxicityObscene(NLIInstruction):
instruction: str = (
"""## Task\nDetermine if the context contains any obscene content."""
)
hypothesis: str = """## Hypothesis\nThe context contains some obscene content."""
class DetectionContextToxicityRace(NLIInstruction):
instruction: str = (
"""## Task\nDetermine if the context contains any racist content."""
)
hypothesis: str = """## Hypothesis\nThe context contains some racist content."""
class DetectionContextToxicityReligion(NLIInstruction):
instruction: str = (
"""## Task\nDetermine if the context contains any toxic content related to religion."""
)
hypothesis: str = (
"""## Hypothesis\nThe context contains some toxic content related to religion."""
)
class DetectionContextToxicityViolence(NLIInstruction):
instruction: str = (
"""## Task\nDetermine if the context contains any violent content."""
)
hypothesis: str = """## Hypothesis\nThe context contains some violent content."""
class QualityContextDocumentRelevance(NLIInstruction):
instruction: str = (
"""## Task\nDetermine if the context contains relevant information used by the completion to answer the question in the given prompt correctly."""
)
hypothesis: str = (
"""## Hypothesis\nThe context contains relevant information used by the completion to answer the question in the given prompt correctly."""
)
class QualityContextDocumentUtilization(NLIInstruction):
instruction: str = (
"""## Task\nDetermine if the context was utilized in the completion to answer the question in the given prompt correctly."""
)
hypothesis: str = (
"""## Hypothesis\nThe context was utilized in the completion to answer the question in the given prompt correctly."""
)
class QualityContextSentenceRelevance(NLIInstruction):
instruction: str = (
"""## Task\nDetermine if the context contains relevant information used by the completion to answer the question in the given prompt correctly."""
)
hypothesis: str = (
"""## Hypothesis\nThe context contains relevant information used by the completion to answer the question in the given prompt correctly."""
)
Sentence: str
@property
def premise(self) -> str:
return super().premise + f"\n### Sentence\n{self.Sentence}\n"
class QualityContextSentenceUtilization(NLIInstruction):
instruction: str = (
"""## Task\nDetermine if the selected sentence was utilized in the completion to answer the question in the given prompt correctly."""
)
hypothesis: str = (
"""## Hypothesis\nThe selected sentence was utilized in the completion to answer the question in the given prompt correctly."""
)
Sentence: str
@property
def premise(self) -> str:
return super().premise + f"\n### Sentence\n{self.Sentence}\n"
class QualityResponseAdherence(NLIInstruction):
instruction: str = (
"""## Task\nDetermine if the completion adheres to the context when answering the question in the given prompt."""
)
hypothesis: str = (
"""## Hypothesis\nThe completion adheres to the context when answering the question in the given prompt."""
)
class QualityResponseAttribution(NLIInstruction):
instruction: str = (
"""## Task\nDetermine if the completion attributes the context when answering the question in the given prompt."""
)
hypothesis: str = (
"""## Hypothesis\nThe completion attributes the context when answering the question in the given prompt."""
)
class QualityResponseCoherence(NLIInstruction):
instruction: str = (
"""## Task\nDetermine if the completion is coherent and for the given context."""
)
hypothesis: str = (
"""## Hypothesis\nThe completion is coherent and for the given context."""
)
class QualityResponseComplexity(NLIInstruction):
instruction: str = (
"""## Task\nDetermine if the completion is complex and contains multiple steps to answer the question."""
)
hypothesis: str = (
"""## Hypothesis\nThe completion is complex and contains multiple steps to answer the question."""
)
class QualityResponseCorrectness(NLIInstruction):
instruction: str = (
"""## Task\nDetermine if the completion is correct with respect to the given prompt and context."""
)
hypothesis: str = (
"""## Hypothesis\nThe completion is correct with respect to the given prompt and context."""
)
class QualityResponseHelpfulness(NLIInstruction):
instruction: str = (
"""## Task\nDetermine if the completion is helpful with respect to the given prompt and context."""
)
hypothesis: str = (
"""## Hypothesis\nThe completion is helpful with respect to the given prompt and context."""
)
class QualityResponseInstructionFollowing(NLIInstruction):
instruction: str = (
"""## Task\nDetermine if the completion follows the instructions provided in the given prompt."""
)
hypothesis: str = (
"""## Hypothesis\nThe completion follows the instructions provided in the given prompt."""
)
class QualityResponseRelevance(NLIInstruction):
instruction: str = (
"""## Task\nDetermine if the completion is relevant to the given prompt and context."""
)
hypothesis: str = (
"""## Hypothesis\nThe completion is relevant to the given prompt and context."""
)
class QualityResponseVerbosity(NLIInstruction):
instruction: str = (
"""## Task\nDetermine if the completion is too verbose with respect to the given prompt and context."""
)
hypothesis: str = (
"""## Hypothesis\nThe completion is too verbose with respect to the given prompt and context."""
)
TASK_CLASSES = {
"Detection/Hallucination/Factual Consistency": DetectionResponseFactualConsistency,
"Detection/Prompt Injection": DetectionContextPromptInjection,
"Detection/Source Code": DetectionContextSourceCode,
"Detection/Toxicity/Disability": DetectionContextToxicityDisability,
"Detection/Toxicity/Gender": DetectionContextToxicityGender,
"Detection/Toxicity/Identity Hate": DetectionContextToxicityIdentityHate,
"Detection/Toxicity/Insult": DetectionContextToxicityInsult,
"Detection/Toxicity/Obscene": DetectionContextToxicityObscene,
"Detection/Toxicity/Race": DetectionContextToxicityRace,
"Detection/Toxicity/Religion": DetectionContextToxicityReligion,
"Detection/Toxicity/Toxicity": DetectionContextToxicity,
"Detection/Toxicity/Toxic": DetectionContextToxicity,
"Detection/Toxicity/Violence": DetectionContextToxicityViolence,
"Quality/Context/Document Relevance": QualityContextDocumentRelevance,
"Quality/Context/Document Utilization": QualityContextDocumentUtilization,
"Quality/Context/Sentence Relevance": QualityContextSentenceRelevance,
"Quality/Context/Sentence Utilization": QualityContextSentenceUtilization,
"Quality/Response/Adherence": QualityResponseAdherence,
"Quality/Response/Attribution": QualityResponseAttribution,
"Quality/Response/Coherence": QualityResponseCoherence,
"Quality/Response/Complexity": QualityResponseComplexity,
"Quality/Response/Correctness": QualityResponseCorrectness,
"Quality/Response/Helpfulness": QualityResponseHelpfulness,
"Quality/Response/Instruction Following": QualityResponseInstructionFollowing,
"Quality/Response/Relevance": QualityResponseRelevance,
"Quality/Response/Verbosity": QualityResponseVerbosity,
}
TASK_THRESHOLDS = {
"Detection/Hallucination/Factual Consistency": 0.5895,
"Detection/Prompt Injection": 0.4147,
"Detection/Source Code": 0.4001,
"Detection/Toxicity/Disability": 0.5547,
"Detection/Toxicity/Gender": 0.4007,
"Detection/Toxicity/Identity Hate": 0.5502,
"Detection/Toxicity/Insult": 0.4913,
"Detection/Toxicity/Obscene": 0.448,
"Detection/Toxicity/Race": 0.5983,
"Detection/Toxicity/Religion": 0.4594,
"Detection/Toxicity/Toxic": 0.5034,
"Detection/Toxicity/Violence": 0.4031,
"Quality/Context/Document Relevance": 0.5809,
"Quality/Context/Document Utilization": 0.4005,
"Quality/Context/Sentence Relevance": 0.6003,
"Quality/Context/Sentence Utilization": 0.5417,
"Quality/Response/Adherence": 0.59,
"Quality/Response/Attribution": 0.5304,
"Quality/Response/Coherence": 0.6891,
"Quality/Response/Complexity": 0.7235,
"Quality/Response/Correctness": 0.6535,
"Quality/Response/Helpfulness": 0.4445,
"Quality/Response/Instruction Following": 0.5323,
"Quality/Response/Relevance": 0.4011,
"Quality/Response/Verbosity": 0.4243,
}
class NLIScorer(Pipeline):
def _sanitize_parameters(self, **kwargs):
preprocess_kwargs = {}
postprocess_kwargs = {}
if "task_type" in kwargs:
preprocess_kwargs["task_type"] = kwargs["task_type"]
postprocess_kwargs["task_type"] = kwargs["task_type"]
return preprocess_kwargs, {}, postprocess_kwargs
def preprocess(self, inputs, task_type):
TaskClass = TASK_CLASSES[task_type]
task_class = TaskClass(tokenizer=self.tokenizer, **inputs)
return task_class.as_model_inputs
def _forward(self, model_inputs):
outputs = self.model(**model_inputs)
return outputs
def postprocess(self, model_outputs, task_type):
threshold = TASK_THRESHOLDS[task_type]
pos_scores = model_outputs["logits"].softmax(-1)[0][1]
best_class = int(pos_scores > threshold)
if best_class == 1:
score = pos_scores
else:
score = 1 - pos_scores
return {"score": score.item(), "label": best_class}