From a967370d393a79e0645ecfbae3fe41118b835feb Mon Sep 17 00:00:00 2001 From: Ben Bell Date: Sun, 24 May 2020 23:08:14 -0400 Subject: [PATCH 1/3] WIP: zmq vector env --- rllib/env/base_env.py | 2 +- rllib/env/zmq_remote_env.py | 225 +++++++++++++++++++++++++++++++++ rllib/env/zmq_remote_worker.py | 105 +++++++++++++++ 3 files changed, 331 insertions(+), 1 deletion(-) create mode 100644 rllib/env/zmq_remote_env.py create mode 100644 rllib/env/zmq_remote_worker.py diff --git a/rllib/env/base_env.py b/rllib/env/base_env.py index 29c2e3a9cf43..f012eb2ab15b 100644 --- a/rllib/env/base_env.py +++ b/rllib/env/base_env.py @@ -80,7 +80,7 @@ def to_base_env(env, remote_env_batch_wait_ms=0): """Wraps any env type as needed to expose the async interface.""" - from ray.rllib.env.remote_vector_env import RemoteVectorEnv + from ray.rllib.env.zmq_vector_env import ZMQRemoteVectorEnv as RemoteVectorEnv if remote_envs and num_envs == 1: raise ValueError( "Remote envs only make sense to use if num_envs > 1 " diff --git a/rllib/env/zmq_remote_env.py b/rllib/env/zmq_remote_env.py new file mode 100644 index 000000000000..f5710937c27b --- /dev/null +++ b/rllib/env/zmq_remote_env.py @@ -0,0 +1,225 @@ +import json +import os +import logging +import pickle + +import cloudpickle +import numpy as np +import torch +import zmq +from torch import multiprocessing as mp + +import ray +from ray.rllib.env.base_env import BaseEnv, _DUMMY_AGENT_ID, ASYNC_RESET_RETURN +from .zmq_remote_worker import zmq_worker + +logger = logging.getLogger(__name__) +ZMQ_CONNECT_METHOD = 'ipc' + + +class WorkerError(BaseException): + pass + + +class ZMQRemoteVectorEnv(BaseEnv): + """Vector env that executes envs in another process using ZMQ. + Waits for all envs to compute, but is generally very fast. + Both single and multi-agent child envs are supported. + """ + def __init__(self, make_env, num_envs, multiagent, + remote_env_batch_wait_ms): + self.make_local_env = make_env + self.num_envs = num_envs + self.multiagent = multiagent + self.poll_timeout = remote_env_batch_wait_ms / 1000 + + self.processes = None # lazy init + + def poll(self): + # Init processes + if self.processes is None: + self._setup() + obs, rewards, dones, infos = self._setup_reset() + return obs, rewards, dones, infos, {} + + # each keyed by env_id in [0, num_remote_envs) + obs, rewards, dones, infos = {}, {}, {}, {} + + # wait for all envs to finish + for e_id, remote in self._zmq_sockets.items(): + result = remote.recv() + self._check_for_errors(result, e_id) + _, rew, done, info = remote.recv() + obs[e_id] = self.shared_memories[e_id] + rewards[e_id] = rew + dones[e_id] = done + infos[e_id] = info + + self.waiting = False + return obs, rewards, dones, infos, {} + + def send_actions(self, action_dict): + for env_id, actions in action_dict.items(): + socket = self._zmq_sockets[env_id] + msg = json.dumps({k: v for k, v in actions.items()}) + socket.send(msg.encode(), zmq.NOBLOCK, copy=False, track=False) + + self.waiting = True + + def try_reset(self, env_id): + raise NotImplementedError('when is this called?' + str(env_id)) + # actor = self.actors[env_id] + # obj_id = actor.reset.remote() + # self.pending[obj_id] = actor + # return ASYNC_RESET_RETURN + + def stop(self): + if self.closed: + return + if self.waiting: + for remote in self._zmq_sockets: + remote.recv() + for socket in self._zmq_sockets: + socket.send('close'.encode()) + for p in self.processes: + p.join() + self.closed = True + + def _setup(self): + self.waiting = False + self.closed = False + self.processes = {} + + self._zmq_context = zmq.Context() + self._zmq_sockets = {} + + # iterate envs to get torch shared memory through pipe then close it + self.shared_memories = {} + + for w_ind in range(self.num_envs): + pipe, w_pipe = mp.Pipe() + socket, port = zmq_robust_bind_socket(self._zmq_context) + + process = mp.Process(target=zmq_worker, args=( + w_pipe, pipe, port, CloudpickleWrapper((self.make_local_env, w_ind)) + )) + process.daemon = True + process.start() + self.processes[w_ind] = process + + self._zmq_sockets[w_ind] = socket + + pipe.send(('get_shared_memory', None)) + self.shared_memories[w_ind] = pipe.recv() + + # switch to zmq socket and close pipes + pipe.send(('switch_zmq', None)) + pipe.close() + w_pipe.close() + + logger.info("All remote envs started") + + def _setup_reset(self): + for _, socket in self._zmq_sockets.items(): + socket.send('reset'.encode()) + + all_obs = {} + all_rew = {} + all_info = {} + all_done = {} + for r_ind, remote in self._zmq_sockets.items(): + wait_for_worker = json.loads(remote.recv().decode()) + all_obs[r_ind] = self.shared_memories[r_ind] + # each keyed by agent_id in the env + all_rew[r_ind] = {agent_id: 0 for agent_id in all_obs[r_ind].keys()} + all_info[r_ind] = {agent_id: {} for agent_id in all_obs[r_ind].keys()} + all_done[r_ind] = {"__all__": False} + + return all_obs, all_rew, all_done, all_info + + def _check_for_errors(self, result, e_id): + if result[:5] == b'error': + error = 'Worker {} has an error {}'.format(e_id, result) + raise WorkerError(error) + + +@ray.remote(num_cpus=0) +class _RemoteMultiAgentEnv: + """Wrapper class for making a multi-agent env a remote actor.""" + + def __init__(self, make_env, i): + self.env = make_env(i) + + def reset(self): + obs = self.env.reset() + # each keyed by agent_id in the env + rew = {agent_id: 0 for agent_id in obs.keys()} + info = {agent_id: {} for agent_id in obs.keys()} + done = {"__all__": False} + return obs, rew, done, info + + def step(self, action_dict): + return self.env.step(action_dict) + + +@ray.remote(num_cpus=0) +class _RemoteSingleAgentEnv: + """Wrapper class for making a gym env a remote actor.""" + + def __init__(self, make_env, i): + self.env = make_env(i) + + def reset(self): + obs = {_DUMMY_AGENT_ID: self.env.reset()} + rew = {agent_id: 0 for agent_id in obs.keys()} + info = {agent_id: {} for agent_id in obs.keys()} + done = {"__all__": False} + return obs, rew, done, info + + def step(self, action): + obs, rew, done, info = self.env.step(action[_DUMMY_AGENT_ID]) + obs, rew, done, info = [{ + _DUMMY_AGENT_ID: x + } for x in [obs, rew, done, info]] + done["__all__"] = done[_DUMMY_AGENT_ID] + return obs, rew, done, info + + +class CloudpickleWrapper(object): + """ + Modified. + MIT License + Copyright (c) 2017 OpenAI (http://openai.com) + """ + + def __init__(self, x): + self.x = x + + def __getstate__(self): + return cloudpickle.dumps(self.x) + + def __setstate__(self, ob): + self.x = pickle.loads(ob) + + +def zmq_robust_bind_socket(zmq_context): + try_count = 0 + while try_count < 3: + try: + socket = zmq_context.socket(zmq.PAIR) + port = np.random.randint(5000, 30000) + if ZMQ_CONNECT_METHOD == 'tcp': + socket.bind("tcp://*:{}".format(port)) + if ZMQ_CONNECT_METHOD == 'ipc': + os.makedirs('/dev/shm/remotezmq/', exist_ok=True) + socket.bind("ipc:///dev/shm/remotezmq/{}".format(port)) + except zmq.error.ZMQError as e: + try_count += 1 + socket = None + last_error = e + continue + break + if socket is None: + raise Exception("ZMQ couldn't bind socket after 3 tries. {}".format(last_error)) + return socket, port + diff --git a/rllib/env/zmq_remote_worker.py b/rllib/env/zmq_remote_worker.py new file mode 100644 index 000000000000..8937fe99b96f --- /dev/null +++ b/rllib/env/zmq_remote_worker.py @@ -0,0 +1,105 @@ +import json +import os +import logging +import pickle + +import cloudpickle +import numpy as np +import torch +import zmq +from torch import multiprocessing as mp + +import ray +from ray.rllib.env.base_env import BaseEnv, _DUMMY_AGENT_ID, ASYNC_RESET_RETURN + +logger = logging.getLogger(__name__) +ZMQ_CONNECT_METHOD = 'icp' + + +def zmq_worker(remote, parent_remote, port, env_fn_wrapper): + """ + Modified. + MIT License + Copyright (c) 2017 OpenAI (http://openai.com) + """ + parent_remote.close() + env = env_fn_wrapper.x() + + shared_memory = {} + print(env.observation_space) + for name, shape in env.observation_space.items(): + if shape is not None: + if not shape.dtype: + tensor = torch.FloatTensor(*shape) + else: + tensor = torch.zeros(*shape, dtype=shape.dtype) + shared_memory[name] = tensor + + # initial python pipe setup + python_pipe = True + while python_pipe: + cmd, _ = remote.recv() + if cmd == 'get_shared_memory': + remote.send(shared_memory) + elif cmd == 'switch_zmq': + # close python pipes + remote.close() + python_pipe = False + else: + raise NotImplementedError + + # zmq setup + context = zmq.Context() + socket = context.socket(zmq.PAIR) + if ZMQ_CONNECT_METHOD == 'tcp': + socket.connect("tcp://localhost:{}".format(port)) + if ZMQ_CONNECT_METHOD == 'ipc': + socket.connect("ipc:///dev/shm/remotezmq/{}".format(port)) + + running = True + while running: + try: + socket_data = socket.recv() + socket_parsed = socket_data.decode() + + # commands that aren't action dictionaries + if socket_parsed == 'reset': + ob = env.reset() + ob = handle_ob(ob, shared_memory) + # only the non-shared obs are returned here + socket.send(json.dumps(ob).encode(), zmq.NOBLOCK, copy=False, track=False) + elif socket_parsed == 'close': + env.close() + running = False + # else action dictionary + else: + action_dictionary = json.loads(socket_parsed) + ob, reward, done, info = env.step(action_dictionary) + if done: + ob = env.reset() + ob = handle_ob(ob, shared_memory) + # only the non-shared obs are returned here + msg = json.dumps((ob, reward, done, info)) + socket.send(msg.encode(), zmq.NOBLOCK, copy=False, track=False) + except KeyboardInterrupt: + pass + except Exception as e: + running = False + e_str = '{}: {}'.format(type(e).__name__, e) + print('Subprocess environment has an error', e_str) + socket.send('error. {}'.format(e_str).encode(), zmq.NOBLOCK, copy=False, track=False) + + +def handle_ob(ob, shared_memory): + print('obs', ob) + non_shared = {} + for k, v in ob.items(): + if isinstance(v, torch.Tensor): + shared_memory[k].copy_(v) + # support double layer dict + elif isinstance(v, dict): + print(v) + exit() + else: + non_shared[k] = v + return non_shared From b9ef4c18c9b41c0b628384f88e8cf65daae7a612 Mon Sep 17 00:00:00 2001 From: Ben Bell Date: Mon, 25 May 2020 01:39:56 -0400 Subject: [PATCH 2/3] Working multiagent zmq env --- rllib/env/base_env.py | 2 +- rllib/env/zmq_remote_env.py | 86 ++++++++++++---------------------- rllib/env/zmq_remote_worker.py | 52 +++++++++++++------- 3 files changed, 66 insertions(+), 74 deletions(-) diff --git a/rllib/env/base_env.py b/rllib/env/base_env.py index f012eb2ab15b..5f82c7a7b560 100644 --- a/rllib/env/base_env.py +++ b/rllib/env/base_env.py @@ -80,7 +80,7 @@ def to_base_env(env, remote_env_batch_wait_ms=0): """Wraps any env type as needed to expose the async interface.""" - from ray.rllib.env.zmq_vector_env import ZMQRemoteVectorEnv as RemoteVectorEnv + from ray.rllib.env.zmq_remote_env import ZMQRemoteVectorEnv as RemoteVectorEnv if remote_envs and num_envs == 1: raise ValueError( "Remote envs only make sense to use if num_envs > 1 " diff --git a/rllib/env/zmq_remote_env.py b/rllib/env/zmq_remote_env.py index f5710937c27b..afad6ddd3eb5 100644 --- a/rllib/env/zmq_remote_env.py +++ b/rllib/env/zmq_remote_env.py @@ -49,8 +49,8 @@ def poll(self): for e_id, remote in self._zmq_sockets.items(): result = remote.recv() self._check_for_errors(result, e_id) - _, rew, done, info = remote.recv() - obs[e_id] = self.shared_memories[e_id] + _, rew, done, info = json.loads(result.decode()) + obs[e_id] = self._obs_to_numpy(self.shared_memories[e_id]) rewards[e_id] = rew dones[e_id] = done infos[e_id] = info @@ -61,17 +61,16 @@ def poll(self): def send_actions(self, action_dict): for env_id, actions in action_dict.items(): socket = self._zmq_sockets[env_id] - msg = json.dumps({k: v for k, v in actions.items()}) + # TODO: support nested action dict? + msg = json.dumps({k: v.item() for k, v in actions.items()}) socket.send(msg.encode(), zmq.NOBLOCK, copy=False, track=False) self.waiting = True def try_reset(self, env_id): - raise NotImplementedError('when is this called?' + str(env_id)) - # actor = self.actors[env_id] - # obj_id = actor.reset.remote() - # self.pending[obj_id] = actor - # return ASYNC_RESET_RETURN + # only reset the env + self._zmq_sockets[env_id].send('reset'.encode()) + return ASYNC_RESET_RETURN def stop(self): if self.closed: @@ -128,63 +127,40 @@ def _setup_reset(self): all_info = {} all_done = {} for r_ind, remote in self._zmq_sockets.items(): - wait_for_worker = json.loads(remote.recv().decode()) - all_obs[r_ind] = self.shared_memories[r_ind] + result = remote.recv().decode() + self._check_for_errors(result, r_ind) + _, rew, done, info = json.loads(result) + + all_obs[r_ind] = self._obs_to_numpy(self.shared_memories[r_ind]) # each keyed by agent_id in the env - all_rew[r_ind] = {agent_id: 0 for agent_id in all_obs[r_ind].keys()} - all_info[r_ind] = {agent_id: {} for agent_id in all_obs[r_ind].keys()} - all_done[r_ind] = {"__all__": False} + all_rew[r_ind] = rew + all_info[r_ind] = info + all_done[r_ind] = done return all_obs, all_rew, all_done, all_info + @staticmethod + def _obs_to_numpy(obs): + o = {} + for k, v in obs.items(): + if isinstance(v, torch.Tensor): + o[k] = v.numpy() + # support double layer dict + elif isinstance(v, dict): + for nk, nv in v.items(): + if isinstance(nv, dict): + raise NotImplementedError('Nested obs space dict not implemented') + o[k][nk] = nv.numpy() + else: + raise NotImplementedError('Unsupported obs type {}'.format(type(v))) + return o + def _check_for_errors(self, result, e_id): if result[:5] == b'error': error = 'Worker {} has an error {}'.format(e_id, result) raise WorkerError(error) -@ray.remote(num_cpus=0) -class _RemoteMultiAgentEnv: - """Wrapper class for making a multi-agent env a remote actor.""" - - def __init__(self, make_env, i): - self.env = make_env(i) - - def reset(self): - obs = self.env.reset() - # each keyed by agent_id in the env - rew = {agent_id: 0 for agent_id in obs.keys()} - info = {agent_id: {} for agent_id in obs.keys()} - done = {"__all__": False} - return obs, rew, done, info - - def step(self, action_dict): - return self.env.step(action_dict) - - -@ray.remote(num_cpus=0) -class _RemoteSingleAgentEnv: - """Wrapper class for making a gym env a remote actor.""" - - def __init__(self, make_env, i): - self.env = make_env(i) - - def reset(self): - obs = {_DUMMY_AGENT_ID: self.env.reset()} - rew = {agent_id: 0 for agent_id in obs.keys()} - info = {agent_id: {} for agent_id in obs.keys()} - done = {"__all__": False} - return obs, rew, done, info - - def step(self, action): - obs, rew, done, info = self.env.step(action[_DUMMY_AGENT_ID]) - obs, rew, done, info = [{ - _DUMMY_AGENT_ID: x - } for x in [obs, rew, done, info]] - done["__all__"] = done[_DUMMY_AGENT_ID] - return obs, rew, done, info - - class CloudpickleWrapper(object): """ Modified. diff --git a/rllib/env/zmq_remote_worker.py b/rllib/env/zmq_remote_worker.py index 8937fe99b96f..6e1c502f21a0 100644 --- a/rllib/env/zmq_remote_worker.py +++ b/rllib/env/zmq_remote_worker.py @@ -1,3 +1,4 @@ +import gym import json import os import logging @@ -13,7 +14,7 @@ from ray.rllib.env.base_env import BaseEnv, _DUMMY_AGENT_ID, ASYNC_RESET_RETURN logger = logging.getLogger(__name__) -ZMQ_CONNECT_METHOD = 'icp' +ZMQ_CONNECT_METHOD = 'ipc' def zmq_worker(remote, parent_remote, port, env_fn_wrapper): @@ -23,17 +24,18 @@ def zmq_worker(remote, parent_remote, port, env_fn_wrapper): Copyright (c) 2017 OpenAI (http://openai.com) """ parent_remote.close() - env = env_fn_wrapper.x() + env_fn, env_id = env_fn_wrapper.x + env = env_fn(env_id) shared_memory = {} - print(env.observation_space) - for name, shape in env.observation_space.items(): - if shape is not None: - if not shape.dtype: - tensor = torch.FloatTensor(*shape) - else: - tensor = torch.zeros(*shape, dtype=shape.dtype) - shared_memory[name] = tensor + for name, space in env.observation_space.spaces.items(): + if isinstance(space, gym.spaces.Box): + if space.dtype != np.float32: + raise NotImplementedError('Type not implemented {}'.format(space.dtype)) + tensor = torch.zeros(space.shape, dtype=torch.float32) + elif isinstance(space, gym.spaces.dict.Dict): + raise NotImplementedError('Gym dict spaces not supported') + shared_memory[name] = tensor # initial python pipe setup python_pipe = True @@ -65,9 +67,15 @@ def zmq_worker(remote, parent_remote, port, env_fn_wrapper): # commands that aren't action dictionaries if socket_parsed == 'reset': ob = env.reset() - ob = handle_ob(ob, shared_memory) + # MUST return ob, reward, done, info + # TODO: should be different for not multi agent env + reward = {agent_id: 0 for agent_id in ob.keys()} + done = {"__all__": False} + info = {agent_id: {} for agent_id in ob.keys()} # only the non-shared obs are returned here - socket.send(json.dumps(ob).encode(), zmq.NOBLOCK, copy=False, track=False) + non_shared_ob = handle_ob(ob, shared_memory) + msg = json.dumps((non_shared_ob, reward, done, info)) + socket.send(msg.encode(), zmq.NOBLOCK, copy=False, track=False) elif socket_parsed == 'close': env.close() running = False @@ -75,8 +83,12 @@ def zmq_worker(remote, parent_remote, port, env_fn_wrapper): else: action_dictionary = json.loads(socket_parsed) ob, reward, done, info = env.step(action_dictionary) - if done: - ob = env.reset() + # Done ob handled by reset + # if isinstance(done, dict): + # if done['__all__']: + # ob = env.reset() + # elif done: + # ob = env.reset() ob = handle_ob(ob, shared_memory) # only the non-shared obs are returned here msg = json.dumps((ob, reward, done, info)) @@ -91,15 +103,19 @@ def zmq_worker(remote, parent_remote, port, env_fn_wrapper): def handle_ob(ob, shared_memory): - print('obs', ob) non_shared = {} for k, v in ob.items(): if isinstance(v, torch.Tensor): shared_memory[k].copy_(v) + elif isinstance(v, np.ndarray): + shared_memory[k].copy_(torch.from_numpy(v)) # support double layer dict elif isinstance(v, dict): - print(v) - exit() + for nk, nv in v.items(): + if isinstance(nv, dict): + raise NotImplementedError('Nested obs space dict not implemented') + shared_memory[k][nk] = torch.from_numpy(nv) else: - non_shared[k] = v + raise NotImplementedError('Unsupported obs type {}'.format(type(v))) + # non_shared[k] = v return non_shared From 8ba9f7ceb97df978a653f54ab34038c58927eb1b Mon Sep 17 00:00:00 2001 From: Ben Bell Date: Wed, 27 May 2020 13:16:15 -0400 Subject: [PATCH 3/3] Fix multi layer dicts --- rllib/env/base_env.py | 1 + rllib/env/zmq_remote_env.py | 29 +++++++++++++++++++++++------ rllib/env/zmq_remote_worker.py | 32 +++++++++++++++++++++++++------- 3 files changed, 49 insertions(+), 13 deletions(-) diff --git a/rllib/env/base_env.py b/rllib/env/base_env.py index 5f82c7a7b560..5dfb7c594a8f 100644 --- a/rllib/env/base_env.py +++ b/rllib/env/base_env.py @@ -80,6 +80,7 @@ def to_base_env(env, remote_env_batch_wait_ms=0): """Wraps any env type as needed to expose the async interface.""" + # from ray.rllib.env.remote_vector_env import RemoteVectorEnv from ray.rllib.env.zmq_remote_env import ZMQRemoteVectorEnv as RemoteVectorEnv if remote_envs and num_envs == 1: raise ValueError( diff --git a/rllib/env/zmq_remote_env.py b/rllib/env/zmq_remote_env.py index afad6ddd3eb5..67bfb8a0e855 100644 --- a/rllib/env/zmq_remote_env.py +++ b/rllib/env/zmq_remote_env.py @@ -1,3 +1,4 @@ +from collections import OrderedDict, Iterable import json import os import logging @@ -14,7 +15,7 @@ from .zmq_remote_worker import zmq_worker logger = logging.getLogger(__name__) -ZMQ_CONNECT_METHOD = 'ipc' +ZMQ_CONNECT_METHOD = 'tcp' class WorkerError(BaseException): @@ -61,8 +62,16 @@ def poll(self): def send_actions(self, action_dict): for env_id, actions in action_dict.items(): socket = self._zmq_sockets[env_id] - # TODO: support nested action dict? - msg = json.dumps({k: v.item() for k, v in actions.items()}) + a = {} + for k, v in actions.items(): + # TODO: support nested action dict? + if isinstance(v, np.ndarray): + a[k] = v.item() + elif isinstance(v, Iterable): + a[k] = [x.item() for x in v] + else: + a[k] = v.item() + msg = json.dumps(a) socket.send(msg.encode(), zmq.NOBLOCK, copy=False, track=False) self.waiting = True @@ -141,16 +150,24 @@ def _setup_reset(self): @staticmethod def _obs_to_numpy(obs): - o = {} + o = OrderedDict({}) for k, v in obs.items(): if isinstance(v, torch.Tensor): o[k] = v.numpy() # support double layer dict elif isinstance(v, dict): + o[k] = OrderedDict({}) for nk, nv in v.items(): if isinstance(nv, dict): - raise NotImplementedError('Nested obs space dict not implemented') - o[k][nk] = nv.numpy() + dd = OrderedDict({}) + for dk, dv in nv.items(): + if isinstance(dv, dict): + raise NotImplementedError('Double Nested obs space dict not implemented') + else: + dd[dk] = dv.numpy() + o[k][nk] = dd + else: + o[k][nk] = nv.numpy() else: raise NotImplementedError('Unsupported obs type {}'.format(type(v))) return o diff --git a/rllib/env/zmq_remote_worker.py b/rllib/env/zmq_remote_worker.py index 6e1c502f21a0..b83a16385c55 100644 --- a/rllib/env/zmq_remote_worker.py +++ b/rllib/env/zmq_remote_worker.py @@ -1,7 +1,9 @@ +from collections import OrderedDict import gym import json import os import logging +import traceback import pickle import cloudpickle @@ -14,7 +16,7 @@ from ray.rllib.env.base_env import BaseEnv, _DUMMY_AGENT_ID, ASYNC_RESET_RETURN logger = logging.getLogger(__name__) -ZMQ_CONNECT_METHOD = 'ipc' +ZMQ_CONNECT_METHOD = 'tcp' def zmq_worker(remote, parent_remote, port, env_fn_wrapper): @@ -27,15 +29,24 @@ def zmq_worker(remote, parent_remote, port, env_fn_wrapper): env_fn, env_id = env_fn_wrapper.x env = env_fn(env_id) - shared_memory = {} + shared_memory = OrderedDict({}) for name, space in env.observation_space.spaces.items(): if isinstance(space, gym.spaces.Box): if space.dtype != np.float32: raise NotImplementedError('Type not implemented {}'.format(space.dtype)) tensor = torch.zeros(space.shape, dtype=torch.float32) + shared_memory[name] = tensor elif isinstance(space, gym.spaces.dict.Dict): - raise NotImplementedError('Gym dict spaces not supported') - shared_memory[name] = tensor + sm = OrderedDict({}) + for nk, ns in space.spaces.items(): + if isinstance(ns, gym.spaces.Box): + if ns.dtype != np.float32: + raise NotImplementedError('Type not implemented {}'.format(ns.dtype)) + tensor = torch.zeros(ns.shape, dtype=torch.float32) + elif isinstance(space, gym.spaces.dict.Dict): + raise NotImplementedError('Gym nested dict spaces not supported') + sm[nk] = tensor + shared_memory[name] = sm # initial python pipe setup python_pipe = True @@ -98,12 +109,13 @@ def zmq_worker(remote, parent_remote, port, env_fn_wrapper): except Exception as e: running = False e_str = '{}: {}'.format(type(e).__name__, e) + e_str += '\n' + traceback.format_exc() print('Subprocess environment has an error', e_str) socket.send('error. {}'.format(e_str).encode(), zmq.NOBLOCK, copy=False, track=False) def handle_ob(ob, shared_memory): - non_shared = {} + non_shared = OrderedDict({}) for k, v in ob.items(): if isinstance(v, torch.Tensor): shared_memory[k].copy_(v) @@ -113,8 +125,14 @@ def handle_ob(ob, shared_memory): elif isinstance(v, dict): for nk, nv in v.items(): if isinstance(nv, dict): - raise NotImplementedError('Nested obs space dict not implemented') - shared_memory[k][nk] = torch.from_numpy(nv) + dd = OrderedDict({}) + for dk, dv in nv.items(): + if isinstance(dv, dict): + raise NotImplementedError('Double Nested obs space dict not implemented') + else: + shared_memory[k][dk].copy_(torch.from_numpy(dv)) + else: + shared_memory[k][nk].copy_(torch.from_numpy(nv)) else: raise NotImplementedError('Unsupported obs type {}'.format(type(v))) # non_shared[k] = v