Triton is slower?
#21
by
doguaraci
- opened
I've been trying the model with the triton on/off with the below code, and triton is almost 3 times slower in my environment (A10G). Do you have any guidance on this?
import time
import torch
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
config = AutoConfig.from_pretrained(
"replit/replit-code-v1-3b",
trust_remote_code=True
)
config.attn_config['attn_impl'] = 'triton' # I'm commenting out this to try with 'torch' implementation
model = AutoModelForCausalLM.from_pretrained('replit/replit-code-v1-3b', config=config, trust_remote_code=True)
model.to(device='cuda:0', dtype=torch.bfloat16)
tokenizer = AutoTokenizer.from_pretrained('replit/replit-code-v1-3b', trust_remote_code=True)
x = tokenizer.encode('def hello():\n print("hello world")\n', return_tensors='pt').to('cuda')
start = time.time()
y = model.generate(x, max_new_tokens=64)
end = time.time()
print(end - start)
doguaraci
changed discussion status to
closed