Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
kharitonov-ivan committed Oct 31, 2023
1 parent 809e429 commit 49a1282
Show file tree
Hide file tree
Showing 8 changed files with 32 additions and 172 deletions.
15 changes: 8 additions & 7 deletions src/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,13 +255,14 @@ def run_tracker(
tracker_motion_model = env_motion_model if tracker_motion_model is None else tracker_motion_model
tracker_sensor_model = env_sensor_model if tracker_sensor_model is None else tracker_sensor_model

tracker_model = tracker(meas_model=tracker_meas_model, sensor_model=tracker_sensor_model, motion_model=tracker_motion_model, gating_size=tracker_P_G, **tracker_params)

filepath = get_images_dir(__file__) + "/" + tracker_model.__class__.__name__ + f"-P_D={env_P_D}-lambda_c={env_lambda_c}"

tracker_estimations = track(object_data, meas_data, tracker_model)
filepath = get_images_dir(__file__) + "/" + f"-P_D={env_P_D}-lambda_c={env_lambda_c}"
tracker_estimations = None
if tracker is not None:
tracker_model = tracker(meas_model=tracker_meas_model, sensor_model=tracker_sensor_model, motion_model=tracker_motion_model, gating_size=tracker_P_G, **tracker_params)
tracker_estimations = track(object_data, meas_data, tracker_model)
visulaize(object_data, meas_data, tracker_estimations, filepath)
gospa = get_gospa(object_data, tracker_estimations) # noqa
motmetrics = get_motmetrics(object_data, tracker_estimations) # noqa F841
if tracker_estimations is not None:
gospa = get_gospa(object_data, tracker_estimations) # noqa
motmetrics = get_motmetrics(object_data, tracker_estimations) # noqa F841
if os.getenv("ANIMATE", "False") == "True":
animate(object_data, meas_data, tracker_estimations, filepath)
100 changes: 0 additions & 100 deletions src/scenarios/scenario_configs.py

This file was deleted.

2 changes: 1 addition & 1 deletion tests/MOT/n_mot_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from src.scenarios.initial_conditions import all_object_scenarios
from src.trackers.n_object_trackers import GlobalNearestNeighboursTracker

from ..test_trackers import ( # noqa
from ..testing_trackers import ( # noqa
env_clutter_rate,
env_detection_probability,
env_measurement_model,
Expand Down
2 changes: 1 addition & 1 deletion tests/PHD/phd_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from src.scenarios.initial_conditions import all_object_scenarios
from src.trackers.multiple_object_trackers.PHD import GMPHD

from ..test_trackers import ( # noqa
from ..testing_trackers import ( # noqa
env_clutter_rate,
env_detection_probability,
env_measurement_model,
Expand Down
2 changes: 1 addition & 1 deletion tests/PMBM/pmbm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
)
from src.trackers.multiple_object_trackers.PMBM.pmbm import PMBM

from ..test_trackers import ( # noqa
from ..testing_trackers import ( # noqa
env_clutter_rate,
env_detection_probability,
env_measurement_model,
Expand Down
2 changes: 1 addition & 1 deletion tests/SOT/sot_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
)
from src.utils.get_path import delete_images_dir

from ..test_trackers import ( # noqa
from ..testing_trackers import ( # noqa
env_clutter_rate,
env_detection_probability,
env_measurement_model,
Expand Down
81 changes: 20 additions & 61 deletions tests/motion_modeling/object_data_generator_test.py
Original file line number Diff line number Diff line change
@@ -1,76 +1,35 @@
import os
from dataclasses import asdict

import pytest

from src.configs import GroundTruthConfig, SensorModelConfig
from src.measurement_models import (
ConstantVelocityMeasurementModel,
RangeBearingMeasurementModel,
)
from src.motion_models import ConstantVelocityMotionModel, CoordinateTurnMotionModel
from src.run import animate, visulaize
from src.scenarios.scenario_configs import (
linear_n_mot,
linear_sot,
nonlinear_n_mot,
nonlinear_sot,
)
from src.simulator import MeasurementData
from src.simulator.object_data_generator import ObjectData
from src.utils.get_path import delete_images_dir, get_images_dir
from src.measurement_models import ConstantVelocityMeasurementModel
from src.motion_models import ConstantVelocityMotionModel
from src.scenarios.initial_conditions import all_object_scenarios


dir_path = os.path.dirname(os.path.realpath(__file__))

test_data = [
(
linear_sot,
ConstantVelocityMotionModel,
"linear_sot_CV.png",
ConstantVelocityMeasurementModel,
),
(
nonlinear_sot,
CoordinateTurnMotionModel,
"nonlinear_sot_CT.png",
RangeBearingMeasurementModel,
),
(
linear_n_mot,
ConstantVelocityMotionModel,
"linear_n_mot_CV.png",
ConstantVelocityMeasurementModel,
),
(
nonlinear_n_mot,
CoordinateTurnMotionModel,
"nonlinear_n_mot_CT.png",
RangeBearingMeasurementModel,
),
]

@pytest.fixture(params=[x for x in all_object_scenarios if x.motion_model == ConstantVelocityMotionModel])
def object_motion_fixture(request):
yield request.param

@pytest.fixture(scope="session", autouse=True)
def do_something_before_all_tests():
# prepare something ahead of all tests
delete_images_dir(__file__)

@pytest.fixture(params=[ConstantVelocityMeasurementModel(sigma_r=10.0)])
def env_measurement_model(request):
yield request.param

@pytest.mark.parametrize("config, motion_model, output_image_name, meas_model", test_data)
def test_linear_model_without_noise(config, motion_model, output_image_name, meas_model):
config = asdict(config)
ground_truth = GroundTruthConfig(config["object_configs"], config["total_time"])
motion_model = motion_model(**config)

object_data = ObjectData(ground_truth_config=ground_truth, motion_model=motion_model, if_noisy=False)
@pytest.fixture
def env_clutter_rate(request):
yield 0.0

sensor_model = SensorModelConfig(**config)

meas_model = meas_model(**config)
meas_data_gen = MeasurementData(object_data=object_data, sensor_model=sensor_model, meas_model=meas_model)
meas_data = [next(meas_data_gen) for _ in range(len(object_data))]
output_image_name = get_images_dir(__file__) + "/" + output_image_name
visulaize(object_data, meas_data, None, output_image_name)
if os.getenv("ANIMATE", "False") == "True":
animate(object_data, meas_data, None, output_image_name)
@pytest.fixture
def env_detection_probability(request):
yield 1.0


@pytest.fixture
def tracker():
yield None, None
File renamed without changes.

0 comments on commit 49a1282

Please sign in to comment.