sam749 commited on
Commit
3a89850
·
verified ·
1 Parent(s): cd2a958

Upload folder using huggingface_hub

Browse files
.gitattributes CHANGED
@@ -33,3 +33,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ IndicTransTokenizer/IndicTransTokenizer/en-indic/model.TGT filter=lfs diff=lfs merge=lfs -text
37
+ IndicTransTokenizer/IndicTransTokenizer/indic-en/model.SRC filter=lfs diff=lfs merge=lfs -text
38
+ IndicTransTokenizer/IndicTransTokenizer/indic-indic/model.SRC filter=lfs diff=lfs merge=lfs -text
39
+ IndicTransTokenizer/IndicTransTokenizer/indic-indic/model.TGT filter=lfs diff=lfs merge=lfs -text
IndicTransTokenizer/.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ dist/
2
+ IndicTransTokenizer.egg-info
3
+ IndicTransTokenizer/__pycache__/
IndicTransTokenizer/IndicTransTokenizer/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .tokenizer import IndicTransTokenizer
2
+ from .utils import IndicProcessor
IndicTransTokenizer/IndicTransTokenizer/en-indic/dict.SRC.json ADDED
The diff for this file is too large to render. See raw diff
 
IndicTransTokenizer/IndicTransTokenizer/en-indic/dict.TGT.json ADDED
The diff for this file is too large to render. See raw diff
 
IndicTransTokenizer/IndicTransTokenizer/en-indic/model.SRC ADDED
Binary file (759 kB). View file
 
IndicTransTokenizer/IndicTransTokenizer/en-indic/model.TGT ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ac9257c8e76b8b607705b959cc3d075656ea33032f7a974e467b8941df6e98d4
3
+ size 3256903
IndicTransTokenizer/IndicTransTokenizer/indic-en/dict.SRC.json ADDED
The diff for this file is too large to render. See raw diff
 
IndicTransTokenizer/IndicTransTokenizer/indic-en/dict.TGT.json ADDED
The diff for this file is too large to render. See raw diff
 
IndicTransTokenizer/IndicTransTokenizer/indic-en/model.SRC ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ac9257c8e76b8b607705b959cc3d075656ea33032f7a974e467b8941df6e98d4
3
+ size 3256903
IndicTransTokenizer/IndicTransTokenizer/indic-en/model.TGT ADDED
Binary file (759 kB). View file
 
IndicTransTokenizer/IndicTransTokenizer/indic-indic/dict.SRC.json ADDED
The diff for this file is too large to render. See raw diff
 
IndicTransTokenizer/IndicTransTokenizer/indic-indic/dict.TGT.json ADDED
The diff for this file is too large to render. See raw diff
 
IndicTransTokenizer/IndicTransTokenizer/indic-indic/model.SRC ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ac9257c8e76b8b607705b959cc3d075656ea33032f7a974e467b8941df6e98d4
3
+ size 3256903
IndicTransTokenizer/IndicTransTokenizer/indic-indic/model.TGT ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ac9257c8e76b8b607705b959cc3d075656ea33032f7a974e467b8941df6e98d4
3
+ size 3256903
IndicTransTokenizer/IndicTransTokenizer/tokenizer.py ADDED
@@ -0,0 +1,262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import torch
4
+ from transformers import BatchEncoding
5
+ from typing import Dict, List, Tuple, Union
6
+ from sentencepiece import SentencePieceProcessor
7
+
8
+ _PATH = os.path.dirname(os.path.realpath(__file__))
9
+
10
+
11
+ class IndicTransTokenizer:
12
+ def __init__(
13
+ self,
14
+ direction=None,
15
+ model_name=None,
16
+ unk_token="<unk>",
17
+ bos_token="<s>",
18
+ eos_token="</s>",
19
+ pad_token="<pad>",
20
+ model_max_length=256,
21
+ ):
22
+ self.model_max_length = model_max_length
23
+
24
+ self.supported_langs = [
25
+ "asm_Beng",
26
+ "awa_Deva",
27
+ "ben_Beng",
28
+ "bho_Deva",
29
+ "brx_Deva",
30
+ "doi_Deva",
31
+ "eng_Latn",
32
+ "gom_Deva",
33
+ "gon_Deva",
34
+ "guj_Gujr",
35
+ "hin_Deva",
36
+ "hne_Deva",
37
+ "kan_Knda",
38
+ "kas_Arab",
39
+ "kas_Deva",
40
+ "kha_Latn",
41
+ "lus_Latn",
42
+ "mag_Deva",
43
+ "mai_Deva",
44
+ "mal_Mlym",
45
+ "mar_Deva",
46
+ "mni_Beng",
47
+ "mni_Mtei",
48
+ "npi_Deva",
49
+ "ory_Orya",
50
+ "pan_Guru",
51
+ "san_Deva",
52
+ "sat_Olck",
53
+ "snd_Arab",
54
+ "snd_Deva",
55
+ "tam_Taml",
56
+ "tel_Telu",
57
+ "urd_Arab",
58
+ "unr_Deva",
59
+ ]
60
+
61
+ if model_name is None and direction is None:
62
+ raise ValueError("Either model_name or direction must be provided!")
63
+
64
+ if model_name is not None:
65
+ direction = self.get_direction(model_name) # model_name overrides direction
66
+
67
+ self.src_vocab_fp = os.path.join(_PATH, direction, "dict.SRC.json")
68
+ self.tgt_vocab_fp = os.path.join(_PATH, direction, "dict.TGT.json")
69
+ self.src_spm_fp = os.path.join(_PATH, direction, "model.SRC")
70
+ self.tgt_spm_fp = os.path.join(_PATH, direction, "model.TGT")
71
+
72
+ self.unk_token = unk_token
73
+ self.pad_token = pad_token
74
+ self.eos_token = eos_token
75
+ self.bos_token = bos_token
76
+
77
+ self.encoder = self._load_json(self.src_vocab_fp)
78
+ if self.unk_token not in self.encoder:
79
+ raise KeyError("<unk> token must be in vocab")
80
+ assert self.pad_token in self.encoder
81
+ self.encoder_rev = {v: k for k, v in self.encoder.items()}
82
+
83
+ self.decoder = self._load_json(self.tgt_vocab_fp)
84
+ if self.unk_token not in self.encoder:
85
+ raise KeyError("<unk> token must be in vocab")
86
+ assert self.pad_token in self.encoder
87
+ self.decoder_rev = {v: k for k, v in self.decoder.items()}
88
+
89
+ # load SentencePiece model for pre-processing
90
+ self.src_spm = self._load_spm(self.src_spm_fp)
91
+ self.tgt_spm = self._load_spm(self.tgt_spm_fp)
92
+
93
+ self.unk_token_id = self.encoder[self.unk_token]
94
+ self.pad_token_id = self.encoder[self.pad_token]
95
+ self.eos_token_id = self.encoder[self.eos_token]
96
+ self.bos_token_id = self.encoder[self.bos_token]
97
+
98
+ def get_direction(self, model_name: str) -> str:
99
+ pieces = model_name.split("/")[-1].split("-")
100
+ return f"{pieces[1]}-{pieces[2]}"
101
+
102
+ def is_special_token(self, x: str):
103
+ return (x == self.pad_token) or (x == self.bos_token) or (x == self.eos_token)
104
+
105
+ def get_vocab_size(self, src: bool) -> int:
106
+ """Returns the size of the vocabulary"""
107
+ return len(self.encoder) if src else len(self.decoder)
108
+
109
+ def _load_spm(self, path: str) -> SentencePieceProcessor:
110
+ return SentencePieceProcessor(model_file=path)
111
+
112
+ def _save_json(self, data, path: str) -> None:
113
+ with open(path, "w", encoding="utf-8") as f:
114
+ json.dump(data, f, indent=2)
115
+
116
+ def _load_json(self, path: str) -> Union[Dict, List]:
117
+ with open(path, "r", encoding="utf-8") as f:
118
+ return json.load(f)
119
+
120
+ def _convert_token_to_id(self, token: str, src: bool) -> int:
121
+ """Converts an token (str) into an index (integer) using the source/target vocabulary map."""
122
+ return (
123
+ self.encoder.get(token, self.encoder[self.unk_token])
124
+ if src
125
+ else self.decoder.get(token, self.encoder[self.unk_token])
126
+ )
127
+
128
+ def _convert_id_to_token(self, index: int, src: bool) -> str:
129
+ """Converts an index (integer) into a token (str) using the source/target vocabulary map."""
130
+ return (
131
+ self.encoder_rev.get(index, self.unk_token)
132
+ if src
133
+ else self.decoder_rev.get(index, self.unk_token)
134
+ )
135
+
136
+ def _convert_tokens_to_string(self, tokens: List[str], src: bool) -> str:
137
+ """Uses sentencepiece model for detokenization"""
138
+ if src:
139
+ if tokens[0] in self.supported_langs and tokens[1] in self.supported_langs:
140
+ tokens = tokens[2:]
141
+ return " ".join(tokens)
142
+ else:
143
+ return " ".join(tokens)
144
+
145
+ def _remove_translation_tags(self, text: str) -> Tuple[List, str]:
146
+ """Removes the translation tags before text normalization and tokenization."""
147
+ tokens = text.split(" ")
148
+ return tokens[:2], " ".join(tokens[2:])
149
+
150
+ def _tokenize_src_line(self, line: str) -> List[str]:
151
+ """Tokenizes a source line."""
152
+ tags, text = self._remove_translation_tags(line)
153
+ tokens = self.src_spm.encode(text, out_type=str)
154
+ return tags + tokens
155
+
156
+ def _tokenize_tgt_line(self, line: str) -> List[str]:
157
+ """Tokenizes a target line."""
158
+ return self.tgt_spm.encode(line, out_type=str)
159
+
160
+ def tokenize(self, text: str, src: bool) -> List[str]:
161
+ """Tokenizes a string into tokens using the source/target vocabulary."""
162
+ return self._tokenize_src_line(text) if src else self._tokenize_tgt_line(text)
163
+
164
+ def batch_tokenize(self, batch: List[str], src: bool) -> List[List[str]]:
165
+ """Tokenizes a list of strings into tokens using the source/target vocabulary."""
166
+ return [self.tokenize(line, src) for line in batch]
167
+
168
+ def _create_attention_mask(self, ids: List[int], max_seq_len: int, src: bool) -> List[int]:
169
+ """Creates a attention mask for the input sequence."""
170
+ if src:
171
+ return [0] * (max_seq_len - len(ids)) + [1] * (len(ids) + 1)
172
+ else:
173
+ return [1] * (len(ids) + 1) + [0] * (max_seq_len - len(ids))
174
+
175
+ def _pad_batch(self, tokens: List[str], max_seq_len: int, src: bool) -> List[str]:
176
+ """Pads a batch of tokens and adds BOS/EOS tokens."""
177
+ if src:
178
+ return [self.pad_token] * (max_seq_len - len(tokens)) + tokens + [self.eos_token]
179
+ else:
180
+ return tokens + [self.eos_token] + [self.pad_token] * (max_seq_len - len(tokens))
181
+
182
+ def _decode_line(self, ids: List[int], src: bool) -> List[str]:
183
+ return [self._convert_id_to_token(_id, src) for _id in ids]
184
+
185
+ def _encode_line(self, tokens: List[str], src: bool) -> List[int]:
186
+ return [self._convert_token_to_id(token, src) for token in tokens]
187
+
188
+ def _strip_special_tokens(self, tokens: List[str]) -> List[str]:
189
+ return [token for token in tokens if not self.is_special_token(token)]
190
+
191
+ def _single_input_preprocessing(
192
+ self, tokens: List[str], src: bool, max_seq_len: int
193
+ ) -> Tuple[List[int], List[int], int]:
194
+ """Tokenizes a string into tokens and also converts them into integers using source/target vocabulary map."""
195
+ attention_mask = self._create_attention_mask(tokens, max_seq_len, src)
196
+ padded_tokens = self._pad_batch(tokens, max_seq_len, src)
197
+ input_ids = self._encode_line(padded_tokens, src)
198
+ return input_ids, attention_mask
199
+
200
+ def _single_output_postprocessing(self, ids: List[int], src: bool) -> str:
201
+ """Detokenizes a list of integer ids into a string using the source/target vocabulary."""
202
+ tokens = self._decode_line(ids, src)
203
+ tokens = self._strip_special_tokens(tokens)
204
+ return (
205
+ self._convert_tokens_to_string(tokens, src).replace(" ", "").replace("▁", " ").strip()
206
+ )
207
+
208
+ def __call__(
209
+ self,
210
+ batch: Union[list, str],
211
+ src: bool,
212
+ truncation: bool = False,
213
+ padding: str = "longest",
214
+ max_length: int = None,
215
+ return_tensors: str = "pt",
216
+ return_attention_mask: bool = True,
217
+ return_length: bool = False,
218
+ ) -> BatchEncoding:
219
+ """Tokenizes a string into tokens and also converts them into integers using source/target vocabulary map."""
220
+ assert padding in [
221
+ "longest",
222
+ "max_length",
223
+ ], "Padding should be either 'longest' or 'max_length'"
224
+
225
+ if not isinstance(batch, list):
226
+ raise TypeError(f"Batch must be a list, but current batch is of type {type(batch)}")
227
+
228
+ # tokenize the source sentences
229
+ batch = self.batch_tokenize(batch, src)
230
+
231
+ # truncate the sentences if needed
232
+ if truncation and max_length is not None:
233
+ batch = [ids[:max_length] for ids in batch]
234
+
235
+ lengths = [len(ids) for ids in batch]
236
+
237
+ max_seq_len = max(lengths) if padding == "longest" else max_length
238
+
239
+ input_ids, attention_mask = zip(
240
+ *[
241
+ self._single_input_preprocessing(tokens=tokens, src=src, max_seq_len=max_seq_len)
242
+ for tokens in batch
243
+ ]
244
+ )
245
+
246
+ _data = {"input_ids": input_ids}
247
+
248
+ if return_attention_mask:
249
+ _data["attention_mask"] = attention_mask
250
+
251
+ if return_length:
252
+ _data["lengths"] = lengths
253
+
254
+ return BatchEncoding(_data, tensor_type=return_tensors)
255
+
256
+ def batch_decode(self, batch: Union[list, torch.Tensor], src: bool) -> List[List[str]]:
257
+ """Detokenizes a list of integer ids or a tensor into a list of strings using the source/target vocabulary."""
258
+
259
+ if isinstance(batch, torch.Tensor):
260
+ batch = batch.detach().cpu().tolist()
261
+
262
+ return [self._single_output_postprocessing(ids=ids, src=src) for ids in batch]
IndicTransTokenizer/IndicTransTokenizer/utils.py ADDED
@@ -0,0 +1,530 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from typing import List, Tuple, Union
3
+
4
+ from indicnlp.tokenize import indic_tokenize, indic_detokenize
5
+ from indicnlp.normalize.indic_normalize import IndicNormalizerFactory
6
+ from sacremoses import MosesPunctNormalizer, MosesTokenizer, MosesDetokenizer
7
+ from indicnlp.transliterate.unicode_transliterate import UnicodeIndicTransliterator
8
+
9
+
10
+ class IndicProcessor:
11
+ def __init__(self, inference=True):
12
+ self.inference = inference
13
+
14
+ self._flores_codes = {
15
+ "asm_Beng": "as",
16
+ "awa_Deva": "hi",
17
+ "ben_Beng": "bn",
18
+ "bho_Deva": "hi",
19
+ "brx_Deva": "hi",
20
+ "doi_Deva": "hi",
21
+ "eng_Latn": "en",
22
+ "gom_Deva": "kK",
23
+ "gon_Deva": "hi",
24
+ "guj_Gujr": "gu",
25
+ "hin_Deva": "hi",
26
+ "hne_Deva": "hi",
27
+ "kan_Knda": "kn",
28
+ "kas_Arab": "ur",
29
+ "kas_Deva": "hi",
30
+ "kha_Latn": "en",
31
+ "lus_Latn": "en",
32
+ "mag_Deva": "hi",
33
+ "mai_Deva": "hi",
34
+ "mal_Mlym": "ml",
35
+ "mar_Deva": "mr",
36
+ "mni_Beng": "bn",
37
+ "mni_Mtei": "hi",
38
+ "npi_Deva": "ne",
39
+ "ory_Orya": "or",
40
+ "pan_Guru": "pa",
41
+ "san_Deva": "hi",
42
+ "sat_Olck": "or",
43
+ "snd_Arab": "ur",
44
+ "snd_Deva": "hi",
45
+ "tam_Taml": "ta",
46
+ "tel_Telu": "te",
47
+ "urd_Arab": "ur",
48
+ "unr_Deva": "hi",
49
+ }
50
+
51
+ self._indic_num_map = {
52
+ "\u09e6": "0",
53
+ "0": "0",
54
+ "\u0ae6": "0",
55
+ "\u0ce6": "0",
56
+ "\u0966": "0",
57
+ "\u0660": "0",
58
+ "\uabf0": "0",
59
+ "\u0b66": "0",
60
+ "\u0a66": "0",
61
+ "\u1c50": "0",
62
+ "\u06f0": "0",
63
+ "\u09e7": "1",
64
+ "1": "1",
65
+ "\u0ae7": "1",
66
+ "\u0967": "1",
67
+ "\u0ce7": "1",
68
+ "\u06f1": "1",
69
+ "\uabf1": "1",
70
+ "\u0b67": "1",
71
+ "\u0a67": "1",
72
+ "\u1c51": "1",
73
+ "\u0c67": "1",
74
+ "\u09e8": "2",
75
+ "2": "2",
76
+ "\u0ae8": "2",
77
+ "\u0968": "2",
78
+ "\u0ce8": "2",
79
+ "\u06f2": "2",
80
+ "\uabf2": "2",
81
+ "\u0b68": "2",
82
+ "\u0a68": "2",
83
+ "\u1c52": "2",
84
+ "\u0c68": "2",
85
+ "\u09e9": "3",
86
+ "3": "3",
87
+ "\u0ae9": "3",
88
+ "\u0969": "3",
89
+ "\u0ce9": "3",
90
+ "\u06f3": "3",
91
+ "\uabf3": "3",
92
+ "\u0b69": "3",
93
+ "\u0a69": "3",
94
+ "\u1c53": "3",
95
+ "\u0c69": "3",
96
+ "\u09ea": "4",
97
+ "4": "4",
98
+ "\u0aea": "4",
99
+ "\u096a": "4",
100
+ "\u0cea": "4",
101
+ "\u06f4": "4",
102
+ "\uabf4": "4",
103
+ "\u0b6a": "4",
104
+ "\u0a6a": "4",
105
+ "\u1c54": "4",
106
+ "\u0c6a": "4",
107
+ "\u09eb": "5",
108
+ "5": "5",
109
+ "\u0aeb": "5",
110
+ "\u096b": "5",
111
+ "\u0ceb": "5",
112
+ "\u06f5": "5",
113
+ "\uabf5": "5",
114
+ "\u0b6b": "5",
115
+ "\u0a6b": "5",
116
+ "\u1c55": "5",
117
+ "\u0c6b": "5",
118
+ "\u09ec": "6",
119
+ "6": "6",
120
+ "\u0aec": "6",
121
+ "\u096c": "6",
122
+ "\u0cec": "6",
123
+ "\u06f6": "6",
124
+ "\uabf6": "6",
125
+ "\u0b6c": "6",
126
+ "\u0a6c": "6",
127
+ "\u1c56": "6",
128
+ "\u0c6c": "6",
129
+ "\u09ed": "7",
130
+ "7": "7",
131
+ "\u0aed": "7",
132
+ "\u096d": "7",
133
+ "\u0ced": "7",
134
+ "\u06f7": "7",
135
+ "\uabf7": "7",
136
+ "\u0b6d": "7",
137
+ "\u0a6d": "7",
138
+ "\u1c57": "7",
139
+ "\u0c6d": "7",
140
+ "\u09ee": "8",
141
+ "8": "8",
142
+ "\u0aee": "8",
143
+ "\u096e": "8",
144
+ "\u0cee": "8",
145
+ "\u06f8": "8",
146
+ "\uabf8": "8",
147
+ "\u0b6e": "8",
148
+ "\u0a6e": "8",
149
+ "\u1c58": "8",
150
+ "\u0c6e": "8",
151
+ "\u09ef": "9",
152
+ "9": "9",
153
+ "\u0aef": "9",
154
+ "\u096f": "9",
155
+ "\u0cef": "9",
156
+ "\u06f9": "9",
157
+ "\uabf9": "9",
158
+ "\u0b6f": "9",
159
+ "\u0a6f": "9",
160
+ "\u1c59": "9",
161
+ "\u0c6f": "9",
162
+ }
163
+
164
+ self._placeholder_entity_maps = []
165
+
166
+ self._en_tok = MosesTokenizer(lang="en")
167
+ self._en_normalizer = MosesPunctNormalizer()
168
+ self._en_detok = MosesDetokenizer(lang="en")
169
+ self._xliterator = UnicodeIndicTransliterator()
170
+
171
+ self._multispace_regex = re.compile("[ ]{2,}")
172
+ self._digit_space_percent = re.compile(r"(\d) %")
173
+ self._double_quot_punc = re.compile(r"\"([,\.]+)")
174
+ self._digit_nbsp_digit = re.compile(r"(\d) (\d)")
175
+ self._end_bracket_space_punc_regex = re.compile(r"\) ([\.!:?;,])")
176
+
177
+ self._URL_PATTERN = r"\b(?<![\w/.])(?:(?:https?|ftp)://)?(?:(?:[\w-]+\.)+(?!\.))(?:[\w/\-?#&=%.]+)+(?!\.\w+)\b"
178
+ self._NUMERAL_PATTERN = r"(~?\d+\.?\d*\s?%?\s?-?\s?~?\d+\.?\d*\s?%|~?\d+%|\d+[-\/.,:']\d+[-\/.,:'+]\d+(?:\.\d+)?|\d+[-\/.:'+]\d+(?:\.\d+)?)"
179
+ self._EMAIL_PATTERN = r"[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}"
180
+ self._OTHER_PATTERN = r"[A-Za-z0-9]*[#|@]\w+"
181
+
182
+ def _add_placeholder_entity_map(self, placeholder_entity_map):
183
+ self._placeholder_entity_maps.append(placeholder_entity_map)
184
+
185
+ def get_placeholder_entity_maps(self):
186
+ return self._placeholder_entity_maps
187
+
188
+ def _punc_norm(self, text) -> str:
189
+ text = (
190
+ text.replace("\r", "")
191
+ .replace("(", " (")
192
+ .replace(")", ") ")
193
+ .replace("( ", "(")
194
+ .replace(" )", ")")
195
+ .replace(" :", ":")
196
+ .replace(" ;", ";")
197
+ .replace("`", "'")
198
+ .replace("„", '"')
199
+ .replace("“", '"')
200
+ .replace("”", '"')
201
+ .replace("–", "-")
202
+ .replace("—", " - ")
203
+ .replace("´", "'")
204
+ .replace("‘", "'")
205
+ .replace("‚", "'")
206
+ .replace("’", "'")
207
+ .replace("''", '"')
208
+ .replace("´´", '"')
209
+ .replace("…", "...")
210
+ .replace(" « ", ' "')
211
+ .replace("« ", '"')
212
+ .replace("«", '"')
213
+ .replace(" » ", '" ')
214
+ .replace(" »", '"')
215
+ .replace("»", '"')
216
+ .replace(" %", "%")
217
+ .replace("nº ", "nº ")
218
+ .replace(" :", ":")
219
+ .replace(" ºC", " ºC")
220
+ .replace(" cm", " cm")
221
+ .replace(" ?", "?")
222
+ .replace(" !", "!")
223
+ .replace(" ;", ";")
224
+ .replace(", ", ", ")
225
+ )
226
+
227
+ text = self._multispace_regex.sub(" ", text)
228
+ text = self._end_bracket_space_punc_regex.sub(r")\1", text)
229
+ text = self._digit_space_percent.sub(r"\1%", text)
230
+ text = self._double_quot_punc.sub(r'\1"', text)
231
+ text = self._digit_nbsp_digit.sub(r"\1.\2", text)
232
+ return text.strip()
233
+
234
+ def _normalize_indic_numerals(self, line: str) -> str:
235
+ """
236
+ Normalize the numerals in Indic languages from native script to Roman script (if present).
237
+
238
+ Args:
239
+ line (str): an input string with Indic numerals to be normalized.
240
+
241
+ Returns:
242
+ str: an input string with the all Indic numerals normalized to Roman script.
243
+ """
244
+ return "".join([self._indic_num_map.get(c, c) for c in line])
245
+
246
+ def _wrap_with_placeholders(self, text: str, patterns: list) -> str:
247
+ """
248
+ Wraps substrings with matched patterns in the given text with placeholders and returns
249
+ the modified text along with a mapping of the placeholders to their original value.
250
+
251
+ Args:
252
+ text (str): an input string which needs to be wrapped with the placeholders.
253
+ pattern (list): list of patterns to search for in the input string.
254
+
255
+ Returns:
256
+ text (str): a modified text.
257
+ """
258
+
259
+ serial_no = 1
260
+
261
+ placeholder_entity_map = dict()
262
+
263
+ indic_failure_cases = [
264
+ "آی ڈی ",
265
+ "ꯑꯥꯏꯗꯤ",
266
+ "आईडी",
267
+ "आई . डी . ",
268
+ "आई . डी .",
269
+ "आई. डी. ",
270
+ "आई. डी.",
271
+ "ऐटि",
272
+ "آئی ڈی ",
273
+ "ᱟᱭᱰᱤ ᱾",
274
+ "आयडी",
275
+ "ऐडि",
276
+ "आइडि",
277
+ "ᱟᱭᱰᱤ",
278
+ ]
279
+
280
+ for pattern in patterns:
281
+ matches = set(re.findall(pattern, text))
282
+
283
+ # wrap common match with placeholder tags
284
+ for match in matches:
285
+ if pattern == self._URL_PATTERN:
286
+ # Avoids false positive URL matches for names with initials.
287
+ if len(match.replace(".", "")) < 4:
288
+ continue
289
+ if pattern == self._NUMERAL_PATTERN:
290
+ # Short numeral patterns do not need placeholder based handling.
291
+ if (
292
+ len(match.replace(" ", "").replace(".", "").replace(":", ""))
293
+ < 4
294
+ ):
295
+ continue
296
+
297
+ # Set of Translations of "ID" in all the suppported languages have been collated.
298
+ # This has been added to deal with edge cases where placeholders might get translated.
299
+ base_placeholder = f"<ID{serial_no}>"
300
+
301
+ placeholder_entity_map[f"<ID{serial_no}]"] = match
302
+ placeholder_entity_map[f"< ID{serial_no} ]"] = match
303
+ placeholder_entity_map[f"<ID{serial_no}>"] = match
304
+ placeholder_entity_map[f"< ID{serial_no} >"] = match
305
+
306
+ for i in indic_failure_cases:
307
+ placeholder_entity_map[f"<{i}{serial_no}>"] = match
308
+ placeholder_entity_map[f"< {i}{serial_no} >"] = match
309
+ placeholder_entity_map[f"< {i} {serial_no} >"] = match
310
+ placeholder_entity_map[f"<{i} {serial_no}]"] = match
311
+ placeholder_entity_map[f"< {i} {serial_no} ]"] = match
312
+ placeholder_entity_map[f"[{i} {serial_no}]"] = match
313
+ placeholder_entity_map[f"[ {i} {serial_no} ]"] = match
314
+
315
+ text = text.replace(match, base_placeholder)
316
+ serial_no += 1
317
+
318
+ text = re.sub("\s+", " ", text).replace(">/", ">").replace("]/", "]")
319
+ self._add_placeholder_entity_map(placeholder_entity_map)
320
+ return text
321
+
322
+ def _normalize(
323
+ self,
324
+ text: str,
325
+ ) -> Tuple[str, dict]:
326
+ """
327
+ Normalizes and wraps the spans of input string with placeholder tags. It first normalizes
328
+ the Indic numerals in the input string to Roman script. Later, it uses the input string with normalized
329
+ Indic numerals to wrap the spans of text matching the pattern with placeholder tags.
330
+
331
+ Args:
332
+ text (str): input string.
333
+ pattern (list): list of patterns to search for in the input string.
334
+
335
+ Returns:
336
+ text (str): the modified text
337
+ """
338
+ patterns = [
339
+ self._EMAIL_PATTERN,
340
+ self._URL_PATTERN,
341
+ self._NUMERAL_PATTERN,
342
+ self._OTHER_PATTERN,
343
+ ]
344
+
345
+ text = self._normalize_indic_numerals(text.strip())
346
+
347
+ if self.inference:
348
+ text = self._wrap_with_placeholders(text, patterns)
349
+
350
+ return text
351
+
352
+ def _apply_lang_tags(
353
+ self, sents: List[str], src_lang: str, tgt_lang: str, delimiter=" "
354
+ ) -> List[str]:
355
+ """
356
+ Add special tokens indicating source and target language to the start of the each input sentence.
357
+ Each resulting input sentence will have the format: "`{src_lang} {tgt_lang} {input_sentence}`".
358
+
359
+ Args:
360
+ sent (str): input sentence to be translated.
361
+ src_lang (str): flores lang code of the input sentence.
362
+ tgt_lang (str): flores lang code in which the input sentence will be translated.
363
+
364
+ Returns:
365
+ List[str]: list of input sentences with the special tokens added to the start.
366
+ """
367
+ return [f"{src_lang}{delimiter}{tgt_lang}{delimiter}{x.strip()}" for x in sents]
368
+
369
+ def _preprocess(
370
+ self,
371
+ sent: str,
372
+ lang: str,
373
+ normalizer: Union[MosesPunctNormalizer, IndicNormalizerFactory],
374
+ ) -> str:
375
+ """
376
+ Preprocess an input text sentence by normalizing, tokenization, and possibly transliterating it.
377
+
378
+ Args:
379
+ sent (str): input text sentence to preprocess.
380
+ normalizer (Union[MosesPunctNormalizer, IndicNormalizerFactory]): an object that performs normalization on the text.
381
+ lang (str): flores language code of the input text sentence.
382
+
383
+ Returns:
384
+ sent (str): a preprocessed input text sentence
385
+ """
386
+ iso_lang = self._flores_codes[lang]
387
+ sent = self._punc_norm(sent)
388
+ sent = self._normalize(sent)
389
+
390
+ transliterate = True
391
+ if lang.split("_")[1] in ["Arab", "Aran", "Olck", "Mtei", "Latn"]:
392
+ transliterate = False
393
+
394
+ if iso_lang == "en":
395
+ processed_sent = " ".join(
396
+ self._en_tok.tokenize(
397
+ self._en_normalizer.normalize(sent.strip()), escape=False
398
+ )
399
+ )
400
+ elif transliterate:
401
+ # transliterates from the any specific language to devanagari
402
+ # which is why we specify lang2_code as "hi".
403
+ processed_sent = self._xliterator.transliterate(
404
+ " ".join(
405
+ indic_tokenize.trivial_tokenize(
406
+ normalizer.normalize(sent.strip()), iso_lang
407
+ )
408
+ ),
409
+ iso_lang,
410
+ "hi",
411
+ ).replace(" ् ", "्")
412
+ else:
413
+ # we only need to transliterate for joint training
414
+ processed_sent = " ".join(
415
+ indic_tokenize.trivial_tokenize(
416
+ normalizer.normalize(sent.strip()), iso_lang
417
+ )
418
+ )
419
+
420
+ return processed_sent
421
+
422
+ def preprocess_batch(
423
+ self, batch: List[str], src_lang: str, tgt_lang: str, is_target: bool = False
424
+ ) -> List[str]:
425
+ """
426
+ Preprocess an array of sentences by normalizing, tokenization, and possibly transliterating it. It also tokenizes the
427
+ normalized text sequences using sentence piece tokenizer and also adds language tags.
428
+
429
+ Args:
430
+ batch (List[str]): input list of sentences to preprocess.
431
+ src_lang (str): flores language code of the input text sentences.
432
+ tgt_lang (str): flores language code of the output text sentences.
433
+ is_target (bool): add language tags if false otherwise skip it.
434
+
435
+ Returns:
436
+ List[str]: a list of preprocessed input text sentences.
437
+ """
438
+ # reset the placeholder entity map for each batch
439
+
440
+ normalizer = (
441
+ IndicNormalizerFactory().get_normalizer(self._flores_codes[src_lang])
442
+ if src_lang != "eng_Latn"
443
+ else None
444
+ )
445
+
446
+ preprocessed_sents = [
447
+ self._preprocess(sent, src_lang, normalizer) for sent in batch
448
+ ]
449
+
450
+ tagged_sents = (
451
+ self._apply_lang_tags(preprocessed_sents, src_lang, tgt_lang)
452
+ if not is_target
453
+ else preprocessed_sents
454
+ )
455
+
456
+ return tagged_sents
457
+
458
+ def _postprocess(
459
+ self,
460
+ sent: str,
461
+ placeholder_entity_map: dict,
462
+ lang: str = "hin_Deva",
463
+ ):
464
+ """
465
+ Postprocesses a single input sentence after the translation generation.
466
+
467
+ Args:
468
+ sent (str): input sentence to postprocess.
469
+ placeholder_entity_map (dict): dictionary mapping placeholders to the original entity values.
470
+ lang (str): flores language code of the input sentence.
471
+
472
+ Returns:
473
+ text (str): postprocessed input sentence.
474
+ """
475
+
476
+ lang_code, script_code = lang.split("_")
477
+ iso_lang = self._flores_codes[lang]
478
+
479
+ # Fixes for Perso-Arabic scripts
480
+ if script_code in ["Arab", "Aran"]:
481
+ sent = (
482
+ sent.replace(" ؟", "؟")
483
+ .replace(" ۔", "۔")
484
+ .replace(" ،", "،")
485
+ .replace("ٮ۪", "ؠ")
486
+ )
487
+
488
+ if lang_code == "ory":
489
+ sent = sent.replace("ଯ଼", "ୟ")
490
+
491
+ for k, v in placeholder_entity_map.items():
492
+ sent = sent.replace(k, v)
493
+
494
+ return (
495
+ self._en_detok.detokenize(sent.split(" "))
496
+ if lang == "eng_Latn"
497
+ else indic_detokenize.trivial_detokenize(
498
+ self._xliterator.transliterate(sent, "hi", iso_lang),
499
+ iso_lang,
500
+ )
501
+ )
502
+
503
+ def postprocess_batch(
504
+ self,
505
+ sents: List[str],
506
+ lang: str = "hin_Deva",
507
+ ) -> List[str]:
508
+ """
509
+ Postprocesses a batch of input sentences after the translation generations.
510
+
511
+ Args:
512
+ sents (List[str]): batch of translated sentences to postprocess.
513
+ placeholder_entity_map (List[dict]): dictionary mapping placeholders to the original entity values.
514
+ lang (str): flores language code of the input sentences.
515
+
516
+ Returns:
517
+ List[str]: postprocessed batch of input sentences.
518
+ """
519
+
520
+ placeholder_entity_maps = self.get_placeholder_entity_maps()
521
+
522
+ postprocessed_sents = [
523
+ self._postprocess(sent, placeholder_entity_map, lang)
524
+ for sent, placeholder_entity_map in zip(sents, placeholder_entity_maps)
525
+ ]
526
+
527
+ # reset the placeholder entity map after each batch
528
+ self._placeholder_entity_maps.clear()
529
+
530
+ return postprocessed_sents
IndicTransTokenizer/IndicTransTokenizer/version.py ADDED
@@ -0,0 +1 @@
 
 
1
+ __version__ = "0.1.1"
IndicTransTokenizer/IndicTransTokenizer/version.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ 0.1.1
IndicTransTokenizer/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) Varun Gumma.
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE
IndicTransTokenizer/README.md ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # IndicTransTokenizer
2
+
3
+ The goal of this repository is to provide a simple, modular, and extendable tokenizer for [IndicTrans2](https://github.com/AI4Bharat/IndicTrans2) and be compatible with the HuggingFace models released.
4
+
5
+ ## Pre-requisites
6
+ - `Python 3.8+`
7
+ - [Indic NLP Library](https://github.com/VarunGumma/indic_nlp_library)
8
+ - Other requirements as listed in `requirements.txt`
9
+
10
+ ## Configuration
11
+ - Editable installation (Note, this may take a while):
12
+ ```bash
13
+ git clone https://github.com/VarunGumma/IndicTransTokenizer
14
+ cd IndicTransTokenizer
15
+
16
+ pip install --editable ./
17
+ ```
18
+
19
+ ## Usage
20
+ ```python
21
+ import torch
22
+ from transformers import AutoModelForSeq2SeqLM
23
+ from IndicTransTokenizer import IndicProcessor, IndicTransTokenizer
24
+
25
+ tokenizer = IndicTransTokenizer(direction="en-indic")
26
+ ip = IndicProcessor(inference=True)
27
+ model = AutoModelForSeq2SeqLM.from_pretrained("ai4bharat/indictrans2-en-indic-dist-200M", trust_remote_code=True)
28
+
29
+ sentences = [
30
+ "This is a test sentence.",
31
+ "This is another longer different test sentence.",
32
+ "Please send an SMS to 9876543210 and an email on [email protected] by 15th October, 2023.",
33
+ ]
34
+
35
+ batch = ip.preprocess_batch(sentences, src_lang="eng_Latn", tgt_lang="hin_Deva")
36
+ batch = tokenizer(batch, src=True, return_tensors="pt")
37
+
38
+ with torch.inference_mode():
39
+ outputs = model.generate(**batch, num_beams=5, num_return_sequences=1, max_length=256)
40
+
41
+ outputs = tokenizer.batch_decode(outputs, src=False)
42
+ outputs = ip.postprocess_batch(outputs, lang="hin_Deva")
43
+ print(outputs)
44
+
45
+ >>> ['यह एक परीक्षण वाक्य है।', 'यह एक और लंबा अलग परीक्षण वाक्य है।', 'कृपया 9876543210 पर एक एस. एम. एस. भेजें और 15 अक्टूबर, 2023 तक [email protected] पर एक ईमेल भेजें।']
46
+ ```
47
+
48
+ For using the tokenizer to train/fine-tune the model, just set the `inference` argument of IndicProcessor to `False`.
49
+
50
+ ## Authors
51
+ - Varun Gumma ([email protected])
52
+ - Jay Gala ([email protected])
53
+ - Pranjal Agadh Chitale ([email protected])
54
+ - Raj Dabre ([email protected])
55
+
56
+
57
+ ## Bugs and Contribution
58
+ Since this a bleeding-edge module, you may encounter broken stuff and import issues once in a while. In case you encounter any bugs or want additional functionalities, please feel free to raise `Issues`/`Pull Requests` or contact the authors.
59
+
60
+
61
+ ## Citation
62
+ If you use our codebase, models or tokenizer, please do cite the following paper:
63
+ ```bibtex
64
+ @article{
65
+ gala2023indictrans,
66
+ title={IndicTrans2: Towards High-Quality and Accessible Machine Translation Models for all 22 Scheduled Indian Languages},
67
+ author={Jay Gala and Pranjal A Chitale and A K Raghavan and Varun Gumma and Sumanth Doddapaneni and Aswanth Kumar M and Janki Atul Nawale and Anupama Sujatha and Ratish Puduppully and Vivek Raghavan and Pratyush Kumar and Mitesh M Khapra and Raj Dabre and Anoop Kunchukuttan},
68
+ journal={Transactions on Machine Learning Research},
69
+ issn={2835-8856},
70
+ year={2023},
71
+ url={https://openreview.net/forum?id=vfT4YuzAYA},
72
+ note={}
73
+ }
74
+ ```
75
+
76
+ ## Note
77
+ This tokenizer module is currently **not** compatible with the [PreTrainedTokenizer](https://huggingface.co/docs/transformers/v4.36.1/en/main_classes/tokenizer#transformers.PreTrainedTokenizer) module from HuggingFace. Hence, we are actively looking for `Pull Requests` to port this tokenizer to HF. Any leads on that front are welcome!
IndicTransTokenizer/requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ setuptools==68.2.2
2
+ torch
3
+ sacremoses
4
+ sentencepiece
5
+ transformers
6
+ indic-nlp-library-IT2 @ git+https://github.com/VarunGumma/indic_nlp_library
IndicTransTokenizer/setup.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pathlib
3
+ from sys import version_info, exit
4
+ from setuptools import setup, find_packages
5
+ from pkg_resources import parse_requirements
6
+
7
+
8
+ def write_version_py():
9
+ with open(os.path.join("IndicTransTokenizer", "version.txt"), "r") as f:
10
+ version = f.read().strip()
11
+
12
+ with open(os.path.join("IndicTransTokenizer", "version.py"), "w") as f:
13
+ f.write(f'__version__ = "{version}"\n')
14
+ return version
15
+
16
+
17
+ if version_info < (3, 8):
18
+ exit("Sorry, Python >= 3.8 is required for IndicTransTokenizer.")
19
+
20
+
21
+ with open("README.md", "r", errors="ignore", encoding="utf-8") as fh:
22
+ long_description = fh.read().strip()
23
+
24
+ version = write_version_py()
25
+
26
+ setup(
27
+ name="IndicTransTokenizer",
28
+ version=version,
29
+ author="Varun Gumma",
30
+ author_email="[email protected]",
31
+ description="A simple, consistent, and extendable module for IndicTrans2 tokenizer compatible with the HuggingFace models",
32
+ long_description=long_description,
33
+ long_description_content_type="text/markdown",
34
+ url="https://github.com/VarunGumma/IndicTransTokenizer",
35
+ packages=find_packages(),
36
+ license="MIT",
37
+ classifiers=[
38
+ "Programming Language :: Python :: 3",
39
+ "License :: OSI Approved :: MIT License",
40
+ "Operating System :: OS Independent",
41
+ ],
42
+ python_requires=">=3.8",
43
+ install_requires=[
44
+ str(requirement)
45
+ for requirement in parse_requirements(pathlib.Path(f"requirements.txt").open())
46
+ ],
47
+ )
README.md CHANGED
@@ -1,12 +1,13 @@
1
  ---
2
- title: Indictrans2 Conversation
3
- emoji: 🚀
4
- colorFrom: pink
5
- colorTo: gray
6
  sdk: gradio
7
  sdk_version: 4.21.0
8
  app_file: app.py
9
  pinned: false
 
10
  ---
11
 
12
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: IndicTrans2 for Conversation
3
+ emoji: 📚
4
+ colorFrom: purple
5
+ colorTo: red
6
  sdk: gradio
7
  sdk_version: 4.21.0
8
  app_file: app.py
9
  pinned: false
10
+ license: apache-2.0
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import time
3
+ from config import model_repo_id, src_lang, tgt_lang
4
+ from indictrans2 import initialize_model_and_tokenizer, batch_translate
5
+ from examples import example_sentences
6
+
7
+
8
+ def load_models():
9
+ model_dict = {}
10
+
11
+ print("\tLoading model: %s" % model_repo_id)
12
+
13
+ # build model and tokenizer
14
+ en_indic_tokenizer, en_indic_model, en_indic_lora_model = (
15
+ initialize_model_and_tokenizer()
16
+ )
17
+
18
+ model_dict["_tokenizer"] = en_indic_tokenizer
19
+ model_dict["_model"] = en_indic_model
20
+ model_dict["_lora_model"] = en_indic_lora_model
21
+
22
+ return model_dict
23
+
24
+
25
+ def translation(text):
26
+
27
+ start_time = time.time()
28
+
29
+ tokenizer = model_dict["_tokenizer"]
30
+ model = model_dict["_model"]
31
+ lora_model = model_dict["_lora_model"]
32
+
33
+ # org translation
34
+ org_translation = batch_translate(
35
+ [text],
36
+ model=model,
37
+ tokenizer=tokenizer,
38
+ )
39
+ org_output = org_translation[0]
40
+ end_time = time.time()
41
+
42
+ # lora translation
43
+ lora_translation = batch_translate(
44
+ [text],
45
+ model=lora_model,
46
+ tokenizer=tokenizer,
47
+ )
48
+ lora_output = lora_translation[0]
49
+ end_time2 = time.time()
50
+
51
+ result = {
52
+ "source": src_lang,
53
+ "target": tgt_lang,
54
+ "input": text,
55
+ "it2_result": org_output,
56
+ "it2_conv_result": lora_output,
57
+ "it2_inference_time": end_time - start_time,
58
+ "it2_conv_inference_time": end_time2 - end_time,
59
+ }
60
+
61
+ return result
62
+
63
+
64
+ print("\tinit models")
65
+
66
+ global model_dict
67
+
68
+ model_dict = load_models()
69
+
70
+ inputs = gr.Textbox(lines=5, label="Input text")
71
+ outputs = gr.JSON(container=True)
72
+ submit_btn = gr.Button("Translate", variant="primary")
73
+
74
+ title = "IndicTrans2 fine-tuned on conversation"
75
+ description = f"Note: LoRA is trained only on En-Hi pair.\nDetails: https://github.com/AI4Bharat/IndicTrans2.\nLoRA Model: https://huggingface.co/sam749/IndicTrans2-Conv"
76
+
77
+ gr.Interface(
78
+ fn=translation,
79
+ inputs=inputs,
80
+ outputs=outputs,
81
+ title=title,
82
+ description=description,
83
+ submit_btn=submit_btn,
84
+ examples=example_sentences,
85
+ examples_per_page=10,
86
+ cache_examples=False,
87
+ ).launch(share=True)
config.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ model_repo_id = "ai4bharat/indictrans2-en-indic-dist-200M"
2
+ lora_repo_id = "sam749/IndicTrans2-Conv"
3
+ src_lang = "eng_Latn"
4
+ tgt_lang = "hin_Deva"
5
+ batch_size = 8
examples.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ example_sentences = [
2
+ ['Avantika to Prakash: Did you mean "I play cricket"? What position do you play?'],
3
+ ["'do you eat pizza?', Manoj said to Jaya"],
4
+ ["Ankita to Avantika: can you come with me to tour?"],
5
+ [
6
+ 'Sudha to Sakshi: Did you mean "I\'ll grab some coffee before the meeting starts."? Can I join you too?'
7
+ ],
8
+ [
9
+ 'Anil to Sakshi: Did you mean "I\'ll grab some coffee before the meeting starts."? Can I join you too?'
10
+ ],
11
+ ]
indictrans2.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoModelForSeq2SeqLM, BitsAndBytesConfig
3
+ from IndicTransTokenizer.IndicTransTokenizer.utils import IndicProcessor
4
+ from IndicTransTokenizer.IndicTransTokenizer.tokenizer import IndicTransTokenizer
5
+ from peft import PeftModel
6
+ from config import lora_repo_id, model_repo_id, batch_size, src_lang, tgt_lang
7
+
8
+
9
+ DIRECTION = "en-indic"
10
+ QUANTIZATION = None
11
+ IP = IndicProcessor(inference=True)
12
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
13
+ HALF = True if torch.cuda.is_available() else False
14
+
15
+
16
+ def initialize_model_and_tokenizer():
17
+
18
+ if QUANTIZATION == "4-bit":
19
+ qconfig = BitsAndBytesConfig(
20
+ load_in_4bit=True,
21
+ bnb_4bit_use_double_quant=True,
22
+ bnb_4bit_compute_dtype=torch.bfloat16,
23
+ )
24
+ elif QUANTIZATION == "8-bit":
25
+ qconfig = BitsAndBytesConfig(
26
+ load_in_8bit=True,
27
+ bnb_8bit_use_double_quant=True,
28
+ bnb_8bit_compute_dtype=torch.bfloat16,
29
+ )
30
+ else:
31
+ qconfig = None
32
+
33
+ tokenizer = IndicTransTokenizer(direction=DIRECTION)
34
+ model = AutoModelForSeq2SeqLM.from_pretrained(
35
+ model_repo_id,
36
+ trust_remote_code=True,
37
+ low_cpu_mem_usage=True,
38
+ quantization_config=qconfig,
39
+ )
40
+ model2 = AutoModelForSeq2SeqLM.from_pretrained(
41
+ model_repo_id,
42
+ trust_remote_code=True,
43
+ low_cpu_mem_usage=True,
44
+ quantization_config=qconfig,
45
+ )
46
+
47
+ if qconfig == None:
48
+ model = model.to(DEVICE)
49
+ model2 = model2.to(DEVICE)
50
+
51
+ model.eval()
52
+ model2.eval()
53
+
54
+ lora_model = PeftModel.from_pretrained(model2, lora_repo_id)
55
+
56
+ return tokenizer, model, lora_model
57
+
58
+
59
+ def batch_translate(input_sentences, model, tokenizer):
60
+ translations = []
61
+ for i in range(0, len(input_sentences), batch_size):
62
+ batch = input_sentences[i : i + batch_size]
63
+
64
+ # Preprocess the batch and extract entity mappings
65
+ batch = IP.preprocess_batch(batch, src_lang=src_lang, tgt_lang=tgt_lang)
66
+
67
+ # Tokenize the batch and generate input encodings
68
+ inputs = tokenizer(
69
+ batch,
70
+ src=True,
71
+ truncation=True,
72
+ padding="longest",
73
+ return_tensors="pt",
74
+ return_attention_mask=True,
75
+ ).to(DEVICE)
76
+
77
+ # Generate translations using the model
78
+ with torch.inference_mode():
79
+ generated_tokens = model.generate(
80
+ **inputs,
81
+ use_cache=True,
82
+ min_length=0,
83
+ max_length=256,
84
+ num_beams=5,
85
+ num_return_sequences=1,
86
+ )
87
+
88
+ # Decode the generated tokens into text
89
+ generated_tokens = tokenizer.batch_decode(
90
+ generated_tokens.detach().cpu().tolist(), src=False
91
+ )
92
+
93
+ # Postprocess the translations, including entity replacement
94
+ translations += IP.postprocess_batch(generated_tokens, lang=tgt_lang)
95
+
96
+ del inputs
97
+
98
+ return translations
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ indic-nlp-library-IT2 @ git+https://github.com/VarunGumma/indic_nlp_library
2
+ setuptools==68.2.2
3
+ transformers
4
+ gradio
5
+ torch
6
+ peft
7
+ sacremoses
8
+ sentencepiece