[논문 리뷰] SRCNN - Image Super-Resolution Using Deep Convolutional Networks
ML

[논문 리뷰] SRCNN - Image Super-Resolution Using Deep Convolutional Networks

논문 바로가기 arxiv.org/pdf/1501.00092.pdf

 

이 논문에서는 딥러닝을 이용한 Single Image Super-Resolution 방법을 제안하고 있다. directlow/high-resolution image 사이의 매핑 방법을 end-to-end로 학습시킨다. Mapping 방법은 CNN을 이용하며 low-resolutioninputhigh-resolution의 output으로 만들어낸다.

 

 

SRCNN은 위 그림과 같이 크게 3가지로 구성되어있다.

 

1. Patch extraction and Representation

이 과정에서 우리는 저해상 이미지로부터 patch를 추출해내며 고차원의 벡터로 representation하는 과정을 거친다. 이 때, 이 벡터들은 피쳐맵으로 구성되어있다.

더보기

피쳐맵(feature-map)이란?

피쳐맵은 특정한 가중치의 필터를 통해 나온 Convolution의 결과에서 얻을 수 있으며 필터에서 추출하고자 하는 이미지의 특징을 나타낸다.

 

2. Non-linear mapping

고차원의 벡터를 다른 고차원의 벡터로 non-linear 매핑한다. 여기서 non-linear 매핑이 사용된 이유는 활성함수를 사용하여 매핑하는 과정을 거치기 때문이다. 활성함수는 모두 비선형함수로 이루어져있다. 이는 이전 포스팅에서 확인할 수 있다. 각 매핑된 벡터는 개념적으로 고해상도의 패치를 의미한다.

 

3. Reconstruction

이 연산은 패치로부터 최종 고해상도의 이미지를 만들어낸다.

 

즉, 각 부분들은 Convolutional Layer로 이루어져 있으며, 첫번째 레이어에서는 저해상도 이미지로부터 피쳐를 추출하고, 두번째 레이어에서는 고차원 벡터들 간의 매핑을 하며, 마지막 레이어에서 피쳐로부터 나온 결과를 고해상 이미지로 복원하는 것이다. end-to-end mapping function인 F를 학습시키기 위해서는 파라미터를 추정해야한다. 이 추정치들은 네터워크로부터 나온 Output(고해상 이미지)와 Ground-Truth(정답)간의 Loss를 최소하화도록 설계(최소화하도록 가중치를 최적화해나간다.)되어야한다. 이 논문에서는 Loss Function은 MSE를, Optimization에는 SGD(Stochastic Gradient-Descent)를 사용하였다.

 

Code

구조 요약

  • 3개의 layer
  • 9*1, 1*1, 5*5 사이즈의 커널 (각)
  • activation function : ReLU
# -*- coding: utf-8 -*-
"""Untitled22.ipynb

Automatically generated by Colaboratory.

Original file is located at
    https://colab.research.google.com/drive/1nWVAaXdbV4zQgGuI21W9nnqWEJ_88TTy
"""

from google.colab import drive 
drive.mount('/gdrive')

from google.colab import files
rc = list(files.upload().values())[0]
!unzip yang91.zip

"""import libraries"""

import keras.backend as K
from keras.models import Sequential, Model
from keras.layers import Dense, Conv2D, Activation, Input
from keras import optimizers
from keras.models import load_model
import numpy as np
import scipy.misc
import scipy.ndimage
import cv2
import math
import glob
import matplotlib.pyplot as plt
import PIL
! sudo pip install --upgrade scipy==1.1.0 # for misc.imread

"""build SRCNN model"""

img_shape = (32,32,1)
input_img = Input(shape=(img_shape))
C1 = Conv2D(64,(9,9),padding='SAME',name='CONV1')(input_img)
A1 = Activation('relu', name='act1')(C1)
C2 = Conv2D(32,(1,1),padding='SAME',name='CONV2')(A1)
A2 = Activation('relu', name='act2')(C2)
C3 = Conv2D(1,(5,5),padding='SAME',name='CONV3')(A2)
A3 = Activation('relu', name='act3')(C3)
model = Model(input_img, A3)
opt = optimizers.Adam(lr=0.0003)
model.compile(optimizer=opt,loss='mean_squared_error')
model.summary()

"""Create function to generate High Resolution from interpolation technique to pass in SRCNN model"""

def modcrop(image, scale=2):
   if len(image.shape) == 3:
      h, w, _ = image.shape
      h = h - np.mod(h, scale)
      w = w - np.mod(w, scale)
      image = image[0:h, 0:w, :]
   else:
      h, w = image.shape
      h = h - np.mod(h, scale)
      w = w - np.mod(w, scale)
      image = image[0:h, 0:w]
   return image

def create_LR(image,scale):
   label_ = modcrop(image, scale)
   # Must be normalized
   label_ = label_ / 255.
   input_ = scipy.ndimage.interpolation.zoom(label_, (1./scale), prefilter=False)
   input_ = scipy.ndimage.interpolation.zoom(input_, (scale/1.), prefilter=False)
   return input_

path = 'yang91/'
files_y = glob.glob(path + '*.bmp')

"""Split data into 3 parts, 21 images for each set."""

trainfiles = files_y[:21]
testfiles = files_y[21:42]
valfiles = files_y[42:]

print(len(trainfiles), len(testfiles), len(valfiles))
print(len(files_y))

img_size = 32
stride = 16
scale = 4
X_train = []
Y_train = []
X_test = []
Y_test = []
X_val = []
Y_val = []

"""Extract patch image for test"""

for file_y in testfiles:
   tmp_y = scipy.misc.imread(file_y,flatten=True, mode='YCbCr').astype(np.float)
   tmp_X = create_LR(tmp_y,scale)
   h,w = tmp_y.shape
   for x in range(0, h-img_size+1, stride):
      for y in range(0, w-img_size+1, stride):
         sub_input = tmp_X[x:x+img_size, y:y+img_size].reshape(img_size,img_size,1) # [32 x 32]
         sub_label = tmp_y[x:x+img_size, y:y+img_size].reshape(img_size,img_size,1) # [32 x 32]
         X_test.append(sub_input)
         Y_test.append(sub_label)

"""Extract patch image for training"""

for file_y in trainfiles:
   tmp_y = scipy.misc.imread(file_y,flatten=True, mode='YCbCr').astype(np.float)
   tmp_X = create_LR(tmp_y,scale)
   h,w = tmp_y.shape
   for x in range(0, h-img_size+1, stride):
      for y in range(0, w-img_size+1, stride):
         sub_input = tmp_X[x:x+img_size, y:y+img_size].reshape(img_size,img_size,1)
         sub_label = tmp_y[x:x+img_size, y:y+img_size].reshape(img_size,img_size,1)
         X_train.append(sub_input)
         Y_train.append(sub_label)

"""Extract patch image for validation"""

for file_y in valfiles:
   tmp_y = scipy.misc.imread(file_y,flatten=True, mode='YCbCr').astype(np.float)
   tmp_X = create_LR(tmp_y,scale)
   h,w = tmp_y.shape
   for x in range(0, h-img_size+1, stride):
      for y in range(0, w-img_size+1, stride):
         sub_input = tmp_X[x:x+img_size, y:y+img_size].reshape(img_size,img_size,1) # [32 x 32]
         sub_label = tmp_y[x:x+img_size, y:y+img_size].reshape(img_size,img_size,1) # [32 x 32]
         X_val.append(sub_input)
         Y_val.append(sub_label)

X_train = np.array(X_train)
Y_train = np.array(Y_train)
X_val = np.array(X_val)
Y_val = np.array(Y_val)
X_test = np.array(X_test)
Y_test = np.array(Y_test)

model.fit(X_train, Y_train, batch_size = 32, epochs = 10, validation_data=(X_val, Y_val))

predictions = model.predict(X_test[0])

print(X_test[0].shape)

print(predictions.shape)