一 模型介绍
是一个CNN+光流对其+双向时许传播的视频增强模型,适合做
视频超分辨率、视频去噪、视频去模糊、视频压缩伪影修复、一般视频增强
BasicVSR系列的核心思想是;不要只增强单帧,而是利用前后多帧的信息。BasicVSR++论文也明确说,基于recurrent structure 通过双向传播和特征对其来利用整个视频序列的信息,BasicVSR++进一步加入二阶传播和flow-guided deformable alignment,增强了对错位视频帧的时空信息利用。
1.1 为何BasicVSR-like模型合适?
模型 适合程度 原因
BasicVSR-lite 最推荐 结构清楚,CNN为主,适合自己实现
EDVR 也推荐,经典视频增强模型,deformable conv实现稍复杂
FastDVDnet 适合视频去噪,不依赖光流,速度快,结构相对直接
BasicVSR++ 效果更强 但结构比BasicVSR更复杂
RealBasicVSR 真实视频超分,适合真实退化视频,但是训练场领略更复杂
优先级
BasicVSR-lite
EDVR
BasicVSR++ / RealBasicVSR
1.2 BasicVSR-lite的整体结构
输入不是一张图,而是一段连续视频帧
输入低质量视频帧
frame_1, frame_2, frame_3, ...frame_T
张量形状一般是
[B, T, C,H,W]
B = batch size
T 连续帧数量,比如7或15
C = 3,RGB 通道数量
H = 图像高度
W= 图像宽度
整体网络可以写成
连续低质量帧
->每帧CNN提取特征
->光流估计/特征对齐
->反向时序传播
->正向时序传播
->特征融合
->重建网络
->增强后视频帧
1.3 网络结构详细拆解
1.3.1 输入
假设一次输入7帧
x = [frame_1, frame_2, ... frame_7]
shape 是
- shape = [8, 7, 3, H, W]
如果做4倍超分,输入可能是
低清视频帧,[B, 7, 3, 64, 64]
高清目标帧 [B, 7, 3, 256, 256]
如果做去噪,去模糊,压缩伪影修复,输入和输出尺寸通常一样
低质量帧 [B,7,3,H,W]
高质量帧 [B, 7, ,3, H, W]
1.3.2 每帧特征提取CNN
frame_t ->Conv->RsBlocks->feature_t
feature_t.shape = [B, 64, H, W]
这里的CNN可以用
COnv2d ResidualBlock
ReLU / LeakyReLU
这部分和人脸模型ResNet思路类似,输出不是512维向量,而是保留二维特征图
人脸识别
[B, 3, 112, 112]->CNN->[B, 512]
视频增强
[B, 3,H,w]->CNN->[B, 64, H, W]
视频增强不能太早flatten, 因为需要恢复图像细节
1.3.3 光流估计/帧间对齐
视频增强最大的问题是:
相邻帧内容相似,物理会运动
BasicVSR类模型通常使用光流网络,比如SPyNet, 来估计相邻帧之间的运动,BasicVSR++补充材料里也提到使用pretrained SPyNet作为flow network
光流可以理解成
第t帧的每个像素,应该往哪里移动,才能对齐到t+1帧
feature_{t-1}
根据optical flow warp
对齐到feature_t
1.3.4 双向时许传播
这是BasicVSR的核心
看当前帧附近的几帧,让信息沿着时间传播。
反向缠传播
从视频最后一帧往前传
frame_T->frame_{T-1}->...frame_1
得到每一帧的backward feature
backward_feature_t
正向传播
再从第一帧往后传
frame_1->frame_2 ...frame_T
得到每一帧的forward feature
forward feature_t
最后第t帧可以利用
当前帧特征
前面帧传来的信息
后面帧传来的信息
enhanced_feature_t = fuse(
current_feature_t,
forward_feature_t,
backward_feature_t
)
1.3.5 重建网络
融合后的特征再经过CNN重建成图像
如果是去噪/去模糊/压缩增强
[B,64,H,W]->Conv->[B,3,H,W]
如果是视频超分辨率,
[B, 64, H, W]
->PixelShuffle x2
->PixelShuffle x2
->[B, 3, 4H, 4W]
1.4 如果用EDVR
EDVR时CVPRW 2019的视频恢复模型,
EDVR沦为提示两个关键模块
PCD Alignment 金字塔,及联,可变形卷积对齐
TSA Fusion 时许和空间注意力融合
1.5 需要去噪FastDVDnet
去噪->降低ISO噪声->减少暗光噪声->视频画面变干净
FastDVDnet是CVPR 2020的视频去噪模型,官方仓库提供Pytorch实现,说明它不适用光流估计的视频去噪算法。
不用光流->结构相对简单->速度快->适合视频去噪入门
二 代码实现
import torch
import torch.nn as nn
import torch.nn.functional as F
def flow_wrap(
x,flow,padding_mode="border",
align_corners=True
):
#使用光流对特征图做wrap对齐
#参数,x: 要背对齐的特征图,shape=[B,C,H,W]
# flow: 光流,shape=[B,2,H,W]
#flow]:, 0, :, :[表示x方向位移,横向位移
#flow[:,1,:,:] 表示y方向位移,就是纵向位移
#padding_mode:
grid_sample 越界采样时的填充方式。
"border" 表示越界时使用边界像素。
# align_corners:
grid_sample的坐标对齐方式
返回 warped_x: 根据flow对齐后的特征图,shape=[B,C,H,W]
#取出输入特征图的batch size,通道数,高,宽
b,c,h,w = x.size()
#确保flow的数据类型和x一致,避免AMP/FP16 时类型冲突
flow = flow.to(dtype=x.dtype)
#生成y坐标网络,范围时0到H-1
#生成x坐标网络,范围是0到W-1
grid_y, grid_x = torch.meshgrid(
torch.arange(0, h, device=x.device, dtype=x.dtype),
torch.arange(0, w, device=x.device, dtype=x.dtype),
indexing="ij"
)
#grid_x原本shape是[H,W]
扩展成[B,H,W] 方便和batch内每张图的flow相加
grid_x = grid_x.unsqueeze(0).expand(b, -1, -1)
#grid_y 原本shape是[H,W]
#扩展成[B,H,W]
grid_y, grid_y.unsqueeze(0).expand(b, -1, -1)
#当前像素为止x坐标 + 光流横向位移
#得到需要从原特征图哪个x为止采样
vgrid_x = grid_x + flow[:,0,;,);]
#当前像素位置坐标y坐标+光流纵向位移
得到需要从原特征图哪个y位置采样
vgrid_y = grid_y + flow[]
#grid_sample要求坐标范围时[-1, 1]
#所以要把像素坐标[0, W-1]转换成[-1, 1]
if w > 1:
vgrid_x = 2.0 * vgrid_x / (w-1) - 1.0
else
vgrid_x = torch.zeros_like(vgrid_x)
#把像素坐标[0, H-1]转换成[-1, 1]
if h > 1:
vgrid_y = 2.0 * vgrid_y / (h - 1) - 1.0
else:
vgrid_y = torch.zeros_like(vgrid_y)
#grid_sample要求最后一维时[x,y ]
#所以这里吧x坐标和y坐标对跌倒最后一维
grid = torch.stack(vgrid_x, vgrid_y), dim=-1
#根据grid 从x中采样,得到warp后的特征图
warped_x = F.grid_sample(
x,grid, mode="bilinear", padding_mode=padding_mode,
align_cornors=align_cornors,
)
#返回对齐后的特征图
return warped_x
class ResidualBlockNoBN(nn.Module):
# 不带BatchNorm 的残差块
#视频增强,超分模型里经常不用BatchNorm
#因为BatchNorm可能影响图像恢复的细节和数值范围
def __init__(self, channels, res_scale=1.0):
#channels 输入和输出通道数
#res_scale 残差缩放系数
#可以让残差分支更稳定
#初始化nn.Module父类
super().__init__()
#第一个3x3卷积,通道数不变
self.conv1 = nn.Conv2d(
channels,
chnanels,
kernel_size=3,
stride=1,
padding=1,
)
#第二个3x3卷积,通道数不变
self.conv2 = nn.Conv2d(
channels,
channels,
kernel_size=3,
stride=1,
padding=1,
)
#使用LeakyReLU作为激活函数
self.relu = nn.LeakyReLU(
negative_slope = 0.1
inplace=True
)
#保存残差缩放系数
self.res_scale = res_scale
def forward(self, x)
#前向传播,输入x shape = [B,C,H,W]
#输出 out shpe = [B,C,H,W]
#保存原始输入,用于残差链接
identity - x
#第一个卷积
out = self.conv1(x)
#激活函数
out = self.relu(out)
#第二个卷积
out = self.conv2(out)
#残差链接,输出 = 原输入 + 残差分支
out = identity + out * self.res_scale
#返回残差块输出
return out
class ResidualBlockWithInputConv(nn.Module):
#信用一个卷积吧输入通道变成mid_channels
再接多个残差块
#
def __init__(
self,in_channels,mid_channels,num_block
);
"""
参数:
in_channels:
输入通道数。
mid_channels:
中间特征通道数。
num_blocks:
残差块数量。
"""
#初始化父类
super().__init__()
#用list存放网络层
layers = []
#输入卷积,吧in_channels变成mid_channels
layers.append(
nn.Conv2d(
in_channels,
mid_channels,
kernel_size=3,
stride=1,
padding=1,
)
)
#激活函数
layers.append(
nn.LeakyReLU(
negative_slope=0.1,inplace=True
)
)
#堆叠多个残差块
for _ in range(num_blocks):
layers.append(
ResidualBlockNoBN(
channels=mid_channels,
)
)
#把所有层组成一个Sequential
self.main = nn.Sequential(*layers)
def forward(self, x):
"""
前向传播。
"""
# 直接把输入送进 Sequential
return self.main(x)
class TinyFlowNet(nn.Module):
#非常简化的光流网络
#不是论文里的BasicVSR里的SpyNet 这是为了吧BasicVSR-like结构跑通
#输入 img_ref 参考帧,shape=[B,3,H,W]
img_supp
支撑帧,相邻帧
shape = [B,3,,H,W]
def __init__(self, max_flow=20.0)
参数,max_flow 限制预测光流的最大像素位移
#初始化父类
super().__init__()
#保存最大光流范围
self.max_flow = max_flo
e#输入是两张RGB图拼接, 所有通道数是6
self.body = nn.Sequential(
#第一层卷积,提取浅层特征
nn.Conv2d(6, 32, kernel_size=7, stride=1, padding=3),
nn.LeakyReLU(0.1, inplace=True),
#下采样一次,扩大感受
nn.Conv2d(32, 64, kernel_size=5, stride=2, padding=2),
nn.LeakyReLU(0.1, inplace=True),
#中间卷积
nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
nn.LeakyReLU(0.1, inplace=True),
#再下次阿阳一次,继续扩大感受
nn.Conv2d(64, 96, kernel_size=3, stride=2, padding=1),
nn.LeakyReLU(0.1, inplace=True),
#中间卷积
nn.Conv2d(96, 96, kernel_size=3, stride=1, padding=1),
nn.LeakyReLU(0.1, inplace=True),
#上采样回较高分辨率
nn.ConvTranspose2d(96, 64, kernel_size=4, stride=2, padding=1),
nn.LeakyReLU(0.1, inplace=True),
#再上采样回原图分辨率附近
nn.Conv2d(64, 96, kernel_size=3, stride=2, padding=1),
nn.LeakyReLU(0.1, inplace=True)
#中间卷积
nn.Conv2d(96, 96, kernel_size=3, stride=1, padding=1),
nn.LeakyReLU(0.1, inplace=True)
#上采样回校高分辨率
nn.ConvTranspose2d(96, 64, kernel_size=4, stride=2, padding=1),
nn.LeakyReLU(0.1, inplace=True)
#再上采样回调分辨率附近
nn.ConvTranspose2d(64, 32, kernel_size=4,stride=2,padding=1)
nn.LeakyReLU(0.1, inplace=True)
#输出 2通道光流,dx和dy
nn.Conv2d(32, 2, kernel_size=3, stride=1, padding=1)
)
def forward(self, img_ref, img_supp):
##前向传播
#记录原始图像的高和宽
#shape [B, 6, H, W]
inp = torch.cat([img_ref, img_supp], dim=1)
#预测光流
flow = self.body(inp)
#如果因为下采样,上采样导致尺寸略有差异,就插值回原尺寸
if flow.shape[-2:] != (h, w):
flow = F.interpolate(
flow,
size(h,w),
mode="bilinear",
align_corners=False
)
#用tanh 把输出限制到[-1,1]
#再乘max_flow 得到像素激光流范围
flow = torch.tanh(flow) * self.max_flow
#返回光流
return flow
class BasicVSRLite(nn.Module):
#教学版,BasicVSR-lite
#整体结构 输入视频帧序列
#CNN提取每帧特征
#估计相邻帧光流
#反向时间传播T->1
正向时间传播1->T
当前帧特征 + 反向传播特征 + 正向传播特征 融合
重建增强帧 / 超分帧
输入:
x shpe = [B, T, 3, H, W]
输出
scale = 1:
out shape=[B, T, 3, H, W]
scale=2:
out shape = [B, T, 3, 2H, 2W]
scale = 4:
out shape = [B, T, 3, 4h, 4w]
def __init__(
self, mid_channels=64,
num_feature_blocks=5,
num_propagation_blocks=7,
scale=1,
max_flow=20.0,
) :
参数
mid_channels:中间特征通道
64 是常见的轻量配置
num_feature_blocks:
每帧特征提取阶段的残差块数量
num_propagation_blocks
正向/反向传播阶段的残差块数量
num_reconstruction_blocks
重建阶段的残差块数量
scale:
放大倍率
scale=1 表示输入输出同尺寸,同于去噪,去模糊,增强
scale=2 表示2倍超分
scale=4 表示4倍超分
max_flow TinyFlowNet 预测光流最大像素位移
#初始化分类
super().__init__()
#只允许1,2,4三种倍率
assert scale in (1,2,4), "scale must be 1,2, or 4"
#保存中间通道数
self.mid_channels = mid_channels
#保存超分倍率
self.scale=scale
#光流网络,用于估计相邻帧之间的运动
self.flow_net = TinyFlowNet(
max_flow = max_flow,
)
#每帧特征提取网络
feature_layers=[]
#第一层卷积,RGB图像3通道->mid_channels
feature_layers.append(
nn.Conv2d(
3, mid_channels,kernel_size=3,
stride=1,padding=1
)
)
#激活函数
feature_layers.append(
nn.LeakyReLU(
negative_slope=0.1, inplace=True
)
)
#堆叠多个残差块,用于提取每一帧的空间特征
for _ in range(num_feature_blocks):
feature_layers.append(
ResidualBlockNoBN(
channels=mid_channels,
)
)
#组成每帧特征提取网络
self.feat_extract = nn.Sequential(*feature_layers)
#反向传播网络
#输入是当前帧特征 + 从未来帧传播过来的特征
#所以输入通道数mid_channels * 2
self.backward_trunk = ResidualBlockWithInputConv(
in_channels = mid_channels * 2,
mid_channels = mid_channels,
num_blocks=nm_propagation_blocks
)
#正向传播网络
#输入是当前帧特征 + 从过去帧传播过来的特征
#所以输入通道数也是mid_channels *2
self.forward_trunk = ResidualBlocksWithInputConv(
in_channels = mid_channels * 2,
mid_channels = mid_channels,
num_blocks = num_propagation_blocks,
)
#重建网络,
#输入是当前帧特征 + 反向传播特征 + 正向传播特征
# 所以输入通道数是mid_channels * 3
self.reconstruction = ResidualBlocksWithInputConv(
in_channels = mid_channels * 3,
mid_channels = mid_channels,
num_blocks = num_reconstruction_blocks,
)
#激活函数
self.lrelu = nn.LeakyReLU(
negative_slope=0.1,
inplace=True,
)
#pixelShuffle用于超分辨率上采样
self.pixel_shuffle = nn.PixelShuffle(
upscale_factor=2,
)
#如果scale >=2 需要一次2倍上采样
if scale >= 2:
self.upconv1 = nn.Conv2d(
mid_channels,
mid_channels * 4,
kernel_size=3,
stride=1,
padding=1,
)
#如果scale==4需要两次2倍上采样
if scale == 4:
self.upconv2 = nn.Conv2d(
mid_channels,
mid_channels *4,
kernel_size=3,
stride=1,
padding=1,
)
#高分辨率空间上的卷积
self.conv_hr = nn.Conv2d(
mid_channels,
mid_channels,
kernel_size=3,
stride=1,
padding=1,
#最后一层卷积,把特征图变回RGB图像
)
self.conv_last = nn.Conv2d(
mid_channels,
3,
kernel_size=3,
stride=1,
padding=1,
)
def compute_flows(self, x):
#计算相邻帧之间的光流
#输入
x shape = [B, T, 3, H ,W]
返回 flow_forward:
用于正向传播
flows_forward[:, i-1]表示第i帧 -> 第i-1帧的光流
用它可以把过去帧特征wrap到当前帧
shape=[B,T-1, 2, H, W]
flows_backward:
用于反向传播,flows_backward[:,i]表示第i帧 第i + 1帧的光流
用它可以把未来帧特征warp到当前帧
shape = [B,T-1,2,H,W]
#取出输入视频的维度
b,t,c,h,w = x.size()
#如果只有1帧,就没有相邻帧光流
if t <=1:
empty = x.new_zeros(b, 0, ,2, h, w)
return empty, empty
#存放反向传播需要的光流
flows_backward = []
#对于反向传播,需要从未来帧传播到的当前帧
#warp future feature 到当前帧时,需要当前帧 ->未来帧 的光流
for i in range(t - 1):
#计算第i帧到第i + 1帧的光流
flow_i_to_next = self.flow_net(
x[:, i, :, :, :],
x[:, i + 1, :, :, :],
)
#保存光流
flows_backward.append(flow_i_to_next)
#把list堆叠成tensor
#shape = [B, T-1, 2, H, W]
flows_backward=torch.stack(
flows_backward,
dim=1,
)
#存档正向传播需要的光流
flows_forward=[]
#对于正向传播,需要从过去帧传播到当前帧
#warp past feature 到当前帧时,需要当前帧-》过去帧的光流
for i in range(1, t):
#计算第i帧到第i-1帧的光流
flow_i_to_prev = self.flow_net(
x[:, i, :, :, :],
x[:, i - 1, :, :, :],
)
#保存光流
flows_forward.append(flow_i_to_prev)
#shape = [B, T-1, 2, H, W]
flow_forwards = torch.stack(
flows_forward,
dim=1
)
#返回正向传播光流和反向传播光流
return flows_forward, flows_backward
def upsample(self, feat):
根据scale对重建特征进行上采样
输入: feat shape = [B, C, H, W]
输出 scale=1
out shape = [B, 3, H, W]
scale=2
out shape=[B, 3, 2H, 2W]
scale=4:
out shape= [B, 3, 4H, 4W]
#如果是2倍或4倍超分,先做一次2倍PixelShuffl
eif self.scale == 2:
#卷积把通道扩展到4倍
feat = self.upconv1(feat)
#pixelShuffle把通道转换为空间分辨率
feat = self.pixel_shuffle(feat)
#激活
feat = self.lrelu(feat)
#如果是4倍超分,需要做两次2倍PixelShuffle
elif self.scale == 4
:#第一次2倍上采样
feat = self.upconv1(feat)
feat = self.pixel_shuffle(feat)
feat = self.lrelu(feat)
#第二次2倍上采样
feat = self.upconv2(feat)
feat = self.pixel_shuffle(feat)
feat = self.lrelu(feat)
#高分辨率卷积
feat = self.conv_hr(feat)
#激活
feat = self.lrelu(fea)t
#输出RGB残差图像
out = self.conv_last(feat)
#返回输出
return out
def get_base_frame(self, lr_frame):
获取残差链接里的base image
#对于scale=1
base就是原输入帧
对于scale=2或scale=4
bas是双线性循环放大后的输入帧
最终输出
enhanced = predicted_residual + base
#如果不做超分,直接返回原图
if self.scale == 1:
return lr_frame
#如果做超分,吧低清晰度输入双线性插值放大
base = F.interpolate(
lr_frame,
scale_factor=self.scale,
mode = "bilinear",
align_corners=False
)
#返回base frame
return base
def forward(self, x):
"""
前向传播。
输入:
x shape = [B, T, 3, H, W]
输出:
out shape = [B, T, 3, H*scale, W*scale]
"""
#检查输入必须是5倍
if x.dim() != 5
#取出输入视频的维度
b,t,c,h,w = x.size()
检查必须是RGB视频
#每帧CNN特征提取
#把[B,T,3,H,W]reshape成[B*t, 3, H ,W]
#这样可以一次性把所有帧送进CNN
x_reshape = x.reshape(b*t, c, h, w)
#提取每帧空间特征
feats = self.feat_extract(x_reshape)
#把特征reshape回视频序列形式
#shape = [B,T,mid_channels, H,W]
feats = feats.reshape(
b,t,self.mid_channels,
h,w
)
#计算相邻帧光流
flows_forward用于正向传播
flows_backward 用于反向传播
flows_forward, flows_backward = self.compute_flows(x)
#反向时间传播 从T-1帧传播到第0帧
#用list存放每一帧反向传播特征
backward_feats = [None] * t
#初始化传播特征为全0
#shape = [B, mid_channels, H, W]
feat_prop = x.new_zeros(b, self.mid_channels,h,w)
#从最后一帧住第一帧遍历
for i in range(t - 1, -1, -1):
#如果不是最后一帧,就需要把未来帧传播特征warp到当前帧
if i < t - 1:
#flows_backward[:, i]是第i帧->第i+1帧的光流
#用它可以把第i + 1帧的传播特征对齐到第i帧
feat_prop = flow_warp (
feat_prop,
flows_backward[:, i,:,:,:],
)
#当前帧特征
curr_feat = feats[L,i,:,:,:]
#拼接当前帧特征和传播特征
#shape = [B, mid_channel *2, H, W]
feat_input = torch.cat(
[curr_feat, feat_prop],
dim = 1,
)
#通过反向传播网络更新传播特征
#shape = [B, mid_channels * 2, H, W]
feat_input = torch.cat(
[curr_feat, feat_prop],
dim=1,
)
#通过反向传播网络更新传播特征
feat_prop = self.backward_trunk(feat_input)
#保存第i帧对应的反响传播特征
backward_feats[i] = feat_prop
#4 正向时间传播,从第0帧传播到T-1帧
#用list存放每一帧的正向传播特征
forwards_feats = [None]*t
#初始化正向传播特征为全0
feat_prop = x.new_zeros(b,self.mid_channels, h,w)
#从第一帧往后一帧遍历
for i in range(t):
#如果不是第一帧,就需要把过去帧传播特征warp到当前帧
if i > 0
#dlows forward[:, i-1]是第i帧 第i-1帧的光流
#用它可以把第i-1帧的传播特征对其道第i帧
feat_prop = flow_warp(
feat_prop,
flow_forward[:,i-1,:,:,:]
)
#当前帧特征
curr_feat = feats[:,i,:,:,:]
#拼接当前帧特征和正向传播特征
feat_input = torch.cat(
[curr_feat, feat_prop],
dim=1,
)
#通过正向传播网更新传播特征
feat_prop = self.forward_trunk(feat_input)
#保存第i帧对应的正向传播特征
forward_feats][i = feat_prop
#融合当前帧特征,反向传播特征,正向传播特征, 并重建输出帧
#存放所有输出帧
outs=[]
#对每一帧分别重建
for i in range()t:
#当前帧的原始空间特征
curr_feat = feats[:,i,:,:,:]
#当前帧的反向传播特征
backward_feat = backward_feats[i]
#当前帧的正向传播特征
forward_feat = forward_feats[i]
#三类特征拼接
#shape = [B,mid_channels * 3, H,W]
feat = torch.cat(
[currefeat, backward_feat, forward_feat],
dim=1,
)
#通过重建网络重建特征
feat = self.reconstruction(feat)
#根据scale输出RGB残差图像
out = self.upsample(feat)
#获取base frame
#scale=1 时就是原输入帧
#scale= 2/4 时是双线性插值放大后的输入帧
base = self.get_base_frame(
x[:,i,:,:,:]
)
#残差学习,最终输出 网络预测残差 + base
out = out + base
#保存当前输出帧
outs.append(out)
#把list里的每一帧堆叠回视频序列
#shape = [B,T,3, H *scale, W*scale]
outs = torch.stack(outs, dim=1)
#返回增强后的视频帧序列
return outs
if __name__ == "__main__":
简单测试代码
直接运行 python basicvsr_lite.py
#构造一个BasicVSR-lite模型
#scale=1 表示输入输出同分辨率
model = BasicVSRLite(
mid_channels=64,
num_feature_blocks=5,
num_propagation_blocks=7,
num_reconstruction_blocks=10,
scale=1,
)
#构造一个假的输入视频batch
#B=2,T=7,C=3,H=64,W=64
x = torch.randn(2,7,3,64,64)
#前向传播
y = model(x)
#打印输入输出尺寸
#训练时一般这样计算损失
#假设gt是清晰视频帧,shape和y一样
gt = torch.randn_like(y)
#视频增强/超分常用L1Loss
loss = F.l1_loss(y, gt)
二 模型总结
这个模型是一个教学版BsicVSR-lite视频增强模型。核心目的不会单张图片,而是处理一段连续视频帧,利用前后帧信息增强当前帧。
输入一段低质量视频帧
提取每一帧的CNN特征
估计相邻之间的光流
用光流吧前后帧特征对齐
做正向和反向时序信息传播
融合当前帧,过去帧,未来帧信息。
1、这个模型能做什么
视频去噪
视频去模糊
视频压缩伪影修复
视频画质增强
视频超分辩率
由scale控制任务类型
scale=1
表示输入输出同尺寸,适合葡萄视频增强
scale=2
表示2倍视频超分
scale=4
表示4倍视频超分
2输入输出格式
x.shape = [B,T,3,H,W]
含义是
B = batch_size 一次训练几个视频片段
T = 每个视频片段有多少帧
3 = RGB通道数量
H 图像高度
W 图像宽度
x = torch.randn(2, 7, 3, 64, 64)
表示
2 个视频片段
每个片段7帧
每帧是RGB图像
每帧大小64x64
如果scale=1 输出是
y.shape = [B,T,3,H,W]
如果scale=4 输出是
y.shape=[B,T,3,4H,4W]
3 模型整体网络结构
class BasicVSRLite(nn.Module)
里面主要有这些模块
feat_extract 每帧CNN特征提取
flow_net 简化光流估计网络
backward_trunk 反向时间传播网络
forward_trunk 正向时间传播网络
reconstruction 特征融合和重建网络
upsample 超分上采样模块
residual base残差输出链接
整体结构可以画成
输入视频x:[B,T,3,H,W]
每帧CNN特征提取
feats [B,T,64,H,W]
计算相邻帧光流
反向传播,T-1->0 得到未来信息backward_feats
正向传播 0->T-1 得到过去信息forward_feats
每一帧融合
当前帧特征 过去信息 未来信息
重建网络
输出增强视频
4 flow_warp是干什么的?
def flow_warp(x, flow)
作用是,根据光流把特征图进行空间对齐
视频里物体会运动,
第一帧 人脸在左边
第二帧,人脸在中间
第三针,人脸在右边
直接把这些帧的特征融合,会无法对齐,结果容易模糊
需要先用光流估计运动
这个额像素从上一帧移动到了哪里
这个特征应该往左还是往右移动
F.grid_sample(...)
重新采样特征图,吧前后特征对齐到当前帧
flow_warp = 根据运动信息移动特征图
5 ResidualBlockNoBN是什么
这是一个不带BatchNorm的残差块
class ResidualBlockNoBN(nn.Module)
结构是
输入x
Conv 3x3
LeakyReLU
Conv 3x3
加回输入x
输出
对应代码
identity=x
out = self.conv1()x
out = self.relu(out)
out = self.conv2(out)
out = identity + out * self.res_scale
作用是增强特征表达能力
为什么没有BatchNorm 归一化一下?
因为视频增强超分,去噪这类图像恢复任务需要保留非常惊喜的像素信息,BatchNorm有时会破坏图像的亮度,颜色,纹理分布,所以很多恢复模型不用BN
6 TinyFlowNet是什么?
class TinyFlowNet(nn.Module)
这是一个简化版本光流估计网络
输入两帧图像
img_ref.shape=[B,3,H,W]
img_supp.shape = [B,3, H,W]
先拼接成
[B,6,H,W]
然后经过一个小CNN,输出
flow.shape = [B,2,H,W]
其中
flow[:, 0,:,)] = x方向位移
flow]:,1,:,:[ = y 方向位移
这个TInyFlowNet是教学版,不是官方BasicVSR里的SpyNet,
7每帧特征提取feat_extract
这一段代码
self.feat_extract = nn.Sequential(*feature_layers)
作用把每一帧RGB图像变成CNN特征图
输入一帧
[B,3,H,W]
输出 [B,64,H,W]
mid_channels = 64
所以每帧被转换成64通道的空间特征
注意这里不是输出一个向量,保留二维空间结构
图片 CNN 512维向量
图片 CNN 64通道特征图
因为视频增强最终需要回复图像,不能把空间信息压扁。
8 compute flows计算什么?
flows_forward, flows_backward = self.compute_flos(x)
这个函数计算相邻帧之间的光流
flows_backward
flows_forward
flows_backward用于反向传播,就是从未来帧往当前帧传信息
代码里面计算是这样的
flow_i_to_next = self.flow_net (
x[:, i]
x[:, i + 1]
)
用它可以把低i+1帧的特征warp到第i帧。
flows_forward
用于正向传播,从过去帧往当前帧传信息
flow_i_to_prev = self.flow_net(
x[:, i]
x[:, i + 1]
)
含义是
第i帧到第i+1帧的光流
9反向时间传播
for i in range(t-1, -1, -1)
从最后一帧往第一帧处理
T-1, T-2 ,0
目的
让当前帧获得未来帧的信息
例如处理第3帧时,利用第4,5,6帧传过来的信息
初始化feat_prop = 0
从最后一帧开始
如果不是最后一帧
用光流warp 未来帧传来的feat_prop
拼接
当前帧特征curr_feat
未来传播特征 feat_prop
送入backward_trunk
得到新的feat_prop
保存为backward_feats[i]
核心代码
feat_prop = flow_warp (
feat_prop,flows_backward[:,i]
)
feat_input= torch.cat(
]curr_feat, feat_prop[
dim=1
)
feat_prop=self.backward_trunk(feat_input)
10 正向时间传播
for i in range(t)
例如处理第3帧时,可以利用第0,1,2帧传过来的信息
初始化feat_prop = 0
从第一帧开始
如果不是第一帧
用光流warp过去帧传来的feat_prop
拼接
当前帧特征curr_feat
过去传播特征 feat_prop
送入forward_tru
nk得到新的feat_prop
保存为forward_feats[i]
核心代码
feat_prop = flow_warp(
feat_prop,
flows_forward[:, i - 1]
)
feat_input = torch.cat(
[curr_feat, feat_prop],
dim=1
)
feat_prop = self.forward_trunk(feat_input)
11 最终融合重建
1curr_feat当前帧呢自己的特征
2 backward_feat 未来帧传来的信息
3 forward_feat过去帧传来的信息
然后拼接
feat = torch.cat([curr_feat, backward_feat, forward_feat],
dim=1)
如果mid_channels=64 那么
curr_feat = 6通道
backward_feat = 64通道
forward_feat = 64通道
拼接后 = 192通道
送入重建网络
feat = self.reconstruction(fe)at
重建网络把192通道融合成64通道,
out = self.upsample(feat)输出RGB图像
12 upsample是什么
def upsample(self, feat)
负责吧特征图转换成最终RGB输出
如果scale=1
不做放大
[B,64,H,W]->[B,3,H,W]
如果scale=2
Conv -> PixelShuffle x2 -> Conv -> RGB
[B,64,H,W] -> [B,3,2H,2W]
如果scale=4
Conv -> PixelShuffle x2
Conv -> PixelShuffle x2
Conv -> RGB
[B,64,H,W] -> [B,3,4H,4W]
PixelShuffle是超分模型里常用的上采样方式,通道维的信息重新排列空间维上。
13 为什么最后要out + base
base = self.get_base_frame(x[:, i])
out = out + base
残差学习
如果scale = 1
base = 原始输入帧
输出=原始输入帧 + 模型预测的修正量
如果scale=4
base = 双线性插值方法后的输入帧
输出=放大后的输入帧 + 模型预测的高频细节
因为模型从不从零生成整张图,
哪里需要去噪 哪里需要变清晰 哪里需要补充细节 哪里需要修复伪影
14 forward函数完整流程
你的forward可以总结成
1检查输入维度
2 把视频帧展开
[B,T,3,H,W]->]B*T,3,H,W[
3 每帧CNN特征提取
]B*T, 3,H,W[->[B*T,64,H,W]
4 reshape回视频形式
[B*T,64,H,W]->[B,T,64,H,W]
5 计算相邻帧光流
flows_forward = [B,T-1,2,H,W]
flows_backward=[B,T-1,2,H,W]
6 反向时间传播
得到backward_feats
7 正向时间传播
得到forward_feat
s8 对每一帧
当前帧呢特征 未来信息 +过去信息
reconstruction
upsample
加base frame
输出增强帧
9 把所有输出帧 stack
[B,T,3,H*scale,W*scale]
15 模型训练时怎么用
普通视频增强
model = BasicVSRLite(scale=1).cuda()
lq = torch.randn(2,7,3,128,128).cuda(
)gt = torch.randn(2,7,3,128,128).cuda()
pred = model(lq)
loss = F.l1_loss(pred, gt)
loss.backward()
倍视频超分
model = BasicVSRLite(scale=4).cuda()
lq = torch.randn(2, 7, 3, 64, 64).cuda()
gt = torch.randn(2, 7, 3, 256, 256).cuda()
pred = model(lq)
loss = F.l1_loss(pred, gt)
loss.backward()
其中,
lq = low quality低质量视频帧
gt = ground truth 高质量视频帧
16这个模型的学习重点
16.1 CNN特征提取
self.feat_extract
16.2 光流对齐
TinyFlowNet + flow_warp
16.3 双向时间传播
backwrd_trunk + forward_trunk
16.4 融合重建
reconstruction + upsample + out +base
当前帧不清楚的地方,可以从前后帧里找信息补回来
17 需要注意的地方
1. TinyFlowNet 光流效果不一定好
2. 没有使用预训练SpyNet / PWC
Net美元BasicVSR+的二阶传播
4没有EDVR的可变性卷积对齐
5 只用了基础L1Loss视觉锐度可能一般

1262

被折叠的 条评论
为什么被折叠?



