w3robotics commited on
Commit
7d7d7ee
·
verified ·
1 Parent(s): 5720b6c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -25
app.py CHANGED
@@ -18,30 +18,42 @@ model.to(device)
18
  #print(np.array(image).shape)
19
 
20
 
21
- # load document image
22
- dataset = load_dataset("hf-internal-testing/example-documents", split="test")
23
- image = dataset[2]["image"]
24
-
25
-
26
- task_prompt = "<s_rvlcdip>"
27
- decoder_input_ids = processor.tokenizer(task_prompt, add_special_tokens=False, return_tensors="pt").input_ids
28
-
29
- pixel_values = processor(image, return_tensors="pt").pixel_values
30
-
31
- outputs = model.generate(
32
- pixel_values.to(device),
33
- decoder_input_ids=decoder_input_ids.to(device),
34
- max_length=model.decoder.config.max_position_embeddings,
35
- pad_token_id=processor.tokenizer.pad_token_id,
36
- eos_token_id=processor.tokenizer.eos_token_id,
37
- use_cache=True,
38
- bad_words_ids=[[processor.tokenizer.unk_token_id]],
39
- return_dict_in_generate=True,
40
- )
41
-
42
- sequence = processor.batch_decode(outputs.sequences)[0]
43
- sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")
44
- sequence = re.sub(r"<.*?>", "", sequence, count=1).strip() # remove first task start token
45
- print(processor.token2json(sequence))
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
 
 
18
  #print(np.array(image).shape)
19
 
20
 
21
+ st.title("Classify Document Image")
22
+
23
+ file_name = st.file_uploader("Upload a candidate image")
24
+
25
+ if file_name is not None:
26
+ col1, col2 = st.columns(2)
27
+
28
+ image = Image.open(file_name)
29
+ image = image.convert("RGB")
30
+
31
+ # load document image
32
+ #dataset = load_dataset("hf-internal-testing/example-documents", split="test")
33
+ #image = dataset[2]["image"]
34
+
35
+
36
+ task_prompt = "<s_rvlcdip>"
37
+ decoder_input_ids = processor.tokenizer(task_prompt, add_special_tokens=False, return_tensors="pt").input_ids
38
+
39
+ pixel_values = processor(image, return_tensors="pt").pixel_values
40
+
41
+ outputs = model.generate(
42
+ pixel_values.to(device),
43
+ decoder_input_ids=decoder_input_ids.to(device),
44
+ max_length=model.decoder.config.max_position_embeddings,
45
+ pad_token_id=processor.tokenizer.pad_token_id,
46
+ eos_token_id=processor.tokenizer.eos_token_id,
47
+ use_cache=True,
48
+ bad_words_ids=[[processor.tokenizer.unk_token_id]],
49
+ return_dict_in_generate=True,
50
+ )
51
+
52
+ sequence = processor.batch_decode(outputs.sequences)[0]
53
+ sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")
54
+ sequence = re.sub(r"<.*?>", "", sequence, count=1).strip() # remove first task start token
55
+ print(processor.token2json(sequence))
56
+
57
+
58
 
59