42 Commits

Author SHA1 Message Date
Xintao
7552a7791c Delete .github/workflows/no-response.yml 2024-04-03 00:39:30 +08:00
Xintao
2eac203389 v1.3.8 2022-09-16 19:33:26 +08:00
Xintao
2f46d95254 fix pylint 2022-09-16 19:32:42 +08:00
Xintao
bc5a5deb95 remove codeformer 2022-09-16 19:31:20 +08:00
Xintao
3fd33abc47 update cog predict 2022-09-12 23:24:08 +08:00
Xintao
d226e86f6c v1.3.7 2022-09-12 22:40:25 +08:00
Xintao
bb2f916764 update 2022-09-12 22:29:46 +08:00
Xintao
fe3beac9dc v1.3.6 2022-09-12 22:17:48 +08:00
Xintao
126c55c68d add restoreformer and codeformer inference codes 2022-09-12 21:33:06 +08:00
Xintao
8d2447a2d9 update cog predict 2022-09-04 23:27:02 +08:00
Xintao
af7569775d v1.3.5 2022-09-04 22:18:25 +08:00
Xintao
c6593e7221 update cog_predict 2022-09-04 20:28:24 +08:00
Xintao
7272e45887 update replicate (#248)
* update util

* update predict

* update predict

* update predict

* update predict

* update predict

* update predict

* update predict

* update predict

* merge replicate update
2022-09-04 20:12:31 +08:00
Xintao
3e27784b1b update replicate related 2022-08-31 17:36:25 +08:00
Xintao
2c420ee565 update readme 2022-08-31 16:33:30 +08:00
Xintao
8e7cf5d723 update readme 2022-08-30 23:02:22 +08:00
Xintao
c541e97f83 update readme 2022-08-30 23:01:28 +08:00
Xintao
86756cba65 update readme 2022-08-30 22:57:22 +08:00
Chenxi
a9a2e3ae15 Add Docker environment & web demo (#67)
* enable cog

* Update README.md

* Update README.md

* refactor

* fix temp input dir bug

Co-authored-by: CJWBW <70536672+CJWBW@users.noreply.github.com>
Co-authored-by: Chenxi <chenxi@Chenxis-MacBook-Pro-2.local>
Co-authored-by: Xintao <wxt1994@126.com>
2022-08-29 17:28:16 +08:00
Xintao
9c3f2d62cb v1.3.4 2022-07-13 10:21:28 +08:00
Xintao
ccd30af837 add release workflow 2022-07-13 10:19:50 +08:00
AJ
7d657f26b6 fix basicsr losses import (#210) 2022-07-13 10:01:06 +08:00
Xintao
c7ccc098a7 update facelib; use seg to paste back 2022-06-07 16:49:26 +08:00
Xintao
bc3f0c4d91 add device to GFPGANer for multiGPU support 2022-05-04 13:23:54 +08:00
Xintao
924ce473ab v1.3.2 2022-02-16 00:32:50 +08:00
Xintao
09a37ae7fd add logo 2022-02-16 00:11:10 +08:00
Xintao
6c544b70e6 v1.3.1 2022-02-14 15:37:45 +08:00
Xintao
47983e1767 add stylegan2_bilinear_arch 2022-02-14 14:28:27 +08:00
Xintao
77df6e4fad v1.3.0 2022-02-14 11:21:59 +08:00
Xintao
24b1f24ef5 Add V1.3 model (#158)
* add gfpgan bilinear arch

* add v1.3

* update readme

* update readme

* update readme

* rename
2022-02-14 11:21:03 +08:00
Tuhin Srivastava
c068e4d113 Add baseten.co demo to list of demos (#149) 2022-02-14 11:06:33 +08:00
Xintao
d8bf32a816 Add CODE_OF_CONDUCT.md 2022-01-08 19:50:14 +08:00
Xintao
09d82ec683 update readme 2021-12-17 01:57:58 +08:00
Xintao
780774d515 update readme 2021-12-17 01:38:13 +08:00
Xintao
95101b46d2 fix pylint 2021-12-17 01:33:09 +08:00
Xintao
547e026042 update reamdme 2021-12-17 01:31:33 +08:00
DARBAZ
83bcb28462 [improvement] No module named 'gfpgan.version' (#117) 2021-12-17 01:24:47 +08:00
Enoyao
942e7b39c6 Update README.md (#121) 2021-12-17 01:23:23 +08:00
Xintao
8ba74c99ba update readme 2021-12-14 14:19:40 +08:00
Xintao
3241c576ae update readme 2021-12-14 14:17:07 +08:00
Bram
05062fac70 Minor linguistic edits (#112)
Co-authored-by: Xintao <wxt1994@126.com>
2021-12-14 13:16:50 +08:00
Mostafa Vatanpour
3241798723 (fix typo) Update README.md (#107)
Some typo corrections.
2021-12-14 13:13:51 +08:00
23 changed files with 2349 additions and 116 deletions

View File

@@ -1,33 +0,0 @@
name: No Response
# TODO: it seems not to work
# Modified from: https://raw.githubusercontent.com/github/docs/main/.github/workflows/no-response.yaml
# **What it does**: Closes issues that don't have enough information to be actionable.
# **Why we have it**: To remove the need for maintainers to remember to check back on issues periodically
# to see if contributors have responded.
# **Who does it impact**: Everyone that works on docs or docs-internal.
on:
issue_comment:
types: [created]
schedule:
# Schedule for five minutes after the hour every hour
- cron: '5 * * * *'
jobs:
noResponse:
runs-on: ubuntu-latest
steps:
- uses: lee-dohm/no-response@v0.5.0
with:
token: ${{ github.token }}
closeComment: >
This issue has been automatically closed because there has been no response
to our request for more information from the original author. With only the
information that is currently in the issue, we don't have enough information
to take action. Please reach out if you have or find the answers we need so
that we can investigate further.
If you still have questions, please improve your description and re-open it.
Thanks :-)

41
.github/workflows/release.yml vendored Normal file
View File

@@ -0,0 +1,41 @@
name: release
on:
push:
tags:
- '*'
jobs:
build:
permissions: write-all
name: Create Release
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v2
- name: Create Release
id: create_release
uses: actions/create-release@v1
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
with:
tag_name: ${{ github.ref }}
release_name: GFPGAN ${{ github.ref }} Release Note
body: |
🚀 See you again 😸
🚀Have a nice day 😸 and happy everyday 😃
🚀 Long time no see ☄️
✨ **Highlights**
✅ [Features] Support ...
🐛 **Bug Fixes**
🌴 **Improvements**
📢📢📢
<p align="center">
<img src="https://raw.githubusercontent.com/TencentARC/GFPGAN/master/assets/gfpgan_logo.png" height=150>
</p>
draft: true
prerelease: false

128
CODE_OF_CONDUCT.md Normal file
View File

@@ -0,0 +1,128 @@
# Contributor Covenant Code of Conduct
## Our Pledge
We as members, contributors, and leaders pledge to make participation in our
community a harassment-free experience for everyone, regardless of age, body
size, visible or invisible disability, ethnicity, sex characteristics, gender
identity and expression, level of experience, education, socio-economic status,
nationality, personal appearance, race, religion, or sexual identity
and orientation.
We pledge to act and interact in ways that contribute to an open, welcoming,
diverse, inclusive, and healthy community.
## Our Standards
Examples of behavior that contributes to a positive environment for our
community include:
* Demonstrating empathy and kindness toward other people
* Being respectful of differing opinions, viewpoints, and experiences
* Giving and gracefully accepting constructive feedback
* Accepting responsibility and apologizing to those affected by our mistakes,
and learning from the experience
* Focusing on what is best not just for us as individuals, but for the
overall community
Examples of unacceptable behavior include:
* The use of sexualized language or imagery, and sexual attention or
advances of any kind
* Trolling, insulting or derogatory comments, and personal or political attacks
* Public or private harassment
* Publishing others' private information, such as a physical or email
address, without their explicit permission
* Other conduct which could reasonably be considered inappropriate in a
professional setting
## Enforcement Responsibilities
Community leaders are responsible for clarifying and enforcing our standards of
acceptable behavior and will take appropriate and fair corrective action in
response to any behavior that they deem inappropriate, threatening, offensive,
or harmful.
Community leaders have the right and responsibility to remove, edit, or reject
comments, commits, code, wiki edits, issues, and other contributions that are
not aligned to this Code of Conduct, and will communicate reasons for moderation
decisions when appropriate.
## Scope
This Code of Conduct applies within all community spaces, and also applies when
an individual is officially representing the community in public spaces.
Examples of representing our community include using an official e-mail address,
posting via an official social media account, or acting as an appointed
representative at an online or offline event.
## Enforcement
Instances of abusive, harassing, or otherwise unacceptable behavior may be
reported to the community leaders responsible for enforcement at
xintao.wang@outlook.com or xintaowang@tencent.com.
All complaints will be reviewed and investigated promptly and fairly.
All community leaders are obligated to respect the privacy and security of the
reporter of any incident.
## Enforcement Guidelines
Community leaders will follow these Community Impact Guidelines in determining
the consequences for any action they deem in violation of this Code of Conduct:
### 1. Correction
**Community Impact**: Use of inappropriate language or other behavior deemed
unprofessional or unwelcome in the community.
**Consequence**: A private, written warning from community leaders, providing
clarity around the nature of the violation and an explanation of why the
behavior was inappropriate. A public apology may be requested.
### 2. Warning
**Community Impact**: A violation through a single incident or series
of actions.
**Consequence**: A warning with consequences for continued behavior. No
interaction with the people involved, including unsolicited interaction with
those enforcing the Code of Conduct, for a specified period of time. This
includes avoiding interactions in community spaces as well as external channels
like social media. Violating these terms may lead to a temporary or
permanent ban.
### 3. Temporary Ban
**Community Impact**: A serious violation of community standards, including
sustained inappropriate behavior.
**Consequence**: A temporary ban from any sort of interaction or public
communication with the community for a specified period of time. No public or
private interaction with the people involved, including unsolicited interaction
with those enforcing the Code of Conduct, is allowed during this period.
Violating these terms may lead to a permanent ban.
### 4. Permanent Ban
**Community Impact**: Demonstrating a pattern of violation of community
standards, including sustained inappropriate behavior, harassment of an
individual, or aggression toward or disparagement of classes of individuals.
**Consequence**: A permanent ban from any sort of public interaction within
the community.
## Attribution
This Code of Conduct is adapted from the [Contributor Covenant][homepage],
version 2.0, available at
https://www.contributor-covenant.org/version/2/0/code_of_conduct.html.
Community Impact Guidelines were inspired by [Mozilla's code of conduct
enforcement ladder](https://github.com/mozilla/diversity).
[homepage]: https://www.contributor-covenant.org
For answers to common questions about this code of conduct, see the FAQ at
https://www.contributor-covenant.org/faq. Translations are available at
https://www.contributor-covenant.org/translations.

24
Comparisons.md Normal file
View File

@@ -0,0 +1,24 @@
# Comparisons
## Comparisons among different model versions
Note that V1.3 is not always better than V1.2. You may need to try different models based on your purpose and inputs.
| Version | Strengths | Weaknesses |
| :---: | :---: | :---: |
|V1.3 | ✓ natural outputs<br> ✓better results on very low-quality inputs <br> ✓ work on relatively high-quality inputs <br>✓ can have repeated (twice) restorations | ✗ not very sharp <br> ✗ have a slight change on identity |
|V1.2 | ✓ sharper output <br> ✓ with beauty makeup | ✗ some outputs are unnatural|
For the following images, you may need to **zoom in** for comparing details, or **click the image** to see in the full size.
| Input | V1 | V1.2 | V1.3
| :---: | :---: | :---: | :---: |
|![019_Anne_Hathaway_01_00](https://user-images.githubusercontent.com/17445847/153762146-96b25999-4ddd-42a5-a3fe-bb90565f4c4f.png)| ![](https://user-images.githubusercontent.com/17445847/153762256-ef41e749-5a27-495c-8a9c-d8403be55869.png) | ![](https://user-images.githubusercontent.com/17445847/153762297-d41582fc-6253-4e7e-a1ce-4dc237ae3bf3.png) | ![](https://user-images.githubusercontent.com/17445847/153762215-e0535e94-b5ba-426e-97b5-35c00873604d.png) |
| ![106_Harry_Styles_00_00](https://user-images.githubusercontent.com/17445847/153789040-632c0eda-c15a-43e9-a63c-9ead64f92d4a.png) | ![](https://user-images.githubusercontent.com/17445847/153789172-93cd4980-5318-4633-a07e-1c8f8064ff89.png) | ![](https://user-images.githubusercontent.com/17445847/153789185-f7b268a7-d1db-47b0-ae4a-335e5d657a18.png) | ![](https://user-images.githubusercontent.com/17445847/153789198-7c7f3bca-0ef0-4494-92f0-20aa6f7d7464.png)|
| ![076_Paris_Hilton_00_00](https://user-images.githubusercontent.com/17445847/153789607-86387770-9db8-441f-b08a-c9679b121b85.png) | ![](https://user-images.githubusercontent.com/17445847/153789619-e56b438a-78a0-425d-8f44-ec4692a43dda.png) | ![](https://user-images.githubusercontent.com/17445847/153789633-5b28f778-3b7f-4e08-8a1d-740ca6e82d8a.png) | ![](https://user-images.githubusercontent.com/17445847/153789645-bc623f21-b32d-4fc3-bfe9-61203407a180.png)|
| ![008_George_Clooney_00_00](https://user-images.githubusercontent.com/17445847/153790017-0c3ca94d-1c9d-4a0e-b539-ab12d4da98ff.png) | ![](https://user-images.githubusercontent.com/17445847/153790028-fb0d38ab-399d-4a30-8154-2dcd72ca90e8.png) | ![](https://user-images.githubusercontent.com/17445847/153790044-1ef68e34-6120-4439-a5d9-0b6cdbe9c3d0.png) | ![](https://user-images.githubusercontent.com/17445847/153790059-a8d3cece-8989-4e9a-9ffe-903e1690cfd6.png)|
| ![057_Madonna_01_00](https://user-images.githubusercontent.com/17445847/153790624-2d0751d0-8fb4-4806-be9d-71b833c2c226.png) | ![](https://user-images.githubusercontent.com/17445847/153790639-7eb870e5-26b2-41dc-b139-b698bb40e6e6.png) | ![](https://user-images.githubusercontent.com/17445847/153790651-86899b7a-a1b6-4242-9e8a-77b462004998.png) | ![](https://user-images.githubusercontent.com/17445847/153790655-c8f6c25b-9b4e-4633-b16f-c43da86cff8f.png)|
| ![044_Amy_Schumer_01_00](https://user-images.githubusercontent.com/17445847/153790811-3fb4fc46-5b4f-45fe-8fcb-a128de2bfa60.png) | ![](https://user-images.githubusercontent.com/17445847/153790817-d45aa4ff-bfc4-4163-b462-75eef9426fab.png) | ![](https://user-images.githubusercontent.com/17445847/153790824-5f93c3a0-fe5a-42f6-8b4b-5a5de8cd0ac3.png) | ![](https://user-images.githubusercontent.com/17445847/153790835-0edf9944-05c7-41c4-8581-4dc5ffc56c9d.png)|
| ![012_Jackie_Chan_01_00](https://user-images.githubusercontent.com/17445847/153791176-737b016a-e94f-4898-8db7-43e7762141c9.png) | ![](https://user-images.githubusercontent.com/17445847/153791183-2f25a723-56bf-4cd5-aafe-a35513a6d1c5.png) | ![](https://user-images.githubusercontent.com/17445847/153791194-93416cf9-2b58-4e70-b806-27e14c58d4fd.png) | ![](https://user-images.githubusercontent.com/17445847/153791202-aa98659c-b702-4bce-9c47-a2fa5eccc5ae.png)|
<!-- | ![]() | ![]() | ![]() | ![]()| -->

7
FAQ.md Normal file
View File

@@ -0,0 +1,7 @@
# FAQ
1. **How to finetune the GFPGANCleanv1-NoCE-C2 (v1.2) model**
**A:** 1) The GFPGANCleanv1-NoCE-C2 (v1.2) model uses the *clean* architecture, which is more friendly for deploying.
2) This model is not directly trained. Instead, it is converted from another *bilinear* model.
3) If you want to finetune the GFPGANCleanv1-NoCE-C2 (v1.2), you need to finetune its original *bilinear* model, and then do the conversion.

View File

@@ -60,17 +60,17 @@ wget https://github.com/TencentARC/GFPGAN/releases/download/v0.1.0/GFPGANv1.pth
- Option 1: Load extensions just-in-time(JIT)
```bash
BASICSR_JIT=True python inference_gfpgan.py --model_path experiments/pretrained_models/GFPGANv1.pth --test_path inputs/whole_imgs --save_root results --arch original --channel 1
BASICSR_JIT=True python inference_gfpgan.py --input inputs/whole_imgs --output results --version 1
# for aligned images
BASICSR_JIT=True python inference_gfpgan.py --model_path experiments/pretrained_models/GFPGANv1.pth --test_path inputs/cropped_faces --save_root results --arch original --channel 1 --aligned
BASICSR_JIT=True python inference_gfpgan.py --input inputs/whole_imgs --output results --version 1 --aligned
```
- Option 2: Have successfully compiled extensions during installation
```bash
python inference_gfpgan.py --model_path experiments/pretrained_models/GFPGANv1.pth --test_path inputs/whole_imgs --save_root results --arch original --channel 1
python inference_gfpgan.py --input inputs/whole_imgs --output results --version 1
# for aligned images
python inference_gfpgan.py --model_path experiments/pretrained_models/GFPGANv1.pth --test_path inputs/cropped_faces --save_root results --arch original --channel 1 --aligned
python inference_gfpgan.py --input inputs/whole_imgs --output results --version 1 --aligned
```

View File

@@ -1,4 +1,13 @@
# GFPGAN (CVPR 2021)
<p align="center">
<img src="assets/gfpgan_logo.png" height=130>
</p>
## <div align="center"><b><a href="README.md">English</a> | <a href="README_CN.md">简体中文</a></b></div>
<div align="center">
<!-- <a href="https://twitter.com/_Xintao_" style="text-decoration:none;">
<img src="https://user-images.githubusercontent.com/17445847/187162058-c764ced6-952f-404b-ac85-ba95cce18e7b.png" width="4%" alt="" />
</a> -->
[![download](https://img.shields.io/github/downloads/TencentARC/GFPGAN/total.svg)](https://github.com/TencentARC/GFPGAN/releases)
[![PyPI](https://img.shields.io/pypi/v/gfpgan)](https://pypi.org/project/gfpgan/)
@@ -7,14 +16,28 @@
[![LICENSE](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://github.com/TencentARC/GFPGAN/blob/master/LICENSE)
[![python lint](https://github.com/TencentARC/GFPGAN/actions/workflows/pylint.yml/badge.svg)](https://github.com/TencentARC/GFPGAN/blob/master/.github/workflows/pylint.yml)
[![Publish-pip](https://github.com/TencentARC/GFPGAN/actions/workflows/publish-pip.yml/badge.svg)](https://github.com/TencentARC/GFPGAN/blob/master/.github/workflows/publish-pip.yml)
</div>
1. :boom: **Updated** online demo: [![Replicate](https://img.shields.io/static/v1?label=Demo&message=Replicate&color=blue)](https://replicate.com/tencentarc/gfpgan). Here is the [backup](https://replicate.com/xinntao/gfpgan).
1. :boom: **Updated** online demo: [![Huggingface Gradio](https://img.shields.io/static/v1?label=Demo&message=Huggingface%20Gradio&color=orange)](https://huggingface.co/spaces/Xintao/GFPGAN)
1. [Colab Demo](https://colab.research.google.com/drive/1sVsoBd9AjckIXThgtZhGrHRfFI6UUYOo) for GFPGAN <a href="https://colab.research.google.com/drive/1sVsoBd9AjckIXThgtZhGrHRfFI6UUYOo"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="google colab logo"></a>; (Another [Colab Demo](https://colab.research.google.com/drive/1Oa1WwKB4M4l1GmR7CtswDVgOCOeSLChA?usp=sharing) for the original paper model)
1. We provide a *clean* version of GFPGAN, which can run without CUDA extensions. So that it can run in **Windows** or on **CPU mode**.
GFPGAN aims at developing **Practical Algorithm for Real-world Face Restoration**.<br>
<!-- 3. Online demo: [Replicate.ai](https://replicate.com/xinntao/gfpgan) (may need to sign in, return the whole image)
4. Online demo: [Baseten.co](https://app.baseten.co/applications/Q04Lz0d/operator_views/8qZG6Bg) (backed by GPU, returns the whole image)
5. We provide a *clean* version of GFPGAN, which can run without CUDA extensions. So that it can run in **Windows** or on **CPU mode**. -->
> :rocket: **Thanks for your interest in our work. You may also want to check our new updates on the *tiny models* for *anime images and videos* in [Real-ESRGAN](https://github.com/xinntao/Real-ESRGAN/blob/master/docs/anime_video_model.md)** :blush:
GFPGAN aims at developing a **Practical Algorithm for Real-world Face Restoration**.<br>
It leverages rich and diverse priors encapsulated in a pretrained face GAN (*e.g.*, StyleGAN2) for blind face restoration.
:question: Frequently Asked Questions can be found in [FAQ.md](FAQ.md).
:triangular_flag_on_post: **Updates**
- :white_check_mark: Add [RestoreFormer](https://github.com/wzhouxiff/RestoreFormer) inference codes.
- :white_check_mark: Add [V1.4 model](https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth), which produces slightly more details and better identity than V1.3.
- :white_check_mark: Add **[V1.3 model](https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth)**, which produces **more natural** restoration results, and better results on *very low-quality* / *high-quality* inputs. See more in [Model zoo](#european_castle-model-zoo), [Comparisons.md](Comparisons.md)
- :white_check_mark: Integrated to [Huggingface Spaces](https://huggingface.co/spaces) with [Gradio](https://github.com/gradio-app/gradio). See [Gradio Web Demo](https://huggingface.co/spaces/akhaliq/GFPGAN).
- :white_check_mark: Support enhancing non-face regions (background) with [Real-ESRGAN](https://github.com/xinntao/Real-ESRGAN).
- :white_check_mark: We provide a *clean* version of GFPGAN, which does not require CUDA extensions.
@@ -25,9 +48,9 @@ It leverages rich and diverse priors encapsulated in a pretrained face GAN (*e.g
If GFPGAN is helpful in your photos/projects, please help to :star: this repo or recommend it to your friends. Thanks:blush:
Other recommended projects:<br>
:arrow_forward: [Real-ESRGAN](https://github.com/xinntao/Real-ESRGAN): A practical algorithm for general image restoration<br>
:arrow_forward: [BasicSR](https://github.com/xinntao/BasicSR): An ppen-source image and video restoration toolbox<br>
:arrow_forward: [facexlib](https://github.com/xinntao/facexlib): A collection that provides useful face-relation functions.<br>
:arrow_forward: [HandyView](https://github.com/xinntao/HandyView): A PyQt5-based image viewer that is handy for view and comparison. <br>
:arrow_forward: [BasicSR](https://github.com/xinntao/BasicSR): An open-source image and video restoration toolbox<br>
:arrow_forward: [facexlib](https://github.com/xinntao/facexlib): A collection that provides useful face-relation functions<br>
:arrow_forward: [HandyView](https://github.com/xinntao/HandyView): A PyQt5-based image viewer that is handy for view and comparison<br>
---
@@ -53,7 +76,7 @@ Other recommended projects:<br>
### Installation
We now provide a *clean* version of GFPGAN, which does not require customized CUDA extensions. <br>
If you want want to use the original model in our paper, please see [PaperModel.md](PaperModel.md) for installation.
If you want to use the original model in our paper, please see [PaperModel.md](PaperModel.md) for installation.
1. Clone repo
@@ -83,24 +106,54 @@ If you want want to use the original model in our paper, please see [PaperModel.
## :zap: Quick Inference
Download pre-trained models: [GFPGANCleanv1-NoCE-C2.pth](https://github.com/TencentARC/GFPGAN/releases/download/v0.2.0/GFPGANCleanv1-NoCE-C2.pth)
We take the v1.3 version for an example. More models can be found [here](#european_castle-model-zoo).
Download pre-trained models: [GFPGANv1.3.pth](https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth)
```bash
wget https://github.com/TencentARC/GFPGAN/releases/download/v0.2.0/GFPGANCleanv1-NoCE-C2.pth -P experiments/pretrained_models
wget https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth -P experiments/pretrained_models
```
**Inference!**
```bash
python inference_gfpgan.py --upscale 2 --test_path inputs/whole_imgs --save_root results
python inference_gfpgan.py -i inputs/whole_imgs -o results -v 1.3 -s 2
```
If you want want to use the original model in our paper, please see [PaperModel.md](PaperModel.md) for installation and inference.
```console
Usage: python inference_gfpgan.py -i inputs/whole_imgs -o results -v 1.3 -s 2 [options]...
-h show this help
-i input Input image or folder. Default: inputs/whole_imgs
-o output Output folder. Default: results
-v version GFPGAN model version. Option: 1 | 1.2 | 1.3. Default: 1.3
-s upscale The final upsampling scale of the image. Default: 2
-bg_upsampler background upsampler. Default: realesrgan
-bg_tile Tile size for background sampler, 0 for no tile during testing. Default: 400
-suffix Suffix of the restored faces
-only_center_face Only restore the center face
-aligned Input are aligned faces
-ext Image extension. Options: auto | jpg | png, auto means using the same extension as inputs. Default: auto
```
If you want to use the original model in our paper, please see [PaperModel.md](PaperModel.md) for installation and inference.
## :european_castle: Model Zoo
- [GFPGANCleanv1-NoCE-C2.pth](https://github.com/TencentARC/GFPGAN/releases/download/v0.2.0/GFPGANCleanv1-NoCE-C2.pth): No colorization; no CUDA extensions are required. It is still in training. Trained with more data with pre-processing.
- [GFPGANv1.pth](https://github.com/TencentARC/GFPGAN/releases/download/v0.1.0/GFPGANv1.pth): The paper model, with colorization.
| Version | Model Name | Description |
| :---: | :---: | :---: |
| V1.3 | [GFPGANv1.3.pth](https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth) | Based on V1.2; **more natural** restoration results; better results on very low-quality / high-quality inputs. |
| V1.2 | [GFPGANCleanv1-NoCE-C2.pth](https://github.com/TencentARC/GFPGAN/releases/download/v0.2.0/GFPGANCleanv1-NoCE-C2.pth) | No colorization; no CUDA extensions are required. Trained with more data with pre-processing. |
| V1 | [GFPGANv1.pth](https://github.com/TencentARC/GFPGAN/releases/download/v0.1.0/GFPGANv1.pth) | The paper model, with colorization. |
The comparisons are in [Comparisons.md](Comparisons.md).
Note that V1.3 is not always better than V1.2. You may need to select different models based on your purpose and inputs.
| Version | Strengths | Weaknesses |
| :---: | :---: | :---: |
|V1.3 | ✓ natural outputs<br> ✓better results on very low-quality inputs <br> ✓ work on relatively high-quality inputs <br>✓ can have repeated (twice) restorations | ✗ not very sharp <br> ✗ have a slight change on identity |
|V1.2 | ✓ sharper output <br> ✓ with beauty makeup | ✗ some outputs are unnatural |
You can find **more models (such as the discriminators)** here: [[Google Drive](https://drive.google.com/drive/folders/17rLiFzcUMoQuhLnptDsKolegHWwJOnHu?usp=sharing)], OR [[Tencent Cloud 腾讯微云](https://share.weiyun.com/ShYoCCoc)]
@@ -121,7 +174,7 @@ You could improve it according to your own needs.
1. Dataset preparation: [FFHQ](https://github.com/NVlabs/ffhq-dataset)
1. Download pre-trained models and other data. Put them in the `experiments/pretrained_models` folder.
1. [Pretrained StyleGAN2 model: StyleGAN2_512_Cmul1_FFHQ_B12G4_scratch_800k.pth](https://github.com/TencentARC/GFPGAN/releases/download/v0.1.0/StyleGAN2_512_Cmul1_FFHQ_B12G4_scratch_800k.pth)
1. [Pre-trained StyleGAN2 model: StyleGAN2_512_Cmul1_FFHQ_B12G4_scratch_800k.pth](https://github.com/TencentARC/GFPGAN/releases/download/v0.1.0/StyleGAN2_512_Cmul1_FFHQ_B12G4_scratch_800k.pth)
1. [Component locations of FFHQ: FFHQ_eye_mouth_landmarks_512.pth](https://github.com/TencentARC/GFPGAN/releases/download/v0.1.0/FFHQ_eye_mouth_landmarks_512.pth)
1. [A simple ArcFace model: arcface_resnet18.pth](https://github.com/TencentARC/GFPGAN/releases/download/v0.1.0/arcface_resnet18.pth)

7
README_CN.md Normal file
View File

@@ -0,0 +1,7 @@
<p align="center">
<img src="assets/gfpgan_logo.png" height=130>
</p>
## <div align="center"><b><a href="README.md">English</a> | <a href="README_CN.md">简体中文</a></b></div>
还未完工,欢迎贡献!

View File

@@ -1 +1 @@
0.2.4
1.3.8

BIN
assets/gfpgan_logo.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 50 KiB

22
cog.yaml Normal file
View File

@@ -0,0 +1,22 @@
# This file is used for constructing replicate env
image: "r8.im/tencentarc/gfpgan"
build:
gpu: true
python_version: "3.8"
system_packages:
- "libgl1-mesa-glx"
- "libglib2.0-0"
python_packages:
- "torch==1.7.1"
- "torchvision==0.8.2"
- "numpy==1.21.1"
- "lmdb==1.2.1"
- "opencv-python==4.5.3.56"
- "PyYAML==5.4.1"
- "tqdm==4.62.2"
- "yapf==0.31.0"
- "basicsr==1.4.2"
- "facexlib==0.2.5"
predict: "cog_predict.py:Predictor"

161
cog_predict.py Normal file
View File

@@ -0,0 +1,161 @@
# flake8: noqa
# This file is used for deploying replicate models
# running: cog predict -i img=@inputs/whole_imgs/10045.png -i version='v1.4' -i scale=2
# push: cog push r8.im/tencentarc/gfpgan
# push (backup): cog push r8.im/xinntao/gfpgan
import os
os.system('python setup.py develop')
os.system('pip install realesrgan')
import cv2
import shutil
import tempfile
import torch
from basicsr.archs.srvgg_arch import SRVGGNetCompact
from gfpgan import GFPGANer
try:
from cog import BasePredictor, Input, Path
from realesrgan.utils import RealESRGANer
except Exception:
print('please install cog and realesrgan package')
class Predictor(BasePredictor):
def setup(self):
os.makedirs('output', exist_ok=True)
# download weights
if not os.path.exists('gfpgan/weights/realesr-general-x4v3.pth'):
os.system(
'wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth -P ./gfpgan/weights'
)
if not os.path.exists('gfpgan/weights/GFPGANv1.2.pth'):
os.system(
'wget https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.2.pth -P ./gfpgan/weights')
if not os.path.exists('gfpgan/weights/GFPGANv1.3.pth'):
os.system(
'wget https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth -P ./gfpgan/weights')
if not os.path.exists('gfpgan/weights/GFPGANv1.4.pth'):
os.system(
'wget https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth -P ./gfpgan/weights')
if not os.path.exists('gfpgan/weights/RestoreFormer.pth'):
os.system(
'wget https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/RestoreFormer.pth -P ./gfpgan/weights'
)
# background enhancer with RealESRGAN
model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
model_path = 'gfpgan/weights/realesr-general-x4v3.pth'
half = True if torch.cuda.is_available() else False
self.upsampler = RealESRGANer(
scale=4, model_path=model_path, model=model, tile=0, tile_pad=10, pre_pad=0, half=half)
# Use GFPGAN for face enhancement
self.face_enhancer = GFPGANer(
model_path='gfpgan/weights/GFPGANv1.4.pth',
upscale=2,
arch='clean',
channel_multiplier=2,
bg_upsampler=self.upsampler)
self.current_version = 'v1.4'
def predict(
self,
img: Path = Input(description='Input'),
version: str = Input(
description='GFPGAN version. v1.3: better quality. v1.4: more details and better identity.',
choices=['v1.2', 'v1.3', 'v1.4', 'RestoreFormer'],
default='v1.4'),
scale: float = Input(description='Rescaling factor', default=2),
) -> Path:
weight = 0.5
print(img, version, scale, weight)
try:
extension = os.path.splitext(os.path.basename(str(img)))[1]
img = cv2.imread(str(img), cv2.IMREAD_UNCHANGED)
if len(img.shape) == 3 and img.shape[2] == 4:
img_mode = 'RGBA'
elif len(img.shape) == 2:
img_mode = None
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
else:
img_mode = None
h, w = img.shape[0:2]
if h < 300:
img = cv2.resize(img, (w * 2, h * 2), interpolation=cv2.INTER_LANCZOS4)
if self.current_version != version:
if version == 'v1.2':
self.face_enhancer = GFPGANer(
model_path='gfpgan/weights/GFPGANv1.2.pth',
upscale=2,
arch='clean',
channel_multiplier=2,
bg_upsampler=self.upsampler)
self.current_version = 'v1.2'
elif version == 'v1.3':
self.face_enhancer = GFPGANer(
model_path='gfpgan/weights/GFPGANv1.3.pth',
upscale=2,
arch='clean',
channel_multiplier=2,
bg_upsampler=self.upsampler)
self.current_version = 'v1.3'
elif version == 'v1.4':
self.face_enhancer = GFPGANer(
model_path='gfpgan/weights/GFPGANv1.4.pth',
upscale=2,
arch='clean',
channel_multiplier=2,
bg_upsampler=self.upsampler)
self.current_version = 'v1.4'
elif version == 'RestoreFormer':
self.face_enhancer = GFPGANer(
model_path='gfpgan/weights/RestoreFormer.pth',
upscale=2,
arch='RestoreFormer',
channel_multiplier=2,
bg_upsampler=self.upsampler)
try:
_, _, output = self.face_enhancer.enhance(
img, has_aligned=False, only_center_face=False, paste_back=True, weight=weight)
except RuntimeError as error:
print('Error', error)
try:
if scale != 2:
interpolation = cv2.INTER_AREA if scale < 2 else cv2.INTER_LANCZOS4
h, w = img.shape[0:2]
output = cv2.resize(output, (int(w * scale / 2), int(h * scale / 2)), interpolation=interpolation)
except Exception as error:
print('wrong scale input.', error)
if img_mode == 'RGBA': # RGBA images should be saved in png format
extension = 'png'
# save_path = f'output/out.{extension}'
# cv2.imwrite(save_path, output)
out_path = Path(tempfile.mkdtemp()) / f'out.{extension}'
cv2.imwrite(str(out_path), output)
except Exception as error:
print('global exception: ', error)
finally:
clean_folder('output')
return out_path
def clean_folder(folder):
for filename in os.listdir(folder):
file_path = os.path.join(folder, filename)
try:
if os.path.isfile(file_path) or os.path.islink(file_path):
os.unlink(file_path)
elif os.path.isdir(file_path):
shutil.rmtree(file_path)
except Exception as e:
print(f'Failed to delete {file_path}. Reason: {e}')

View File

@@ -3,4 +3,5 @@ from .archs import *
from .data import *
from .models import *
from .utils import *
from .version import *
# from .version import *

View File

@@ -0,0 +1,312 @@
import math
import random
import torch
from basicsr.utils.registry import ARCH_REGISTRY
from torch import nn
from .gfpganv1_arch import ResUpBlock
from .stylegan2_bilinear_arch import (ConvLayer, EqualConv2d, EqualLinear, ResBlock, ScaledLeakyReLU,
StyleGAN2GeneratorBilinear)
class StyleGAN2GeneratorBilinearSFT(StyleGAN2GeneratorBilinear):
"""StyleGAN2 Generator with SFT modulation (Spatial Feature Transform).
It is the bilinear version. It does not use the complicated UpFirDnSmooth function that is not friendly for
deployment. It can be easily converted to the clean version: StyleGAN2GeneratorCSFT.
Args:
out_size (int): The spatial size of outputs.
num_style_feat (int): Channel number of style features. Default: 512.
num_mlp (int): Layer number of MLP style layers. Default: 8.
channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
lr_mlp (float): Learning rate multiplier for mlp layers. Default: 0.01.
narrow (float): The narrow ratio for channels. Default: 1.
sft_half (bool): Whether to apply SFT on half of the input channels. Default: False.
"""
def __init__(self,
out_size,
num_style_feat=512,
num_mlp=8,
channel_multiplier=2,
lr_mlp=0.01,
narrow=1,
sft_half=False):
super(StyleGAN2GeneratorBilinearSFT, self).__init__(
out_size,
num_style_feat=num_style_feat,
num_mlp=num_mlp,
channel_multiplier=channel_multiplier,
lr_mlp=lr_mlp,
narrow=narrow)
self.sft_half = sft_half
def forward(self,
styles,
conditions,
input_is_latent=False,
noise=None,
randomize_noise=True,
truncation=1,
truncation_latent=None,
inject_index=None,
return_latents=False):
"""Forward function for StyleGAN2GeneratorBilinearSFT.
Args:
styles (list[Tensor]): Sample codes of styles.
conditions (list[Tensor]): SFT conditions to generators.
input_is_latent (bool): Whether input is latent style. Default: False.
noise (Tensor | None): Input noise or None. Default: None.
randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True.
truncation (float): The truncation ratio. Default: 1.
truncation_latent (Tensor | None): The truncation latent tensor. Default: None.
inject_index (int | None): The injection index for mixing noise. Default: None.
return_latents (bool): Whether to return style latents. Default: False.
"""
# style codes -> latents with Style MLP layer
if not input_is_latent:
styles = [self.style_mlp(s) for s in styles]
# noises
if noise is None:
if randomize_noise:
noise = [None] * self.num_layers # for each style conv layer
else: # use the stored noise
noise = [getattr(self.noises, f'noise{i}') for i in range(self.num_layers)]
# style truncation
if truncation < 1:
style_truncation = []
for style in styles:
style_truncation.append(truncation_latent + truncation * (style - truncation_latent))
styles = style_truncation
# get style latents with injection
if len(styles) == 1:
inject_index = self.num_latent
if styles[0].ndim < 3:
# repeat latent code for all the layers
latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
else: # used for encoder with different latent code for each layer
latent = styles[0]
elif len(styles) == 2: # mixing noises
if inject_index is None:
inject_index = random.randint(1, self.num_latent - 1)
latent1 = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
latent2 = styles[1].unsqueeze(1).repeat(1, self.num_latent - inject_index, 1)
latent = torch.cat([latent1, latent2], 1)
# main generation
out = self.constant_input(latent.shape[0])
out = self.style_conv1(out, latent[:, 0], noise=noise[0])
skip = self.to_rgb1(out, latent[:, 1])
i = 1
for conv1, conv2, noise1, noise2, to_rgb in zip(self.style_convs[::2], self.style_convs[1::2], noise[1::2],
noise[2::2], self.to_rgbs):
out = conv1(out, latent[:, i], noise=noise1)
# the conditions may have fewer levels
if i < len(conditions):
# SFT part to combine the conditions
if self.sft_half: # only apply SFT to half of the channels
out_same, out_sft = torch.split(out, int(out.size(1) // 2), dim=1)
out_sft = out_sft * conditions[i - 1] + conditions[i]
out = torch.cat([out_same, out_sft], dim=1)
else: # apply SFT to all the channels
out = out * conditions[i - 1] + conditions[i]
out = conv2(out, latent[:, i + 1], noise=noise2)
skip = to_rgb(out, latent[:, i + 2], skip) # feature back to the rgb space
i += 2
image = skip
if return_latents:
return image, latent
else:
return image, None
@ARCH_REGISTRY.register()
class GFPGANBilinear(nn.Module):
"""The GFPGAN architecture: Unet + StyleGAN2 decoder with SFT.
It is the bilinear version and it does not use the complicated UpFirDnSmooth function that is not friendly for
deployment. It can be easily converted to the clean version: GFPGANv1Clean.
Ref: GFP-GAN: Towards Real-World Blind Face Restoration with Generative Facial Prior.
Args:
out_size (int): The spatial size of outputs.
num_style_feat (int): Channel number of style features. Default: 512.
channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
decoder_load_path (str): The path to the pre-trained decoder model (usually, the StyleGAN2). Default: None.
fix_decoder (bool): Whether to fix the decoder. Default: True.
num_mlp (int): Layer number of MLP style layers. Default: 8.
lr_mlp (float): Learning rate multiplier for mlp layers. Default: 0.01.
input_is_latent (bool): Whether input is latent style. Default: False.
different_w (bool): Whether to use different latent w for different layers. Default: False.
narrow (float): The narrow ratio for channels. Default: 1.
sft_half (bool): Whether to apply SFT on half of the input channels. Default: False.
"""
def __init__(
self,
out_size,
num_style_feat=512,
channel_multiplier=1,
decoder_load_path=None,
fix_decoder=True,
# for stylegan decoder
num_mlp=8,
lr_mlp=0.01,
input_is_latent=False,
different_w=False,
narrow=1,
sft_half=False):
super(GFPGANBilinear, self).__init__()
self.input_is_latent = input_is_latent
self.different_w = different_w
self.num_style_feat = num_style_feat
unet_narrow = narrow * 0.5 # by default, use a half of input channels
channels = {
'4': int(512 * unet_narrow),
'8': int(512 * unet_narrow),
'16': int(512 * unet_narrow),
'32': int(512 * unet_narrow),
'64': int(256 * channel_multiplier * unet_narrow),
'128': int(128 * channel_multiplier * unet_narrow),
'256': int(64 * channel_multiplier * unet_narrow),
'512': int(32 * channel_multiplier * unet_narrow),
'1024': int(16 * channel_multiplier * unet_narrow)
}
self.log_size = int(math.log(out_size, 2))
first_out_size = 2**(int(math.log(out_size, 2)))
self.conv_body_first = ConvLayer(3, channels[f'{first_out_size}'], 1, bias=True, activate=True)
# downsample
in_channels = channels[f'{first_out_size}']
self.conv_body_down = nn.ModuleList()
for i in range(self.log_size, 2, -1):
out_channels = channels[f'{2**(i - 1)}']
self.conv_body_down.append(ResBlock(in_channels, out_channels))
in_channels = out_channels
self.final_conv = ConvLayer(in_channels, channels['4'], 3, bias=True, activate=True)
# upsample
in_channels = channels['4']
self.conv_body_up = nn.ModuleList()
for i in range(3, self.log_size + 1):
out_channels = channels[f'{2**i}']
self.conv_body_up.append(ResUpBlock(in_channels, out_channels))
in_channels = out_channels
# to RGB
self.toRGB = nn.ModuleList()
for i in range(3, self.log_size + 1):
self.toRGB.append(EqualConv2d(channels[f'{2**i}'], 3, 1, stride=1, padding=0, bias=True, bias_init_val=0))
if different_w:
linear_out_channel = (int(math.log(out_size, 2)) * 2 - 2) * num_style_feat
else:
linear_out_channel = num_style_feat
self.final_linear = EqualLinear(
channels['4'] * 4 * 4, linear_out_channel, bias=True, bias_init_val=0, lr_mul=1, activation=None)
# the decoder: stylegan2 generator with SFT modulations
self.stylegan_decoder = StyleGAN2GeneratorBilinearSFT(
out_size=out_size,
num_style_feat=num_style_feat,
num_mlp=num_mlp,
channel_multiplier=channel_multiplier,
lr_mlp=lr_mlp,
narrow=narrow,
sft_half=sft_half)
# load pre-trained stylegan2 model if necessary
if decoder_load_path:
self.stylegan_decoder.load_state_dict(
torch.load(decoder_load_path, map_location=lambda storage, loc: storage)['params_ema'])
# fix decoder without updating params
if fix_decoder:
for _, param in self.stylegan_decoder.named_parameters():
param.requires_grad = False
# for SFT modulations (scale and shift)
self.condition_scale = nn.ModuleList()
self.condition_shift = nn.ModuleList()
for i in range(3, self.log_size + 1):
out_channels = channels[f'{2**i}']
if sft_half:
sft_out_channels = out_channels
else:
sft_out_channels = out_channels * 2
self.condition_scale.append(
nn.Sequential(
EqualConv2d(out_channels, out_channels, 3, stride=1, padding=1, bias=True, bias_init_val=0),
ScaledLeakyReLU(0.2),
EqualConv2d(out_channels, sft_out_channels, 3, stride=1, padding=1, bias=True, bias_init_val=1)))
self.condition_shift.append(
nn.Sequential(
EqualConv2d(out_channels, out_channels, 3, stride=1, padding=1, bias=True, bias_init_val=0),
ScaledLeakyReLU(0.2),
EqualConv2d(out_channels, sft_out_channels, 3, stride=1, padding=1, bias=True, bias_init_val=0)))
def forward(self, x, return_latents=False, return_rgb=True, randomize_noise=True):
"""Forward function for GFPGANBilinear.
Args:
x (Tensor): Input images.
return_latents (bool): Whether to return style latents. Default: False.
return_rgb (bool): Whether return intermediate rgb images. Default: True.
randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True.
"""
conditions = []
unet_skips = []
out_rgbs = []
# encoder
feat = self.conv_body_first(x)
for i in range(self.log_size - 2):
feat = self.conv_body_down[i](feat)
unet_skips.insert(0, feat)
feat = self.final_conv(feat)
# style code
style_code = self.final_linear(feat.view(feat.size(0), -1))
if self.different_w:
style_code = style_code.view(style_code.size(0), -1, self.num_style_feat)
# decode
for i in range(self.log_size - 2):
# add unet skip
feat = feat + unet_skips[i]
# ResUpLayer
feat = self.conv_body_up[i](feat)
# generate scale and shift for SFT layers
scale = self.condition_scale[i](feat)
conditions.append(scale.clone())
shift = self.condition_shift[i](feat)
conditions.append(shift.clone())
# generate rgb images
if return_rgb:
out_rgbs.append(self.toRGB[i](feat))
# decoder
image, _ = self.stylegan_decoder([style_code],
conditions,
return_latents=return_latents,
input_is_latent=self.input_is_latent,
randomize_noise=randomize_noise)
return image, out_rgbs

View File

@@ -350,7 +350,7 @@ class GFPGANv1(nn.Module):
ScaledLeakyReLU(0.2),
EqualConv2d(out_channels, sft_out_channels, 3, stride=1, padding=1, bias=True, bias_init_val=0)))
def forward(self, x, return_latents=False, return_rgb=True, randomize_noise=True):
def forward(self, x, return_latents=False, return_rgb=True, randomize_noise=True, **kwargs):
"""Forward function for GFPGANv1.
Args:
@@ -416,7 +416,7 @@ class FacialComponentDiscriminator(nn.Module):
self.conv5 = ConvLayer(256, 256, 3, downsample=False, resample_kernel=(1, 3, 3, 1), bias=True, activate=True)
self.final_conv = ConvLayer(256, 1, 3, bias=True, activate=False)
def forward(self, x, return_feats=False):
def forward(self, x, return_feats=False, **kwargs):
"""Forward function for FacialComponentDiscriminator.
Args:

View File

@@ -274,7 +274,7 @@ class GFPGANv1Clean(nn.Module):
nn.Conv2d(out_channels, out_channels, 3, 1, 1), nn.LeakyReLU(0.2, True),
nn.Conv2d(out_channels, sft_out_channels, 3, 1, 1)))
def forward(self, x, return_latents=False, return_rgb=True, randomize_noise=True):
def forward(self, x, return_latents=False, return_rgb=True, randomize_noise=True, **kwargs):
"""Forward function for GFPGANv1Clean.
Args:

View File

@@ -0,0 +1,658 @@
"""Modified from https://github.com/wzhouxiff/RestoreFormer
"""
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
class VectorQuantizer(nn.Module):
"""
see https://github.com/MishaLaskin/vqvae/blob/d761a999e2267766400dc646d82d3ac3657771d4/models/quantizer.py
____________________________________________
Discretization bottleneck part of the VQ-VAE.
Inputs:
- n_e : number of embeddings
- e_dim : dimension of embedding
- beta : commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2
_____________________________________________
"""
def __init__(self, n_e, e_dim, beta):
super(VectorQuantizer, self).__init__()
self.n_e = n_e
self.e_dim = e_dim
self.beta = beta
self.embedding = nn.Embedding(self.n_e, self.e_dim)
self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
def forward(self, z):
"""
Inputs the output of the encoder network z and maps it to a discrete
one-hot vector that is the index of the closest embedding vector e_j
z (continuous) -> z_q (discrete)
z.shape = (batch, channel, height, width)
quantization pipeline:
1. get encoder input (B,C,H,W)
2. flatten input to (B*H*W,C)
"""
# reshape z -> (batch, height, width, channel) and flatten
z = z.permute(0, 2, 3, 1).contiguous()
z_flattened = z.view(-1, self.e_dim)
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \
torch.sum(self.embedding.weight**2, dim=1) - 2 * \
torch.matmul(z_flattened, self.embedding.weight.t())
# could possible replace this here
# #\start...
# find closest encodings
min_value, min_encoding_indices = torch.min(d, dim=1)
min_encoding_indices = min_encoding_indices.unsqueeze(1)
min_encodings = torch.zeros(min_encoding_indices.shape[0], self.n_e).to(z)
min_encodings.scatter_(1, min_encoding_indices, 1)
# dtype min encodings: torch.float32
# min_encodings shape: torch.Size([2048, 512])
# min_encoding_indices.shape: torch.Size([2048, 1])
# get quantized latent vectors
z_q = torch.matmul(min_encodings, self.embedding.weight).view(z.shape)
# .........\end
# with:
# .........\start
# min_encoding_indices = torch.argmin(d, dim=1)
# z_q = self.embedding(min_encoding_indices)
# ......\end......... (TODO)
# compute loss for embedding
loss = torch.mean((z_q.detach() - z)**2) + self.beta * torch.mean((z_q - z.detach())**2)
# preserve gradients
z_q = z + (z_q - z).detach()
# perplexity
e_mean = torch.mean(min_encodings, dim=0)
perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + 1e-10)))
# reshape back to match original input shape
z_q = z_q.permute(0, 3, 1, 2).contiguous()
return z_q, loss, (perplexity, min_encodings, min_encoding_indices, d)
def get_codebook_entry(self, indices, shape):
# shape specifying (batch, height, width, channel)
# TODO: check for more easy handling with nn.Embedding
min_encodings = torch.zeros(indices.shape[0], self.n_e).to(indices)
min_encodings.scatter_(1, indices[:, None], 1)
# get quantized latent vectors
z_q = torch.matmul(min_encodings.float(), self.embedding.weight)
if shape is not None:
z_q = z_q.view(shape)
# reshape back to match original input shape
z_q = z_q.permute(0, 3, 1, 2).contiguous()
return z_q
# pytorch_diffusion + derived encoder decoder
def nonlinearity(x):
# swish
return x * torch.sigmoid(x)
def Normalize(in_channels):
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
class Upsample(nn.Module):
def __init__(self, in_channels, with_conv):
super().__init__()
self.with_conv = with_conv
if self.with_conv:
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
def forward(self, x):
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode='nearest')
if self.with_conv:
x = self.conv(x)
return x
class Downsample(nn.Module):
def __init__(self, in_channels, with_conv):
super().__init__()
self.with_conv = with_conv
if self.with_conv:
# no asymmetric padding in torch conv, must do it ourselves
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
def forward(self, x):
if self.with_conv:
pad = (0, 1, 0, 1)
x = torch.nn.functional.pad(x, pad, mode='constant', value=0)
x = self.conv(x)
else:
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
return x
class ResnetBlock(nn.Module):
def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, dropout, temb_channels=512):
super().__init__()
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.use_conv_shortcut = conv_shortcut
self.norm1 = Normalize(in_channels)
self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
if temb_channels > 0:
self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
self.norm2 = Normalize(out_channels)
self.dropout = torch.nn.Dropout(dropout)
self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
else:
self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
def forward(self, x, temb):
h = x
h = self.norm1(h)
h = nonlinearity(h)
h = self.conv1(h)
if temb is not None:
h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
h = self.norm2(h)
h = nonlinearity(h)
h = self.dropout(h)
h = self.conv2(h)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
x = self.conv_shortcut(x)
else:
x = self.nin_shortcut(x)
return x + h
class MultiHeadAttnBlock(nn.Module):
def __init__(self, in_channels, head_size=1):
super().__init__()
self.in_channels = in_channels
self.head_size = head_size
self.att_size = in_channels // head_size
assert (in_channels % head_size == 0), 'The size of head should be divided by the number of channels.'
self.norm1 = Normalize(in_channels)
self.norm2 = Normalize(in_channels)
self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.num = 0
def forward(self, x, y=None):
h_ = x
h_ = self.norm1(h_)
if y is None:
y = h_
else:
y = self.norm2(y)
q = self.q(y)
k = self.k(h_)
v = self.v(h_)
# compute attention
b, c, h, w = q.shape
q = q.reshape(b, self.head_size, self.att_size, h * w)
q = q.permute(0, 3, 1, 2) # b, hw, head, att
k = k.reshape(b, self.head_size, self.att_size, h * w)
k = k.permute(0, 3, 1, 2)
v = v.reshape(b, self.head_size, self.att_size, h * w)
v = v.permute(0, 3, 1, 2)
q = q.transpose(1, 2)
v = v.transpose(1, 2)
k = k.transpose(1, 2).transpose(2, 3)
scale = int(self.att_size)**(-0.5)
q.mul_(scale)
w_ = torch.matmul(q, k)
w_ = F.softmax(w_, dim=3)
w_ = w_.matmul(v)
w_ = w_.transpose(1, 2).contiguous() # [b, h*w, head, att]
w_ = w_.view(b, h, w, -1)
w_ = w_.permute(0, 3, 1, 2)
w_ = self.proj_out(w_)
return x + w_
class MultiHeadEncoder(nn.Module):
def __init__(self,
ch,
out_ch,
ch_mult=(1, 2, 4, 8),
num_res_blocks=2,
attn_resolutions=(16, ),
dropout=0.0,
resamp_with_conv=True,
in_channels=3,
resolution=512,
z_channels=256,
double_z=True,
enable_mid=True,
head_size=1,
**ignore_kwargs):
super().__init__()
self.ch = ch
self.temb_ch = 0
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.in_channels = in_channels
self.enable_mid = enable_mid
# downsampling
self.conv_in = torch.nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1)
curr_res = resolution
in_ch_mult = (1, ) + tuple(ch_mult)
self.down = nn.ModuleList()
for i_level in range(self.num_resolutions):
block = nn.ModuleList()
attn = nn.ModuleList()
block_in = ch * in_ch_mult[i_level]
block_out = ch * ch_mult[i_level]
for i_block in range(self.num_res_blocks):
block.append(
ResnetBlock(
in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout))
block_in = block_out
if curr_res in attn_resolutions:
attn.append(MultiHeadAttnBlock(block_in, head_size))
down = nn.Module()
down.block = block
down.attn = attn
if i_level != self.num_resolutions - 1:
down.downsample = Downsample(block_in, resamp_with_conv)
curr_res = curr_res // 2
self.down.append(down)
# middle
if self.enable_mid:
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(
in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout)
self.mid.attn_1 = MultiHeadAttnBlock(block_in, head_size)
self.mid.block_2 = ResnetBlock(
in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout)
# end
self.norm_out = Normalize(block_in)
self.conv_out = torch.nn.Conv2d(
block_in, 2 * z_channels if double_z else z_channels, kernel_size=3, stride=1, padding=1)
def forward(self, x):
hs = {}
# timestep embedding
temb = None
# downsampling
h = self.conv_in(x)
hs['in'] = h
for i_level in range(self.num_resolutions):
for i_block in range(self.num_res_blocks):
h = self.down[i_level].block[i_block](h, temb)
if len(self.down[i_level].attn) > 0:
h = self.down[i_level].attn[i_block](h)
if i_level != self.num_resolutions - 1:
# hs.append(h)
hs['block_' + str(i_level)] = h
h = self.down[i_level].downsample(h)
# middle
# h = hs[-1]
if self.enable_mid:
h = self.mid.block_1(h, temb)
hs['block_' + str(i_level) + '_atten'] = h
h = self.mid.attn_1(h)
h = self.mid.block_2(h, temb)
hs['mid_atten'] = h
# end
h = self.norm_out(h)
h = nonlinearity(h)
h = self.conv_out(h)
# hs.append(h)
hs['out'] = h
return hs
class MultiHeadDecoder(nn.Module):
def __init__(self,
ch,
out_ch,
ch_mult=(1, 2, 4, 8),
num_res_blocks=2,
attn_resolutions=(16, ),
dropout=0.0,
resamp_with_conv=True,
in_channels=3,
resolution=512,
z_channels=256,
give_pre_end=False,
enable_mid=True,
head_size=1,
**ignorekwargs):
super().__init__()
self.ch = ch
self.temb_ch = 0
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.in_channels = in_channels
self.give_pre_end = give_pre_end
self.enable_mid = enable_mid
# compute in_ch_mult, block_in and curr_res at lowest res
block_in = ch * ch_mult[self.num_resolutions - 1]
curr_res = resolution // 2**(self.num_resolutions - 1)
self.z_shape = (1, z_channels, curr_res, curr_res)
print('Working with z of shape {} = {} dimensions.'.format(self.z_shape, np.prod(self.z_shape)))
# z to block_in
self.conv_in = torch.nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
# middle
if self.enable_mid:
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(
in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout)
self.mid.attn_1 = MultiHeadAttnBlock(block_in, head_size)
self.mid.block_2 = ResnetBlock(
in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout)
# upsampling
self.up = nn.ModuleList()
for i_level in reversed(range(self.num_resolutions)):
block = nn.ModuleList()
attn = nn.ModuleList()
block_out = ch * ch_mult[i_level]
for i_block in range(self.num_res_blocks + 1):
block.append(
ResnetBlock(
in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout))
block_in = block_out
if curr_res in attn_resolutions:
attn.append(MultiHeadAttnBlock(block_in, head_size))
up = nn.Module()
up.block = block
up.attn = attn
if i_level != 0:
up.upsample = Upsample(block_in, resamp_with_conv)
curr_res = curr_res * 2
self.up.insert(0, up) # prepend to get consistent order
# end
self.norm_out = Normalize(block_in)
self.conv_out = torch.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
def forward(self, z):
# assert z.shape[1:] == self.z_shape[1:]
self.last_z_shape = z.shape
# timestep embedding
temb = None
# z to block_in
h = self.conv_in(z)
# middle
if self.enable_mid:
h = self.mid.block_1(h, temb)
h = self.mid.attn_1(h)
h = self.mid.block_2(h, temb)
# upsampling
for i_level in reversed(range(self.num_resolutions)):
for i_block in range(self.num_res_blocks + 1):
h = self.up[i_level].block[i_block](h, temb)
if len(self.up[i_level].attn) > 0:
h = self.up[i_level].attn[i_block](h)
if i_level != 0:
h = self.up[i_level].upsample(h)
# end
if self.give_pre_end:
return h
h = self.norm_out(h)
h = nonlinearity(h)
h = self.conv_out(h)
return h
class MultiHeadDecoderTransformer(nn.Module):
def __init__(self,
ch,
out_ch,
ch_mult=(1, 2, 4, 8),
num_res_blocks=2,
attn_resolutions=(16, ),
dropout=0.0,
resamp_with_conv=True,
in_channels=3,
resolution=512,
z_channels=256,
give_pre_end=False,
enable_mid=True,
head_size=1,
**ignorekwargs):
super().__init__()
self.ch = ch
self.temb_ch = 0
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.in_channels = in_channels
self.give_pre_end = give_pre_end
self.enable_mid = enable_mid
# compute in_ch_mult, block_in and curr_res at lowest res
block_in = ch * ch_mult[self.num_resolutions - 1]
curr_res = resolution // 2**(self.num_resolutions - 1)
self.z_shape = (1, z_channels, curr_res, curr_res)
print('Working with z of shape {} = {} dimensions.'.format(self.z_shape, np.prod(self.z_shape)))
# z to block_in
self.conv_in = torch.nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
# middle
if self.enable_mid:
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(
in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout)
self.mid.attn_1 = MultiHeadAttnBlock(block_in, head_size)
self.mid.block_2 = ResnetBlock(
in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout)
# upsampling
self.up = nn.ModuleList()
for i_level in reversed(range(self.num_resolutions)):
block = nn.ModuleList()
attn = nn.ModuleList()
block_out = ch * ch_mult[i_level]
for i_block in range(self.num_res_blocks + 1):
block.append(
ResnetBlock(
in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout))
block_in = block_out
if curr_res in attn_resolutions:
attn.append(MultiHeadAttnBlock(block_in, head_size))
up = nn.Module()
up.block = block
up.attn = attn
if i_level != 0:
up.upsample = Upsample(block_in, resamp_with_conv)
curr_res = curr_res * 2
self.up.insert(0, up) # prepend to get consistent order
# end
self.norm_out = Normalize(block_in)
self.conv_out = torch.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
def forward(self, z, hs):
# assert z.shape[1:] == self.z_shape[1:]
# self.last_z_shape = z.shape
# timestep embedding
temb = None
# z to block_in
h = self.conv_in(z)
# middle
if self.enable_mid:
h = self.mid.block_1(h, temb)
h = self.mid.attn_1(h, hs['mid_atten'])
h = self.mid.block_2(h, temb)
# upsampling
for i_level in reversed(range(self.num_resolutions)):
for i_block in range(self.num_res_blocks + 1):
h = self.up[i_level].block[i_block](h, temb)
if len(self.up[i_level].attn) > 0:
h = self.up[i_level].attn[i_block](h, hs['block_' + str(i_level) + '_atten'])
# hfeature = h.clone()
if i_level != 0:
h = self.up[i_level].upsample(h)
# end
if self.give_pre_end:
return h
h = self.norm_out(h)
h = nonlinearity(h)
h = self.conv_out(h)
return h
class RestoreFormer(nn.Module):
def __init__(self,
n_embed=1024,
embed_dim=256,
ch=64,
out_ch=3,
ch_mult=(1, 2, 2, 4, 4, 8),
num_res_blocks=2,
attn_resolutions=(16, ),
dropout=0.0,
in_channels=3,
resolution=512,
z_channels=256,
double_z=False,
enable_mid=True,
fix_decoder=False,
fix_codebook=True,
fix_encoder=False,
head_size=8):
super(RestoreFormer, self).__init__()
self.encoder = MultiHeadEncoder(
ch=ch,
out_ch=out_ch,
ch_mult=ch_mult,
num_res_blocks=num_res_blocks,
attn_resolutions=attn_resolutions,
dropout=dropout,
in_channels=in_channels,
resolution=resolution,
z_channels=z_channels,
double_z=double_z,
enable_mid=enable_mid,
head_size=head_size)
self.decoder = MultiHeadDecoderTransformer(
ch=ch,
out_ch=out_ch,
ch_mult=ch_mult,
num_res_blocks=num_res_blocks,
attn_resolutions=attn_resolutions,
dropout=dropout,
in_channels=in_channels,
resolution=resolution,
z_channels=z_channels,
enable_mid=enable_mid,
head_size=head_size)
self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25)
self.quant_conv = torch.nn.Conv2d(z_channels, embed_dim, 1)
self.post_quant_conv = torch.nn.Conv2d(embed_dim, z_channels, 1)
if fix_decoder:
for _, param in self.decoder.named_parameters():
param.requires_grad = False
for _, param in self.post_quant_conv.named_parameters():
param.requires_grad = False
for _, param in self.quantize.named_parameters():
param.requires_grad = False
elif fix_codebook:
for _, param in self.quantize.named_parameters():
param.requires_grad = False
if fix_encoder:
for _, param in self.encoder.named_parameters():
param.requires_grad = False
def encode(self, x):
hs = self.encoder(x)
h = self.quant_conv(hs['out'])
quant, emb_loss, info = self.quantize(h)
return quant, emb_loss, info, hs
def decode(self, quant, hs):
quant = self.post_quant_conv(quant)
dec = self.decoder(quant, hs)
return dec
def forward(self, input, **kwargs):
quant, diff, info, hs = self.encode(input)
dec = self.decode(quant, hs)
return dec, None

View File

@@ -0,0 +1,613 @@
import math
import random
import torch
from basicsr.ops.fused_act import FusedLeakyReLU, fused_leaky_relu
from basicsr.utils.registry import ARCH_REGISTRY
from torch import nn
from torch.nn import functional as F
class NormStyleCode(nn.Module):
def forward(self, x):
"""Normalize the style codes.
Args:
x (Tensor): Style codes with shape (b, c).
Returns:
Tensor: Normalized tensor.
"""
return x * torch.rsqrt(torch.mean(x**2, dim=1, keepdim=True) + 1e-8)
class EqualLinear(nn.Module):
"""Equalized Linear as StyleGAN2.
Args:
in_channels (int): Size of each sample.
out_channels (int): Size of each output sample.
bias (bool): If set to ``False``, the layer will not learn an additive
bias. Default: ``True``.
bias_init_val (float): Bias initialized value. Default: 0.
lr_mul (float): Learning rate multiplier. Default: 1.
activation (None | str): The activation after ``linear`` operation.
Supported: 'fused_lrelu', None. Default: None.
"""
def __init__(self, in_channels, out_channels, bias=True, bias_init_val=0, lr_mul=1, activation=None):
super(EqualLinear, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.lr_mul = lr_mul
self.activation = activation
if self.activation not in ['fused_lrelu', None]:
raise ValueError(f'Wrong activation value in EqualLinear: {activation}'
"Supported ones are: ['fused_lrelu', None].")
self.scale = (1 / math.sqrt(in_channels)) * lr_mul
self.weight = nn.Parameter(torch.randn(out_channels, in_channels).div_(lr_mul))
if bias:
self.bias = nn.Parameter(torch.zeros(out_channels).fill_(bias_init_val))
else:
self.register_parameter('bias', None)
def forward(self, x):
if self.bias is None:
bias = None
else:
bias = self.bias * self.lr_mul
if self.activation == 'fused_lrelu':
out = F.linear(x, self.weight * self.scale)
out = fused_leaky_relu(out, bias)
else:
out = F.linear(x, self.weight * self.scale, bias=bias)
return out
def __repr__(self):
return (f'{self.__class__.__name__}(in_channels={self.in_channels}, '
f'out_channels={self.out_channels}, bias={self.bias is not None})')
class ModulatedConv2d(nn.Module):
"""Modulated Conv2d used in StyleGAN2.
There is no bias in ModulatedConv2d.
Args:
in_channels (int): Channel number of the input.
out_channels (int): Channel number of the output.
kernel_size (int): Size of the convolving kernel.
num_style_feat (int): Channel number of style features.
demodulate (bool): Whether to demodulate in the conv layer.
Default: True.
sample_mode (str | None): Indicating 'upsample', 'downsample' or None.
Default: None.
eps (float): A value added to the denominator for numerical stability.
Default: 1e-8.
"""
def __init__(self,
in_channels,
out_channels,
kernel_size,
num_style_feat,
demodulate=True,
sample_mode=None,
eps=1e-8,
interpolation_mode='bilinear'):
super(ModulatedConv2d, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.demodulate = demodulate
self.sample_mode = sample_mode
self.eps = eps
self.interpolation_mode = interpolation_mode
if self.interpolation_mode == 'nearest':
self.align_corners = None
else:
self.align_corners = False
self.scale = 1 / math.sqrt(in_channels * kernel_size**2)
# modulation inside each modulated conv
self.modulation = EqualLinear(
num_style_feat, in_channels, bias=True, bias_init_val=1, lr_mul=1, activation=None)
self.weight = nn.Parameter(torch.randn(1, out_channels, in_channels, kernel_size, kernel_size))
self.padding = kernel_size // 2
def forward(self, x, style):
"""Forward function.
Args:
x (Tensor): Tensor with shape (b, c, h, w).
style (Tensor): Tensor with shape (b, num_style_feat).
Returns:
Tensor: Modulated tensor after convolution.
"""
b, c, h, w = x.shape # c = c_in
# weight modulation
style = self.modulation(style).view(b, 1, c, 1, 1)
# self.weight: (1, c_out, c_in, k, k); style: (b, 1, c, 1, 1)
weight = self.scale * self.weight * style # (b, c_out, c_in, k, k)
if self.demodulate:
demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + self.eps)
weight = weight * demod.view(b, self.out_channels, 1, 1, 1)
weight = weight.view(b * self.out_channels, c, self.kernel_size, self.kernel_size)
if self.sample_mode == 'upsample':
x = F.interpolate(x, scale_factor=2, mode=self.interpolation_mode, align_corners=self.align_corners)
elif self.sample_mode == 'downsample':
x = F.interpolate(x, scale_factor=0.5, mode=self.interpolation_mode, align_corners=self.align_corners)
b, c, h, w = x.shape
x = x.view(1, b * c, h, w)
# weight: (b*c_out, c_in, k, k), groups=b
out = F.conv2d(x, weight, padding=self.padding, groups=b)
out = out.view(b, self.out_channels, *out.shape[2:4])
return out
def __repr__(self):
return (f'{self.__class__.__name__}(in_channels={self.in_channels}, '
f'out_channels={self.out_channels}, '
f'kernel_size={self.kernel_size}, '
f'demodulate={self.demodulate}, sample_mode={self.sample_mode})')
class StyleConv(nn.Module):
"""Style conv.
Args:
in_channels (int): Channel number of the input.
out_channels (int): Channel number of the output.
kernel_size (int): Size of the convolving kernel.
num_style_feat (int): Channel number of style features.
demodulate (bool): Whether demodulate in the conv layer. Default: True.
sample_mode (str | None): Indicating 'upsample', 'downsample' or None.
Default: None.
"""
def __init__(self,
in_channels,
out_channels,
kernel_size,
num_style_feat,
demodulate=True,
sample_mode=None,
interpolation_mode='bilinear'):
super(StyleConv, self).__init__()
self.modulated_conv = ModulatedConv2d(
in_channels,
out_channels,
kernel_size,
num_style_feat,
demodulate=demodulate,
sample_mode=sample_mode,
interpolation_mode=interpolation_mode)
self.weight = nn.Parameter(torch.zeros(1)) # for noise injection
self.activate = FusedLeakyReLU(out_channels)
def forward(self, x, style, noise=None):
# modulate
out = self.modulated_conv(x, style)
# noise injection
if noise is None:
b, _, h, w = out.shape
noise = out.new_empty(b, 1, h, w).normal_()
out = out + self.weight * noise
# activation (with bias)
out = self.activate(out)
return out
class ToRGB(nn.Module):
"""To RGB from features.
Args:
in_channels (int): Channel number of input.
num_style_feat (int): Channel number of style features.
upsample (bool): Whether to upsample. Default: True.
"""
def __init__(self, in_channels, num_style_feat, upsample=True, interpolation_mode='bilinear'):
super(ToRGB, self).__init__()
self.upsample = upsample
self.interpolation_mode = interpolation_mode
if self.interpolation_mode == 'nearest':
self.align_corners = None
else:
self.align_corners = False
self.modulated_conv = ModulatedConv2d(
in_channels,
3,
kernel_size=1,
num_style_feat=num_style_feat,
demodulate=False,
sample_mode=None,
interpolation_mode=interpolation_mode)
self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1))
def forward(self, x, style, skip=None):
"""Forward function.
Args:
x (Tensor): Feature tensor with shape (b, c, h, w).
style (Tensor): Tensor with shape (b, num_style_feat).
skip (Tensor): Base/skip tensor. Default: None.
Returns:
Tensor: RGB images.
"""
out = self.modulated_conv(x, style)
out = out + self.bias
if skip is not None:
if self.upsample:
skip = F.interpolate(
skip, scale_factor=2, mode=self.interpolation_mode, align_corners=self.align_corners)
out = out + skip
return out
class ConstantInput(nn.Module):
"""Constant input.
Args:
num_channel (int): Channel number of constant input.
size (int): Spatial size of constant input.
"""
def __init__(self, num_channel, size):
super(ConstantInput, self).__init__()
self.weight = nn.Parameter(torch.randn(1, num_channel, size, size))
def forward(self, batch):
out = self.weight.repeat(batch, 1, 1, 1)
return out
@ARCH_REGISTRY.register()
class StyleGAN2GeneratorBilinear(nn.Module):
"""StyleGAN2 Generator.
Args:
out_size (int): The spatial size of outputs.
num_style_feat (int): Channel number of style features. Default: 512.
num_mlp (int): Layer number of MLP style layers. Default: 8.
channel_multiplier (int): Channel multiplier for large networks of
StyleGAN2. Default: 2.
lr_mlp (float): Learning rate multiplier for mlp layers. Default: 0.01.
narrow (float): Narrow ratio for channels. Default: 1.0.
"""
def __init__(self,
out_size,
num_style_feat=512,
num_mlp=8,
channel_multiplier=2,
lr_mlp=0.01,
narrow=1,
interpolation_mode='bilinear'):
super(StyleGAN2GeneratorBilinear, self).__init__()
# Style MLP layers
self.num_style_feat = num_style_feat
style_mlp_layers = [NormStyleCode()]
for i in range(num_mlp):
style_mlp_layers.append(
EqualLinear(
num_style_feat, num_style_feat, bias=True, bias_init_val=0, lr_mul=lr_mlp,
activation='fused_lrelu'))
self.style_mlp = nn.Sequential(*style_mlp_layers)
channels = {
'4': int(512 * narrow),
'8': int(512 * narrow),
'16': int(512 * narrow),
'32': int(512 * narrow),
'64': int(256 * channel_multiplier * narrow),
'128': int(128 * channel_multiplier * narrow),
'256': int(64 * channel_multiplier * narrow),
'512': int(32 * channel_multiplier * narrow),
'1024': int(16 * channel_multiplier * narrow)
}
self.channels = channels
self.constant_input = ConstantInput(channels['4'], size=4)
self.style_conv1 = StyleConv(
channels['4'],
channels['4'],
kernel_size=3,
num_style_feat=num_style_feat,
demodulate=True,
sample_mode=None,
interpolation_mode=interpolation_mode)
self.to_rgb1 = ToRGB(channels['4'], num_style_feat, upsample=False, interpolation_mode=interpolation_mode)
self.log_size = int(math.log(out_size, 2))
self.num_layers = (self.log_size - 2) * 2 + 1
self.num_latent = self.log_size * 2 - 2
self.style_convs = nn.ModuleList()
self.to_rgbs = nn.ModuleList()
self.noises = nn.Module()
in_channels = channels['4']
# noise
for layer_idx in range(self.num_layers):
resolution = 2**((layer_idx + 5) // 2)
shape = [1, 1, resolution, resolution]
self.noises.register_buffer(f'noise{layer_idx}', torch.randn(*shape))
# style convs and to_rgbs
for i in range(3, self.log_size + 1):
out_channels = channels[f'{2**i}']
self.style_convs.append(
StyleConv(
in_channels,
out_channels,
kernel_size=3,
num_style_feat=num_style_feat,
demodulate=True,
sample_mode='upsample',
interpolation_mode=interpolation_mode))
self.style_convs.append(
StyleConv(
out_channels,
out_channels,
kernel_size=3,
num_style_feat=num_style_feat,
demodulate=True,
sample_mode=None,
interpolation_mode=interpolation_mode))
self.to_rgbs.append(
ToRGB(out_channels, num_style_feat, upsample=True, interpolation_mode=interpolation_mode))
in_channels = out_channels
def make_noise(self):
"""Make noise for noise injection."""
device = self.constant_input.weight.device
noises = [torch.randn(1, 1, 4, 4, device=device)]
for i in range(3, self.log_size + 1):
for _ in range(2):
noises.append(torch.randn(1, 1, 2**i, 2**i, device=device))
return noises
def get_latent(self, x):
return self.style_mlp(x)
def mean_latent(self, num_latent):
latent_in = torch.randn(num_latent, self.num_style_feat, device=self.constant_input.weight.device)
latent = self.style_mlp(latent_in).mean(0, keepdim=True)
return latent
def forward(self,
styles,
input_is_latent=False,
noise=None,
randomize_noise=True,
truncation=1,
truncation_latent=None,
inject_index=None,
return_latents=False):
"""Forward function for StyleGAN2Generator.
Args:
styles (list[Tensor]): Sample codes of styles.
input_is_latent (bool): Whether input is latent style.
Default: False.
noise (Tensor | None): Input noise or None. Default: None.
randomize_noise (bool): Randomize noise, used when 'noise' is
False. Default: True.
truncation (float): TODO. Default: 1.
truncation_latent (Tensor | None): TODO. Default: None.
inject_index (int | None): The injection index for mixing noise.
Default: None.
return_latents (bool): Whether to return style latents.
Default: False.
"""
# style codes -> latents with Style MLP layer
if not input_is_latent:
styles = [self.style_mlp(s) for s in styles]
# noises
if noise is None:
if randomize_noise:
noise = [None] * self.num_layers # for each style conv layer
else: # use the stored noise
noise = [getattr(self.noises, f'noise{i}') for i in range(self.num_layers)]
# style truncation
if truncation < 1:
style_truncation = []
for style in styles:
style_truncation.append(truncation_latent + truncation * (style - truncation_latent))
styles = style_truncation
# get style latent with injection
if len(styles) == 1:
inject_index = self.num_latent
if styles[0].ndim < 3:
# repeat latent code for all the layers
latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
else: # used for encoder with different latent code for each layer
latent = styles[0]
elif len(styles) == 2: # mixing noises
if inject_index is None:
inject_index = random.randint(1, self.num_latent - 1)
latent1 = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
latent2 = styles[1].unsqueeze(1).repeat(1, self.num_latent - inject_index, 1)
latent = torch.cat([latent1, latent2], 1)
# main generation
out = self.constant_input(latent.shape[0])
out = self.style_conv1(out, latent[:, 0], noise=noise[0])
skip = self.to_rgb1(out, latent[:, 1])
i = 1
for conv1, conv2, noise1, noise2, to_rgb in zip(self.style_convs[::2], self.style_convs[1::2], noise[1::2],
noise[2::2], self.to_rgbs):
out = conv1(out, latent[:, i], noise=noise1)
out = conv2(out, latent[:, i + 1], noise=noise2)
skip = to_rgb(out, latent[:, i + 2], skip)
i += 2
image = skip
if return_latents:
return image, latent
else:
return image, None
class ScaledLeakyReLU(nn.Module):
"""Scaled LeakyReLU.
Args:
negative_slope (float): Negative slope. Default: 0.2.
"""
def __init__(self, negative_slope=0.2):
super(ScaledLeakyReLU, self).__init__()
self.negative_slope = negative_slope
def forward(self, x):
out = F.leaky_relu(x, negative_slope=self.negative_slope)
return out * math.sqrt(2)
class EqualConv2d(nn.Module):
"""Equalized Linear as StyleGAN2.
Args:
in_channels (int): Channel number of the input.
out_channels (int): Channel number of the output.
kernel_size (int): Size of the convolving kernel.
stride (int): Stride of the convolution. Default: 1
padding (int): Zero-padding added to both sides of the input.
Default: 0.
bias (bool): If ``True``, adds a learnable bias to the output.
Default: ``True``.
bias_init_val (float): Bias initialized value. Default: 0.
"""
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, bias=True, bias_init_val=0):
super(EqualConv2d, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.stride = stride
self.padding = padding
self.scale = 1 / math.sqrt(in_channels * kernel_size**2)
self.weight = nn.Parameter(torch.randn(out_channels, in_channels, kernel_size, kernel_size))
if bias:
self.bias = nn.Parameter(torch.zeros(out_channels).fill_(bias_init_val))
else:
self.register_parameter('bias', None)
def forward(self, x):
out = F.conv2d(
x,
self.weight * self.scale,
bias=self.bias,
stride=self.stride,
padding=self.padding,
)
return out
def __repr__(self):
return (f'{self.__class__.__name__}(in_channels={self.in_channels}, '
f'out_channels={self.out_channels}, '
f'kernel_size={self.kernel_size},'
f' stride={self.stride}, padding={self.padding}, '
f'bias={self.bias is not None})')
class ConvLayer(nn.Sequential):
"""Conv Layer used in StyleGAN2 Discriminator.
Args:
in_channels (int): Channel number of the input.
out_channels (int): Channel number of the output.
kernel_size (int): Kernel size.
downsample (bool): Whether downsample by a factor of 2.
Default: False.
bias (bool): Whether with bias. Default: True.
activate (bool): Whether use activateion. Default: True.
"""
def __init__(self,
in_channels,
out_channels,
kernel_size,
downsample=False,
bias=True,
activate=True,
interpolation_mode='bilinear'):
layers = []
self.interpolation_mode = interpolation_mode
# downsample
if downsample:
if self.interpolation_mode == 'nearest':
self.align_corners = None
else:
self.align_corners = False
layers.append(
torch.nn.Upsample(scale_factor=0.5, mode=interpolation_mode, align_corners=self.align_corners))
stride = 1
self.padding = kernel_size // 2
# conv
layers.append(
EqualConv2d(
in_channels, out_channels, kernel_size, stride=stride, padding=self.padding, bias=bias
and not activate))
# activation
if activate:
if bias:
layers.append(FusedLeakyReLU(out_channels))
else:
layers.append(ScaledLeakyReLU(0.2))
super(ConvLayer, self).__init__(*layers)
class ResBlock(nn.Module):
"""Residual block used in StyleGAN2 Discriminator.
Args:
in_channels (int): Channel number of the input.
out_channels (int): Channel number of the output.
"""
def __init__(self, in_channels, out_channels, interpolation_mode='bilinear'):
super(ResBlock, self).__init__()
self.conv1 = ConvLayer(in_channels, in_channels, 3, bias=True, activate=True)
self.conv2 = ConvLayer(
in_channels,
out_channels,
3,
downsample=True,
interpolation_mode=interpolation_mode,
bias=True,
activate=True)
self.skip = ConvLayer(
in_channels,
out_channels,
1,
downsample=True,
interpolation_mode=interpolation_mode,
bias=False,
activate=False)
def forward(self, x):
out = self.conv1(x)
out = self.conv2(out)
skip = self.skip(x)
out = (out + skip) / math.sqrt(2)
return out

View File

@@ -3,7 +3,7 @@ import os.path as osp
import torch
from basicsr.archs import build_network
from basicsr.losses import build_loss
from basicsr.losses.losses import r1_penalty
from basicsr.losses.gan_loss import r1_penalty
from basicsr.metrics import calculate_metric
from basicsr.models.base_model import BaseModel
from basicsr.utils import get_root_logger, imwrite, tensor2img
@@ -209,18 +209,18 @@ class GFPGANModel(BaseModel):
self.loc_right_eyes = data['loc_right_eye']
self.loc_mouths = data['loc_mouth']
# uncomment to check data
# import torchvision
# if self.opt['rank'] == 0:
# import os
# os.makedirs('tmp/gt', exist_ok=True)
# os.makedirs('tmp/lq', exist_ok=True)
# print(self.idx)
# torchvision.utils.save_image(
# self.gt, f'tmp/gt/gt_{self.idx}.png', nrow=4, padding=2, normalize=True, range=(-1, 1))
# torchvision.utils.save_image(
# self.lq, f'tmp/lq/lq{self.idx}.png', nrow=4, padding=2, normalize=True, range=(-1, 1))
# self.idx = self.idx + 1
# uncomment to check data
# import torchvision
# if self.opt['rank'] == 0:
# import os
# os.makedirs('tmp/gt', exist_ok=True)
# os.makedirs('tmp/lq', exist_ok=True)
# print(self.idx)
# torchvision.utils.save_image(
# self.gt, f'tmp/gt/gt_{self.idx}.png', nrow=4, padding=2, normalize=True, range=(-1, 1))
# torchvision.utils.save_image(
# self.lq, f'tmp/lq/lq{self.idx}.png', nrow=4, padding=2, normalize=True, range=(-1, 1))
# self.idx = self.idx + 1
def construct_img_pyramid(self):
"""Construct image pyramid for intermediate restoration loss"""
@@ -300,10 +300,9 @@ class GFPGANModel(BaseModel):
p.requires_grad = False
# image pyramid loss weight
if current_iter < self.opt['train'].get('remove_pyramid_loss', float('inf')):
pyramid_loss_weight = self.opt['train'].get('pyramid_loss_weight', 1)
else:
pyramid_loss_weight = 1e-12 # very small loss
pyramid_loss_weight = self.opt['train'].get('pyramid_loss_weight', 0)
if pyramid_loss_weight > 0 and current_iter > self.opt['train'].get('remove_pyramid_loss', float('inf')):
pyramid_loss_weight = 1e-12 # very small weight to avoid unused param error
if pyramid_loss_weight > 0:
self.output, out_rgbs = self.net_g(self.lq, return_rgb=True)
pyramid_gt = self.construct_img_pyramid()

View File

@@ -6,6 +6,7 @@ from basicsr.utils.download_util import load_file_from_url
from facexlib.utils.face_restoration_helper import FaceRestoreHelper
from torchvision.transforms.functional import normalize
from gfpgan.archs.gfpgan_bilinear_arch import GFPGANBilinear
from gfpgan.archs.gfpganv1_arch import GFPGANv1
from gfpgan.archs.gfpganv1_clean_arch import GFPGANv1Clean
@@ -28,12 +29,12 @@ class GFPGANer():
bg_upsampler (nn.Module): The upsampler for the background. Default: None.
"""
def __init__(self, model_path, upscale=2, arch='clean', channel_multiplier=2, bg_upsampler=None):
def __init__(self, model_path, upscale=2, arch='clean', channel_multiplier=2, bg_upsampler=None, device=None):
self.upscale = upscale
self.bg_upsampler = bg_upsampler
# initialize model
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if device is None else device
# initialize the GFP-GAN
if arch == 'clean':
self.gfpgan = GFPGANv1Clean(
@@ -47,7 +48,19 @@ class GFPGANer():
different_w=True,
narrow=1,
sft_half=True)
else:
elif arch == 'bilinear':
self.gfpgan = GFPGANBilinear(
out_size=512,
num_style_feat=512,
channel_multiplier=channel_multiplier,
decoder_load_path=None,
fix_decoder=False,
num_mlp=8,
input_is_latent=True,
different_w=True,
narrow=1,
sft_half=True)
elif arch == 'original':
self.gfpgan = GFPGANv1(
out_size=512,
num_style_feat=512,
@@ -59,6 +72,9 @@ class GFPGANer():
different_w=True,
narrow=1,
sft_half=True)
elif arch == 'RestoreFormer':
from gfpgan.archs.restoreformer_arch import RestoreFormer
self.gfpgan = RestoreFormer()
# initialize face helper
self.face_helper = FaceRestoreHelper(
upscale,
@@ -66,7 +82,9 @@ class GFPGANer():
crop_ratio=(1, 1),
det_model='retinaface_resnet50',
save_ext='png',
device=self.device)
use_parse=True,
device=self.device,
model_rootpath='gfpgan/weights')
if model_path.startswith('https://'):
model_path = load_file_from_url(
@@ -81,7 +99,7 @@ class GFPGANer():
self.gfpgan = self.gfpgan.to(self.device)
@torch.no_grad()
def enhance(self, img, has_aligned=False, only_center_face=False, paste_back=True):
def enhance(self, img, has_aligned=False, only_center_face=False, paste_back=True, weight=0.5):
self.face_helper.clean_all()
if has_aligned: # the inputs are already aligned
@@ -104,7 +122,7 @@ class GFPGANer():
cropped_face_t = cropped_face_t.unsqueeze(0).to(self.device)
try:
output = self.gfpgan(cropped_face_t, return_rgb=False)[0]
output = self.gfpgan(cropped_face_t, return_rgb=False, weight=weight)[0]
# convert to image
restored_face = tensor2img(output.squeeze(0), rgb2bgr=True, min_max=(-1, 1))
except RuntimeError as error:

View File

@@ -10,39 +10,57 @@ from gfpgan import GFPGANer
def main():
"""Inference demo for GFPGAN.
"""Inference demo for GFPGAN (for users).
"""
parser = argparse.ArgumentParser()
parser.add_argument('--upscale', type=int, default=2, help='The final upsampling scale of the image')
parser.add_argument('--arch', type=str, default='clean', help='The GFPGAN architecture. Option: clean | original')
parser.add_argument('--channel', type=int, default=2, help='Channel multiplier for large networks of StyleGAN2')
parser.add_argument('--model_path', type=str, default='experiments/pretrained_models/GFPGANCleanv1-NoCE-C2.pth')
parser.add_argument('--bg_upsampler', type=str, default='realesrgan', help='background upsampler')
parser.add_argument(
'--bg_tile', type=int, default=400, help='Tile size for background sampler, 0 for no tile during testing')
parser.add_argument('--test_path', type=str, default='inputs/whole_imgs', help='Input folder')
'-i',
'--input',
type=str,
default='inputs/whole_imgs',
help='Input image or folder. Default: inputs/whole_imgs')
parser.add_argument('-o', '--output', type=str, default='results', help='Output folder. Default: results')
# we use version to select models, which is more user-friendly
parser.add_argument(
'-v', '--version', type=str, default='1.3', help='GFPGAN model version. Option: 1 | 1.2 | 1.3. Default: 1.3')
parser.add_argument(
'-s', '--upscale', type=int, default=2, help='The final upsampling scale of the image. Default: 2')
parser.add_argument(
'--bg_upsampler', type=str, default='realesrgan', help='background upsampler. Default: realesrgan')
parser.add_argument(
'--bg_tile',
type=int,
default=400,
help='Tile size for background sampler, 0 for no tile during testing. Default: 400')
parser.add_argument('--suffix', type=str, default=None, help='Suffix of the restored faces')
parser.add_argument('--only_center_face', action='store_true', help='Only restore the center face')
parser.add_argument('--aligned', action='store_true', help='Input are aligned faces')
parser.add_argument('--paste_back', action='store_false', help='Paste the restored faces back to images')
parser.add_argument('--save_root', type=str, default='results', help='Path to save root')
parser.add_argument(
'--ext',
type=str,
default='auto',
help='Image extension. Options: auto | jpg | png, auto means using the same extension as inputs')
help='Image extension. Options: auto | jpg | png, auto means using the same extension as inputs. Default: auto')
parser.add_argument('-w', '--weight', type=float, default=0.5, help='Adjustable weights.')
args = parser.parse_args()
args = parser.parse_args()
if args.test_path.endswith('/'):
args.test_path = args.test_path[:-1]
os.makedirs(args.save_root, exist_ok=True)
# background upsampler
# ------------------------ input & output ------------------------
if args.input.endswith('/'):
args.input = args.input[:-1]
if os.path.isfile(args.input):
img_list = [args.input]
else:
img_list = sorted(glob.glob(os.path.join(args.input, '*')))
os.makedirs(args.output, exist_ok=True)
# ------------------------ set up background upsampler ------------------------
if args.bg_upsampler == 'realesrgan':
if not torch.cuda.is_available(): # CPU
import warnings
warnings.warn('The unoptimized RealESRGAN is very slow on CPU. We do not use it. '
warnings.warn('The unoptimized RealESRGAN is slow on CPU. We do not use it. '
'If you really want to use it, please modify the corresponding codes.')
bg_upsampler = None
else:
@@ -59,15 +77,52 @@ def main():
half=True) # need to set False in CPU mode
else:
bg_upsampler = None
# set up GFPGAN restorer
# ------------------------ set up GFPGAN restorer ------------------------
if args.version == '1':
arch = 'original'
channel_multiplier = 1
model_name = 'GFPGANv1'
url = 'https://github.com/TencentARC/GFPGAN/releases/download/v0.1.0/GFPGANv1.pth'
elif args.version == '1.2':
arch = 'clean'
channel_multiplier = 2
model_name = 'GFPGANCleanv1-NoCE-C2'
url = 'https://github.com/TencentARC/GFPGAN/releases/download/v0.2.0/GFPGANCleanv1-NoCE-C2.pth'
elif args.version == '1.3':
arch = 'clean'
channel_multiplier = 2
model_name = 'GFPGANv1.3'
url = 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth'
elif args.version == '1.4':
arch = 'clean'
channel_multiplier = 2
model_name = 'GFPGANv1.4'
url = 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth'
elif args.version == 'RestoreFormer':
arch = 'RestoreFormer'
channel_multiplier = 2
model_name = 'RestoreFormer'
url = 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/RestoreFormer.pth'
else:
raise ValueError(f'Wrong model version {args.version}.')
# determine model paths
model_path = os.path.join('experiments/pretrained_models', model_name + '.pth')
if not os.path.isfile(model_path):
model_path = os.path.join('gfpgan/weights', model_name + '.pth')
if not os.path.isfile(model_path):
# download pre-trained models from url
model_path = url
restorer = GFPGANer(
model_path=args.model_path,
model_path=model_path,
upscale=args.upscale,
arch=args.arch,
channel_multiplier=args.channel,
arch=arch,
channel_multiplier=channel_multiplier,
bg_upsampler=bg_upsampler)
img_list = sorted(glob.glob(os.path.join(args.test_path, '*')))
# ------------------------ restore ------------------------
for img_path in img_list:
# read image
img_name = os.path.basename(img_path)
@@ -77,23 +132,27 @@ def main():
# restore faces and background if necessary
cropped_faces, restored_faces, restored_img = restorer.enhance(
input_img, has_aligned=args.aligned, only_center_face=args.only_center_face, paste_back=args.paste_back)
input_img,
has_aligned=args.aligned,
only_center_face=args.only_center_face,
paste_back=True,
weight=args.weight)
# save faces
for idx, (cropped_face, restored_face) in enumerate(zip(cropped_faces, restored_faces)):
# save cropped face
save_crop_path = os.path.join(args.save_root, 'cropped_faces', f'{basename}_{idx:02d}.png')
save_crop_path = os.path.join(args.output, 'cropped_faces', f'{basename}_{idx:02d}.png')
imwrite(cropped_face, save_crop_path)
# save restored face
if args.suffix is not None:
save_face_name = f'{basename}_{idx:02d}_{args.suffix}.png'
else:
save_face_name = f'{basename}_{idx:02d}.png'
save_restore_path = os.path.join(args.save_root, 'restored_faces', save_face_name)
save_restore_path = os.path.join(args.output, 'restored_faces', save_face_name)
imwrite(restored_face, save_restore_path)
# save comparison image
cmp_img = np.concatenate((cropped_face, restored_face), axis=1)
imwrite(cmp_img, os.path.join(args.save_root, 'cmp', f'{basename}_{idx:02d}.png'))
imwrite(cmp_img, os.path.join(args.output, 'cmp', f'{basename}_{idx:02d}.png'))
# save restored img
if restored_img is not None:
@@ -103,13 +162,12 @@ def main():
extension = args.ext
if args.suffix is not None:
save_restore_path = os.path.join(args.save_root, 'restored_imgs',
f'{basename}_{args.suffix}.{extension}')
save_restore_path = os.path.join(args.output, 'restored_imgs', f'{basename}_{args.suffix}.{extension}')
else:
save_restore_path = os.path.join(args.save_root, 'restored_imgs', f'{basename}.{extension}')
save_restore_path = os.path.join(args.output, 'restored_imgs', f'{basename}.{extension}')
imwrite(restored_img, save_restore_path)
print(f'Results are in the [{args.save_root}] folder.')
print(f'Results are in the [{args.output}] folder.')
if __name__ == '__main__':

View File

@@ -1,12 +1,12 @@
torch>=1.7
numpy<1.21 # numba requires numpy<1.21,>=1.17
opencv-python
torchvision
scipy
tqdm
basicsr>=1.3.4.0
facexlib>=0.2.0.3
basicsr>=1.4.2
facexlib>=0.2.5
lmdb
numpy
opencv-python
pyyaml
scipy
tb-nightly
torch>=1.7
torchvision
tqdm
yapf

View File

@@ -0,0 +1,164 @@
import argparse
import math
import torch
from gfpgan.archs.gfpganv1_clean_arch import GFPGANv1Clean
def modify_checkpoint(checkpoint_bilinear, checkpoint_clean):
for ori_k, ori_v in checkpoint_bilinear.items():
if 'stylegan_decoder' in ori_k:
if 'style_mlp' in ori_k: # style_mlp_layers
lr_mul = 0.01
prefix, name, idx, var = ori_k.split('.')
idx = (int(idx) * 2) - 1
crt_k = f'{prefix}.{name}.{idx}.{var}'
if var == 'weight':
_, c_in = ori_v.size()
scale = (1 / math.sqrt(c_in)) * lr_mul
crt_v = ori_v * scale * 2**0.5
else:
crt_v = ori_v * lr_mul * 2**0.5
checkpoint_clean[crt_k] = crt_v
elif 'modulation' in ori_k: # modulation in StyleConv
lr_mul = 1
crt_k = ori_k
var = ori_k.split('.')[-1]
if var == 'weight':
_, c_in = ori_v.size()
scale = (1 / math.sqrt(c_in)) * lr_mul
crt_v = ori_v * scale
else:
crt_v = ori_v * lr_mul
checkpoint_clean[crt_k] = crt_v
elif 'style_conv' in ori_k:
# StyleConv in style_conv1 and style_convs
if 'activate' in ori_k: # FusedLeakyReLU
# eg. style_conv1.activate.bias
# eg. style_convs.13.activate.bias
split_rlt = ori_k.split('.')
if len(split_rlt) == 4:
prefix, name, _, var = split_rlt
crt_k = f'{prefix}.{name}.{var}'
elif len(split_rlt) == 5:
prefix, name, idx, _, var = split_rlt
crt_k = f'{prefix}.{name}.{idx}.{var}'
crt_v = ori_v * 2**0.5 # 2**0.5 used in FusedLeakyReLU
c = crt_v.size(0)
checkpoint_clean[crt_k] = crt_v.view(1, c, 1, 1)
elif 'modulated_conv' in ori_k:
# eg. style_conv1.modulated_conv.weight
# eg. style_convs.13.modulated_conv.weight
_, c_out, c_in, k1, k2 = ori_v.size()
scale = 1 / math.sqrt(c_in * k1 * k2)
crt_k = ori_k
checkpoint_clean[crt_k] = ori_v * scale
elif 'weight' in ori_k:
crt_k = ori_k
checkpoint_clean[crt_k] = ori_v * 2**0.5
elif 'to_rgb' in ori_k: # StyleConv in to_rgb1 and to_rgbs
if 'modulated_conv' in ori_k:
# eg. to_rgb1.modulated_conv.weight
# eg. to_rgbs.5.modulated_conv.weight
_, c_out, c_in, k1, k2 = ori_v.size()
scale = 1 / math.sqrt(c_in * k1 * k2)
crt_k = ori_k
checkpoint_clean[crt_k] = ori_v * scale
else:
crt_k = ori_k
checkpoint_clean[crt_k] = ori_v
else:
crt_k = ori_k
checkpoint_clean[crt_k] = ori_v
# end of 'stylegan_decoder'
elif 'conv_body_first' in ori_k or 'final_conv' in ori_k:
# key name
name, _, var = ori_k.split('.')
crt_k = f'{name}.{var}'
# weight and bias
if var == 'weight':
c_out, c_in, k1, k2 = ori_v.size()
scale = 1 / math.sqrt(c_in * k1 * k2)
checkpoint_clean[crt_k] = ori_v * scale * 2**0.5
else:
checkpoint_clean[crt_k] = ori_v * 2**0.5
elif 'conv_body' in ori_k:
if 'conv_body_up' in ori_k:
ori_k = ori_k.replace('conv2.weight', 'conv2.1.weight')
ori_k = ori_k.replace('skip.weight', 'skip.1.weight')
name1, idx1, name2, _, var = ori_k.split('.')
crt_k = f'{name1}.{idx1}.{name2}.{var}'
if name2 == 'skip':
c_out, c_in, k1, k2 = ori_v.size()
scale = 1 / math.sqrt(c_in * k1 * k2)
checkpoint_clean[crt_k] = ori_v * scale / 2**0.5
else:
if var == 'weight':
c_out, c_in, k1, k2 = ori_v.size()
scale = 1 / math.sqrt(c_in * k1 * k2)
checkpoint_clean[crt_k] = ori_v * scale
else:
checkpoint_clean[crt_k] = ori_v
if 'conv1' in ori_k:
checkpoint_clean[crt_k] *= 2**0.5
elif 'toRGB' in ori_k:
crt_k = ori_k
if 'weight' in ori_k:
c_out, c_in, k1, k2 = ori_v.size()
scale = 1 / math.sqrt(c_in * k1 * k2)
checkpoint_clean[crt_k] = ori_v * scale
else:
checkpoint_clean[crt_k] = ori_v
elif 'final_linear' in ori_k:
crt_k = ori_k
if 'weight' in ori_k:
_, c_in = ori_v.size()
scale = 1 / math.sqrt(c_in)
checkpoint_clean[crt_k] = ori_v * scale
else:
checkpoint_clean[crt_k] = ori_v
elif 'condition' in ori_k:
crt_k = ori_k
if '0.weight' in ori_k:
c_out, c_in, k1, k2 = ori_v.size()
scale = 1 / math.sqrt(c_in * k1 * k2)
checkpoint_clean[crt_k] = ori_v * scale * 2**0.5
elif '0.bias' in ori_k:
checkpoint_clean[crt_k] = ori_v * 2**0.5
elif '2.weight' in ori_k:
c_out, c_in, k1, k2 = ori_v.size()
scale = 1 / math.sqrt(c_in * k1 * k2)
checkpoint_clean[crt_k] = ori_v * scale
elif '2.bias' in ori_k:
checkpoint_clean[crt_k] = ori_v
return checkpoint_clean
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--ori_path', type=str, help='Path to the original model')
parser.add_argument('--narrow', type=float, default=1)
parser.add_argument('--channel_multiplier', type=float, default=2)
parser.add_argument('--save_path', type=str)
args = parser.parse_args()
ori_ckpt = torch.load(args.ori_path)['params_ema']
net = GFPGANv1Clean(
512,
num_style_feat=512,
channel_multiplier=args.channel_multiplier,
decoder_load_path=None,
fix_decoder=False,
# for stylegan decoder
num_mlp=8,
input_is_latent=True,
different_w=True,
narrow=args.narrow,
sft_half=True)
crt_ckpt = net.state_dict()
crt_ckpt = modify_checkpoint(ori_ckpt, crt_ckpt)
print(f'Save to {args.save_path}.')
torch.save(dict(params_ema=crt_ckpt), args.save_path, _use_new_zipfile_serialization=False)