Pix2Pix로 스케치그림을 사진으로 바꾸기 -2- Generator
Pix2Pix가 대충 어떤 구조로 되어 있는지 알았으니 이제 이미지를 생성할 생성모델을 만들어보자.
입력 데이터로 이미지를 넣어 이미지 데이터를 출력하는 것에 encoder-decoder 구조를 많이 사용한다. 이미지를 분석하고 처리하는 CNN의 Convolution layer는 이미지의 픽셀을 잘라내는 방식으로 연산을 하기 때문에 다시 이미지를 출력해 내기 위해서는 이러한 인코딩 된 이미지를 다시 디코딩 해야 한다.
정리하자면, 일반적인 CNN 레이어의 진행 방식인 다운 샘플링된 이미지 데이터를 다시 되돌리기 위해 업샘플링을 해야 한다. 이에 편리한 함수가 Upsampling()이다. 여기에 Conv2D를 이어 붙여 주면 잘라내어 졌던 가로 세로 길이는 다시 늘어난다. 이 두가지 레이어를 합친 기능을 가진 것이 Conv2DTranspose이다.
문제는 이런 인코딩-디코딩 방식은 이미지를 잘라내어 연산하고 다시 증폭하는 과정에서 필연적인 데이터 손실이 발생한다는 것이다. 계산하기 편하려고 잘라낸 이미지를, 컴퓨터가 알아서 친절하게 사람이 알아보기 쉬운 방식으로 다시 이어붙여줄 리는 없지 않은가? 하지만 디코딩을 할 때, 인코딩에서 잘라낸 레이어를 그대로 다시 가져온다면 데이터 손실을 방지할 수 있을 것이다. 이런 인코딩-디코딩 방식을 U-Net 구조라고 한다.
(유명한 Encoder-Decoder 방식과 U-Net 방식의 이미지 데이터 손실 관련 비교짤)
인코딩 레이어와 디코딩 레이어는 각각 레이어를 반환하는 함수로 만들었다.
이걸 이제 생성자에 하나하나 쌓을 것이다.
# Encoder (downsampling)
def encoder_layer(layer_in, n_filters, batchnorm=True):
# 가중치 초기화
init = RandomNormal(stddev=0.02)
# layer 추가
g = Conv2D(n_filters, (4,4), strides=(2,2), padding='same',
kernel_initializer=init)(layer_in)
# BatchNormalization이 True일 경우에만 추가
if batchnorm:
g = BatchNormalization()(g, training=True)
g = LeakyReLU(alpha=0.2)(g)
return g
# Decoder (upsampling)
def decoder_layer(layer_in, skip_in, n_filters, dropout=True):
# 가중치 초기화
init = RandomNormal(stddev=0.02)
# layer 추가
g = Conv2DTranspose(n_filters, (4,4), strides=(2,2), padding='same',
kernel_initializer=init)(layer_in)
g = BatchNormalization()(g, training=True)
# Dropout이 True일 경우에만 추가
if dropout:
g = Dropout(0.5)(g, training=True)
# Encoder layer에서 Activation 거치기 전의 output 복사 & merge
g = Concatenate()([g, skip_in])
g = Activation('relu')(g)
return g
Decoder에 Concatenate 함수가 보이는가? 이게 매개변수로 받아 온 인코딩 된 레이어와 디코딩 된 레이어를 병합해 줄 것이다. 인코딩 된 레이어를 그대로 병합했으니, 우리는 업샘플링을 할 때 우리가 원하는 방식으로 데이터를 복원할 수 있다.
모델 서머리를 통해 형태를 보면 다음과 같다....
인코딩 레이어를 이어붙여 이미지를 다운샘플링 한 결과이다. 256x256이던 이미지의 픽셀이 갈수록 줄어들어 1x1이 된 걸 볼 수 있다.
이건 디코딩 레이어를 쌓아 업샘플링한 결과다. shape는 늘어나고 늘어날 때마다 인코딩 레이어와의 병합이 이루어진다.
이제 이 레이어들을 쌓아서 생성자 모델을 완성시킬 수 있다.
# 생성자 (Generator)
def generator(image_shape=(256,256,3)):
#가중치 초기화
init = RandomNormal(stddev=0.02)
# image input
image_in = Input(shape=image_shape)
# Encoder
e1 = encoder_layer(image_in, 64, batchnorm=False)
e2 = encoder_layer(e1, 128)
e3 = encoder_layer(e2, 256)
e4 = encoder_layer(e3, 512)
e5 = encoder_layer(e4, 512)
e6 = encoder_layer(e5, 512)
e7 = encoder_layer(e6, 512)
# 병목 현상 방지
b = Conv2D(512, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(e7)
b = Activation('relu')(b)
# Decoder
d1 = decoder_layer(b, e7, 512)
d2 = decoder_layer(d1, e6, 512)
d3 = decoder_layer(d2, e5, 512)
d4 = decoder_layer(d3, e4, 512, dropout=False)
d5 = decoder_layer(d4, e3, 256, dropout=False)
d6 = decoder_layer(d5, e2, 128, dropout=False)
d7 = decoder_layer(d6, e1, 64, dropout=False)
# output
g = Conv2DTranspose(3, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(d7)
image_out = Activation('tanh')(g)
# 모델 정의
model = Model(image_in, image_out)
return model
이제 생성자 모델을 완성했으니 판별자를 만들어 이 둘을 싸움 붙여야 한다!
제목에 생성자를 써버렸으므로 판별자는 다음 포스팅에.... 하지만 판별자는 자신에게 입력 된 이미지가 가짜냐 진짜냐를 판별해서 이진분류로 결과를 내기만 하면 되기에 정말 쉽다! 중요한 건 이 둘을 어떻게 붙여서 적대적인 관계를 형성시키느냐 하는 것이다.