liujch1998 commited on
Commit
8185fe8
Β·
1 Parent(s): 8aa36ac

Add logging

Browse files
Files changed (2) hide show
  1. app.py +64 -30
  2. requirements.txt +2 -1
app.py CHANGED
@@ -2,48 +2,82 @@ import gradio as gr
2
  import os
3
  import torch
4
  import transformers
 
 
 
5
  import shutil
6
- stat = shutil.disk_usage('/home/user/app')
7
- print('Disk usage:')
8
- print(stat)
9
- import os
10
- # execute a shell command and print its output
11
- print(os.popen('df -h').read())
12
- print(os.popen('du -sh ~').read())
13
 
14
  device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  class Interactive:
17
  def __init__(self):
18
- self.tokenizer = transformers.AutoTokenizer.from_pretrained('liujch1998/cd-pi', use_auth_token=os.environ['HF_TOKEN_DOWNLOAD'])
19
- self.model = transformers.T5EncoderModel.from_pretrained('liujch1998/cd-pi', use_auth_token=os.environ['HF_TOKEN_DOWNLOAD']).to(device)
20
- self.linear = torch.nn.Linear(self.model.shared.embedding_dim, 1).to(device)
21
- self.linear.weight = torch.nn.Parameter(self.model.shared.weight[32099, :].unsqueeze(0)) # (1, D)
22
- self.linear.bias = torch.nn.Parameter(self.model.shared.weight[32098, 0].unsqueeze(0)) # (1)
23
- self.model.eval()
24
- self.t = self.model.shared.weight[32097, 0].item()
25
 
26
  def run(self, statement):
27
- input_ids = self.tokenizer.batch_encode_plus([statement], return_tensors='pt', padding='longest').input_ids.to(device)
28
- with torch.no_grad():
29
- output = self.model(input_ids)
30
- last_hidden_state = output.last_hidden_state.to(device) # (B=1, L, D)
31
- hidden = last_hidden_state[0, -1, :] # (D)
32
- logit = self.linear(hidden).squeeze(-1) # ()
33
- logit_calibrated = logit / self.t
34
- score = logit.sigmoid()
35
- score_calibrated = logit_calibrated.sigmoid()
 
 
 
 
 
 
36
  return {
37
- 'logit': logit.item(),
38
- 'logit_calibrated': logit_calibrated.item(),
39
- 'score': score.item(),
40
- 'score_calibrated': score_calibrated.item(),
41
  }
42
 
43
  interactive = Interactive()
44
 
45
- def predict(statement, model):
46
  result = interactive.run(statement)
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  return {
48
  'True': result['score_calibrated'],
49
  'False': 1 - result['score_calibrated'],
@@ -113,14 +147,14 @@ examples = [
113
  ]
114
 
115
  input_statement = gr.Dropdown(choices=examples, label='Statement:')
116
- input_model = gr.Textbox(label='Commonsense statement verification model:', value='liujch1998/cd-pi', interactive=False)
117
  output = gr.outputs.Label(num_top_classes=2)
118
 
119
  description = '''This is a demo for a commonsense statement verification model. Under development.'''
120
 
121
  gr.Interface(
122
  fn=predict,
123
- inputs=[input_statement, input_model],
124
  outputs=output,
125
  title="cd-pi Demo",
126
  description=description,
 
2
  import os
3
  import torch
4
  import transformers
5
+ import huggingface_hub
6
+ import datetime
7
+ import json
8
  import shutil
 
 
 
 
 
 
 
9
 
10
  device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
11
 
12
+ HF_TOKEN_DOWNLOAD = os.environ['HF_TOKEN_DOWNLOAD']
13
+ HF_TOKEN_UPLOAD = os.environ['HF_TOKEN_UPLOAD']
14
+
15
+ MODEL_NAME = 'liujch1998/cd-pi'
16
+ DATASET_REPO_URL = "https://huggingface.co/datasets/liujch1998/cd-pi-dataset"
17
+ DATA_DIR = 'data'
18
+ DATA_PATH = os.path.join(DATA_DIR, 'data.jsonl')
19
+
20
+ try:
21
+ shutil.rmtree(DATA_DIR)
22
+ except:
23
+ pass
24
+ repo = huggingface_hub.Repository(
25
+ local_dir=DATA_DIR,
26
+ clone_from=DATASET_REPO_URL,
27
+ use_auth_token=HF_TOKEN_DOWNLOAD,
28
+ )
29
+ repo.git_pull()
30
+
31
  class Interactive:
32
  def __init__(self):
33
+ self.tokenizer = transformers.AutoTokenizer.from_pretrained(MODEL_NAME, use_auth_token=HF_TOKEN_DOWNLOAD)
34
+ # self.model = transformers.T5EncoderModel.from_pretrained(MODEL_NAME, use_auth_token=HF_TOKEN_DOWNLOAD, low_cpu_mem_usage=True, device_map='auto', torch_dtype='auto')
35
+ # self.linear = torch.nn.Linear(self.model.shared.embedding_dim, 1).to(device)
36
+ # self.linear.weight = torch.nn.Parameter(self.model.shared.weight[32099, :].unsqueeze(0)) # (1, D)
37
+ # self.linear.bias = torch.nn.Parameter(self.model.shared.weight[32098, 0].unsqueeze(0)) # (1)
38
+ # self.model.eval()
39
+ # self.t = self.model.shared.weight[32097, 0].item()
40
 
41
  def run(self, statement):
42
+ # input_ids = self.tokenizer.batch_encode_plus([statement], return_tensors='pt', padding='longest').input_ids.to(device)
43
+ # with torch.no_grad():
44
+ # output = self.model(input_ids)
45
+ # last_hidden_state = output.last_hidden_state.to(device) # (B=1, L, D)
46
+ # hidden = last_hidden_state[0, -1, :] # (D)
47
+ # logit = self.linear(hidden).squeeze(-1) # ()
48
+ # logit_calibrated = logit / self.t
49
+ # score = logit.sigmoid()
50
+ # score_calibrated = logit_calibrated.sigmoid()
51
+ # return {
52
+ # 'logit': logit.item(),
53
+ # 'logit_calibrated': logit_calibrated.item(),
54
+ # 'score': score.item(),
55
+ # 'score_calibrated': score_calibrated.item(),
56
+ # }
57
  return {
58
+ 'logit': 0.0,
59
+ 'logit_calibrated': 0.0,
60
+ 'score': 0.5,
61
+ 'score_calibrated': 0.5,
62
  }
63
 
64
  interactive = Interactive()
65
 
66
+ def predict(statement):
67
  result = interactive.run(statement)
68
+ with open(DATA_PATH, 'a') as f:
69
+ row = {
70
+ 'timestamp': datetime.datetime.now().strftime('%Y%m%d-%H%M%S'),
71
+ 'statement': statement,
72
+ 'logit': result['logit'],
73
+ 'logit_calibrated': result['logit_calibrated'],
74
+ 'score': result['score'],
75
+ 'score_calibrated': result['score_calibrated'],
76
+ }
77
+ json.dump(row, f, ensure_ascii=False)
78
+ f.write('\n')
79
+ commit_url = repo.push_to_hub()
80
+ print(commit_url)
81
  return {
82
  'True': result['score_calibrated'],
83
  'False': 1 - result['score_calibrated'],
 
147
  ]
148
 
149
  input_statement = gr.Dropdown(choices=examples, label='Statement:')
150
+ input_model = gr.Textbox(label='Commonsense statement verification model:', value=MODEL_NAME, interactive=False)
151
  output = gr.outputs.Label(num_top_classes=2)
152
 
153
  description = '''This is a demo for a commonsense statement verification model. Under development.'''
154
 
155
  gr.Interface(
156
  fn=predict,
157
+ inputs=[input_statement],
158
  outputs=output,
159
  title="cd-pi Demo",
160
  description=description,
requirements.txt CHANGED
@@ -1,4 +1,5 @@
1
  torch==1.13.1
2
  transformers==4.23.1
3
  tokenizers==0.13.2
4
- sentencepiece==0.1.96
 
 
1
  torch==1.13.1
2
  transformers==4.23.1
3
  tokenizers==0.13.2
4
+ sentencepiece==0.1.96
5
+ huggingface_hub