TensorFlowでUnpooling Layerの実装

やりたいこと

これ

f:id:usagisagi:20180704212410j:plain

[1711.08763] DeepPainter: Painter Classification Using Deep Convolutional Autoencoders

やったこと

全コードは下記参照。

memo/unpool.py at master · usagisagi/memo · GitHub

nn.max_pool_with_argmaxscatter_ndを使う。max_pool_with_argmaxする前のshapeを取って置くことと、scatter_ndを1次元に配置してから行うことがコツ。

def pooling(input: Tensor,
            ksize,
            strides,
            padding: str = 'VALID',
            name: str = 'pooling') -> Tuple[Tensor, Tensor, TensorShape]:
    """
    pooling層。実質pool_with_arg_max。

    :param input:
    :param ksize:
    :param strides:
    :param padding:
    :param name:
    :return:
        次ノードへのTensor, argmax、inputのshape
    """
    with tf.name_scope(name):
        input_shape = input.get_shape()
        output, argmax = \
            tf.nn.max_pool_with_argmax(input, ksize, strides, padding)
        return output, argmax, input_shape


def unpooling(input: Tensor,
              argmax: Tensor,
              output_shape: TensorShape,
              name='unpooling'):
    """
    unpooling層

    :param input:
    :param argmax:
    :param output_shape:
    :return:
        次ノードへのTensor
    """

    with tf.name_scope(name):
        # inputを直線に並べる
        input_stride = tf.reshape(input, [-1])

        # argmaxを直線に並べる。1行1要素の2次元になる。
        argmax_stride = tf.reshape(argmax, [-1, 1])

        # outputの要素数
        num_elem = output_shape.num_elements()

        # inputをargmaxに従い、num_elemの1行Tensorに再配置
        output_stride = tf.scatter_nd(argmax_stride, input_stride, [num_elem])

        output = tf.reshape(output_stride, output_shape)

        return output