yhavinga commited on
Commit
332e951
·
1 Parent(s): 31857b8

Update flax to pytorch script

Browse files
Files changed (1) hide show
  1. flax_to_pytorch.py +5 -47
flax_to_pytorch.py CHANGED
@@ -1,47 +1,5 @@
1
- # from transformers import T5ForConditionalGeneration, TFT5ForConditionalGeneration, FlaxT5ForConditionalGeneration
2
- # import numpy as np
3
- # import torch
4
- #
5
- # fx_model = FlaxT5ForConditionalGeneration.from_pretrained(".")
6
- #
7
- # pt_model = T5ForConditionalGeneration.from_pretrained(".", from_flax=True)
8
- # pt_model.save_pretrained(".")
9
- #
10
- #
11
- # # tf_model = TFT5ForConditionalGeneration.from_pretrained(".", from_pt=True)
12
- # # tf_model.save_pretrained(".")
13
- #
14
-
15
- #!/usr/bin/env python
16
- import tempfile
17
- import jax
18
- import numpy as np
19
- import torch
20
- from jax import numpy as jnp
21
- from transformers import AutoTokenizer, FlaxT5ForConditionalGeneration, T5ForConditionalGeneration
22
-
23
- def to_f32(t):
24
- return jax.tree_map(lambda x: x.astype(jnp.float32) if x.dtype == jnp.bfloat16 else x, t)
25
-
26
- def main():
27
- # Saving extra files from config.json and tokenizer.json files
28
- tokenizer = AutoTokenizer.from_pretrained("./")
29
- tokenizer.save_pretrained("./")
30
- # Temporary saving bfloat16 Flax model into float32
31
- tmp = tempfile.mkdtemp()
32
- flax_model = FlaxT5ForConditionalGeneration.from_pretrained("./")
33
- flax_model.params = to_f32(flax_model.params)
34
- flax_model.save_pretrained(tmp)
35
- # Converting float32 Flax to PyTorch
36
- pt_model = T5ForConditionalGeneration.from_pretrained(tmp, from_flax=True)
37
- pt_model.save_pretrained("./", save_config=False)
38
-
39
- input_ids = np.asarray(2 * [128 * [0]], dtype=np.int32)
40
- input_ids_pt = torch.tensor(input_ids)
41
- logits_pt = pt_model(input_ids_pt).logits
42
- print(logits_pt)
43
- logits_fx = flax_model(input_ids).logits
44
- print(logits_fx)
45
-
46
- if __name__ == "__main__":
47
- main()
 
1
+ from transformers import T5ForConditionalGeneration, TFT5ForConditionalGeneration
2
+ pt_model = T5ForConditionalGeneration.from_pretrained(".", from_flax=True)
3
+ pt_model.save_pretrained(".")
4
+ tf_model = TFT5ForConditionalGeneration.from_pretrained(".", from_pt=True)
5
+ tf_model.save_pretrained(".")