lavawolfiee commited on
Commit
4e06d26
·
1 Parent(s): 6bc49a9

Some fixies

Browse files
Files changed (1) hide show
  1. app.py +2 -2
app.py CHANGED
@@ -35,9 +35,9 @@ bos_token_id, eos_token_id = tokenizer.bos_token_id, tokenizer.eos_token_id
35
  bos_token, eos_token = tokenizer.bos_token, tokenizer.eos_token
36
 
37
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
38
- knn_model = inject_knn_in_gpt2(
39
  model, knn_memory, bos_token_id, eos_token_id, device, layer_ind=8)
40
- knn_model.load_state_dict(torch.load('gpt2_knn_attention.pt'))
41
 
42
 
43
  def generate(text, temperature, max_new_tokens, top_p):
 
35
  bos_token, eos_token = tokenizer.bos_token, tokenizer.eos_token
36
 
37
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
38
+ inject_knn_in_gpt2(
39
  model, knn_memory, bos_token_id, eos_token_id, device, layer_ind=8)
40
+ model.load_state_dict(torch.load('gpt2_knn_attention.pt'))
41
 
42
 
43
  def generate(text, temperature, max_new_tokens, top_p):