Jong Wook Kim commited on
Commit
6f40009
·
1 Parent(s): cba7812

detector model

Browse files
detector/README.md ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ GPT-2 Output Detector
2
+ =====================
3
+
4
+ This directory contains the code for working with the GPT-2 output detector model, obtained by fine-tuning a
5
+ [RoBERTa model](https://ai.facebook.com/blog/roberta-an-optimized-method-for-pretraining-self-supervised-nlp-systems/)
6
+ with [the outputs of the 1.5B-parameter GPT-2 model](https://github.com/openai/gpt-2-output-dataset).
7
+ For motivations and discussions regarding the release of this detector model, please check out
8
+ [out blog post](https://openai.com/blog/gpt-2-6-month-follow-up/) and [report](https://arxiv.org/abs/1908.09203).
9
+
10
+ ## Downloading a pre-trained detector model
11
+
12
+ Download the weights for the fine-tuned `roberta-base` model (478 MB):
13
+
14
+ ```bash
15
+ wget https://storage.googleapis.com/gpt-2/detector-models/v1/detector-base.pt
16
+ ```
17
+
18
+ or `roberta-large` model (1.5 GB):
19
+
20
+ ```bash
21
+ wget https://storage.googleapis.com/gpt-2/detector-models/v1/detector-large.pt
22
+ ```
23
+
24
+ These RoBERTa-based models are fine-tuned with a mixture of temperature-1 and nucleus sampling outputs,
25
+ which should generalize well to outputs generated using different sampling methods.
26
+
27
+ ## Running a detector model
28
+
29
+ You can launch a web UI in which you can enter a text and see the detector model's prediction
30
+ on whether or not it was generated by a GPT-2 model.
31
+
32
+ ```bash
33
+ # (on the top-level directory of this repository)
34
+ pip install -r requirements.txt
35
+ python -m detector.server detector-base.pt
36
+ ```
37
+
38
+ ## Training a new detector model
39
+
40
+ You can use the provided training script to train a detector model on a new set of datasets.
41
+ We recommend using a GPU machine for this task.
42
+
43
+ ```bash
44
+ # (on the top-level directory of this repository)
45
+ pip install -r requirements.txt
46
+ python -m detector.train
47
+ ```
48
+
49
+ The training script supports a number of different options; append `--help` to the command above for usage.
detector/dataset.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import numpy as np
3
+ from typing import List
4
+
5
+ import torch
6
+ from torch.utils.data import Dataset
7
+ from tqdm import tqdm
8
+ from transformers import PreTrainedTokenizer
9
+
10
+ from .download import download
11
+
12
+
13
+ def load_texts(data_file, expected_size=None):
14
+ texts = []
15
+
16
+ for line in tqdm(open(data_file), total=expected_size, desc=f'Loading {data_file}'):
17
+ texts.append(json.loads(line)['text'])
18
+
19
+ return texts
20
+
21
+
22
+ class Corpus:
23
+ def __init__(self, name, data_dir='data', skip_train=False):
24
+ download(name, data_dir=data_dir)
25
+ self.name = name
26
+ self.train = load_texts(f'{data_dir}/{name}.train.jsonl', expected_size=250000) if not skip_train else None
27
+ self.test = load_texts(f'{data_dir}/{name}.test.jsonl', expected_size=5000)
28
+ self.valid = load_texts(f'{data_dir}/{name}.valid.jsonl', expected_size=5000)
29
+
30
+
31
+ class EncodedDataset(Dataset):
32
+ def __init__(self, real_texts: List[str], fake_texts: List[str], tokenizer: PreTrainedTokenizer,
33
+ max_sequence_length: int = None, min_sequence_length: int = None, epoch_size: int = None,
34
+ token_dropout: float = None, seed: int = None):
35
+ self.real_texts = real_texts
36
+ self.fake_texts = fake_texts
37
+ self.tokenizer = tokenizer
38
+ self.max_sequence_length = max_sequence_length
39
+ self.min_sequence_length = min_sequence_length
40
+ self.epoch_size = epoch_size
41
+ self.token_dropout = token_dropout
42
+ self.random = np.random.RandomState(seed)
43
+
44
+ def __len__(self):
45
+ return self.epoch_size or len(self.real_texts) + len(self.fake_texts)
46
+
47
+ def __getitem__(self, index):
48
+ if self.epoch_size is not None:
49
+ label = self.random.randint(2)
50
+ texts = [self.fake_texts, self.real_texts][label]
51
+ text = texts[self.random.randint(len(texts))]
52
+ else:
53
+ if index < len(self.real_texts):
54
+ text = self.real_texts[index]
55
+ label = 1
56
+ else:
57
+ text = self.fake_texts[index - len(self.real_texts)]
58
+ label = 0
59
+
60
+ tokens = self.tokenizer.encode(text)
61
+
62
+ if self.max_sequence_length is None:
63
+ tokens = tokens[:self.tokenizer.max_len - 2]
64
+ else:
65
+ output_length = min(len(tokens), self.max_sequence_length)
66
+ if self.min_sequence_length:
67
+ output_length = self.random.randint(min(self.min_sequence_length, len(tokens)), output_length + 1)
68
+ start_index = 0 if len(tokens) <= output_length else self.random.randint(0, len(tokens) - output_length + 1)
69
+ end_index = start_index + output_length
70
+ tokens = tokens[start_index:end_index]
71
+
72
+ if self.token_dropout:
73
+ dropout_mask = self.random.binomial(1, self.token_dropout, len(tokens)).astype(np.bool)
74
+ tokens = np.array(tokens)
75
+ tokens[dropout_mask] = self.tokenizer.unk_token_id
76
+ tokens = tokens.tolist()
77
+
78
+ if self.max_sequence_length is None or len(tokens) == self.max_sequence_length:
79
+ mask = torch.ones(len(tokens) + 2)
80
+ return torch.tensor([self.tokenizer.bos_token_id] + tokens + [self.tokenizer.eos_token_id]), mask, label
81
+
82
+ padding = [self.tokenizer.pad_token_id] * (self.max_sequence_length - len(tokens))
83
+ tokens = torch.tensor([self.tokenizer.bos_token_id] + tokens + [self.tokenizer.eos_token_id] + padding)
84
+ mask = torch.ones(tokens.shape[0])
85
+ mask[-len(padding):] = 0
86
+ return tokens, mask, label
detector/download.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import requests
4
+ import torch.distributed as dist
5
+ from tqdm import tqdm
6
+
7
+ from .utils import distributed
8
+
9
+ ALL_DATASETS = [
10
+ 'webtext',
11
+ 'small-117M', 'small-117M-k40', 'small-117M-nucleus',
12
+ 'medium-345M', 'medium-345M-k40', 'medium-345M-nucleus',
13
+ 'large-762M', 'large-762M-k40', 'large-762M-nucleus',
14
+ 'xl-1542M', 'xl-1542M-k40', 'xl-1542M-nucleus'
15
+ ]
16
+
17
+
18
+ def download(*datasets, data_dir='data'):
19
+ os.makedirs(data_dir, exist_ok=True)
20
+
21
+ if distributed() and dist.get_rank() > 0:
22
+ dist.barrier()
23
+
24
+ for ds in datasets:
25
+ assert ds in ALL_DATASETS, f'Unknown dataset {ds}'
26
+
27
+ for split in ['train', 'valid', 'test']:
28
+ filename = ds + "." + split + '.jsonl'
29
+ output_file = os.path.join(data_dir, filename)
30
+ if os.path.isfile(output_file):
31
+ continue
32
+
33
+ r = requests.get("https://storage.googleapis.com/gpt-2/output-dataset/v1/" + filename, stream=True)
34
+
35
+ with open(output_file, 'wb') as f:
36
+ file_size = int(r.headers["content-length"])
37
+ chunk_size = 1000
38
+ with tqdm(ncols=100, desc="Fetching " + filename, total=file_size, unit_scale=True) as pbar:
39
+ # 1k for chunk_size, since Ethernet packet size is around 1500 bytes
40
+ for chunk in r.iter_content(chunk_size=chunk_size):
41
+ f.write(chunk)
42
+ pbar.update(chunk_size)
43
+
44
+ if distributed() and dist.get_rank() == 0:
45
+ dist.barrier()
46
+
47
+
48
+ if __name__ == '__main__':
49
+ download(*ALL_DATASETS)
detector/index.html ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!doctype html>
2
+ <html>
3
+ <head>
4
+ <title>GPT-2 Output Detector</title>
5
+ <style type="text/css">
6
+ * {
7
+ box-sizing: border-box;
8
+ }
9
+
10
+ body {
11
+ font-family: sans-serif;
12
+ margin: 0;
13
+ }
14
+
15
+ h1 {
16
+ font-weight: lighter;
17
+ }
18
+
19
+ a {
20
+ text-decoration: none;
21
+ color: #666;
22
+ }
23
+
24
+ a:hover {
25
+ text-decoration: underline;
26
+ }
27
+
28
+ #container {
29
+ margin: auto;
30
+ width: 960px;
31
+ }
32
+
33
+ #textbox {
34
+ font-family: serif;
35
+ font-size: 16pt;
36
+ width: 100%;
37
+ height: 480px;
38
+ padding: 20px 30px;
39
+ line-height: 1.6;
40
+ }
41
+
42
+ .bar-row {
43
+ height: 30px;
44
+ }
45
+ #real-percentage {
46
+ width: 80px;
47
+ vertical-align: top;
48
+ }
49
+ #bar-container {
50
+ width: 800px;
51
+ background-color: #ff7674;
52
+ line-height: 0.5;
53
+ position:relative;
54
+ top:6px;
55
+ }
56
+ #fake-percentage {
57
+ width: 80px;
58
+ vertical-align: top;
59
+ }
60
+ #bar {
61
+ display: inline-block;
62
+ height: 30px;
63
+ background-color: #83aaff;
64
+ }
65
+ em {
66
+ font-family: monospace;
67
+ font-style: normal;
68
+ }
69
+ </style>
70
+ </head>
71
+ <body>
72
+ <div id="container">
73
+ <h1>GPT-2 Output Detector Demo</h1>
74
+ <p>
75
+ This is an online demo of the
76
+ <a href="https://github.com/openai/gpt-2-output-dataset/tree/master/detector">GPT-2 output detector</a>
77
+ model. Enter some text in the text box; the predicted probabilities will be displayed below.
78
+ <u>The results start to get reliable after around 50 tokens.</u>
79
+ </p>
80
+ <textarea id="textbox" placeholder="Enter text here"></textarea>
81
+ <div><table cellspacing="0" cellpadding="0">
82
+ <tr class="bar-row" style="vertical-align: bottom;">
83
+ <td style="text-align: left;">Real</td>
84
+ <td id="message" style="text-align: center;"></td>
85
+ <td style="text-align: right;">Fake</td>
86
+ </tr>
87
+ <tr class="bar-row">
88
+ <td id="real-percentage" style="text-align: left; vertical-align: bottom;"></td>
89
+ <td id="bar-container"><div id="bar" style="width: 50%;"></div></td>
90
+ <td id="fake-percentage" style="text-align: right; vertical-align: bottom;"></td>
91
+ </tr>
92
+ </table></div>
93
+ </div>
94
+ <script>
95
+ let textbox = document.getElementById('textbox');
96
+ let last_submit = null;
97
+
98
+ let real_percentage = document.getElementById('real-percentage');
99
+ let fake_percentage = document.getElementById('fake-percentage');
100
+ let bar = document.getElementById('bar');
101
+ let message = document.getElementById('message');
102
+
103
+ function update_graph(result) {
104
+ if (result === null) {
105
+ real_percentage.innerHTML = '';
106
+ fake_percentage.innerHTML = '';
107
+ bar.style.width = '50%';
108
+ message.innerHTML = '';
109
+ } else {
110
+ let percentage = result.real_probability;
111
+ real_percentage.innerHTML = (100 * percentage).toFixed(2) + '%';
112
+ fake_percentage.innerHTML = (100 * (1 - percentage)).toFixed(2) + '%';
113
+ bar.style.width = (100 * percentage).toFixed(2) + '%';
114
+ if (result.used_tokens === result.all_tokens) {
115
+ message.innerHTML = `Prediction based on ${result.used_tokens} tokens`;
116
+ } else {
117
+ message.innerHTML = `Prediction based on the first ${result.used_tokens} tokens among the total ${result.all_tokens}`;
118
+ }
119
+ }
120
+ }
121
+
122
+ textbox.oninput = () => {
123
+ if (last_submit) {
124
+ clearTimeout(last_submit);
125
+ }
126
+ if (textbox.value.length === 0) {
127
+ update_graph(null);
128
+ return;
129
+ }
130
+ message.innerText = 'Predicting ...';
131
+ last_submit = setTimeout(() => {
132
+ let req = new XMLHttpRequest();
133
+ if (textbox.value.length === 0) {
134
+ update_graph(null);
135
+ return;
136
+ }
137
+ req.open('GET', '/?' + textbox.value, true);
138
+ req.onreadystatechange = () => {
139
+ if (req.readyState !== 4) return;
140
+ if (req.status !== 200) throw new Error("HTTP status: " + req.status);
141
+ let result = JSON.parse(req.responseText);
142
+ update_graph(result);
143
+ };
144
+ req.send();
145
+ }, 1000);
146
+
147
+ };
148
+
149
+ window.addEventListener('DOMContentLoaded', () => {
150
+ textbox.focus();
151
+ });
152
+ </script>
153
+ </body>
154
+ </html>
detector/server.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ from http.server import HTTPServer, SimpleHTTPRequestHandler
4
+ from multiprocessing import Process
5
+ import subprocess
6
+ from transformers import RobertaForSequenceClassification, RobertaTokenizer
7
+ import json
8
+ import fire
9
+ import torch
10
+ from urllib.parse import urlparse, unquote
11
+
12
+
13
+ model: RobertaForSequenceClassification = None
14
+ tokenizer: RobertaTokenizer = None
15
+ device: str = None
16
+
17
+ def log(*args):
18
+ print(f"[{os.environ.get('RANK', '')}]", *args, file=sys.stderr)
19
+
20
+
21
+ class RequestHandler(SimpleHTTPRequestHandler):
22
+
23
+ def do_GET(self):
24
+ query = unquote(urlparse(self.path).query)
25
+
26
+ if not query:
27
+ self.begin_content('text/html')
28
+
29
+ html = os.path.join(os.path.dirname(__file__), 'index.html')
30
+ self.wfile.write(open(html).read().encode())
31
+ return
32
+
33
+ self.begin_content('application/json;charset=UTF-8')
34
+
35
+ tokens = tokenizer.encode(query)
36
+ all_tokens = len(tokens)
37
+ tokens = tokens[:tokenizer.max_len - 2]
38
+ used_tokens = len(tokens)
39
+ tokens = torch.tensor([tokenizer.bos_token_id] + tokens + [tokenizer.eos_token_id]).unsqueeze(0)
40
+ mask = torch.ones_like(tokens)
41
+
42
+ with torch.no_grad():
43
+ logits = model(tokens.to(device), attention_mask=mask.to(device))[0]
44
+ probs = logits.softmax(dim=-1)
45
+
46
+ fake, real = probs.detach().cpu().flatten().numpy().tolist()
47
+
48
+ self.wfile.write(json.dumps(dict(
49
+ all_tokens=all_tokens,
50
+ used_tokens=used_tokens,
51
+ real_probability=real,
52
+ fake_probability=fake
53
+ )).encode())
54
+
55
+ def begin_content(self, content_type):
56
+ self.send_response(200)
57
+ self.send_header('Content-Type', content_type)
58
+ self.send_header('Access-Control-Allow-Origin', '*')
59
+ self.end_headers()
60
+
61
+ def log_message(self, format, *args):
62
+ log(format % args)
63
+
64
+
65
+ def serve_forever(server, model, tokenizer, device):
66
+ log('Process has started; loading the model ...')
67
+ globals()['model'] = model.to(device)
68
+ globals()['tokenizer'] = tokenizer
69
+ globals()['device'] = device
70
+
71
+ log('Ready to serve')
72
+ server.serve_forever()
73
+
74
+
75
+ def main(checkpoint, port=8080, device='cuda' if torch.cuda.is_available() else 'cpu'):
76
+ if checkpoint.startswith('gs://'):
77
+ print(f'Downloading {checkpoint}', file=sys.stderr)
78
+ subprocess.check_output(['gsutil', 'cp', checkpoint, '.'])
79
+ checkpoint = os.path.basename(checkpoint)
80
+ assert os.path.isfile(checkpoint)
81
+
82
+ print(f'Loading checkpoint from {checkpoint}')
83
+ data = torch.load(checkpoint, map_location='cpu')
84
+
85
+ model_name = 'roberta-large' if data['args']['large'] else 'roberta-base'
86
+ model = RobertaForSequenceClassification.from_pretrained(model_name)
87
+ tokenizer = RobertaTokenizer.from_pretrained(model_name)
88
+
89
+ model.load_state_dict(data['model_state_dict'])
90
+ model.eval()
91
+
92
+ print(f'Starting HTTP server on port {port}', file=sys.stderr)
93
+ server = HTTPServer(('0.0.0.0', port), RequestHandler)
94
+
95
+ # avoid calling CUDA API before forking; doing so in a subprocess is fine.
96
+ num_workers = int(subprocess.check_output(['python', '-c', 'import torch; print(torch.cuda.device_count())']))
97
+
98
+ if num_workers <= 1:
99
+ serve_forever(server, model, tokenizer, device)
100
+ else:
101
+ print(f'Launching {num_workers} worker processes...')
102
+
103
+ subprocesses = []
104
+
105
+ for i in range(num_workers):
106
+ os.environ['RANK'] = f'{i}'
107
+ os.environ['CUDA_VISIBLE_DEVICES'] = f'{i}'
108
+ process = Process(target=serve_forever, args=(server, model, tokenizer, device))
109
+ process.start()
110
+ subprocesses.append(process)
111
+
112
+ del os.environ['RANK']
113
+ del os.environ['CUDA_VISIBLE_DEVICES']
114
+
115
+ for process in subprocesses:
116
+ process.join()
117
+
118
+
119
+ if __name__ == '__main__':
120
+ fire.Fire(main)
detector/train.py ADDED
@@ -0,0 +1,305 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Training code for the detector model"""
2
+
3
+ import argparse
4
+ import os
5
+ import subprocess
6
+ import sys
7
+ from itertools import count
8
+ from multiprocessing import Process
9
+
10
+ import torch
11
+ import torch.distributed as dist
12
+ from torch import nn
13
+ from torch.nn.parallel import DistributedDataParallel
14
+ from torch.optim import Adam
15
+ from torch.utils.data import DataLoader, DistributedSampler, RandomSampler
16
+ from tqdm import tqdm
17
+ from transformers import *
18
+
19
+ from .dataset import Corpus, EncodedDataset
20
+ from .download import download
21
+ from .utils import summary, distributed
22
+
23
+
24
+ def setup_distributed(port=29500):
25
+ if not dist.is_available() or not torch.cuda.is_available() or torch.cuda.device_count() <= 1:
26
+ return 0, 1
27
+
28
+ if 'MPIR_CVAR_CH3_INTERFACE_HOSTNAME' in os.environ:
29
+ from mpi4py import MPI
30
+ mpi_rank = MPI.COMM_WORLD.Get_rank()
31
+ mpi_size = MPI.COMM_WORLD.Get_size()
32
+
33
+ os.environ["MASTER_ADDR"] = '127.0.0.1'
34
+ os.environ["MASTER_PORT"] = str(port)
35
+
36
+ dist.init_process_group(backend="nccl", world_size=mpi_size, rank=mpi_rank)
37
+ return mpi_rank, mpi_size
38
+
39
+ dist.init_process_group(backend="nccl", init_method="env://")
40
+ return dist.get_rank(), dist.get_world_size()
41
+
42
+
43
+ def load_datasets(data_dir, real_dataset, fake_dataset, tokenizer, batch_size,
44
+ max_sequence_length, random_sequence_length, epoch_size=None, token_dropout=None, seed=None):
45
+ if fake_dataset == 'TWO':
46
+ download(real_dataset, 'xl-1542M', 'xl-1542M-nucleus', data_dir=data_dir)
47
+ elif fake_dataset == 'THREE':
48
+ download(real_dataset, 'xl-1542M', 'xl-1542M-k40', 'xl-1542M-nucleus', data_dir=data_dir)
49
+ else:
50
+ download(real_dataset, fake_dataset, data_dir=data_dir)
51
+
52
+ real_corpus = Corpus(real_dataset, data_dir=data_dir)
53
+
54
+ if fake_dataset == "TWO":
55
+ real_train, real_valid = real_corpus.train * 2, real_corpus.valid * 2
56
+ fake_corpora = [Corpus(name, data_dir=data_dir) for name in ['xl-1542M', 'xl-1542M-nucleus']]
57
+ fake_train = sum([corpus.train for corpus in fake_corpora], [])
58
+ fake_valid = sum([corpus.valid for corpus in fake_corpora], [])
59
+ elif fake_dataset == "THREE":
60
+ real_train, real_valid = real_corpus.train * 3, real_corpus.valid * 3
61
+ fake_corpora = [Corpus(name, data_dir=data_dir) for name in
62
+ ['xl-1542M', 'xl-1542M-k40', 'xl-1542M-nucleus']]
63
+ fake_train = sum([corpus.train for corpus in fake_corpora], [])
64
+ fake_valid = sum([corpus.valid for corpus in fake_corpora], [])
65
+ else:
66
+ fake_corpus = Corpus(fake_dataset, data_dir=data_dir)
67
+
68
+ real_train, real_valid = real_corpus.train, real_corpus.valid
69
+ fake_train, fake_valid = fake_corpus.train, fake_corpus.valid
70
+
71
+ Sampler = DistributedSampler if distributed() and dist.get_world_size() > 1 else RandomSampler
72
+
73
+ min_sequence_length = 10 if random_sequence_length else None
74
+ train_dataset = EncodedDataset(real_train, fake_train, tokenizer, max_sequence_length, min_sequence_length,
75
+ epoch_size, token_dropout, seed)
76
+ train_loader = DataLoader(train_dataset, batch_size, sampler=Sampler(train_dataset), num_workers=0)
77
+
78
+ validation_dataset = EncodedDataset(real_valid, fake_valid, tokenizer)
79
+ validation_loader = DataLoader(validation_dataset, batch_size=1, sampler=Sampler(validation_dataset))
80
+
81
+ return train_loader, validation_loader
82
+
83
+
84
+ def accuracy_sum(logits, labels):
85
+ if list(logits.shape) == list(labels.shape) + [2]:
86
+ # 2-d outputs
87
+ classification = (logits[..., 0] < logits[..., 1]).long().flatten()
88
+ else:
89
+ classification = (logits > 0).long().flatten()
90
+ assert classification.shape == labels.shape
91
+ return (classification == labels).float().sum().item()
92
+
93
+
94
+ def train(model: nn.Module, optimizer, device: str, loader: DataLoader, desc='Train'):
95
+ model.train()
96
+
97
+ train_accuracy = 0
98
+ train_epoch_size = 0
99
+ train_loss = 0
100
+
101
+ with tqdm(loader, desc=desc, disable=distributed() and dist.get_rank() > 0) as loop:
102
+ for texts, masks, labels in loop:
103
+
104
+ texts, masks, labels = texts.to(device), masks.to(device), labels.to(device)
105
+ batch_size = texts.shape[0]
106
+
107
+ optimizer.zero_grad()
108
+ loss, logits = model(texts, attention_mask=masks, labels=labels)
109
+ loss.backward()
110
+ optimizer.step()
111
+
112
+ batch_accuracy = accuracy_sum(logits, labels)
113
+ train_accuracy += batch_accuracy
114
+ train_epoch_size += batch_size
115
+ train_loss += loss.item() * batch_size
116
+
117
+ loop.set_postfix(loss=loss.item(), acc=train_accuracy / train_epoch_size)
118
+
119
+ return {
120
+ "train/accuracy": train_accuracy,
121
+ "train/epoch_size": train_epoch_size,
122
+ "train/loss": train_loss
123
+ }
124
+
125
+
126
+ def validate(model: nn.Module, device: str, loader: DataLoader, votes=1, desc='Validation'):
127
+ model.eval()
128
+
129
+ validation_accuracy = 0
130
+ validation_epoch_size = 0
131
+ validation_loss = 0
132
+
133
+ records = [record for v in range(votes) for record in tqdm(loader, desc=f'Preloading data ... {v}',
134
+ disable=dist.is_available() and dist.get_rank() > 0)]
135
+ records = [[records[v * len(loader) + i] for v in range(votes)] for i in range(len(loader))]
136
+
137
+ with tqdm(records, desc=desc, disable=distributed() and dist.get_rank() > 0) as loop, torch.no_grad():
138
+ for example in loop:
139
+ losses = []
140
+ logit_votes = []
141
+
142
+ for texts, masks, labels in example:
143
+ texts, masks, labels = texts.to(device), masks.to(device), labels.to(device)
144
+ batch_size = texts.shape[0]
145
+
146
+ loss, logits = model(texts, attention_mask=masks, labels=labels)
147
+ losses.append(loss)
148
+ logit_votes.append(logits)
149
+
150
+ loss = torch.stack(losses).mean(dim=0)
151
+ logits = torch.stack(logit_votes).mean(dim=0)
152
+
153
+ batch_accuracy = accuracy_sum(logits, labels)
154
+ validation_accuracy += batch_accuracy
155
+ validation_epoch_size += batch_size
156
+ validation_loss += loss.item() * batch_size
157
+
158
+ loop.set_postfix(loss=loss.item(), acc=validation_accuracy / validation_epoch_size)
159
+
160
+ return {
161
+ "validation/accuracy": validation_accuracy,
162
+ "validation/epoch_size": validation_epoch_size,
163
+ "validation/loss": validation_loss
164
+ }
165
+
166
+
167
+ def _all_reduce_dict(d, device):
168
+ # wrap in tensor and use reduce to gpu0 tensor
169
+ output_d = {}
170
+ for (key, value) in sorted(d.items()):
171
+ tensor_input = torch.tensor([[value]]).to(device)
172
+ torch.distributed.all_reduce(tensor_input)
173
+ output_d[key] = tensor_input.item()
174
+ return output_d
175
+
176
+
177
+ def run(max_epochs=None,
178
+ device=None,
179
+ batch_size=24,
180
+ max_sequence_length=128,
181
+ random_sequence_length=False,
182
+ epoch_size=None,
183
+ seed=None,
184
+ data_dir='data',
185
+ real_dataset='webtext',
186
+ fake_dataset='xl-1542M-nucleus',
187
+ token_dropout=None,
188
+ large=False,
189
+ learning_rate=2e-5,
190
+ weight_decay=0,
191
+ **kwargs):
192
+ args = locals()
193
+ rank, world_size = setup_distributed()
194
+
195
+ if device is None:
196
+ device = f'cuda:{rank}' if torch.cuda.is_available() else 'cpu'
197
+
198
+ print('rank:', rank, 'world_size:', world_size, 'device:', device)
199
+
200
+ import torch.distributed as dist
201
+ if distributed() and rank > 0:
202
+ dist.barrier()
203
+
204
+ model_name = 'roberta-large' if large else 'roberta-base'
205
+ tokenization_utils.logger.setLevel('ERROR')
206
+ tokenizer = RobertaTokenizer.from_pretrained(model_name)
207
+ model = RobertaForSequenceClassification.from_pretrained(model_name).to(device)
208
+
209
+ if rank == 0:
210
+ summary(model)
211
+ if distributed():
212
+ dist.barrier()
213
+
214
+ if world_size > 1:
215
+ model = DistributedDataParallel(model, [rank], output_device=rank, find_unused_parameters=True)
216
+
217
+ train_loader, validation_loader = load_datasets(data_dir, real_dataset, fake_dataset, tokenizer, batch_size,
218
+ max_sequence_length, random_sequence_length, epoch_size,
219
+ token_dropout, seed)
220
+
221
+ optimizer = Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
222
+ epoch_loop = count(1) if max_epochs is None else range(1, max_epochs + 1)
223
+
224
+ logdir = os.environ.get("OPENAI_LOGDIR", "logs")
225
+ os.makedirs(logdir, exist_ok=True)
226
+
227
+ from torch.utils.tensorboard import SummaryWriter
228
+ writer = SummaryWriter(logdir) if rank == 0 else None
229
+ best_validation_accuracy = 0
230
+
231
+ for epoch in epoch_loop:
232
+ if world_size > 1:
233
+ train_loader.sampler.set_epoch(epoch)
234
+ validation_loader.sampler.set_epoch(epoch)
235
+
236
+ train_metrics = train(model, optimizer, device, train_loader, f'Epoch {epoch}')
237
+ validation_metrics = validate(model, device, validation_loader)
238
+
239
+ combined_metrics = _all_reduce_dict({**validation_metrics, **train_metrics}, device)
240
+
241
+ combined_metrics["train/accuracy"] /= combined_metrics["train/epoch_size"]
242
+ combined_metrics["train/loss"] /= combined_metrics["train/epoch_size"]
243
+ combined_metrics["validation/accuracy"] /= combined_metrics["validation/epoch_size"]
244
+ combined_metrics["validation/loss"] /= combined_metrics["validation/epoch_size"]
245
+
246
+ if rank == 0:
247
+ for key, value in combined_metrics.items():
248
+ writer.add_scalar(key, value, global_step=epoch)
249
+
250
+ if combined_metrics["validation/accuracy"] > best_validation_accuracy:
251
+ best_validation_accuracy = combined_metrics["validation/accuracy"]
252
+
253
+ model_to_save = model.module if hasattr(model, 'module') else model
254
+ torch.save(dict(
255
+ epoch=epoch,
256
+ model_state_dict=model_to_save.state_dict(),
257
+ optimizer_state_dict=optimizer.state_dict(),
258
+ args=args
259
+ ),
260
+ os.path.join(logdir, "best-model.pt")
261
+ )
262
+
263
+
264
+ if __name__ == '__main__':
265
+ parser = argparse.ArgumentParser()
266
+
267
+ parser.add_argument('--max-epochs', type=int, default=None)
268
+ parser.add_argument('--device', type=str, default=None)
269
+ parser.add_argument('--batch-size', type=int, default=24)
270
+ parser.add_argument('--max-sequence-length', type=int, default=128)
271
+ parser.add_argument('--random-sequence-length', action='store_true')
272
+ parser.add_argument('--epoch-size', type=int, default=None)
273
+ parser.add_argument('--seed', type=int, default=None)
274
+ parser.add_argument('--data-dir', type=str, default='data')
275
+ parser.add_argument('--real-dataset', type=str, default='webtext')
276
+ parser.add_argument('--fake-dataset', type=str, default='xl-1542M-k40')
277
+ parser.add_argument('--token-dropout', type=float, default=None)
278
+
279
+ parser.add_argument('--large', action='store_true', help='use the roberta-large model instead of roberta-base')
280
+ parser.add_argument('--learning-rate', type=float, default=2e-5)
281
+ parser.add_argument('--weight-decay', type=float, default=0)
282
+ args = parser.parse_args()
283
+
284
+ nproc = int(subprocess.check_output(['python', '-c', "import torch;"
285
+ "print(torch.cuda.device_count() if torch.cuda.is_available() else 1)"]))
286
+ if nproc > 1:
287
+ print(f'Launching {nproc} processes ...', file=sys.stderr)
288
+
289
+ os.environ["MASTER_ADDR"] = '127.0.0.1'
290
+ os.environ["MASTER_PORT"] = str(29500)
291
+ os.environ['WORLD_SIZE'] = str(nproc)
292
+ os.environ['OMP_NUM_THREAD'] = str(1)
293
+ subprocesses = []
294
+
295
+ for i in range(nproc):
296
+ os.environ['RANK'] = str(i)
297
+ os.environ['LOCAL_RANK'] = str(i)
298
+ process = Process(target=run, kwargs=vars(args))
299
+ process.start()
300
+ subprocesses.append(process)
301
+
302
+ for process in subprocesses:
303
+ process.join()
304
+ else:
305
+ run(**vars(args))
detector/utils.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ from functools import reduce
3
+
4
+ from torch import nn
5
+ import torch.distributed as dist
6
+
7
+
8
+ def summary(model: nn.Module, file=sys.stdout):
9
+ def repr(model):
10
+ # We treat the extra repr like the sub-module, one item per line
11
+ extra_lines = []
12
+ extra_repr = model.extra_repr()
13
+ # empty string will be split into list ['']
14
+ if extra_repr:
15
+ extra_lines = extra_repr.split('\n')
16
+ child_lines = []
17
+ total_params = 0
18
+ for key, module in model._modules.items():
19
+ mod_str, num_params = repr(module)
20
+ mod_str = nn.modules.module._addindent(mod_str, 2)
21
+ child_lines.append('(' + key + '): ' + mod_str)
22
+ total_params += num_params
23
+ lines = extra_lines + child_lines
24
+
25
+ for name, p in model._parameters.items():
26
+ if hasattr(p, 'shape'):
27
+ total_params += reduce(lambda x, y: x * y, p.shape)
28
+
29
+ main_str = model._get_name() + '('
30
+ if lines:
31
+ # simple one-liner info, which most builtin Modules will use
32
+ if len(extra_lines) == 1 and not child_lines:
33
+ main_str += extra_lines[0]
34
+ else:
35
+ main_str += '\n ' + '\n '.join(lines) + '\n'
36
+
37
+ main_str += ')'
38
+ if file is sys.stdout:
39
+ main_str += ', \033[92m{:,}\033[0m params'.format(total_params)
40
+ else:
41
+ main_str += ', {:,} params'.format(total_params)
42
+ return main_str, total_params
43
+
44
+ string, count = repr(model)
45
+ if file is not None:
46
+ if isinstance(file, str):
47
+ file = open(file, 'w')
48
+ print(string, file=file)
49
+ file.flush()
50
+
51
+ return count
52
+
53
+
54
+ def grad_norm(model: nn.Module):
55
+ total_norm = 0
56
+ for p in model.parameters():
57
+ param_norm = p.grad.data.norm(2)
58
+ total_norm += param_norm.item() ** 2
59
+ return total_norm ** 0.5
60
+
61
+ def distributed():
62
+ return dist.is_available() and dist.is_initialized()
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ transformers>=2.0.0
2
+ fire>=0.2.1
3
+ requests>=2.22.0
4
+ tqdm>=4.32.2
5
+ torch>=1.2.0
6
+ tensorboard>=1.14.0