Update README.md
Browse files
README.md
CHANGED
@@ -1,3 +1,54 @@
|
|
1 |
---
|
2 |
license: cc-by-nd-4.0
|
3 |
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
---
|
2 |
license: cc-by-nd-4.0
|
3 |
---
|
4 |
+
## Czech Metrum Validator.
|
5 |
+
Validator for metrum. Trained on Czech poetry from github project by
|
6 |
+
Institute of Czech Literature, Czech Academy of Sciences.
|
7 |
+
|
8 |
+
https://github.com/versotym/corpusCzechVerse
|
9 |
+
|
10 |
+
## Usage
|
11 |
+
|
12 |
+
### Loading model
|
13 |
+
Download validator.py with interface
|
14 |
+
Download model and load it by pytorch
|
15 |
+
|
16 |
+
```python
|
17 |
+
import torch
|
18 |
+
model: ValidatorInterface = (torch.load(args.metre_model_path_full, map_location=torch.device('cpu')))
|
19 |
+
```
|
20 |
+
|
21 |
+
Load base robeczech tokenizer and try it out
|
22 |
+
|
23 |
+
```python
|
24 |
+
tokenizer = = AutoTokenizer.from_pretrained('roberta-base')
|
25 |
+
model.validate(input_ids=datum["input_ids"], metre=datum["metre"])['acc']
|
26 |
+
```
|
27 |
+
|
28 |
+
### Train Model
|
29 |
+
|
30 |
+
```python
|
31 |
+
meter_model = MeterValidator(pretrained_model=args.pretrained_model)
|
32 |
+
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer)
|
33 |
+
|
34 |
+
training_args = TrainingArguments(
|
35 |
+
save_strategy = "no",
|
36 |
+
logging_steps = 500,
|
37 |
+
warmup_steps = args.worm_up,
|
38 |
+
weight_decay = 0.0,
|
39 |
+
num_train_epochs = args.epochs,
|
40 |
+
learning_rate = args.learning_rate,
|
41 |
+
fp16 = True if torch.cuda.is_available() else False,
|
42 |
+
ddp_backend = "nccl",
|
43 |
+
lr_scheduler_type="cosine",
|
44 |
+
logging_dir = './logs',
|
45 |
+
output_dir = './results',
|
46 |
+
per_device_train_batch_size = args.batch_size)
|
47 |
+
|
48 |
+
Trainer(model = rhyme_model,
|
49 |
+
args = training_args,
|
50 |
+
train_dataset= train_data.pytorch_dataset_body,
|
51 |
+
data_collator=collate).train()
|
52 |
+
|
53 |
+
```
|
54 |
+
|