Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion rllib/env/base_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,8 @@ def to_base_env(env,
remote_env_batch_wait_ms=0):
"""Wraps any env type as needed to expose the async interface."""

from ray.rllib.env.remote_vector_env import RemoteVectorEnv
# from ray.rllib.env.remote_vector_env import RemoteVectorEnv
from ray.rllib.env.zmq_remote_env import ZMQRemoteVectorEnv as RemoteVectorEnv
if remote_envs and num_envs == 1:
raise ValueError(
"Remote envs only make sense to use if num_envs > 1 "
Expand Down
218 changes: 218 additions & 0 deletions rllib/env/zmq_remote_env.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,218 @@
from collections import OrderedDict, Iterable
import json
import os
import logging
import pickle

import cloudpickle
import numpy as np
import torch
import zmq
from torch import multiprocessing as mp

import ray
from ray.rllib.env.base_env import BaseEnv, _DUMMY_AGENT_ID, ASYNC_RESET_RETURN
from .zmq_remote_worker import zmq_worker

logger = logging.getLogger(__name__)
ZMQ_CONNECT_METHOD = 'tcp'


class WorkerError(BaseException):
pass


class ZMQRemoteVectorEnv(BaseEnv):
"""Vector env that executes envs in another process using ZMQ.
Waits for all envs to compute, but is generally very fast.
Both single and multi-agent child envs are supported.
"""
def __init__(self, make_env, num_envs, multiagent,
remote_env_batch_wait_ms):
self.make_local_env = make_env
self.num_envs = num_envs
self.multiagent = multiagent
self.poll_timeout = remote_env_batch_wait_ms / 1000

self.processes = None # lazy init

def poll(self):
# Init processes
if self.processes is None:
self._setup()
obs, rewards, dones, infos = self._setup_reset()
return obs, rewards, dones, infos, {}

# each keyed by env_id in [0, num_remote_envs)
obs, rewards, dones, infos = {}, {}, {}, {}

# wait for all envs to finish
for e_id, remote in self._zmq_sockets.items():
result = remote.recv()
self._check_for_errors(result, e_id)
_, rew, done, info = json.loads(result.decode())
obs[e_id] = self._obs_to_numpy(self.shared_memories[e_id])
rewards[e_id] = rew
dones[e_id] = done
infos[e_id] = info

self.waiting = False
return obs, rewards, dones, infos, {}

def send_actions(self, action_dict):
for env_id, actions in action_dict.items():
socket = self._zmq_sockets[env_id]
a = {}
for k, v in actions.items():
# TODO: support nested action dict?
if isinstance(v, np.ndarray):
a[k] = v.item()
elif isinstance(v, Iterable):
a[k] = [x.item() for x in v]
else:
a[k] = v.item()
msg = json.dumps(a)
socket.send(msg.encode(), zmq.NOBLOCK, copy=False, track=False)

self.waiting = True

def try_reset(self, env_id):
# only reset the env
self._zmq_sockets[env_id].send('reset'.encode())
return ASYNC_RESET_RETURN

def stop(self):
if self.closed:
return
if self.waiting:
for remote in self._zmq_sockets:
remote.recv()
for socket in self._zmq_sockets:
socket.send('close'.encode())
for p in self.processes:
p.join()
self.closed = True

def _setup(self):
self.waiting = False
self.closed = False
self.processes = {}

self._zmq_context = zmq.Context()
self._zmq_sockets = {}

# iterate envs to get torch shared memory through pipe then close it
self.shared_memories = {}

for w_ind in range(self.num_envs):
pipe, w_pipe = mp.Pipe()
socket, port = zmq_robust_bind_socket(self._zmq_context)

process = mp.Process(target=zmq_worker, args=(
w_pipe, pipe, port, CloudpickleWrapper((self.make_local_env, w_ind))
))
process.daemon = True
process.start()
self.processes[w_ind] = process

self._zmq_sockets[w_ind] = socket

pipe.send(('get_shared_memory', None))
self.shared_memories[w_ind] = pipe.recv()

# switch to zmq socket and close pipes
pipe.send(('switch_zmq', None))
pipe.close()
w_pipe.close()

logger.info("All remote envs started")

def _setup_reset(self):
for _, socket in self._zmq_sockets.items():
socket.send('reset'.encode())

all_obs = {}
all_rew = {}
all_info = {}
all_done = {}
for r_ind, remote in self._zmq_sockets.items():
result = remote.recv().decode()
self._check_for_errors(result, r_ind)
_, rew, done, info = json.loads(result)

all_obs[r_ind] = self._obs_to_numpy(self.shared_memories[r_ind])
# each keyed by agent_id in the env
all_rew[r_ind] = rew
all_info[r_ind] = info
all_done[r_ind] = done

return all_obs, all_rew, all_done, all_info

@staticmethod
def _obs_to_numpy(obs):
o = OrderedDict({})
for k, v in obs.items():
if isinstance(v, torch.Tensor):
o[k] = v.numpy()
# support double layer dict
elif isinstance(v, dict):
o[k] = OrderedDict({})
for nk, nv in v.items():
if isinstance(nv, dict):
dd = OrderedDict({})
for dk, dv in nv.items():
if isinstance(dv, dict):
raise NotImplementedError('Double Nested obs space dict not implemented')
else:
dd[dk] = dv.numpy()
o[k][nk] = dd
else:
o[k][nk] = nv.numpy()
else:
raise NotImplementedError('Unsupported obs type {}'.format(type(v)))
return o

def _check_for_errors(self, result, e_id):
if result[:5] == b'error':
error = 'Worker {} has an error {}'.format(e_id, result)
raise WorkerError(error)


class CloudpickleWrapper(object):
"""
Modified.
MIT License
Copyright (c) 2017 OpenAI (http://openai.com)
"""

def __init__(self, x):
self.x = x

def __getstate__(self):
return cloudpickle.dumps(self.x)

def __setstate__(self, ob):
self.x = pickle.loads(ob)


def zmq_robust_bind_socket(zmq_context):
try_count = 0
while try_count < 3:
try:
socket = zmq_context.socket(zmq.PAIR)
port = np.random.randint(5000, 30000)
if ZMQ_CONNECT_METHOD == 'tcp':
socket.bind("tcp://*:{}".format(port))
if ZMQ_CONNECT_METHOD == 'ipc':
os.makedirs('/dev/shm/remotezmq/', exist_ok=True)
socket.bind("ipc:///dev/shm/remotezmq/{}".format(port))
except zmq.error.ZMQError as e:
try_count += 1
socket = None
last_error = e
continue
break
if socket is None:
raise Exception("ZMQ couldn't bind socket after 3 tries. {}".format(last_error))
return socket, port

139 changes: 139 additions & 0 deletions rllib/env/zmq_remote_worker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
from collections import OrderedDict
import gym
import json
import os
import logging
import traceback
import pickle

import cloudpickle
import numpy as np
import torch
import zmq
from torch import multiprocessing as mp

import ray
from ray.rllib.env.base_env import BaseEnv, _DUMMY_AGENT_ID, ASYNC_RESET_RETURN

logger = logging.getLogger(__name__)
ZMQ_CONNECT_METHOD = 'tcp'


def zmq_worker(remote, parent_remote, port, env_fn_wrapper):
"""
Modified.
MIT License
Copyright (c) 2017 OpenAI (http://openai.com)
"""
parent_remote.close()
env_fn, env_id = env_fn_wrapper.x
env = env_fn(env_id)

shared_memory = OrderedDict({})
for name, space in env.observation_space.spaces.items():
if isinstance(space, gym.spaces.Box):
if space.dtype != np.float32:
raise NotImplementedError('Type not implemented {}'.format(space.dtype))
tensor = torch.zeros(space.shape, dtype=torch.float32)
shared_memory[name] = tensor
elif isinstance(space, gym.spaces.dict.Dict):
sm = OrderedDict({})
for nk, ns in space.spaces.items():
if isinstance(ns, gym.spaces.Box):
if ns.dtype != np.float32:
raise NotImplementedError('Type not implemented {}'.format(ns.dtype))
tensor = torch.zeros(ns.shape, dtype=torch.float32)
elif isinstance(space, gym.spaces.dict.Dict):
raise NotImplementedError('Gym nested dict spaces not supported')
sm[nk] = tensor
shared_memory[name] = sm

# initial python pipe setup
python_pipe = True
while python_pipe:
cmd, _ = remote.recv()
if cmd == 'get_shared_memory':
remote.send(shared_memory)
elif cmd == 'switch_zmq':
# close python pipes
remote.close()
python_pipe = False
else:
raise NotImplementedError

# zmq setup
context = zmq.Context()
socket = context.socket(zmq.PAIR)
if ZMQ_CONNECT_METHOD == 'tcp':
socket.connect("tcp://localhost:{}".format(port))
if ZMQ_CONNECT_METHOD == 'ipc':
socket.connect("ipc:///dev/shm/remotezmq/{}".format(port))

running = True
while running:
try:
socket_data = socket.recv()
socket_parsed = socket_data.decode()

# commands that aren't action dictionaries
if socket_parsed == 'reset':
ob = env.reset()
# MUST return ob, reward, done, info
# TODO: should be different for not multi agent env
reward = {agent_id: 0 for agent_id in ob.keys()}
done = {"__all__": False}
info = {agent_id: {} for agent_id in ob.keys()}
# only the non-shared obs are returned here
non_shared_ob = handle_ob(ob, shared_memory)
msg = json.dumps((non_shared_ob, reward, done, info))
socket.send(msg.encode(), zmq.NOBLOCK, copy=False, track=False)
elif socket_parsed == 'close':
env.close()
running = False
# else action dictionary
else:
action_dictionary = json.loads(socket_parsed)
ob, reward, done, info = env.step(action_dictionary)
# Done ob handled by reset
# if isinstance(done, dict):
# if done['__all__']:
# ob = env.reset()
# elif done:
# ob = env.reset()
ob = handle_ob(ob, shared_memory)
# only the non-shared obs are returned here
msg = json.dumps((ob, reward, done, info))
socket.send(msg.encode(), zmq.NOBLOCK, copy=False, track=False)
except KeyboardInterrupt:
pass
except Exception as e:
running = False
e_str = '{}: {}'.format(type(e).__name__, e)
e_str += '\n' + traceback.format_exc()
print('Subprocess environment has an error', e_str)
socket.send('error. {}'.format(e_str).encode(), zmq.NOBLOCK, copy=False, track=False)


def handle_ob(ob, shared_memory):
non_shared = OrderedDict({})
for k, v in ob.items():
if isinstance(v, torch.Tensor):
shared_memory[k].copy_(v)
elif isinstance(v, np.ndarray):
shared_memory[k].copy_(torch.from_numpy(v))
# support double layer dict
elif isinstance(v, dict):
for nk, nv in v.items():
if isinstance(nv, dict):
dd = OrderedDict({})
for dk, dv in nv.items():
if isinstance(dv, dict):
raise NotImplementedError('Double Nested obs space dict not implemented')
else:
shared_memory[k][dk].copy_(torch.from_numpy(dv))
else:
shared_memory[k][nk].copy_(torch.from_numpy(nv))
else:
raise NotImplementedError('Unsupported obs type {}'.format(type(v)))
# non_shared[k] = v
return non_shared