gmastrapas
commited on
Commit
·
b845577
1
Parent(s):
d35adf7
feat: customize image processor in __call__
Browse files- processing_clip.py +29 -7
processing_clip.py
CHANGED
@@ -26,6 +26,14 @@ class JinaCLIPProcessor(CLIPProcessor):
|
|
26 |
|
27 |
class JinaCLIPImageProcessor(BaseImageProcessor):
|
28 |
model_input_names = ['pixel_values']
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
|
30 |
def __init__(
|
31 |
self,
|
@@ -44,14 +52,17 @@ class JinaCLIPImageProcessor(BaseImageProcessor):
|
|
44 |
self.resize_mode = resize_mode
|
45 |
self.interpolation = interpolation
|
46 |
self.fill_color = fill_color
|
47 |
-
self.transform =
|
48 |
-
|
|
|
|
|
|
|
49 |
is_train=False,
|
50 |
-
mean=mean,
|
51 |
-
std=std,
|
52 |
-
resize_mode=resize_mode,
|
53 |
-
interpolation=interpolation,
|
54 |
-
fill_color=fill_color,
|
55 |
aug_cfg=None,
|
56 |
)
|
57 |
|
@@ -61,6 +72,17 @@ class JinaCLIPImageProcessor(BaseImageProcessor):
|
|
61 |
return output
|
62 |
|
63 |
def preprocess(self, images: ImageInput, **kwargs) -> BatchFeature:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
64 |
images = make_list_of_images(images)
|
65 |
out = torch.stack([self.transform(image) for image in images], dim=0)
|
66 |
return BatchFeature(data={'pixel_values': out})
|
|
|
26 |
|
27 |
class JinaCLIPImageProcessor(BaseImageProcessor):
|
28 |
model_input_names = ['pixel_values']
|
29 |
+
_valid_processor_keys = [
|
30 |
+
'size',
|
31 |
+
'mean',
|
32 |
+
'std',
|
33 |
+
'resize_mode',
|
34 |
+
'interpolation',
|
35 |
+
'fill_color',
|
36 |
+
]
|
37 |
|
38 |
def __init__(
|
39 |
self,
|
|
|
52 |
self.resize_mode = resize_mode
|
53 |
self.interpolation = interpolation
|
54 |
self.fill_color = fill_color
|
55 |
+
self.transform = self._build_transform()
|
56 |
+
|
57 |
+
def _build_transform(self):
|
58 |
+
return image_transform(
|
59 |
+
image_size=self.size,
|
60 |
is_train=False,
|
61 |
+
mean=self.mean,
|
62 |
+
std=self.std,
|
63 |
+
resize_mode=self.resize_mode,
|
64 |
+
interpolation=self.interpolation,
|
65 |
+
fill_color=self.fill_color,
|
66 |
aug_cfg=None,
|
67 |
)
|
68 |
|
|
|
72 |
return output
|
73 |
|
74 |
def preprocess(self, images: ImageInput, **kwargs) -> BatchFeature:
|
75 |
+
|
76 |
+
_transform_needs_rebuild = False
|
77 |
+
for k, v in kwargs.items():
|
78 |
+
if k in self._valid_processor_keys:
|
79 |
+
if v != getattr(self, k):
|
80 |
+
setattr(self, k, v)
|
81 |
+
_transform_needs_rebuild = True
|
82 |
+
|
83 |
+
if _transform_needs_rebuild:
|
84 |
+
self.transform = self._build_transform()
|
85 |
+
|
86 |
images = make_list_of_images(images)
|
87 |
out = torch.stack([self.transform(image) for image in images], dim=0)
|
88 |
return BatchFeature(data={'pixel_values': out})
|