PereLluis13 commited on
Commit
ab249d6
·
1 Parent(s): f3a761e

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +103 -0
README.md CHANGED
@@ -5,3 +5,106 @@ tags:
5
  - seq2seq
6
  license: cc-by-nc-sa-4.0
7
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  - seq2seq
6
  license: cc-by-nc-sa-4.0
7
  ---
8
+ To use the model with a pipeline:
9
+ ```python3
10
+ from transformers import pipeline
11
+
12
+ def extract_triplets(text):
13
+ triplets = []
14
+ relation = ''
15
+ for token in text.split():
16
+ if token == "<triplet>":
17
+ current = 't'
18
+ if relation != '':
19
+ triplets.append((subject, relation, object_))
20
+ relation = ''
21
+ subject = ''
22
+ elif token == "<subj>":
23
+ current = 's'
24
+ if relation != '':
25
+ triplets.append((subject, relation, object_))
26
+ object_ = ''
27
+ elif token == "<obj>":
28
+ current = 'o'
29
+ relation = ''
30
+ else:
31
+ if current == 't':
32
+ subject += ' ' + token
33
+ elif current == 's':
34
+ object_ += ' ' + token
35
+ elif current == 'o':
36
+ relation += ' ' + token
37
+ triplets.append((subject, relation, object_))
38
+ return triplets
39
+
40
+ triplet_extractor = pipeline('text2text-generation', model='Babelscape/rebel-large', tokenizer='Babelscape/rebel-large')
41
+ extracted_text = triplet_extractor("Punta Cana is a resort town in the municipality of Higüey, in La Altagracia Province, the easternmost province of the Dominican Republic.Punta Cana is a resort town in the municipality of Higüey, in La Altagracia Province, the easternmost province of the Dominican Republic.")["generated_text"]
42
+
43
+ extracted_triplets = extract_triplets(extracted_text)
44
+ print(extracted_triplets)
45
+ ```
46
+
47
+ Or using the transformers
48
+ ```python3
49
+
50
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
51
+
52
+ def extract_triplets(text):
53
+ triplets = []
54
+ relation = ''
55
+ for token in text.split():
56
+ if token == "<triplet>":
57
+ current = 't'
58
+ if relation != '':
59
+ triplets.append((subject, relation, object_))
60
+ relation = ''
61
+ subject = ''
62
+ elif token == "<subj>":
63
+ current = 's'
64
+ if relation != '':
65
+ triplets.append((subject, relation, object_))
66
+ object_ = ''
67
+ elif token == "<obj>":
68
+ current = 'o'
69
+ relation = ''
70
+ else:
71
+ if current == 't':
72
+ subject += ' ' + token
73
+ elif current == 's':
74
+ object_ += ' ' + token
75
+ elif current == 'o':
76
+ relation += ' ' + token
77
+ triplets.append((subject, relation, object_))
78
+ return triplets
79
+
80
+ # Load model and tokenizer
81
+ tokenizer = AutoTokenizer.from_pretrained("model/rebel-large")
82
+ model = AutoModelForSeq2SeqLM.from_pretrained("model/rebel-large")
83
+ gen_kwargs = {
84
+ "max_length": 256,
85
+ "length_penalty": 0,
86
+ "num_beams": 3,
87
+ "num_return_sequences": 3,
88
+ }
89
+
90
+ # Text to extract triplets from
91
+ text = 'Punta Cana is a resort town in the municipality of Higüey, in La Altagracia Province, the easternmost province of the Dominican Republic.'
92
+
93
+ # Tokenizer text
94
+ model_inputs = tokenizer(text, max_length=256, padding=True, truncation=True, return_tensors = 'pt')
95
+
96
+ # Generate
97
+ generated_tokens = model.generate(
98
+ model_inputs["input_ids"].to(model.device),
99
+ attention_mask=model_inputs["attention_mask"].to(model.device),
100
+ **gen_kwargs,
101
+ )
102
+
103
+ # Extract text
104
+ decoded_preds = tokenizer.batch_decode(generated_tokens, skip_special_tokens=False)
105
+
106
+ # Extract triplets
107
+ for idx, sentence in enumerate(decoded_preds):
108
+ print(f'Prediction triplets sentence {idx}')
109
+ print(extract_triplets(sentence))
110
+ ```