Add pytorch model
Browse files- config.json +2 -1
- flax_to_pytorch.py +26 -0
- pytorch_model.bin +3 -0
config.json
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
{
|
2 |
-
"_name_or_path": "
|
3 |
"architectures": [
|
4 |
"T5ForConditionalGeneration"
|
5 |
],
|
@@ -21,6 +21,7 @@
|
|
21 |
"pad_token_id": 0,
|
22 |
"relative_attention_num_buckets": 32,
|
23 |
"tie_word_embeddings": false,
|
|
|
24 |
"transformers_version": "4.13.0",
|
25 |
"use_cache": true,
|
26 |
"vocab_size": 32103
|
|
|
1 |
{
|
2 |
+
"_name_or_path": ".",
|
3 |
"architectures": [
|
4 |
"T5ForConditionalGeneration"
|
5 |
],
|
|
|
21 |
"pad_token_id": 0,
|
22 |
"relative_attention_num_buckets": 32,
|
23 |
"tie_word_embeddings": false,
|
24 |
+
"torch_dtype": "float32",
|
25 |
"transformers_version": "4.13.0",
|
26 |
"use_cache": true,
|
27 |
"vocab_size": 32103
|
flax_to_pytorch.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
import jax.numpy as jnp
|
4 |
+
from transformers import AutoTokenizer
|
5 |
+
from transformers import FlaxT5ForConditionalGeneration
|
6 |
+
from transformers import T5ForConditionalGeneration
|
7 |
+
tokenizer = AutoTokenizer.from_pretrained(".")
|
8 |
+
model_fx = FlaxT5ForConditionalGeneration.from_pretrained(".")
|
9 |
+
model_pt = T5ForConditionalGeneration.from_pretrained(".", from_flax=True)
|
10 |
+
model_pt.save_pretrained("./")
|
11 |
+
text = "Hoe gaat het?"
|
12 |
+
e_input_ids_fx = tokenizer(text, return_tensors="np", padding=True, max_length=128, truncation=True)
|
13 |
+
d_input_ids_fx = jnp.ones((e_input_ids_fx.input_ids.shape[0], 1), dtype="i4") * model_fx.config.decoder_start_token_id
|
14 |
+
e_input_ids_pt = tokenizer(text, return_tensors="pt", padding=True, max_length=128, truncation=True)
|
15 |
+
d_input_ids_pt = np.ones((e_input_ids_pt.input_ids.shape[0], 1), dtype="i4") * model_pt.config.decoder_start_token_id
|
16 |
+
print(e_input_ids_fx)
|
17 |
+
print(d_input_ids_fx)
|
18 |
+
print()
|
19 |
+
encoder_pt = model_fx.encode(**e_input_ids_pt)
|
20 |
+
decoder_pt = model_fx.decode(d_input_ids_pt, encoder_pt)
|
21 |
+
logits_pt = decoder_pt.logits
|
22 |
+
print(logits_pt)
|
23 |
+
encoder_fx = model_fx.encode(**e_input_ids_fx)
|
24 |
+
decoder_fx = model_fx.decode(d_input_ids_fx, encoder_fx)
|
25 |
+
logits_fx = decoder_fx.logits
|
26 |
+
print(logits_fx)
|
pytorch_model.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:fa8b87f8bb924ddaf9823ed6c9ed8f57adbee415b398049da58ddbe36997cf9a
|
3 |
+
size 990280781
|