few-shot-demo / model.py
spdin
initial commit
333cd19
from setfit import SetFitModel, SetFitTrainer
from sentence_transformers.losses import CosineSimilarityLoss
# Function to create a pipeline for text classification using the trained model
def create_classifier(model_path):
classifier = SetFitModel.from_pretrained(
model_path,
local_files_only=True,
)
return classifier
def run_setfit_training(
session_id, model_id, model_name, train_dataset, batch_size, num_iterations
):
model = SetFitModel.from_pretrained(model_id)
# Create trainer
trainer = SetFitTrainer(
model=model,
train_dataset=train_dataset,
eval_dataset=train_dataset,
loss_class=CosineSimilarityLoss,
metric="accuracy",
batch_size=batch_size,
num_iterations=num_iterations, # The number of text pairs to generate for contrastive learning
num_epochs=1, # The number of epochs to use for constrastive learning
column_mapping={"text": "text", "label": "label"},
)
trainer.train()
# metrics = trainer.evaluate()
# accuracy = metrics["accuracy"]
print(f"model used: {model_id}")
print(f"train dataset: {len(train_dataset)} samples")
# print(f"accuracy: {accuracy}")
save_model_path = f"./models/{session_id}/{model_id}_{model_name}"
trainer.model._save_pretrained(
save_directory=f"./models/{session_id}/{model_id}_{model_name}"
)
return save_model_path