BasicVSR-lite图像画质增强

一 模型介绍

是一个CNN+光流对其+双向时许传播的视频增强模型,适合做

视频超分辨率、视频去噪、视频去模糊、视频压缩伪影修复、一般视频增强

BasicVSR系列的核心思想是;不要只增强单帧,而是利用前后多帧的信息。BasicVSR++论文也明确说,基于recurrent structure 通过双向传播和特征对其来利用整个视频序列的信息,BasicVSR++进一步加入二阶传播和flow-guided deformable alignment,增强了对错位视频帧的时空信息利用。

1.1 为何BasicVSR-like模型合适?

模型 适合程度 原因

BasicVSR-lite 最推荐 结构清楚,CNN为主,适合自己实现

EDVR 也推荐,经典视频增强模型,deformable conv实现稍复杂

FastDVDnet 适合视频去噪,不依赖光流,速度快,结构相对直接

BasicVSR++ 效果更强 但结构比BasicVSR更复杂

RealBasicVSR 真实视频超分,适合真实退化视频,但是训练场领略更复杂

优先级

BasicVSR-lite

EDVR

BasicVSR++ / RealBasicVSR

1.2 BasicVSR-lite的整体结构

输入不是一张图,而是一段连续视频帧

输入低质量视频帧

frame_1, frame_2, frame_3, ...frame_T

张量形状一般是

[B, T, C,H,W]

B = batch size

T 连续帧数量,比如7或15

C = 3,RGB 通道数量

H = 图像高度

W= 图像宽度

整体网络可以写成

连续低质量帧

->每帧CNN提取特征

->光流估计/特征对齐

->反向时序传播

->正向时序传播

->特征融合

->重建网络

->增强后视频帧

1.3 网络结构详细拆解

1.3.1 输入

假设一次输入7帧

x = [frame_1, frame_2, ... frame_7]

shape 是

  1. shape = [8, 7, 3, H, W]

如果做4倍超分,输入可能是

低清视频帧,[B, 7, 3, 64, 64]

高清目标帧 [B, 7, 3, 256, 256]

如果做去噪,去模糊,压缩伪影修复,输入和输出尺寸通常一样

低质量帧 [B,7,3,H,W]

高质量帧 [B, 7, ,3, H, W]

1.3.2 每帧特征提取CNN

frame_t ->Conv->RsBlocks->feature_t

feature_t.shape = [B, 64, H, W]

这里的CNN可以用

COnv2d ResidualBlock

ReLU / LeakyReLU

这部分和人脸模型ResNet思路类似,输出不是512维向量,而是保留二维特征图

人脸识别

[B, 3, 112, 112]->CNN->[B, 512]

视频增强

[B, 3,H,w]->CNN->[B, 64, H, W]

视频增强不能太早flatten, 因为需要恢复图像细节

1.3.3 光流估计/帧间对齐

视频增强最大的问题是:

相邻帧内容相似,物理会运动

BasicVSR类模型通常使用光流网络,比如SPyNet, 来估计相邻帧之间的运动,BasicVSR++补充材料里也提到使用pretrained SPyNet作为flow network

光流可以理解成

第t帧的每个像素,应该往哪里移动,才能对齐到t+1帧

feature_{t-1}

根据optical flow warp

对齐到feature_t

1.3.4 双向时许传播

这是BasicVSR的核心

看当前帧附近的几帧,让信息沿着时间传播。

反向缠传播

从视频最后一帧往前传

frame_T->frame_{T-1}->...frame_1

得到每一帧的backward feature

backward_feature_t

正向传播

再从第一帧往后传

frame_1->frame_2 ...frame_T

得到每一帧的forward feature

forward feature_t

最后第t帧可以利用

当前帧特征

前面帧传来的信息

后面帧传来的信息

enhanced_feature_t = fuse(

current_feature_t,

forward_feature_t,

backward_feature_t

)

1.3.5 重建网络

融合后的特征再经过CNN重建成图像

如果是去噪/去模糊/压缩增强

[B,64,H,W]->Conv->[B,3,H,W]

如果是视频超分辨率,

[B, 64, H, W]

->PixelShuffle x2

->PixelShuffle x2

->[B, 3, 4H, 4W]

1.4 如果用EDVR

EDVR时CVPRW 2019的视频恢复模型,

EDVR沦为提示两个关键模块

PCD Alignment 金字塔,及联,可变形卷积对齐

TSA Fusion 时许和空间注意力融合

1.5 需要去噪FastDVDnet

去噪->降低ISO噪声->减少暗光噪声->视频画面变干净

FastDVDnet是CVPR 2020的视频去噪模型,官方仓库提供Pytorch实现,说明它不适用光流估计的视频去噪算法。

不用光流->结构相对简单->速度快->适合视频去噪入门

二 代码实现

import torch
import torch.nn as nn
import torch.nn.functional as F

def flow_wrap(
x,flow,padding_mode="border",
align_corners=True
):
#使用光流对特征图做wrap对齐
#参数,x: 要背对齐的特征图,shape=[B,C,H,W]
# flow: 光流,shape=[B,2,H,W]
#flow]:, 0, :, :[表示x方向位移,横向位移
#flow[:,1,:,:] 表示y方向位移,就是纵向位移
#padding_mode:
            grid_sample 越界采样时的填充方式。
            "border" 表示越界时使用边界像素。
# align_corners:
  grid_sample的坐标对齐方式
返回 warped_x: 根据flow对齐后的特征图,shape=[B,C,H,W]
#取出输入特征图的batch size,通道数,高,宽
b,c,h,w = x.size()
#确保flow的数据类型和x一致,避免AMP/FP16 时类型冲突
flow = flow.to(dtype=x.dtype)
#生成y坐标网络,范围时0到H-1
#生成x坐标网络,范围是0到W-1
grid_y, grid_x = torch.meshgrid(
    torch.arange(0, h, device=x.device, dtype=x.dtype),
    torch.arange(0, w, device=x.device, dtype=x.dtype),
    indexing="ij"
)
#grid_x原本shape是[H,W]
扩展成[B,H,W] 方便和batch内每张图的flow相加
grid_x = grid_x.unsqueeze(0).expand(b, -1, -1)
#grid_y 原本shape是[H,W]
#扩展成[B,H,W]
grid_y, grid_y.unsqueeze(0).expand(b, -1, -1)
#当前像素为止x坐标 + 光流横向位移
#得到需要从原特征图哪个x为止采样
vgrid_x = grid_x + flow[:,0,;,);]
#当前像素位置坐标y坐标+光流纵向位移
得到需要从原特征图哪个y位置采样
vgrid_y = grid_y + flow[]

#grid_sample要求坐标范围时[-1, 1]
#所以要把像素坐标[0, W-1]转换成[-1, 1]
if w > 1:
    vgrid_x = 2.0 * vgrid_x / (w-1) - 1.0
else
    vgrid_x = torch.zeros_like(vgrid_x)
    
#把像素坐标[0, H-1]转换成[-1, 1]
if h > 1:
        vgrid_y = 2.0 * vgrid_y / (h - 1) - 1.0
    else:
        vgrid_y = torch.zeros_like(vgrid_y)
#grid_sample要求最后一维时[x,y ]
#所以这里吧x坐标和y坐标对跌倒最后一维
grid = torch.stack(vgrid_x, vgrid_y), dim=-1

#根据grid 从x中采样,得到warp后的特征图
warped_x = F.grid_sample(
x,grid, mode="bilinear", padding_mode=padding_mode,
align_cornors=align_cornors,
)
#返回对齐后的特征图
return warped_x

class ResidualBlockNoBN(nn.Module):
# 不带BatchNorm 的残差块
#视频增强,超分模型里经常不用BatchNorm
#因为BatchNorm可能影响图像恢复的细节和数值范围
def __init__(self, channels, res_scale=1.0):
#channels 输入和输出通道数
#res_scale 残差缩放系数
#可以让残差分支更稳定
#初始化nn.Module父类
super().__init__()
#第一个3x3卷积,通道数不变
self.conv1 = nn.Conv2d(
 channels,
 chnanels, 
 kernel_size=3,
 stride=1,
 padding=1,
)
#第二个3x3卷积,通道数不变
self.conv2 = nn.Conv2d(
            channels,
            channels,
            kernel_size=3,
            stride=1,
            padding=1,
        )
#使用LeakyReLU作为激活函数
self.relu = nn.LeakyReLU(
  negative_slope = 0.1
 inplace=True
)
#保存残差缩放系数
self.res_scale = res_scale

def forward(self, x)
#前向传播,输入x shape = [B,C,H,W]
#输出 out shpe = [B,C,H,W]
#保存原始输入,用于残差链接
identity - x
#第一个卷积
out = self.conv1(x)
#激活函数
out = self.relu(out)
#第二个卷积
out = self.conv2(out)
#残差链接,输出 = 原输入 + 残差分支
out = identity + out * self.res_scale
#返回残差块输出
return out

class ResidualBlockWithInputConv(nn.Module):
#信用一个卷积吧输入通道变成mid_channels
再接多个残差块
#
def __init__(
self,in_channels,mid_channels,num_block
);
 """
        参数:
            in_channels:
                输入通道数。

            mid_channels:
                中间特征通道数。

            num_blocks:
                残差块数量。
        """
    #初始化父类
    
    super().__init__()
    #用list存放网络层
    layers = []
    #输入卷积,吧in_channels变成mid_channels
    layers.append(
    nn.Conv2d(
                in_channels,
                mid_channels,
                kernel_size=3,
                stride=1,
                padding=1,
            )
    )
    #激活函数
    layers.append(
     nn.LeakyReLU(
     negative_slope=0.1,inplace=True
     )
    )
    #堆叠多个残差块
    for _ in range(num_blocks):
      layers.append(
          ResidualBlockNoBN(
             channels=mid_channels,
          )
      )
#把所有层组成一个Sequential
  self.main = nn.Sequential(*layers)
def forward(self, x):
        """
        前向传播。
        """

        # 直接把输入送进 Sequential
        return self.main(x)
class TinyFlowNet(nn.Module):
#非常简化的光流网络
#不是论文里的BasicVSR里的SpyNet 这是为了吧BasicVSR-like结构跑通
#输入 img_ref 参考帧,shape=[B,3,H,W]
img_supp
 支撑帧,相邻帧
 shape = [B,3,,H,W]
 
 def __init__(self, max_flow=20.0)
 参数,max_flow 限制预测光流的最大像素位移
 #初始化父类
 super().__init__()
 #保存最大光流范围
 self.max_flow = max_flo
 e#输入是两张RGB图拼接, 所有通道数是6
 self.body = nn.Sequential(
    #第一层卷积,提取浅层特征
 nn.Conv2d(6, 32, kernel_size=7, stride=1, padding=3),
 nn.LeakyReLU(0.1, inplace=True),
 #下采样一次,扩大感受
 nn.Conv2d(32, 64, kernel_size=5, stride=2, padding=2),
 nn.LeakyReLU(0.1, inplace=True),
 #中间卷积
 nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
nn.LeakyReLU(0.1, inplace=True),
#再下次阿阳一次,继续扩大感受
nn.Conv2d(64, 96, kernel_size=3, stride=2, padding=1),
nn.LeakyReLU(0.1, inplace=True),
#中间卷积
nn.Conv2d(96, 96, kernel_size=3, stride=1, padding=1),
nn.LeakyReLU(0.1, inplace=True),
#上采样回较高分辨率
nn.ConvTranspose2d(96, 64, kernel_size=4, stride=2, padding=1),
nn.LeakyReLU(0.1, inplace=True),
#再上采样回原图分辨率附近
nn.Conv2d(64, 96, kernel_size=3, stride=2, padding=1),
nn.LeakyReLU(0.1, inplace=True)
#中间卷积
nn.Conv2d(96, 96, kernel_size=3, stride=1, padding=1),
nn.LeakyReLU(0.1, inplace=True)
#上采样回校高分辨率
nn.ConvTranspose2d(96, 64, kernel_size=4, stride=2, padding=1),
nn.LeakyReLU(0.1, inplace=True)
#上采样回调分辨率附近
nn.ConvTranspose2d(64, 32, kernel_size=4,stride=2,padding=1)
nn.LeakyReLU(0.1, inplace=True)
#输出 2通道光流dxdy
nn.Conv2d(32, 2, kernel_size=3, stride=1, padding=1)
)        

def forward(self, img_ref, img_supp):
##前向传播
#记录原始图像
#shape [B, 6, H, W]
inp = torch.cat([img_ref, img_supp], dim=1)
#预测光流
flow = self.body(inp)
#如果因为下采样上采样导致尺寸略有差异插值尺寸
if flow.shape[-2:] != (h, w):
   flow = F.interpolate(
      flow,
      size(h,w),
      mode="bilinear",
      align_corners=False
   )
#tanh 输出限制到[-1,1]
#再乘max_flow 得到像素激光流范围
flow = torch.tanh(flow) * self.max_flow
#返回光流
return flow

class BasicVSRLite(nn.Module):
#教学版BasicVSR-lite
#整体结构 输入视频序列
#CNN提取每帧特征
#估计相邻光流
#反向时间传播T->1
正向时间传播1->T
当前特征 + 反向传播特征 + 正向传播特征 融合
重建增强 / 超分帧
输入
x shpe = [B, T, 3, H, W]
输出
scale = 1:
 out shape=[B, T, 3, H, W]
 scale=2:
  out shape = [B, T, 3, 2H, 2W]
  scale = 4:
  out shape = [B, T, 3, 4h, 4w]
def __init__(
  self, mid_channels=64,
  num_feature_blocks=5,
  num_propagation_blocks=7,
scale=1,
max_flow=20.0,
) :
参数
mid_channels:中间特征通道
  64 是常见轻量配置
num_feature_blocks:
 每帧特征提取阶段残差数量
num_propagation_blocks
 正向/反向传播阶段残差数量
 num_reconstruction_blocks
  重建阶段残差数量
scale:
 放大倍率
 scale=1 表示输入输出尺寸同于去噪去模糊增强
 scale=2 表示2超分
 scale=4 表示4超分
 max_flow TinyFlowNet 预测光流最大像素位移
 
 #初始化分类
 super().__init__()
 #只允许124三种倍率
 assert scale in (1,2,4), "scale must be 1,2, or 4"
 #保存中间通道数
 self.mid_channels = mid_channels
 #保存超分倍率
 self.scale=scale
 #光流网络用于估计相邻之间运动
 self.flow_net = TinyFlowNet(
    max_flow = max_flow,
 )
 #每帧特征提取网络
 feature_layers=[]
 #第一层卷积RGB图像3通道->mid_channels
 feature_layers.append(
  nn.Conv2d(
  3, mid_channels,kernel_size=3,
   stride=1,padding=1
  )
 )
#激活函数
feature_layers.append(
  nn.LeakyReLU(
      negative_slope=0.1, inplace=True
  )
)
#堆叠多个残差块用于提取每一帧空间特征
for _ in range(num_feature_blocks):
 feature_layers.append(
     ResidualBlockNoBN(
        channels=mid_channels,
     )
 )
#组成每帧特征提取网络
self.feat_extract = nn.Sequential(*feature_layers)
#反向传播网络
#输入当前特征 + 从未来传播过来特征
#所以输入通道数mid_channels * 2
self.backward_trunk = ResidualBlockWithInputConv(
    in_channels = mid_channels * 2,
    mid_channels = mid_channels,
    num_blocks=nm_propagation_blocks
)
#正向传播网络
#输入当前特征 + 从过去传播过来特征
#所以输入通道数也是mid_channels *2
self.forward_trunk = ResidualBlocksWithInputConv(
   in_channels = mid_channels * 2,
   mid_channels = mid_channels,
   num_blocks = num_propagation_blocks,
)
#重建网络
#输入当前特征 + 反向传播特征 + 正向传播特征
# 所以输入通道数mid_channels * 3
self.reconstruction = ResidualBlocksWithInputConv(
  in_channels = mid_channels * 3,
mid_channels = mid_channels,
num_blocks = num_reconstruction_blocks,
)
#激活函数
self.lrelu = nn.LeakyReLU(
   negative_slope=0.1,
   inplace=True,
)
#pixelShuffle用于分辨率上采样
self.pixel_shuffle = nn.PixelShuffle(
  upscale_factor=2,
)
#如果scale >=2 需要一次2上采样
if scale >= 2:
  self.upconv1 = nn.Conv2d(
     mid_channels,
     mid_channels * 4,
     kernel_size=3,
     stride=1,
     padding=1,
  )
#如果scale==4需要两次2上采样
if scale == 4:
 self.upconv2 = nn.Conv2d(
    mid_channels,
    mid_channels *4,
    kernel_size=3,
    stride=1,
    padding=1,
 )  
#分辨率空间上卷积
self.conv_hr = nn.Conv2d(
  mid_channels,
  mid_channels,
  kernel_size=3,
  stride=1,
  padding=1,
  #最后一层卷积特征变回RGB图像
)
self.conv_last = nn.Conv2d(
  mid_channels,
 3,
 kernel_size=3,
 stride=1,
 padding=1,
)
def compute_flows(self, x):
#计算相邻之间光流
#输入
x shape = [B, T, 3, H ,W]
返回 flow_forward:
用于正向传播
flows_forward[:, i-1]表示i -> i-1光流
用它可以过去特征wrap当前
shape=[B,T-1, 2, H, W]
flows_backward:
 用于反向传播flows_backward[:,i]表示i i + 1光流
 用它可以把未来特征warp当前
 shape = [B,T-1,2,H,W]
 #取出输入视频维度
 b,t,c,h,w = x.size()
 #如果只有1就没有相邻光流
 if t <=1:
  empty = x.new_zeros(b, 0, ,2, h, w)
  return empty, empty

#存放反向传播需要光流
flows_backward = []
#对于反向传播需要从未来传播到当前
#warp future feature 当前需要当前 ->未来 光流
for i in range(t - 1):
#计算ii + 1光流
flow_i_to_next = self.flow_net(
x[:, i, :, :, :],
x[:, i + 1, :, :, :],
)
#保存光流
flows_backward.append(flow_i_to_next)
#list堆叠tensor
#shape = [B, T-1, 2, H, W]
flows_backward=torch.stack(
   flows_backward,
   dim=1,
)
#存档正向传播需要光流
flows_forward=[]
#对于正向传播需要从过去传播当前
#warp past feature 当前帧需要当前-过去光流
for i in range(1, t):
#计算ii-1光流
 flow_i_to_prev = self.flow_net(
 x[:, i, :, :, :],
 x[:, i - 1, :, :, :],
 )
 #保存光流
 flows_forward.append(flow_i_to_prev)
 #shape = [B, T-1, 2, H, W]
 flow_forwards = torch.stack(
 flows_forward,
 dim=1
 )
 #返回正向传播光流反向传播光流
 return flows_forward, flows_backward
 
 def upsample(self, feat):
 根据scale重建特征进行上采样
 输入: feat shape = [B, C, H, W]
 输出 scale=1
  out shape = [B, 3, H, W]
  scale=2
  out shape=[B, 3, 2H, 2W]
  scale=4:
  out shape= [B, 3, 4H, 4W]
  #如果24超分先做一次2PixelShuffl
  eif self.scale == 2:
   #卷积通道扩展到4
   feat = self.upconv1(feat)
   
   #pixelShuffle通道转换为空间分辨率
   feat = self.pixel_shuffle(feat)
   #激活
   feat = self.lrelu(feat)
   #如果4超分需要做两次2PixelShuffle
   elif self.scale == 4
   :#第一次2上采样
   feat = self.upconv1(feat)
   feat = self.pixel_shuffle(feat)
   feat = self.lrelu(feat)

#第二次2上采样
feat = self.upconv2(feat)
feat = self.pixel_shuffle(feat)
feat = self.lrelu(feat)
#高分辨率卷积
feat = self.conv_hr(feat)
#激活
feat = self.lrelu(fea)t
#输出RGB残差图像
out = self.conv_last(feat)
#返回输出
return out

def get_base_frame(self, lr_frame):
获取残差链接base image
#对于scale=1
base就是输入
对于scale=2scale=4
bas双线性循环放大后的输入帧
最终输出
enhanced = predicted_residual + base
#如果不做超分直接返回原图
if self.scale == 1:
  return lr_frame
  #如果超分清晰度输入双线性插值放大
base = F.interpolate(
lr_frame,
scale_factor=self.scale,
mode = "bilinear",
align_corners=False
)
#返回base frame
return base

def forward(self, x):
 """
        前向传播。

        输入:
            x shape = [B, T, 3, H, W]

        输出:
            out shape = [B, T, 3, H*scale, W*scale]
        """
#检查输入必须5
if x.dim() != 5
#取出输入视频维度
b,t,c,h,w = x.size()
检查必须RGB视频

#每帧CNN特征提取
#[B,T,3,H,W]reshape成[B*t, 3, H ,W]
#这样可以一次性所有CNN
x_reshape = x.reshape(b*t, c, h, w)
#提取每帧空间特征
feats = self.feat_extract(x_reshape)
#特征reshape视频序列形式
#shape = [B,T,mid_channels, H,W]
feats = feats.reshape(
b,t,self.mid_channels,
h,w
)
#计算相邻光流
flows_forward用于正向传播
flows_backward 用于反向传播
flows_forward, flows_backward = self.compute_flows(x)
#反向时间传播 T-1传播0
#list存放每一帧反向传播特征
 backward_feats = [None] * t
 #初始化传播特征0
 #shape = [B, mid_channels, H, W]
 feat_prop = x.new_zeros(b, self.mid_channels,h,w)
#从最后一帧第一遍历
for i in range(t - 1, -1, -1):
 #如果不是最后一帧就需要未来传播特征warp当前
  if i < t - 1:
    #flows_backward[:, i]i->i+1光流
    #用它可以i + 1传播特征对齐i
  feat_prop = flow_warp (
      feat_prop,
      flows_backward[:, i,:,:,:],
  )
#当前特征
curr_feat = feats[L,i,:,:,:]
#拼接当前特征传播特征
#shape = [B, mid_channel *2, H, W]
feat_input = torch.cat(
   [curr_feat, feat_prop],
   dim = 1,
)
#通过反向传播网络更新传播特征
#shape = [B, mid_channels * 2, H, W]
feat_input = torch.cat(
   [curr_feat, feat_prop],
   dim=1,
)
#通过反向传播网络更新传播特征
feat_prop = self.backward_trunk(feat_input)
#保存i对应反响传播特征
backward_feats[i] = feat_prop
#4 正向时间传播0传播T-1帧
#list存放每一帧正向传播特征
forwards_feats = [None]*t
#初始化正向传播特征0
feat_prop = x.new_zeros(b,self.mid_channels, h,w)

#从第一帧往后一帧遍历
for i in range(t):
#如果不是第一帧就需要把过去传播特征warp当前
   if i > 0
   #dlows forward[:, i-1]是i i-1光流
   #用它可以i-1传播特征对其道i
   feat_prop = flow_warp(
       feat_prop,
       flow_forward[:,i-1,:,:,:]   
   )
   #当前特征
   curr_feat = feats[:,i,:,:,:]
   #拼接当前特征正向传播特征
   feat_input = torch.cat(
    [curr_feat, feat_prop],
    dim=1,
   )
   #通过正向传播更新传播特征
   feat_prop = self.forward_trunk(feat_input)
   #保存i对应正向传播特征
   forward_feats][i = feat_prop
   #融合当前特征反向传播特征正向传播特征, 重建输出
   #存放所有输出
   outs=[]
   #每一帧分别重建
   for i in range()t:
   #当前原始空间特征
    curr_feat = feats[:,i,:,:,:]
   #当前反向传播特征
   backward_feat = backward_feats[i]
   #当前正向传播特征
   forward_feat = forward_feats[i]
   #三类特征拼接
   #shape = [B,mid_channels * 3, H,W]
   feat = torch.cat(
     [currefeat, backward_feat, forward_feat],
     dim=1,
   )
   #通过重建网络重建特征
   feat = self.reconstruction(feat)
   #根据scale输出RGB残差图像
   out = self.upsample(feat)
   #获取base frame
   #scale=1 就是输入帧
   #scale= 2/4 双线性插值放大输入
   base = self.get_base_frame(
     x[:,i,:,:,:]
   )
   #残差学习最终输出 网络预测残差 + base
   out = out + base
   #保存当前输出
   outs.append(out)
   #list每一帧堆叠视频序列
   #shape = [B,T,3, H *scale, W*scale]
   outs = torch.stack(outs, dim=1)
   #返回增强后视频序列
   return outs
   
   if __name__ == "__main__":
   简单测试代码
   直接运行 python basicvsr_lite.py
   #构造一个BasicVSR-lite模型
   #scale=1 表示输入输出同分辨率
   model = BasicVSRLite(
    mid_channels=64,
    num_feature_blocks=5,
    num_propagation_blocks=7,
    num_reconstruction_blocks=10,
    scale=1,
   )
   #构造一个输入视频batch
   #B=2,T=7,C=3,H=64,W=64
   x = torch.randn(2,7,3,64,64)
#前向传播
y = model(x)
#打印输入输出尺寸
#训练时一般这样计算损失
#假设gt清晰视频shapey一样
gt = torch.randn_like(y)
#视频增强/超分常用L1Loss
loss = F.l1_loss(y, gt)

模型总结

这个模型一个教学版BsicVSR-lite视频增强模型核心目的不会单张图片而是处理一段连续视频利用前后信息增强当前

输入一段质量视频
提取每一帧CNN特征
估计相邻之间光流
光流前后特征对齐
正向反向时序信息传播
融合当前过去未来帧信息
1这个模型能做什么
视频去噪
视频去模糊
视频压缩伪影修复
视频画质增强
视频超分辩率
scale控制任务类型
scale=1
表示输入输出同尺寸适合葡萄视频增强
scale=2
表示2视频超分
scale=4
表示4视频超分

2输入输出格式
x.shape = [B,T,3,H,W]
含义
B = batch_size 一次训练几个视频片段
T = 每个视频片段有多少帧
3 = RGB通道数量
H 图像高度
W 图像宽度
x = torch.randn(2, 7, 3, 64, 64)
表示
2 视频片段
每个片段7
每帧RGB图像
每帧大小64x64
如果scale=1 输出
y.shape = [B,T,3,H,W]
如果scale=4 输出是
y.shape=[B,T,3,4H,4W]
3 模型整体网络结构
class BasicVSRLite(nn.Module)
里面主要这些模块
feat_extract 每帧CNN特征提取
flow_net 简化光流估计网络
backward_trunk 反向时间传播网络
forward_trunk 正向时间传播网络
reconstruction 特征融合重建网络
upsample 超分上采样模块
residual base残差输出链接

整体结构可以画成
输入视频x:[B,T,3,H,W]
每帧CNN特征提取
feats [B,T,64,H,W]
计算相邻光流
反向传播T-1->0 得到未来信息backward_feats
正向传播 0->T-1 得到过去信息forward_feats
每一帧融合
当前帧特征 过去信息 未来信息
重建网络
输出增强视频

4 flow_warp是干什么
def flow_warp(x, flow)
作用是根据光流特征图进行空间对齐
视频物体运动
第一帧 人脸在左边
第二人脸在中间
第三针人脸在右边
直接把这些帧特征融合无法对齐结果容易模糊
需要先用光流估计运动
这个额像素上一帧移动哪里
这个特征应该往左还是往右移动
F.grid_sample(...)
重新采样特征图吧前后特征对齐当前
flow_warp = 根据运动信息移动特征图
5 ResidualBlockNoBN是什么
这是一个不带BatchNorm残差块
class ResidualBlockNoBN(nn.Module)
结构
输入x
Conv 3x3
LeakyReLU
Conv 3x3
加回输入x
输出
对应代码
identity=x
out = self.conv1()x
out = self.relu(out)
out = self.conv2(out)
out = identity + out * self.res_scale
作用是增强特征表达能力
为什么没有BatchNorm 归一化一下
因为视频增强超分去噪这类图像恢复任务需要保留非常惊喜像素信息BatchNorm有时破坏图像亮度颜色纹理分布所以很多恢复模型不用BN
6 TinyFlowNet是什么
class TinyFlowNet(nn.Module)
这是一个简化版本光流估计网络
输入图像
img_ref.shape=[B,3,H,W]
img_supp.shape = [B,3, H,W]
先拼接成
[B,6,H,W]
然后经过一个CNN输出
flow.shape = [B,2,H,W]
其中
flow[:, 0,:,)] = x方向位移
flow]:,1,:,:[ = y 方向位移
这个TInyFlowNet是教学版不是官方BasicVSRSpyNet
7每帧特征提取feat_extract
这一段代码
self.feat_extract = nn.Sequential(*feature_layers)
作用每一帧RGB图像变成CNN特征图
输入一帧
[B,3,H,W]
输出 [B,64,H,W]
mid_channels = 64
所以每帧转换成64通道空间特征
注意这里不是输出一个向量保留二维空间结构

图片 CNN 512维向量
图片 CNN 64通道特征图
因为视频增强最终需要回复图像不能把空间信息压扁
8 compute flows计算什么
flows_forward, flows_backward = self.compute_flos(x)
这个函数计算相邻之间光流
flows_backward
flows_forward

flows_backward用于反向传播就是从未来当前信息
代码里面计算是这样
flow_i_to_next = self.flow_net (
   x[:, i]
   x[:, i + 1]
)
用它可以i+1特征warpi
flows_forward
用于正向传播从过去帧当前传信息
flow_i_to_prev = self.flow_net(
x[:, i]
   x[:, i + 1]
)
含义是
i帧i+1光流
9反向时间传播
for i in range(t-1, -1, -1)
最后一帧第一帧处理
T-1, T-2 ,0
目的
当前获得未来帧信息
例如处理3利用4,5,6传过来信息
初始化feat_prop = 0
从最后一帧开始
 如果不是最后一帧
   用光流warp 未来帧传来feat_prop
   拼接
     当前帧特征curr_feat
   未来传播特征 feat_prop
   送入backward_trunk
   得到新的feat_prop
   保存为backward_feats[i]
核心代码
feat_prop = flow_warp (
 feat_prop,flows_backward[:,i]
)
feat_input= torch.cat(
]curr_feat, feat_prop[
dim=1
)
feat_prop=self.backward_trunk(feat_input)
10 正向时间传播
for i in range(t)
例如处理3可以利用0,1,2过来信息
初始化feat_prop = 0
从第一帧开始
 如果不是第一帧
   用光流warp过去帧传来feat_prop
   拼接
    当前帧特征curr_feat
     过去传播特征 feat_prop
      送入forward_tru
      nk得到新的feat_prop
  保存forward_feats[i]
  核心代码
  feat_prop = flow_warp(
    feat_prop,
    flows_forward[:, i - 1]
)

feat_input = torch.cat(
    [curr_feat, feat_prop],
    dim=1
)

feat_prop = self.forward_trunk(feat_input)
11 最终融合重建
1curr_feat当前帧呢自己的特征
2 backward_feat 未来帧传来信息
3 forward_feat过去传来信息
然后拼接
feat = torch.cat([curr_feat, backward_feat, forward_feat],
dim=1)
如果mid_channels=64 那么
curr_feat = 6通道
backward_feat = 64通道
forward_feat = 64通道
拼接后 = 192通道
送入重建网络
feat = self.reconstruction(fe)at
重建网络192通道融合64通道
out = self.upsample(feat)输出RGB图像
12 upsample是什么
def upsample(self, feat)
负责特征图转换成最终RGB输出
如果scale=1
不做放大
[B,64,H,W]->[B,3,H,W]
如果scale=2
Conv -> PixelShuffle x2 -> Conv -> RGB
[B,64,H,W] -> [B,3,2H,2W]
如果scale=4
Conv -> PixelShuffle x2
Conv -> PixelShuffle x2
Conv -> RGB

[B,64,H,W] -> [B,3,4H,4W]
PixelShuffle超分模型常用上采样方式通道信息重新排列空间上。
13 为什么最后out + base
base = self.get_base_frame(x[:, i])
out = out + base
残差学习
如果scale = 1
base = 原始输入
输出=原始输入帧 + 模型预测修正量
如果scale=4
base = 双线性插值方法后的输入帧
输出=放大后输入 + 模型预测高频细节
因为模型从不从生成整张
哪里需要去噪 哪里需要变清晰 哪里需要补充细节  哪里需要修复伪影
14 forward函数完整流程
你的forward可以总结成
1检查输入维度
2 把视频展开
 [B,T,3,H,W]->]B*T,3,H,W[
3 每帧CNN特征提取
]B*T, 3,H,W[->[B*T,64,H,W]
4 reshape视频形式
[B*T,64,H,W]->[B,T,64,H,W]
5 计算相邻光流
flows_forward = [B,T-1,2,H,W]
flows_backward=[B,T-1,2,H,W]
6 反向时间传播
 得到backward_feats
 7 正向时间传播
 得到forward_feat
 s8 对每一帧
 当前帧呢特征 未来信息 +过去信息
 reconstruction
 upsample
 base frame
 输出增强
 9 所有输出 stack
 [B,T,3,H*scale,W*scale]
 15 模型训练时怎么用
 普通视频增强
 model = BasicVSRLite(scale=1).cuda()
 lq = torch.randn(2,7,3,128,128).cuda(
 )gt = torch.randn(2,7,3,128,128).cuda()
 pred = model(lq)
 loss = F.l1_loss(pred, gt)
 loss.backward()
 视频超分
 model = BasicVSRLite(scale=4).cuda()
lq = torch.randn(2, 7, 3, 64, 64).cuda()
gt = torch.randn(2, 7, 3, 256, 256).cuda()
pred = model(lq)
loss = F.l1_loss(pred, gt)
loss.backward()
其中
lq = low quality低质量视频
gt = ground truth 高质量视频
16这个模型学习重点
16.1 CNN特征提取
self.feat_extract
16.2 光流对齐
TinyFlowNet + flow_warp
16.3 双向时间传播
backwrd_trunk + forward_trunk
16.4 融合重建
reconstruction + upsample + out +base
当前不清楚地方可以从前后找信息回来
17 需要注意的地方
1. TinyFlowNet 光流效果不一定好
2. 没有使用训练SpyNet / PWC
Net美元BasicVSR+二阶传播
4没有EDVR可变性卷积对齐
5 只用了基础L1Loss视觉锐度可能一般



































评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值