TensorFlowでUnpooling Layerの実装
やりたいこと
これ
[1711.08763] DeepPainter: Painter Classification Using Deep Convolutional Autoencoders
やったこと
全コードは下記参照。
memo/unpool.py at master · usagisagi/memo · GitHub
nn.max_pool_with_argmax
とscatter_nd
を使う。max_pool_with_argmax
する前のshapeを取って置くことと、scatter_nd
を1次元に配置してから行うことがコツ。
def pooling(input: Tensor, ksize, strides, padding: str = 'VALID', name: str = 'pooling') -> Tuple[Tensor, Tensor, TensorShape]: """ pooling層。実質pool_with_arg_max。 :param input: :param ksize: :param strides: :param padding: :param name: :return: 次ノードへのTensor, argmax、inputのshape """ with tf.name_scope(name): input_shape = input.get_shape() output, argmax = \ tf.nn.max_pool_with_argmax(input, ksize, strides, padding) return output, argmax, input_shape def unpooling(input: Tensor, argmax: Tensor, output_shape: TensorShape, name='unpooling'): """ unpooling層 :param input: :param argmax: :param output_shape: :return: 次ノードへのTensor """ with tf.name_scope(name): # inputを直線に並べる input_stride = tf.reshape(input, [-1]) # argmaxを直線に並べる。1行1要素の2次元になる。 argmax_stride = tf.reshape(argmax, [-1, 1]) # outputの要素数 num_elem = output_shape.num_elements() # inputをargmaxに従い、num_elemの1行Tensorに再配置 output_stride = tf.scatter_nd(argmax_stride, input_stride, [num_elem]) output = tf.reshape(output_stride, output_shape) return output