mirror of
https://github.com/TencentARC/GFPGAN.git
synced 2026-02-04 05:44:32 +00:00
clean and add more comments
This commit is contained in:
9
.github/workflows/no-response.yml
vendored
9
.github/workflows/no-response.yml
vendored
@@ -1,12 +1,11 @@
|
||||
name: No Response
|
||||
|
||||
# TODO: it seems not to work
|
||||
# Modified from: https://raw.githubusercontent.com/github/docs/main/.github/workflows/no-response.yaml
|
||||
|
||||
# **What it does**: Closes issues that don't have enough information to be
|
||||
# actionable.
|
||||
# **Why we have it**: To remove the need for maintainers to remember to check
|
||||
# back on issues periodically to see if contributors have
|
||||
# responded.
|
||||
# **What it does**: Closes issues that don't have enough information to be actionable.
|
||||
# **Why we have it**: To remove the need for maintainers to remember to check back on issues periodically
|
||||
# to see if contributors have responded.
|
||||
# **Who does it impact**: Everyone that works on docs or docs-internal.
|
||||
|
||||
on:
|
||||
|
||||
@@ -3,4 +3,4 @@ from .archs import *
|
||||
from .data import *
|
||||
from .models import *
|
||||
from .utils import *
|
||||
from .version import __gitsha__, __version__
|
||||
from .version import *
|
||||
|
||||
@@ -2,13 +2,27 @@ import torch.nn as nn
|
||||
from basicsr.utils.registry import ARCH_REGISTRY
|
||||
|
||||
|
||||
def conv3x3(in_planes, out_planes, stride=1):
|
||||
"""3x3 convolution with padding"""
|
||||
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
|
||||
def conv3x3(inplanes, outplanes, stride=1):
|
||||
"""A simple wrapper for 3x3 convolution with padding.
|
||||
|
||||
Args:
|
||||
inplanes (int): Channel number of inputs.
|
||||
outplanes (int): Channel number of outputs.
|
||||
stride (int): Stride in convolution. Default: 1.
|
||||
"""
|
||||
return nn.Conv2d(inplanes, outplanes, kernel_size=3, stride=stride, padding=1, bias=False)
|
||||
|
||||
|
||||
class BasicBlock(nn.Module):
|
||||
expansion = 1
|
||||
"""Basic residual block used in the ResNetArcFace architecture.
|
||||
|
||||
Args:
|
||||
inplanes (int): Channel number of inputs.
|
||||
planes (int): Channel number of outputs.
|
||||
stride (int): Stride in convolution. Default: 1.
|
||||
downsample (nn.Module): The downsample module. Default: None.
|
||||
"""
|
||||
expansion = 1 # output channel expansion ratio
|
||||
|
||||
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
||||
super(BasicBlock, self).__init__()
|
||||
@@ -40,7 +54,16 @@ class BasicBlock(nn.Module):
|
||||
|
||||
|
||||
class IRBlock(nn.Module):
|
||||
expansion = 1
|
||||
"""Improved residual block (IR Block) used in the ResNetArcFace architecture.
|
||||
|
||||
Args:
|
||||
inplanes (int): Channel number of inputs.
|
||||
planes (int): Channel number of outputs.
|
||||
stride (int): Stride in convolution. Default: 1.
|
||||
downsample (nn.Module): The downsample module. Default: None.
|
||||
use_se (bool): Whether use the SEBlock (squeeze and excitation block). Default: True.
|
||||
"""
|
||||
expansion = 1 # output channel expansion ratio
|
||||
|
||||
def __init__(self, inplanes, planes, stride=1, downsample=None, use_se=True):
|
||||
super(IRBlock, self).__init__()
|
||||
@@ -78,7 +101,15 @@ class IRBlock(nn.Module):
|
||||
|
||||
|
||||
class Bottleneck(nn.Module):
|
||||
expansion = 4
|
||||
"""Bottleneck block used in the ResNetArcFace architecture.
|
||||
|
||||
Args:
|
||||
inplanes (int): Channel number of inputs.
|
||||
planes (int): Channel number of outputs.
|
||||
stride (int): Stride in convolution. Default: 1.
|
||||
downsample (nn.Module): The downsample module. Default: None.
|
||||
"""
|
||||
expansion = 4 # output channel expansion ratio
|
||||
|
||||
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
||||
super(Bottleneck, self).__init__()
|
||||
@@ -116,10 +147,16 @@ class Bottleneck(nn.Module):
|
||||
|
||||
|
||||
class SEBlock(nn.Module):
|
||||
"""The squeeze-and-excitation block (SEBlock) used in the IRBlock.
|
||||
|
||||
Args:
|
||||
channel (int): Channel number of inputs.
|
||||
reduction (int): Channel reduction ration. Default: 16.
|
||||
"""
|
||||
|
||||
def __init__(self, channel, reduction=16):
|
||||
super(SEBlock, self).__init__()
|
||||
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
||||
self.avg_pool = nn.AdaptiveAvgPool2d(1) # pool to 1x1 without spatial information
|
||||
self.fc = nn.Sequential(
|
||||
nn.Linear(channel, channel // reduction), nn.PReLU(), nn.Linear(channel // reduction, channel),
|
||||
nn.Sigmoid())
|
||||
@@ -133,6 +170,15 @@ class SEBlock(nn.Module):
|
||||
|
||||
@ARCH_REGISTRY.register()
|
||||
class ResNetArcFace(nn.Module):
|
||||
"""ArcFace with ResNet architectures.
|
||||
|
||||
Ref: ArcFace: Additive Angular Margin Loss for Deep Face Recognition.
|
||||
|
||||
Args:
|
||||
block (str): Block used in the ArcFace architecture.
|
||||
layers (tuple(int)): Block numbers in each layer.
|
||||
use_se (bool): Whether use the SEBlock (squeeze and excitation block). Default: True.
|
||||
"""
|
||||
|
||||
def __init__(self, block, layers, use_se=True):
|
||||
if block == 'IRBlock':
|
||||
@@ -140,6 +186,7 @@ class ResNetArcFace(nn.Module):
|
||||
self.inplanes = 64
|
||||
self.use_se = use_se
|
||||
super(ResNetArcFace, self).__init__()
|
||||
|
||||
self.conv1 = nn.Conv2d(1, 64, kernel_size=3, padding=1, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(64)
|
||||
self.prelu = nn.PReLU()
|
||||
@@ -153,6 +200,7 @@ class ResNetArcFace(nn.Module):
|
||||
self.fc5 = nn.Linear(512 * 8 * 8, 512)
|
||||
self.bn5 = nn.BatchNorm1d(512)
|
||||
|
||||
# initialization
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.xavier_normal_(m.weight)
|
||||
@@ -163,7 +211,7 @@ class ResNetArcFace(nn.Module):
|
||||
nn.init.xavier_normal_(m.weight)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
def _make_layer(self, block, planes, blocks, stride=1):
|
||||
def _make_layer(self, block, planes, num_blocks, stride=1):
|
||||
downsample = None
|
||||
if stride != 1 or self.inplanes != planes * block.expansion:
|
||||
downsample = nn.Sequential(
|
||||
@@ -173,7 +221,7 @@ class ResNetArcFace(nn.Module):
|
||||
layers = []
|
||||
layers.append(block(self.inplanes, planes, stride, downsample, use_se=self.use_se))
|
||||
self.inplanes = planes
|
||||
for _ in range(1, blocks):
|
||||
for _ in range(1, num_blocks):
|
||||
layers.append(block(self.inplanes, planes, use_se=self.use_se))
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
@@ -10,18 +10,18 @@ from torch.nn import functional as F
|
||||
|
||||
|
||||
class StyleGAN2GeneratorSFT(StyleGAN2Generator):
|
||||
"""StyleGAN2 Generator.
|
||||
"""StyleGAN2 Generator with SFT modulation (Spatial Feature Transform).
|
||||
|
||||
Args:
|
||||
out_size (int): The spatial size of outputs.
|
||||
num_style_feat (int): Channel number of style features. Default: 512.
|
||||
num_mlp (int): Layer number of MLP style layers. Default: 8.
|
||||
channel_multiplier (int): Channel multiplier for large networks of
|
||||
StyleGAN2. Default: 2.
|
||||
resample_kernel (list[int]): A list indicating the 1D resample kernel
|
||||
magnitude. A cross production will be applied to extent 1D resample
|
||||
kernel to 2D resample kernel. Default: [1, 3, 3, 1].
|
||||
channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
|
||||
resample_kernel (list[int]): A list indicating the 1D resample kernel magnitude. A cross production will be
|
||||
applied to extent 1D resample kernel to 2D resample kernel. Default: (1, 3, 3, 1).
|
||||
lr_mlp (float): Learning rate multiplier for mlp layers. Default: 0.01.
|
||||
narrow (float): The narrow ratio for channels. Default: 1.
|
||||
sft_half (bool): Whether to apply SFT on half of the input channels. Default: False.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
@@ -53,21 +53,18 @@ class StyleGAN2GeneratorSFT(StyleGAN2Generator):
|
||||
truncation_latent=None,
|
||||
inject_index=None,
|
||||
return_latents=False):
|
||||
"""Forward function for StyleGAN2Generator.
|
||||
"""Forward function for StyleGAN2GeneratorSFT.
|
||||
|
||||
Args:
|
||||
styles (list[Tensor]): Sample codes of styles.
|
||||
input_is_latent (bool): Whether input is latent style.
|
||||
Default: False.
|
||||
conditions (list[Tensor]): SFT conditions to generators.
|
||||
input_is_latent (bool): Whether input is latent style. Default: False.
|
||||
noise (Tensor | None): Input noise or None. Default: None.
|
||||
randomize_noise (bool): Randomize noise, used when 'noise' is
|
||||
False. Default: True.
|
||||
truncation (float): TODO. Default: 1.
|
||||
truncation_latent (Tensor | None): TODO. Default: None.
|
||||
inject_index (int | None): The injection index for mixing noise.
|
||||
Default: None.
|
||||
return_latents (bool): Whether to return style latents.
|
||||
Default: False.
|
||||
randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True.
|
||||
truncation (float): The truncation ratio. Default: 1.
|
||||
truncation_latent (Tensor | None): The truncation latent tensor. Default: None.
|
||||
inject_index (int | None): The injection index for mixing noise. Default: None.
|
||||
return_latents (bool): Whether to return style latents. Default: False.
|
||||
"""
|
||||
# style codes -> latents with Style MLP layer
|
||||
if not input_is_latent:
|
||||
@@ -84,7 +81,7 @@ class StyleGAN2GeneratorSFT(StyleGAN2Generator):
|
||||
for style in styles:
|
||||
style_truncation.append(truncation_latent + truncation * (style - truncation_latent))
|
||||
styles = style_truncation
|
||||
# get style latent with injection
|
||||
# get style latents with injection
|
||||
if len(styles) == 1:
|
||||
inject_index = self.num_latent
|
||||
|
||||
@@ -113,15 +110,15 @@ class StyleGAN2GeneratorSFT(StyleGAN2Generator):
|
||||
# the conditions may have fewer levels
|
||||
if i < len(conditions):
|
||||
# SFT part to combine the conditions
|
||||
if self.sft_half:
|
||||
if self.sft_half: # only apply SFT to half of the channels
|
||||
out_same, out_sft = torch.split(out, int(out.size(1) // 2), dim=1)
|
||||
out_sft = out_sft * conditions[i - 1] + conditions[i]
|
||||
out = torch.cat([out_same, out_sft], dim=1)
|
||||
else:
|
||||
else: # apply SFT to all the channels
|
||||
out = out * conditions[i - 1] + conditions[i]
|
||||
|
||||
out = conv2(out, latent[:, i + 1], noise=noise2)
|
||||
skip = to_rgb(out, latent[:, i + 2], skip)
|
||||
skip = to_rgb(out, latent[:, i + 2], skip) # feature back to the rgb space
|
||||
i += 2
|
||||
|
||||
image = skip
|
||||
@@ -133,17 +130,15 @@ class StyleGAN2GeneratorSFT(StyleGAN2Generator):
|
||||
|
||||
|
||||
class ConvUpLayer(nn.Module):
|
||||
"""Conv Up Layer. Bilinear upsample + Conv.
|
||||
"""Convolutional upsampling layer. It uses bilinear upsampler + Conv.
|
||||
|
||||
Args:
|
||||
in_channels (int): Channel number of the input.
|
||||
out_channels (int): Channel number of the output.
|
||||
kernel_size (int): Size of the convolving kernel.
|
||||
stride (int): Stride of the convolution. Default: 1
|
||||
padding (int): Zero-padding added to both sides of the input.
|
||||
Default: 0.
|
||||
bias (bool): If ``True``, adds a learnable bias to the output.
|
||||
Default: ``True``.
|
||||
padding (int): Zero-padding added to both sides of the input. Default: 0.
|
||||
bias (bool): If ``True``, adds a learnable bias to the output. Default: ``True``.
|
||||
bias_init_val (float): Bias initialized value. Default: 0.
|
||||
activate (bool): Whether use activateion. Default: True.
|
||||
"""
|
||||
@@ -163,6 +158,7 @@ class ConvUpLayer(nn.Module):
|
||||
self.kernel_size = kernel_size
|
||||
self.stride = stride
|
||||
self.padding = padding
|
||||
# self.scale is used to scale the convolution weights, which is related to the common initializations.
|
||||
self.scale = 1 / math.sqrt(in_channels * kernel_size**2)
|
||||
|
||||
self.weight = nn.Parameter(torch.randn(out_channels, in_channels, kernel_size, kernel_size))
|
||||
@@ -223,7 +219,26 @@ class ResUpBlock(nn.Module):
|
||||
|
||||
@ARCH_REGISTRY.register()
|
||||
class GFPGANv1(nn.Module):
|
||||
"""Unet + StyleGAN2 decoder with SFT."""
|
||||
"""The GFPGAN architecture: Unet + StyleGAN2 decoder with SFT.
|
||||
|
||||
Ref: GFP-GAN: Towards Real-World Blind Face Restoration with Generative Facial Prior.
|
||||
|
||||
Args:
|
||||
out_size (int): The spatial size of outputs.
|
||||
num_style_feat (int): Channel number of style features. Default: 512.
|
||||
channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
|
||||
resample_kernel (list[int]): A list indicating the 1D resample kernel magnitude. A cross production will be
|
||||
applied to extent 1D resample kernel to 2D resample kernel. Default: (1, 3, 3, 1).
|
||||
decoder_load_path (str): The path to the pre-trained decoder model (usually, the StyleGAN2). Default: None.
|
||||
fix_decoder (bool): Whether to fix the decoder. Default: True.
|
||||
|
||||
num_mlp (int): Layer number of MLP style layers. Default: 8.
|
||||
lr_mlp (float): Learning rate multiplier for mlp layers. Default: 0.01.
|
||||
input_is_latent (bool): Whether input is latent style. Default: False.
|
||||
different_w (bool): Whether to use different latent w for different layers. Default: False.
|
||||
narrow (float): The narrow ratio for channels. Default: 1.
|
||||
sft_half (bool): Whether to apply SFT on half of the input channels. Default: False.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -246,7 +261,7 @@ class GFPGANv1(nn.Module):
|
||||
self.different_w = different_w
|
||||
self.num_style_feat = num_style_feat
|
||||
|
||||
unet_narrow = narrow * 0.5
|
||||
unet_narrow = narrow * 0.5 # by default, use a half of input channels
|
||||
channels = {
|
||||
'4': int(512 * unet_narrow),
|
||||
'8': int(512 * unet_narrow),
|
||||
@@ -295,6 +310,7 @@ class GFPGANv1(nn.Module):
|
||||
self.final_linear = EqualLinear(
|
||||
channels['4'] * 4 * 4, linear_out_channel, bias=True, bias_init_val=0, lr_mul=1, activation=None)
|
||||
|
||||
# the decoder: stylegan2 generator with SFT modulations
|
||||
self.stylegan_decoder = StyleGAN2GeneratorSFT(
|
||||
out_size=out_size,
|
||||
num_style_feat=num_style_feat,
|
||||
@@ -305,14 +321,16 @@ class GFPGANv1(nn.Module):
|
||||
narrow=narrow,
|
||||
sft_half=sft_half)
|
||||
|
||||
# load pre-trained stylegan2 model if necessary
|
||||
if decoder_load_path:
|
||||
self.stylegan_decoder.load_state_dict(
|
||||
torch.load(decoder_load_path, map_location=lambda storage, loc: storage)['params_ema'])
|
||||
# fix decoder without updating params
|
||||
if fix_decoder:
|
||||
for _, param in self.stylegan_decoder.named_parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
# for SFT
|
||||
# for SFT modulations (scale and shift)
|
||||
self.condition_scale = nn.ModuleList()
|
||||
self.condition_shift = nn.ModuleList()
|
||||
for i in range(3, self.log_size + 1):
|
||||
@@ -332,13 +350,15 @@ class GFPGANv1(nn.Module):
|
||||
ScaledLeakyReLU(0.2),
|
||||
EqualConv2d(out_channels, sft_out_channels, 3, stride=1, padding=1, bias=True, bias_init_val=0)))
|
||||
|
||||
def forward(self,
|
||||
x,
|
||||
return_latents=False,
|
||||
save_feat_path=None,
|
||||
load_feat_path=None,
|
||||
return_rgb=True,
|
||||
randomize_noise=True):
|
||||
def forward(self, x, return_latents=False, return_rgb=True, randomize_noise=True):
|
||||
"""Forward function for GFPGANv1.
|
||||
|
||||
Args:
|
||||
x (Tensor): Input images.
|
||||
return_latents (bool): Whether to return style latents. Default: False.
|
||||
return_rgb (bool): Whether return intermediate rgb images. Default: True.
|
||||
randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True.
|
||||
"""
|
||||
conditions = []
|
||||
unet_skips = []
|
||||
out_rgbs = []
|
||||
@@ -362,7 +382,7 @@ class GFPGANv1(nn.Module):
|
||||
feat = feat + unet_skips[i]
|
||||
# ResUpLayer
|
||||
feat = self.conv_body_up[i](feat)
|
||||
# generate scale and shift for SFT layer
|
||||
# generate scale and shift for SFT layers
|
||||
scale = self.condition_scale[i](feat)
|
||||
conditions.append(scale.clone())
|
||||
shift = self.condition_shift[i](feat)
|
||||
@@ -371,12 +391,6 @@ class GFPGANv1(nn.Module):
|
||||
if return_rgb:
|
||||
out_rgbs.append(self.toRGB[i](feat))
|
||||
|
||||
if save_feat_path is not None:
|
||||
torch.save(conditions, save_feat_path)
|
||||
if load_feat_path is not None:
|
||||
conditions = torch.load(load_feat_path)
|
||||
conditions = [v.cuda() for v in conditions]
|
||||
|
||||
# decoder
|
||||
image, _ = self.stylegan_decoder([style_code],
|
||||
conditions,
|
||||
@@ -389,10 +403,12 @@ class GFPGANv1(nn.Module):
|
||||
|
||||
@ARCH_REGISTRY.register()
|
||||
class FacialComponentDiscriminator(nn.Module):
|
||||
"""Facial component (eyes, mouth, noise) discriminator used in GFPGAN.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super(FacialComponentDiscriminator, self).__init__()
|
||||
|
||||
# It now uses a VGG-style architectrue with fixed model size
|
||||
self.conv1 = ConvLayer(3, 64, 3, downsample=False, resample_kernel=(1, 3, 3, 1), bias=True, activate=True)
|
||||
self.conv2 = ConvLayer(64, 128, 3, downsample=True, resample_kernel=(1, 3, 3, 1), bias=True, activate=True)
|
||||
self.conv3 = ConvLayer(128, 128, 3, downsample=False, resample_kernel=(1, 3, 3, 1), bias=True, activate=True)
|
||||
@@ -401,6 +417,12 @@ class FacialComponentDiscriminator(nn.Module):
|
||||
self.final_conv = ConvLayer(256, 1, 3, bias=True, activate=False)
|
||||
|
||||
def forward(self, x, return_feats=False):
|
||||
"""Forward function for FacialComponentDiscriminator.
|
||||
|
||||
Args:
|
||||
x (Tensor): Input images.
|
||||
return_feats (bool): Whether to return intermediate features. Default: False.
|
||||
"""
|
||||
feat = self.conv1(x)
|
||||
feat = self.conv3(self.conv2(feat))
|
||||
rlt_feats = []
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import math
|
||||
import random
|
||||
import torch
|
||||
from basicsr.utils.registry import ARCH_REGISTRY
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
@@ -8,14 +9,17 @@ from .stylegan2_clean_arch import StyleGAN2GeneratorClean
|
||||
|
||||
|
||||
class StyleGAN2GeneratorCSFT(StyleGAN2GeneratorClean):
|
||||
"""StyleGAN2 Generator.
|
||||
"""StyleGAN2 Generator with SFT modulation (Spatial Feature Transform).
|
||||
|
||||
It is the clean version without custom compiled CUDA extensions used in StyleGAN2.
|
||||
|
||||
Args:
|
||||
out_size (int): The spatial size of outputs.
|
||||
num_style_feat (int): Channel number of style features. Default: 512.
|
||||
num_mlp (int): Layer number of MLP style layers. Default: 8.
|
||||
channel_multiplier (int): Channel multiplier for large networks of
|
||||
StyleGAN2. Default: 2.
|
||||
channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
|
||||
narrow (float): The narrow ratio for channels. Default: 1.
|
||||
sft_half (bool): Whether to apply SFT on half of the input channels. Default: False.
|
||||
"""
|
||||
|
||||
def __init__(self, out_size, num_style_feat=512, num_mlp=8, channel_multiplier=2, narrow=1, sft_half=False):
|
||||
@@ -25,7 +29,6 @@ class StyleGAN2GeneratorCSFT(StyleGAN2GeneratorClean):
|
||||
num_mlp=num_mlp,
|
||||
channel_multiplier=channel_multiplier,
|
||||
narrow=narrow)
|
||||
|
||||
self.sft_half = sft_half
|
||||
|
||||
def forward(self,
|
||||
@@ -38,21 +41,18 @@ class StyleGAN2GeneratorCSFT(StyleGAN2GeneratorClean):
|
||||
truncation_latent=None,
|
||||
inject_index=None,
|
||||
return_latents=False):
|
||||
"""Forward function for StyleGAN2Generator.
|
||||
"""Forward function for StyleGAN2GeneratorCSFT.
|
||||
|
||||
Args:
|
||||
styles (list[Tensor]): Sample codes of styles.
|
||||
input_is_latent (bool): Whether input is latent style.
|
||||
Default: False.
|
||||
conditions (list[Tensor]): SFT conditions to generators.
|
||||
input_is_latent (bool): Whether input is latent style. Default: False.
|
||||
noise (Tensor | None): Input noise or None. Default: None.
|
||||
randomize_noise (bool): Randomize noise, used when 'noise' is
|
||||
False. Default: True.
|
||||
truncation (float): TODO. Default: 1.
|
||||
truncation_latent (Tensor | None): TODO. Default: None.
|
||||
inject_index (int | None): The injection index for mixing noise.
|
||||
Default: None.
|
||||
return_latents (bool): Whether to return style latents.
|
||||
Default: False.
|
||||
randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True.
|
||||
truncation (float): The truncation ratio. Default: 1.
|
||||
truncation_latent (Tensor | None): The truncation latent tensor. Default: None.
|
||||
inject_index (int | None): The injection index for mixing noise. Default: None.
|
||||
return_latents (bool): Whether to return style latents. Default: False.
|
||||
"""
|
||||
# style codes -> latents with Style MLP layer
|
||||
if not input_is_latent:
|
||||
@@ -69,7 +69,7 @@ class StyleGAN2GeneratorCSFT(StyleGAN2GeneratorClean):
|
||||
for style in styles:
|
||||
style_truncation.append(truncation_latent + truncation * (style - truncation_latent))
|
||||
styles = style_truncation
|
||||
# get style latent with injection
|
||||
# get style latents with injection
|
||||
if len(styles) == 1:
|
||||
inject_index = self.num_latent
|
||||
|
||||
@@ -98,15 +98,15 @@ class StyleGAN2GeneratorCSFT(StyleGAN2GeneratorClean):
|
||||
# the conditions may have fewer levels
|
||||
if i < len(conditions):
|
||||
# SFT part to combine the conditions
|
||||
if self.sft_half:
|
||||
if self.sft_half: # only apply SFT to half of the channels
|
||||
out_same, out_sft = torch.split(out, int(out.size(1) // 2), dim=1)
|
||||
out_sft = out_sft * conditions[i - 1] + conditions[i]
|
||||
out = torch.cat([out_same, out_sft], dim=1)
|
||||
else:
|
||||
else: # apply SFT to all the channels
|
||||
out = out * conditions[i - 1] + conditions[i]
|
||||
|
||||
out = conv2(out, latent[:, i + 1], noise=noise2)
|
||||
skip = to_rgb(out, latent[:, i + 2], skip)
|
||||
skip = to_rgb(out, latent[:, i + 2], skip) # feature back to the rgb space
|
||||
i += 2
|
||||
|
||||
image = skip
|
||||
@@ -118,11 +118,12 @@ class StyleGAN2GeneratorCSFT(StyleGAN2GeneratorClean):
|
||||
|
||||
|
||||
class ResBlock(nn.Module):
|
||||
"""Residual block with upsampling/downsampling.
|
||||
"""Residual block with bilinear upsampling/downsampling.
|
||||
|
||||
Args:
|
||||
in_channels (int): Channel number of the input.
|
||||
out_channels (int): Channel number of the output.
|
||||
mode (str): Upsampling/downsampling mode. Options: down | up. Default: down.
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels, out_channels, mode='down'):
|
||||
@@ -148,8 +149,27 @@ class ResBlock(nn.Module):
|
||||
return out
|
||||
|
||||
|
||||
@ARCH_REGISTRY.register()
|
||||
class GFPGANv1Clean(nn.Module):
|
||||
"""GFPGANv1 Clean version."""
|
||||
"""The GFPGAN architecture: Unet + StyleGAN2 decoder with SFT.
|
||||
|
||||
It is the clean version without custom compiled CUDA extensions used in StyleGAN2.
|
||||
|
||||
Ref: GFP-GAN: Towards Real-World Blind Face Restoration with Generative Facial Prior.
|
||||
|
||||
Args:
|
||||
out_size (int): The spatial size of outputs.
|
||||
num_style_feat (int): Channel number of style features. Default: 512.
|
||||
channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
|
||||
decoder_load_path (str): The path to the pre-trained decoder model (usually, the StyleGAN2). Default: None.
|
||||
fix_decoder (bool): Whether to fix the decoder. Default: True.
|
||||
|
||||
num_mlp (int): Layer number of MLP style layers. Default: 8.
|
||||
input_is_latent (bool): Whether input is latent style. Default: False.
|
||||
different_w (bool): Whether to use different latent w for different layers. Default: False.
|
||||
narrow (float): The narrow ratio for channels. Default: 1.
|
||||
sft_half (bool): Whether to apply SFT on half of the input channels. Default: False.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -170,7 +190,7 @@ class GFPGANv1Clean(nn.Module):
|
||||
self.different_w = different_w
|
||||
self.num_style_feat = num_style_feat
|
||||
|
||||
unet_narrow = narrow * 0.5
|
||||
unet_narrow = narrow * 0.5 # by default, use a half of input channels
|
||||
channels = {
|
||||
'4': int(512 * unet_narrow),
|
||||
'8': int(512 * unet_narrow),
|
||||
@@ -218,6 +238,7 @@ class GFPGANv1Clean(nn.Module):
|
||||
|
||||
self.final_linear = nn.Linear(channels['4'] * 4 * 4, linear_out_channel)
|
||||
|
||||
# the decoder: stylegan2 generator with SFT modulations
|
||||
self.stylegan_decoder = StyleGAN2GeneratorCSFT(
|
||||
out_size=out_size,
|
||||
num_style_feat=num_style_feat,
|
||||
@@ -226,14 +247,16 @@ class GFPGANv1Clean(nn.Module):
|
||||
narrow=narrow,
|
||||
sft_half=sft_half)
|
||||
|
||||
# load pre-trained stylegan2 model if necessary
|
||||
if decoder_load_path:
|
||||
self.stylegan_decoder.load_state_dict(
|
||||
torch.load(decoder_load_path, map_location=lambda storage, loc: storage)['params_ema'])
|
||||
# fix decoder without updating params
|
||||
if fix_decoder:
|
||||
for _, param in self.stylegan_decoder.named_parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
# for SFT
|
||||
# for SFT modulations (scale and shift)
|
||||
self.condition_scale = nn.ModuleList()
|
||||
self.condition_shift = nn.ModuleList()
|
||||
for i in range(3, self.log_size + 1):
|
||||
@@ -251,13 +274,15 @@ class GFPGANv1Clean(nn.Module):
|
||||
nn.Conv2d(out_channels, out_channels, 3, 1, 1), nn.LeakyReLU(0.2, True),
|
||||
nn.Conv2d(out_channels, sft_out_channels, 3, 1, 1)))
|
||||
|
||||
def forward(self,
|
||||
x,
|
||||
return_latents=False,
|
||||
save_feat_path=None,
|
||||
load_feat_path=None,
|
||||
return_rgb=True,
|
||||
randomize_noise=True):
|
||||
def forward(self, x, return_latents=False, return_rgb=True, randomize_noise=True):
|
||||
"""Forward function for GFPGANv1Clean.
|
||||
|
||||
Args:
|
||||
x (Tensor): Input images.
|
||||
return_latents (bool): Whether to return style latents. Default: False.
|
||||
return_rgb (bool): Whether return intermediate rgb images. Default: True.
|
||||
randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True.
|
||||
"""
|
||||
conditions = []
|
||||
unet_skips = []
|
||||
out_rgbs = []
|
||||
@@ -273,13 +298,14 @@ class GFPGANv1Clean(nn.Module):
|
||||
style_code = self.final_linear(feat.view(feat.size(0), -1))
|
||||
if self.different_w:
|
||||
style_code = style_code.view(style_code.size(0), -1, self.num_style_feat)
|
||||
|
||||
# decode
|
||||
for i in range(self.log_size - 2):
|
||||
# add unet skip
|
||||
feat = feat + unet_skips[i]
|
||||
# ResUpLayer
|
||||
feat = self.conv_body_up[i](feat)
|
||||
# generate scale and shift for SFT layer
|
||||
# generate scale and shift for SFT layers
|
||||
scale = self.condition_scale[i](feat)
|
||||
conditions.append(scale.clone())
|
||||
shift = self.condition_shift[i](feat)
|
||||
@@ -288,12 +314,6 @@ class GFPGANv1Clean(nn.Module):
|
||||
if return_rgb:
|
||||
out_rgbs.append(self.toRGB[i](feat))
|
||||
|
||||
if save_feat_path is not None:
|
||||
torch.save(conditions, save_feat_path)
|
||||
if load_feat_path is not None:
|
||||
conditions = torch.load(load_feat_path)
|
||||
conditions = [v.cuda() for v in conditions]
|
||||
|
||||
# decoder
|
||||
image, _ = self.stylegan_decoder([style_code],
|
||||
conditions,
|
||||
|
||||
@@ -31,12 +31,9 @@ class ModulatedConv2d(nn.Module):
|
||||
out_channels (int): Channel number of the output.
|
||||
kernel_size (int): Size of the convolving kernel.
|
||||
num_style_feat (int): Channel number of style features.
|
||||
demodulate (bool): Whether to demodulate in the conv layer.
|
||||
Default: True.
|
||||
sample_mode (str | None): Indicating 'upsample', 'downsample' or None.
|
||||
Default: None.
|
||||
eps (float): A value added to the denominator for numerical stability.
|
||||
Default: 1e-8.
|
||||
demodulate (bool): Whether to demodulate in the conv layer. Default: True.
|
||||
sample_mode (str | None): Indicating 'upsample', 'downsample' or None. Default: None.
|
||||
eps (float): A value added to the denominator for numerical stability. Default: 1e-8.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
@@ -87,6 +84,7 @@ class ModulatedConv2d(nn.Module):
|
||||
|
||||
weight = weight.view(b * self.out_channels, c, self.kernel_size, self.kernel_size)
|
||||
|
||||
# upsample or downsample if necessary
|
||||
if self.sample_mode == 'upsample':
|
||||
x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)
|
||||
elif self.sample_mode == 'downsample':
|
||||
@@ -101,14 +99,12 @@ class ModulatedConv2d(nn.Module):
|
||||
return out
|
||||
|
||||
def __repr__(self):
|
||||
return (f'{self.__class__.__name__}(in_channels={self.in_channels}, '
|
||||
f'out_channels={self.out_channels}, '
|
||||
f'kernel_size={self.kernel_size}, '
|
||||
f'demodulate={self.demodulate}, sample_mode={self.sample_mode})')
|
||||
return (f'{self.__class__.__name__}(in_channels={self.in_channels}, out_channels={self.out_channels}, '
|
||||
f'kernel_size={self.kernel_size}, demodulate={self.demodulate}, sample_mode={self.sample_mode})')
|
||||
|
||||
|
||||
class StyleConv(nn.Module):
|
||||
"""Style conv.
|
||||
"""Style conv used in StyleGAN2.
|
||||
|
||||
Args:
|
||||
in_channels (int): Channel number of the input.
|
||||
@@ -116,8 +112,7 @@ class StyleConv(nn.Module):
|
||||
kernel_size (int): Size of the convolving kernel.
|
||||
num_style_feat (int): Channel number of style features.
|
||||
demodulate (bool): Whether demodulate in the conv layer. Default: True.
|
||||
sample_mode (str | None): Indicating 'upsample', 'downsample' or None.
|
||||
Default: None.
|
||||
sample_mode (str | None): Indicating 'upsample', 'downsample' or None. Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels, out_channels, kernel_size, num_style_feat, demodulate=True, sample_mode=None):
|
||||
@@ -144,7 +139,7 @@ class StyleConv(nn.Module):
|
||||
|
||||
|
||||
class ToRGB(nn.Module):
|
||||
"""To RGB from features.
|
||||
"""To RGB (image space) from features.
|
||||
|
||||
Args:
|
||||
in_channels (int): Channel number of input.
|
||||
@@ -204,8 +199,7 @@ class StyleGAN2GeneratorClean(nn.Module):
|
||||
out_size (int): The spatial size of outputs.
|
||||
num_style_feat (int): Channel number of style features. Default: 512.
|
||||
num_mlp (int): Layer number of MLP style layers. Default: 8.
|
||||
channel_multiplier (int): Channel multiplier for large networks of
|
||||
StyleGAN2. Default: 2.
|
||||
channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
|
||||
narrow (float): Narrow ratio for channels. Default: 1.0.
|
||||
"""
|
||||
|
||||
@@ -222,6 +216,7 @@ class StyleGAN2GeneratorClean(nn.Module):
|
||||
# initialization
|
||||
default_init_weights(self.style_mlp, scale=1, bias_fill=0, a=0.2, mode='fan_in', nonlinearity='leaky_relu')
|
||||
|
||||
# channel list
|
||||
channels = {
|
||||
'4': int(512 * narrow),
|
||||
'8': int(512 * narrow),
|
||||
@@ -309,21 +304,17 @@ class StyleGAN2GeneratorClean(nn.Module):
|
||||
truncation_latent=None,
|
||||
inject_index=None,
|
||||
return_latents=False):
|
||||
"""Forward function for StyleGAN2Generator.
|
||||
"""Forward function for StyleGAN2GeneratorClean.
|
||||
|
||||
Args:
|
||||
styles (list[Tensor]): Sample codes of styles.
|
||||
input_is_latent (bool): Whether input is latent style.
|
||||
Default: False.
|
||||
input_is_latent (bool): Whether input is latent style. Default: False.
|
||||
noise (Tensor | None): Input noise or None. Default: None.
|
||||
randomize_noise (bool): Randomize noise, used when 'noise' is
|
||||
False. Default: True.
|
||||
truncation (float): TODO. Default: 1.
|
||||
truncation_latent (Tensor | None): TODO. Default: None.
|
||||
inject_index (int | None): The injection index for mixing noise.
|
||||
Default: None.
|
||||
return_latents (bool): Whether to return style latents.
|
||||
Default: False.
|
||||
randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True.
|
||||
truncation (float): The truncation ratio. Default: 1.
|
||||
truncation_latent (Tensor | None): The truncation latent tensor. Default: None.
|
||||
inject_index (int | None): The injection index for mixing noise. Default: None.
|
||||
return_latents (bool): Whether to return style latents. Default: False.
|
||||
"""
|
||||
# style codes -> latents with Style MLP layer
|
||||
if not input_is_latent:
|
||||
@@ -340,7 +331,7 @@ class StyleGAN2GeneratorClean(nn.Module):
|
||||
for style in styles:
|
||||
style_truncation.append(truncation_latent + truncation * (style - truncation_latent))
|
||||
styles = style_truncation
|
||||
# get style latent with injection
|
||||
# get style latents with injection
|
||||
if len(styles) == 1:
|
||||
inject_index = self.num_latent
|
||||
|
||||
@@ -366,7 +357,7 @@ class StyleGAN2GeneratorClean(nn.Module):
|
||||
noise[2::2], self.to_rgbs):
|
||||
out = conv1(out, latent[:, i], noise=noise1)
|
||||
out = conv2(out, latent[:, i + 1], noise=noise2)
|
||||
skip = to_rgb(out, latent[:, i + 2], skip)
|
||||
skip = to_rgb(out, latent[:, i + 2], skip) # feature back to the rgb space
|
||||
i += 2
|
||||
|
||||
image = skip
|
||||
|
||||
@@ -15,6 +15,19 @@ from torchvision.transforms.functional import (adjust_brightness, adjust_contras
|
||||
|
||||
@DATASET_REGISTRY.register()
|
||||
class FFHQDegradationDataset(data.Dataset):
|
||||
"""FFHQ dataset for GFPGAN.
|
||||
|
||||
It reads high resolution images, and then generate low-quality (LQ) images on-the-fly.
|
||||
|
||||
Args:
|
||||
opt (dict): Config for train datasets. It contains the following keys:
|
||||
dataroot_gt (str): Data root path for gt.
|
||||
io_backend (dict): IO backend type and other kwarg.
|
||||
mean (list | tuple): Image mean.
|
||||
std (list | tuple): Image std.
|
||||
use_hflip (bool): Whether to horizontally flip.
|
||||
Please see more options in the codes.
|
||||
"""
|
||||
|
||||
def __init__(self, opt):
|
||||
super(FFHQDegradationDataset, self).__init__()
|
||||
@@ -29,11 +42,13 @@ class FFHQDegradationDataset(data.Dataset):
|
||||
self.out_size = opt['out_size']
|
||||
|
||||
self.crop_components = opt.get('crop_components', False) # facial components
|
||||
self.eye_enlarge_ratio = opt.get('eye_enlarge_ratio', 1)
|
||||
self.eye_enlarge_ratio = opt.get('eye_enlarge_ratio', 1) # whether enlarge eye regions
|
||||
|
||||
if self.crop_components:
|
||||
# load component list from a pre-process pth files
|
||||
self.components_list = torch.load(opt.get('component_path'))
|
||||
|
||||
# file client (lmdb io backend)
|
||||
if self.io_backend_opt['type'] == 'lmdb':
|
||||
self.io_backend_opt['db_paths'] = self.gt_folder
|
||||
if not self.gt_folder.endswith('.lmdb'):
|
||||
@@ -41,9 +56,10 @@ class FFHQDegradationDataset(data.Dataset):
|
||||
with open(osp.join(self.gt_folder, 'meta_info.txt')) as fin:
|
||||
self.paths = [line.split('.')[0] for line in fin]
|
||||
else:
|
||||
# disk backend: scan file list from a folder
|
||||
self.paths = paths_from_folder(self.gt_folder)
|
||||
|
||||
# degradations
|
||||
# degradation configurations
|
||||
self.blur_kernel_size = opt['blur_kernel_size']
|
||||
self.kernel_list = opt['kernel_list']
|
||||
self.kernel_prob = opt['kernel_prob']
|
||||
@@ -60,22 +76,20 @@ class FFHQDegradationDataset(data.Dataset):
|
||||
self.gray_prob = opt.get('gray_prob')
|
||||
|
||||
logger = get_root_logger()
|
||||
logger.info(f'Blur: blur_kernel_size {self.blur_kernel_size}, '
|
||||
f'sigma: [{", ".join(map(str, self.blur_sigma))}]')
|
||||
logger.info(f'Blur: blur_kernel_size {self.blur_kernel_size}, sigma: [{", ".join(map(str, self.blur_sigma))}]')
|
||||
logger.info(f'Downsample: downsample_range [{", ".join(map(str, self.downsample_range))}]')
|
||||
logger.info(f'Noise: [{", ".join(map(str, self.noise_range))}]')
|
||||
logger.info(f'JPEG compression: [{", ".join(map(str, self.jpeg_range))}]')
|
||||
|
||||
if self.color_jitter_prob is not None:
|
||||
logger.info(f'Use random color jitter. Prob: {self.color_jitter_prob}, '
|
||||
f'shift: {self.color_jitter_shift}')
|
||||
logger.info(f'Use random color jitter. Prob: {self.color_jitter_prob}, shift: {self.color_jitter_shift}')
|
||||
if self.gray_prob is not None:
|
||||
logger.info(f'Use random gray. Prob: {self.gray_prob}')
|
||||
|
||||
self.color_jitter_shift /= 255.
|
||||
|
||||
@staticmethod
|
||||
def color_jitter(img, shift):
|
||||
"""jitter color: randomly jitter the RGB values, in numpy formats"""
|
||||
jitter_val = np.random.uniform(-shift, shift, 3).astype(np.float32)
|
||||
img = img + jitter_val
|
||||
img = np.clip(img, 0, 1)
|
||||
@@ -83,6 +97,7 @@ class FFHQDegradationDataset(data.Dataset):
|
||||
|
||||
@staticmethod
|
||||
def color_jitter_pt(img, brightness, contrast, saturation, hue):
|
||||
"""jitter color: randomly jitter the brightness, contrast, saturation, and hue, in torch Tensor formats"""
|
||||
fn_idx = torch.randperm(4)
|
||||
for fn_id in fn_idx:
|
||||
if fn_id == 0 and brightness is not None:
|
||||
@@ -103,6 +118,7 @@ class FFHQDegradationDataset(data.Dataset):
|
||||
return img
|
||||
|
||||
def get_component_coordinates(self, index, status):
|
||||
"""Get facial component (left_eye, right_eye, mouth) coordinates from a pre-loaded pth file"""
|
||||
components_bbox = self.components_list[f'{index:08d}']
|
||||
if status[0]: # hflip
|
||||
# exchange right and left eye
|
||||
@@ -131,6 +147,7 @@ class FFHQDegradationDataset(data.Dataset):
|
||||
self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
|
||||
|
||||
# load gt image
|
||||
# Shape: (h, w, c); channel order: BGR; image range: [0, 1], float32.
|
||||
gt_path = self.paths[index]
|
||||
img_bytes = self.file_client.get(gt_path)
|
||||
img_gt = imfrombytes(img_bytes, float32=True)
|
||||
@@ -139,6 +156,7 @@ class FFHQDegradationDataset(data.Dataset):
|
||||
img_gt, status = augment(img_gt, hflip=self.opt['use_hflip'], rotation=False, return_status=True)
|
||||
h, w, _ = img_gt.shape
|
||||
|
||||
# get facial component coordinates
|
||||
if self.crop_components:
|
||||
locations = self.get_component_coordinates(index, status)
|
||||
loc_left_eye, loc_right_eye, loc_mouth = locations
|
||||
@@ -173,9 +191,9 @@ class FFHQDegradationDataset(data.Dataset):
|
||||
if self.gray_prob and np.random.uniform() < self.gray_prob:
|
||||
img_lq = cv2.cvtColor(img_lq, cv2.COLOR_BGR2GRAY)
|
||||
img_lq = np.tile(img_lq[:, :, None], [1, 1, 3])
|
||||
if self.opt.get('gt_gray'):
|
||||
if self.opt.get('gt_gray'): # whether convert GT to gray images
|
||||
img_gt = cv2.cvtColor(img_gt, cv2.COLOR_BGR2GRAY)
|
||||
img_gt = np.tile(img_gt[:, :, None], [1, 1, 3])
|
||||
img_gt = np.tile(img_gt[:, :, None], [1, 1, 3]) # repeat the color channels
|
||||
|
||||
# BGR to RGB, HWC to CHW, numpy to tensor
|
||||
img_gt, img_lq = img2tensor([img_gt, img_lq], bgr2rgb=True, float32=True)
|
||||
|
||||
@@ -16,11 +16,11 @@ from tqdm import tqdm
|
||||
|
||||
@MODEL_REGISTRY.register()
|
||||
class GFPGANModel(BaseModel):
|
||||
"""GFPGAN model for <Towards real-world blind face restoratin with generative facial prior>"""
|
||||
"""The GFPGAN model for Towards real-world blind face restoratin with generative facial prior"""
|
||||
|
||||
def __init__(self, opt):
|
||||
super(GFPGANModel, self).__init__(opt)
|
||||
self.idx = 0
|
||||
self.idx = 0 # it is used for saving data for check
|
||||
|
||||
# define network
|
||||
self.net_g = build_network(opt['network_g'])
|
||||
@@ -51,8 +51,7 @@ class GFPGANModel(BaseModel):
|
||||
self.load_network(self.net_d, load_path, self.opt['path'].get('strict_load_d', True))
|
||||
|
||||
# ----------- define net_g with Exponential Moving Average (EMA) ----------- #
|
||||
# net_g_ema only used for testing on one GPU and saving
|
||||
# There is no need to wrap with DistributedDataParallel
|
||||
# net_g_ema only used for testing on one GPU and saving. There is no need to wrap with DistributedDataParallel
|
||||
self.net_g_ema = build_network(self.opt['network_g']).to(self.device)
|
||||
# load pretrained model
|
||||
load_path = self.opt['path'].get('pretrain_network_g', None)
|
||||
@@ -65,7 +64,7 @@ class GFPGANModel(BaseModel):
|
||||
self.net_d.train()
|
||||
self.net_g_ema.eval()
|
||||
|
||||
# ----------- facial components networks ----------- #
|
||||
# ----------- facial component networks ----------- #
|
||||
if ('network_d_left_eye' in self.opt and 'network_d_right_eye' in self.opt and 'network_d_mouth' in self.opt):
|
||||
self.use_facial_disc = True
|
||||
else:
|
||||
@@ -102,17 +101,19 @@ class GFPGANModel(BaseModel):
|
||||
self.cri_component = build_loss(train_opt['gan_component_opt']).to(self.device)
|
||||
|
||||
# ----------- define losses ----------- #
|
||||
# pixel loss
|
||||
if train_opt.get('pixel_opt'):
|
||||
self.cri_pix = build_loss(train_opt['pixel_opt']).to(self.device)
|
||||
else:
|
||||
self.cri_pix = None
|
||||
|
||||
# perceptual loss
|
||||
if train_opt.get('perceptual_opt'):
|
||||
self.cri_perceptual = build_loss(train_opt['perceptual_opt']).to(self.device)
|
||||
else:
|
||||
self.cri_perceptual = None
|
||||
|
||||
# L1 loss used in pyramid loss, component style loss and identity loss
|
||||
# L1 loss is used in pyramid loss, component style loss and identity loss
|
||||
self.cri_l1 = build_loss(train_opt['L1_opt']).to(self.device)
|
||||
|
||||
# gan loss (wgan)
|
||||
@@ -179,6 +180,7 @@ class GFPGANModel(BaseModel):
|
||||
self.optimizer_d = self.get_optimizer(optim_type, optim_params_d, lr, betas=betas)
|
||||
self.optimizers.append(self.optimizer_d)
|
||||
|
||||
# ----------- optimizers for facial component networks ----------- #
|
||||
if self.use_facial_disc:
|
||||
# setup optimizers for facial component discriminators
|
||||
optim_type = train_opt['optim_component'].pop('type')
|
||||
@@ -221,6 +223,7 @@ class GFPGANModel(BaseModel):
|
||||
# self.idx = self.idx + 1
|
||||
|
||||
def construct_img_pyramid(self):
|
||||
"""Construct image pyramid for intermediate restoration loss"""
|
||||
pyramid_gt = [self.gt]
|
||||
down_img = self.gt
|
||||
for _ in range(0, self.log_size - 3):
|
||||
@@ -229,7 +232,6 @@ class GFPGANModel(BaseModel):
|
||||
return pyramid_gt
|
||||
|
||||
def get_roi_regions(self, eye_out_size=80, mouth_out_size=120):
|
||||
# hard code
|
||||
face_ratio = int(self.opt['network_g']['out_size'] / 512)
|
||||
eye_out_size *= face_ratio
|
||||
mouth_out_size *= face_ratio
|
||||
@@ -288,6 +290,7 @@ class GFPGANModel(BaseModel):
|
||||
p.requires_grad = False
|
||||
self.optimizer_g.zero_grad()
|
||||
|
||||
# do not update facial component net_d
|
||||
if self.use_facial_disc:
|
||||
for p in self.net_d_left_eye.parameters():
|
||||
p.requires_grad = False
|
||||
@@ -419,11 +422,12 @@ class GFPGANModel(BaseModel):
|
||||
real_d_pred = self.net_d(self.gt)
|
||||
l_d = self.cri_gan(real_d_pred, True, is_disc=True) + self.cri_gan(fake_d_pred, False, is_disc=True)
|
||||
loss_dict['l_d'] = l_d
|
||||
# In wgan, real_score should be positive and fake_score should benegative
|
||||
# In WGAN, real_score should be positive and fake_score should be negative
|
||||
loss_dict['real_score'] = real_d_pred.detach().mean()
|
||||
loss_dict['fake_score'] = fake_d_pred.detach().mean()
|
||||
l_d.backward()
|
||||
|
||||
# regularization loss
|
||||
if current_iter % self.net_d_reg_every == 0:
|
||||
self.gt.requires_grad = True
|
||||
real_pred = self.net_d(self.gt)
|
||||
@@ -434,8 +438,9 @@ class GFPGANModel(BaseModel):
|
||||
|
||||
self.optimizer_d.step()
|
||||
|
||||
# optimize facial component discriminators
|
||||
if self.use_facial_disc:
|
||||
# lefe eye
|
||||
# left eye
|
||||
fake_d_pred, _ = self.net_d_left_eye(self.left_eyes.detach())
|
||||
real_d_pred, _ = self.net_d_left_eye(self.left_eyes_gt)
|
||||
l_d_left_eye = self.cri_component(
|
||||
@@ -485,22 +490,32 @@ class GFPGANModel(BaseModel):
|
||||
def nondist_validation(self, dataloader, current_iter, tb_logger, save_img):
|
||||
dataset_name = dataloader.dataset.opt['name']
|
||||
with_metrics = self.opt['val'].get('metrics') is not None
|
||||
use_pbar = self.opt['val'].get('pbar', False)
|
||||
|
||||
if with_metrics:
|
||||
self.metric_results = {metric: 0 for metric in self.opt['val']['metrics'].keys()}
|
||||
pbar = tqdm(total=len(dataloader), unit='image')
|
||||
if not hasattr(self, 'metric_results'): # only execute in the first run
|
||||
self.metric_results = {metric: 0 for metric in self.opt['val']['metrics'].keys()}
|
||||
# initialize the best metric results for each dataset_name (supporting multiple validation datasets)
|
||||
self._initialize_best_metric_results(dataset_name)
|
||||
# zero self.metric_results
|
||||
self.metric_results = {metric: 0 for metric in self.metric_results}
|
||||
|
||||
metric_data = dict()
|
||||
if use_pbar:
|
||||
pbar = tqdm(total=len(dataloader), unit='image')
|
||||
|
||||
for idx, val_data in enumerate(dataloader):
|
||||
img_name = osp.splitext(osp.basename(val_data['lq_path'][0]))[0]
|
||||
self.feed_data(val_data)
|
||||
self.test()
|
||||
|
||||
visuals = self.get_current_visuals()
|
||||
sr_img = tensor2img([visuals['sr']], min_max=(-1, 1))
|
||||
gt_img = tensor2img([visuals['gt']], min_max=(-1, 1))
|
||||
|
||||
if 'gt' in visuals:
|
||||
gt_img = tensor2img([visuals['gt']], min_max=(-1, 1))
|
||||
sr_img = tensor2img(self.output.detach().cpu(), min_max=(-1, 1))
|
||||
metric_data['img'] = sr_img
|
||||
if hasattr(self, 'gt'):
|
||||
gt_img = tensor2img(self.gt.detach().cpu(), min_max=(-1, 1))
|
||||
metric_data['img2'] = gt_img
|
||||
del self.gt
|
||||
|
||||
# tentative for out of GPU memory
|
||||
del self.lq
|
||||
del self.output
|
||||
@@ -522,35 +537,38 @@ class GFPGANModel(BaseModel):
|
||||
if with_metrics:
|
||||
# calculate metrics
|
||||
for name, opt_ in self.opt['val']['metrics'].items():
|
||||
metric_data = dict(img1=sr_img, img2=gt_img)
|
||||
self.metric_results[name] += calculate_metric(metric_data, opt_)
|
||||
pbar.update(1)
|
||||
pbar.set_description(f'Test {img_name}')
|
||||
pbar.close()
|
||||
if use_pbar:
|
||||
pbar.update(1)
|
||||
pbar.set_description(f'Test {img_name}')
|
||||
if use_pbar:
|
||||
pbar.close()
|
||||
|
||||
if with_metrics:
|
||||
for metric in self.metric_results.keys():
|
||||
self.metric_results[metric] /= (idx + 1)
|
||||
# update the best metric result
|
||||
self._update_best_metric_result(dataset_name, metric, self.metric_results[metric], current_iter)
|
||||
|
||||
self._log_validation_metric_values(current_iter, dataset_name, tb_logger)
|
||||
|
||||
def _log_validation_metric_values(self, current_iter, dataset_name, tb_logger):
|
||||
log_str = f'Validation {dataset_name}\n'
|
||||
for metric, value in self.metric_results.items():
|
||||
log_str += f'\t # {metric}: {value:.4f}\n'
|
||||
log_str += f'\t # {metric}: {value:.4f}'
|
||||
if hasattr(self, 'best_metric_results'):
|
||||
log_str += (f'\tBest: {self.best_metric_results[dataset_name][metric]["val"]:.4f} @ '
|
||||
f'{self.best_metric_results[dataset_name][metric]["iter"]} iter')
|
||||
log_str += '\n'
|
||||
|
||||
logger = get_root_logger()
|
||||
logger.info(log_str)
|
||||
if tb_logger:
|
||||
for metric, value in self.metric_results.items():
|
||||
tb_logger.add_scalar(f'metrics/{metric}', value, current_iter)
|
||||
|
||||
def get_current_visuals(self):
|
||||
out_dict = OrderedDict()
|
||||
out_dict['gt'] = self.gt.detach().cpu()
|
||||
out_dict['sr'] = self.output.detach().cpu()
|
||||
return out_dict
|
||||
tb_logger.add_scalar(f'metrics/{dataset_name}/{metric}', value, current_iter)
|
||||
|
||||
def save(self, epoch, current_iter):
|
||||
# save net_g and net_d
|
||||
self.save_network([self.net_g, self.net_g_ema], 'net_g', current_iter, param_key=['params', 'params_ema'])
|
||||
self.save_network(self.net_d, 'net_d', current_iter)
|
||||
# save component discriminators
|
||||
@@ -558,4 +576,5 @@ class GFPGANModel(BaseModel):
|
||||
self.save_network(self.net_d_left_eye, 'net_d_left_eye', current_iter)
|
||||
self.save_network(self.net_d_right_eye, 'net_d_right_eye', current_iter)
|
||||
self.save_network(self.net_d_mouth, 'net_d_mouth', current_iter)
|
||||
# save training state
|
||||
self.save_training_state(epoch, current_iter)
|
||||
|
||||
@@ -14,6 +14,20 @@ ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
|
||||
class GFPGANer():
|
||||
"""Helper for restoration with GFPGAN.
|
||||
|
||||
It will detect and crop faces, and then resize the faces to 512x512.
|
||||
GFPGAN is used to restored the resized faces.
|
||||
The background is upsampled with the bg_upsampler.
|
||||
Finally, the faces will be pasted back to the upsample background image.
|
||||
|
||||
Args:
|
||||
model_path (str): The path to the GFPGAN model. It can be urls (will first download it automatically).
|
||||
upscale (float): The upscale of the final output. Default: 2.
|
||||
arch (str): The GFPGAN architecture. Option: clean | original. Default: clean.
|
||||
channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
|
||||
bg_upsampler (nn.Module): The upsampler for the background. Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self, model_path, upscale=2, arch='clean', channel_multiplier=2, bg_upsampler=None):
|
||||
self.upscale = upscale
|
||||
@@ -70,7 +84,7 @@ class GFPGANer():
|
||||
def enhance(self, img, has_aligned=False, only_center_face=False, paste_back=True):
|
||||
self.face_helper.clean_all()
|
||||
|
||||
if has_aligned:
|
||||
if has_aligned: # the inputs are already aligned
|
||||
img = cv2.resize(img, (512, 512))
|
||||
self.face_helper.cropped_faces = [img]
|
||||
else:
|
||||
@@ -78,6 +92,7 @@ class GFPGANer():
|
||||
# get face landmarks for each face
|
||||
self.face_helper.get_face_landmarks_5(only_center_face=only_center_face, eye_dist_threshold=5)
|
||||
# eye_dist_threshold=5: skip faces whose eye distance is smaller than 5 pixels
|
||||
# TODO: even with eye_dist_threshold, it will still introduce wrong detections and restorations.
|
||||
# align and warp each face
|
||||
self.face_helper.align_warp_face()
|
||||
|
||||
@@ -100,9 +115,9 @@ class GFPGANer():
|
||||
self.face_helper.add_restored_face(restored_face)
|
||||
|
||||
if not has_aligned and paste_back:
|
||||
|
||||
# upsample the background
|
||||
if self.bg_upsampler is not None:
|
||||
# Now only support RealESRGAN
|
||||
# Now only support RealESRGAN for upsampling background
|
||||
bg_img = self.bg_upsampler.enhance(img, outscale=self.upscale)[0]
|
||||
else:
|
||||
bg_img = None
|
||||
@@ -116,7 +131,9 @@ class GFPGANer():
|
||||
|
||||
|
||||
def load_file_from_url(url, model_dir=None, progress=True, file_name=None):
|
||||
"""Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py
|
||||
"""Load file form http url, will download models if necessary.
|
||||
|
||||
Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py
|
||||
"""
|
||||
if model_dir is None:
|
||||
hub_dir = get_dir()
|
||||
|
||||
@@ -10,20 +10,22 @@ from gfpgan import GFPGANer
|
||||
|
||||
|
||||
def main():
|
||||
"""Inference demo for GFPGAN.
|
||||
"""
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument('--upscale', type=int, default=2)
|
||||
parser.add_argument('--arch', type=str, default='clean')
|
||||
parser.add_argument('--channel', type=int, default=2)
|
||||
parser.add_argument('--upscale', type=int, default=2, help='The final upsampling scale of the image')
|
||||
parser.add_argument('--arch', type=str, default='clean', help='The GFPGAN architecture. Option: clean | original')
|
||||
parser.add_argument('--channel', type=int, default=2, help='Channel multiplier for large networks of StyleGAN2')
|
||||
parser.add_argument('--model_path', type=str, default='experiments/pretrained_models/GFPGANCleanv1-NoCE-C2.pth')
|
||||
parser.add_argument('--bg_upsampler', type=str, default='realesrgan')
|
||||
parser.add_argument('--bg_tile', type=int, default=400)
|
||||
parser.add_argument('--test_path', type=str, default='inputs/whole_imgs')
|
||||
parser.add_argument('--bg_upsampler', type=str, default='realesrgan', help='background upsampler')
|
||||
parser.add_argument(
|
||||
'--bg_tile', type=int, default=400, help='Tile size for background sampler, 0 for no tile during testing')
|
||||
parser.add_argument('--test_path', type=str, default='inputs/whole_imgs', help='Input folder')
|
||||
parser.add_argument('--suffix', type=str, default=None, help='Suffix of the restored faces')
|
||||
parser.add_argument('--only_center_face', action='store_true')
|
||||
parser.add_argument('--aligned', action='store_true')
|
||||
parser.add_argument('--paste_back', action='store_false')
|
||||
parser.add_argument('--save_root', type=str, default='results')
|
||||
parser.add_argument('--only_center_face', action='store_true', help='Only restore the center face')
|
||||
parser.add_argument('--aligned', action='store_true', help='Input are aligned faces')
|
||||
parser.add_argument('--paste_back', action='store_false', help='Paste the restored faces back to images')
|
||||
parser.add_argument('--save_root', type=str, default='results', help='Path to save root')
|
||||
parser.add_argument(
|
||||
'--ext',
|
||||
type=str,
|
||||
@@ -70,6 +72,7 @@ def main():
|
||||
basename, ext = os.path.splitext(img_name)
|
||||
input_img = cv2.imread(img_path, cv2.IMREAD_COLOR)
|
||||
|
||||
# restore faces and background if necessary
|
||||
cropped_faces, restored_faces, restored_img = restorer.enhance(
|
||||
input_img, has_aligned=args.aligned, only_center_face=args.only_center_face, paste_back=args.paste_back)
|
||||
|
||||
@@ -85,7 +88,7 @@ def main():
|
||||
save_face_name = f'{basename}_{idx:02d}.png'
|
||||
save_restore_path = os.path.join(args.save_root, 'restored_faces', save_face_name)
|
||||
imwrite(restored_face, save_restore_path)
|
||||
# save cmp image
|
||||
# save comparison image
|
||||
cmp_img = np.concatenate((cropped_face, restored_face), axis=1)
|
||||
imwrite(cmp_img, os.path.join(args.save_root, 'cmp', f'{basename}_{idx:02d}.png'))
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# general settings
|
||||
name: train_GFPGANv1_512
|
||||
model_type: GFPGANModel
|
||||
num_gpu: 4
|
||||
num_gpu: auto # officially, we use 4 GPUs
|
||||
manual_seed: 0
|
||||
|
||||
# dataset and data loader settings
|
||||
@@ -194,7 +194,7 @@ val:
|
||||
save_img: true
|
||||
|
||||
metrics:
|
||||
psnr: # metric name, can be arbitrary
|
||||
psnr: # metric name
|
||||
type: calculate_psnr
|
||||
crop_border: 0
|
||||
test_y_channel: false
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# general settings
|
||||
name: train_GFPGANv1_512_simple
|
||||
model_type: GFPGANModel
|
||||
num_gpu: 4
|
||||
num_gpu: auto # officially, we use 4 GPUs
|
||||
manual_seed: 0
|
||||
|
||||
# dataset and data loader settings
|
||||
@@ -40,10 +40,6 @@ datasets:
|
||||
# gray_prob: 0.01
|
||||
# gt_gray: True
|
||||
|
||||
# crop_components: false
|
||||
# component_path: experiments/pretrained_models/FFHQ_eye_mouth_landmarks_512.pth
|
||||
# eye_enlarge_ratio: 1.4
|
||||
|
||||
# data loader
|
||||
use_shuffle: true
|
||||
num_worker_per_gpu: 6
|
||||
@@ -86,20 +82,6 @@ network_d:
|
||||
channel_multiplier: 1
|
||||
resample_kernel: [1, 3, 3, 1]
|
||||
|
||||
# network_d_left_eye:
|
||||
# type: FacialComponentDiscriminator
|
||||
|
||||
# network_d_right_eye:
|
||||
# type: FacialComponentDiscriminator
|
||||
|
||||
# network_d_mouth:
|
||||
# type: FacialComponentDiscriminator
|
||||
|
||||
network_identity:
|
||||
type: ResNetArcFace
|
||||
block: IRBlock
|
||||
layers: [2, 2, 2, 2]
|
||||
use_se: False
|
||||
|
||||
# path
|
||||
path:
|
||||
@@ -107,13 +89,7 @@ path:
|
||||
param_key_g: params_ema
|
||||
strict_load_g: ~
|
||||
pretrain_network_d: ~
|
||||
# pretrain_network_d_left_eye: ~
|
||||
# pretrain_network_d_right_eye: ~
|
||||
# pretrain_network_d_mouth: ~
|
||||
pretrain_network_identity: experiments/pretrained_models/arcface_resnet18.pth
|
||||
# resume
|
||||
resume_state: ~
|
||||
ignore_resume_networks: ['network_identity']
|
||||
|
||||
# training settings
|
||||
train:
|
||||
@@ -173,16 +149,6 @@ train:
|
||||
loss_weight: !!float 1e-1
|
||||
# r1 regularization for discriminator
|
||||
r1_reg_weight: 10
|
||||
# facial component loss
|
||||
# gan_component_opt:
|
||||
# type: GANLoss
|
||||
# gan_type: vanilla
|
||||
# real_label_val: 1.0
|
||||
# fake_label_val: 0.0
|
||||
# loss_weight: !!float 1
|
||||
# comp_style_weight: 200
|
||||
# identity loss
|
||||
identity_weight: 10
|
||||
|
||||
net_d_iters: 1
|
||||
net_d_init_iters: 0
|
||||
@@ -194,7 +160,7 @@ val:
|
||||
save_img: true
|
||||
|
||||
metrics:
|
||||
psnr: # metric name, can be arbitrary
|
||||
psnr: # metric name
|
||||
type: calculate_psnr
|
||||
crop_border: 0
|
||||
test_y_channel: false
|
||||
|
||||
@@ -1,24 +1,31 @@
|
||||
import cv2
|
||||
import json
|
||||
import numpy as np
|
||||
import os
|
||||
import torch
|
||||
from basicsr.utils import FileClient, imfrombytes
|
||||
from collections import OrderedDict
|
||||
|
||||
# ---------------------------- This script is used to parse facial landmarks ------------------------------------- #
|
||||
# Configurations
|
||||
save_img = False
|
||||
scale = 0.5 # 0.5 for official FFHQ (512x512), 1 for others
|
||||
enlarge_ratio = 1.4 # only for eyes
|
||||
json_path = 'ffhq-dataset-v2.json'
|
||||
face_path = 'datasets/ffhq/ffhq_512.lmdb'
|
||||
save_path = './FFHQ_eye_mouth_landmarks_512.pth'
|
||||
|
||||
print('Load JSON metadata...')
|
||||
# use the json file in FFHQ dataset
|
||||
with open('ffhq-dataset-v2.json', 'rb') as f:
|
||||
# use the official json file in FFHQ dataset
|
||||
with open(json_path, 'rb') as f:
|
||||
json_data = json.load(f, object_pairs_hook=OrderedDict)
|
||||
|
||||
print('Open LMDB file...')
|
||||
# read ffhq images
|
||||
file_client = FileClient('lmdb', db_paths='datasets/ffhq/ffhq_512.lmdb')
|
||||
with open('datasets/ffhq/ffhq_512.lmdb/meta_info.txt') as fin:
|
||||
file_client = FileClient('lmdb', db_paths=face_path)
|
||||
with open(os.path.join(face_path, 'meta_info.txt')) as fin:
|
||||
paths = [line.split('.')[0] for line in fin]
|
||||
|
||||
save_img = False
|
||||
scale = 0.5 # 0.5 for official FFHQ (512x512), 1 for others
|
||||
enlarge_ratio = 1.4 # only for eyes
|
||||
save_dict = {}
|
||||
|
||||
for item_idx, item in enumerate(json_data.values()):
|
||||
@@ -34,6 +41,7 @@ for item_idx, item in enumerate(json_data.values()):
|
||||
img_bytes = file_client.get(paths[item_idx])
|
||||
img = imfrombytes(img_bytes, float32=True)
|
||||
|
||||
# get landmarks for each component
|
||||
map_left_eye = list(range(36, 42))
|
||||
map_right_eye = list(range(42, 48))
|
||||
map_mouth = list(range(48, 68))
|
||||
@@ -74,4 +82,4 @@ for item_idx, item in enumerate(json_data.values()):
|
||||
save_dict[f'{item_idx:08d}'] = item_dict
|
||||
|
||||
print('Save...')
|
||||
torch.save(save_dict, './FFHQ_eye_mouth_landmarks_512.pth')
|
||||
torch.save(save_dict, save_path)
|
||||
|
||||
Reference in New Issue
Block a user