ssaroya commited on
Commit
320d690
·
1 Parent(s): 8083870

Create gcp_run.py

Browse files
Files changed (1) hide show
  1. gcp_run.py +78 -0
gcp_run.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import quant
6
+
7
+ from gptq import GPTQ
8
+ from utils import find_layers, DEV, set_seed, get_wikitext2, get_ptb, get_c4, get_ptb_new, get_c4_new, get_loaders
9
+ import transformers
10
+ from transformers import AutoTokenizer
11
+
12
+
13
+ def load_quant(model = "Wizard-Vicuna-13B-Uncensored-GPTQ",
14
+ checkpoint = "Wizard-Vicuna-13B-Uncensored-GPTQ/Wizard-Vicuna-13B-Uncensored-GPTQ-4bit-128g.compat.no-act-order.safetensors",
15
+ wbits = 4,
16
+ groupsize=128,
17
+ fused_mlp=True, eval=True, warmup_autotune=True):
18
+ from transformers import LlamaConfig, LlamaForCausalLM
19
+ config = LlamaConfig.from_pretrained(model)
20
+
21
+ def noop(*args, **kwargs):
22
+ pass
23
+
24
+ torch.nn.init.kaiming_uniform_ = noop
25
+ torch.nn.init.uniform_ = noop
26
+ torch.nn.init.normal_ = noop
27
+
28
+ torch.set_default_dtype(torch.half)
29
+ transformers.modeling_utils._init_weights = False
30
+ torch.set_default_dtype(torch.half)
31
+ model = LlamaForCausalLM(config)
32
+ torch.set_default_dtype(torch.float)
33
+ if eval:
34
+ model = model.eval()
35
+ layers = find_layers(model)
36
+ for name in ['lm_head']:
37
+ if name in layers:
38
+ del layers[name]
39
+ quant.make_quant_linear(model, layers, wbits, groupsize)
40
+
41
+ del layers
42
+
43
+ print('Loading model ...')
44
+ if checkpoint.endswith('.safetensors'):
45
+ from safetensors.torch import load_file as safe_load
46
+ model.load_state_dict(safe_load(checkpoint), strict=False)
47
+ else:
48
+ model.load_state_dict(torch.load(checkpoint), strict=False)
49
+
50
+ if eval:
51
+ quant.make_quant_attn(model)
52
+ quant.make_quant_norm(model)
53
+ if fused_mlp:
54
+ quant.make_fused_mlp(model)
55
+ if warmup_autotune:
56
+ quant.autotune_warmup_linear(model, transpose=not (eval))
57
+ if eval and fused_mlp:
58
+ quant.autotune_warmup_fused(model)
59
+ model.seqlen = 2048
60
+ print('Done.')
61
+
62
+ return model
63
+
64
+
65
+ model.to(DEV)
66
+ tokenizer = AutoTokenizer.from_pretrained("Wizard-Vicuna-13B-Uncensored-GPTQ", use_fast=False)
67
+ input_ids = tokenizer.encode("TEXT PROMPT GOES HERE", return_tensors="pt").to(DEV)
68
+
69
+ with torch.no_grad():
70
+ generated_ids = model.generate(
71
+ input_ids,
72
+ do_sample=True,
73
+ min_length=50,
74
+ max_length=200,
75
+ top_p=0.99,
76
+ temperature=0.8,
77
+ )
78
+ print(tokenizer.decode([el.item() for el in generated_ids[0]]))