VarunGumma commited on
Commit
c6d60bb
·
verified ·
1 Parent(s): d9771c1

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +15 -10
README.md CHANGED
@@ -64,17 +64,22 @@ Please refer to the [github repository](https://github.com/AI4Bharat/IndicTrans2
64
 
65
  ```python
66
  import torch
67
- from transformers import (
68
- AutoModelForSeq2SeqLM,
69
- AutoTokenizer,
70
- )
71
  from IndicTransToolkit import IndicProcessor
 
 
 
72
 
73
-
74
  model_name = "ai4bharat/indictrans2-indic-en-1B"
75
  tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
76
 
77
- model = AutoModelForSeq2SeqLM.from_pretrained(model_name, trust_remote_code=True)
 
 
 
 
 
78
 
79
  ip = IndicProcessor(inference=True)
80
 
@@ -85,16 +90,12 @@ input_sentences = [
85
  "मेरे मित्र ने मुझे उसके जन्मदिन की पार्टी में बुलाया है, और मैं उसे एक तोहफा दूंगा।",
86
  ]
87
 
88
- src_lang, tgt_lang = "hin_Deva", "eng_Latn"
89
-
90
  batch = ip.preprocess_batch(
91
  input_sentences,
92
  src_lang=src_lang,
93
  tgt_lang=tgt_lang,
94
  )
95
 
96
- DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
97
-
98
  # Tokenize the sentences and generate input encodings
99
  inputs = tokenizer(
100
  batch,
@@ -131,7 +132,11 @@ for input_sentence, translation in zip(input_sentences, translations):
131
  print(f"{tgt_lang}: {translation}")
132
  ```
133
 
 
134
 
 
 
 
135
 
136
  ### Citation
137
 
 
64
 
65
  ```python
66
  import torch
67
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
 
 
 
68
  from IndicTransToolkit import IndicProcessor
69
+ # recommended to run this on a gpu with flash_attn installed
70
+ # don't set attn_implemetation if you don't have flash_attn
71
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
72
 
73
+ src_lang, tgt_lang = "hin_Deva", "eng_Latn"
74
  model_name = "ai4bharat/indictrans2-indic-en-1B"
75
  tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
76
 
77
+ model = AutoModelForSeq2SeqLM.from_pretrained(
78
+ model_name,
79
+ trust_remote_code=True,
80
+ torch_dtype=torch.float16, # performance might slightly vary for bfloat16
81
+ attn_implementation="flash_attention_2"
82
+ ).to(DEVICE)
83
 
84
  ip = IndicProcessor(inference=True)
85
 
 
90
  "मेरे मित्र ने मुझे उसके जन्मदिन की पार्टी में बुलाया है, और मैं उसे एक तोहफा दूंगा।",
91
  ]
92
 
 
 
93
  batch = ip.preprocess_batch(
94
  input_sentences,
95
  src_lang=src_lang,
96
  tgt_lang=tgt_lang,
97
  )
98
 
 
 
99
  # Tokenize the sentences and generate input encodings
100
  inputs = tokenizer(
101
  batch,
 
132
  print(f"{tgt_lang}: {translation}")
133
  ```
134
 
135
+ ### 📢 Long Context IT2 Models
136
 
137
+ - New RoPE based IndicTrans2 models which are capable of handling sequence lengths **upto 2048 tokens** are available [here](https://huggingface.co/collections/prajdabre/indictrans2-rope-6742ddac669a05db0804db35).
138
+ - These models can be used by just changing the `model_name` parameter. Please read the model card of the RoPE-IT2 models for more information about the generation.
139
+ - It is recommended to run these models with `flash_attention_2` for efficient generation.
140
 
141
  ### Citation
142