jmercat's picture
Removed history to avoid any unverified information being released
5769ee4
import math
import os
import pytest
import torch
from mmcv import Config
from risk_biased.mpc_planner.dynamics import PositionVelocityDoubleIntegrator
from risk_biased.mpc_planner.planner_cost import TrackingCost, TrackingCostParams
from risk_biased.utils.cost import TTCCostTorch, TTCCostParams
from risk_biased.utils.risk import get_risk_estimator
from risk_biased.utils.planner_utils import (
to_state,
get_interaction_cost,
evaluate_risk,
evaluate_control_sequence,
)
@pytest.fixture(scope="module")
def params():
torch.manual_seed(0)
working_dir = os.path.dirname(os.path.realpath(__file__))
config_path = os.path.join(
working_dir, "..", "..", "..", "risk_biased", "config", "learning_config.py"
)
waymo_config_path = os.path.join(
working_dir, "..", "..", "..", "risk_biased", "config", "waymo_config.py"
)
paths = [config_path, waymo_config_path]
if isinstance(paths, str):
cfg = Config.fromfile(paths)
else:
cfg = Config.fromfile(paths[0])
for path in paths[1:]:
c = Config.fromfile(path)
cfg.update(c)
cfg.num_control_samples = 10
cfg.dt = 0.1
cfg.num_steps = 3
cfg.num_steps_future = 5
cfg.state_dim = 5
cfg.tracking_cost_scale_longitudinal = 0.1
cfg.tracking_cost_scale_lateral = 1.0
cfg.tracking_cost_reduce = "mean"
cfg.cost_scale = 10
cfg.cost_reduce = "mean"
cfg.distance_bandwidth = 2
cfg.time_bandwidth = 0.5
cfg.min_velocity_diff = 0.01
cfg.risk_estimator = {"type": "cvar", "eps": 1e-3}
return cfg
class TestPlannerUtils:
@pytest.fixture(autouse=True)
def setup(self, params):
self.dynamics_model = PositionVelocityDoubleIntegrator(0.1)
self.interaction_cost_function = TTCCostTorch(TTCCostParams.from_config(params))
self.tracking_cost_function = TrackingCost(
TrackingCostParams.from_config(params)
)
self.risk_estimator = get_risk_estimator(params.risk_estimator)
self.dt = params.dt
@pytest.mark.parametrize("ndim, sequence_size", [(2, 2), (3, 4), (4, 3)])
def test_translate_position(self, ndim: int, sequence_size: int):
translation_m = torch.Tensor([1.0, 2.0])
state_pos = torch.zeros(sequence_size, 2)
while state_pos.ndim < ndim:
state_pos = state_pos.unsqueeze(0)
state_pos = to_state(state_pos, self.dt)
translated_state_pos = state_pos.translate(translation_m)
assert torch.allclose(
translated_state_pos.get_states(), state_pos.get_states() + translation_m
)
state_double_integrator = torch.zeros(sequence_size, 4)
while state_double_integrator.ndim < ndim:
state_double_integrator = state_double_integrator.unsqueeze(0)
state_double_integrator = to_state(state_double_integrator, self.dt)
translated_state_double_integrator = state_double_integrator.translate(
translation_m
)
assert torch.allclose(
translated_state_double_integrator.get_states(5),
translated_state_pos.get_states(5),
)
@pytest.mark.parametrize("ndim, sequence_size", [(2, 3), (3, 5), (4, 2)])
def test_rotate_angle(self, ndim: int, sequence_size: int):
rotation_rad = torch.Tensor([math.pi / 2])
state_pos = torch.ones(sequence_size, 2)
while state_pos.ndim < ndim:
state_pos = state_pos.unsqueeze(0)
state_pos = to_state(state_pos, self.dt)
rotated_state_pos = state_pos.rotate(rotation_rad)
assert torch.allclose(
rotated_state_pos.get_states(),
torch.Tensor([-1.0, 1.0]).expand_as(state_pos.get_states()),
)
state_double_integrator = torch.Tensor([[1.0, 1.0, -1.0, 1.0]]).repeat(
sequence_size, 1
)
while state_double_integrator.ndim < ndim:
state_double_integrator = state_double_integrator.unsqueeze(0)
state_double_integrator = to_state(state_double_integrator, self.dt)
rotated_state_double_integrator = state_double_integrator.rotate(rotation_rad)
assert torch.allclose(
rotated_state_double_integrator.get_states(2),
rotated_state_pos.get_states(2),
)
assert torch.allclose(
rotated_state_double_integrator.get_states(4),
torch.Tensor([-1.0, 1.0, -1.0, -1.0]).expand_as(
rotated_state_pos.get_states(4)
),
)
@pytest.mark.parametrize(
"with_ado_batch_dim, num_prediction_samples, num_agents",
[(True, 1, 1), (False, 1, 2), (True, 5, 1), (False, 5, 2)],
)
def test_get_interaction_cost(
self, params, with_ado_batch_dim, num_prediction_samples, num_agents
):
ego_state_future = to_state(
torch.randn(
params.num_control_samples, 1, params.num_steps_future, params.state_dim
),
params.dt,
)
if not with_ado_batch_dim:
ado_position_future_samples = to_state(
torch.randn(
num_prediction_samples,
num_agents,
params.num_steps_future,
params.state_dim,
),
params.dt,
)
else:
ado_position_future_samples = to_state(
torch.randn(
num_prediction_samples,
num_agents,
params.num_steps_future,
params.state_dim,
),
params.dt,
)
cost = get_interaction_cost(
ego_state_future,
ado_position_future_samples,
self.interaction_cost_function,
)
assert cost.shape == torch.Size(
[params.num_control_samples, num_agents, num_prediction_samples]
)
@pytest.mark.parametrize(
"num_prediction_samples, num_agents, risk_level",
[(1, 1, 0.0), (5, 2, 0.0), (1, 1, 0.9), (5, 2, 0.9)],
)
def test_evaluate_risk(
self, params, num_prediction_samples, num_agents, risk_level
):
cost = torch.rand(
params.num_control_samples, num_agents, num_prediction_samples
)
weights = (
torch.rand(params.num_control_samples, num_agents, num_prediction_samples)
/ num_prediction_samples
)
risk = evaluate_risk(risk_level, cost, weights, self.risk_estimator)
if risk_level is None or risk_level == 0.0:
assert torch.allclose(risk, cost.mean(dim=2))
assert risk.shape == torch.Size([params.num_control_samples, num_agents])
@pytest.mark.parametrize("risk_level", [(0.0), (0.5)])
def test_evaluate_control_sequence(self, params, risk_level):
num_prediction_samples = 5
num_agents = 1
control_sequence = torch.randn(
1, params.num_steps_future, self.dynamics_model.control_dim
)
ego_state_history = to_state(
torch.randn(1, params.num_steps, params.state_dim), self.dt
)
ego_state_target_trajectory = to_state(
torch.randn(1, params.num_steps_future, params.state_dim), self.dt
)
ado_state_future_samples = to_state(
torch.randn(num_prediction_samples, num_agents, params.num_steps_future, 2),
params.dt,
)
weights = (
torch.rand(num_prediction_samples, num_agents) / num_prediction_samples
)
interaction_risk, tracking_cost = evaluate_control_sequence(
control_sequence,
self.dynamics_model,
ego_state_history,
ego_state_target_trajectory,
ado_state_future_samples,
weights,
self.interaction_cost_function,
self.tracking_cost_function,
risk_level,
self.risk_estimator,
)
assert interaction_risk > 0.0
assert tracking_cost > 0.0