|
- import numpy as np
- from .vec_env import AlreadySteppingError, NotSteppingError, VecEnv
- from .env_utils import obs_space_info,dict_to_obs,copy_obs_dict
-
- class DummyVecEnv(VecEnv):
- """
- VecEnv that does runs multiple environments sequentially, that is,
- the step and reset commands are send to one environment at a time.
- Useful when debugging and when num_env == 1 (in the latter case,
- avoids communication overhead)
- """
- def __init__(self,env_fns):
- self.waiting = False
- self.closed = False
- self.envs = [fn() for fn in env_fns]
- env = self.envs[0]
- VecEnv.__init__(self,len(env_fns),env.observation_space,env.action_space)
- obs_space = env.observation_space
- self.keys,shapes,dtypes = obs_space_info(obs_space)
- self.buf_obs = {k:np.zeros((self.num_envs,)+tuple(shapes[k]),dtype=dtypes[k]) for k in self.keys}
- self.buf_dones = np.zeros((self.num_envs,),dtype=np.bool)
- self.buf_rews = np.zeros((self.num_envs,),dtype=np.float32)
- self.buf_infos = [{} for _ in range(self.num_envs)]
- self.actions = None
-
- def reset(self):
- for e in range(self.num_envs):
- obs = self.envs[e].reset()
- self._save_obs(e,obs)
- return self._obs_from_buf()
-
- def step_async(self,actions):
- if self.waiting == True:
- raise AlreadySteppingError
- listify = True
- try:
- if len(actions) == self.num_envs:
- listify = False
- except TypeError:
- pass
- if listify == False:
- self.actions = actions
- else:
- assert self.num_envs == 1, "actions {} is either not a list or has a wrong size - cannot match to {} environments".format(actions, self.num_envs)
- self.actions = [actions]
- self.waiting = True
-
- def step_wait(self):
- if self.waiting == False:
- raise NotSteppingError
- for e in range(self.num_envs):
- action = self.actions[e]
- obs,self.buf_rews[e],self.buf_dones[e],self.buf_infos[e] = self.envs[e].step(action)
- if self.buf_dones[e]:
- obs = self.envs[e].reset()
- self._save_obs(e,obs)
- self.waiting = False
- return self._obs_from_buf(),np.copy(self.buf_rews),np.copy(self.buf_dones),self.buf_infos.copy()
-
- def get_images(self):
- return [env.render("rgb_array") for env in self.envs]
-
- def close_extras(self):
- self.closed = True
- for env in self.envs:
- env.close()
-
- # save observation of indexes of e environment
- def _save_obs(self,e,obs):
- for k in self.keys:
- if k is None:
- self.buf_obs[k][e] = obs
- else:
- self.buf_obs[k][e] = obs[k]
-
- def _obs_from_buf(self):
- return dict_to_obs(copy_obs_dict(self.buf_obs))
-
- def render(self,mode='human'):
- if self.num_envs == 1:
- return self.envs[0].render(mode=mode)
- else:
- return super().render(mode=mode)
|