justinsiow
commited on
Uploaded Utils, Pycache and Python Files
Browse files- __pycache__/schema_filter.cpython-38.pyc +0 -0
- eval_mode.py +182 -0
- schema_filter.py +339 -0
- training_mode.py +194 -0
- utils/__pycache__/classifier_model.cpython-38.pyc +0 -0
- utils/classifier_model.py +186 -0
__pycache__/schema_filter.cpython-38.pyc
ADDED
Binary file (11 kB). View file
|
|
eval_mode.py
ADDED
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from schema_filter import filter_func, SchemaItemClassifierInference
|
2 |
+
|
3 |
+
# 在eval模式下,sql不用提供
|
4 |
+
data = {
|
5 |
+
"text": "Name movie titles released in year 1945. Sort the listing by the descending order of movie popularity.",
|
6 |
+
"sql": "",
|
7 |
+
"schema": {
|
8 |
+
"schema_items": [
|
9 |
+
{
|
10 |
+
"table_name": "lists",
|
11 |
+
"table_comment": "",
|
12 |
+
"column_names": [
|
13 |
+
"user_id",
|
14 |
+
"list_id",
|
15 |
+
"list_title",
|
16 |
+
"list_movie_number",
|
17 |
+
"list_update_timestamp_utc",
|
18 |
+
"list_creation_timestamp_utc",
|
19 |
+
"list_followers",
|
20 |
+
"list_url",
|
21 |
+
"list_comments",
|
22 |
+
"list_description",
|
23 |
+
"list_cover_image_url",
|
24 |
+
"list_first_image_url",
|
25 |
+
"list_second_image_url",
|
26 |
+
"list_third_image_url"
|
27 |
+
],
|
28 |
+
"column_comments": [
|
29 |
+
"",
|
30 |
+
"",
|
31 |
+
"",
|
32 |
+
"",
|
33 |
+
"",
|
34 |
+
"",
|
35 |
+
"",
|
36 |
+
"",
|
37 |
+
"",
|
38 |
+
"",
|
39 |
+
"",
|
40 |
+
"",
|
41 |
+
"",
|
42 |
+
""
|
43 |
+
]
|
44 |
+
},
|
45 |
+
{
|
46 |
+
"table_name": "movies",
|
47 |
+
"table_comment": "",
|
48 |
+
"column_names": [
|
49 |
+
"movie_id",
|
50 |
+
"movie_title",
|
51 |
+
"movie_release_year",
|
52 |
+
"movie_url",
|
53 |
+
"movie_title_language",
|
54 |
+
"movie_popularity",
|
55 |
+
"movie_image_url",
|
56 |
+
"director_id",
|
57 |
+
"director_name",
|
58 |
+
"director_url"
|
59 |
+
],
|
60 |
+
"column_comments": [
|
61 |
+
"",
|
62 |
+
"",
|
63 |
+
"",
|
64 |
+
"",
|
65 |
+
"",
|
66 |
+
"",
|
67 |
+
"",
|
68 |
+
"",
|
69 |
+
"",
|
70 |
+
""
|
71 |
+
]
|
72 |
+
},
|
73 |
+
{
|
74 |
+
"table_name": "ratings_users",
|
75 |
+
"table_comment": "",
|
76 |
+
"column_names": [
|
77 |
+
"user_id",
|
78 |
+
"rating_date_utc",
|
79 |
+
"user_trialist",
|
80 |
+
"user_subscriber",
|
81 |
+
"user_avatar_image_url",
|
82 |
+
"user_cover_image_url",
|
83 |
+
"user_eligible_for_trial",
|
84 |
+
"user_has_payment_method"
|
85 |
+
],
|
86 |
+
"column_comments": [
|
87 |
+
"",
|
88 |
+
"",
|
89 |
+
"",
|
90 |
+
"",
|
91 |
+
"",
|
92 |
+
"",
|
93 |
+
"",
|
94 |
+
""
|
95 |
+
]
|
96 |
+
},
|
97 |
+
{
|
98 |
+
"table_name": "lists_users",
|
99 |
+
"table_comment": "",
|
100 |
+
"column_names": [
|
101 |
+
"user_id",
|
102 |
+
"list_id",
|
103 |
+
"list_update_date_utc",
|
104 |
+
"list_creation_date_utc",
|
105 |
+
"user_trialist",
|
106 |
+
"user_subscriber",
|
107 |
+
"user_avatar_image_url",
|
108 |
+
"user_cover_image_url",
|
109 |
+
"user_eligible_for_trial",
|
110 |
+
"user_has_payment_method"
|
111 |
+
],
|
112 |
+
"column_comments": [
|
113 |
+
"",
|
114 |
+
"",
|
115 |
+
"",
|
116 |
+
"",
|
117 |
+
"",
|
118 |
+
"",
|
119 |
+
"",
|
120 |
+
"",
|
121 |
+
"",
|
122 |
+
""
|
123 |
+
]
|
124 |
+
},
|
125 |
+
{
|
126 |
+
"table_name": "ratings",
|
127 |
+
"table_comment": "",
|
128 |
+
"column_names": [
|
129 |
+
"movie_id",
|
130 |
+
"rating_id",
|
131 |
+
"rating_url",
|
132 |
+
"rating_score",
|
133 |
+
"rating_timestamp_utc",
|
134 |
+
"critic",
|
135 |
+
"critic_likes",
|
136 |
+
"critic_comments",
|
137 |
+
"user_id",
|
138 |
+
"user_trialist",
|
139 |
+
"user_subscriber",
|
140 |
+
"user_eligible_for_trial",
|
141 |
+
"user_has_payment_method"
|
142 |
+
],
|
143 |
+
"column_comments": [
|
144 |
+
"",
|
145 |
+
"",
|
146 |
+
"",
|
147 |
+
"",
|
148 |
+
"",
|
149 |
+
"",
|
150 |
+
"",
|
151 |
+
"",
|
152 |
+
"",
|
153 |
+
"",
|
154 |
+
"",
|
155 |
+
"",
|
156 |
+
""
|
157 |
+
]
|
158 |
+
}
|
159 |
+
]
|
160 |
+
}
|
161 |
+
}
|
162 |
+
|
163 |
+
dataset = [data]
|
164 |
+
|
165 |
+
# 最多保留数据库中的7张表
|
166 |
+
num_top_k_tables = 7
|
167 |
+
# 对于每张保留的表,最多保留其中20个列,所以输入的prompt中最多有7*10=70个列
|
168 |
+
num_top_k_columns = 10
|
169 |
+
|
170 |
+
# 加载分类器模型
|
171 |
+
sic = SchemaItemClassifierInference("sic_merged")
|
172 |
+
|
173 |
+
# 对于测试数据,我们需要加载训练好的分类器,根据用户问题对表和列打分
|
174 |
+
dataset = filter_func(
|
175 |
+
dataset = dataset,
|
176 |
+
dataset_type = "eval",
|
177 |
+
sic = sic,
|
178 |
+
num_top_k_tables = num_top_k_tables,
|
179 |
+
num_top_k_columns = num_top_k_columns
|
180 |
+
)
|
181 |
+
|
182 |
+
print(dataset)
|
schema_filter.py
ADDED
@@ -0,0 +1,339 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import random
|
3 |
+
import torch
|
4 |
+
|
5 |
+
from tqdm import tqdm
|
6 |
+
from transformers import AutoTokenizer
|
7 |
+
from utils.classifier_model import SchemaItemClassifier
|
8 |
+
from transformers.trainer_utils import set_seed
|
9 |
+
|
10 |
+
def prepare_inputs_and_labels(sample, tokenizer):
|
11 |
+
table_names = [table["table_name"] for table in sample["schema"]["schema_items"]]
|
12 |
+
column_names = [table["column_names"] for table in sample["schema"]["schema_items"]]
|
13 |
+
column_num_in_each_table = [len(table["column_names"]) for table in sample["schema"]["schema_items"]]
|
14 |
+
|
15 |
+
# `column_name_word_indices` and `table_name_word_indices` record the word indices of each column and table in `input_words`, whose element is an integer
|
16 |
+
column_name_word_indices, table_name_word_indices = [], []
|
17 |
+
|
18 |
+
input_words = [sample["text"]]
|
19 |
+
for table_id, table_name in enumerate(table_names):
|
20 |
+
input_words.append("|")
|
21 |
+
input_words.append(table_name)
|
22 |
+
table_name_word_indices.append(len(input_words) - 1)
|
23 |
+
input_words.append(":")
|
24 |
+
|
25 |
+
for column_name in column_names[table_id]:
|
26 |
+
input_words.append(column_name)
|
27 |
+
column_name_word_indices.append(len(input_words) - 1)
|
28 |
+
input_words.append(",")
|
29 |
+
|
30 |
+
# remove the last ","
|
31 |
+
input_words = input_words[:-1]
|
32 |
+
|
33 |
+
tokenized_inputs = tokenizer(
|
34 |
+
input_words,
|
35 |
+
return_tensors="pt",
|
36 |
+
is_split_into_words = True,
|
37 |
+
padding = "max_length",
|
38 |
+
max_length = 512,
|
39 |
+
truncation = True
|
40 |
+
)
|
41 |
+
|
42 |
+
# after tokenizing, one table name or column name may be splitted into multiple tokens (i.e., sub-words)
|
43 |
+
# `column_name_token_indices` and `table_name_token_indices` records the token indices of each column and table in `input_ids`, whose element is a list of integer
|
44 |
+
column_name_token_indices, table_name_token_indices = [], []
|
45 |
+
word_indices = tokenized_inputs.word_ids(batch_index = 0)
|
46 |
+
|
47 |
+
# obtain token indices of each column in `input_ids`
|
48 |
+
for column_name_word_index in column_name_word_indices:
|
49 |
+
column_name_token_indices.append([token_id for token_id, word_index in enumerate(word_indices) if column_name_word_index == word_index])
|
50 |
+
|
51 |
+
# obtain token indices of each table in `input_ids`
|
52 |
+
for table_name_word_index in table_name_word_indices:
|
53 |
+
table_name_token_indices.append([token_id for token_id, word_index in enumerate(word_indices) if table_name_word_index == word_index])
|
54 |
+
|
55 |
+
encoder_input_ids = tokenized_inputs["input_ids"]
|
56 |
+
encoder_input_attention_mask = tokenized_inputs["attention_mask"]
|
57 |
+
|
58 |
+
# print("\n".join(tokenizer.batch_decode(encoder_input_ids, skip_special_tokens = True)))
|
59 |
+
|
60 |
+
if torch.cuda.is_available():
|
61 |
+
encoder_input_ids = encoder_input_ids.cuda()
|
62 |
+
encoder_input_attention_mask = encoder_input_attention_mask.cuda()
|
63 |
+
|
64 |
+
return encoder_input_ids, encoder_input_attention_mask, \
|
65 |
+
column_name_token_indices, table_name_token_indices, column_num_in_each_table
|
66 |
+
|
67 |
+
def get_schema(tables_and_columns):
|
68 |
+
schema_items = []
|
69 |
+
table_names = list(dict.fromkeys([t for t, c in tables_and_columns]))
|
70 |
+
for table_name in table_names:
|
71 |
+
schema_items.append(
|
72 |
+
{
|
73 |
+
"table_name": table_name,
|
74 |
+
"column_names": [c for t, c in tables_and_columns if t == table_name]
|
75 |
+
}
|
76 |
+
)
|
77 |
+
|
78 |
+
return {"schema_items": schema_items}
|
79 |
+
|
80 |
+
def get_sequence_length(text, tables_and_columns, tokenizer):
|
81 |
+
table_names = [t for t, c in tables_and_columns]
|
82 |
+
# duplicate `table_names` while preserving order
|
83 |
+
table_names = list(dict.fromkeys(table_names))
|
84 |
+
|
85 |
+
column_names = []
|
86 |
+
for table_name in table_names:
|
87 |
+
column_names.append([c for t, c in tables_and_columns if t == table_name])
|
88 |
+
|
89 |
+
input_words = [text]
|
90 |
+
for table_id, table_name in enumerate(table_names):
|
91 |
+
input_words.append("|")
|
92 |
+
input_words.append(table_name)
|
93 |
+
input_words.append(":")
|
94 |
+
for column_name in column_names[table_id]:
|
95 |
+
input_words.append(column_name)
|
96 |
+
input_words.append(",")
|
97 |
+
# remove the last ","
|
98 |
+
input_words = input_words[:-1]
|
99 |
+
|
100 |
+
tokenized_inputs = tokenizer(input_words, is_split_into_words = True)
|
101 |
+
|
102 |
+
return len(tokenized_inputs["input_ids"])
|
103 |
+
|
104 |
+
# handle extremely long schema sequences
|
105 |
+
def split_sample(sample, tokenizer):
|
106 |
+
text = sample["text"]
|
107 |
+
|
108 |
+
table_names = []
|
109 |
+
column_names = []
|
110 |
+
for table in sample["schema"]["schema_items"]:
|
111 |
+
table_names.append(table["table_name"] + " ( " + table["table_comment"] + " ) " \
|
112 |
+
if table["table_comment"] != "" else table["table_name"])
|
113 |
+
column_names.append([column_name + " ( " + column_comment + " ) " \
|
114 |
+
if column_comment != "" else column_name \
|
115 |
+
for column_name, column_comment in zip(table["column_names"], table["column_comments"])])
|
116 |
+
|
117 |
+
splitted_samples = []
|
118 |
+
recorded_tables_and_columns = []
|
119 |
+
|
120 |
+
for table_idx, table_name in enumerate(table_names):
|
121 |
+
for column_name in column_names[table_idx]:
|
122 |
+
if get_sequence_length(text, recorded_tables_and_columns + [[table_name, column_name]], tokenizer) < 500:
|
123 |
+
recorded_tables_and_columns.append([table_name, column_name])
|
124 |
+
else:
|
125 |
+
splitted_samples.append(
|
126 |
+
{
|
127 |
+
"text": text,
|
128 |
+
"schema": get_schema(recorded_tables_and_columns)
|
129 |
+
}
|
130 |
+
)
|
131 |
+
recorded_tables_and_columns = [[table_name, column_name]]
|
132 |
+
|
133 |
+
splitted_samples.append(
|
134 |
+
{
|
135 |
+
"text": text,
|
136 |
+
"schema": get_schema(recorded_tables_and_columns)
|
137 |
+
}
|
138 |
+
)
|
139 |
+
|
140 |
+
return splitted_samples
|
141 |
+
|
142 |
+
def merge_pred_results(sample, pred_results):
|
143 |
+
# table_names = [table["table_name"] for table in sample["schema"]["schema_items"]]
|
144 |
+
# column_names = [table["column_names"] for table in sample["schema"]["schema_items"]]
|
145 |
+
table_names = []
|
146 |
+
column_names = []
|
147 |
+
for table in sample["schema"]["schema_items"]:
|
148 |
+
table_names.append(table["table_name"] + " ( " + table["table_comment"] + " ) " \
|
149 |
+
if table["table_comment"] != "" else table["table_name"])
|
150 |
+
column_names.append([column_name + " ( " + column_comment + " ) " \
|
151 |
+
if column_comment != "" else column_name \
|
152 |
+
for column_name, column_comment in zip(table["column_names"], table["column_comments"])])
|
153 |
+
|
154 |
+
merged_results = []
|
155 |
+
for table_id, table_name in enumerate(table_names):
|
156 |
+
table_prob = 0
|
157 |
+
column_probs = []
|
158 |
+
for result_dict in pred_results:
|
159 |
+
if table_name in result_dict:
|
160 |
+
if table_prob < result_dict[table_name]["table_prob"]:
|
161 |
+
table_prob = result_dict[table_name]["table_prob"]
|
162 |
+
column_probs += result_dict[table_name]["column_probs"]
|
163 |
+
|
164 |
+
merged_results.append(
|
165 |
+
{
|
166 |
+
"table_name": table_name,
|
167 |
+
"table_prob": table_prob,
|
168 |
+
"column_names": column_names[table_id],
|
169 |
+
"column_probs": column_probs
|
170 |
+
}
|
171 |
+
)
|
172 |
+
|
173 |
+
return merged_results
|
174 |
+
|
175 |
+
def filter_func(dataset, dataset_type, sic, num_top_k_tables = 5, num_top_k_columns = 5):
|
176 |
+
for data in tqdm(dataset, desc = "filtering schema items for the dataset"):
|
177 |
+
filtered_schema = dict()
|
178 |
+
filtered_schema["schema_items"] = []
|
179 |
+
|
180 |
+
table_names = [table["table_name"] for table in data["schema"]["schema_items"]]
|
181 |
+
table_comments = [table["table_comment"] for table in data["schema"]["schema_items"]]
|
182 |
+
column_names = [table["column_names"] for table in data["schema"]["schema_items"]]
|
183 |
+
column_comments = [table["column_comments"] for table in data["schema"]["schema_items"]]
|
184 |
+
|
185 |
+
if dataset_type == "eval":
|
186 |
+
# predict scores for each tables and columns
|
187 |
+
pred_results = sic.predict(data)
|
188 |
+
# remain top_k1 tables for each database and top_k2 columns for each remained table
|
189 |
+
table_probs = [pred_result["table_prob"] for pred_result in pred_results]
|
190 |
+
table_indices = np.argsort(-np.array(table_probs), kind="stable")[:num_top_k_tables].tolist()
|
191 |
+
elif dataset_type == "train":
|
192 |
+
table_indices = [table_idx for table_idx, table_label in enumerate(data["table_labels"]) if table_label == 1]
|
193 |
+
if len(table_indices) < num_top_k_tables:
|
194 |
+
unused_table_indices = [table_idx for table_idx, table_label in enumerate(data["table_labels"]) if table_label == 0]
|
195 |
+
table_indices += random.sample(unused_table_indices, min(len(unused_table_indices), num_top_k_tables - len(table_indices)))
|
196 |
+
random.shuffle(table_indices)
|
197 |
+
|
198 |
+
for table_idx in table_indices:
|
199 |
+
if dataset_type == "eval":
|
200 |
+
column_probs = pred_results[table_idx]["column_probs"]
|
201 |
+
column_indices = np.argsort(-np.array(column_probs), kind="stable")[:num_top_k_columns].tolist()
|
202 |
+
elif dataset_type == "train":
|
203 |
+
column_indices = [column_idx for column_idx, column_label in enumerate(data["column_labels"][table_idx]) if column_label == 1]
|
204 |
+
if len(column_indices) < num_top_k_columns:
|
205 |
+
unused_column_indices = [column_idx for column_idx, column_label in enumerate(data["column_labels"][table_idx]) if column_label == 0]
|
206 |
+
column_indices += random.sample(unused_column_indices, min(len(unused_column_indices), num_top_k_columns - len(column_indices)))
|
207 |
+
random.shuffle(column_indices)
|
208 |
+
|
209 |
+
filtered_schema["schema_items"].append(
|
210 |
+
{
|
211 |
+
"table_name": table_names[table_idx],
|
212 |
+
"table_comment": table_comments[table_idx],
|
213 |
+
"column_names": [column_names[table_idx][column_idx] for column_idx in column_indices],
|
214 |
+
"column_comments": [column_comments[table_idx][column_idx] for column_idx in column_indices]
|
215 |
+
}
|
216 |
+
)
|
217 |
+
|
218 |
+
# replace the old schema with the filtered schema
|
219 |
+
data["schema"] = filtered_schema
|
220 |
+
|
221 |
+
if dataset_type == "train":
|
222 |
+
del data["table_labels"]
|
223 |
+
del data["column_labels"]
|
224 |
+
|
225 |
+
return dataset
|
226 |
+
|
227 |
+
def lista_contains_listb(lista, listb):
|
228 |
+
for b in listb:
|
229 |
+
if b not in lista:
|
230 |
+
return 0
|
231 |
+
|
232 |
+
return 1
|
233 |
+
|
234 |
+
class SchemaItemClassifierInference():
|
235 |
+
def __init__(self, model_save_path):
|
236 |
+
set_seed(42)
|
237 |
+
# load tokenizer
|
238 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_save_path, add_prefix_space = True)
|
239 |
+
# initialize model
|
240 |
+
self.model = SchemaItemClassifier(model_save_path, "test")
|
241 |
+
# load fine-tuned params
|
242 |
+
self.model.load_state_dict(torch.load(model_save_path + "/dense_classifier.pt", map_location=torch.device('cpu')), strict=False)
|
243 |
+
if torch.cuda.is_available():
|
244 |
+
self.model = self.model.cuda()
|
245 |
+
self.model.eval()
|
246 |
+
|
247 |
+
def predict_one(self, sample):
|
248 |
+
encoder_input_ids, encoder_input_attention_mask, column_name_token_indices,\
|
249 |
+
table_name_token_indices, column_num_in_each_table = prepare_inputs_and_labels(sample, self.tokenizer)
|
250 |
+
|
251 |
+
with torch.no_grad():
|
252 |
+
model_outputs = self.model(
|
253 |
+
encoder_input_ids,
|
254 |
+
encoder_input_attention_mask,
|
255 |
+
[column_name_token_indices],
|
256 |
+
[table_name_token_indices],
|
257 |
+
[column_num_in_each_table]
|
258 |
+
)
|
259 |
+
|
260 |
+
table_logits = model_outputs["batch_table_name_cls_logits"][0]
|
261 |
+
table_pred_probs = torch.nn.functional.softmax(table_logits, dim = 1)[:, 1].cpu().tolist()
|
262 |
+
|
263 |
+
column_logits = model_outputs["batch_column_info_cls_logits"][0]
|
264 |
+
column_pred_probs = torch.nn.functional.softmax(column_logits, dim = 1)[:, 1].cpu().tolist()
|
265 |
+
|
266 |
+
splitted_column_pred_probs = []
|
267 |
+
# split predicted column probs into each table
|
268 |
+
for table_id, column_num in enumerate(column_num_in_each_table):
|
269 |
+
splitted_column_pred_probs.append(column_pred_probs[sum(column_num_in_each_table[:table_id]): sum(column_num_in_each_table[:table_id]) + column_num])
|
270 |
+
column_pred_probs = splitted_column_pred_probs
|
271 |
+
|
272 |
+
result_dict = dict()
|
273 |
+
for table_idx, table in enumerate(sample["schema"]["schema_items"]):
|
274 |
+
result_dict[table["table_name"]] = {
|
275 |
+
"table_name": table["table_name"],
|
276 |
+
"table_prob": table_pred_probs[table_idx],
|
277 |
+
"column_names": table["column_names"],
|
278 |
+
"column_probs": column_pred_probs[table_idx],
|
279 |
+
}
|
280 |
+
|
281 |
+
return result_dict
|
282 |
+
|
283 |
+
def predict(self, test_sample):
|
284 |
+
splitted_samples = split_sample(test_sample, self.tokenizer)
|
285 |
+
pred_results = []
|
286 |
+
for splitted_sample in splitted_samples:
|
287 |
+
pred_results.append(self.predict_one(splitted_sample))
|
288 |
+
|
289 |
+
return merge_pred_results(test_sample, pred_results)
|
290 |
+
|
291 |
+
def evaluate_coverage(self, dataset):
|
292 |
+
max_k = 100
|
293 |
+
total_num_for_table_coverage, total_num_for_column_coverage = 0, 0
|
294 |
+
table_coverage_results = [0]*max_k
|
295 |
+
column_coverage_results = [0]*max_k
|
296 |
+
|
297 |
+
for data in dataset:
|
298 |
+
indices_of_used_tables = [idx for idx, label in enumerate(data["table_labels"]) if label == 1]
|
299 |
+
pred_results = sic.predict(data)
|
300 |
+
# print(pred_results)
|
301 |
+
table_probs = [res["table_prob"] for res in pred_results]
|
302 |
+
for k in range(max_k):
|
303 |
+
indices_of_top_k_tables = np.argsort(-np.array(table_probs), kind="stable")[:k+1].tolist()
|
304 |
+
if lista_contains_listb(indices_of_top_k_tables, indices_of_used_tables):
|
305 |
+
table_coverage_results[k] += 1
|
306 |
+
total_num_for_table_coverage += 1
|
307 |
+
|
308 |
+
for table_idx in range(len(data["table_labels"])):
|
309 |
+
indices_of_used_columns = [idx for idx, label in enumerate(data["column_labels"][table_idx]) if label == 1]
|
310 |
+
if len(indices_of_used_columns) == 0:
|
311 |
+
continue
|
312 |
+
column_probs = pred_results[table_idx]["column_probs"]
|
313 |
+
for k in range(max_k):
|
314 |
+
indices_of_top_k_columns = np.argsort(-np.array(column_probs), kind="stable")[:k+1].tolist()
|
315 |
+
if lista_contains_listb(indices_of_top_k_columns, indices_of_used_columns):
|
316 |
+
column_coverage_results[k] += 1
|
317 |
+
|
318 |
+
total_num_for_column_coverage += 1
|
319 |
+
|
320 |
+
indices_of_top_10_columns = np.argsort(-np.array(column_probs), kind="stable")[:10].tolist()
|
321 |
+
if lista_contains_listb(indices_of_top_10_columns, indices_of_used_columns) == 0:
|
322 |
+
print(pred_results[table_idx])
|
323 |
+
print(data["column_labels"][table_idx])
|
324 |
+
print(data["question"])
|
325 |
+
|
326 |
+
print(total_num_for_table_coverage)
|
327 |
+
print(table_coverage_results)
|
328 |
+
print(total_num_for_column_coverage)
|
329 |
+
print(column_coverage_results)
|
330 |
+
|
331 |
+
if __name__ == "__main__":
|
332 |
+
dataset_name = "bird_with_evidence"
|
333 |
+
# dataset_name = "bird"
|
334 |
+
# dataset_name = "spider"
|
335 |
+
sic = SchemaItemClassifierInference("sic_ckpts/sic_{}".format(dataset_name))
|
336 |
+
import json
|
337 |
+
dataset = json.load(open("./data/sft_eval_{}_text2sql.json".format(dataset_name)))
|
338 |
+
|
339 |
+
sic.evaluate_coverage(dataset)
|
training_mode.py
ADDED
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from schema_filter import filter_func
|
2 |
+
|
3 |
+
data = {
|
4 |
+
"text": "Name movie titles released in year 1945. Sort the listing by the descending order of movie popularity.",
|
5 |
+
"sql": "SELECT movie_title FROM movies WHERE movie_release_year = 1945 ORDER BY movie_popularity DESC LIMIT 1",
|
6 |
+
"schema": {
|
7 |
+
"schema_items": [
|
8 |
+
{
|
9 |
+
"table_name": "lists",
|
10 |
+
"table_comment": "",
|
11 |
+
"column_names": [
|
12 |
+
"user_id",
|
13 |
+
"list_id",
|
14 |
+
"list_title",
|
15 |
+
"list_movie_number",
|
16 |
+
"list_update_timestamp_utc",
|
17 |
+
"list_creation_timestamp_utc",
|
18 |
+
"list_followers",
|
19 |
+
"list_url",
|
20 |
+
"list_comments",
|
21 |
+
"list_description",
|
22 |
+
"list_cover_image_url",
|
23 |
+
"list_first_image_url",
|
24 |
+
"list_second_image_url",
|
25 |
+
"list_third_image_url"
|
26 |
+
],
|
27 |
+
"column_comments": [
|
28 |
+
"",
|
29 |
+
"",
|
30 |
+
"",
|
31 |
+
"",
|
32 |
+
"",
|
33 |
+
"",
|
34 |
+
"",
|
35 |
+
"",
|
36 |
+
"",
|
37 |
+
"",
|
38 |
+
"",
|
39 |
+
"",
|
40 |
+
"",
|
41 |
+
""
|
42 |
+
]
|
43 |
+
},
|
44 |
+
{
|
45 |
+
"table_name": "movies",
|
46 |
+
"table_comment": "",
|
47 |
+
"column_names": [
|
48 |
+
"movie_id",
|
49 |
+
"movie_title",
|
50 |
+
"movie_release_year",
|
51 |
+
"movie_url",
|
52 |
+
"movie_title_language",
|
53 |
+
"movie_popularity",
|
54 |
+
"movie_image_url",
|
55 |
+
"director_id",
|
56 |
+
"director_name",
|
57 |
+
"director_url"
|
58 |
+
],
|
59 |
+
"column_comments": [
|
60 |
+
"",
|
61 |
+
"",
|
62 |
+
"",
|
63 |
+
"",
|
64 |
+
"",
|
65 |
+
"",
|
66 |
+
"",
|
67 |
+
"",
|
68 |
+
"",
|
69 |
+
""
|
70 |
+
]
|
71 |
+
},
|
72 |
+
{
|
73 |
+
"table_name": "ratings_users",
|
74 |
+
"table_comment": "",
|
75 |
+
"column_names": [
|
76 |
+
"user_id",
|
77 |
+
"rating_date_utc",
|
78 |
+
"user_trialist",
|
79 |
+
"user_subscriber",
|
80 |
+
"user_avatar_image_url",
|
81 |
+
"user_cover_image_url",
|
82 |
+
"user_eligible_for_trial",
|
83 |
+
"user_has_payment_method"
|
84 |
+
],
|
85 |
+
"column_comments": [
|
86 |
+
"",
|
87 |
+
"",
|
88 |
+
"",
|
89 |
+
"",
|
90 |
+
"",
|
91 |
+
"",
|
92 |
+
"",
|
93 |
+
""
|
94 |
+
]
|
95 |
+
},
|
96 |
+
{
|
97 |
+
"table_name": "lists_users",
|
98 |
+
"table_comment": "",
|
99 |
+
"column_names": [
|
100 |
+
"user_id",
|
101 |
+
"list_id",
|
102 |
+
"list_update_date_utc",
|
103 |
+
"list_creation_date_utc",
|
104 |
+
"user_trialist",
|
105 |
+
"user_subscriber",
|
106 |
+
"user_avatar_image_url",
|
107 |
+
"user_cover_image_url",
|
108 |
+
"user_eligible_for_trial",
|
109 |
+
"user_has_payment_method"
|
110 |
+
],
|
111 |
+
"column_comments": [
|
112 |
+
"",
|
113 |
+
"",
|
114 |
+
"",
|
115 |
+
"",
|
116 |
+
"",
|
117 |
+
"",
|
118 |
+
"",
|
119 |
+
"",
|
120 |
+
"",
|
121 |
+
""
|
122 |
+
]
|
123 |
+
},
|
124 |
+
{
|
125 |
+
"table_name": "ratings",
|
126 |
+
"table_comment": "",
|
127 |
+
"column_names": [
|
128 |
+
"movie_id",
|
129 |
+
"rating_id",
|
130 |
+
"rating_url",
|
131 |
+
"rating_score",
|
132 |
+
"rating_timestamp_utc",
|
133 |
+
"critic",
|
134 |
+
"critic_likes",
|
135 |
+
"critic_comments",
|
136 |
+
"user_id",
|
137 |
+
"user_trialist",
|
138 |
+
"user_subscriber",
|
139 |
+
"user_eligible_for_trial",
|
140 |
+
"user_has_payment_method"
|
141 |
+
],
|
142 |
+
"column_comments": [
|
143 |
+
"",
|
144 |
+
"",
|
145 |
+
"",
|
146 |
+
"",
|
147 |
+
"",
|
148 |
+
"",
|
149 |
+
"",
|
150 |
+
"",
|
151 |
+
"",
|
152 |
+
"",
|
153 |
+
"",
|
154 |
+
"",
|
155 |
+
""
|
156 |
+
]
|
157 |
+
}
|
158 |
+
]
|
159 |
+
}
|
160 |
+
}
|
161 |
+
|
162 |
+
def find_used_tables_and_columns(dataset):
|
163 |
+
for data in dataset:
|
164 |
+
sql = data["sql"].lower()
|
165 |
+
data["table_labels"] = []
|
166 |
+
data["column_labels"] = []
|
167 |
+
|
168 |
+
for table_info in data["schema"]["schema_items"]:
|
169 |
+
table_name = table_info["table_name"]
|
170 |
+
data["table_labels"].append(1 if table_name.lower() in sql else 0)
|
171 |
+
data["column_labels"].append([1 if column_name.lower() in sql else 0 \
|
172 |
+
for column_name in table_info["column_names"]])
|
173 |
+
return dataset
|
174 |
+
|
175 |
+
dataset = [data]
|
176 |
+
|
177 |
+
# 根据sql找到用到的表和列
|
178 |
+
dataset = find_used_tables_and_columns(dataset)
|
179 |
+
|
180 |
+
# 最多保留数据库中的6张表
|
181 |
+
num_top_k_tables = 6
|
182 |
+
# 对于每张保留的表,最多保留其中6个列,所以输入的prompt中最多有6*6=36个列
|
183 |
+
num_top_k_columns = 6
|
184 |
+
|
185 |
+
# 对于训练数据,我们可以根据sql来模拟filter的过程,这时,sic(schema item classifier)是None就行,不需要用到模型
|
186 |
+
dataset = filter_func(
|
187 |
+
dataset = dataset,
|
188 |
+
dataset_type = "train",
|
189 |
+
sic = None,
|
190 |
+
num_top_k_tables = num_top_k_tables,
|
191 |
+
num_top_k_columns = num_top_k_columns
|
192 |
+
)
|
193 |
+
|
194 |
+
print(dataset)
|
utils/__pycache__/classifier_model.cpython-38.pyc
ADDED
Binary file (4.01 kB). View file
|
|
utils/classifier_model.py
ADDED
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
from transformers import AutoConfig, XLMRobertaXLModel
|
5 |
+
|
6 |
+
class SchemaItemClassifier(nn.Module):
|
7 |
+
def __init__(self, model_name_or_path, mode):
|
8 |
+
super(SchemaItemClassifier, self).__init__()
|
9 |
+
if mode in ["eval", "test"]:
|
10 |
+
# load config
|
11 |
+
config = AutoConfig.from_pretrained(model_name_or_path)
|
12 |
+
# randomly initialize model's parameters according to the config
|
13 |
+
self.plm_encoder = XLMRobertaXLModel(config)
|
14 |
+
elif mode == "train":
|
15 |
+
self.plm_encoder = XLMRobertaXLModel.from_pretrained(model_name_or_path)
|
16 |
+
else:
|
17 |
+
raise ValueError()
|
18 |
+
|
19 |
+
self.plm_hidden_size = self.plm_encoder.config.hidden_size
|
20 |
+
|
21 |
+
# column cls head
|
22 |
+
self.column_info_cls_head_linear1 = nn.Linear(self.plm_hidden_size, 256)
|
23 |
+
self.column_info_cls_head_linear2 = nn.Linear(256, 2)
|
24 |
+
|
25 |
+
# column bi-lstm layer
|
26 |
+
self.column_info_bilstm = nn.LSTM(
|
27 |
+
input_size = self.plm_hidden_size,
|
28 |
+
hidden_size = int(self.plm_hidden_size/2),
|
29 |
+
num_layers = 2,
|
30 |
+
dropout = 0,
|
31 |
+
bidirectional = True
|
32 |
+
)
|
33 |
+
|
34 |
+
# linear layer after column bi-lstm layer
|
35 |
+
self.column_info_linear_after_pooling = nn.Linear(self.plm_hidden_size, self.plm_hidden_size)
|
36 |
+
|
37 |
+
# table cls head
|
38 |
+
self.table_name_cls_head_linear1 = nn.Linear(self.plm_hidden_size, 256)
|
39 |
+
self.table_name_cls_head_linear2 = nn.Linear(256, 2)
|
40 |
+
|
41 |
+
# table bi-lstm pooling layer
|
42 |
+
self.table_name_bilstm = nn.LSTM(
|
43 |
+
input_size = self.plm_hidden_size,
|
44 |
+
hidden_size = int(self.plm_hidden_size/2),
|
45 |
+
num_layers = 2,
|
46 |
+
dropout = 0,
|
47 |
+
bidirectional = True
|
48 |
+
)
|
49 |
+
# linear layer after table bi-lstm layer
|
50 |
+
self.table_name_linear_after_pooling = nn.Linear(self.plm_hidden_size, self.plm_hidden_size)
|
51 |
+
|
52 |
+
# activation function
|
53 |
+
self.leakyrelu = nn.LeakyReLU()
|
54 |
+
self.tanh = nn.Tanh()
|
55 |
+
|
56 |
+
# table-column cross-attention layer
|
57 |
+
self.table_column_cross_attention_layer = nn.MultiheadAttention(embed_dim = self.plm_hidden_size, num_heads = 8)
|
58 |
+
|
59 |
+
# dropout function, p=0.2 means randomly set 20% neurons to 0
|
60 |
+
self.dropout = nn.Dropout(p = 0.2)
|
61 |
+
|
62 |
+
def table_column_cross_attention(
|
63 |
+
self,
|
64 |
+
table_name_embeddings_in_one_db,
|
65 |
+
column_info_embeddings_in_one_db,
|
66 |
+
column_number_in_each_table
|
67 |
+
):
|
68 |
+
table_num = table_name_embeddings_in_one_db.shape[0]
|
69 |
+
table_name_embedding_attn_list = []
|
70 |
+
for table_id in range(table_num):
|
71 |
+
table_name_embedding = table_name_embeddings_in_one_db[[table_id], :]
|
72 |
+
column_info_embeddings_in_one_table = column_info_embeddings_in_one_db[
|
73 |
+
sum(column_number_in_each_table[:table_id]) : sum(column_number_in_each_table[:table_id+1]), :]
|
74 |
+
|
75 |
+
table_name_embedding_attn, _ = self.table_column_cross_attention_layer(
|
76 |
+
table_name_embedding,
|
77 |
+
column_info_embeddings_in_one_table,
|
78 |
+
column_info_embeddings_in_one_table
|
79 |
+
)
|
80 |
+
|
81 |
+
table_name_embedding_attn_list.append(table_name_embedding_attn)
|
82 |
+
|
83 |
+
# residual connection
|
84 |
+
table_name_embeddings_in_one_db = table_name_embeddings_in_one_db + torch.cat(table_name_embedding_attn_list, dim = 0)
|
85 |
+
# row-wise L2 norm
|
86 |
+
table_name_embeddings_in_one_db = torch.nn.functional.normalize(table_name_embeddings_in_one_db, p=2.0, dim=1)
|
87 |
+
|
88 |
+
return table_name_embeddings_in_one_db
|
89 |
+
|
90 |
+
def table_column_cls(
|
91 |
+
self,
|
92 |
+
encoder_input_ids,
|
93 |
+
encoder_input_attention_mask,
|
94 |
+
batch_aligned_column_info_ids,
|
95 |
+
batch_aligned_table_name_ids,
|
96 |
+
batch_column_number_in_each_table
|
97 |
+
):
|
98 |
+
batch_size = encoder_input_ids.shape[0]
|
99 |
+
|
100 |
+
encoder_output = self.plm_encoder(
|
101 |
+
input_ids = encoder_input_ids,
|
102 |
+
attention_mask = encoder_input_attention_mask,
|
103 |
+
return_dict = True
|
104 |
+
) # encoder_output["last_hidden_state"].shape = (batch_size x seq_length x hidden_size)
|
105 |
+
|
106 |
+
batch_table_name_cls_logits, batch_column_info_cls_logits = [], []
|
107 |
+
|
108 |
+
# handle each data in current batch
|
109 |
+
for batch_id in range(batch_size):
|
110 |
+
column_number_in_each_table = batch_column_number_in_each_table[batch_id]
|
111 |
+
sequence_embeddings = encoder_output["last_hidden_state"][batch_id, :, :] # (seq_length x hidden_size)
|
112 |
+
|
113 |
+
# obtain table ids for each table
|
114 |
+
aligned_table_name_ids = batch_aligned_table_name_ids[batch_id]
|
115 |
+
# obtain column ids for each column
|
116 |
+
aligned_column_info_ids = batch_aligned_column_info_ids[batch_id]
|
117 |
+
|
118 |
+
table_name_embedding_list, column_info_embedding_list = [], []
|
119 |
+
|
120 |
+
# obtain table embedding via bi-lstm pooling + a non-linear layer
|
121 |
+
for table_name_ids in aligned_table_name_ids:
|
122 |
+
table_name_embeddings = sequence_embeddings[table_name_ids, :]
|
123 |
+
|
124 |
+
# BiLSTM pooling
|
125 |
+
output_t, (hidden_state_t, cell_state_t) = self.table_name_bilstm(table_name_embeddings)
|
126 |
+
table_name_embedding = hidden_state_t[-2:, :].view(1, self.plm_hidden_size)
|
127 |
+
table_name_embedding_list.append(table_name_embedding)
|
128 |
+
table_name_embeddings_in_one_db = torch.cat(table_name_embedding_list, dim = 0)
|
129 |
+
# non-linear mlp layer
|
130 |
+
table_name_embeddings_in_one_db = self.leakyrelu(self.table_name_linear_after_pooling(table_name_embeddings_in_one_db))
|
131 |
+
|
132 |
+
# obtain column embedding via bi-lstm pooling + a non-linear layer
|
133 |
+
for column_info_ids in aligned_column_info_ids:
|
134 |
+
column_info_embeddings = sequence_embeddings[column_info_ids, :]
|
135 |
+
|
136 |
+
# BiLSTM pooling
|
137 |
+
output_c, (hidden_state_c, cell_state_c) = self.column_info_bilstm(column_info_embeddings)
|
138 |
+
column_info_embedding = hidden_state_c[-2:, :].view(1, self.plm_hidden_size)
|
139 |
+
column_info_embedding_list.append(column_info_embedding)
|
140 |
+
column_info_embeddings_in_one_db = torch.cat(column_info_embedding_list, dim = 0)
|
141 |
+
# non-linear mlp layer
|
142 |
+
column_info_embeddings_in_one_db = self.leakyrelu(self.column_info_linear_after_pooling(column_info_embeddings_in_one_db))
|
143 |
+
|
144 |
+
# table-column (tc) cross-attention
|
145 |
+
table_name_embeddings_in_one_db = self.table_column_cross_attention(
|
146 |
+
table_name_embeddings_in_one_db,
|
147 |
+
column_info_embeddings_in_one_db,
|
148 |
+
column_number_in_each_table
|
149 |
+
)
|
150 |
+
|
151 |
+
# calculate table 0-1 logits
|
152 |
+
table_name_embeddings_in_one_db = self.table_name_cls_head_linear1(table_name_embeddings_in_one_db)
|
153 |
+
table_name_embeddings_in_one_db = self.dropout(self.leakyrelu(table_name_embeddings_in_one_db))
|
154 |
+
table_name_cls_logits = self.table_name_cls_head_linear2(table_name_embeddings_in_one_db)
|
155 |
+
|
156 |
+
# calculate column 0-1 logits
|
157 |
+
column_info_embeddings_in_one_db = self.column_info_cls_head_linear1(column_info_embeddings_in_one_db)
|
158 |
+
column_info_embeddings_in_one_db = self.dropout(self.leakyrelu(column_info_embeddings_in_one_db))
|
159 |
+
column_info_cls_logits = self.column_info_cls_head_linear2(column_info_embeddings_in_one_db)
|
160 |
+
|
161 |
+
batch_table_name_cls_logits.append(table_name_cls_logits)
|
162 |
+
batch_column_info_cls_logits.append(column_info_cls_logits)
|
163 |
+
|
164 |
+
return batch_table_name_cls_logits, batch_column_info_cls_logits
|
165 |
+
|
166 |
+
def forward(
|
167 |
+
self,
|
168 |
+
encoder_input_ids,
|
169 |
+
encoder_attention_mask,
|
170 |
+
batch_aligned_column_info_ids,
|
171 |
+
batch_aligned_table_name_ids,
|
172 |
+
batch_column_number_in_each_table,
|
173 |
+
):
|
174 |
+
batch_table_name_cls_logits, batch_column_info_cls_logits \
|
175 |
+
= self.table_column_cls(
|
176 |
+
encoder_input_ids,
|
177 |
+
encoder_attention_mask,
|
178 |
+
batch_aligned_column_info_ids,
|
179 |
+
batch_aligned_table_name_ids,
|
180 |
+
batch_column_number_in_each_table
|
181 |
+
)
|
182 |
+
|
183 |
+
return {
|
184 |
+
"batch_table_name_cls_logits" : batch_table_name_cls_logits,
|
185 |
+
"batch_column_info_cls_logits": batch_column_info_cls_logits
|
186 |
+
}
|