Mental Health Text Classification Model v0.2

!! Accuracy: 69.87% !!

This model is designed to classify texts into different mental health categories. It uses 2% of the dataset from the following study:

@article{low2020natural,
title={Natural Language Processing Reveals Vulnerable Mental Health Support Groups and Heightened Health Anxiety on Reddit During COVID-19: Observational Study},
author={Low, Daniel M and Rumker, Laurie and Torous, John and Cecchi, Guillermo and Ghosh, Satrajit S and Talkar, Tanya},
journal={Journal of medical Internet research},
volume={22},
number={10},
pages={e22635},
year={2020},
publisher={JMIR Publications Inc., Toronto, Canada}
}

Model Details

This model is fine-tuned to classify texts into the following mental health categories:

  • EDAnonymous
  • addiction
  • alcoholism
  • adhd
  • anxiety
  • autism
  • bipolarreddit
  • bpd
  • depression
  • healthanxiety
  • lonely
  • ptsd
  • schizophrenia
  • socialanxiety
  • suicidewatch

Example Usage

An example usage of the model is:

from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch

# Load the tokenizer and model
tokenizer = AutoTokenizer.from_pretrained("tahaenesaslanturk/mental-health-classification-v0.2")
model = AutoModelForSequenceClassification.from_pretrained("tahaenesaslanturk/mental-health-classification-v0.2")

# Encode the input text
input_text = "I struggle with my relationship with food and my body image, often feeling guilt or shame after eating."
inputs = tokenizer(input_text, return_tensors="pt")

# Perform inference
with torch.no_grad():
    outputs = model(**inputs)

# Get the predicted label
predicted_label = torch.argmax(outputs.logits, dim=1).item()
label = model.config.id2label[predicted_label]

print(f"Predicted label: {label}")
Downloads last month
24
Safetensors
Model size
335M params
Tensor type
F32
·
Inference Providers NEW
This model is not currently available via any of the supported third-party Inference Providers, and the model is not deployed on the HF Inference API.