Spaces:
Paused
Paused
import os | |
import torch | |
from unsloth import FastLanguageModel, is_bfloat16_supported | |
from trl import SFTTrainer | |
from transformers import TrainingArguments | |
from datasets import load_dataset | |
import gradio as gr | |
import json | |
from huggingface_hub import HfApi | |
max_seq_length = 4096 | |
dtype = None | |
load_in_4bit = True | |
hf_token = os.getenv("HF_TOKEN") | |
current_num = os.getenv("NUM") | |
print(f"stage ${current_num}") | |
api = HfApi(token=hf_token) | |
models = "dad1909/cybersentinal-2.0" | |
print("Starting model and tokenizer loading...") | |
# Load the model and tokenizer | |
model, tokenizer = FastLanguageModel.from_pretrained( | |
model_name=models, | |
max_seq_length=max_seq_length, | |
dtype=dtype, | |
load_in_4bit=load_in_4bit, | |
token=hf_token | |
) | |
print("Model and tokenizer loaded successfully.") | |
print("Configuring PEFT model...") | |
model = FastLanguageModel.get_peft_model( | |
model, | |
r=16, | |
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], | |
lora_alpha=16, | |
lora_dropout=0, | |
bias="none", | |
use_gradient_checkpointing="unsloth", | |
random_state=3407, | |
use_rslora=False, | |
loftq_config=None, | |
) | |
print("PEFT model configured.") | |
# Updated alpaca_prompt for different types | |
alpaca_prompt = { | |
"learning_from": """Below is a CVE definition. | |
### CVE definition: | |
{} | |
### detail CVE: | |
{}""", | |
"definition": """Below is a definition about software vulnerability. Explain it. | |
### Definition: | |
{} | |
### Explanation: | |
{}""", | |
"code_vulnerability": """Below is a code snippet. Identify the line of code that is vulnerable and describe the type of software vulnerability. | |
### Code Snippet: | |
{} | |
### Vulnerability solution: | |
{}""" | |
} | |
EOS_TOKEN = tokenizer.eos_token | |
def detect_prompt_type(instruction): | |
if instruction.startswith("what is code vulnerable of this code:"): | |
return "code_vulnerability" | |
elif instruction.startswith("Learning from"): | |
return "learning_from" | |
elif instruction.startswith("what is"): | |
return "definition" | |
else: | |
return "unknown" | |
def formatting_prompts_func(examples): | |
instructions = examples["instruction"] | |
outputs = examples["output"] | |
texts = [] | |
for instruction, output in zip(instructions, outputs): | |
prompt_type = detect_prompt_type(instruction) | |
if prompt_type in alpaca_prompt: | |
prompt = alpaca_prompt[prompt_type].format(instruction, output) | |
else: | |
prompt = instruction + "\n\n" + output | |
text = prompt + EOS_TOKEN | |
texts.append(text) | |
return {"text": texts} | |
print("Loading dataset...") | |
dataset = load_dataset("admincybers2/DSV", split="train") | |
print("Dataset loaded successfully.") | |
print("Applying formatting function to the dataset...") | |
dataset = dataset.map(formatting_prompts_func, batched=True) | |
print("Formatting function applied.") | |
print("Initializing trainer...") | |
trainer = SFTTrainer( | |
model=model, | |
tokenizer=tokenizer, | |
train_dataset=dataset, | |
dataset_text_field="text", | |
max_seq_length=max_seq_length, | |
dataset_num_proc=2, | |
packing=False, | |
args=TrainingArguments( | |
per_device_train_batch_size=5, | |
gradient_accumulation_steps=5, | |
learning_rate=2e-4, | |
fp16=not is_bfloat16_supported(), | |
bf16=is_bfloat16_supported(), | |
warmup_steps=5, | |
logging_steps=10, | |
max_steps=200, | |
optim="adamw_8bit", | |
weight_decay=0.01, | |
lr_scheduler_type="linear", | |
seed=3407, | |
output_dir="outputs" | |
), | |
) | |
print("Trainer initialized.") | |
print("Starting training...") | |
trainer_stats = trainer.train() | |
print("Training completed.") | |
num = int(current_num) | |
num += 1 | |
up = "sentinal-2" | |
print("Saving the trained model...") | |
model.save_pretrained_merged("model", tokenizer, save_method="merged_16bit") | |
print("Model saved successfully.") | |
print("Pushing the model to the hub...") | |
model.push_to_hub_merged( | |
up, | |
tokenizer, | |
save_method="merged_16bit", | |
token=hf_token | |
) | |
print("Model pushed to hub successfully.") | |
api.delete_space_variable(repo_id="admincybers2/CyberController", key="NUM") | |
api.add_space_variable(repo_id="admincybers2/CyberController", key="NUM", value=str(num)) |