PEFT
English
code
sql
bugdaryan's picture
Librarian Bot: Add base_model information to model (#1)
1482399
metadata
language:
  - en
license: openrail
library_name: peft
tags:
  - code
  - sql
datasets:
  - bugdaryan/spider-natsql-wikisql-instruct
base_model: WizardLM/WizardCoder-15B-V1.0

LoRA adapters for model WizardCoderSQL

Overview

Description

This repository contains a LoRA fine-tuned version of the Wizard Coder 15B model. The LoRA attention mechanism has been customized with specific parameters to enhance model performance in certain tasks. Additionally, the fine-tuned model has been merged with custom parameters to create a specialized model for specific use cases.

Model Details

  • Base Model: Wizard Coder 15B
  • Fine-Tuned Model Name: WizardCoderSQL-15B-V1.0-QLoRA
  • Fine-Tuning Parameters:
    • QLoRA Parameters:
      • LoRA Attention Dimension (lora_r): 64
      • LoRA Alpha Parameter (lora_alpha): 16
      • LoRA Dropout Probability (lora_dropout): 0.1
    • bitsandbytes Parameters:
      • Use 4-bit Precision Base Model (use_4bit): True
      • Compute Dtype for 4-bit Base Models (bnb_4bit_compute_dtype): float16
      • Quantization Type (bnb_4bit_quant_type): nf4
      • Activate Nested Quantization (use_nested_quant): False
    • TrainingArguments Parameters:
      • Number of Training Epochs (num_train_epochs): 1
      • Enable FP16/BF16 Training (fp16/bf16): False/True
      • Batch Size per GPU for Training (per_device_train_batch_size): 48
      • Batch Size per GPU for Evaluation (per_device_eval_batch_size): 4
      • Gradient Accumulation Steps (gradient_accumulation_steps): 1
      • Enable Gradient Checkpointing (gradient_checkpointing): True
      • Maximum Gradient Norm (max_grad_norm): 0.3
      • Initial Learning Rate (learning_rate): 2e-4
      • Weight Decay (weight_decay): 0.001
      • Optimizer (optim): paged_adamw_32bit
      • Learning Rate Scheduler Type (lr_scheduler_type): cosine
      • Maximum Training Steps (max_steps): -1
      • Warmup Ratio (warmup_ratio): 0.03
      • Group Sequences into Batches with Same Length (group_by_length): True
      • Save Checkpoint Every X Update Steps (save_steps): 0
      • Log Every X Update Steps (logging_steps): 25
    • SFT Parameters:
      • Maximum Sequence Length (max_seq_length): 500

Usage

To use this fine-tuned LoRA model and merged parameters, you can load it using the Hugging Face Transformers library in Python. Here's an example of how to use it:

from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
from peft import PeftModel

model_name = 'WizardLM/WizardCoder-15B-V1.0'
adapter_name = 'bugdaryan/WizardCoderSQL-15B-V1.0-QLoRA'

base_model = AutoModelForCausalLM.from_pretrained(model_name, device_map='auto')
model = PeftModel.from_pretrained(base_model, adapter_name)
model = model.merge_and_unload()

tokenizer = AutoTokenizer.from_pretrained(model_name)

pipe = pipeline('text-generation', model=model, tokenizer=tokenizer)

tables = "CREATE TABLE sales ( sale_id number PRIMARY KEY, product_id number, customer_id number, salesperson_id number, sale_date DATE, quantity number, FOREIGN KEY (product_id) REFERENCES products(product_id), FOREIGN KEY (customer_id) REFERENCES customers(customer_id), FOREIGN KEY (salesperson_id) REFERENCES salespeople(salesperson_id)); CREATE TABLE product_suppliers ( supplier_id number PRIMARY KEY, product_id number, supply_price number, FOREIGN KEY (product_id) REFERENCES products(product_id)); CREATE TABLE customers ( customer_id number PRIMARY KEY, name text, address text ); CREATE TABLE salespeople ( salesperson_id number PRIMARY KEY, name text, region text ); CREATE TABLE product_suppliers ( supplier_id number PRIMARY KEY, product_id number, supply_price number );"

question = 'Find the salesperson who made the most sales.'

prompt = f"Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request. ### Instruction: Convert text to SQLite query: {question} {tables} ### Response:"

ans = pipe(prompt, max_new_tokens=200)
print(ans[0]['generated_text'])

Disclaimer

WizardCoderSQL model follows the same license as WizardCoder. The content produced by any version of WizardCoderSQL is influenced by uncontrollable variables such as randomness, and therefore, the accuracy of the output cannot be guaranteed by this project. This project does not accept any legal liability for the content of the model output, nor does it assume responsibility for any losses incurred due to the use of associated resources and output results.