mirror of
https://github.com/TencentARC/GFPGAN.git
synced 2026-02-07 13:56:55 +00:00
Compare commits
46 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7552a7791c | ||
|
|
2eac203389 | ||
|
|
2f46d95254 | ||
|
|
bc5a5deb95 | ||
|
|
3fd33abc47 | ||
|
|
d226e86f6c | ||
|
|
bb2f916764 | ||
|
|
fe3beac9dc | ||
|
|
126c55c68d | ||
|
|
8d2447a2d9 | ||
|
|
af7569775d | ||
|
|
c6593e7221 | ||
|
|
7272e45887 | ||
|
|
3e27784b1b | ||
|
|
2c420ee565 | ||
|
|
8e7cf5d723 | ||
|
|
c541e97f83 | ||
|
|
86756cba65 | ||
|
|
a9a2e3ae15 | ||
|
|
9c3f2d62cb | ||
|
|
ccd30af837 | ||
|
|
7d657f26b6 | ||
|
|
c7ccc098a7 | ||
|
|
bc3f0c4d91 | ||
|
|
924ce473ab | ||
|
|
09a37ae7fd | ||
|
|
6c544b70e6 | ||
|
|
47983e1767 | ||
|
|
77df6e4fad | ||
|
|
24b1f24ef5 | ||
|
|
c068e4d113 | ||
|
|
d8bf32a816 | ||
|
|
09d82ec683 | ||
|
|
780774d515 | ||
|
|
95101b46d2 | ||
|
|
547e026042 | ||
|
|
83bcb28462 | ||
|
|
942e7b39c6 | ||
|
|
8ba74c99ba | ||
|
|
3241c576ae | ||
|
|
05062fac70 | ||
|
|
3241798723 | ||
|
|
ee3e556f18 | ||
|
|
ad1397180d | ||
|
|
37237da798 | ||
|
|
be73d6d9a4 |
34
.github/workflows/no-response.yml
vendored
34
.github/workflows/no-response.yml
vendored
@@ -1,34 +0,0 @@
|
||||
name: No Response
|
||||
|
||||
# Modified from: https://raw.githubusercontent.com/github/docs/main/.github/workflows/no-response.yaml
|
||||
|
||||
# **What it does**: Closes issues that don't have enough information to be
|
||||
# actionable.
|
||||
# **Why we have it**: To remove the need for maintainers to remember to check
|
||||
# back on issues periodically to see if contributors have
|
||||
# responded.
|
||||
# **Who does it impact**: Everyone that works on docs or docs-internal.
|
||||
|
||||
on:
|
||||
issue_comment:
|
||||
types: [created]
|
||||
|
||||
schedule:
|
||||
# Schedule for five minutes after the hour every hour
|
||||
- cron: '5 * * * *'
|
||||
|
||||
jobs:
|
||||
noResponse:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: lee-dohm/no-response@v0.5.0
|
||||
with:
|
||||
token: ${{ github.token }}
|
||||
closeComment: >
|
||||
This issue has been automatically closed because there has been no response
|
||||
to our request for more information from the original author. With only the
|
||||
information that is currently in the issue, we don't have enough information
|
||||
to take action. Please reach out if you have or find the answers we need so
|
||||
that we can investigate further.
|
||||
If you still have questions, please improve your description and re-open it.
|
||||
Thanks :-)
|
||||
41
.github/workflows/release.yml
vendored
Normal file
41
.github/workflows/release.yml
vendored
Normal file
@@ -0,0 +1,41 @@
|
||||
name: release
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- '*'
|
||||
|
||||
jobs:
|
||||
build:
|
||||
permissions: write-all
|
||||
name: Create Release
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v2
|
||||
- name: Create Release
|
||||
id: create_release
|
||||
uses: actions/create-release@v1
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
with:
|
||||
tag_name: ${{ github.ref }}
|
||||
release_name: GFPGAN ${{ github.ref }} Release Note
|
||||
body: |
|
||||
🚀 See you again 😸
|
||||
🚀Have a nice day 😸 and happy everyday 😃
|
||||
🚀 Long time no see ☄️
|
||||
|
||||
✨ **Highlights**
|
||||
✅ [Features] Support ...
|
||||
|
||||
🐛 **Bug Fixes**
|
||||
|
||||
🌴 **Improvements**
|
||||
|
||||
📢📢📢
|
||||
|
||||
<p align="center">
|
||||
<img src="https://raw.githubusercontent.com/TencentARC/GFPGAN/master/assets/gfpgan_logo.png" height=150>
|
||||
</p>
|
||||
draft: true
|
||||
prerelease: false
|
||||
128
CODE_OF_CONDUCT.md
Normal file
128
CODE_OF_CONDUCT.md
Normal file
@@ -0,0 +1,128 @@
|
||||
# Contributor Covenant Code of Conduct
|
||||
|
||||
## Our Pledge
|
||||
|
||||
We as members, contributors, and leaders pledge to make participation in our
|
||||
community a harassment-free experience for everyone, regardless of age, body
|
||||
size, visible or invisible disability, ethnicity, sex characteristics, gender
|
||||
identity and expression, level of experience, education, socio-economic status,
|
||||
nationality, personal appearance, race, religion, or sexual identity
|
||||
and orientation.
|
||||
|
||||
We pledge to act and interact in ways that contribute to an open, welcoming,
|
||||
diverse, inclusive, and healthy community.
|
||||
|
||||
## Our Standards
|
||||
|
||||
Examples of behavior that contributes to a positive environment for our
|
||||
community include:
|
||||
|
||||
* Demonstrating empathy and kindness toward other people
|
||||
* Being respectful of differing opinions, viewpoints, and experiences
|
||||
* Giving and gracefully accepting constructive feedback
|
||||
* Accepting responsibility and apologizing to those affected by our mistakes,
|
||||
and learning from the experience
|
||||
* Focusing on what is best not just for us as individuals, but for the
|
||||
overall community
|
||||
|
||||
Examples of unacceptable behavior include:
|
||||
|
||||
* The use of sexualized language or imagery, and sexual attention or
|
||||
advances of any kind
|
||||
* Trolling, insulting or derogatory comments, and personal or political attacks
|
||||
* Public or private harassment
|
||||
* Publishing others' private information, such as a physical or email
|
||||
address, without their explicit permission
|
||||
* Other conduct which could reasonably be considered inappropriate in a
|
||||
professional setting
|
||||
|
||||
## Enforcement Responsibilities
|
||||
|
||||
Community leaders are responsible for clarifying and enforcing our standards of
|
||||
acceptable behavior and will take appropriate and fair corrective action in
|
||||
response to any behavior that they deem inappropriate, threatening, offensive,
|
||||
or harmful.
|
||||
|
||||
Community leaders have the right and responsibility to remove, edit, or reject
|
||||
comments, commits, code, wiki edits, issues, and other contributions that are
|
||||
not aligned to this Code of Conduct, and will communicate reasons for moderation
|
||||
decisions when appropriate.
|
||||
|
||||
## Scope
|
||||
|
||||
This Code of Conduct applies within all community spaces, and also applies when
|
||||
an individual is officially representing the community in public spaces.
|
||||
Examples of representing our community include using an official e-mail address,
|
||||
posting via an official social media account, or acting as an appointed
|
||||
representative at an online or offline event.
|
||||
|
||||
## Enforcement
|
||||
|
||||
Instances of abusive, harassing, or otherwise unacceptable behavior may be
|
||||
reported to the community leaders responsible for enforcement at
|
||||
xintao.wang@outlook.com or xintaowang@tencent.com.
|
||||
All complaints will be reviewed and investigated promptly and fairly.
|
||||
|
||||
All community leaders are obligated to respect the privacy and security of the
|
||||
reporter of any incident.
|
||||
|
||||
## Enforcement Guidelines
|
||||
|
||||
Community leaders will follow these Community Impact Guidelines in determining
|
||||
the consequences for any action they deem in violation of this Code of Conduct:
|
||||
|
||||
### 1. Correction
|
||||
|
||||
**Community Impact**: Use of inappropriate language or other behavior deemed
|
||||
unprofessional or unwelcome in the community.
|
||||
|
||||
**Consequence**: A private, written warning from community leaders, providing
|
||||
clarity around the nature of the violation and an explanation of why the
|
||||
behavior was inappropriate. A public apology may be requested.
|
||||
|
||||
### 2. Warning
|
||||
|
||||
**Community Impact**: A violation through a single incident or series
|
||||
of actions.
|
||||
|
||||
**Consequence**: A warning with consequences for continued behavior. No
|
||||
interaction with the people involved, including unsolicited interaction with
|
||||
those enforcing the Code of Conduct, for a specified period of time. This
|
||||
includes avoiding interactions in community spaces as well as external channels
|
||||
like social media. Violating these terms may lead to a temporary or
|
||||
permanent ban.
|
||||
|
||||
### 3. Temporary Ban
|
||||
|
||||
**Community Impact**: A serious violation of community standards, including
|
||||
sustained inappropriate behavior.
|
||||
|
||||
**Consequence**: A temporary ban from any sort of interaction or public
|
||||
communication with the community for a specified period of time. No public or
|
||||
private interaction with the people involved, including unsolicited interaction
|
||||
with those enforcing the Code of Conduct, is allowed during this period.
|
||||
Violating these terms may lead to a permanent ban.
|
||||
|
||||
### 4. Permanent Ban
|
||||
|
||||
**Community Impact**: Demonstrating a pattern of violation of community
|
||||
standards, including sustained inappropriate behavior, harassment of an
|
||||
individual, or aggression toward or disparagement of classes of individuals.
|
||||
|
||||
**Consequence**: A permanent ban from any sort of public interaction within
|
||||
the community.
|
||||
|
||||
## Attribution
|
||||
|
||||
This Code of Conduct is adapted from the [Contributor Covenant][homepage],
|
||||
version 2.0, available at
|
||||
https://www.contributor-covenant.org/version/2/0/code_of_conduct.html.
|
||||
|
||||
Community Impact Guidelines were inspired by [Mozilla's code of conduct
|
||||
enforcement ladder](https://github.com/mozilla/diversity).
|
||||
|
||||
[homepage]: https://www.contributor-covenant.org
|
||||
|
||||
For answers to common questions about this code of conduct, see the FAQ at
|
||||
https://www.contributor-covenant.org/faq. Translations are available at
|
||||
https://www.contributor-covenant.org/translations.
|
||||
24
Comparisons.md
Normal file
24
Comparisons.md
Normal file
@@ -0,0 +1,24 @@
|
||||
# Comparisons
|
||||
|
||||
## Comparisons among different model versions
|
||||
|
||||
Note that V1.3 is not always better than V1.2. You may need to try different models based on your purpose and inputs.
|
||||
|
||||
| Version | Strengths | Weaknesses |
|
||||
| :---: | :---: | :---: |
|
||||
|V1.3 | ✓ natural outputs<br> ✓better results on very low-quality inputs <br> ✓ work on relatively high-quality inputs <br>✓ can have repeated (twice) restorations | ✗ not very sharp <br> ✗ have a slight change on identity |
|
||||
|V1.2 | ✓ sharper output <br> ✓ with beauty makeup | ✗ some outputs are unnatural|
|
||||
|
||||
For the following images, you may need to **zoom in** for comparing details, or **click the image** to see in the full size.
|
||||
|
||||
| Input | V1 | V1.2 | V1.3
|
||||
| :---: | :---: | :---: | :---: |
|
||||
||  |  |  |
|
||||
|  |  |  | |
|
||||
|  |  |  | |
|
||||
|  |  |  | |
|
||||
|  |  |  | |
|
||||
|  |  |  | |
|
||||
|  |  |  | |
|
||||
|
||||
<!-- | ![]() | ![]() | ![]() | ![]()| -->
|
||||
7
FAQ.md
Normal file
7
FAQ.md
Normal file
@@ -0,0 +1,7 @@
|
||||
# FAQ
|
||||
|
||||
1. **How to finetune the GFPGANCleanv1-NoCE-C2 (v1.2) model**
|
||||
|
||||
**A:** 1) The GFPGANCleanv1-NoCE-C2 (v1.2) model uses the *clean* architecture, which is more friendly for deploying.
|
||||
2) This model is not directly trained. Instead, it is converted from another *bilinear* model.
|
||||
3) If you want to finetune the GFPGANCleanv1-NoCE-C2 (v1.2), you need to finetune its original *bilinear* model, and then do the conversion.
|
||||
@@ -60,17 +60,17 @@ wget https://github.com/TencentARC/GFPGAN/releases/download/v0.1.0/GFPGANv1.pth
|
||||
- Option 1: Load extensions just-in-time(JIT)
|
||||
|
||||
```bash
|
||||
BASICSR_JIT=True python inference_gfpgan.py --model_path experiments/pretrained_models/GFPGANv1.pth --test_path inputs/whole_imgs --save_root results --arch original --channel 1
|
||||
BASICSR_JIT=True python inference_gfpgan.py --input inputs/whole_imgs --output results --version 1
|
||||
|
||||
# for aligned images
|
||||
BASICSR_JIT=True python inference_gfpgan.py --model_path experiments/pretrained_models/GFPGANv1.pth --test_path inputs/cropped_faces --save_root results --arch original --channel 1 --aligned
|
||||
BASICSR_JIT=True python inference_gfpgan.py --input inputs/whole_imgs --output results --version 1 --aligned
|
||||
```
|
||||
|
||||
- Option 2: Have successfully compiled extensions during installation
|
||||
|
||||
```bash
|
||||
python inference_gfpgan.py --model_path experiments/pretrained_models/GFPGANv1.pth --test_path inputs/whole_imgs --save_root results --arch original --channel 1
|
||||
python inference_gfpgan.py --input inputs/whole_imgs --output results --version 1
|
||||
|
||||
# for aligned images
|
||||
python inference_gfpgan.py --model_path experiments/pretrained_models/GFPGANv1.pth --test_path inputs/cropped_faces --save_root results --arch original --channel 1 --aligned
|
||||
python inference_gfpgan.py --input inputs/whole_imgs --output results --version 1 --aligned
|
||||
```
|
||||
|
||||
81
README.md
81
README.md
@@ -1,4 +1,13 @@
|
||||
# GFPGAN (CVPR 2021)
|
||||
<p align="center">
|
||||
<img src="assets/gfpgan_logo.png" height=130>
|
||||
</p>
|
||||
|
||||
## <div align="center"><b><a href="README.md">English</a> | <a href="README_CN.md">简体中文</a></b></div>
|
||||
|
||||
<div align="center">
|
||||
<!-- <a href="https://twitter.com/_Xintao_" style="text-decoration:none;">
|
||||
<img src="https://user-images.githubusercontent.com/17445847/187162058-c764ced6-952f-404b-ac85-ba95cce18e7b.png" width="4%" alt="" />
|
||||
</a> -->
|
||||
|
||||
[](https://github.com/TencentARC/GFPGAN/releases)
|
||||
[](https://pypi.org/project/gfpgan/)
|
||||
@@ -7,14 +16,28 @@
|
||||
[](https://github.com/TencentARC/GFPGAN/blob/master/LICENSE)
|
||||
[](https://github.com/TencentARC/GFPGAN/blob/master/.github/workflows/pylint.yml)
|
||||
[](https://github.com/TencentARC/GFPGAN/blob/master/.github/workflows/publish-pip.yml)
|
||||
</div>
|
||||
|
||||
1. :boom: **Updated** online demo: [](https://replicate.com/tencentarc/gfpgan). Here is the [backup](https://replicate.com/xinntao/gfpgan).
|
||||
1. :boom: **Updated** online demo: [](https://huggingface.co/spaces/Xintao/GFPGAN)
|
||||
1. [Colab Demo](https://colab.research.google.com/drive/1sVsoBd9AjckIXThgtZhGrHRfFI6UUYOo) for GFPGAN <a href="https://colab.research.google.com/drive/1sVsoBd9AjckIXThgtZhGrHRfFI6UUYOo"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="google colab logo"></a>; (Another [Colab Demo](https://colab.research.google.com/drive/1Oa1WwKB4M4l1GmR7CtswDVgOCOeSLChA?usp=sharing) for the original paper model)
|
||||
1. We provide a *clean* version of GFPGAN, which can run without CUDA extensions. So that it can run in **Windows** or on **CPU mode**.
|
||||
|
||||
GFPGAN aims at developing **Practical Algorithm for Real-world Face Restoration**.<br>
|
||||
<!-- 3. Online demo: [Replicate.ai](https://replicate.com/xinntao/gfpgan) (may need to sign in, return the whole image)
|
||||
4. Online demo: [Baseten.co](https://app.baseten.co/applications/Q04Lz0d/operator_views/8qZG6Bg) (backed by GPU, returns the whole image)
|
||||
5. We provide a *clean* version of GFPGAN, which can run without CUDA extensions. So that it can run in **Windows** or on **CPU mode**. -->
|
||||
|
||||
> :rocket: **Thanks for your interest in our work. You may also want to check our new updates on the *tiny models* for *anime images and videos* in [Real-ESRGAN](https://github.com/xinntao/Real-ESRGAN/blob/master/docs/anime_video_model.md)** :blush:
|
||||
|
||||
GFPGAN aims at developing a **Practical Algorithm for Real-world Face Restoration**.<br>
|
||||
It leverages rich and diverse priors encapsulated in a pretrained face GAN (*e.g.*, StyleGAN2) for blind face restoration.
|
||||
|
||||
:question: Frequently Asked Questions can be found in [FAQ.md](FAQ.md).
|
||||
|
||||
:triangular_flag_on_post: **Updates**
|
||||
|
||||
- :white_check_mark: Add [RestoreFormer](https://github.com/wzhouxiff/RestoreFormer) inference codes.
|
||||
- :white_check_mark: Add [V1.4 model](https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth), which produces slightly more details and better identity than V1.3.
|
||||
- :white_check_mark: Add **[V1.3 model](https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth)**, which produces **more natural** restoration results, and better results on *very low-quality* / *high-quality* inputs. See more in [Model zoo](#european_castle-model-zoo), [Comparisons.md](Comparisons.md)
|
||||
- :white_check_mark: Integrated to [Huggingface Spaces](https://huggingface.co/spaces) with [Gradio](https://github.com/gradio-app/gradio). See [Gradio Web Demo](https://huggingface.co/spaces/akhaliq/GFPGAN).
|
||||
- :white_check_mark: Support enhancing non-face regions (background) with [Real-ESRGAN](https://github.com/xinntao/Real-ESRGAN).
|
||||
- :white_check_mark: We provide a *clean* version of GFPGAN, which does not require CUDA extensions.
|
||||
@@ -25,9 +48,9 @@ It leverages rich and diverse priors encapsulated in a pretrained face GAN (*e.g
|
||||
If GFPGAN is helpful in your photos/projects, please help to :star: this repo or recommend it to your friends. Thanks:blush:
|
||||
Other recommended projects:<br>
|
||||
:arrow_forward: [Real-ESRGAN](https://github.com/xinntao/Real-ESRGAN): A practical algorithm for general image restoration<br>
|
||||
:arrow_forward: [BasicSR](https://github.com/xinntao/BasicSR): An ppen-source image and video restoration toolbox<br>
|
||||
:arrow_forward: [facexlib](https://github.com/xinntao/facexlib): A collection that provides useful face-relation functions.<br>
|
||||
:arrow_forward: [HandyView](https://github.com/xinntao/HandyView): A PyQt5-based image viewer that is handy for view and comparison. <br>
|
||||
:arrow_forward: [BasicSR](https://github.com/xinntao/BasicSR): An open-source image and video restoration toolbox<br>
|
||||
:arrow_forward: [facexlib](https://github.com/xinntao/facexlib): A collection that provides useful face-relation functions<br>
|
||||
:arrow_forward: [HandyView](https://github.com/xinntao/HandyView): A PyQt5-based image viewer that is handy for view and comparison<br>
|
||||
|
||||
---
|
||||
|
||||
@@ -53,7 +76,7 @@ Other recommended projects:<br>
|
||||
### Installation
|
||||
|
||||
We now provide a *clean* version of GFPGAN, which does not require customized CUDA extensions. <br>
|
||||
If you want want to use the original model in our paper, please see [PaperModel.md](PaperModel.md) for installation.
|
||||
If you want to use the original model in our paper, please see [PaperModel.md](PaperModel.md) for installation.
|
||||
|
||||
1. Clone repo
|
||||
|
||||
@@ -83,24 +106,54 @@ If you want want to use the original model in our paper, please see [PaperModel.
|
||||
|
||||
## :zap: Quick Inference
|
||||
|
||||
Download pre-trained models: [GFPGANCleanv1-NoCE-C2.pth](https://github.com/TencentARC/GFPGAN/releases/download/v0.2.0/GFPGANCleanv1-NoCE-C2.pth)
|
||||
We take the v1.3 version for an example. More models can be found [here](#european_castle-model-zoo).
|
||||
|
||||
Download pre-trained models: [GFPGANv1.3.pth](https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth)
|
||||
|
||||
```bash
|
||||
wget https://github.com/TencentARC/GFPGAN/releases/download/v0.2.0/GFPGANCleanv1-NoCE-C2.pth -P experiments/pretrained_models
|
||||
wget https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth -P experiments/pretrained_models
|
||||
```
|
||||
|
||||
**Inference!**
|
||||
|
||||
```bash
|
||||
python inference_gfpgan.py --upscale 2 --test_path inputs/whole_imgs --save_root results
|
||||
python inference_gfpgan.py -i inputs/whole_imgs -o results -v 1.3 -s 2
|
||||
```
|
||||
|
||||
If you want want to use the original model in our paper, please see [PaperModel.md](PaperModel.md) for installation and inference.
|
||||
```console
|
||||
Usage: python inference_gfpgan.py -i inputs/whole_imgs -o results -v 1.3 -s 2 [options]...
|
||||
|
||||
-h show this help
|
||||
-i input Input image or folder. Default: inputs/whole_imgs
|
||||
-o output Output folder. Default: results
|
||||
-v version GFPGAN model version. Option: 1 | 1.2 | 1.3. Default: 1.3
|
||||
-s upscale The final upsampling scale of the image. Default: 2
|
||||
-bg_upsampler background upsampler. Default: realesrgan
|
||||
-bg_tile Tile size for background sampler, 0 for no tile during testing. Default: 400
|
||||
-suffix Suffix of the restored faces
|
||||
-only_center_face Only restore the center face
|
||||
-aligned Input are aligned faces
|
||||
-ext Image extension. Options: auto | jpg | png, auto means using the same extension as inputs. Default: auto
|
||||
```
|
||||
|
||||
If you want to use the original model in our paper, please see [PaperModel.md](PaperModel.md) for installation and inference.
|
||||
|
||||
## :european_castle: Model Zoo
|
||||
|
||||
- [GFPGANCleanv1-NoCE-C2.pth](https://github.com/TencentARC/GFPGAN/releases/download/v0.2.0/GFPGANCleanv1-NoCE-C2.pth): No colorization; no CUDA extensions are required. It is still in training. Trained with more data with pre-processing.
|
||||
- [GFPGANv1.pth](https://github.com/TencentARC/GFPGAN/releases/download/v0.1.0/GFPGANv1.pth): The paper model, with colorization.
|
||||
| Version | Model Name | Description |
|
||||
| :---: | :---: | :---: |
|
||||
| V1.3 | [GFPGANv1.3.pth](https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth) | Based on V1.2; **more natural** restoration results; better results on very low-quality / high-quality inputs. |
|
||||
| V1.2 | [GFPGANCleanv1-NoCE-C2.pth](https://github.com/TencentARC/GFPGAN/releases/download/v0.2.0/GFPGANCleanv1-NoCE-C2.pth) | No colorization; no CUDA extensions are required. Trained with more data with pre-processing. |
|
||||
| V1 | [GFPGANv1.pth](https://github.com/TencentARC/GFPGAN/releases/download/v0.1.0/GFPGANv1.pth) | The paper model, with colorization. |
|
||||
|
||||
The comparisons are in [Comparisons.md](Comparisons.md).
|
||||
|
||||
Note that V1.3 is not always better than V1.2. You may need to select different models based on your purpose and inputs.
|
||||
|
||||
| Version | Strengths | Weaknesses |
|
||||
| :---: | :---: | :---: |
|
||||
|V1.3 | ✓ natural outputs<br> ✓better results on very low-quality inputs <br> ✓ work on relatively high-quality inputs <br>✓ can have repeated (twice) restorations | ✗ not very sharp <br> ✗ have a slight change on identity |
|
||||
|V1.2 | ✓ sharper output <br> ✓ with beauty makeup | ✗ some outputs are unnatural |
|
||||
|
||||
You can find **more models (such as the discriminators)** here: [[Google Drive](https://drive.google.com/drive/folders/17rLiFzcUMoQuhLnptDsKolegHWwJOnHu?usp=sharing)], OR [[Tencent Cloud 腾讯微云](https://share.weiyun.com/ShYoCCoc)]
|
||||
|
||||
@@ -121,7 +174,7 @@ You could improve it according to your own needs.
|
||||
1. Dataset preparation: [FFHQ](https://github.com/NVlabs/ffhq-dataset)
|
||||
|
||||
1. Download pre-trained models and other data. Put them in the `experiments/pretrained_models` folder.
|
||||
1. [Pretrained StyleGAN2 model: StyleGAN2_512_Cmul1_FFHQ_B12G4_scratch_800k.pth](https://github.com/TencentARC/GFPGAN/releases/download/v0.1.0/StyleGAN2_512_Cmul1_FFHQ_B12G4_scratch_800k.pth)
|
||||
1. [Pre-trained StyleGAN2 model: StyleGAN2_512_Cmul1_FFHQ_B12G4_scratch_800k.pth](https://github.com/TencentARC/GFPGAN/releases/download/v0.1.0/StyleGAN2_512_Cmul1_FFHQ_B12G4_scratch_800k.pth)
|
||||
1. [Component locations of FFHQ: FFHQ_eye_mouth_landmarks_512.pth](https://github.com/TencentARC/GFPGAN/releases/download/v0.1.0/FFHQ_eye_mouth_landmarks_512.pth)
|
||||
1. [A simple ArcFace model: arcface_resnet18.pth](https://github.com/TencentARC/GFPGAN/releases/download/v0.1.0/arcface_resnet18.pth)
|
||||
|
||||
|
||||
7
README_CN.md
Normal file
7
README_CN.md
Normal file
@@ -0,0 +1,7 @@
|
||||
<p align="center">
|
||||
<img src="assets/gfpgan_logo.png" height=130>
|
||||
</p>
|
||||
|
||||
## <div align="center"><b><a href="README.md">English</a> | <a href="README_CN.md">简体中文</a></b></div>
|
||||
|
||||
还未完工,欢迎贡献!
|
||||
BIN
assets/gfpgan_logo.png
Normal file
BIN
assets/gfpgan_logo.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 50 KiB |
22
cog.yaml
Normal file
22
cog.yaml
Normal file
@@ -0,0 +1,22 @@
|
||||
# This file is used for constructing replicate env
|
||||
image: "r8.im/tencentarc/gfpgan"
|
||||
|
||||
build:
|
||||
gpu: true
|
||||
python_version: "3.8"
|
||||
system_packages:
|
||||
- "libgl1-mesa-glx"
|
||||
- "libglib2.0-0"
|
||||
python_packages:
|
||||
- "torch==1.7.1"
|
||||
- "torchvision==0.8.2"
|
||||
- "numpy==1.21.1"
|
||||
- "lmdb==1.2.1"
|
||||
- "opencv-python==4.5.3.56"
|
||||
- "PyYAML==5.4.1"
|
||||
- "tqdm==4.62.2"
|
||||
- "yapf==0.31.0"
|
||||
- "basicsr==1.4.2"
|
||||
- "facexlib==0.2.5"
|
||||
|
||||
predict: "cog_predict.py:Predictor"
|
||||
161
cog_predict.py
Normal file
161
cog_predict.py
Normal file
@@ -0,0 +1,161 @@
|
||||
# flake8: noqa
|
||||
# This file is used for deploying replicate models
|
||||
# running: cog predict -i img=@inputs/whole_imgs/10045.png -i version='v1.4' -i scale=2
|
||||
# push: cog push r8.im/tencentarc/gfpgan
|
||||
# push (backup): cog push r8.im/xinntao/gfpgan
|
||||
|
||||
import os
|
||||
|
||||
os.system('python setup.py develop')
|
||||
os.system('pip install realesrgan')
|
||||
|
||||
import cv2
|
||||
import shutil
|
||||
import tempfile
|
||||
import torch
|
||||
from basicsr.archs.srvgg_arch import SRVGGNetCompact
|
||||
|
||||
from gfpgan import GFPGANer
|
||||
|
||||
try:
|
||||
from cog import BasePredictor, Input, Path
|
||||
from realesrgan.utils import RealESRGANer
|
||||
except Exception:
|
||||
print('please install cog and realesrgan package')
|
||||
|
||||
|
||||
class Predictor(BasePredictor):
|
||||
|
||||
def setup(self):
|
||||
os.makedirs('output', exist_ok=True)
|
||||
# download weights
|
||||
if not os.path.exists('gfpgan/weights/realesr-general-x4v3.pth'):
|
||||
os.system(
|
||||
'wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth -P ./gfpgan/weights'
|
||||
)
|
||||
if not os.path.exists('gfpgan/weights/GFPGANv1.2.pth'):
|
||||
os.system(
|
||||
'wget https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.2.pth -P ./gfpgan/weights')
|
||||
if not os.path.exists('gfpgan/weights/GFPGANv1.3.pth'):
|
||||
os.system(
|
||||
'wget https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth -P ./gfpgan/weights')
|
||||
if not os.path.exists('gfpgan/weights/GFPGANv1.4.pth'):
|
||||
os.system(
|
||||
'wget https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth -P ./gfpgan/weights')
|
||||
if not os.path.exists('gfpgan/weights/RestoreFormer.pth'):
|
||||
os.system(
|
||||
'wget https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/RestoreFormer.pth -P ./gfpgan/weights'
|
||||
)
|
||||
|
||||
# background enhancer with RealESRGAN
|
||||
model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
|
||||
model_path = 'gfpgan/weights/realesr-general-x4v3.pth'
|
||||
half = True if torch.cuda.is_available() else False
|
||||
self.upsampler = RealESRGANer(
|
||||
scale=4, model_path=model_path, model=model, tile=0, tile_pad=10, pre_pad=0, half=half)
|
||||
|
||||
# Use GFPGAN for face enhancement
|
||||
self.face_enhancer = GFPGANer(
|
||||
model_path='gfpgan/weights/GFPGANv1.4.pth',
|
||||
upscale=2,
|
||||
arch='clean',
|
||||
channel_multiplier=2,
|
||||
bg_upsampler=self.upsampler)
|
||||
self.current_version = 'v1.4'
|
||||
|
||||
def predict(
|
||||
self,
|
||||
img: Path = Input(description='Input'),
|
||||
version: str = Input(
|
||||
description='GFPGAN version. v1.3: better quality. v1.4: more details and better identity.',
|
||||
choices=['v1.2', 'v1.3', 'v1.4', 'RestoreFormer'],
|
||||
default='v1.4'),
|
||||
scale: float = Input(description='Rescaling factor', default=2),
|
||||
) -> Path:
|
||||
weight = 0.5
|
||||
print(img, version, scale, weight)
|
||||
try:
|
||||
extension = os.path.splitext(os.path.basename(str(img)))[1]
|
||||
img = cv2.imread(str(img), cv2.IMREAD_UNCHANGED)
|
||||
if len(img.shape) == 3 and img.shape[2] == 4:
|
||||
img_mode = 'RGBA'
|
||||
elif len(img.shape) == 2:
|
||||
img_mode = None
|
||||
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
|
||||
else:
|
||||
img_mode = None
|
||||
|
||||
h, w = img.shape[0:2]
|
||||
if h < 300:
|
||||
img = cv2.resize(img, (w * 2, h * 2), interpolation=cv2.INTER_LANCZOS4)
|
||||
|
||||
if self.current_version != version:
|
||||
if version == 'v1.2':
|
||||
self.face_enhancer = GFPGANer(
|
||||
model_path='gfpgan/weights/GFPGANv1.2.pth',
|
||||
upscale=2,
|
||||
arch='clean',
|
||||
channel_multiplier=2,
|
||||
bg_upsampler=self.upsampler)
|
||||
self.current_version = 'v1.2'
|
||||
elif version == 'v1.3':
|
||||
self.face_enhancer = GFPGANer(
|
||||
model_path='gfpgan/weights/GFPGANv1.3.pth',
|
||||
upscale=2,
|
||||
arch='clean',
|
||||
channel_multiplier=2,
|
||||
bg_upsampler=self.upsampler)
|
||||
self.current_version = 'v1.3'
|
||||
elif version == 'v1.4':
|
||||
self.face_enhancer = GFPGANer(
|
||||
model_path='gfpgan/weights/GFPGANv1.4.pth',
|
||||
upscale=2,
|
||||
arch='clean',
|
||||
channel_multiplier=2,
|
||||
bg_upsampler=self.upsampler)
|
||||
self.current_version = 'v1.4'
|
||||
elif version == 'RestoreFormer':
|
||||
self.face_enhancer = GFPGANer(
|
||||
model_path='gfpgan/weights/RestoreFormer.pth',
|
||||
upscale=2,
|
||||
arch='RestoreFormer',
|
||||
channel_multiplier=2,
|
||||
bg_upsampler=self.upsampler)
|
||||
|
||||
try:
|
||||
_, _, output = self.face_enhancer.enhance(
|
||||
img, has_aligned=False, only_center_face=False, paste_back=True, weight=weight)
|
||||
except RuntimeError as error:
|
||||
print('Error', error)
|
||||
|
||||
try:
|
||||
if scale != 2:
|
||||
interpolation = cv2.INTER_AREA if scale < 2 else cv2.INTER_LANCZOS4
|
||||
h, w = img.shape[0:2]
|
||||
output = cv2.resize(output, (int(w * scale / 2), int(h * scale / 2)), interpolation=interpolation)
|
||||
except Exception as error:
|
||||
print('wrong scale input.', error)
|
||||
|
||||
if img_mode == 'RGBA': # RGBA images should be saved in png format
|
||||
extension = 'png'
|
||||
# save_path = f'output/out.{extension}'
|
||||
# cv2.imwrite(save_path, output)
|
||||
out_path = Path(tempfile.mkdtemp()) / f'out.{extension}'
|
||||
cv2.imwrite(str(out_path), output)
|
||||
except Exception as error:
|
||||
print('global exception: ', error)
|
||||
finally:
|
||||
clean_folder('output')
|
||||
return out_path
|
||||
|
||||
|
||||
def clean_folder(folder):
|
||||
for filename in os.listdir(folder):
|
||||
file_path = os.path.join(folder, filename)
|
||||
try:
|
||||
if os.path.isfile(file_path) or os.path.islink(file_path):
|
||||
os.unlink(file_path)
|
||||
elif os.path.isdir(file_path):
|
||||
shutil.rmtree(file_path)
|
||||
except Exception as e:
|
||||
print(f'Failed to delete {file_path}. Reason: {e}')
|
||||
@@ -3,4 +3,5 @@ from .archs import *
|
||||
from .data import *
|
||||
from .models import *
|
||||
from .utils import *
|
||||
from .version import __gitsha__, __version__
|
||||
|
||||
# from .version import *
|
||||
|
||||
@@ -2,13 +2,27 @@ import torch.nn as nn
|
||||
from basicsr.utils.registry import ARCH_REGISTRY
|
||||
|
||||
|
||||
def conv3x3(in_planes, out_planes, stride=1):
|
||||
"""3x3 convolution with padding"""
|
||||
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
|
||||
def conv3x3(inplanes, outplanes, stride=1):
|
||||
"""A simple wrapper for 3x3 convolution with padding.
|
||||
|
||||
Args:
|
||||
inplanes (int): Channel number of inputs.
|
||||
outplanes (int): Channel number of outputs.
|
||||
stride (int): Stride in convolution. Default: 1.
|
||||
"""
|
||||
return nn.Conv2d(inplanes, outplanes, kernel_size=3, stride=stride, padding=1, bias=False)
|
||||
|
||||
|
||||
class BasicBlock(nn.Module):
|
||||
expansion = 1
|
||||
"""Basic residual block used in the ResNetArcFace architecture.
|
||||
|
||||
Args:
|
||||
inplanes (int): Channel number of inputs.
|
||||
planes (int): Channel number of outputs.
|
||||
stride (int): Stride in convolution. Default: 1.
|
||||
downsample (nn.Module): The downsample module. Default: None.
|
||||
"""
|
||||
expansion = 1 # output channel expansion ratio
|
||||
|
||||
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
||||
super(BasicBlock, self).__init__()
|
||||
@@ -40,7 +54,16 @@ class BasicBlock(nn.Module):
|
||||
|
||||
|
||||
class IRBlock(nn.Module):
|
||||
expansion = 1
|
||||
"""Improved residual block (IR Block) used in the ResNetArcFace architecture.
|
||||
|
||||
Args:
|
||||
inplanes (int): Channel number of inputs.
|
||||
planes (int): Channel number of outputs.
|
||||
stride (int): Stride in convolution. Default: 1.
|
||||
downsample (nn.Module): The downsample module. Default: None.
|
||||
use_se (bool): Whether use the SEBlock (squeeze and excitation block). Default: True.
|
||||
"""
|
||||
expansion = 1 # output channel expansion ratio
|
||||
|
||||
def __init__(self, inplanes, planes, stride=1, downsample=None, use_se=True):
|
||||
super(IRBlock, self).__init__()
|
||||
@@ -78,7 +101,15 @@ class IRBlock(nn.Module):
|
||||
|
||||
|
||||
class Bottleneck(nn.Module):
|
||||
expansion = 4
|
||||
"""Bottleneck block used in the ResNetArcFace architecture.
|
||||
|
||||
Args:
|
||||
inplanes (int): Channel number of inputs.
|
||||
planes (int): Channel number of outputs.
|
||||
stride (int): Stride in convolution. Default: 1.
|
||||
downsample (nn.Module): The downsample module. Default: None.
|
||||
"""
|
||||
expansion = 4 # output channel expansion ratio
|
||||
|
||||
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
||||
super(Bottleneck, self).__init__()
|
||||
@@ -116,10 +147,16 @@ class Bottleneck(nn.Module):
|
||||
|
||||
|
||||
class SEBlock(nn.Module):
|
||||
"""The squeeze-and-excitation block (SEBlock) used in the IRBlock.
|
||||
|
||||
Args:
|
||||
channel (int): Channel number of inputs.
|
||||
reduction (int): Channel reduction ration. Default: 16.
|
||||
"""
|
||||
|
||||
def __init__(self, channel, reduction=16):
|
||||
super(SEBlock, self).__init__()
|
||||
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
||||
self.avg_pool = nn.AdaptiveAvgPool2d(1) # pool to 1x1 without spatial information
|
||||
self.fc = nn.Sequential(
|
||||
nn.Linear(channel, channel // reduction), nn.PReLU(), nn.Linear(channel // reduction, channel),
|
||||
nn.Sigmoid())
|
||||
@@ -133,6 +170,15 @@ class SEBlock(nn.Module):
|
||||
|
||||
@ARCH_REGISTRY.register()
|
||||
class ResNetArcFace(nn.Module):
|
||||
"""ArcFace with ResNet architectures.
|
||||
|
||||
Ref: ArcFace: Additive Angular Margin Loss for Deep Face Recognition.
|
||||
|
||||
Args:
|
||||
block (str): Block used in the ArcFace architecture.
|
||||
layers (tuple(int)): Block numbers in each layer.
|
||||
use_se (bool): Whether use the SEBlock (squeeze and excitation block). Default: True.
|
||||
"""
|
||||
|
||||
def __init__(self, block, layers, use_se=True):
|
||||
if block == 'IRBlock':
|
||||
@@ -140,6 +186,7 @@ class ResNetArcFace(nn.Module):
|
||||
self.inplanes = 64
|
||||
self.use_se = use_se
|
||||
super(ResNetArcFace, self).__init__()
|
||||
|
||||
self.conv1 = nn.Conv2d(1, 64, kernel_size=3, padding=1, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(64)
|
||||
self.prelu = nn.PReLU()
|
||||
@@ -153,6 +200,7 @@ class ResNetArcFace(nn.Module):
|
||||
self.fc5 = nn.Linear(512 * 8 * 8, 512)
|
||||
self.bn5 = nn.BatchNorm1d(512)
|
||||
|
||||
# initialization
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.xavier_normal_(m.weight)
|
||||
@@ -163,7 +211,7 @@ class ResNetArcFace(nn.Module):
|
||||
nn.init.xavier_normal_(m.weight)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
def _make_layer(self, block, planes, blocks, stride=1):
|
||||
def _make_layer(self, block, planes, num_blocks, stride=1):
|
||||
downsample = None
|
||||
if stride != 1 or self.inplanes != planes * block.expansion:
|
||||
downsample = nn.Sequential(
|
||||
@@ -173,7 +221,7 @@ class ResNetArcFace(nn.Module):
|
||||
layers = []
|
||||
layers.append(block(self.inplanes, planes, stride, downsample, use_se=self.use_se))
|
||||
self.inplanes = planes
|
||||
for _ in range(1, blocks):
|
||||
for _ in range(1, num_blocks):
|
||||
layers.append(block(self.inplanes, planes, use_se=self.use_se))
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
312
gfpgan/archs/gfpgan_bilinear_arch.py
Normal file
312
gfpgan/archs/gfpgan_bilinear_arch.py
Normal file
@@ -0,0 +1,312 @@
|
||||
import math
|
||||
import random
|
||||
import torch
|
||||
from basicsr.utils.registry import ARCH_REGISTRY
|
||||
from torch import nn
|
||||
|
||||
from .gfpganv1_arch import ResUpBlock
|
||||
from .stylegan2_bilinear_arch import (ConvLayer, EqualConv2d, EqualLinear, ResBlock, ScaledLeakyReLU,
|
||||
StyleGAN2GeneratorBilinear)
|
||||
|
||||
|
||||
class StyleGAN2GeneratorBilinearSFT(StyleGAN2GeneratorBilinear):
|
||||
"""StyleGAN2 Generator with SFT modulation (Spatial Feature Transform).
|
||||
|
||||
It is the bilinear version. It does not use the complicated UpFirDnSmooth function that is not friendly for
|
||||
deployment. It can be easily converted to the clean version: StyleGAN2GeneratorCSFT.
|
||||
|
||||
Args:
|
||||
out_size (int): The spatial size of outputs.
|
||||
num_style_feat (int): Channel number of style features. Default: 512.
|
||||
num_mlp (int): Layer number of MLP style layers. Default: 8.
|
||||
channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
|
||||
lr_mlp (float): Learning rate multiplier for mlp layers. Default: 0.01.
|
||||
narrow (float): The narrow ratio for channels. Default: 1.
|
||||
sft_half (bool): Whether to apply SFT on half of the input channels. Default: False.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
out_size,
|
||||
num_style_feat=512,
|
||||
num_mlp=8,
|
||||
channel_multiplier=2,
|
||||
lr_mlp=0.01,
|
||||
narrow=1,
|
||||
sft_half=False):
|
||||
super(StyleGAN2GeneratorBilinearSFT, self).__init__(
|
||||
out_size,
|
||||
num_style_feat=num_style_feat,
|
||||
num_mlp=num_mlp,
|
||||
channel_multiplier=channel_multiplier,
|
||||
lr_mlp=lr_mlp,
|
||||
narrow=narrow)
|
||||
self.sft_half = sft_half
|
||||
|
||||
def forward(self,
|
||||
styles,
|
||||
conditions,
|
||||
input_is_latent=False,
|
||||
noise=None,
|
||||
randomize_noise=True,
|
||||
truncation=1,
|
||||
truncation_latent=None,
|
||||
inject_index=None,
|
||||
return_latents=False):
|
||||
"""Forward function for StyleGAN2GeneratorBilinearSFT.
|
||||
|
||||
Args:
|
||||
styles (list[Tensor]): Sample codes of styles.
|
||||
conditions (list[Tensor]): SFT conditions to generators.
|
||||
input_is_latent (bool): Whether input is latent style. Default: False.
|
||||
noise (Tensor | None): Input noise or None. Default: None.
|
||||
randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True.
|
||||
truncation (float): The truncation ratio. Default: 1.
|
||||
truncation_latent (Tensor | None): The truncation latent tensor. Default: None.
|
||||
inject_index (int | None): The injection index for mixing noise. Default: None.
|
||||
return_latents (bool): Whether to return style latents. Default: False.
|
||||
"""
|
||||
# style codes -> latents with Style MLP layer
|
||||
if not input_is_latent:
|
||||
styles = [self.style_mlp(s) for s in styles]
|
||||
# noises
|
||||
if noise is None:
|
||||
if randomize_noise:
|
||||
noise = [None] * self.num_layers # for each style conv layer
|
||||
else: # use the stored noise
|
||||
noise = [getattr(self.noises, f'noise{i}') for i in range(self.num_layers)]
|
||||
# style truncation
|
||||
if truncation < 1:
|
||||
style_truncation = []
|
||||
for style in styles:
|
||||
style_truncation.append(truncation_latent + truncation * (style - truncation_latent))
|
||||
styles = style_truncation
|
||||
# get style latents with injection
|
||||
if len(styles) == 1:
|
||||
inject_index = self.num_latent
|
||||
|
||||
if styles[0].ndim < 3:
|
||||
# repeat latent code for all the layers
|
||||
latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
|
||||
else: # used for encoder with different latent code for each layer
|
||||
latent = styles[0]
|
||||
elif len(styles) == 2: # mixing noises
|
||||
if inject_index is None:
|
||||
inject_index = random.randint(1, self.num_latent - 1)
|
||||
latent1 = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
|
||||
latent2 = styles[1].unsqueeze(1).repeat(1, self.num_latent - inject_index, 1)
|
||||
latent = torch.cat([latent1, latent2], 1)
|
||||
|
||||
# main generation
|
||||
out = self.constant_input(latent.shape[0])
|
||||
out = self.style_conv1(out, latent[:, 0], noise=noise[0])
|
||||
skip = self.to_rgb1(out, latent[:, 1])
|
||||
|
||||
i = 1
|
||||
for conv1, conv2, noise1, noise2, to_rgb in zip(self.style_convs[::2], self.style_convs[1::2], noise[1::2],
|
||||
noise[2::2], self.to_rgbs):
|
||||
out = conv1(out, latent[:, i], noise=noise1)
|
||||
|
||||
# the conditions may have fewer levels
|
||||
if i < len(conditions):
|
||||
# SFT part to combine the conditions
|
||||
if self.sft_half: # only apply SFT to half of the channels
|
||||
out_same, out_sft = torch.split(out, int(out.size(1) // 2), dim=1)
|
||||
out_sft = out_sft * conditions[i - 1] + conditions[i]
|
||||
out = torch.cat([out_same, out_sft], dim=1)
|
||||
else: # apply SFT to all the channels
|
||||
out = out * conditions[i - 1] + conditions[i]
|
||||
|
||||
out = conv2(out, latent[:, i + 1], noise=noise2)
|
||||
skip = to_rgb(out, latent[:, i + 2], skip) # feature back to the rgb space
|
||||
i += 2
|
||||
|
||||
image = skip
|
||||
|
||||
if return_latents:
|
||||
return image, latent
|
||||
else:
|
||||
return image, None
|
||||
|
||||
|
||||
@ARCH_REGISTRY.register()
|
||||
class GFPGANBilinear(nn.Module):
|
||||
"""The GFPGAN architecture: Unet + StyleGAN2 decoder with SFT.
|
||||
|
||||
It is the bilinear version and it does not use the complicated UpFirDnSmooth function that is not friendly for
|
||||
deployment. It can be easily converted to the clean version: GFPGANv1Clean.
|
||||
|
||||
|
||||
Ref: GFP-GAN: Towards Real-World Blind Face Restoration with Generative Facial Prior.
|
||||
|
||||
Args:
|
||||
out_size (int): The spatial size of outputs.
|
||||
num_style_feat (int): Channel number of style features. Default: 512.
|
||||
channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
|
||||
decoder_load_path (str): The path to the pre-trained decoder model (usually, the StyleGAN2). Default: None.
|
||||
fix_decoder (bool): Whether to fix the decoder. Default: True.
|
||||
|
||||
num_mlp (int): Layer number of MLP style layers. Default: 8.
|
||||
lr_mlp (float): Learning rate multiplier for mlp layers. Default: 0.01.
|
||||
input_is_latent (bool): Whether input is latent style. Default: False.
|
||||
different_w (bool): Whether to use different latent w for different layers. Default: False.
|
||||
narrow (float): The narrow ratio for channels. Default: 1.
|
||||
sft_half (bool): Whether to apply SFT on half of the input channels. Default: False.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
out_size,
|
||||
num_style_feat=512,
|
||||
channel_multiplier=1,
|
||||
decoder_load_path=None,
|
||||
fix_decoder=True,
|
||||
# for stylegan decoder
|
||||
num_mlp=8,
|
||||
lr_mlp=0.01,
|
||||
input_is_latent=False,
|
||||
different_w=False,
|
||||
narrow=1,
|
||||
sft_half=False):
|
||||
|
||||
super(GFPGANBilinear, self).__init__()
|
||||
self.input_is_latent = input_is_latent
|
||||
self.different_w = different_w
|
||||
self.num_style_feat = num_style_feat
|
||||
|
||||
unet_narrow = narrow * 0.5 # by default, use a half of input channels
|
||||
channels = {
|
||||
'4': int(512 * unet_narrow),
|
||||
'8': int(512 * unet_narrow),
|
||||
'16': int(512 * unet_narrow),
|
||||
'32': int(512 * unet_narrow),
|
||||
'64': int(256 * channel_multiplier * unet_narrow),
|
||||
'128': int(128 * channel_multiplier * unet_narrow),
|
||||
'256': int(64 * channel_multiplier * unet_narrow),
|
||||
'512': int(32 * channel_multiplier * unet_narrow),
|
||||
'1024': int(16 * channel_multiplier * unet_narrow)
|
||||
}
|
||||
|
||||
self.log_size = int(math.log(out_size, 2))
|
||||
first_out_size = 2**(int(math.log(out_size, 2)))
|
||||
|
||||
self.conv_body_first = ConvLayer(3, channels[f'{first_out_size}'], 1, bias=True, activate=True)
|
||||
|
||||
# downsample
|
||||
in_channels = channels[f'{first_out_size}']
|
||||
self.conv_body_down = nn.ModuleList()
|
||||
for i in range(self.log_size, 2, -1):
|
||||
out_channels = channels[f'{2**(i - 1)}']
|
||||
self.conv_body_down.append(ResBlock(in_channels, out_channels))
|
||||
in_channels = out_channels
|
||||
|
||||
self.final_conv = ConvLayer(in_channels, channels['4'], 3, bias=True, activate=True)
|
||||
|
||||
# upsample
|
||||
in_channels = channels['4']
|
||||
self.conv_body_up = nn.ModuleList()
|
||||
for i in range(3, self.log_size + 1):
|
||||
out_channels = channels[f'{2**i}']
|
||||
self.conv_body_up.append(ResUpBlock(in_channels, out_channels))
|
||||
in_channels = out_channels
|
||||
|
||||
# to RGB
|
||||
self.toRGB = nn.ModuleList()
|
||||
for i in range(3, self.log_size + 1):
|
||||
self.toRGB.append(EqualConv2d(channels[f'{2**i}'], 3, 1, stride=1, padding=0, bias=True, bias_init_val=0))
|
||||
|
||||
if different_w:
|
||||
linear_out_channel = (int(math.log(out_size, 2)) * 2 - 2) * num_style_feat
|
||||
else:
|
||||
linear_out_channel = num_style_feat
|
||||
|
||||
self.final_linear = EqualLinear(
|
||||
channels['4'] * 4 * 4, linear_out_channel, bias=True, bias_init_val=0, lr_mul=1, activation=None)
|
||||
|
||||
# the decoder: stylegan2 generator with SFT modulations
|
||||
self.stylegan_decoder = StyleGAN2GeneratorBilinearSFT(
|
||||
out_size=out_size,
|
||||
num_style_feat=num_style_feat,
|
||||
num_mlp=num_mlp,
|
||||
channel_multiplier=channel_multiplier,
|
||||
lr_mlp=lr_mlp,
|
||||
narrow=narrow,
|
||||
sft_half=sft_half)
|
||||
|
||||
# load pre-trained stylegan2 model if necessary
|
||||
if decoder_load_path:
|
||||
self.stylegan_decoder.load_state_dict(
|
||||
torch.load(decoder_load_path, map_location=lambda storage, loc: storage)['params_ema'])
|
||||
# fix decoder without updating params
|
||||
if fix_decoder:
|
||||
for _, param in self.stylegan_decoder.named_parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
# for SFT modulations (scale and shift)
|
||||
self.condition_scale = nn.ModuleList()
|
||||
self.condition_shift = nn.ModuleList()
|
||||
for i in range(3, self.log_size + 1):
|
||||
out_channels = channels[f'{2**i}']
|
||||
if sft_half:
|
||||
sft_out_channels = out_channels
|
||||
else:
|
||||
sft_out_channels = out_channels * 2
|
||||
self.condition_scale.append(
|
||||
nn.Sequential(
|
||||
EqualConv2d(out_channels, out_channels, 3, stride=1, padding=1, bias=True, bias_init_val=0),
|
||||
ScaledLeakyReLU(0.2),
|
||||
EqualConv2d(out_channels, sft_out_channels, 3, stride=1, padding=1, bias=True, bias_init_val=1)))
|
||||
self.condition_shift.append(
|
||||
nn.Sequential(
|
||||
EqualConv2d(out_channels, out_channels, 3, stride=1, padding=1, bias=True, bias_init_val=0),
|
||||
ScaledLeakyReLU(0.2),
|
||||
EqualConv2d(out_channels, sft_out_channels, 3, stride=1, padding=1, bias=True, bias_init_val=0)))
|
||||
|
||||
def forward(self, x, return_latents=False, return_rgb=True, randomize_noise=True):
|
||||
"""Forward function for GFPGANBilinear.
|
||||
|
||||
Args:
|
||||
x (Tensor): Input images.
|
||||
return_latents (bool): Whether to return style latents. Default: False.
|
||||
return_rgb (bool): Whether return intermediate rgb images. Default: True.
|
||||
randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True.
|
||||
"""
|
||||
conditions = []
|
||||
unet_skips = []
|
||||
out_rgbs = []
|
||||
|
||||
# encoder
|
||||
feat = self.conv_body_first(x)
|
||||
for i in range(self.log_size - 2):
|
||||
feat = self.conv_body_down[i](feat)
|
||||
unet_skips.insert(0, feat)
|
||||
|
||||
feat = self.final_conv(feat)
|
||||
|
||||
# style code
|
||||
style_code = self.final_linear(feat.view(feat.size(0), -1))
|
||||
if self.different_w:
|
||||
style_code = style_code.view(style_code.size(0), -1, self.num_style_feat)
|
||||
|
||||
# decode
|
||||
for i in range(self.log_size - 2):
|
||||
# add unet skip
|
||||
feat = feat + unet_skips[i]
|
||||
# ResUpLayer
|
||||
feat = self.conv_body_up[i](feat)
|
||||
# generate scale and shift for SFT layers
|
||||
scale = self.condition_scale[i](feat)
|
||||
conditions.append(scale.clone())
|
||||
shift = self.condition_shift[i](feat)
|
||||
conditions.append(shift.clone())
|
||||
# generate rgb images
|
||||
if return_rgb:
|
||||
out_rgbs.append(self.toRGB[i](feat))
|
||||
|
||||
# decoder
|
||||
image, _ = self.stylegan_decoder([style_code],
|
||||
conditions,
|
||||
return_latents=return_latents,
|
||||
input_is_latent=self.input_is_latent,
|
||||
randomize_noise=randomize_noise)
|
||||
|
||||
return image, out_rgbs
|
||||
@@ -10,18 +10,18 @@ from torch.nn import functional as F
|
||||
|
||||
|
||||
class StyleGAN2GeneratorSFT(StyleGAN2Generator):
|
||||
"""StyleGAN2 Generator.
|
||||
"""StyleGAN2 Generator with SFT modulation (Spatial Feature Transform).
|
||||
|
||||
Args:
|
||||
out_size (int): The spatial size of outputs.
|
||||
num_style_feat (int): Channel number of style features. Default: 512.
|
||||
num_mlp (int): Layer number of MLP style layers. Default: 8.
|
||||
channel_multiplier (int): Channel multiplier for large networks of
|
||||
StyleGAN2. Default: 2.
|
||||
resample_kernel (list[int]): A list indicating the 1D resample kernel
|
||||
magnitude. A cross production will be applied to extent 1D resample
|
||||
kernel to 2D resample kernel. Default: [1, 3, 3, 1].
|
||||
channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
|
||||
resample_kernel (list[int]): A list indicating the 1D resample kernel magnitude. A cross production will be
|
||||
applied to extent 1D resample kernel to 2D resample kernel. Default: (1, 3, 3, 1).
|
||||
lr_mlp (float): Learning rate multiplier for mlp layers. Default: 0.01.
|
||||
narrow (float): The narrow ratio for channels. Default: 1.
|
||||
sft_half (bool): Whether to apply SFT on half of the input channels. Default: False.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
@@ -53,21 +53,18 @@ class StyleGAN2GeneratorSFT(StyleGAN2Generator):
|
||||
truncation_latent=None,
|
||||
inject_index=None,
|
||||
return_latents=False):
|
||||
"""Forward function for StyleGAN2Generator.
|
||||
"""Forward function for StyleGAN2GeneratorSFT.
|
||||
|
||||
Args:
|
||||
styles (list[Tensor]): Sample codes of styles.
|
||||
input_is_latent (bool): Whether input is latent style.
|
||||
Default: False.
|
||||
conditions (list[Tensor]): SFT conditions to generators.
|
||||
input_is_latent (bool): Whether input is latent style. Default: False.
|
||||
noise (Tensor | None): Input noise or None. Default: None.
|
||||
randomize_noise (bool): Randomize noise, used when 'noise' is
|
||||
False. Default: True.
|
||||
truncation (float): TODO. Default: 1.
|
||||
truncation_latent (Tensor | None): TODO. Default: None.
|
||||
inject_index (int | None): The injection index for mixing noise.
|
||||
Default: None.
|
||||
return_latents (bool): Whether to return style latents.
|
||||
Default: False.
|
||||
randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True.
|
||||
truncation (float): The truncation ratio. Default: 1.
|
||||
truncation_latent (Tensor | None): The truncation latent tensor. Default: None.
|
||||
inject_index (int | None): The injection index for mixing noise. Default: None.
|
||||
return_latents (bool): Whether to return style latents. Default: False.
|
||||
"""
|
||||
# style codes -> latents with Style MLP layer
|
||||
if not input_is_latent:
|
||||
@@ -84,7 +81,7 @@ class StyleGAN2GeneratorSFT(StyleGAN2Generator):
|
||||
for style in styles:
|
||||
style_truncation.append(truncation_latent + truncation * (style - truncation_latent))
|
||||
styles = style_truncation
|
||||
# get style latent with injection
|
||||
# get style latents with injection
|
||||
if len(styles) == 1:
|
||||
inject_index = self.num_latent
|
||||
|
||||
@@ -113,15 +110,15 @@ class StyleGAN2GeneratorSFT(StyleGAN2Generator):
|
||||
# the conditions may have fewer levels
|
||||
if i < len(conditions):
|
||||
# SFT part to combine the conditions
|
||||
if self.sft_half:
|
||||
if self.sft_half: # only apply SFT to half of the channels
|
||||
out_same, out_sft = torch.split(out, int(out.size(1) // 2), dim=1)
|
||||
out_sft = out_sft * conditions[i - 1] + conditions[i]
|
||||
out = torch.cat([out_same, out_sft], dim=1)
|
||||
else:
|
||||
else: # apply SFT to all the channels
|
||||
out = out * conditions[i - 1] + conditions[i]
|
||||
|
||||
out = conv2(out, latent[:, i + 1], noise=noise2)
|
||||
skip = to_rgb(out, latent[:, i + 2], skip)
|
||||
skip = to_rgb(out, latent[:, i + 2], skip) # feature back to the rgb space
|
||||
i += 2
|
||||
|
||||
image = skip
|
||||
@@ -133,17 +130,15 @@ class StyleGAN2GeneratorSFT(StyleGAN2Generator):
|
||||
|
||||
|
||||
class ConvUpLayer(nn.Module):
|
||||
"""Conv Up Layer. Bilinear upsample + Conv.
|
||||
"""Convolutional upsampling layer. It uses bilinear upsampler + Conv.
|
||||
|
||||
Args:
|
||||
in_channels (int): Channel number of the input.
|
||||
out_channels (int): Channel number of the output.
|
||||
kernel_size (int): Size of the convolving kernel.
|
||||
stride (int): Stride of the convolution. Default: 1
|
||||
padding (int): Zero-padding added to both sides of the input.
|
||||
Default: 0.
|
||||
bias (bool): If ``True``, adds a learnable bias to the output.
|
||||
Default: ``True``.
|
||||
padding (int): Zero-padding added to both sides of the input. Default: 0.
|
||||
bias (bool): If ``True``, adds a learnable bias to the output. Default: ``True``.
|
||||
bias_init_val (float): Bias initialized value. Default: 0.
|
||||
activate (bool): Whether use activateion. Default: True.
|
||||
"""
|
||||
@@ -163,6 +158,7 @@ class ConvUpLayer(nn.Module):
|
||||
self.kernel_size = kernel_size
|
||||
self.stride = stride
|
||||
self.padding = padding
|
||||
# self.scale is used to scale the convolution weights, which is related to the common initializations.
|
||||
self.scale = 1 / math.sqrt(in_channels * kernel_size**2)
|
||||
|
||||
self.weight = nn.Parameter(torch.randn(out_channels, in_channels, kernel_size, kernel_size))
|
||||
@@ -223,7 +219,26 @@ class ResUpBlock(nn.Module):
|
||||
|
||||
@ARCH_REGISTRY.register()
|
||||
class GFPGANv1(nn.Module):
|
||||
"""Unet + StyleGAN2 decoder with SFT."""
|
||||
"""The GFPGAN architecture: Unet + StyleGAN2 decoder with SFT.
|
||||
|
||||
Ref: GFP-GAN: Towards Real-World Blind Face Restoration with Generative Facial Prior.
|
||||
|
||||
Args:
|
||||
out_size (int): The spatial size of outputs.
|
||||
num_style_feat (int): Channel number of style features. Default: 512.
|
||||
channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
|
||||
resample_kernel (list[int]): A list indicating the 1D resample kernel magnitude. A cross production will be
|
||||
applied to extent 1D resample kernel to 2D resample kernel. Default: (1, 3, 3, 1).
|
||||
decoder_load_path (str): The path to the pre-trained decoder model (usually, the StyleGAN2). Default: None.
|
||||
fix_decoder (bool): Whether to fix the decoder. Default: True.
|
||||
|
||||
num_mlp (int): Layer number of MLP style layers. Default: 8.
|
||||
lr_mlp (float): Learning rate multiplier for mlp layers. Default: 0.01.
|
||||
input_is_latent (bool): Whether input is latent style. Default: False.
|
||||
different_w (bool): Whether to use different latent w for different layers. Default: False.
|
||||
narrow (float): The narrow ratio for channels. Default: 1.
|
||||
sft_half (bool): Whether to apply SFT on half of the input channels. Default: False.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -246,7 +261,7 @@ class GFPGANv1(nn.Module):
|
||||
self.different_w = different_w
|
||||
self.num_style_feat = num_style_feat
|
||||
|
||||
unet_narrow = narrow * 0.5
|
||||
unet_narrow = narrow * 0.5 # by default, use a half of input channels
|
||||
channels = {
|
||||
'4': int(512 * unet_narrow),
|
||||
'8': int(512 * unet_narrow),
|
||||
@@ -295,6 +310,7 @@ class GFPGANv1(nn.Module):
|
||||
self.final_linear = EqualLinear(
|
||||
channels['4'] * 4 * 4, linear_out_channel, bias=True, bias_init_val=0, lr_mul=1, activation=None)
|
||||
|
||||
# the decoder: stylegan2 generator with SFT modulations
|
||||
self.stylegan_decoder = StyleGAN2GeneratorSFT(
|
||||
out_size=out_size,
|
||||
num_style_feat=num_style_feat,
|
||||
@@ -305,14 +321,16 @@ class GFPGANv1(nn.Module):
|
||||
narrow=narrow,
|
||||
sft_half=sft_half)
|
||||
|
||||
# load pre-trained stylegan2 model if necessary
|
||||
if decoder_load_path:
|
||||
self.stylegan_decoder.load_state_dict(
|
||||
torch.load(decoder_load_path, map_location=lambda storage, loc: storage)['params_ema'])
|
||||
# fix decoder without updating params
|
||||
if fix_decoder:
|
||||
for _, param in self.stylegan_decoder.named_parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
# for SFT
|
||||
# for SFT modulations (scale and shift)
|
||||
self.condition_scale = nn.ModuleList()
|
||||
self.condition_shift = nn.ModuleList()
|
||||
for i in range(3, self.log_size + 1):
|
||||
@@ -332,13 +350,15 @@ class GFPGANv1(nn.Module):
|
||||
ScaledLeakyReLU(0.2),
|
||||
EqualConv2d(out_channels, sft_out_channels, 3, stride=1, padding=1, bias=True, bias_init_val=0)))
|
||||
|
||||
def forward(self,
|
||||
x,
|
||||
return_latents=False,
|
||||
save_feat_path=None,
|
||||
load_feat_path=None,
|
||||
return_rgb=True,
|
||||
randomize_noise=True):
|
||||
def forward(self, x, return_latents=False, return_rgb=True, randomize_noise=True, **kwargs):
|
||||
"""Forward function for GFPGANv1.
|
||||
|
||||
Args:
|
||||
x (Tensor): Input images.
|
||||
return_latents (bool): Whether to return style latents. Default: False.
|
||||
return_rgb (bool): Whether return intermediate rgb images. Default: True.
|
||||
randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True.
|
||||
"""
|
||||
conditions = []
|
||||
unet_skips = []
|
||||
out_rgbs = []
|
||||
@@ -362,7 +382,7 @@ class GFPGANv1(nn.Module):
|
||||
feat = feat + unet_skips[i]
|
||||
# ResUpLayer
|
||||
feat = self.conv_body_up[i](feat)
|
||||
# generate scale and shift for SFT layer
|
||||
# generate scale and shift for SFT layers
|
||||
scale = self.condition_scale[i](feat)
|
||||
conditions.append(scale.clone())
|
||||
shift = self.condition_shift[i](feat)
|
||||
@@ -371,12 +391,6 @@ class GFPGANv1(nn.Module):
|
||||
if return_rgb:
|
||||
out_rgbs.append(self.toRGB[i](feat))
|
||||
|
||||
if save_feat_path is not None:
|
||||
torch.save(conditions, save_feat_path)
|
||||
if load_feat_path is not None:
|
||||
conditions = torch.load(load_feat_path)
|
||||
conditions = [v.cuda() for v in conditions]
|
||||
|
||||
# decoder
|
||||
image, _ = self.stylegan_decoder([style_code],
|
||||
conditions,
|
||||
@@ -389,10 +403,12 @@ class GFPGANv1(nn.Module):
|
||||
|
||||
@ARCH_REGISTRY.register()
|
||||
class FacialComponentDiscriminator(nn.Module):
|
||||
"""Facial component (eyes, mouth, noise) discriminator used in GFPGAN.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super(FacialComponentDiscriminator, self).__init__()
|
||||
|
||||
# It now uses a VGG-style architectrue with fixed model size
|
||||
self.conv1 = ConvLayer(3, 64, 3, downsample=False, resample_kernel=(1, 3, 3, 1), bias=True, activate=True)
|
||||
self.conv2 = ConvLayer(64, 128, 3, downsample=True, resample_kernel=(1, 3, 3, 1), bias=True, activate=True)
|
||||
self.conv3 = ConvLayer(128, 128, 3, downsample=False, resample_kernel=(1, 3, 3, 1), bias=True, activate=True)
|
||||
@@ -400,7 +416,13 @@ class FacialComponentDiscriminator(nn.Module):
|
||||
self.conv5 = ConvLayer(256, 256, 3, downsample=False, resample_kernel=(1, 3, 3, 1), bias=True, activate=True)
|
||||
self.final_conv = ConvLayer(256, 1, 3, bias=True, activate=False)
|
||||
|
||||
def forward(self, x, return_feats=False):
|
||||
def forward(self, x, return_feats=False, **kwargs):
|
||||
"""Forward function for FacialComponentDiscriminator.
|
||||
|
||||
Args:
|
||||
x (Tensor): Input images.
|
||||
return_feats (bool): Whether to return intermediate features. Default: False.
|
||||
"""
|
||||
feat = self.conv1(x)
|
||||
feat = self.conv3(self.conv2(feat))
|
||||
rlt_feats = []
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import math
|
||||
import random
|
||||
import torch
|
||||
from basicsr.utils.registry import ARCH_REGISTRY
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
@@ -8,14 +9,17 @@ from .stylegan2_clean_arch import StyleGAN2GeneratorClean
|
||||
|
||||
|
||||
class StyleGAN2GeneratorCSFT(StyleGAN2GeneratorClean):
|
||||
"""StyleGAN2 Generator.
|
||||
"""StyleGAN2 Generator with SFT modulation (Spatial Feature Transform).
|
||||
|
||||
It is the clean version without custom compiled CUDA extensions used in StyleGAN2.
|
||||
|
||||
Args:
|
||||
out_size (int): The spatial size of outputs.
|
||||
num_style_feat (int): Channel number of style features. Default: 512.
|
||||
num_mlp (int): Layer number of MLP style layers. Default: 8.
|
||||
channel_multiplier (int): Channel multiplier for large networks of
|
||||
StyleGAN2. Default: 2.
|
||||
channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
|
||||
narrow (float): The narrow ratio for channels. Default: 1.
|
||||
sft_half (bool): Whether to apply SFT on half of the input channels. Default: False.
|
||||
"""
|
||||
|
||||
def __init__(self, out_size, num_style_feat=512, num_mlp=8, channel_multiplier=2, narrow=1, sft_half=False):
|
||||
@@ -25,7 +29,6 @@ class StyleGAN2GeneratorCSFT(StyleGAN2GeneratorClean):
|
||||
num_mlp=num_mlp,
|
||||
channel_multiplier=channel_multiplier,
|
||||
narrow=narrow)
|
||||
|
||||
self.sft_half = sft_half
|
||||
|
||||
def forward(self,
|
||||
@@ -38,21 +41,18 @@ class StyleGAN2GeneratorCSFT(StyleGAN2GeneratorClean):
|
||||
truncation_latent=None,
|
||||
inject_index=None,
|
||||
return_latents=False):
|
||||
"""Forward function for StyleGAN2Generator.
|
||||
"""Forward function for StyleGAN2GeneratorCSFT.
|
||||
|
||||
Args:
|
||||
styles (list[Tensor]): Sample codes of styles.
|
||||
input_is_latent (bool): Whether input is latent style.
|
||||
Default: False.
|
||||
conditions (list[Tensor]): SFT conditions to generators.
|
||||
input_is_latent (bool): Whether input is latent style. Default: False.
|
||||
noise (Tensor | None): Input noise or None. Default: None.
|
||||
randomize_noise (bool): Randomize noise, used when 'noise' is
|
||||
False. Default: True.
|
||||
truncation (float): TODO. Default: 1.
|
||||
truncation_latent (Tensor | None): TODO. Default: None.
|
||||
inject_index (int | None): The injection index for mixing noise.
|
||||
Default: None.
|
||||
return_latents (bool): Whether to return style latents.
|
||||
Default: False.
|
||||
randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True.
|
||||
truncation (float): The truncation ratio. Default: 1.
|
||||
truncation_latent (Tensor | None): The truncation latent tensor. Default: None.
|
||||
inject_index (int | None): The injection index for mixing noise. Default: None.
|
||||
return_latents (bool): Whether to return style latents. Default: False.
|
||||
"""
|
||||
# style codes -> latents with Style MLP layer
|
||||
if not input_is_latent:
|
||||
@@ -69,7 +69,7 @@ class StyleGAN2GeneratorCSFT(StyleGAN2GeneratorClean):
|
||||
for style in styles:
|
||||
style_truncation.append(truncation_latent + truncation * (style - truncation_latent))
|
||||
styles = style_truncation
|
||||
# get style latent with injection
|
||||
# get style latents with injection
|
||||
if len(styles) == 1:
|
||||
inject_index = self.num_latent
|
||||
|
||||
@@ -98,15 +98,15 @@ class StyleGAN2GeneratorCSFT(StyleGAN2GeneratorClean):
|
||||
# the conditions may have fewer levels
|
||||
if i < len(conditions):
|
||||
# SFT part to combine the conditions
|
||||
if self.sft_half:
|
||||
if self.sft_half: # only apply SFT to half of the channels
|
||||
out_same, out_sft = torch.split(out, int(out.size(1) // 2), dim=1)
|
||||
out_sft = out_sft * conditions[i - 1] + conditions[i]
|
||||
out = torch.cat([out_same, out_sft], dim=1)
|
||||
else:
|
||||
else: # apply SFT to all the channels
|
||||
out = out * conditions[i - 1] + conditions[i]
|
||||
|
||||
out = conv2(out, latent[:, i + 1], noise=noise2)
|
||||
skip = to_rgb(out, latent[:, i + 2], skip)
|
||||
skip = to_rgb(out, latent[:, i + 2], skip) # feature back to the rgb space
|
||||
i += 2
|
||||
|
||||
image = skip
|
||||
@@ -118,11 +118,12 @@ class StyleGAN2GeneratorCSFT(StyleGAN2GeneratorClean):
|
||||
|
||||
|
||||
class ResBlock(nn.Module):
|
||||
"""Residual block with upsampling/downsampling.
|
||||
"""Residual block with bilinear upsampling/downsampling.
|
||||
|
||||
Args:
|
||||
in_channels (int): Channel number of the input.
|
||||
out_channels (int): Channel number of the output.
|
||||
mode (str): Upsampling/downsampling mode. Options: down | up. Default: down.
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels, out_channels, mode='down'):
|
||||
@@ -148,8 +149,27 @@ class ResBlock(nn.Module):
|
||||
return out
|
||||
|
||||
|
||||
@ARCH_REGISTRY.register()
|
||||
class GFPGANv1Clean(nn.Module):
|
||||
"""GFPGANv1 Clean version."""
|
||||
"""The GFPGAN architecture: Unet + StyleGAN2 decoder with SFT.
|
||||
|
||||
It is the clean version without custom compiled CUDA extensions used in StyleGAN2.
|
||||
|
||||
Ref: GFP-GAN: Towards Real-World Blind Face Restoration with Generative Facial Prior.
|
||||
|
||||
Args:
|
||||
out_size (int): The spatial size of outputs.
|
||||
num_style_feat (int): Channel number of style features. Default: 512.
|
||||
channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
|
||||
decoder_load_path (str): The path to the pre-trained decoder model (usually, the StyleGAN2). Default: None.
|
||||
fix_decoder (bool): Whether to fix the decoder. Default: True.
|
||||
|
||||
num_mlp (int): Layer number of MLP style layers. Default: 8.
|
||||
input_is_latent (bool): Whether input is latent style. Default: False.
|
||||
different_w (bool): Whether to use different latent w for different layers. Default: False.
|
||||
narrow (float): The narrow ratio for channels. Default: 1.
|
||||
sft_half (bool): Whether to apply SFT on half of the input channels. Default: False.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -170,7 +190,7 @@ class GFPGANv1Clean(nn.Module):
|
||||
self.different_w = different_w
|
||||
self.num_style_feat = num_style_feat
|
||||
|
||||
unet_narrow = narrow * 0.5
|
||||
unet_narrow = narrow * 0.5 # by default, use a half of input channels
|
||||
channels = {
|
||||
'4': int(512 * unet_narrow),
|
||||
'8': int(512 * unet_narrow),
|
||||
@@ -218,6 +238,7 @@ class GFPGANv1Clean(nn.Module):
|
||||
|
||||
self.final_linear = nn.Linear(channels['4'] * 4 * 4, linear_out_channel)
|
||||
|
||||
# the decoder: stylegan2 generator with SFT modulations
|
||||
self.stylegan_decoder = StyleGAN2GeneratorCSFT(
|
||||
out_size=out_size,
|
||||
num_style_feat=num_style_feat,
|
||||
@@ -226,14 +247,16 @@ class GFPGANv1Clean(nn.Module):
|
||||
narrow=narrow,
|
||||
sft_half=sft_half)
|
||||
|
||||
# load pre-trained stylegan2 model if necessary
|
||||
if decoder_load_path:
|
||||
self.stylegan_decoder.load_state_dict(
|
||||
torch.load(decoder_load_path, map_location=lambda storage, loc: storage)['params_ema'])
|
||||
# fix decoder without updating params
|
||||
if fix_decoder:
|
||||
for _, param in self.stylegan_decoder.named_parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
# for SFT
|
||||
# for SFT modulations (scale and shift)
|
||||
self.condition_scale = nn.ModuleList()
|
||||
self.condition_shift = nn.ModuleList()
|
||||
for i in range(3, self.log_size + 1):
|
||||
@@ -251,13 +274,15 @@ class GFPGANv1Clean(nn.Module):
|
||||
nn.Conv2d(out_channels, out_channels, 3, 1, 1), nn.LeakyReLU(0.2, True),
|
||||
nn.Conv2d(out_channels, sft_out_channels, 3, 1, 1)))
|
||||
|
||||
def forward(self,
|
||||
x,
|
||||
return_latents=False,
|
||||
save_feat_path=None,
|
||||
load_feat_path=None,
|
||||
return_rgb=True,
|
||||
randomize_noise=True):
|
||||
def forward(self, x, return_latents=False, return_rgb=True, randomize_noise=True, **kwargs):
|
||||
"""Forward function for GFPGANv1Clean.
|
||||
|
||||
Args:
|
||||
x (Tensor): Input images.
|
||||
return_latents (bool): Whether to return style latents. Default: False.
|
||||
return_rgb (bool): Whether return intermediate rgb images. Default: True.
|
||||
randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True.
|
||||
"""
|
||||
conditions = []
|
||||
unet_skips = []
|
||||
out_rgbs = []
|
||||
@@ -273,13 +298,14 @@ class GFPGANv1Clean(nn.Module):
|
||||
style_code = self.final_linear(feat.view(feat.size(0), -1))
|
||||
if self.different_w:
|
||||
style_code = style_code.view(style_code.size(0), -1, self.num_style_feat)
|
||||
|
||||
# decode
|
||||
for i in range(self.log_size - 2):
|
||||
# add unet skip
|
||||
feat = feat + unet_skips[i]
|
||||
# ResUpLayer
|
||||
feat = self.conv_body_up[i](feat)
|
||||
# generate scale and shift for SFT layer
|
||||
# generate scale and shift for SFT layers
|
||||
scale = self.condition_scale[i](feat)
|
||||
conditions.append(scale.clone())
|
||||
shift = self.condition_shift[i](feat)
|
||||
@@ -288,12 +314,6 @@ class GFPGANv1Clean(nn.Module):
|
||||
if return_rgb:
|
||||
out_rgbs.append(self.toRGB[i](feat))
|
||||
|
||||
if save_feat_path is not None:
|
||||
torch.save(conditions, save_feat_path)
|
||||
if load_feat_path is not None:
|
||||
conditions = torch.load(load_feat_path)
|
||||
conditions = [v.cuda() for v in conditions]
|
||||
|
||||
# decoder
|
||||
image, _ = self.stylegan_decoder([style_code],
|
||||
conditions,
|
||||
|
||||
658
gfpgan/archs/restoreformer_arch.py
Normal file
658
gfpgan/archs/restoreformer_arch.py
Normal file
@@ -0,0 +1,658 @@
|
||||
"""Modified from https://github.com/wzhouxiff/RestoreFormer
|
||||
"""
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class VectorQuantizer(nn.Module):
|
||||
"""
|
||||
see https://github.com/MishaLaskin/vqvae/blob/d761a999e2267766400dc646d82d3ac3657771d4/models/quantizer.py
|
||||
____________________________________________
|
||||
Discretization bottleneck part of the VQ-VAE.
|
||||
Inputs:
|
||||
- n_e : number of embeddings
|
||||
- e_dim : dimension of embedding
|
||||
- beta : commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2
|
||||
_____________________________________________
|
||||
"""
|
||||
|
||||
def __init__(self, n_e, e_dim, beta):
|
||||
super(VectorQuantizer, self).__init__()
|
||||
self.n_e = n_e
|
||||
self.e_dim = e_dim
|
||||
self.beta = beta
|
||||
|
||||
self.embedding = nn.Embedding(self.n_e, self.e_dim)
|
||||
self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
|
||||
|
||||
def forward(self, z):
|
||||
"""
|
||||
Inputs the output of the encoder network z and maps it to a discrete
|
||||
one-hot vector that is the index of the closest embedding vector e_j
|
||||
z (continuous) -> z_q (discrete)
|
||||
z.shape = (batch, channel, height, width)
|
||||
quantization pipeline:
|
||||
1. get encoder input (B,C,H,W)
|
||||
2. flatten input to (B*H*W,C)
|
||||
"""
|
||||
# reshape z -> (batch, height, width, channel) and flatten
|
||||
z = z.permute(0, 2, 3, 1).contiguous()
|
||||
z_flattened = z.view(-1, self.e_dim)
|
||||
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
|
||||
|
||||
d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \
|
||||
torch.sum(self.embedding.weight**2, dim=1) - 2 * \
|
||||
torch.matmul(z_flattened, self.embedding.weight.t())
|
||||
|
||||
# could possible replace this here
|
||||
# #\start...
|
||||
# find closest encodings
|
||||
|
||||
min_value, min_encoding_indices = torch.min(d, dim=1)
|
||||
|
||||
min_encoding_indices = min_encoding_indices.unsqueeze(1)
|
||||
|
||||
min_encodings = torch.zeros(min_encoding_indices.shape[0], self.n_e).to(z)
|
||||
min_encodings.scatter_(1, min_encoding_indices, 1)
|
||||
|
||||
# dtype min encodings: torch.float32
|
||||
# min_encodings shape: torch.Size([2048, 512])
|
||||
# min_encoding_indices.shape: torch.Size([2048, 1])
|
||||
|
||||
# get quantized latent vectors
|
||||
z_q = torch.matmul(min_encodings, self.embedding.weight).view(z.shape)
|
||||
# .........\end
|
||||
|
||||
# with:
|
||||
# .........\start
|
||||
# min_encoding_indices = torch.argmin(d, dim=1)
|
||||
# z_q = self.embedding(min_encoding_indices)
|
||||
# ......\end......... (TODO)
|
||||
|
||||
# compute loss for embedding
|
||||
loss = torch.mean((z_q.detach() - z)**2) + self.beta * torch.mean((z_q - z.detach())**2)
|
||||
|
||||
# preserve gradients
|
||||
z_q = z + (z_q - z).detach()
|
||||
|
||||
# perplexity
|
||||
|
||||
e_mean = torch.mean(min_encodings, dim=0)
|
||||
perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + 1e-10)))
|
||||
|
||||
# reshape back to match original input shape
|
||||
z_q = z_q.permute(0, 3, 1, 2).contiguous()
|
||||
|
||||
return z_q, loss, (perplexity, min_encodings, min_encoding_indices, d)
|
||||
|
||||
def get_codebook_entry(self, indices, shape):
|
||||
# shape specifying (batch, height, width, channel)
|
||||
# TODO: check for more easy handling with nn.Embedding
|
||||
min_encodings = torch.zeros(indices.shape[0], self.n_e).to(indices)
|
||||
min_encodings.scatter_(1, indices[:, None], 1)
|
||||
|
||||
# get quantized latent vectors
|
||||
z_q = torch.matmul(min_encodings.float(), self.embedding.weight)
|
||||
|
||||
if shape is not None:
|
||||
z_q = z_q.view(shape)
|
||||
|
||||
# reshape back to match original input shape
|
||||
z_q = z_q.permute(0, 3, 1, 2).contiguous()
|
||||
|
||||
return z_q
|
||||
|
||||
|
||||
# pytorch_diffusion + derived encoder decoder
|
||||
def nonlinearity(x):
|
||||
# swish
|
||||
return x * torch.sigmoid(x)
|
||||
|
||||
|
||||
def Normalize(in_channels):
|
||||
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
||||
|
||||
|
||||
class Upsample(nn.Module):
|
||||
|
||||
def __init__(self, in_channels, with_conv):
|
||||
super().__init__()
|
||||
self.with_conv = with_conv
|
||||
if self.with_conv:
|
||||
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
def forward(self, x):
|
||||
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode='nearest')
|
||||
if self.with_conv:
|
||||
x = self.conv(x)
|
||||
return x
|
||||
|
||||
|
||||
class Downsample(nn.Module):
|
||||
|
||||
def __init__(self, in_channels, with_conv):
|
||||
super().__init__()
|
||||
self.with_conv = with_conv
|
||||
if self.with_conv:
|
||||
# no asymmetric padding in torch conv, must do it ourselves
|
||||
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
|
||||
|
||||
def forward(self, x):
|
||||
if self.with_conv:
|
||||
pad = (0, 1, 0, 1)
|
||||
x = torch.nn.functional.pad(x, pad, mode='constant', value=0)
|
||||
x = self.conv(x)
|
||||
else:
|
||||
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
|
||||
return x
|
||||
|
||||
|
||||
class ResnetBlock(nn.Module):
|
||||
|
||||
def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, dropout, temb_channels=512):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
out_channels = in_channels if out_channels is None else out_channels
|
||||
self.out_channels = out_channels
|
||||
self.use_conv_shortcut = conv_shortcut
|
||||
|
||||
self.norm1 = Normalize(in_channels)
|
||||
self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
if temb_channels > 0:
|
||||
self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
|
||||
self.norm2 = Normalize(out_channels)
|
||||
self.dropout = torch.nn.Dropout(dropout)
|
||||
self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
if self.in_channels != self.out_channels:
|
||||
if self.use_conv_shortcut:
|
||||
self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
else:
|
||||
self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
||||
|
||||
def forward(self, x, temb):
|
||||
h = x
|
||||
h = self.norm1(h)
|
||||
h = nonlinearity(h)
|
||||
h = self.conv1(h)
|
||||
|
||||
if temb is not None:
|
||||
h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
|
||||
|
||||
h = self.norm2(h)
|
||||
h = nonlinearity(h)
|
||||
h = self.dropout(h)
|
||||
h = self.conv2(h)
|
||||
|
||||
if self.in_channels != self.out_channels:
|
||||
if self.use_conv_shortcut:
|
||||
x = self.conv_shortcut(x)
|
||||
else:
|
||||
x = self.nin_shortcut(x)
|
||||
|
||||
return x + h
|
||||
|
||||
|
||||
class MultiHeadAttnBlock(nn.Module):
|
||||
|
||||
def __init__(self, in_channels, head_size=1):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.head_size = head_size
|
||||
self.att_size = in_channels // head_size
|
||||
assert (in_channels % head_size == 0), 'The size of head should be divided by the number of channels.'
|
||||
|
||||
self.norm1 = Normalize(in_channels)
|
||||
self.norm2 = Normalize(in_channels)
|
||||
|
||||
self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.num = 0
|
||||
|
||||
def forward(self, x, y=None):
|
||||
h_ = x
|
||||
h_ = self.norm1(h_)
|
||||
if y is None:
|
||||
y = h_
|
||||
else:
|
||||
y = self.norm2(y)
|
||||
|
||||
q = self.q(y)
|
||||
k = self.k(h_)
|
||||
v = self.v(h_)
|
||||
|
||||
# compute attention
|
||||
b, c, h, w = q.shape
|
||||
q = q.reshape(b, self.head_size, self.att_size, h * w)
|
||||
q = q.permute(0, 3, 1, 2) # b, hw, head, att
|
||||
|
||||
k = k.reshape(b, self.head_size, self.att_size, h * w)
|
||||
k = k.permute(0, 3, 1, 2)
|
||||
|
||||
v = v.reshape(b, self.head_size, self.att_size, h * w)
|
||||
v = v.permute(0, 3, 1, 2)
|
||||
|
||||
q = q.transpose(1, 2)
|
||||
v = v.transpose(1, 2)
|
||||
k = k.transpose(1, 2).transpose(2, 3)
|
||||
|
||||
scale = int(self.att_size)**(-0.5)
|
||||
q.mul_(scale)
|
||||
w_ = torch.matmul(q, k)
|
||||
w_ = F.softmax(w_, dim=3)
|
||||
|
||||
w_ = w_.matmul(v)
|
||||
|
||||
w_ = w_.transpose(1, 2).contiguous() # [b, h*w, head, att]
|
||||
w_ = w_.view(b, h, w, -1)
|
||||
w_ = w_.permute(0, 3, 1, 2)
|
||||
|
||||
w_ = self.proj_out(w_)
|
||||
|
||||
return x + w_
|
||||
|
||||
|
||||
class MultiHeadEncoder(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
ch,
|
||||
out_ch,
|
||||
ch_mult=(1, 2, 4, 8),
|
||||
num_res_blocks=2,
|
||||
attn_resolutions=(16, ),
|
||||
dropout=0.0,
|
||||
resamp_with_conv=True,
|
||||
in_channels=3,
|
||||
resolution=512,
|
||||
z_channels=256,
|
||||
double_z=True,
|
||||
enable_mid=True,
|
||||
head_size=1,
|
||||
**ignore_kwargs):
|
||||
super().__init__()
|
||||
self.ch = ch
|
||||
self.temb_ch = 0
|
||||
self.num_resolutions = len(ch_mult)
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.resolution = resolution
|
||||
self.in_channels = in_channels
|
||||
self.enable_mid = enable_mid
|
||||
|
||||
# downsampling
|
||||
self.conv_in = torch.nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
curr_res = resolution
|
||||
in_ch_mult = (1, ) + tuple(ch_mult)
|
||||
self.down = nn.ModuleList()
|
||||
for i_level in range(self.num_resolutions):
|
||||
block = nn.ModuleList()
|
||||
attn = nn.ModuleList()
|
||||
block_in = ch * in_ch_mult[i_level]
|
||||
block_out = ch * ch_mult[i_level]
|
||||
for i_block in range(self.num_res_blocks):
|
||||
block.append(
|
||||
ResnetBlock(
|
||||
in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout))
|
||||
block_in = block_out
|
||||
if curr_res in attn_resolutions:
|
||||
attn.append(MultiHeadAttnBlock(block_in, head_size))
|
||||
down = nn.Module()
|
||||
down.block = block
|
||||
down.attn = attn
|
||||
if i_level != self.num_resolutions - 1:
|
||||
down.downsample = Downsample(block_in, resamp_with_conv)
|
||||
curr_res = curr_res // 2
|
||||
self.down.append(down)
|
||||
|
||||
# middle
|
||||
if self.enable_mid:
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = ResnetBlock(
|
||||
in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout)
|
||||
self.mid.attn_1 = MultiHeadAttnBlock(block_in, head_size)
|
||||
self.mid.block_2 = ResnetBlock(
|
||||
in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout)
|
||||
|
||||
# end
|
||||
self.norm_out = Normalize(block_in)
|
||||
self.conv_out = torch.nn.Conv2d(
|
||||
block_in, 2 * z_channels if double_z else z_channels, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
def forward(self, x):
|
||||
hs = {}
|
||||
# timestep embedding
|
||||
temb = None
|
||||
|
||||
# downsampling
|
||||
h = self.conv_in(x)
|
||||
hs['in'] = h
|
||||
for i_level in range(self.num_resolutions):
|
||||
for i_block in range(self.num_res_blocks):
|
||||
h = self.down[i_level].block[i_block](h, temb)
|
||||
if len(self.down[i_level].attn) > 0:
|
||||
h = self.down[i_level].attn[i_block](h)
|
||||
|
||||
if i_level != self.num_resolutions - 1:
|
||||
# hs.append(h)
|
||||
hs['block_' + str(i_level)] = h
|
||||
h = self.down[i_level].downsample(h)
|
||||
|
||||
# middle
|
||||
# h = hs[-1]
|
||||
if self.enable_mid:
|
||||
h = self.mid.block_1(h, temb)
|
||||
hs['block_' + str(i_level) + '_atten'] = h
|
||||
h = self.mid.attn_1(h)
|
||||
h = self.mid.block_2(h, temb)
|
||||
hs['mid_atten'] = h
|
||||
|
||||
# end
|
||||
h = self.norm_out(h)
|
||||
h = nonlinearity(h)
|
||||
h = self.conv_out(h)
|
||||
# hs.append(h)
|
||||
hs['out'] = h
|
||||
|
||||
return hs
|
||||
|
||||
|
||||
class MultiHeadDecoder(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
ch,
|
||||
out_ch,
|
||||
ch_mult=(1, 2, 4, 8),
|
||||
num_res_blocks=2,
|
||||
attn_resolutions=(16, ),
|
||||
dropout=0.0,
|
||||
resamp_with_conv=True,
|
||||
in_channels=3,
|
||||
resolution=512,
|
||||
z_channels=256,
|
||||
give_pre_end=False,
|
||||
enable_mid=True,
|
||||
head_size=1,
|
||||
**ignorekwargs):
|
||||
super().__init__()
|
||||
self.ch = ch
|
||||
self.temb_ch = 0
|
||||
self.num_resolutions = len(ch_mult)
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.resolution = resolution
|
||||
self.in_channels = in_channels
|
||||
self.give_pre_end = give_pre_end
|
||||
self.enable_mid = enable_mid
|
||||
|
||||
# compute in_ch_mult, block_in and curr_res at lowest res
|
||||
block_in = ch * ch_mult[self.num_resolutions - 1]
|
||||
curr_res = resolution // 2**(self.num_resolutions - 1)
|
||||
self.z_shape = (1, z_channels, curr_res, curr_res)
|
||||
print('Working with z of shape {} = {} dimensions.'.format(self.z_shape, np.prod(self.z_shape)))
|
||||
|
||||
# z to block_in
|
||||
self.conv_in = torch.nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
# middle
|
||||
if self.enable_mid:
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = ResnetBlock(
|
||||
in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout)
|
||||
self.mid.attn_1 = MultiHeadAttnBlock(block_in, head_size)
|
||||
self.mid.block_2 = ResnetBlock(
|
||||
in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout)
|
||||
|
||||
# upsampling
|
||||
self.up = nn.ModuleList()
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
block = nn.ModuleList()
|
||||
attn = nn.ModuleList()
|
||||
block_out = ch * ch_mult[i_level]
|
||||
for i_block in range(self.num_res_blocks + 1):
|
||||
block.append(
|
||||
ResnetBlock(
|
||||
in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout))
|
||||
block_in = block_out
|
||||
if curr_res in attn_resolutions:
|
||||
attn.append(MultiHeadAttnBlock(block_in, head_size))
|
||||
up = nn.Module()
|
||||
up.block = block
|
||||
up.attn = attn
|
||||
if i_level != 0:
|
||||
up.upsample = Upsample(block_in, resamp_with_conv)
|
||||
curr_res = curr_res * 2
|
||||
self.up.insert(0, up) # prepend to get consistent order
|
||||
|
||||
# end
|
||||
self.norm_out = Normalize(block_in)
|
||||
self.conv_out = torch.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
def forward(self, z):
|
||||
# assert z.shape[1:] == self.z_shape[1:]
|
||||
self.last_z_shape = z.shape
|
||||
|
||||
# timestep embedding
|
||||
temb = None
|
||||
|
||||
# z to block_in
|
||||
h = self.conv_in(z)
|
||||
|
||||
# middle
|
||||
if self.enable_mid:
|
||||
h = self.mid.block_1(h, temb)
|
||||
h = self.mid.attn_1(h)
|
||||
h = self.mid.block_2(h, temb)
|
||||
|
||||
# upsampling
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
for i_block in range(self.num_res_blocks + 1):
|
||||
h = self.up[i_level].block[i_block](h, temb)
|
||||
if len(self.up[i_level].attn) > 0:
|
||||
h = self.up[i_level].attn[i_block](h)
|
||||
if i_level != 0:
|
||||
h = self.up[i_level].upsample(h)
|
||||
|
||||
# end
|
||||
if self.give_pre_end:
|
||||
return h
|
||||
|
||||
h = self.norm_out(h)
|
||||
h = nonlinearity(h)
|
||||
h = self.conv_out(h)
|
||||
return h
|
||||
|
||||
|
||||
class MultiHeadDecoderTransformer(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
ch,
|
||||
out_ch,
|
||||
ch_mult=(1, 2, 4, 8),
|
||||
num_res_blocks=2,
|
||||
attn_resolutions=(16, ),
|
||||
dropout=0.0,
|
||||
resamp_with_conv=True,
|
||||
in_channels=3,
|
||||
resolution=512,
|
||||
z_channels=256,
|
||||
give_pre_end=False,
|
||||
enable_mid=True,
|
||||
head_size=1,
|
||||
**ignorekwargs):
|
||||
super().__init__()
|
||||
self.ch = ch
|
||||
self.temb_ch = 0
|
||||
self.num_resolutions = len(ch_mult)
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.resolution = resolution
|
||||
self.in_channels = in_channels
|
||||
self.give_pre_end = give_pre_end
|
||||
self.enable_mid = enable_mid
|
||||
|
||||
# compute in_ch_mult, block_in and curr_res at lowest res
|
||||
block_in = ch * ch_mult[self.num_resolutions - 1]
|
||||
curr_res = resolution // 2**(self.num_resolutions - 1)
|
||||
self.z_shape = (1, z_channels, curr_res, curr_res)
|
||||
print('Working with z of shape {} = {} dimensions.'.format(self.z_shape, np.prod(self.z_shape)))
|
||||
|
||||
# z to block_in
|
||||
self.conv_in = torch.nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
# middle
|
||||
if self.enable_mid:
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = ResnetBlock(
|
||||
in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout)
|
||||
self.mid.attn_1 = MultiHeadAttnBlock(block_in, head_size)
|
||||
self.mid.block_2 = ResnetBlock(
|
||||
in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout)
|
||||
|
||||
# upsampling
|
||||
self.up = nn.ModuleList()
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
block = nn.ModuleList()
|
||||
attn = nn.ModuleList()
|
||||
block_out = ch * ch_mult[i_level]
|
||||
for i_block in range(self.num_res_blocks + 1):
|
||||
block.append(
|
||||
ResnetBlock(
|
||||
in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout))
|
||||
block_in = block_out
|
||||
if curr_res in attn_resolutions:
|
||||
attn.append(MultiHeadAttnBlock(block_in, head_size))
|
||||
up = nn.Module()
|
||||
up.block = block
|
||||
up.attn = attn
|
||||
if i_level != 0:
|
||||
up.upsample = Upsample(block_in, resamp_with_conv)
|
||||
curr_res = curr_res * 2
|
||||
self.up.insert(0, up) # prepend to get consistent order
|
||||
|
||||
# end
|
||||
self.norm_out = Normalize(block_in)
|
||||
self.conv_out = torch.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
def forward(self, z, hs):
|
||||
# assert z.shape[1:] == self.z_shape[1:]
|
||||
# self.last_z_shape = z.shape
|
||||
|
||||
# timestep embedding
|
||||
temb = None
|
||||
|
||||
# z to block_in
|
||||
h = self.conv_in(z)
|
||||
|
||||
# middle
|
||||
if self.enable_mid:
|
||||
h = self.mid.block_1(h, temb)
|
||||
h = self.mid.attn_1(h, hs['mid_atten'])
|
||||
h = self.mid.block_2(h, temb)
|
||||
|
||||
# upsampling
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
for i_block in range(self.num_res_blocks + 1):
|
||||
h = self.up[i_level].block[i_block](h, temb)
|
||||
if len(self.up[i_level].attn) > 0:
|
||||
h = self.up[i_level].attn[i_block](h, hs['block_' + str(i_level) + '_atten'])
|
||||
# hfeature = h.clone()
|
||||
if i_level != 0:
|
||||
h = self.up[i_level].upsample(h)
|
||||
|
||||
# end
|
||||
if self.give_pre_end:
|
||||
return h
|
||||
|
||||
h = self.norm_out(h)
|
||||
h = nonlinearity(h)
|
||||
h = self.conv_out(h)
|
||||
return h
|
||||
|
||||
|
||||
class RestoreFormer(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
n_embed=1024,
|
||||
embed_dim=256,
|
||||
ch=64,
|
||||
out_ch=3,
|
||||
ch_mult=(1, 2, 2, 4, 4, 8),
|
||||
num_res_blocks=2,
|
||||
attn_resolutions=(16, ),
|
||||
dropout=0.0,
|
||||
in_channels=3,
|
||||
resolution=512,
|
||||
z_channels=256,
|
||||
double_z=False,
|
||||
enable_mid=True,
|
||||
fix_decoder=False,
|
||||
fix_codebook=True,
|
||||
fix_encoder=False,
|
||||
head_size=8):
|
||||
super(RestoreFormer, self).__init__()
|
||||
|
||||
self.encoder = MultiHeadEncoder(
|
||||
ch=ch,
|
||||
out_ch=out_ch,
|
||||
ch_mult=ch_mult,
|
||||
num_res_blocks=num_res_blocks,
|
||||
attn_resolutions=attn_resolutions,
|
||||
dropout=dropout,
|
||||
in_channels=in_channels,
|
||||
resolution=resolution,
|
||||
z_channels=z_channels,
|
||||
double_z=double_z,
|
||||
enable_mid=enable_mid,
|
||||
head_size=head_size)
|
||||
self.decoder = MultiHeadDecoderTransformer(
|
||||
ch=ch,
|
||||
out_ch=out_ch,
|
||||
ch_mult=ch_mult,
|
||||
num_res_blocks=num_res_blocks,
|
||||
attn_resolutions=attn_resolutions,
|
||||
dropout=dropout,
|
||||
in_channels=in_channels,
|
||||
resolution=resolution,
|
||||
z_channels=z_channels,
|
||||
enable_mid=enable_mid,
|
||||
head_size=head_size)
|
||||
|
||||
self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25)
|
||||
|
||||
self.quant_conv = torch.nn.Conv2d(z_channels, embed_dim, 1)
|
||||
self.post_quant_conv = torch.nn.Conv2d(embed_dim, z_channels, 1)
|
||||
|
||||
if fix_decoder:
|
||||
for _, param in self.decoder.named_parameters():
|
||||
param.requires_grad = False
|
||||
for _, param in self.post_quant_conv.named_parameters():
|
||||
param.requires_grad = False
|
||||
for _, param in self.quantize.named_parameters():
|
||||
param.requires_grad = False
|
||||
elif fix_codebook:
|
||||
for _, param in self.quantize.named_parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
if fix_encoder:
|
||||
for _, param in self.encoder.named_parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def encode(self, x):
|
||||
|
||||
hs = self.encoder(x)
|
||||
h = self.quant_conv(hs['out'])
|
||||
quant, emb_loss, info = self.quantize(h)
|
||||
return quant, emb_loss, info, hs
|
||||
|
||||
def decode(self, quant, hs):
|
||||
quant = self.post_quant_conv(quant)
|
||||
dec = self.decoder(quant, hs)
|
||||
|
||||
return dec
|
||||
|
||||
def forward(self, input, **kwargs):
|
||||
quant, diff, info, hs = self.encode(input)
|
||||
dec = self.decode(quant, hs)
|
||||
|
||||
return dec, None
|
||||
613
gfpgan/archs/stylegan2_bilinear_arch.py
Normal file
613
gfpgan/archs/stylegan2_bilinear_arch.py
Normal file
@@ -0,0 +1,613 @@
|
||||
import math
|
||||
import random
|
||||
import torch
|
||||
from basicsr.ops.fused_act import FusedLeakyReLU, fused_leaky_relu
|
||||
from basicsr.utils.registry import ARCH_REGISTRY
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
|
||||
class NormStyleCode(nn.Module):
|
||||
|
||||
def forward(self, x):
|
||||
"""Normalize the style codes.
|
||||
|
||||
Args:
|
||||
x (Tensor): Style codes with shape (b, c).
|
||||
|
||||
Returns:
|
||||
Tensor: Normalized tensor.
|
||||
"""
|
||||
return x * torch.rsqrt(torch.mean(x**2, dim=1, keepdim=True) + 1e-8)
|
||||
|
||||
|
||||
class EqualLinear(nn.Module):
|
||||
"""Equalized Linear as StyleGAN2.
|
||||
|
||||
Args:
|
||||
in_channels (int): Size of each sample.
|
||||
out_channels (int): Size of each output sample.
|
||||
bias (bool): If set to ``False``, the layer will not learn an additive
|
||||
bias. Default: ``True``.
|
||||
bias_init_val (float): Bias initialized value. Default: 0.
|
||||
lr_mul (float): Learning rate multiplier. Default: 1.
|
||||
activation (None | str): The activation after ``linear`` operation.
|
||||
Supported: 'fused_lrelu', None. Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels, out_channels, bias=True, bias_init_val=0, lr_mul=1, activation=None):
|
||||
super(EqualLinear, self).__init__()
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.lr_mul = lr_mul
|
||||
self.activation = activation
|
||||
if self.activation not in ['fused_lrelu', None]:
|
||||
raise ValueError(f'Wrong activation value in EqualLinear: {activation}'
|
||||
"Supported ones are: ['fused_lrelu', None].")
|
||||
self.scale = (1 / math.sqrt(in_channels)) * lr_mul
|
||||
|
||||
self.weight = nn.Parameter(torch.randn(out_channels, in_channels).div_(lr_mul))
|
||||
if bias:
|
||||
self.bias = nn.Parameter(torch.zeros(out_channels).fill_(bias_init_val))
|
||||
else:
|
||||
self.register_parameter('bias', None)
|
||||
|
||||
def forward(self, x):
|
||||
if self.bias is None:
|
||||
bias = None
|
||||
else:
|
||||
bias = self.bias * self.lr_mul
|
||||
if self.activation == 'fused_lrelu':
|
||||
out = F.linear(x, self.weight * self.scale)
|
||||
out = fused_leaky_relu(out, bias)
|
||||
else:
|
||||
out = F.linear(x, self.weight * self.scale, bias=bias)
|
||||
return out
|
||||
|
||||
def __repr__(self):
|
||||
return (f'{self.__class__.__name__}(in_channels={self.in_channels}, '
|
||||
f'out_channels={self.out_channels}, bias={self.bias is not None})')
|
||||
|
||||
|
||||
class ModulatedConv2d(nn.Module):
|
||||
"""Modulated Conv2d used in StyleGAN2.
|
||||
|
||||
There is no bias in ModulatedConv2d.
|
||||
|
||||
Args:
|
||||
in_channels (int): Channel number of the input.
|
||||
out_channels (int): Channel number of the output.
|
||||
kernel_size (int): Size of the convolving kernel.
|
||||
num_style_feat (int): Channel number of style features.
|
||||
demodulate (bool): Whether to demodulate in the conv layer.
|
||||
Default: True.
|
||||
sample_mode (str | None): Indicating 'upsample', 'downsample' or None.
|
||||
Default: None.
|
||||
eps (float): A value added to the denominator for numerical stability.
|
||||
Default: 1e-8.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
num_style_feat,
|
||||
demodulate=True,
|
||||
sample_mode=None,
|
||||
eps=1e-8,
|
||||
interpolation_mode='bilinear'):
|
||||
super(ModulatedConv2d, self).__init__()
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.kernel_size = kernel_size
|
||||
self.demodulate = demodulate
|
||||
self.sample_mode = sample_mode
|
||||
self.eps = eps
|
||||
self.interpolation_mode = interpolation_mode
|
||||
if self.interpolation_mode == 'nearest':
|
||||
self.align_corners = None
|
||||
else:
|
||||
self.align_corners = False
|
||||
|
||||
self.scale = 1 / math.sqrt(in_channels * kernel_size**2)
|
||||
# modulation inside each modulated conv
|
||||
self.modulation = EqualLinear(
|
||||
num_style_feat, in_channels, bias=True, bias_init_val=1, lr_mul=1, activation=None)
|
||||
|
||||
self.weight = nn.Parameter(torch.randn(1, out_channels, in_channels, kernel_size, kernel_size))
|
||||
self.padding = kernel_size // 2
|
||||
|
||||
def forward(self, x, style):
|
||||
"""Forward function.
|
||||
|
||||
Args:
|
||||
x (Tensor): Tensor with shape (b, c, h, w).
|
||||
style (Tensor): Tensor with shape (b, num_style_feat).
|
||||
|
||||
Returns:
|
||||
Tensor: Modulated tensor after convolution.
|
||||
"""
|
||||
b, c, h, w = x.shape # c = c_in
|
||||
# weight modulation
|
||||
style = self.modulation(style).view(b, 1, c, 1, 1)
|
||||
# self.weight: (1, c_out, c_in, k, k); style: (b, 1, c, 1, 1)
|
||||
weight = self.scale * self.weight * style # (b, c_out, c_in, k, k)
|
||||
|
||||
if self.demodulate:
|
||||
demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + self.eps)
|
||||
weight = weight * demod.view(b, self.out_channels, 1, 1, 1)
|
||||
|
||||
weight = weight.view(b * self.out_channels, c, self.kernel_size, self.kernel_size)
|
||||
|
||||
if self.sample_mode == 'upsample':
|
||||
x = F.interpolate(x, scale_factor=2, mode=self.interpolation_mode, align_corners=self.align_corners)
|
||||
elif self.sample_mode == 'downsample':
|
||||
x = F.interpolate(x, scale_factor=0.5, mode=self.interpolation_mode, align_corners=self.align_corners)
|
||||
|
||||
b, c, h, w = x.shape
|
||||
x = x.view(1, b * c, h, w)
|
||||
# weight: (b*c_out, c_in, k, k), groups=b
|
||||
out = F.conv2d(x, weight, padding=self.padding, groups=b)
|
||||
out = out.view(b, self.out_channels, *out.shape[2:4])
|
||||
|
||||
return out
|
||||
|
||||
def __repr__(self):
|
||||
return (f'{self.__class__.__name__}(in_channels={self.in_channels}, '
|
||||
f'out_channels={self.out_channels}, '
|
||||
f'kernel_size={self.kernel_size}, '
|
||||
f'demodulate={self.demodulate}, sample_mode={self.sample_mode})')
|
||||
|
||||
|
||||
class StyleConv(nn.Module):
|
||||
"""Style conv.
|
||||
|
||||
Args:
|
||||
in_channels (int): Channel number of the input.
|
||||
out_channels (int): Channel number of the output.
|
||||
kernel_size (int): Size of the convolving kernel.
|
||||
num_style_feat (int): Channel number of style features.
|
||||
demodulate (bool): Whether demodulate in the conv layer. Default: True.
|
||||
sample_mode (str | None): Indicating 'upsample', 'downsample' or None.
|
||||
Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
num_style_feat,
|
||||
demodulate=True,
|
||||
sample_mode=None,
|
||||
interpolation_mode='bilinear'):
|
||||
super(StyleConv, self).__init__()
|
||||
self.modulated_conv = ModulatedConv2d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
num_style_feat,
|
||||
demodulate=demodulate,
|
||||
sample_mode=sample_mode,
|
||||
interpolation_mode=interpolation_mode)
|
||||
self.weight = nn.Parameter(torch.zeros(1)) # for noise injection
|
||||
self.activate = FusedLeakyReLU(out_channels)
|
||||
|
||||
def forward(self, x, style, noise=None):
|
||||
# modulate
|
||||
out = self.modulated_conv(x, style)
|
||||
# noise injection
|
||||
if noise is None:
|
||||
b, _, h, w = out.shape
|
||||
noise = out.new_empty(b, 1, h, w).normal_()
|
||||
out = out + self.weight * noise
|
||||
# activation (with bias)
|
||||
out = self.activate(out)
|
||||
return out
|
||||
|
||||
|
||||
class ToRGB(nn.Module):
|
||||
"""To RGB from features.
|
||||
|
||||
Args:
|
||||
in_channels (int): Channel number of input.
|
||||
num_style_feat (int): Channel number of style features.
|
||||
upsample (bool): Whether to upsample. Default: True.
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels, num_style_feat, upsample=True, interpolation_mode='bilinear'):
|
||||
super(ToRGB, self).__init__()
|
||||
self.upsample = upsample
|
||||
self.interpolation_mode = interpolation_mode
|
||||
if self.interpolation_mode == 'nearest':
|
||||
self.align_corners = None
|
||||
else:
|
||||
self.align_corners = False
|
||||
self.modulated_conv = ModulatedConv2d(
|
||||
in_channels,
|
||||
3,
|
||||
kernel_size=1,
|
||||
num_style_feat=num_style_feat,
|
||||
demodulate=False,
|
||||
sample_mode=None,
|
||||
interpolation_mode=interpolation_mode)
|
||||
self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1))
|
||||
|
||||
def forward(self, x, style, skip=None):
|
||||
"""Forward function.
|
||||
|
||||
Args:
|
||||
x (Tensor): Feature tensor with shape (b, c, h, w).
|
||||
style (Tensor): Tensor with shape (b, num_style_feat).
|
||||
skip (Tensor): Base/skip tensor. Default: None.
|
||||
|
||||
Returns:
|
||||
Tensor: RGB images.
|
||||
"""
|
||||
out = self.modulated_conv(x, style)
|
||||
out = out + self.bias
|
||||
if skip is not None:
|
||||
if self.upsample:
|
||||
skip = F.interpolate(
|
||||
skip, scale_factor=2, mode=self.interpolation_mode, align_corners=self.align_corners)
|
||||
out = out + skip
|
||||
return out
|
||||
|
||||
|
||||
class ConstantInput(nn.Module):
|
||||
"""Constant input.
|
||||
|
||||
Args:
|
||||
num_channel (int): Channel number of constant input.
|
||||
size (int): Spatial size of constant input.
|
||||
"""
|
||||
|
||||
def __init__(self, num_channel, size):
|
||||
super(ConstantInput, self).__init__()
|
||||
self.weight = nn.Parameter(torch.randn(1, num_channel, size, size))
|
||||
|
||||
def forward(self, batch):
|
||||
out = self.weight.repeat(batch, 1, 1, 1)
|
||||
return out
|
||||
|
||||
|
||||
@ARCH_REGISTRY.register()
|
||||
class StyleGAN2GeneratorBilinear(nn.Module):
|
||||
"""StyleGAN2 Generator.
|
||||
|
||||
Args:
|
||||
out_size (int): The spatial size of outputs.
|
||||
num_style_feat (int): Channel number of style features. Default: 512.
|
||||
num_mlp (int): Layer number of MLP style layers. Default: 8.
|
||||
channel_multiplier (int): Channel multiplier for large networks of
|
||||
StyleGAN2. Default: 2.
|
||||
lr_mlp (float): Learning rate multiplier for mlp layers. Default: 0.01.
|
||||
narrow (float): Narrow ratio for channels. Default: 1.0.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
out_size,
|
||||
num_style_feat=512,
|
||||
num_mlp=8,
|
||||
channel_multiplier=2,
|
||||
lr_mlp=0.01,
|
||||
narrow=1,
|
||||
interpolation_mode='bilinear'):
|
||||
super(StyleGAN2GeneratorBilinear, self).__init__()
|
||||
# Style MLP layers
|
||||
self.num_style_feat = num_style_feat
|
||||
style_mlp_layers = [NormStyleCode()]
|
||||
for i in range(num_mlp):
|
||||
style_mlp_layers.append(
|
||||
EqualLinear(
|
||||
num_style_feat, num_style_feat, bias=True, bias_init_val=0, lr_mul=lr_mlp,
|
||||
activation='fused_lrelu'))
|
||||
self.style_mlp = nn.Sequential(*style_mlp_layers)
|
||||
|
||||
channels = {
|
||||
'4': int(512 * narrow),
|
||||
'8': int(512 * narrow),
|
||||
'16': int(512 * narrow),
|
||||
'32': int(512 * narrow),
|
||||
'64': int(256 * channel_multiplier * narrow),
|
||||
'128': int(128 * channel_multiplier * narrow),
|
||||
'256': int(64 * channel_multiplier * narrow),
|
||||
'512': int(32 * channel_multiplier * narrow),
|
||||
'1024': int(16 * channel_multiplier * narrow)
|
||||
}
|
||||
self.channels = channels
|
||||
|
||||
self.constant_input = ConstantInput(channels['4'], size=4)
|
||||
self.style_conv1 = StyleConv(
|
||||
channels['4'],
|
||||
channels['4'],
|
||||
kernel_size=3,
|
||||
num_style_feat=num_style_feat,
|
||||
demodulate=True,
|
||||
sample_mode=None,
|
||||
interpolation_mode=interpolation_mode)
|
||||
self.to_rgb1 = ToRGB(channels['4'], num_style_feat, upsample=False, interpolation_mode=interpolation_mode)
|
||||
|
||||
self.log_size = int(math.log(out_size, 2))
|
||||
self.num_layers = (self.log_size - 2) * 2 + 1
|
||||
self.num_latent = self.log_size * 2 - 2
|
||||
|
||||
self.style_convs = nn.ModuleList()
|
||||
self.to_rgbs = nn.ModuleList()
|
||||
self.noises = nn.Module()
|
||||
|
||||
in_channels = channels['4']
|
||||
# noise
|
||||
for layer_idx in range(self.num_layers):
|
||||
resolution = 2**((layer_idx + 5) // 2)
|
||||
shape = [1, 1, resolution, resolution]
|
||||
self.noises.register_buffer(f'noise{layer_idx}', torch.randn(*shape))
|
||||
# style convs and to_rgbs
|
||||
for i in range(3, self.log_size + 1):
|
||||
out_channels = channels[f'{2**i}']
|
||||
self.style_convs.append(
|
||||
StyleConv(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
num_style_feat=num_style_feat,
|
||||
demodulate=True,
|
||||
sample_mode='upsample',
|
||||
interpolation_mode=interpolation_mode))
|
||||
self.style_convs.append(
|
||||
StyleConv(
|
||||
out_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
num_style_feat=num_style_feat,
|
||||
demodulate=True,
|
||||
sample_mode=None,
|
||||
interpolation_mode=interpolation_mode))
|
||||
self.to_rgbs.append(
|
||||
ToRGB(out_channels, num_style_feat, upsample=True, interpolation_mode=interpolation_mode))
|
||||
in_channels = out_channels
|
||||
|
||||
def make_noise(self):
|
||||
"""Make noise for noise injection."""
|
||||
device = self.constant_input.weight.device
|
||||
noises = [torch.randn(1, 1, 4, 4, device=device)]
|
||||
|
||||
for i in range(3, self.log_size + 1):
|
||||
for _ in range(2):
|
||||
noises.append(torch.randn(1, 1, 2**i, 2**i, device=device))
|
||||
|
||||
return noises
|
||||
|
||||
def get_latent(self, x):
|
||||
return self.style_mlp(x)
|
||||
|
||||
def mean_latent(self, num_latent):
|
||||
latent_in = torch.randn(num_latent, self.num_style_feat, device=self.constant_input.weight.device)
|
||||
latent = self.style_mlp(latent_in).mean(0, keepdim=True)
|
||||
return latent
|
||||
|
||||
def forward(self,
|
||||
styles,
|
||||
input_is_latent=False,
|
||||
noise=None,
|
||||
randomize_noise=True,
|
||||
truncation=1,
|
||||
truncation_latent=None,
|
||||
inject_index=None,
|
||||
return_latents=False):
|
||||
"""Forward function for StyleGAN2Generator.
|
||||
|
||||
Args:
|
||||
styles (list[Tensor]): Sample codes of styles.
|
||||
input_is_latent (bool): Whether input is latent style.
|
||||
Default: False.
|
||||
noise (Tensor | None): Input noise or None. Default: None.
|
||||
randomize_noise (bool): Randomize noise, used when 'noise' is
|
||||
False. Default: True.
|
||||
truncation (float): TODO. Default: 1.
|
||||
truncation_latent (Tensor | None): TODO. Default: None.
|
||||
inject_index (int | None): The injection index for mixing noise.
|
||||
Default: None.
|
||||
return_latents (bool): Whether to return style latents.
|
||||
Default: False.
|
||||
"""
|
||||
# style codes -> latents with Style MLP layer
|
||||
if not input_is_latent:
|
||||
styles = [self.style_mlp(s) for s in styles]
|
||||
# noises
|
||||
if noise is None:
|
||||
if randomize_noise:
|
||||
noise = [None] * self.num_layers # for each style conv layer
|
||||
else: # use the stored noise
|
||||
noise = [getattr(self.noises, f'noise{i}') for i in range(self.num_layers)]
|
||||
# style truncation
|
||||
if truncation < 1:
|
||||
style_truncation = []
|
||||
for style in styles:
|
||||
style_truncation.append(truncation_latent + truncation * (style - truncation_latent))
|
||||
styles = style_truncation
|
||||
# get style latent with injection
|
||||
if len(styles) == 1:
|
||||
inject_index = self.num_latent
|
||||
|
||||
if styles[0].ndim < 3:
|
||||
# repeat latent code for all the layers
|
||||
latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
|
||||
else: # used for encoder with different latent code for each layer
|
||||
latent = styles[0]
|
||||
elif len(styles) == 2: # mixing noises
|
||||
if inject_index is None:
|
||||
inject_index = random.randint(1, self.num_latent - 1)
|
||||
latent1 = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
|
||||
latent2 = styles[1].unsqueeze(1).repeat(1, self.num_latent - inject_index, 1)
|
||||
latent = torch.cat([latent1, latent2], 1)
|
||||
|
||||
# main generation
|
||||
out = self.constant_input(latent.shape[0])
|
||||
out = self.style_conv1(out, latent[:, 0], noise=noise[0])
|
||||
skip = self.to_rgb1(out, latent[:, 1])
|
||||
|
||||
i = 1
|
||||
for conv1, conv2, noise1, noise2, to_rgb in zip(self.style_convs[::2], self.style_convs[1::2], noise[1::2],
|
||||
noise[2::2], self.to_rgbs):
|
||||
out = conv1(out, latent[:, i], noise=noise1)
|
||||
out = conv2(out, latent[:, i + 1], noise=noise2)
|
||||
skip = to_rgb(out, latent[:, i + 2], skip)
|
||||
i += 2
|
||||
|
||||
image = skip
|
||||
|
||||
if return_latents:
|
||||
return image, latent
|
||||
else:
|
||||
return image, None
|
||||
|
||||
|
||||
class ScaledLeakyReLU(nn.Module):
|
||||
"""Scaled LeakyReLU.
|
||||
|
||||
Args:
|
||||
negative_slope (float): Negative slope. Default: 0.2.
|
||||
"""
|
||||
|
||||
def __init__(self, negative_slope=0.2):
|
||||
super(ScaledLeakyReLU, self).__init__()
|
||||
self.negative_slope = negative_slope
|
||||
|
||||
def forward(self, x):
|
||||
out = F.leaky_relu(x, negative_slope=self.negative_slope)
|
||||
return out * math.sqrt(2)
|
||||
|
||||
|
||||
class EqualConv2d(nn.Module):
|
||||
"""Equalized Linear as StyleGAN2.
|
||||
|
||||
Args:
|
||||
in_channels (int): Channel number of the input.
|
||||
out_channels (int): Channel number of the output.
|
||||
kernel_size (int): Size of the convolving kernel.
|
||||
stride (int): Stride of the convolution. Default: 1
|
||||
padding (int): Zero-padding added to both sides of the input.
|
||||
Default: 0.
|
||||
bias (bool): If ``True``, adds a learnable bias to the output.
|
||||
Default: ``True``.
|
||||
bias_init_val (float): Bias initialized value. Default: 0.
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, bias=True, bias_init_val=0):
|
||||
super(EqualConv2d, self).__init__()
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.kernel_size = kernel_size
|
||||
self.stride = stride
|
||||
self.padding = padding
|
||||
self.scale = 1 / math.sqrt(in_channels * kernel_size**2)
|
||||
|
||||
self.weight = nn.Parameter(torch.randn(out_channels, in_channels, kernel_size, kernel_size))
|
||||
if bias:
|
||||
self.bias = nn.Parameter(torch.zeros(out_channels).fill_(bias_init_val))
|
||||
else:
|
||||
self.register_parameter('bias', None)
|
||||
|
||||
def forward(self, x):
|
||||
out = F.conv2d(
|
||||
x,
|
||||
self.weight * self.scale,
|
||||
bias=self.bias,
|
||||
stride=self.stride,
|
||||
padding=self.padding,
|
||||
)
|
||||
|
||||
return out
|
||||
|
||||
def __repr__(self):
|
||||
return (f'{self.__class__.__name__}(in_channels={self.in_channels}, '
|
||||
f'out_channels={self.out_channels}, '
|
||||
f'kernel_size={self.kernel_size},'
|
||||
f' stride={self.stride}, padding={self.padding}, '
|
||||
f'bias={self.bias is not None})')
|
||||
|
||||
|
||||
class ConvLayer(nn.Sequential):
|
||||
"""Conv Layer used in StyleGAN2 Discriminator.
|
||||
|
||||
Args:
|
||||
in_channels (int): Channel number of the input.
|
||||
out_channels (int): Channel number of the output.
|
||||
kernel_size (int): Kernel size.
|
||||
downsample (bool): Whether downsample by a factor of 2.
|
||||
Default: False.
|
||||
bias (bool): Whether with bias. Default: True.
|
||||
activate (bool): Whether use activateion. Default: True.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
downsample=False,
|
||||
bias=True,
|
||||
activate=True,
|
||||
interpolation_mode='bilinear'):
|
||||
layers = []
|
||||
self.interpolation_mode = interpolation_mode
|
||||
# downsample
|
||||
if downsample:
|
||||
if self.interpolation_mode == 'nearest':
|
||||
self.align_corners = None
|
||||
else:
|
||||
self.align_corners = False
|
||||
|
||||
layers.append(
|
||||
torch.nn.Upsample(scale_factor=0.5, mode=interpolation_mode, align_corners=self.align_corners))
|
||||
stride = 1
|
||||
self.padding = kernel_size // 2
|
||||
# conv
|
||||
layers.append(
|
||||
EqualConv2d(
|
||||
in_channels, out_channels, kernel_size, stride=stride, padding=self.padding, bias=bias
|
||||
and not activate))
|
||||
# activation
|
||||
if activate:
|
||||
if bias:
|
||||
layers.append(FusedLeakyReLU(out_channels))
|
||||
else:
|
||||
layers.append(ScaledLeakyReLU(0.2))
|
||||
|
||||
super(ConvLayer, self).__init__(*layers)
|
||||
|
||||
|
||||
class ResBlock(nn.Module):
|
||||
"""Residual block used in StyleGAN2 Discriminator.
|
||||
|
||||
Args:
|
||||
in_channels (int): Channel number of the input.
|
||||
out_channels (int): Channel number of the output.
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels, out_channels, interpolation_mode='bilinear'):
|
||||
super(ResBlock, self).__init__()
|
||||
|
||||
self.conv1 = ConvLayer(in_channels, in_channels, 3, bias=True, activate=True)
|
||||
self.conv2 = ConvLayer(
|
||||
in_channels,
|
||||
out_channels,
|
||||
3,
|
||||
downsample=True,
|
||||
interpolation_mode=interpolation_mode,
|
||||
bias=True,
|
||||
activate=True)
|
||||
self.skip = ConvLayer(
|
||||
in_channels,
|
||||
out_channels,
|
||||
1,
|
||||
downsample=True,
|
||||
interpolation_mode=interpolation_mode,
|
||||
bias=False,
|
||||
activate=False)
|
||||
|
||||
def forward(self, x):
|
||||
out = self.conv1(x)
|
||||
out = self.conv2(out)
|
||||
skip = self.skip(x)
|
||||
out = (out + skip) / math.sqrt(2)
|
||||
return out
|
||||
@@ -31,12 +31,9 @@ class ModulatedConv2d(nn.Module):
|
||||
out_channels (int): Channel number of the output.
|
||||
kernel_size (int): Size of the convolving kernel.
|
||||
num_style_feat (int): Channel number of style features.
|
||||
demodulate (bool): Whether to demodulate in the conv layer.
|
||||
Default: True.
|
||||
sample_mode (str | None): Indicating 'upsample', 'downsample' or None.
|
||||
Default: None.
|
||||
eps (float): A value added to the denominator for numerical stability.
|
||||
Default: 1e-8.
|
||||
demodulate (bool): Whether to demodulate in the conv layer. Default: True.
|
||||
sample_mode (str | None): Indicating 'upsample', 'downsample' or None. Default: None.
|
||||
eps (float): A value added to the denominator for numerical stability. Default: 1e-8.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
@@ -87,6 +84,7 @@ class ModulatedConv2d(nn.Module):
|
||||
|
||||
weight = weight.view(b * self.out_channels, c, self.kernel_size, self.kernel_size)
|
||||
|
||||
# upsample or downsample if necessary
|
||||
if self.sample_mode == 'upsample':
|
||||
x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)
|
||||
elif self.sample_mode == 'downsample':
|
||||
@@ -101,14 +99,12 @@ class ModulatedConv2d(nn.Module):
|
||||
return out
|
||||
|
||||
def __repr__(self):
|
||||
return (f'{self.__class__.__name__}(in_channels={self.in_channels}, '
|
||||
f'out_channels={self.out_channels}, '
|
||||
f'kernel_size={self.kernel_size}, '
|
||||
f'demodulate={self.demodulate}, sample_mode={self.sample_mode})')
|
||||
return (f'{self.__class__.__name__}(in_channels={self.in_channels}, out_channels={self.out_channels}, '
|
||||
f'kernel_size={self.kernel_size}, demodulate={self.demodulate}, sample_mode={self.sample_mode})')
|
||||
|
||||
|
||||
class StyleConv(nn.Module):
|
||||
"""Style conv.
|
||||
"""Style conv used in StyleGAN2.
|
||||
|
||||
Args:
|
||||
in_channels (int): Channel number of the input.
|
||||
@@ -116,8 +112,7 @@ class StyleConv(nn.Module):
|
||||
kernel_size (int): Size of the convolving kernel.
|
||||
num_style_feat (int): Channel number of style features.
|
||||
demodulate (bool): Whether demodulate in the conv layer. Default: True.
|
||||
sample_mode (str | None): Indicating 'upsample', 'downsample' or None.
|
||||
Default: None.
|
||||
sample_mode (str | None): Indicating 'upsample', 'downsample' or None. Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels, out_channels, kernel_size, num_style_feat, demodulate=True, sample_mode=None):
|
||||
@@ -144,7 +139,7 @@ class StyleConv(nn.Module):
|
||||
|
||||
|
||||
class ToRGB(nn.Module):
|
||||
"""To RGB from features.
|
||||
"""To RGB (image space) from features.
|
||||
|
||||
Args:
|
||||
in_channels (int): Channel number of input.
|
||||
@@ -204,8 +199,7 @@ class StyleGAN2GeneratorClean(nn.Module):
|
||||
out_size (int): The spatial size of outputs.
|
||||
num_style_feat (int): Channel number of style features. Default: 512.
|
||||
num_mlp (int): Layer number of MLP style layers. Default: 8.
|
||||
channel_multiplier (int): Channel multiplier for large networks of
|
||||
StyleGAN2. Default: 2.
|
||||
channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
|
||||
narrow (float): Narrow ratio for channels. Default: 1.0.
|
||||
"""
|
||||
|
||||
@@ -222,6 +216,7 @@ class StyleGAN2GeneratorClean(nn.Module):
|
||||
# initialization
|
||||
default_init_weights(self.style_mlp, scale=1, bias_fill=0, a=0.2, mode='fan_in', nonlinearity='leaky_relu')
|
||||
|
||||
# channel list
|
||||
channels = {
|
||||
'4': int(512 * narrow),
|
||||
'8': int(512 * narrow),
|
||||
@@ -309,21 +304,17 @@ class StyleGAN2GeneratorClean(nn.Module):
|
||||
truncation_latent=None,
|
||||
inject_index=None,
|
||||
return_latents=False):
|
||||
"""Forward function for StyleGAN2Generator.
|
||||
"""Forward function for StyleGAN2GeneratorClean.
|
||||
|
||||
Args:
|
||||
styles (list[Tensor]): Sample codes of styles.
|
||||
input_is_latent (bool): Whether input is latent style.
|
||||
Default: False.
|
||||
input_is_latent (bool): Whether input is latent style. Default: False.
|
||||
noise (Tensor | None): Input noise or None. Default: None.
|
||||
randomize_noise (bool): Randomize noise, used when 'noise' is
|
||||
False. Default: True.
|
||||
truncation (float): TODO. Default: 1.
|
||||
truncation_latent (Tensor | None): TODO. Default: None.
|
||||
inject_index (int | None): The injection index for mixing noise.
|
||||
Default: None.
|
||||
return_latents (bool): Whether to return style latents.
|
||||
Default: False.
|
||||
randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True.
|
||||
truncation (float): The truncation ratio. Default: 1.
|
||||
truncation_latent (Tensor | None): The truncation latent tensor. Default: None.
|
||||
inject_index (int | None): The injection index for mixing noise. Default: None.
|
||||
return_latents (bool): Whether to return style latents. Default: False.
|
||||
"""
|
||||
# style codes -> latents with Style MLP layer
|
||||
if not input_is_latent:
|
||||
@@ -340,7 +331,7 @@ class StyleGAN2GeneratorClean(nn.Module):
|
||||
for style in styles:
|
||||
style_truncation.append(truncation_latent + truncation * (style - truncation_latent))
|
||||
styles = style_truncation
|
||||
# get style latent with injection
|
||||
# get style latents with injection
|
||||
if len(styles) == 1:
|
||||
inject_index = self.num_latent
|
||||
|
||||
@@ -366,7 +357,7 @@ class StyleGAN2GeneratorClean(nn.Module):
|
||||
noise[2::2], self.to_rgbs):
|
||||
out = conv1(out, latent[:, i], noise=noise1)
|
||||
out = conv2(out, latent[:, i + 1], noise=noise2)
|
||||
skip = to_rgb(out, latent[:, i + 2], skip)
|
||||
skip = to_rgb(out, latent[:, i + 2], skip) # feature back to the rgb space
|
||||
i += 2
|
||||
|
||||
image = skip
|
||||
|
||||
@@ -15,6 +15,19 @@ from torchvision.transforms.functional import (adjust_brightness, adjust_contras
|
||||
|
||||
@DATASET_REGISTRY.register()
|
||||
class FFHQDegradationDataset(data.Dataset):
|
||||
"""FFHQ dataset for GFPGAN.
|
||||
|
||||
It reads high resolution images, and then generate low-quality (LQ) images on-the-fly.
|
||||
|
||||
Args:
|
||||
opt (dict): Config for train datasets. It contains the following keys:
|
||||
dataroot_gt (str): Data root path for gt.
|
||||
io_backend (dict): IO backend type and other kwarg.
|
||||
mean (list | tuple): Image mean.
|
||||
std (list | tuple): Image std.
|
||||
use_hflip (bool): Whether to horizontally flip.
|
||||
Please see more options in the codes.
|
||||
"""
|
||||
|
||||
def __init__(self, opt):
|
||||
super(FFHQDegradationDataset, self).__init__()
|
||||
@@ -29,11 +42,13 @@ class FFHQDegradationDataset(data.Dataset):
|
||||
self.out_size = opt['out_size']
|
||||
|
||||
self.crop_components = opt.get('crop_components', False) # facial components
|
||||
self.eye_enlarge_ratio = opt.get('eye_enlarge_ratio', 1)
|
||||
self.eye_enlarge_ratio = opt.get('eye_enlarge_ratio', 1) # whether enlarge eye regions
|
||||
|
||||
if self.crop_components:
|
||||
# load component list from a pre-process pth files
|
||||
self.components_list = torch.load(opt.get('component_path'))
|
||||
|
||||
# file client (lmdb io backend)
|
||||
if self.io_backend_opt['type'] == 'lmdb':
|
||||
self.io_backend_opt['db_paths'] = self.gt_folder
|
||||
if not self.gt_folder.endswith('.lmdb'):
|
||||
@@ -41,9 +56,10 @@ class FFHQDegradationDataset(data.Dataset):
|
||||
with open(osp.join(self.gt_folder, 'meta_info.txt')) as fin:
|
||||
self.paths = [line.split('.')[0] for line in fin]
|
||||
else:
|
||||
# disk backend: scan file list from a folder
|
||||
self.paths = paths_from_folder(self.gt_folder)
|
||||
|
||||
# degradations
|
||||
# degradation configurations
|
||||
self.blur_kernel_size = opt['blur_kernel_size']
|
||||
self.kernel_list = opt['kernel_list']
|
||||
self.kernel_prob = opt['kernel_prob']
|
||||
@@ -60,22 +76,20 @@ class FFHQDegradationDataset(data.Dataset):
|
||||
self.gray_prob = opt.get('gray_prob')
|
||||
|
||||
logger = get_root_logger()
|
||||
logger.info(f'Blur: blur_kernel_size {self.blur_kernel_size}, '
|
||||
f'sigma: [{", ".join(map(str, self.blur_sigma))}]')
|
||||
logger.info(f'Blur: blur_kernel_size {self.blur_kernel_size}, sigma: [{", ".join(map(str, self.blur_sigma))}]')
|
||||
logger.info(f'Downsample: downsample_range [{", ".join(map(str, self.downsample_range))}]')
|
||||
logger.info(f'Noise: [{", ".join(map(str, self.noise_range))}]')
|
||||
logger.info(f'JPEG compression: [{", ".join(map(str, self.jpeg_range))}]')
|
||||
|
||||
if self.color_jitter_prob is not None:
|
||||
logger.info(f'Use random color jitter. Prob: {self.color_jitter_prob}, '
|
||||
f'shift: {self.color_jitter_shift}')
|
||||
logger.info(f'Use random color jitter. Prob: {self.color_jitter_prob}, shift: {self.color_jitter_shift}')
|
||||
if self.gray_prob is not None:
|
||||
logger.info(f'Use random gray. Prob: {self.gray_prob}')
|
||||
|
||||
self.color_jitter_shift /= 255.
|
||||
|
||||
@staticmethod
|
||||
def color_jitter(img, shift):
|
||||
"""jitter color: randomly jitter the RGB values, in numpy formats"""
|
||||
jitter_val = np.random.uniform(-shift, shift, 3).astype(np.float32)
|
||||
img = img + jitter_val
|
||||
img = np.clip(img, 0, 1)
|
||||
@@ -83,6 +97,7 @@ class FFHQDegradationDataset(data.Dataset):
|
||||
|
||||
@staticmethod
|
||||
def color_jitter_pt(img, brightness, contrast, saturation, hue):
|
||||
"""jitter color: randomly jitter the brightness, contrast, saturation, and hue, in torch Tensor formats"""
|
||||
fn_idx = torch.randperm(4)
|
||||
for fn_id in fn_idx:
|
||||
if fn_id == 0 and brightness is not None:
|
||||
@@ -103,6 +118,7 @@ class FFHQDegradationDataset(data.Dataset):
|
||||
return img
|
||||
|
||||
def get_component_coordinates(self, index, status):
|
||||
"""Get facial component (left_eye, right_eye, mouth) coordinates from a pre-loaded pth file"""
|
||||
components_bbox = self.components_list[f'{index:08d}']
|
||||
if status[0]: # hflip
|
||||
# exchange right and left eye
|
||||
@@ -131,6 +147,7 @@ class FFHQDegradationDataset(data.Dataset):
|
||||
self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
|
||||
|
||||
# load gt image
|
||||
# Shape: (h, w, c); channel order: BGR; image range: [0, 1], float32.
|
||||
gt_path = self.paths[index]
|
||||
img_bytes = self.file_client.get(gt_path)
|
||||
img_gt = imfrombytes(img_bytes, float32=True)
|
||||
@@ -139,6 +156,7 @@ class FFHQDegradationDataset(data.Dataset):
|
||||
img_gt, status = augment(img_gt, hflip=self.opt['use_hflip'], rotation=False, return_status=True)
|
||||
h, w, _ = img_gt.shape
|
||||
|
||||
# get facial component coordinates
|
||||
if self.crop_components:
|
||||
locations = self.get_component_coordinates(index, status)
|
||||
loc_left_eye, loc_right_eye, loc_mouth = locations
|
||||
@@ -173,9 +191,9 @@ class FFHQDegradationDataset(data.Dataset):
|
||||
if self.gray_prob and np.random.uniform() < self.gray_prob:
|
||||
img_lq = cv2.cvtColor(img_lq, cv2.COLOR_BGR2GRAY)
|
||||
img_lq = np.tile(img_lq[:, :, None], [1, 1, 3])
|
||||
if self.opt.get('gt_gray'):
|
||||
if self.opt.get('gt_gray'): # whether convert GT to gray images
|
||||
img_gt = cv2.cvtColor(img_gt, cv2.COLOR_BGR2GRAY)
|
||||
img_gt = np.tile(img_gt[:, :, None], [1, 1, 3])
|
||||
img_gt = np.tile(img_gt[:, :, None], [1, 1, 3]) # repeat the color channels
|
||||
|
||||
# BGR to RGB, HWC to CHW, numpy to tensor
|
||||
img_gt, img_lq = img2tensor([img_gt, img_lq], bgr2rgb=True, float32=True)
|
||||
|
||||
@@ -3,7 +3,7 @@ import os.path as osp
|
||||
import torch
|
||||
from basicsr.archs import build_network
|
||||
from basicsr.losses import build_loss
|
||||
from basicsr.losses.losses import r1_penalty
|
||||
from basicsr.losses.gan_loss import r1_penalty
|
||||
from basicsr.metrics import calculate_metric
|
||||
from basicsr.models.base_model import BaseModel
|
||||
from basicsr.utils import get_root_logger, imwrite, tensor2img
|
||||
@@ -16,11 +16,11 @@ from tqdm import tqdm
|
||||
|
||||
@MODEL_REGISTRY.register()
|
||||
class GFPGANModel(BaseModel):
|
||||
"""GFPGAN model for <Towards real-world blind face restoratin with generative facial prior>"""
|
||||
"""The GFPGAN model for Towards real-world blind face restoratin with generative facial prior"""
|
||||
|
||||
def __init__(self, opt):
|
||||
super(GFPGANModel, self).__init__(opt)
|
||||
self.idx = 0
|
||||
self.idx = 0 # it is used for saving data for check
|
||||
|
||||
# define network
|
||||
self.net_g = build_network(opt['network_g'])
|
||||
@@ -51,8 +51,7 @@ class GFPGANModel(BaseModel):
|
||||
self.load_network(self.net_d, load_path, self.opt['path'].get('strict_load_d', True))
|
||||
|
||||
# ----------- define net_g with Exponential Moving Average (EMA) ----------- #
|
||||
# net_g_ema only used for testing on one GPU and saving
|
||||
# There is no need to wrap with DistributedDataParallel
|
||||
# net_g_ema only used for testing on one GPU and saving. There is no need to wrap with DistributedDataParallel
|
||||
self.net_g_ema = build_network(self.opt['network_g']).to(self.device)
|
||||
# load pretrained model
|
||||
load_path = self.opt['path'].get('pretrain_network_g', None)
|
||||
@@ -65,7 +64,7 @@ class GFPGANModel(BaseModel):
|
||||
self.net_d.train()
|
||||
self.net_g_ema.eval()
|
||||
|
||||
# ----------- facial components networks ----------- #
|
||||
# ----------- facial component networks ----------- #
|
||||
if ('network_d_left_eye' in self.opt and 'network_d_right_eye' in self.opt and 'network_d_mouth' in self.opt):
|
||||
self.use_facial_disc = True
|
||||
else:
|
||||
@@ -102,17 +101,19 @@ class GFPGANModel(BaseModel):
|
||||
self.cri_component = build_loss(train_opt['gan_component_opt']).to(self.device)
|
||||
|
||||
# ----------- define losses ----------- #
|
||||
# pixel loss
|
||||
if train_opt.get('pixel_opt'):
|
||||
self.cri_pix = build_loss(train_opt['pixel_opt']).to(self.device)
|
||||
else:
|
||||
self.cri_pix = None
|
||||
|
||||
# perceptual loss
|
||||
if train_opt.get('perceptual_opt'):
|
||||
self.cri_perceptual = build_loss(train_opt['perceptual_opt']).to(self.device)
|
||||
else:
|
||||
self.cri_perceptual = None
|
||||
|
||||
# L1 loss used in pyramid loss, component style loss and identity loss
|
||||
# L1 loss is used in pyramid loss, component style loss and identity loss
|
||||
self.cri_l1 = build_loss(train_opt['L1_opt']).to(self.device)
|
||||
|
||||
# gan loss (wgan)
|
||||
@@ -179,6 +180,7 @@ class GFPGANModel(BaseModel):
|
||||
self.optimizer_d = self.get_optimizer(optim_type, optim_params_d, lr, betas=betas)
|
||||
self.optimizers.append(self.optimizer_d)
|
||||
|
||||
# ----------- optimizers for facial component networks ----------- #
|
||||
if self.use_facial_disc:
|
||||
# setup optimizers for facial component discriminators
|
||||
optim_type = train_opt['optim_component'].pop('type')
|
||||
@@ -207,20 +209,21 @@ class GFPGANModel(BaseModel):
|
||||
self.loc_right_eyes = data['loc_right_eye']
|
||||
self.loc_mouths = data['loc_mouth']
|
||||
|
||||
# uncomment to check data
|
||||
# import torchvision
|
||||
# if self.opt['rank'] == 0:
|
||||
# import os
|
||||
# os.makedirs('tmp/gt', exist_ok=True)
|
||||
# os.makedirs('tmp/lq', exist_ok=True)
|
||||
# print(self.idx)
|
||||
# torchvision.utils.save_image(
|
||||
# self.gt, f'tmp/gt/gt_{self.idx}.png', nrow=4, padding=2, normalize=True, range=(-1, 1))
|
||||
# torchvision.utils.save_image(
|
||||
# self.lq, f'tmp/lq/lq{self.idx}.png', nrow=4, padding=2, normalize=True, range=(-1, 1))
|
||||
# self.idx = self.idx + 1
|
||||
# uncomment to check data
|
||||
# import torchvision
|
||||
# if self.opt['rank'] == 0:
|
||||
# import os
|
||||
# os.makedirs('tmp/gt', exist_ok=True)
|
||||
# os.makedirs('tmp/lq', exist_ok=True)
|
||||
# print(self.idx)
|
||||
# torchvision.utils.save_image(
|
||||
# self.gt, f'tmp/gt/gt_{self.idx}.png', nrow=4, padding=2, normalize=True, range=(-1, 1))
|
||||
# torchvision.utils.save_image(
|
||||
# self.lq, f'tmp/lq/lq{self.idx}.png', nrow=4, padding=2, normalize=True, range=(-1, 1))
|
||||
# self.idx = self.idx + 1
|
||||
|
||||
def construct_img_pyramid(self):
|
||||
"""Construct image pyramid for intermediate restoration loss"""
|
||||
pyramid_gt = [self.gt]
|
||||
down_img = self.gt
|
||||
for _ in range(0, self.log_size - 3):
|
||||
@@ -229,7 +232,6 @@ class GFPGANModel(BaseModel):
|
||||
return pyramid_gt
|
||||
|
||||
def get_roi_regions(self, eye_out_size=80, mouth_out_size=120):
|
||||
# hard code
|
||||
face_ratio = int(self.opt['network_g']['out_size'] / 512)
|
||||
eye_out_size *= face_ratio
|
||||
mouth_out_size *= face_ratio
|
||||
@@ -288,6 +290,7 @@ class GFPGANModel(BaseModel):
|
||||
p.requires_grad = False
|
||||
self.optimizer_g.zero_grad()
|
||||
|
||||
# do not update facial component net_d
|
||||
if self.use_facial_disc:
|
||||
for p in self.net_d_left_eye.parameters():
|
||||
p.requires_grad = False
|
||||
@@ -297,10 +300,9 @@ class GFPGANModel(BaseModel):
|
||||
p.requires_grad = False
|
||||
|
||||
# image pyramid loss weight
|
||||
if current_iter < self.opt['train'].get('remove_pyramid_loss', float('inf')):
|
||||
pyramid_loss_weight = self.opt['train'].get('pyramid_loss_weight', 1)
|
||||
else:
|
||||
pyramid_loss_weight = 1e-12 # very small loss
|
||||
pyramid_loss_weight = self.opt['train'].get('pyramid_loss_weight', 0)
|
||||
if pyramid_loss_weight > 0 and current_iter > self.opt['train'].get('remove_pyramid_loss', float('inf')):
|
||||
pyramid_loss_weight = 1e-12 # very small weight to avoid unused param error
|
||||
if pyramid_loss_weight > 0:
|
||||
self.output, out_rgbs = self.net_g(self.lq, return_rgb=True)
|
||||
pyramid_gt = self.construct_img_pyramid()
|
||||
@@ -419,11 +421,12 @@ class GFPGANModel(BaseModel):
|
||||
real_d_pred = self.net_d(self.gt)
|
||||
l_d = self.cri_gan(real_d_pred, True, is_disc=True) + self.cri_gan(fake_d_pred, False, is_disc=True)
|
||||
loss_dict['l_d'] = l_d
|
||||
# In wgan, real_score should be positive and fake_score should benegative
|
||||
# In WGAN, real_score should be positive and fake_score should be negative
|
||||
loss_dict['real_score'] = real_d_pred.detach().mean()
|
||||
loss_dict['fake_score'] = fake_d_pred.detach().mean()
|
||||
l_d.backward()
|
||||
|
||||
# regularization loss
|
||||
if current_iter % self.net_d_reg_every == 0:
|
||||
self.gt.requires_grad = True
|
||||
real_pred = self.net_d(self.gt)
|
||||
@@ -434,8 +437,9 @@ class GFPGANModel(BaseModel):
|
||||
|
||||
self.optimizer_d.step()
|
||||
|
||||
# optimize facial component discriminators
|
||||
if self.use_facial_disc:
|
||||
# lefe eye
|
||||
# left eye
|
||||
fake_d_pred, _ = self.net_d_left_eye(self.left_eyes.detach())
|
||||
real_d_pred, _ = self.net_d_left_eye(self.left_eyes_gt)
|
||||
l_d_left_eye = self.cri_component(
|
||||
@@ -485,22 +489,32 @@ class GFPGANModel(BaseModel):
|
||||
def nondist_validation(self, dataloader, current_iter, tb_logger, save_img):
|
||||
dataset_name = dataloader.dataset.opt['name']
|
||||
with_metrics = self.opt['val'].get('metrics') is not None
|
||||
use_pbar = self.opt['val'].get('pbar', False)
|
||||
|
||||
if with_metrics:
|
||||
self.metric_results = {metric: 0 for metric in self.opt['val']['metrics'].keys()}
|
||||
pbar = tqdm(total=len(dataloader), unit='image')
|
||||
if not hasattr(self, 'metric_results'): # only execute in the first run
|
||||
self.metric_results = {metric: 0 for metric in self.opt['val']['metrics'].keys()}
|
||||
# initialize the best metric results for each dataset_name (supporting multiple validation datasets)
|
||||
self._initialize_best_metric_results(dataset_name)
|
||||
# zero self.metric_results
|
||||
self.metric_results = {metric: 0 for metric in self.metric_results}
|
||||
|
||||
metric_data = dict()
|
||||
if use_pbar:
|
||||
pbar = tqdm(total=len(dataloader), unit='image')
|
||||
|
||||
for idx, val_data in enumerate(dataloader):
|
||||
img_name = osp.splitext(osp.basename(val_data['lq_path'][0]))[0]
|
||||
self.feed_data(val_data)
|
||||
self.test()
|
||||
|
||||
visuals = self.get_current_visuals()
|
||||
sr_img = tensor2img([visuals['sr']], min_max=(-1, 1))
|
||||
gt_img = tensor2img([visuals['gt']], min_max=(-1, 1))
|
||||
|
||||
if 'gt' in visuals:
|
||||
gt_img = tensor2img([visuals['gt']], min_max=(-1, 1))
|
||||
sr_img = tensor2img(self.output.detach().cpu(), min_max=(-1, 1))
|
||||
metric_data['img'] = sr_img
|
||||
if hasattr(self, 'gt'):
|
||||
gt_img = tensor2img(self.gt.detach().cpu(), min_max=(-1, 1))
|
||||
metric_data['img2'] = gt_img
|
||||
del self.gt
|
||||
|
||||
# tentative for out of GPU memory
|
||||
del self.lq
|
||||
del self.output
|
||||
@@ -522,35 +536,38 @@ class GFPGANModel(BaseModel):
|
||||
if with_metrics:
|
||||
# calculate metrics
|
||||
for name, opt_ in self.opt['val']['metrics'].items():
|
||||
metric_data = dict(img1=sr_img, img2=gt_img)
|
||||
self.metric_results[name] += calculate_metric(metric_data, opt_)
|
||||
pbar.update(1)
|
||||
pbar.set_description(f'Test {img_name}')
|
||||
pbar.close()
|
||||
if use_pbar:
|
||||
pbar.update(1)
|
||||
pbar.set_description(f'Test {img_name}')
|
||||
if use_pbar:
|
||||
pbar.close()
|
||||
|
||||
if with_metrics:
|
||||
for metric in self.metric_results.keys():
|
||||
self.metric_results[metric] /= (idx + 1)
|
||||
# update the best metric result
|
||||
self._update_best_metric_result(dataset_name, metric, self.metric_results[metric], current_iter)
|
||||
|
||||
self._log_validation_metric_values(current_iter, dataset_name, tb_logger)
|
||||
|
||||
def _log_validation_metric_values(self, current_iter, dataset_name, tb_logger):
|
||||
log_str = f'Validation {dataset_name}\n'
|
||||
for metric, value in self.metric_results.items():
|
||||
log_str += f'\t # {metric}: {value:.4f}\n'
|
||||
log_str += f'\t # {metric}: {value:.4f}'
|
||||
if hasattr(self, 'best_metric_results'):
|
||||
log_str += (f'\tBest: {self.best_metric_results[dataset_name][metric]["val"]:.4f} @ '
|
||||
f'{self.best_metric_results[dataset_name][metric]["iter"]} iter')
|
||||
log_str += '\n'
|
||||
|
||||
logger = get_root_logger()
|
||||
logger.info(log_str)
|
||||
if tb_logger:
|
||||
for metric, value in self.metric_results.items():
|
||||
tb_logger.add_scalar(f'metrics/{metric}', value, current_iter)
|
||||
|
||||
def get_current_visuals(self):
|
||||
out_dict = OrderedDict()
|
||||
out_dict['gt'] = self.gt.detach().cpu()
|
||||
out_dict['sr'] = self.output.detach().cpu()
|
||||
return out_dict
|
||||
tb_logger.add_scalar(f'metrics/{dataset_name}/{metric}', value, current_iter)
|
||||
|
||||
def save(self, epoch, current_iter):
|
||||
# save net_g and net_d
|
||||
self.save_network([self.net_g, self.net_g_ema], 'net_g', current_iter, param_key=['params', 'params_ema'])
|
||||
self.save_network(self.net_d, 'net_d', current_iter)
|
||||
# save component discriminators
|
||||
@@ -558,4 +575,5 @@ class GFPGANModel(BaseModel):
|
||||
self.save_network(self.net_d_left_eye, 'net_d_left_eye', current_iter)
|
||||
self.save_network(self.net_d_right_eye, 'net_d_right_eye', current_iter)
|
||||
self.save_network(self.net_d_mouth, 'net_d_mouth', current_iter)
|
||||
# save training state
|
||||
self.save_training_state(epoch, current_iter)
|
||||
|
||||
@@ -2,11 +2,11 @@ import cv2
|
||||
import os
|
||||
import torch
|
||||
from basicsr.utils import img2tensor, tensor2img
|
||||
from basicsr.utils.download_util import load_file_from_url
|
||||
from facexlib.utils.face_restoration_helper import FaceRestoreHelper
|
||||
from torch.hub import download_url_to_file, get_dir
|
||||
from torchvision.transforms.functional import normalize
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from gfpgan.archs.gfpgan_bilinear_arch import GFPGANBilinear
|
||||
from gfpgan.archs.gfpganv1_arch import GFPGANv1
|
||||
from gfpgan.archs.gfpganv1_clean_arch import GFPGANv1Clean
|
||||
|
||||
@@ -14,13 +14,27 @@ ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
|
||||
class GFPGANer():
|
||||
"""Helper for restoration with GFPGAN.
|
||||
|
||||
def __init__(self, model_path, upscale=2, arch='clean', channel_multiplier=2, bg_upsampler=None):
|
||||
It will detect and crop faces, and then resize the faces to 512x512.
|
||||
GFPGAN is used to restored the resized faces.
|
||||
The background is upsampled with the bg_upsampler.
|
||||
Finally, the faces will be pasted back to the upsample background image.
|
||||
|
||||
Args:
|
||||
model_path (str): The path to the GFPGAN model. It can be urls (will first download it automatically).
|
||||
upscale (float): The upscale of the final output. Default: 2.
|
||||
arch (str): The GFPGAN architecture. Option: clean | original. Default: clean.
|
||||
channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
|
||||
bg_upsampler (nn.Module): The upsampler for the background. Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self, model_path, upscale=2, arch='clean', channel_multiplier=2, bg_upsampler=None, device=None):
|
||||
self.upscale = upscale
|
||||
self.bg_upsampler = bg_upsampler
|
||||
|
||||
# initialize model
|
||||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if device is None else device
|
||||
# initialize the GFP-GAN
|
||||
if arch == 'clean':
|
||||
self.gfpgan = GFPGANv1Clean(
|
||||
@@ -34,7 +48,19 @@ class GFPGANer():
|
||||
different_w=True,
|
||||
narrow=1,
|
||||
sft_half=True)
|
||||
else:
|
||||
elif arch == 'bilinear':
|
||||
self.gfpgan = GFPGANBilinear(
|
||||
out_size=512,
|
||||
num_style_feat=512,
|
||||
channel_multiplier=channel_multiplier,
|
||||
decoder_load_path=None,
|
||||
fix_decoder=False,
|
||||
num_mlp=8,
|
||||
input_is_latent=True,
|
||||
different_w=True,
|
||||
narrow=1,
|
||||
sft_half=True)
|
||||
elif arch == 'original':
|
||||
self.gfpgan = GFPGANv1(
|
||||
out_size=512,
|
||||
num_style_feat=512,
|
||||
@@ -46,6 +72,9 @@ class GFPGANer():
|
||||
different_w=True,
|
||||
narrow=1,
|
||||
sft_half=True)
|
||||
elif arch == 'RestoreFormer':
|
||||
from gfpgan.archs.restoreformer_arch import RestoreFormer
|
||||
self.gfpgan = RestoreFormer()
|
||||
# initialize face helper
|
||||
self.face_helper = FaceRestoreHelper(
|
||||
upscale,
|
||||
@@ -53,10 +82,13 @@ class GFPGANer():
|
||||
crop_ratio=(1, 1),
|
||||
det_model='retinaface_resnet50',
|
||||
save_ext='png',
|
||||
device=self.device)
|
||||
use_parse=True,
|
||||
device=self.device,
|
||||
model_rootpath='gfpgan/weights')
|
||||
|
||||
if model_path.startswith('https://'):
|
||||
model_path = load_file_from_url(url=model_path, model_dir='gfpgan/weights', progress=True, file_name=None)
|
||||
model_path = load_file_from_url(
|
||||
url=model_path, model_dir=os.path.join(ROOT_DIR, 'gfpgan/weights'), progress=True, file_name=None)
|
||||
loadnet = torch.load(model_path)
|
||||
if 'params_ema' in loadnet:
|
||||
keyname = 'params_ema'
|
||||
@@ -67,10 +99,10 @@ class GFPGANer():
|
||||
self.gfpgan = self.gfpgan.to(self.device)
|
||||
|
||||
@torch.no_grad()
|
||||
def enhance(self, img, has_aligned=False, only_center_face=False, paste_back=True):
|
||||
def enhance(self, img, has_aligned=False, only_center_face=False, paste_back=True, weight=0.5):
|
||||
self.face_helper.clean_all()
|
||||
|
||||
if has_aligned:
|
||||
if has_aligned: # the inputs are already aligned
|
||||
img = cv2.resize(img, (512, 512))
|
||||
self.face_helper.cropped_faces = [img]
|
||||
else:
|
||||
@@ -78,6 +110,7 @@ class GFPGANer():
|
||||
# get face landmarks for each face
|
||||
self.face_helper.get_face_landmarks_5(only_center_face=only_center_face, eye_dist_threshold=5)
|
||||
# eye_dist_threshold=5: skip faces whose eye distance is smaller than 5 pixels
|
||||
# TODO: even with eye_dist_threshold, it will still introduce wrong detections and restorations.
|
||||
# align and warp each face
|
||||
self.face_helper.align_warp_face()
|
||||
|
||||
@@ -89,7 +122,7 @@ class GFPGANer():
|
||||
cropped_face_t = cropped_face_t.unsqueeze(0).to(self.device)
|
||||
|
||||
try:
|
||||
output = self.gfpgan(cropped_face_t, return_rgb=False)[0]
|
||||
output = self.gfpgan(cropped_face_t, return_rgb=False, weight=weight)[0]
|
||||
# convert to image
|
||||
restored_face = tensor2img(output.squeeze(0), rgb2bgr=True, min_max=(-1, 1))
|
||||
except RuntimeError as error:
|
||||
@@ -100,9 +133,9 @@ class GFPGANer():
|
||||
self.face_helper.add_restored_face(restored_face)
|
||||
|
||||
if not has_aligned and paste_back:
|
||||
|
||||
# upsample the background
|
||||
if self.bg_upsampler is not None:
|
||||
# Now only support RealESRGAN
|
||||
# Now only support RealESRGAN for upsampling background
|
||||
bg_img = self.bg_upsampler.enhance(img, outscale=self.upscale)[0]
|
||||
else:
|
||||
bg_img = None
|
||||
@@ -113,23 +146,3 @@ class GFPGANer():
|
||||
return self.face_helper.cropped_faces, self.face_helper.restored_faces, restored_img
|
||||
else:
|
||||
return self.face_helper.cropped_faces, self.face_helper.restored_faces, None
|
||||
|
||||
|
||||
def load_file_from_url(url, model_dir=None, progress=True, file_name=None):
|
||||
"""Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py
|
||||
"""
|
||||
if model_dir is None:
|
||||
hub_dir = get_dir()
|
||||
model_dir = os.path.join(hub_dir, 'checkpoints')
|
||||
|
||||
os.makedirs(os.path.join(ROOT_DIR, model_dir), exist_ok=True)
|
||||
|
||||
parts = urlparse(url)
|
||||
filename = os.path.basename(parts.path)
|
||||
if file_name is not None:
|
||||
filename = file_name
|
||||
cached_file = os.path.abspath(os.path.join(ROOT_DIR, model_dir, filename))
|
||||
if not os.path.exists(cached_file):
|
||||
print(f'Downloading: "{url}" to {cached_file}\n')
|
||||
download_url_to_file(url, cached_file, hash_prefix=None, progress=progress)
|
||||
return cached_file
|
||||
|
||||
@@ -10,59 +10,119 @@ from gfpgan import GFPGANer
|
||||
|
||||
|
||||
def main():
|
||||
"""Inference demo for GFPGAN (for users).
|
||||
"""
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
'-i',
|
||||
'--input',
|
||||
type=str,
|
||||
default='inputs/whole_imgs',
|
||||
help='Input image or folder. Default: inputs/whole_imgs')
|
||||
parser.add_argument('-o', '--output', type=str, default='results', help='Output folder. Default: results')
|
||||
# we use version to select models, which is more user-friendly
|
||||
parser.add_argument(
|
||||
'-v', '--version', type=str, default='1.3', help='GFPGAN model version. Option: 1 | 1.2 | 1.3. Default: 1.3')
|
||||
parser.add_argument(
|
||||
'-s', '--upscale', type=int, default=2, help='The final upsampling scale of the image. Default: 2')
|
||||
|
||||
parser.add_argument('--upscale', type=int, default=2)
|
||||
parser.add_argument('--arch', type=str, default='clean')
|
||||
parser.add_argument('--channel', type=int, default=2)
|
||||
parser.add_argument('--model_path', type=str, default='experiments/pretrained_models/GFPGANCleanv1-NoCE-C2.pth')
|
||||
parser.add_argument('--bg_upsampler', type=str, default='realesrgan')
|
||||
parser.add_argument('--bg_tile', type=int, default=400)
|
||||
parser.add_argument('--test_path', type=str, default='inputs/whole_imgs')
|
||||
parser.add_argument(
|
||||
'--bg_upsampler', type=str, default='realesrgan', help='background upsampler. Default: realesrgan')
|
||||
parser.add_argument(
|
||||
'--bg_tile',
|
||||
type=int,
|
||||
default=400,
|
||||
help='Tile size for background sampler, 0 for no tile during testing. Default: 400')
|
||||
parser.add_argument('--suffix', type=str, default=None, help='Suffix of the restored faces')
|
||||
parser.add_argument('--only_center_face', action='store_true')
|
||||
parser.add_argument('--aligned', action='store_true')
|
||||
parser.add_argument('--paste_back', action='store_false')
|
||||
parser.add_argument('--save_root', type=str, default='results')
|
||||
parser.add_argument('--only_center_face', action='store_true', help='Only restore the center face')
|
||||
parser.add_argument('--aligned', action='store_true', help='Input are aligned faces')
|
||||
parser.add_argument(
|
||||
'--ext',
|
||||
type=str,
|
||||
default='auto',
|
||||
help='Image extension. Options: auto | jpg | png, auto means using the same extension as inputs')
|
||||
help='Image extension. Options: auto | jpg | png, auto means using the same extension as inputs. Default: auto')
|
||||
parser.add_argument('-w', '--weight', type=float, default=0.5, help='Adjustable weights.')
|
||||
args = parser.parse_args()
|
||||
|
||||
args = parser.parse_args()
|
||||
if args.test_path.endswith('/'):
|
||||
args.test_path = args.test_path[:-1]
|
||||
os.makedirs(args.save_root, exist_ok=True)
|
||||
|
||||
# background upsampler
|
||||
# ------------------------ input & output ------------------------
|
||||
if args.input.endswith('/'):
|
||||
args.input = args.input[:-1]
|
||||
if os.path.isfile(args.input):
|
||||
img_list = [args.input]
|
||||
else:
|
||||
img_list = sorted(glob.glob(os.path.join(args.input, '*')))
|
||||
|
||||
os.makedirs(args.output, exist_ok=True)
|
||||
|
||||
# ------------------------ set up background upsampler ------------------------
|
||||
if args.bg_upsampler == 'realesrgan':
|
||||
if not torch.cuda.is_available(): # CPU
|
||||
import warnings
|
||||
warnings.warn('The unoptimized RealESRGAN is very slow on CPU. We do not use it. '
|
||||
warnings.warn('The unoptimized RealESRGAN is slow on CPU. We do not use it. '
|
||||
'If you really want to use it, please modify the corresponding codes.')
|
||||
bg_upsampler = None
|
||||
else:
|
||||
from basicsr.archs.rrdbnet_arch import RRDBNet
|
||||
from realesrgan import RealESRGANer
|
||||
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2)
|
||||
bg_upsampler = RealESRGANer(
|
||||
scale=2,
|
||||
model_path='https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth',
|
||||
model=model,
|
||||
tile=args.bg_tile,
|
||||
tile_pad=10,
|
||||
pre_pad=0,
|
||||
half=True) # need to set False in CPU mode
|
||||
else:
|
||||
bg_upsampler = None
|
||||
# set up GFPGAN restorer
|
||||
|
||||
# ------------------------ set up GFPGAN restorer ------------------------
|
||||
if args.version == '1':
|
||||
arch = 'original'
|
||||
channel_multiplier = 1
|
||||
model_name = 'GFPGANv1'
|
||||
url = 'https://github.com/TencentARC/GFPGAN/releases/download/v0.1.0/GFPGANv1.pth'
|
||||
elif args.version == '1.2':
|
||||
arch = 'clean'
|
||||
channel_multiplier = 2
|
||||
model_name = 'GFPGANCleanv1-NoCE-C2'
|
||||
url = 'https://github.com/TencentARC/GFPGAN/releases/download/v0.2.0/GFPGANCleanv1-NoCE-C2.pth'
|
||||
elif args.version == '1.3':
|
||||
arch = 'clean'
|
||||
channel_multiplier = 2
|
||||
model_name = 'GFPGANv1.3'
|
||||
url = 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth'
|
||||
elif args.version == '1.4':
|
||||
arch = 'clean'
|
||||
channel_multiplier = 2
|
||||
model_name = 'GFPGANv1.4'
|
||||
url = 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth'
|
||||
elif args.version == 'RestoreFormer':
|
||||
arch = 'RestoreFormer'
|
||||
channel_multiplier = 2
|
||||
model_name = 'RestoreFormer'
|
||||
url = 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/RestoreFormer.pth'
|
||||
else:
|
||||
raise ValueError(f'Wrong model version {args.version}.')
|
||||
|
||||
# determine model paths
|
||||
model_path = os.path.join('experiments/pretrained_models', model_name + '.pth')
|
||||
if not os.path.isfile(model_path):
|
||||
model_path = os.path.join('gfpgan/weights', model_name + '.pth')
|
||||
if not os.path.isfile(model_path):
|
||||
# download pre-trained models from url
|
||||
model_path = url
|
||||
|
||||
restorer = GFPGANer(
|
||||
model_path=args.model_path,
|
||||
model_path=model_path,
|
||||
upscale=args.upscale,
|
||||
arch=args.arch,
|
||||
channel_multiplier=args.channel,
|
||||
arch=arch,
|
||||
channel_multiplier=channel_multiplier,
|
||||
bg_upsampler=bg_upsampler)
|
||||
|
||||
img_list = sorted(glob.glob(os.path.join(args.test_path, '*')))
|
||||
# ------------------------ restore ------------------------
|
||||
for img_path in img_list:
|
||||
# read image
|
||||
img_name = os.path.basename(img_path)
|
||||
@@ -70,24 +130,29 @@ def main():
|
||||
basename, ext = os.path.splitext(img_name)
|
||||
input_img = cv2.imread(img_path, cv2.IMREAD_COLOR)
|
||||
|
||||
# restore faces and background if necessary
|
||||
cropped_faces, restored_faces, restored_img = restorer.enhance(
|
||||
input_img, has_aligned=args.aligned, only_center_face=args.only_center_face, paste_back=args.paste_back)
|
||||
input_img,
|
||||
has_aligned=args.aligned,
|
||||
only_center_face=args.only_center_face,
|
||||
paste_back=True,
|
||||
weight=args.weight)
|
||||
|
||||
# save faces
|
||||
for idx, (cropped_face, restored_face) in enumerate(zip(cropped_faces, restored_faces)):
|
||||
# save cropped face
|
||||
save_crop_path = os.path.join(args.save_root, 'cropped_faces', f'{basename}_{idx:02d}.png')
|
||||
save_crop_path = os.path.join(args.output, 'cropped_faces', f'{basename}_{idx:02d}.png')
|
||||
imwrite(cropped_face, save_crop_path)
|
||||
# save restored face
|
||||
if args.suffix is not None:
|
||||
save_face_name = f'{basename}_{idx:02d}_{args.suffix}.png'
|
||||
else:
|
||||
save_face_name = f'{basename}_{idx:02d}.png'
|
||||
save_restore_path = os.path.join(args.save_root, 'restored_faces', save_face_name)
|
||||
save_restore_path = os.path.join(args.output, 'restored_faces', save_face_name)
|
||||
imwrite(restored_face, save_restore_path)
|
||||
# save cmp image
|
||||
# save comparison image
|
||||
cmp_img = np.concatenate((cropped_face, restored_face), axis=1)
|
||||
imwrite(cmp_img, os.path.join(args.save_root, 'cmp', f'{basename}_{idx:02d}.png'))
|
||||
imwrite(cmp_img, os.path.join(args.output, 'cmp', f'{basename}_{idx:02d}.png'))
|
||||
|
||||
# save restored img
|
||||
if restored_img is not None:
|
||||
@@ -97,13 +162,12 @@ def main():
|
||||
extension = args.ext
|
||||
|
||||
if args.suffix is not None:
|
||||
save_restore_path = os.path.join(args.save_root, 'restored_imgs',
|
||||
f'{basename}_{args.suffix}.{extension}')
|
||||
save_restore_path = os.path.join(args.output, 'restored_imgs', f'{basename}_{args.suffix}.{extension}')
|
||||
else:
|
||||
save_restore_path = os.path.join(args.save_root, 'restored_imgs', f'{basename}.{extension}')
|
||||
save_restore_path = os.path.join(args.output, 'restored_imgs', f'{basename}.{extension}')
|
||||
imwrite(restored_img, save_restore_path)
|
||||
|
||||
print(f'Results are in the [{args.save_root}] folder.')
|
||||
print(f'Results are in the [{args.output}] folder.')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# general settings
|
||||
name: train_GFPGANv1_512
|
||||
model_type: GFPGANModel
|
||||
num_gpu: 4
|
||||
num_gpu: auto # officially, we use 4 GPUs
|
||||
manual_seed: 0
|
||||
|
||||
# dataset and data loader settings
|
||||
@@ -194,7 +194,7 @@ val:
|
||||
save_img: true
|
||||
|
||||
metrics:
|
||||
psnr: # metric name, can be arbitrary
|
||||
psnr: # metric name
|
||||
type: calculate_psnr
|
||||
crop_border: 0
|
||||
test_y_channel: false
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# general settings
|
||||
name: train_GFPGANv1_512_simple
|
||||
model_type: GFPGANModel
|
||||
num_gpu: 4
|
||||
num_gpu: auto # officially, we use 4 GPUs
|
||||
manual_seed: 0
|
||||
|
||||
# dataset and data loader settings
|
||||
@@ -40,10 +40,6 @@ datasets:
|
||||
# gray_prob: 0.01
|
||||
# gt_gray: True
|
||||
|
||||
# crop_components: false
|
||||
# component_path: experiments/pretrained_models/FFHQ_eye_mouth_landmarks_512.pth
|
||||
# eye_enlarge_ratio: 1.4
|
||||
|
||||
# data loader
|
||||
use_shuffle: true
|
||||
num_worker_per_gpu: 6
|
||||
@@ -86,20 +82,6 @@ network_d:
|
||||
channel_multiplier: 1
|
||||
resample_kernel: [1, 3, 3, 1]
|
||||
|
||||
# network_d_left_eye:
|
||||
# type: FacialComponentDiscriminator
|
||||
|
||||
# network_d_right_eye:
|
||||
# type: FacialComponentDiscriminator
|
||||
|
||||
# network_d_mouth:
|
||||
# type: FacialComponentDiscriminator
|
||||
|
||||
network_identity:
|
||||
type: ResNetArcFace
|
||||
block: IRBlock
|
||||
layers: [2, 2, 2, 2]
|
||||
use_se: False
|
||||
|
||||
# path
|
||||
path:
|
||||
@@ -107,13 +89,7 @@ path:
|
||||
param_key_g: params_ema
|
||||
strict_load_g: ~
|
||||
pretrain_network_d: ~
|
||||
# pretrain_network_d_left_eye: ~
|
||||
# pretrain_network_d_right_eye: ~
|
||||
# pretrain_network_d_mouth: ~
|
||||
pretrain_network_identity: experiments/pretrained_models/arcface_resnet18.pth
|
||||
# resume
|
||||
resume_state: ~
|
||||
ignore_resume_networks: ['network_identity']
|
||||
|
||||
# training settings
|
||||
train:
|
||||
@@ -173,16 +149,6 @@ train:
|
||||
loss_weight: !!float 1e-1
|
||||
# r1 regularization for discriminator
|
||||
r1_reg_weight: 10
|
||||
# facial component loss
|
||||
# gan_component_opt:
|
||||
# type: GANLoss
|
||||
# gan_type: vanilla
|
||||
# real_label_val: 1.0
|
||||
# fake_label_val: 0.0
|
||||
# loss_weight: !!float 1
|
||||
# comp_style_weight: 200
|
||||
# identity loss
|
||||
identity_weight: 10
|
||||
|
||||
net_d_iters: 1
|
||||
net_d_init_iters: 0
|
||||
@@ -194,7 +160,7 @@ val:
|
||||
save_img: true
|
||||
|
||||
metrics:
|
||||
psnr: # metric name, can be arbitrary
|
||||
psnr: # metric name
|
||||
type: calculate_psnr
|
||||
crop_border: 0
|
||||
test_y_channel: false
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
torch>=1.7
|
||||
numpy<1.21 # numba requires numpy<1.21,>=1.17
|
||||
opencv-python
|
||||
torchvision
|
||||
scipy
|
||||
tqdm
|
||||
basicsr>=1.3.4.0
|
||||
facexlib>=0.2.0.3
|
||||
basicsr>=1.4.2
|
||||
facexlib>=0.2.5
|
||||
lmdb
|
||||
numpy
|
||||
opencv-python
|
||||
pyyaml
|
||||
scipy
|
||||
tb-nightly
|
||||
torch>=1.7
|
||||
torchvision
|
||||
tqdm
|
||||
yapf
|
||||
|
||||
164
scripts/convert_gfpganv_to_clean.py
Normal file
164
scripts/convert_gfpganv_to_clean.py
Normal file
@@ -0,0 +1,164 @@
|
||||
import argparse
|
||||
import math
|
||||
import torch
|
||||
|
||||
from gfpgan.archs.gfpganv1_clean_arch import GFPGANv1Clean
|
||||
|
||||
|
||||
def modify_checkpoint(checkpoint_bilinear, checkpoint_clean):
|
||||
for ori_k, ori_v in checkpoint_bilinear.items():
|
||||
if 'stylegan_decoder' in ori_k:
|
||||
if 'style_mlp' in ori_k: # style_mlp_layers
|
||||
lr_mul = 0.01
|
||||
prefix, name, idx, var = ori_k.split('.')
|
||||
idx = (int(idx) * 2) - 1
|
||||
crt_k = f'{prefix}.{name}.{idx}.{var}'
|
||||
if var == 'weight':
|
||||
_, c_in = ori_v.size()
|
||||
scale = (1 / math.sqrt(c_in)) * lr_mul
|
||||
crt_v = ori_v * scale * 2**0.5
|
||||
else:
|
||||
crt_v = ori_v * lr_mul * 2**0.5
|
||||
checkpoint_clean[crt_k] = crt_v
|
||||
elif 'modulation' in ori_k: # modulation in StyleConv
|
||||
lr_mul = 1
|
||||
crt_k = ori_k
|
||||
var = ori_k.split('.')[-1]
|
||||
if var == 'weight':
|
||||
_, c_in = ori_v.size()
|
||||
scale = (1 / math.sqrt(c_in)) * lr_mul
|
||||
crt_v = ori_v * scale
|
||||
else:
|
||||
crt_v = ori_v * lr_mul
|
||||
checkpoint_clean[crt_k] = crt_v
|
||||
elif 'style_conv' in ori_k:
|
||||
# StyleConv in style_conv1 and style_convs
|
||||
if 'activate' in ori_k: # FusedLeakyReLU
|
||||
# eg. style_conv1.activate.bias
|
||||
# eg. style_convs.13.activate.bias
|
||||
split_rlt = ori_k.split('.')
|
||||
if len(split_rlt) == 4:
|
||||
prefix, name, _, var = split_rlt
|
||||
crt_k = f'{prefix}.{name}.{var}'
|
||||
elif len(split_rlt) == 5:
|
||||
prefix, name, idx, _, var = split_rlt
|
||||
crt_k = f'{prefix}.{name}.{idx}.{var}'
|
||||
crt_v = ori_v * 2**0.5 # 2**0.5 used in FusedLeakyReLU
|
||||
c = crt_v.size(0)
|
||||
checkpoint_clean[crt_k] = crt_v.view(1, c, 1, 1)
|
||||
elif 'modulated_conv' in ori_k:
|
||||
# eg. style_conv1.modulated_conv.weight
|
||||
# eg. style_convs.13.modulated_conv.weight
|
||||
_, c_out, c_in, k1, k2 = ori_v.size()
|
||||
scale = 1 / math.sqrt(c_in * k1 * k2)
|
||||
crt_k = ori_k
|
||||
checkpoint_clean[crt_k] = ori_v * scale
|
||||
elif 'weight' in ori_k:
|
||||
crt_k = ori_k
|
||||
checkpoint_clean[crt_k] = ori_v * 2**0.5
|
||||
elif 'to_rgb' in ori_k: # StyleConv in to_rgb1 and to_rgbs
|
||||
if 'modulated_conv' in ori_k:
|
||||
# eg. to_rgb1.modulated_conv.weight
|
||||
# eg. to_rgbs.5.modulated_conv.weight
|
||||
_, c_out, c_in, k1, k2 = ori_v.size()
|
||||
scale = 1 / math.sqrt(c_in * k1 * k2)
|
||||
crt_k = ori_k
|
||||
checkpoint_clean[crt_k] = ori_v * scale
|
||||
else:
|
||||
crt_k = ori_k
|
||||
checkpoint_clean[crt_k] = ori_v
|
||||
else:
|
||||
crt_k = ori_k
|
||||
checkpoint_clean[crt_k] = ori_v
|
||||
# end of 'stylegan_decoder'
|
||||
elif 'conv_body_first' in ori_k or 'final_conv' in ori_k:
|
||||
# key name
|
||||
name, _, var = ori_k.split('.')
|
||||
crt_k = f'{name}.{var}'
|
||||
# weight and bias
|
||||
if var == 'weight':
|
||||
c_out, c_in, k1, k2 = ori_v.size()
|
||||
scale = 1 / math.sqrt(c_in * k1 * k2)
|
||||
checkpoint_clean[crt_k] = ori_v * scale * 2**0.5
|
||||
else:
|
||||
checkpoint_clean[crt_k] = ori_v * 2**0.5
|
||||
elif 'conv_body' in ori_k:
|
||||
if 'conv_body_up' in ori_k:
|
||||
ori_k = ori_k.replace('conv2.weight', 'conv2.1.weight')
|
||||
ori_k = ori_k.replace('skip.weight', 'skip.1.weight')
|
||||
name1, idx1, name2, _, var = ori_k.split('.')
|
||||
crt_k = f'{name1}.{idx1}.{name2}.{var}'
|
||||
if name2 == 'skip':
|
||||
c_out, c_in, k1, k2 = ori_v.size()
|
||||
scale = 1 / math.sqrt(c_in * k1 * k2)
|
||||
checkpoint_clean[crt_k] = ori_v * scale / 2**0.5
|
||||
else:
|
||||
if var == 'weight':
|
||||
c_out, c_in, k1, k2 = ori_v.size()
|
||||
scale = 1 / math.sqrt(c_in * k1 * k2)
|
||||
checkpoint_clean[crt_k] = ori_v * scale
|
||||
else:
|
||||
checkpoint_clean[crt_k] = ori_v
|
||||
if 'conv1' in ori_k:
|
||||
checkpoint_clean[crt_k] *= 2**0.5
|
||||
elif 'toRGB' in ori_k:
|
||||
crt_k = ori_k
|
||||
if 'weight' in ori_k:
|
||||
c_out, c_in, k1, k2 = ori_v.size()
|
||||
scale = 1 / math.sqrt(c_in * k1 * k2)
|
||||
checkpoint_clean[crt_k] = ori_v * scale
|
||||
else:
|
||||
checkpoint_clean[crt_k] = ori_v
|
||||
elif 'final_linear' in ori_k:
|
||||
crt_k = ori_k
|
||||
if 'weight' in ori_k:
|
||||
_, c_in = ori_v.size()
|
||||
scale = 1 / math.sqrt(c_in)
|
||||
checkpoint_clean[crt_k] = ori_v * scale
|
||||
else:
|
||||
checkpoint_clean[crt_k] = ori_v
|
||||
elif 'condition' in ori_k:
|
||||
crt_k = ori_k
|
||||
if '0.weight' in ori_k:
|
||||
c_out, c_in, k1, k2 = ori_v.size()
|
||||
scale = 1 / math.sqrt(c_in * k1 * k2)
|
||||
checkpoint_clean[crt_k] = ori_v * scale * 2**0.5
|
||||
elif '0.bias' in ori_k:
|
||||
checkpoint_clean[crt_k] = ori_v * 2**0.5
|
||||
elif '2.weight' in ori_k:
|
||||
c_out, c_in, k1, k2 = ori_v.size()
|
||||
scale = 1 / math.sqrt(c_in * k1 * k2)
|
||||
checkpoint_clean[crt_k] = ori_v * scale
|
||||
elif '2.bias' in ori_k:
|
||||
checkpoint_clean[crt_k] = ori_v
|
||||
|
||||
return checkpoint_clean
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--ori_path', type=str, help='Path to the original model')
|
||||
parser.add_argument('--narrow', type=float, default=1)
|
||||
parser.add_argument('--channel_multiplier', type=float, default=2)
|
||||
parser.add_argument('--save_path', type=str)
|
||||
args = parser.parse_args()
|
||||
|
||||
ori_ckpt = torch.load(args.ori_path)['params_ema']
|
||||
|
||||
net = GFPGANv1Clean(
|
||||
512,
|
||||
num_style_feat=512,
|
||||
channel_multiplier=args.channel_multiplier,
|
||||
decoder_load_path=None,
|
||||
fix_decoder=False,
|
||||
# for stylegan decoder
|
||||
num_mlp=8,
|
||||
input_is_latent=True,
|
||||
different_w=True,
|
||||
narrow=args.narrow,
|
||||
sft_half=True)
|
||||
crt_ckpt = net.state_dict()
|
||||
|
||||
crt_ckpt = modify_checkpoint(ori_ckpt, crt_ckpt)
|
||||
print(f'Save to {args.save_path}.')
|
||||
torch.save(dict(params_ema=crt_ckpt), args.save_path, _use_new_zipfile_serialization=False)
|
||||
@@ -1,24 +1,31 @@
|
||||
import cv2
|
||||
import json
|
||||
import numpy as np
|
||||
import os
|
||||
import torch
|
||||
from basicsr.utils import FileClient, imfrombytes
|
||||
from collections import OrderedDict
|
||||
|
||||
# ---------------------------- This script is used to parse facial landmarks ------------------------------------- #
|
||||
# Configurations
|
||||
save_img = False
|
||||
scale = 0.5 # 0.5 for official FFHQ (512x512), 1 for others
|
||||
enlarge_ratio = 1.4 # only for eyes
|
||||
json_path = 'ffhq-dataset-v2.json'
|
||||
face_path = 'datasets/ffhq/ffhq_512.lmdb'
|
||||
save_path = './FFHQ_eye_mouth_landmarks_512.pth'
|
||||
|
||||
print('Load JSON metadata...')
|
||||
# use the json file in FFHQ dataset
|
||||
with open('ffhq-dataset-v2.json', 'rb') as f:
|
||||
# use the official json file in FFHQ dataset
|
||||
with open(json_path, 'rb') as f:
|
||||
json_data = json.load(f, object_pairs_hook=OrderedDict)
|
||||
|
||||
print('Open LMDB file...')
|
||||
# read ffhq images
|
||||
file_client = FileClient('lmdb', db_paths='datasets/ffhq/ffhq_512.lmdb')
|
||||
with open('datasets/ffhq/ffhq_512.lmdb/meta_info.txt') as fin:
|
||||
file_client = FileClient('lmdb', db_paths=face_path)
|
||||
with open(os.path.join(face_path, 'meta_info.txt')) as fin:
|
||||
paths = [line.split('.')[0] for line in fin]
|
||||
|
||||
save_img = False
|
||||
scale = 0.5 # 0.5 for official FFHQ (512x512), 1 for others
|
||||
enlarge_ratio = 1.4 # only for eyes
|
||||
save_dict = {}
|
||||
|
||||
for item_idx, item in enumerate(json_data.values()):
|
||||
@@ -34,6 +41,7 @@ for item_idx, item in enumerate(json_data.values()):
|
||||
img_bytes = file_client.get(paths[item_idx])
|
||||
img = imfrombytes(img_bytes, float32=True)
|
||||
|
||||
# get landmarks for each component
|
||||
map_left_eye = list(range(36, 42))
|
||||
map_right_eye = list(range(42, 48))
|
||||
map_mouth = list(range(48, 68))
|
||||
@@ -74,4 +82,4 @@ for item_idx, item in enumerate(json_data.values()):
|
||||
save_dict[f'{item_idx:08d}'] = item_dict
|
||||
|
||||
print('Save...')
|
||||
torch.save(save_dict, './FFHQ_eye_mouth_landmarks_512.pth')
|
||||
torch.save(save_dict, save_path)
|
||||
|
||||
@@ -17,7 +17,7 @@ line_length = 120
|
||||
multi_line_output = 0
|
||||
known_standard_library = pkg_resources,setuptools
|
||||
known_first_party = gfpgan
|
||||
known_third_party = basicsr,cv2,facexlib,numpy,torch,torchvision,tqdm
|
||||
known_third_party = basicsr,cv2,facexlib,numpy,pytest,torch,torchvision,tqdm,yaml
|
||||
no_lines_before = STDLIB,LOCALFOLDER
|
||||
default_section = THIRDPARTY
|
||||
|
||||
@@ -25,3 +25,9 @@ default_section = THIRDPARTY
|
||||
skip = .git,./docs/build
|
||||
count =
|
||||
quiet-level = 3
|
||||
|
||||
[aliases]
|
||||
test=pytest
|
||||
|
||||
[tool:pytest]
|
||||
addopts=tests/
|
||||
|
||||
BIN
tests/data/ffhq_gt.lmdb/data.mdb
Normal file
BIN
tests/data/ffhq_gt.lmdb/data.mdb
Normal file
Binary file not shown.
BIN
tests/data/ffhq_gt.lmdb/lock.mdb
Normal file
BIN
tests/data/ffhq_gt.lmdb/lock.mdb
Normal file
Binary file not shown.
1
tests/data/ffhq_gt.lmdb/meta_info.txt
Normal file
1
tests/data/ffhq_gt.lmdb/meta_info.txt
Normal file
@@ -0,0 +1 @@
|
||||
00000000.png (512,512,3) 1
|
||||
BIN
tests/data/gt/00000000.png
Normal file
BIN
tests/data/gt/00000000.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 429 KiB |
BIN
tests/data/test_eye_mouth_landmarks.pth
Normal file
BIN
tests/data/test_eye_mouth_landmarks.pth
Normal file
Binary file not shown.
24
tests/data/test_ffhq_degradation_dataset.yml
Normal file
24
tests/data/test_ffhq_degradation_dataset.yml
Normal file
@@ -0,0 +1,24 @@
|
||||
name: UnitTest
|
||||
type: FFHQDegradationDataset
|
||||
dataroot_gt: tests/data/gt
|
||||
io_backend:
|
||||
type: disk
|
||||
|
||||
use_hflip: true
|
||||
mean: [0.5, 0.5, 0.5]
|
||||
std: [0.5, 0.5, 0.5]
|
||||
out_size: 512
|
||||
|
||||
blur_kernel_size: 41
|
||||
kernel_list: ['iso', 'aniso']
|
||||
kernel_prob: [0.5, 0.5]
|
||||
blur_sigma: [0.1, 10]
|
||||
downsample_range: [0.8, 8]
|
||||
noise_range: [0, 20]
|
||||
jpeg_range: [60, 100]
|
||||
|
||||
# color jitter and gray
|
||||
color_jitter_prob: 1
|
||||
color_jitter_shift: 20
|
||||
color_jitter_pt_prob: 1
|
||||
gray_prob: 1
|
||||
140
tests/data/test_gfpgan_model.yml
Normal file
140
tests/data/test_gfpgan_model.yml
Normal file
@@ -0,0 +1,140 @@
|
||||
num_gpu: 1
|
||||
manual_seed: 0
|
||||
is_train: True
|
||||
dist: False
|
||||
|
||||
# network structures
|
||||
network_g:
|
||||
type: GFPGANv1
|
||||
out_size: 512
|
||||
num_style_feat: 512
|
||||
channel_multiplier: 1
|
||||
resample_kernel: [1, 3, 3, 1]
|
||||
decoder_load_path: ~
|
||||
fix_decoder: true
|
||||
num_mlp: 8
|
||||
lr_mlp: 0.01
|
||||
input_is_latent: true
|
||||
different_w: true
|
||||
narrow: 0.5
|
||||
sft_half: true
|
||||
|
||||
network_d:
|
||||
type: StyleGAN2Discriminator
|
||||
out_size: 512
|
||||
channel_multiplier: 1
|
||||
resample_kernel: [1, 3, 3, 1]
|
||||
|
||||
network_d_left_eye:
|
||||
type: FacialComponentDiscriminator
|
||||
|
||||
network_d_right_eye:
|
||||
type: FacialComponentDiscriminator
|
||||
|
||||
network_d_mouth:
|
||||
type: FacialComponentDiscriminator
|
||||
|
||||
network_identity:
|
||||
type: ResNetArcFace
|
||||
block: IRBlock
|
||||
layers: [2, 2, 2, 2]
|
||||
use_se: False
|
||||
|
||||
# path
|
||||
path:
|
||||
pretrain_network_g: ~
|
||||
param_key_g: params_ema
|
||||
strict_load_g: ~
|
||||
pretrain_network_d: ~
|
||||
pretrain_network_d_left_eye: ~
|
||||
pretrain_network_d_right_eye: ~
|
||||
pretrain_network_d_mouth: ~
|
||||
pretrain_network_identity: ~
|
||||
# resume
|
||||
resume_state: ~
|
||||
ignore_resume_networks: ['network_identity']
|
||||
|
||||
# training settings
|
||||
train:
|
||||
optim_g:
|
||||
type: Adam
|
||||
lr: !!float 2e-3
|
||||
optim_d:
|
||||
type: Adam
|
||||
lr: !!float 2e-3
|
||||
optim_component:
|
||||
type: Adam
|
||||
lr: !!float 2e-3
|
||||
|
||||
scheduler:
|
||||
type: MultiStepLR
|
||||
milestones: [600000, 700000]
|
||||
gamma: 0.5
|
||||
|
||||
total_iter: 800000
|
||||
warmup_iter: -1 # no warm up
|
||||
|
||||
# losses
|
||||
# pixel loss
|
||||
pixel_opt:
|
||||
type: L1Loss
|
||||
loss_weight: !!float 1e-1
|
||||
reduction: mean
|
||||
# L1 loss used in pyramid loss, component style loss and identity loss
|
||||
L1_opt:
|
||||
type: L1Loss
|
||||
loss_weight: 1
|
||||
reduction: mean
|
||||
|
||||
# image pyramid loss
|
||||
pyramid_loss_weight: 1
|
||||
remove_pyramid_loss: 50000
|
||||
# perceptual loss (content and style losses)
|
||||
perceptual_opt:
|
||||
type: PerceptualLoss
|
||||
layer_weights:
|
||||
# before relu
|
||||
'conv1_2': 0.1
|
||||
'conv2_2': 0.1
|
||||
'conv3_4': 1
|
||||
'conv4_4': 1
|
||||
'conv5_4': 1
|
||||
vgg_type: vgg19
|
||||
use_input_norm: true
|
||||
perceptual_weight: !!float 1
|
||||
style_weight: 50
|
||||
range_norm: true
|
||||
criterion: l1
|
||||
# gan loss
|
||||
gan_opt:
|
||||
type: GANLoss
|
||||
gan_type: wgan_softplus
|
||||
loss_weight: !!float 1e-1
|
||||
# r1 regularization for discriminator
|
||||
r1_reg_weight: 10
|
||||
# facial component loss
|
||||
gan_component_opt:
|
||||
type: GANLoss
|
||||
gan_type: vanilla
|
||||
real_label_val: 1.0
|
||||
fake_label_val: 0.0
|
||||
loss_weight: !!float 1
|
||||
comp_style_weight: 200
|
||||
# identity loss
|
||||
identity_weight: 10
|
||||
|
||||
net_d_iters: 1
|
||||
net_d_init_iters: 0
|
||||
net_d_reg_every: 1
|
||||
|
||||
# validation settings
|
||||
val:
|
||||
val_freq: !!float 5e3
|
||||
save_img: True
|
||||
use_pbar: True
|
||||
|
||||
metrics:
|
||||
psnr: # metric name
|
||||
type: calculate_psnr
|
||||
crop_border: 0
|
||||
test_y_channel: false
|
||||
49
tests/test_arcface_arch.py
Normal file
49
tests/test_arcface_arch.py
Normal file
@@ -0,0 +1,49 @@
|
||||
import torch
|
||||
|
||||
from gfpgan.archs.arcface_arch import BasicBlock, Bottleneck, ResNetArcFace
|
||||
|
||||
|
||||
def test_resnetarcface():
|
||||
"""Test arch: ResNetArcFace."""
|
||||
|
||||
# model init and forward (gpu)
|
||||
if torch.cuda.is_available():
|
||||
net = ResNetArcFace(block='IRBlock', layers=(2, 2, 2, 2), use_se=True).cuda().eval()
|
||||
img = torch.rand((1, 1, 128, 128), dtype=torch.float32).cuda()
|
||||
output = net(img)
|
||||
assert output.shape == (1, 512)
|
||||
|
||||
# -------------------- without SE block ----------------------- #
|
||||
net = ResNetArcFace(block='IRBlock', layers=(2, 2, 2, 2), use_se=False).cuda().eval()
|
||||
output = net(img)
|
||||
assert output.shape == (1, 512)
|
||||
|
||||
|
||||
def test_basicblock():
|
||||
"""Test the BasicBlock in arcface_arch"""
|
||||
block = BasicBlock(1, 3, stride=1, downsample=None).cuda()
|
||||
img = torch.rand((1, 1, 12, 12), dtype=torch.float32).cuda()
|
||||
output = block(img)
|
||||
assert output.shape == (1, 3, 12, 12)
|
||||
|
||||
# ----------------- use the downsmaple module--------------- #
|
||||
downsample = torch.nn.UpsamplingNearest2d(scale_factor=0.5).cuda()
|
||||
block = BasicBlock(1, 3, stride=2, downsample=downsample).cuda()
|
||||
img = torch.rand((1, 1, 12, 12), dtype=torch.float32).cuda()
|
||||
output = block(img)
|
||||
assert output.shape == (1, 3, 6, 6)
|
||||
|
||||
|
||||
def test_bottleneck():
|
||||
"""Test the Bottleneck in arcface_arch"""
|
||||
block = Bottleneck(1, 1, stride=1, downsample=None).cuda()
|
||||
img = torch.rand((1, 1, 12, 12), dtype=torch.float32).cuda()
|
||||
output = block(img)
|
||||
assert output.shape == (1, 4, 12, 12)
|
||||
|
||||
# ----------------- use the downsmaple module--------------- #
|
||||
downsample = torch.nn.UpsamplingNearest2d(scale_factor=0.5).cuda()
|
||||
block = Bottleneck(1, 1, stride=2, downsample=downsample).cuda()
|
||||
img = torch.rand((1, 1, 12, 12), dtype=torch.float32).cuda()
|
||||
output = block(img)
|
||||
assert output.shape == (1, 4, 6, 6)
|
||||
96
tests/test_ffhq_degradation_dataset.py
Normal file
96
tests/test_ffhq_degradation_dataset.py
Normal file
@@ -0,0 +1,96 @@
|
||||
import pytest
|
||||
import yaml
|
||||
|
||||
from gfpgan.data.ffhq_degradation_dataset import FFHQDegradationDataset
|
||||
|
||||
|
||||
def test_ffhq_degradation_dataset():
|
||||
|
||||
with open('tests/data/test_ffhq_degradation_dataset.yml', mode='r') as f:
|
||||
opt = yaml.load(f, Loader=yaml.FullLoader)
|
||||
|
||||
dataset = FFHQDegradationDataset(opt)
|
||||
assert dataset.io_backend_opt['type'] == 'disk' # io backend
|
||||
assert len(dataset) == 1 # whether to read correct meta info
|
||||
assert dataset.kernel_list == ['iso', 'aniso'] # correct initialization the degradation configurations
|
||||
assert dataset.color_jitter_prob == 1
|
||||
|
||||
# test __getitem__
|
||||
result = dataset.__getitem__(0)
|
||||
# check returned keys
|
||||
expected_keys = ['gt', 'lq', 'gt_path']
|
||||
assert set(expected_keys).issubset(set(result.keys()))
|
||||
# check shape and contents
|
||||
assert result['gt'].shape == (3, 512, 512)
|
||||
assert result['lq'].shape == (3, 512, 512)
|
||||
assert result['gt_path'] == 'tests/data/gt/00000000.png'
|
||||
|
||||
# ------------------ test with probability = 0 -------------------- #
|
||||
opt['color_jitter_prob'] = 0
|
||||
opt['color_jitter_pt_prob'] = 0
|
||||
opt['gray_prob'] = 0
|
||||
opt['io_backend'] = dict(type='disk')
|
||||
dataset = FFHQDegradationDataset(opt)
|
||||
assert dataset.io_backend_opt['type'] == 'disk' # io backend
|
||||
assert len(dataset) == 1 # whether to read correct meta info
|
||||
assert dataset.kernel_list == ['iso', 'aniso'] # correct initialization the degradation configurations
|
||||
assert dataset.color_jitter_prob == 0
|
||||
|
||||
# test __getitem__
|
||||
result = dataset.__getitem__(0)
|
||||
# check returned keys
|
||||
expected_keys = ['gt', 'lq', 'gt_path']
|
||||
assert set(expected_keys).issubset(set(result.keys()))
|
||||
# check shape and contents
|
||||
assert result['gt'].shape == (3, 512, 512)
|
||||
assert result['lq'].shape == (3, 512, 512)
|
||||
assert result['gt_path'] == 'tests/data/gt/00000000.png'
|
||||
|
||||
# ------------------ test lmdb backend -------------------- #
|
||||
opt['dataroot_gt'] = 'tests/data/ffhq_gt.lmdb'
|
||||
opt['io_backend'] = dict(type='lmdb')
|
||||
|
||||
dataset = FFHQDegradationDataset(opt)
|
||||
assert dataset.io_backend_opt['type'] == 'lmdb' # io backend
|
||||
assert len(dataset) == 1 # whether to read correct meta info
|
||||
assert dataset.kernel_list == ['iso', 'aniso'] # correct initialization the degradation configurations
|
||||
assert dataset.color_jitter_prob == 0
|
||||
|
||||
# test __getitem__
|
||||
result = dataset.__getitem__(0)
|
||||
# check returned keys
|
||||
expected_keys = ['gt', 'lq', 'gt_path']
|
||||
assert set(expected_keys).issubset(set(result.keys()))
|
||||
# check shape and contents
|
||||
assert result['gt'].shape == (3, 512, 512)
|
||||
assert result['lq'].shape == (3, 512, 512)
|
||||
assert result['gt_path'] == '00000000'
|
||||
|
||||
# ------------------ test with crop_components -------------------- #
|
||||
opt['crop_components'] = True
|
||||
opt['component_path'] = 'tests/data/test_eye_mouth_landmarks.pth'
|
||||
opt['eye_enlarge_ratio'] = 1.4
|
||||
opt['gt_gray'] = True
|
||||
opt['io_backend'] = dict(type='lmdb')
|
||||
|
||||
dataset = FFHQDegradationDataset(opt)
|
||||
assert dataset.crop_components is True
|
||||
|
||||
# test __getitem__
|
||||
result = dataset.__getitem__(0)
|
||||
# check returned keys
|
||||
expected_keys = ['gt', 'lq', 'gt_path', 'loc_left_eye', 'loc_right_eye', 'loc_mouth']
|
||||
assert set(expected_keys).issubset(set(result.keys()))
|
||||
# check shape and contents
|
||||
assert result['gt'].shape == (3, 512, 512)
|
||||
assert result['lq'].shape == (3, 512, 512)
|
||||
assert result['gt_path'] == '00000000'
|
||||
assert result['loc_left_eye'].shape == (4, )
|
||||
assert result['loc_right_eye'].shape == (4, )
|
||||
assert result['loc_mouth'].shape == (4, )
|
||||
|
||||
# ------------------ lmdb backend should have paths ends with lmdb -------------------- #
|
||||
with pytest.raises(ValueError):
|
||||
opt['dataroot_gt'] = 'tests/data/gt'
|
||||
opt['io_backend'] = dict(type='lmdb')
|
||||
dataset = FFHQDegradationDataset(opt)
|
||||
203
tests/test_gfpgan_arch.py
Normal file
203
tests/test_gfpgan_arch.py
Normal file
@@ -0,0 +1,203 @@
|
||||
import torch
|
||||
|
||||
from gfpgan.archs.gfpganv1_arch import FacialComponentDiscriminator, GFPGANv1, StyleGAN2GeneratorSFT
|
||||
from gfpgan.archs.gfpganv1_clean_arch import GFPGANv1Clean, StyleGAN2GeneratorCSFT
|
||||
|
||||
|
||||
def test_stylegan2generatorsft():
|
||||
"""Test arch: StyleGAN2GeneratorSFT."""
|
||||
|
||||
# model init and forward (gpu)
|
||||
if torch.cuda.is_available():
|
||||
net = StyleGAN2GeneratorSFT(
|
||||
out_size=32,
|
||||
num_style_feat=512,
|
||||
num_mlp=8,
|
||||
channel_multiplier=1,
|
||||
resample_kernel=(1, 3, 3, 1),
|
||||
lr_mlp=0.01,
|
||||
narrow=1,
|
||||
sft_half=False).cuda().eval()
|
||||
style = torch.rand((1, 512), dtype=torch.float32).cuda()
|
||||
condition1 = torch.rand((1, 512, 8, 8), dtype=torch.float32).cuda()
|
||||
condition2 = torch.rand((1, 512, 16, 16), dtype=torch.float32).cuda()
|
||||
condition3 = torch.rand((1, 512, 32, 32), dtype=torch.float32).cuda()
|
||||
conditions = [condition1, condition1, condition2, condition2, condition3, condition3]
|
||||
output = net([style], conditions)
|
||||
assert output[0].shape == (1, 3, 32, 32)
|
||||
assert output[1] is None
|
||||
|
||||
# -------------------- with return_latents ----------------------- #
|
||||
output = net([style], conditions, return_latents=True)
|
||||
assert output[0].shape == (1, 3, 32, 32)
|
||||
assert len(output[1]) == 1
|
||||
# check latent
|
||||
assert output[1][0].shape == (8, 512)
|
||||
|
||||
# -------------------- with randomize_noise = False ----------------------- #
|
||||
output = net([style], conditions, randomize_noise=False)
|
||||
assert output[0].shape == (1, 3, 32, 32)
|
||||
assert output[1] is None
|
||||
|
||||
# -------------------- with truncation = 0.5 and mixing----------------------- #
|
||||
output = net([style, style], conditions, truncation=0.5, truncation_latent=style)
|
||||
assert output[0].shape == (1, 3, 32, 32)
|
||||
assert output[1] is None
|
||||
|
||||
|
||||
def test_gfpganv1():
|
||||
"""Test arch: GFPGANv1."""
|
||||
|
||||
# model init and forward (gpu)
|
||||
if torch.cuda.is_available():
|
||||
net = GFPGANv1(
|
||||
out_size=32,
|
||||
num_style_feat=512,
|
||||
channel_multiplier=1,
|
||||
resample_kernel=(1, 3, 3, 1),
|
||||
decoder_load_path=None,
|
||||
fix_decoder=True,
|
||||
# for stylegan decoder
|
||||
num_mlp=8,
|
||||
lr_mlp=0.01,
|
||||
input_is_latent=False,
|
||||
different_w=False,
|
||||
narrow=1,
|
||||
sft_half=True).cuda().eval()
|
||||
img = torch.rand((1, 3, 32, 32), dtype=torch.float32).cuda()
|
||||
output = net(img)
|
||||
assert output[0].shape == (1, 3, 32, 32)
|
||||
assert len(output[1]) == 3
|
||||
# check out_rgbs for intermediate loss
|
||||
assert output[1][0].shape == (1, 3, 8, 8)
|
||||
assert output[1][1].shape == (1, 3, 16, 16)
|
||||
assert output[1][2].shape == (1, 3, 32, 32)
|
||||
|
||||
# -------------------- with different_w = True ----------------------- #
|
||||
net = GFPGANv1(
|
||||
out_size=32,
|
||||
num_style_feat=512,
|
||||
channel_multiplier=1,
|
||||
resample_kernel=(1, 3, 3, 1),
|
||||
decoder_load_path=None,
|
||||
fix_decoder=True,
|
||||
# for stylegan decoder
|
||||
num_mlp=8,
|
||||
lr_mlp=0.01,
|
||||
input_is_latent=False,
|
||||
different_w=True,
|
||||
narrow=1,
|
||||
sft_half=True).cuda().eval()
|
||||
img = torch.rand((1, 3, 32, 32), dtype=torch.float32).cuda()
|
||||
output = net(img)
|
||||
assert output[0].shape == (1, 3, 32, 32)
|
||||
assert len(output[1]) == 3
|
||||
# check out_rgbs for intermediate loss
|
||||
assert output[1][0].shape == (1, 3, 8, 8)
|
||||
assert output[1][1].shape == (1, 3, 16, 16)
|
||||
assert output[1][2].shape == (1, 3, 32, 32)
|
||||
|
||||
|
||||
def test_facialcomponentdiscriminator():
|
||||
"""Test arch: FacialComponentDiscriminator."""
|
||||
|
||||
# model init and forward (gpu)
|
||||
if torch.cuda.is_available():
|
||||
net = FacialComponentDiscriminator().cuda().eval()
|
||||
img = torch.rand((1, 3, 32, 32), dtype=torch.float32).cuda()
|
||||
output = net(img)
|
||||
assert len(output) == 2
|
||||
assert output[0].shape == (1, 1, 8, 8)
|
||||
assert output[1] is None
|
||||
|
||||
# -------------------- return intermediate features ----------------------- #
|
||||
output = net(img, return_feats=True)
|
||||
assert len(output) == 2
|
||||
assert output[0].shape == (1, 1, 8, 8)
|
||||
assert len(output[1]) == 2
|
||||
assert output[1][0].shape == (1, 128, 16, 16)
|
||||
assert output[1][1].shape == (1, 256, 8, 8)
|
||||
|
||||
|
||||
def test_stylegan2generatorcsft():
|
||||
"""Test arch: StyleGAN2GeneratorCSFT."""
|
||||
|
||||
# model init and forward (gpu)
|
||||
if torch.cuda.is_available():
|
||||
net = StyleGAN2GeneratorCSFT(
|
||||
out_size=32, num_style_feat=512, num_mlp=8, channel_multiplier=1, narrow=1, sft_half=False).cuda().eval()
|
||||
style = torch.rand((1, 512), dtype=torch.float32).cuda()
|
||||
condition1 = torch.rand((1, 512, 8, 8), dtype=torch.float32).cuda()
|
||||
condition2 = torch.rand((1, 512, 16, 16), dtype=torch.float32).cuda()
|
||||
condition3 = torch.rand((1, 512, 32, 32), dtype=torch.float32).cuda()
|
||||
conditions = [condition1, condition1, condition2, condition2, condition3, condition3]
|
||||
output = net([style], conditions)
|
||||
assert output[0].shape == (1, 3, 32, 32)
|
||||
assert output[1] is None
|
||||
|
||||
# -------------------- with return_latents ----------------------- #
|
||||
output = net([style], conditions, return_latents=True)
|
||||
assert output[0].shape == (1, 3, 32, 32)
|
||||
assert len(output[1]) == 1
|
||||
# check latent
|
||||
assert output[1][0].shape == (8, 512)
|
||||
|
||||
# -------------------- with randomize_noise = False ----------------------- #
|
||||
output = net([style], conditions, randomize_noise=False)
|
||||
assert output[0].shape == (1, 3, 32, 32)
|
||||
assert output[1] is None
|
||||
|
||||
# -------------------- with truncation = 0.5 and mixing----------------------- #
|
||||
output = net([style, style], conditions, truncation=0.5, truncation_latent=style)
|
||||
assert output[0].shape == (1, 3, 32, 32)
|
||||
assert output[1] is None
|
||||
|
||||
|
||||
def test_gfpganv1clean():
|
||||
"""Test arch: GFPGANv1Clean."""
|
||||
|
||||
# model init and forward (gpu)
|
||||
if torch.cuda.is_available():
|
||||
net = GFPGANv1Clean(
|
||||
out_size=32,
|
||||
num_style_feat=512,
|
||||
channel_multiplier=1,
|
||||
decoder_load_path=None,
|
||||
fix_decoder=True,
|
||||
# for stylegan decoder
|
||||
num_mlp=8,
|
||||
input_is_latent=False,
|
||||
different_w=False,
|
||||
narrow=1,
|
||||
sft_half=True).cuda().eval()
|
||||
|
||||
img = torch.rand((1, 3, 32, 32), dtype=torch.float32).cuda()
|
||||
output = net(img)
|
||||
assert output[0].shape == (1, 3, 32, 32)
|
||||
assert len(output[1]) == 3
|
||||
# check out_rgbs for intermediate loss
|
||||
assert output[1][0].shape == (1, 3, 8, 8)
|
||||
assert output[1][1].shape == (1, 3, 16, 16)
|
||||
assert output[1][2].shape == (1, 3, 32, 32)
|
||||
|
||||
# -------------------- with different_w = True ----------------------- #
|
||||
net = GFPGANv1Clean(
|
||||
out_size=32,
|
||||
num_style_feat=512,
|
||||
channel_multiplier=1,
|
||||
decoder_load_path=None,
|
||||
fix_decoder=True,
|
||||
# for stylegan decoder
|
||||
num_mlp=8,
|
||||
input_is_latent=False,
|
||||
different_w=True,
|
||||
narrow=1,
|
||||
sft_half=True).cuda().eval()
|
||||
img = torch.rand((1, 3, 32, 32), dtype=torch.float32).cuda()
|
||||
output = net(img)
|
||||
assert output[0].shape == (1, 3, 32, 32)
|
||||
assert len(output[1]) == 3
|
||||
# check out_rgbs for intermediate loss
|
||||
assert output[1][0].shape == (1, 3, 8, 8)
|
||||
assert output[1][1].shape == (1, 3, 16, 16)
|
||||
assert output[1][2].shape == (1, 3, 32, 32)
|
||||
132
tests/test_gfpgan_model.py
Normal file
132
tests/test_gfpgan_model.py
Normal file
@@ -0,0 +1,132 @@
|
||||
import tempfile
|
||||
import torch
|
||||
import yaml
|
||||
from basicsr.archs.stylegan2_arch import StyleGAN2Discriminator
|
||||
from basicsr.data.paired_image_dataset import PairedImageDataset
|
||||
from basicsr.losses.losses import GANLoss, L1Loss, PerceptualLoss
|
||||
|
||||
from gfpgan.archs.arcface_arch import ResNetArcFace
|
||||
from gfpgan.archs.gfpganv1_arch import FacialComponentDiscriminator, GFPGANv1
|
||||
from gfpgan.models.gfpgan_model import GFPGANModel
|
||||
|
||||
|
||||
def test_gfpgan_model():
|
||||
with open('tests/data/test_gfpgan_model.yml', mode='r') as f:
|
||||
opt = yaml.load(f, Loader=yaml.FullLoader)
|
||||
|
||||
# build model
|
||||
model = GFPGANModel(opt)
|
||||
# test attributes
|
||||
assert model.__class__.__name__ == 'GFPGANModel'
|
||||
assert isinstance(model.net_g, GFPGANv1) # generator
|
||||
assert isinstance(model.net_d, StyleGAN2Discriminator) # discriminator
|
||||
# facial component discriminators
|
||||
assert isinstance(model.net_d_left_eye, FacialComponentDiscriminator)
|
||||
assert isinstance(model.net_d_right_eye, FacialComponentDiscriminator)
|
||||
assert isinstance(model.net_d_mouth, FacialComponentDiscriminator)
|
||||
# identity network
|
||||
assert isinstance(model.network_identity, ResNetArcFace)
|
||||
# losses
|
||||
assert isinstance(model.cri_pix, L1Loss)
|
||||
assert isinstance(model.cri_perceptual, PerceptualLoss)
|
||||
assert isinstance(model.cri_gan, GANLoss)
|
||||
assert isinstance(model.cri_l1, L1Loss)
|
||||
# optimizer
|
||||
assert isinstance(model.optimizers[0], torch.optim.Adam)
|
||||
assert isinstance(model.optimizers[1], torch.optim.Adam)
|
||||
|
||||
# prepare data
|
||||
gt = torch.rand((1, 3, 512, 512), dtype=torch.float32)
|
||||
lq = torch.rand((1, 3, 512, 512), dtype=torch.float32)
|
||||
loc_left_eye = torch.rand((1, 4), dtype=torch.float32)
|
||||
loc_right_eye = torch.rand((1, 4), dtype=torch.float32)
|
||||
loc_mouth = torch.rand((1, 4), dtype=torch.float32)
|
||||
data = dict(gt=gt, lq=lq, loc_left_eye=loc_left_eye, loc_right_eye=loc_right_eye, loc_mouth=loc_mouth)
|
||||
model.feed_data(data)
|
||||
# check data shape
|
||||
assert model.lq.shape == (1, 3, 512, 512)
|
||||
assert model.gt.shape == (1, 3, 512, 512)
|
||||
assert model.loc_left_eyes.shape == (1, 4)
|
||||
assert model.loc_right_eyes.shape == (1, 4)
|
||||
assert model.loc_mouths.shape == (1, 4)
|
||||
|
||||
# ----------------- test optimize_parameters -------------------- #
|
||||
model.feed_data(data)
|
||||
model.optimize_parameters(1)
|
||||
assert model.output.shape == (1, 3, 512, 512)
|
||||
assert isinstance(model.log_dict, dict)
|
||||
# check returned keys
|
||||
expected_keys = [
|
||||
'l_g_pix', 'l_g_percep', 'l_g_style', 'l_g_gan', 'l_g_gan_left_eye', 'l_g_gan_right_eye', 'l_g_gan_mouth',
|
||||
'l_g_comp_style_loss', 'l_identity', 'l_d', 'real_score', 'fake_score', 'l_d_r1', 'l_d_left_eye',
|
||||
'l_d_right_eye', 'l_d_mouth'
|
||||
]
|
||||
assert set(expected_keys).issubset(set(model.log_dict.keys()))
|
||||
|
||||
# ----------------- remove pyramid_loss_weight-------------------- #
|
||||
model.feed_data(data)
|
||||
model.optimize_parameters(100000) # large than remove_pyramid_loss = 50000
|
||||
assert model.output.shape == (1, 3, 512, 512)
|
||||
assert isinstance(model.log_dict, dict)
|
||||
# check returned keys
|
||||
expected_keys = [
|
||||
'l_g_pix', 'l_g_percep', 'l_g_style', 'l_g_gan', 'l_g_gan_left_eye', 'l_g_gan_right_eye', 'l_g_gan_mouth',
|
||||
'l_g_comp_style_loss', 'l_identity', 'l_d', 'real_score', 'fake_score', 'l_d_r1', 'l_d_left_eye',
|
||||
'l_d_right_eye', 'l_d_mouth'
|
||||
]
|
||||
assert set(expected_keys).issubset(set(model.log_dict.keys()))
|
||||
|
||||
# ----------------- test save -------------------- #
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
model.opt['path']['models'] = tmpdir
|
||||
model.opt['path']['training_states'] = tmpdir
|
||||
model.save(0, 1)
|
||||
|
||||
# ----------------- test the test function -------------------- #
|
||||
model.test()
|
||||
assert model.output.shape == (1, 3, 512, 512)
|
||||
# delete net_g_ema
|
||||
model.__delattr__('net_g_ema')
|
||||
model.test()
|
||||
assert model.output.shape == (1, 3, 512, 512)
|
||||
assert model.net_g.training is True # should back to training mode after testing
|
||||
|
||||
# ----------------- test nondist_validation -------------------- #
|
||||
# construct dataloader
|
||||
dataset_opt = dict(
|
||||
name='Demo',
|
||||
dataroot_gt='tests/data/gt',
|
||||
dataroot_lq='tests/data/gt',
|
||||
io_backend=dict(type='disk'),
|
||||
scale=4,
|
||||
phase='val')
|
||||
dataset = PairedImageDataset(dataset_opt)
|
||||
dataloader = torch.utils.data.DataLoader(dataset=dataset, batch_size=1, shuffle=False, num_workers=0)
|
||||
assert model.is_train is True
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
model.opt['path']['visualization'] = tmpdir
|
||||
model.nondist_validation(dataloader, 1, None, save_img=True)
|
||||
assert model.is_train is True
|
||||
# check metric_results
|
||||
assert 'psnr' in model.metric_results
|
||||
assert isinstance(model.metric_results['psnr'], float)
|
||||
|
||||
# validation
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
model.opt['is_train'] = False
|
||||
model.opt['val']['suffix'] = 'test'
|
||||
model.opt['path']['visualization'] = tmpdir
|
||||
model.opt['val']['pbar'] = True
|
||||
model.nondist_validation(dataloader, 1, None, save_img=True)
|
||||
# check metric_results
|
||||
assert 'psnr' in model.metric_results
|
||||
assert isinstance(model.metric_results['psnr'], float)
|
||||
|
||||
# if opt['val']['suffix'] is None
|
||||
model.opt['val']['suffix'] = None
|
||||
model.opt['name'] = 'demo'
|
||||
model.opt['path']['visualization'] = tmpdir
|
||||
model.nondist_validation(dataloader, 1, None, save_img=True)
|
||||
# check metric_results
|
||||
assert 'psnr' in model.metric_results
|
||||
assert isinstance(model.metric_results['psnr'], float)
|
||||
52
tests/test_stylegan2_clean_arch.py
Normal file
52
tests/test_stylegan2_clean_arch.py
Normal file
@@ -0,0 +1,52 @@
|
||||
import torch
|
||||
|
||||
from gfpgan.archs.stylegan2_clean_arch import StyleGAN2GeneratorClean
|
||||
|
||||
|
||||
def test_stylegan2generatorclean():
|
||||
"""Test arch: StyleGAN2GeneratorClean."""
|
||||
|
||||
# model init and forward (gpu)
|
||||
if torch.cuda.is_available():
|
||||
net = StyleGAN2GeneratorClean(
|
||||
out_size=32, num_style_feat=512, num_mlp=8, channel_multiplier=1, narrow=0.5).cuda().eval()
|
||||
style = torch.rand((1, 512), dtype=torch.float32).cuda()
|
||||
output = net([style], input_is_latent=False)
|
||||
assert output[0].shape == (1, 3, 32, 32)
|
||||
assert output[1] is None
|
||||
|
||||
# -------------------- with return_latents ----------------------- #
|
||||
output = net([style], input_is_latent=True, return_latents=True)
|
||||
assert output[0].shape == (1, 3, 32, 32)
|
||||
assert len(output[1]) == 1
|
||||
# check latent
|
||||
assert output[1][0].shape == (8, 512)
|
||||
|
||||
# -------------------- with randomize_noise = False ----------------------- #
|
||||
output = net([style], randomize_noise=False)
|
||||
assert output[0].shape == (1, 3, 32, 32)
|
||||
assert output[1] is None
|
||||
|
||||
# -------------------- with truncation = 0.5 and mixing----------------------- #
|
||||
output = net([style, style], truncation=0.5, truncation_latent=style)
|
||||
assert output[0].shape == (1, 3, 32, 32)
|
||||
assert output[1] is None
|
||||
|
||||
# ------------------ test make_noise ----------------------- #
|
||||
out = net.make_noise()
|
||||
assert len(out) == 7
|
||||
assert out[0].shape == (1, 1, 4, 4)
|
||||
assert out[1].shape == (1, 1, 8, 8)
|
||||
assert out[2].shape == (1, 1, 8, 8)
|
||||
assert out[3].shape == (1, 1, 16, 16)
|
||||
assert out[4].shape == (1, 1, 16, 16)
|
||||
assert out[5].shape == (1, 1, 32, 32)
|
||||
assert out[6].shape == (1, 1, 32, 32)
|
||||
|
||||
# ------------------ test get_latent ----------------------- #
|
||||
out = net.get_latent(style)
|
||||
assert out.shape == (1, 512)
|
||||
|
||||
# ------------------ test mean_latent ----------------------- #
|
||||
out = net.mean_latent(2)
|
||||
assert out.shape == (1, 512)
|
||||
43
tests/test_utils.py
Normal file
43
tests/test_utils.py
Normal file
@@ -0,0 +1,43 @@
|
||||
import cv2
|
||||
from facexlib.utils.face_restoration_helper import FaceRestoreHelper
|
||||
|
||||
from gfpgan.archs.gfpganv1_arch import GFPGANv1
|
||||
from gfpgan.archs.gfpganv1_clean_arch import GFPGANv1Clean
|
||||
from gfpgan.utils import GFPGANer
|
||||
|
||||
|
||||
def test_gfpganer():
|
||||
# initialize with the clean model
|
||||
restorer = GFPGANer(
|
||||
model_path='experiments/pretrained_models/GFPGANCleanv1-NoCE-C2.pth',
|
||||
upscale=2,
|
||||
arch='clean',
|
||||
channel_multiplier=2,
|
||||
bg_upsampler=None)
|
||||
# test attribute
|
||||
assert isinstance(restorer.gfpgan, GFPGANv1Clean)
|
||||
assert isinstance(restorer.face_helper, FaceRestoreHelper)
|
||||
|
||||
# initialize with the original model
|
||||
restorer = GFPGANer(
|
||||
model_path='experiments/pretrained_models/GFPGANv1.pth',
|
||||
upscale=2,
|
||||
arch='original',
|
||||
channel_multiplier=1,
|
||||
bg_upsampler=None)
|
||||
# test attribute
|
||||
assert isinstance(restorer.gfpgan, GFPGANv1)
|
||||
assert isinstance(restorer.face_helper, FaceRestoreHelper)
|
||||
|
||||
# ------------------ test enhance ---------------- #
|
||||
img = cv2.imread('tests/data/gt/00000000.png', cv2.IMREAD_COLOR)
|
||||
result = restorer.enhance(img, has_aligned=False, paste_back=True)
|
||||
assert result[0][0].shape == (512, 512, 3)
|
||||
assert result[1][0].shape == (512, 512, 3)
|
||||
assert result[2].shape == (1024, 1024, 3)
|
||||
|
||||
# with has_aligned=True
|
||||
result = restorer.enhance(img, has_aligned=True, paste_back=False)
|
||||
assert result[0][0].shape == (512, 512, 3)
|
||||
assert result[1][0].shape == (512, 512, 3)
|
||||
assert result[2] is None
|
||||
Reference in New Issue
Block a user