-
Notifications
You must be signed in to change notification settings - Fork 0
/
kuka_env.py
73 lines (55 loc) · 2.41 KB
/
kuka_env.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
#!usr/bin/env python3
import os
import threading
import numpy as np
import gym
from gym import Wrapper
from gym.wrappers.monitor import video_recorder
import pybullet_envs
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.vec_env import VecTransposeImage, DummyVecEnv
class AssignTypeWrapper(gym.Wrapper):
def __init__(self, env: gym.Env):
super(AssignTypeWrapper, self).__init__(env)
low = env.observation_space.low
high = env.observation_space.high
shape = env.observation_space.shape
self.observation_space = gym.spaces.Box(low=low, high=high, shape=shape, dtype=np.uint8)
class KukaVideoRecorder(Wrapper):
def __init__(self, env, filename, video_folder):
super(KukaVideoRecorder, self).__init__(env)
self.recording = False
self.in_playing = False
os.makedirs(video_folder, exist_ok=True)
self.file_path = os.path.join(video_folder, filename)
self.video_recorder = video_recorder.VideoRecorder(env=self.env, base_path=self.file_path)
def snapshot_worker(recorder):
while recorder.recording:
if recorder.in_playing:
recorder.video_recorder.capture_frame()
self.capture_runner_thread = threading.Thread(target=snapshot_worker, args=(self,), daemon=True)
def reset(self, **kwargs):
self.in_playing = False
observation = self.env.reset(**kwargs)
self.in_playing = True
if not self.recording:
self.recording = True
self.capture_runner_thread.start()
return observation
def __del__(self):
if self.recording:
self.recording = False
self.video_recorder.close()
self.capture_runner_thread.join()
def get_train_env(filename: str):
env = gym.make("KukaDiverseObjectGrasping-v0", maxSteps=20, isDiscrete=False, renders=False, removeHeightHack=True)
env = AssignTypeWrapper(env)
env = Monitor(env, filename=filename)
env = VecTransposeImage(DummyVecEnv([lambda: env]))
return env
def get_test_env(filename: str, dir: str):
env = gym.make("KukaDiverseObjectGrasping-v0", maxSteps=20, isDiscrete=False, renders=True, removeHeightHack=True, isTest=True)
env = AssignTypeWrapper(env)
env = KukaVideoRecorder(env=env, filename=filename, video_folder=dir)
env = VecTransposeImage(DummyVecEnv([lambda: env]))
return env