Source code for btgym.research.casual_conv.networks

import tensorflow as tf
from tensorflow.contrib.layers import layer_norm as norm_layer

import numpy as np
import math

from btgym.algorithms.nn.layers import conv1d

[docs]def conv_1d_casual_encoder( x, ob_space, ac_space, conv_1d_num_filters=32, conv_1d_filter_size=2, conv_1d_activation=tf.nn.elu, conv_1d_overlap=1, name='casual_encoder', keep_prob=None, conv_1d_gated=False, reuse=False, collections=None, **kwargs ): """ Tree-shaped convolution stack encoder as more comp. efficient alternative to dilated one. Stage1 casual convolutions network: from 1D input to estimated features. Returns: tensor holding state features; """ with tf.variable_scope(name_or_scope=name, reuse=reuse): shape = x.get_shape().as_list() if len(shape) > 3: # remove pseudo 2d dimension x = x[:, :, 0, :] num_layers = int(math.log(shape[1], conv_1d_filter_size)) # print('num_layers: ', num_layers) layers = [] slice_depth = [] y = x for i in range(num_layers): _, length, channels = y.get_shape().as_list() # t2b: tail = length % conv_1d_filter_size if tail != 0: pad = conv_1d_filter_size - tail paddings = [[0, 0], [pad, 0], [0, 0]] y = tf.pad(y, paddings) length += pad # print('padded_length: ', length) num_time_batches = int(length / conv_1d_filter_size) stride = conv_1d_filter_size - conv_1d_overlap assert stride > 0, 'conv_1d_filter_size should be greater then conv_1d_overlap' # print('num_time_batches: ', num_time_batches) y = tf.reshape(y, [-1, conv_1d_filter_size, channels], name='layer_{}_t2b'.format(i)) y = conv1d( x=y, num_filters=conv_1d_num_filters, filter_size=conv_1d_filter_size, stride=stride, pad='VALID', name='conv1d_layer_{}'.format(i) ) # b2t: y = tf.reshape(y, [-1, num_time_batches, conv_1d_num_filters], name='layer_{}_output'.format(i)) y = norm_layer(y) if conv_1d_activation is not None: y = conv_1d_activation(y) if keep_prob is not None: y = tf.nn.dropout(y, keep_prob=keep_prob, name="_layer_{}_with_dropout".format(i)) layers.append(y) depth = conv_1d_overlap // conv_1d_filter_size ** i if depth < 1: depth = 1 slice_depth.append(depth) # encoded = tf.stack([h[:, -1, :] for h in layers], axis=1, name='encoded_state') sliced_layers = [ tf.slice( h, begin=[0, h.get_shape().as_list()[1] - d, 0], size=[-1, d, -1] ) for h, d in zip(layers, slice_depth) ] output_stack = sliced_layers # But: if conv_1d_gated: split_size = int(conv_1d_num_filters / 2) output_stack = [] for l in sliced_layers: x1 = l[..., :split_size] x2 = l[..., split_size:] y = tf.multiply( x1, tf.nn.sigmoid(x2), name='gated_conv_output' ) output_stack.append(y) encoded = tf.concat(output_stack, axis=1, name='encoded_state') # encoded = tf.concat( # [ # tf.slice( # h, # begin=[0, h.get_shape().as_list()[1] - d, 0], # size=[-1, d, -1] # ) for h, d in zip(layers, slice_depth) # ], # axis=1, # name='encoded_state' # ) # print('encoder :', encoded) return encoded
[docs]def attention_layer(inputs, attention_ref=tf.contrib.seq2seq.LuongAttention, name='attention_layer', **kwargs): """ Temporal attention layer. Computes attention context based on last(left) value in time dim. Paper: Minh-Thang Luong, Hieu Pham, Christopher D. Manning., "Effective Approaches to Attention-based Neural Machine Translation." Args: inputs: attention_ref: attention mechanism class name: Returns: attention output tensor """ shape = inputs.get_shape().as_list() source_states = inputs[:, :-1, :] # all but last query_state = inputs[:, -1, :] attention_mechanism = attention_ref( num_units=shape[-1], memory=source_states, #scale=True, name=name, **kwargs ) # alignments = attention_mechanism(query_state, None) # normalized attention weights # Suppose there is no previous context for attention (uhm?): alignments = attention_mechanism( query_state, attention_mechanism.initial_alignments(tf.shape(inputs)[0], dtype=tf.float32) ) # Somehow attention call returns tuple of twin tensors (wtf?): if isinstance(alignments, tuple): alignments = alignments[0] # Compute context vector: expanded_alignments = tf.expand_dims(alignments, axis=-2) # print('attention_mechanism.values:', attention_mechanism.values) # print('alignments: ', alignments) # print('expanded_alignments:', expanded_alignments) context = tf.matmul(expanded_alignments, attention_mechanism.values) # values == source_states # context = tf.squeeze(context, [1]) # attention = tf.layers.Dense(shape-1, name='attention_layer')(tf.concat([query_state, context], 1)) attention = context return attention
[docs]def conv_1d_casual_attention_encoder( x, ob_space, ac_space, conv_1d_num_filters=32, conv_1d_filter_size=2, conv_1d_activation=tf.nn.elu, conv_1d_attention_ref=tf.contrib.seq2seq.LuongAttention, name='casual_encoder', keep_prob=None, conv_1d_gated=False, conv_1d_full_hidden=False, reuse=False, collections=None, **kwargs ): """ Tree-shaped convolution stack encoder with self-attention. Stage1 casual convolutions network: from 1D input to estimated features. Returns: tensor holding state features; """ with tf.variable_scope(name_or_scope=name, reuse=reuse): shape = x.get_shape().as_list() if len(shape) > 3: # remove pseudo 2d dimension x = x[:, :, 0, :] num_layers = int(math.log(shape[1], conv_1d_filter_size)) # print('num_layers: ', num_layers) layers = [] attention_layers = [] y = x for i in range(num_layers): _, length, channels = y.get_shape().as_list() # t2b: tail = length % conv_1d_filter_size if tail != 0: pad = conv_1d_filter_size - tail paddings = [[0, 0], [pad, 0], [0, 0]] y = tf.pad(y, paddings) length += pad # print('padded_length: ', length) num_time_batches = int(length / conv_1d_filter_size) # print('num_time_batches: ', num_time_batches) y = tf.reshape(y, [-1, conv_1d_filter_size, channels], name='layer_{}_t2b'.format(i)) y = conv1d( x=y, num_filters=conv_1d_num_filters, filter_size=conv_1d_filter_size, stride=1, pad='VALID', name='conv1d_layer_{}'.format(i) ) # b2t: y = tf.reshape(y, [-1, num_time_batches, conv_1d_num_filters], name='layer_{}_output'.format(i)) y = norm_layer(y) if conv_1d_activation is not None: y = conv_1d_activation(y) if keep_prob is not None: y = tf.nn.dropout(y, keep_prob=keep_prob, name="_layer_{}_with_dropout".format(i)) if conv_1d_gated: split_size = int(conv_1d_num_filters / 2) y1 = y[..., :split_size] y2 = y[..., split_size:] y = tf.multiply( y1, tf.nn.sigmoid(y2), name="_layer_{}_gated".format(i) ) layers.append(y) # Insert attention for all but top layer: if num_time_batches > 1: attention = attention_layer( y, attention_ref=conv_1d_attention_ref, name='attention_layer_{}'.format(i) ) attention_layers.append(attention) if conv_1d_full_hidden: convolved = tf.concat(layers, axis=-2, name='convolved_stack_full') else: convolved = tf.stack([h[:, -1, :] for h in layers], axis=1, name='convolved_stack') attended = tf.concat(attention_layers, axis=-2, name='attention_stack') encoded = tf.concat([convolved, attended], axis=-2, name='encoded_state') # print('layers', layers) # print('convolved: ', convolved) # print('attention_layers:', attention_layers) # print('attention_stack: ', attended) # print('encoded :', encoded) return encoded