Spaces:
Sleeping
Sleeping
from setfit import SetFitModel, SetFitTrainer | |
from sentence_transformers.losses import CosineSimilarityLoss | |
# Function to create a pipeline for text classification using the trained model | |
def create_classifier(model_path): | |
classifier = SetFitModel.from_pretrained( | |
model_path, | |
local_files_only=True, | |
) | |
return classifier | |
def run_setfit_training( | |
session_id, model_id, model_name, train_dataset, batch_size, num_iterations | |
): | |
model = SetFitModel.from_pretrained(model_id) | |
# Create trainer | |
trainer = SetFitTrainer( | |
model=model, | |
train_dataset=train_dataset, | |
eval_dataset=train_dataset, | |
loss_class=CosineSimilarityLoss, | |
metric="accuracy", | |
batch_size=batch_size, | |
num_iterations=num_iterations, # The number of text pairs to generate for contrastive learning | |
num_epochs=1, # The number of epochs to use for constrastive learning | |
column_mapping={"text": "text", "label": "label"}, | |
) | |
trainer.train() | |
# metrics = trainer.evaluate() | |
# accuracy = metrics["accuracy"] | |
print(f"model used: {model_id}") | |
print(f"train dataset: {len(train_dataset)} samples") | |
# print(f"accuracy: {accuracy}") | |
save_model_path = f"./models/{session_id}/{model_id}_{model_name}" | |
trainer.model._save_pretrained( | |
save_directory=f"./models/{session_id}/{model_id}_{model_name}" | |
) | |
return save_model_path | |