from transformers import TextClassificationPipeline, AutoTokenizer class CustomTextClassificationPipeline(TextClassificationPipeline): def __init__(self, model, tokenizer=None, **kwargs): # Initialize tokenizer first if tokenizer is None: tokenizer = AutoTokenizer.from_pretrained(model.config._name_or_path) # Make sure we store the tokenizer before calling super().__init__ self.tokenizer = tokenizer super().__init__(model=model, tokenizer=tokenizer, **kwargs) def _sanitize_parameters(self, **kwargs): preprocess_kwargs = {} return preprocess_kwargs, {}, {} def preprocess(self, inputs): return self.tokenizer(inputs, return_tensors='pt', truncation=False) def _forward(self, model_inputs): input_ids = model_inputs['input_ids'] attention_mask = (input_ids != 0).long() outputs = self.model(input_ids=input_ids, attention_mask=attention_mask) return outputs def postprocess(self, model_outputs): predictions = model_outputs.logits.argmax(dim=-1).squeeze().tolist() categories = ["Race/Origin", "Gender/Sex", "Religion", "Ability", "Violence", "Other"] return dict(zip(categories, predictions))