DeXFit-TryOn / densepose /data /inference_based_loader.py
zhengchong's picture
chore: Update dependencies and code structure
6eb1d7d
raw
history blame
6.8 kB
# Copyright (c) Facebook, Inc. and its affiliates.
# pyre-unsafe
import random
from typing import Any, Callable, Dict, Iterable, Iterator, List, Optional, Tuple
import torch
from torch import nn
SampledData = Any
ModelOutput = Any
def _grouper(iterable: Iterable[Any], n: int, fillvalue=None) -> Iterator[Tuple[Any]]:
"""
Group elements of an iterable by chunks of size `n`, e.g.
grouper(range(9), 4) ->
(0, 1, 2, 3), (4, 5, 6, 7), (8, None, None, None)
"""
it = iter(iterable)
while True:
values = []
for _ in range(n):
try:
value = next(it)
except StopIteration:
if values:
values.extend([fillvalue] * (n - len(values)))
yield tuple(values)
return
values.append(value)
yield tuple(values)
class ScoreBasedFilter:
"""
Filters entries in model output based on their scores
Discards all entries with score less than the specified minimum
"""
def __init__(self, min_score: float = 0.8):
self.min_score = min_score
def __call__(self, model_output: ModelOutput) -> ModelOutput:
for model_output_i in model_output:
instances = model_output_i["instances"]
if not instances.has("scores"):
continue
instances_filtered = instances[instances.scores >= self.min_score]
model_output_i["instances"] = instances_filtered
return model_output
class InferenceBasedLoader:
"""
Data loader based on results inferred by a model. Consists of:
- a data loader that provides batches of images
- a model that is used to infer the results
- a data sampler that converts inferred results to annotations
"""
def __init__(
self,
model: nn.Module,
data_loader: Iterable[List[Dict[str, Any]]],
data_sampler: Optional[Callable[[ModelOutput], List[SampledData]]] = None,
data_filter: Optional[Callable[[ModelOutput], ModelOutput]] = None,
shuffle: bool = True,
batch_size: int = 4,
inference_batch_size: int = 4,
drop_last: bool = False,
category_to_class_mapping: Optional[dict] = None,
):
"""
Constructor
Args:
model (torch.nn.Module): model used to produce data
data_loader (Iterable[List[Dict[str, Any]]]): iterable that provides
dictionaries with "images" and "categories" fields to perform inference on
data_sampler (Callable: ModelOutput -> SampledData): functor
that produces annotation data from inference results;
(optional, default: None)
data_filter (Callable: ModelOutput -> ModelOutput): filter
that selects model outputs for further processing
(optional, default: None)
shuffle (bool): if True, the input images get shuffled
batch_size (int): batch size for the produced annotation data
inference_batch_size (int): batch size for input images
drop_last (bool): if True, drop the last batch if it is undersized
category_to_class_mapping (dict): category to class mapping
"""
self.model = model
self.model.eval()
self.data_loader = data_loader
self.data_sampler = data_sampler
self.data_filter = data_filter
self.shuffle = shuffle
self.batch_size = batch_size
self.inference_batch_size = inference_batch_size
self.drop_last = drop_last
if category_to_class_mapping is not None:
self.category_to_class_mapping = category_to_class_mapping
else:
self.category_to_class_mapping = {}
def __iter__(self) -> Iterator[List[SampledData]]:
for batch in self.data_loader:
# batch : List[Dict[str: Tensor[N, C, H, W], str: Optional[str]]]
# images_batch : Tensor[N, C, H, W]
# image : Tensor[C, H, W]
images_and_categories = [
{"image": image, "category": category}
for element in batch
for image, category in zip(element["images"], element["categories"])
]
if not images_and_categories:
continue
if self.shuffle:
random.shuffle(images_and_categories)
yield from self._produce_data(images_and_categories) # pyre-ignore[6]
def _produce_data(
self, images_and_categories: List[Tuple[torch.Tensor, Optional[str]]]
) -> Iterator[List[SampledData]]:
"""
Produce batches of data from images
Args:
images_and_categories (List[Tuple[torch.Tensor, Optional[str]]]):
list of images and corresponding categories to process
Returns:
Iterator over batches of data sampled from model outputs
"""
data_batches: List[SampledData] = []
category_to_class_mapping = self.category_to_class_mapping
batched_images_and_categories = _grouper(images_and_categories, self.inference_batch_size)
for batch in batched_images_and_categories:
batch = [
{
"image": image_and_category["image"].to(self.model.device),
"category": image_and_category["category"],
}
for image_and_category in batch
if image_and_category is not None
]
if not batch:
continue
with torch.no_grad():
model_output = self.model(batch)
for model_output_i, batch_i in zip(model_output, batch):
assert len(batch_i["image"].shape) == 3
model_output_i["image"] = batch_i["image"]
instance_class = category_to_class_mapping.get(batch_i["category"], 0)
model_output_i["instances"].dataset_classes = torch.tensor(
[instance_class] * len(model_output_i["instances"])
)
model_output_filtered = (
model_output if self.data_filter is None else self.data_filter(model_output)
)
data = (
model_output_filtered
if self.data_sampler is None
else self.data_sampler(model_output_filtered)
)
for data_i in data:
if len(data_i["instances"]):
data_batches.append(data_i)
if len(data_batches) >= self.batch_size:
yield data_batches[: self.batch_size]
data_batches = data_batches[self.batch_size :]
if not self.drop_last and data_batches:
yield data_batches