gmastrapas commited on
Commit
b845577
·
1 Parent(s): d35adf7

feat: customize image processor in __call__

Browse files
Files changed (1) hide show
  1. processing_clip.py +29 -7
processing_clip.py CHANGED
@@ -26,6 +26,14 @@ class JinaCLIPProcessor(CLIPProcessor):
26
 
27
  class JinaCLIPImageProcessor(BaseImageProcessor):
28
  model_input_names = ['pixel_values']
 
 
 
 
 
 
 
 
29
 
30
  def __init__(
31
  self,
@@ -44,14 +52,17 @@ class JinaCLIPImageProcessor(BaseImageProcessor):
44
  self.resize_mode = resize_mode
45
  self.interpolation = interpolation
46
  self.fill_color = fill_color
47
- self.transform = image_transform(
48
- image_size=size,
 
 
 
49
  is_train=False,
50
- mean=mean,
51
- std=std,
52
- resize_mode=resize_mode,
53
- interpolation=interpolation,
54
- fill_color=fill_color,
55
  aug_cfg=None,
56
  )
57
 
@@ -61,6 +72,17 @@ class JinaCLIPImageProcessor(BaseImageProcessor):
61
  return output
62
 
63
  def preprocess(self, images: ImageInput, **kwargs) -> BatchFeature:
 
 
 
 
 
 
 
 
 
 
 
64
  images = make_list_of_images(images)
65
  out = torch.stack([self.transform(image) for image in images], dim=0)
66
  return BatchFeature(data={'pixel_values': out})
 
26
 
27
  class JinaCLIPImageProcessor(BaseImageProcessor):
28
  model_input_names = ['pixel_values']
29
+ _valid_processor_keys = [
30
+ 'size',
31
+ 'mean',
32
+ 'std',
33
+ 'resize_mode',
34
+ 'interpolation',
35
+ 'fill_color',
36
+ ]
37
 
38
  def __init__(
39
  self,
 
52
  self.resize_mode = resize_mode
53
  self.interpolation = interpolation
54
  self.fill_color = fill_color
55
+ self.transform = self._build_transform()
56
+
57
+ def _build_transform(self):
58
+ return image_transform(
59
+ image_size=self.size,
60
  is_train=False,
61
+ mean=self.mean,
62
+ std=self.std,
63
+ resize_mode=self.resize_mode,
64
+ interpolation=self.interpolation,
65
+ fill_color=self.fill_color,
66
  aug_cfg=None,
67
  )
68
 
 
72
  return output
73
 
74
  def preprocess(self, images: ImageInput, **kwargs) -> BatchFeature:
75
+
76
+ _transform_needs_rebuild = False
77
+ for k, v in kwargs.items():
78
+ if k in self._valid_processor_keys:
79
+ if v != getattr(self, k):
80
+ setattr(self, k, v)
81
+ _transform_needs_rebuild = True
82
+
83
+ if _transform_needs_rebuild:
84
+ self.transform = self._build_transform()
85
+
86
  images = make_list_of_images(images)
87
  out = torch.stack([self.transform(image) for image in images], dim=0)
88
  return BatchFeature(data={'pixel_values': out})