fclong's picture
Upload 396 files
8ebda9e
import argparse
from fengshen.pipelines.multiplechoice import UniMCPipelines
def main():
total_parser = argparse.ArgumentParser("TASK NAME")
total_parser = UniMCPipelines.piplines_args(total_parser)
args = total_parser.parse_args()
pretrained_model_path = 'IDEA-CCNL/Erlangshen-UniMC-RoBERTa-110M-Chinese'
args.learning_rate = 2e-5
args.max_length = 512
args.max_epochs = 3
args.batchsize = 8
args.train = 'train'
args.default_root_dir = './'
model = UniMCPipelines(args, model_path=pretrained_model_path)
train_data = [ # 训练数据
{
"texta": "凌云研发的国产两轮电动车怎么样,有什么惊喜?",
"textb": "",
"question": "下面新闻属于哪一个类别?",
"choice": [
"教育",
"科技",
"军事",
"旅游",
"国际",
"股票",
"农业",
"电竞"
],
"answer": "科技",
"label": 1,
"id": 0
}
]
dev_data = [ # 验证数据
{
"texta": "我四千一个月,老婆一千五一个月,存款八万且有两小孩,是先买房还是先买车?",
"textb": "",
"question": "下面新闻属于哪一个类别?",
"choice": [
"故事",
"文化",
"娱乐",
"体育",
"财经",
"房产",
"汽车"
],
"answer": "汽车",
"label": 6,
"id": 0
}
]
test_data = [ # 测试数据
{"texta": "街头偶遇2018款长安CS35,颜值美炸!或售6万起,还买宝骏510?",
"textb": "",
"question": "下面新闻属于哪一个类别?",
"choice": [
"房产",
"汽车",
"教育",
"军事"
],
"answer": "汽车",
"label": 1,
"id": 7759}
]
if args.train:
model.train(train_data, dev_data)
result = model.predict(test_data)
for line in result:
print(line)
if __name__ == "__main__":
main()