mclear1's picture
Update README.md
d25492d verified
metadata
license: mit

from transformers import DistilBertTokenizer, DistilBertForSequenceClassification import torch import numpy as np

Define the model and tokenizer

tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased') model = DistilBertForSequenceClassification.from_pretrained('distilbert-base-uncased', num_labels=2)

Define the key words and their corresponding labels

key_words = ['ascites', 'cirrhosis', 'liver disease'] labels = [0, 1]

Define a function to preprocess the input text

def preprocess_text(text): inputs = tokenizer.encode_plus( text, add_special_tokens=True, max_length=512, return_attention_mask=True, return_tensors='pt' ) return inputs

Define a function to make predictions

def make_prediction(text): inputs = preprocess_text(text) outputs = model(inputs['input_ids'], attention_mask=inputs['attention_mask']) logits = outputs.logits probabilities = torch.nn.functional.softmax(logits, dim=1) predicted_class = torch.argmax(probabilities) return predicted_class.item()

Define a function to get the clinic that the referral should be directed to

def get_clinic(text): predicted_class = make_prediction(text) if predicted_class == 1: return 'Liver Clinic' else: return 'Kidney Clinic'

Define the model's configuration

model_config = { 'model_type': 'distilbert', 'num_labels': 2, 'key_words': key_words, 'labels': labels }

Define the model's metadata

model_metadata = { 'name': 'Referral Clinic Classifier', 'description': 'A model that classifies referrals to either the Liver Clinic or Kidney Clinic based on the presence of certain key words.', 'author': 'Your Name', 'version': '1.0' }

Train the model

train_data = [ ('Patient has ascites and cirrhosis.', 1), ('Patient has liver disease.', 1), ('Patient has kidney disease.', 0), ('Patient has liver failure.', 1), ('Patient has kidney failure.', 0), ]

for text, label in train_data: inputs = preprocess_text(text) labels = torch.tensor(label) outputs = model(inputs['input_ids'], attention_mask=inputs['attention_mask'], labels=labels) loss = outputs.loss model.zero_grad() loss.backward() optimizer = torch.optim.Adam(model.parameters(), lr=1e-5) optimizer.step()

Save the model to a file

torch.save(model.state_dict(),'referral_clinic_classifier.pth') with open('model_config.json', 'w') as f: json.dump(model_config, f) with open('model_metadata.json', 'w') as f: json.dump(model_metadata, f)