from tensorflow.contrib.layers import flatten as batch_flatten
from btgym.algorithms.policy.base import BaseAacPolicy
from btgym.algorithms.nn.networks import *
from btgym.algorithms.utils import *
from btgym.spaces import DictSpace, ActionDictSpace
[docs]class StackedLstmPolicy(BaseAacPolicy):
"""
Conv.-Stacked_LSTM policy, based on `NAV A3C agent` architecture from
`LEARNING TO NAVIGATE IN COMPLEX ENVIRONMENTS` by Mirowski et all. and
`LEARNING TO REINFORCEMENT LEARN` by JX Wang et all.
Papers:
https://arxiv.org/pdf/1611.03673.pdf
https://arxiv.org/pdf/1611.05763.pdf
"""
def __init__(self,
ob_space,
ac_space,
rp_sequence_size,
state_encoder_class_ref=conv_2d_network,
lstm_class_ref=tf.contrib.rnn.LayerNormBasicLSTMCell,
lstm_layers=(256, 256),
linear_layer_ref=noisy_linear,
share_encoder_params=False,
dropout_keep_prob=1.0,
action_dp_alpha=200.0,
aux_estimate=False,
encode_internal_state=False,
static_rnn=True,
shared_p_v=False,
**kwargs):
"""
Defines [partially shared] on/off-policy networks for estimating action-logits, value function,
reward and state 'pixel_change' predictions.
Expects multi-modal observation as array of shape `ob_space`.
Args:
ob_space: instance of btgym.spaces.DictSpace
ac_space: instance of btgym.spaces.ActionDictSpace
rp_sequence_size: reward prediction sample length
lstm_class_ref: tf.nn.lstm class to use
lstm_layers: tuple of LSTM layers sizes
linear_layer_ref: linear layer class to use
share_encoder_params: bool, whether to share encoder parameters for every 'external' data stream
dropout_keep_prob: in (0, 1] dropout regularisation parameter
action_dp_alpha:
aux_estimate: (bool), if True - add auxiliary tasks estimations to self.callbacks dictionary
encode_internal_state: use encoder over 'internal' part of observation space
static_rnn: (bool), it True - use static rnn graph, dynamic otherwise
**kwargs not used
"""
assert isinstance(ob_space, DictSpace), \
'Expected observation space be instance of btgym.spaces.DictSpace, got: {}'.format(ob_space)
self.ob_space = ob_space
assert isinstance(ac_space, ActionDictSpace), \
'Expected action space be instance of btgym.spaces.ActionDictSpace, got: {}'.format(ac_space)
self.ac_space = ac_space
self.rp_sequence_size = rp_sequence_size
self.state_encoder_class_ref = state_encoder_class_ref
self.lstm_class = lstm_class_ref
self.lstm_layers = lstm_layers
self.action_dp_alpha = action_dp_alpha
self.aux_estimate = aux_estimate
self.callback = {}
# self.encode_internal_state = encode_internal_state
self.share_encoder_params = share_encoder_params
if self.share_encoder_params:
self.reuse_encoder_params = tf.AUTO_REUSE
else:
self.reuse_encoder_params = False
self.static_rnn = static_rnn
self.dropout_keep_prob = dropout_keep_prob
assert 0 < self.dropout_keep_prob <= 1, 'Dropout keep_prob value should be in (0, 1]'
self.debug = {}
# Placeholders for obs. state input:
self.on_state_in = nested_placeholders(self.ob_space.shape, batch_dim=None, name='on_policy_state_in')
self.off_state_in = nested_placeholders(self.ob_space.shape, batch_dim=None, name='off_policy_state_in_pl')
self.rp_state_in = nested_placeholders(self.ob_space.shape, batch_dim=None, name='rp_state_in')
# Placeholders for previous step action[multi-categorical vector encoding] and reward [scalar]:
self.on_last_a_in = tf.placeholder(
tf.float32,
[None, self.ac_space.encoded_depth],
name='on_policy_last_action_in_pl'
)
self.on_last_reward_in = tf.placeholder(tf.float32, [None], name='on_policy_last_reward_in_pl')
self.off_last_a_in = tf.placeholder(
tf.float32,
[None, self.ac_space.encoded_depth],
name='off_policy_last_action_in_pl'
)
self.off_last_reward_in = tf.placeholder(tf.float32, [None], name='off_policy_last_reward_in_pl')
# Placeholders for rnn batch and time-step dimensions:
self.on_batch_size = tf.placeholder(tf.int32, name='on_policy_batch_size')
self.on_time_length = tf.placeholder(tf.int32, name='on_policy_sequence_size')
self.off_batch_size = tf.placeholder(tf.int32, name='off_policy_batch_size')
self.off_time_length = tf.placeholder(tf.int32, name='off_policy_sequence_size')
self.debug['on_state_in_keys'] = list(self.on_state_in.keys())
# Dropout related:
try:
if self.train_phase is not None:
pass
except AttributeError:
self.train_phase = tf.placeholder_with_default(
tf.constant(False, dtype=tf.bool),
shape=(),
name='train_phase_flag_pl'
)
self.keep_prob = 1.0 - (1.0 - self.dropout_keep_prob) * tf.cast(self.train_phase, tf.float32)
# Default parameters:
default_kwargs = dict(
conv_2d_filter_size=[3, 1],
conv_2d_stride=[2, 1],
conv_2d_num_filters=[32, 32, 64, 64],
pc_estimator_stride=[2, 1],
duell_pc_x_inner_shape=(6, 1, 32), # [6,3,32] if swapping W-C dims
duell_pc_filter_size=(4, 1),
duell_pc_stride=(2, 1),
keep_prob=self.keep_prob,
)
# Insert if not already:
for key, default_value in default_kwargs.items():
if key not in kwargs.keys():
kwargs[key] = default_value
# Base on-policy AAC network:
self.modes_to_encode = ['external', 'internal']
for mode in self.modes_to_encode:
assert mode in self.on_state_in.keys(), \
'Required top-level mode `{}` not found in state shape specification'.format(mode)
# Separately encode than concatenate all `external` and 'internal' states modes,
# [jointly] encode every stream within mode:
self.on_aac_x_encoded = {}
for key in self.modes_to_encode:
if isinstance(self.on_state_in[key], dict): # got dictionary of data streams
if self.share_encoder_params:
layer_name_template = 'encoded_{}_shared'
else:
layer_name_template = 'encoded_{}_{}'
encoded_streams = {
name: tf.layers.flatten(
self.state_encoder_class_ref(
x=stream,
ob_space=self.ob_space.shape[key][name],
ac_space=self.ac_space,
name=layer_name_template.format(key, name),
reuse=self.reuse_encoder_params, # shared params for all streams in mode
**kwargs
)
)
for name, stream in self.on_state_in[key].items()
}
encoded_mode = tf.concat(
list(encoded_streams.values()),
axis=-1,
name='multi_encoded_{}'.format(key)
)
else:
# Got single data stream:
encoded_mode = tf.layers.flatten(
self.state_encoder_class_ref(
x=self.on_state_in[key],
ob_space=self.ob_space.shape[key],
ac_space=self.ac_space,
name='encoded_{}'.format(key),
**kwargs
)
)
self.on_aac_x_encoded[key] = encoded_mode
self.debug['on_state_external_encoded_dict'] = self.on_aac_x_encoded
# on_aac_x = tf.concat(list(self.on_aac_x_encoded.values()), axis=-1, name='on_state_external_encoded')
on_aac_x = self.on_aac_x_encoded['external']
self.debug['on_state_external_encoded'] = on_aac_x
# TODO: for encoder prediction test, output `naive` estimates for logits and value directly from encoder:
[self.on_simple_logits, self.on_simple_value, _] = dense_aac_network(
tf.layers.flatten(on_aac_x),
ac_space_depth=self.ac_space.one_hot_depth,
linear_layer_ref=linear_layer_ref,
name='aac_dense_simple_pi_v'
)
# Reshape rnn inputs for batch training as: [rnn_batch_dim, rnn_time_dim, flattened_depth]:
x_shape_dynamic = tf.shape(on_aac_x)
max_seq_len = tf.cast(x_shape_dynamic[0] / self.on_batch_size, tf.int32)
x_shape_static = on_aac_x.get_shape().as_list()
on_last_action_in = tf.reshape(
self.on_last_a_in,
[self.on_batch_size, max_seq_len, self.ac_space.encoded_depth]
)
on_last_r_in = tf.reshape(self.on_last_reward_in, [self.on_batch_size, max_seq_len, 1])
on_aac_x = tf.reshape(on_aac_x, [self.on_batch_size, max_seq_len, np.prod(x_shape_static[1:])])
# # Prepare `internal` state, if any:
# if 'internal' in list(self.on_state_in.keys()):
# if self.encode_internal_state:
# # Use convolution encoder:
# on_x_internal = self.state_encoder_class_ref(
# x=self.on_state_in['internal'],
# ob_space=self.ob_space.shape['internal'],
# ac_space=self.ac_space,
# name='encoded_internal',
# **kwargs
# )
# x_int_shape_static = on_x_internal.get_shape().as_list()
# on_x_internal = [
# tf.reshape(on_x_internal, [self.on_batch_size, max_seq_len, np.prod(x_int_shape_static[1:])])]
# self.debug['on_state_internal_encoded'] = on_x_internal
#
# else:
# # Feed as is:
# x_int_shape_static = self.on_state_in['internal'].get_shape().as_list()
# on_x_internal = tf.reshape(
# self.on_state_in['internal'],
# [self.on_batch_size, max_seq_len, np.prod(x_int_shape_static[1:])]
# )
# self.debug['on_state_internal_encoded'] = on_x_internal
# on_x_internal = [on_x_internal]
#
# else:
# on_x_internal = []
on_x_internal = self.on_aac_x_encoded['internal']
# Reshape to batch-feed rnn:
x_int_shape_static = on_x_internal.get_shape().as_list()
on_x_internal = tf.reshape(
on_x_internal,
[self.on_batch_size, max_seq_len, np.prod(x_int_shape_static[1:])]
)
self.debug['on_state_internal_encoded'] = on_x_internal
on_x_internal = [on_x_internal]
# Prepare datetime index if any:
if 'datetime' in list(self.on_state_in.keys()):
x_dt_shape_static = self.on_state_in['datetime'].get_shape().as_list()
on_x_dt = tf.reshape(
self.on_state_in['datetime'],
[self.on_batch_size, max_seq_len, np.prod(x_dt_shape_static[1:])]
)
on_x_dt = [on_x_dt]
else:
on_x_dt = []
self.debug['on_state_dt_encoded'] = on_x_dt
self.debug['conv_input_to_lstm1'] = on_aac_x
# Feed last last_reward into LSTM_1 layer along with encoded `external` state features and datetime stamp:
# on_stage2_1_input = [on_aac_x, on_last_action_in, on_last_reward_in] + on_x_dt
on_stage2_1_input = [on_aac_x, on_last_r_in] #+ on_x_dt
# Feed last_action, encoded `external` state, `internal` state, datetime stamp into LSTM_2:
# on_stage2_2_input = [on_aac_x, on_last_action_in, on_last_reward_in] + on_x_internal + on_x_dt
on_stage2_2_input = [on_aac_x, on_last_action_in] + on_x_internal #+ on_x_dt
# LSTM_1 full input:
on_aac_x = tf.concat(on_stage2_1_input, axis=-1)
self.debug['concat_input_to_lstm1'] = on_aac_x
# First LSTM layer takes encoded `external` state:
[on_x_lstm_1_out, self.on_lstm_1_init_state, self.on_lstm_1_state_out, self.on_lstm_1_state_pl_flatten] =\
lstm_network(
x=on_aac_x,
lstm_sequence_length=self.on_time_length,
lstm_class=lstm_class_ref,
lstm_layers=(lstm_layers[0],),
static=static_rnn,
name='lstm_1',
**kwargs,
)
# var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, tf.get_variable_scope().name)
# print('var_list: ', var_list)
self.debug['on_x_lstm_1_out'] = on_x_lstm_1_out
self.debug['self.on_lstm_1_state_out'] = self.on_lstm_1_state_out
self.debug['self.on_lstm_1_state_pl_flatten'] = self.on_lstm_1_state_pl_flatten
# For time_flat only: Reshape on_lstm_1_state_out from [1,2,20,size] -->[20,1,2,size] --> [20,1, 2xsize]:
reshape_lstm_1_state_out = tf.transpose(self.on_lstm_1_state_out, [2, 0, 1, 3])
reshape_lstm_1_state_out_shape_static = reshape_lstm_1_state_out.get_shape().as_list()
# Take policy logits off first LSTM-dense layer:
# Reshape back to [batch, flattened_depth], where batch = rnn_batch_dim * rnn_time_dim:
x_shape_static = on_x_lstm_1_out.get_shape().as_list()
rsh_on_x_lstm_1_out = tf.reshape(on_x_lstm_1_out, [x_shape_dynamic[0], x_shape_static[-1]])
self.debug['reshaped_on_x_lstm_1_out'] = rsh_on_x_lstm_1_out
if not shared_p_v:
# Aac policy output and action-sampling function:
[self.on_logits, _, self.on_sample] = dense_aac_network(
rsh_on_x_lstm_1_out,
ac_space_depth=self.ac_space.one_hot_depth,
linear_layer_ref=linear_layer_ref,
name='aac_dense_pi'
)
# Second LSTM layer takes concatenated encoded 'external' state, LSTM_1 output,
# last_action and `internal_state` (if present) tensors:
on_stage2_2_input += [on_x_lstm_1_out]
# Try: feed context instead of output
#on_stage2_2_input = [reshape_lstm_1_state_out] + on_stage2_1_input
# LSTM_2 full input:
on_aac_x = tf.concat(on_stage2_2_input, axis=-1)
self.debug['on_stage2_2_input'] = on_aac_x
[on_x_lstm_2_out, self.on_lstm_2_init_state, self.on_lstm_2_state_out, self.on_lstm_2_state_pl_flatten] = \
lstm_network(
x=on_aac_x,
lstm_sequence_length=self.on_time_length,
lstm_class=lstm_class_ref,
lstm_layers=(lstm_layers[-1],),
static=static_rnn,
name='lstm_2',
**kwargs,
)
self.debug['on_x_lstm_2_out'] = on_x_lstm_2_out
self.debug['self.on_lstm_2_state_out'] = self.on_lstm_2_state_out
self.debug['self.on_lstm_2_state_pl_flatten'] = self.on_lstm_2_state_pl_flatten
# Reshape back to [batch, flattened_depth], where batch = rnn_batch_dim * rnn_time_dim:
x_shape_static = on_x_lstm_2_out.get_shape().as_list()
rsh_on_x_lstm_2_out = tf.reshape(on_x_lstm_2_out, [x_shape_dynamic[0], x_shape_static[-1]])
self.debug['reshaped_on_x_lstm_2_out'] = rsh_on_x_lstm_2_out
if shared_p_v:
[self.on_logits, self.on_vf, self.on_sample] = dense_aac_network(
rsh_on_x_lstm_2_out,
ac_space_depth=self.ac_space.one_hot_depth,
linear_layer_ref=linear_layer_ref,
name='aac_dense_pi_vfn'
)
else:
# Aac value function:
[_, self.on_vf, _] = dense_aac_network(
rsh_on_x_lstm_2_out,
ac_space_depth=self.ac_space.one_hot_depth,
linear_layer_ref=linear_layer_ref,
name='aac_dense_vfn'
)
# Concatenate LSTM placeholders, init. states and context:
self.on_lstm_init_state = (self.on_lstm_1_init_state, self.on_lstm_2_init_state)
self.on_lstm_state_out = (self.on_lstm_1_state_out, self.on_lstm_2_state_out)
self.on_lstm_state_pl_flatten = self.on_lstm_1_state_pl_flatten + self.on_lstm_2_state_pl_flatten
self.off_aac_x_encoded = {}
for key in self.modes_to_encode:
if isinstance(self.off_state_in[key], dict): # got dictionary of data streams
if self.share_encoder_params:
layer_name_template = 'encoded_{}_shared'
else:
layer_name_template = 'encoded_{}_{}'
encoded_streams = {
name: tf.layers.flatten(
self.state_encoder_class_ref(
x=stream,
ob_space=self.ob_space.shape[key][name],
ac_space=self.ac_space,
name=layer_name_template.format(key, name),
reuse=True, # shared params for all streams in mode
**kwargs
)
)
for name, stream in self.off_state_in[key].items()
}
encoded_mode = tf.concat(
list(encoded_streams.values()),
axis=-1,
name='multi_encoded_{}'.format(key)
)
else:
# Got single data stream:
encoded_mode = tf.layers.flatten(
self.state_encoder_class_ref(
x=self.off_state_in[key],
ob_space=self.ob_space.shape[key],
ac_space=self.ac_space,
name='encoded_{}'.format(key),
reuse=True,
**kwargs
)
)
self.off_aac_x_encoded[key] = encoded_mode
# off_aac_x = tf.concat(list(self.off_aac_x_encoded.values()), axis=-1, name='off_state_external_encoded')
off_aac_x = self.off_aac_x_encoded['external']
# Reshape rnn inputs for batch training as [rnn_batch_dim, rnn_time_dim, flattened_depth]:
x_shape_dynamic = tf.shape(off_aac_x)
max_seq_len = tf.cast(x_shape_dynamic[0] / self.off_batch_size, tf.int32)
x_shape_static = off_aac_x.get_shape().as_list()
off_last_action_in = tf.reshape(
self.off_last_a_in,
[self.off_batch_size, max_seq_len, self.ac_space.encoded_depth]
)
off_last_r_in = tf.reshape(self.off_last_reward_in, [self.off_batch_size, max_seq_len, 1])
off_aac_x = tf.reshape( off_aac_x, [self.off_batch_size, max_seq_len, np.prod(x_shape_static[1:])])
# # Prepare `internal` state, if any:
# if 'internal' in list(self.off_state_in.keys()):
# if self.encode_internal_state:
# # Use convolution encoder:
# off_x_internal = self.state_encoder_class_ref(
# x=self.off_state_in['internal'],
# ob_space=self.ob_space.shape['internal'],
# ac_space=self.ac_space,
# name='encoded_internal',
# reuse=True,
# **kwargs
# )
# x_int_shape_static = off_x_internal.get_shape().as_list()
# off_x_internal = [
# tf.reshape(off_x_internal, [self.off_batch_size, max_seq_len, np.prod(x_int_shape_static[1:])])
# ]
# else:
# x_int_shape_static = self.off_state_in['internal'].get_shape().as_list()
# off_x_internal = tf.reshape(
# self.off_state_in['internal'],
# [self.off_batch_size, max_seq_len, np.prod(x_int_shape_static[1:])]
# )
# off_x_internal = [off_x_internal]
#
# else:
# off_x_internal = []
off_x_internal = self.off_aac_x_encoded['internal']
x_int_shape_static = off_x_internal.get_shape().as_list()
# Properly feed LSTM2:
off_x_internal = tf.reshape(
off_x_internal,
[self.off_batch_size, max_seq_len, np.prod(x_int_shape_static[1:])]
)
off_x_internal = [off_x_internal]
if 'datetime' in list(self.off_state_in.keys()):
x_dt_shape_static = self.off_state_in['datetime'].get_shape().as_list()
off_x_dt = tf.reshape(
self.off_state_in['datetime'],
[self.off_batch_size, max_seq_len, np.prod(x_dt_shape_static[1:])]
)
off_x_dt = [off_x_dt]
else:
off_x_dt = []
# off_stage2_1_input = [off_aac_x, off_last_action_in, off_last_reward_in] + off_x_dt
off_stage2_1_input = [off_aac_x, off_last_r_in] # + off_x_dt
# off_stage2_2_input = [off_aac_x, off_last_action_in, off_last_reward_in] + off_x_internal + off_x_dt
off_stage2_2_input = [off_aac_x, off_last_action_in] + off_x_internal # + off_x_dt
off_aac_x = tf.concat(off_stage2_1_input, axis=-1)
[off_x_lstm_1_out, _, _, self.off_lstm_1_state_pl_flatten] =\
lstm_network(
off_aac_x,
self.off_time_length,
lstm_class_ref,
(lstm_layers[0],),
name='lstm_1',
static=static_rnn,
reuse=True,
**kwargs,
)
# Reshape back to [batch, flattened_depth], where batch = rnn_batch_dim * rnn_time_dim:
x_shape_static = off_x_lstm_1_out.get_shape().as_list()
rsh_off_x_lstm_1_out = tf.reshape(off_x_lstm_1_out, [x_shape_dynamic[0], x_shape_static[-1]])
if not shared_p_v:
[self.off_logits, _, _] =\
dense_aac_network(
rsh_off_x_lstm_1_out,
ac_space_depth=self.ac_space.one_hot_depth,
linear_layer_ref=linear_layer_ref,
name='aac_dense_pi',
reuse=True
)
off_stage2_2_input += [off_x_lstm_1_out]
# LSTM_2 full input:
off_aac_x = tf.concat(off_stage2_2_input, axis=-1)
[off_x_lstm_2_out, _, _, self.off_lstm_2_state_pl_flatten] = \
lstm_network(
off_aac_x,
self.off_time_length,
lstm_class_ref,
(lstm_layers[-1],),
name='lstm_2',
static=static_rnn,
reuse=True,
**kwargs,
)
# Reshape back to [batch, flattened_depth], where batch = rnn_batch_dim * rnn_time_dim:
x_shape_static = off_x_lstm_2_out.get_shape().as_list()
rsh_off_x_lstm_2_out = tf.reshape(off_x_lstm_2_out, [x_shape_dynamic[0], x_shape_static[-1]])
if shared_p_v:
[self.off_logits, self.off_vf, _] = dense_aac_network(
rsh_off_x_lstm_2_out,
ac_space_depth=self.ac_space.one_hot_depth,
linear_layer_ref=linear_layer_ref,
name='aac_dense_pi_vfn',
reuse=True
)
else:
# Aac value function:
[_, self.off_vf, _] = dense_aac_network(
rsh_off_x_lstm_2_out,
ac_space_depth=self.ac_space.one_hot_depth,
linear_layer_ref=linear_layer_ref,
name='aac_dense_vfn',
reuse=True
)
# Concatenate LSTM states:
self.off_lstm_state_pl_flatten = self.off_lstm_1_state_pl_flatten + self.off_lstm_2_state_pl_flatten
if False: # TEMP DISABLE
# Aux1:
# `Pixel control` network.
#
# Define pixels-change estimation function:
# Yes, it rather env-specific but for atari case it is handy to do it here, see self.get_pc_target():
[self.pc_change_state_in, self.pc_change_last_state_in, self.pc_target] =\
pixel_change_2d_estimator(ob_space['external'], **kwargs)
self.pc_batch_size = self.off_batch_size
self.pc_time_length = self.off_time_length
self.pc_state_in = self.off_state_in
self.pc_a_r_in = self.off_a_r_in
self.pc_lstm_state_pl_flatten = self.off_lstm_state_pl_flatten
# Shared conv and lstm nets, same off-policy batch:
pc_x = rsh_off_x_lstm_2_out
# PC duelling Q-network, outputs [None, 20, 20, ac_size] Q-features tensor:
self.pc_q = duelling_pc_network(pc_x, self.ac_space, linear_layer_ref=linear_layer_ref, **kwargs)
# Aux2:
# `Value function replay` network.
#
# VR network is fully shared with ppo network but with `value` only output:
# and has same off-policy batch pass with off_ppo network:
self.vr_batch_size = self.off_batch_size
self.vr_time_length = self.off_time_length
self.vr_state_in = self.off_state_in
self.vr_last_a_in = self.off_last_a_in
self.vr_last_reward_in = self.off_last_reward_in
self.vr_lstm_state_pl_flatten = self.off_lstm_state_pl_flatten
self.vr_value = self.off_vf
# Aux3:
# `Reward prediction` network.
self.rp_batch_size = tf.placeholder(tf.int32, name='rp_batch_size')
# Shared encoded output:
rp_x = {}
for key in self.rp_state_in.keys():
if 'external' in key:
if isinstance(self.rp_state_in[key], dict): # got dictionary of data streams
if self.share_encoder_params:
layer_name_template = 'encoded_{}_shared'
else:
layer_name_template = 'encoded_{}_{}'
encoded_streams = {
name: tf.layers.flatten(
self.state_encoder_class_ref(
x=stream,
ob_space=self.ob_space.shape[key][name],
ac_space=self.ac_space,
name=layer_name_template.format(key, name),
reuse=True, # shared params for all streams in mode
**kwargs
)
)
for name, stream in self.rp_state_in[key].items()
}
encoded_mode = tf.concat(
list(encoded_streams.values()),
axis=-1,
name='multi_encoded_{}'.format(key)
)
else:
# Got single data stream:
encoded_mode = tf.layers.flatten(
self.state_encoder_class_ref(
x=self.rp_state_in[key],
ob_space=self.ob_space.shape,
ac_space=self.ac_space,
name='encoded_{}'.format(key),
reuse=True,
**kwargs
)
)
rp_x[key] = encoded_mode
rp_x = tf.concat(list(rp_x.values()), axis=-1, name='rp_state_external_encoded')
# Flatten batch-wise:
rp_x_shape_static = rp_x.get_shape().as_list()
rp_x = tf.reshape(rp_x, [self.rp_batch_size, np.prod(rp_x_shape_static[1:]) * (self.rp_sequence_size-1)])
# RP output:
self.rp_logits = dense_rp_network(rp_x, linear_layer_ref=linear_layer_ref)
# Batch-norm related:
self.update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
# Add moving averages to save list:
moving_var_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, tf.get_variable_scope().name + '.*moving.*')
renorm_var_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, tf.get_variable_scope().name + '.*renorm.*')
# What to save:
self.var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, tf.get_variable_scope().name)
self.var_list += moving_var_list + renorm_var_list
# Callbacks:
if self.aux_estimate:
pass
# TEMP DISABLE: due to computation costs
# do not use pixel change aux. task; otherwise enable lines 533 - 553 & 640:
if False:
self.callback['pixel_change'] = self.get_pc_target
# print('policy_debug_dict:\n', self.debug)
[docs]class AacStackedRL2Policy(StackedLstmPolicy):
"""
Attempt to implement two-level RL^2
This policy class in conjunction with DataDomain classes from btgym.datafeed
is aimed to implement RL^2 algorithm by Duan et al.
Paper:
`FAST REINFORCEMENT LEARNING VIA SLOW REINFORCEMENT LEARNING`,
https://arxiv.org/pdf/1611.02779.pdf
The only difference from Base policy is `get_initial_features()` method, which has been changed
either to reset RNN context to zero-state or return context from the end of previous episode,
depending on episode metadata received or `lstm_2_init_period' parameter.
"""
def __init__(self, lstm_2_init_period=50, **kwargs):
super(AacStackedRL2Policy, self).__init__(**kwargs)
self.current_trial_num = -1 # always give initial context at first call
self.lstm_2_init_period = lstm_2_init_period
self.current_ep_num = 0
[docs] def get_initial_features(self, state, context=None):
"""
Returns RNN initial context.
RNN_1 (lower) context is reset at every call.
RNN_2 (upper) context is reset:
- every `lstm_2_init_period' episodes;
- episode initial `state` `trial_num` metadata has been changed form last call (new train trial started);
- episode metatdata `type` is non-zero (test episode);
- no context arg is provided (initial episode of training);
- ... else carries context on to new episode;
Episode metadata are provided by DataTrialIterator, which is shaping Trial data distribution in this case,
and delivered through env.strategy as separate key in observation dictionary.
Args:
state: initial episode state (result of env.reset())
context: last previous episode RNN state (last_context of runner)
Returns:
2_RNN zero-state tuple.
Raises:
KeyError if [`metadata`]:[`trial_num`,`type`] keys not found
"""
try:
sess = tf.get_default_session()
new_context = list(sess.run(self.on_lstm_init_state))
if state['metadata']['trial_num'] != self.current_trial_num\
or context is None\
or state['metadata']['type']\
or self.current_ep_num % self.lstm_2_init_period == 0:
# Assume new/initial trial or test sample, reset_1, 2 context:
pass #print('RL^2 policy context 1, 2 reset')
else:
# Asssume same training trial, keep context_2 same as received:
new_context[-1] = context[-1]
#print('RL^2 policy context 1, reset')
# Back to tuple:
new_context = tuple(new_context)
# Keep trial number:
self.current_trial_num = state['metadata']['trial_num']
except KeyError:
raise KeyError(
'RL^2 policy: expected observation state dict. to have keys [`metadata`]:[`trial_num`,`type`]; got: {}'.
format(state.keys())
)
self.current_ep_num +=1
return new_context