File size: 6,801 Bytes
6eb1d7d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
# Copyright (c) Facebook, Inc. and its affiliates.

# pyre-unsafe

import random
from typing import Any, Callable, Dict, Iterable, Iterator, List, Optional, Tuple
import torch
from torch import nn

SampledData = Any
ModelOutput = Any


def _grouper(iterable: Iterable[Any], n: int, fillvalue=None) -> Iterator[Tuple[Any]]:
    """
    Group elements of an iterable by chunks of size `n`, e.g.
    grouper(range(9), 4) ->
        (0, 1, 2, 3), (4, 5, 6, 7), (8, None, None, None)
    """
    it = iter(iterable)
    while True:
        values = []
        for _ in range(n):
            try:
                value = next(it)
            except StopIteration:
                if values:
                    values.extend([fillvalue] * (n - len(values)))
                    yield tuple(values)
                return
            values.append(value)
        yield tuple(values)


class ScoreBasedFilter:
    """
    Filters entries in model output based on their scores
    Discards all entries with score less than the specified minimum
    """

    def __init__(self, min_score: float = 0.8):
        self.min_score = min_score

    def __call__(self, model_output: ModelOutput) -> ModelOutput:
        for model_output_i in model_output:
            instances = model_output_i["instances"]
            if not instances.has("scores"):
                continue
            instances_filtered = instances[instances.scores >= self.min_score]
            model_output_i["instances"] = instances_filtered
        return model_output


class InferenceBasedLoader:
    """
    Data loader based on results inferred by a model. Consists of:
     - a data loader that provides batches of images
     - a model that is used to infer the results
     - a data sampler that converts inferred results to annotations
    """

    def __init__(
        self,
        model: nn.Module,
        data_loader: Iterable[List[Dict[str, Any]]],
        data_sampler: Optional[Callable[[ModelOutput], List[SampledData]]] = None,
        data_filter: Optional[Callable[[ModelOutput], ModelOutput]] = None,
        shuffle: bool = True,
        batch_size: int = 4,
        inference_batch_size: int = 4,
        drop_last: bool = False,
        category_to_class_mapping: Optional[dict] = None,
    ):
        """
        Constructor

        Args:
          model (torch.nn.Module): model used to produce data
          data_loader (Iterable[List[Dict[str, Any]]]): iterable that provides
            dictionaries with "images" and "categories" fields to perform inference on
          data_sampler (Callable: ModelOutput -> SampledData): functor
              that produces annotation data from inference results;
              (optional, default: None)
          data_filter (Callable: ModelOutput -> ModelOutput): filter
              that selects model outputs for further processing
              (optional, default: None)
          shuffle (bool): if True, the input images get shuffled
          batch_size (int): batch size for the produced annotation data
          inference_batch_size (int): batch size for input images
          drop_last (bool): if True, drop the last batch if it is undersized
          category_to_class_mapping (dict): category to class mapping
        """
        self.model = model
        self.model.eval()
        self.data_loader = data_loader
        self.data_sampler = data_sampler
        self.data_filter = data_filter
        self.shuffle = shuffle
        self.batch_size = batch_size
        self.inference_batch_size = inference_batch_size
        self.drop_last = drop_last
        if category_to_class_mapping is not None:
            self.category_to_class_mapping = category_to_class_mapping
        else:
            self.category_to_class_mapping = {}

    def __iter__(self) -> Iterator[List[SampledData]]:
        for batch in self.data_loader:
            # batch : List[Dict[str: Tensor[N, C, H, W], str: Optional[str]]]
            # images_batch : Tensor[N, C, H, W]
            # image : Tensor[C, H, W]
            images_and_categories = [
                {"image": image, "category": category}
                for element in batch
                for image, category in zip(element["images"], element["categories"])
            ]
            if not images_and_categories:
                continue
            if self.shuffle:
                random.shuffle(images_and_categories)
            yield from self._produce_data(images_and_categories)  # pyre-ignore[6]

    def _produce_data(
        self, images_and_categories: List[Tuple[torch.Tensor, Optional[str]]]
    ) -> Iterator[List[SampledData]]:
        """
        Produce batches of data from images

        Args:
          images_and_categories (List[Tuple[torch.Tensor, Optional[str]]]):
            list of images and corresponding categories to process

        Returns:
          Iterator over batches of data sampled from model outputs
        """
        data_batches: List[SampledData] = []
        category_to_class_mapping = self.category_to_class_mapping
        batched_images_and_categories = _grouper(images_and_categories, self.inference_batch_size)
        for batch in batched_images_and_categories:
            batch = [
                {
                    "image": image_and_category["image"].to(self.model.device),
                    "category": image_and_category["category"],
                }
                for image_and_category in batch
                if image_and_category is not None
            ]
            if not batch:
                continue
            with torch.no_grad():
                model_output = self.model(batch)
            for model_output_i, batch_i in zip(model_output, batch):
                assert len(batch_i["image"].shape) == 3
                model_output_i["image"] = batch_i["image"]
                instance_class = category_to_class_mapping.get(batch_i["category"], 0)
                model_output_i["instances"].dataset_classes = torch.tensor(
                    [instance_class] * len(model_output_i["instances"])
                )
            model_output_filtered = (
                model_output if self.data_filter is None else self.data_filter(model_output)
            )
            data = (
                model_output_filtered
                if self.data_sampler is None
                else self.data_sampler(model_output_filtered)
            )
            for data_i in data:
                if len(data_i["instances"]):
                    data_batches.append(data_i)
            if len(data_batches) >= self.batch_size:
                yield data_batches[: self.batch_size]
                data_batches = data_batches[self.batch_size :]
        if not self.drop_last and data_batches:
            yield data_batches