KoichiYasuoka commited on
Commit
4ea29e2
·
1 Parent(s): 6c0dc23

algorithm improved

Browse files
Files changed (1) hide show
  1. upos.py +5 -6
upos.py CHANGED
@@ -1,29 +1,28 @@
 
1
  from transformers import TokenClassificationPipeline
2
 
3
  class BellmanFordTokenClassificationPipeline(TokenClassificationPipeline):
4
  def __init__(self,**kwargs):
5
- import numpy
6
  super().__init__(**kwargs)
7
  x=self.model.config.label2id
8
  y=[k for k in x if not k.startswith("I-")]
9
- self.transition=numpy.full((len(x),len(x)),numpy.nan)
10
  for k,v in x.items():
11
  for j in ["I-"+k[2:]] if k.startswith("B-") else [k]+y if k.startswith("I-") else y:
12
  self.transition[v,x[j]]=0
13
  def check_model_type(self,supported_models):
14
  pass
15
  def postprocess(self,model_outputs,**kwargs):
16
- import numpy
17
  if "logits" not in model_outputs:
18
  return self.postprocess(model_outputs[0],**kwargs)
19
  m=model_outputs["logits"][0].numpy()
20
  e=numpy.exp(m-numpy.max(m,axis=-1,keepdims=True))
21
  z=e/e.sum(axis=-1,keepdims=True)
22
  for i in range(m.shape[0]-1,0,-1):
23
- m[i-1]+=numpy.nanmax(m[i]+self.transition,axis=1)
24
- k=[numpy.nanargmax(m[0]+self.transition[0])]
25
  for i in range(1,m.shape[0]):
26
- k.append(numpy.nanargmax(m[i]+self.transition[k[-1]]))
27
  w=[{"entity":self.model.config.id2label[j],"start":s,"end":e,"score":z[i,j]} for i,((s,e),j) in enumerate(zip(model_outputs["offset_mapping"][0].tolist(),k)) if s<e]
28
  if "aggregation_strategy" in kwargs and kwargs["aggregation_strategy"]!="none":
29
  for i,t in reversed(list(enumerate(w))):
 
1
+ import numpy
2
  from transformers import TokenClassificationPipeline
3
 
4
  class BellmanFordTokenClassificationPipeline(TokenClassificationPipeline):
5
  def __init__(self,**kwargs):
 
6
  super().__init__(**kwargs)
7
  x=self.model.config.label2id
8
  y=[k for k in x if not k.startswith("I-")]
9
+ self.transition=numpy.full((len(x),len(x)),-numpy.inf)
10
  for k,v in x.items():
11
  for j in ["I-"+k[2:]] if k.startswith("B-") else [k]+y if k.startswith("I-") else y:
12
  self.transition[v,x[j]]=0
13
  def check_model_type(self,supported_models):
14
  pass
15
  def postprocess(self,model_outputs,**kwargs):
 
16
  if "logits" not in model_outputs:
17
  return self.postprocess(model_outputs[0],**kwargs)
18
  m=model_outputs["logits"][0].numpy()
19
  e=numpy.exp(m-numpy.max(m,axis=-1,keepdims=True))
20
  z=e/e.sum(axis=-1,keepdims=True)
21
  for i in range(m.shape[0]-1,0,-1):
22
+ m[i-1]+=numpy.max(m[i]+self.transition,axis=1)
23
+ k=[numpy.argmax(m[0]+self.transition[0])]
24
  for i in range(1,m.shape[0]):
25
+ k.append(numpy.argmax(m[i]+self.transition[k[-1]]))
26
  w=[{"entity":self.model.config.id2label[j],"start":s,"end":e,"score":z[i,j]} for i,((s,e),j) in enumerate(zip(model_outputs["offset_mapping"][0].tolist(),k)) if s<e]
27
  if "aggregation_strategy" in kwargs and kwargs["aggregation_strategy"]!="none":
28
  for i,t in reversed(list(enumerate(w))):