from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig, TrainingArguments
import torch

# Load the model and tokenizer
model_name = 'Siddharth63/medELM-diseases'
original_model = AutoModelForCausalLM.from_pretrained(
    model_name, 
    torch_dtype=torch.float32, 
    device_map="auto",
    trust_remote_code=True  # Changed to "auto" for automatic placement
)
tokenizer = AutoTokenizer.from_pretrained(
    model_name, 
    model_max_length=3072, 
    padding_side="right", 
    use_fast=True
)

# Ensure inputs are on the correct device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

prompt = "Given a text, extract disease entities and their disease types from the text.\n\n### Text:Patient has been vomiting over the past 6 - 8 weeks, since before she was diagnosed with metastatic cancer. She also reports pain over her upper abdomen and has very poor PO intake. She has been feeling progressively weak over this time period. Her vomiting and abdominal pain has not increased from the past weeks, but she just feels more fatigued. She has a chronic non-productive cough as well. No URI symptoms, no urinary complaints. She has been constipated, which improves when she stops her anti-emetics. Last bowel movement was yesterday. She is passing gas. She has lower extremity edema, which has been present for the past several weeks. Of note, she was supposed to have one of her liver mets biopsied in the past several weeks, but she was taking ibuprofen so the biopsy had to be postponed. In the ED, initial VS were: 97.6 117 128/74 18 95% RA. Labs were significant for WBC of 18.7, with 77% polys. UA was significant for ketones. Patient received zofran, NS. She had a CXR that showed new left sided opacity that may reflect PNA superimposed on metastatic diseae vs. lymphangiitic spread of cancer. She received vanc and cefepime for pneumonia.\n\n### Entities:"
# Tokenize the input text
inputs = tokenizer(, return_tensors="pt").to(device)  # Move inputs to the same device as the model

# Generate outputs
outputs = original_model.generate(
    **inputs, 
    max_length=1024, 
    do_sample=True,
    temperature=0.01
)

# Decode and print the output
print(tokenizer.batch_decode(outputs, skip_special_tokens=True)[0][len(prompt):].split("}]")[0] + "}]")
Downloads last month
2
Safetensors
Model size
272M params
Tensor type
F32
·
Inference API
Unable to determine this model's library. Check the docs .