text2sql / README.md
juanfra218's picture
Update README.md
3c5e719 verified
|
raw
history blame
3.18 kB
metadata
license: mit
datasets:
  - b-mc2/sql-create-context
  - gretelai/synthetic_text_to_sql
language:
  - en
base_model: google-t5/t5-base
metrics:
  - exact_match
model-index:
  - name: juanfra218/text2sql
    results:
      - task:
          type: text-to-sql
        metrics:
          - name: exact_match
            type: exact_match
            value: 0.4326836917562724
tags:
  - sql
library_name: transformers

Fine-Tuned Google T5 Model for Text to SQL Translation

A fine-tuned version of the Google T5 model, trained for the task of translating natural language queries into SQL statements.

Model Details

Ongoing Work

Currently working to implement PICARD (Parsing Incrementally for Constrained Auto-Regressive Decoding from Language Models) to improve the results of this model. More details can be found in the original PICARD paper.

Training Parameters

training_args = Seq2SeqTrainingArguments(
    output_dir="./results",
    evaluation_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    weight_decay=0.01,
    save_total_limit=3,
    num_train_epochs=3,
    predict_with_generate=True,
    fp16=True,
    push_to_hub=False,
)

Usage

import torch
from transformers import AutoTokenizer, T5ForConditionalGeneration

# Load the tokenizer and model
model_path = 'text2sql_model_path'
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = T5ForConditionalGeneration.from_pretrained(model_path)

# Function to generate SQL queries
def generate_sql(prompt, schema):
    input_text = "translate English to SQL: " + prompt + " " + schema
    inputs = tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True, padding="max_length")

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    inputs = {key: value.to(device) for key, value in inputs.items()}

    max_output_length = 1024
    outputs = model.generate(**inputs, max_length=max_output_length)
    return tokenizer.decode(outputs[0], skip_special_tokens=True)

# Interactive loop
print("Enter 'quit' to exit.")
while True:
    prompt = input("Insert prompt: ")
    schema = input("Insert schema: ")
    if prompt.lower() == 'quit':
        break

    sql_query = generate_sql(prompt, schema)
    print(f"Generated SQL query: {sql_query}")
    print()

Files

  • optimizer.pt: State of the optimizer.
  • training_args.bin: Training arguments and hyperparameters.
  • tokenizer.json: Tokenizer vocabulary and settings.
  • spiece.model: SentencePiece model file.
  • special_tokens_map.json: Special tokens mapping.
  • tokenizer_config.json: Tokenizer configuration settings.
  • model.safetensors: Trained model weights.
  • generation_config.json: Configuration for text generation.
  • config.json: Model architecture configuration.