技术专栏
风格迁移 Style transfer
一、介绍
- 将一张图片的艺术风格应用在另外一张图片上
- 使用深度卷积网络CNN提取一张图片的内容和提取一张图片的风格, 然后将两者结合起来得到最后的结果
二、 方法
- 我们知道
CNN
可以捕捉图像的高层次特征,如上图所示,内容图片经过CNN
可以得到对应的图像表述(representation
, 就是经过卷积操作的feature map
),然后经过重构可以得到近似原图的效果- 特别是前面几层经过重构得到的结果和原图更接近,也说明前几层保留的图片细节会更多,因为后面还有
pooling
层,自然会丢弃调一些信息 - 这里的网络使用的是
VGG-16
(如下图),包含13
个卷积层,3
个全连接层
- 特别是前面几层经过重构得到的结果和原图更接近,也说明前几层保留的图片细节会更多,因为后面还有
1、内容损失
- 假设一个卷积层包含 ${N_l}$ 个过滤器
filters
,则可以得到 ${N_l}$ 个feature maps
,假设feature map
的大小是 $M_l$ (长乘宽),则可以通过一个矩阵来存储l
层的数据 $$F^l \in R^{N_l \times M_l} $$- $F^l_{i,j}$ 表示第
l
层的第i
个filter
在j
位置上的激活值
- $F^l_{i,j}$ 表示第
- 所以现在一张内容图片$\overrightarrow p$,一张生成图片$\overrightarrow x$(初始值为高斯分布), 经过一层卷积层l可以得到其对应的特征表示:$P^l$ 和 $Fl$, 则对应的损失采用均方误差: $$L{content}(\overrightarrow p, \overrightarrow x, l) = {1 \over 2} \sum{ij}(F^l{ij}-P^l_{ij})^2$$
- $F$ 和 $P$是两个矩阵,大小是$N_l \times M_l$,即
l
层过滤器的个数 和feature map
的长乘宽的值
- $F$ 和 $P$是两个矩阵,大小是$N_l \times M_l$,即
2、风格损失
-
风格的表示这里采用格拉姆矩阵(
Gram Matrix
): $G^l \in R^{N_l \times Nl}$ $$G^l{ij} = {\sumk F^l{ik}F^l_{jk}}$$- 格拉姆矩阵计算的是两两特征的相关性 , 即哪两个特征是同时出现的,哪两个特征是此消彼长的等,能够保留图像的风格
- ( 比如一幅画中有人和树,它们可以出现在任意位置,格拉姆矩阵可以衡量它们之间的关系,可以认为是这幅画的风格信息 )
-
假设$\overrightarrow a$是风格图像,$\overrightarrow x$是生成图像,$A^l$ 和 $G^l$ 表示在 $l$ 层的格拉姆矩阵,则这一层的损失为:$$E_l = {1 \over 4N^2_lM^2l}{\sum{i,j} (G^l{ij}-A^l{ij})^2}$$
-
提取风格信息是我们会使用多个卷积层的输出,所以总损失为:$$L_{style}(\overrightarrow a, \overrightarrow x) = {\sum^L_lw_lE_l}$$
- 这里$w_l$是每一层损失的权重
3、总损失函数
- 通过白噪声初始化(就是高斯分布)一个输出的结果,然后通过网络对这个结果进行风格和内容两方面的约束进行修正
$$L{total}(\overrightarrow p,\overrightarrow a,\overrightarrow x)=\alpha L{content}(\overrightarrow p, \overrightarrow x) +\beta L_{style}(\overrightarrow a, \overrightarrow x)$$
三、代码实现
1、说明
- 全部代码:点击查看
- 图像使用一张建筑图和梵高的星空
2、加载并预处理图片和初始化输出图片
- 输出图片采用高斯分布初始化
import numpy as np
from keras import backend as K
from keras.applications.vgg16 import preprocess_input
from keras.preprocessing.image import load_img, img_to_array
from keras.applications import VGG16
from scipy.optimize import fmin_l_bfgs_b
from matplotlib import pyplot as plt
'''图片路径'''
content_image_path = './data/buildings.jpg'
style_image_path = './data/starry-sky.jpg'
generate_image_path = './data/output.jpg'
'''加载图片并初始化输出图片'''
target_height = 512
target_width = 512
target_size = (target_height, target_width)
content_image = load_img(content_image_path, target_size=target_size)
content_image_array = img_to_array(content_image)
content_image_array = K.variable(preprocess_input(np.expand_dims(content_image_array, 0)), dtype='float32')
style_image = load_img(style_image_path, target_size=target_size)
style_image_array = img_to_array(style_image)
style_image_array = K.variable(preprocess_input(np.expand_dims(style_image_array, 0)), dtype='float32')
generate_image = np.random.randint(256, size=(target_width, target_height, 3)).astype('float64')
generate_image = preprocess_input(np.expand_dims(generate_image, 0))
generate_image_placeholder = K.placeholder(shape=(1, target_width, target_height, 3))
3、获取网络中对应层的输出
def get_feature_represent(x, layer_names, model):
'''图片的特征图表示
参数
----------------------------------------------
x : 输入,
这里并没有使用,可以看作一个输入的标识
layer_names : list
CNN网络层的名字
model : CNN模型
返回值
----------------------------------------------
feature_matrices : list
经过CNN卷积层的特征表示,这里大小是(filter个数, feature map的长*宽)
'''
feature_matrices = []
for ln in layer_names:
select_layer = model.get_layer(ln)
feature_raw = select_layer.output
feature_raw_shape = K.shape(feature_raw).eval(session=tf_session)
N_l = feature_raw_shape[-1]
M_l = feature_raw_shape[1]*feature_raw_shape[2]
feature_matrix = K.reshape(feature_raw, (M_l, N_l))
feature_matrix = K.transpose(feature_matrix)
feature_matrices.append(feature_matrix)
return feature_matrices
4、内容损失函数
def get_content_loss(F, P):
'''计算内容损失
参数
---------------------------------------
F : tensor, float32
生成图片特征图矩阵
P : tensor, float32
内容图片特征图矩阵
返回值
---------------------------------------
content_loss : tensor, float32
内容损失
'''
content_loss = 0.5*K.sum(K.square(F-P))
return content_loss
5、Gram矩阵和风格损失
def get_gram_matrix(F):
'''计算gram矩阵'''
G = K.dot(F, K.transpose(F))
return G
def get_style_loss(ws, Gs, As):
'''计算风格损失
参数
---------------------------------------
ws : array
每一层layer的权重
Gs : list
生成图片每一层得到的特征表示组成的list
As : list
风格图片每一层得到的特征表示组成的list
'''
style_loss = K.variable(0.)
for w, G, A in zip(ws, Gs, As):
M_l = K.int_shape(G)[1]
N_l = K.int_shape(G)[0]
G_gram = get_gram_matrix(G)
A_gram = get_gram_matrix(A)
style_loss += w*0.25*K.sum(K.square(G_gram-A_gram))/(N_l**2*M_l**2)
return style_loss
6、总损失
def get_total_loss(generate_image_placeholder, alpha=1.0, beta=10000.0):
'''总损失
'''
F = get_feature_represent(generate_image_placeholder, layer_names=[content_layer_name], model=gModel)[0]
Gs = get_feature_represent(generate_image_placeholder, layer_names=style_layer_names, model=gModel)
content_loss = get_content_loss(F, P)
style_loss = get_style_loss(ws, Gs, As)
total_loss = alpha*content_loss + beta*style_loss
return total_loss
def calculate_loss(gen_image_array):
'''调用总损失函数,计算得到总损失数值'''
if gen_image_array != (1, target_width, target_height, 3):
gen_image_array = gen_image_array.reshape((1, target_width, target_height, 3))
loss_fn = K.function(inputs=[gModel.input], outputs=[get_total_loss(gModel.input)])
return loss_fn([gen_image_array])[0].astype('float64')
7、损失函数梯度
def get_grad(gen_image_array):
'''计算损失函数的梯度'''
if gen_image_array != (1, target_width, target_height, 3):
gen_image_array = gen_image_array.reshape((1, target_width, target_height, 3))
grad_fn = K.function([gModel.input], K.gradients(get_total_loss(gModel.input), [gModel.input]))
grad = grad_fn([gen_image_array])[0].flatten().astype('float64')
return grad
8、生成结果后处理
-
因为之
前preprocess_input
函数中做了处理,这里进行逆处理还原def postprocess_array(x): '''生成图片后处理,因为之前preprocess_input函数中做了处理,这里进行逆处理还原 ''' if x.shape != (target_width, target_height, 3): x = x.reshape((target_width, target_height, 3)) x[..., 0] += 103.939 x[..., 1] += 116.779 x[..., 2] += 123.68 x = x[..., ::-1] # BGR-->RGB x = np.clip(x, 0, 255) x = x.astype('uint8') return x
9、定义模型并优化
'''定义VGG模型'''
tf_session = K.get_session()
cModel = VGG16(include_top=False, input_tensor=content_image_array)
sModel = VGG16(include_top=False, input_tensor=style_image_array)
gModel = VGG16(include_top=False, input_tensor=generate_image_placeholder)
content_layer_name = 'block4_conv2'
style_layer_names = [
'block1_conv1',
'block2_conv1',
'block3_conv1',
'block4_conv1'
]
'''得到对应的representation矩阵'''
P = get_feature_represent(x=content_image_array, layer_names=[content_layer_name], model=cModel)[0]
As = get_feature_represent(x=style_image_array, layer_names=style_layer_names, model=sModel)
ws = np.ones(len(style_layer_names))/float(len(style_layer_names))
'''使用fmin_l_bfgs_b进行损失函数优化'''
iterations = 600
x_val = generate_image.flatten()
xopt, f_val, info = fmin_l_bfgs_b(func=calculate_loss, x0=x_val, fprime=get_grad, maxiter=iterations, disp=True)
x_out = postprocess_array(xopt)
10、输出结果
- 初始化输出图片
-
迭代200次,${\beta \over \alpha} = 10^3$
-
迭代
500
轮,${\beta \over \alpha} = 10^4$
四、总结
style tranfer
通过白噪声初始化(就是高斯分布)一个输出的结果,然后通过优化损失对这个结果进行风格和内容两方面的约束修正- 图片的风格信息使用的是 Gram矩阵来表示
- 其中超参数风格损失的权重
ws
、内容损失和风格损失的权重$\alpha$, $\beta$可以进行调整查看结果- 论文给出的${\beta \over \alpha} = 10^3或10^4$结果较好,可以自己适当增加看看最后的结果
Reference
声明:本文内容由易百纳平台入驻作者撰写,文章观点仅代表作者本人,不代表易百纳立场。如有内容侵权或者其他问题,请联系本站进行删除。
红包
点赞
收藏
评论
打赏
- 分享
- 举报
评论
0个
手气红包
暂无数据
相关专栏
-
浏览量:9844次2021-12-31 09:00:12
-
浏览量:656次2023-12-14 16:38:19
-
浏览量:768次2023-07-05 10:11:45
-
浏览量:2432次2017-11-23 18:22:37
-
浏览量:1959次2020-04-08 11:00:23
-
浏览量:1002次2021-12-02 10:10:05
-
浏览量:1011次2023-02-13 10:57:00
-
浏览量:7144次2021-07-12 11:03:00
-
浏览量:3876次2021-04-02 15:53:41
-
浏览量:1804次2019-12-12 15:38:37
-
浏览量:8607次2021-04-26 17:26:00
-
浏览量:1476次2020-06-22 19:13:17
-
浏览量:3891次2021-08-23 15:13:18
-
浏览量:2650次2022-01-13 09:00:13
-
浏览量:4086次2024-05-28 16:26:51
-
浏览量:285次2023-08-03 15:28:32
-
浏览量:167次2023-08-16 18:03:17
-
浏览量:2156次2017-10-16 20:50:06
-
浏览量:1984次2020-07-02 16:19:22
置顶时间设置
结束时间
删除原因
-
广告/SPAM
-
恶意灌水
-
违规内容
-
文不对题
-
重复发帖
打赏作者
lawlite19
您的支持将鼓励我继续创作!
打赏金额:
¥1
¥5
¥10
¥50
¥100
支付方式:
微信支付
打赏成功!
感谢您的打赏,如若您也想被打赏,可前往 发表专栏 哦~
举报反馈
举报类型
- 内容涉黄/赌/毒
- 内容侵权/抄袭
- 政治相关
- 涉嫌广告
- 侮辱谩骂
- 其他
详细说明
审核成功
发布时间设置
发布时间:
请选择发布时间设置
是否关联周任务-专栏模块
审核失败
失败原因
请选择失败原因
备注
请输入备注