Spaces:
Running
Running
import math | |
import os | |
import pytest | |
import torch | |
from mmcv import Config | |
from risk_biased.mpc_planner.dynamics import PositionVelocityDoubleIntegrator | |
from risk_biased.mpc_planner.planner_cost import TrackingCost, TrackingCostParams | |
from risk_biased.utils.cost import TTCCostTorch, TTCCostParams | |
from risk_biased.utils.risk import get_risk_estimator | |
from risk_biased.utils.planner_utils import ( | |
to_state, | |
get_interaction_cost, | |
evaluate_risk, | |
evaluate_control_sequence, | |
) | |
def params(): | |
torch.manual_seed(0) | |
working_dir = os.path.dirname(os.path.realpath(__file__)) | |
config_path = os.path.join( | |
working_dir, "..", "..", "..", "risk_biased", "config", "learning_config.py" | |
) | |
waymo_config_path = os.path.join( | |
working_dir, "..", "..", "..", "risk_biased", "config", "waymo_config.py" | |
) | |
paths = [config_path, waymo_config_path] | |
if isinstance(paths, str): | |
cfg = Config.fromfile(paths) | |
else: | |
cfg = Config.fromfile(paths[0]) | |
for path in paths[1:]: | |
c = Config.fromfile(path) | |
cfg.update(c) | |
cfg.num_control_samples = 10 | |
cfg.dt = 0.1 | |
cfg.num_steps = 3 | |
cfg.num_steps_future = 5 | |
cfg.state_dim = 5 | |
cfg.tracking_cost_scale_longitudinal = 0.1 | |
cfg.tracking_cost_scale_lateral = 1.0 | |
cfg.tracking_cost_reduce = "mean" | |
cfg.cost_scale = 10 | |
cfg.cost_reduce = "mean" | |
cfg.distance_bandwidth = 2 | |
cfg.time_bandwidth = 0.5 | |
cfg.min_velocity_diff = 0.01 | |
cfg.risk_estimator = {"type": "cvar", "eps": 1e-3} | |
return cfg | |
class TestPlannerUtils: | |
def setup(self, params): | |
self.dynamics_model = PositionVelocityDoubleIntegrator(0.1) | |
self.interaction_cost_function = TTCCostTorch(TTCCostParams.from_config(params)) | |
self.tracking_cost_function = TrackingCost( | |
TrackingCostParams.from_config(params) | |
) | |
self.risk_estimator = get_risk_estimator(params.risk_estimator) | |
self.dt = params.dt | |
def test_translate_position(self, ndim: int, sequence_size: int): | |
translation_m = torch.Tensor([1.0, 2.0]) | |
state_pos = torch.zeros(sequence_size, 2) | |
while state_pos.ndim < ndim: | |
state_pos = state_pos.unsqueeze(0) | |
state_pos = to_state(state_pos, self.dt) | |
translated_state_pos = state_pos.translate(translation_m) | |
assert torch.allclose( | |
translated_state_pos.get_states(), state_pos.get_states() + translation_m | |
) | |
state_double_integrator = torch.zeros(sequence_size, 4) | |
while state_double_integrator.ndim < ndim: | |
state_double_integrator = state_double_integrator.unsqueeze(0) | |
state_double_integrator = to_state(state_double_integrator, self.dt) | |
translated_state_double_integrator = state_double_integrator.translate( | |
translation_m | |
) | |
assert torch.allclose( | |
translated_state_double_integrator.get_states(5), | |
translated_state_pos.get_states(5), | |
) | |
def test_rotate_angle(self, ndim: int, sequence_size: int): | |
rotation_rad = torch.Tensor([math.pi / 2]) | |
state_pos = torch.ones(sequence_size, 2) | |
while state_pos.ndim < ndim: | |
state_pos = state_pos.unsqueeze(0) | |
state_pos = to_state(state_pos, self.dt) | |
rotated_state_pos = state_pos.rotate(rotation_rad) | |
assert torch.allclose( | |
rotated_state_pos.get_states(), | |
torch.Tensor([-1.0, 1.0]).expand_as(state_pos.get_states()), | |
) | |
state_double_integrator = torch.Tensor([[1.0, 1.0, -1.0, 1.0]]).repeat( | |
sequence_size, 1 | |
) | |
while state_double_integrator.ndim < ndim: | |
state_double_integrator = state_double_integrator.unsqueeze(0) | |
state_double_integrator = to_state(state_double_integrator, self.dt) | |
rotated_state_double_integrator = state_double_integrator.rotate(rotation_rad) | |
assert torch.allclose( | |
rotated_state_double_integrator.get_states(2), | |
rotated_state_pos.get_states(2), | |
) | |
assert torch.allclose( | |
rotated_state_double_integrator.get_states(4), | |
torch.Tensor([-1.0, 1.0, -1.0, -1.0]).expand_as( | |
rotated_state_pos.get_states(4) | |
), | |
) | |
def test_get_interaction_cost( | |
self, params, with_ado_batch_dim, num_prediction_samples, num_agents | |
): | |
ego_state_future = to_state( | |
torch.randn( | |
params.num_control_samples, 1, params.num_steps_future, params.state_dim | |
), | |
params.dt, | |
) | |
if not with_ado_batch_dim: | |
ado_position_future_samples = to_state( | |
torch.randn( | |
num_prediction_samples, | |
num_agents, | |
params.num_steps_future, | |
params.state_dim, | |
), | |
params.dt, | |
) | |
else: | |
ado_position_future_samples = to_state( | |
torch.randn( | |
num_prediction_samples, | |
num_agents, | |
params.num_steps_future, | |
params.state_dim, | |
), | |
params.dt, | |
) | |
cost = get_interaction_cost( | |
ego_state_future, | |
ado_position_future_samples, | |
self.interaction_cost_function, | |
) | |
assert cost.shape == torch.Size( | |
[params.num_control_samples, num_agents, num_prediction_samples] | |
) | |
def test_evaluate_risk( | |
self, params, num_prediction_samples, num_agents, risk_level | |
): | |
cost = torch.rand( | |
params.num_control_samples, num_agents, num_prediction_samples | |
) | |
weights = ( | |
torch.rand(params.num_control_samples, num_agents, num_prediction_samples) | |
/ num_prediction_samples | |
) | |
risk = evaluate_risk(risk_level, cost, weights, self.risk_estimator) | |
if risk_level is None or risk_level == 0.0: | |
assert torch.allclose(risk, cost.mean(dim=2)) | |
assert risk.shape == torch.Size([params.num_control_samples, num_agents]) | |
def test_evaluate_control_sequence(self, params, risk_level): | |
num_prediction_samples = 5 | |
num_agents = 1 | |
control_sequence = torch.randn( | |
1, params.num_steps_future, self.dynamics_model.control_dim | |
) | |
ego_state_history = to_state( | |
torch.randn(1, params.num_steps, params.state_dim), self.dt | |
) | |
ego_state_target_trajectory = to_state( | |
torch.randn(1, params.num_steps_future, params.state_dim), self.dt | |
) | |
ado_state_future_samples = to_state( | |
torch.randn(num_prediction_samples, num_agents, params.num_steps_future, 2), | |
params.dt, | |
) | |
weights = ( | |
torch.rand(num_prediction_samples, num_agents) / num_prediction_samples | |
) | |
interaction_risk, tracking_cost = evaluate_control_sequence( | |
control_sequence, | |
self.dynamics_model, | |
ego_state_history, | |
ego_state_target_trajectory, | |
ado_state_future_samples, | |
weights, | |
self.interaction_cost_function, | |
self.tracking_cost_function, | |
risk_level, | |
self.risk_estimator, | |
) | |
assert interaction_risk > 0.0 | |
assert tracking_cost > 0.0 | |