pytorch 代码

https://github.com/JunnYu/GAU-alpha-pytorch

bert4keras代码

https://github.com/ZhuiyiTechnology/GAU-alpha

Install

pip install git+https://github.com/JunnYu/GAU-alpha-pytorch.git
or
pip install gau_alpha

评测对比

CLUE-dev榜单分类任务结果,base版本。

iflytek tnews afqmc cmnli ocnli wsc csl
BERT 60.06 56.80 72.41 79.56 73.93 78.62 83.93
RoBERTa 60.64 58.06 74.05 81.24 76.00 87.50 84.50
RoFormer 60.91 57.54 73.52 80.92 76.07 86.84 84.63
RoFormerV2* 60.87 56.54 72.75 80.34 75.36 80.92 84.67
GAU-α 61.41 57.76 74.17 81.82 75.86 79.93 85.67
RoFormerV2-pytorch 62.87 59.03 76.20 80.85 79.73 87.82 91.87
GAU-α-pytorch(Adafactor) 61.18 57.52 73.42 80.91 75.69 80.59 85.5
GAU-α-pytorch(AdamW wd0.01 warmup0.1) 60.68 57.95 73.08 81.02 75.36 81.25 83.93

CLUE-test榜单分类任务结果,base版本。

iflytek tnews afqmc cmnli ocnli wsc csl
RoFormerV2-pytorch 63.15 58.24 75.42 80.59 74.17 83.79 83.73
GAU-α-pytorch(Adafactor) 61.38 57.08 74.05 80.37 73.53 74.83 85.6
GAU-α-pytorch(AdamW wd0.01 warmup0.1) 60.54 57.67 72.44 80.32 72.97 76.55 84.13

CLUE-dev集榜单阅读理解和NER结果

cmrc2018 c3 chid cluener
BERT 56.17 60.54 85.69 79.45
RoBERTa 56.54 67.66 86.71 79.47
RoFormer 56.26 67.24 86.57 79.72
RoFormerV2* 57.91 64.62 85.09 81.08
GAU-α 58.09 68.24 87.91 80.01

注:

  • 其中RoFormerV2*表示的是未进行多任务学习的RoFormerV2模型,该模型苏神并未开源,感谢苏神的提醒。
  • 其中不带有pytorch后缀结果都是从GAU-alpha仓库复制过来的。
  • 其中带有pytorch后缀的结果都是自己训练得出的。

Usage

import torch

from gau_alpha import GAUAlphaForMaskedLM, GAUAlphaTokenizer

text = "今天[MASK]很好,我[MASK]去公园玩。"
tokenizer = GAUAlphaTokenizer.from_pretrained(
    "junnyu/chinese_GAU-alpha-char_L-24_H-768"
)
pt_model = GAUAlphaForMaskedLM.from_pretrained(
    "junnyu/chinese_GAU-alpha-char_L-24_H-768"
)

pt_inputs = tokenizer(text, return_tensors="pt")

with torch.no_grad():
    pt_outputs = pt_model(**pt_inputs).logits[0]
pt_outputs_sentence = "pytorch: "
for i, id in enumerate(tokenizer.encode(text)):
    if id == tokenizer.mask_token_id:
        val, idx = pt_outputs[i].softmax(-1).topk(k=5)
        tokens = tokenizer.convert_ids_to_tokens(idx)
        new_tokens = []
        for v, t in zip(val.cpu(), tokens):
            new_tokens.append(f"{t}+{round(v.item(),4)}")
        pt_outputs_sentence += "[" + "||".join(new_tokens) + "]"
    else:
        pt_outputs_sentence += "".join(
            tokenizer.convert_ids_to_tokens([id], skip_special_tokens=True)
        )
print(pt_outputs_sentence)
# pytorch: 今天[天+0.8657||气+0.0535||阳+0.0165||,+0.0126||晴+0.0111]很好,我[要+0.4619||想+0.4352||又+0.0252||就+0.0157||跑+0.0064]去公园玩。

Reference

Bibtex:

@techreport{gau-alpha,
  title={GAU-α: GAU-based Transformers for NLP - ZhuiyiAI},
  author={Jianlin Su, Shengfeng Pan, Bo Wen, Yunfeng Liu},
  year={2022},
  url="https://github.com/ZhuiyiTechnology/GAU-alpha",
}
Downloads last month
6
Inference Examples
Inference API (serverless) has been turned off for this model.