|
--- |
|
license: mit |
|
--- |
|
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification |
|
import torch |
|
import numpy as np |
|
|
|
# Define the model and tokenizer |
|
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased') |
|
model = DistilBertForSequenceClassification.from_pretrained('distilbert-base-uncased', num_labels=2) |
|
|
|
# Define the key words and their corresponding labels |
|
key_words = ['ascites', 'cirrhosis', 'liver disease'] |
|
labels = [0, 1] |
|
|
|
# Define a function to preprocess the input text |
|
def preprocess_text(text): |
|
inputs = tokenizer.encode_plus( |
|
text, |
|
add_special_tokens=True, |
|
max_length=512, |
|
return_attention_mask=True, |
|
return_tensors='pt' |
|
) |
|
return inputs |
|
|
|
# Define a function to make predictions |
|
def make_prediction(text): |
|
inputs = preprocess_text(text) |
|
outputs = model(inputs['input_ids'], attention_mask=inputs['attention_mask']) |
|
logits = outputs.logits |
|
probabilities = torch.nn.functional.softmax(logits, dim=1) |
|
predicted_class = torch.argmax(probabilities) |
|
return predicted_class.item() |
|
|
|
# Define a function to get the clinic that the referral should be directed to |
|
def get_clinic(text): |
|
predicted_class = make_prediction(text) |
|
if predicted_class == 1: |
|
return 'Liver Clinic' |
|
else: |
|
return 'Kidney Clinic' |
|
|
|
# Define the model's configuration |
|
model_config = { |
|
'model_type': 'distilbert', |
|
'num_labels': 2, |
|
'key_words': key_words, |
|
'labels': labels |
|
} |
|
|
|
# Define the model's metadata |
|
model_metadata = { |
|
'name': 'Referral Clinic Classifier', |
|
'description': 'A model that classifies referrals to either the Liver Clinic or Kidney Clinic based on the presence of certain key words.', |
|
'author': 'Your Name', |
|
'version': '1.0' |
|
} |
|
|
|
# Train the model |
|
train_data = [ |
|
('Patient has ascites and cirrhosis.', 1), |
|
('Patient has liver disease.', 1), |
|
('Patient has kidney disease.', 0), |
|
('Patient has liver failure.', 1), |
|
('Patient has kidney failure.', 0), |
|
] |
|
|
|
for text, label in train_data: |
|
inputs = preprocess_text(text) |
|
labels = torch.tensor(label) |
|
outputs = model(inputs['input_ids'], attention_mask=inputs['attention_mask'], labels=labels) |
|
loss = outputs.loss |
|
model.zero_grad() |
|
loss.backward() |
|
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5) |
|
optimizer.step() |
|
|
|
# Save the model to a file |
|
torch.save(model.state_dict(),'referral_clinic_classifier.pth') |
|
with open('model_config.json', 'w') as f: |
|
json.dump(model_config, f) |
|
with open('model_metadata.json', 'w') as f: |
|
json.dump(model_metadata, f) |