QINGCHE commited on
Commit
e350168
·
1 Parent(s): f21d2ac
Files changed (5) hide show
  1. abstruct.py +71 -0
  2. classification.py +83 -0
  3. requirements.txt +109 -0
  4. run.py +1 -0
  5. util.py +75 -0
abstruct.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 导入所需的库
2
+ import json
3
+ import paddlenlp
4
+ import gensim
5
+ import sklearn
6
+ from collections import Counter
7
+ from gensim import corpora, models, similarities
8
+ import numpy as np
9
+ import matplotlib.pyplot as plt
10
+
11
+
12
+
13
+
14
+
15
+ def build_corpus(sentences):
16
+ # 使用paddlenlp提供的预训练词典
17
+ vocab = paddlenlp.transformers.BertTokenizer.from_pretrained('bert-base-chinese').vocab
18
+
19
+ # 创建分词器
20
+ tokenizer = paddlenlp.data.JiebaTokenizer(vocab)
21
+ # 对每个句子进行分词,并去除停用词,得到一个二维列表
22
+ stopwords = [""]
23
+ words_list = []
24
+ for sentence in sentences:
25
+ words = [word for word in tokenizer.cut(sentence) if word not in stopwords]
26
+ words_list.append(words)
27
+ # print(words_list)
28
+ # 将二维列表转换为一维列表
29
+ words = [word for sentence in words_list for word in sentence]
30
+
31
+ dictionary = corpora.Dictionary(words_list)
32
+ corpus = [dictionary.doc2bow(text) for text in words_list]
33
+
34
+ return corpus,dictionary,words_list
35
+
36
+ def lda(words_list,sentences,corpus,dictionary,num):
37
+ lda = gensim.models.ldamodel.LdaModel(corpus=corpus,id2word=dictionary, num_topics=num)
38
+
39
+ topics = lda.print_topics(num_topics=num, num_words=10)
40
+
41
+ # 根据关键词的匹配度,选择最能代表每个主题的句子,作为中心句
42
+
43
+ central_sentences = []
44
+ for topic in topics:
45
+ topic_id, topic_words = topic
46
+ topic_words = [word.split("*")[1].strip('"') for word in topic_words.split("+")]
47
+ max_score = 0
48
+ candidates = [] # 存储候选中心句
49
+ for sentence, words in zip(sentences, words_list):
50
+ score = 0
51
+ for word in words:
52
+ if word in topic_words:
53
+ score += 1
54
+ if score > max_score:
55
+ max_score = score
56
+ candidates = [sentence] # 如果找到更高的匹配度,更新候选列表
57
+ elif score == max_score:
58
+ candidates.append(sentence) # 如果匹配度相同,添加到候选列表
59
+ for candidate in candidates: # 遍历候选列表
60
+ if candidate not in central_sentences: # 检查是否已经存在相同的句子
61
+ central_sentence = candidate # 如果不存在,选择为中心句
62
+ central_sentences.append(central_sentence)
63
+ break # 跳出循环
64
+
65
+ return central_sentences
66
+
67
+
68
+ def abstruct_main(sentences,num):
69
+ corpus,dictionary,words_list = build_corpus(sentences)
70
+ central_sentences= lda(words_list, sentences, corpus, dictionary,num)
71
+ return central_sentences
classification.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gensim
2
+ import numpy as np
3
+ from sklearn.feature_extraction.text import TfidfVectorizer
4
+ from sklearn.metrics.pairwise import cosine_similarity
5
+ from transformers import AutoTokenizer, AutoModel
6
+ import torch
7
+
8
+
9
+ def classify_by_topic(articles, central_topics):
10
+
11
+ # 计算每篇文章与每个中心主题的相似度,返回一个矩阵
12
+ def compute_similarity(articles, central_topics):
13
+
14
+ model = AutoModel.from_pretrained("distilbert-base-multilingual-cased")
15
+ tokenizer = AutoTokenizer.from_pretrained(
16
+ "distilbert-base-multilingual-cased")
17
+
18
+ # 将一个句子转换为一个向量
19
+ def sentence_to_vector(sentence, context):
20
+ # 分词并添加特殊标记
21
+ sentence = context[0]+context[1]+sentence*4+context[2]+context[3]
22
+ tokens = tokenizer.encode_plus(
23
+ sentence, add_special_tokens=True, return_tensors="pt")
24
+ # 获取每个词的隐藏状态向量
25
+ outputs = model(**tokens)
26
+ hidden_states = outputs.last_hidden_state
27
+ # 计算平均向量作为句子向量
28
+ vector = np.squeeze(torch.mean(
29
+ hidden_states, dim=1).detach().numpy()) # a 1 x d tensor
30
+ return vector
31
+
32
+ # 获取一个句子的上下文
33
+ def get_context(sentences, index):
34
+ if index == 0:
35
+ prev_sentence = ""
36
+ pprev_sentence = ""
37
+ elif index == 1:
38
+ prev_sentence = sentences[index-1]
39
+ pprev_sentence = ""
40
+ else:
41
+ prev_sentence = sentences[index-1]
42
+ pprev_sentence = sentences[index-2]
43
+ if index == len(sentences) - 1:
44
+ next_sentence = ""
45
+ nnext_sentence = ""
46
+ elif index == len(sentences) - 2:
47
+ next_sentence = sentences[index+1]
48
+ nnext_sentence = ""
49
+ else:
50
+ next_sentence = sentences[index+1]
51
+ nnext_sentence = sentences[index+2]
52
+ return (pprev_sentence, prev_sentence, next_sentence, nnext_sentence)
53
+
54
+ # 将每个文章句子和每个中心句子转换为向量
55
+ doc_vectors = [sentence_to_vector(sentence, get_context(
56
+ articles, i)) for i, sentence in enumerate(articles)]
57
+ topic_vectors = [sentence_to_vector(sentence, get_context(
58
+ central_topics, i)) for i, sentence in enumerate(central_topics)]
59
+ # 计算每个文章句子和每个中心句子之间的余弦相似度矩阵
60
+ cos_sim_matrix = cosine_similarity(doc_vectors, topic_vectors)
61
+
62
+ # print(cos_sim_matrix)
63
+ return cos_sim_matrix
64
+
65
+ # 按照相似度矩阵分类文章,返回一个列表
66
+ def group_by_topic(articles, central_topics, similarity_matrix):
67
+ group = []
68
+ original_articles = articles.copy() # 保存一份原始的文章列表
69
+ # 用原始的文章列表替换预处理后的文章列表
70
+ for article, similarity in zip(original_articles, similarity_matrix):
71
+ max_similarity = max(similarity) # 取最高的相似度值
72
+ max_index = similarity.tolist().index(max_similarity) # 取最高相似度值对应的索引
73
+ # print(max_similarity,max_index )
74
+ group.append((article, central_topics[max_index]))
75
+
76
+ return group
77
+
78
+ # 实现分类功能
79
+ similarity_matrix = compute_similarity(articles, central_topics)
80
+ groups = group_by_topic(articles, central_topics, similarity_matrix)
81
+
82
+ # 返回分类后的列表
83
+ return groups
requirements.txt ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiohttp==3.8.4
2
+ aiosignal==1.3.1
3
+ anyio==3.7.0
4
+ astor==0.8.1
5
+ async-timeout==4.0.2
6
+ attrs==23.1.0
7
+ Babel==2.12.1
8
+ backoff==2.2.1
9
+ bce-python-sdk==0.8.83
10
+ blinker==1.6.2
11
+ certifi==2023.5.7
12
+ charset-normalizer==3.1.0
13
+ click==8.1.3
14
+ cmake==3.26.3
15
+ colorama==0.4.6
16
+ colorlog==6.7.0
17
+ contourpy==1.0.7
18
+ cycler==0.11.0
19
+ datasets==2.12.0
20
+ decorator==5.1.1
21
+ dill==0.3.4
22
+ exceptiongroup==1.1.1
23
+ fastapi==0.95.2
24
+ filelock==3.12.0
25
+ Flask==2.3.2
26
+ Flask-Babel==2.0.0
27
+ fonttools==4.39.4
28
+ frozenlist==1.3.3
29
+ fsspec==2023.5.0
30
+ future==0.18.3
31
+ gensim==4.3.1
32
+ h11==0.14.0
33
+ huggingface-hub==0.14.1
34
+ idna==3.4
35
+ importlib-metadata==6.6.0
36
+ importlib-resources==5.12.0
37
+ itsdangerous==2.1.2
38
+ jieba==0.42.1
39
+ Jinja2==3.1.2
40
+ joblib==1.2.0
41
+ kiwisolver==1.4.4
42
+ lit==16.0.5
43
+ markdown-it-py==2.2.0
44
+ MarkupSafe==2.1.2
45
+ matplotlib==3.7.1
46
+ mdurl==0.1.2
47
+ mpmath==1.3.0
48
+ multidict==6.0.4
49
+ multiprocess==0.70.12.2
50
+ networkx==3.1
51
+ numpy==1.24.3
52
+ nvidia-cublas-cu11==11.10.3.66
53
+ nvidia-cuda-cupti-cu11==11.7.101
54
+ nvidia-cuda-nvrtc-cu11==11.7.99
55
+ nvidia-cuda-runtime-cu11==11.7.99
56
+ nvidia-cudnn-cu11==8.5.0.96
57
+ nvidia-cufft-cu11==10.9.0.58
58
+ nvidia-curand-cu11==10.2.10.91
59
+ nvidia-cusolver-cu11==11.4.0.1
60
+ nvidia-cusparse-cu11==11.7.4.91
61
+ nvidia-nccl-cu11==2.14.3
62
+ nvidia-nvtx-cu11==11.7.91
63
+ opt-einsum==3.3.0
64
+ packaging==23.1
65
+ paddle-bfloat==0.1.7
66
+ paddle2onnx==1.0.6
67
+ paddlefsl==1.1.0
68
+ paddlenlp==2.5.2
69
+ paddlepaddle==2.4.2
70
+ pandas==2.0.2
71
+ Pillow==9.5.0
72
+ protobuf==3.20.0
73
+ pyarrow==12.0.0
74
+ pycryptodome==3.18.0
75
+ pydantic==1.10.8
76
+ Pygments==2.15.1
77
+ pyparsing==3.0.9
78
+ python-dateutil==2.8.2
79
+ pytz==2023.3
80
+ PyYAML==6.0
81
+ regex==2023.5.5
82
+ requests==2.31.0
83
+ responses==0.18.0
84
+ rich==13.4.1
85
+ scikit-learn==1.2.2
86
+ scipy==1.10.1
87
+ sentencepiece==0.1.99
88
+ seqeval==1.2.2
89
+ six==1.16.0
90
+ smart-open==6.3.0
91
+ sniffio==1.3.0
92
+ starlette==0.27.0
93
+ sympy==1.12
94
+ threadpoolctl==3.1.0
95
+ tokenizers==0.13.3
96
+ torch==2.0.1
97
+ tqdm==4.65.0
98
+ transformers==4.29.2
99
+ triton==2.0.0
100
+ typer==0.9.0
101
+ typing-extensions==4.6.2
102
+ tzdata==2023.3
103
+ urllib3==2.0.2
104
+ uvicorn==0.22.0
105
+ visualdl==2.4.2
106
+ Werkzeug==2.3.4
107
+ xxhash==3.2.0
108
+ yarl==1.9.2
109
+ zipp==3.15.0
run.py CHANGED
@@ -40,3 +40,4 @@ ans = util.generation(groups, max_length)
40
  # {(main_sentence,(Ai_abstruct,paragraph))}
41
  for i in ans.items():
42
  print(i)
 
 
40
  # {(main_sentence,(Ai_abstruct,paragraph))}
41
  for i in ans.items():
42
  print(i)
43
+ ``
util.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import jieba
3
+ import re
4
+ import requests
5
+ import backoff
6
+
7
+
8
+ @backoff.on_exception(backoff.expo, requests.exceptions.RequestException)
9
+ def post_url(url, headers, payload):
10
+ response = requests.request("POST", url, headers=headers, data=payload)
11
+ return response
12
+
13
+
14
+ def seg(text):
15
+ sentences = re.split(r'(?<=[。!?])\s*', text)
16
+ return sentences
17
+
18
+
19
+ def clean_text(text):
20
+ text = text.replace('\n', " ")
21
+ text = re.sub(r"-", " ", text)
22
+ text = re.sub(r"\d+/\d+/\d+", "", text) # 日期
23
+ text = re.sub(r"[0-2]?[0-9]:[0-6][0-9]", "", text) # 时间
24
+ text = re.sub(
25
+ r"/[a-zA-Z]*[:\//\]*[A-Za-z0-9\-_]+\.+[A-Za-z0-9\.\/%&=\?\-_]+/i", "", text) # 网址
26
+ pure_text = ''
27
+ for letter in text:
28
+ if letter.isalpha() or letter == ' ':
29
+ pure_text += letter
30
+
31
+ text = ' '.join(word for word in pure_text.split() if len(word) > 1)
32
+ return text
33
+
34
+
35
+ def article_to_group(groups, topics):
36
+ para = {}
37
+ for i in groups:
38
+ if not i[1] in para:
39
+ para[i[1]] = i[0]
40
+ else:
41
+ para[i[1]] = para[i[1]] + i[0]
42
+ return para
43
+
44
+
45
+ def generation(para, max_length):
46
+ API_KEY = "IZt1uK9PAI0LiqleqT0cE30b"
47
+ SECRET_KEY = "Xv5kHB8eyhNuI1B1G7fRgm2SIPdlxGxs"
48
+
49
+ def get_access_token():
50
+
51
+ url = "https://aip.baidubce.com/oauth/2.0/token"
52
+ params = {"grant_type": "client_credentials",
53
+ "client_id": API_KEY, "client_secret": SECRET_KEY}
54
+ return str(requests.post(url, params=params).json().get("access_token"))
55
+
56
+ url = "https://aip.baidubce.com/rpc/2.0/nlp/v1/news_summary?charset=UTF-8&access_token=" + get_access_token()
57
+ topic = {}
58
+
59
+ for i, (j, k) in enumerate(para.items()):
60
+ input_text = k
61
+ # print(k)
62
+ payload = json.dumps({
63
+ "content": k,
64
+ "max_summary_len": max_length
65
+ })
66
+ headers = {
67
+ 'Content-Type': 'application/json',
68
+ 'Accept': 'application/json'
69
+ }
70
+
71
+ response = post_url(url, headers, payload)
72
+ text_dict = json.loads(response.text)
73
+ # print(text_dict)
74
+ topic[j] = (text_dict['summary'], k)
75
+ return topic