28 Commits

Author SHA1 Message Date
Xintao
7552a7791c Delete .github/workflows/no-response.yml 2024-04-03 00:39:30 +08:00
Xintao
2eac203389 v1.3.8 2022-09-16 19:33:26 +08:00
Xintao
2f46d95254 fix pylint 2022-09-16 19:32:42 +08:00
Xintao
bc5a5deb95 remove codeformer 2022-09-16 19:31:20 +08:00
Xintao
3fd33abc47 update cog predict 2022-09-12 23:24:08 +08:00
Xintao
d226e86f6c v1.3.7 2022-09-12 22:40:25 +08:00
Xintao
bb2f916764 update 2022-09-12 22:29:46 +08:00
Xintao
fe3beac9dc v1.3.6 2022-09-12 22:17:48 +08:00
Xintao
126c55c68d add restoreformer and codeformer inference codes 2022-09-12 21:33:06 +08:00
Xintao
8d2447a2d9 update cog predict 2022-09-04 23:27:02 +08:00
Xintao
af7569775d v1.3.5 2022-09-04 22:18:25 +08:00
Xintao
c6593e7221 update cog_predict 2022-09-04 20:28:24 +08:00
Xintao
7272e45887 update replicate (#248)
* update util

* update predict

* update predict

* update predict

* update predict

* update predict

* update predict

* update predict

* update predict

* merge replicate update
2022-09-04 20:12:31 +08:00
Xintao
3e27784b1b update replicate related 2022-08-31 17:36:25 +08:00
Xintao
2c420ee565 update readme 2022-08-31 16:33:30 +08:00
Xintao
8e7cf5d723 update readme 2022-08-30 23:02:22 +08:00
Xintao
c541e97f83 update readme 2022-08-30 23:01:28 +08:00
Xintao
86756cba65 update readme 2022-08-30 22:57:22 +08:00
Chenxi
a9a2e3ae15 Add Docker environment & web demo (#67)
* enable cog

* Update README.md

* Update README.md

* refactor

* fix temp input dir bug

Co-authored-by: CJWBW <70536672+CJWBW@users.noreply.github.com>
Co-authored-by: Chenxi <chenxi@Chenxis-MacBook-Pro-2.local>
Co-authored-by: Xintao <wxt1994@126.com>
2022-08-29 17:28:16 +08:00
Xintao
9c3f2d62cb v1.3.4 2022-07-13 10:21:28 +08:00
Xintao
ccd30af837 add release workflow 2022-07-13 10:19:50 +08:00
AJ
7d657f26b6 fix basicsr losses import (#210) 2022-07-13 10:01:06 +08:00
Xintao
c7ccc098a7 update facelib; use seg to paste back 2022-06-07 16:49:26 +08:00
Xintao
bc3f0c4d91 add device to GFPGANer for multiGPU support 2022-05-04 13:23:54 +08:00
Xintao
924ce473ab v1.3.2 2022-02-16 00:32:50 +08:00
Xintao
09a37ae7fd add logo 2022-02-16 00:11:10 +08:00
Xintao
6c544b70e6 v1.3.1 2022-02-14 15:37:45 +08:00
Xintao
47983e1767 add stylegan2_bilinear_arch 2022-02-14 14:28:27 +08:00
17 changed files with 1568 additions and 61 deletions

View File

@@ -1,33 +0,0 @@
name: No Response
# TODO: it seems not to work
# Modified from: https://raw.githubusercontent.com/github/docs/main/.github/workflows/no-response.yaml
# **What it does**: Closes issues that don't have enough information to be actionable.
# **Why we have it**: To remove the need for maintainers to remember to check back on issues periodically
# to see if contributors have responded.
# **Who does it impact**: Everyone that works on docs or docs-internal.
on:
issue_comment:
types: [created]
schedule:
# Schedule for five minutes after the hour every hour
- cron: '5 * * * *'
jobs:
noResponse:
runs-on: ubuntu-latest
steps:
- uses: lee-dohm/no-response@v0.5.0
with:
token: ${{ github.token }}
closeComment: >
This issue has been automatically closed because there has been no response
to our request for more information from the original author. With only the
information that is currently in the issue, we don't have enough information
to take action. Please reach out if you have or find the answers we need so
that we can investigate further.
If you still have questions, please improve your description and re-open it.
Thanks :-)

41
.github/workflows/release.yml vendored Normal file
View File

@@ -0,0 +1,41 @@
name: release
on:
push:
tags:
- '*'
jobs:
build:
permissions: write-all
name: Create Release
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v2
- name: Create Release
id: create_release
uses: actions/create-release@v1
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
with:
tag_name: ${{ github.ref }}
release_name: GFPGAN ${{ github.ref }} Release Note
body: |
🚀 See you again 😸
🚀Have a nice day 😸 and happy everyday 😃
🚀 Long time no see ☄️
✨ **Highlights**
✅ [Features] Support ...
🐛 **Bug Fixes**
🌴 **Improvements**
📢📢📢
<p align="center">
<img src="https://raw.githubusercontent.com/TencentARC/GFPGAN/master/assets/gfpgan_logo.png" height=150>
</p>
draft: true
prerelease: false

View File

@@ -1,4 +1,13 @@
# GFPGAN (CVPR 2021)
<p align="center">
<img src="assets/gfpgan_logo.png" height=130>
</p>
## <div align="center"><b><a href="README.md">English</a> | <a href="README_CN.md">简体中文</a></b></div>
<div align="center">
<!-- <a href="https://twitter.com/_Xintao_" style="text-decoration:none;">
<img src="https://user-images.githubusercontent.com/17445847/187162058-c764ced6-952f-404b-ac85-ba95cce18e7b.png" width="4%" alt="" />
</a> -->
[![download](https://img.shields.io/github/downloads/TencentARC/GFPGAN/total.svg)](https://github.com/TencentARC/GFPGAN/releases)
[![PyPI](https://img.shields.io/pypi/v/gfpgan)](https://pypi.org/project/gfpgan/)
@@ -7,12 +16,15 @@
[![LICENSE](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://github.com/TencentARC/GFPGAN/blob/master/LICENSE)
[![python lint](https://github.com/TencentARC/GFPGAN/actions/workflows/pylint.yml/badge.svg)](https://github.com/TencentARC/GFPGAN/blob/master/.github/workflows/pylint.yml)
[![Publish-pip](https://github.com/TencentARC/GFPGAN/actions/workflows/publish-pip.yml/badge.svg)](https://github.com/TencentARC/GFPGAN/blob/master/.github/workflows/publish-pip.yml)
</div>
1. :boom: **Updated** online demo: [![Replicate](https://img.shields.io/static/v1?label=Demo&message=Replicate&color=blue)](https://replicate.com/tencentarc/gfpgan). Here is the [backup](https://replicate.com/xinntao/gfpgan).
1. :boom: **Updated** online demo: [![Huggingface Gradio](https://img.shields.io/static/v1?label=Demo&message=Huggingface%20Gradio&color=orange)](https://huggingface.co/spaces/Xintao/GFPGAN)
1. [Colab Demo](https://colab.research.google.com/drive/1sVsoBd9AjckIXThgtZhGrHRfFI6UUYOo) for GFPGAN <a href="https://colab.research.google.com/drive/1sVsoBd9AjckIXThgtZhGrHRfFI6UUYOo"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="google colab logo"></a>; (Another [Colab Demo](https://colab.research.google.com/drive/1Oa1WwKB4M4l1GmR7CtswDVgOCOeSLChA?usp=sharing) for the original paper model)
2. Online demo: [Huggingface](https://huggingface.co/spaces/akhaliq/GFPGAN) (return only the cropped face)
3. Online demo: [Replicate.ai](https://replicate.com/xinntao/gfpgan) (may need to sign in, return the whole image)
<!-- 3. Online demo: [Replicate.ai](https://replicate.com/xinntao/gfpgan) (may need to sign in, return the whole image)
4. Online demo: [Baseten.co](https://app.baseten.co/applications/Q04Lz0d/operator_views/8qZG6Bg) (backed by GPU, returns the whole image)
5. We provide a *clean* version of GFPGAN, which can run without CUDA extensions. So that it can run in **Windows** or on **CPU mode**.
5. We provide a *clean* version of GFPGAN, which can run without CUDA extensions. So that it can run in **Windows** or on **CPU mode**. -->
> :rocket: **Thanks for your interest in our work. You may also want to check our new updates on the *tiny models* for *anime images and videos* in [Real-ESRGAN](https://github.com/xinntao/Real-ESRGAN/blob/master/docs/anime_video_model.md)** :blush:
@@ -23,7 +35,9 @@ It leverages rich and diverse priors encapsulated in a pretrained face GAN (*e.g
:triangular_flag_on_post: **Updates**
- :fire::fire::white_check_mark: Add **[V1.3 model](https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth)**, which produces **more natural** restoration results, and better results on *very low-quality* / *high-quality* inputs. See more in [Model zoo](#european_castle-model-zoo), [Comparisons.md](Comparisons.md)
- :white_check_mark: Add [RestoreFormer](https://github.com/wzhouxiff/RestoreFormer) inference codes.
- :white_check_mark: Add [V1.4 model](https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth), which produces slightly more details and better identity than V1.3.
- :white_check_mark: Add **[V1.3 model](https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth)**, which produces **more natural** restoration results, and better results on *very low-quality* / *high-quality* inputs. See more in [Model zoo](#european_castle-model-zoo), [Comparisons.md](Comparisons.md)
- :white_check_mark: Integrated to [Huggingface Spaces](https://huggingface.co/spaces) with [Gradio](https://github.com/gradio-app/gradio). See [Gradio Web Demo](https://huggingface.co/spaces/akhaliq/GFPGAN).
- :white_check_mark: Support enhancing non-face regions (background) with [Real-ESRGAN](https://github.com/xinntao/Real-ESRGAN).
- :white_check_mark: We provide a *clean* version of GFPGAN, which does not require CUDA extensions.

7
README_CN.md Normal file
View File

@@ -0,0 +1,7 @@
<p align="center">
<img src="assets/gfpgan_logo.png" height=130>
</p>
## <div align="center"><b><a href="README.md">English</a> | <a href="README_CN.md">简体中文</a></b></div>
还未完工,欢迎贡献!

View File

@@ -1 +1 @@
1.3.0
1.3.8

BIN
assets/gfpgan_logo.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 50 KiB

22
cog.yaml Normal file
View File

@@ -0,0 +1,22 @@
# This file is used for constructing replicate env
image: "r8.im/tencentarc/gfpgan"
build:
gpu: true
python_version: "3.8"
system_packages:
- "libgl1-mesa-glx"
- "libglib2.0-0"
python_packages:
- "torch==1.7.1"
- "torchvision==0.8.2"
- "numpy==1.21.1"
- "lmdb==1.2.1"
- "opencv-python==4.5.3.56"
- "PyYAML==5.4.1"
- "tqdm==4.62.2"
- "yapf==0.31.0"
- "basicsr==1.4.2"
- "facexlib==0.2.5"
predict: "cog_predict.py:Predictor"

161
cog_predict.py Normal file
View File

@@ -0,0 +1,161 @@
# flake8: noqa
# This file is used for deploying replicate models
# running: cog predict -i img=@inputs/whole_imgs/10045.png -i version='v1.4' -i scale=2
# push: cog push r8.im/tencentarc/gfpgan
# push (backup): cog push r8.im/xinntao/gfpgan
import os
os.system('python setup.py develop')
os.system('pip install realesrgan')
import cv2
import shutil
import tempfile
import torch
from basicsr.archs.srvgg_arch import SRVGGNetCompact
from gfpgan import GFPGANer
try:
from cog import BasePredictor, Input, Path
from realesrgan.utils import RealESRGANer
except Exception:
print('please install cog and realesrgan package')
class Predictor(BasePredictor):
def setup(self):
os.makedirs('output', exist_ok=True)
# download weights
if not os.path.exists('gfpgan/weights/realesr-general-x4v3.pth'):
os.system(
'wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth -P ./gfpgan/weights'
)
if not os.path.exists('gfpgan/weights/GFPGANv1.2.pth'):
os.system(
'wget https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.2.pth -P ./gfpgan/weights')
if not os.path.exists('gfpgan/weights/GFPGANv1.3.pth'):
os.system(
'wget https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth -P ./gfpgan/weights')
if not os.path.exists('gfpgan/weights/GFPGANv1.4.pth'):
os.system(
'wget https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth -P ./gfpgan/weights')
if not os.path.exists('gfpgan/weights/RestoreFormer.pth'):
os.system(
'wget https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/RestoreFormer.pth -P ./gfpgan/weights'
)
# background enhancer with RealESRGAN
model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
model_path = 'gfpgan/weights/realesr-general-x4v3.pth'
half = True if torch.cuda.is_available() else False
self.upsampler = RealESRGANer(
scale=4, model_path=model_path, model=model, tile=0, tile_pad=10, pre_pad=0, half=half)
# Use GFPGAN for face enhancement
self.face_enhancer = GFPGANer(
model_path='gfpgan/weights/GFPGANv1.4.pth',
upscale=2,
arch='clean',
channel_multiplier=2,
bg_upsampler=self.upsampler)
self.current_version = 'v1.4'
def predict(
self,
img: Path = Input(description='Input'),
version: str = Input(
description='GFPGAN version. v1.3: better quality. v1.4: more details and better identity.',
choices=['v1.2', 'v1.3', 'v1.4', 'RestoreFormer'],
default='v1.4'),
scale: float = Input(description='Rescaling factor', default=2),
) -> Path:
weight = 0.5
print(img, version, scale, weight)
try:
extension = os.path.splitext(os.path.basename(str(img)))[1]
img = cv2.imread(str(img), cv2.IMREAD_UNCHANGED)
if len(img.shape) == 3 and img.shape[2] == 4:
img_mode = 'RGBA'
elif len(img.shape) == 2:
img_mode = None
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
else:
img_mode = None
h, w = img.shape[0:2]
if h < 300:
img = cv2.resize(img, (w * 2, h * 2), interpolation=cv2.INTER_LANCZOS4)
if self.current_version != version:
if version == 'v1.2':
self.face_enhancer = GFPGANer(
model_path='gfpgan/weights/GFPGANv1.2.pth',
upscale=2,
arch='clean',
channel_multiplier=2,
bg_upsampler=self.upsampler)
self.current_version = 'v1.2'
elif version == 'v1.3':
self.face_enhancer = GFPGANer(
model_path='gfpgan/weights/GFPGANv1.3.pth',
upscale=2,
arch='clean',
channel_multiplier=2,
bg_upsampler=self.upsampler)
self.current_version = 'v1.3'
elif version == 'v1.4':
self.face_enhancer = GFPGANer(
model_path='gfpgan/weights/GFPGANv1.4.pth',
upscale=2,
arch='clean',
channel_multiplier=2,
bg_upsampler=self.upsampler)
self.current_version = 'v1.4'
elif version == 'RestoreFormer':
self.face_enhancer = GFPGANer(
model_path='gfpgan/weights/RestoreFormer.pth',
upscale=2,
arch='RestoreFormer',
channel_multiplier=2,
bg_upsampler=self.upsampler)
try:
_, _, output = self.face_enhancer.enhance(
img, has_aligned=False, only_center_face=False, paste_back=True, weight=weight)
except RuntimeError as error:
print('Error', error)
try:
if scale != 2:
interpolation = cv2.INTER_AREA if scale < 2 else cv2.INTER_LANCZOS4
h, w = img.shape[0:2]
output = cv2.resize(output, (int(w * scale / 2), int(h * scale / 2)), interpolation=interpolation)
except Exception as error:
print('wrong scale input.', error)
if img_mode == 'RGBA': # RGBA images should be saved in png format
extension = 'png'
# save_path = f'output/out.{extension}'
# cv2.imwrite(save_path, output)
out_path = Path(tempfile.mkdtemp()) / f'out.{extension}'
cv2.imwrite(str(out_path), output)
except Exception as error:
print('global exception: ', error)
finally:
clean_folder('output')
return out_path
def clean_folder(folder):
for filename in os.listdir(folder):
file_path = os.path.join(folder, filename)
try:
if os.path.isfile(file_path) or os.path.islink(file_path):
os.unlink(file_path)
elif os.path.isdir(file_path):
shutil.rmtree(file_path)
except Exception as e:
print(f'Failed to delete {file_path}. Reason: {e}')

View File

@@ -1,12 +1,12 @@
import math
import random
import torch
from basicsr.archs.stylegan2_bilinear_arch import (ConvLayer, EqualConv2d, EqualLinear, ResBlock, ScaledLeakyReLU,
StyleGAN2GeneratorBilinear)
from basicsr.utils.registry import ARCH_REGISTRY
from torch import nn
from .gfpganv1_arch import ResUpBlock
from .stylegan2_bilinear_arch import (ConvLayer, EqualConv2d, EqualLinear, ResBlock, ScaledLeakyReLU,
StyleGAN2GeneratorBilinear)
class StyleGAN2GeneratorBilinearSFT(StyleGAN2GeneratorBilinear):

View File

@@ -350,7 +350,7 @@ class GFPGANv1(nn.Module):
ScaledLeakyReLU(0.2),
EqualConv2d(out_channels, sft_out_channels, 3, stride=1, padding=1, bias=True, bias_init_val=0)))
def forward(self, x, return_latents=False, return_rgb=True, randomize_noise=True):
def forward(self, x, return_latents=False, return_rgb=True, randomize_noise=True, **kwargs):
"""Forward function for GFPGANv1.
Args:
@@ -416,7 +416,7 @@ class FacialComponentDiscriminator(nn.Module):
self.conv5 = ConvLayer(256, 256, 3, downsample=False, resample_kernel=(1, 3, 3, 1), bias=True, activate=True)
self.final_conv = ConvLayer(256, 1, 3, bias=True, activate=False)
def forward(self, x, return_feats=False):
def forward(self, x, return_feats=False, **kwargs):
"""Forward function for FacialComponentDiscriminator.
Args:

View File

@@ -274,7 +274,7 @@ class GFPGANv1Clean(nn.Module):
nn.Conv2d(out_channels, out_channels, 3, 1, 1), nn.LeakyReLU(0.2, True),
nn.Conv2d(out_channels, sft_out_channels, 3, 1, 1)))
def forward(self, x, return_latents=False, return_rgb=True, randomize_noise=True):
def forward(self, x, return_latents=False, return_rgb=True, randomize_noise=True, **kwargs):
"""Forward function for GFPGANv1Clean.
Args:

View File

@@ -0,0 +1,658 @@
"""Modified from https://github.com/wzhouxiff/RestoreFormer
"""
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
class VectorQuantizer(nn.Module):
"""
see https://github.com/MishaLaskin/vqvae/blob/d761a999e2267766400dc646d82d3ac3657771d4/models/quantizer.py
____________________________________________
Discretization bottleneck part of the VQ-VAE.
Inputs:
- n_e : number of embeddings
- e_dim : dimension of embedding
- beta : commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2
_____________________________________________
"""
def __init__(self, n_e, e_dim, beta):
super(VectorQuantizer, self).__init__()
self.n_e = n_e
self.e_dim = e_dim
self.beta = beta
self.embedding = nn.Embedding(self.n_e, self.e_dim)
self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
def forward(self, z):
"""
Inputs the output of the encoder network z and maps it to a discrete
one-hot vector that is the index of the closest embedding vector e_j
z (continuous) -> z_q (discrete)
z.shape = (batch, channel, height, width)
quantization pipeline:
1. get encoder input (B,C,H,W)
2. flatten input to (B*H*W,C)
"""
# reshape z -> (batch, height, width, channel) and flatten
z = z.permute(0, 2, 3, 1).contiguous()
z_flattened = z.view(-1, self.e_dim)
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \
torch.sum(self.embedding.weight**2, dim=1) - 2 * \
torch.matmul(z_flattened, self.embedding.weight.t())
# could possible replace this here
# #\start...
# find closest encodings
min_value, min_encoding_indices = torch.min(d, dim=1)
min_encoding_indices = min_encoding_indices.unsqueeze(1)
min_encodings = torch.zeros(min_encoding_indices.shape[0], self.n_e).to(z)
min_encodings.scatter_(1, min_encoding_indices, 1)
# dtype min encodings: torch.float32
# min_encodings shape: torch.Size([2048, 512])
# min_encoding_indices.shape: torch.Size([2048, 1])
# get quantized latent vectors
z_q = torch.matmul(min_encodings, self.embedding.weight).view(z.shape)
# .........\end
# with:
# .........\start
# min_encoding_indices = torch.argmin(d, dim=1)
# z_q = self.embedding(min_encoding_indices)
# ......\end......... (TODO)
# compute loss for embedding
loss = torch.mean((z_q.detach() - z)**2) + self.beta * torch.mean((z_q - z.detach())**2)
# preserve gradients
z_q = z + (z_q - z).detach()
# perplexity
e_mean = torch.mean(min_encodings, dim=0)
perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + 1e-10)))
# reshape back to match original input shape
z_q = z_q.permute(0, 3, 1, 2).contiguous()
return z_q, loss, (perplexity, min_encodings, min_encoding_indices, d)
def get_codebook_entry(self, indices, shape):
# shape specifying (batch, height, width, channel)
# TODO: check for more easy handling with nn.Embedding
min_encodings = torch.zeros(indices.shape[0], self.n_e).to(indices)
min_encodings.scatter_(1, indices[:, None], 1)
# get quantized latent vectors
z_q = torch.matmul(min_encodings.float(), self.embedding.weight)
if shape is not None:
z_q = z_q.view(shape)
# reshape back to match original input shape
z_q = z_q.permute(0, 3, 1, 2).contiguous()
return z_q
# pytorch_diffusion + derived encoder decoder
def nonlinearity(x):
# swish
return x * torch.sigmoid(x)
def Normalize(in_channels):
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
class Upsample(nn.Module):
def __init__(self, in_channels, with_conv):
super().__init__()
self.with_conv = with_conv
if self.with_conv:
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
def forward(self, x):
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode='nearest')
if self.with_conv:
x = self.conv(x)
return x
class Downsample(nn.Module):
def __init__(self, in_channels, with_conv):
super().__init__()
self.with_conv = with_conv
if self.with_conv:
# no asymmetric padding in torch conv, must do it ourselves
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
def forward(self, x):
if self.with_conv:
pad = (0, 1, 0, 1)
x = torch.nn.functional.pad(x, pad, mode='constant', value=0)
x = self.conv(x)
else:
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
return x
class ResnetBlock(nn.Module):
def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, dropout, temb_channels=512):
super().__init__()
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.use_conv_shortcut = conv_shortcut
self.norm1 = Normalize(in_channels)
self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
if temb_channels > 0:
self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
self.norm2 = Normalize(out_channels)
self.dropout = torch.nn.Dropout(dropout)
self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
else:
self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
def forward(self, x, temb):
h = x
h = self.norm1(h)
h = nonlinearity(h)
h = self.conv1(h)
if temb is not None:
h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
h = self.norm2(h)
h = nonlinearity(h)
h = self.dropout(h)
h = self.conv2(h)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
x = self.conv_shortcut(x)
else:
x = self.nin_shortcut(x)
return x + h
class MultiHeadAttnBlock(nn.Module):
def __init__(self, in_channels, head_size=1):
super().__init__()
self.in_channels = in_channels
self.head_size = head_size
self.att_size = in_channels // head_size
assert (in_channels % head_size == 0), 'The size of head should be divided by the number of channels.'
self.norm1 = Normalize(in_channels)
self.norm2 = Normalize(in_channels)
self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.num = 0
def forward(self, x, y=None):
h_ = x
h_ = self.norm1(h_)
if y is None:
y = h_
else:
y = self.norm2(y)
q = self.q(y)
k = self.k(h_)
v = self.v(h_)
# compute attention
b, c, h, w = q.shape
q = q.reshape(b, self.head_size, self.att_size, h * w)
q = q.permute(0, 3, 1, 2) # b, hw, head, att
k = k.reshape(b, self.head_size, self.att_size, h * w)
k = k.permute(0, 3, 1, 2)
v = v.reshape(b, self.head_size, self.att_size, h * w)
v = v.permute(0, 3, 1, 2)
q = q.transpose(1, 2)
v = v.transpose(1, 2)
k = k.transpose(1, 2).transpose(2, 3)
scale = int(self.att_size)**(-0.5)
q.mul_(scale)
w_ = torch.matmul(q, k)
w_ = F.softmax(w_, dim=3)
w_ = w_.matmul(v)
w_ = w_.transpose(1, 2).contiguous() # [b, h*w, head, att]
w_ = w_.view(b, h, w, -1)
w_ = w_.permute(0, 3, 1, 2)
w_ = self.proj_out(w_)
return x + w_
class MultiHeadEncoder(nn.Module):
def __init__(self,
ch,
out_ch,
ch_mult=(1, 2, 4, 8),
num_res_blocks=2,
attn_resolutions=(16, ),
dropout=0.0,
resamp_with_conv=True,
in_channels=3,
resolution=512,
z_channels=256,
double_z=True,
enable_mid=True,
head_size=1,
**ignore_kwargs):
super().__init__()
self.ch = ch
self.temb_ch = 0
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.in_channels = in_channels
self.enable_mid = enable_mid
# downsampling
self.conv_in = torch.nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1)
curr_res = resolution
in_ch_mult = (1, ) + tuple(ch_mult)
self.down = nn.ModuleList()
for i_level in range(self.num_resolutions):
block = nn.ModuleList()
attn = nn.ModuleList()
block_in = ch * in_ch_mult[i_level]
block_out = ch * ch_mult[i_level]
for i_block in range(self.num_res_blocks):
block.append(
ResnetBlock(
in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout))
block_in = block_out
if curr_res in attn_resolutions:
attn.append(MultiHeadAttnBlock(block_in, head_size))
down = nn.Module()
down.block = block
down.attn = attn
if i_level != self.num_resolutions - 1:
down.downsample = Downsample(block_in, resamp_with_conv)
curr_res = curr_res // 2
self.down.append(down)
# middle
if self.enable_mid:
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(
in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout)
self.mid.attn_1 = MultiHeadAttnBlock(block_in, head_size)
self.mid.block_2 = ResnetBlock(
in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout)
# end
self.norm_out = Normalize(block_in)
self.conv_out = torch.nn.Conv2d(
block_in, 2 * z_channels if double_z else z_channels, kernel_size=3, stride=1, padding=1)
def forward(self, x):
hs = {}
# timestep embedding
temb = None
# downsampling
h = self.conv_in(x)
hs['in'] = h
for i_level in range(self.num_resolutions):
for i_block in range(self.num_res_blocks):
h = self.down[i_level].block[i_block](h, temb)
if len(self.down[i_level].attn) > 0:
h = self.down[i_level].attn[i_block](h)
if i_level != self.num_resolutions - 1:
# hs.append(h)
hs['block_' + str(i_level)] = h
h = self.down[i_level].downsample(h)
# middle
# h = hs[-1]
if self.enable_mid:
h = self.mid.block_1(h, temb)
hs['block_' + str(i_level) + '_atten'] = h
h = self.mid.attn_1(h)
h = self.mid.block_2(h, temb)
hs['mid_atten'] = h
# end
h = self.norm_out(h)
h = nonlinearity(h)
h = self.conv_out(h)
# hs.append(h)
hs['out'] = h
return hs
class MultiHeadDecoder(nn.Module):
def __init__(self,
ch,
out_ch,
ch_mult=(1, 2, 4, 8),
num_res_blocks=2,
attn_resolutions=(16, ),
dropout=0.0,
resamp_with_conv=True,
in_channels=3,
resolution=512,
z_channels=256,
give_pre_end=False,
enable_mid=True,
head_size=1,
**ignorekwargs):
super().__init__()
self.ch = ch
self.temb_ch = 0
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.in_channels = in_channels
self.give_pre_end = give_pre_end
self.enable_mid = enable_mid
# compute in_ch_mult, block_in and curr_res at lowest res
block_in = ch * ch_mult[self.num_resolutions - 1]
curr_res = resolution // 2**(self.num_resolutions - 1)
self.z_shape = (1, z_channels, curr_res, curr_res)
print('Working with z of shape {} = {} dimensions.'.format(self.z_shape, np.prod(self.z_shape)))
# z to block_in
self.conv_in = torch.nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
# middle
if self.enable_mid:
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(
in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout)
self.mid.attn_1 = MultiHeadAttnBlock(block_in, head_size)
self.mid.block_2 = ResnetBlock(
in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout)
# upsampling
self.up = nn.ModuleList()
for i_level in reversed(range(self.num_resolutions)):
block = nn.ModuleList()
attn = nn.ModuleList()
block_out = ch * ch_mult[i_level]
for i_block in range(self.num_res_blocks + 1):
block.append(
ResnetBlock(
in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout))
block_in = block_out
if curr_res in attn_resolutions:
attn.append(MultiHeadAttnBlock(block_in, head_size))
up = nn.Module()
up.block = block
up.attn = attn
if i_level != 0:
up.upsample = Upsample(block_in, resamp_with_conv)
curr_res = curr_res * 2
self.up.insert(0, up) # prepend to get consistent order
# end
self.norm_out = Normalize(block_in)
self.conv_out = torch.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
def forward(self, z):
# assert z.shape[1:] == self.z_shape[1:]
self.last_z_shape = z.shape
# timestep embedding
temb = None
# z to block_in
h = self.conv_in(z)
# middle
if self.enable_mid:
h = self.mid.block_1(h, temb)
h = self.mid.attn_1(h)
h = self.mid.block_2(h, temb)
# upsampling
for i_level in reversed(range(self.num_resolutions)):
for i_block in range(self.num_res_blocks + 1):
h = self.up[i_level].block[i_block](h, temb)
if len(self.up[i_level].attn) > 0:
h = self.up[i_level].attn[i_block](h)
if i_level != 0:
h = self.up[i_level].upsample(h)
# end
if self.give_pre_end:
return h
h = self.norm_out(h)
h = nonlinearity(h)
h = self.conv_out(h)
return h
class MultiHeadDecoderTransformer(nn.Module):
def __init__(self,
ch,
out_ch,
ch_mult=(1, 2, 4, 8),
num_res_blocks=2,
attn_resolutions=(16, ),
dropout=0.0,
resamp_with_conv=True,
in_channels=3,
resolution=512,
z_channels=256,
give_pre_end=False,
enable_mid=True,
head_size=1,
**ignorekwargs):
super().__init__()
self.ch = ch
self.temb_ch = 0
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.in_channels = in_channels
self.give_pre_end = give_pre_end
self.enable_mid = enable_mid
# compute in_ch_mult, block_in and curr_res at lowest res
block_in = ch * ch_mult[self.num_resolutions - 1]
curr_res = resolution // 2**(self.num_resolutions - 1)
self.z_shape = (1, z_channels, curr_res, curr_res)
print('Working with z of shape {} = {} dimensions.'.format(self.z_shape, np.prod(self.z_shape)))
# z to block_in
self.conv_in = torch.nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
# middle
if self.enable_mid:
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(
in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout)
self.mid.attn_1 = MultiHeadAttnBlock(block_in, head_size)
self.mid.block_2 = ResnetBlock(
in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout)
# upsampling
self.up = nn.ModuleList()
for i_level in reversed(range(self.num_resolutions)):
block = nn.ModuleList()
attn = nn.ModuleList()
block_out = ch * ch_mult[i_level]
for i_block in range(self.num_res_blocks + 1):
block.append(
ResnetBlock(
in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout))
block_in = block_out
if curr_res in attn_resolutions:
attn.append(MultiHeadAttnBlock(block_in, head_size))
up = nn.Module()
up.block = block
up.attn = attn
if i_level != 0:
up.upsample = Upsample(block_in, resamp_with_conv)
curr_res = curr_res * 2
self.up.insert(0, up) # prepend to get consistent order
# end
self.norm_out = Normalize(block_in)
self.conv_out = torch.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
def forward(self, z, hs):
# assert z.shape[1:] == self.z_shape[1:]
# self.last_z_shape = z.shape
# timestep embedding
temb = None
# z to block_in
h = self.conv_in(z)
# middle
if self.enable_mid:
h = self.mid.block_1(h, temb)
h = self.mid.attn_1(h, hs['mid_atten'])
h = self.mid.block_2(h, temb)
# upsampling
for i_level in reversed(range(self.num_resolutions)):
for i_block in range(self.num_res_blocks + 1):
h = self.up[i_level].block[i_block](h, temb)
if len(self.up[i_level].attn) > 0:
h = self.up[i_level].attn[i_block](h, hs['block_' + str(i_level) + '_atten'])
# hfeature = h.clone()
if i_level != 0:
h = self.up[i_level].upsample(h)
# end
if self.give_pre_end:
return h
h = self.norm_out(h)
h = nonlinearity(h)
h = self.conv_out(h)
return h
class RestoreFormer(nn.Module):
def __init__(self,
n_embed=1024,
embed_dim=256,
ch=64,
out_ch=3,
ch_mult=(1, 2, 2, 4, 4, 8),
num_res_blocks=2,
attn_resolutions=(16, ),
dropout=0.0,
in_channels=3,
resolution=512,
z_channels=256,
double_z=False,
enable_mid=True,
fix_decoder=False,
fix_codebook=True,
fix_encoder=False,
head_size=8):
super(RestoreFormer, self).__init__()
self.encoder = MultiHeadEncoder(
ch=ch,
out_ch=out_ch,
ch_mult=ch_mult,
num_res_blocks=num_res_blocks,
attn_resolutions=attn_resolutions,
dropout=dropout,
in_channels=in_channels,
resolution=resolution,
z_channels=z_channels,
double_z=double_z,
enable_mid=enable_mid,
head_size=head_size)
self.decoder = MultiHeadDecoderTransformer(
ch=ch,
out_ch=out_ch,
ch_mult=ch_mult,
num_res_blocks=num_res_blocks,
attn_resolutions=attn_resolutions,
dropout=dropout,
in_channels=in_channels,
resolution=resolution,
z_channels=z_channels,
enable_mid=enable_mid,
head_size=head_size)
self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25)
self.quant_conv = torch.nn.Conv2d(z_channels, embed_dim, 1)
self.post_quant_conv = torch.nn.Conv2d(embed_dim, z_channels, 1)
if fix_decoder:
for _, param in self.decoder.named_parameters():
param.requires_grad = False
for _, param in self.post_quant_conv.named_parameters():
param.requires_grad = False
for _, param in self.quantize.named_parameters():
param.requires_grad = False
elif fix_codebook:
for _, param in self.quantize.named_parameters():
param.requires_grad = False
if fix_encoder:
for _, param in self.encoder.named_parameters():
param.requires_grad = False
def encode(self, x):
hs = self.encoder(x)
h = self.quant_conv(hs['out'])
quant, emb_loss, info = self.quantize(h)
return quant, emb_loss, info, hs
def decode(self, quant, hs):
quant = self.post_quant_conv(quant)
dec = self.decoder(quant, hs)
return dec
def forward(self, input, **kwargs):
quant, diff, info, hs = self.encode(input)
dec = self.decode(quant, hs)
return dec, None

View File

@@ -0,0 +1,613 @@
import math
import random
import torch
from basicsr.ops.fused_act import FusedLeakyReLU, fused_leaky_relu
from basicsr.utils.registry import ARCH_REGISTRY
from torch import nn
from torch.nn import functional as F
class NormStyleCode(nn.Module):
def forward(self, x):
"""Normalize the style codes.
Args:
x (Tensor): Style codes with shape (b, c).
Returns:
Tensor: Normalized tensor.
"""
return x * torch.rsqrt(torch.mean(x**2, dim=1, keepdim=True) + 1e-8)
class EqualLinear(nn.Module):
"""Equalized Linear as StyleGAN2.
Args:
in_channels (int): Size of each sample.
out_channels (int): Size of each output sample.
bias (bool): If set to ``False``, the layer will not learn an additive
bias. Default: ``True``.
bias_init_val (float): Bias initialized value. Default: 0.
lr_mul (float): Learning rate multiplier. Default: 1.
activation (None | str): The activation after ``linear`` operation.
Supported: 'fused_lrelu', None. Default: None.
"""
def __init__(self, in_channels, out_channels, bias=True, bias_init_val=0, lr_mul=1, activation=None):
super(EqualLinear, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.lr_mul = lr_mul
self.activation = activation
if self.activation not in ['fused_lrelu', None]:
raise ValueError(f'Wrong activation value in EqualLinear: {activation}'
"Supported ones are: ['fused_lrelu', None].")
self.scale = (1 / math.sqrt(in_channels)) * lr_mul
self.weight = nn.Parameter(torch.randn(out_channels, in_channels).div_(lr_mul))
if bias:
self.bias = nn.Parameter(torch.zeros(out_channels).fill_(bias_init_val))
else:
self.register_parameter('bias', None)
def forward(self, x):
if self.bias is None:
bias = None
else:
bias = self.bias * self.lr_mul
if self.activation == 'fused_lrelu':
out = F.linear(x, self.weight * self.scale)
out = fused_leaky_relu(out, bias)
else:
out = F.linear(x, self.weight * self.scale, bias=bias)
return out
def __repr__(self):
return (f'{self.__class__.__name__}(in_channels={self.in_channels}, '
f'out_channels={self.out_channels}, bias={self.bias is not None})')
class ModulatedConv2d(nn.Module):
"""Modulated Conv2d used in StyleGAN2.
There is no bias in ModulatedConv2d.
Args:
in_channels (int): Channel number of the input.
out_channels (int): Channel number of the output.
kernel_size (int): Size of the convolving kernel.
num_style_feat (int): Channel number of style features.
demodulate (bool): Whether to demodulate in the conv layer.
Default: True.
sample_mode (str | None): Indicating 'upsample', 'downsample' or None.
Default: None.
eps (float): A value added to the denominator for numerical stability.
Default: 1e-8.
"""
def __init__(self,
in_channels,
out_channels,
kernel_size,
num_style_feat,
demodulate=True,
sample_mode=None,
eps=1e-8,
interpolation_mode='bilinear'):
super(ModulatedConv2d, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.demodulate = demodulate
self.sample_mode = sample_mode
self.eps = eps
self.interpolation_mode = interpolation_mode
if self.interpolation_mode == 'nearest':
self.align_corners = None
else:
self.align_corners = False
self.scale = 1 / math.sqrt(in_channels * kernel_size**2)
# modulation inside each modulated conv
self.modulation = EqualLinear(
num_style_feat, in_channels, bias=True, bias_init_val=1, lr_mul=1, activation=None)
self.weight = nn.Parameter(torch.randn(1, out_channels, in_channels, kernel_size, kernel_size))
self.padding = kernel_size // 2
def forward(self, x, style):
"""Forward function.
Args:
x (Tensor): Tensor with shape (b, c, h, w).
style (Tensor): Tensor with shape (b, num_style_feat).
Returns:
Tensor: Modulated tensor after convolution.
"""
b, c, h, w = x.shape # c = c_in
# weight modulation
style = self.modulation(style).view(b, 1, c, 1, 1)
# self.weight: (1, c_out, c_in, k, k); style: (b, 1, c, 1, 1)
weight = self.scale * self.weight * style # (b, c_out, c_in, k, k)
if self.demodulate:
demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + self.eps)
weight = weight * demod.view(b, self.out_channels, 1, 1, 1)
weight = weight.view(b * self.out_channels, c, self.kernel_size, self.kernel_size)
if self.sample_mode == 'upsample':
x = F.interpolate(x, scale_factor=2, mode=self.interpolation_mode, align_corners=self.align_corners)
elif self.sample_mode == 'downsample':
x = F.interpolate(x, scale_factor=0.5, mode=self.interpolation_mode, align_corners=self.align_corners)
b, c, h, w = x.shape
x = x.view(1, b * c, h, w)
# weight: (b*c_out, c_in, k, k), groups=b
out = F.conv2d(x, weight, padding=self.padding, groups=b)
out = out.view(b, self.out_channels, *out.shape[2:4])
return out
def __repr__(self):
return (f'{self.__class__.__name__}(in_channels={self.in_channels}, '
f'out_channels={self.out_channels}, '
f'kernel_size={self.kernel_size}, '
f'demodulate={self.demodulate}, sample_mode={self.sample_mode})')
class StyleConv(nn.Module):
"""Style conv.
Args:
in_channels (int): Channel number of the input.
out_channels (int): Channel number of the output.
kernel_size (int): Size of the convolving kernel.
num_style_feat (int): Channel number of style features.
demodulate (bool): Whether demodulate in the conv layer. Default: True.
sample_mode (str | None): Indicating 'upsample', 'downsample' or None.
Default: None.
"""
def __init__(self,
in_channels,
out_channels,
kernel_size,
num_style_feat,
demodulate=True,
sample_mode=None,
interpolation_mode='bilinear'):
super(StyleConv, self).__init__()
self.modulated_conv = ModulatedConv2d(
in_channels,
out_channels,
kernel_size,
num_style_feat,
demodulate=demodulate,
sample_mode=sample_mode,
interpolation_mode=interpolation_mode)
self.weight = nn.Parameter(torch.zeros(1)) # for noise injection
self.activate = FusedLeakyReLU(out_channels)
def forward(self, x, style, noise=None):
# modulate
out = self.modulated_conv(x, style)
# noise injection
if noise is None:
b, _, h, w = out.shape
noise = out.new_empty(b, 1, h, w).normal_()
out = out + self.weight * noise
# activation (with bias)
out = self.activate(out)
return out
class ToRGB(nn.Module):
"""To RGB from features.
Args:
in_channels (int): Channel number of input.
num_style_feat (int): Channel number of style features.
upsample (bool): Whether to upsample. Default: True.
"""
def __init__(self, in_channels, num_style_feat, upsample=True, interpolation_mode='bilinear'):
super(ToRGB, self).__init__()
self.upsample = upsample
self.interpolation_mode = interpolation_mode
if self.interpolation_mode == 'nearest':
self.align_corners = None
else:
self.align_corners = False
self.modulated_conv = ModulatedConv2d(
in_channels,
3,
kernel_size=1,
num_style_feat=num_style_feat,
demodulate=False,
sample_mode=None,
interpolation_mode=interpolation_mode)
self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1))
def forward(self, x, style, skip=None):
"""Forward function.
Args:
x (Tensor): Feature tensor with shape (b, c, h, w).
style (Tensor): Tensor with shape (b, num_style_feat).
skip (Tensor): Base/skip tensor. Default: None.
Returns:
Tensor: RGB images.
"""
out = self.modulated_conv(x, style)
out = out + self.bias
if skip is not None:
if self.upsample:
skip = F.interpolate(
skip, scale_factor=2, mode=self.interpolation_mode, align_corners=self.align_corners)
out = out + skip
return out
class ConstantInput(nn.Module):
"""Constant input.
Args:
num_channel (int): Channel number of constant input.
size (int): Spatial size of constant input.
"""
def __init__(self, num_channel, size):
super(ConstantInput, self).__init__()
self.weight = nn.Parameter(torch.randn(1, num_channel, size, size))
def forward(self, batch):
out = self.weight.repeat(batch, 1, 1, 1)
return out
@ARCH_REGISTRY.register()
class StyleGAN2GeneratorBilinear(nn.Module):
"""StyleGAN2 Generator.
Args:
out_size (int): The spatial size of outputs.
num_style_feat (int): Channel number of style features. Default: 512.
num_mlp (int): Layer number of MLP style layers. Default: 8.
channel_multiplier (int): Channel multiplier for large networks of
StyleGAN2. Default: 2.
lr_mlp (float): Learning rate multiplier for mlp layers. Default: 0.01.
narrow (float): Narrow ratio for channels. Default: 1.0.
"""
def __init__(self,
out_size,
num_style_feat=512,
num_mlp=8,
channel_multiplier=2,
lr_mlp=0.01,
narrow=1,
interpolation_mode='bilinear'):
super(StyleGAN2GeneratorBilinear, self).__init__()
# Style MLP layers
self.num_style_feat = num_style_feat
style_mlp_layers = [NormStyleCode()]
for i in range(num_mlp):
style_mlp_layers.append(
EqualLinear(
num_style_feat, num_style_feat, bias=True, bias_init_val=0, lr_mul=lr_mlp,
activation='fused_lrelu'))
self.style_mlp = nn.Sequential(*style_mlp_layers)
channels = {
'4': int(512 * narrow),
'8': int(512 * narrow),
'16': int(512 * narrow),
'32': int(512 * narrow),
'64': int(256 * channel_multiplier * narrow),
'128': int(128 * channel_multiplier * narrow),
'256': int(64 * channel_multiplier * narrow),
'512': int(32 * channel_multiplier * narrow),
'1024': int(16 * channel_multiplier * narrow)
}
self.channels = channels
self.constant_input = ConstantInput(channels['4'], size=4)
self.style_conv1 = StyleConv(
channels['4'],
channels['4'],
kernel_size=3,
num_style_feat=num_style_feat,
demodulate=True,
sample_mode=None,
interpolation_mode=interpolation_mode)
self.to_rgb1 = ToRGB(channels['4'], num_style_feat, upsample=False, interpolation_mode=interpolation_mode)
self.log_size = int(math.log(out_size, 2))
self.num_layers = (self.log_size - 2) * 2 + 1
self.num_latent = self.log_size * 2 - 2
self.style_convs = nn.ModuleList()
self.to_rgbs = nn.ModuleList()
self.noises = nn.Module()
in_channels = channels['4']
# noise
for layer_idx in range(self.num_layers):
resolution = 2**((layer_idx + 5) // 2)
shape = [1, 1, resolution, resolution]
self.noises.register_buffer(f'noise{layer_idx}', torch.randn(*shape))
# style convs and to_rgbs
for i in range(3, self.log_size + 1):
out_channels = channels[f'{2**i}']
self.style_convs.append(
StyleConv(
in_channels,
out_channels,
kernel_size=3,
num_style_feat=num_style_feat,
demodulate=True,
sample_mode='upsample',
interpolation_mode=interpolation_mode))
self.style_convs.append(
StyleConv(
out_channels,
out_channels,
kernel_size=3,
num_style_feat=num_style_feat,
demodulate=True,
sample_mode=None,
interpolation_mode=interpolation_mode))
self.to_rgbs.append(
ToRGB(out_channels, num_style_feat, upsample=True, interpolation_mode=interpolation_mode))
in_channels = out_channels
def make_noise(self):
"""Make noise for noise injection."""
device = self.constant_input.weight.device
noises = [torch.randn(1, 1, 4, 4, device=device)]
for i in range(3, self.log_size + 1):
for _ in range(2):
noises.append(torch.randn(1, 1, 2**i, 2**i, device=device))
return noises
def get_latent(self, x):
return self.style_mlp(x)
def mean_latent(self, num_latent):
latent_in = torch.randn(num_latent, self.num_style_feat, device=self.constant_input.weight.device)
latent = self.style_mlp(latent_in).mean(0, keepdim=True)
return latent
def forward(self,
styles,
input_is_latent=False,
noise=None,
randomize_noise=True,
truncation=1,
truncation_latent=None,
inject_index=None,
return_latents=False):
"""Forward function for StyleGAN2Generator.
Args:
styles (list[Tensor]): Sample codes of styles.
input_is_latent (bool): Whether input is latent style.
Default: False.
noise (Tensor | None): Input noise or None. Default: None.
randomize_noise (bool): Randomize noise, used when 'noise' is
False. Default: True.
truncation (float): TODO. Default: 1.
truncation_latent (Tensor | None): TODO. Default: None.
inject_index (int | None): The injection index for mixing noise.
Default: None.
return_latents (bool): Whether to return style latents.
Default: False.
"""
# style codes -> latents with Style MLP layer
if not input_is_latent:
styles = [self.style_mlp(s) for s in styles]
# noises
if noise is None:
if randomize_noise:
noise = [None] * self.num_layers # for each style conv layer
else: # use the stored noise
noise = [getattr(self.noises, f'noise{i}') for i in range(self.num_layers)]
# style truncation
if truncation < 1:
style_truncation = []
for style in styles:
style_truncation.append(truncation_latent + truncation * (style - truncation_latent))
styles = style_truncation
# get style latent with injection
if len(styles) == 1:
inject_index = self.num_latent
if styles[0].ndim < 3:
# repeat latent code for all the layers
latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
else: # used for encoder with different latent code for each layer
latent = styles[0]
elif len(styles) == 2: # mixing noises
if inject_index is None:
inject_index = random.randint(1, self.num_latent - 1)
latent1 = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
latent2 = styles[1].unsqueeze(1).repeat(1, self.num_latent - inject_index, 1)
latent = torch.cat([latent1, latent2], 1)
# main generation
out = self.constant_input(latent.shape[0])
out = self.style_conv1(out, latent[:, 0], noise=noise[0])
skip = self.to_rgb1(out, latent[:, 1])
i = 1
for conv1, conv2, noise1, noise2, to_rgb in zip(self.style_convs[::2], self.style_convs[1::2], noise[1::2],
noise[2::2], self.to_rgbs):
out = conv1(out, latent[:, i], noise=noise1)
out = conv2(out, latent[:, i + 1], noise=noise2)
skip = to_rgb(out, latent[:, i + 2], skip)
i += 2
image = skip
if return_latents:
return image, latent
else:
return image, None
class ScaledLeakyReLU(nn.Module):
"""Scaled LeakyReLU.
Args:
negative_slope (float): Negative slope. Default: 0.2.
"""
def __init__(self, negative_slope=0.2):
super(ScaledLeakyReLU, self).__init__()
self.negative_slope = negative_slope
def forward(self, x):
out = F.leaky_relu(x, negative_slope=self.negative_slope)
return out * math.sqrt(2)
class EqualConv2d(nn.Module):
"""Equalized Linear as StyleGAN2.
Args:
in_channels (int): Channel number of the input.
out_channels (int): Channel number of the output.
kernel_size (int): Size of the convolving kernel.
stride (int): Stride of the convolution. Default: 1
padding (int): Zero-padding added to both sides of the input.
Default: 0.
bias (bool): If ``True``, adds a learnable bias to the output.
Default: ``True``.
bias_init_val (float): Bias initialized value. Default: 0.
"""
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, bias=True, bias_init_val=0):
super(EqualConv2d, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.stride = stride
self.padding = padding
self.scale = 1 / math.sqrt(in_channels * kernel_size**2)
self.weight = nn.Parameter(torch.randn(out_channels, in_channels, kernel_size, kernel_size))
if bias:
self.bias = nn.Parameter(torch.zeros(out_channels).fill_(bias_init_val))
else:
self.register_parameter('bias', None)
def forward(self, x):
out = F.conv2d(
x,
self.weight * self.scale,
bias=self.bias,
stride=self.stride,
padding=self.padding,
)
return out
def __repr__(self):
return (f'{self.__class__.__name__}(in_channels={self.in_channels}, '
f'out_channels={self.out_channels}, '
f'kernel_size={self.kernel_size},'
f' stride={self.stride}, padding={self.padding}, '
f'bias={self.bias is not None})')
class ConvLayer(nn.Sequential):
"""Conv Layer used in StyleGAN2 Discriminator.
Args:
in_channels (int): Channel number of the input.
out_channels (int): Channel number of the output.
kernel_size (int): Kernel size.
downsample (bool): Whether downsample by a factor of 2.
Default: False.
bias (bool): Whether with bias. Default: True.
activate (bool): Whether use activateion. Default: True.
"""
def __init__(self,
in_channels,
out_channels,
kernel_size,
downsample=False,
bias=True,
activate=True,
interpolation_mode='bilinear'):
layers = []
self.interpolation_mode = interpolation_mode
# downsample
if downsample:
if self.interpolation_mode == 'nearest':
self.align_corners = None
else:
self.align_corners = False
layers.append(
torch.nn.Upsample(scale_factor=0.5, mode=interpolation_mode, align_corners=self.align_corners))
stride = 1
self.padding = kernel_size // 2
# conv
layers.append(
EqualConv2d(
in_channels, out_channels, kernel_size, stride=stride, padding=self.padding, bias=bias
and not activate))
# activation
if activate:
if bias:
layers.append(FusedLeakyReLU(out_channels))
else:
layers.append(ScaledLeakyReLU(0.2))
super(ConvLayer, self).__init__(*layers)
class ResBlock(nn.Module):
"""Residual block used in StyleGAN2 Discriminator.
Args:
in_channels (int): Channel number of the input.
out_channels (int): Channel number of the output.
"""
def __init__(self, in_channels, out_channels, interpolation_mode='bilinear'):
super(ResBlock, self).__init__()
self.conv1 = ConvLayer(in_channels, in_channels, 3, bias=True, activate=True)
self.conv2 = ConvLayer(
in_channels,
out_channels,
3,
downsample=True,
interpolation_mode=interpolation_mode,
bias=True,
activate=True)
self.skip = ConvLayer(
in_channels,
out_channels,
1,
downsample=True,
interpolation_mode=interpolation_mode,
bias=False,
activate=False)
def forward(self, x):
out = self.conv1(x)
out = self.conv2(out)
skip = self.skip(x)
out = (out + skip) / math.sqrt(2)
return out

View File

@@ -3,7 +3,7 @@ import os.path as osp
import torch
from basicsr.archs import build_network
from basicsr.losses import build_loss
from basicsr.losses.losses import r1_penalty
from basicsr.losses.gan_loss import r1_penalty
from basicsr.metrics import calculate_metric
from basicsr.models.base_model import BaseModel
from basicsr.utils import get_root_logger, imwrite, tensor2img

View File

@@ -29,12 +29,12 @@ class GFPGANer():
bg_upsampler (nn.Module): The upsampler for the background. Default: None.
"""
def __init__(self, model_path, upscale=2, arch='clean', channel_multiplier=2, bg_upsampler=None):
def __init__(self, model_path, upscale=2, arch='clean', channel_multiplier=2, bg_upsampler=None, device=None):
self.upscale = upscale
self.bg_upsampler = bg_upsampler
# initialize model
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if device is None else device
# initialize the GFP-GAN
if arch == 'clean':
self.gfpgan = GFPGANv1Clean(
@@ -72,6 +72,9 @@ class GFPGANer():
different_w=True,
narrow=1,
sft_half=True)
elif arch == 'RestoreFormer':
from gfpgan.archs.restoreformer_arch import RestoreFormer
self.gfpgan = RestoreFormer()
# initialize face helper
self.face_helper = FaceRestoreHelper(
upscale,
@@ -79,7 +82,9 @@ class GFPGANer():
crop_ratio=(1, 1),
det_model='retinaface_resnet50',
save_ext='png',
device=self.device)
use_parse=True,
device=self.device,
model_rootpath='gfpgan/weights')
if model_path.startswith('https://'):
model_path = load_file_from_url(
@@ -94,7 +99,7 @@ class GFPGANer():
self.gfpgan = self.gfpgan.to(self.device)
@torch.no_grad()
def enhance(self, img, has_aligned=False, only_center_face=False, paste_back=True):
def enhance(self, img, has_aligned=False, only_center_face=False, paste_back=True, weight=0.5):
self.face_helper.clean_all()
if has_aligned: # the inputs are already aligned
@@ -117,7 +122,7 @@ class GFPGANer():
cropped_face_t = cropped_face_t.unsqueeze(0).to(self.device)
try:
output = self.gfpgan(cropped_face_t, return_rgb=False)[0]
output = self.gfpgan(cropped_face_t, return_rgb=False, weight=weight)[0]
# convert to image
restored_face = tensor2img(output.squeeze(0), rgb2bgr=True, min_max=(-1, 1))
except RuntimeError as error:

View File

@@ -41,6 +41,7 @@ def main():
type=str,
default='auto',
help='Image extension. Options: auto | jpg | png, auto means using the same extension as inputs. Default: auto')
parser.add_argument('-w', '--weight', type=float, default=0.5, help='Adjustable weights.')
args = parser.parse_args()
args = parser.parse_args()
@@ -82,23 +83,37 @@ def main():
arch = 'original'
channel_multiplier = 1
model_name = 'GFPGANv1'
url = 'https://github.com/TencentARC/GFPGAN/releases/download/v0.1.0/GFPGANv1.pth'
elif args.version == '1.2':
arch = 'clean'
channel_multiplier = 2
model_name = 'GFPGANCleanv1-NoCE-C2'
url = 'https://github.com/TencentARC/GFPGAN/releases/download/v0.2.0/GFPGANCleanv1-NoCE-C2.pth'
elif args.version == '1.3':
arch = 'clean'
channel_multiplier = 2
model_name = 'GFPGANv1.3'
url = 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth'
elif args.version == '1.4':
arch = 'clean'
channel_multiplier = 2
model_name = 'GFPGANv1.4'
url = 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth'
elif args.version == 'RestoreFormer':
arch = 'RestoreFormer'
channel_multiplier = 2
model_name = 'RestoreFormer'
url = 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/RestoreFormer.pth'
else:
raise ValueError(f'Wrong model version {args.version}.')
# determine model paths
model_path = os.path.join('experiments/pretrained_models', model_name + '.pth')
if not os.path.isfile(model_path):
model_path = os.path.join('realesrgan/weights', model_name + '.pth')
model_path = os.path.join('gfpgan/weights', model_name + '.pth')
if not os.path.isfile(model_path):
raise ValueError(f'Model {model_name} does not exist.')
# download pre-trained models from url
model_path = url
restorer = GFPGANer(
model_path=model_path,
@@ -117,7 +132,11 @@ def main():
# restore faces and background if necessary
cropped_faces, restored_faces, restored_img = restorer.enhance(
input_img, has_aligned=args.aligned, only_center_face=args.only_center_face, paste_back=True)
input_img,
has_aligned=args.aligned,
only_center_face=args.only_center_face,
paste_back=True,
weight=args.weight)
# save faces
for idx, (cropped_face, restored_face) in enumerate(zip(cropped_faces, restored_faces)):

View File

@@ -1,12 +1,12 @@
torch>=1.7
numpy<1.21 # numba requires numpy<1.21,>=1.17
opencv-python
torchvision
scipy
tqdm
basicsr>=1.3.4.0
facexlib>=0.2.0.3
basicsr>=1.4.2
facexlib>=0.2.5
lmdb
numpy
opencv-python
pyyaml
scipy
tb-nightly
torch>=1.7
torchvision
tqdm
yapf