Skip to content

Commit 121992d

Browse files
Add files via upload
1 parent d05efc8 commit 121992d

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

49 files changed

+15895
-0
lines changed

LICENSE

Lines changed: 673 additions & 0 deletions
Large diffs are not rendered by default.

README.md

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
# drone_causality
2+
3+
All training, data processing, and analysis code used for the paper "Robust Visual Flight Navigation with Liquid Neural Networks". For code run onboard the drone, see [this repository](https://github.com/GoldenZephyr/rosetta_drone).
4+
5+
## Installation Instructions
6+
7+
For x86 based systems (most computers), setup your python environment using conda environment file in configs/environment.yml
8+
9+
~~~
10+
cd drone_causality
11+
conda env create -f config/environment.yml
12+
conda activate causality
13+
~~~
14+
15+
Another environment file is available for ppc64le (PowerPC) based architectures
16+
~~~
17+
conda env create -f config/satori_environment.yml
18+
conda activate causality
19+
~~~
20+
21+
Alternatively, a Docker image containing all required packages can be found on Docker Hub at dolphonie1/causal_repo:0.1.17
22+
23+
~~~
24+
docker pull dolphonie1/causal_repo:0.1.17
25+
docker run -it --net=host dolphonie1/causal_repo:0.1.17 /bin/bash
26+
~~~
27+
## Downloading Datasets/Existing Checkpoints
28+
The original hand-collected training dataset can be found [here](http://knightridermit.myqnapcloud.com:8080/share.cgi?ssid=06lMJMN&fid=06lMJMN&path=%2F&filename=devens_snowy_fixed.zip&openfolder=forcedownload&ep=) (filename: devens_snowy_fixed, size:33.2GB). Additionally, we have a subset of the full `devens_snowy_fixed` dataset that only contains runs with the chair [here](http://knightridermit.myqnapcloud.com:8080/share.cgi?ssid=06lMJMN&fid=06lMJMN&path=%2F&filename=devens_chair.zip&openfolder=forcedownload&ep=) (devens_chair, 2.3GB).
29+
30+
We have also included the exact synthetic datasets we used for our experiments. These datasets were created using the script at `preprocess/closed_loop_augmentation.py`, but with a random seed. We have both a [full dataset](http://knightridermit.myqnapcloud.com:8080/share.cgi?ssid=06lMJMN&fid=06lMJMN&path=%2F&filename=synthetic_small4.zip&openfolder=forcedownload&ep=) (synthetic_small4, 14.7GB) used to train the starting checkpoint here and a [chair-only dataset](http://knightridermit.myqnapcloud.com:8080/share.cgi?ssid=06lMJMN&fid=06lMJMN&path=%2F&filename=synthetic_chair4.zip&openfolder=forcedownload&ep=) (synthetic_chair, 4.3 GB) used to fine-tune the final models for testing at.
31+
32+
To replicate the results of our experiments, first train on the entire dataset, `devens_snowy_fixed` with the full synthetic dataset `synthetic_small4` or use the checkpoints in chair4_long_balanced.
33+
34+
Afterwards, fine-tune models starting from `checkpoints/chair4_long_balanced` on the `devens_chair` dataset with the synthetic dataset `synthetic_chair4`
35+
36+
All training was done using the best hyperparameters found in the `old_db` folder.
37+
## Training Models
38+
### Training Once
39+
The script tf_data_training.py executes 1 training run. It loads data and models, sets up multi-GPU processing strategy, and runs training while checkpointing models. The script's default hyperparameters are static and are _not_ the best hyperparameters found during parameter tuning. Any hyperparameters need to be manually specified.
40+
41+
Example usage:
42+
~~~
43+
python3 tf_data_training.py --model ncp --data_dir /path/to/devens_snowy_fixed --extra_data_dir /path/to/synthetic_small4 --epochs 100 --seq_len 64 --data_stride 1 --data_shift 16
44+
~~~
45+
46+
### Training Multiple Times
47+
The convenience script train_multiple.py automatically manages multiple training runs, saving log JSON files to record the results of each run and intelligently determining how many runs have been completed so far to allow for resuming training. The script also automatically loads hyperparameters from the best study when given a hyperparameter study database file.
48+
49+
Example usage:
50+
~~~
51+
python train_multiple.py ncp_objective /path/to/devens_snowy_short --n_trains 5 --batch_size 300 --storage_name sqlite:///old_db/ncp_objective".db --storage_type rdb --timeout 72000 --extra_data_dir /path/to/synthetic --hotstart_dir /path/to/chair4_long_balanced --study_name hyperparam_tuning_ --out_dir chair4_fine_targets
52+
~~~
53+
54+
The `storage_name` argument specifies the database file (in the `old_db` folder) that the best hyperparameters should be read from. Unfortunately, because training was conducted on different machines, different objectives have different hyperparameter files. For each type of network, use the following `storage_name`:
55+
56+
- LSTM: sqlite:///old_db/lstm_objective.db
57+
- CFC: sqlite:///old_db/cfc_objective.db
58+
- NCP: sqlite:///old_db/ncp_objective.db
59+
- GRUODE: sqlite:///old_db/hyperparam_tuning.db
60+
- TCN: old_db/tcn_objective.json
61+
- Wiredcfccell (Sparse-CfC): sqlite:///old_db/wiredcfccell_objective.db
62+
- LTC: sqlite:///old_db/hyperparam_tuning.db
63+
- CT-RNN: old_db/ctrnn_objective.json
64+
65+
Note that the `storage_type` argument should be set to `rdb` for sqlite URLs, json for JSON files, and `pkl` for PKL files
66+
67+
## Preprocessing Data
68+
This section describes the methodology used to generate the dataset `devens_snowy_fixed`.
69+
70+
If using new data collected on the drone, use script `preprocess/process_data.py` to format it correctly for training scripts. Runs should have the red channel as the 0th channel (appear not flipped when opened by an image viewer).
71+
72+
The runs tht don't have an underscore in them (ex 1628106140.64) are the original long runs that see all 5 targets. The runs with underscores (ex 1628106140.64_1) are generated using the script `preprocess/sequence_slice/slice_sequence.py`, which provides a GUI for specifying start and end points and automaticallly copies images and control csv.
73+
74+
To generate new synthetic datasets, use the script `preprocess/closed_loop_augmentation.py`. The directory `preprocess/aug_json` contains JSON files that contain images to be augmented and the pixel location of the target within the image (generated by `preprocess/select_targets.py`).
75+
76+
Example Usage:
77+
The dataset `synthetic_small4` was generated with the following invocation:
78+
~~~
79+
python closed_loop_augmentation.py aug_json/synthetic_full_small.json /path/to/out/dir/synthetic_small4 --num_aug 5 --balance_classes --balance_offsets -10 -70 0 0
80+
~~~
81+
82+
## Tuning Hyperparams
83+
The Optuna hyperparameter study db files in the `old_db` directory were generated using the file `hyperparam_tuning.py`. This script is responsible for sampling parameters using Bayesian Optimization, running training multiple times using the objective functions in `utils/objective_functions.py`, and logging the results within the Optuna study object.
84+
85+
Example usage:
86+
~~~
87+
python hyperparameter_tuning.py ncp_objective /path/to/dataset --n_trials 40 --timeout 64800 --batch_size 300 --extra_data_dir /path/to/synthetic_dataset
88+
~~~
89+
90+
## Analyzing Results
91+
92+
### Stress Tests
93+
The stress test figures used in the paper were generated with the script `analysis/perturb_trajectory.py`
94+
95+
Example usage:
96+
~~~
97+
python analysis/perturb_trajectory.py dataset_jsons/chair_short_raw.json checkpoints/chair4_fine/train/params.json contrast_perturbation --distance_fxn final_distance --deltas 0.5 1.5 2 2.5 --skip_models ctrnn_mixedcfc --perturb_frac 0.2 --force_even_x
98+
~~~
99+
100+
This file, (and most other analysis files), consume a dataset_json file in the format
101+
~~~
102+
{
103+
"name_of_dataset" : [
104+
"/path/to/dataset",
105+
[boolean of whether to flip color channels],
106+
"path/to/control_csv" or null if no csv desired,
107+
], ...
108+
}
109+
~~~
110+
111+
You will most likely have to edit the files in `dataset_jsons` to match the runs you want to analyze on your computer.
112+
### Useful files
113+
114+
- visualization_runner.py: Used for generating videos of visual backprop, input grad, shap, or other visualization technique overlaid on original video sequence and visualization of controls
115+
- analysis/vis_grid.py: Used for generating multiple images of visual backprop and original camera images. Used in paper
116+
- analysis/lipschitz_constant.py: Calculates lipschitz constant of RNN hidden state components when seeing a given sequence of inputs. (Measures maximum difference in rnn hidden state in 2 consecutive timestamps)
117+
- analysis/loss_graph.py: Plots training loss curves
118+
- analysis/ssim.py: Calculates structural similarity index of saliency maps when random noise is added to image
119+
120+
121+
Contact: patrick[dot]d[dot]kao[at]gmail[dot]com for any questions

helper_scripts/flip_channels.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
# Created by Patrick Kao at 5/3/22
2+
import argparse
3+
import os
4+
from pathlib import Path
5+
6+
import cv2
7+
import numpy as np
8+
9+
from keras_models import IMAGE_SHAPE
10+
from utils.data_utils import load_image
11+
12+
13+
def flip_channels(im_dir: str, out_dir: str):
14+
Path(out_dir).mkdir(parents=True, exist_ok=True)
15+
for im_path in os.listdir(im_dir):
16+
img = load_image(os.path.join(im_dir, im_path), IMAGE_SHAPE, reverse_channels=False) # writing flips channels
17+
cv2.imwrite(os.path.join(out_dir, im_path), np.squeeze(img, axis=0), )
18+
19+
if __name__ == "__main__":
20+
parser = argparse.ArgumentParser()
21+
parser.add_argument("im_dir")
22+
parser.add_argument("out_dir")
23+
args = parser.parse_args()
24+
flip_channels(args.im_dir, args.out_dir)

helper_scripts/flip_csv.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
import os
2+
from typing import Sequence
3+
4+
import pandas as pd
5+
6+
7+
def flip_csv(csv_file: str, columns: Sequence[str]):
8+
df = pd.read_csv(csv_file)
9+
for col in columns:
10+
df[col] = df[col].apply(lambda x: x*-1)
11+
12+
df.to_csv(csv_file, index=False)
13+
14+
data = "/home/dolphonie/Desktop/mixed_aug_fixed"
15+
for folder in os.listdir(data):
16+
flip_csv(os.path.join(data, folder, "data_out.csv"), ["vz", "omega_z"])
Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
import argparse
2+
import json
3+
import os.path
4+
import re
5+
import shutil
6+
from collections import defaultdict
7+
from json import JSONDecodeError
8+
from pathlib import Path
9+
from typing import List, Dict, Any
10+
11+
from utils.model_utils import get_readable_name
12+
13+
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
14+
15+
16+
def get_checkpoint_props(checkpoint_path: str) -> Dict[str, Any]:
17+
"""
18+
Given name of checkpoint path, extracts relevant properties from string
19+
20+
:param checkpoint_path: Path or basename of model checkpoint to be analyzed
21+
:return: Dict of checkpoint properties, val loss, train loss, and epoch
22+
"""
23+
props = {}
24+
25+
val_index = checkpoint_path.index("val")
26+
val_loss = float(checkpoint_path[val_index + 9:val_index + 15])
27+
props["val_loss"] = val_loss
28+
29+
try:
30+
train_index = checkpoint_path.index("train")
31+
train_loss = float(checkpoint_path[train_index + 11:train_index + 17])
32+
props["train_loss"] = train_loss
33+
except ValueError:
34+
props["train_loss"] = 999
35+
36+
epoch_index = checkpoint_path.index("epoch")
37+
epoch = int(checkpoint_path[epoch_index + 6:epoch_index + 9])
38+
props["epoch"] = epoch
39+
40+
# get checkpoint time string
41+
time_search = re.compile(".*(\d\d\d\d:\d\d:\d\d:\d\d:\d\d:\d\d).hdf5")
42+
time_str = time_search.search(checkpoint_path).group(1)
43+
props["checkpoint_time_str"] = time_str
44+
45+
# get model name
46+
name_search = re.compile("model-(.*)_seq-.*")
47+
model_name = name_search.search(checkpoint_path).group(1)
48+
props["model_name"] = model_name
49+
50+
return props
51+
52+
53+
def get_best_checkpoint(candidate_jsons: List[Dict[str, Any]], checkpoint_dir: str, criteria_key: str = "val"):
54+
assert criteria_key == "val" or criteria_key == "train", "only val and train supported"
55+
best_props = None
56+
best_cand_value = float("inf")
57+
for candidate in candidate_jsons:
58+
cand_value = candidate[f"best_{criteria_key}_loss"]
59+
if cand_value < best_cand_value:
60+
best_props = {
61+
f"{criteria_key}_loss": round(cand_value, 4),
62+
"epoch": candidate[f"best_{criteria_key}_epoch"] + 1, # checkpoints epoch 1 indexed, jsons 0-indexed
63+
"model_name": get_readable_name(candidate["model_params"])
64+
}
65+
if "checkpoint_time_str" in candidate:
66+
best_props["checkpoint_time_str"] = candidate["checkpoint_time_str"]
67+
68+
for checkpoint in os.listdir(checkpoint_dir):
69+
if ".hdf5" not in checkpoint:
70+
continue
71+
props = get_checkpoint_props(checkpoint)
72+
if best_props.items() <= props.items():
73+
return os.path.join(checkpoint_dir, checkpoint)
74+
75+
raise ValueError(f"No checkpoint matching props in json {best_props} found")
76+
77+
78+
def read_json(path):
79+
with open(path, "r") as f:
80+
return json.load(f)
81+
82+
83+
def process_json_list(json_dir: str, checkpoint_dir: str, out_dir: str):
84+
json_map = defaultdict(list)
85+
# separate jsons by class
86+
re_match = re.compile("(?:hyperparam_tuning_)?(.*)_\d_train_results.json")
87+
for file in os.listdir(json_dir):
88+
match = re_match.search(file)
89+
if match is not None:
90+
model_type = match.group(1)
91+
# read json data and save
92+
json_path = os.path.join(json_dir, file)
93+
try:
94+
parsed = read_json(json_path)
95+
json_map[model_type].append(parsed)
96+
except JSONDecodeError:
97+
print(f"Could not parse json at {json_path}, skipping")
98+
continue
99+
100+
for candidate in ["val", "train"]:
101+
params_map = {}
102+
# for each class, get best checkpoint
103+
dest = os.path.join(out_dir, candidate)
104+
Path(dest).mkdir(exist_ok=True, parents=True)
105+
for model_type, json_data in json_map.items():
106+
checkpoint_path = get_best_checkpoint(candidate_jsons=json_data, checkpoint_dir=checkpoint_dir,
107+
criteria_key=candidate)
108+
shutil.copy(checkpoint_path, dest)
109+
params_map[os.path.basename(checkpoint_path)] = json_data[0]["model_params"]
110+
111+
with open(os.path.join(dest, "params.json"), "w") as f:
112+
json.dump(params_map, f)
113+
114+
115+
if __name__ == "__main__":
116+
parser = argparse.ArgumentParser()
117+
parser.add_argument("json_dir", type=str)
118+
parser.add_argument("checkpoint_dir", type=str)
119+
parser.add_argument("--out_dir", type=str, default="out_models")
120+
args = parser.parse_args()
121+
process_json_list(args.json_dir, args.checkpoint_dir, args.out_dir)
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
# Created by Patrick Kao at 3/11/22
2+
import argparse
3+
import json
4+
import os
5+
import shutil
6+
from pathlib import Path
7+
8+
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
9+
10+
11+
def get_matching_checkpoints(checkpoint_dir: str, filter_str: str, params_str: str,
12+
out_dir: str = "matching_checkpoints"):
13+
"""
14+
Finds all checkpoints matching filter_str and creates a params.json file for them with the params_str given by
15+
params_str
16+
:return:
17+
"""
18+
19+
model_params = {}
20+
out_dir = os.path.join(SCRIPT_DIR, out_dir)
21+
Path(out_dir).mkdir(exist_ok=True, parents=True)
22+
for checkpoint in sorted(os.listdir(checkpoint_dir)):
23+
if filter_str in checkpoint and ".hdf5" in checkpoint:
24+
model_params[checkpoint] = params_str
25+
shutil.copy(os.path.join(checkpoint_dir, checkpoint), out_dir)
26+
27+
with open(os.path.join(out_dir, "params.json"), "w") as f:
28+
json.dump(model_params, f)
29+
30+
31+
if __name__ == "__main__":
32+
parser = argparse.ArgumentParser()
33+
parser.add_argument("checkpoint_dir", type=str)
34+
parser.add_argument("filter_str", type=str)
35+
parser.add_argument("params_str", type=str)
36+
args = parser.parse_args()
37+
get_matching_checkpoints(args.checkpoint_dir, args.filter_str, args.params_str)

helper_scripts/intersect_aug_json.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
# Created by Patrick Kao at 3/16/22
2+
"""
3+
Calculates the intersection of the img dirs of the images in a data processing json file and a data directory and
4+
only saves the json with corresponding entries in the data directory
5+
"""
6+
import argparse
7+
import json
8+
import os.path
9+
from typing import Any, Dict
10+
11+
12+
def get_intersection_json(data_json: str, data_dir: str, out_path: str = "intersect.json") -> Dict[str, Any]:
13+
to_ret = []
14+
with open(data_json, "r") as f:
15+
synth_data = json.load(f)
16+
17+
for img_path, center_coords in synth_data:
18+
img_dir = os.path.basename(os.path.dirname(img_path))
19+
if os.path.exists(os.path.join(data_dir, img_dir)):
20+
to_ret.append([img_path, center_coords])
21+
22+
with open(out_path, "w") as f:
23+
json.dump(to_ret, f)
24+
25+
26+
if __name__ == "__main__":
27+
parser = argparse.ArgumentParser()
28+
parser.add_argument("data_json", type=str)
29+
parser.add_argument("data_dir", type=str)
30+
args = parser.parse_args()
31+
get_intersection_json(args.data_json, args.data_dir)

helper_scripts/merge_model_dirs.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
# Created by Patrick Kao at 4/18/22
2+
import argparse
3+
import json
4+
import os
5+
import shutil
6+
from pathlib import Path
7+
from typing import Sequence
8+
9+
10+
def merge_model_dirs(merge_dirs: Sequence[str], out_dir: str):
11+
for model_type in ["train", "val"]:
12+
out_json = {}
13+
for model_dir in merge_dirs:
14+
dir_path = os.path.join(model_dir, model_type)
15+
type_out = os.path.join(out_dir, model_type)
16+
Path(type_out).mkdir(parents=True, exist_ok=True)
17+
contents = os.listdir(dir_path)
18+
model_names = [file for file in contents if ".hdf5" in file]
19+
for model in model_names:
20+
abs_path = os.path.join(dir_path, model)
21+
shutil.copy(abs_path, type_out)
22+
23+
with open(os.path.join(dir_path, "params.json"), "r") as f:
24+
param_data = json.load(f)
25+
26+
out_json.update(param_data)
27+
28+
with open(os.path.join(type_out, "params.json"), "w") as f:
29+
json.dump(out_json, f)
30+
31+
32+
if __name__ == "__main__":
33+
parser = argparse.ArgumentParser()
34+
parser.add_argument("merge_dirs", nargs='+', default=[])
35+
parser.add_argument("--out_dir", default="merged_models")
36+
args = parser.parse_args()
37+
merge_model_dirs(args.merge_dirs, args.out_dir)

0 commit comments

Comments
 (0)