zb9 commited on
Commit
c0f6104
·
verified ·
1 Parent(s): f29450b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -14
app.py CHANGED
@@ -1,10 +1,9 @@
1
  import gradio as gr
2
  from colpali_engine.models import ColQwen2, ColQwen2Processor
3
  import torch
4
- import base64
5
  from PIL import Image
6
- import io
7
  import logging
 
8
 
9
  # Setup logging
10
  logging.basicConfig(level=logging.INFO)
@@ -13,25 +12,37 @@ logger = logging.getLogger("colqwen-api")
13
  # Initialize model
14
  logger.info("Loading ColQwen2 model...")
15
  model = ColQwen2.from_pretrained(
16
- "vidore/colqwen2-v0.1",
17
  torch_dtype=torch.bfloat16,
18
  device_map="auto",
19
  )
20
- processor = ColQwen2Processor.from_pretrained("vidore/colqwen2-v0.1")
21
  model = model.eval()
22
  logger.info("Model loaded successfully")
23
 
24
- def process_image(image_data):
25
  try:
26
  logger.info("Processing image")
27
- processed = processor.process_images([image_data])
28
- logger.info("Image processed")
 
29
 
30
- with torch.no_grad():
31
- embeddings = model(**processed)
32
- logger.info(f"Embeddings generated: {embeddings.shape}")
 
 
33
 
34
- return {"embeddings": embeddings.tolist()}
 
 
 
 
 
 
 
 
 
35
  except Exception as e:
36
  logger.error(f"Error: {str(e)}", exc_info=True)
37
  raise
@@ -40,8 +51,9 @@ interface = gr.Interface(
40
  fn=process_image,
41
  inputs=gr.Image(),
42
  outputs="json",
43
- title="ColQwen2 Embedding API"
 
44
  )
45
 
46
- # Add share=True to create public URL
47
- interface.launch()
 
1
  import gradio as gr
2
  from colpali_engine.models import ColQwen2, ColQwen2Processor
3
  import torch
 
4
  from PIL import Image
 
5
  import logging
6
+ import numpy as np
7
 
8
  # Setup logging
9
  logging.basicConfig(level=logging.INFO)
 
12
  # Initialize model
13
  logger.info("Loading ColQwen2 model...")
14
  model = ColQwen2.from_pretrained(
15
+ "vidore/colqwen2-v1.0", # Updated to v1.0
16
  torch_dtype=torch.bfloat16,
17
  device_map="auto",
18
  )
19
+ processor = ColQwen2Processor.from_pretrained("vidore/colqwen2-v1.0")
20
  model = model.eval()
21
  logger.info("Model loaded successfully")
22
 
23
+ def process_image(image):
24
  try:
25
  logger.info("Processing image")
26
+ # Convert to PIL Image if needed
27
+ if not isinstance(image, Image.Image):
28
+ image = Image.fromarray(image)
29
 
30
+ # Process image
31
+ inputs = processor(
32
+ images=image,
33
+ return_tensors="pt"
34
+ ).to(model.device)
35
 
36
+ logger.info("Generating embeddings")
37
+ with torch.no_grad():
38
+ outputs = model(**inputs)
39
+ embeddings = outputs.last_hidden_state.mean(dim=1).cpu().numpy()
40
+
41
+ logger.info(f"Embeddings shape: {embeddings.shape}")
42
+ return {
43
+ "embeddings": embeddings.tolist(),
44
+ "shape": embeddings.shape
45
+ }
46
  except Exception as e:
47
  logger.error(f"Error: {str(e)}", exc_info=True)
48
  raise
 
51
  fn=process_image,
52
  inputs=gr.Image(),
53
  outputs="json",
54
+ title="ColQwen2 Embedding API",
55
+ description="Generate embeddings from images using ColQwen2"
56
  )
57
 
58
+ # Launch with API
59
+ interface.launch(server_name="0.0.0.0", server_port=7861)