fullstuckdev commited on
Commit
8c2f469
·
1 Parent(s): 1317aa0

first init

Browse files
Files changed (5) hide show
  1. .gitignore +5 -0
  2. README.md +1 -0
  3. app.py +138 -0
  4. generate_dataset.py +72 -0
  5. train.py +53 -0
.gitignore ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ .env
2
+ *.ipynb
3
+ *.pyc
4
+ *.pycache__
5
+ *.DS_Store
README.md CHANGED
@@ -6,6 +6,7 @@ colorTo: indigo
6
  sdk: docker
7
  pinned: false
8
  license: apache-2.0
 
9
  ---
10
 
11
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
6
  sdk: docker
7
  pinned: false
8
  license: apache-2.0
9
+ app_port: 8000
10
  ---
11
 
12
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from fastapi import FastAPI, HTTPException, BackgroundTasks
3
+ from pydantic import BaseModel
4
+ from typing import List, Optional
5
+ import torch
6
+ from datasets import load_dataset
7
+ from transformers import (
8
+ AutoTokenizer,
9
+ AutoModelForCausalLM,
10
+ TrainingArguments,
11
+ Trainer,
12
+ DataCollatorForLanguageModeling
13
+ )
14
+ import uvicorn
15
+
16
+ app = FastAPI(title="Medical LLaMA API")
17
+
18
+ model = None
19
+ tokenizer = None
20
+ model_output_path = "./model/medical_llama_3b"
21
+
22
+ class TrainRequest(BaseModel):
23
+ dataset_path: str
24
+ num_epochs: int = 3
25
+ batch_size: int = 4
26
+ learning_rate: float = 2e-5
27
+
28
+ class Query(BaseModel):
29
+ text: str
30
+ max_length: int = 512
31
+ temperature: float = 0.7
32
+ num_return_sequences: int = 1
33
+
34
+ class Response(BaseModel):
35
+ generated_text: List[str]
36
+
37
+ def train_model(dataset_path: str, num_epochs: int, batch_size: int, learning_rate: float):
38
+ global model, tokenizer
39
+
40
+ os.makedirs(model_output_path, exist_ok=True)
41
+
42
+ model_name = "nvidia/Meta-Llama-3.2-3B-Instruct-ONNX-INT4"
43
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
44
+ model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16)
45
+
46
+ dataset = load_dataset("json", data_files=dataset_path)
47
+
48
+ def preprocess_function(examples):
49
+ return tokenizer(examples["text"], truncation=True, padding="max_length", max_length=512)
50
+
51
+ tokenized_dataset = dataset.map(
52
+ preprocess_function,
53
+ batched=True,
54
+ remove_columns=dataset["train"].column_names
55
+ )
56
+
57
+ training_args = TrainingArguments(
58
+ output_dir=f"{model_output_path}/checkpoints",
59
+ per_device_train_batch_size=batch_size,
60
+ gradient_accumulation_steps=4,
61
+ num_train_epochs=num_epochs,
62
+ learning_rate=learning_rate,
63
+ fp16=True,
64
+ save_steps=500,
65
+ logging_steps=100,
66
+ )
67
+
68
+ trainer = Trainer(
69
+ model=model,
70
+ args=training_args,
71
+ train_dataset=tokenized_dataset["train"],
72
+ data_collator=DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False),
73
+ )
74
+
75
+ # Start training
76
+ trainer.train()
77
+
78
+ # Save the final model and tokenizer
79
+ model.save_pretrained(model_output_path)
80
+ tokenizer.save_pretrained(model_output_path)
81
+
82
+ print(f"Model and tokenizer saved to: {model_output_path}")
83
+
84
+ @app.post("/train")
85
+ async def train(request: TrainRequest, background_tasks: BackgroundTasks):
86
+ background_tasks.add_task(train_model, request.dataset_path, request.num_epochs, request.batch_size, request.learning_rate)
87
+ return {"message": "Training started in the background"}
88
+
89
+ @app.post("/generate", response_model=Response)
90
+ async def generate_text(query: Query):
91
+ global model, tokenizer
92
+
93
+ if model is None or tokenizer is None:
94
+ try:
95
+ tokenizer = AutoTokenizer.from_pretrained(model_output_path)
96
+ model = AutoModelForCausalLM.from_pretrained(
97
+ model_output_path,
98
+ torch_dtype=torch.float16,
99
+ device_map="auto"
100
+ )
101
+ except Exception as e:
102
+ raise HTTPException(status_code=500, detail=f"Error loading model: {str(e)}")
103
+
104
+ try:
105
+ inputs = tokenizer(
106
+ query.text,
107
+ return_tensors="pt",
108
+ padding=True,
109
+ truncation=True,
110
+ max_length=query.max_length
111
+ ).to(model.device)
112
+
113
+ with torch.no_grad():
114
+ generated_ids = model.generate(
115
+ inputs.input_ids,
116
+ max_length=query.max_length,
117
+ num_return_sequences=query.num_return_sequences,
118
+ temperature=query.temperature,
119
+ pad_token_id=tokenizer.pad_token_id,
120
+ eos_token_id=tokenizer.eos_token_id,
121
+ )
122
+
123
+ generated_texts = [
124
+ tokenizer.decode(g, skip_special_tokens=True)
125
+ for g in generated_ids
126
+ ]
127
+
128
+ return Response(generated_text=generated_texts)
129
+
130
+ except Exception as e:
131
+ raise HTTPException(status_code=500, detail=str(e))
132
+
133
+ @app.get("/health")
134
+ async def health_check():
135
+ return {"status": "healthy"}
136
+
137
+ if __name__ == "__main__":
138
+ uvicorn.run("app:app", host="0.0.0.0", port=8000, reload=False)
generate_dataset.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import random
3
+
4
+ conditions = [
5
+ "Hypertension", "Diabetes", "Asthma", "Arthritis", "Depression",
6
+ "Anxiety", "Obesity", "Migraine", "Allergies", "Influenza"
7
+ ]
8
+
9
+ medications = [
10
+ "Lisinopril", "Metformin", "Albuterol", "Ibuprofen", "Sertraline",
11
+ "Alprazolam", "Orlistat", "Sumatriptan", "Cetirizine", "Oseltamivir"
12
+ ]
13
+
14
+ def generate_question(condition):
15
+ questions = [
16
+ f"What are the symptoms of {condition}?",
17
+ f"How is {condition} typically diagnosed?",
18
+ f"What are the common treatments for {condition}?",
19
+ f"Can you explain the causes of {condition}?",
20
+ f"What lifestyle changes can help manage {condition}?",
21
+ f"Are there any complications associated with {condition}?",
22
+ f"How can {condition} be prevented?",
23
+ f"What's the long-term outlook for someone with {condition}?",
24
+ f"Are there any new treatments being developed for {condition}?",
25
+ f"How does {condition} affect daily life?"
26
+ ]
27
+ return random.choice(questions)
28
+
29
+ # Function to generate an answer (simplified for this example)
30
+ def generate_answer(condition, question):
31
+ return f"Here's some information about {condition} related to your question: '{question}' [Detailed medical explanation would go here.]"
32
+
33
+ # Function to generate a medication question
34
+ def generate_medication_question(medication):
35
+ questions = [
36
+ f"What is {medication} used for?",
37
+ f"What are the common side effects of {medication}?",
38
+ f"How should {medication} be taken?",
39
+ f"Are there any drug interactions with {medication}?",
40
+ f"What should I know before starting {medication}?",
41
+ f"How long does it take for {medication} to start working?",
42
+ f"Can {medication} be taken during pregnancy?",
43
+ f"What should I do if I miss a dose of {medication}?",
44
+ f"Is {medication} habit-forming?",
45
+ f"Are there any alternatives to {medication}?"
46
+ ]
47
+ return random.choice(questions)
48
+
49
+ def generate_medication_answer(medication, question):
50
+ return f"Regarding {medication} and your question: '{question}' [Detailed medication information would go here.]"
51
+
52
+ dataset = []
53
+ for _ in range(5000):
54
+ if random.choice([True, False]):
55
+ condition = random.choice(conditions)
56
+ question = generate_question(condition)
57
+ answer = generate_answer(condition, question)
58
+ else:
59
+ medication = random.choice(medications)
60
+ question = generate_medication_question(medication)
61
+ answer = generate_medication_answer(medication, question)
62
+
63
+ dataset.append({
64
+ "question": question,
65
+ "answer": answer,
66
+ "text": f"Question: {question}\nAnswer: {answer}"
67
+ })
68
+
69
+ with open("medical_dataset.json", "w") as f:
70
+ json.dump(dataset, f, indent=2)
71
+
72
+ print("Dataset generated and saved to medical_dataset.json")
train.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datasets import load_dataset
2
+ from transformers import (
3
+ AutoTokenizer,
4
+ AutoModelForCausalLM,
5
+ TrainingArguments,
6
+ Trainer,
7
+ DataCollatorForLanguageModeling
8
+ )
9
+ import torch
10
+ import os
11
+
12
+ model_output_path = "./model/medical_llama_3b"
13
+ os.makedirs(model_output_path, exist_ok=True)
14
+
15
+ model_name = "nvidia/Meta-Llama-3.2-3B-Instruct-ONNX-INT4"
16
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
17
+ model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16)
18
+
19
+ dataset = load_dataset("json", data_files="medical_dataset.json")
20
+
21
+ def preprocess_function(examples):
22
+ return tokenizer(examples["text"], truncation=True, padding="max_length", max_length=512)
23
+
24
+ tokenized_dataset = dataset.map(
25
+ preprocess_function,
26
+ batched=True,
27
+ remove_columns=dataset["train"].column_names
28
+ )
29
+
30
+ training_args = TrainingArguments(
31
+ output_dir="./model/medical_llama_3b/checkpoints",
32
+ per_device_train_batch_size=4,
33
+ gradient_accumulation_steps=4,
34
+ num_train_epochs=3,
35
+ learning_rate=2e-5,
36
+ fp16=True,
37
+ save_steps=500,
38
+ logging_steps=100,
39
+ )
40
+
41
+ trainer = Trainer(
42
+ model=model,
43
+ args=training_args,
44
+ train_dataset=tokenized_dataset["train"],
45
+ data_collator=DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False),
46
+ )
47
+
48
+ trainer.train()
49
+
50
+ model.save_pretrained(model_output_path)
51
+ tokenizer.save_pretrained(model_output_path)
52
+
53
+ print(f"Model and tokenizer saved to: {model_output_path}")