CyberCode / app.py
admincybers2's picture
Update app.py
ecb4489 verified
import os
import torch
from unsloth import FastLanguageModel, is_bfloat16_supported
from trl import SFTTrainer
from transformers import TrainingArguments
from datasets import load_dataset
import gradio as gr
import json
from huggingface_hub import HfApi
max_seq_length = 4096
dtype = None
load_in_4bit = True
hf_token = os.getenv("HF_TOKEN")
current_num = os.getenv("NUM")
print(f"stage ${current_num}")
api = HfApi(token=hf_token)
models = "dad1909/cybersentinal-2.0"
print("Starting model and tokenizer loading...")
# Load the model and tokenizer
model, tokenizer = FastLanguageModel.from_pretrained(
model_name=models,
max_seq_length=max_seq_length,
dtype=dtype,
load_in_4bit=load_in_4bit,
token=hf_token
)
print("Model and tokenizer loaded successfully.")
print("Configuring PEFT model...")
model = FastLanguageModel.get_peft_model(
model,
r=16,
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
lora_alpha=16,
lora_dropout=0,
bias="none",
use_gradient_checkpointing="unsloth",
random_state=3407,
use_rslora=False,
loftq_config=None,
)
print("PEFT model configured.")
# Updated alpaca_prompt for different types
alpaca_prompt = {
"learning_from": """Below is a CVE definition.
### CVE definition:
{}
### detail CVE:
{}""",
"definition": """Below is a definition about software vulnerability. Explain it.
### Definition:
{}
### Explanation:
{}""",
"code_vulnerability": """Below is a code snippet. Identify the line of code that is vulnerable and describe the type of software vulnerability.
### Code Snippet:
{}
### Vulnerability solution:
{}"""
}
EOS_TOKEN = tokenizer.eos_token
def detect_prompt_type(instruction):
if instruction.startswith("what is code vulnerable of this code:"):
return "code_vulnerability"
elif instruction.startswith("Learning from"):
return "learning_from"
elif instruction.startswith("what is"):
return "definition"
else:
return "unknown"
def formatting_prompts_func(examples):
instructions = examples["instruction"]
outputs = examples["output"]
texts = []
for instruction, output in zip(instructions, outputs):
prompt_type = detect_prompt_type(instruction)
if prompt_type in alpaca_prompt:
prompt = alpaca_prompt[prompt_type].format(instruction, output)
else:
prompt = instruction + "\n\n" + output
text = prompt + EOS_TOKEN
texts.append(text)
return {"text": texts}
print("Loading dataset...")
dataset = load_dataset("admincybers2/DSV", split="train")
print("Dataset loaded successfully.")
print("Applying formatting function to the dataset...")
dataset = dataset.map(formatting_prompts_func, batched=True)
print("Formatting function applied.")
print("Initializing trainer...")
trainer = SFTTrainer(
model=model,
tokenizer=tokenizer,
train_dataset=dataset,
dataset_text_field="text",
max_seq_length=max_seq_length,
dataset_num_proc=2,
packing=False,
args=TrainingArguments(
per_device_train_batch_size=5,
gradient_accumulation_steps=5,
learning_rate=2e-4,
fp16=not is_bfloat16_supported(),
bf16=is_bfloat16_supported(),
warmup_steps=5,
logging_steps=10,
max_steps=200,
optim="adamw_8bit",
weight_decay=0.01,
lr_scheduler_type="linear",
seed=3407,
output_dir="outputs"
),
)
print("Trainer initialized.")
print("Starting training...")
trainer_stats = trainer.train()
print("Training completed.")
num = int(current_num)
num += 1
up = "sentinal-2"
print("Saving the trained model...")
model.save_pretrained_merged("model", tokenizer, save_method="merged_16bit")
print("Model saved successfully.")
print("Pushing the model to the hub...")
model.push_to_hub_merged(
up,
tokenizer,
save_method="merged_16bit",
token=hf_token
)
print("Model pushed to hub successfully.")
api.delete_space_variable(repo_id="admincybers2/CyberController", key="NUM")
api.add_space_variable(repo_id="admincybers2/CyberController", key="NUM", value=str(num))