5 Commits

Author SHA1 Message Date
Xintao
ee3e556f18 v0.2.4 2021-12-12 22:54:36 +08:00
Xintao
ad1397180d fix bug in inference: RealESRGAN model is None 2021-12-12 22:46:07 +08:00
Xintao
37237da798 update utils and unittest 2021-11-28 23:09:38 +08:00
Xintao
be73d6d9a4 clean and add more comments 2021-11-27 19:59:23 +08:00
Xintao
0ff1cf7215 update setup.py, V0.2.3 2021-10-22 17:06:29 +08:00
29 changed files with 1087 additions and 255 deletions

View File

@@ -1,12 +1,11 @@
name: No Response
# TODO: it seems not to work
# Modified from: https://raw.githubusercontent.com/github/docs/main/.github/workflows/no-response.yaml
# **What it does**: Closes issues that don't have enough information to be
# actionable.
# **Why we have it**: To remove the need for maintainers to remember to check
# back on issues periodically to see if contributors have
# responded.
# **What it does**: Closes issues that don't have enough information to be actionable.
# **Why we have it**: To remove the need for maintainers to remember to check back on issues periodically
# to see if contributors have responded.
# **Who does it impact**: Everyone that works on docs or docs-internal.
on:

View File

@@ -1 +1 @@
0.2.1
0.2.4

View File

@@ -3,4 +3,4 @@ from .archs import *
from .data import *
from .models import *
from .utils import *
from .version import __gitsha__, __version__
from .version import *

View File

@@ -2,13 +2,27 @@ import torch.nn as nn
from basicsr.utils.registry import ARCH_REGISTRY
def conv3x3(in_planes, out_planes, stride=1):
"""3x3 convolution with padding"""
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
def conv3x3(inplanes, outplanes, stride=1):
"""A simple wrapper for 3x3 convolution with padding.
Args:
inplanes (int): Channel number of inputs.
outplanes (int): Channel number of outputs.
stride (int): Stride in convolution. Default: 1.
"""
return nn.Conv2d(inplanes, outplanes, kernel_size=3, stride=stride, padding=1, bias=False)
class BasicBlock(nn.Module):
expansion = 1
"""Basic residual block used in the ResNetArcFace architecture.
Args:
inplanes (int): Channel number of inputs.
planes (int): Channel number of outputs.
stride (int): Stride in convolution. Default: 1.
downsample (nn.Module): The downsample module. Default: None.
"""
expansion = 1 # output channel expansion ratio
def __init__(self, inplanes, planes, stride=1, downsample=None):
super(BasicBlock, self).__init__()
@@ -40,7 +54,16 @@ class BasicBlock(nn.Module):
class IRBlock(nn.Module):
expansion = 1
"""Improved residual block (IR Block) used in the ResNetArcFace architecture.
Args:
inplanes (int): Channel number of inputs.
planes (int): Channel number of outputs.
stride (int): Stride in convolution. Default: 1.
downsample (nn.Module): The downsample module. Default: None.
use_se (bool): Whether use the SEBlock (squeeze and excitation block). Default: True.
"""
expansion = 1 # output channel expansion ratio
def __init__(self, inplanes, planes, stride=1, downsample=None, use_se=True):
super(IRBlock, self).__init__()
@@ -78,7 +101,15 @@ class IRBlock(nn.Module):
class Bottleneck(nn.Module):
expansion = 4
"""Bottleneck block used in the ResNetArcFace architecture.
Args:
inplanes (int): Channel number of inputs.
planes (int): Channel number of outputs.
stride (int): Stride in convolution. Default: 1.
downsample (nn.Module): The downsample module. Default: None.
"""
expansion = 4 # output channel expansion ratio
def __init__(self, inplanes, planes, stride=1, downsample=None):
super(Bottleneck, self).__init__()
@@ -116,10 +147,16 @@ class Bottleneck(nn.Module):
class SEBlock(nn.Module):
"""The squeeze-and-excitation block (SEBlock) used in the IRBlock.
Args:
channel (int): Channel number of inputs.
reduction (int): Channel reduction ration. Default: 16.
"""
def __init__(self, channel, reduction=16):
super(SEBlock, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.avg_pool = nn.AdaptiveAvgPool2d(1) # pool to 1x1 without spatial information
self.fc = nn.Sequential(
nn.Linear(channel, channel // reduction), nn.PReLU(), nn.Linear(channel // reduction, channel),
nn.Sigmoid())
@@ -133,6 +170,15 @@ class SEBlock(nn.Module):
@ARCH_REGISTRY.register()
class ResNetArcFace(nn.Module):
"""ArcFace with ResNet architectures.
Ref: ArcFace: Additive Angular Margin Loss for Deep Face Recognition.
Args:
block (str): Block used in the ArcFace architecture.
layers (tuple(int)): Block numbers in each layer.
use_se (bool): Whether use the SEBlock (squeeze and excitation block). Default: True.
"""
def __init__(self, block, layers, use_se=True):
if block == 'IRBlock':
@@ -140,6 +186,7 @@ class ResNetArcFace(nn.Module):
self.inplanes = 64
self.use_se = use_se
super(ResNetArcFace, self).__init__()
self.conv1 = nn.Conv2d(1, 64, kernel_size=3, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(64)
self.prelu = nn.PReLU()
@@ -153,6 +200,7 @@ class ResNetArcFace(nn.Module):
self.fc5 = nn.Linear(512 * 8 * 8, 512)
self.bn5 = nn.BatchNorm1d(512)
# initialization
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.xavier_normal_(m.weight)
@@ -163,7 +211,7 @@ class ResNetArcFace(nn.Module):
nn.init.xavier_normal_(m.weight)
nn.init.constant_(m.bias, 0)
def _make_layer(self, block, planes, blocks, stride=1):
def _make_layer(self, block, planes, num_blocks, stride=1):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
@@ -173,7 +221,7 @@ class ResNetArcFace(nn.Module):
layers = []
layers.append(block(self.inplanes, planes, stride, downsample, use_se=self.use_se))
self.inplanes = planes
for _ in range(1, blocks):
for _ in range(1, num_blocks):
layers.append(block(self.inplanes, planes, use_se=self.use_se))
return nn.Sequential(*layers)

View File

@@ -10,18 +10,18 @@ from torch.nn import functional as F
class StyleGAN2GeneratorSFT(StyleGAN2Generator):
"""StyleGAN2 Generator.
"""StyleGAN2 Generator with SFT modulation (Spatial Feature Transform).
Args:
out_size (int): The spatial size of outputs.
num_style_feat (int): Channel number of style features. Default: 512.
num_mlp (int): Layer number of MLP style layers. Default: 8.
channel_multiplier (int): Channel multiplier for large networks of
StyleGAN2. Default: 2.
resample_kernel (list[int]): A list indicating the 1D resample kernel
magnitude. A cross production will be applied to extent 1D resample
kernel to 2D resample kernel. Default: [1, 3, 3, 1].
channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
resample_kernel (list[int]): A list indicating the 1D resample kernel magnitude. A cross production will be
applied to extent 1D resample kernel to 2D resample kernel. Default: (1, 3, 3, 1).
lr_mlp (float): Learning rate multiplier for mlp layers. Default: 0.01.
narrow (float): The narrow ratio for channels. Default: 1.
sft_half (bool): Whether to apply SFT on half of the input channels. Default: False.
"""
def __init__(self,
@@ -53,21 +53,18 @@ class StyleGAN2GeneratorSFT(StyleGAN2Generator):
truncation_latent=None,
inject_index=None,
return_latents=False):
"""Forward function for StyleGAN2Generator.
"""Forward function for StyleGAN2GeneratorSFT.
Args:
styles (list[Tensor]): Sample codes of styles.
input_is_latent (bool): Whether input is latent style.
Default: False.
conditions (list[Tensor]): SFT conditions to generators.
input_is_latent (bool): Whether input is latent style. Default: False.
noise (Tensor | None): Input noise or None. Default: None.
randomize_noise (bool): Randomize noise, used when 'noise' is
False. Default: True.
truncation (float): TODO. Default: 1.
truncation_latent (Tensor | None): TODO. Default: None.
inject_index (int | None): The injection index for mixing noise.
Default: None.
return_latents (bool): Whether to return style latents.
Default: False.
randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True.
truncation (float): The truncation ratio. Default: 1.
truncation_latent (Tensor | None): The truncation latent tensor. Default: None.
inject_index (int | None): The injection index for mixing noise. Default: None.
return_latents (bool): Whether to return style latents. Default: False.
"""
# style codes -> latents with Style MLP layer
if not input_is_latent:
@@ -84,7 +81,7 @@ class StyleGAN2GeneratorSFT(StyleGAN2Generator):
for style in styles:
style_truncation.append(truncation_latent + truncation * (style - truncation_latent))
styles = style_truncation
# get style latent with injection
# get style latents with injection
if len(styles) == 1:
inject_index = self.num_latent
@@ -113,15 +110,15 @@ class StyleGAN2GeneratorSFT(StyleGAN2Generator):
# the conditions may have fewer levels
if i < len(conditions):
# SFT part to combine the conditions
if self.sft_half:
if self.sft_half: # only apply SFT to half of the channels
out_same, out_sft = torch.split(out, int(out.size(1) // 2), dim=1)
out_sft = out_sft * conditions[i - 1] + conditions[i]
out = torch.cat([out_same, out_sft], dim=1)
else:
else: # apply SFT to all the channels
out = out * conditions[i - 1] + conditions[i]
out = conv2(out, latent[:, i + 1], noise=noise2)
skip = to_rgb(out, latent[:, i + 2], skip)
skip = to_rgb(out, latent[:, i + 2], skip) # feature back to the rgb space
i += 2
image = skip
@@ -133,17 +130,15 @@ class StyleGAN2GeneratorSFT(StyleGAN2Generator):
class ConvUpLayer(nn.Module):
"""Conv Up Layer. Bilinear upsample + Conv.
"""Convolutional upsampling layer. It uses bilinear upsampler + Conv.
Args:
in_channels (int): Channel number of the input.
out_channels (int): Channel number of the output.
kernel_size (int): Size of the convolving kernel.
stride (int): Stride of the convolution. Default: 1
padding (int): Zero-padding added to both sides of the input.
Default: 0.
bias (bool): If ``True``, adds a learnable bias to the output.
Default: ``True``.
padding (int): Zero-padding added to both sides of the input. Default: 0.
bias (bool): If ``True``, adds a learnable bias to the output. Default: ``True``.
bias_init_val (float): Bias initialized value. Default: 0.
activate (bool): Whether use activateion. Default: True.
"""
@@ -163,6 +158,7 @@ class ConvUpLayer(nn.Module):
self.kernel_size = kernel_size
self.stride = stride
self.padding = padding
# self.scale is used to scale the convolution weights, which is related to the common initializations.
self.scale = 1 / math.sqrt(in_channels * kernel_size**2)
self.weight = nn.Parameter(torch.randn(out_channels, in_channels, kernel_size, kernel_size))
@@ -223,7 +219,26 @@ class ResUpBlock(nn.Module):
@ARCH_REGISTRY.register()
class GFPGANv1(nn.Module):
"""Unet + StyleGAN2 decoder with SFT."""
"""The GFPGAN architecture: Unet + StyleGAN2 decoder with SFT.
Ref: GFP-GAN: Towards Real-World Blind Face Restoration with Generative Facial Prior.
Args:
out_size (int): The spatial size of outputs.
num_style_feat (int): Channel number of style features. Default: 512.
channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
resample_kernel (list[int]): A list indicating the 1D resample kernel magnitude. A cross production will be
applied to extent 1D resample kernel to 2D resample kernel. Default: (1, 3, 3, 1).
decoder_load_path (str): The path to the pre-trained decoder model (usually, the StyleGAN2). Default: None.
fix_decoder (bool): Whether to fix the decoder. Default: True.
num_mlp (int): Layer number of MLP style layers. Default: 8.
lr_mlp (float): Learning rate multiplier for mlp layers. Default: 0.01.
input_is_latent (bool): Whether input is latent style. Default: False.
different_w (bool): Whether to use different latent w for different layers. Default: False.
narrow (float): The narrow ratio for channels. Default: 1.
sft_half (bool): Whether to apply SFT on half of the input channels. Default: False.
"""
def __init__(
self,
@@ -246,7 +261,7 @@ class GFPGANv1(nn.Module):
self.different_w = different_w
self.num_style_feat = num_style_feat
unet_narrow = narrow * 0.5
unet_narrow = narrow * 0.5 # by default, use a half of input channels
channels = {
'4': int(512 * unet_narrow),
'8': int(512 * unet_narrow),
@@ -295,6 +310,7 @@ class GFPGANv1(nn.Module):
self.final_linear = EqualLinear(
channels['4'] * 4 * 4, linear_out_channel, bias=True, bias_init_val=0, lr_mul=1, activation=None)
# the decoder: stylegan2 generator with SFT modulations
self.stylegan_decoder = StyleGAN2GeneratorSFT(
out_size=out_size,
num_style_feat=num_style_feat,
@@ -305,14 +321,16 @@ class GFPGANv1(nn.Module):
narrow=narrow,
sft_half=sft_half)
# load pre-trained stylegan2 model if necessary
if decoder_load_path:
self.stylegan_decoder.load_state_dict(
torch.load(decoder_load_path, map_location=lambda storage, loc: storage)['params_ema'])
# fix decoder without updating params
if fix_decoder:
for _, param in self.stylegan_decoder.named_parameters():
param.requires_grad = False
# for SFT
# for SFT modulations (scale and shift)
self.condition_scale = nn.ModuleList()
self.condition_shift = nn.ModuleList()
for i in range(3, self.log_size + 1):
@@ -332,13 +350,15 @@ class GFPGANv1(nn.Module):
ScaledLeakyReLU(0.2),
EqualConv2d(out_channels, sft_out_channels, 3, stride=1, padding=1, bias=True, bias_init_val=0)))
def forward(self,
x,
return_latents=False,
save_feat_path=None,
load_feat_path=None,
return_rgb=True,
randomize_noise=True):
def forward(self, x, return_latents=False, return_rgb=True, randomize_noise=True):
"""Forward function for GFPGANv1.
Args:
x (Tensor): Input images.
return_latents (bool): Whether to return style latents. Default: False.
return_rgb (bool): Whether return intermediate rgb images. Default: True.
randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True.
"""
conditions = []
unet_skips = []
out_rgbs = []
@@ -362,7 +382,7 @@ class GFPGANv1(nn.Module):
feat = feat + unet_skips[i]
# ResUpLayer
feat = self.conv_body_up[i](feat)
# generate scale and shift for SFT layer
# generate scale and shift for SFT layers
scale = self.condition_scale[i](feat)
conditions.append(scale.clone())
shift = self.condition_shift[i](feat)
@@ -371,12 +391,6 @@ class GFPGANv1(nn.Module):
if return_rgb:
out_rgbs.append(self.toRGB[i](feat))
if save_feat_path is not None:
torch.save(conditions, save_feat_path)
if load_feat_path is not None:
conditions = torch.load(load_feat_path)
conditions = [v.cuda() for v in conditions]
# decoder
image, _ = self.stylegan_decoder([style_code],
conditions,
@@ -389,10 +403,12 @@ class GFPGANv1(nn.Module):
@ARCH_REGISTRY.register()
class FacialComponentDiscriminator(nn.Module):
"""Facial component (eyes, mouth, noise) discriminator used in GFPGAN.
"""
def __init__(self):
super(FacialComponentDiscriminator, self).__init__()
# It now uses a VGG-style architectrue with fixed model size
self.conv1 = ConvLayer(3, 64, 3, downsample=False, resample_kernel=(1, 3, 3, 1), bias=True, activate=True)
self.conv2 = ConvLayer(64, 128, 3, downsample=True, resample_kernel=(1, 3, 3, 1), bias=True, activate=True)
self.conv3 = ConvLayer(128, 128, 3, downsample=False, resample_kernel=(1, 3, 3, 1), bias=True, activate=True)
@@ -401,6 +417,12 @@ class FacialComponentDiscriminator(nn.Module):
self.final_conv = ConvLayer(256, 1, 3, bias=True, activate=False)
def forward(self, x, return_feats=False):
"""Forward function for FacialComponentDiscriminator.
Args:
x (Tensor): Input images.
return_feats (bool): Whether to return intermediate features. Default: False.
"""
feat = self.conv1(x)
feat = self.conv3(self.conv2(feat))
rlt_feats = []

View File

@@ -1,6 +1,7 @@
import math
import random
import torch
from basicsr.utils.registry import ARCH_REGISTRY
from torch import nn
from torch.nn import functional as F
@@ -8,14 +9,17 @@ from .stylegan2_clean_arch import StyleGAN2GeneratorClean
class StyleGAN2GeneratorCSFT(StyleGAN2GeneratorClean):
"""StyleGAN2 Generator.
"""StyleGAN2 Generator with SFT modulation (Spatial Feature Transform).
It is the clean version without custom compiled CUDA extensions used in StyleGAN2.
Args:
out_size (int): The spatial size of outputs.
num_style_feat (int): Channel number of style features. Default: 512.
num_mlp (int): Layer number of MLP style layers. Default: 8.
channel_multiplier (int): Channel multiplier for large networks of
StyleGAN2. Default: 2.
channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
narrow (float): The narrow ratio for channels. Default: 1.
sft_half (bool): Whether to apply SFT on half of the input channels. Default: False.
"""
def __init__(self, out_size, num_style_feat=512, num_mlp=8, channel_multiplier=2, narrow=1, sft_half=False):
@@ -25,7 +29,6 @@ class StyleGAN2GeneratorCSFT(StyleGAN2GeneratorClean):
num_mlp=num_mlp,
channel_multiplier=channel_multiplier,
narrow=narrow)
self.sft_half = sft_half
def forward(self,
@@ -38,21 +41,18 @@ class StyleGAN2GeneratorCSFT(StyleGAN2GeneratorClean):
truncation_latent=None,
inject_index=None,
return_latents=False):
"""Forward function for StyleGAN2Generator.
"""Forward function for StyleGAN2GeneratorCSFT.
Args:
styles (list[Tensor]): Sample codes of styles.
input_is_latent (bool): Whether input is latent style.
Default: False.
conditions (list[Tensor]): SFT conditions to generators.
input_is_latent (bool): Whether input is latent style. Default: False.
noise (Tensor | None): Input noise or None. Default: None.
randomize_noise (bool): Randomize noise, used when 'noise' is
False. Default: True.
truncation (float): TODO. Default: 1.
truncation_latent (Tensor | None): TODO. Default: None.
inject_index (int | None): The injection index for mixing noise.
Default: None.
return_latents (bool): Whether to return style latents.
Default: False.
randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True.
truncation (float): The truncation ratio. Default: 1.
truncation_latent (Tensor | None): The truncation latent tensor. Default: None.
inject_index (int | None): The injection index for mixing noise. Default: None.
return_latents (bool): Whether to return style latents. Default: False.
"""
# style codes -> latents with Style MLP layer
if not input_is_latent:
@@ -69,7 +69,7 @@ class StyleGAN2GeneratorCSFT(StyleGAN2GeneratorClean):
for style in styles:
style_truncation.append(truncation_latent + truncation * (style - truncation_latent))
styles = style_truncation
# get style latent with injection
# get style latents with injection
if len(styles) == 1:
inject_index = self.num_latent
@@ -98,15 +98,15 @@ class StyleGAN2GeneratorCSFT(StyleGAN2GeneratorClean):
# the conditions may have fewer levels
if i < len(conditions):
# SFT part to combine the conditions
if self.sft_half:
if self.sft_half: # only apply SFT to half of the channels
out_same, out_sft = torch.split(out, int(out.size(1) // 2), dim=1)
out_sft = out_sft * conditions[i - 1] + conditions[i]
out = torch.cat([out_same, out_sft], dim=1)
else:
else: # apply SFT to all the channels
out = out * conditions[i - 1] + conditions[i]
out = conv2(out, latent[:, i + 1], noise=noise2)
skip = to_rgb(out, latent[:, i + 2], skip)
skip = to_rgb(out, latent[:, i + 2], skip) # feature back to the rgb space
i += 2
image = skip
@@ -118,11 +118,12 @@ class StyleGAN2GeneratorCSFT(StyleGAN2GeneratorClean):
class ResBlock(nn.Module):
"""Residual block with upsampling/downsampling.
"""Residual block with bilinear upsampling/downsampling.
Args:
in_channels (int): Channel number of the input.
out_channels (int): Channel number of the output.
mode (str): Upsampling/downsampling mode. Options: down | up. Default: down.
"""
def __init__(self, in_channels, out_channels, mode='down'):
@@ -148,8 +149,27 @@ class ResBlock(nn.Module):
return out
@ARCH_REGISTRY.register()
class GFPGANv1Clean(nn.Module):
"""GFPGANv1 Clean version."""
"""The GFPGAN architecture: Unet + StyleGAN2 decoder with SFT.
It is the clean version without custom compiled CUDA extensions used in StyleGAN2.
Ref: GFP-GAN: Towards Real-World Blind Face Restoration with Generative Facial Prior.
Args:
out_size (int): The spatial size of outputs.
num_style_feat (int): Channel number of style features. Default: 512.
channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
decoder_load_path (str): The path to the pre-trained decoder model (usually, the StyleGAN2). Default: None.
fix_decoder (bool): Whether to fix the decoder. Default: True.
num_mlp (int): Layer number of MLP style layers. Default: 8.
input_is_latent (bool): Whether input is latent style. Default: False.
different_w (bool): Whether to use different latent w for different layers. Default: False.
narrow (float): The narrow ratio for channels. Default: 1.
sft_half (bool): Whether to apply SFT on half of the input channels. Default: False.
"""
def __init__(
self,
@@ -170,7 +190,7 @@ class GFPGANv1Clean(nn.Module):
self.different_w = different_w
self.num_style_feat = num_style_feat
unet_narrow = narrow * 0.5
unet_narrow = narrow * 0.5 # by default, use a half of input channels
channels = {
'4': int(512 * unet_narrow),
'8': int(512 * unet_narrow),
@@ -218,6 +238,7 @@ class GFPGANv1Clean(nn.Module):
self.final_linear = nn.Linear(channels['4'] * 4 * 4, linear_out_channel)
# the decoder: stylegan2 generator with SFT modulations
self.stylegan_decoder = StyleGAN2GeneratorCSFT(
out_size=out_size,
num_style_feat=num_style_feat,
@@ -226,14 +247,16 @@ class GFPGANv1Clean(nn.Module):
narrow=narrow,
sft_half=sft_half)
# load pre-trained stylegan2 model if necessary
if decoder_load_path:
self.stylegan_decoder.load_state_dict(
torch.load(decoder_load_path, map_location=lambda storage, loc: storage)['params_ema'])
# fix decoder without updating params
if fix_decoder:
for _, param in self.stylegan_decoder.named_parameters():
param.requires_grad = False
# for SFT
# for SFT modulations (scale and shift)
self.condition_scale = nn.ModuleList()
self.condition_shift = nn.ModuleList()
for i in range(3, self.log_size + 1):
@@ -251,13 +274,15 @@ class GFPGANv1Clean(nn.Module):
nn.Conv2d(out_channels, out_channels, 3, 1, 1), nn.LeakyReLU(0.2, True),
nn.Conv2d(out_channels, sft_out_channels, 3, 1, 1)))
def forward(self,
x,
return_latents=False,
save_feat_path=None,
load_feat_path=None,
return_rgb=True,
randomize_noise=True):
def forward(self, x, return_latents=False, return_rgb=True, randomize_noise=True):
"""Forward function for GFPGANv1Clean.
Args:
x (Tensor): Input images.
return_latents (bool): Whether to return style latents. Default: False.
return_rgb (bool): Whether return intermediate rgb images. Default: True.
randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True.
"""
conditions = []
unet_skips = []
out_rgbs = []
@@ -273,13 +298,14 @@ class GFPGANv1Clean(nn.Module):
style_code = self.final_linear(feat.view(feat.size(0), -1))
if self.different_w:
style_code = style_code.view(style_code.size(0), -1, self.num_style_feat)
# decode
for i in range(self.log_size - 2):
# add unet skip
feat = feat + unet_skips[i]
# ResUpLayer
feat = self.conv_body_up[i](feat)
# generate scale and shift for SFT layer
# generate scale and shift for SFT layers
scale = self.condition_scale[i](feat)
conditions.append(scale.clone())
shift = self.condition_shift[i](feat)
@@ -288,12 +314,6 @@ class GFPGANv1Clean(nn.Module):
if return_rgb:
out_rgbs.append(self.toRGB[i](feat))
if save_feat_path is not None:
torch.save(conditions, save_feat_path)
if load_feat_path is not None:
conditions = torch.load(load_feat_path)
conditions = [v.cuda() for v in conditions]
# decoder
image, _ = self.stylegan_decoder([style_code],
conditions,

View File

@@ -31,12 +31,9 @@ class ModulatedConv2d(nn.Module):
out_channels (int): Channel number of the output.
kernel_size (int): Size of the convolving kernel.
num_style_feat (int): Channel number of style features.
demodulate (bool): Whether to demodulate in the conv layer.
Default: True.
sample_mode (str | None): Indicating 'upsample', 'downsample' or None.
Default: None.
eps (float): A value added to the denominator for numerical stability.
Default: 1e-8.
demodulate (bool): Whether to demodulate in the conv layer. Default: True.
sample_mode (str | None): Indicating 'upsample', 'downsample' or None. Default: None.
eps (float): A value added to the denominator for numerical stability. Default: 1e-8.
"""
def __init__(self,
@@ -87,6 +84,7 @@ class ModulatedConv2d(nn.Module):
weight = weight.view(b * self.out_channels, c, self.kernel_size, self.kernel_size)
# upsample or downsample if necessary
if self.sample_mode == 'upsample':
x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)
elif self.sample_mode == 'downsample':
@@ -101,14 +99,12 @@ class ModulatedConv2d(nn.Module):
return out
def __repr__(self):
return (f'{self.__class__.__name__}(in_channels={self.in_channels}, '
f'out_channels={self.out_channels}, '
f'kernel_size={self.kernel_size}, '
f'demodulate={self.demodulate}, sample_mode={self.sample_mode})')
return (f'{self.__class__.__name__}(in_channels={self.in_channels}, out_channels={self.out_channels}, '
f'kernel_size={self.kernel_size}, demodulate={self.demodulate}, sample_mode={self.sample_mode})')
class StyleConv(nn.Module):
"""Style conv.
"""Style conv used in StyleGAN2.
Args:
in_channels (int): Channel number of the input.
@@ -116,8 +112,7 @@ class StyleConv(nn.Module):
kernel_size (int): Size of the convolving kernel.
num_style_feat (int): Channel number of style features.
demodulate (bool): Whether demodulate in the conv layer. Default: True.
sample_mode (str | None): Indicating 'upsample', 'downsample' or None.
Default: None.
sample_mode (str | None): Indicating 'upsample', 'downsample' or None. Default: None.
"""
def __init__(self, in_channels, out_channels, kernel_size, num_style_feat, demodulate=True, sample_mode=None):
@@ -144,7 +139,7 @@ class StyleConv(nn.Module):
class ToRGB(nn.Module):
"""To RGB from features.
"""To RGB (image space) from features.
Args:
in_channels (int): Channel number of input.
@@ -204,8 +199,7 @@ class StyleGAN2GeneratorClean(nn.Module):
out_size (int): The spatial size of outputs.
num_style_feat (int): Channel number of style features. Default: 512.
num_mlp (int): Layer number of MLP style layers. Default: 8.
channel_multiplier (int): Channel multiplier for large networks of
StyleGAN2. Default: 2.
channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
narrow (float): Narrow ratio for channels. Default: 1.0.
"""
@@ -222,6 +216,7 @@ class StyleGAN2GeneratorClean(nn.Module):
# initialization
default_init_weights(self.style_mlp, scale=1, bias_fill=0, a=0.2, mode='fan_in', nonlinearity='leaky_relu')
# channel list
channels = {
'4': int(512 * narrow),
'8': int(512 * narrow),
@@ -309,21 +304,17 @@ class StyleGAN2GeneratorClean(nn.Module):
truncation_latent=None,
inject_index=None,
return_latents=False):
"""Forward function for StyleGAN2Generator.
"""Forward function for StyleGAN2GeneratorClean.
Args:
styles (list[Tensor]): Sample codes of styles.
input_is_latent (bool): Whether input is latent style.
Default: False.
input_is_latent (bool): Whether input is latent style. Default: False.
noise (Tensor | None): Input noise or None. Default: None.
randomize_noise (bool): Randomize noise, used when 'noise' is
False. Default: True.
truncation (float): TODO. Default: 1.
truncation_latent (Tensor | None): TODO. Default: None.
inject_index (int | None): The injection index for mixing noise.
Default: None.
return_latents (bool): Whether to return style latents.
Default: False.
randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True.
truncation (float): The truncation ratio. Default: 1.
truncation_latent (Tensor | None): The truncation latent tensor. Default: None.
inject_index (int | None): The injection index for mixing noise. Default: None.
return_latents (bool): Whether to return style latents. Default: False.
"""
# style codes -> latents with Style MLP layer
if not input_is_latent:
@@ -340,7 +331,7 @@ class StyleGAN2GeneratorClean(nn.Module):
for style in styles:
style_truncation.append(truncation_latent + truncation * (style - truncation_latent))
styles = style_truncation
# get style latent with injection
# get style latents with injection
if len(styles) == 1:
inject_index = self.num_latent
@@ -366,7 +357,7 @@ class StyleGAN2GeneratorClean(nn.Module):
noise[2::2], self.to_rgbs):
out = conv1(out, latent[:, i], noise=noise1)
out = conv2(out, latent[:, i + 1], noise=noise2)
skip = to_rgb(out, latent[:, i + 2], skip)
skip = to_rgb(out, latent[:, i + 2], skip) # feature back to the rgb space
i += 2
image = skip

View File

@@ -15,6 +15,19 @@ from torchvision.transforms.functional import (adjust_brightness, adjust_contras
@DATASET_REGISTRY.register()
class FFHQDegradationDataset(data.Dataset):
"""FFHQ dataset for GFPGAN.
It reads high resolution images, and then generate low-quality (LQ) images on-the-fly.
Args:
opt (dict): Config for train datasets. It contains the following keys:
dataroot_gt (str): Data root path for gt.
io_backend (dict): IO backend type and other kwarg.
mean (list | tuple): Image mean.
std (list | tuple): Image std.
use_hflip (bool): Whether to horizontally flip.
Please see more options in the codes.
"""
def __init__(self, opt):
super(FFHQDegradationDataset, self).__init__()
@@ -29,11 +42,13 @@ class FFHQDegradationDataset(data.Dataset):
self.out_size = opt['out_size']
self.crop_components = opt.get('crop_components', False) # facial components
self.eye_enlarge_ratio = opt.get('eye_enlarge_ratio', 1)
self.eye_enlarge_ratio = opt.get('eye_enlarge_ratio', 1) # whether enlarge eye regions
if self.crop_components:
# load component list from a pre-process pth files
self.components_list = torch.load(opt.get('component_path'))
# file client (lmdb io backend)
if self.io_backend_opt['type'] == 'lmdb':
self.io_backend_opt['db_paths'] = self.gt_folder
if not self.gt_folder.endswith('.lmdb'):
@@ -41,9 +56,10 @@ class FFHQDegradationDataset(data.Dataset):
with open(osp.join(self.gt_folder, 'meta_info.txt')) as fin:
self.paths = [line.split('.')[0] for line in fin]
else:
# disk backend: scan file list from a folder
self.paths = paths_from_folder(self.gt_folder)
# degradations
# degradation configurations
self.blur_kernel_size = opt['blur_kernel_size']
self.kernel_list = opt['kernel_list']
self.kernel_prob = opt['kernel_prob']
@@ -60,22 +76,20 @@ class FFHQDegradationDataset(data.Dataset):
self.gray_prob = opt.get('gray_prob')
logger = get_root_logger()
logger.info(f'Blur: blur_kernel_size {self.blur_kernel_size}, '
f'sigma: [{", ".join(map(str, self.blur_sigma))}]')
logger.info(f'Blur: blur_kernel_size {self.blur_kernel_size}, sigma: [{", ".join(map(str, self.blur_sigma))}]')
logger.info(f'Downsample: downsample_range [{", ".join(map(str, self.downsample_range))}]')
logger.info(f'Noise: [{", ".join(map(str, self.noise_range))}]')
logger.info(f'JPEG compression: [{", ".join(map(str, self.jpeg_range))}]')
if self.color_jitter_prob is not None:
logger.info(f'Use random color jitter. Prob: {self.color_jitter_prob}, '
f'shift: {self.color_jitter_shift}')
logger.info(f'Use random color jitter. Prob: {self.color_jitter_prob}, shift: {self.color_jitter_shift}')
if self.gray_prob is not None:
logger.info(f'Use random gray. Prob: {self.gray_prob}')
self.color_jitter_shift /= 255.
@staticmethod
def color_jitter(img, shift):
"""jitter color: randomly jitter the RGB values, in numpy formats"""
jitter_val = np.random.uniform(-shift, shift, 3).astype(np.float32)
img = img + jitter_val
img = np.clip(img, 0, 1)
@@ -83,6 +97,7 @@ class FFHQDegradationDataset(data.Dataset):
@staticmethod
def color_jitter_pt(img, brightness, contrast, saturation, hue):
"""jitter color: randomly jitter the brightness, contrast, saturation, and hue, in torch Tensor formats"""
fn_idx = torch.randperm(4)
for fn_id in fn_idx:
if fn_id == 0 and brightness is not None:
@@ -103,6 +118,7 @@ class FFHQDegradationDataset(data.Dataset):
return img
def get_component_coordinates(self, index, status):
"""Get facial component (left_eye, right_eye, mouth) coordinates from a pre-loaded pth file"""
components_bbox = self.components_list[f'{index:08d}']
if status[0]: # hflip
# exchange right and left eye
@@ -131,6 +147,7 @@ class FFHQDegradationDataset(data.Dataset):
self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
# load gt image
# Shape: (h, w, c); channel order: BGR; image range: [0, 1], float32.
gt_path = self.paths[index]
img_bytes = self.file_client.get(gt_path)
img_gt = imfrombytes(img_bytes, float32=True)
@@ -139,6 +156,7 @@ class FFHQDegradationDataset(data.Dataset):
img_gt, status = augment(img_gt, hflip=self.opt['use_hflip'], rotation=False, return_status=True)
h, w, _ = img_gt.shape
# get facial component coordinates
if self.crop_components:
locations = self.get_component_coordinates(index, status)
loc_left_eye, loc_right_eye, loc_mouth = locations
@@ -173,9 +191,9 @@ class FFHQDegradationDataset(data.Dataset):
if self.gray_prob and np.random.uniform() < self.gray_prob:
img_lq = cv2.cvtColor(img_lq, cv2.COLOR_BGR2GRAY)
img_lq = np.tile(img_lq[:, :, None], [1, 1, 3])
if self.opt.get('gt_gray'):
if self.opt.get('gt_gray'): # whether convert GT to gray images
img_gt = cv2.cvtColor(img_gt, cv2.COLOR_BGR2GRAY)
img_gt = np.tile(img_gt[:, :, None], [1, 1, 3])
img_gt = np.tile(img_gt[:, :, None], [1, 1, 3]) # repeat the color channels
# BGR to RGB, HWC to CHW, numpy to tensor
img_gt, img_lq = img2tensor([img_gt, img_lq], bgr2rgb=True, float32=True)

View File

@@ -16,11 +16,11 @@ from tqdm import tqdm
@MODEL_REGISTRY.register()
class GFPGANModel(BaseModel):
"""GFPGAN model for <Towards real-world blind face restoratin with generative facial prior>"""
"""The GFPGAN model for Towards real-world blind face restoratin with generative facial prior"""
def __init__(self, opt):
super(GFPGANModel, self).__init__(opt)
self.idx = 0
self.idx = 0 # it is used for saving data for check
# define network
self.net_g = build_network(opt['network_g'])
@@ -51,8 +51,7 @@ class GFPGANModel(BaseModel):
self.load_network(self.net_d, load_path, self.opt['path'].get('strict_load_d', True))
# ----------- define net_g with Exponential Moving Average (EMA) ----------- #
# net_g_ema only used for testing on one GPU and saving
# There is no need to wrap with DistributedDataParallel
# net_g_ema only used for testing on one GPU and saving. There is no need to wrap with DistributedDataParallel
self.net_g_ema = build_network(self.opt['network_g']).to(self.device)
# load pretrained model
load_path = self.opt['path'].get('pretrain_network_g', None)
@@ -65,7 +64,7 @@ class GFPGANModel(BaseModel):
self.net_d.train()
self.net_g_ema.eval()
# ----------- facial components networks ----------- #
# ----------- facial component networks ----------- #
if ('network_d_left_eye' in self.opt and 'network_d_right_eye' in self.opt and 'network_d_mouth' in self.opt):
self.use_facial_disc = True
else:
@@ -102,17 +101,19 @@ class GFPGANModel(BaseModel):
self.cri_component = build_loss(train_opt['gan_component_opt']).to(self.device)
# ----------- define losses ----------- #
# pixel loss
if train_opt.get('pixel_opt'):
self.cri_pix = build_loss(train_opt['pixel_opt']).to(self.device)
else:
self.cri_pix = None
# perceptual loss
if train_opt.get('perceptual_opt'):
self.cri_perceptual = build_loss(train_opt['perceptual_opt']).to(self.device)
else:
self.cri_perceptual = None
# L1 loss used in pyramid loss, component style loss and identity loss
# L1 loss is used in pyramid loss, component style loss and identity loss
self.cri_l1 = build_loss(train_opt['L1_opt']).to(self.device)
# gan loss (wgan)
@@ -179,6 +180,7 @@ class GFPGANModel(BaseModel):
self.optimizer_d = self.get_optimizer(optim_type, optim_params_d, lr, betas=betas)
self.optimizers.append(self.optimizer_d)
# ----------- optimizers for facial component networks ----------- #
if self.use_facial_disc:
# setup optimizers for facial component discriminators
optim_type = train_opt['optim_component'].pop('type')
@@ -221,6 +223,7 @@ class GFPGANModel(BaseModel):
# self.idx = self.idx + 1
def construct_img_pyramid(self):
"""Construct image pyramid for intermediate restoration loss"""
pyramid_gt = [self.gt]
down_img = self.gt
for _ in range(0, self.log_size - 3):
@@ -229,7 +232,6 @@ class GFPGANModel(BaseModel):
return pyramid_gt
def get_roi_regions(self, eye_out_size=80, mouth_out_size=120):
# hard code
face_ratio = int(self.opt['network_g']['out_size'] / 512)
eye_out_size *= face_ratio
mouth_out_size *= face_ratio
@@ -288,6 +290,7 @@ class GFPGANModel(BaseModel):
p.requires_grad = False
self.optimizer_g.zero_grad()
# do not update facial component net_d
if self.use_facial_disc:
for p in self.net_d_left_eye.parameters():
p.requires_grad = False
@@ -419,11 +422,12 @@ class GFPGANModel(BaseModel):
real_d_pred = self.net_d(self.gt)
l_d = self.cri_gan(real_d_pred, True, is_disc=True) + self.cri_gan(fake_d_pred, False, is_disc=True)
loss_dict['l_d'] = l_d
# In wgan, real_score should be positive and fake_score should benegative
# In WGAN, real_score should be positive and fake_score should be negative
loss_dict['real_score'] = real_d_pred.detach().mean()
loss_dict['fake_score'] = fake_d_pred.detach().mean()
l_d.backward()
# regularization loss
if current_iter % self.net_d_reg_every == 0:
self.gt.requires_grad = True
real_pred = self.net_d(self.gt)
@@ -434,8 +438,9 @@ class GFPGANModel(BaseModel):
self.optimizer_d.step()
# optimize facial component discriminators
if self.use_facial_disc:
# lefe eye
# left eye
fake_d_pred, _ = self.net_d_left_eye(self.left_eyes.detach())
real_d_pred, _ = self.net_d_left_eye(self.left_eyes_gt)
l_d_left_eye = self.cri_component(
@@ -485,22 +490,32 @@ class GFPGANModel(BaseModel):
def nondist_validation(self, dataloader, current_iter, tb_logger, save_img):
dataset_name = dataloader.dataset.opt['name']
with_metrics = self.opt['val'].get('metrics') is not None
use_pbar = self.opt['val'].get('pbar', False)
if with_metrics:
self.metric_results = {metric: 0 for metric in self.opt['val']['metrics'].keys()}
pbar = tqdm(total=len(dataloader), unit='image')
if not hasattr(self, 'metric_results'): # only execute in the first run
self.metric_results = {metric: 0 for metric in self.opt['val']['metrics'].keys()}
# initialize the best metric results for each dataset_name (supporting multiple validation datasets)
self._initialize_best_metric_results(dataset_name)
# zero self.metric_results
self.metric_results = {metric: 0 for metric in self.metric_results}
metric_data = dict()
if use_pbar:
pbar = tqdm(total=len(dataloader), unit='image')
for idx, val_data in enumerate(dataloader):
img_name = osp.splitext(osp.basename(val_data['lq_path'][0]))[0]
self.feed_data(val_data)
self.test()
visuals = self.get_current_visuals()
sr_img = tensor2img([visuals['sr']], min_max=(-1, 1))
gt_img = tensor2img([visuals['gt']], min_max=(-1, 1))
if 'gt' in visuals:
gt_img = tensor2img([visuals['gt']], min_max=(-1, 1))
sr_img = tensor2img(self.output.detach().cpu(), min_max=(-1, 1))
metric_data['img'] = sr_img
if hasattr(self, 'gt'):
gt_img = tensor2img(self.gt.detach().cpu(), min_max=(-1, 1))
metric_data['img2'] = gt_img
del self.gt
# tentative for out of GPU memory
del self.lq
del self.output
@@ -522,35 +537,38 @@ class GFPGANModel(BaseModel):
if with_metrics:
# calculate metrics
for name, opt_ in self.opt['val']['metrics'].items():
metric_data = dict(img1=sr_img, img2=gt_img)
self.metric_results[name] += calculate_metric(metric_data, opt_)
pbar.update(1)
pbar.set_description(f'Test {img_name}')
pbar.close()
if use_pbar:
pbar.update(1)
pbar.set_description(f'Test {img_name}')
if use_pbar:
pbar.close()
if with_metrics:
for metric in self.metric_results.keys():
self.metric_results[metric] /= (idx + 1)
# update the best metric result
self._update_best_metric_result(dataset_name, metric, self.metric_results[metric], current_iter)
self._log_validation_metric_values(current_iter, dataset_name, tb_logger)
def _log_validation_metric_values(self, current_iter, dataset_name, tb_logger):
log_str = f'Validation {dataset_name}\n'
for metric, value in self.metric_results.items():
log_str += f'\t # {metric}: {value:.4f}\n'
log_str += f'\t # {metric}: {value:.4f}'
if hasattr(self, 'best_metric_results'):
log_str += (f'\tBest: {self.best_metric_results[dataset_name][metric]["val"]:.4f} @ '
f'{self.best_metric_results[dataset_name][metric]["iter"]} iter')
log_str += '\n'
logger = get_root_logger()
logger.info(log_str)
if tb_logger:
for metric, value in self.metric_results.items():
tb_logger.add_scalar(f'metrics/{metric}', value, current_iter)
def get_current_visuals(self):
out_dict = OrderedDict()
out_dict['gt'] = self.gt.detach().cpu()
out_dict['sr'] = self.output.detach().cpu()
return out_dict
tb_logger.add_scalar(f'metrics/{dataset_name}/{metric}', value, current_iter)
def save(self, epoch, current_iter):
# save net_g and net_d
self.save_network([self.net_g, self.net_g_ema], 'net_g', current_iter, param_key=['params', 'params_ema'])
self.save_network(self.net_d, 'net_d', current_iter)
# save component discriminators
@@ -558,4 +576,5 @@ class GFPGANModel(BaseModel):
self.save_network(self.net_d_left_eye, 'net_d_left_eye', current_iter)
self.save_network(self.net_d_right_eye, 'net_d_right_eye', current_iter)
self.save_network(self.net_d_mouth, 'net_d_mouth', current_iter)
# save training state
self.save_training_state(epoch, current_iter)

View File

@@ -2,10 +2,9 @@ import cv2
import os
import torch
from basicsr.utils import img2tensor, tensor2img
from basicsr.utils.download_util import load_file_from_url
from facexlib.utils.face_restoration_helper import FaceRestoreHelper
from torch.hub import download_url_to_file, get_dir
from torchvision.transforms.functional import normalize
from urllib.parse import urlparse
from gfpgan.archs.gfpganv1_arch import GFPGANv1
from gfpgan.archs.gfpganv1_clean_arch import GFPGANv1Clean
@@ -14,6 +13,20 @@ ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
class GFPGANer():
"""Helper for restoration with GFPGAN.
It will detect and crop faces, and then resize the faces to 512x512.
GFPGAN is used to restored the resized faces.
The background is upsampled with the bg_upsampler.
Finally, the faces will be pasted back to the upsample background image.
Args:
model_path (str): The path to the GFPGAN model. It can be urls (will first download it automatically).
upscale (float): The upscale of the final output. Default: 2.
arch (str): The GFPGAN architecture. Option: clean | original. Default: clean.
channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
bg_upsampler (nn.Module): The upsampler for the background. Default: None.
"""
def __init__(self, model_path, upscale=2, arch='clean', channel_multiplier=2, bg_upsampler=None):
self.upscale = upscale
@@ -56,7 +69,8 @@ class GFPGANer():
device=self.device)
if model_path.startswith('https://'):
model_path = load_file_from_url(url=model_path, model_dir='gfpgan/weights', progress=True, file_name=None)
model_path = load_file_from_url(
url=model_path, model_dir=os.path.join(ROOT_DIR, 'gfpgan/weights'), progress=True, file_name=None)
loadnet = torch.load(model_path)
if 'params_ema' in loadnet:
keyname = 'params_ema'
@@ -70,7 +84,7 @@ class GFPGANer():
def enhance(self, img, has_aligned=False, only_center_face=False, paste_back=True):
self.face_helper.clean_all()
if has_aligned:
if has_aligned: # the inputs are already aligned
img = cv2.resize(img, (512, 512))
self.face_helper.cropped_faces = [img]
else:
@@ -78,6 +92,7 @@ class GFPGANer():
# get face landmarks for each face
self.face_helper.get_face_landmarks_5(only_center_face=only_center_face, eye_dist_threshold=5)
# eye_dist_threshold=5: skip faces whose eye distance is smaller than 5 pixels
# TODO: even with eye_dist_threshold, it will still introduce wrong detections and restorations.
# align and warp each face
self.face_helper.align_warp_face()
@@ -100,9 +115,9 @@ class GFPGANer():
self.face_helper.add_restored_face(restored_face)
if not has_aligned and paste_back:
# upsample the background
if self.bg_upsampler is not None:
# Now only support RealESRGAN
# Now only support RealESRGAN for upsampling background
bg_img = self.bg_upsampler.enhance(img, outscale=self.upscale)[0]
else:
bg_img = None
@@ -113,23 +128,3 @@ class GFPGANer():
return self.face_helper.cropped_faces, self.face_helper.restored_faces, restored_img
else:
return self.face_helper.cropped_faces, self.face_helper.restored_faces, None
def load_file_from_url(url, model_dir=None, progress=True, file_name=None):
"""Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py
"""
if model_dir is None:
hub_dir = get_dir()
model_dir = os.path.join(hub_dir, 'checkpoints')
os.makedirs(os.path.join(ROOT_DIR, model_dir), exist_ok=True)
parts = urlparse(url)
filename = os.path.basename(parts.path)
if file_name is not None:
filename = file_name
cached_file = os.path.abspath(os.path.join(ROOT_DIR, model_dir, filename))
if not os.path.exists(cached_file):
print(f'Downloading: "{url}" to {cached_file}\n')
download_url_to_file(url, cached_file, hash_prefix=None, progress=progress)
return cached_file

View File

@@ -10,20 +10,22 @@ from gfpgan import GFPGANer
def main():
"""Inference demo for GFPGAN.
"""
parser = argparse.ArgumentParser()
parser.add_argument('--upscale', type=int, default=2)
parser.add_argument('--arch', type=str, default='clean')
parser.add_argument('--channel', type=int, default=2)
parser.add_argument('--upscale', type=int, default=2, help='The final upsampling scale of the image')
parser.add_argument('--arch', type=str, default='clean', help='The GFPGAN architecture. Option: clean | original')
parser.add_argument('--channel', type=int, default=2, help='Channel multiplier for large networks of StyleGAN2')
parser.add_argument('--model_path', type=str, default='experiments/pretrained_models/GFPGANCleanv1-NoCE-C2.pth')
parser.add_argument('--bg_upsampler', type=str, default='realesrgan')
parser.add_argument('--bg_tile', type=int, default=400)
parser.add_argument('--test_path', type=str, default='inputs/whole_imgs')
parser.add_argument('--bg_upsampler', type=str, default='realesrgan', help='background upsampler')
parser.add_argument(
'--bg_tile', type=int, default=400, help='Tile size for background sampler, 0 for no tile during testing')
parser.add_argument('--test_path', type=str, default='inputs/whole_imgs', help='Input folder')
parser.add_argument('--suffix', type=str, default=None, help='Suffix of the restored faces')
parser.add_argument('--only_center_face', action='store_true')
parser.add_argument('--aligned', action='store_true')
parser.add_argument('--paste_back', action='store_false')
parser.add_argument('--save_root', type=str, default='results')
parser.add_argument('--only_center_face', action='store_true', help='Only restore the center face')
parser.add_argument('--aligned', action='store_true', help='Input are aligned faces')
parser.add_argument('--paste_back', action='store_false', help='Paste the restored faces back to images')
parser.add_argument('--save_root', type=str, default='results', help='Path to save root')
parser.add_argument(
'--ext',
type=str,
@@ -44,10 +46,13 @@ def main():
'If you really want to use it, please modify the corresponding codes.')
bg_upsampler = None
else:
from basicsr.archs.rrdbnet_arch import RRDBNet
from realesrgan import RealESRGANer
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2)
bg_upsampler = RealESRGANer(
scale=2,
model_path='https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth',
model=model,
tile=args.bg_tile,
tile_pad=10,
pre_pad=0,
@@ -70,6 +75,7 @@ def main():
basename, ext = os.path.splitext(img_name)
input_img = cv2.imread(img_path, cv2.IMREAD_COLOR)
# restore faces and background if necessary
cropped_faces, restored_faces, restored_img = restorer.enhance(
input_img, has_aligned=args.aligned, only_center_face=args.only_center_face, paste_back=args.paste_back)
@@ -85,7 +91,7 @@ def main():
save_face_name = f'{basename}_{idx:02d}.png'
save_restore_path = os.path.join(args.save_root, 'restored_faces', save_face_name)
imwrite(restored_face, save_restore_path)
# save cmp image
# save comparison image
cmp_img = np.concatenate((cropped_face, restored_face), axis=1)
imwrite(cmp_img, os.path.join(args.save_root, 'cmp', f'{basename}_{idx:02d}.png'))

View File

@@ -1,7 +1,7 @@
# general settings
name: train_GFPGANv1_512
model_type: GFPGANModel
num_gpu: 4
num_gpu: auto # officially, we use 4 GPUs
manual_seed: 0
# dataset and data loader settings
@@ -194,7 +194,7 @@ val:
save_img: true
metrics:
psnr: # metric name, can be arbitrary
psnr: # metric name
type: calculate_psnr
crop_border: 0
test_y_channel: false

View File

@@ -1,7 +1,7 @@
# general settings
name: train_GFPGANv1_512_simple
model_type: GFPGANModel
num_gpu: 4
num_gpu: auto # officially, we use 4 GPUs
manual_seed: 0
# dataset and data loader settings
@@ -40,10 +40,6 @@ datasets:
# gray_prob: 0.01
# gt_gray: True
# crop_components: false
# component_path: experiments/pretrained_models/FFHQ_eye_mouth_landmarks_512.pth
# eye_enlarge_ratio: 1.4
# data loader
use_shuffle: true
num_worker_per_gpu: 6
@@ -86,20 +82,6 @@ network_d:
channel_multiplier: 1
resample_kernel: [1, 3, 3, 1]
# network_d_left_eye:
# type: FacialComponentDiscriminator
# network_d_right_eye:
# type: FacialComponentDiscriminator
# network_d_mouth:
# type: FacialComponentDiscriminator
network_identity:
type: ResNetArcFace
block: IRBlock
layers: [2, 2, 2, 2]
use_se: False
# path
path:
@@ -107,13 +89,7 @@ path:
param_key_g: params_ema
strict_load_g: ~
pretrain_network_d: ~
# pretrain_network_d_left_eye: ~
# pretrain_network_d_right_eye: ~
# pretrain_network_d_mouth: ~
pretrain_network_identity: experiments/pretrained_models/arcface_resnet18.pth
# resume
resume_state: ~
ignore_resume_networks: ['network_identity']
# training settings
train:
@@ -173,16 +149,6 @@ train:
loss_weight: !!float 1e-1
# r1 regularization for discriminator
r1_reg_weight: 10
# facial component loss
# gan_component_opt:
# type: GANLoss
# gan_type: vanilla
# real_label_val: 1.0
# fake_label_val: 0.0
# loss_weight: !!float 1
# comp_style_weight: 200
# identity loss
identity_weight: 10
net_d_iters: 1
net_d_init_iters: 0
@@ -194,7 +160,7 @@ val:
save_img: true
metrics:
psnr: # metric name, can be arbitrary
psnr: # metric name
type: calculate_psnr
crop_border: 0
test_y_channel: false

View File

@@ -1,24 +1,31 @@
import cv2
import json
import numpy as np
import os
import torch
from basicsr.utils import FileClient, imfrombytes
from collections import OrderedDict
# ---------------------------- This script is used to parse facial landmarks ------------------------------------- #
# Configurations
save_img = False
scale = 0.5 # 0.5 for official FFHQ (512x512), 1 for others
enlarge_ratio = 1.4 # only for eyes
json_path = 'ffhq-dataset-v2.json'
face_path = 'datasets/ffhq/ffhq_512.lmdb'
save_path = './FFHQ_eye_mouth_landmarks_512.pth'
print('Load JSON metadata...')
# use the json file in FFHQ dataset
with open('ffhq-dataset-v2.json', 'rb') as f:
# use the official json file in FFHQ dataset
with open(json_path, 'rb') as f:
json_data = json.load(f, object_pairs_hook=OrderedDict)
print('Open LMDB file...')
# read ffhq images
file_client = FileClient('lmdb', db_paths='datasets/ffhq/ffhq_512.lmdb')
with open('datasets/ffhq/ffhq_512.lmdb/meta_info.txt') as fin:
file_client = FileClient('lmdb', db_paths=face_path)
with open(os.path.join(face_path, 'meta_info.txt')) as fin:
paths = [line.split('.')[0] for line in fin]
save_img = False
scale = 0.5 # 0.5 for official FFHQ (512x512), 1 for others
enlarge_ratio = 1.4 # only for eyes
save_dict = {}
for item_idx, item in enumerate(json_data.values()):
@@ -34,6 +41,7 @@ for item_idx, item in enumerate(json_data.values()):
img_bytes = file_client.get(paths[item_idx])
img = imfrombytes(img_bytes, float32=True)
# get landmarks for each component
map_left_eye = list(range(36, 42))
map_right_eye = list(range(42, 48))
map_mouth = list(range(48, 68))
@@ -74,4 +82,4 @@ for item_idx, item in enumerate(json_data.values()):
save_dict[f'{item_idx:08d}'] = item_dict
print('Save...')
torch.save(save_dict, './FFHQ_eye_mouth_landmarks_512.pth')
torch.save(save_dict, save_path)

View File

@@ -17,7 +17,7 @@ line_length = 120
multi_line_output = 0
known_standard_library = pkg_resources,setuptools
known_first_party = gfpgan
known_third_party = basicsr,cv2,facexlib,numpy,torch,torchvision,tqdm
known_third_party = basicsr,cv2,facexlib,numpy,pytest,torch,torchvision,tqdm,yaml
no_lines_before = STDLIB,LOCALFOLDER
default_section = THIRDPARTY
@@ -25,3 +25,9 @@ default_section = THIRDPARTY
skip = .git,./docs/build
count =
quiet-level = 3
[aliases]
test=pytest
[tool:pytest]
addopts=tests/

View File

@@ -43,12 +43,6 @@ def get_git_hash():
def get_hash():
if os.path.exists('.git'):
sha = get_git_hash()[:7]
elif os.path.exists(version_file):
try:
from gfpgan.version import __version__
sha = __version__.split('+')[-1]
except ImportError:
raise ImportError('Unable to get git version')
else:
sha = 'unknown'

Binary file not shown.

Binary file not shown.

View File

@@ -0,0 +1 @@
00000000.png (512,512,3) 1

BIN
tests/data/gt/00000000.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 429 KiB

Binary file not shown.

View File

@@ -0,0 +1,24 @@
name: UnitTest
type: FFHQDegradationDataset
dataroot_gt: tests/data/gt
io_backend:
type: disk
use_hflip: true
mean: [0.5, 0.5, 0.5]
std: [0.5, 0.5, 0.5]
out_size: 512
blur_kernel_size: 41
kernel_list: ['iso', 'aniso']
kernel_prob: [0.5, 0.5]
blur_sigma: [0.1, 10]
downsample_range: [0.8, 8]
noise_range: [0, 20]
jpeg_range: [60, 100]
# color jitter and gray
color_jitter_prob: 1
color_jitter_shift: 20
color_jitter_pt_prob: 1
gray_prob: 1

View File

@@ -0,0 +1,140 @@
num_gpu: 1
manual_seed: 0
is_train: True
dist: False
# network structures
network_g:
type: GFPGANv1
out_size: 512
num_style_feat: 512
channel_multiplier: 1
resample_kernel: [1, 3, 3, 1]
decoder_load_path: ~
fix_decoder: true
num_mlp: 8
lr_mlp: 0.01
input_is_latent: true
different_w: true
narrow: 0.5
sft_half: true
network_d:
type: StyleGAN2Discriminator
out_size: 512
channel_multiplier: 1
resample_kernel: [1, 3, 3, 1]
network_d_left_eye:
type: FacialComponentDiscriminator
network_d_right_eye:
type: FacialComponentDiscriminator
network_d_mouth:
type: FacialComponentDiscriminator
network_identity:
type: ResNetArcFace
block: IRBlock
layers: [2, 2, 2, 2]
use_se: False
# path
path:
pretrain_network_g: ~
param_key_g: params_ema
strict_load_g: ~
pretrain_network_d: ~
pretrain_network_d_left_eye: ~
pretrain_network_d_right_eye: ~
pretrain_network_d_mouth: ~
pretrain_network_identity: ~
# resume
resume_state: ~
ignore_resume_networks: ['network_identity']
# training settings
train:
optim_g:
type: Adam
lr: !!float 2e-3
optim_d:
type: Adam
lr: !!float 2e-3
optim_component:
type: Adam
lr: !!float 2e-3
scheduler:
type: MultiStepLR
milestones: [600000, 700000]
gamma: 0.5
total_iter: 800000
warmup_iter: -1 # no warm up
# losses
# pixel loss
pixel_opt:
type: L1Loss
loss_weight: !!float 1e-1
reduction: mean
# L1 loss used in pyramid loss, component style loss and identity loss
L1_opt:
type: L1Loss
loss_weight: 1
reduction: mean
# image pyramid loss
pyramid_loss_weight: 1
remove_pyramid_loss: 50000
# perceptual loss (content and style losses)
perceptual_opt:
type: PerceptualLoss
layer_weights:
# before relu
'conv1_2': 0.1
'conv2_2': 0.1
'conv3_4': 1
'conv4_4': 1
'conv5_4': 1
vgg_type: vgg19
use_input_norm: true
perceptual_weight: !!float 1
style_weight: 50
range_norm: true
criterion: l1
# gan loss
gan_opt:
type: GANLoss
gan_type: wgan_softplus
loss_weight: !!float 1e-1
# r1 regularization for discriminator
r1_reg_weight: 10
# facial component loss
gan_component_opt:
type: GANLoss
gan_type: vanilla
real_label_val: 1.0
fake_label_val: 0.0
loss_weight: !!float 1
comp_style_weight: 200
# identity loss
identity_weight: 10
net_d_iters: 1
net_d_init_iters: 0
net_d_reg_every: 1
# validation settings
val:
val_freq: !!float 5e3
save_img: True
use_pbar: True
metrics:
psnr: # metric name
type: calculate_psnr
crop_border: 0
test_y_channel: false

View File

@@ -0,0 +1,49 @@
import torch
from gfpgan.archs.arcface_arch import BasicBlock, Bottleneck, ResNetArcFace
def test_resnetarcface():
"""Test arch: ResNetArcFace."""
# model init and forward (gpu)
if torch.cuda.is_available():
net = ResNetArcFace(block='IRBlock', layers=(2, 2, 2, 2), use_se=True).cuda().eval()
img = torch.rand((1, 1, 128, 128), dtype=torch.float32).cuda()
output = net(img)
assert output.shape == (1, 512)
# -------------------- without SE block ----------------------- #
net = ResNetArcFace(block='IRBlock', layers=(2, 2, 2, 2), use_se=False).cuda().eval()
output = net(img)
assert output.shape == (1, 512)
def test_basicblock():
"""Test the BasicBlock in arcface_arch"""
block = BasicBlock(1, 3, stride=1, downsample=None).cuda()
img = torch.rand((1, 1, 12, 12), dtype=torch.float32).cuda()
output = block(img)
assert output.shape == (1, 3, 12, 12)
# ----------------- use the downsmaple module--------------- #
downsample = torch.nn.UpsamplingNearest2d(scale_factor=0.5).cuda()
block = BasicBlock(1, 3, stride=2, downsample=downsample).cuda()
img = torch.rand((1, 1, 12, 12), dtype=torch.float32).cuda()
output = block(img)
assert output.shape == (1, 3, 6, 6)
def test_bottleneck():
"""Test the Bottleneck in arcface_arch"""
block = Bottleneck(1, 1, stride=1, downsample=None).cuda()
img = torch.rand((1, 1, 12, 12), dtype=torch.float32).cuda()
output = block(img)
assert output.shape == (1, 4, 12, 12)
# ----------------- use the downsmaple module--------------- #
downsample = torch.nn.UpsamplingNearest2d(scale_factor=0.5).cuda()
block = Bottleneck(1, 1, stride=2, downsample=downsample).cuda()
img = torch.rand((1, 1, 12, 12), dtype=torch.float32).cuda()
output = block(img)
assert output.shape == (1, 4, 6, 6)

View File

@@ -0,0 +1,96 @@
import pytest
import yaml
from gfpgan.data.ffhq_degradation_dataset import FFHQDegradationDataset
def test_ffhq_degradation_dataset():
with open('tests/data/test_ffhq_degradation_dataset.yml', mode='r') as f:
opt = yaml.load(f, Loader=yaml.FullLoader)
dataset = FFHQDegradationDataset(opt)
assert dataset.io_backend_opt['type'] == 'disk' # io backend
assert len(dataset) == 1 # whether to read correct meta info
assert dataset.kernel_list == ['iso', 'aniso'] # correct initialization the degradation configurations
assert dataset.color_jitter_prob == 1
# test __getitem__
result = dataset.__getitem__(0)
# check returned keys
expected_keys = ['gt', 'lq', 'gt_path']
assert set(expected_keys).issubset(set(result.keys()))
# check shape and contents
assert result['gt'].shape == (3, 512, 512)
assert result['lq'].shape == (3, 512, 512)
assert result['gt_path'] == 'tests/data/gt/00000000.png'
# ------------------ test with probability = 0 -------------------- #
opt['color_jitter_prob'] = 0
opt['color_jitter_pt_prob'] = 0
opt['gray_prob'] = 0
opt['io_backend'] = dict(type='disk')
dataset = FFHQDegradationDataset(opt)
assert dataset.io_backend_opt['type'] == 'disk' # io backend
assert len(dataset) == 1 # whether to read correct meta info
assert dataset.kernel_list == ['iso', 'aniso'] # correct initialization the degradation configurations
assert dataset.color_jitter_prob == 0
# test __getitem__
result = dataset.__getitem__(0)
# check returned keys
expected_keys = ['gt', 'lq', 'gt_path']
assert set(expected_keys).issubset(set(result.keys()))
# check shape and contents
assert result['gt'].shape == (3, 512, 512)
assert result['lq'].shape == (3, 512, 512)
assert result['gt_path'] == 'tests/data/gt/00000000.png'
# ------------------ test lmdb backend -------------------- #
opt['dataroot_gt'] = 'tests/data/ffhq_gt.lmdb'
opt['io_backend'] = dict(type='lmdb')
dataset = FFHQDegradationDataset(opt)
assert dataset.io_backend_opt['type'] == 'lmdb' # io backend
assert len(dataset) == 1 # whether to read correct meta info
assert dataset.kernel_list == ['iso', 'aniso'] # correct initialization the degradation configurations
assert dataset.color_jitter_prob == 0
# test __getitem__
result = dataset.__getitem__(0)
# check returned keys
expected_keys = ['gt', 'lq', 'gt_path']
assert set(expected_keys).issubset(set(result.keys()))
# check shape and contents
assert result['gt'].shape == (3, 512, 512)
assert result['lq'].shape == (3, 512, 512)
assert result['gt_path'] == '00000000'
# ------------------ test with crop_components -------------------- #
opt['crop_components'] = True
opt['component_path'] = 'tests/data/test_eye_mouth_landmarks.pth'
opt['eye_enlarge_ratio'] = 1.4
opt['gt_gray'] = True
opt['io_backend'] = dict(type='lmdb')
dataset = FFHQDegradationDataset(opt)
assert dataset.crop_components is True
# test __getitem__
result = dataset.__getitem__(0)
# check returned keys
expected_keys = ['gt', 'lq', 'gt_path', 'loc_left_eye', 'loc_right_eye', 'loc_mouth']
assert set(expected_keys).issubset(set(result.keys()))
# check shape and contents
assert result['gt'].shape == (3, 512, 512)
assert result['lq'].shape == (3, 512, 512)
assert result['gt_path'] == '00000000'
assert result['loc_left_eye'].shape == (4, )
assert result['loc_right_eye'].shape == (4, )
assert result['loc_mouth'].shape == (4, )
# ------------------ lmdb backend should have paths ends with lmdb -------------------- #
with pytest.raises(ValueError):
opt['dataroot_gt'] = 'tests/data/gt'
opt['io_backend'] = dict(type='lmdb')
dataset = FFHQDegradationDataset(opt)

203
tests/test_gfpgan_arch.py Normal file
View File

@@ -0,0 +1,203 @@
import torch
from gfpgan.archs.gfpganv1_arch import FacialComponentDiscriminator, GFPGANv1, StyleGAN2GeneratorSFT
from gfpgan.archs.gfpganv1_clean_arch import GFPGANv1Clean, StyleGAN2GeneratorCSFT
def test_stylegan2generatorsft():
"""Test arch: StyleGAN2GeneratorSFT."""
# model init and forward (gpu)
if torch.cuda.is_available():
net = StyleGAN2GeneratorSFT(
out_size=32,
num_style_feat=512,
num_mlp=8,
channel_multiplier=1,
resample_kernel=(1, 3, 3, 1),
lr_mlp=0.01,
narrow=1,
sft_half=False).cuda().eval()
style = torch.rand((1, 512), dtype=torch.float32).cuda()
condition1 = torch.rand((1, 512, 8, 8), dtype=torch.float32).cuda()
condition2 = torch.rand((1, 512, 16, 16), dtype=torch.float32).cuda()
condition3 = torch.rand((1, 512, 32, 32), dtype=torch.float32).cuda()
conditions = [condition1, condition1, condition2, condition2, condition3, condition3]
output = net([style], conditions)
assert output[0].shape == (1, 3, 32, 32)
assert output[1] is None
# -------------------- with return_latents ----------------------- #
output = net([style], conditions, return_latents=True)
assert output[0].shape == (1, 3, 32, 32)
assert len(output[1]) == 1
# check latent
assert output[1][0].shape == (8, 512)
# -------------------- with randomize_noise = False ----------------------- #
output = net([style], conditions, randomize_noise=False)
assert output[0].shape == (1, 3, 32, 32)
assert output[1] is None
# -------------------- with truncation = 0.5 and mixing----------------------- #
output = net([style, style], conditions, truncation=0.5, truncation_latent=style)
assert output[0].shape == (1, 3, 32, 32)
assert output[1] is None
def test_gfpganv1():
"""Test arch: GFPGANv1."""
# model init and forward (gpu)
if torch.cuda.is_available():
net = GFPGANv1(
out_size=32,
num_style_feat=512,
channel_multiplier=1,
resample_kernel=(1, 3, 3, 1),
decoder_load_path=None,
fix_decoder=True,
# for stylegan decoder
num_mlp=8,
lr_mlp=0.01,
input_is_latent=False,
different_w=False,
narrow=1,
sft_half=True).cuda().eval()
img = torch.rand((1, 3, 32, 32), dtype=torch.float32).cuda()
output = net(img)
assert output[0].shape == (1, 3, 32, 32)
assert len(output[1]) == 3
# check out_rgbs for intermediate loss
assert output[1][0].shape == (1, 3, 8, 8)
assert output[1][1].shape == (1, 3, 16, 16)
assert output[1][2].shape == (1, 3, 32, 32)
# -------------------- with different_w = True ----------------------- #
net = GFPGANv1(
out_size=32,
num_style_feat=512,
channel_multiplier=1,
resample_kernel=(1, 3, 3, 1),
decoder_load_path=None,
fix_decoder=True,
# for stylegan decoder
num_mlp=8,
lr_mlp=0.01,
input_is_latent=False,
different_w=True,
narrow=1,
sft_half=True).cuda().eval()
img = torch.rand((1, 3, 32, 32), dtype=torch.float32).cuda()
output = net(img)
assert output[0].shape == (1, 3, 32, 32)
assert len(output[1]) == 3
# check out_rgbs for intermediate loss
assert output[1][0].shape == (1, 3, 8, 8)
assert output[1][1].shape == (1, 3, 16, 16)
assert output[1][2].shape == (1, 3, 32, 32)
def test_facialcomponentdiscriminator():
"""Test arch: FacialComponentDiscriminator."""
# model init and forward (gpu)
if torch.cuda.is_available():
net = FacialComponentDiscriminator().cuda().eval()
img = torch.rand((1, 3, 32, 32), dtype=torch.float32).cuda()
output = net(img)
assert len(output) == 2
assert output[0].shape == (1, 1, 8, 8)
assert output[1] is None
# -------------------- return intermediate features ----------------------- #
output = net(img, return_feats=True)
assert len(output) == 2
assert output[0].shape == (1, 1, 8, 8)
assert len(output[1]) == 2
assert output[1][0].shape == (1, 128, 16, 16)
assert output[1][1].shape == (1, 256, 8, 8)
def test_stylegan2generatorcsft():
"""Test arch: StyleGAN2GeneratorCSFT."""
# model init and forward (gpu)
if torch.cuda.is_available():
net = StyleGAN2GeneratorCSFT(
out_size=32, num_style_feat=512, num_mlp=8, channel_multiplier=1, narrow=1, sft_half=False).cuda().eval()
style = torch.rand((1, 512), dtype=torch.float32).cuda()
condition1 = torch.rand((1, 512, 8, 8), dtype=torch.float32).cuda()
condition2 = torch.rand((1, 512, 16, 16), dtype=torch.float32).cuda()
condition3 = torch.rand((1, 512, 32, 32), dtype=torch.float32).cuda()
conditions = [condition1, condition1, condition2, condition2, condition3, condition3]
output = net([style], conditions)
assert output[0].shape == (1, 3, 32, 32)
assert output[1] is None
# -------------------- with return_latents ----------------------- #
output = net([style], conditions, return_latents=True)
assert output[0].shape == (1, 3, 32, 32)
assert len(output[1]) == 1
# check latent
assert output[1][0].shape == (8, 512)
# -------------------- with randomize_noise = False ----------------------- #
output = net([style], conditions, randomize_noise=False)
assert output[0].shape == (1, 3, 32, 32)
assert output[1] is None
# -------------------- with truncation = 0.5 and mixing----------------------- #
output = net([style, style], conditions, truncation=0.5, truncation_latent=style)
assert output[0].shape == (1, 3, 32, 32)
assert output[1] is None
def test_gfpganv1clean():
"""Test arch: GFPGANv1Clean."""
# model init and forward (gpu)
if torch.cuda.is_available():
net = GFPGANv1Clean(
out_size=32,
num_style_feat=512,
channel_multiplier=1,
decoder_load_path=None,
fix_decoder=True,
# for stylegan decoder
num_mlp=8,
input_is_latent=False,
different_w=False,
narrow=1,
sft_half=True).cuda().eval()
img = torch.rand((1, 3, 32, 32), dtype=torch.float32).cuda()
output = net(img)
assert output[0].shape == (1, 3, 32, 32)
assert len(output[1]) == 3
# check out_rgbs for intermediate loss
assert output[1][0].shape == (1, 3, 8, 8)
assert output[1][1].shape == (1, 3, 16, 16)
assert output[1][2].shape == (1, 3, 32, 32)
# -------------------- with different_w = True ----------------------- #
net = GFPGANv1Clean(
out_size=32,
num_style_feat=512,
channel_multiplier=1,
decoder_load_path=None,
fix_decoder=True,
# for stylegan decoder
num_mlp=8,
input_is_latent=False,
different_w=True,
narrow=1,
sft_half=True).cuda().eval()
img = torch.rand((1, 3, 32, 32), dtype=torch.float32).cuda()
output = net(img)
assert output[0].shape == (1, 3, 32, 32)
assert len(output[1]) == 3
# check out_rgbs for intermediate loss
assert output[1][0].shape == (1, 3, 8, 8)
assert output[1][1].shape == (1, 3, 16, 16)
assert output[1][2].shape == (1, 3, 32, 32)

132
tests/test_gfpgan_model.py Normal file
View File

@@ -0,0 +1,132 @@
import tempfile
import torch
import yaml
from basicsr.archs.stylegan2_arch import StyleGAN2Discriminator
from basicsr.data.paired_image_dataset import PairedImageDataset
from basicsr.losses.losses import GANLoss, L1Loss, PerceptualLoss
from gfpgan.archs.arcface_arch import ResNetArcFace
from gfpgan.archs.gfpganv1_arch import FacialComponentDiscriminator, GFPGANv1
from gfpgan.models.gfpgan_model import GFPGANModel
def test_gfpgan_model():
with open('tests/data/test_gfpgan_model.yml', mode='r') as f:
opt = yaml.load(f, Loader=yaml.FullLoader)
# build model
model = GFPGANModel(opt)
# test attributes
assert model.__class__.__name__ == 'GFPGANModel'
assert isinstance(model.net_g, GFPGANv1) # generator
assert isinstance(model.net_d, StyleGAN2Discriminator) # discriminator
# facial component discriminators
assert isinstance(model.net_d_left_eye, FacialComponentDiscriminator)
assert isinstance(model.net_d_right_eye, FacialComponentDiscriminator)
assert isinstance(model.net_d_mouth, FacialComponentDiscriminator)
# identity network
assert isinstance(model.network_identity, ResNetArcFace)
# losses
assert isinstance(model.cri_pix, L1Loss)
assert isinstance(model.cri_perceptual, PerceptualLoss)
assert isinstance(model.cri_gan, GANLoss)
assert isinstance(model.cri_l1, L1Loss)
# optimizer
assert isinstance(model.optimizers[0], torch.optim.Adam)
assert isinstance(model.optimizers[1], torch.optim.Adam)
# prepare data
gt = torch.rand((1, 3, 512, 512), dtype=torch.float32)
lq = torch.rand((1, 3, 512, 512), dtype=torch.float32)
loc_left_eye = torch.rand((1, 4), dtype=torch.float32)
loc_right_eye = torch.rand((1, 4), dtype=torch.float32)
loc_mouth = torch.rand((1, 4), dtype=torch.float32)
data = dict(gt=gt, lq=lq, loc_left_eye=loc_left_eye, loc_right_eye=loc_right_eye, loc_mouth=loc_mouth)
model.feed_data(data)
# check data shape
assert model.lq.shape == (1, 3, 512, 512)
assert model.gt.shape == (1, 3, 512, 512)
assert model.loc_left_eyes.shape == (1, 4)
assert model.loc_right_eyes.shape == (1, 4)
assert model.loc_mouths.shape == (1, 4)
# ----------------- test optimize_parameters -------------------- #
model.feed_data(data)
model.optimize_parameters(1)
assert model.output.shape == (1, 3, 512, 512)
assert isinstance(model.log_dict, dict)
# check returned keys
expected_keys = [
'l_g_pix', 'l_g_percep', 'l_g_style', 'l_g_gan', 'l_g_gan_left_eye', 'l_g_gan_right_eye', 'l_g_gan_mouth',
'l_g_comp_style_loss', 'l_identity', 'l_d', 'real_score', 'fake_score', 'l_d_r1', 'l_d_left_eye',
'l_d_right_eye', 'l_d_mouth'
]
assert set(expected_keys).issubset(set(model.log_dict.keys()))
# ----------------- remove pyramid_loss_weight-------------------- #
model.feed_data(data)
model.optimize_parameters(100000) # large than remove_pyramid_loss = 50000
assert model.output.shape == (1, 3, 512, 512)
assert isinstance(model.log_dict, dict)
# check returned keys
expected_keys = [
'l_g_pix', 'l_g_percep', 'l_g_style', 'l_g_gan', 'l_g_gan_left_eye', 'l_g_gan_right_eye', 'l_g_gan_mouth',
'l_g_comp_style_loss', 'l_identity', 'l_d', 'real_score', 'fake_score', 'l_d_r1', 'l_d_left_eye',
'l_d_right_eye', 'l_d_mouth'
]
assert set(expected_keys).issubset(set(model.log_dict.keys()))
# ----------------- test save -------------------- #
with tempfile.TemporaryDirectory() as tmpdir:
model.opt['path']['models'] = tmpdir
model.opt['path']['training_states'] = tmpdir
model.save(0, 1)
# ----------------- test the test function -------------------- #
model.test()
assert model.output.shape == (1, 3, 512, 512)
# delete net_g_ema
model.__delattr__('net_g_ema')
model.test()
assert model.output.shape == (1, 3, 512, 512)
assert model.net_g.training is True # should back to training mode after testing
# ----------------- test nondist_validation -------------------- #
# construct dataloader
dataset_opt = dict(
name='Demo',
dataroot_gt='tests/data/gt',
dataroot_lq='tests/data/gt',
io_backend=dict(type='disk'),
scale=4,
phase='val')
dataset = PairedImageDataset(dataset_opt)
dataloader = torch.utils.data.DataLoader(dataset=dataset, batch_size=1, shuffle=False, num_workers=0)
assert model.is_train is True
with tempfile.TemporaryDirectory() as tmpdir:
model.opt['path']['visualization'] = tmpdir
model.nondist_validation(dataloader, 1, None, save_img=True)
assert model.is_train is True
# check metric_results
assert 'psnr' in model.metric_results
assert isinstance(model.metric_results['psnr'], float)
# validation
with tempfile.TemporaryDirectory() as tmpdir:
model.opt['is_train'] = False
model.opt['val']['suffix'] = 'test'
model.opt['path']['visualization'] = tmpdir
model.opt['val']['pbar'] = True
model.nondist_validation(dataloader, 1, None, save_img=True)
# check metric_results
assert 'psnr' in model.metric_results
assert isinstance(model.metric_results['psnr'], float)
# if opt['val']['suffix'] is None
model.opt['val']['suffix'] = None
model.opt['name'] = 'demo'
model.opt['path']['visualization'] = tmpdir
model.nondist_validation(dataloader, 1, None, save_img=True)
# check metric_results
assert 'psnr' in model.metric_results
assert isinstance(model.metric_results['psnr'], float)

View File

@@ -0,0 +1,52 @@
import torch
from gfpgan.archs.stylegan2_clean_arch import StyleGAN2GeneratorClean
def test_stylegan2generatorclean():
"""Test arch: StyleGAN2GeneratorClean."""
# model init and forward (gpu)
if torch.cuda.is_available():
net = StyleGAN2GeneratorClean(
out_size=32, num_style_feat=512, num_mlp=8, channel_multiplier=1, narrow=0.5).cuda().eval()
style = torch.rand((1, 512), dtype=torch.float32).cuda()
output = net([style], input_is_latent=False)
assert output[0].shape == (1, 3, 32, 32)
assert output[1] is None
# -------------------- with return_latents ----------------------- #
output = net([style], input_is_latent=True, return_latents=True)
assert output[0].shape == (1, 3, 32, 32)
assert len(output[1]) == 1
# check latent
assert output[1][0].shape == (8, 512)
# -------------------- with randomize_noise = False ----------------------- #
output = net([style], randomize_noise=False)
assert output[0].shape == (1, 3, 32, 32)
assert output[1] is None
# -------------------- with truncation = 0.5 and mixing----------------------- #
output = net([style, style], truncation=0.5, truncation_latent=style)
assert output[0].shape == (1, 3, 32, 32)
assert output[1] is None
# ------------------ test make_noise ----------------------- #
out = net.make_noise()
assert len(out) == 7
assert out[0].shape == (1, 1, 4, 4)
assert out[1].shape == (1, 1, 8, 8)
assert out[2].shape == (1, 1, 8, 8)
assert out[3].shape == (1, 1, 16, 16)
assert out[4].shape == (1, 1, 16, 16)
assert out[5].shape == (1, 1, 32, 32)
assert out[6].shape == (1, 1, 32, 32)
# ------------------ test get_latent ----------------------- #
out = net.get_latent(style)
assert out.shape == (1, 512)
# ------------------ test mean_latent ----------------------- #
out = net.mean_latent(2)
assert out.shape == (1, 512)

43
tests/test_utils.py Normal file
View File

@@ -0,0 +1,43 @@
import cv2
from facexlib.utils.face_restoration_helper import FaceRestoreHelper
from gfpgan.archs.gfpganv1_arch import GFPGANv1
from gfpgan.archs.gfpganv1_clean_arch import GFPGANv1Clean
from gfpgan.utils import GFPGANer
def test_gfpganer():
# initialize with the clean model
restorer = GFPGANer(
model_path='experiments/pretrained_models/GFPGANCleanv1-NoCE-C2.pth',
upscale=2,
arch='clean',
channel_multiplier=2,
bg_upsampler=None)
# test attribute
assert isinstance(restorer.gfpgan, GFPGANv1Clean)
assert isinstance(restorer.face_helper, FaceRestoreHelper)
# initialize with the original model
restorer = GFPGANer(
model_path='experiments/pretrained_models/GFPGANv1.pth',
upscale=2,
arch='original',
channel_multiplier=1,
bg_upsampler=None)
# test attribute
assert isinstance(restorer.gfpgan, GFPGANv1)
assert isinstance(restorer.face_helper, FaceRestoreHelper)
# ------------------ test enhance ---------------- #
img = cv2.imread('tests/data/gt/00000000.png', cv2.IMREAD_COLOR)
result = restorer.enhance(img, has_aligned=False, paste_back=True)
assert result[0][0].shape == (512, 512, 3)
assert result[1][0].shape == (512, 512, 3)
assert result[2].shape == (1024, 1024, 3)
# with has_aligned=True
result = restorer.enhance(img, has_aligned=True, paste_back=False)
assert result[0][0].shape == (512, 512, 3)
assert result[1][0].shape == (512, 512, 3)
assert result[2] is None