clean and add more comments

This commit is contained in:
Xintao
2021-11-27 19:59:23 +08:00
parent 0ff1cf7215
commit be73d6d9a4
13 changed files with 336 additions and 225 deletions

View File

@@ -1,12 +1,11 @@
name: No Response
# TODO: it seems not to work
# Modified from: https://raw.githubusercontent.com/github/docs/main/.github/workflows/no-response.yaml
# **What it does**: Closes issues that don't have enough information to be
# actionable.
# **Why we have it**: To remove the need for maintainers to remember to check
# back on issues periodically to see if contributors have
# responded.
# **What it does**: Closes issues that don't have enough information to be actionable.
# **Why we have it**: To remove the need for maintainers to remember to check back on issues periodically
# to see if contributors have responded.
# **Who does it impact**: Everyone that works on docs or docs-internal.
on:

View File

@@ -3,4 +3,4 @@ from .archs import *
from .data import *
from .models import *
from .utils import *
from .version import __gitsha__, __version__
from .version import *

View File

@@ -2,13 +2,27 @@ import torch.nn as nn
from basicsr.utils.registry import ARCH_REGISTRY
def conv3x3(in_planes, out_planes, stride=1):
"""3x3 convolution with padding"""
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
def conv3x3(inplanes, outplanes, stride=1):
"""A simple wrapper for 3x3 convolution with padding.
Args:
inplanes (int): Channel number of inputs.
outplanes (int): Channel number of outputs.
stride (int): Stride in convolution. Default: 1.
"""
return nn.Conv2d(inplanes, outplanes, kernel_size=3, stride=stride, padding=1, bias=False)
class BasicBlock(nn.Module):
expansion = 1
"""Basic residual block used in the ResNetArcFace architecture.
Args:
inplanes (int): Channel number of inputs.
planes (int): Channel number of outputs.
stride (int): Stride in convolution. Default: 1.
downsample (nn.Module): The downsample module. Default: None.
"""
expansion = 1 # output channel expansion ratio
def __init__(self, inplanes, planes, stride=1, downsample=None):
super(BasicBlock, self).__init__()
@@ -40,7 +54,16 @@ class BasicBlock(nn.Module):
class IRBlock(nn.Module):
expansion = 1
"""Improved residual block (IR Block) used in the ResNetArcFace architecture.
Args:
inplanes (int): Channel number of inputs.
planes (int): Channel number of outputs.
stride (int): Stride in convolution. Default: 1.
downsample (nn.Module): The downsample module. Default: None.
use_se (bool): Whether use the SEBlock (squeeze and excitation block). Default: True.
"""
expansion = 1 # output channel expansion ratio
def __init__(self, inplanes, planes, stride=1, downsample=None, use_se=True):
super(IRBlock, self).__init__()
@@ -78,7 +101,15 @@ class IRBlock(nn.Module):
class Bottleneck(nn.Module):
expansion = 4
"""Bottleneck block used in the ResNetArcFace architecture.
Args:
inplanes (int): Channel number of inputs.
planes (int): Channel number of outputs.
stride (int): Stride in convolution. Default: 1.
downsample (nn.Module): The downsample module. Default: None.
"""
expansion = 4 # output channel expansion ratio
def __init__(self, inplanes, planes, stride=1, downsample=None):
super(Bottleneck, self).__init__()
@@ -116,10 +147,16 @@ class Bottleneck(nn.Module):
class SEBlock(nn.Module):
"""The squeeze-and-excitation block (SEBlock) used in the IRBlock.
Args:
channel (int): Channel number of inputs.
reduction (int): Channel reduction ration. Default: 16.
"""
def __init__(self, channel, reduction=16):
super(SEBlock, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.avg_pool = nn.AdaptiveAvgPool2d(1) # pool to 1x1 without spatial information
self.fc = nn.Sequential(
nn.Linear(channel, channel // reduction), nn.PReLU(), nn.Linear(channel // reduction, channel),
nn.Sigmoid())
@@ -133,6 +170,15 @@ class SEBlock(nn.Module):
@ARCH_REGISTRY.register()
class ResNetArcFace(nn.Module):
"""ArcFace with ResNet architectures.
Ref: ArcFace: Additive Angular Margin Loss for Deep Face Recognition.
Args:
block (str): Block used in the ArcFace architecture.
layers (tuple(int)): Block numbers in each layer.
use_se (bool): Whether use the SEBlock (squeeze and excitation block). Default: True.
"""
def __init__(self, block, layers, use_se=True):
if block == 'IRBlock':
@@ -140,6 +186,7 @@ class ResNetArcFace(nn.Module):
self.inplanes = 64
self.use_se = use_se
super(ResNetArcFace, self).__init__()
self.conv1 = nn.Conv2d(1, 64, kernel_size=3, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(64)
self.prelu = nn.PReLU()
@@ -153,6 +200,7 @@ class ResNetArcFace(nn.Module):
self.fc5 = nn.Linear(512 * 8 * 8, 512)
self.bn5 = nn.BatchNorm1d(512)
# initialization
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.xavier_normal_(m.weight)
@@ -163,7 +211,7 @@ class ResNetArcFace(nn.Module):
nn.init.xavier_normal_(m.weight)
nn.init.constant_(m.bias, 0)
def _make_layer(self, block, planes, blocks, stride=1):
def _make_layer(self, block, planes, num_blocks, stride=1):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
@@ -173,7 +221,7 @@ class ResNetArcFace(nn.Module):
layers = []
layers.append(block(self.inplanes, planes, stride, downsample, use_se=self.use_se))
self.inplanes = planes
for _ in range(1, blocks):
for _ in range(1, num_blocks):
layers.append(block(self.inplanes, planes, use_se=self.use_se))
return nn.Sequential(*layers)

View File

@@ -10,18 +10,18 @@ from torch.nn import functional as F
class StyleGAN2GeneratorSFT(StyleGAN2Generator):
"""StyleGAN2 Generator.
"""StyleGAN2 Generator with SFT modulation (Spatial Feature Transform).
Args:
out_size (int): The spatial size of outputs.
num_style_feat (int): Channel number of style features. Default: 512.
num_mlp (int): Layer number of MLP style layers. Default: 8.
channel_multiplier (int): Channel multiplier for large networks of
StyleGAN2. Default: 2.
resample_kernel (list[int]): A list indicating the 1D resample kernel
magnitude. A cross production will be applied to extent 1D resample
kernel to 2D resample kernel. Default: [1, 3, 3, 1].
channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
resample_kernel (list[int]): A list indicating the 1D resample kernel magnitude. A cross production will be
applied to extent 1D resample kernel to 2D resample kernel. Default: (1, 3, 3, 1).
lr_mlp (float): Learning rate multiplier for mlp layers. Default: 0.01.
narrow (float): The narrow ratio for channels. Default: 1.
sft_half (bool): Whether to apply SFT on half of the input channels. Default: False.
"""
def __init__(self,
@@ -53,21 +53,18 @@ class StyleGAN2GeneratorSFT(StyleGAN2Generator):
truncation_latent=None,
inject_index=None,
return_latents=False):
"""Forward function for StyleGAN2Generator.
"""Forward function for StyleGAN2GeneratorSFT.
Args:
styles (list[Tensor]): Sample codes of styles.
input_is_latent (bool): Whether input is latent style.
Default: False.
conditions (list[Tensor]): SFT conditions to generators.
input_is_latent (bool): Whether input is latent style. Default: False.
noise (Tensor | None): Input noise or None. Default: None.
randomize_noise (bool): Randomize noise, used when 'noise' is
False. Default: True.
truncation (float): TODO. Default: 1.
truncation_latent (Tensor | None): TODO. Default: None.
inject_index (int | None): The injection index for mixing noise.
Default: None.
return_latents (bool): Whether to return style latents.
Default: False.
randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True.
truncation (float): The truncation ratio. Default: 1.
truncation_latent (Tensor | None): The truncation latent tensor. Default: None.
inject_index (int | None): The injection index for mixing noise. Default: None.
return_latents (bool): Whether to return style latents. Default: False.
"""
# style codes -> latents with Style MLP layer
if not input_is_latent:
@@ -84,7 +81,7 @@ class StyleGAN2GeneratorSFT(StyleGAN2Generator):
for style in styles:
style_truncation.append(truncation_latent + truncation * (style - truncation_latent))
styles = style_truncation
# get style latent with injection
# get style latents with injection
if len(styles) == 1:
inject_index = self.num_latent
@@ -113,15 +110,15 @@ class StyleGAN2GeneratorSFT(StyleGAN2Generator):
# the conditions may have fewer levels
if i < len(conditions):
# SFT part to combine the conditions
if self.sft_half:
if self.sft_half: # only apply SFT to half of the channels
out_same, out_sft = torch.split(out, int(out.size(1) // 2), dim=1)
out_sft = out_sft * conditions[i - 1] + conditions[i]
out = torch.cat([out_same, out_sft], dim=1)
else:
else: # apply SFT to all the channels
out = out * conditions[i - 1] + conditions[i]
out = conv2(out, latent[:, i + 1], noise=noise2)
skip = to_rgb(out, latent[:, i + 2], skip)
skip = to_rgb(out, latent[:, i + 2], skip) # feature back to the rgb space
i += 2
image = skip
@@ -133,17 +130,15 @@ class StyleGAN2GeneratorSFT(StyleGAN2Generator):
class ConvUpLayer(nn.Module):
"""Conv Up Layer. Bilinear upsample + Conv.
"""Convolutional upsampling layer. It uses bilinear upsampler + Conv.
Args:
in_channels (int): Channel number of the input.
out_channels (int): Channel number of the output.
kernel_size (int): Size of the convolving kernel.
stride (int): Stride of the convolution. Default: 1
padding (int): Zero-padding added to both sides of the input.
Default: 0.
bias (bool): If ``True``, adds a learnable bias to the output.
Default: ``True``.
padding (int): Zero-padding added to both sides of the input. Default: 0.
bias (bool): If ``True``, adds a learnable bias to the output. Default: ``True``.
bias_init_val (float): Bias initialized value. Default: 0.
activate (bool): Whether use activateion. Default: True.
"""
@@ -163,6 +158,7 @@ class ConvUpLayer(nn.Module):
self.kernel_size = kernel_size
self.stride = stride
self.padding = padding
# self.scale is used to scale the convolution weights, which is related to the common initializations.
self.scale = 1 / math.sqrt(in_channels * kernel_size**2)
self.weight = nn.Parameter(torch.randn(out_channels, in_channels, kernel_size, kernel_size))
@@ -223,7 +219,26 @@ class ResUpBlock(nn.Module):
@ARCH_REGISTRY.register()
class GFPGANv1(nn.Module):
"""Unet + StyleGAN2 decoder with SFT."""
"""The GFPGAN architecture: Unet + StyleGAN2 decoder with SFT.
Ref: GFP-GAN: Towards Real-World Blind Face Restoration with Generative Facial Prior.
Args:
out_size (int): The spatial size of outputs.
num_style_feat (int): Channel number of style features. Default: 512.
channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
resample_kernel (list[int]): A list indicating the 1D resample kernel magnitude. A cross production will be
applied to extent 1D resample kernel to 2D resample kernel. Default: (1, 3, 3, 1).
decoder_load_path (str): The path to the pre-trained decoder model (usually, the StyleGAN2). Default: None.
fix_decoder (bool): Whether to fix the decoder. Default: True.
num_mlp (int): Layer number of MLP style layers. Default: 8.
lr_mlp (float): Learning rate multiplier for mlp layers. Default: 0.01.
input_is_latent (bool): Whether input is latent style. Default: False.
different_w (bool): Whether to use different latent w for different layers. Default: False.
narrow (float): The narrow ratio for channels. Default: 1.
sft_half (bool): Whether to apply SFT on half of the input channels. Default: False.
"""
def __init__(
self,
@@ -246,7 +261,7 @@ class GFPGANv1(nn.Module):
self.different_w = different_w
self.num_style_feat = num_style_feat
unet_narrow = narrow * 0.5
unet_narrow = narrow * 0.5 # by default, use a half of input channels
channels = {
'4': int(512 * unet_narrow),
'8': int(512 * unet_narrow),
@@ -295,6 +310,7 @@ class GFPGANv1(nn.Module):
self.final_linear = EqualLinear(
channels['4'] * 4 * 4, linear_out_channel, bias=True, bias_init_val=0, lr_mul=1, activation=None)
# the decoder: stylegan2 generator with SFT modulations
self.stylegan_decoder = StyleGAN2GeneratorSFT(
out_size=out_size,
num_style_feat=num_style_feat,
@@ -305,14 +321,16 @@ class GFPGANv1(nn.Module):
narrow=narrow,
sft_half=sft_half)
# load pre-trained stylegan2 model if necessary
if decoder_load_path:
self.stylegan_decoder.load_state_dict(
torch.load(decoder_load_path, map_location=lambda storage, loc: storage)['params_ema'])
# fix decoder without updating params
if fix_decoder:
for _, param in self.stylegan_decoder.named_parameters():
param.requires_grad = False
# for SFT
# for SFT modulations (scale and shift)
self.condition_scale = nn.ModuleList()
self.condition_shift = nn.ModuleList()
for i in range(3, self.log_size + 1):
@@ -332,13 +350,15 @@ class GFPGANv1(nn.Module):
ScaledLeakyReLU(0.2),
EqualConv2d(out_channels, sft_out_channels, 3, stride=1, padding=1, bias=True, bias_init_val=0)))
def forward(self,
x,
return_latents=False,
save_feat_path=None,
load_feat_path=None,
return_rgb=True,
randomize_noise=True):
def forward(self, x, return_latents=False, return_rgb=True, randomize_noise=True):
"""Forward function for GFPGANv1.
Args:
x (Tensor): Input images.
return_latents (bool): Whether to return style latents. Default: False.
return_rgb (bool): Whether return intermediate rgb images. Default: True.
randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True.
"""
conditions = []
unet_skips = []
out_rgbs = []
@@ -362,7 +382,7 @@ class GFPGANv1(nn.Module):
feat = feat + unet_skips[i]
# ResUpLayer
feat = self.conv_body_up[i](feat)
# generate scale and shift for SFT layer
# generate scale and shift for SFT layers
scale = self.condition_scale[i](feat)
conditions.append(scale.clone())
shift = self.condition_shift[i](feat)
@@ -371,12 +391,6 @@ class GFPGANv1(nn.Module):
if return_rgb:
out_rgbs.append(self.toRGB[i](feat))
if save_feat_path is not None:
torch.save(conditions, save_feat_path)
if load_feat_path is not None:
conditions = torch.load(load_feat_path)
conditions = [v.cuda() for v in conditions]
# decoder
image, _ = self.stylegan_decoder([style_code],
conditions,
@@ -389,10 +403,12 @@ class GFPGANv1(nn.Module):
@ARCH_REGISTRY.register()
class FacialComponentDiscriminator(nn.Module):
"""Facial component (eyes, mouth, noise) discriminator used in GFPGAN.
"""
def __init__(self):
super(FacialComponentDiscriminator, self).__init__()
# It now uses a VGG-style architectrue with fixed model size
self.conv1 = ConvLayer(3, 64, 3, downsample=False, resample_kernel=(1, 3, 3, 1), bias=True, activate=True)
self.conv2 = ConvLayer(64, 128, 3, downsample=True, resample_kernel=(1, 3, 3, 1), bias=True, activate=True)
self.conv3 = ConvLayer(128, 128, 3, downsample=False, resample_kernel=(1, 3, 3, 1), bias=True, activate=True)
@@ -401,6 +417,12 @@ class FacialComponentDiscriminator(nn.Module):
self.final_conv = ConvLayer(256, 1, 3, bias=True, activate=False)
def forward(self, x, return_feats=False):
"""Forward function for FacialComponentDiscriminator.
Args:
x (Tensor): Input images.
return_feats (bool): Whether to return intermediate features. Default: False.
"""
feat = self.conv1(x)
feat = self.conv3(self.conv2(feat))
rlt_feats = []

View File

@@ -1,6 +1,7 @@
import math
import random
import torch
from basicsr.utils.registry import ARCH_REGISTRY
from torch import nn
from torch.nn import functional as F
@@ -8,14 +9,17 @@ from .stylegan2_clean_arch import StyleGAN2GeneratorClean
class StyleGAN2GeneratorCSFT(StyleGAN2GeneratorClean):
"""StyleGAN2 Generator.
"""StyleGAN2 Generator with SFT modulation (Spatial Feature Transform).
It is the clean version without custom compiled CUDA extensions used in StyleGAN2.
Args:
out_size (int): The spatial size of outputs.
num_style_feat (int): Channel number of style features. Default: 512.
num_mlp (int): Layer number of MLP style layers. Default: 8.
channel_multiplier (int): Channel multiplier for large networks of
StyleGAN2. Default: 2.
channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
narrow (float): The narrow ratio for channels. Default: 1.
sft_half (bool): Whether to apply SFT on half of the input channels. Default: False.
"""
def __init__(self, out_size, num_style_feat=512, num_mlp=8, channel_multiplier=2, narrow=1, sft_half=False):
@@ -25,7 +29,6 @@ class StyleGAN2GeneratorCSFT(StyleGAN2GeneratorClean):
num_mlp=num_mlp,
channel_multiplier=channel_multiplier,
narrow=narrow)
self.sft_half = sft_half
def forward(self,
@@ -38,21 +41,18 @@ class StyleGAN2GeneratorCSFT(StyleGAN2GeneratorClean):
truncation_latent=None,
inject_index=None,
return_latents=False):
"""Forward function for StyleGAN2Generator.
"""Forward function for StyleGAN2GeneratorCSFT.
Args:
styles (list[Tensor]): Sample codes of styles.
input_is_latent (bool): Whether input is latent style.
Default: False.
conditions (list[Tensor]): SFT conditions to generators.
input_is_latent (bool): Whether input is latent style. Default: False.
noise (Tensor | None): Input noise or None. Default: None.
randomize_noise (bool): Randomize noise, used when 'noise' is
False. Default: True.
truncation (float): TODO. Default: 1.
truncation_latent (Tensor | None): TODO. Default: None.
inject_index (int | None): The injection index for mixing noise.
Default: None.
return_latents (bool): Whether to return style latents.
Default: False.
randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True.
truncation (float): The truncation ratio. Default: 1.
truncation_latent (Tensor | None): The truncation latent tensor. Default: None.
inject_index (int | None): The injection index for mixing noise. Default: None.
return_latents (bool): Whether to return style latents. Default: False.
"""
# style codes -> latents with Style MLP layer
if not input_is_latent:
@@ -69,7 +69,7 @@ class StyleGAN2GeneratorCSFT(StyleGAN2GeneratorClean):
for style in styles:
style_truncation.append(truncation_latent + truncation * (style - truncation_latent))
styles = style_truncation
# get style latent with injection
# get style latents with injection
if len(styles) == 1:
inject_index = self.num_latent
@@ -98,15 +98,15 @@ class StyleGAN2GeneratorCSFT(StyleGAN2GeneratorClean):
# the conditions may have fewer levels
if i < len(conditions):
# SFT part to combine the conditions
if self.sft_half:
if self.sft_half: # only apply SFT to half of the channels
out_same, out_sft = torch.split(out, int(out.size(1) // 2), dim=1)
out_sft = out_sft * conditions[i - 1] + conditions[i]
out = torch.cat([out_same, out_sft], dim=1)
else:
else: # apply SFT to all the channels
out = out * conditions[i - 1] + conditions[i]
out = conv2(out, latent[:, i + 1], noise=noise2)
skip = to_rgb(out, latent[:, i + 2], skip)
skip = to_rgb(out, latent[:, i + 2], skip) # feature back to the rgb space
i += 2
image = skip
@@ -118,11 +118,12 @@ class StyleGAN2GeneratorCSFT(StyleGAN2GeneratorClean):
class ResBlock(nn.Module):
"""Residual block with upsampling/downsampling.
"""Residual block with bilinear upsampling/downsampling.
Args:
in_channels (int): Channel number of the input.
out_channels (int): Channel number of the output.
mode (str): Upsampling/downsampling mode. Options: down | up. Default: down.
"""
def __init__(self, in_channels, out_channels, mode='down'):
@@ -148,8 +149,27 @@ class ResBlock(nn.Module):
return out
@ARCH_REGISTRY.register()
class GFPGANv1Clean(nn.Module):
"""GFPGANv1 Clean version."""
"""The GFPGAN architecture: Unet + StyleGAN2 decoder with SFT.
It is the clean version without custom compiled CUDA extensions used in StyleGAN2.
Ref: GFP-GAN: Towards Real-World Blind Face Restoration with Generative Facial Prior.
Args:
out_size (int): The spatial size of outputs.
num_style_feat (int): Channel number of style features. Default: 512.
channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
decoder_load_path (str): The path to the pre-trained decoder model (usually, the StyleGAN2). Default: None.
fix_decoder (bool): Whether to fix the decoder. Default: True.
num_mlp (int): Layer number of MLP style layers. Default: 8.
input_is_latent (bool): Whether input is latent style. Default: False.
different_w (bool): Whether to use different latent w for different layers. Default: False.
narrow (float): The narrow ratio for channels. Default: 1.
sft_half (bool): Whether to apply SFT on half of the input channels. Default: False.
"""
def __init__(
self,
@@ -170,7 +190,7 @@ class GFPGANv1Clean(nn.Module):
self.different_w = different_w
self.num_style_feat = num_style_feat
unet_narrow = narrow * 0.5
unet_narrow = narrow * 0.5 # by default, use a half of input channels
channels = {
'4': int(512 * unet_narrow),
'8': int(512 * unet_narrow),
@@ -218,6 +238,7 @@ class GFPGANv1Clean(nn.Module):
self.final_linear = nn.Linear(channels['4'] * 4 * 4, linear_out_channel)
# the decoder: stylegan2 generator with SFT modulations
self.stylegan_decoder = StyleGAN2GeneratorCSFT(
out_size=out_size,
num_style_feat=num_style_feat,
@@ -226,14 +247,16 @@ class GFPGANv1Clean(nn.Module):
narrow=narrow,
sft_half=sft_half)
# load pre-trained stylegan2 model if necessary
if decoder_load_path:
self.stylegan_decoder.load_state_dict(
torch.load(decoder_load_path, map_location=lambda storage, loc: storage)['params_ema'])
# fix decoder without updating params
if fix_decoder:
for _, param in self.stylegan_decoder.named_parameters():
param.requires_grad = False
# for SFT
# for SFT modulations (scale and shift)
self.condition_scale = nn.ModuleList()
self.condition_shift = nn.ModuleList()
for i in range(3, self.log_size + 1):
@@ -251,13 +274,15 @@ class GFPGANv1Clean(nn.Module):
nn.Conv2d(out_channels, out_channels, 3, 1, 1), nn.LeakyReLU(0.2, True),
nn.Conv2d(out_channels, sft_out_channels, 3, 1, 1)))
def forward(self,
x,
return_latents=False,
save_feat_path=None,
load_feat_path=None,
return_rgb=True,
randomize_noise=True):
def forward(self, x, return_latents=False, return_rgb=True, randomize_noise=True):
"""Forward function for GFPGANv1Clean.
Args:
x (Tensor): Input images.
return_latents (bool): Whether to return style latents. Default: False.
return_rgb (bool): Whether return intermediate rgb images. Default: True.
randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True.
"""
conditions = []
unet_skips = []
out_rgbs = []
@@ -273,13 +298,14 @@ class GFPGANv1Clean(nn.Module):
style_code = self.final_linear(feat.view(feat.size(0), -1))
if self.different_w:
style_code = style_code.view(style_code.size(0), -1, self.num_style_feat)
# decode
for i in range(self.log_size - 2):
# add unet skip
feat = feat + unet_skips[i]
# ResUpLayer
feat = self.conv_body_up[i](feat)
# generate scale and shift for SFT layer
# generate scale and shift for SFT layers
scale = self.condition_scale[i](feat)
conditions.append(scale.clone())
shift = self.condition_shift[i](feat)
@@ -288,12 +314,6 @@ class GFPGANv1Clean(nn.Module):
if return_rgb:
out_rgbs.append(self.toRGB[i](feat))
if save_feat_path is not None:
torch.save(conditions, save_feat_path)
if load_feat_path is not None:
conditions = torch.load(load_feat_path)
conditions = [v.cuda() for v in conditions]
# decoder
image, _ = self.stylegan_decoder([style_code],
conditions,

View File

@@ -31,12 +31,9 @@ class ModulatedConv2d(nn.Module):
out_channels (int): Channel number of the output.
kernel_size (int): Size of the convolving kernel.
num_style_feat (int): Channel number of style features.
demodulate (bool): Whether to demodulate in the conv layer.
Default: True.
sample_mode (str | None): Indicating 'upsample', 'downsample' or None.
Default: None.
eps (float): A value added to the denominator for numerical stability.
Default: 1e-8.
demodulate (bool): Whether to demodulate in the conv layer. Default: True.
sample_mode (str | None): Indicating 'upsample', 'downsample' or None. Default: None.
eps (float): A value added to the denominator for numerical stability. Default: 1e-8.
"""
def __init__(self,
@@ -87,6 +84,7 @@ class ModulatedConv2d(nn.Module):
weight = weight.view(b * self.out_channels, c, self.kernel_size, self.kernel_size)
# upsample or downsample if necessary
if self.sample_mode == 'upsample':
x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)
elif self.sample_mode == 'downsample':
@@ -101,14 +99,12 @@ class ModulatedConv2d(nn.Module):
return out
def __repr__(self):
return (f'{self.__class__.__name__}(in_channels={self.in_channels}, '
f'out_channels={self.out_channels}, '
f'kernel_size={self.kernel_size}, '
f'demodulate={self.demodulate}, sample_mode={self.sample_mode})')
return (f'{self.__class__.__name__}(in_channels={self.in_channels}, out_channels={self.out_channels}, '
f'kernel_size={self.kernel_size}, demodulate={self.demodulate}, sample_mode={self.sample_mode})')
class StyleConv(nn.Module):
"""Style conv.
"""Style conv used in StyleGAN2.
Args:
in_channels (int): Channel number of the input.
@@ -116,8 +112,7 @@ class StyleConv(nn.Module):
kernel_size (int): Size of the convolving kernel.
num_style_feat (int): Channel number of style features.
demodulate (bool): Whether demodulate in the conv layer. Default: True.
sample_mode (str | None): Indicating 'upsample', 'downsample' or None.
Default: None.
sample_mode (str | None): Indicating 'upsample', 'downsample' or None. Default: None.
"""
def __init__(self, in_channels, out_channels, kernel_size, num_style_feat, demodulate=True, sample_mode=None):
@@ -144,7 +139,7 @@ class StyleConv(nn.Module):
class ToRGB(nn.Module):
"""To RGB from features.
"""To RGB (image space) from features.
Args:
in_channels (int): Channel number of input.
@@ -204,8 +199,7 @@ class StyleGAN2GeneratorClean(nn.Module):
out_size (int): The spatial size of outputs.
num_style_feat (int): Channel number of style features. Default: 512.
num_mlp (int): Layer number of MLP style layers. Default: 8.
channel_multiplier (int): Channel multiplier for large networks of
StyleGAN2. Default: 2.
channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
narrow (float): Narrow ratio for channels. Default: 1.0.
"""
@@ -222,6 +216,7 @@ class StyleGAN2GeneratorClean(nn.Module):
# initialization
default_init_weights(self.style_mlp, scale=1, bias_fill=0, a=0.2, mode='fan_in', nonlinearity='leaky_relu')
# channel list
channels = {
'4': int(512 * narrow),
'8': int(512 * narrow),
@@ -309,21 +304,17 @@ class StyleGAN2GeneratorClean(nn.Module):
truncation_latent=None,
inject_index=None,
return_latents=False):
"""Forward function for StyleGAN2Generator.
"""Forward function for StyleGAN2GeneratorClean.
Args:
styles (list[Tensor]): Sample codes of styles.
input_is_latent (bool): Whether input is latent style.
Default: False.
input_is_latent (bool): Whether input is latent style. Default: False.
noise (Tensor | None): Input noise or None. Default: None.
randomize_noise (bool): Randomize noise, used when 'noise' is
False. Default: True.
truncation (float): TODO. Default: 1.
truncation_latent (Tensor | None): TODO. Default: None.
inject_index (int | None): The injection index for mixing noise.
Default: None.
return_latents (bool): Whether to return style latents.
Default: False.
randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True.
truncation (float): The truncation ratio. Default: 1.
truncation_latent (Tensor | None): The truncation latent tensor. Default: None.
inject_index (int | None): The injection index for mixing noise. Default: None.
return_latents (bool): Whether to return style latents. Default: False.
"""
# style codes -> latents with Style MLP layer
if not input_is_latent:
@@ -340,7 +331,7 @@ class StyleGAN2GeneratorClean(nn.Module):
for style in styles:
style_truncation.append(truncation_latent + truncation * (style - truncation_latent))
styles = style_truncation
# get style latent with injection
# get style latents with injection
if len(styles) == 1:
inject_index = self.num_latent
@@ -366,7 +357,7 @@ class StyleGAN2GeneratorClean(nn.Module):
noise[2::2], self.to_rgbs):
out = conv1(out, latent[:, i], noise=noise1)
out = conv2(out, latent[:, i + 1], noise=noise2)
skip = to_rgb(out, latent[:, i + 2], skip)
skip = to_rgb(out, latent[:, i + 2], skip) # feature back to the rgb space
i += 2
image = skip

View File

@@ -15,6 +15,19 @@ from torchvision.transforms.functional import (adjust_brightness, adjust_contras
@DATASET_REGISTRY.register()
class FFHQDegradationDataset(data.Dataset):
"""FFHQ dataset for GFPGAN.
It reads high resolution images, and then generate low-quality (LQ) images on-the-fly.
Args:
opt (dict): Config for train datasets. It contains the following keys:
dataroot_gt (str): Data root path for gt.
io_backend (dict): IO backend type and other kwarg.
mean (list | tuple): Image mean.
std (list | tuple): Image std.
use_hflip (bool): Whether to horizontally flip.
Please see more options in the codes.
"""
def __init__(self, opt):
super(FFHQDegradationDataset, self).__init__()
@@ -29,11 +42,13 @@ class FFHQDegradationDataset(data.Dataset):
self.out_size = opt['out_size']
self.crop_components = opt.get('crop_components', False) # facial components
self.eye_enlarge_ratio = opt.get('eye_enlarge_ratio', 1)
self.eye_enlarge_ratio = opt.get('eye_enlarge_ratio', 1) # whether enlarge eye regions
if self.crop_components:
# load component list from a pre-process pth files
self.components_list = torch.load(opt.get('component_path'))
# file client (lmdb io backend)
if self.io_backend_opt['type'] == 'lmdb':
self.io_backend_opt['db_paths'] = self.gt_folder
if not self.gt_folder.endswith('.lmdb'):
@@ -41,9 +56,10 @@ class FFHQDegradationDataset(data.Dataset):
with open(osp.join(self.gt_folder, 'meta_info.txt')) as fin:
self.paths = [line.split('.')[0] for line in fin]
else:
# disk backend: scan file list from a folder
self.paths = paths_from_folder(self.gt_folder)
# degradations
# degradation configurations
self.blur_kernel_size = opt['blur_kernel_size']
self.kernel_list = opt['kernel_list']
self.kernel_prob = opt['kernel_prob']
@@ -60,22 +76,20 @@ class FFHQDegradationDataset(data.Dataset):
self.gray_prob = opt.get('gray_prob')
logger = get_root_logger()
logger.info(f'Blur: blur_kernel_size {self.blur_kernel_size}, '
f'sigma: [{", ".join(map(str, self.blur_sigma))}]')
logger.info(f'Blur: blur_kernel_size {self.blur_kernel_size}, sigma: [{", ".join(map(str, self.blur_sigma))}]')
logger.info(f'Downsample: downsample_range [{", ".join(map(str, self.downsample_range))}]')
logger.info(f'Noise: [{", ".join(map(str, self.noise_range))}]')
logger.info(f'JPEG compression: [{", ".join(map(str, self.jpeg_range))}]')
if self.color_jitter_prob is not None:
logger.info(f'Use random color jitter. Prob: {self.color_jitter_prob}, '
f'shift: {self.color_jitter_shift}')
logger.info(f'Use random color jitter. Prob: {self.color_jitter_prob}, shift: {self.color_jitter_shift}')
if self.gray_prob is not None:
logger.info(f'Use random gray. Prob: {self.gray_prob}')
self.color_jitter_shift /= 255.
@staticmethod
def color_jitter(img, shift):
"""jitter color: randomly jitter the RGB values, in numpy formats"""
jitter_val = np.random.uniform(-shift, shift, 3).astype(np.float32)
img = img + jitter_val
img = np.clip(img, 0, 1)
@@ -83,6 +97,7 @@ class FFHQDegradationDataset(data.Dataset):
@staticmethod
def color_jitter_pt(img, brightness, contrast, saturation, hue):
"""jitter color: randomly jitter the brightness, contrast, saturation, and hue, in torch Tensor formats"""
fn_idx = torch.randperm(4)
for fn_id in fn_idx:
if fn_id == 0 and brightness is not None:
@@ -103,6 +118,7 @@ class FFHQDegradationDataset(data.Dataset):
return img
def get_component_coordinates(self, index, status):
"""Get facial component (left_eye, right_eye, mouth) coordinates from a pre-loaded pth file"""
components_bbox = self.components_list[f'{index:08d}']
if status[0]: # hflip
# exchange right and left eye
@@ -131,6 +147,7 @@ class FFHQDegradationDataset(data.Dataset):
self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
# load gt image
# Shape: (h, w, c); channel order: BGR; image range: [0, 1], float32.
gt_path = self.paths[index]
img_bytes = self.file_client.get(gt_path)
img_gt = imfrombytes(img_bytes, float32=True)
@@ -139,6 +156,7 @@ class FFHQDegradationDataset(data.Dataset):
img_gt, status = augment(img_gt, hflip=self.opt['use_hflip'], rotation=False, return_status=True)
h, w, _ = img_gt.shape
# get facial component coordinates
if self.crop_components:
locations = self.get_component_coordinates(index, status)
loc_left_eye, loc_right_eye, loc_mouth = locations
@@ -173,9 +191,9 @@ class FFHQDegradationDataset(data.Dataset):
if self.gray_prob and np.random.uniform() < self.gray_prob:
img_lq = cv2.cvtColor(img_lq, cv2.COLOR_BGR2GRAY)
img_lq = np.tile(img_lq[:, :, None], [1, 1, 3])
if self.opt.get('gt_gray'):
if self.opt.get('gt_gray'): # whether convert GT to gray images
img_gt = cv2.cvtColor(img_gt, cv2.COLOR_BGR2GRAY)
img_gt = np.tile(img_gt[:, :, None], [1, 1, 3])
img_gt = np.tile(img_gt[:, :, None], [1, 1, 3]) # repeat the color channels
# BGR to RGB, HWC to CHW, numpy to tensor
img_gt, img_lq = img2tensor([img_gt, img_lq], bgr2rgb=True, float32=True)

View File

@@ -16,11 +16,11 @@ from tqdm import tqdm
@MODEL_REGISTRY.register()
class GFPGANModel(BaseModel):
"""GFPGAN model for <Towards real-world blind face restoratin with generative facial prior>"""
"""The GFPGAN model for Towards real-world blind face restoratin with generative facial prior"""
def __init__(self, opt):
super(GFPGANModel, self).__init__(opt)
self.idx = 0
self.idx = 0 # it is used for saving data for check
# define network
self.net_g = build_network(opt['network_g'])
@@ -51,8 +51,7 @@ class GFPGANModel(BaseModel):
self.load_network(self.net_d, load_path, self.opt['path'].get('strict_load_d', True))
# ----------- define net_g with Exponential Moving Average (EMA) ----------- #
# net_g_ema only used for testing on one GPU and saving
# There is no need to wrap with DistributedDataParallel
# net_g_ema only used for testing on one GPU and saving. There is no need to wrap with DistributedDataParallel
self.net_g_ema = build_network(self.opt['network_g']).to(self.device)
# load pretrained model
load_path = self.opt['path'].get('pretrain_network_g', None)
@@ -65,7 +64,7 @@ class GFPGANModel(BaseModel):
self.net_d.train()
self.net_g_ema.eval()
# ----------- facial components networks ----------- #
# ----------- facial component networks ----------- #
if ('network_d_left_eye' in self.opt and 'network_d_right_eye' in self.opt and 'network_d_mouth' in self.opt):
self.use_facial_disc = True
else:
@@ -102,17 +101,19 @@ class GFPGANModel(BaseModel):
self.cri_component = build_loss(train_opt['gan_component_opt']).to(self.device)
# ----------- define losses ----------- #
# pixel loss
if train_opt.get('pixel_opt'):
self.cri_pix = build_loss(train_opt['pixel_opt']).to(self.device)
else:
self.cri_pix = None
# perceptual loss
if train_opt.get('perceptual_opt'):
self.cri_perceptual = build_loss(train_opt['perceptual_opt']).to(self.device)
else:
self.cri_perceptual = None
# L1 loss used in pyramid loss, component style loss and identity loss
# L1 loss is used in pyramid loss, component style loss and identity loss
self.cri_l1 = build_loss(train_opt['L1_opt']).to(self.device)
# gan loss (wgan)
@@ -179,6 +180,7 @@ class GFPGANModel(BaseModel):
self.optimizer_d = self.get_optimizer(optim_type, optim_params_d, lr, betas=betas)
self.optimizers.append(self.optimizer_d)
# ----------- optimizers for facial component networks ----------- #
if self.use_facial_disc:
# setup optimizers for facial component discriminators
optim_type = train_opt['optim_component'].pop('type')
@@ -221,6 +223,7 @@ class GFPGANModel(BaseModel):
# self.idx = self.idx + 1
def construct_img_pyramid(self):
"""Construct image pyramid for intermediate restoration loss"""
pyramid_gt = [self.gt]
down_img = self.gt
for _ in range(0, self.log_size - 3):
@@ -229,7 +232,6 @@ class GFPGANModel(BaseModel):
return pyramid_gt
def get_roi_regions(self, eye_out_size=80, mouth_out_size=120):
# hard code
face_ratio = int(self.opt['network_g']['out_size'] / 512)
eye_out_size *= face_ratio
mouth_out_size *= face_ratio
@@ -288,6 +290,7 @@ class GFPGANModel(BaseModel):
p.requires_grad = False
self.optimizer_g.zero_grad()
# do not update facial component net_d
if self.use_facial_disc:
for p in self.net_d_left_eye.parameters():
p.requires_grad = False
@@ -419,11 +422,12 @@ class GFPGANModel(BaseModel):
real_d_pred = self.net_d(self.gt)
l_d = self.cri_gan(real_d_pred, True, is_disc=True) + self.cri_gan(fake_d_pred, False, is_disc=True)
loss_dict['l_d'] = l_d
# In wgan, real_score should be positive and fake_score should benegative
# In WGAN, real_score should be positive and fake_score should be negative
loss_dict['real_score'] = real_d_pred.detach().mean()
loss_dict['fake_score'] = fake_d_pred.detach().mean()
l_d.backward()
# regularization loss
if current_iter % self.net_d_reg_every == 0:
self.gt.requires_grad = True
real_pred = self.net_d(self.gt)
@@ -434,8 +438,9 @@ class GFPGANModel(BaseModel):
self.optimizer_d.step()
# optimize facial component discriminators
if self.use_facial_disc:
# lefe eye
# left eye
fake_d_pred, _ = self.net_d_left_eye(self.left_eyes.detach())
real_d_pred, _ = self.net_d_left_eye(self.left_eyes_gt)
l_d_left_eye = self.cri_component(
@@ -485,22 +490,32 @@ class GFPGANModel(BaseModel):
def nondist_validation(self, dataloader, current_iter, tb_logger, save_img):
dataset_name = dataloader.dataset.opt['name']
with_metrics = self.opt['val'].get('metrics') is not None
use_pbar = self.opt['val'].get('pbar', False)
if with_metrics:
self.metric_results = {metric: 0 for metric in self.opt['val']['metrics'].keys()}
pbar = tqdm(total=len(dataloader), unit='image')
if not hasattr(self, 'metric_results'): # only execute in the first run
self.metric_results = {metric: 0 for metric in self.opt['val']['metrics'].keys()}
# initialize the best metric results for each dataset_name (supporting multiple validation datasets)
self._initialize_best_metric_results(dataset_name)
# zero self.metric_results
self.metric_results = {metric: 0 for metric in self.metric_results}
metric_data = dict()
if use_pbar:
pbar = tqdm(total=len(dataloader), unit='image')
for idx, val_data in enumerate(dataloader):
img_name = osp.splitext(osp.basename(val_data['lq_path'][0]))[0]
self.feed_data(val_data)
self.test()
visuals = self.get_current_visuals()
sr_img = tensor2img([visuals['sr']], min_max=(-1, 1))
gt_img = tensor2img([visuals['gt']], min_max=(-1, 1))
if 'gt' in visuals:
gt_img = tensor2img([visuals['gt']], min_max=(-1, 1))
sr_img = tensor2img(self.output.detach().cpu(), min_max=(-1, 1))
metric_data['img'] = sr_img
if hasattr(self, 'gt'):
gt_img = tensor2img(self.gt.detach().cpu(), min_max=(-1, 1))
metric_data['img2'] = gt_img
del self.gt
# tentative for out of GPU memory
del self.lq
del self.output
@@ -522,35 +537,38 @@ class GFPGANModel(BaseModel):
if with_metrics:
# calculate metrics
for name, opt_ in self.opt['val']['metrics'].items():
metric_data = dict(img1=sr_img, img2=gt_img)
self.metric_results[name] += calculate_metric(metric_data, opt_)
pbar.update(1)
pbar.set_description(f'Test {img_name}')
pbar.close()
if use_pbar:
pbar.update(1)
pbar.set_description(f'Test {img_name}')
if use_pbar:
pbar.close()
if with_metrics:
for metric in self.metric_results.keys():
self.metric_results[metric] /= (idx + 1)
# update the best metric result
self._update_best_metric_result(dataset_name, metric, self.metric_results[metric], current_iter)
self._log_validation_metric_values(current_iter, dataset_name, tb_logger)
def _log_validation_metric_values(self, current_iter, dataset_name, tb_logger):
log_str = f'Validation {dataset_name}\n'
for metric, value in self.metric_results.items():
log_str += f'\t # {metric}: {value:.4f}\n'
log_str += f'\t # {metric}: {value:.4f}'
if hasattr(self, 'best_metric_results'):
log_str += (f'\tBest: {self.best_metric_results[dataset_name][metric]["val"]:.4f} @ '
f'{self.best_metric_results[dataset_name][metric]["iter"]} iter')
log_str += '\n'
logger = get_root_logger()
logger.info(log_str)
if tb_logger:
for metric, value in self.metric_results.items():
tb_logger.add_scalar(f'metrics/{metric}', value, current_iter)
def get_current_visuals(self):
out_dict = OrderedDict()
out_dict['gt'] = self.gt.detach().cpu()
out_dict['sr'] = self.output.detach().cpu()
return out_dict
tb_logger.add_scalar(f'metrics/{dataset_name}/{metric}', value, current_iter)
def save(self, epoch, current_iter):
# save net_g and net_d
self.save_network([self.net_g, self.net_g_ema], 'net_g', current_iter, param_key=['params', 'params_ema'])
self.save_network(self.net_d, 'net_d', current_iter)
# save component discriminators
@@ -558,4 +576,5 @@ class GFPGANModel(BaseModel):
self.save_network(self.net_d_left_eye, 'net_d_left_eye', current_iter)
self.save_network(self.net_d_right_eye, 'net_d_right_eye', current_iter)
self.save_network(self.net_d_mouth, 'net_d_mouth', current_iter)
# save training state
self.save_training_state(epoch, current_iter)

View File

@@ -14,6 +14,20 @@ ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
class GFPGANer():
"""Helper for restoration with GFPGAN.
It will detect and crop faces, and then resize the faces to 512x512.
GFPGAN is used to restored the resized faces.
The background is upsampled with the bg_upsampler.
Finally, the faces will be pasted back to the upsample background image.
Args:
model_path (str): The path to the GFPGAN model. It can be urls (will first download it automatically).
upscale (float): The upscale of the final output. Default: 2.
arch (str): The GFPGAN architecture. Option: clean | original. Default: clean.
channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
bg_upsampler (nn.Module): The upsampler for the background. Default: None.
"""
def __init__(self, model_path, upscale=2, arch='clean', channel_multiplier=2, bg_upsampler=None):
self.upscale = upscale
@@ -70,7 +84,7 @@ class GFPGANer():
def enhance(self, img, has_aligned=False, only_center_face=False, paste_back=True):
self.face_helper.clean_all()
if has_aligned:
if has_aligned: # the inputs are already aligned
img = cv2.resize(img, (512, 512))
self.face_helper.cropped_faces = [img]
else:
@@ -78,6 +92,7 @@ class GFPGANer():
# get face landmarks for each face
self.face_helper.get_face_landmarks_5(only_center_face=only_center_face, eye_dist_threshold=5)
# eye_dist_threshold=5: skip faces whose eye distance is smaller than 5 pixels
# TODO: even with eye_dist_threshold, it will still introduce wrong detections and restorations.
# align and warp each face
self.face_helper.align_warp_face()
@@ -100,9 +115,9 @@ class GFPGANer():
self.face_helper.add_restored_face(restored_face)
if not has_aligned and paste_back:
# upsample the background
if self.bg_upsampler is not None:
# Now only support RealESRGAN
# Now only support RealESRGAN for upsampling background
bg_img = self.bg_upsampler.enhance(img, outscale=self.upscale)[0]
else:
bg_img = None
@@ -116,7 +131,9 @@ class GFPGANer():
def load_file_from_url(url, model_dir=None, progress=True, file_name=None):
"""Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py
"""Load file form http url, will download models if necessary.
Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py
"""
if model_dir is None:
hub_dir = get_dir()

View File

@@ -10,20 +10,22 @@ from gfpgan import GFPGANer
def main():
"""Inference demo for GFPGAN.
"""
parser = argparse.ArgumentParser()
parser.add_argument('--upscale', type=int, default=2)
parser.add_argument('--arch', type=str, default='clean')
parser.add_argument('--channel', type=int, default=2)
parser.add_argument('--upscale', type=int, default=2, help='The final upsampling scale of the image')
parser.add_argument('--arch', type=str, default='clean', help='The GFPGAN architecture. Option: clean | original')
parser.add_argument('--channel', type=int, default=2, help='Channel multiplier for large networks of StyleGAN2')
parser.add_argument('--model_path', type=str, default='experiments/pretrained_models/GFPGANCleanv1-NoCE-C2.pth')
parser.add_argument('--bg_upsampler', type=str, default='realesrgan')
parser.add_argument('--bg_tile', type=int, default=400)
parser.add_argument('--test_path', type=str, default='inputs/whole_imgs')
parser.add_argument('--bg_upsampler', type=str, default='realesrgan', help='background upsampler')
parser.add_argument(
'--bg_tile', type=int, default=400, help='Tile size for background sampler, 0 for no tile during testing')
parser.add_argument('--test_path', type=str, default='inputs/whole_imgs', help='Input folder')
parser.add_argument('--suffix', type=str, default=None, help='Suffix of the restored faces')
parser.add_argument('--only_center_face', action='store_true')
parser.add_argument('--aligned', action='store_true')
parser.add_argument('--paste_back', action='store_false')
parser.add_argument('--save_root', type=str, default='results')
parser.add_argument('--only_center_face', action='store_true', help='Only restore the center face')
parser.add_argument('--aligned', action='store_true', help='Input are aligned faces')
parser.add_argument('--paste_back', action='store_false', help='Paste the restored faces back to images')
parser.add_argument('--save_root', type=str, default='results', help='Path to save root')
parser.add_argument(
'--ext',
type=str,
@@ -70,6 +72,7 @@ def main():
basename, ext = os.path.splitext(img_name)
input_img = cv2.imread(img_path, cv2.IMREAD_COLOR)
# restore faces and background if necessary
cropped_faces, restored_faces, restored_img = restorer.enhance(
input_img, has_aligned=args.aligned, only_center_face=args.only_center_face, paste_back=args.paste_back)
@@ -85,7 +88,7 @@ def main():
save_face_name = f'{basename}_{idx:02d}.png'
save_restore_path = os.path.join(args.save_root, 'restored_faces', save_face_name)
imwrite(restored_face, save_restore_path)
# save cmp image
# save comparison image
cmp_img = np.concatenate((cropped_face, restored_face), axis=1)
imwrite(cmp_img, os.path.join(args.save_root, 'cmp', f'{basename}_{idx:02d}.png'))

View File

@@ -1,7 +1,7 @@
# general settings
name: train_GFPGANv1_512
model_type: GFPGANModel
num_gpu: 4
num_gpu: auto # officially, we use 4 GPUs
manual_seed: 0
# dataset and data loader settings
@@ -194,7 +194,7 @@ val:
save_img: true
metrics:
psnr: # metric name, can be arbitrary
psnr: # metric name
type: calculate_psnr
crop_border: 0
test_y_channel: false

View File

@@ -1,7 +1,7 @@
# general settings
name: train_GFPGANv1_512_simple
model_type: GFPGANModel
num_gpu: 4
num_gpu: auto # officially, we use 4 GPUs
manual_seed: 0
# dataset and data loader settings
@@ -40,10 +40,6 @@ datasets:
# gray_prob: 0.01
# gt_gray: True
# crop_components: false
# component_path: experiments/pretrained_models/FFHQ_eye_mouth_landmarks_512.pth
# eye_enlarge_ratio: 1.4
# data loader
use_shuffle: true
num_worker_per_gpu: 6
@@ -86,20 +82,6 @@ network_d:
channel_multiplier: 1
resample_kernel: [1, 3, 3, 1]
# network_d_left_eye:
# type: FacialComponentDiscriminator
# network_d_right_eye:
# type: FacialComponentDiscriminator
# network_d_mouth:
# type: FacialComponentDiscriminator
network_identity:
type: ResNetArcFace
block: IRBlock
layers: [2, 2, 2, 2]
use_se: False
# path
path:
@@ -107,13 +89,7 @@ path:
param_key_g: params_ema
strict_load_g: ~
pretrain_network_d: ~
# pretrain_network_d_left_eye: ~
# pretrain_network_d_right_eye: ~
# pretrain_network_d_mouth: ~
pretrain_network_identity: experiments/pretrained_models/arcface_resnet18.pth
# resume
resume_state: ~
ignore_resume_networks: ['network_identity']
# training settings
train:
@@ -173,16 +149,6 @@ train:
loss_weight: !!float 1e-1
# r1 regularization for discriminator
r1_reg_weight: 10
# facial component loss
# gan_component_opt:
# type: GANLoss
# gan_type: vanilla
# real_label_val: 1.0
# fake_label_val: 0.0
# loss_weight: !!float 1
# comp_style_weight: 200
# identity loss
identity_weight: 10
net_d_iters: 1
net_d_init_iters: 0
@@ -194,7 +160,7 @@ val:
save_img: true
metrics:
psnr: # metric name, can be arbitrary
psnr: # metric name
type: calculate_psnr
crop_border: 0
test_y_channel: false

View File

@@ -1,24 +1,31 @@
import cv2
import json
import numpy as np
import os
import torch
from basicsr.utils import FileClient, imfrombytes
from collections import OrderedDict
# ---------------------------- This script is used to parse facial landmarks ------------------------------------- #
# Configurations
save_img = False
scale = 0.5 # 0.5 for official FFHQ (512x512), 1 for others
enlarge_ratio = 1.4 # only for eyes
json_path = 'ffhq-dataset-v2.json'
face_path = 'datasets/ffhq/ffhq_512.lmdb'
save_path = './FFHQ_eye_mouth_landmarks_512.pth'
print('Load JSON metadata...')
# use the json file in FFHQ dataset
with open('ffhq-dataset-v2.json', 'rb') as f:
# use the official json file in FFHQ dataset
with open(json_path, 'rb') as f:
json_data = json.load(f, object_pairs_hook=OrderedDict)
print('Open LMDB file...')
# read ffhq images
file_client = FileClient('lmdb', db_paths='datasets/ffhq/ffhq_512.lmdb')
with open('datasets/ffhq/ffhq_512.lmdb/meta_info.txt') as fin:
file_client = FileClient('lmdb', db_paths=face_path)
with open(os.path.join(face_path, 'meta_info.txt')) as fin:
paths = [line.split('.')[0] for line in fin]
save_img = False
scale = 0.5 # 0.5 for official FFHQ (512x512), 1 for others
enlarge_ratio = 1.4 # only for eyes
save_dict = {}
for item_idx, item in enumerate(json_data.values()):
@@ -34,6 +41,7 @@ for item_idx, item in enumerate(json_data.values()):
img_bytes = file_client.get(paths[item_idx])
img = imfrombytes(img_bytes, float32=True)
# get landmarks for each component
map_left_eye = list(range(36, 42))
map_right_eye = list(range(42, 48))
map_mouth = list(range(48, 68))
@@ -74,4 +82,4 @@ for item_idx, item in enumerate(json_data.values()):
save_dict[f'{item_idx:08d}'] = item_dict
print('Save...')
torch.save(save_dict, './FFHQ_eye_mouth_landmarks_512.pth')
torch.save(save_dict, save_path)