Source code for btgym.algorithms.worker

#
# Original A3C code comes from OpenAI repository under MIT licence:
# https://github.com/openai/universe-starter-agent
#
# Papers:
# https://arxiv.org/abs/1602.01783
# https://arxiv.org/abs/1611.05397

from logbook import Logger, StreamHandler
import sys
import os
import random
import multiprocessing
import datetime

import tensorflow as tf

sys.path.insert(0, '..')
tf.logging.set_verbosity(tf.logging.INFO)


[docs]class FastSaver(tf.train.Saver): """ Disables write_meta_graph argument, which freezes entire process and is mostly useless. """ def save( self, sess, save_path, global_step=None, latest_filename=None, meta_graph_suffix="meta", write_meta_graph=True, write_state=True, strip_default_attrs=False ): super(FastSaver, self).save( sess, save_path, global_step, latest_filename, meta_graph_suffix, write_meta_graph=False, )
[docs]class Worker(multiprocessing.Process): """ Distributed tf worker class. Sets up environment, trainer and starts training process in supervised session. """ env_list = None def __init__(self, env_config, policy_config, trainer_config, cluster_spec, job_name, task, log_dir, log_ckpt_subdir, initial_ckpt_dir, save_secs, log_level, max_env_steps, random_seed=None, render_last_env=False, test_mode=False): """ Args: env_config: environment class_config_dict. policy_config: model policy estimator class_config_dict. trainer_config: algorithm class_config_dict. cluster_spec: tf.cluster specification. job_name: worker or parameter server. task: integer number, 0 is chief worker. log_dir: path for tb summaries and current checkpoints. log_ckpt_subdir: log_dir subdirectory to store current checkpoints initial_ckpt_dir: path for checkpoint to load as pre-trained model. save_secs: int, save model checkpoint every N secs. log_level: int, logbook.level max_env_steps: number of environment steps to run training on random_seed: int or None render_last_env: bool, if True - render enabled for last environment in a list; first otherwise test_mode: if True - use Atari mode, BTGym otherwise. Note: - Conventional `self.global_step` refers to number of environment steps, summarized over all environment instances, not to number of policy optimizer train steps. - Every worker can run several environments in parralell, as specified by `cluster_config'['num_envs']. If use 4 forkers and num_envs=4 => total number of environments is 16. Every env instance has it's own ThreadRunner process. - When using replay memory, keep in mind that every ThreadRunner is keeping it's own replay memory, If memory_size = 2000, num_workers=4, num_envs=4 => total replay memory size equals 32 000 frames. """ super(Worker, self).__init__() self.env_class = env_config['class_ref'] self.env_kwargs = env_config['kwargs'] self.policy_config = policy_config self.trainer_class = trainer_config['class_ref'] self.trainer_kwargs = trainer_config['kwargs'] self.cluster_spec = cluster_spec self.job_name = job_name self.task = task self.is_chief = (self.task == 0) self.log_dir = log_dir self.save_secs = save_secs self.max_env_steps = max_env_steps self.log_level = log_level self.log = None self.test_mode = test_mode self.random_seed = random_seed self.render_last_env = render_last_env # Saver and summaries path: self.current_ckpt_dir = self.log_dir + log_ckpt_subdir self.initial_ckpt_dir = initial_ckpt_dir self.summary_dir = self.log_dir + '/worker_{}'.format(self.task) # print(log_ckpt_subdir) # print(self.log_dir) # print(self.current_ckpt_dir) # print(self.initial_ckpt_dir) # print(self.summary_dir) self.summary_writer = None self.config = None self.saver = None def _restore_model_params(self, sess, save_path): """ Restores model parameters from specified location. Args: sess: tf.Session obj. save_path: path where parameters were previously saved. Returns: True if model has been successfully loaded, False otherwise. """ if save_path is None: return False assert self.saver is not None, 'FastSaver has not been configured.' try: # Look for valid checkpoint: ckpt_state = tf.train.get_checkpoint_state(save_path) if ckpt_state is not None and ckpt_state.model_checkpoint_path: self.saver.restore(sess, ckpt_state.model_checkpoint_path) else: self.log.notice('no saved model parameters found in:\n{}'.format(save_path)) return False except (ValueError, tf.errors.NotFoundError, tf.errors.InvalidArgumentError) as e: self.log.notice('failed to restore model parameters from:\n{}'.format(save_path)) return False return True def _save_model_params(self, sess, global_step): """ Saves model checkpoint to predefined location. Args: sess: tf.Session obj. global_step: global step number is appended to save_path to create the checkpoint filenames """ assert self.saver is not None, 'FastSaver has not been configured.' self.saver.save( sess, save_path=self.current_ckpt_dir + '/model_parameters', global_step=global_step )
[docs] def run(self): """Worker runtime body. """ # Logging: StreamHandler(sys.stdout).push_application() self.log = Logger('Worker_{}'.format(self.task), level=self.log_level) try: tf.reset_default_graph() if self.test_mode: import gym # Define cluster: cluster = tf.train.ClusterSpec(self.cluster_spec).as_cluster_def() # Start tf.server: if self.job_name in 'ps': server = tf.train.Server( cluster, job_name=self.job_name, task_index=self.task, config=tf.ConfigProto(device_filters=["/job:ps"]) ) self.log.debug('parameters_server started.') # Just block here: server.join() else: server = tf.train.Server( cluster, job_name='worker', task_index=self.task, config=tf.ConfigProto( intra_op_parallelism_threads=1, # original was: 1 inter_op_parallelism_threads=2 # original was: 2 ) ) self.log.debug('tf.server started.') self.log.debug('making environments:') # Making as many environments as many entries in env_config `port` list: # TODO: Hacky-II: only one example over all parallel environments can be data-master [and renderer] # TODO: measure data_server lags, maybe launch several instances self.env_list = [] env_kwargs = self.env_kwargs.copy() env_kwargs['log_level'] = self.log_level port_list = env_kwargs.pop('port') data_port_list = env_kwargs.pop('data_port') data_master = env_kwargs.pop('data_master') render_enabled = env_kwargs.pop('render_enabled') render_list = [False for entry in port_list] if render_enabled: if self.render_last_env: render_list[-1] = True else: render_list[0] = True data_master_list = [False for entry in port_list] if data_master: data_master_list[0] = True # Parallel envs. numbering: if len(port_list) > 1: task_id = 0.0 else: task_id = 0 for port, data_port, is_render, is_master in zip(port_list, data_port_list, render_list, data_master_list): # Get random seed for environments: env_kwargs['random_seed'] = random.randint(0, 2 ** 30) if not self.test_mode: # Assume BTgym env. class: self.log.debug('setting env at port_{} is data_master: {}'.format(port, data_master)) self.log.debug('env_kwargs:') for k, v in env_kwargs.items(): self.log.debug('{}: {}'.format(k, v)) try: self.env_list.append( self.env_class( port=port, data_port=data_port, data_master=is_master, render_enabled=is_render, task=self.task + task_id, **env_kwargs ) ) data_master = False self.log.info('set BTGym environment {} @ port:{}, data_port:{}'. format(self.task + task_id, port, data_port)) task_id += 0.01 except Exception as e: self.log.exception( 'failed to make BTGym environment at port_{}.'.format(port) ) raise e else: # Assume atari testing: try: self.env_list.append(self.env_class(env_kwargs['gym_id'])) self.log.debug('set Gyn/Atari environment.') except Exception as e: self.log.exception('failed to make Gym/Atari environment') raise e self.log.debug('Defining trainer...') # Define trainer: trainer = self.trainer_class( env=self.env_list, task=self.task, policy_config=self.policy_config, log_level=self.log_level, cluster_spec=self.cluster_spec, random_seed=self.random_seed, **self.trainer_kwargs, ) self.log.debug('trainer ok.') # Saver-related: variables_to_save = [v for v in tf.global_variables() if not 'local' in v.name] local_variables = [v for v in tf.global_variables() if 'local' in v.name] + tf.local_variables() init_op = tf.initializers.variables(variables_to_save) local_init_op = tf.initializers.variables(local_variables) init_all_op = tf.global_variables_initializer() def init_fn(_sess): self.log.notice("initializing all parameters...") _sess.run(init_all_op) # def init_fn_scaff(scaffold, _sess): # self.log.notice("initializing all parameters...") # _sess.run(init_all_op) # self.log.warning('VARIABLES TO SAVE:') # for v in variables_to_save: # self.log.warning(v) # # self.log.warning('LOCAL VARS:') # for v in local_variables: # self.log.warning(v) self.saver = FastSaver(var_list=variables_to_save, max_to_keep=1, save_relative_paths=True) self.config = tf.ConfigProto(device_filters=["/job:ps", "/job:worker/task:{}/cpu:0".format(self.task)]) sess_manager = tf.train.SessionManager( local_init_op=local_init_op, ready_op=None, ready_for_local_init_op=tf.report_uninitialized_variables(variables_to_save), graph=None, recovery_wait_secs=90, ) with sess_manager.prepare_session( master=server.target, init_op=init_op, config=self.config, init_fn=init_fn, ) as sess: # Try to restore pre-trained model pre_trained_restored = self._restore_model_params(sess, self.initial_ckpt_dir) _ = sess.run(trainer.reset_global_step) if not pre_trained_restored: # If not - try to recover current checkpoint: current_restored = self._restore_model_params(sess, self.current_ckpt_dir) else: current_restored = False if not pre_trained_restored and not current_restored: self.log.notice('training from scratch...') self.log.info("connecting to the parameter server... ") self.summary_writer = tf.summary.FileWriter(self.summary_dir, sess.graph) trainer.start(sess, self.summary_writer) # Note: `self.global_step` refers to number of environment steps # summarized over all environment instances, not to number of policy optimizer train steps. global_step = sess.run(trainer.global_step) self.log.notice("started training at step: {}".format(global_step)) last_saved_time = datetime.datetime.now() last_saved_step = global_step while global_step < self.max_env_steps: trainer.process(sess) global_step = sess.run(trainer.global_step) time_delta = datetime.datetime.now() - last_saved_time if self.is_chief and time_delta.total_seconds() > self.save_secs: self._save_model_params(sess, global_step) train_speed = (global_step - last_saved_step) / (time_delta.total_seconds() + 1) self.log.notice( 'train step: {}; cluster speed: {:.0f} step/sec; checkpoint saved.'.format( global_step, train_speed ) ) last_saved_time = datetime.datetime.now() last_saved_step = global_step # Ask for all the services to stop: for env in self.env_list: env.close() self.log.notice('reached {} steps, exiting.'.format(global_step)) except Exception as e: self.log.exception(e) raise e