|
import argparse |
|
from fengshen.pipelines.tcbert import TCBertPipelines |
|
from pytorch_lightning import seed_everything |
|
|
|
def main(): |
|
seed_everything(123) |
|
total_parser = argparse.ArgumentParser("Topic Classification") |
|
total_parser = TCBertPipelines.piplines_args(total_parser) |
|
args = total_parser.parse_args() |
|
|
|
pretrained_model_path = 'IDEA-CCNL/Erlangshen-TCBert-110M-Classification-Chinese' |
|
args.learning_rate = 2e-5 |
|
args.max_length = 512 |
|
args.max_epochs = 5 |
|
args.batchsize = 4 |
|
args.train = 'train' |
|
args.default_root_dir = './' |
|
|
|
args.fixed_lablen = 2 |
|
|
|
train_data = [ |
|
{"content": "真正的放养教育,放的是孩子的思维,养的是孩子的习惯", "label": "故事"}, |
|
{"content": "《唐人街探案》捧红了王宝强跟刘昊然,唯独戏份不少的他发展最差", "label": "娱乐"}, |
|
{"content": "油价攀升 阿曼经济加速增长", "label": "财经"}, |
|
{"content": "日本男篮近期动作频频,中国队的未来劲敌会是他们吗?", "label": "体育"}, |
|
{"content": "教育部:坚决防止因撤并乡村小规模学校导致学生上学困难", "label": "教育"}, |
|
{"content": "LOL设计最完美的三个英雄,玩家们都很认可!", "label": "电竞"}, |
|
{"content": "上联:浅看红楼终是梦,怎么对下联?", "label": "文化"}, |
|
{"content": "楼市再出新政!北京部分限房价项目或转为共有产权房", "label": "房产"}, |
|
{"content": "企业怎样选云服务器?云服务器哪家比较好?", "label": "科技"}, |
|
{"content": "贝纳利的三缸车TRE899K、TRE1130K华丽转身", "label": "汽车"}, |
|
{"content": "如何评价:刘姝威的《严惩做空中国股市者》?", "label": "股票"}, |
|
{"content": "宁夏邀深圳市民共赴“寻找穿越”之旅", "label": "旅游"}, |
|
{"content": "日本自民党又一派系力挺安倍 称会竭尽全力", "label": "国际"}, |
|
{"content": "农村养老保险每年交5000,交满15年退休后能每月领多少钱?", "label": "农业"}, |
|
{"content": "国产舰载机首次现身,进度超过预期,将率先在滑跃航母测试", "label": "军事"} |
|
] |
|
|
|
dev_data = [ |
|
{"content": "西游记后传中,灵儿最爱的女人是谁?不是碧游!", "label": "故事"}, |
|
{"content": "小李子莱奥纳多有特别的提袋子技能,这些年他还有过哪些神奇的造型?", "label": "娱乐"}, |
|
{"content": "现在手上有钱是投资买房还是存钱,为什么?", "label": "财经"}, |
|
{"content": "迪卡侬的衣服值得购买吗?", "label": "体育"}, |
|
{"content": "黑龙江省旅游委在齐齐哈尔组织举办导游培训班", "label": "教育"}, |
|
{"content": "《王者荣耀》中,哪些英雄的大招最“废柴”?", "label": "电竞"}, |
|
{"content": "上交演绎马勒《复活》,用音乐带来抚慰和希望", "label": "文化"}, |
|
{"content": "All in服务业,58集团在租房、住房市场的全力以赋", "label": "房产"}, |
|
{"content": "为什么有的人宁愿选择骁龙660的X21,也不买骁龙845的小米MIX2S?", "label": "科技"}, |
|
{"content": "众泰大型SUV来袭,售13.98万,2.0T榨出231马力,汉兰达要危险了", "label": "汽车"}, |
|
{"content": "股票放量下趺,大资金出逃谁在接盘?", "label": "股票"}, |
|
{"content": "广西博白最大的特色是什么?", "label": "旅游"}, |
|
{"content": "特朗普退出《伊朗核协议》,对此你怎么看?", "label": "国际"}, |
|
{"content": "卖水果利润怎么样?", "label": "农业"}, |
|
{"content": "特种兵都是身材高大的猛男么?别再被电视骗了,超过1米8都不合格", "label": "军事"} |
|
] |
|
|
|
test_data = [ |
|
{"content": "廖凡重出“江湖”再争影帝 亮相戛纳红毯霸气有型"}, |
|
{"content": "《绝地求生: 刺激战场》越玩越卡?竟是手机厂商没交“保护费”!"}, |
|
{"content": "买涡轮增压还是自然吸气车?今天终于有答案了!"}, |
|
] |
|
|
|
|
|
prompt_label = { |
|
"体育":"体育", "军事":"军事", "农业":"农业", "国际":"国际", |
|
"娱乐":"娱乐", "房产":"房产", "故事":"故事", "教育":"教育", |
|
"文化":"文化", "旅游":"旅游", "汽车":"汽车", "电竞":"电竞", |
|
"科技":"科技", "股票":"股票", "财经":"财经" |
|
} |
|
|
|
|
|
|
|
prompt = "下面是一则关于{}的新闻:" |
|
|
|
model = TCBertPipelines(args, model_path=pretrained_model_path, nlabels=len(prompt_label)) |
|
|
|
if args.train: |
|
model.train(train_data, dev_data, prompt, prompt_label) |
|
result = model.predict(test_data, prompt, prompt_label) |
|
|
|
for i, line in enumerate(result): |
|
print({"content":test_data[i]["content"], "label":list(prompt_label.keys())[line]}) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|