Skip to content

Commit 76d94ae

Browse files
authored
optimize spatial data augmentation (#3)
* optimize data augmentation * enlarge norm range for translation * add data visualization for config validation
1 parent 4abcd0c commit 76d94ae

File tree

6 files changed

+52
-10
lines changed

6 files changed

+52
-10
lines changed

README.md

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,17 @@
88

99
## 🔥 News
1010

11+
- **[Feb 16, 2025]** Optimize spatial data augmentation. Add data visualization for config check. Add tips for setting workspace range and normalization range.
1112
- **[Dec 26, 2024]** Fix several potential installation issues. Add support for CUDA 12.1.
1213
- **[May 11, 2024]** Initial release.
1314

1415
## 🛫 Getting Started
1516

1617
### 💻 Installation
1718

18-
Please following the [installation guide](assets/docs/INSTALL.md) to install the `rise` conda environments and the dependencies, as well as the real robot environments. Also, remember to adjust the constant parameters in `dataset/constants.py` and `utils/constants.py` according to your own environment.
19+
Please follow the [installation guide](assets/docs/INSTALL.md) to install the `rise` conda environments and the dependencies, as well as the real robot environments. Also, remember to adjust the constant parameters in `dataset/constants.py` and `utils/constants.py` according to your own environment.
20+
21+
**Make sure that `TRANS_MIN/MAX` and `WORKSPACE_MIN/MAX` are correctly set in the camera coordinates, or you may obtain meaningless output.** We recommend expanding `TRANS_MIN/MAX` by 0.15 - 0.3 meters on both sides of the actual translation range to accommodate spatial data augmentation. You could follow [command_train.sh](command_train.sh) for data visualization and parameter check.
1922

2023
### 📷 Calibration
2124

assets/docs/DEPLOY.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
- `IMG_MEAN` and `IMG_STD` are the image normalization constants. Here we use ImageNet normalization coefficients.
55
- `TRANS_MIN` and `TRANS_MAX` are the tcp normalization range in the camera coordinate.
66
- `MAX_GRIPPER_WIDTH` indicates the gripper width normalization range (in meter).
7-
- `WORKSPACE_MIN` and `WORKSPACE_MAX` are the workspace range in the camera coordinate.
7+
- `WORKSPACE_MIN` and `WORKSPACE_MAX` are the workspace range in the camera coordinate and used for point cloud cropping.
88
- `SAFE_WORKSPACE_MIN` and `SAFE_WORKSPACE_MAX` are the safe workspace range in the base coordinate (used for evaluation).
99
- `SAFE_EPS` denotes the safe epsilon of the safe workspace range. Therefore, the real range should be [min + eps, max - eps].
1010
- `GRIPPER_THRESHOLD` denotes the gripper moving threshold (in meter) to avoid gripper action too frequently during evaluation.

command_train.sh

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,5 @@
1-
torchrun --master_addr 192.168.3.50 --master_port 14522 --nproc_per_node 2 --nnodes 1 --node_rank 0 train.py --data_path data/collect_pens --aug --aug_jitter --num_action 20 --voxel_size 0.005 --obs_feature_dim 512 --hidden_dim 512 --nheads 8 --num_encoder_layers 4 --num_decoder_layers 1 --dim_feedforward 2048 --dropout 0.1 --ckpt_dir logs/collect_pens --batch_size 240 --num_epochs 1000 --save_epochs 50 --num_workers 24 --seed 233
1+
# example for policy training
2+
torchrun --master_addr 192.168.3.50 --master_port 14522 --nproc_per_node 2 --nnodes 1 --node_rank 0 train.py --data_path data/collect_pens --aug --aug_jitter --num_action 20 --voxel_size 0.005 --obs_feature_dim 512 --hidden_dim 512 --nheads 8 --num_encoder_layers 4 --num_decoder_layers 1 --dim_feedforward 2048 --dropout 0.1 --ckpt_dir logs/collect_pens --batch_size 240 --num_epochs 1000 --save_epochs 50 --num_workers 24 --seed 233
3+
4+
# example for data visualization & parameter check
5+
torchrun --master_addr 192.168.3.50 --master_port 14522 --nproc_per_node 1 --nnodes 1 --node_rank 0 train.py --data_path data/collect_pens --aug --aug_jitter --num_action 20 --voxel_size 0.005 --obs_feature_dim 512 --hidden_dim 512 --nheads 8 --num_encoder_layers 4 --num_decoder_layers 1 --dim_feedforward 2048 --dropout 0.1 --ckpt_dir logs/collect_pens --batch_size 1 --num_epochs 1 --save_epochs 1 --num_workers 1 --seed 233 --vis_data

dataset/realworld.py

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,8 @@ def __init__(
3636
aug_jitter = False,
3737
aug_jitter_params = [0.4, 0.4, 0.2, 0.1],
3838
aug_jitter_prob = 0.2,
39-
with_cloud = False
39+
with_cloud = False,
40+
vis = False
4041
):
4142
assert split in ['train', 'val', 'all']
4243

@@ -56,6 +57,7 @@ def __init__(
5657
self.aug_jitter_params = np.array(aug_jitter_params)
5758
self.aug_jitter_prob = aug_jitter_prob
5859
self.with_cloud = with_cloud
60+
self.vis = vis
5961

6062
self.all_demos = sorted(os.listdir(self.data_path))
6163
self.num_demos = len(self.all_demos)
@@ -115,9 +117,17 @@ def _augmentation(self, clouds, tcps):
115117
rotation_angles = np.random.rand(3) * (self.aug_rot_max - self.aug_rot_min) + self.aug_rot_min
116118
rotation_angles = rotation_angles / 180 * np.pi # tranform from degree to radius
117119
aug_mat = rot_trans_mat(translation_offsets, rotation_angles)
118-
for cloud in clouds:
119-
cloud = apply_mat_to_pcd(cloud, aug_mat)
120+
center = clouds[-1][..., :3].mean(axis = 0)
121+
122+
for i in range(len(clouds)):
123+
clouds[i][..., :3] -= center
124+
clouds[i] = apply_mat_to_pcd(clouds[i], aug_mat)
125+
clouds[i][..., :3] += center
126+
127+
tcps[..., :3] -= center
120128
tcps = apply_mat_to_pose(tcps, aug_mat, rotation_rep = "quaternion")
129+
tcps[..., :3] += center
130+
121131
return clouds, tcps
122132

123133
def _normalize_tcp(self, tcp_list):
@@ -221,6 +231,27 @@ def __getitem__(self, index):
221231
# point augmentations
222232
if self.split == 'train' and self.aug:
223233
clouds, action_tcps = self._augmentation(clouds, action_tcps)
234+
235+
# visualization
236+
if self.vis:
237+
points = clouds[-1][..., :3]
238+
print("point range", points.min(axis=0), points.max(axis=0))
239+
pcd = o3d.geometry.PointCloud()
240+
pcd.points = o3d.utility.Vector3dVector(points)
241+
pcd.colors = o3d.utility.Vector3dVector(colors * IMG_STD + IMG_MEAN)
242+
traj = []
243+
# red box stands for the workspace range
244+
bbox3d_1 = o3d.geometry.AxisAlignedBoundingBox(WORKSPACE_MIN, WORKSPACE_MAX)
245+
bbox3d_1.color = [1, 0, 0]
246+
# green box stands for the translation normalization range
247+
bbox3d_2 = o3d.geometry.AxisAlignedBoundingBox(TRANS_MIN, TRANS_MAX)
248+
bbox3d_2.color = [0, 1, 0]
249+
action_tcps_vis = xyz_rot_transform(action_tcps, from_rep = "quaternion", to_rep = "matrix")
250+
for i in range(len(action_tcps_vis)):
251+
action = action_tcps_vis[i]
252+
frame = o3d.geometry.TriangleMesh.create_coordinate_frame(size=0.03).transform(action)
253+
traj.append(frame)
254+
o3d.visualization.draw_geometries([pcd.voxel_down_sample(self.voxel_size), bbox3d_1, bbox3d_2, *traj])
224255

225256
# rotation transformation (to 6d)
226257
action_tcps = xyz_rot_transform(action_tcps, from_rep = "quaternion", to_rep = "rotation_6d")

train.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,8 @@
3939
"num_epochs": 1000,
4040
"save_epochs": 50,
4141
"num_workers": 24,
42-
"seed": 233
42+
"seed": 233,
43+
"vis_data": False
4344
})
4445

4546

@@ -72,7 +73,8 @@ def train(args_override):
7273
voxel_size = args.voxel_size,
7374
aug = args.aug,
7475
aug_jitter = args.aug_jitter,
75-
with_cloud = False
76+
with_cloud = False,
77+
vis = args.vis_data
7678
)
7779
sampler = torch.utils.data.distributed.DistributedSampler(
7880
dataset,
@@ -85,7 +87,8 @@ def train(args_override):
8587
batch_size = args.batch_size // WORLD_SIZE,
8688
num_workers = args.num_workers,
8789
collate_fn = collate_fn,
88-
sampler = sampler
90+
sampler = sampler,
91+
drop_last = True
8992
)
9093

9194
# policy
@@ -203,5 +206,6 @@ def train(args_override):
203206
parser.add_argument('--save_epochs', action = 'store', type = int, help = 'saving epochs', required = False, default = 50)
204207
parser.add_argument('--num_workers', action = 'store', type = int, help = 'number of workers', required = False, default = 24)
205208
parser.add_argument('--seed', action = 'store', type = int, help = 'seed', required = False, default = 233)
209+
parser.add_argument('--vis_data', action = 'store_true', help = 'whether to visualize the input data and ground truth actions.')
206210

207211
train(vars(parser.parse_args()))

utils/constants.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
IMG_STD = np.array([0.229, 0.224, 0.225])
66

77
# tcp normalization and gripper width normalization
8-
TRANS_MIN, TRANS_MAX = np.array([-0.35, -0.35, 0]), np.array([0.35, 0.35, 0.7])
8+
TRANS_MIN, TRANS_MAX = np.array([-0.5, -0.5, 0]), np.array([0.5, 0.5, 1.0])
99
MAX_GRIPPER_WIDTH = 0.11 # meter
1010

1111
# workspace in camera coordinate

0 commit comments

Comments
 (0)