|
from datasets import load_dataset |
|
from openai import AsyncOpenAI |
|
from tqdm import tqdm |
|
import asyncio |
|
import json |
|
import os |
|
|
|
client = AsyncOpenAI(api_key="no-need", base_url="http://localhost:8000/v1") |
|
|
|
async def generate_answer(messages): |
|
try: |
|
response = await client.chat.completions.create( |
|
model="outputs/out-alpha", |
|
messages=messages, |
|
max_tokens=2048, |
|
) |
|
return response.choices[0].message.content |
|
except Exception as e: |
|
return 'error' |
|
|
|
async def process_batch(questions, batch_num, all_qa_pairs): |
|
tasks = [] |
|
for q in tqdm(questions, desc=f"Batch {batch_num}", leave=False): |
|
tasks.append(generate_answer(q)) |
|
answers = await asyncio.gather(*tasks) |
|
|
|
|
|
for q, a in zip(questions, answers): |
|
q.append({'role': 'assistant', 'content': a}) |
|
all_qa_pairs.append({"conversations": q}) |
|
|
|
|
|
with open('qa_pairs_all-alpha1b_2.json', 'w') as f: |
|
json.dump(all_qa_pairs, f, indent=2) |
|
|
|
return answers |
|
|
|
async def main(): |
|
dataset = load_dataset('qnguyen3/sft-r1') |
|
|
|
all_qa_pairs = [] |
|
if os.path.exists('qa_pairs_all-alpha1b_2.json'): |
|
with open('qa_pairs_all-alpha1b_2.json', 'r') as f: |
|
all_qa_pairs = json.load(f) |
|
|
|
|
|
question_list = [] |
|
print("Preparing questions...") |
|
for i, item in tqdm(enumerate(dataset['train']), desc="Loading dataset"): |
|
if i >= 21600: |
|
question_list.append(item['messages'][:-1]) |
|
|
|
|
|
batch_size = 200 |
|
for i in tqdm(range(0, len(question_list), batch_size), desc="Processing batches"): |
|
batch_questions = question_list[i:i+batch_size] |
|
batch_num = i // batch_size |
|
await process_batch(batch_questions, batch_num, all_qa_pairs) |
|
|
|
if __name__ == "__main__": |
|
asyncio.run(main()) |