Spaces:
Runtime error
Runtime error
initial commit
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .dockerignore +1 -0
- Dockerfile +17 -0
- README.md +11 -1
- backend/API.py +165 -0
- backend/GPUHandler.py +159 -0
- backend/Model.py +299 -0
- backend/PPLM.py +723 -0
- backend/README.md +69 -0
- backend/Utils.py +57 -0
- backend/id/.gitignore +2 -0
- backend/id/id.ts +17 -0
- backend/id/package.json +12 -0
- backend/id/tsconfig.json +10 -0
- backend/id/yarn.lock +426 -0
- backend/install.sh +6 -0
- backend/launch.sh +50 -0
- backend/machine_configurations/neuralgenv2.json +13 -0
- backend/machine_configurations/transformer-autocomplete.json +11 -0
- backend/requirements.txt +4 -0
- backend/run_pplm_discrim_train.py +582 -0
- entrypoint.sh +7 -0
- front/.vscode/settings.json +7 -0
- front/assets/Icon-info.svg +9 -0
- front/assets/Salesforce_logo.svg +83 -0
- front/assets/Uber_logo.svg +11 -0
- front/assets/cross-collab.svg +14 -0
- front/assets/github-buttons.js +9 -0
- front/assets/huggingface_logo.svg +47 -0
- front/assets/icon-back.svg +15 -0
- front/assets/icon-publish.svg +16 -0
- front/assets/iconmonstr-download-14.svg +13 -0
- front/assets/iconmonstr-media-control-55.svg +13 -0
- front/assets/iconmonstr-share-11-purple.svg +10 -0
- front/assets/iconmonstr-share-11.svg +13 -0
- front/assets/oval.svg +17 -0
- front/assets/tail-spin.svg +32 -0
- front/assets/thumbnail-large-distilgpt2.png +0 -0
- front/assets/thumbnail-large-pplm.png +0 -0
- front/assets/thumbnail-large.png +0 -0
- front/assets/unicorn-tweaked.svg +1 -0
- front/favicon.ico +0 -0
- front/js-src/Api.ts +153 -0
- front/js-src/Mention.ts +441 -0
- front/js-src/controller.ts +319 -0
- front/js-src/lib/Log.ts +1 -0
- front/js-src/lib/Utils.ts +76 -0
- front/js-src/modals.ts +134 -0
- front/js-src/quill.d.ts +181 -0
- front/js-src/vanilla-tilt.ts +371 -0
- front/less/mixins/bfc.less +3 -0
.dockerignore
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
Dockerfile
|
Dockerfile
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM node:14
|
2 |
+
|
3 |
+
RUN apt-get update && \
|
4 |
+
apt-get install -y nginx gettext-base && \
|
5 |
+
rm -rf /var/lib/apt/lists/* && \
|
6 |
+
chown -R 1000:1000 /etc/nginx && \
|
7 |
+
chown -R 1000:1000 /var/log/nginx && \
|
8 |
+
chown -R 1000:1000 /var/lib/nginx
|
9 |
+
|
10 |
+
WORKDIR /app/transformer-autocomplete
|
11 |
+
ADD . .
|
12 |
+
|
13 |
+
RUN cd front && npm install && npx tsc && npm run build:prod
|
14 |
+
RUN cd grunt && npm install && npx grunt
|
15 |
+
RUN cd server && npm install && npx tsc
|
16 |
+
|
17 |
+
ENTRYPOINT ["./entrypoint.sh"]
|
README.md
CHANGED
@@ -4,7 +4,17 @@ emoji: 🏆
|
|
4 |
colorFrom: green
|
5 |
colorTo: gray
|
6 |
sdk: docker
|
|
|
7 |
pinned: false
|
8 |
---
|
9 |
|
10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
colorFrom: green
|
5 |
colorTo: gray
|
6 |
sdk: docker
|
7 |
+
app_port: 8080
|
8 |
pinned: false
|
9 |
---
|
10 |
|
11 |
+
# transformer-autocomplete
|
12 |
+
|
13 |
+
Autocompletion based on GPT-2
|
14 |
+
|
15 |
+
#### How to compile the front (to test the front with any server)
|
16 |
+
|
17 |
+
1. Update the API endpoint in `front/js-src/Api.ts`
|
18 |
+
2. compile the TS to pure JS with `cd front; tsc` or through vscode (you can launch it in watch mode if needed)
|
19 |
+
3. pack the js into a single file (we use rollup) with `npm run watch`
|
20 |
+
|
backend/API.py
ADDED
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from threading import Thread
|
2 |
+
import falcon
|
3 |
+
from falcon.http_status import HTTPStatus
|
4 |
+
import json
|
5 |
+
import requests
|
6 |
+
import time
|
7 |
+
from Model import generate_completion
|
8 |
+
import sys
|
9 |
+
|
10 |
+
|
11 |
+
class AutoComplete(object):
|
12 |
+
def on_post(self, req, resp, single_endpoint=True, x=None, y=None):
|
13 |
+
json_data = json.loads(req.bounded_stream.read())
|
14 |
+
|
15 |
+
resp.status = falcon.HTTP_200
|
16 |
+
|
17 |
+
start = time.time()
|
18 |
+
|
19 |
+
try:
|
20 |
+
context = json_data["context"].rstrip()
|
21 |
+
except KeyError:
|
22 |
+
resp.body = "The context field is required"
|
23 |
+
resp.status = falcon.HTTP_422
|
24 |
+
return
|
25 |
+
|
26 |
+
try:
|
27 |
+
n_samples = json_data['samples']
|
28 |
+
except KeyError:
|
29 |
+
n_samples = 3
|
30 |
+
|
31 |
+
try:
|
32 |
+
length = json_data['gen_length']
|
33 |
+
except KeyError:
|
34 |
+
length = 20
|
35 |
+
|
36 |
+
try:
|
37 |
+
max_time = json_data['max_time']
|
38 |
+
except KeyError:
|
39 |
+
max_time = -1
|
40 |
+
|
41 |
+
try:
|
42 |
+
model_name = json_data['model_size']
|
43 |
+
except KeyError:
|
44 |
+
model_name = "small"
|
45 |
+
|
46 |
+
try:
|
47 |
+
temperature = json_data['temperature']
|
48 |
+
except KeyError:
|
49 |
+
temperature = 0.7
|
50 |
+
|
51 |
+
try:
|
52 |
+
max_tokens = json_data['max_tokens']
|
53 |
+
except KeyError:
|
54 |
+
max_tokens = 256
|
55 |
+
|
56 |
+
try:
|
57 |
+
top_p = json_data['top_p']
|
58 |
+
except KeyError:
|
59 |
+
top_p = 0.95
|
60 |
+
|
61 |
+
try:
|
62 |
+
top_k = json_data['top_k']
|
63 |
+
except KeyError:
|
64 |
+
top_k = 40
|
65 |
+
|
66 |
+
|
67 |
+
# CTRL
|
68 |
+
try:
|
69 |
+
repetition_penalty = json_data['repetition_penalty']
|
70 |
+
except KeyError:
|
71 |
+
repetition_penalty = 0.02
|
72 |
+
|
73 |
+
# PPLM
|
74 |
+
try:
|
75 |
+
stepsize = json_data['step_size']
|
76 |
+
except KeyError:
|
77 |
+
stepsize = 0.02
|
78 |
+
|
79 |
+
try:
|
80 |
+
gm_scale = json_data['gm_scale']
|
81 |
+
except KeyError:
|
82 |
+
gm_scale = None
|
83 |
+
|
84 |
+
try:
|
85 |
+
kl_scale = json_data['kl_scale']
|
86 |
+
except KeyError:
|
87 |
+
kl_scale = None
|
88 |
+
|
89 |
+
try:
|
90 |
+
num_iterations = json_data['num_iterations']
|
91 |
+
except KeyError:
|
92 |
+
num_iterations = None
|
93 |
+
|
94 |
+
try:
|
95 |
+
use_sampling = json_data['use_sampling']
|
96 |
+
except KeyError:
|
97 |
+
use_sampling = None
|
98 |
+
|
99 |
+
try:
|
100 |
+
bag_of_words_or_discrim = json_data['bow_or_discrim']
|
101 |
+
except KeyError:
|
102 |
+
bag_of_words_or_discrim = "kitchen"
|
103 |
+
|
104 |
+
print(json_data)
|
105 |
+
|
106 |
+
sentences = generate_completion(
|
107 |
+
context,
|
108 |
+
length=length,
|
109 |
+
max_time=max_time,
|
110 |
+
model_name=model_name,
|
111 |
+
temperature=temperature,
|
112 |
+
max_tokens=max_tokens,
|
113 |
+
top_p=top_p,
|
114 |
+
top_k=top_k,
|
115 |
+
|
116 |
+
# CTRL
|
117 |
+
repetition_penalty=repetition_penalty,
|
118 |
+
|
119 |
+
# PPLM
|
120 |
+
stepsize=stepsize,
|
121 |
+
bag_of_words_or_discrim=bag_of_words_or_discrim,
|
122 |
+
gm_scale=gm_scale,
|
123 |
+
kl_scale=kl_scale,
|
124 |
+
num_iterations=num_iterations,
|
125 |
+
use_sampling=use_sampling
|
126 |
+
)
|
127 |
+
|
128 |
+
resp.body = json.dumps({"sentences": sentences, 'time': time.time() - start})
|
129 |
+
|
130 |
+
resp.status = falcon.HTTP_200
|
131 |
+
sys.stdout.flush()
|
132 |
+
|
133 |
+
|
134 |
+
class Request(Thread):
|
135 |
+
def __init__(self, end_point, data):
|
136 |
+
Thread.__init__(self)
|
137 |
+
self.end_point = end_point
|
138 |
+
self.data = data
|
139 |
+
self.ret = None
|
140 |
+
|
141 |
+
def run(self):
|
142 |
+
print("Requesting with url", self.end_point)
|
143 |
+
self.ret = requests.post(url=self.end_point, json=self.data)
|
144 |
+
|
145 |
+
def join(self):
|
146 |
+
Thread.join(self)
|
147 |
+
return self.ret.text
|
148 |
+
|
149 |
+
|
150 |
+
class HandleCORS(object):
|
151 |
+
def process_request(self, req, resp):
|
152 |
+
resp.set_header('Access-Control-Allow-Origin', '*')
|
153 |
+
resp.set_header('Access-Control-Allow-Methods', '*')
|
154 |
+
resp.set_header('Access-Control-Allow-Headers', '*')
|
155 |
+
if req.method == 'OPTIONS':
|
156 |
+
raise HTTPStatus(falcon.HTTP_200, body='\n')
|
157 |
+
|
158 |
+
|
159 |
+
autocomplete = AutoComplete()
|
160 |
+
app = falcon.API(middleware=[HandleCORS()])
|
161 |
+
app.add_route('/autocomplete', autocomplete)
|
162 |
+
app.add_route('/autocomplete/{x}', autocomplete)
|
163 |
+
app.add_route('/autocomplete/{x}/{y}', autocomplete)
|
164 |
+
|
165 |
+
application = app
|
backend/GPUHandler.py
ADDED
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
import torch
|
3 |
+
from transformers import (GPT2LMHeadModel, GPT2Tokenizer, GPT2Config,
|
4 |
+
OpenAIGPTLMHeadModel, OpenAIGPTTokenizer,
|
5 |
+
XLNetLMHeadModel, XLNetTokenizer,
|
6 |
+
TransfoXLLMHeadModel, TransfoXLTokenizer,
|
7 |
+
CTRLLMHeadModel, CTRLTokenizer)
|
8 |
+
|
9 |
+
model_metadata = {
|
10 |
+
"gpt2/small": {
|
11 |
+
"tokenizer": GPT2Tokenizer,
|
12 |
+
"model": GPT2LMHeadModel,
|
13 |
+
"size": 550,
|
14 |
+
"checkpoint": "gpt2",
|
15 |
+
"identifier": "gpt2/small"
|
16 |
+
}, "gpt": {
|
17 |
+
"tokenizer": OpenAIGPTTokenizer,
|
18 |
+
"model": OpenAIGPTLMHeadModel,
|
19 |
+
"size": 550,
|
20 |
+
"checkpoint": "openai-gpt",
|
21 |
+
"identifier": "gpt"
|
22 |
+
}, "xlnet": {
|
23 |
+
"tokenizer": XLNetTokenizer,
|
24 |
+
"model": XLNetLMHeadModel,
|
25 |
+
"size": 550,
|
26 |
+
"checkpoint": "xlnet-base-cased",
|
27 |
+
"identifier": "xlnet"
|
28 |
+
}, "gpt2/arxiv-nlp": {
|
29 |
+
"tokenizer": GPT2Tokenizer,
|
30 |
+
"model": GPT2LMHeadModel,
|
31 |
+
"size": 550,
|
32 |
+
"checkpoint": "arxiv-nlp-v1",
|
33 |
+
"identifier": "gpt2/arxiv-nlp"
|
34 |
+
}, "gpt2/medium": {
|
35 |
+
"tokenizer": GPT2Tokenizer,
|
36 |
+
"model": GPT2LMHeadModel,
|
37 |
+
"size": 1500,
|
38 |
+
"checkpoint": "gpt2-medium",
|
39 |
+
"identifier": "gpt2/medium"
|
40 |
+
}, "gpt2/large": {
|
41 |
+
"tokenizer": GPT2Tokenizer,
|
42 |
+
"model": GPT2LMHeadModel,
|
43 |
+
"size": 3300,
|
44 |
+
"checkpoint": "gpt2-large",
|
45 |
+
"identifier": "gpt2/large"
|
46 |
+
}, "distilgpt2/small": {
|
47 |
+
"tokenizer": GPT2Tokenizer,
|
48 |
+
"model": GPT2LMHeadModel,
|
49 |
+
"size": 350,
|
50 |
+
"checkpoint": "distilgpt2",
|
51 |
+
"identifier": "distilgpt2/small"
|
52 |
+
}, "ctrl": {
|
53 |
+
"tokenizer": CTRLTokenizer,
|
54 |
+
"model": CTRLLMHeadModel,
|
55 |
+
"size": 6300,
|
56 |
+
"checkpoint": "ctrl",
|
57 |
+
"identifier": "ctrl"
|
58 |
+
}, "pplm": {
|
59 |
+
"tokenizer": GPT2Tokenizer,
|
60 |
+
"model": GPT2LMHeadModel,
|
61 |
+
"size": 3000,
|
62 |
+
"checkpoint": "gpt2-large",
|
63 |
+
"identifier": "pplm"
|
64 |
+
}, "gpt2/xl": {
|
65 |
+
"tokenizer": GPT2Tokenizer,
|
66 |
+
"model": GPT2LMHeadModel,
|
67 |
+
"size": 7000,
|
68 |
+
"checkpoint": "gpt2-xl",
|
69 |
+
"identifier": "gpt2/xl"
|
70 |
+
}, "pplm": {
|
71 |
+
"tokenizer": GPT2Tokenizer,
|
72 |
+
"model": GPT2LMHeadModel,
|
73 |
+
"size": 4000,
|
74 |
+
"checkpoint": "gpt2-medium",
|
75 |
+
"identifier": "pplm",
|
76 |
+
"configuration_options": {
|
77 |
+
"config": GPT2Config,
|
78 |
+
"options": {
|
79 |
+
"output_hidden_states": True
|
80 |
+
}
|
81 |
+
}
|
82 |
+
}
|
83 |
+
}
|
84 |
+
|
85 |
+
memory_overhead = 500
|
86 |
+
|
87 |
+
class GPU:
|
88 |
+
def __init__(self, id):
|
89 |
+
self.id = id
|
90 |
+
self.models = []
|
91 |
+
self.total_memory = torch.cuda.get_device_properties(
|
92 |
+
"cuda:{}".format(id)).total_memory / 1_000_000 - 1_000
|
93 |
+
|
94 |
+
print("INIT GPU WITH DEVICE", "cuda:{}".format(id))
|
95 |
+
|
96 |
+
def register_model(self, model, cached_path=None):
|
97 |
+
if self.total_memory_used() + model["size"] < self.total_memory:
|
98 |
+
model["device"] = "cuda:{}".format(self.id)
|
99 |
+
|
100 |
+
if cached_path:
|
101 |
+
model["cached_path"] = cached_path
|
102 |
+
|
103 |
+
self.models.append(model)
|
104 |
+
return True
|
105 |
+
else:
|
106 |
+
return False
|
107 |
+
|
108 |
+
def total_memory_used(self):
|
109 |
+
return sum([model["size"] for model in self.models]) + memory_overhead
|
110 |
+
|
111 |
+
def __repr__(self):
|
112 |
+
return str(
|
113 |
+
[(model["checkpoint"], model["size"]) for model in self.models] +
|
114 |
+
[str(round(100 * (self.total_memory_used() / self.total_memory))) + "%"] +
|
115 |
+
["cuda:{}".format(self.id)]
|
116 |
+
)
|
117 |
+
|
118 |
+
|
119 |
+
class GPUHandler:
|
120 |
+
def __init__(self, ids, model_list, gpu_ids, cached_models=None):
|
121 |
+
if cached_models is None:
|
122 |
+
cached_models = {}
|
123 |
+
|
124 |
+
self.gpus = [GPU(id) for id in gpu_ids]
|
125 |
+
print("GPU handler initiated with {} gpus.".format(len(self.gpus)))
|
126 |
+
|
127 |
+
self.sanity_check([model_metadata[model] for model in model_list])
|
128 |
+
|
129 |
+
for model in model_list:
|
130 |
+
self.register_model(model_metadata[model], cached_models.get(model))
|
131 |
+
|
132 |
+
def register_model(self, model, cached_path=None):
|
133 |
+
for index, gpu in enumerate(self.gpus):
|
134 |
+
if gpu.register_model(model, cached_path):
|
135 |
+
print("Registered model", model, "in GPU", gpu)
|
136 |
+
break
|
137 |
+
|
138 |
+
if index >= len(self.gpus):
|
139 |
+
raise ValueError("Could not load model", model["checkpoint"])
|
140 |
+
|
141 |
+
def sanity_check(self, model_list):
|
142 |
+
temp_gpus = [GPU(id) for id in range(len(self.gpus))]
|
143 |
+
|
144 |
+
for model in model_list:
|
145 |
+
|
146 |
+
current_gpu_index = 0
|
147 |
+
while current_gpu_index < len(temp_gpus):
|
148 |
+
if not temp_gpus[current_gpu_index].register_model(model):
|
149 |
+
current_gpu_index += 1
|
150 |
+
else:
|
151 |
+
break
|
152 |
+
|
153 |
+
if current_gpu_index >= len(temp_gpus):
|
154 |
+
raise RuntimeError("SANITY CHECK FAILED")
|
155 |
+
|
156 |
+
print("Current layout", temp_gpus)
|
157 |
+
|
158 |
+
def __repr__(self):
|
159 |
+
return f"NO. GPUS: {len(self.gpus)}.\n{self.gpus}"
|
backend/Model.py
ADDED
@@ -0,0 +1,299 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
from transformers import (GPT2LMHeadModel, GPT2Tokenizer,
|
3 |
+
OpenAIGPTLMHeadModel, OpenAIGPTTokenizer,
|
4 |
+
XLNetLMHeadModel, XLNetTokenizer,
|
5 |
+
TransfoXLLMHeadModel, TransfoXLTokenizer,
|
6 |
+
CTRLLMHeadModel, CTRLTokenizer)
|
7 |
+
|
8 |
+
from Utils import forward, create_context
|
9 |
+
import torch
|
10 |
+
import torch.nn.functional as F
|
11 |
+
from math import floor
|
12 |
+
import requests
|
13 |
+
import json
|
14 |
+
import os
|
15 |
+
from PPLM import run_model as run_pplm, DISCRIMINATOR_MODELS_PARAMS
|
16 |
+
from GPUHandler import GPUHandler
|
17 |
+
|
18 |
+
PADDING_TEXT = """With eyes for the most part downcast and, if ever they lighted on a fellow creature, at once and
|
19 |
+
furtively averted, Bernard hastened across the roof. He was like a man pursued, but pursued by enemies he does not
|
20 |
+
wish to see, lest they should seem more hostile even than he had supposed, and he himself be made to feel guiltier
|
21 |
+
and even more helplessly alone. That horrible Benito Hoover!’ And yet the man had meant well enough. Which only made
|
22 |
+
it, in a way, much worse. Those who meant well behaved in the same way as those who meant badly. Even Lenina was making
|
23 |
+
him suffer. He remembered those weeks of timid indecision, during which he had looked and longed and despaired of ever
|
24 |
+
having the courage to ask her. Dared he face the risk of being humiliated by a contemptuous refusal? But if she were to
|
25 |
+
say yes, what rapture! Well, now she had said it and he was still wretched—wretched that she should have thought it
|
26 |
+
such a perfect afternoon for Obstacle Golf, that she should have trotted away to join Henry Foster, that she should
|
27 |
+
have found him funny for not wanting to talk of their most private affairs in public. Wretched, in a word, because she
|
28 |
+
had behaved as any healthy and virtuous English girl ought to behave and not in some other, abnormal, extraordinary
|
29 |
+
way. <eod> </s> <eos>"""
|
30 |
+
|
31 |
+
try:
|
32 |
+
PID = int(requests.get(url="http://localhost:3000").json())
|
33 |
+
N_GPU = torch.cuda.device_count()
|
34 |
+
GPU_PER_WORKER = int(os.getenv("GPU_PER_WORKER"))
|
35 |
+
GPU_IDS = list(range(PID * GPU_PER_WORKER, (PID + 1) * GPU_PER_WORKER))
|
36 |
+
print("Successfully init thread with id {}. The GPU ids attributed are: {}".format(PID, GPU_IDS))
|
37 |
+
|
38 |
+
with open(os.getenv("FILE")) as json_file:
|
39 |
+
data = json.load(json_file)
|
40 |
+
models = data["models_to_load"]
|
41 |
+
cached_models = data.get("cached_models")
|
42 |
+
except requests.exceptions.ConnectionError or TypeError:
|
43 |
+
if __name__ == "__main__":
|
44 |
+
PID = 0
|
45 |
+
N_GPU = torch.cuda.device_count()
|
46 |
+
GPU_PER_WORKER = 1
|
47 |
+
GPU_IDS = [0]
|
48 |
+
print("Successfully init development thread with id {}. The GPU ids attributed are: {}".format(PID, GPU_IDS))
|
49 |
+
models = ["pplm"]
|
50 |
+
cached_models = None
|
51 |
+
pass
|
52 |
+
else:
|
53 |
+
raise requests.exceptions.ConnectionError("The PID server is not running.")
|
54 |
+
|
55 |
+
|
56 |
+
handler = GPUHandler(int(), models, GPU_IDS, cached_models)
|
57 |
+
models = {}
|
58 |
+
|
59 |
+
for gpu in handler.gpus:
|
60 |
+
for model in gpu.models:
|
61 |
+
model_name = model["identifier"]
|
62 |
+
print(f"Loading {model_name} model and tokenizer")
|
63 |
+
models[model_name] = model
|
64 |
+
|
65 |
+
if model.get("cached_path"):
|
66 |
+
print("Loading {} from local path.".format(model_name))
|
67 |
+
model_checkpoint_path = model["cached_path"]
|
68 |
+
else:
|
69 |
+
model_checkpoint_path = model["checkpoint"]
|
70 |
+
|
71 |
+
if "configuration_options" in models[model_name]:
|
72 |
+
configuration_options = models[model_name]["configuration_options"]
|
73 |
+
print("Specific configuration options", configuration_options["options"])
|
74 |
+
|
75 |
+
config = configuration_options["config"].from_pretrained(model_checkpoint_path)
|
76 |
+
|
77 |
+
for option_key, option_value in configuration_options["options"].items():
|
78 |
+
setattr(config, option_key, option_value)
|
79 |
+
|
80 |
+
models[model_name]["model"] = models[model_name]["model"].from_pretrained(model_checkpoint_path, config=config).to(models[model_name]["device"])
|
81 |
+
else:
|
82 |
+
models[model_name]["model"] = models[model_name]["model"].from_pretrained(model_checkpoint_path).to(models[model_name]["device"])
|
83 |
+
|
84 |
+
models[model_name]["tokenizer"] = models[model_name]["tokenizer"].from_pretrained(models[model_name]["checkpoint"])
|
85 |
+
models[model_name]["model"].eval()
|
86 |
+
|
87 |
+
print("All models successfully loaded.")
|
88 |
+
|
89 |
+
|
90 |
+
def top_k_top_p_filtering(batch_logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
|
91 |
+
"""
|
92 |
+
Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
|
93 |
+
|
94 |
+
:param batch_logits: logits output by the model
|
95 |
+
:param top_k: >0: keep only top k tokens with highest probability (top-k filtering).
|
96 |
+
:param top_p: >0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
|
97 |
+
:param filter_value:
|
98 |
+
:return: A top_p/top_k filtered tensor of logits
|
99 |
+
"""
|
100 |
+
|
101 |
+
for i in range(batch_logits.size(0)):
|
102 |
+
logits = batch_logits[i]
|
103 |
+
assert logits.dim() == 1 # batch size 1 for now - could be updated for more but the code would be less clear
|
104 |
+
top_k = min(top_k, logits.size(-1)) # Safety check
|
105 |
+
if top_k and top_k > 0:
|
106 |
+
# Remove all tokens with a probability less than the last token of the top-k
|
107 |
+
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
|
108 |
+
logits[indices_to_remove] = filter_value
|
109 |
+
|
110 |
+
if top_p and top_p > 0.0:
|
111 |
+
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
112 |
+
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
113 |
+
|
114 |
+
# Remove tokens with cumulative probability above the threshold
|
115 |
+
sorted_indices_to_remove = cumulative_probs > top_p
|
116 |
+
# Shift the indices to the right to keep also the first token above the threshold
|
117 |
+
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
118 |
+
sorted_indices_to_remove[..., 0] = 0
|
119 |
+
|
120 |
+
indices_to_remove = sorted_indices[sorted_indices_to_remove]
|
121 |
+
logits[indices_to_remove] = filter_value
|
122 |
+
|
123 |
+
if 'batched_logits' in locals():
|
124 |
+
batched_logits = torch.cat((batched_logits, logits.unsqueeze(0)), dim=0)
|
125 |
+
else:
|
126 |
+
batched_logits = logits.unsqueeze(0)
|
127 |
+
|
128 |
+
return batched_logits
|
129 |
+
|
130 |
+
|
131 |
+
def check_tensor_for_eot(output, eot_token, dot_token):
|
132 |
+
return all([(eot_token in output_item or dot_token in output_item) for output_item in output.tolist()])
|
133 |
+
|
134 |
+
|
135 |
+
def truncate_after_eot(output, eot_tokens):
|
136 |
+
result = []
|
137 |
+
for i in range(output.size(0)):
|
138 |
+
if any([eot_token in output[i] for eot_token in eot_tokens]):
|
139 |
+
item = output[i].tolist()
|
140 |
+
index = find_min_value_in_array(item, eot_tokens)
|
141 |
+
result.append(item[:index] + [eot_tokens[0]])
|
142 |
+
else:
|
143 |
+
result.append(output[i].tolist())
|
144 |
+
return result
|
145 |
+
|
146 |
+
|
147 |
+
def find_min_value_in_array(array, values):
|
148 |
+
indexes = []
|
149 |
+
for value in values:
|
150 |
+
try:
|
151 |
+
indexes.append(array.index(value))
|
152 |
+
except ValueError:
|
153 |
+
"" # Couldn't find value in array
|
154 |
+
|
155 |
+
return min(indexes)
|
156 |
+
|
157 |
+
|
158 |
+
# @lru_cache()
|
159 |
+
def generate_completion(
|
160 |
+
raw_text,
|
161 |
+
length=-1,
|
162 |
+
max_time=-1,
|
163 |
+
model_name="small",
|
164 |
+
temperature=1,
|
165 |
+
max_tokens=256,
|
166 |
+
top_p=0.0,
|
167 |
+
top_k=0,
|
168 |
+
batch_size=3,
|
169 |
+
repetition_penalty=1.2,
|
170 |
+
|
171 |
+
# PPLM
|
172 |
+
bag_of_words_or_discrim=None,
|
173 |
+
stepsize=0.02,
|
174 |
+
gamma=1.5,
|
175 |
+
num_iterations=3,
|
176 |
+
window_length=5,
|
177 |
+
kl_scale=0.01,
|
178 |
+
gm_scale=0.95,
|
179 |
+
use_sampling=False
|
180 |
+
):
|
181 |
+
start = time.time()
|
182 |
+
|
183 |
+
try:
|
184 |
+
print("Running with model", model_name)
|
185 |
+
model, tokenizer, device = models[model_name]["model"], models[model_name]["tokenizer"], models[model_name]["device"]
|
186 |
+
except KeyError:
|
187 |
+
print("Error. Defaulting to small model.")
|
188 |
+
model, tokenizer, device = models["gpt2/small"]["model"], models["gpt2/small"]["tokenizer"], models["gpt2/small"]["device"]
|
189 |
+
|
190 |
+
if "pplm" in model_name:
|
191 |
+
if ":" in bag_of_words_or_discrim:
|
192 |
+
discrim, discrim_label = bag_of_words_or_discrim.split(":")
|
193 |
+
discrim_label = DISCRIMINATOR_MODELS_PARAMS[discrim]["class_id"][int(discrim_label)]
|
194 |
+
bag_of_words = None
|
195 |
+
|
196 |
+
# Hardcoded parameters for the discriminator
|
197 |
+
gamma = 1.0
|
198 |
+
|
199 |
+
print("Running PPLM with discriminator:", discrim, discrim_label)
|
200 |
+
else:
|
201 |
+
bag_of_words = bag_of_words_or_discrim
|
202 |
+
discrim = None
|
203 |
+
discrim_label = None
|
204 |
+
|
205 |
+
# Hardcoded parameters for the BOW
|
206 |
+
gamma = 1.5
|
207 |
+
window_length = 5
|
208 |
+
|
209 |
+
print("Running PPLM with bag of words:", bag_of_words)
|
210 |
+
|
211 |
+
print("kl", kl_scale, "gm", gm_scale, "sampling", use_sampling, "window length", window_length, "gamma", gamma, "temperature", temperature)
|
212 |
+
|
213 |
+
return run_pplm(
|
214 |
+
model, tokenizer, device, raw_text,
|
215 |
+
max_time=max_time,
|
216 |
+
discrim=discrim,
|
217 |
+
discrim_label=discrim_label,
|
218 |
+
num_samples=batch_size,
|
219 |
+
bag_of_words=bag_of_words,
|
220 |
+
length=length,
|
221 |
+
temperature=temperature,
|
222 |
+
top_k=top_k,
|
223 |
+
stepsize=stepsize,
|
224 |
+
gamma=gamma,
|
225 |
+
num_iterations=num_iterations,
|
226 |
+
window_length=window_length,
|
227 |
+
kl_scale=kl_scale,
|
228 |
+
gm_scale=gm_scale,
|
229 |
+
use_sampling=use_sampling
|
230 |
+
)
|
231 |
+
|
232 |
+
|
233 |
+
context_tokens, eot_token, dot_token = create_context(model_name, tokenizer, raw_text, PADDING_TEXT, max_tokens=max_tokens)
|
234 |
+
|
235 |
+
if length == -1:
|
236 |
+
length = 100
|
237 |
+
|
238 |
+
context = torch.tensor(context_tokens, device=device, dtype=torch.long).unsqueeze(0).repeat(batch_size, 1)
|
239 |
+
prev = context
|
240 |
+
past = None
|
241 |
+
|
242 |
+
with torch.no_grad():
|
243 |
+
for _ in range(length):
|
244 |
+
try:
|
245 |
+
output = forward(model_name, model, prev, past, device=device)
|
246 |
+
except RuntimeError:
|
247 |
+
return "ERROR 500: OOM. TransfoXL asked for too much memory."
|
248 |
+
|
249 |
+
logits, past = output if len(output) > 2 else output[0], None
|
250 |
+
|
251 |
+
logits = logits[:, -1, :] / max(temperature, 0.001)
|
252 |
+
|
253 |
+
if "ctrl" in model_name:
|
254 |
+
for i in range(batch_size):
|
255 |
+
for j in set(prev[i].tolist()):
|
256 |
+
logits[i, j] /= repetition_penalty
|
257 |
+
|
258 |
+
logits = top_k_top_p_filtering(logits, top_p=top_p, top_k=top_k)
|
259 |
+
log_probs = F.softmax(logits, dim=-1)
|
260 |
+
token = torch.multinomial(log_probs, num_samples=1)
|
261 |
+
|
262 |
+
prev = torch.cat((prev, token), dim=1)
|
263 |
+
|
264 |
+
# Check that there is no eot token in all of the sentence, else breaks.
|
265 |
+
if check_tensor_for_eot(prev[:, len(context_tokens):], eot_token, dot_token) or (max_time != -1 and time.time() - start + 0.1 > max_time):
|
266 |
+
break
|
267 |
+
|
268 |
+
out = prev[:, len(context_tokens):]
|
269 |
+
# Remove the words following the eot tokens.
|
270 |
+
out = truncate_after_eot(out, list(filter(lambda t: t is not None, [dot_token, eot_token])))
|
271 |
+
end = time.time()
|
272 |
+
|
273 |
+
# Remove empty sentences and duplicates
|
274 |
+
generations = list(set(filter(lambda x: len(x) > 0, [" " + tokenizer.decode(single_generation).strip() for single_generation in out])))
|
275 |
+
|
276 |
+
sentences = [
|
277 |
+
{"value": generations[i], "time": end - start, "tokens": len(out[i])} for i in range(len(generations))
|
278 |
+
]
|
279 |
+
|
280 |
+
|
281 |
+
# print(end - start, [len(out[i]) for i in range(len(generations))])
|
282 |
+
|
283 |
+
return sentences
|
284 |
+
|
285 |
+
|
286 |
+
if __name__ == "__main__":
|
287 |
+
print(generate_completion(
|
288 |
+
"My dog died",
|
289 |
+
length=30, model_name="pplm", batch_size=3, top_k=10, top_p=0.9,
|
290 |
+
bag_of_words_or_discrim="sentiment:2",
|
291 |
+
stepsize=0.03,
|
292 |
+
gamma=1,
|
293 |
+
num_iterations=3,
|
294 |
+
window_length=5,
|
295 |
+
kl_scale=0.01,
|
296 |
+
gm_scale=0.95,
|
297 |
+
max_time=-1,
|
298 |
+
use_sampling=False
|
299 |
+
))
|
backend/PPLM.py
ADDED
@@ -0,0 +1,723 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#! /usr/bin/env python3
|
2 |
+
# coding=utf-8
|
3 |
+
# Copyright 2018 The Uber AI Team Authors.
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
|
17 |
+
"""
|
18 |
+
Example command with bag of words:
|
19 |
+
python examples/run_pplm.py -B space --cond_text "The president" --length 100 --gamma 1.5 --num_iterations 3 --num_samples 10 --stepsize 0.01 --window_length 5 --kl_scale 0.01 --gm_scale 0.95
|
20 |
+
|
21 |
+
Example command with discriminator:
|
22 |
+
python examples/run_pplm.py -D sentiment --class_label 3 --cond_text "The lake" --length 10 --gamma 1.0 --num_iterations 30 --num_samples 10 --stepsize 0.01 --kl_scale 0.01 --gm_scale 0.95
|
23 |
+
"""
|
24 |
+
|
25 |
+
import json
|
26 |
+
from operator import add
|
27 |
+
from typing import List, Optional, Tuple, Union
|
28 |
+
|
29 |
+
import numpy as np
|
30 |
+
import torch
|
31 |
+
import torch.nn.functional as F
|
32 |
+
from torch.autograd import Variable
|
33 |
+
from tqdm import trange
|
34 |
+
from transformers.file_utils import cached_path
|
35 |
+
import time
|
36 |
+
|
37 |
+
from run_pplm_discrim_train import ClassificationHead
|
38 |
+
|
39 |
+
PPLM_BOW = 1
|
40 |
+
PPLM_DISCRIM = 2
|
41 |
+
PPLM_BOW_DISCRIM = 3
|
42 |
+
SMALL_CONST = 1e-15
|
43 |
+
BIG_CONST = 1e10
|
44 |
+
|
45 |
+
BAG_OF_WORDS_ARCHIVE_MAP = {
|
46 |
+
'kitchen': "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/bow/kitchen.txt",
|
47 |
+
'legal': "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/bow/legal.txt",
|
48 |
+
'military': "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/bow/military.txt",
|
49 |
+
'monsters': "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/bow/monsters.txt",
|
50 |
+
'politics': "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/bow/politics.txt",
|
51 |
+
'positive_words': "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/bow/positive_words.txt",
|
52 |
+
'religion': "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/bow/religion.txt",
|
53 |
+
'science': "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/bow/science.txt",
|
54 |
+
'space': "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/bow/space.txt",
|
55 |
+
'technology': "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/bow/technology.txt",
|
56 |
+
}
|
57 |
+
|
58 |
+
DISCRIMINATOR_MODELS_PARAMS = {
|
59 |
+
"clickbait": {
|
60 |
+
"url": "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/discriminators/clickbait_classifierhead.pt",
|
61 |
+
"class_size": 2,
|
62 |
+
"embed_size": 1024,
|
63 |
+
"class_vocab": {"non_clickbait": 0, "clickbait": 1},
|
64 |
+
"class_id": {0: "non_clickbait", 1: "clickbait"},
|
65 |
+
"default_class": 1,
|
66 |
+
"pretrained_model": "gpt2-medium",
|
67 |
+
},
|
68 |
+
"sentiment": {
|
69 |
+
"url": "http://s.yosinski.com/SST_classifier_head.pt",
|
70 |
+
"class_size": 5,
|
71 |
+
"embed_size": 1024,
|
72 |
+
"class_vocab": {"very_positive": 2, "very_negative": 3},
|
73 |
+
"class_id": {2: "very_positive", 3: "very_negative"},
|
74 |
+
"default_class": 3,
|
75 |
+
"pretrained_model": "gpt2-medium",
|
76 |
+
},
|
77 |
+
"toxicity": {
|
78 |
+
"url": "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/discriminators/toxicity_classifierhead.pt",
|
79 |
+
"class_size": 2,
|
80 |
+
"embed_size": 1024,
|
81 |
+
"class_vocab": {"non_toxic": 0, "toxic": 1},
|
82 |
+
"class_id": {0: "non_toxic", 1: "toxic"},
|
83 |
+
"default_class": 0,
|
84 |
+
"pretrained_model": "gpt2-medium",
|
85 |
+
},
|
86 |
+
}
|
87 |
+
|
88 |
+
|
89 |
+
def to_var(x, requires_grad=False, volatile=False, device='cuda'):
|
90 |
+
if torch.cuda.is_available() and device == 'cuda':
|
91 |
+
x = x.cuda()
|
92 |
+
elif device != 'cuda':
|
93 |
+
x = x.to(device)
|
94 |
+
return Variable(x, requires_grad=requires_grad, volatile=volatile)
|
95 |
+
|
96 |
+
|
97 |
+
def top_k_filter(logits, k, probs=False):
|
98 |
+
"""
|
99 |
+
Masks everything but the k top entries as -infinity (1e10).
|
100 |
+
Used to mask logits such that e^-infinity -> 0 won't contribute to the
|
101 |
+
sum of the denominator.
|
102 |
+
"""
|
103 |
+
if k == 0:
|
104 |
+
return logits
|
105 |
+
else:
|
106 |
+
values = torch.topk(logits, k)[0]
|
107 |
+
batch_mins = values[:, -1].view(-1, 1).expand_as(logits)
|
108 |
+
if probs:
|
109 |
+
return torch.where(logits < batch_mins,
|
110 |
+
torch.ones_like(logits) * 0.0, logits)
|
111 |
+
return torch.where(logits < batch_mins,
|
112 |
+
torch.ones_like(logits) * -BIG_CONST,
|
113 |
+
logits)
|
114 |
+
|
115 |
+
|
116 |
+
def perturb_past(
|
117 |
+
past,
|
118 |
+
model,
|
119 |
+
last,
|
120 |
+
unpert_past=None,
|
121 |
+
unpert_logits=None,
|
122 |
+
accumulated_hidden=None,
|
123 |
+
grad_norms=None,
|
124 |
+
stepsize=0.01,
|
125 |
+
one_hot_bows_vectors=None,
|
126 |
+
classifier=None,
|
127 |
+
class_label=None,
|
128 |
+
loss_type=0,
|
129 |
+
num_iterations=3,
|
130 |
+
horizon_length=1,
|
131 |
+
window_length=0,
|
132 |
+
decay=False,
|
133 |
+
gamma=1.5,
|
134 |
+
kl_scale=0.01,
|
135 |
+
device='cuda',
|
136 |
+
):
|
137 |
+
# Generate inital perturbed past
|
138 |
+
grad_accumulator = [
|
139 |
+
(np.zeros(p.shape).astype("float32"))
|
140 |
+
for p in past
|
141 |
+
]
|
142 |
+
|
143 |
+
if accumulated_hidden is None:
|
144 |
+
accumulated_hidden = 0
|
145 |
+
|
146 |
+
if decay:
|
147 |
+
decay_mask = torch.arange(
|
148 |
+
0.,
|
149 |
+
1.0 + SMALL_CONST,
|
150 |
+
1.0 / (window_length)
|
151 |
+
)[1:]
|
152 |
+
else:
|
153 |
+
decay_mask = 1.0
|
154 |
+
|
155 |
+
# TODO fix this comment (SUMANTH)
|
156 |
+
# Generate a mask is gradient perturbated is based on a past window
|
157 |
+
_, batch_size, _, curr_length, _ = past[0].shape
|
158 |
+
|
159 |
+
if curr_length > window_length and window_length > 0:
|
160 |
+
ones_key_val_shape = (
|
161 |
+
tuple(past[0].shape[:-2])
|
162 |
+
+ tuple([window_length])
|
163 |
+
+ tuple(past[0].shape[-1:])
|
164 |
+
)
|
165 |
+
|
166 |
+
zeros_key_val_shape = (
|
167 |
+
tuple(past[0].shape[:-2])
|
168 |
+
+ tuple([curr_length - window_length])
|
169 |
+
+ tuple(past[0].shape[-1:])
|
170 |
+
)
|
171 |
+
|
172 |
+
ones_mask = torch.ones(ones_key_val_shape)
|
173 |
+
ones_mask = decay_mask * ones_mask.permute(0, 1, 2, 4, 3)
|
174 |
+
ones_mask = ones_mask.permute(0, 1, 2, 4, 3)
|
175 |
+
|
176 |
+
window_mask = torch.cat(
|
177 |
+
(ones_mask, torch.zeros(zeros_key_val_shape)),
|
178 |
+
dim=-2
|
179 |
+
).to(device)
|
180 |
+
else:
|
181 |
+
window_mask = torch.ones_like(past[0]).to(device)
|
182 |
+
|
183 |
+
# accumulate perturbations for num_iterations
|
184 |
+
loss_per_iter = []
|
185 |
+
losses_per_iter = []
|
186 |
+
new_accumulated_hidden = None
|
187 |
+
for i in range(num_iterations):
|
188 |
+
curr_perturbation = [
|
189 |
+
to_var(torch.from_numpy(p_), requires_grad=True, device=device)
|
190 |
+
for p_ in grad_accumulator
|
191 |
+
]
|
192 |
+
|
193 |
+
# Compute hidden using perturbed past
|
194 |
+
perturbed_past = list(map(add, past, curr_perturbation))
|
195 |
+
_, _, _, curr_length, _ = curr_perturbation[0].shape
|
196 |
+
all_logits, _, all_hidden = model(last, past=perturbed_past)
|
197 |
+
hidden = all_hidden[-1]
|
198 |
+
new_accumulated_hidden = accumulated_hidden + torch.sum(
|
199 |
+
hidden,
|
200 |
+
dim=1
|
201 |
+
).detach()
|
202 |
+
# TODO: Check the layer-norm consistency of this with trained discriminator (Sumanth)
|
203 |
+
logits = all_logits[:, -1, :]
|
204 |
+
probs = F.softmax(logits, dim=-1)
|
205 |
+
|
206 |
+
loss = 0.0
|
207 |
+
losses = torch.zeros(batch_size, device=device)
|
208 |
+
loss_list = []
|
209 |
+
if loss_type == PPLM_BOW or loss_type == PPLM_BOW_DISCRIM:
|
210 |
+
for one_hot_bow in one_hot_bows_vectors:
|
211 |
+
bow_logits = torch.mm(probs, torch.t(one_hot_bow))
|
212 |
+
bow_losses = -torch.log(torch.sum(bow_logits, dim=-1))
|
213 |
+
losses += bow_losses
|
214 |
+
bow_loss = torch.sum(bow_losses) # sum over batches
|
215 |
+
loss += bow_loss
|
216 |
+
loss_list.append(bow_loss)
|
217 |
+
|
218 |
+
if loss_type == 2 or loss_type == 3:
|
219 |
+
ce_loss = torch.nn.CrossEntropyLoss(reduction='none')
|
220 |
+
# TODO why we need to do this assignment and not just using unpert_past? (Sumanth)
|
221 |
+
curr_unpert_past = unpert_past
|
222 |
+
curr_probs = torch.unsqueeze(probs, dim=1)
|
223 |
+
wte = model.resize_token_embeddings()
|
224 |
+
for _ in range(horizon_length):
|
225 |
+
inputs_embeds = torch.matmul(curr_probs, wte.weight.data)
|
226 |
+
_, curr_unpert_past, curr_all_hidden = model(
|
227 |
+
past=curr_unpert_past,
|
228 |
+
inputs_embeds=inputs_embeds
|
229 |
+
)
|
230 |
+
curr_hidden = curr_all_hidden[-1]
|
231 |
+
new_accumulated_hidden = new_accumulated_hidden + torch.sum(
|
232 |
+
curr_hidden, dim=1)
|
233 |
+
|
234 |
+
prediction = classifier(new_accumulated_hidden /
|
235 |
+
(curr_length + 1 + horizon_length))
|
236 |
+
|
237 |
+
label = torch.tensor(batch_size * [class_label],
|
238 |
+
device=device,
|
239 |
+
dtype=torch.long)
|
240 |
+
discrim_losses = ce_loss(prediction, label)
|
241 |
+
losses += discrim_losses
|
242 |
+
discrim_loss = discrim_losses.sum(-1)
|
243 |
+
loss += discrim_loss
|
244 |
+
loss_list.append(discrim_loss)
|
245 |
+
|
246 |
+
kl_loss = 0.0
|
247 |
+
if kl_scale > 0.0:
|
248 |
+
unpert_probs = F.softmax(unpert_logits[:, -1, :], dim=-1)
|
249 |
+
unpert_probs = (
|
250 |
+
unpert_probs + SMALL_CONST *
|
251 |
+
(unpert_probs <= SMALL_CONST).float().to(device).detach()
|
252 |
+
)
|
253 |
+
correction = SMALL_CONST * (probs <= SMALL_CONST).float().to(
|
254 |
+
device).detach()
|
255 |
+
corrected_probs = probs + correction.detach()
|
256 |
+
kl_losses = kl_scale * (
|
257 |
+
(corrected_probs * (corrected_probs / unpert_probs).log()).sum(-1)
|
258 |
+
)
|
259 |
+
losses += kl_losses
|
260 |
+
kl_loss = kl_losses.sum()
|
261 |
+
loss += kl_loss
|
262 |
+
|
263 |
+
loss_per_iter.append(loss.data.cpu().numpy())
|
264 |
+
losses_per_iter.append(losses.data.cpu().numpy())
|
265 |
+
|
266 |
+
# compute gradients
|
267 |
+
loss.backward()
|
268 |
+
|
269 |
+
# calculate gradient norms
|
270 |
+
if grad_norms is not None and loss_type == PPLM_BOW:
|
271 |
+
grad_norms = [
|
272 |
+
torch.max(grad_norms[index],
|
273 |
+
torch.norm_except_dim(p_.grad * window_mask, dim=1))
|
274 |
+
#torch.norm(p_.grad * window_mask))
|
275 |
+
for index, p_ in enumerate(curr_perturbation)
|
276 |
+
]
|
277 |
+
else:
|
278 |
+
grad_norms = [
|
279 |
+
(torch.norm_except_dim(p_.grad * window_mask, dim=1) + SMALL_CONST)
|
280 |
+
for index, p_ in enumerate(curr_perturbation)
|
281 |
+
]
|
282 |
+
|
283 |
+
# normalize gradients
|
284 |
+
grad = [
|
285 |
+
-stepsize *
|
286 |
+
(p_.grad * window_mask / grad_norms[
|
287 |
+
index] ** gamma).data.cpu().numpy()
|
288 |
+
for index, p_ in enumerate(curr_perturbation)
|
289 |
+
]
|
290 |
+
|
291 |
+
# accumulate gradient
|
292 |
+
grad_accumulator = list(map(add, grad, grad_accumulator))
|
293 |
+
|
294 |
+
# reset gradients, just to make sure
|
295 |
+
for p_ in curr_perturbation:
|
296 |
+
p_.grad.data.zero_()
|
297 |
+
|
298 |
+
# removing past from the graph
|
299 |
+
new_past = []
|
300 |
+
for p_ in past:
|
301 |
+
new_past.append(p_.detach())
|
302 |
+
past = new_past
|
303 |
+
|
304 |
+
# apply the accumulated perturbations to the past
|
305 |
+
grad_accumulator = [
|
306 |
+
to_var(torch.from_numpy(p_), requires_grad=True, device=device)
|
307 |
+
for p_ in grad_accumulator
|
308 |
+
]
|
309 |
+
pert_past = list(map(add, past, grad_accumulator))
|
310 |
+
|
311 |
+
return pert_past, new_accumulated_hidden, grad_norms, losses_per_iter
|
312 |
+
|
313 |
+
|
314 |
+
def get_classifier(
|
315 |
+
name: Optional[str], class_label: Union[str, int],
|
316 |
+
device: str
|
317 |
+
) -> Tuple[Optional[ClassificationHead], Optional[int]]:
|
318 |
+
if name is None:
|
319 |
+
return None, None
|
320 |
+
|
321 |
+
params = DISCRIMINATOR_MODELS_PARAMS[name]
|
322 |
+
classifier = ClassificationHead(
|
323 |
+
class_size=params['class_size'],
|
324 |
+
embed_size=params['embed_size']
|
325 |
+
).to(device)
|
326 |
+
if "url" in params:
|
327 |
+
resolved_archive_file = cached_path(params["url"])
|
328 |
+
elif "path" in params:
|
329 |
+
resolved_archive_file = params["path"]
|
330 |
+
else:
|
331 |
+
raise ValueError("Either url or path have to be specified "
|
332 |
+
"in the discriminator model parameters")
|
333 |
+
classifier.load_state_dict(
|
334 |
+
torch.load(resolved_archive_file, map_location=device))
|
335 |
+
classifier.eval()
|
336 |
+
|
337 |
+
if isinstance(class_label, str):
|
338 |
+
if class_label in params["class_vocab"]:
|
339 |
+
label_id = params["class_vocab"][class_label]
|
340 |
+
else:
|
341 |
+
label_id = params["default_class"]
|
342 |
+
|
343 |
+
|
344 |
+
elif isinstance(class_label, int):
|
345 |
+
if class_label in set(params["class_vocab"].values()):
|
346 |
+
label_id = class_label
|
347 |
+
else:
|
348 |
+
label_id = params["default_class"]
|
349 |
+
|
350 |
+
else:
|
351 |
+
label_id = params["default_class"]
|
352 |
+
|
353 |
+
return classifier, label_id
|
354 |
+
|
355 |
+
|
356 |
+
def get_bag_of_words_indices(bag_of_words_ids_or_paths: List[str], tokenizer) -> \
|
357 |
+
List[List[List[int]]]:
|
358 |
+
bow_indices = []
|
359 |
+
for id_or_path in bag_of_words_ids_or_paths:
|
360 |
+
if id_or_path in BAG_OF_WORDS_ARCHIVE_MAP:
|
361 |
+
filepath = cached_path(BAG_OF_WORDS_ARCHIVE_MAP[id_or_path])
|
362 |
+
else:
|
363 |
+
filepath = id_or_path
|
364 |
+
with open(filepath, "r") as f:
|
365 |
+
words = f.read().strip().split("\n")
|
366 |
+
bow_indices.append(
|
367 |
+
[tokenizer.encode(word.strip(), add_prefix_space=True,
|
368 |
+
add_special_tokens=False) for word in
|
369 |
+
words])
|
370 |
+
return bow_indices
|
371 |
+
|
372 |
+
|
373 |
+
def build_bows_one_hot_vectors(bow_indices, tokenizer, device='cuda'):
|
374 |
+
if bow_indices is None:
|
375 |
+
return None
|
376 |
+
|
377 |
+
one_hot_bows_vectors = []
|
378 |
+
for single_bow in bow_indices:
|
379 |
+
single_bow = list(filter(lambda x: len(x) <= 1, single_bow))
|
380 |
+
single_bow = torch.tensor(single_bow).to(device)
|
381 |
+
num_words = single_bow.shape[0]
|
382 |
+
one_hot_bow = torch.zeros(num_words, tokenizer.vocab_size).to(device)
|
383 |
+
one_hot_bow.scatter_(1, single_bow, 1)
|
384 |
+
one_hot_bows_vectors.append(one_hot_bow)
|
385 |
+
return one_hot_bows_vectors
|
386 |
+
|
387 |
+
|
388 |
+
def full_text_generation(
|
389 |
+
model,
|
390 |
+
tokenizer,
|
391 |
+
context=None,
|
392 |
+
num_samples=1,
|
393 |
+
device="cuda",
|
394 |
+
max_time=5,
|
395 |
+
sample=False,
|
396 |
+
discrim=None,
|
397 |
+
class_label=None,
|
398 |
+
bag_of_words=None,
|
399 |
+
length=100,
|
400 |
+
grad_length=10000,
|
401 |
+
stepsize=0.02,
|
402 |
+
num_iterations=3,
|
403 |
+
temperature=1.0,
|
404 |
+
gm_scale=0.9,
|
405 |
+
kl_scale=0.01,
|
406 |
+
top_k=10,
|
407 |
+
window_length=0,
|
408 |
+
horizon_length=1,
|
409 |
+
decay=False,
|
410 |
+
gamma=1.5,
|
411 |
+
):
|
412 |
+
classifier, class_id = get_classifier(
|
413 |
+
discrim,
|
414 |
+
class_label,
|
415 |
+
device
|
416 |
+
)
|
417 |
+
|
418 |
+
bow_indices = []
|
419 |
+
if bag_of_words:
|
420 |
+
bow_indices = get_bag_of_words_indices(bag_of_words.split(";"),
|
421 |
+
tokenizer)
|
422 |
+
|
423 |
+
if bag_of_words and classifier:
|
424 |
+
loss_type = PPLM_BOW_DISCRIM
|
425 |
+
|
426 |
+
elif bag_of_words:
|
427 |
+
loss_type = PPLM_BOW
|
428 |
+
|
429 |
+
elif classifier is not None:
|
430 |
+
loss_type = PPLM_DISCRIM
|
431 |
+
|
432 |
+
else:
|
433 |
+
raise Exception("Specify either a bag of words or a discriminator")
|
434 |
+
|
435 |
+
# unpert_gen_tok_text = generate_text_pplm(
|
436 |
+
# model=model,
|
437 |
+
# tokenizer=tokenizer,
|
438 |
+
# context=context,
|
439 |
+
# device=device,
|
440 |
+
# length=length,
|
441 |
+
# perturb=False
|
442 |
+
# )
|
443 |
+
# if device == 'cuda':
|
444 |
+
# torch.cuda.empty_cache()
|
445 |
+
|
446 |
+
print(context, bow_indices, top_k, gm_scale, kl_scale)
|
447 |
+
|
448 |
+
pert_gen_tok_text, last_losses = generate_text_pplm(
|
449 |
+
model=model,
|
450 |
+
context=context,
|
451 |
+
tokenizer=tokenizer,
|
452 |
+
device=device,
|
453 |
+
max_time=max_time,
|
454 |
+
sample=sample,
|
455 |
+
perturb=True,
|
456 |
+
bow_indices=bow_indices,
|
457 |
+
classifier=classifier,
|
458 |
+
class_label=class_id,
|
459 |
+
loss_type=loss_type,
|
460 |
+
length=length,
|
461 |
+
grad_length=grad_length,
|
462 |
+
stepsize=stepsize,
|
463 |
+
num_iterations=num_iterations,
|
464 |
+
temperature=temperature,
|
465 |
+
gm_scale=gm_scale,
|
466 |
+
kl_scale=kl_scale,
|
467 |
+
top_k=top_k,
|
468 |
+
window_length=window_length,
|
469 |
+
horizon_length=horizon_length,
|
470 |
+
decay=decay,
|
471 |
+
gamma=gamma,
|
472 |
+
)
|
473 |
+
|
474 |
+
if device == 'cuda':
|
475 |
+
torch.cuda.empty_cache()
|
476 |
+
|
477 |
+
return pert_gen_tok_text, last_losses
|
478 |
+
|
479 |
+
|
480 |
+
def generate_text_pplm(
|
481 |
+
model,
|
482 |
+
tokenizer,
|
483 |
+
context=None,
|
484 |
+
past=None,
|
485 |
+
device="cuda",
|
486 |
+
max_time=5,
|
487 |
+
perturb=True,
|
488 |
+
bow_indices=None,
|
489 |
+
classifier=None,
|
490 |
+
class_label=None,
|
491 |
+
loss_type=0,
|
492 |
+
length=100,
|
493 |
+
stepsize=0.02,
|
494 |
+
temperature=1.0,
|
495 |
+
top_k=10,
|
496 |
+
sample=False,
|
497 |
+
num_iterations=3,
|
498 |
+
grad_length=10000,
|
499 |
+
horizon_length=1,
|
500 |
+
window_length=0,
|
501 |
+
decay=False,
|
502 |
+
gamma=1.5,
|
503 |
+
gm_scale=0.9,
|
504 |
+
kl_scale=0.01,
|
505 |
+
):
|
506 |
+
output_so_far = None
|
507 |
+
if context:
|
508 |
+
context_t = torch.tensor(context, device=device, dtype=torch.long)
|
509 |
+
while len(context_t.shape) < 2:
|
510 |
+
context_t = context_t.unsqueeze(0)
|
511 |
+
output_so_far = context_t
|
512 |
+
|
513 |
+
# collect one hot vectors for bags of words
|
514 |
+
one_hot_bows_vectors = build_bows_one_hot_vectors(bow_indices, tokenizer,
|
515 |
+
device)
|
516 |
+
|
517 |
+
start = time.time()
|
518 |
+
|
519 |
+
grad_norms = None
|
520 |
+
last = None
|
521 |
+
losses_this_iter = None
|
522 |
+
losses_in_time = []
|
523 |
+
for i in trange(length, ascii=True):
|
524 |
+
|
525 |
+
# Get past/probs for current output, except for last word
|
526 |
+
# Note that GPT takes 2 inputs: past + current_token
|
527 |
+
|
528 |
+
# run model forward to obtain unperturbed
|
529 |
+
if past is None and output_so_far is not None:
|
530 |
+
last = output_so_far[:, -1:]
|
531 |
+
if output_so_far.shape[1] > 1:
|
532 |
+
_, past, _ = model(output_so_far[:, :-1])
|
533 |
+
|
534 |
+
unpert_logits, unpert_past, unpert_all_hidden = model(output_so_far)
|
535 |
+
unpert_last_hidden = unpert_all_hidden[-1]
|
536 |
+
|
537 |
+
# check if we are abowe grad max length
|
538 |
+
if i >= grad_length:
|
539 |
+
current_stepsize = stepsize * 0
|
540 |
+
else:
|
541 |
+
current_stepsize = stepsize
|
542 |
+
|
543 |
+
# modify the past if necessary
|
544 |
+
if not perturb or num_iterations == 0:
|
545 |
+
pert_past = past
|
546 |
+
|
547 |
+
else:
|
548 |
+
accumulated_hidden = unpert_last_hidden[:, :-1, :]
|
549 |
+
accumulated_hidden = torch.sum(accumulated_hidden, dim=1)
|
550 |
+
|
551 |
+
if past is not None:
|
552 |
+
pert_past, _, grad_norms, losses_this_iter = perturb_past(
|
553 |
+
past,
|
554 |
+
model,
|
555 |
+
last,
|
556 |
+
unpert_past=unpert_past,
|
557 |
+
unpert_logits=unpert_logits,
|
558 |
+
accumulated_hidden=accumulated_hidden,
|
559 |
+
grad_norms=grad_norms,
|
560 |
+
stepsize=current_stepsize,
|
561 |
+
one_hot_bows_vectors=one_hot_bows_vectors,
|
562 |
+
classifier=classifier,
|
563 |
+
class_label=class_label,
|
564 |
+
loss_type=loss_type,
|
565 |
+
num_iterations=num_iterations,
|
566 |
+
horizon_length=horizon_length,
|
567 |
+
window_length=window_length,
|
568 |
+
decay=decay,
|
569 |
+
gamma=gamma,
|
570 |
+
kl_scale=kl_scale,
|
571 |
+
device=device,
|
572 |
+
)
|
573 |
+
losses_in_time.append(losses_this_iter)
|
574 |
+
else:
|
575 |
+
pert_past = past
|
576 |
+
|
577 |
+
pert_logits, past, pert_all_hidden = model(last, past=pert_past)
|
578 |
+
pert_logits = pert_logits[:, -1, :] / temperature # + SMALL_CONST
|
579 |
+
pert_probs = F.softmax(pert_logits, dim=-1)
|
580 |
+
|
581 |
+
# Fuse the modified model and original model
|
582 |
+
if perturb:
|
583 |
+
|
584 |
+
unpert_probs = F.softmax(unpert_logits[:, -1, :], dim=-1)
|
585 |
+
|
586 |
+
pert_probs = ((pert_probs ** gm_scale) * (
|
587 |
+
unpert_probs ** (1 - gm_scale))) # + SMALL_CONST
|
588 |
+
pert_probs = top_k_filter(pert_probs, k=top_k,
|
589 |
+
probs=True) # + SMALL_CONST
|
590 |
+
|
591 |
+
# rescale
|
592 |
+
if torch.sum(pert_probs) <= 1:
|
593 |
+
pert_probs = pert_probs / torch.sum(pert_probs)
|
594 |
+
|
595 |
+
else:
|
596 |
+
pert_logits = top_k_filter(pert_logits, k=top_k) # + SMALL_CONST
|
597 |
+
pert_probs = F.softmax(pert_logits, dim=-1)
|
598 |
+
|
599 |
+
# sample or greedy
|
600 |
+
if sample:
|
601 |
+
last = torch.multinomial(pert_probs, num_samples=1)
|
602 |
+
|
603 |
+
else:
|
604 |
+
_, last = torch.topk(pert_probs, k=1, dim=-1)
|
605 |
+
|
606 |
+
# update context/output_so_far appending the new token
|
607 |
+
output_so_far = (
|
608 |
+
last if output_so_far is None
|
609 |
+
else torch.cat((output_so_far, last), dim=1)
|
610 |
+
)
|
611 |
+
|
612 |
+
if time.time() - start > max_time and max_time != -1:
|
613 |
+
break
|
614 |
+
|
615 |
+
final_losses = losses_this_iter[-1] if losses_this_iter else None
|
616 |
+
return output_so_far, final_losses
|
617 |
+
|
618 |
+
|
619 |
+
def set_generic_model_params(discrim_weights, discrim_meta):
|
620 |
+
if discrim_weights is None:
|
621 |
+
raise ValueError('When using a generic discriminator, '
|
622 |
+
'discrim_weights need to be specified')
|
623 |
+
if discrim_meta is None:
|
624 |
+
raise ValueError('When using a generic discriminator, '
|
625 |
+
'discrim_meta need to be specified')
|
626 |
+
|
627 |
+
with open(discrim_meta, 'r') as discrim_meta_file:
|
628 |
+
meta = json.load(discrim_meta_file)
|
629 |
+
meta['path'] = discrim_weights
|
630 |
+
DISCRIMINATOR_MODELS_PARAMS['generic'] = meta
|
631 |
+
|
632 |
+
|
633 |
+
def run_model(
|
634 |
+
model,
|
635 |
+
tokenizer,
|
636 |
+
device,
|
637 |
+
raw_text,
|
638 |
+
max_time,
|
639 |
+
bag_of_words=None,
|
640 |
+
discrim=None,
|
641 |
+
discrim_weights=None,
|
642 |
+
discrim_meta=None,
|
643 |
+
discrim_label=-1,
|
644 |
+
stepsize=0.02,
|
645 |
+
length=10,
|
646 |
+
seed=None,
|
647 |
+
temperature=1.0,
|
648 |
+
top_k=10,
|
649 |
+
gm_scale=0.9,
|
650 |
+
kl_scale=0.01,
|
651 |
+
uncond=False,
|
652 |
+
num_iterations=3,
|
653 |
+
grad_length=10000,
|
654 |
+
num_samples=1,
|
655 |
+
horizon_length=1,
|
656 |
+
window_length=0,
|
657 |
+
decay=False,
|
658 |
+
gamma=1.5,
|
659 |
+
use_sampling=False
|
660 |
+
):
|
661 |
+
print(seed)
|
662 |
+
if seed is not None:
|
663 |
+
# set Random seed
|
664 |
+
torch.manual_seed(seed)
|
665 |
+
np.random.seed(seed)
|
666 |
+
|
667 |
+
if discrim == 'generic':
|
668 |
+
set_generic_model_params(discrim_weights, discrim_meta)
|
669 |
+
|
670 |
+
tokenized_cond_text = [tokenizer.encode(
|
671 |
+
tokenizer.bos_token + raw_text, max_length=512 - length - 1)] * num_samples
|
672 |
+
|
673 |
+
# Freeze GPT-2 weights
|
674 |
+
for param in model.parameters():
|
675 |
+
param.requires_grad = False
|
676 |
+
|
677 |
+
# generate unperturbed and perturbed texts
|
678 |
+
|
679 |
+
# full_text_generation returns:
|
680 |
+
# unpert_gen_tok_text, pert_gen_tok_texts, discrim_losses, losses_in_time
|
681 |
+
|
682 |
+
pert_gen_tok_text, last_losses = full_text_generation(
|
683 |
+
model=model,
|
684 |
+
tokenizer=tokenizer,
|
685 |
+
context=tokenized_cond_text,
|
686 |
+
device=device,
|
687 |
+
max_time=max_time,
|
688 |
+
num_samples=num_samples,
|
689 |
+
discrim=discrim,
|
690 |
+
class_label=discrim_label,
|
691 |
+
bag_of_words=bag_of_words,
|
692 |
+
length=length,
|
693 |
+
grad_length=grad_length,
|
694 |
+
stepsize=stepsize,
|
695 |
+
num_iterations=num_iterations,
|
696 |
+
temperature=temperature,
|
697 |
+
gm_scale=gm_scale,
|
698 |
+
kl_scale=kl_scale,
|
699 |
+
top_k=top_k,
|
700 |
+
window_length=window_length,
|
701 |
+
horizon_length=horizon_length,
|
702 |
+
decay=decay,
|
703 |
+
gamma=gamma,
|
704 |
+
sample=use_sampling
|
705 |
+
)
|
706 |
+
|
707 |
+
generated_texts = []
|
708 |
+
|
709 |
+
# iterate through the perturbed texts
|
710 |
+
for sample, loss in zip(pert_gen_tok_text.tolist(), last_losses.tolist()):
|
711 |
+
generated_part = sample[len(tokenized_cond_text[0]):]
|
712 |
+
pert_gen_text = tokenizer.decode(generated_part)
|
713 |
+
|
714 |
+
# keep the prefix, perturbed seq, original seq for each index
|
715 |
+
generated_texts.append(
|
716 |
+
{
|
717 |
+
"value": pert_gen_text,
|
718 |
+
"tokens": len(generated_part),
|
719 |
+
"loss": loss
|
720 |
+
}
|
721 |
+
)
|
722 |
+
|
723 |
+
return generated_texts
|
backend/README.md
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Python backend
|
2 |
+
|
3 |
+
## Setup
|
4 |
+
|
5 |
+
```
|
6 |
+
pip install -r requirements.txt
|
7 |
+
chmod +x launch.sh
|
8 |
+
```
|
9 |
+
|
10 |
+
## Execution
|
11 |
+
|
12 |
+
|
13 |
+
`./launch.sh`
|
14 |
+
|
15 |
+
## Usage
|
16 |
+
|
17 |
+
The API listens to the port `6006` and the route `autocomplete`. It listens to `POST` requests.
|
18 |
+
Query it like this: `{POST}http://<url>:6006/autocomplete`
|
19 |
+
|
20 |
+
The necessary argument is `context` which is a string of characters (ideally a sentence) which will be converted in tokens and fed to GPT-2.
|
21 |
+
|
22 |
+
The optional arguments are detailed below:
|
23 |
+
|
24 |
+
`length` is an unsigned int which sets the maximum length (in tokens) of the generated sentence __default: 100__
|
25 |
+
|
26 |
+
`n_samples` is an int `0 < n_samples <= 3` which sets the maximum amount of samples generated. __default: 3__
|
27 |
+
|
28 |
+
`max_time` is an unsigned float which sets an heuristic for the maximum time spent generating sentences. It is a heuristic because it is not exact, it can slightly overflow. __default: infinite__
|
29 |
+
|
30 |
+
`model_size` takes `"small"` or `"medium"` as input and corresponds to the GPT model size __default: small__
|
31 |
+
|
32 |
+
`temperature` float - temperature of the model __default: 1__
|
33 |
+
|
34 |
+
`max_tokens` int - maximum amount of tokens that will be fed into the model. __default: 256__
|
35 |
+
|
36 |
+
`top_p` float - 0 < top_p < 1, nucleus sampling; only tokens with a cumulative probability of top_p will be selected for multinomial sampling __default: 0.9__
|
37 |
+
|
38 |
+
`top_k` int - Only top k tokens will be selected for multinomial sampling. __default: 256__
|
39 |
+
|
40 |
+
## Return format
|
41 |
+
|
42 |
+
The server returns a set of sentences according to the context. Their format is:
|
43 |
+
```
|
44 |
+
{sentences: {value: string, time: number}[], time: number}
|
45 |
+
```
|
46 |
+
|
47 |
+
Example:
|
48 |
+
|
49 |
+
With POST parameters as:
|
50 |
+
|
51 |
+
```json
|
52 |
+
{
|
53 |
+
"context": "That man is just another",
|
54 |
+
"samples": 3
|
55 |
+
}
|
56 |
+
```
|
57 |
+
|
58 |
+
The response is as follows:
|
59 |
+
|
60 |
+
```json
|
61 |
+
{
|
62 |
+
"sentences": [
|
63 |
+
{"value": " handicapped working man.", "time": 0.15415167808532715},
|
64 |
+
{"value": " guy, doing everything his manly nature requires.", "time": 0.2581148147583008},
|
65 |
+
{"value": " guy, Mohr said.", "time": 0.17547011375427246}
|
66 |
+
],
|
67 |
+
"time": 0.264873743057251
|
68 |
+
}
|
69 |
+
```
|
backend/Utils.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
|
4 |
+
def forward(model_name, model, input_ids, past, device='cpu'):
|
5 |
+
if "gpt2" in model_name or "ctrl" in model_name:
|
6 |
+
if past is not None:
|
7 |
+
return model(input_ids[:, -1], past=past)
|
8 |
+
return model(input_ids)
|
9 |
+
elif "xlnet" in model_name:
|
10 |
+
input_ids = torch.cat((
|
11 |
+
input_ids,
|
12 |
+
torch.zeros((input_ids.shape[0], 1), dtype=torch.long, device=device)
|
13 |
+
), dim=1)
|
14 |
+
|
15 |
+
perm_mask = torch.zeros(
|
16 |
+
(input_ids.shape[0], input_ids.shape[1], input_ids.shape[1]),
|
17 |
+
dtype=torch.float,
|
18 |
+
device=device
|
19 |
+
)
|
20 |
+
perm_mask[:, :, -1] = 1.0
|
21 |
+
|
22 |
+
target_mapping = torch.zeros(
|
23 |
+
(input_ids.shape[0], 1, input_ids.shape[1]),
|
24 |
+
dtype=torch.float,
|
25 |
+
device=device)
|
26 |
+
target_mapping[:, 0, -1] = 1.0
|
27 |
+
|
28 |
+
return model(input_ids, perm_mask=perm_mask, target_mapping=target_mapping)
|
29 |
+
elif "transfo-xl" in model_name:
|
30 |
+
return model(input_ids, mems=past)
|
31 |
+
else:
|
32 |
+
return model(input_ids)
|
33 |
+
|
34 |
+
|
35 |
+
def create_context(model_name, tokenizer, initial_text="", padding_text=None, max_tokens=512):
|
36 |
+
if not len(initial_text) and "gpt2" in model_name:
|
37 |
+
initial_text = "<|endoftext|>"
|
38 |
+
if 'xlnet' in model_name or "transfo-xl" in model_name:
|
39 |
+
initial_text = padding_text + initial_text
|
40 |
+
|
41 |
+
if 'transfo-xl' in model_name:
|
42 |
+
max_tokens = int(max_tokens / 2)
|
43 |
+
|
44 |
+
context_tokens = tokenizer.encode(initial_text)[-max_tokens:]
|
45 |
+
|
46 |
+
if "gpt2" in model_name:
|
47 |
+
eot_token = tokenizer.encoder["<|endoftext|>"]
|
48 |
+
if len(context_tokens) == 0:
|
49 |
+
context_tokens = [tokenizer.encoder["<|endoftext|>"]]
|
50 |
+
elif "xlnet" in model_name:
|
51 |
+
eot_token = tokenizer.convert_tokens_to_ids('<eop>')
|
52 |
+
else:
|
53 |
+
eot_token = None
|
54 |
+
dot_token = tokenizer.encode(".")[-1]
|
55 |
+
|
56 |
+
return context_tokens, eot_token, dot_token
|
57 |
+
|
backend/id/.gitignore
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
/node_modules
|
2 |
+
/dist
|
backend/id/id.ts
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import * as Koa from "koa";
|
2 |
+
import * as Router from "koa-router";
|
3 |
+
|
4 |
+
let id = 0;
|
5 |
+
|
6 |
+
const app = new Koa();
|
7 |
+
const router = new Router();
|
8 |
+
|
9 |
+
router.get("/*", async ctx => {
|
10 |
+
ctx.body = id++;
|
11 |
+
});
|
12 |
+
|
13 |
+
app.use(router.routes());
|
14 |
+
|
15 |
+
app.listen(3000);
|
16 |
+
|
17 |
+
console.log("Server running on port 3000");
|
backend/id/package.json
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"dependencies": {
|
3 |
+
"@types/koa": "^2.0.48",
|
4 |
+
"@types/koa-router": "^7.0.40",
|
5 |
+
"koa": "^2.7.0",
|
6 |
+
"koa-router": "^7.4.0",
|
7 |
+
"typescript": "^3.5.1"
|
8 |
+
},
|
9 |
+
"scripts": {
|
10 |
+
"start": "tsc && node dist/id.js"
|
11 |
+
}
|
12 |
+
}
|
backend/id/tsconfig.json
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"compilerOptions": {
|
3 |
+
"target": "esnext",
|
4 |
+
"module": "commonjs",
|
5 |
+
"outDir": "dist/",
|
6 |
+
"strictNullChecks": true,
|
7 |
+
"strict": true,
|
8 |
+
"lib": ["esnext", "dom", "es6", "es2016", "es2017", "es2018"]
|
9 |
+
}
|
10 |
+
}
|
backend/id/yarn.lock
ADDED
@@ -0,0 +1,426 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# THIS IS AN AUTOGENERATED FILE. DO NOT EDIT THIS FILE DIRECTLY.
|
2 |
+
# yarn lockfile v1
|
3 |
+
|
4 |
+
|
5 |
+
"@types/accepts@*":
|
6 |
+
version "1.3.5"
|
7 |
+
resolved "https://registry.yarnpkg.com/@types/accepts/-/accepts-1.3.5.tgz#c34bec115cfc746e04fe5a059df4ce7e7b391575"
|
8 |
+
integrity sha512-jOdnI/3qTpHABjM5cx1Hc0sKsPoYCp+DP/GJRGtDlPd7fiV9oXGGIcjW/ZOxLIvjGz8MA+uMZI9metHlgqbgwQ==
|
9 |
+
dependencies:
|
10 |
+
"@types/node" "*"
|
11 |
+
|
12 |
+
"@types/body-parser@*":
|
13 |
+
version "1.17.0"
|
14 |
+
resolved "https://registry.yarnpkg.com/@types/body-parser/-/body-parser-1.17.0.tgz#9f5c9d9bd04bb54be32d5eb9fc0d8c974e6cf58c"
|
15 |
+
integrity sha512-a2+YeUjPkztKJu5aIF2yArYFQQp8d51wZ7DavSHjFuY1mqVgidGyzEQ41JIVNy82fXj8yPgy2vJmfIywgESW6w==
|
16 |
+
dependencies:
|
17 |
+
"@types/connect" "*"
|
18 |
+
"@types/node" "*"
|
19 |
+
|
20 |
+
"@types/connect@*":
|
21 |
+
version "3.4.32"
|
22 |
+
resolved "https://registry.yarnpkg.com/@types/connect/-/connect-3.4.32.tgz#aa0e9616b9435ccad02bc52b5b454ffc2c70ba28"
|
23 |
+
integrity sha512-4r8qa0quOvh7lGD0pre62CAb1oni1OO6ecJLGCezTmhQ8Fz50Arx9RUszryR8KlgK6avuSXvviL6yWyViQABOg==
|
24 |
+
dependencies:
|
25 |
+
"@types/node" "*"
|
26 |
+
|
27 |
+
"@types/cookies@*":
|
28 |
+
version "0.7.2"
|
29 |
+
resolved "https://registry.yarnpkg.com/@types/cookies/-/cookies-0.7.2.tgz#5e0560d46ed9998082dce799af1058dd6a49780a"
|
30 |
+
integrity sha512-jnihWgshWystcJKrz8C9hV+Ot9lqOUyAh2RF+o3BEo6K6AS2l4zYCb9GYaBuZ3C6Il59uIGqpE3HvCun4KKeJA==
|
31 |
+
dependencies:
|
32 |
+
"@types/connect" "*"
|
33 |
+
"@types/express" "*"
|
34 |
+
"@types/keygrip" "*"
|
35 |
+
"@types/node" "*"
|
36 |
+
|
37 |
+
"@types/express-serve-static-core@*":
|
38 |
+
version "4.16.7"
|
39 |
+
resolved "https://registry.yarnpkg.com/@types/express-serve-static-core/-/express-serve-static-core-4.16.7.tgz#50ba6f8a691c08a3dd9fa7fba25ef3133d298049"
|
40 |
+
integrity sha512-847KvL8Q1y3TtFLRTXcVakErLJQgdpFSaq+k043xefz9raEf0C7HalpSY7OW5PyjCnY8P7bPW5t/Co9qqp+USg==
|
41 |
+
dependencies:
|
42 |
+
"@types/node" "*"
|
43 |
+
"@types/range-parser" "*"
|
44 |
+
|
45 |
+
"@types/express@*":
|
46 |
+
version "4.17.0"
|
47 |
+
resolved "https://registry.yarnpkg.com/@types/express/-/express-4.17.0.tgz#49eaedb209582a86f12ed9b725160f12d04ef287"
|
48 |
+
integrity sha512-CjaMu57cjgjuZbh9DpkloeGxV45CnMGlVd+XpG7Gm9QgVrd7KFq+X4HY0vM+2v0bczS48Wg7bvnMY5TN+Xmcfw==
|
49 |
+
dependencies:
|
50 |
+
"@types/body-parser" "*"
|
51 |
+
"@types/express-serve-static-core" "*"
|
52 |
+
"@types/serve-static" "*"
|
53 |
+
|
54 |
+
"@types/http-assert@*":
|
55 |
+
version "1.4.0"
|
56 |
+
resolved "https://registry.yarnpkg.com/@types/http-assert/-/http-assert-1.4.0.tgz#41d173466e396e99a14d75f7160cc997f2f9ed8b"
|
57 |
+
integrity sha512-TZDqvFW4nQwL9DVSNJIJu4lPLttKgzRF58COa7Vs42Ki/MrhIqUbeIw0MWn4kGLiZLXB7oCBibm7nkSjPkzfKQ==
|
58 |
+
|
59 |
+
"@types/keygrip@*":
|
60 |
+
version "1.0.1"
|
61 |
+
resolved "https://registry.yarnpkg.com/@types/keygrip/-/keygrip-1.0.1.tgz#ff540462d2fb4d0a88441ceaf27d287b01c3d878"
|
62 |
+
integrity sha1-/1QEYtL7TQqIRBzq8n0oewHD2Hg=
|
63 |
+
|
64 |
+
"@types/koa-compose@*":
|
65 |
+
version "3.2.4"
|
66 |
+
resolved "https://registry.yarnpkg.com/@types/koa-compose/-/koa-compose-3.2.4.tgz#76a461634a59c3e13449831708bb9b355fb1548e"
|
67 |
+
integrity sha512-ioou0rxkuWL+yBQYsHUQAzRTfVxAg8Y2VfMftU+Y3RA03/MzuFL0x/M2sXXj3PkfnENbHsjeHR1aMdezLYpTeA==
|
68 |
+
dependencies:
|
69 |
+
"@types/koa" "*"
|
70 |
+
|
71 |
+
"@types/koa-router@^7.0.40":
|
72 |
+
version "7.0.40"
|
73 |
+
resolved "https://registry.yarnpkg.com/@types/koa-router/-/koa-router-7.0.40.tgz#9654dbc43375a0380c44c49c4504b4dbfc3e4e6a"
|
74 |
+
integrity sha512-YK4+WGXch6Ig9PreZ9jlHZb2onm0S1szGw0oQxWvPhoyjSHo1Tq+CpjxMmthEUIQUc9KznOGgehFarOx8XwsFw==
|
75 |
+
dependencies:
|
76 |
+
"@types/koa" "*"
|
77 |
+
|
78 |
+
"@types/koa@*", "@types/koa@^2.0.48":
|
79 |
+
version "2.0.48"
|
80 |
+
resolved "https://registry.yarnpkg.com/@types/koa/-/koa-2.0.48.tgz#29162783029d3e5df8b58c55f6bf0d35f78fc39f"
|
81 |
+
integrity sha512-CiIUYhHlOFJhSCTmsFoFkV2t9ij1JwW26nt0W9XZoWTvmAw6zTE0+k3IAoGICtjzIfhZpZcO323NHmI1LGmdDw==
|
82 |
+
dependencies:
|
83 |
+
"@types/accepts" "*"
|
84 |
+
"@types/cookies" "*"
|
85 |
+
"@types/http-assert" "*"
|
86 |
+
"@types/keygrip" "*"
|
87 |
+
"@types/koa-compose" "*"
|
88 |
+
"@types/node" "*"
|
89 |
+
|
90 |
+
"@types/mime@*":
|
91 |
+
version "2.0.1"
|
92 |
+
resolved "https://registry.yarnpkg.com/@types/mime/-/mime-2.0.1.tgz#dc488842312a7f075149312905b5e3c0b054c79d"
|
93 |
+
integrity sha512-FwI9gX75FgVBJ7ywgnq/P7tw+/o1GUbtP0KzbtusLigAOgIgNISRK0ZPl4qertvXSIE8YbsVJueQ90cDt9YYyw==
|
94 |
+
|
95 |
+
"@types/node@*":
|
96 |
+
version "12.0.7"
|
97 |
+
resolved "https://registry.yarnpkg.com/@types/node/-/node-12.0.7.tgz#4f2563bad652b2acb1722d7e7aae2b0ff62d192c"
|
98 |
+
integrity sha512-1YKeT4JitGgE4SOzyB9eMwO0nGVNkNEsm9qlIt1Lqm/tG2QEiSMTD4kS3aO6L+w5SClLVxALmIBESK6Mk5wX0A==
|
99 |
+
|
100 |
+
"@types/range-parser@*":
|
101 |
+
version "1.2.3"
|
102 |
+
resolved "https://registry.yarnpkg.com/@types/range-parser/-/range-parser-1.2.3.tgz#7ee330ba7caafb98090bece86a5ee44115904c2c"
|
103 |
+
integrity sha512-ewFXqrQHlFsgc09MK5jP5iR7vumV/BYayNC6PgJO2LPe8vrnNFyjQjSppfEngITi0qvfKtzFvgKymGheFM9UOA==
|
104 |
+
|
105 |
+
"@types/serve-static@*":
|
106 |
+
version "1.13.2"
|
107 |
+
resolved "https://registry.yarnpkg.com/@types/serve-static/-/serve-static-1.13.2.tgz#f5ac4d7a6420a99a6a45af4719f4dcd8cd907a48"
|
108 |
+
integrity sha512-/BZ4QRLpH/bNYgZgwhKEh+5AsboDBcUdlBYgzoLX0fpj3Y2gp6EApyOlM3bK53wQS/OE1SrdSYBAbux2D1528Q==
|
109 |
+
dependencies:
|
110 |
+
"@types/express-serve-static-core" "*"
|
111 |
+
"@types/mime" "*"
|
112 |
+
|
113 |
+
accepts@^1.3.5:
|
114 |
+
version "1.3.7"
|
115 |
+
resolved "https://registry.yarnpkg.com/accepts/-/accepts-1.3.7.tgz#531bc726517a3b2b41f850021c6cc15eaab507cd"
|
116 |
+
integrity sha512-Il80Qs2WjYlJIBNzNkK6KYqlVMTbZLXgHx2oT0pU/fjRHyEp+PEfEPY0R3WCwAGVOtauxh1hOxNgIf5bv7dQpA==
|
117 |
+
dependencies:
|
118 |
+
mime-types "~2.1.24"
|
119 |
+
negotiator "0.6.2"
|
120 |
+
|
121 |
+
any-promise@^1.1.0:
|
122 |
+
version "1.3.0"
|
123 |
+
resolved "https://registry.yarnpkg.com/any-promise/-/any-promise-1.3.0.tgz#abc6afeedcea52e809cdc0376aed3ce39635d17f"
|
124 |
+
integrity sha1-q8av7tzqUugJzcA3au0845Y10X8=
|
125 |
+
|
126 |
+
cache-content-type@^1.0.0:
|
127 |
+
version "1.0.1"
|
128 |
+
resolved "https://registry.yarnpkg.com/cache-content-type/-/cache-content-type-1.0.1.tgz#035cde2b08ee2129f4a8315ea8f00a00dba1453c"
|
129 |
+
integrity sha512-IKufZ1o4Ut42YUrZSo8+qnMTrFuKkvyoLXUywKz9GJ5BrhOFGhLdkx9sG4KAnVvbY6kEcSFjLQul+DVmBm2bgA==
|
130 |
+
dependencies:
|
131 |
+
mime-types "^2.1.18"
|
132 |
+
ylru "^1.2.0"
|
133 |
+
|
134 |
+
co@^4.6.0:
|
135 |
+
version "4.6.0"
|
136 |
+
resolved "https://registry.yarnpkg.com/co/-/co-4.6.0.tgz#6ea6bdf3d853ae54ccb8e47bfa0bf3f9031fb184"
|
137 |
+
integrity sha1-bqa989hTrlTMuOR7+gvz+QMfsYQ=
|
138 |
+
|
139 |
+
content-disposition@~0.5.2:
|
140 |
+
version "0.5.3"
|
141 |
+
resolved "https://registry.yarnpkg.com/content-disposition/-/content-disposition-0.5.3.tgz#e130caf7e7279087c5616c2007d0485698984fbd"
|
142 |
+
integrity sha512-ExO0774ikEObIAEV9kDo50o+79VCUdEB6n6lzKgGwupcVeRlhrj3qGAfwq8G6uBJjkqLrhT0qEYFcWng8z1z0g==
|
143 |
+
dependencies:
|
144 |
+
safe-buffer "5.1.2"
|
145 |
+
|
146 |
+
content-type@^1.0.4:
|
147 |
+
version "1.0.4"
|
148 |
+
resolved "https://registry.yarnpkg.com/content-type/-/content-type-1.0.4.tgz#e138cc75e040c727b1966fe5e5f8c9aee256fe3b"
|
149 |
+
integrity sha512-hIP3EEPs8tB9AT1L+NUqtwOAps4mk2Zob89MWXMHjHWg9milF/j4osnnQLXBCBFBk/tvIG/tUc9mOUJiPBhPXA==
|
150 |
+
|
151 |
+
cookies@~0.7.1:
|
152 |
+
version "0.7.3"
|
153 |
+
resolved "https://registry.yarnpkg.com/cookies/-/cookies-0.7.3.tgz#7912ce21fbf2e8c2da70cf1c3f351aecf59dadfa"
|
154 |
+
integrity sha512-+gixgxYSgQLTaTIilDHAdlNPZDENDQernEMiIcZpYYP14zgHsCt4Ce1FEjFtcp6GefhozebB6orvhAAWx/IS0A==
|
155 |
+
dependencies:
|
156 |
+
depd "~1.1.2"
|
157 |
+
keygrip "~1.0.3"
|
158 |
+
|
159 |
+
debug@^3.1.0:
|
160 |
+
version "3.2.6"
|
161 |
+
resolved "https://registry.yarnpkg.com/debug/-/debug-3.2.6.tgz#e83d17de16d8a7efb7717edbe5fb10135eee629b"
|
162 |
+
integrity sha512-mel+jf7nrtEl5Pn1Qx46zARXKDpBbvzezse7p7LqINmdoIk8PYP5SySaxEmYv6TZ0JyEKA1hsCId6DIhgITtWQ==
|
163 |
+
dependencies:
|
164 |
+
ms "^2.1.1"
|
165 |
+
|
166 |
+
debug@~3.1.0:
|
167 |
+
version "3.1.0"
|
168 |
+
resolved "https://registry.yarnpkg.com/debug/-/debug-3.1.0.tgz#5bb5a0672628b64149566ba16819e61518c67261"
|
169 |
+
integrity sha512-OX8XqP7/1a9cqkxYw2yXss15f26NKWBpDXQd0/uK/KPqdQhxbPa994hnzjcE2VqQpDslf55723cKPUOGSmMY3g==
|
170 |
+
dependencies:
|
171 |
+
ms "2.0.0"
|
172 |
+
|
173 |
+
deep-equal@~1.0.1:
|
174 |
+
version "1.0.1"
|
175 |
+
resolved "https://registry.yarnpkg.com/deep-equal/-/deep-equal-1.0.1.tgz#f5d260292b660e084eff4cdbc9f08ad3247448b5"
|
176 |
+
integrity sha1-9dJgKStmDghO/0zbyfCK0yR0SLU=
|
177 |
+
|
178 |
+
delegates@^1.0.0:
|
179 |
+
version "1.0.0"
|
180 |
+
resolved "https://registry.yarnpkg.com/delegates/-/delegates-1.0.0.tgz#84c6e159b81904fdca59a0ef44cd870d31250f9a"
|
181 |
+
integrity sha1-hMbhWbgZBP3KWaDvRM2HDTElD5o=
|
182 |
+
|
183 |
+
depd@^1.1.2, depd@~1.1.2:
|
184 |
+
version "1.1.2"
|
185 |
+
resolved "https://registry.yarnpkg.com/depd/-/depd-1.1.2.tgz#9bcd52e14c097763e749b274c4346ed2e560b5a9"
|
186 |
+
integrity sha1-m81S4UwJd2PnSbJ0xDRu0uVgtak=
|
187 |
+
|
188 |
+
destroy@^1.0.4:
|
189 |
+
version "1.0.4"
|
190 |
+
resolved "https://registry.yarnpkg.com/destroy/-/destroy-1.0.4.tgz#978857442c44749e4206613e37946205826abd80"
|
191 |
+
integrity sha1-l4hXRCxEdJ5CBmE+N5RiBYJqvYA=
|
192 |
+
|
193 | |
194 |
+
version "1.1.1"
|
195 |
+
resolved "https://registry.yarnpkg.com/ee-first/-/ee-first-1.1.1.tgz#590c61156b0ae2f4f0255732a158b266bc56b21d"
|
196 |
+
integrity sha1-WQxhFWsK4vTwJVcyoViyZrxWsh0=
|
197 |
+
|
198 |
+
error-inject@^1.0.0:
|
199 |
+
version "1.0.0"
|
200 |
+
resolved "https://registry.yarnpkg.com/error-inject/-/error-inject-1.0.0.tgz#e2b3d91b54aed672f309d950d154850fa11d4f37"
|
201 |
+
integrity sha1-4rPZG1Su1nLzCdlQ0VSFD6EdTzc=
|
202 |
+
|
203 |
+
escape-html@^1.0.3:
|
204 |
+
version "1.0.3"
|
205 |
+
resolved "https://registry.yarnpkg.com/escape-html/-/escape-html-1.0.3.tgz#0258eae4d3d0c0974de1c169188ef0051d1d1988"
|
206 |
+
integrity sha1-Aljq5NPQwJdN4cFpGI7wBR0dGYg=
|
207 |
+
|
208 |
+
fresh@~0.5.2:
|
209 |
+
version "0.5.2"
|
210 |
+
resolved "https://registry.yarnpkg.com/fresh/-/fresh-0.5.2.tgz#3d8cadd90d976569fa835ab1f8e4b23a105605a7"
|
211 |
+
integrity sha1-PYyt2Q2XZWn6g1qx+OSyOhBWBac=
|
212 |
+
|
213 |
+
http-assert@^1.3.0:
|
214 |
+
version "1.4.1"
|
215 |
+
resolved "https://registry.yarnpkg.com/http-assert/-/http-assert-1.4.1.tgz#c5f725d677aa7e873ef736199b89686cceb37878"
|
216 |
+
integrity sha512-rdw7q6GTlibqVVbXr0CKelfV5iY8G2HqEUkhSk297BMbSpSL8crXC+9rjKoMcZZEsksX30le6f/4ul4E28gegw==
|
217 |
+
dependencies:
|
218 |
+
deep-equal "~1.0.1"
|
219 |
+
http-errors "~1.7.2"
|
220 |
+
|
221 |
+
http-errors@^1.3.1, http-errors@^1.6.3, http-errors@~1.7.2:
|
222 |
+
version "1.7.2"
|
223 |
+
resolved "https://registry.yarnpkg.com/http-errors/-/http-errors-1.7.2.tgz#4f5029cf13239f31036e5b2e55292bcfbcc85c8f"
|
224 |
+
integrity sha512-uUQBt3H/cSIVfch6i1EuPNy/YsRSOUBXTVfZ+yR7Zjez3qjBz6i9+i4zjNaoqcoFVI4lQJ5plg63TvGfRSDCRg==
|
225 |
+
dependencies:
|
226 |
+
depd "~1.1.2"
|
227 |
+
inherits "2.0.3"
|
228 |
+
setprototypeof "1.1.1"
|
229 |
+
statuses ">= 1.5.0 < 2"
|
230 |
+
toidentifier "1.0.0"
|
231 |
+
|
232 | |
233 |
+
version "2.0.3"
|
234 |
+
resolved "https://registry.yarnpkg.com/inherits/-/inherits-2.0.3.tgz#633c2c83e3da42a502f52466022480f4208261de"
|
235 |
+
integrity sha1-Yzwsg+PaQqUC9SRmAiSA9CCCYd4=
|
236 |
+
|
237 |
+
is-generator-function@^1.0.7:
|
238 |
+
version "1.0.7"
|
239 |
+
resolved "https://registry.yarnpkg.com/is-generator-function/-/is-generator-function-1.0.7.tgz#d2132e529bb0000a7f80794d4bdf5cd5e5813522"
|
240 |
+
integrity sha512-YZc5EwyO4f2kWCax7oegfuSr9mFz1ZvieNYBEjmukLxgXfBUbxAWGVF7GZf0zidYtoBl3WvC07YK0wT76a+Rtw==
|
241 |
+
|
242 | |
243 |
+
version "0.0.1"
|
244 |
+
resolved "https://registry.yarnpkg.com/isarray/-/isarray-0.0.1.tgz#8a18acfca9a8f4177e09abfc6038939b05d1eedf"
|
245 |
+
integrity sha1-ihis/Kmo9Bd+Cav8YDiTmwXR7t8=
|
246 |
+
|
247 |
+
keygrip@~1.0.3:
|
248 |
+
version "1.0.3"
|
249 |
+
resolved "https://registry.yarnpkg.com/keygrip/-/keygrip-1.0.3.tgz#399d709f0aed2bab0a059e0cdd3a5023a053e1dc"
|
250 |
+
integrity sha512-/PpesirAIfaklxUzp4Yb7xBper9MwP6hNRA6BGGUFCgbJ+BM5CKBtsoxinNXkLHAr+GXS1/lSlF2rP7cv5Fl+g==
|
251 |
+
|
252 |
+
koa-compose@^3.0.0:
|
253 |
+
version "3.2.1"
|
254 |
+
resolved "https://registry.yarnpkg.com/koa-compose/-/koa-compose-3.2.1.tgz#a85ccb40b7d986d8e5a345b3a1ace8eabcf54de7"
|
255 |
+
integrity sha1-qFzLQLfZhtjlo0Wzoazo6rz1Tec=
|
256 |
+
dependencies:
|
257 |
+
any-promise "^1.1.0"
|
258 |
+
|
259 |
+
koa-compose@^4.1.0:
|
260 |
+
version "4.1.0"
|
261 |
+
resolved "https://registry.yarnpkg.com/koa-compose/-/koa-compose-4.1.0.tgz#507306b9371901db41121c812e923d0d67d3e877"
|
262 |
+
integrity sha512-8ODW8TrDuMYvXRwra/Kh7/rJo9BtOfPc6qO8eAfC80CnCvSjSl0bkRM24X6/XBBEyj0v1nRUQ1LyOy3dbqOWXw==
|
263 |
+
|
264 |
+
koa-convert@^1.2.0:
|
265 |
+
version "1.2.0"
|
266 |
+
resolved "https://registry.yarnpkg.com/koa-convert/-/koa-convert-1.2.0.tgz#da40875df49de0539098d1700b50820cebcd21d0"
|
267 |
+
integrity sha1-2kCHXfSd4FOQmNFwC1CCDOvNIdA=
|
268 |
+
dependencies:
|
269 |
+
co "^4.6.0"
|
270 |
+
koa-compose "^3.0.0"
|
271 |
+
|
272 |
+
koa-is-json@^1.0.0:
|
273 |
+
version "1.0.0"
|
274 |
+
resolved "https://registry.yarnpkg.com/koa-is-json/-/koa-is-json-1.0.0.tgz#273c07edcdcb8df6a2c1ab7d59ee76491451ec14"
|
275 |
+
integrity sha1-JzwH7c3Ljfaiwat9We52SRRR7BQ=
|
276 |
+
|
277 |
+
koa-router@^7.4.0:
|
278 |
+
version "7.4.0"
|
279 |
+
resolved "https://registry.yarnpkg.com/koa-router/-/koa-router-7.4.0.tgz#aee1f7adc02d5cb31d7d67465c9eacc825e8c5e0"
|
280 |
+
integrity sha512-IWhaDXeAnfDBEpWS6hkGdZ1ablgr6Q6pGdXCyK38RbzuH4LkUOpPqPw+3f8l8aTDrQmBQ7xJc0bs2yV4dzcO+g==
|
281 |
+
dependencies:
|
282 |
+
debug "^3.1.0"
|
283 |
+
http-errors "^1.3.1"
|
284 |
+
koa-compose "^3.0.0"
|
285 |
+
methods "^1.0.1"
|
286 |
+
path-to-regexp "^1.1.1"
|
287 |
+
urijs "^1.19.0"
|
288 |
+
|
289 |
+
koa@^2.7.0:
|
290 |
+
version "2.7.0"
|
291 |
+
resolved "https://registry.yarnpkg.com/koa/-/koa-2.7.0.tgz#7e00843506942b9d82c6cc33749f657c6e5e7adf"
|
292 |
+
integrity sha512-7ojD05s2Q+hFudF8tDLZ1CpCdVZw8JQELWSkcfG9bdtoTDzMmkRF6BQBU7JzIzCCOY3xd3tftiy/loHBUYaY2Q==
|
293 |
+
dependencies:
|
294 |
+
accepts "^1.3.5"
|
295 |
+
cache-content-type "^1.0.0"
|
296 |
+
content-disposition "~0.5.2"
|
297 |
+
content-type "^1.0.4"
|
298 |
+
cookies "~0.7.1"
|
299 |
+
debug "~3.1.0"
|
300 |
+
delegates "^1.0.0"
|
301 |
+
depd "^1.1.2"
|
302 |
+
destroy "^1.0.4"
|
303 |
+
error-inject "^1.0.0"
|
304 |
+
escape-html "^1.0.3"
|
305 |
+
fresh "~0.5.2"
|
306 |
+
http-assert "^1.3.0"
|
307 |
+
http-errors "^1.6.3"
|
308 |
+
is-generator-function "^1.0.7"
|
309 |
+
koa-compose "^4.1.0"
|
310 |
+
koa-convert "^1.2.0"
|
311 |
+
koa-is-json "^1.0.0"
|
312 |
+
on-finished "^2.3.0"
|
313 |
+
only "~0.0.2"
|
314 |
+
parseurl "^1.3.2"
|
315 |
+
statuses "^1.5.0"
|
316 |
+
type-is "^1.6.16"
|
317 |
+
vary "^1.1.2"
|
318 |
+
|
319 | |
320 |
+
version "0.3.0"
|
321 |
+
resolved "https://registry.yarnpkg.com/media-typer/-/media-typer-0.3.0.tgz#8710d7af0aa626f8fffa1ce00168545263255748"
|
322 |
+
integrity sha1-hxDXrwqmJvj/+hzgAWhUUmMlV0g=
|
323 |
+
|
324 |
+
methods@^1.0.1:
|
325 |
+
version "1.1.2"
|
326 |
+
resolved "https://registry.yarnpkg.com/methods/-/methods-1.1.2.tgz#5529a4d67654134edcc5266656835b0f851afcee"
|
327 |
+
integrity sha1-VSmk1nZUE07cxSZmVoNbD4Ua/O4=
|
328 |
+
|
329 | |
330 |
+
version "1.40.0"
|
331 |
+
resolved "https://registry.yarnpkg.com/mime-db/-/mime-db-1.40.0.tgz#a65057e998db090f732a68f6c276d387d4126c32"
|
332 |
+
integrity sha512-jYdeOMPy9vnxEqFRRo6ZvTZ8d9oPb+k18PKoYNYUe2stVEBPPwsln/qWzdbmaIvnhZ9v2P+CuecK+fpUfsV2mA==
|
333 |
+
|
334 |
+
mime-types@^2.1.18, mime-types@~2.1.24:
|
335 |
+
version "2.1.24"
|
336 |
+
resolved "https://registry.yarnpkg.com/mime-types/-/mime-types-2.1.24.tgz#b6f8d0b3e951efb77dedeca194cff6d16f676f81"
|
337 |
+
integrity sha512-WaFHS3MCl5fapm3oLxU4eYDw77IQM2ACcxQ9RIxfaC3ooc6PFuBMGZZsYpvoXS5D5QTWPieo1jjLdAm3TBP3cQ==
|
338 |
+
dependencies:
|
339 |
+
mime-db "1.40.0"
|
340 |
+
|
341 | |
342 |
+
version "2.0.0"
|
343 |
+
resolved "https://registry.yarnpkg.com/ms/-/ms-2.0.0.tgz#5608aeadfc00be6c2901df5f9861788de0d597c8"
|
344 |
+
integrity sha1-VgiurfwAvmwpAd9fmGF4jeDVl8g=
|
345 |
+
|
346 |
+
ms@^2.1.1:
|
347 |
+
version "2.1.2"
|
348 |
+
resolved "https://registry.yarnpkg.com/ms/-/ms-2.1.2.tgz#d09d1f357b443f493382a8eb3ccd183872ae6009"
|
349 |
+
integrity sha512-sGkPx+VjMtmA6MX27oA4FBFELFCZZ4S4XqeGOXCv68tT+jb3vk/RyaKWP0PTKyWtmLSM0b+adUTEvbs1PEaH2w==
|
350 |
+
|
351 | |
352 |
+
version "0.6.2"
|
353 |
+
resolved "https://registry.yarnpkg.com/negotiator/-/negotiator-0.6.2.tgz#feacf7ccf525a77ae9634436a64883ffeca346fb"
|
354 |
+
integrity sha512-hZXc7K2e+PgeI1eDBe/10Ard4ekbfrrqG8Ep+8Jmf4JID2bNg7NvCPOZN+kfF574pFQI7mum2AUqDidoKqcTOw==
|
355 |
+
|
356 |
+
on-finished@^2.3.0:
|
357 |
+
version "2.3.0"
|
358 |
+
resolved "https://registry.yarnpkg.com/on-finished/-/on-finished-2.3.0.tgz#20f1336481b083cd75337992a16971aa2d906947"
|
359 |
+
integrity sha1-IPEzZIGwg811M3mSoWlxqi2QaUc=
|
360 |
+
dependencies:
|
361 |
+
ee-first "1.1.1"
|
362 |
+
|
363 |
+
only@~0.0.2:
|
364 |
+
version "0.0.2"
|
365 |
+
resolved "https://registry.yarnpkg.com/only/-/only-0.0.2.tgz#2afde84d03e50b9a8edc444e30610a70295edfb4"
|
366 |
+
integrity sha1-Kv3oTQPlC5qO3EROMGEKcCle37Q=
|
367 |
+
|
368 |
+
parseurl@^1.3.2:
|
369 |
+
version "1.3.3"
|
370 |
+
resolved "https://registry.yarnpkg.com/parseurl/-/parseurl-1.3.3.tgz#9da19e7bee8d12dff0513ed5b76957793bc2e8d4"
|
371 |
+
integrity sha512-CiyeOxFT/JZyN5m0z9PfXw4SCBJ6Sygz1Dpl0wqjlhDEGGBP1GnsUVEL0p63hoG1fcj3fHynXi9NYO4nWOL+qQ==
|
372 |
+
|
373 |
+
path-to-regexp@^1.1.1:
|
374 |
+
version "1.7.0"
|
375 |
+
resolved "https://registry.yarnpkg.com/path-to-regexp/-/path-to-regexp-1.7.0.tgz#59fde0f435badacba103a84e9d3bc64e96b9937d"
|
376 |
+
integrity sha1-Wf3g9DW62suhA6hOnTvGTpa5k30=
|
377 |
+
dependencies:
|
378 |
+
isarray "0.0.1"
|
379 |
+
|
380 | |
381 |
+
version "5.1.2"
|
382 |
+
resolved "https://registry.yarnpkg.com/safe-buffer/-/safe-buffer-5.1.2.tgz#991ec69d296e0313747d59bdfd2b745c35f8828d"
|
383 |
+
integrity sha512-Gd2UZBJDkXlY7GbJxfsE8/nvKkUEU1G38c1siN6QP6a9PT9MmHB8GnpscSmMJSoF8LOIrt8ud/wPtojys4G6+g==
|
384 |
+
|
385 | |
386 |
+
version "1.1.1"
|
387 |
+
resolved "https://registry.yarnpkg.com/setprototypeof/-/setprototypeof-1.1.1.tgz#7e95acb24aa92f5885e0abef5ba131330d4ae683"
|
388 |
+
integrity sha512-JvdAWfbXeIGaZ9cILp38HntZSFSo3mWg6xGcJJsd+d4aRMOqauag1C63dJfDw7OaMYwEbHMOxEZ1lqVRYP2OAw==
|
389 |
+
|
390 |
+
"statuses@>= 1.5.0 < 2", statuses@^1.5.0:
|
391 |
+
version "1.5.0"
|
392 |
+
resolved "https://registry.yarnpkg.com/statuses/-/statuses-1.5.0.tgz#161c7dac177659fd9811f43771fa99381478628c"
|
393 |
+
integrity sha1-Fhx9rBd2Wf2YEfQ3cfqZOBR4Yow=
|
394 |
+
|
395 | |
396 |
+
version "1.0.0"
|
397 |
+
resolved "https://registry.yarnpkg.com/toidentifier/-/toidentifier-1.0.0.tgz#7e1be3470f1e77948bc43d94a3c8f4d7752ba553"
|
398 |
+
integrity sha512-yaOH/Pk/VEhBWWTlhI+qXxDFXlejDGcQipMlyxda9nthulaxLZUNcUqFxokp0vcYnvteJln5FNQDRrxj3YcbVw==
|
399 |
+
|
400 |
+
type-is@^1.6.16:
|
401 |
+
version "1.6.18"
|
402 |
+
resolved "https://registry.yarnpkg.com/type-is/-/type-is-1.6.18.tgz#4e552cd05df09467dcbc4ef739de89f2cf37c131"
|
403 |
+
integrity sha512-TkRKr9sUTxEH8MdfuCSP7VizJyzRNMjj2J2do2Jr3Kym598JVdEksuzPQCnlFPW4ky9Q+iA+ma9BGm06XQBy8g==
|
404 |
+
dependencies:
|
405 |
+
media-typer "0.3.0"
|
406 |
+
mime-types "~2.1.24"
|
407 |
+
|
408 |
+
typescript@^3.5.1:
|
409 |
+
version "3.5.1"
|
410 |
+
resolved "https://registry.yarnpkg.com/typescript/-/typescript-3.5.1.tgz#ba72a6a600b2158139c5dd8850f700e231464202"
|
411 |
+
integrity sha512-64HkdiRv1yYZsSe4xC1WVgamNigVYjlssIoaH2HcZF0+ijsk5YK2g0G34w9wJkze8+5ow4STd22AynfO6ZYYLw==
|
412 |
+
|
413 |
+
urijs@^1.19.0:
|
414 |
+
version "1.19.1"
|
415 |
+
resolved "https://registry.yarnpkg.com/urijs/-/urijs-1.19.1.tgz#5b0ff530c0cbde8386f6342235ba5ca6e995d25a"
|
416 |
+
integrity sha512-xVrGVi94ueCJNrBSTjWqjvtgvl3cyOTThp2zaMaFNGp3F542TR6sM3f2o8RqZl+AwteClSVmoCyt0ka4RjQOQg==
|
417 |
+
|
418 |
+
vary@^1.1.2:
|
419 |
+
version "1.1.2"
|
420 |
+
resolved "https://registry.yarnpkg.com/vary/-/vary-1.1.2.tgz#2299f02c6ded30d4a5961b0b9f74524a18f634fc"
|
421 |
+
integrity sha1-IpnwLG3tMNSllhsLn3RSShj2NPw=
|
422 |
+
|
423 |
+
ylru@^1.2.0:
|
424 |
+
version "1.2.1"
|
425 |
+
resolved "https://registry.yarnpkg.com/ylru/-/ylru-1.2.1.tgz#f576b63341547989c1de7ba288760923b27fe84f"
|
426 |
+
integrity sha512-faQrqNMzcPCHGVC2aaOINk13K+aaBDUPjGWl0teOXywElLjyVAB6Oe2jj62jHYtwsU49jXhScYbvPENK+6zAvQ==
|
backend/install.sh
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env bash
|
2 |
+
|
3 |
+
sudo apt install jq -y
|
4 |
+
pip install -r requirements.txt
|
5 |
+
cd id
|
6 |
+
npm install
|
backend/launch.sh
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env bash
|
2 |
+
pgrep -f gunicorn | xargs kill -9
|
3 |
+
kill $(lsof -t -i:3000)
|
4 |
+
|
5 |
+
cd id
|
6 |
+
npm run start &
|
7 |
+
sleep 5
|
8 |
+
cd -
|
9 |
+
|
10 |
+
if [[ "$1" == "" ]]; then
|
11 |
+
echo "JSON file argument not supplied. Exiting." 1>&2
|
12 |
+
exit 1
|
13 |
+
fi
|
14 |
+
|
15 |
+
# Number of GPUs
|
16 |
+
N_GPU=$(nvidia-smi -L | wc -l)
|
17 |
+
|
18 |
+
export FILE=$1
|
19 |
+
export GPU_PER_WORKER=`cat "$FILE" | jq -r .gpu_per_worker`
|
20 |
+
|
21 |
+
# Are there enough GPUs ?
|
22 |
+
if [[ $(($N_GPU / $GPU_PER_WORKER)) -eq 0 ]]; then
|
23 |
+
echo "Not enough GPUs to run this." 1>&2
|
24 |
+
exit 1
|
25 |
+
fi
|
26 |
+
|
27 |
+
N_WORKERS=$(($N_GPU / $GPU_PER_WORKER))
|
28 |
+
|
29 |
+
echo "File $FILE"
|
30 |
+
echo "Available GPUs $N_GPU"
|
31 |
+
echo "GPUs per worker $GPU_PER_WORKER"
|
32 |
+
echo "Total workers $N_WORKERS"
|
33 |
+
|
34 |
+
function sys_exit ()
|
35 |
+
{
|
36 |
+
echo "Ctrl-C caught...performing clean up"
|
37 |
+
echo "Cleaning up the servers."
|
38 |
+
echo $INST1
|
39 |
+
kill -9 $INST1
|
40 |
+
exit 2
|
41 |
+
|
42 |
+
}
|
43 |
+
|
44 |
+
trap "sys_exit" INT
|
45 |
+
|
46 |
+
echo "Running server with" ${N_WORKERS} "workers."
|
47 |
+
gunicorn --statsd-host=localhost:8125 -w ${N_WORKERS} API --bind=0.0.0.0:6006 --statsd-prefix=transformer-autocomplete -t 600 &
|
48 |
+
INST1=$!
|
49 |
+
|
50 |
+
while true; do sleep 1000; done
|
backend/machine_configurations/neuralgenv2.json
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"models_to_load": [
|
3 |
+
"gpt2/small",
|
4 |
+
"gpt2/medium",
|
5 |
+
"gpt2/large",
|
6 |
+
"gpt2/arxiv-nlp",
|
7 |
+
|
8 |
+
"gpt",
|
9 |
+
"xlnet",
|
10 |
+
"distilgpt2/small"
|
11 |
+
],
|
12 |
+
"gpu_per_worker": 1
|
13 |
+
}
|
backend/machine_configurations/transformer-autocomplete.json
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"models_to_load": [
|
3 |
+
"ctrl",
|
4 |
+
"gpt2/xl"
|
5 |
+
],
|
6 |
+
"gpu_per_worker": 2,
|
7 |
+
"cached_models": {
|
8 |
+
"gpt2/xl": "/datadrive/transformer-autocomplete/backend/gpt2-xl-local",
|
9 |
+
"ctrl": "/datadrive/transformer-autocomplete/backend/ctrl-local"
|
10 |
+
}
|
11 |
+
}
|
backend/requirements.txt
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
falcon
|
2 |
+
gunicorn
|
3 |
+
torch
|
4 |
+
transformers
|
backend/run_pplm_discrim_train.py
ADDED
@@ -0,0 +1,582 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#! /usr/bin/env python3
|
2 |
+
# coding=utf-8
|
3 |
+
|
4 |
+
# This code is licensed under a non-commercial license.
|
5 |
+
|
6 |
+
import argparse
|
7 |
+
import csv
|
8 |
+
import json
|
9 |
+
import math
|
10 |
+
import time
|
11 |
+
|
12 |
+
import numpy as np
|
13 |
+
import torch
|
14 |
+
import torch.nn.functional as F
|
15 |
+
import torch.optim
|
16 |
+
import torch.optim as optim
|
17 |
+
import torch.utils.data as data
|
18 |
+
from nltk.tokenize.treebank import TreebankWordDetokenizer
|
19 |
+
from torchtext import data as torchtext_data
|
20 |
+
from torchtext import datasets
|
21 |
+
from tqdm import tqdm, trange
|
22 |
+
|
23 |
+
from transformers import GPT2Tokenizer, GPT2LMHeadModel
|
24 |
+
|
25 |
+
torch.manual_seed(0)
|
26 |
+
np.random.seed(0)
|
27 |
+
EPSILON = 1e-10
|
28 |
+
device = "cpu"
|
29 |
+
example_sentence = "This is incredible! I love it, this is the best chicken I have ever had."
|
30 |
+
max_length_seq = 100
|
31 |
+
|
32 |
+
|
33 |
+
class ClassificationHead(torch.nn.Module):
|
34 |
+
"""Classification Head for transformer encoders"""
|
35 |
+
|
36 |
+
def __init__(self, class_size, embed_size):
|
37 |
+
super(ClassificationHead, self).__init__()
|
38 |
+
self.class_size = class_size
|
39 |
+
self.embed_size = embed_size
|
40 |
+
# self.mlp1 = torch.nn.Linear(embed_size, embed_size)
|
41 |
+
# self.mlp2 = (torch.nn.Linear(embed_size, class_size))
|
42 |
+
self.mlp = torch.nn.Linear(embed_size, class_size)
|
43 |
+
|
44 |
+
def forward(self, hidden_state):
|
45 |
+
# hidden_state = F.relu(self.mlp1(hidden_state))
|
46 |
+
# hidden_state = self.mlp2(hidden_state)
|
47 |
+
logits = self.mlp(hidden_state)
|
48 |
+
return logits
|
49 |
+
|
50 |
+
|
51 |
+
class Discriminator(torch.nn.Module):
|
52 |
+
"""Transformer encoder followed by a Classification Head"""
|
53 |
+
|
54 |
+
def __init__(
|
55 |
+
self,
|
56 |
+
class_size,
|
57 |
+
pretrained_model="gpt2-medium",
|
58 |
+
cached_mode=False
|
59 |
+
):
|
60 |
+
super(Discriminator, self).__init__()
|
61 |
+
self.tokenizer = GPT2Tokenizer.from_pretrained(pretrained_model)
|
62 |
+
self.encoder = GPT2LMHeadModel.from_pretrained(pretrained_model)
|
63 |
+
self.embed_size = self.encoder.transformer.config.hidden_size
|
64 |
+
self.classifier_head = ClassificationHead(
|
65 |
+
class_size=class_size,
|
66 |
+
embed_size=self.embed_size
|
67 |
+
)
|
68 |
+
self.cached_mode = cached_mode
|
69 |
+
|
70 |
+
def get_classifier(self):
|
71 |
+
return self.classifier_head
|
72 |
+
|
73 |
+
def train_custom(self):
|
74 |
+
for param in self.encoder.parameters():
|
75 |
+
param.requires_grad = False
|
76 |
+
self.classifier_head.train()
|
77 |
+
|
78 |
+
def avg_representation(self, x):
|
79 |
+
mask = x.ne(0).unsqueeze(2).repeat(
|
80 |
+
1, 1, self.embed_size
|
81 |
+
).float().to(device).detach()
|
82 |
+
hidden, _ = self.encoder.transformer(x)
|
83 |
+
masked_hidden = hidden * mask
|
84 |
+
avg_hidden = torch.sum(masked_hidden, dim=1) / (
|
85 |
+
torch.sum(mask, dim=1).detach() + EPSILON
|
86 |
+
)
|
87 |
+
return avg_hidden
|
88 |
+
|
89 |
+
def forward(self, x):
|
90 |
+
if self.cached_mode:
|
91 |
+
avg_hidden = x.to(device)
|
92 |
+
else:
|
93 |
+
avg_hidden = self.avg_representation(x.to(device))
|
94 |
+
|
95 |
+
logits = self.classifier_head(avg_hidden)
|
96 |
+
probs = F.log_softmax(logits, dim=-1)
|
97 |
+
|
98 |
+
return probs
|
99 |
+
|
100 |
+
|
101 |
+
class Dataset(data.Dataset):
|
102 |
+
def __init__(self, X, y):
|
103 |
+
"""Reads source and target sequences from txt files."""
|
104 |
+
self.X = X
|
105 |
+
self.y = y
|
106 |
+
|
107 |
+
def __len__(self):
|
108 |
+
return len(self.X)
|
109 |
+
|
110 |
+
def __getitem__(self, index):
|
111 |
+
"""Returns one data pair (source and target)."""
|
112 |
+
data = {}
|
113 |
+
data["X"] = self.X[index]
|
114 |
+
data["y"] = self.y[index]
|
115 |
+
return data
|
116 |
+
|
117 |
+
|
118 |
+
def collate_fn(data):
|
119 |
+
def pad_sequences(sequences):
|
120 |
+
lengths = [len(seq) for seq in sequences]
|
121 |
+
|
122 |
+
padded_sequences = torch.zeros(
|
123 |
+
len(sequences),
|
124 |
+
max(lengths)
|
125 |
+
).long() # padding value = 0
|
126 |
+
|
127 |
+
for i, seq in enumerate(sequences):
|
128 |
+
end = lengths[i]
|
129 |
+
padded_sequences[i, :end] = seq[:end]
|
130 |
+
|
131 |
+
return padded_sequences, lengths
|
132 |
+
|
133 |
+
item_info = {}
|
134 |
+
for key in data[0].keys():
|
135 |
+
item_info[key] = [d[key] for d in data]
|
136 |
+
|
137 |
+
x_batch, _ = pad_sequences(item_info["X"])
|
138 |
+
y_batch = torch.tensor(item_info["y"], dtype=torch.long)
|
139 |
+
|
140 |
+
return x_batch, y_batch
|
141 |
+
|
142 |
+
|
143 |
+
def cached_collate_fn(data):
|
144 |
+
item_info = {}
|
145 |
+
for key in data[0].keys():
|
146 |
+
item_info[key] = [d[key] for d in data]
|
147 |
+
|
148 |
+
x_batch = torch.cat(item_info["X"], 0)
|
149 |
+
y_batch = torch.tensor(item_info["y"], dtype=torch.long)
|
150 |
+
|
151 |
+
return x_batch, y_batch
|
152 |
+
|
153 |
+
|
154 |
+
def train_epoch(data_loader, discriminator, optimizer,
|
155 |
+
epoch=0, log_interval=10):
|
156 |
+
samples_so_far = 0
|
157 |
+
discriminator.train_custom()
|
158 |
+
for batch_idx, (input_t, target_t) in enumerate(data_loader):
|
159 |
+
input_t, target_t = input_t.to(device), target_t.to(device)
|
160 |
+
|
161 |
+
optimizer.zero_grad()
|
162 |
+
|
163 |
+
output_t = discriminator(input_t)
|
164 |
+
loss = F.nll_loss(output_t, target_t)
|
165 |
+
loss.backward(retain_graph=True)
|
166 |
+
optimizer.step()
|
167 |
+
|
168 |
+
samples_so_far += len(input_t)
|
169 |
+
|
170 |
+
if batch_idx % log_interval == 0:
|
171 |
+
print(
|
172 |
+
"Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format(
|
173 |
+
epoch + 1,
|
174 |
+
samples_so_far, len(data_loader.dataset),
|
175 |
+
100 * samples_so_far / len(data_loader.dataset), loss.item()
|
176 |
+
)
|
177 |
+
)
|
178 |
+
|
179 |
+
|
180 |
+
def evaluate_performance(data_loader, discriminator):
|
181 |
+
discriminator.eval()
|
182 |
+
test_loss = 0
|
183 |
+
correct = 0
|
184 |
+
with torch.no_grad():
|
185 |
+
for input_t, target_t in data_loader:
|
186 |
+
input_t, target_t = input_t.to(device), target_t.to(device)
|
187 |
+
output_t = discriminator(input_t)
|
188 |
+
# sum up batch loss
|
189 |
+
test_loss += F.nll_loss(output_t, target_t, reduction="sum").item()
|
190 |
+
# get the index of the max log-probability
|
191 |
+
pred_t = output_t.argmax(dim=1, keepdim=True)
|
192 |
+
correct += pred_t.eq(target_t.view_as(pred_t)).sum().item()
|
193 |
+
|
194 |
+
test_loss /= len(data_loader.dataset)
|
195 |
+
|
196 |
+
print(
|
197 |
+
"Performance on test set: "
|
198 |
+
"Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)".format(
|
199 |
+
test_loss, correct, len(data_loader.dataset),
|
200 |
+
100. * correct / len(data_loader.dataset)
|
201 |
+
)
|
202 |
+
)
|
203 |
+
|
204 |
+
|
205 |
+
def predict(input_sentence, model, classes, cached=False):
|
206 |
+
input_t = model.tokenizer.encode(input_sentence)
|
207 |
+
input_t = torch.tensor([input_t], dtype=torch.long, device=device)
|
208 |
+
if cached:
|
209 |
+
input_t = model.avg_representation(input_t)
|
210 |
+
|
211 |
+
log_probs = model(input_t).data.cpu().numpy().flatten().tolist()
|
212 |
+
print("Input sentence:", input_sentence)
|
213 |
+
print("Predictions:", ", ".join(
|
214 |
+
"{}: {:.4f}".format(c, math.exp(log_prob)) for c, log_prob in
|
215 |
+
zip(classes, log_probs)
|
216 |
+
))
|
217 |
+
|
218 |
+
|
219 |
+
def get_cached_data_loader(dataset, batch_size, discriminator, shuffle=False):
|
220 |
+
data_loader = torch.utils.data.DataLoader(dataset=dataset,
|
221 |
+
batch_size=batch_size,
|
222 |
+
collate_fn=collate_fn)
|
223 |
+
|
224 |
+
xs = []
|
225 |
+
ys = []
|
226 |
+
for batch_idx, (x, y) in enumerate(tqdm(data_loader, ascii=True)):
|
227 |
+
with torch.no_grad():
|
228 |
+
x = x.to(device)
|
229 |
+
avg_rep = discriminator.avg_representation(x).cpu().detach()
|
230 |
+
avg_rep_list = torch.unbind(avg_rep.unsqueeze(1))
|
231 |
+
xs += avg_rep_list
|
232 |
+
ys += y.cpu().numpy().tolist()
|
233 |
+
|
234 |
+
data_loader = torch.utils.data.DataLoader(
|
235 |
+
dataset=Dataset(xs, ys),
|
236 |
+
batch_size=batch_size,
|
237 |
+
shuffle=shuffle,
|
238 |
+
collate_fn=cached_collate_fn)
|
239 |
+
|
240 |
+
return data_loader
|
241 |
+
|
242 |
+
|
243 |
+
def train_discriminator(
|
244 |
+
dataset, dataset_fp=None, pretrained_model="gpt2-medium",
|
245 |
+
epochs=10, batch_size=64, log_interval=10,
|
246 |
+
save_model=False, cached=False, no_cuda=False):
|
247 |
+
global device
|
248 |
+
device = "cuda" if torch.cuda.is_available() and not no_cuda else "cpu"
|
249 |
+
|
250 |
+
print("Preprocessing {} dataset...".format(dataset))
|
251 |
+
start = time.time()
|
252 |
+
|
253 |
+
if dataset == "SST":
|
254 |
+
idx2class = ["positive", "negative", "very positive", "very negative",
|
255 |
+
"neutral"]
|
256 |
+
class2idx = {c: i for i, c in enumerate(idx2class)}
|
257 |
+
|
258 |
+
discriminator = Discriminator(
|
259 |
+
class_size=len(idx2class),
|
260 |
+
pretrained_model=pretrained_model,
|
261 |
+
cached_mode=cached
|
262 |
+
).to(device)
|
263 |
+
|
264 |
+
text = torchtext_data.Field()
|
265 |
+
label = torchtext_data.Field(sequential=False)
|
266 |
+
train_data, val_data, test_data = datasets.SST.splits(
|
267 |
+
text,
|
268 |
+
label,
|
269 |
+
fine_grained=True,
|
270 |
+
train_subtrees=True,
|
271 |
+
)
|
272 |
+
|
273 |
+
x = []
|
274 |
+
y = []
|
275 |
+
for i in trange(len(train_data), ascii=True):
|
276 |
+
seq = TreebankWordDetokenizer().detokenize(
|
277 |
+
vars(train_data[i])["text"]
|
278 |
+
)
|
279 |
+
seq = discriminator.tokenizer.encode(seq)
|
280 |
+
seq = torch.tensor([50256] + seq, device=device, dtype=torch.long)
|
281 |
+
x.append(seq)
|
282 |
+
y.append(class2idx[vars(train_data[i])["label"]])
|
283 |
+
train_dataset = Dataset(x, y)
|
284 |
+
|
285 |
+
test_x = []
|
286 |
+
test_y = []
|
287 |
+
for i in trange(len(test_data), ascii=True):
|
288 |
+
seq = TreebankWordDetokenizer().detokenize(
|
289 |
+
vars(test_data[i])["text"]
|
290 |
+
)
|
291 |
+
seq = discriminator.tokenizer.encode(seq)
|
292 |
+
seq = torch.tensor([50256] + seq, device=device, dtype=torch.long)
|
293 |
+
test_x.append(seq)
|
294 |
+
test_y.append(class2idx[vars(test_data[i])["label"]])
|
295 |
+
test_dataset = Dataset(test_x, test_y)
|
296 |
+
|
297 |
+
discriminator_meta = {
|
298 |
+
"class_size": len(idx2class),
|
299 |
+
"embed_size": discriminator.embed_size,
|
300 |
+
"pretrained_model": pretrained_model,
|
301 |
+
"class_vocab": class2idx,
|
302 |
+
"default_class": 2,
|
303 |
+
}
|
304 |
+
|
305 |
+
elif dataset == "clickbait":
|
306 |
+
idx2class = ["non_clickbait", "clickbait"]
|
307 |
+
class2idx = {c: i for i, c in enumerate(idx2class)}
|
308 |
+
|
309 |
+
discriminator = Discriminator(
|
310 |
+
class_size=len(idx2class),
|
311 |
+
pretrained_model=pretrained_model,
|
312 |
+
cached_mode=cached
|
313 |
+
).to(device)
|
314 |
+
|
315 |
+
with open("datasets/clickbait/clickbait_train_prefix.txt") as f:
|
316 |
+
data = []
|
317 |
+
for i, line in enumerate(f):
|
318 |
+
try:
|
319 |
+
data.append(eval(line))
|
320 |
+
except:
|
321 |
+
print("Error evaluating line {}: {}".format(
|
322 |
+
i, line
|
323 |
+
))
|
324 |
+
continue
|
325 |
+
x = []
|
326 |
+
y = []
|
327 |
+
with open("datasets/clickbait/clickbait_train_prefix.txt") as f:
|
328 |
+
for i, line in enumerate(tqdm(f, ascii=True)):
|
329 |
+
try:
|
330 |
+
d = eval(line)
|
331 |
+
seq = discriminator.tokenizer.encode(d["text"])
|
332 |
+
|
333 |
+
if len(seq) < max_length_seq:
|
334 |
+
seq = torch.tensor(
|
335 |
+
[50256] + seq, device=device, dtype=torch.long
|
336 |
+
)
|
337 |
+
else:
|
338 |
+
print("Line {} is longer than maximum length {}".format(
|
339 |
+
i, max_length_seq
|
340 |
+
))
|
341 |
+
continue
|
342 |
+
x.append(seq)
|
343 |
+
y.append(d["label"])
|
344 |
+
except:
|
345 |
+
print("Error evaluating / tokenizing"
|
346 |
+
" line {}, skipping it".format(i))
|
347 |
+
pass
|
348 |
+
|
349 |
+
full_dataset = Dataset(x, y)
|
350 |
+
train_size = int(0.9 * len(full_dataset))
|
351 |
+
test_size = len(full_dataset) - train_size
|
352 |
+
train_dataset, test_dataset = torch.utils.data.random_split(
|
353 |
+
full_dataset, [train_size, test_size]
|
354 |
+
)
|
355 |
+
|
356 |
+
discriminator_meta = {
|
357 |
+
"class_size": len(idx2class),
|
358 |
+
"embed_size": discriminator.embed_size,
|
359 |
+
"pretrained_model": pretrained_model,
|
360 |
+
"class_vocab": class2idx,
|
361 |
+
"default_class": 1,
|
362 |
+
}
|
363 |
+
|
364 |
+
elif dataset == "toxic":
|
365 |
+
idx2class = ["non_toxic", "toxic"]
|
366 |
+
class2idx = {c: i for i, c in enumerate(idx2class)}
|
367 |
+
|
368 |
+
discriminator = Discriminator(
|
369 |
+
class_size=len(idx2class),
|
370 |
+
pretrained_model=pretrained_model,
|
371 |
+
cached_mode=cached
|
372 |
+
).to(device)
|
373 |
+
|
374 |
+
x = []
|
375 |
+
y = []
|
376 |
+
with open("datasets/toxic/toxic_train.txt") as f:
|
377 |
+
for i, line in enumerate(tqdm(f, ascii=True)):
|
378 |
+
try:
|
379 |
+
d = eval(line)
|
380 |
+
seq = discriminator.tokenizer.encode(d["text"])
|
381 |
+
|
382 |
+
if len(seq) < max_length_seq:
|
383 |
+
seq = torch.tensor(
|
384 |
+
[50256] + seq, device=device, dtype=torch.long
|
385 |
+
)
|
386 |
+
else:
|
387 |
+
print("Line {} is longer than maximum length {}".format(
|
388 |
+
i, max_length_seq
|
389 |
+
))
|
390 |
+
continue
|
391 |
+
x.append(seq)
|
392 |
+
y.append(int(np.sum(d["label"]) > 0))
|
393 |
+
except:
|
394 |
+
print("Error evaluating / tokenizing"
|
395 |
+
" line {}, skipping it".format(i))
|
396 |
+
pass
|
397 |
+
|
398 |
+
full_dataset = Dataset(x, y)
|
399 |
+
train_size = int(0.9 * len(full_dataset))
|
400 |
+
test_size = len(full_dataset) - train_size
|
401 |
+
train_dataset, test_dataset = torch.utils.data.random_split(
|
402 |
+
full_dataset, [train_size, test_size]
|
403 |
+
)
|
404 |
+
|
405 |
+
discriminator_meta = {
|
406 |
+
"class_size": len(idx2class),
|
407 |
+
"embed_size": discriminator.embed_size,
|
408 |
+
"pretrained_model": pretrained_model,
|
409 |
+
"class_vocab": class2idx,
|
410 |
+
"default_class": 0,
|
411 |
+
}
|
412 |
+
|
413 |
+
else: # if dataset == "generic":
|
414 |
+
# This assumes the input dataset is a TSV with the following structure:
|
415 |
+
# class \t text
|
416 |
+
|
417 |
+
if dataset_fp is None:
|
418 |
+
raise ValueError("When generic dataset is selected, "
|
419 |
+
"dataset_fp needs to be specified aswell.")
|
420 |
+
|
421 |
+
classes = set()
|
422 |
+
with open(dataset_fp) as f:
|
423 |
+
csv_reader = csv.reader(f, delimiter="\t")
|
424 |
+
for row in tqdm(csv_reader, ascii=True):
|
425 |
+
if row:
|
426 |
+
classes.add(row[0])
|
427 |
+
|
428 |
+
idx2class = sorted(classes)
|
429 |
+
class2idx = {c: i for i, c in enumerate(idx2class)}
|
430 |
+
|
431 |
+
discriminator = Discriminator(
|
432 |
+
class_size=len(idx2class),
|
433 |
+
pretrained_model=pretrained_model,
|
434 |
+
cached_mode=cached
|
435 |
+
).to(device)
|
436 |
+
|
437 |
+
x = []
|
438 |
+
y = []
|
439 |
+
with open(dataset_fp) as f:
|
440 |
+
csv_reader = csv.reader(f, delimiter="\t")
|
441 |
+
for i, row in enumerate(tqdm(csv_reader, ascii=True)):
|
442 |
+
if row:
|
443 |
+
label = row[0]
|
444 |
+
text = row[1]
|
445 |
+
|
446 |
+
try:
|
447 |
+
seq = discriminator.tokenizer.encode(text)
|
448 |
+
if (len(seq) < max_length_seq):
|
449 |
+
seq = torch.tensor(
|
450 |
+
[50256] + seq,
|
451 |
+
device=device,
|
452 |
+
dtype=torch.long
|
453 |
+
)
|
454 |
+
|
455 |
+
else:
|
456 |
+
print(
|
457 |
+
"Line {} is longer than maximum length {}".format(
|
458 |
+
i, max_length_seq
|
459 |
+
))
|
460 |
+
continue
|
461 |
+
|
462 |
+
x.append(seq)
|
463 |
+
y.append(class2idx[label])
|
464 |
+
|
465 |
+
except:
|
466 |
+
print("Error tokenizing line {}, skipping it".format(i))
|
467 |
+
pass
|
468 |
+
|
469 |
+
full_dataset = Dataset(x, y)
|
470 |
+
train_size = int(0.9 * len(full_dataset))
|
471 |
+
test_size = len(full_dataset) - train_size
|
472 |
+
train_dataset, test_dataset = torch.utils.data.random_split(
|
473 |
+
full_dataset,
|
474 |
+
[train_size, test_size]
|
475 |
+
)
|
476 |
+
|
477 |
+
discriminator_meta = {
|
478 |
+
"class_size": len(idx2class),
|
479 |
+
"embed_size": discriminator.embed_size,
|
480 |
+
"pretrained_model": pretrained_model,
|
481 |
+
"class_vocab": class2idx,
|
482 |
+
"default_class": 0,
|
483 |
+
}
|
484 |
+
|
485 |
+
end = time.time()
|
486 |
+
print("Preprocessed {} data points".format(
|
487 |
+
len(train_dataset) + len(test_dataset))
|
488 |
+
)
|
489 |
+
print("Data preprocessing took: {:.3f}s".format(end - start))
|
490 |
+
|
491 |
+
if cached:
|
492 |
+
print("Building representation cache...")
|
493 |
+
|
494 |
+
start = time.time()
|
495 |
+
|
496 |
+
train_loader = get_cached_data_loader(
|
497 |
+
train_dataset, batch_size, discriminator, shuffle=True
|
498 |
+
)
|
499 |
+
|
500 |
+
test_loader = get_cached_data_loader(
|
501 |
+
test_dataset, batch_size, discriminator
|
502 |
+
)
|
503 |
+
|
504 |
+
end = time.time()
|
505 |
+
print("Building representation cache took: {:.3f}s".format(end - start))
|
506 |
+
|
507 |
+
else:
|
508 |
+
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
|
509 |
+
batch_size=batch_size,
|
510 |
+
shuffle=True,
|
511 |
+
collate_fn=collate_fn)
|
512 |
+
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
|
513 |
+
batch_size=batch_size,
|
514 |
+
collate_fn=collate_fn)
|
515 |
+
|
516 |
+
if save_model:
|
517 |
+
with open("{}_classifier_head_meta.json".format(dataset),
|
518 |
+
"w") as meta_file:
|
519 |
+
json.dump(discriminator_meta, meta_file)
|
520 |
+
|
521 |
+
optimizer = optim.Adam(discriminator.parameters(), lr=0.0001)
|
522 |
+
|
523 |
+
for epoch in range(epochs):
|
524 |
+
start = time.time()
|
525 |
+
print("\nEpoch", epoch + 1)
|
526 |
+
|
527 |
+
train_epoch(
|
528 |
+
discriminator=discriminator,
|
529 |
+
data_loader=train_loader,
|
530 |
+
optimizer=optimizer,
|
531 |
+
epoch=epoch,
|
532 |
+
log_interval=log_interval
|
533 |
+
)
|
534 |
+
evaluate_performance(
|
535 |
+
data_loader=test_loader,
|
536 |
+
discriminator=discriminator
|
537 |
+
)
|
538 |
+
|
539 |
+
end = time.time()
|
540 |
+
print("Epoch took: {:.3f}s".format(end - start))
|
541 |
+
|
542 |
+
print("\nExample prediction")
|
543 |
+
predict(example_sentence, discriminator, idx2class, cached)
|
544 |
+
|
545 |
+
if save_model:
|
546 |
+
# torch.save(discriminator.state_dict(),
|
547 |
+
# "{}_discriminator_{}.pt".format(
|
548 |
+
# args.dataset, epoch + 1
|
549 |
+
# ))
|
550 |
+
torch.save(discriminator.get_classifier().state_dict(),
|
551 |
+
"{}_classifier_head_epoch_{}.pt".format(dataset,
|
552 |
+
epoch + 1))
|
553 |
+
|
554 |
+
|
555 |
+
if __name__ == "__main__":
|
556 |
+
parser = argparse.ArgumentParser(
|
557 |
+
description="Train a discriminator on top of GPT-2 representations")
|
558 |
+
parser.add_argument("--dataset", type=str, default="SST",
|
559 |
+
choices=("SST", "clickbait", "toxic", "generic"),
|
560 |
+
help="dataset to train the discriminator on."
|
561 |
+
"In case of generic, the dataset is expected"
|
562 |
+
"to be a TSBV file with structure: class \\t text")
|
563 |
+
parser.add_argument("--dataset_fp", type=str, default="",
|
564 |
+
help="File path of the dataset to use. "
|
565 |
+
"Needed only in case of generic datadset")
|
566 |
+
parser.add_argument("--pretrained_model", type=str, default="gpt2-medium",
|
567 |
+
help="Pretrained model to use as encoder")
|
568 |
+
parser.add_argument("--epochs", type=int, default=10, metavar="N",
|
569 |
+
help="Number of training epochs")
|
570 |
+
parser.add_argument("--batch_size", type=int, default=64, metavar="N",
|
571 |
+
help="input batch size for training (default: 64)")
|
572 |
+
parser.add_argument("--log_interval", type=int, default=10, metavar="N",
|
573 |
+
help="how many batches to wait before logging training status")
|
574 |
+
parser.add_argument("--save_model", action="store_true",
|
575 |
+
help="whether to save the model")
|
576 |
+
parser.add_argument("--cached", action="store_true",
|
577 |
+
help="whether to cache the input representations")
|
578 |
+
parser.add_argument("--no_cuda", action="store_true",
|
579 |
+
help="use to turn off cuda")
|
580 |
+
args = parser.parse_args()
|
581 |
+
|
582 |
+
train_discriminator(**(vars(args)))
|
entrypoint.sh
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/sh
|
2 |
+
|
3 |
+
defined_envs=$(printf '${%s} ' $(awk "END { for (name in ENVIRON) { print ( name ~ /NGINX_/ ) ? name : \"\" } }" < /dev/null ))
|
4 |
+
|
5 |
+
envsubst "$defined_envs" < nginx.conf > /etc/nginx/nginx.conf
|
6 |
+
|
7 |
+
nginx && node server/dist/server.js
|
front/.vscode/settings.json
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
// Configure glob patterns for excluding files and folders in searches. Inherits all glob patterns from the files.exclude setting.
|
3 |
+
"search.exclude": {
|
4 |
+
"dist": true,
|
5 |
+
"build": true,
|
6 |
+
}
|
7 |
+
}
|
front/assets/Icon-info.svg
ADDED
front/assets/Salesforce_logo.svg
ADDED
front/assets/Uber_logo.svg
ADDED
front/assets/cross-collab.svg
ADDED
front/assets/github-buttons.js
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/*!
|
2 |
+
* github-buttons v2.2.10
|
3 |
+
* (c) 2019 なつき
|
4 |
+
* @license BSD-2-Clause
|
5 |
+
*/
|
6 |
+
/**
|
7 |
+
* Julien: just modified to add a `transform: scale(1.5);` on the .widget
|
8 |
+
*/
|
9 |
+
!function(){"use strict";var e=window.document,t=e.location,o=window.encodeURIComponent,r=window.decodeURIComponent,n=window.Math,a=window.HTMLElement,i=window.XMLHttpRequest,l="https://unpkg.com/[email protected]/dist/buttons.html",c=i&&i.prototype&&"withCredentials"in i.prototype,d=c&&a&&a.prototype.attachShadow&&!a.prototype.attachShadow.prototype,s=function(e,t,o){e.addEventListener?e.addEventListener(t,o):e.attachEvent("on"+t,o)},u=function(e,t,o){e.removeEventListener?e.removeEventListener(t,o):e.detachEvent("on"+t,o)},h=function(e,t,o){var r=function(n){return u(e,t,r),o(n)};s(e,t,r)},f=function(e,t,o){var r=function(n){if(t.test(e.readyState))return u(e,"readystatechange",r),o(n)};s(e,"readystatechange",r)},p=function(e){return function(t,o,r){var n=e.createElement(t);if(o)for(var a in o){var i=o[a];null!=i&&(null!=n[a]?n[a]=i:n.setAttribute(a,i))}if(r)for(var l=0,c=r.length;l<c;l++){var d=r[l];n.appendChild("string"==typeof d?e.createTextNode(d):d)}return n}},g=p(e),b=function(e){var t;return function(){t||(t=1,e.apply(this,arguments))}},m="body{margin:0}a{color:#24292e;text-decoration:none;outline:0}.octicon{display:inline-block;vertical-align:text-top;fill:currentColor}.widget{ transform: scale(1.5); display:inline-block;overflow:hidden;font-family:-apple-system, BlinkMacSystemFont, \"Segoe UI\", Helvetica, Arial, sans-serif;font-size:0;white-space:nowrap;-webkit-user-select:none;-moz-user-select:none;-ms-user-select:none;user-select:none}.btn,.social-count{display:inline-block;height:14px;padding:2px 5px;font-size:11px;font-weight:600;line-height:14px;vertical-align:bottom;cursor:pointer;border:1px solid #c5c9cc;border-radius:0.25em}.btn{background-color:#eff3f6;background-image:-webkit-linear-gradient(top, #fafbfc, #eff3f6 90%);background-image:-moz-linear-gradient(top, #fafbfc, #eff3f6 90%);background-image:linear-gradient(180deg, #fafbfc, #eff3f6 90%);background-position:-1px -1px;background-repeat:repeat-x;background-size:110% 110%;border-color:rgba(27,31,35,0.2);-ms-filter:\"progid:DXImageTransform.Microsoft.Gradient(startColorstr='#FFFAFBFC', endColorstr='#FFEEF2F5')\";*filter:progid:DXImageTransform.Microsoft.Gradient(startColorstr='#FFFAFBFC', endColorstr='#FFEEF2F5')}.btn:active{background-color:#e9ecef;background-image:none;border-color:#a5a9ac;border-color:rgba(27,31,35,0.35);box-shadow:inset 0 0.15em 0.3em rgba(27,31,35,0.15)}.btn:focus,.btn:hover{background-color:#e6ebf1;background-image:-webkit-linear-gradient(top, #f0f3f6, #e6ebf1 90%);background-image:-moz-linear-gradient(top, #f0f3f6, #e6ebf1 90%);background-image:linear-gradient(180deg, #f0f3f6, #e6ebf1 90%);border-color:#a5a9ac;border-color:rgba(27,31,35,0.35);-ms-filter:\"progid:DXImageTransform.Microsoft.Gradient(startColorstr='#FFF0F3F6', endColorstr='#FFE5EAF0')\";*filter:progid:DXImageTransform.Microsoft.Gradient(startColorstr='#FFF0F3F6', endColorstr='#FFE5EAF0')}.social-count{position:relative;margin-left:5px;background-color:#fff}.social-count:focus,.social-count:hover{color:#0366d6}.social-count b,.social-count i{position:absolute;top:50%;left:0;display:block;width:0;height:0;margin:-4px 0 0 -4px;border:solid transparent;border-width:4px 4px 4px 0;_line-height:0;_border-top-color:red !important;_border-bottom-color:red !important;_border-left-color:red !important;_filter:chroma(color=red)}.social-count b{border-right-color:#c5c9cc}.social-count i{margin-left:-3px;border-right-color:#fff}.lg .btn,.lg .social-count{height:16px;padding:5px 10px;font-size:12px;line-height:16px}.lg .social-count{margin-left:6px}.lg .social-count b,.lg .social-count i{margin:-5px 0 0 -5px;border-width:5px 5px 5px 0}.lg .social-count i{margin-left:-4px}\n",v={"mark-github":{width:16,height:16,path:'<path fill-rule="evenodd" d="M8 0C3.58 0 0 3.58 0 8c0 3.54 2.29 6.53 5.47 7.59.4.07.55-.17.55-.38 0-.19-.01-.82-.01-1.49-2.01.37-2.53-.49-2.69-.94-.09-.23-.48-.94-.82-1.13-.28-.15-.68-.52-.01-.53.63-.01 1.08.58 1.23.82.72 1.21 1.87.87 2.33.66.07-.52.28-.87.51-1.07-1.78-.2-3.64-.89-3.64-3.95 0-.87.31-1.59.82-2.15-.08-.2-.36-1.02.08-2.12 0 0 .67-.21 2.2.82.64-.18 1.32-.27 2-.27.68 0 1.36.09 2 .27 1.53-1.04 2.2-.82 2.2-.82.44 1.1.16 1.92.08 2.12.51.56.82 1.27.82 2.15 0 3.07-1.87 3.75-3.65 3.95.29.25.54.73.54 1.48 0 1.07-.01 1.93-.01 2.2 0 .21.15.46.55.38A8.013 8.013 0 0 0 16 8c0-4.42-3.58-8-8-8z"/>'},eye:{width:16,height:16,path:'<path fill-rule="evenodd" d="M8.06 2C3 2 0 8 0 8s3 6 8.06 6C13 14 16 8 16 8s-3-6-7.94-6zM8 12c-2.2 0-4-1.78-4-4 0-2.2 1.8-4 4-4 2.22 0 4 1.8 4 4 0 2.22-1.78 4-4 4zm2-4c0 1.11-.89 2-2 2-1.11 0-2-.89-2-2 0-1.11.89-2 2-2 1.11 0 2 .89 2 2z"/>'},star:{width:14,height:16,path:'<path fill-rule="evenodd" d="M14 6l-4.9-.64L7 1 4.9 5.36 0 6l3.6 3.26L2.67 14 7 11.67 11.33 14l-.93-4.74L14 6z"/>'},"repo-forked":{width:10,height:16,path:'<path fill-rule="evenodd" d="M8 1a1.993 1.993 0 0 0-1 3.72V6L5 8 3 6V4.72A1.993 1.993 0 0 0 2 1a1.993 1.993 0 0 0-1 3.72V6.5l3 3v1.78A1.993 1.993 0 0 0 5 15a1.993 1.993 0 0 0 1-3.72V9.5l3-3V4.72A1.993 1.993 0 0 0 8 1zM2 4.2C1.34 4.2.8 3.65.8 3c0-.65.55-1.2 1.2-1.2.65 0 1.2.55 1.2 1.2 0 .65-.55 1.2-1.2 1.2zm3 10c-.66 0-1.2-.55-1.2-1.2 0-.65.55-1.2 1.2-1.2.65 0 1.2.55 1.2 1.2 0 .65-.55 1.2-1.2 1.2zm3-10c-.66 0-1.2-.55-1.2-1.2 0-.65.55-1.2 1.2-1.2.65 0 1.2.55 1.2 1.2 0 .65-.55 1.2-1.2 1.2z"/>'},"issue-opened":{width:14,height:16,path:'<path fill-rule="evenodd" d="M7 2.3c3.14 0 5.7 2.56 5.7 5.7s-2.56 5.7-5.7 5.7A5.71 5.71 0 0 1 1.3 8c0-3.14 2.56-5.7 5.7-5.7zM7 1C3.14 1 0 4.14 0 8s3.14 7 7 7 7-3.14 7-7-3.14-7-7-7zm1 3H6v5h2V4zm0 6H6v2h2v-2z"/>'},"cloud-download":{width:16,height:16,path:'<path fill-rule="evenodd" d="M9 12h2l-3 3-3-3h2V7h2v5zm3-8c0-.44-.91-3-4.5-3C5.08 1 3 2.92 3 5 1.02 5 0 6.52 0 8c0 1.53 1 3 3 3h3V9.7H3C1.38 9.7 1.3 8.28 1.3 8c0-.17.05-1.7 1.7-1.7h1.3V5c0-1.39 1.56-2.7 3.2-2.7 2.55 0 3.13 1.55 3.2 1.8v1.2H12c.81 0 2.7.22 2.7 2.2 0 2.09-2.25 2.2-2.7 2.2h-2V11h2c2.08 0 4-1.16 4-3.5C16 5.06 14.08 4 12 4z"/>'}},w={},x=function(e,t,o){var r=p(e.ownerDocument),n=e.appendChild(r("style",{type:"text/css"}));n.styleSheet?n.styleSheet.cssText=m:n.appendChild(e.ownerDocument.createTextNode(m));var a,l,d=r("a",{className:"btn",href:t.href,target:"_blank",innerHTML:(a=t["data-icon"],l=/^large$/i.test(t["data-size"])?16:14,a=(""+a).toLowerCase().replace(/^octicon-/,""),{}.hasOwnProperty.call(v,a)||(a="mark-github"),'<svg version="1.1" width="'+l*v[a].width/v[a].height+'" height="'+l+'" viewBox="0 0 '+v[a].width+" "+v[a].height+'" class="octicon octicon-'+a+'" aria-hidden="true">'+v[a].path+"</svg>"),"aria-label":t["aria-label"]||void 0},[" ",r("span",{},[t["data-text"]||""])]);/\.github\.com$/.test("."+d.hostname)?/^https?:\/\/((gist\.)?github\.com\/[^\/?#]+\/[^\/?#]+\/archive\/|github\.com\/[^\/?#]+\/[^\/?#]+\/releases\/download\/|codeload\.github\.com\/)/.test(d.href)&&(d.target="_top"):(d.href="#",d.target="_self");var u,h,g,x,y=e.appendChild(r("div",{className:"widget"+(/^large$/i.test(t["data-size"])?" lg":"")},[d]));/^(true|1)$/i.test(t["data-show-count"])&&"github.com"===d.hostname&&(u=d.pathname.replace(/^(?!\/)/,"/").match(/^\/([^\/?#]+)(?:\/([^\/?#]+)(?:\/(?:(subscription)|(fork)|(issues)|([^\/?#]+)))?)?(?:[\/?#]|$)/))&&!u[6]?(u[2]?(h="/repos/"+u[1]+"/"+u[2],u[3]?(x="subscribers_count",g="watchers"):u[4]?(x="forks_count",g="network"):u[5]?(x="open_issues_count",g="issues"):(x="stargazers_count",g="stargazers")):(h="/users/"+u[1],g=x="followers"),function(e,t){var o=w[e]||(w[e]=[]);if(!(o.push(t)>1)){var r=b(function(){for(delete w[e];t=o.shift();)t.apply(null,arguments)});if(c){var n=new i;s(n,"abort",r),s(n,"error",r),s(n,"load",function(){var e;try{e=JSON.parse(n.responseText)}catch(e){return void r(e)}r(200!==n.status,e)}),n.open("GET",e),n.send()}else{var a=this||window;a._=function(e){a._=null,r(200!==e.meta.status,e.data)};var l=p(a.document)("script",{async:!0,src:e+(/\?/.test(e)?"&":"?")+"callback=_"}),d=function(){a._&&a._({meta:{}})};s(l,"load",d),s(l,"error",d),l.readyState&&f(l,/de|m/,d),a.document.getElementsByTagName("head")[0].appendChild(l)}}}.call(this,"https://api.github.com"+h,function(e,t){if(!e){var n=t[x];y.appendChild(r("a",{className:"social-count",href:t.html_url+"/"+g,target:"_blank","aria-label":n+" "+x.replace(/_count$/,"").replace("_"," ").slice(0,n<2?-1:void 0)+" on GitHub"},[r("b"),r("i"),r("span",{},[(""+n).replace(/\B(?=(\d{3})+(?!\d))/g,",")])]))}o&&o(y)})):o&&o(y)},y=window.devicePixelRatio||1,C=function(e){return(y>1?n.ceil(n.round(e*y)/y*2)/2:n.ceil(e))||0},F=function(e,t){e.style.width=t[0]+"px",e.style.height=t[1]+"px"},k=function(t,r){if(null!=t&&null!=r)if(t.getAttribute&&(t=function(e){for(var t={href:e.href,title:e.title,"aria-label":e.getAttribute("aria-label")},o=["icon","text","size","show-count"],r=0,n=o.length;r<n;r++){var a="data-"+o[r];t[a]=e.getAttribute(a)}return null==t["data-text"]&&(t["data-text"]=e.textContent||e.innerText),t}(t)),d){var a=g("span",{title:t.title||void 0});x(a.attachShadow({mode:"closed"}),t,function(){r(a)})}else{var i=g("iframe",{src:"javascript:0",title:t.title||void 0,allowtransparency:!0,scrolling:"no",frameBorder:0});F(i,[0,0]),i.style.border="none";var c=function(){var a,d=i.contentWindow;try{a=d.document.body}catch(t){return void e.body.appendChild(i.parentNode.removeChild(i))}u(i,"load",c),x.call(d,a,t,function(e){var a=function(e){var t=e.offsetWidth,o=e.offsetHeight;if(e.getBoundingClientRect){var r=e.getBoundingClientRect();t=n.max(t,C(r.width)),o=n.max(o,C(r.height))}return[t,o]}(e);i.parentNode.removeChild(i),h(i,"load",function(){F(i,a)}),i.src=l+"#"+(i.name=function(e){var t=[];for(var r in e){var n=e[r];null!=n&&t.push(o(r)+"="+o(n))}return t.join("&")}(t)),r(i)})};s(i,"load",c),e.body.appendChild(i)}};t.protocol+"//"+t.host+t.pathname===l?x(e.body,function(e){for(var t={},o=e.split("&"),n=0,a=o.length;n<a;n++){var i=o[n];if(""!==i){var l=i.split("=");t[r(l[0])]=null!=l[1]?r(l.slice(1).join("=")):void 0}}return t}(window.name||t.hash.replace(/^#/,""))):function(t){if(/m/.test(e.readyState)||!/g/.test(e.readyState)&&!e.documentElement.doScroll)setTimeout(t);else if(e.addEventListener){var o=b(t);h(e,"DOMContentLoaded",o),h(window,"load",o)}else f(e,/m/,t)}(function(){for(var t=e.querySelectorAll?e.querySelectorAll("a.github-button"):function(){for(var t=[],o=e.getElementsByTagName("a"),r=0,n=o.length;r<n;r++)~(" "+o[r].className+" ").replace(/[ \t\n\f\r]+/g," ").indexOf(" github-button ")&&t.push(o[r]);return t}(),o=0,r=t.length;o<r;o++)!function(e){k(e,function(t){e.parentNode.replaceChild(t,e)})}(t[o])})}();
|
front/assets/huggingface_logo.svg
ADDED
front/assets/icon-back.svg
ADDED
front/assets/icon-publish.svg
ADDED
front/assets/iconmonstr-download-14.svg
ADDED
front/assets/iconmonstr-media-control-55.svg
ADDED
front/assets/iconmonstr-share-11-purple.svg
ADDED
front/assets/iconmonstr-share-11.svg
ADDED
front/assets/oval.svg
ADDED
front/assets/tail-spin.svg
ADDED
front/assets/thumbnail-large-distilgpt2.png
ADDED
front/assets/thumbnail-large-pplm.png
ADDED
front/assets/thumbnail-large.png
ADDED
front/assets/unicorn-tweaked.svg
ADDED
front/favicon.ico
ADDED
front/js-src/Api.ts
ADDED
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import { c } from './lib/Log';
|
2 |
+
|
3 |
+
|
4 |
+
interface AutocompleteOutput {
|
5 |
+
sentences: {
|
6 |
+
value: string;
|
7 |
+
time: number;
|
8 |
+
}[];
|
9 |
+
time: number;
|
10 |
+
}
|
11 |
+
|
12 |
+
export class Api {
|
13 |
+
|
14 |
+
private static ENDPOINT =
|
15 |
+
// `http://coconut-proxy.huggingface.test`
|
16 |
+
// `http://coconuthf.eastus.cloudapp.azure.com:6006`
|
17 |
+
// "http://localhost:6006"
|
18 |
+
`https://transformer.huggingface.co`
|
19 |
+
;
|
20 |
+
static shared = new Api();
|
21 |
+
|
22 |
+
private path(p: string): string {
|
23 |
+
return `${Api.ENDPOINT}/${p}`;
|
24 |
+
}
|
25 |
+
|
26 |
+
private async postAutocomplete(
|
27 |
+
params: {
|
28 |
+
context: string;
|
29 |
+
model_size?: string; /// 'small' | 'medium',
|
30 |
+
top_p?: number; /// float between 0 and 1
|
31 |
+
temperature?: number; /// float between 0 and 100
|
32 |
+
step_size?: number;
|
33 |
+
kl_scale?: number;
|
34 |
+
gm_scale?: number;
|
35 |
+
num_iterations?: number;
|
36 |
+
gen_length?: number;
|
37 |
+
max_time?: number; /// <- if we want to limit the response time. (in sec)
|
38 |
+
bow_or_discrim?: string;
|
39 |
+
use_sampling?: boolean;
|
40 |
+
}
|
41 |
+
): Promise<AutocompleteOutput> {
|
42 |
+
|
43 |
+
const path = this.path(`autocomplete/${params.model_size || ""}`);
|
44 |
+
|
45 |
+
const response = await fetch(path, {
|
46 |
+
method: 'POST',
|
47 |
+
headers: { 'Content-Type': 'application/json' },
|
48 |
+
body: JSON.stringify(params),
|
49 |
+
});
|
50 |
+
return await response.json() as AutocompleteOutput;
|
51 |
+
}
|
52 |
+
|
53 |
+
/**
|
54 |
+
* Demo-specific helpers
|
55 |
+
*/
|
56 |
+
async postWithSettings(
|
57 |
+
params: {
|
58 |
+
context: string;
|
59 |
+
}
|
60 |
+
): Promise<AutocompleteOutput> {
|
61 |
+
/// Retrieve all settings params then launch the request.
|
62 |
+
const model_size =
|
63 |
+
document.querySelector('.decoder-settings .setting.model_size .js-val')!.textContent
|
64 |
+
|| undefined;
|
65 |
+
|
66 |
+
const parseSliderVal = (sel: string): number | undefined => {
|
67 |
+
const x = document.querySelector(sel);
|
68 |
+
if (x && x.textContent) {
|
69 |
+
return Number(x.textContent);
|
70 |
+
}
|
71 |
+
return undefined;
|
72 |
+
};
|
73 |
+
|
74 |
+
const top_p = parseSliderVal('.decoder-settings .setting.top_p .js-val');
|
75 |
+
const temperature = parseSliderVal('.decoder-settings .setting.temperature .js-val');
|
76 |
+
const step_size = parseSliderVal('.decoder-settings .setting.step_size .js-val');
|
77 |
+
const kl_scale = parseSliderVal('.decoder-settings .setting.kl_scale .js-val');
|
78 |
+
const gm_scale = parseSliderVal('.decoder-settings .setting.gm_scale .js-val');
|
79 |
+
const num_iterations = parseSliderVal('.decoder-settings .setting.num_iterations .js-val');
|
80 |
+
const gen_length = parseSliderVal('.decoder-settings .setting.gen_length .js-val');
|
81 |
+
const max_time = parseSliderVal('.decoder-settings .setting.max_time .js-val');
|
82 |
+
|
83 |
+
const bow_or_discrim = (
|
84 |
+
document.querySelector<HTMLInputElement>('.decoder-settings input[name=bow_or_discrim]:checked') || {}
|
85 |
+
).value;
|
86 |
+
const use_sampling = (
|
87 |
+
document.querySelector<HTMLInputElement>('.decoder-settings input[name=use_sampling]') || {}
|
88 |
+
).checked;
|
89 |
+
|
90 |
+
return this.postAutocomplete({
|
91 |
+
...params,
|
92 |
+
model_size,
|
93 |
+
top_p,
|
94 |
+
temperature,
|
95 |
+
step_size,
|
96 |
+
kl_scale,
|
97 |
+
gm_scale,
|
98 |
+
num_iterations,
|
99 |
+
gen_length,
|
100 |
+
max_time,
|
101 |
+
bow_or_discrim,
|
102 |
+
use_sampling,
|
103 |
+
});
|
104 |
+
}
|
105 |
+
|
106 |
+
/**
|
107 |
+
* Edit AJAX endpoint
|
108 |
+
*
|
109 |
+
* Contrary to the autocomplete endpoint,
|
110 |
+
* this is on server,
|
111 |
+
* not on backend.
|
112 |
+
*/
|
113 |
+
async postEdit(body: any): Promise<boolean> {
|
114 |
+
const doc = (<any>window).doc as { [index: string]: string };
|
115 |
+
if (!doc || !doc.longId) {
|
116 |
+
throw new Error(`invalid doc`);
|
117 |
+
}
|
118 |
+
|
119 |
+
const path = `/edit/${doc.model}/${doc.longId}/${doc.shortId}`;
|
120 |
+
|
121 |
+
const response = await fetch(path, {
|
122 |
+
method: 'POST',
|
123 |
+
headers: { 'Content-Type': 'application/json' },
|
124 |
+
body: JSON.stringify(body),
|
125 |
+
});
|
126 |
+
return response.ok;
|
127 |
+
}
|
128 |
+
|
129 |
+
/**
|
130 |
+
* Duplicate AJAX endpoint
|
131 |
+
*
|
132 |
+
* Contrary to the autocomplete endpoint,
|
133 |
+
* this is on server,
|
134 |
+
* not on backend.
|
135 |
+
*/
|
136 |
+
async postDuplicate(): Promise<string> {
|
137 |
+
const doc = (<any>window).doc as { [index: string]: string };
|
138 |
+
if (!doc || !doc.shortId) {
|
139 |
+
throw new Error(`invalid doc`);
|
140 |
+
}
|
141 |
+
|
142 |
+
const path = `/duplicate/${doc.shortId}`;
|
143 |
+
const response = await fetch(path, {
|
144 |
+
method: 'POST',
|
145 |
+
headers: { 'Content-Type': 'application/json' },
|
146 |
+
});
|
147 |
+
const url = await response.text();
|
148 |
+
c.log('[new url]', url);
|
149 |
+
|
150 |
+
return url;
|
151 |
+
}
|
152 |
+
}
|
153 |
+
|
front/js-src/Mention.ts
ADDED
@@ -0,0 +1,441 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
interface Datum {
|
3 |
+
id: string;
|
4 |
+
value: string;
|
5 |
+
}
|
6 |
+
|
7 |
+
|
8 |
+
export class Mention {
|
9 |
+
static Keys = {
|
10 |
+
TAB: 9,
|
11 |
+
ENTER: 13,
|
12 |
+
ESCAPE: 27,
|
13 |
+
UP: 38,
|
14 |
+
DOWN: 40,
|
15 |
+
};
|
16 |
+
static numberIsNaN = (x: any) => x !== x;
|
17 |
+
private isOpen = false;
|
18 |
+
/**
|
19 |
+
* index of currently selected item.
|
20 |
+
*/
|
21 |
+
private itemIndex = 0;
|
22 |
+
private mentionCharPos: number | undefined = undefined;
|
23 |
+
private cursorPos: number | undefined = undefined;
|
24 |
+
private values = [] as Datum[];
|
25 |
+
private suspendMouseEnter = false;
|
26 |
+
private options = {
|
27 |
+
source: (searchTerm: string, renderList: Function, mentionChar: string) => {},
|
28 |
+
renderItem: (item: Datum, searchTerm: string) => {
|
29 |
+
return `${item.value}`;
|
30 |
+
},
|
31 |
+
onSelect: (item: DOMStringMap, insertItem: (item: DOMStringMap) => void) => {
|
32 |
+
insertItem(item);
|
33 |
+
},
|
34 |
+
mentionDenotationChars: ['@'],
|
35 |
+
showDenotationChar: true,
|
36 |
+
allowedChars: /^[a-zA-Z0-9_]*$/,
|
37 |
+
minChars: 0,
|
38 |
+
maxChars: 31,
|
39 |
+
offsetTop: 2,
|
40 |
+
offsetLeft: 0,
|
41 |
+
/**
|
42 |
+
* Whether or not the denotation character(s) should be isolated. For example, to avoid mentioning in an email.
|
43 |
+
*/
|
44 |
+
isolateCharacter: false,
|
45 |
+
fixMentionsToQuill: false,
|
46 |
+
defaultMenuOrientation: 'bottom',
|
47 |
+
dataAttributes: ['id', 'value', 'denotationChar', 'link', 'target'],
|
48 |
+
linkTarget: '_blank',
|
49 |
+
onOpen: () => true,
|
50 |
+
onClose: () => true,
|
51 |
+
// Style options
|
52 |
+
listItemClass: 'ql-mention-list-item',
|
53 |
+
mentionContainerClass: 'ql-mention-list-container',
|
54 |
+
mentionListClass: 'ql-mention-list',
|
55 |
+
};
|
56 |
+
/// HTML elements
|
57 |
+
private mentionContainer = document.createElement('div');
|
58 |
+
private mentionList = document.createElement('ul');
|
59 |
+
|
60 |
+
|
61 |
+
constructor(
|
62 |
+
private quill: Quill,
|
63 |
+
) {
|
64 |
+
this.mentionContainer.className = this.options.mentionContainerClass;
|
65 |
+
this.mentionContainer.style.cssText = 'display: none; position: absolute;';
|
66 |
+
this.mentionContainer.onmousemove = this.onContainerMouseMove.bind(this);
|
67 |
+
|
68 |
+
if (this.options.fixMentionsToQuill) {
|
69 |
+
this.mentionContainer.style.width = 'auto';
|
70 |
+
}
|
71 |
+
|
72 |
+
this.mentionList.className = this.options.mentionListClass;
|
73 |
+
this.mentionContainer.appendChild(this.mentionList);
|
74 |
+
|
75 |
+
this.quill.container.appendChild(this.mentionContainer);
|
76 |
+
|
77 |
+
quill.on('text-change', this.onTextChange.bind(this));
|
78 |
+
quill.on('selection-change', this.onSelectionChange.bind(this));
|
79 |
+
|
80 |
+
quill.keyboard.addBinding({
|
81 |
+
key: Mention.Keys.ENTER,
|
82 |
+
}, this.selectHandler.bind(this));
|
83 |
+
quill.keyboard.bindings[Mention.Keys.ENTER].unshift(
|
84 |
+
quill.keyboard.bindings[Mention.Keys.ENTER].pop()
|
85 |
+
);
|
86 |
+
/// ^^ place it at beginning of bindings.
|
87 |
+
|
88 |
+
quill.keyboard.addBinding({
|
89 |
+
key: Mention.Keys.ESCAPE,
|
90 |
+
}, this.escapeHandler.bind(this));
|
91 |
+
|
92 |
+
quill.keyboard.addBinding({
|
93 |
+
key: Mention.Keys.UP,
|
94 |
+
}, this.upHandler.bind(this));
|
95 |
+
|
96 |
+
quill.keyboard.addBinding({
|
97 |
+
key: Mention.Keys.DOWN,
|
98 |
+
}, this.downHandler.bind(this));
|
99 |
+
|
100 |
+
document.addEventListener("keypress", e => {
|
101 |
+
/// Quick’n’dirty hack.
|
102 |
+
if (! this.quill.hasFocus()) {
|
103 |
+
return ;
|
104 |
+
}
|
105 |
+
setTimeout(() => {
|
106 |
+
this.setCursorPos();
|
107 |
+
this.quill.removeFormat(this.cursorPos! - 1, 1, 'silent');
|
108 |
+
}, 0);
|
109 |
+
});
|
110 |
+
}
|
111 |
+
|
112 |
+
selectHandler() {
|
113 |
+
if (this.isOpen) {
|
114 |
+
this.selectItem();
|
115 |
+
return false;
|
116 |
+
}
|
117 |
+
return true;
|
118 |
+
}
|
119 |
+
|
120 |
+
escapeHandler() {
|
121 |
+
if (this.isOpen) {
|
122 |
+
this.hideMentionList();
|
123 |
+
return false;
|
124 |
+
}
|
125 |
+
return true;
|
126 |
+
}
|
127 |
+
|
128 |
+
upHandler() {
|
129 |
+
if (this.isOpen) {
|
130 |
+
this.prevItem();
|
131 |
+
return false;
|
132 |
+
}
|
133 |
+
return true;
|
134 |
+
}
|
135 |
+
|
136 |
+
downHandler() {
|
137 |
+
if (this.isOpen) {
|
138 |
+
this.nextItem();
|
139 |
+
return false;
|
140 |
+
}
|
141 |
+
return true;
|
142 |
+
}
|
143 |
+
|
144 |
+
showMentionList() {
|
145 |
+
this.mentionContainer.style.visibility = 'hidden';
|
146 |
+
this.mentionContainer.style.display = '';
|
147 |
+
this.setMentionContainerPosition();
|
148 |
+
this.setIsOpen(true);
|
149 |
+
}
|
150 |
+
|
151 |
+
hideMentionList() {
|
152 |
+
this.mentionContainer.style.display = 'none';
|
153 |
+
this.setIsOpen(false);
|
154 |
+
}
|
155 |
+
|
156 |
+
|
157 |
+
private highlightItem(scrollItemInView = true) {
|
158 |
+
const childNodes = Array.from(this.mentionList.childNodes) as HTMLLIElement[];
|
159 |
+
for (const node of childNodes) {
|
160 |
+
node.classList.remove('selected');
|
161 |
+
}
|
162 |
+
childNodes[this.itemIndex].classList.add('selected');
|
163 |
+
|
164 |
+
if (scrollItemInView) {
|
165 |
+
const itemHeight = childNodes[this.itemIndex].offsetHeight;
|
166 |
+
const itemPos = this.itemIndex * itemHeight;
|
167 |
+
const containerTop = this.mentionContainer.scrollTop;
|
168 |
+
const containerBottom = containerTop + this.mentionContainer.offsetHeight;
|
169 |
+
|
170 |
+
if (itemPos < containerTop) {
|
171 |
+
// Scroll up if the item is above the top of the container
|
172 |
+
this.mentionContainer.scrollTop = itemPos;
|
173 |
+
} else if (itemPos > (containerBottom - itemHeight)) {
|
174 |
+
// scroll down if any part of the element is below the bottom of the container
|
175 |
+
this.mentionContainer.scrollTop += (itemPos - containerBottom) + itemHeight;
|
176 |
+
}
|
177 |
+
}
|
178 |
+
}
|
179 |
+
|
180 |
+
private getItemData(): DOMStringMap {
|
181 |
+
const node = this.mentionList.childNodes[this.itemIndex] as HTMLElement;
|
182 |
+
const { link } = node.dataset;
|
183 |
+
const itemTarget = node.dataset.target;
|
184 |
+
if (link !== undefined) {
|
185 |
+
node.dataset.value = `<a href="${link}" target=${itemTarget || this.options.linkTarget}>${node.dataset.value}`;
|
186 |
+
}
|
187 |
+
return node.dataset;
|
188 |
+
}
|
189 |
+
|
190 |
+
onContainerMouseMove() {
|
191 |
+
this.suspendMouseEnter = false;
|
192 |
+
}
|
193 |
+
|
194 |
+
selectItem() {
|
195 |
+
const data = this.getItemData();
|
196 |
+
this.options.onSelect(data, (asyncData) => {
|
197 |
+
this.insertItem(asyncData);
|
198 |
+
});
|
199 |
+
this.hideMentionList();
|
200 |
+
}
|
201 |
+
|
202 |
+
insertItem(data: DOMStringMap) {
|
203 |
+
const render = data;
|
204 |
+
if (render === null) {
|
205 |
+
return ;
|
206 |
+
}
|
207 |
+
if (!this.options.showDenotationChar) {
|
208 |
+
render.denotationChar = '';
|
209 |
+
}
|
210 |
+
if (this.cursorPos === undefined) {
|
211 |
+
throw new Error(`Invalid this.cursorPos`);
|
212 |
+
}
|
213 |
+
if (!render.value) {
|
214 |
+
throw new Error(`Didn't receive value from server.`);
|
215 |
+
}
|
216 |
+
|
217 |
+
this.quill.insertText(this.cursorPos, render.value, 'bold', Quill.sources.USER);
|
218 |
+
this.quill.setSelection(this.cursorPos + render.value.length, 0);
|
219 |
+
this.setCursorPos();
|
220 |
+
this.hideMentionList();
|
221 |
+
}
|
222 |
+
|
223 |
+
onItemMouseEnter(e: MouseEvent) {
|
224 |
+
if (this.suspendMouseEnter) {
|
225 |
+
return ;
|
226 |
+
}
|
227 |
+
const index = Number(
|
228 |
+
(e.target as HTMLLIElement).dataset.index
|
229 |
+
);
|
230 |
+
if (! Mention.numberIsNaN(index) && index !== this.itemIndex) {
|
231 |
+
this.itemIndex = index;
|
232 |
+
this.highlightItem(false);
|
233 |
+
}
|
234 |
+
}
|
235 |
+
|
236 |
+
onItemClick(e: MouseEvent) {
|
237 |
+
e.stopImmediatePropagation();
|
238 |
+
e.preventDefault();
|
239 |
+
this.itemIndex = Number(
|
240 |
+
(e.currentTarget as HTMLElement).dataset.index
|
241 |
+
);
|
242 |
+
this.highlightItem();
|
243 |
+
this.selectItem();
|
244 |
+
}
|
245 |
+
|
246 |
+
private attachDataValues(element: HTMLLIElement, data: Datum): HTMLLIElement {
|
247 |
+
for (const [key, value] of Object.entries(data)) {
|
248 |
+
if (this.options.dataAttributes.includes(key)) {
|
249 |
+
element.dataset[key] = value;
|
250 |
+
} else {
|
251 |
+
delete element.dataset[key];
|
252 |
+
}
|
253 |
+
}
|
254 |
+
return element;
|
255 |
+
}
|
256 |
+
|
257 |
+
renderList(mentionChar: string, data: Datum[], searchTerm: string = "") {
|
258 |
+
if (data.length > 0) {
|
259 |
+
this.values = data;
|
260 |
+
this.mentionList.innerHTML = '';
|
261 |
+
|
262 |
+
for (const [i, datum] of data.entries()) {
|
263 |
+
const li = document.createElement('li');
|
264 |
+
li.className = this.options.listItemClass;
|
265 |
+
li.dataset.index = `${i}`;
|
266 |
+
// li.innerHTML = this.options.renderItem(datum, searchTerm);
|
267 |
+
li.innerText = datum.value.replace(/\n/g, "↵");
|
268 |
+
/// ^^
|
269 |
+
li.onmouseenter = this.onItemMouseEnter.bind(this);
|
270 |
+
li.dataset.denotationChar = mentionChar;
|
271 |
+
li.onclick = this.onItemClick.bind(this);
|
272 |
+
this.mentionList.appendChild(
|
273 |
+
this.attachDataValues(li, datum)
|
274 |
+
);
|
275 |
+
}
|
276 |
+
this.itemIndex = 0;
|
277 |
+
this.highlightItem();
|
278 |
+
this.showMentionList();
|
279 |
+
} else {
|
280 |
+
this.hideMentionList();
|
281 |
+
}
|
282 |
+
}
|
283 |
+
|
284 |
+
nextItem() {
|
285 |
+
this.itemIndex = (this.itemIndex + 1) % this.values.length;
|
286 |
+
this.suspendMouseEnter = true;
|
287 |
+
this.highlightItem();
|
288 |
+
}
|
289 |
+
|
290 |
+
prevItem() {
|
291 |
+
this.itemIndex = ((this.itemIndex + this.values.length) - 1) % this.values.length;
|
292 |
+
this.suspendMouseEnter = true;
|
293 |
+
this.highlightItem();
|
294 |
+
}
|
295 |
+
|
296 |
+
private hasValidChars(s: string) {
|
297 |
+
return this.options.allowedChars.test(s);
|
298 |
+
}
|
299 |
+
|
300 |
+
private containerBottomIsNotVisible(topPos: number, containerPos: ClientRect | DOMRect) {
|
301 |
+
const mentionContainerBottom = topPos + this.mentionContainer.offsetHeight + containerPos.top;
|
302 |
+
return mentionContainerBottom > window.pageYOffset + window.innerHeight;
|
303 |
+
}
|
304 |
+
|
305 |
+
private containerRightIsNotVisible(leftPos: number, containerPos: ClientRect | DOMRect) {
|
306 |
+
if (this.options.fixMentionsToQuill) {
|
307 |
+
return false;
|
308 |
+
}
|
309 |
+
const rightPos = leftPos + this.mentionContainer.offsetWidth + containerPos.left;
|
310 |
+
const browserWidth = window.pageXOffset + document.documentElement.clientWidth;
|
311 |
+
return rightPos > browserWidth;
|
312 |
+
}
|
313 |
+
|
314 |
+
private setIsOpen(isOpen: boolean) {
|
315 |
+
if (this.isOpen !== isOpen) {
|
316 |
+
if (isOpen) {
|
317 |
+
this.options.onOpen();
|
318 |
+
} else {
|
319 |
+
this.options.onClose();
|
320 |
+
}
|
321 |
+
this.isOpen = isOpen;
|
322 |
+
}
|
323 |
+
}
|
324 |
+
|
325 |
+
private setMentionContainerPosition() {
|
326 |
+
const containerPos = this.quill.container.getBoundingClientRect();
|
327 |
+
/// vv Here we always trigger from the cursor.
|
328 |
+
if (this.cursorPos === undefined) {
|
329 |
+
throw new Error(`Invalid this.cursorPos`);
|
330 |
+
}
|
331 |
+
const mentionCharPos = this.quill.getBounds(this.cursorPos);
|
332 |
+
const containerHeight = this.mentionContainer.offsetHeight;
|
333 |
+
|
334 |
+
let topPos = this.options.offsetTop;
|
335 |
+
let leftPos = this.options.offsetLeft;
|
336 |
+
|
337 |
+
// handle horizontal positioning
|
338 |
+
if (this.options.fixMentionsToQuill) {
|
339 |
+
const rightPos = 0;
|
340 |
+
this.mentionContainer.style.right = `${rightPos}px`;
|
341 |
+
} else {
|
342 |
+
leftPos += mentionCharPos.left;
|
343 |
+
}
|
344 |
+
|
345 |
+
if (this.containerRightIsNotVisible(leftPos, containerPos)) {
|
346 |
+
const containerWidth = this.mentionContainer.offsetWidth + this.options.offsetLeft;
|
347 |
+
const quillWidth = containerPos.width;
|
348 |
+
leftPos = quillWidth - containerWidth;
|
349 |
+
}
|
350 |
+
|
351 |
+
// handle vertical positioning
|
352 |
+
if (this.options.defaultMenuOrientation === 'top') {
|
353 |
+
// Attempt to align the mention container with the top of the quill editor
|
354 |
+
if (this.options.fixMentionsToQuill) {
|
355 |
+
topPos = -1 * (containerHeight + this.options.offsetTop);
|
356 |
+
} else {
|
357 |
+
topPos = mentionCharPos.top - (containerHeight + this.options.offsetTop);
|
358 |
+
}
|
359 |
+
|
360 |
+
// default to bottom if the top is not visible
|
361 |
+
if (topPos + containerPos.top <= 0) {
|
362 |
+
let overMentionCharPos = this.options.offsetTop;
|
363 |
+
|
364 |
+
if (this.options.fixMentionsToQuill) {
|
365 |
+
overMentionCharPos += containerPos.height;
|
366 |
+
} else {
|
367 |
+
overMentionCharPos += mentionCharPos.bottom;
|
368 |
+
}
|
369 |
+
|
370 |
+
topPos = overMentionCharPos;
|
371 |
+
}
|
372 |
+
} else {
|
373 |
+
// Attempt to align the mention container with the bottom of the quill editor
|
374 |
+
if (this.options.fixMentionsToQuill) {
|
375 |
+
topPos += containerPos.height;
|
376 |
+
} else {
|
377 |
+
topPos += mentionCharPos.bottom;
|
378 |
+
}
|
379 |
+
|
380 |
+
// default to the top if the bottom is not visible
|
381 |
+
if (this.containerBottomIsNotVisible(topPos, containerPos)) {
|
382 |
+
let overMentionCharPos = this.options.offsetTop * -1;
|
383 |
+
|
384 |
+
if (!this.options.fixMentionsToQuill) {
|
385 |
+
overMentionCharPos += mentionCharPos.top;
|
386 |
+
}
|
387 |
+
|
388 |
+
topPos = overMentionCharPos - containerHeight;
|
389 |
+
}
|
390 |
+
}
|
391 |
+
|
392 |
+
this.mentionContainer.style.top = `${topPos}px`;
|
393 |
+
this.mentionContainer.style.left = `${leftPos}px`;
|
394 |
+
this.mentionContainer.style.visibility = 'visible';
|
395 |
+
}
|
396 |
+
|
397 |
+
|
398 |
+
/**
|
399 |
+
* HF Helpers for manual trigger
|
400 |
+
*/
|
401 |
+
setCursorPos() {
|
402 |
+
const range = this.quill.getSelection();
|
403 |
+
if (range) {
|
404 |
+
this.cursorPos = range.index;
|
405 |
+
} else {
|
406 |
+
this.quill.setSelection(this.quill.getLength(), 0);
|
407 |
+
/// ^^ place cursor at the end of input by default.
|
408 |
+
this.cursorPos = this.quill.getLength();
|
409 |
+
}
|
410 |
+
}
|
411 |
+
getCursorPos(): number {
|
412 |
+
return this.cursorPos!;
|
413 |
+
}
|
414 |
+
trigger(values: string[]) {
|
415 |
+
this.renderList("", values.map(x => {
|
416 |
+
return { id: x, value: x };
|
417 |
+
}), "");
|
418 |
+
}
|
419 |
+
|
420 |
+
onSomethingChange() {
|
421 |
+
/// We trigger manually so here we can _probably_ just always close.
|
422 |
+
this.hideMentionList();
|
423 |
+
}
|
424 |
+
|
425 |
+
onTextChange(delta: Delta, oldDelta: Delta, source: Sources) {
|
426 |
+
if (source === 'user') {
|
427 |
+
this.onSomethingChange();
|
428 |
+
}
|
429 |
+
}
|
430 |
+
|
431 |
+
onSelectionChange(range: RangeStatic) {
|
432 |
+
if (range && range.length === 0) {
|
433 |
+
this.onSomethingChange();
|
434 |
+
} else {
|
435 |
+
this.hideMentionList();
|
436 |
+
}
|
437 |
+
}
|
438 |
+
}
|
439 |
+
|
440 |
+
|
441 |
+
Quill.register('modules/mention', Mention);
|
front/js-src/controller.ts
ADDED
@@ -0,0 +1,319 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import { Api } from './Api';
|
2 |
+
import { Mention } from './Mention';
|
3 |
+
import { c } from './lib/Log';
|
4 |
+
import { Utils } from './lib/Utils';
|
5 |
+
import { VanillaTilt } from './vanilla-tilt';
|
6 |
+
import { ShareScreenshotModal, SavePublishModal } from './modals';
|
7 |
+
|
8 |
+
/// We experimented with a couple of different build systems
|
9 |
+
/// to integrate Quill (for instance module-then-postprocessing
|
10 |
+
/// like in `web3d`) but none worked really well so we just
|
11 |
+
/// hotlink the js and basically copy/paste the @types/quill
|
12 |
+
/// declaration here.
|
13 |
+
/// Update: we now use rollup (for html2canvas), but quill is
|
14 |
+
/// still a pain so it's still not in the same bundle.
|
15 |
+
|
16 |
+
const DEBUG = false;
|
17 |
+
/// ^^ when debugging the quill integration, add the quill.snow.css to layout.hbs
|
18 |
+
/// <link href="/front/node_modules/quill/dist/quill.snow.css" rel="stylesheet">
|
19 |
+
/// <link href="/front/node_modules/quill/dist/quill.core.css" rel="stylesheet">
|
20 |
+
/// We tried doing it programmatically here but it's a bit slow.
|
21 |
+
if (DEBUG) {
|
22 |
+
document.head.insertAdjacentHTML(
|
23 |
+
'beforeend',
|
24 |
+
`<link href="/front/node_modules/quill/dist/quill.snow.css" rel="stylesheet">`
|
25 |
+
);
|
26 |
+
/// ^^ add css to debug. Do it as early as possible.
|
27 |
+
}
|
28 |
+
|
29 |
+
enum Page {
|
30 |
+
app, landing, model
|
31 |
+
}
|
32 |
+
const App = {
|
33 |
+
page:
|
34 |
+
(document.body.classList.contains('app')) ? Page.app
|
35 |
+
: (document.body.classList.contains('landing')) ? Page.landing
|
36 |
+
: Page.model
|
37 |
+
,
|
38 |
+
editable: document.body.dataset.editable === 'true',
|
39 |
+
header: {
|
40 |
+
shuffleBtn: document.querySelector('header .js-shuffle') as HTMLAnchorElement,
|
41 |
+
triggerBtn: document.querySelector('header .js-trigger') as HTMLAnchorElement,
|
42 |
+
mainInfoBtn: document.querySelector('header .title .info') as HTMLImageElement,
|
43 |
+
shareBtn: document.querySelector<HTMLAnchorElement>('header .js-share'),
|
44 |
+
saveBtn: document.querySelector<HTMLAnchorElement>('header .js-save'),
|
45 |
+
duplicateBtn: document.querySelector<HTMLAnchorElement>('header .js-duplicate'),
|
46 |
+
},
|
47 |
+
shareScreenBtn: document.querySelector('.page-container .js-share') as HTMLAnchorElement,
|
48 |
+
loaderEditor: document.querySelector('.page-container .js-loader') as HTMLImageElement,
|
49 |
+
sliders: Array.from(
|
50 |
+
document.querySelectorAll('.decoder-settings input.slider')
|
51 |
+
) as HTMLInputElement[],
|
52 |
+
INITIAL_CONTENT: {} as Delta,
|
53 |
+
/**
|
54 |
+
* Helper function to more cleanly route different page types.
|
55 |
+
*/
|
56 |
+
onLoad: (p: Page, callback: () => void) => {
|
57 |
+
if (p === App.page) {
|
58 |
+
document.addEventListener('DOMContentLoaded', () => {
|
59 |
+
callback();
|
60 |
+
});
|
61 |
+
}
|
62 |
+
},
|
63 |
+
};
|
64 |
+
|
65 |
+
const PROMPTS = [
|
66 |
+
`Before boarding your rocket to Mars, remember to pack these items`,
|
67 |
+
`In a shocking finding, scientist discovered a herd of unicorns living in a remote, previously unexplored valley, in the Andes Mountains. Even more surprising to the researchers was the fact that the unicorns spoke perfect English.`,
|
68 |
+
`Legolas and Gimli advanced on the orcs, raising their weapons with a harrowing war cry.`,
|
69 |
+
`Today, scientists confirmed the worst possible outcome: the massive asteroid will collide with Earth`,
|
70 |
+
`
|
71 |
+
Thor: The Tesseract belongs on Asgard, no human is a match for it.
|
72 |
+
Tony turns to leave, but Steve stops him.
|
73 |
+
Steve: You're not going alone!
|
74 |
+
Tony: You gonna stop me?
|
75 |
+
`.replace(/\t/g, "").trim().concat("\n"),
|
76 |
+
];
|
77 |
+
|
78 |
+
|
79 |
+
|
80 |
+
|
81 |
+
App.onLoad(Page.app, () => {
|
82 |
+
const modalScreenshot = new ShareScreenshotModal;
|
83 |
+
|
84 |
+
const opts: QuillOptionsStatic = DEBUG
|
85 |
+
? {
|
86 |
+
theme: 'snow',
|
87 |
+
modules: {
|
88 |
+
mention: {},
|
89 |
+
},
|
90 |
+
}
|
91 |
+
: {
|
92 |
+
theme: undefined,
|
93 |
+
// formats: [],
|
94 |
+
modules: {
|
95 |
+
toolbar: [],
|
96 |
+
mention: {},
|
97 |
+
},
|
98 |
+
}
|
99 |
+
;
|
100 |
+
if (! App.editable) {
|
101 |
+
opts.readOnly = true;
|
102 |
+
}
|
103 |
+
const quill = new Quill('div.editor', opts);
|
104 |
+
const mention = quill.getModule('mention') as Mention;
|
105 |
+
(<any>window).quill = quill;
|
106 |
+
const QUILL_C = (<any>window).QUILL_C;
|
107 |
+
if (QUILL_C) {
|
108 |
+
quill.setContents(QUILL_C);
|
109 |
+
}
|
110 |
+
|
111 |
+
|
112 |
+
|
113 |
+
quill.container.appendChild(App.loaderEditor);
|
114 |
+
quill.container.appendChild(App.shareScreenBtn);
|
115 |
+
|
116 |
+
//
|
117 |
+
// div.editor .ql-container <-- quill.container
|
118 |
+
// +--------------------------------+
|
119 |
+
// | div.ql-editor contenteditable | <-- quill.root
|
120 |
+
// | +----------------------------+ |
|
121 |
+
// | | | |
|
122 |
+
// | | | |
|
123 |
+
// | +----------------------------+ |
|
124 |
+
// +--------------------------------+
|
125 |
+
//
|
126 |
+
|
127 |
+
quill.keyboard.addBinding({ key: Mention.Keys.TAB }, () => {
|
128 |
+
triggerAutocomplete();
|
129 |
+
});
|
130 |
+
quill.keyboard.bindings[Mention.Keys.TAB].unshift(
|
131 |
+
quill.keyboard.bindings[Mention.Keys.TAB].pop()
|
132 |
+
);
|
133 |
+
/// ^^ important.
|
134 |
+
/// ^^ place it at beginning of bindings.
|
135 |
+
|
136 |
+
|
137 |
+
const triggerAutocomplete = async () => {
|
138 |
+
/// vv position loader
|
139 |
+
mention.setCursorPos();
|
140 |
+
const cursorBbox = quill.getBounds(mention.getCursorPos());
|
141 |
+
App.loaderEditor.style.top = `${cursorBbox.top - 4}px`;
|
142 |
+
App.loaderEditor.style.left = `${cursorBbox.left + 4}px`;
|
143 |
+
App.loaderEditor.classList.remove('hide');
|
144 |
+
|
145 |
+
/// vv Launch api request.
|
146 |
+
const text = quill.getText(0, mention.getCursorPos());
|
147 |
+
// ^^ That is so much simpler that what we used to do
|
148 |
+
// when we were embbedding objects like in `quill-mention`.
|
149 |
+
c.debug(
|
150 |
+
`%c[About to launch autocomplete for]`,
|
151 |
+
`color: green;`,
|
152 |
+
text,
|
153 |
+
);
|
154 |
+
const o = await Api.shared.postWithSettings({ context: text });
|
155 |
+
App.loaderEditor.classList.add('hide');
|
156 |
+
|
157 |
+
/// vv Trigger mention module.
|
158 |
+
for (const x of o.sentences) {
|
159 |
+
c.log(x.value);
|
160 |
+
}
|
161 |
+
mention.trigger(
|
162 |
+
o.sentences.map(x => x.value)
|
163 |
+
);
|
164 |
+
};
|
165 |
+
|
166 |
+
|
167 |
+
App.header.duplicateBtn?.addEventListener('click', async (e) => {
|
168 |
+
e.preventDefault();
|
169 |
+
const url = await Api.shared.postDuplicate();
|
170 |
+
window.location.href = url;
|
171 |
+
});
|
172 |
+
|
173 |
+
|
174 |
+
if (! App.editable) {
|
175 |
+
return ;
|
176 |
+
}
|
177 |
+
/**
|
178 |
+
* vv Below is only in editable mode.
|
179 |
+
*/
|
180 |
+
|
181 |
+
const modalSave = new SavePublishModal(quill);
|
182 |
+
|
183 |
+
App.header.shuffleBtn.addEventListener('click', (e) => {
|
184 |
+
e.preventDefault();
|
185 |
+
quill.setText(
|
186 |
+
Utils.randomItem(PROMPTS)
|
187 |
+
);
|
188 |
+
quill.setSelection(quill.getLength(), 0);
|
189 |
+
/// ^^ github.com/quilljs/quill/issues/2635
|
190 |
+
triggerAutocomplete();
|
191 |
+
});
|
192 |
+
App.header.triggerBtn.addEventListener('click', (e) => {
|
193 |
+
e.preventDefault();
|
194 |
+
triggerAutocomplete();
|
195 |
+
});
|
196 |
+
App.header.shareBtn?.addEventListener('click', async (e) => {
|
197 |
+
e.preventDefault();
|
198 |
+
const text = `Write With Transformer via @huggingface`;
|
199 |
+
window.open(`https://twitter.com/share?url=${ encodeURIComponent(window.location.href) }&text=${ encodeURIComponent(text) }`);
|
200 |
+
});
|
201 |
+
App.header.saveBtn?.addEventListener('click', (e) => {
|
202 |
+
e.preventDefault();
|
203 |
+
mention.hideMentionList();
|
204 |
+
modalSave.show();
|
205 |
+
});
|
206 |
+
|
207 |
+
App.shareScreenBtn.addEventListener('click', async (e) => {
|
208 |
+
e.preventDefault();
|
209 |
+
mention.hideMentionList();
|
210 |
+
modalScreenshot.show();
|
211 |
+
});
|
212 |
+
quill.on('text-change', () => {
|
213 |
+
App.shareScreenBtn.classList.remove('hide'); /// <- we use a fadeout effect.
|
214 |
+
const hasTextFromAI = quill.getContents()
|
215 |
+
.ops
|
216 |
+
.some(op => op.attributes && op.attributes.bold === true)
|
217 |
+
;
|
218 |
+
App.shareScreenBtn.classList.toggle('fadeout', ! hasTextFromAI);
|
219 |
+
});
|
220 |
+
document.addEventListener('click', (e) => {
|
221 |
+
/// Handle clicks on links inside the editor.
|
222 |
+
if (! (
|
223 |
+
e.target instanceof HTMLAnchorElement
|
224 |
+
&& e.target.closest('div.ql-editor') !== null
|
225 |
+
)) {
|
226 |
+
return ;
|
227 |
+
}
|
228 |
+
/// Ok, let's do this.
|
229 |
+
e.preventDefault();
|
230 |
+
e.stopPropagation();
|
231 |
+
const href = e.target.getAttribute('href'); /// <- caution, get the original string.
|
232 |
+
c.debug(`[click]`, href);
|
233 |
+
if (href === '#js-shuffle') {
|
234 |
+
App.header.shuffleBtn.click();
|
235 |
+
} else {
|
236 |
+
window.open(e.target.href);
|
237 |
+
}
|
238 |
+
});
|
239 |
+
document.addEventListener("scroll", e => {
|
240 |
+
const trigger = document.getElementsByClassName("js-trigger")[0] as HTMLAnchorElement;
|
241 |
+
if (scrollY > 100) {
|
242 |
+
trigger.style.position = "fixed";
|
243 |
+
trigger.style.top = "10px";
|
244 |
+
trigger.style.border = "1px solid blue";
|
245 |
+
trigger.style.backgroundColor = "white";
|
246 |
+
trigger.style.borderRadius = "100px";
|
247 |
+
trigger.style.padding = "5px";
|
248 |
+
trigger.style.zIndex = "1";
|
249 |
+
trigger.style.left = "50%";
|
250 |
+
trigger.style.transform = "translateX(-50%)";
|
251 |
+
} else {
|
252 |
+
trigger.style.position = "relative";
|
253 |
+
trigger.style.top = "auto";
|
254 |
+
trigger.style.border = "none";
|
255 |
+
trigger.style.backgroundColor = "white";
|
256 |
+
trigger.style.borderRadius = "0";
|
257 |
+
trigger.style.padding = "0";
|
258 |
+
trigger.style.zIndex = "1";
|
259 |
+
trigger.style.left = "auto"
|
260 |
+
}
|
261 |
+
});
|
262 |
+
|
263 |
+
/**
|
264 |
+
* Settings
|
265 |
+
*/
|
266 |
+
const handleSliderChange = (slider: HTMLInputElement) => {
|
267 |
+
const div = slider.parentNode as HTMLDivElement;
|
268 |
+
const spanVal = div.querySelector('.js-val') as HTMLSpanElement;
|
269 |
+
const value = Number.isInteger(slider.valueAsNumber)
|
270 |
+
? slider.valueAsNumber
|
271 |
+
: Number(slider.valueAsNumber.toFixed(2))
|
272 |
+
;
|
273 |
+
const valueKey = `value-${value}`;
|
274 |
+
if (slider.dataset[valueKey]) {
|
275 |
+
spanVal.innerText = slider.dataset[valueKey]!;
|
276 |
+
} else {
|
277 |
+
spanVal.innerText = value.toString();
|
278 |
+
}
|
279 |
+
const min = Number(slider.getAttribute('min'));
|
280 |
+
const max = Number(slider.getAttribute('max'));
|
281 |
+
if (value < min + (max - min) / 3) {
|
282 |
+
spanVal.className = "js-val green";
|
283 |
+
} else if (value < min + 2 * (max - min) / 3) {
|
284 |
+
spanVal.className = "js-val orange";
|
285 |
+
} else {
|
286 |
+
spanVal.className = "js-val red";
|
287 |
+
}
|
288 |
+
const isInverted = slider.classList.contains('js-inverted');
|
289 |
+
if (isInverted) {
|
290 |
+
if (spanVal.classList.contains('green')) {
|
291 |
+
spanVal.classList.remove('green');
|
292 |
+
spanVal.classList.add('red');
|
293 |
+
} else if (spanVal.classList.contains('red')) {
|
294 |
+
spanVal.classList.remove('red');
|
295 |
+
spanVal.classList.add('green');
|
296 |
+
}
|
297 |
+
}
|
298 |
+
};
|
299 |
+
for (const slider of App.sliders) {
|
300 |
+
handleSliderChange(slider);
|
301 |
+
slider.addEventListener('input', () => {
|
302 |
+
handleSliderChange(slider);
|
303 |
+
});
|
304 |
+
}
|
305 |
+
});
|
306 |
+
|
307 |
+
|
308 |
+
|
309 |
+
App.onLoad(Page.landing, () => {
|
310 |
+
/**
|
311 |
+
* VanillaTilt
|
312 |
+
*/
|
313 |
+
VanillaTilt.init(document.querySelectorAll("[data-tilt]"), {
|
314 |
+
glare: true,
|
315 |
+
scale: 1.06,
|
316 |
+
'max-glare': 0.3,
|
317 |
+
speed: 400,
|
318 |
+
});
|
319 |
+
});
|
front/js-src/lib/Log.ts
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
export const c = console;
|
front/js-src/lib/Utils.ts
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
export class Utils {
|
3 |
+
private static escapeMap = {
|
4 |
+
/// From underscore.js
|
5 |
+
'&': '&',
|
6 |
+
'<': '<',
|
7 |
+
'>': '>',
|
8 |
+
'"': '"',
|
9 |
+
"'": ''',
|
10 |
+
'`': '`'
|
11 |
+
};
|
12 |
+
|
13 |
+
/**
|
14 |
+
* Escape a message's content for insertion into html.
|
15 |
+
*/
|
16 |
+
static escape(s: string): string {
|
17 |
+
let x = s;
|
18 |
+
for (const [k, v] of Object.entries(this.escapeMap)) {
|
19 |
+
x = x.replace(new RegExp(k, 'g'), v);
|
20 |
+
}
|
21 |
+
return x.replace(/\n/g, '<br>');
|
22 |
+
}
|
23 |
+
|
24 |
+
/**
|
25 |
+
* Opposite of escape.
|
26 |
+
*/
|
27 |
+
static unescape(s: string): string {
|
28 |
+
let x = s.replace(/<br>/g, '\n');
|
29 |
+
for (const [k, v] of Object.entries(this.escapeMap)) {
|
30 |
+
x = x.replace(new RegExp(v, 'g'), k);
|
31 |
+
}
|
32 |
+
return x;
|
33 |
+
}
|
34 |
+
|
35 |
+
/**
|
36 |
+
* "Real" modulo (always >= 0), not remainder.
|
37 |
+
*/
|
38 |
+
static mod(a: number, n: number): number {
|
39 |
+
return ((a % n) + n) % n;
|
40 |
+
}
|
41 |
+
|
42 |
+
/**
|
43 |
+
* Noop object with arbitrary number of nested attributes that are also noop.
|
44 |
+
*/
|
45 |
+
static deepNoop() {
|
46 |
+
const noop = new Proxy(() => {}, {
|
47 |
+
get: () => noop
|
48 |
+
});
|
49 |
+
return noop;
|
50 |
+
}
|
51 |
+
|
52 |
+
/**
|
53 |
+
* Capitalize
|
54 |
+
*/
|
55 |
+
static capitalize(s: string): string {
|
56 |
+
return s.charAt(0).toUpperCase() + s.slice(1);
|
57 |
+
}
|
58 |
+
|
59 |
+
/**
|
60 |
+
* Returns a promise that will resolve after the specified time
|
61 |
+
* @param ms Number of ms to wait
|
62 |
+
*/
|
63 |
+
static delay(ms: number): Promise<void> {
|
64 |
+
return new Promise((resolve, reject) => {
|
65 |
+
setTimeout(() => resolve(), ms);
|
66 |
+
});
|
67 |
+
}
|
68 |
+
|
69 |
+
/**
|
70 |
+
* Random element from array
|
71 |
+
*/
|
72 |
+
static randomItem<T>(arr: T[]): T {
|
73 |
+
return arr[Math.floor(Math.random()*arr.length)];
|
74 |
+
}
|
75 |
+
}
|
76 |
+
|
front/js-src/modals.ts
ADDED
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import { Utils } from './lib/Utils';
|
2 |
+
import html2canvas from 'html2canvas';
|
3 |
+
import { c } from './lib/Log';
|
4 |
+
import { Api } from './Api';
|
5 |
+
|
6 |
+
abstract class Modal {
|
7 |
+
protected div: HTMLDivElement;
|
8 |
+
protected doneBtn: HTMLAnchorElement | null;
|
9 |
+
protected loader: HTMLImageElement;
|
10 |
+
constructor(className: string) {
|
11 |
+
this.div = document.querySelector(`div.modal.${className}`) as HTMLDivElement;
|
12 |
+
this.doneBtn = this.div.querySelector<HTMLAnchorElement>('.js-close');
|
13 |
+
this.loader = this.div.querySelector('.js-loader') as HTMLImageElement;
|
14 |
+
|
15 |
+
this.doneBtn?.addEventListener('click', (e) => {
|
16 |
+
e.preventDefault();
|
17 |
+
this.hide();
|
18 |
+
});
|
19 |
+
this.div.addEventListener('click', (e) => {
|
20 |
+
if (e.target === this.div) {
|
21 |
+
c.debug(`modal:background.click`);
|
22 |
+
this.hide();
|
23 |
+
}
|
24 |
+
});
|
25 |
+
}
|
26 |
+
/**
|
27 |
+
* Hooks: Implement those to perform the actual work done on show and hide.
|
28 |
+
*/
|
29 |
+
abstract performBeforeShow(): Promise<void>;
|
30 |
+
abstract performShow(): Promise<void>;
|
31 |
+
abstract performHide(): Promise<void>;
|
32 |
+
async show() {
|
33 |
+
await this.performBeforeShow();
|
34 |
+
this.div.classList.add('fadeout');
|
35 |
+
this.div.classList.remove('hide');
|
36 |
+
await Utils.delay(100);
|
37 |
+
this.div.classList.remove('fadeout');
|
38 |
+
await this.performShow();
|
39 |
+
this.loader.classList.add('hide');
|
40 |
+
}
|
41 |
+
async hide() {
|
42 |
+
this.div.classList.add('fadeout');
|
43 |
+
await Utils.delay(200);
|
44 |
+
this.div.classList.add('hide');
|
45 |
+
this.div.classList.remove('fadeout');
|
46 |
+
await this.performHide();
|
47 |
+
}
|
48 |
+
}
|
49 |
+
|
50 |
+
export class ShareScreenshotModal extends Modal {
|
51 |
+
private imResult = this.div.querySelector('.js-result') as HTMLImageElement;
|
52 |
+
|
53 |
+
constructor() {
|
54 |
+
super(`share-screenshot`);
|
55 |
+
}
|
56 |
+
async performBeforeShow() {
|
57 |
+
this.loader.classList.remove('hide');
|
58 |
+
}
|
59 |
+
async performShow() {
|
60 |
+
await Utils.delay(800); /// <- for good ux
|
61 |
+
const el = document.querySelector('div.page-inner') as HTMLDivElement;
|
62 |
+
const canvas = await html2canvas(el, {
|
63 |
+
logging: false, /// <- inoperant in our version of html2canvas.
|
64 |
+
onclone: (doc) => {
|
65 |
+
const clonedEl = doc.querySelector('div.page-inner') as HTMLDivElement;
|
66 |
+
clonedEl.classList.add('html2canvas');
|
67 |
+
const watermark = doc.querySelector('div.watermark') as HTMLDivElement;
|
68 |
+
watermark.style.visibility = `visible`;
|
69 |
+
}
|
70 |
+
});
|
71 |
+
this.imResult.src = canvas.toDataURL();
|
72 |
+
}
|
73 |
+
async performHide() {
|
74 |
+
this.imResult.src = "";
|
75 |
+
}
|
76 |
+
}
|
77 |
+
|
78 |
+
export class SavePublishModal extends Modal {
|
79 |
+
private saveBtn = this.div.querySelector('.js-save') as HTMLAnchorElement;
|
80 |
+
private form = this.div.querySelector('form') as HTMLFormElement;
|
81 |
+
constructor(
|
82 |
+
private quill: Quill
|
83 |
+
) {
|
84 |
+
super(`save-publish`);
|
85 |
+
|
86 |
+
/// vv Url fields auto-select.
|
87 |
+
const urlInputs = Array.from(
|
88 |
+
this.div.querySelectorAll('.doc-url')
|
89 |
+
) as HTMLInputElement[];
|
90 |
+
for (const x of urlInputs) {
|
91 |
+
x.addEventListener('focus', () => {
|
92 |
+
x.select();
|
93 |
+
});
|
94 |
+
}
|
95 |
+
|
96 |
+
this.saveBtn.addEventListener('click', (e) => {
|
97 |
+
e.preventDefault();
|
98 |
+
if (! this.form.reportValidity()) {
|
99 |
+
/// Form is invalid.
|
100 |
+
return ;
|
101 |
+
}
|
102 |
+
this.save();
|
103 |
+
});
|
104 |
+
this.form.addEventListener('submit', (e) => {
|
105 |
+
e.preventDefault();
|
106 |
+
this.saveBtn.click();
|
107 |
+
});
|
108 |
+
}
|
109 |
+
async performBeforeShow() {}
|
110 |
+
async performShow() {}
|
111 |
+
async performHide() {}
|
112 |
+
async save() {
|
113 |
+
this.loader.classList.remove('hide');
|
114 |
+
|
115 |
+
const inputTitle = this.div.querySelector('.doc-title') as HTMLInputElement;
|
116 |
+
const title = inputTitle.value;
|
117 |
+
const contents = this.quill.getContents();
|
118 |
+
c.log(JSON.stringify({ title, contents }));
|
119 |
+
|
120 |
+
const success = await Api.shared.postEdit({ title, contents });
|
121 |
+
await Utils.delay(800); /// <- for good ux
|
122 |
+
|
123 |
+
if (success) {
|
124 |
+
this.loader.classList.add('hide');
|
125 |
+
this.hide();
|
126 |
+
/// For now we always redirect to the edit url here:
|
127 |
+
/// vv
|
128 |
+
const inputEditUrl = this.div.querySelector('.doc-edit-url') as HTMLInputElement;
|
129 |
+
window.location.href = inputEditUrl.value;
|
130 |
+
} else {
|
131 |
+
window.alert(`did not manage to save`);
|
132 |
+
}
|
133 |
+
}
|
134 |
+
}
|
front/js-src/quill.d.ts
ADDED
@@ -0,0 +1,181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
// import { Blot } from "../node_modules/parchment/dist/src/blot/abstract/blot";
|
3 |
+
interface Blot {}
|
4 |
+
interface Delta {
|
5 |
+
ops: DeltaOperation[];
|
6 |
+
}
|
7 |
+
|
8 |
+
/**
|
9 |
+
* A stricter type definition would be:
|
10 |
+
*
|
11 |
+
* type DeltaOperation ({ insert: any } | { delete: number } | { retain: number }) & OptionalAttributes;
|
12 |
+
*
|
13 |
+
* But this would break a lot of existing code as it would require manual discrimination of the union types.
|
14 |
+
*/
|
15 |
+
type DeltaOperation = { insert?: any, delete?: number, retain?: number } & OptionalAttributes;
|
16 |
+
type Sources = "api" | "user" | "silent";
|
17 |
+
|
18 |
+
interface Key {
|
19 |
+
key: string | number;
|
20 |
+
shortKey?: boolean;
|
21 |
+
}
|
22 |
+
|
23 |
+
interface StringMap {
|
24 |
+
[key: string]: any;
|
25 |
+
}
|
26 |
+
|
27 |
+
interface OptionalAttributes {
|
28 |
+
attributes?: StringMap;
|
29 |
+
}
|
30 |
+
|
31 |
+
type TextChangeHandler = (delta: Delta, oldContents: Delta, source: Sources) => any;
|
32 |
+
type SelectionChangeHandler = (range: RangeStatic, oldRange: RangeStatic, source: Sources) => any;
|
33 |
+
type EditorChangeHandler = ((name: "text-change", delta: Delta, oldContents: Delta, source: Sources) => any)
|
34 |
+
| ((name: "selection-change", range: RangeStatic, oldRange: RangeStatic, source: Sources) => any);
|
35 |
+
|
36 |
+
interface KeyboardStatic {
|
37 |
+
addBinding(key: Key, callback: (range: RangeStatic, context: any) => void): void;
|
38 |
+
addBinding(key: Key, context: any, callback: (range: RangeStatic, context: any) => void): void;
|
39 |
+
bindings: { [index: number]: any[] };
|
40 |
+
}
|
41 |
+
|
42 |
+
interface ClipboardStatic {
|
43 |
+
convert(html?: string): Delta;
|
44 |
+
addMatcher(selectorOrNodeType: string|number, callback: (node: any, delta: Delta) => Delta): void;
|
45 |
+
dangerouslyPasteHTML(html: string, source?: Sources): void;
|
46 |
+
dangerouslyPasteHTML(index: number, html: string, source?: Sources): void;
|
47 |
+
}
|
48 |
+
|
49 |
+
interface QuillOptionsStatic {
|
50 |
+
debug?: string | boolean;
|
51 |
+
modules?: StringMap;
|
52 |
+
placeholder?: string;
|
53 |
+
readOnly?: boolean;
|
54 |
+
theme?: string;
|
55 |
+
formats?: string[];
|
56 |
+
bounds?: HTMLElement | string;
|
57 |
+
scrollingContainer?: HTMLElement | string;
|
58 |
+
strict?: boolean;
|
59 |
+
}
|
60 |
+
|
61 |
+
interface BoundsStatic {
|
62 |
+
bottom: number;
|
63 |
+
left: number;
|
64 |
+
right: number;
|
65 |
+
top: number;
|
66 |
+
height: number;
|
67 |
+
width: number;
|
68 |
+
}
|
69 |
+
|
70 |
+
declare interface RangeStatic {
|
71 |
+
index: number;
|
72 |
+
length: number;
|
73 |
+
}
|
74 |
+
|
75 |
+
declare class RangeStatic implements RangeStatic {
|
76 |
+
constructor();
|
77 |
+
index: number;
|
78 |
+
length: number;
|
79 |
+
}
|
80 |
+
|
81 |
+
interface EventEmitter {
|
82 |
+
on(eventName: "text-change", handler: TextChangeHandler): EventEmitter;
|
83 |
+
on(eventName: "selection-change", handler: SelectionChangeHandler): EventEmitter;
|
84 |
+
on(eventName: "editor-change", handler: EditorChangeHandler): EventEmitter;
|
85 |
+
once(eventName: "text-change", handler: TextChangeHandler): EventEmitter;
|
86 |
+
once(eventName: "selection-change", handler: SelectionChangeHandler): EventEmitter;
|
87 |
+
once(eventName: "editor-change", handler: EditorChangeHandler): EventEmitter;
|
88 |
+
off(eventName: "text-change", handler: TextChangeHandler): EventEmitter;
|
89 |
+
off(eventName: "selection-change", handler: SelectionChangeHandler): EventEmitter;
|
90 |
+
off(eventName: "editor-change", handler: EditorChangeHandler): EventEmitter;
|
91 |
+
}
|
92 |
+
|
93 |
+
declare class Quill {
|
94 |
+
/**
|
95 |
+
* @private Internal API
|
96 |
+
*/
|
97 |
+
root: HTMLDivElement;
|
98 |
+
container: HTMLElement; /// <- used by quill-mention
|
99 |
+
clipboard: ClipboardStatic;
|
100 |
+
scroll: Blot;
|
101 |
+
keyboard: KeyboardStatic;
|
102 |
+
constructor(container: string | Element, options?: QuillOptionsStatic);
|
103 |
+
deleteText(index: number, length: number, source?: Sources): Delta;
|
104 |
+
disable(): void;
|
105 |
+
enable(enabled?: boolean): void;
|
106 |
+
getContents(index?: number, length?: number): Delta;
|
107 |
+
getLength(): number;
|
108 |
+
getText(index?: number, length?: number): string;
|
109 |
+
insertEmbed(index: number, type: string, value: any, source?: Sources): Delta;
|
110 |
+
insertText(index: number, text: string, source?: Sources): Delta;
|
111 |
+
insertText(index: number, text: string, format: string, value: any, source?: Sources): Delta;
|
112 |
+
insertText(index: number, text: string, formats: StringMap, source?: Sources): Delta;
|
113 |
+
/**
|
114 |
+
* @deprecated Remove in 2.0. Use clipboard.dangerouslyPasteHTML(index: number, html: string, source: Sources)
|
115 |
+
*/
|
116 |
+
pasteHTML(index: number, html: string, source?: Sources): string;
|
117 |
+
/**
|
118 |
+
* @deprecated Remove in 2.0. Use clipboard.dangerouslyPasteHTML(html: string, source: Sources): void;
|
119 |
+
*/
|
120 |
+
pasteHTML(html: string, source?: Sources): string;
|
121 |
+
setContents(delta: Delta, source?: Sources): Delta;
|
122 |
+
setText(text: string, source?: Sources): Delta;
|
123 |
+
update(source?: Sources): void;
|
124 |
+
updateContents(delta: Delta, source?: Sources): Delta;
|
125 |
+
|
126 |
+
format(name: string, value: any, source?: Sources): Delta;
|
127 |
+
formatLine(index: number, length: number, source?: Sources): Delta;
|
128 |
+
formatLine(index: number, length: number, format: string, value: any, source?: Sources): Delta;
|
129 |
+
formatLine(index: number, length: number, formats: StringMap, source?: Sources): Delta;
|
130 |
+
formatText(index: number, length: number, source?: Sources): Delta;
|
131 |
+
formatText(index: number, length: number, format: string, value: any, source?: Sources): Delta;
|
132 |
+
formatText(index: number, length: number, formats: StringMap, source?: Sources): Delta;
|
133 |
+
formatText(range: RangeStatic, format: string, value: any, source?: Sources): Delta;
|
134 |
+
formatText(range: RangeStatic, formats: StringMap, source?: Sources): Delta;
|
135 |
+
getFormat(range?: RangeStatic): StringMap;
|
136 |
+
getFormat(index: number, length?: number): StringMap;
|
137 |
+
removeFormat(index: number, length: number, source?: Sources): Delta;
|
138 |
+
|
139 |
+
blur(): void;
|
140 |
+
focus(): void;
|
141 |
+
getBounds(index: number, length?: number): BoundsStatic;
|
142 |
+
getSelection(focus: true): RangeStatic;
|
143 |
+
getSelection(focus?: false): RangeStatic | null;
|
144 |
+
hasFocus(): boolean;
|
145 |
+
setSelection(index: number, length: number, source?: Sources): void;
|
146 |
+
setSelection(range: RangeStatic, source?: Sources): void;
|
147 |
+
|
148 |
+
// static methods: debug, import, register, find
|
149 |
+
static debug(level: string|boolean): void;
|
150 |
+
static import(path: string): any;
|
151 |
+
static register(path: string, def: any, suppressWarning?: boolean): void;
|
152 |
+
static register(defs: StringMap, suppressWarning?: boolean): void;
|
153 |
+
static find(domNode: Node, bubble?: boolean): Quill | any;
|
154 |
+
|
155 |
+
addContainer(classNameOrDomNode: string|Node, refNode?: Node): any;
|
156 |
+
getModule(name: string): any;
|
157 |
+
|
158 |
+
// Blot interface is not exported on Parchment
|
159 |
+
getIndex(blot: any): number;
|
160 |
+
getLeaf(index: number): any;
|
161 |
+
getLine(index: number): [any, number];
|
162 |
+
getLines(index?: number, length?: number): any[];
|
163 |
+
getLines(range: RangeStatic): any[];
|
164 |
+
|
165 |
+
// EventEmitter methods
|
166 |
+
on(eventName: "text-change", handler: TextChangeHandler): EventEmitter;
|
167 |
+
on(eventName: "selection-change", handler: SelectionChangeHandler): EventEmitter;
|
168 |
+
on(eventName: "editor-change", handler: EditorChangeHandler): EventEmitter;
|
169 |
+
once(eventName: "text-change", handler: TextChangeHandler): EventEmitter;
|
170 |
+
once(eventName: "selection-change", handler: SelectionChangeHandler): EventEmitter;
|
171 |
+
once(eventName: "editor-change", handler: EditorChangeHandler): EventEmitter;
|
172 |
+
off(eventName: "text-change", handler: TextChangeHandler): EventEmitter;
|
173 |
+
off(eventName: "selection-change", handler: SelectionChangeHandler): EventEmitter;
|
174 |
+
off(eventName: "editor-change", handler: EditorChangeHandler): EventEmitter;
|
175 |
+
|
176 |
+
static sources: {
|
177 |
+
API: 'api',
|
178 |
+
SILENT: 'silent',
|
179 |
+
USER: 'user',
|
180 |
+
};
|
181 |
+
}
|
front/js-src/vanilla-tilt.ts
ADDED
@@ -0,0 +1,371 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
export namespace VanillaTilt {
|
2 |
+
/**
|
3 |
+
* Options which configures the tilting
|
4 |
+
*/
|
5 |
+
export interface TiltOptions {
|
6 |
+
/**
|
7 |
+
* Reverse the tilt direction
|
8 |
+
*/
|
9 |
+
reverse?: boolean;
|
10 |
+
/**
|
11 |
+
* Max tilt rotation (degrees)
|
12 |
+
*/
|
13 |
+
max?: number;
|
14 |
+
/**
|
15 |
+
* Transform perspective, the lower the more extreme the tilt gets.
|
16 |
+
*/
|
17 |
+
perspective?: number;
|
18 |
+
/**
|
19 |
+
* 2 = 200%, 1.5 = 150%, etc..
|
20 |
+
*/
|
21 |
+
scale?: number;
|
22 |
+
/**
|
23 |
+
* Speed of the enter/exit transition
|
24 |
+
*/
|
25 |
+
speed?: number;
|
26 |
+
/**
|
27 |
+
* Set a transition on enter/exit.
|
28 |
+
*/
|
29 |
+
transition?: boolean;
|
30 |
+
/**
|
31 |
+
* What axis should be disabled. Can be X or Y.
|
32 |
+
*/
|
33 |
+
axis?: null | "x" | "y";
|
34 |
+
/**
|
35 |
+
* If the tilt effect has to be reset on exit.
|
36 |
+
*/
|
37 |
+
reset?: boolean;
|
38 |
+
/**
|
39 |
+
* Easing on enter/exit.
|
40 |
+
*/
|
41 |
+
easing?: string;
|
42 |
+
/**
|
43 |
+
* Added (@julien-c)
|
44 |
+
*/
|
45 |
+
glare?: boolean;
|
46 |
+
'max-glare'?: number;
|
47 |
+
}
|
48 |
+
|
49 |
+
export interface TiltValues {
|
50 |
+
/**
|
51 |
+
* The current tilt on the X axis
|
52 |
+
*/
|
53 |
+
tiltX: number;
|
54 |
+
/**
|
55 |
+
* The current tilt on the Y axis
|
56 |
+
*/
|
57 |
+
tiltY: number;
|
58 |
+
/**
|
59 |
+
* The current percentage on the X axis
|
60 |
+
*/
|
61 |
+
percentageX: number;
|
62 |
+
/**
|
63 |
+
* The current percentage on the Y axis
|
64 |
+
*/
|
65 |
+
percentageY: number;
|
66 |
+
}
|
67 |
+
|
68 |
+
export interface HTMLVanillaTiltElement extends HTMLElement {
|
69 |
+
vanillaTilt: VanillaTilt
|
70 |
+
}
|
71 |
+
}
|
72 |
+
|
73 |
+
|
74 |
+
export class VanillaTilt {
|
75 |
+
width: number | null;
|
76 |
+
height: number | null;
|
77 |
+
left: number | null;
|
78 |
+
top: number | null;
|
79 |
+
element: VanillaTilt.HTMLVanillaTiltElement;
|
80 |
+
settings: VanillaTilt.TiltOptions;
|
81 |
+
reverse : -1 | 1;
|
82 |
+
glare: boolean;
|
83 |
+
glarePrerender: boolean;
|
84 |
+
transitionTimeout: number | null;
|
85 |
+
updateCall: number | null;
|
86 |
+
glareElementWrapper: HTMLElement;
|
87 |
+
glareElement: HTMLElement;
|
88 |
+
updateBind: () => void;
|
89 |
+
resetBind: () => void;
|
90 |
+
onMouseEnterBind: (e: Event) => void;
|
91 |
+
onMouseMoveBind: (e: Event) => void;
|
92 |
+
onMouseLeaveBind: (e: Event) => void;
|
93 |
+
event: MouseEvent;
|
94 |
+
|
95 |
+
constructor(element, settings: VanillaTilt.TiltOptions = {}) {
|
96 |
+
if (!(element instanceof Node)) {
|
97 |
+
throw ("Can't initialize VanillaTilt because " + element + " is not a Node.");
|
98 |
+
}
|
99 |
+
|
100 |
+
this.width = null;
|
101 |
+
this.height = null;
|
102 |
+
this.left = null;
|
103 |
+
this.top = null;
|
104 |
+
this.transitionTimeout = null;
|
105 |
+
this.updateCall = null;
|
106 |
+
|
107 |
+
this.updateBind = this.update.bind(this);
|
108 |
+
this.resetBind = this.reset.bind(this);
|
109 |
+
|
110 |
+
this.element = element as VanillaTilt.HTMLVanillaTiltElement;
|
111 |
+
this.settings = this.extendSettings(settings);
|
112 |
+
|
113 |
+
this.reverse = this.settings.reverse ? -1 : 1;
|
114 |
+
|
115 |
+
this.glare = this.isSettingTrue(this.settings.glare);
|
116 |
+
this.glarePrerender = this.isSettingTrue(this.settings["glare-prerender"]);
|
117 |
+
|
118 |
+
if (this.glare) {
|
119 |
+
this.prepareGlare();
|
120 |
+
}
|
121 |
+
|
122 |
+
this.addEventListeners();
|
123 |
+
}
|
124 |
+
|
125 |
+
isSettingTrue(setting) {
|
126 |
+
return setting === "" || setting === true || setting === 1;
|
127 |
+
}
|
128 |
+
|
129 |
+
addEventListeners() {
|
130 |
+
this.onMouseEnterBind = this.onMouseEnter.bind(this);
|
131 |
+
this.onMouseMoveBind = this.onMouseMove.bind(this);
|
132 |
+
this.onMouseLeaveBind = this.onMouseLeave.bind(this);
|
133 |
+
this.onWindowResizeBind = this.onWindowResizeBind.bind(this);
|
134 |
+
|
135 |
+
this.element.addEventListener("mouseenter", this.onMouseEnterBind);
|
136 |
+
this.element.addEventListener("mousemove", this.onMouseMoveBind);
|
137 |
+
this.element.addEventListener("mouseleave", this.onMouseLeaveBind);
|
138 |
+
if (this.glare) {
|
139 |
+
window.addEventListener("resize", this.onWindowResizeBind);
|
140 |
+
}
|
141 |
+
}
|
142 |
+
|
143 |
+
|
144 |
+
onMouseEnter(event) {
|
145 |
+
this.updateElementPosition();
|
146 |
+
(<any>this.element.style).willChange = "transform";
|
147 |
+
this.setTransition();
|
148 |
+
}
|
149 |
+
|
150 |
+
onMouseMove(event) {
|
151 |
+
if (this.updateCall !== null) {
|
152 |
+
cancelAnimationFrame(this.updateCall);
|
153 |
+
}
|
154 |
+
|
155 |
+
this.event = event;
|
156 |
+
this.updateCall = requestAnimationFrame(this.updateBind);
|
157 |
+
}
|
158 |
+
|
159 |
+
onMouseLeave(event) {
|
160 |
+
this.setTransition();
|
161 |
+
|
162 |
+
if (this.settings.reset) {
|
163 |
+
requestAnimationFrame(this.resetBind);
|
164 |
+
}
|
165 |
+
}
|
166 |
+
|
167 |
+
reset() {
|
168 |
+
this.event = {
|
169 |
+
pageX: this.left! + this.width! / 2,
|
170 |
+
pageY: this.top! + this.height! / 2
|
171 |
+
} as MouseEvent;
|
172 |
+
|
173 |
+
this.element.style.transform = "perspective(" + this.settings.perspective + "px) " +
|
174 |
+
"rotateX(0deg) " +
|
175 |
+
"rotateY(0deg) " +
|
176 |
+
"scale3d(1, 1, 1)"
|
177 |
+
;
|
178 |
+
|
179 |
+
if (this.glare) {
|
180 |
+
this.glareElement.style.transform = 'rotate(180deg) translate(-50%, -50%)';
|
181 |
+
this.glareElement.style.opacity = '0';
|
182 |
+
}
|
183 |
+
}
|
184 |
+
|
185 |
+
getValues() {
|
186 |
+
let x = (this.event.clientX - this.left!) / this.width!;
|
187 |
+
let y = (this.event.clientY - this.top!) / this.height!;
|
188 |
+
|
189 |
+
x = Math.min(Math.max(x, 0), 1);
|
190 |
+
y = Math.min(Math.max(y, 0), 1);
|
191 |
+
|
192 |
+
let tiltX = (this.reverse * (this.settings.max! / 2 - x * this.settings.max!)).toFixed(2);
|
193 |
+
let tiltY = (this.reverse * (y * this.settings.max! - this.settings.max! / 2)).toFixed(2);
|
194 |
+
let angle = Math.atan2(this.event.clientX - (this.left! + this.width! / 2), -(this.event.clientY - (this.top! + this.height! / 2))) * (180 / Math.PI);
|
195 |
+
|
196 |
+
return {
|
197 |
+
tiltX: tiltX,
|
198 |
+
tiltY: tiltY,
|
199 |
+
percentageX: x * 100,
|
200 |
+
percentageY: y * 100,
|
201 |
+
angle: angle
|
202 |
+
};
|
203 |
+
}
|
204 |
+
|
205 |
+
updateElementPosition() {
|
206 |
+
let rect = this.element.getBoundingClientRect();
|
207 |
+
|
208 |
+
this.width = this.element.offsetWidth;
|
209 |
+
this.height = this.element.offsetHeight;
|
210 |
+
this.left = rect.left;
|
211 |
+
this.top = rect.top;
|
212 |
+
}
|
213 |
+
|
214 |
+
update() {
|
215 |
+
const values = this.getValues();
|
216 |
+
|
217 |
+
this.element.style.transform = [
|
218 |
+
"perspective(" + this.settings.perspective + "px) ",
|
219 |
+
"rotateX(" + (this.settings.axis === "x" ? 0 : values.tiltY) + "deg) ",
|
220 |
+
"rotateY(" + (this.settings.axis === "y" ? 0 : values.tiltX) + "deg) ",
|
221 |
+
"scale3d(" + this.settings.scale + ", " + this.settings.scale + ", " + this.settings.scale + ")",
|
222 |
+
].join(" ");
|
223 |
+
|
224 |
+
if (this.glare) {
|
225 |
+
this.glareElement.style.transform = `rotate(${values.angle}deg) translate(-50%, -50%)`;
|
226 |
+
this.glareElement.style.opacity = `${values.percentageY * this.settings["max-glare"]! / 100}`;
|
227 |
+
}
|
228 |
+
|
229 |
+
this.element.dispatchEvent(new CustomEvent("tiltChange", {
|
230 |
+
"detail": values
|
231 |
+
}));
|
232 |
+
|
233 |
+
this.updateCall = null;
|
234 |
+
}
|
235 |
+
|
236 |
+
/**
|
237 |
+
* Appends the glare element (if glarePrerender equals false)
|
238 |
+
* and sets the default style
|
239 |
+
*/
|
240 |
+
prepareGlare() {
|
241 |
+
// If option pre-render is enabled we assume all html/css is present for an optimal glare effect.
|
242 |
+
if (!this.glarePrerender) {
|
243 |
+
// Create glare element
|
244 |
+
const jsTiltGlare = document.createElement("div");
|
245 |
+
jsTiltGlare.classList.add("js-tilt-glare");
|
246 |
+
|
247 |
+
const jsTiltGlareInner = document.createElement("div");
|
248 |
+
jsTiltGlareInner.classList.add("js-tilt-glare-inner");
|
249 |
+
|
250 |
+
jsTiltGlare.appendChild(jsTiltGlareInner);
|
251 |
+
this.element.appendChild(jsTiltGlare);
|
252 |
+
}
|
253 |
+
|
254 |
+
this.glareElementWrapper = this.element.querySelector(".js-tilt-glare") as HTMLElement;
|
255 |
+
this.glareElement = this.element.querySelector(".js-tilt-glare-inner") as HTMLElement;
|
256 |
+
|
257 |
+
if (this.glarePrerender) {
|
258 |
+
return ;
|
259 |
+
}
|
260 |
+
|
261 |
+
Object.assign(this.glareElementWrapper.style, {
|
262 |
+
"position": "absolute",
|
263 |
+
"top": "0",
|
264 |
+
"left": "0",
|
265 |
+
"width": "100%",
|
266 |
+
"height": "100%",
|
267 |
+
"overflow": "hidden",
|
268 |
+
'pointer-events': 'none',
|
269 |
+
});
|
270 |
+
|
271 |
+
Object.assign(this.glareElement.style, {
|
272 |
+
'position': 'absolute',
|
273 |
+
'top': '50%',
|
274 |
+
'left': '50%',
|
275 |
+
'pointer-events': 'none',
|
276 |
+
'background-image': `linear-gradient(0deg, rgba(255,255,255,0) 0%, rgba(255,255,255,1) 100%)`,
|
277 |
+
'width': `${this.element.offsetWidth * 2}px`,
|
278 |
+
'height': `${this.element.offsetWidth * 2}px`,
|
279 |
+
'transform': 'rotate(180deg) translate(-50%, -50%)',
|
280 |
+
'transform-origin': '0% 0%',
|
281 |
+
'opacity': '0',
|
282 |
+
});
|
283 |
+
}
|
284 |
+
|
285 |
+
updateGlareSize() {
|
286 |
+
Object.assign(this.glareElement.style, {
|
287 |
+
'width': `${this.element.offsetWidth * 2}`,
|
288 |
+
'height': `${this.element.offsetWidth * 2}`,
|
289 |
+
});
|
290 |
+
}
|
291 |
+
|
292 |
+
onWindowResizeBind() {
|
293 |
+
this.updateGlareSize();
|
294 |
+
}
|
295 |
+
|
296 |
+
setTransition() {
|
297 |
+
if (this.transitionTimeout) {
|
298 |
+
clearTimeout(this.transitionTimeout);
|
299 |
+
}
|
300 |
+
// this.element.style.transition = `${this.settings.speed}ms ${this.settings.easing}`;
|
301 |
+
/// From openai:
|
302 |
+
this.element.style.transition = `transform .4s cubic-bezier(0,0,.2,1)`;
|
303 |
+
if (this.glare) {
|
304 |
+
this.glareElement.style.transition = `opacity ${this.settings.speed}ms ${this.settings.easing}`;
|
305 |
+
}
|
306 |
+
|
307 |
+
this.transitionTimeout = setTimeout(() => {
|
308 |
+
this.element.style.transition = "";
|
309 |
+
if (this.glare) {
|
310 |
+
this.glareElement.style.transition = "";
|
311 |
+
}
|
312 |
+
}, this.settings.speed);
|
313 |
+
|
314 |
+
}
|
315 |
+
|
316 |
+
extendSettings(settings) {
|
317 |
+
let defaultSettings = {
|
318 |
+
reverse: false,
|
319 |
+
max: 35,
|
320 |
+
perspective: 1000,
|
321 |
+
easing: "cubic-bezier(.03,.98,.52,.99)",
|
322 |
+
scale: "1",
|
323 |
+
speed: "300",
|
324 |
+
transition: true,
|
325 |
+
axis: null,
|
326 |
+
glare: false,
|
327 |
+
"max-glare": 1,
|
328 |
+
"glare-prerender": false,
|
329 |
+
reset: true,
|
330 |
+
};
|
331 |
+
|
332 |
+
let newSettings = {};
|
333 |
+
for (var property in defaultSettings) {
|
334 |
+
if (property in settings) {
|
335 |
+
newSettings[property] = settings[property];
|
336 |
+
} else if (this.element.hasAttribute("data-tilt-" + property)) {
|
337 |
+
let attribute = this.element.getAttribute("data-tilt-" + property);
|
338 |
+
try {
|
339 |
+
newSettings[property] = JSON.parse(<any>attribute);
|
340 |
+
} catch (e) {
|
341 |
+
newSettings[property] = attribute;
|
342 |
+
}
|
343 |
+
} else {
|
344 |
+
newSettings[property] = defaultSettings[property];
|
345 |
+
}
|
346 |
+
}
|
347 |
+
|
348 |
+
return newSettings;
|
349 |
+
}
|
350 |
+
|
351 |
+
static init(elements, settings: VanillaTilt.TiltOptions = {}) {
|
352 |
+
if (elements instanceof Node) {
|
353 |
+
elements = [elements];
|
354 |
+
}
|
355 |
+
|
356 |
+
if (elements instanceof NodeList) {
|
357 |
+
elements = [].slice.call(elements);
|
358 |
+
}
|
359 |
+
|
360 |
+
if (!(elements instanceof Array)) {
|
361 |
+
return ;
|
362 |
+
}
|
363 |
+
|
364 |
+
elements.forEach((element) => {
|
365 |
+
if (!("vanillaTilt" in element)) {
|
366 |
+
element.vanillaTilt = new VanillaTilt(element, settings);
|
367 |
+
}
|
368 |
+
});
|
369 |
+
}
|
370 |
+
}
|
371 |
+
|
front/less/mixins/bfc.less
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
.bfc {
|
2 |
+
overflow: hidden;
|
3 |
+
}
|