pszemraj commited on
Commit
89b6f9b
·
verified ·
1 Parent(s): 839b37f

Upload inference.py

Browse files
Files changed (1) hide show
  1. inference.py +142 -0
inference.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ import torch
3
+ import fire
4
+ import json
5
+ from pathlib import Path
6
+ import sys
7
+
8
+ from nGPT_pytorch import nGPT
9
+
10
+
11
+ def exists(v):
12
+ return v is not None
13
+
14
+
15
+ def decode_token(token):
16
+ return str(chr(max(32, token)))
17
+
18
+
19
+ def decode_tokens(tokens):
20
+ return "".join(list(map(decode_token, tokens)))
21
+
22
+
23
+ def log(t, eps=1e-20):
24
+ return torch.log(t.clamp(min=eps))
25
+
26
+
27
+ def gumbel_noise(t):
28
+ noise = torch.zeros_like(t).uniform_(0, 1)
29
+ return -log(-log(noise))
30
+
31
+
32
+ def gumbel_sample(t, temperature=1.0, dim=-1, keepdim=True):
33
+ return ((t / max(temperature, 1e-10)) + gumbel_noise(t)).argmax(
34
+ dim=dim, keepdim=keepdim
35
+ )
36
+
37
+
38
+ def min_p_filter(logits, min_p=0.1):
39
+ probs = logits.softmax(dim=-1)
40
+ max_probs = probs.amax(dim=-1, keepdim=True)
41
+ limit = min_p * max_probs
42
+ return torch.where(probs < limit, float("-inf"), logits)
43
+
44
+
45
+ def base_decoding(
46
+ net,
47
+ prompt: torch.Tensor,
48
+ seq_len: int,
49
+ temperature=1.5,
50
+ min_p=1e-1,
51
+ filter_thres=0.9,
52
+ ):
53
+ prompt_seq_len, out = prompt.shape[-1], prompt.clone()
54
+ sample_num_times = max(0, seq_len - prompt_seq_len)
55
+
56
+ for _ in range(sample_num_times):
57
+ logits = net(out)
58
+ logits = logits[:, -1]
59
+
60
+ logits = min_p_filter(logits, min_p=min_p)
61
+ sample = gumbel_sample(logits, temperature=temperature, dim=-1)
62
+
63
+ out = torch.cat((out, sample), dim=-1)
64
+
65
+ return out[..., prompt_seq_len:]
66
+
67
+
68
+ def main(
69
+ checkpoint_path: str,
70
+ prompt: str,
71
+ max_new_tokens: int = 100,
72
+ temperature: float = 1.0,
73
+ min_p: float = 0.1,
74
+ device: str = "cuda" if torch.cuda.is_available() else "cpu",
75
+ ):
76
+ """Generate text using a trained nGPT model."""
77
+
78
+ # Load checkpoint
79
+ checkpoint_path = Path(checkpoint_path)
80
+ if not checkpoint_path.exists():
81
+ print(f"Error: Checkpoint not found at {checkpoint_path}")
82
+ sys.exit(1)
83
+
84
+ print(f"Loading checkpoint from {checkpoint_path}...")
85
+ checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=True)
86
+
87
+ # Get config from checkpoint or file
88
+ config = checkpoint.get("config", {})
89
+ if not config and checkpoint_path.parent.joinpath("config.json").exists():
90
+ with open(checkpoint_path.parent.joinpath("config.json")) as f:
91
+ config = json.load(f)
92
+
93
+ use_parametrize = config.get("use_parametrize", True)
94
+
95
+ # Initialize model
96
+ model = nGPT(
97
+ num_tokens=256,
98
+ dim=512,
99
+ depth=8,
100
+ tied_embedding=True,
101
+ add_value_residual=True,
102
+ attn_norm_qk=False,
103
+ manual_norm_weights=not use_parametrize,
104
+ ).to(device)
105
+
106
+ # Load weights
107
+ model.load_state_dict(checkpoint["model_state_dict"])
108
+ model.eval()
109
+
110
+ print("\nModel loaded successfully. Generating with:")
111
+ print(f" Temperature: {temperature}")
112
+ print(f" Min-p: {min_p}")
113
+ print(f" Max new tokens: {max_new_tokens}")
114
+
115
+ # Convert prompt to tensor
116
+ prompt_tensor = torch.tensor(
117
+ [ord(c) for c in prompt], dtype=torch.long, device=device
118
+ )
119
+ prompt_tensor = prompt_tensor.unsqueeze(0)
120
+
121
+ # Generate
122
+ with torch.no_grad():
123
+ sampled = base_decoding(
124
+ model,
125
+ prompt_tensor,
126
+ seq_len=max_new_tokens,
127
+ temperature=temperature,
128
+ min_p=min_p,
129
+ )
130
+
131
+ generated = decode_tokens(sampled[0])
132
+
133
+ print("\nGenerated text:")
134
+ print("-" * 80)
135
+ print(prompt + generated)
136
+ print("-" * 80)
137
+
138
+ return generated
139
+
140
+
141
+ if __name__ == "__main__":
142
+ fire.Fire(main)