xuxw98 commited on
Commit
85e24d4
·
1 Parent(s): a29d76a

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +11 -11
  2. generate.py +28 -35
app.py CHANGED
@@ -91,17 +91,17 @@ def instruct_generate(
91
  encoded = tokenizer.encode(prompt, bos=True, eos=False, device=model.device)
92
  # prompt_length = encoded.size(0)
93
 
94
- # y = generate(
95
- # model,
96
- # idx=encoded,
97
- # max_seq_length=max_new_tokens,
98
- # max_new_tokens=max_new_tokens,
99
- # temperature=temperature,
100
- # top_k=top_k,
101
- # eos_id=tokenizer.eos_id
102
- # )
103
-
104
- y = generate(model, encoded, max_new_tokens, temperature=temperature, top_k=top_k, eos_id=tokenizer.eos_id)
105
 
106
  output = tokenizer.decode(y)
107
  output = output.split("### Response:")[1].strip()
 
91
  encoded = tokenizer.encode(prompt, bos=True, eos=False, device=model.device)
92
  # prompt_length = encoded.size(0)
93
 
94
+ y = generate(
95
+ model,
96
+ idx=encoded,
97
+ max_seq_length=max_new_tokens,
98
+ max_new_tokens=max_new_tokens,
99
+ temperature=temperature,
100
+ top_k=top_k,
101
+ eos_id=tokenizer.eos_id
102
+ )
103
+
104
+ # y = generate(model, encoded, max_new_tokens, temperature=temperature, top_k=top_k, eos_id=tokenizer.eos_id)
105
 
106
  output = tokenizer.decode(y)
107
  output = output.split("### Response:")[1].strip()
generate.py CHANGED
@@ -12,16 +12,15 @@ wd = Path(__file__).parent.parent.resolve()
12
  sys.path.append(str(wd))
13
 
14
  from lit_llama import LLaMA, Tokenizer
15
- from lit_llama.utils import lazy_load, llama_model_lookup, quantization
16
 
17
 
18
  @torch.no_grad()
19
  def generate(
20
- model: LLaMA,
21
  idx: torch.Tensor,
22
  max_new_tokens: int,
23
- *,
24
- max_seq_length: Optional[int] = None,
25
  temperature: float = 1.0,
26
  top_k: Optional[int] = None,
27
  eos_id: Optional[int] = None,
@@ -42,49 +41,35 @@ def generate(
42
  # create an empty tensor of the expected final shape and fill in the current tokens
43
  T = idx.size(0)
44
  T_new = T + max_new_tokens
45
- if max_seq_length is None:
46
- max_seq_length = min(T_new, model.config.block_size)
47
-
48
- device, dtype = idx.device, idx.dtype
49
- # create an empty tensor of the expected final shape and fill in the current tokens
50
- empty = torch.empty(T_new, dtype=dtype, device=device)
51
  empty[:T] = idx
52
  idx = empty
53
- input_pos = torch.arange(0, T, device=device)
54
-
55
- if idx.device.type == "xla":
56
- import torch_xla.core.xla_model as xm
57
-
58
- xm.mark_step()
59
 
60
  # generate max_new_tokens tokens
61
- for _ in range(max_new_tokens):
62
- x = idx.index_select(0, input_pos).view(1, -1)
 
 
 
63
 
64
  # forward
65
- logits = model(x, max_seq_length, input_pos)
66
  logits = logits[0, -1] / temperature
67
 
68
  # optionally crop the logits to only the top k options
69
  if top_k is not None:
70
  v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
71
- logits = torch.where(logits < v[[-1]], -float("Inf"), logits)
72
 
73
  probs = torch.nn.functional.softmax(logits, dim=-1)
74
- idx_next = torch.multinomial(probs, num_samples=1).to(dtype=dtype)
75
-
76
- # advance
77
- input_pos = input_pos[-1:] + 1
78
-
79
- if idx.device.type == "xla":
80
- xm.mark_step()
81
 
82
  # concatenate the new generation
83
- idx = idx.index_copy(0, input_pos, idx_next)
84
 
85
  # if <eos> token is triggered, return the output (stop generation)
86
  if idx_next == eos_id:
87
- return idx[:input_pos] # include the EOS token
88
 
89
  return idx
90
 
@@ -118,22 +103,24 @@ def main(
118
  assert checkpoint_path.is_file(), checkpoint_path
119
  assert tokenizer_path.is_file(), tokenizer_path
120
 
121
- precision = "bf16-true" if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else "32-true"
122
- fabric = L.Fabric(devices=1, precision=precision)
123
 
124
  print("Loading model ...", file=sys.stderr)
125
  t0 = time.time()
126
  with lazy_load(checkpoint_path) as checkpoint:
127
  name = llama_model_lookup(checkpoint)
128
 
129
- with fabric.init_module(empty_init=True), quantization(mode=quantize):
 
 
130
  model = LLaMA.from_name(name)
131
 
132
  model.load_state_dict(checkpoint)
133
  print(f"Time to load model: {time.time() - t0:.02f} seconds.", file=sys.stderr)
134
 
135
  model.eval()
136
- model = fabric.setup(model)
137
 
138
  tokenizer = Tokenizer(tokenizer_path)
139
  encoded = tokenizer.encode(prompt, bos=True, eos=False, device=fabric.device)
@@ -142,10 +129,16 @@ def main(
142
  L.seed_everything(1234)
143
  for i in range(num_samples):
144
  t0 = time.perf_counter()
145
- y = generate(model, encoded, max_new_tokens, temperature=temperature, top_k=top_k)
 
 
 
 
 
 
 
146
  t = time.perf_counter() - t0
147
 
148
- model.reset_cache()
149
  print(tokenizer.decode(y))
150
  tokens_generated = y.size(0) - prompt_length
151
  print(f"Time for inference {i + 1}: {t:.02f} sec total, {tokens_generated / t:.02f} tokens/sec", file=sys.stderr)
 
12
  sys.path.append(str(wd))
13
 
14
  from lit_llama import LLaMA, Tokenizer
15
+ from lit_llama.utils import EmptyInitOnDevice, lazy_load, llama_model_lookup
16
 
17
 
18
  @torch.no_grad()
19
  def generate(
20
+ model: torch.nn.Module,
21
  idx: torch.Tensor,
22
  max_new_tokens: int,
23
+ max_seq_length: int,
 
24
  temperature: float = 1.0,
25
  top_k: Optional[int] = None,
26
  eos_id: Optional[int] = None,
 
41
  # create an empty tensor of the expected final shape and fill in the current tokens
42
  T = idx.size(0)
43
  T_new = T + max_new_tokens
44
+ empty = torch.empty(T_new, dtype=idx.dtype, device=idx.device)
 
 
 
 
 
45
  empty[:T] = idx
46
  idx = empty
 
 
 
 
 
 
47
 
48
  # generate max_new_tokens tokens
49
+ for t in range(T, T_new):
50
+ # ignore the not-filled-yet tokens
51
+ idx_cond = idx[:t]
52
+ # if the sequence context is growing too long we must crop it at max_seq_length
53
+ idx_cond = idx_cond if t <= max_seq_length else idx_cond[-max_seq_length:]
54
 
55
  # forward
56
+ logits = model(idx_cond.view(1, -1))
57
  logits = logits[0, -1] / temperature
58
 
59
  # optionally crop the logits to only the top k options
60
  if top_k is not None:
61
  v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
62
+ logits[logits < v[[-1]]] = -float("Inf")
63
 
64
  probs = torch.nn.functional.softmax(logits, dim=-1)
65
+ idx_next = torch.multinomial(probs, num_samples=1)
 
 
 
 
 
 
66
 
67
  # concatenate the new generation
68
+ idx[t] = idx_next
69
 
70
  # if <eos> token is triggered, return the output (stop generation)
71
  if idx_next == eos_id:
72
+ return idx[:t + 1] # include the EOS token
73
 
74
  return idx
75
 
 
103
  assert checkpoint_path.is_file(), checkpoint_path
104
  assert tokenizer_path.is_file(), tokenizer_path
105
 
106
+ fabric = L.Fabric(devices=1)
107
+ dtype = torch.bfloat16 if fabric.device.type == "cuda" and torch.cuda.is_bf16_supported() else torch.float32
108
 
109
  print("Loading model ...", file=sys.stderr)
110
  t0 = time.time()
111
  with lazy_load(checkpoint_path) as checkpoint:
112
  name = llama_model_lookup(checkpoint)
113
 
114
+ with EmptyInitOnDevice(
115
+ device=fabric.device, dtype=dtype, quantization_mode=quantize
116
+ ):
117
  model = LLaMA.from_name(name)
118
 
119
  model.load_state_dict(checkpoint)
120
  print(f"Time to load model: {time.time() - t0:.02f} seconds.", file=sys.stderr)
121
 
122
  model.eval()
123
+ model = fabric.setup_module(model)
124
 
125
  tokenizer = Tokenizer(tokenizer_path)
126
  encoded = tokenizer.encode(prompt, bos=True, eos=False, device=fabric.device)
 
129
  L.seed_everything(1234)
130
  for i in range(num_samples):
131
  t0 = time.perf_counter()
132
+ y = generate(
133
+ model,
134
+ encoded,
135
+ max_new_tokens,
136
+ model.config.block_size, # type: ignore[union-attr,arg-type]
137
+ temperature=temperature,
138
+ top_k=top_k,
139
+ )
140
  t = time.perf_counter() - t0
141
 
 
142
  print(tokenizer.decode(y))
143
  tokens_generated = y.size(0) - prompt_length
144
  print(f"Time for inference {i + 1}: {t:.02f} sec total, {tokens_generated / t:.02f} tokens/sec", file=sys.stderr)