|
import argparse |
|
from fengshen.pipelines.multiplechoice import UniMCPipelines |
|
|
|
|
|
def main(): |
|
total_parser = argparse.ArgumentParser("TASK NAME") |
|
total_parser = UniMCPipelines.piplines_args(total_parser) |
|
args = total_parser.parse_args() |
|
|
|
pretrained_model_path = 'IDEA-CCNL/Erlangshen-UniMC-RoBERTa-110M-Chinese' |
|
args.learning_rate = 2e-5 |
|
args.max_length = 512 |
|
args.max_epochs = 3 |
|
args.batchsize = 8 |
|
args.train = 'train' |
|
args.default_root_dir = './' |
|
|
|
model = UniMCPipelines(args, model_path=pretrained_model_path) |
|
|
|
train_data = [ |
|
{ |
|
"texta": "凌云研发的国产两轮电动车怎么样,有什么惊喜?", |
|
"textb": "", |
|
"question": "下面新闻属于哪一个类别?", |
|
"choice": [ |
|
"教育", |
|
"科技", |
|
"军事", |
|
"旅游", |
|
"国际", |
|
"股票", |
|
"农业", |
|
"电竞" |
|
], |
|
"answer": "科技", |
|
"label": 1, |
|
"id": 0 |
|
} |
|
] |
|
dev_data = [ |
|
{ |
|
"texta": "我四千一个月,老婆一千五一个月,存款八万且有两小孩,是先买房还是先买车?", |
|
"textb": "", |
|
"question": "下面新闻属于哪一个类别?", |
|
"choice": [ |
|
"故事", |
|
"文化", |
|
"娱乐", |
|
"体育", |
|
"财经", |
|
"房产", |
|
"汽车" |
|
], |
|
"answer": "汽车", |
|
"label": 6, |
|
"id": 0 |
|
} |
|
] |
|
test_data = [ |
|
{"texta": "街头偶遇2018款长安CS35,颜值美炸!或售6万起,还买宝骏510?", |
|
"textb": "", |
|
"question": "下面新闻属于哪一个类别?", |
|
"choice": [ |
|
"房产", |
|
"汽车", |
|
"教育", |
|
"军事" |
|
], |
|
"answer": "汽车", |
|
"label": 1, |
|
"id": 7759} |
|
] |
|
|
|
if args.train: |
|
model.train(train_data, dev_data) |
|
result = model.predict(test_data) |
|
for line in result: |
|
print(line) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|