Spaces:
Runtime error
Runtime error
# Copyright (c) Facebook, Inc. and its affiliates. | |
# pyre-unsafe | |
from dataclasses import dataclass | |
from typing import Any, Callable, Dict, List, Optional | |
from detectron2.structures import Instances | |
ModelOutput = Dict[str, Any] | |
SampledData = Dict[str, Any] | |
class _Sampler: | |
""" | |
Sampler registry entry that contains: | |
- src (str): source field to sample from (deleted after sampling) | |
- dst (Optional[str]): destination field to sample to, if not None | |
- func (Optional[Callable: Any -> Any]): function that performs sampling, | |
if None, reference copy is performed | |
""" | |
src: str | |
dst: Optional[str] | |
func: Optional[Callable[[Any], Any]] | |
class PredictionToGroundTruthSampler: | |
""" | |
Sampler implementation that converts predictions to GT using registered | |
samplers for different fields of `Instances`. | |
""" | |
def __init__(self, dataset_name: str = ""): | |
self.dataset_name = dataset_name | |
self._samplers = {} | |
self.register_sampler("pred_boxes", "gt_boxes", None) | |
self.register_sampler("pred_classes", "gt_classes", None) | |
# delete scores | |
self.register_sampler("scores") | |
def __call__(self, model_output: List[ModelOutput]) -> List[SampledData]: | |
""" | |
Transform model output into ground truth data through sampling | |
Args: | |
model_output (Dict[str, Any]): model output | |
Returns: | |
Dict[str, Any]: sampled data | |
""" | |
for model_output_i in model_output: | |
instances: Instances = model_output_i["instances"] | |
# transform data in each field | |
for _, sampler in self._samplers.items(): | |
if not instances.has(sampler.src) or sampler.dst is None: | |
continue | |
if sampler.func is None: | |
instances.set(sampler.dst, instances.get(sampler.src)) | |
else: | |
instances.set(sampler.dst, sampler.func(instances)) | |
# delete model output data that was transformed | |
for _, sampler in self._samplers.items(): | |
if sampler.src != sampler.dst and instances.has(sampler.src): | |
instances.remove(sampler.src) | |
model_output_i["dataset"] = self.dataset_name | |
return model_output | |
def register_sampler( | |
self, | |
prediction_attr: str, | |
gt_attr: Optional[str] = None, | |
func: Optional[Callable[[Any], Any]] = None, | |
): | |
""" | |
Register sampler for a field | |
Args: | |
prediction_attr (str): field to replace with a sampled value | |
gt_attr (Optional[str]): field to store the sampled value to, if not None | |
func (Optional[Callable: Any -> Any]): sampler function | |
""" | |
self._samplers[(prediction_attr, gt_attr)] = _Sampler( | |
src=prediction_attr, dst=gt_attr, func=func | |
) | |
def remove_sampler( | |
self, | |
prediction_attr: str, | |
gt_attr: Optional[str] = None, | |
): | |
""" | |
Remove sampler for a field | |
Args: | |
prediction_attr (str): field to replace with a sampled value | |
gt_attr (Optional[str]): field to store the sampled value to, if not None | |
""" | |
assert (prediction_attr, gt_attr) in self._samplers | |
del self._samplers[(prediction_attr, gt_attr)] | |