ajimeno commited on
Commit
c77986f
·
1 Parent(s): ef5fe89

Model in memory

Browse files
Files changed (1) hide show
  1. app.py +21 -15
app.py CHANGED
@@ -66,27 +66,33 @@ st.image(image, caption='Your target document')
66
  with st.spinner(f'Processing the document ...'):
67
  pre_trained_model = "unstructuredio/chipper-fast-fine-tuning"
68
  processor = DonutProcessor.from_pretrained(pre_trained_model)
69
- model = VisionEncoderDecoderModel.from_pretrained(pre_trained_model)
 
70
 
71
- from huggingface_hub import hf_hub_download
 
 
 
72
 
73
- lm_head_file = hf_hub_download(
74
- repo_id=pre_trained_model, filename="lm_head.pth"
75
- )
76
 
77
- rank = 128
78
- model.decoder.lm_head = nn.Sequential(
79
- nn.Linear(model.decoder.lm_head.weight.shape[1], rank, bias=False),
80
- nn.Linear(rank, rank, bias=False),
81
- nn.Linear(rank, model.decoder.lm_head.weight.shape[0], bias=True),
82
- )
83
 
84
- model.decoder.lm_head.load_state_dict(torch.load(lm_head_file))
 
 
 
 
 
 
 
85
 
86
- device = "cuda" if torch.cuda.is_available() else "cpu"
87
 
88
- model.eval()
89
- model.to(device)
 
90
 
91
  st.info(f'Parsing document')
92
  parsed_info = run_prediction(image.convert("RGB"), model, processor, prompt)
 
66
  with st.spinner(f'Processing the document ...'):
67
  pre_trained_model = "unstructuredio/chipper-fast-fine-tuning"
68
  processor = DonutProcessor.from_pretrained(pre_trained_model)
69
+
70
+ device = "cuda" if torch.cuda.is_available() else "cpu"
71
 
72
+ if 'model' in st.session_state:
73
+ model = st.session_state['model']
74
+ else:
75
+ model = VisionEncoderDecoderModel.from_pretrained(pre_trained_model)
76
 
77
+ from huggingface_hub import hf_hub_download
 
 
78
 
79
+ lm_head_file = hf_hub_download(
80
+ repo_id=pre_trained_model, filename="lm_head.pth"
81
+ )
 
 
 
82
 
83
+ rank = 128
84
+ model.decoder.lm_head = nn.Sequential(
85
+ nn.Linear(model.decoder.lm_head.weight.shape[1], rank, bias=False),
86
+ nn.Linear(rank, rank, bias=False),
87
+ nn.Linear(rank, model.decoder.lm_head.weight.shape[0], bias=True),
88
+ )
89
+
90
+ model.decoder.lm_head.load_state_dict(torch.load(lm_head_file))
91
 
 
92
 
93
+ model.eval()
94
+ model.to(device)
95
+ st.session_state['model'] = model
96
 
97
  st.info(f'Parsing document')
98
  parsed_info = run_prediction(image.convert("RGB"), model, processor, prompt)