mirror of
https://github.com/TencentARC/GFPGAN.git
synced 2026-05-07 04:36:22 +00:00
Compare commits
18 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
996d1e3df9 | ||
|
|
77dc85b882 | ||
|
|
805e62af29 | ||
|
|
b529c9789d | ||
|
|
401cf51a73 | ||
|
|
d507febad8 | ||
|
|
41be5d43d4 | ||
|
|
7f67e12999 | ||
|
|
cc3c881f85 | ||
|
|
7023b5cbdd | ||
|
|
cd37764741 | ||
|
|
1083f910dd | ||
|
|
eafc847101 | ||
|
|
373bf723eb | ||
|
|
c466b9bfdd | ||
|
|
2bbdcc1c84 | ||
|
|
8041b69fbc | ||
|
|
4f50996feb |
30
.github/workflows/publish-pip.yml
vendored
Normal file
30
.github/workflows/publish-pip.yml
vendored
Normal file
@@ -0,0 +1,30 @@
|
||||
name: PyPI Publish
|
||||
|
||||
on: push
|
||||
|
||||
jobs:
|
||||
build-n-publish:
|
||||
runs-on: ubuntu-latest
|
||||
if: startsWith(github.event.ref, 'refs/tags')
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- name: Set up Python 3.8
|
||||
uses: actions/setup-python@v1
|
||||
with:
|
||||
python-version: 3.8
|
||||
- name: Upgrade pip
|
||||
run: pip install pip --upgrade
|
||||
- name: Install PyTorch (cpu)
|
||||
run: pip install torch==1.7.0+cpu torchvision==0.8.1+cpu -f https://download.pytorch.org/whl/torch_stable.html
|
||||
- name: Install dependencies
|
||||
run: pip install -r requirements.txt
|
||||
- name: Build and install
|
||||
run: rm -rf .eggs && pip install -e .
|
||||
- name: Build for distribution
|
||||
# remove bdist_wheel for pip installation with compiling cuda extensions
|
||||
run: python setup.py sdist
|
||||
- name: Publish distribution to PyPI
|
||||
uses: pypa/gh-action-pypi-publish@master
|
||||
with:
|
||||
password: ${{ secrets.PYPI_API_TOKEN }}
|
||||
6
.github/workflows/pylint.yml
vendored
6
.github/workflows/pylint.yml
vendored
@@ -1,4 +1,4 @@
|
||||
name: Python Lint
|
||||
name: PyLint
|
||||
|
||||
on: [push, pull_request]
|
||||
|
||||
@@ -25,5 +25,5 @@ jobs:
|
||||
- name: Lint
|
||||
run: |
|
||||
flake8 .
|
||||
isort --check-only --diff data/ archs/ models/ train.py inference_gfpgan_full.py
|
||||
yapf -r -d data/ archs/ models/ train.py inference_gfpgan_full.py
|
||||
isort --check-only --diff gfpgan/ scripts/ inference_gfpgan.py setup.py
|
||||
yapf -r -d gfpgan/ scripts/ inference_gfpgan.py setup.py
|
||||
|
||||
48
.gitignore
vendored
48
.gitignore
vendored
@@ -1,18 +1,13 @@
|
||||
.vscode
|
||||
# ignored folders
|
||||
datasets/*
|
||||
experiments/*
|
||||
results/*
|
||||
tb_logger/*
|
||||
wandb/*
|
||||
tmp/*
|
||||
|
||||
# ignored files
|
||||
version.py
|
||||
|
||||
# ignored files with suffix
|
||||
*.html
|
||||
*.png
|
||||
*.jpeg
|
||||
*.jpg
|
||||
*.gif
|
||||
*.pth
|
||||
*.zip
|
||||
|
||||
# template
|
||||
.vscode
|
||||
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
@@ -36,6 +31,8 @@ parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
pip-wheel-metadata/
|
||||
share/python-wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
@@ -54,12 +51,14 @@ pip-delete-this-directory.txt
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
*.py,cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
|
||||
@@ -71,6 +70,7 @@ coverage.xml
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
db.sqlite3-journal
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
@@ -88,11 +88,26 @@ target/
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# IPython
|
||||
profile_default/
|
||||
ipython_config.py
|
||||
|
||||
# pyenv
|
||||
.python-version
|
||||
|
||||
# celery beat schedule file
|
||||
# pipenv
|
||||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||
# install all needed dependencies.
|
||||
#Pipfile.lock
|
||||
|
||||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
||||
__pypackages__/
|
||||
|
||||
# Celery stuff
|
||||
celerybeat-schedule
|
||||
celerybeat.pid
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
@@ -118,3 +133,8 @@ venv.bak/
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
|
||||
8
MANIFEST.in
Normal file
8
MANIFEST.in
Normal file
@@ -0,0 +1,8 @@
|
||||
include assets/*
|
||||
include inputs/*
|
||||
include scripts/*.py
|
||||
include inference_gfpgan.py
|
||||
include VERSION
|
||||
include LICENSE
|
||||
include requirements.txt
|
||||
include gfpgan/weights/README.md
|
||||
76
PaperModel.md
Normal file
76
PaperModel.md
Normal file
@@ -0,0 +1,76 @@
|
||||
# Installation
|
||||
|
||||
We now provide a *clean* version of GFPGAN, which does not require customized CUDA extensions. See [here](README.md#installation) for this easier installation.<br>
|
||||
If you want want to use the original model in our paper, please follow the instructions below.
|
||||
|
||||
1. Clone repo
|
||||
|
||||
```bash
|
||||
git clone https://github.com/xinntao/GFPGAN.git
|
||||
cd GFPGAN
|
||||
```
|
||||
|
||||
1. Install dependent packages
|
||||
|
||||
As StyleGAN2 uses customized PyTorch C++ extensions, you need to **compile them during installation** or **load them just-in-time(JIT)**.
|
||||
You can refer to [BasicSR-INSTALL.md](https://github.com/xinntao/BasicSR/blob/master/INSTALL.md) for more details.
|
||||
|
||||
**Option 1: Load extensions just-in-time(JIT)** (For those just want to do simple inferences, may have less issues)
|
||||
|
||||
```bash
|
||||
# Install basicsr - https://github.com/xinntao/BasicSR
|
||||
# We use BasicSR for both training and inference
|
||||
pip install basicsr
|
||||
|
||||
# Install facexlib - https://github.com/xinntao/facexlib
|
||||
# We use face detection and face restoration helper in the facexlib package
|
||||
pip install facexlib
|
||||
|
||||
pip install -r requirements.txt
|
||||
python setup.py develop
|
||||
|
||||
# remember to set BASICSR_JIT=True before your running commands
|
||||
```
|
||||
|
||||
**Option 2: Compile extensions during installation** (For those need to train/inference for many times)
|
||||
|
||||
```bash
|
||||
# Install basicsr - https://github.com/xinntao/BasicSR
|
||||
# We use BasicSR for both training and inference
|
||||
# Set BASICSR_EXT=True to compile the cuda extensions in the BasicSR - It may take several minutes to compile, please be patient
|
||||
# Add -vvv for detailed log prints
|
||||
BASICSR_EXT=True pip install basicsr -vvv
|
||||
|
||||
# Install facexlib - https://github.com/xinntao/facexlib
|
||||
# We use face detection and face restoration helper in the facexlib package
|
||||
pip install facexlib
|
||||
|
||||
pip install -r requirements.txt
|
||||
python setup.py develop
|
||||
```
|
||||
|
||||
## :zap: Quick Inference
|
||||
|
||||
Download pre-trained models: [GFPGANv1.pth](https://github.com/TencentARC/GFPGAN/releases/download/v0.1.0/GFPGANv1.pth)
|
||||
|
||||
```bash
|
||||
wget https://github.com/TencentARC/GFPGAN/releases/download/v0.1.0/GFPGANv1.pth -P experiments/pretrained_models
|
||||
```
|
||||
|
||||
- Option 1: Load extensions just-in-time(JIT)
|
||||
|
||||
```bash
|
||||
BASICSR_JIT=True python inference_gfpgan.py --model_path experiments/pretrained_models/GFPGANv1.pth --test_path inputs/whole_imgs --save_root results --arch original --channel 1
|
||||
|
||||
# for aligned images
|
||||
BASICSR_JIT=True python inference_gfpgan.py --model_path experiments/pretrained_models/GFPGANv1.pth --test_path inputs/cropped_faces --save_root results --arch original --channel 1 --aligned
|
||||
```
|
||||
|
||||
- Option 2: Have successfully compiled extensions during installation
|
||||
|
||||
```bash
|
||||
python inference_gfpgan.py --model_path experiments/pretrained_models/GFPGANv1.pth --test_path inputs/whole_imgs --save_root results --arch original --channel 1
|
||||
|
||||
# for aligned images
|
||||
python inference_gfpgan.py --model_path experiments/pretrained_models/GFPGANv1.pth --test_path inputs/cropped_faces --save_root results --arch original --channel 1 --aligned
|
||||
```
|
||||
108
README.md
108
README.md
@@ -1,27 +1,29 @@
|
||||
# GFPGAN (CVPR 2021)
|
||||
|
||||
[**Paper**](https://arxiv.org/abs/2101.04061) **|** [**Project Page**](https://xinntao.github.io/projects/gfpgan)    [English](README.md) **|** [简体中文](README_CN.md)
|
||||
[](https://github.com/TencentARC/GFPGAN/releases)
|
||||
[](https://pypi.org/project/gfpgan/)
|
||||
[](https://github.com/TencentARC/GFPGAN/issues)
|
||||
[](https://github.com/TencentARC/GFPGAN/blob/master/LICENSE)
|
||||
[](https://github.com/TencentARC/GFPGAN/blob/master/.github/workflows/pylint.yml)
|
||||
[](https://github.com/TencentARC/GFPGAN/blob/master/.github/workflows/publish-pip.yml)
|
||||
|
||||
GFPGAN is a blind face restoration algorithm towards real-world face images.
|
||||
1. [Colab Demo](https://colab.research.google.com/drive/1sVsoBd9AjckIXThgtZhGrHRfFI6UUYOo) for GFPGAN <a href="https://colab.research.google.com/drive/1sVsoBd9AjckIXThgtZhGrHRfFI6UUYOo"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="google colab logo"></a>; (Another [Colab Demo](https://colab.research.google.com/drive/1Oa1WwKB4M4l1GmR7CtswDVgOCOeSLChA?usp=sharing) for the original paper model)
|
||||
1. We provide a *clean* version of GFPGAN, which can run without CUDA extensions. So that it can run in **Windows** or on **CPU mode**.
|
||||
|
||||
GFPGAN aims at developing **Practical Algorithm for Real-world Face Restoration**.<br>
|
||||
It leverages rich and diverse priors encapsulated in a pretrained face GAN (*e.g.*, StyleGAN2) for blind face restoration.
|
||||
|
||||
:triangular_flag_on_post: **Updates**
|
||||
|
||||
- :white_check_mark: We provide a *clean* version of GFPGAN, which does not require CUDA extensions.
|
||||
- :white_check_mark: We provide an updated model without colorizing faces.
|
||||
|
||||
### :book: GFP-GAN: Towards Real-World Blind Face Restoration with Generative Facial Prior
|
||||
> [[Paper](https://arxiv.org/abs/2101.04061)]   [[Project Page](https://xinntao.github.io/projects/gfpgan)]   [[Demo]()] <br>
|
||||
|
||||
> [[Paper](https://arxiv.org/abs/2101.04061)]   [[Project Page](https://xinntao.github.io/projects/gfpgan)]   [Demo] <br>
|
||||
> [Xintao Wang](https://xinntao.github.io/), [Yu Li](https://yu-li.github.io/), [Honglun Zhang](https://scholar.google.com/citations?hl=en&user=KjQLROoAAAAJ), [Ying Shan](https://scholar.google.com/citations?user=4oXBp9UAAAAJ&hl=en) <br>
|
||||
> Applied Research Center (ARC), Tencent PCG
|
||||
|
||||
#### Abstract
|
||||
|
||||
Blind face restoration usually relies on facial priors, such as facial geometry prior or reference prior, to restore realistic and faithful details. However, very low-quality inputs cannot offer accurate geometric prior while high-quality references are inaccessible, limiting the applicability in real-world scenarios. In this work, we propose GFP-GAN that leverages rich and diverse priors encapsulated in a pretrained face GAN for blind face restoration. This Generative Facial Prior (GFP) is incorporated into the face restoration process via novel channel-split spatial feature transform layers, which allow our method to achieve a good balance of realness and fidelity. Thanks to the powerful generative facial prior and delicate designs, our GFP-GAN could jointly restore facial details and enhance colors with just a single forward pass, while GAN inversion methods require expensive image-specific optimization at inference. Extensive experiments show that our method achieves superior performance to prior art on both synthetic and real-world datasets.
|
||||
|
||||
#### BibTeX
|
||||
|
||||
@InProceedings{wang2021gfpgan,
|
||||
author = {Xintao Wang and Yu Li and Honglun Zhang and Ying Shan},
|
||||
title = {Towards Real-World Blind Face Restoration with Generative Facial Prior},
|
||||
booktitle={The IEEE Conference on Computer Vision and Pattern Recognition (CVPR)},
|
||||
year = {2021}
|
||||
}
|
||||
|
||||
<p align="center">
|
||||
<img src="https://xinntao.github.io/projects/GFPGAN_src/gfpgan_teaser.jpg">
|
||||
</p>
|
||||
@@ -32,30 +34,94 @@ Blind face restoration usually relies on facial priors, such as facial geometry
|
||||
|
||||
- Python >= 3.7 (Recommend to use [Anaconda](https://www.anaconda.com/download/#linux) or [Miniconda](https://docs.conda.io/en/latest/miniconda.html))
|
||||
- [PyTorch >= 1.7](https://pytorch.org/)
|
||||
- NVIDIA GPU + [CUDA](https://developer.nvidia.com/cuda-downloads)
|
||||
- Option: NVIDIA GPU + [CUDA](https://developer.nvidia.com/cuda-downloads)
|
||||
- Option: Linux (We have not tested on Windows)
|
||||
|
||||
### Installation
|
||||
|
||||
We now provide a *clean* version of GFPGAN, which does not require customized CUDA extensions. <br>
|
||||
If you want want to use the original model in our paper, please see [PaperModel.md](PaperModel.md) for installation.
|
||||
|
||||
1. Clone repo
|
||||
|
||||
```bash
|
||||
git clone https://github.com/xinntao/GFPGAN.git
|
||||
git clone https://github.com/TencentARC/GFPGAN.git
|
||||
cd GFPGAN
|
||||
```
|
||||
|
||||
1. Install dependent packages
|
||||
|
||||
```bash
|
||||
cd GFPGAN
|
||||
# Install basicsr - https://github.com/xinntao/BasicSR
|
||||
# We use BasicSR for both training and inference
|
||||
pip install basicsr
|
||||
|
||||
# Install facexlib - https://github.com/xinntao/facexlib
|
||||
# We use face detection and face restoration helper in the facexlib package
|
||||
pip install facexlib
|
||||
|
||||
pip install -r requirements.txt
|
||||
python setup.py develop
|
||||
```
|
||||
|
||||
## :zap: Quick Inference
|
||||
|
||||
> python inference_gfpgan_full.py --model_path experiments/pretrained_models/GFPGANv1.pth --test_path inputs
|
||||
Download pre-trained models: [GFPGANCleanv1-NoCE-C2.pth](https://github.com/TencentARC/GFPGAN/releases/download/v0.2.0/GFPGANCleanv1-NoCE-C2.pth)
|
||||
|
||||
```bash
|
||||
wget https://github.com/TencentARC/GFPGAN/releases/download/v0.2.0/GFPGANCleanv1-NoCE-C2.pth -P experiments/pretrained_models
|
||||
```
|
||||
|
||||
**Inference!**
|
||||
|
||||
```bash
|
||||
python inference_gfpgan.py --upscale_factor 2 --test_path inputs/whole_imgs --save_root results
|
||||
```
|
||||
|
||||
## :european_castle: Model Zoo
|
||||
|
||||
- [GFPGANCleanv1-NoCE-C2.pth](https://github.com/TencentARC/GFPGAN/releases/download/v0.2.0/GFPGANCleanv1-NoCE-C2.pth): No colorization; no CUDA extensions are required. It is still in training. Trained with more data with pre-processing.
|
||||
- [GFPGANv1.pth](https://github.com/TencentARC/GFPGAN/releases/download/v0.1.0/GFPGANv1.pth): The paper model, with colorization.
|
||||
|
||||
## :computer: Training
|
||||
|
||||
We provide the training codes for GFPGAN (used in our paper). <br>
|
||||
You could improve it according to your own needs.
|
||||
|
||||
**Tips**
|
||||
|
||||
1. More high quality faces can improve the restoration quality.
|
||||
2. You may need to perform some pre-processing, such as beauty makeup.
|
||||
|
||||
**Procedures**
|
||||
|
||||
(You can try a simple version ( `options/train_gfpgan_v1_simple.yml`) that does not require face component landmarks.)
|
||||
|
||||
1. Dataset preparation: [FFHQ](https://github.com/NVlabs/ffhq-dataset)
|
||||
|
||||
1. Download pre-trained models and other data. Put them in the `experiments/pretrained_models` folder.
|
||||
1. [Pretrained StyleGAN2 model: StyleGAN2_512_Cmul1_FFHQ_B12G4_scratch_800k.pth](https://github.com/TencentARC/GFPGAN/releases/download/v0.1.0/StyleGAN2_512_Cmul1_FFHQ_B12G4_scratch_800k.pth)
|
||||
1. [Component locations of FFHQ: FFHQ_eye_mouth_landmarks_512.pth](https://github.com/TencentARC/GFPGAN/releases/download/v0.1.0/FFHQ_eye_mouth_landmarks_512.pth)
|
||||
1. [A simple ArcFace model: arcface_resnet18.pth](https://github.com/TencentARC/GFPGAN/releases/download/v0.1.0/arcface_resnet18.pth)
|
||||
|
||||
1. Modify the configuration file `options/train_gfpgan_v1.yml` accordingly.
|
||||
|
||||
1. Training
|
||||
|
||||
> python -m torch.distributed.launch --nproc_per_node=4 --master_port=22021 gfpgan/train.py -opt options/train_gfpgan_v1.yml --launcher pytorch
|
||||
|
||||
## :scroll: License and Acknowledgement
|
||||
|
||||
GFPGAN is realeased under Apache License Version 2.0.
|
||||
GFPGAN is released under Apache License Version 2.0.
|
||||
|
||||
## BibTeX
|
||||
|
||||
@InProceedings{wang2021gfpgan,
|
||||
author = {Xintao Wang and Yu Li and Honglun Zhang and Ying Shan},
|
||||
title = {Towards Real-World Blind Face Restoration with Generative Facial Prior},
|
||||
booktitle={The IEEE Conference on Computer Vision and Pattern Recognition (CVPR)},
|
||||
year = {2021}
|
||||
}
|
||||
|
||||
## :e-mail: Contact
|
||||
|
||||
|
||||
62
README_CN.md
62
README_CN.md
@@ -1,62 +0,0 @@
|
||||
# GFPGAN (CVPR 2021)
|
||||
|
||||
[**Paper**](https://arxiv.org/abs/2101.04061) **|** [**Project Page**](https://xinntao.github.io/projects/gfpgan)    [English](README.md) **|** [简体中文](README_CN.md)
|
||||
|
||||
GFPGAN is a blind face restoration algorithm towards real-world face images.
|
||||
|
||||
### :book: GFP-GAN: Towards Real-World Blind Face Restoration with Generative Facial Prior
|
||||
> [[Paper](https://arxiv.org/abs/2101.04061)]   [[Project Page](https://xinntao.github.io/projects/gfpgan)]   [[Demo]()] <br>
|
||||
> [Xintao Wang](https://xinntao.github.io/), [Yu Li](https://yu-li.github.io/), [Honglun Zhang](https://scholar.google.com/citations?hl=en&user=KjQLROoAAAAJ), [Ying Shan](https://scholar.google.com/citations?user=4oXBp9UAAAAJ&hl=en) <br>
|
||||
> Applied Research Center (ARC), Tencent PCG
|
||||
|
||||
#### Abstract
|
||||
|
||||
Blind face restoration usually relies on facial priors, such as facial geometry prior or reference prior, to restore realistic and faithful details. However, very low-quality inputs cannot offer accurate geometric prior while high-quality references are inaccessible, limiting the applicability in real-world scenarios. In this work, we propose GFP-GAN that leverages rich and diverse priors encapsulated in a pretrained face GAN for blind face restoration. This Generative Facial Prior (GFP) is incorporated into the face restoration process via novel channel-split spatial feature transform layers, which allow our method to achieve a good balance of realness and fidelity. Thanks to the powerful generative facial prior and delicate designs, our GFP-GAN could jointly restore facial details and enhance colors with just a single forward pass, while GAN inversion methods require expensive image-specific optimization at inference. Extensive experiments show that our method achieves superior performance to prior art on both synthetic and real-world datasets.
|
||||
|
||||
#### BibTeX
|
||||
|
||||
@InProceedings{wang2021gfpgan,
|
||||
author = {Xintao Wang and Yu Li and Honglun Zhang and Ying Shan},
|
||||
title = {Towards Real-World Blind Face Restoration with Generative Facial Prior},
|
||||
booktitle={The IEEE Conference on Computer Vision and Pattern Recognition (CVPR)},
|
||||
year = {2021}
|
||||
}
|
||||
|
||||
<p align="center">
|
||||
<img src="https://xinntao.github.io/projects/GFPGAN_src/gfpgan_teaser.jpg">
|
||||
</p>
|
||||
|
||||
---
|
||||
|
||||
## :wrench: Dependencies and Installation
|
||||
|
||||
- Python >= 3.7 (Recommend to use [Anaconda](https://www.anaconda.com/download/#linux) or [Miniconda](https://docs.conda.io/en/latest/miniconda.html))
|
||||
- [PyTorch >= 1.7](https://pytorch.org/)
|
||||
- NVIDIA GPU + [CUDA](https://developer.nvidia.com/cuda-downloads)
|
||||
|
||||
### Installation
|
||||
|
||||
1. Clone repo
|
||||
|
||||
```bash
|
||||
git clone https://github.com/xinntao/GFPGAN.git
|
||||
```
|
||||
|
||||
1. Install dependent packages
|
||||
|
||||
```bash
|
||||
cd GFPGAN
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
## :zap: Quick Inference
|
||||
|
||||
> python inference_gfpgan_full.py --model_path experiments/pretrained_models/GFPGANv1.pth --test_path inputs
|
||||
|
||||
## :scroll: License and Acknowledgement
|
||||
|
||||
GFPGAN is realeased under Apache License Version 2.0.
|
||||
|
||||
## :e-mail: Contact
|
||||
|
||||
If you have any question, please email `xintao.wang@outlook.com` or `xintaowang@tencent.com`.
|
||||
7
experiments/pretrained_models/README.md
Normal file
7
experiments/pretrained_models/README.md
Normal file
@@ -0,0 +1,7 @@
|
||||
# Pre-trained Models and Other Data
|
||||
|
||||
Download pre-trained models and other data. Put them in this folder.
|
||||
|
||||
1. [Pretrained StyleGAN2 model: StyleGAN2_512_Cmul1_FFHQ_B12G4_scratch_800k.pth](https://github.com/TencentARC/GFPGAN/releases/download/v0.1.0/StyleGAN2_512_Cmul1_FFHQ_B12G4_scratch_800k.pth)
|
||||
1. [Component locations of FFHQ: FFHQ_eye_mouth_landmarks_512.pth](https://github.com/TencentARC/GFPGAN/releases/download/v0.1.0/FFHQ_eye_mouth_landmarks_512.pth)
|
||||
1. [A simple ArcFace model: arcface_resnet18.pth](https://github.com/TencentARC/GFPGAN/releases/download/v0.1.0/arcface_resnet18.pth)
|
||||
6
gfpgan/__init__.py
Normal file
6
gfpgan/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
# flake8: noqa
|
||||
from .archs import *
|
||||
from .data import *
|
||||
from .models import *
|
||||
from .utils import *
|
||||
from .version import __gitsha__, __version__
|
||||
@@ -1,12 +1,10 @@
|
||||
import importlib
|
||||
from basicsr.utils import scandir
|
||||
from os import path as osp
|
||||
|
||||
from basicsr.utils import scandir
|
||||
|
||||
# automatically scan and import arch modules for registry
|
||||
# scan all the files under the 'archs' folder and collect files ending with
|
||||
# '_arch.py'
|
||||
# scan all the files that end with '_arch.py' under the archs folder
|
||||
arch_folder = osp.dirname(osp.abspath(__file__))
|
||||
arch_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(arch_folder) if v.endswith('_arch.py')]
|
||||
# import all the arch modules
|
||||
_arch_modules = [importlib.import_module(f'archs.{file_name}') for file_name in arch_filenames]
|
||||
_arch_modules = [importlib.import_module(f'gfpgan.archs.{file_name}') for file_name in arch_filenames]
|
||||
@@ -1,5 +1,4 @@
|
||||
import torch.nn as nn
|
||||
|
||||
from basicsr.utils.registry import ARCH_REGISTRY
|
||||
|
||||
|
||||
@@ -1,13 +1,12 @@
|
||||
import math
|
||||
import random
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from basicsr.archs.stylegan2_arch import (ConvLayer, EqualConv2d, EqualLinear, ResBlock, ScaledLeakyReLU,
|
||||
StyleGAN2Generator)
|
||||
from basicsr.ops.fused_act import FusedLeakyReLU
|
||||
from basicsr.utils.registry import ARCH_REGISTRY
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
|
||||
class StyleGAN2GeneratorSFT(StyleGAN2Generator):
|
||||
304
gfpgan/archs/gfpganv1_clean_arch.py
Normal file
304
gfpgan/archs/gfpganv1_clean_arch.py
Normal file
@@ -0,0 +1,304 @@
|
||||
import math
|
||||
import random
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from .stylegan2_clean_arch import StyleGAN2GeneratorClean
|
||||
|
||||
|
||||
class StyleGAN2GeneratorCSFT(StyleGAN2GeneratorClean):
|
||||
"""StyleGAN2 Generator.
|
||||
|
||||
Args:
|
||||
out_size (int): The spatial size of outputs.
|
||||
num_style_feat (int): Channel number of style features. Default: 512.
|
||||
num_mlp (int): Layer number of MLP style layers. Default: 8.
|
||||
channel_multiplier (int): Channel multiplier for large networks of
|
||||
StyleGAN2. Default: 2.
|
||||
"""
|
||||
|
||||
def __init__(self, out_size, num_style_feat=512, num_mlp=8, channel_multiplier=2, narrow=1, sft_half=False):
|
||||
super(StyleGAN2GeneratorCSFT, self).__init__(
|
||||
out_size,
|
||||
num_style_feat=num_style_feat,
|
||||
num_mlp=num_mlp,
|
||||
channel_multiplier=channel_multiplier,
|
||||
narrow=narrow)
|
||||
|
||||
self.sft_half = sft_half
|
||||
|
||||
def forward(self,
|
||||
styles,
|
||||
conditions,
|
||||
input_is_latent=False,
|
||||
noise=None,
|
||||
randomize_noise=True,
|
||||
truncation=1,
|
||||
truncation_latent=None,
|
||||
inject_index=None,
|
||||
return_latents=False):
|
||||
"""Forward function for StyleGAN2Generator.
|
||||
|
||||
Args:
|
||||
styles (list[Tensor]): Sample codes of styles.
|
||||
input_is_latent (bool): Whether input is latent style.
|
||||
Default: False.
|
||||
noise (Tensor | None): Input noise or None. Default: None.
|
||||
randomize_noise (bool): Randomize noise, used when 'noise' is
|
||||
False. Default: True.
|
||||
truncation (float): TODO. Default: 1.
|
||||
truncation_latent (Tensor | None): TODO. Default: None.
|
||||
inject_index (int | None): The injection index for mixing noise.
|
||||
Default: None.
|
||||
return_latents (bool): Whether to return style latents.
|
||||
Default: False.
|
||||
"""
|
||||
# style codes -> latents with Style MLP layer
|
||||
if not input_is_latent:
|
||||
styles = [self.style_mlp(s) for s in styles]
|
||||
# noises
|
||||
if noise is None:
|
||||
if randomize_noise:
|
||||
noise = [None] * self.num_layers # for each style conv layer
|
||||
else: # use the stored noise
|
||||
noise = [getattr(self.noises, f'noise{i}') for i in range(self.num_layers)]
|
||||
# style truncation
|
||||
if truncation < 1:
|
||||
style_truncation = []
|
||||
for style in styles:
|
||||
style_truncation.append(truncation_latent + truncation * (style - truncation_latent))
|
||||
styles = style_truncation
|
||||
# get style latent with injection
|
||||
if len(styles) == 1:
|
||||
inject_index = self.num_latent
|
||||
|
||||
if styles[0].ndim < 3:
|
||||
# repeat latent code for all the layers
|
||||
latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
|
||||
else: # used for encoder with different latent code for each layer
|
||||
latent = styles[0]
|
||||
elif len(styles) == 2: # mixing noises
|
||||
if inject_index is None:
|
||||
inject_index = random.randint(1, self.num_latent - 1)
|
||||
latent1 = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
|
||||
latent2 = styles[1].unsqueeze(1).repeat(1, self.num_latent - inject_index, 1)
|
||||
latent = torch.cat([latent1, latent2], 1)
|
||||
|
||||
# main generation
|
||||
out = self.constant_input(latent.shape[0])
|
||||
out = self.style_conv1(out, latent[:, 0], noise=noise[0])
|
||||
skip = self.to_rgb1(out, latent[:, 1])
|
||||
|
||||
i = 1
|
||||
for conv1, conv2, noise1, noise2, to_rgb in zip(self.style_convs[::2], self.style_convs[1::2], noise[1::2],
|
||||
noise[2::2], self.to_rgbs):
|
||||
out = conv1(out, latent[:, i], noise=noise1)
|
||||
|
||||
# the conditions may have fewer levels
|
||||
if i < len(conditions):
|
||||
# SFT part to combine the conditions
|
||||
if self.sft_half:
|
||||
out_same, out_sft = torch.split(out, int(out.size(1) // 2), dim=1)
|
||||
out_sft = out_sft * conditions[i - 1] + conditions[i]
|
||||
out = torch.cat([out_same, out_sft], dim=1)
|
||||
else:
|
||||
out = out * conditions[i - 1] + conditions[i]
|
||||
|
||||
out = conv2(out, latent[:, i + 1], noise=noise2)
|
||||
skip = to_rgb(out, latent[:, i + 2], skip)
|
||||
i += 2
|
||||
|
||||
image = skip
|
||||
|
||||
if return_latents:
|
||||
return image, latent
|
||||
else:
|
||||
return image, None
|
||||
|
||||
|
||||
class ResBlock(nn.Module):
|
||||
"""Residual block with upsampling/downsampling.
|
||||
|
||||
Args:
|
||||
in_channels (int): Channel number of the input.
|
||||
out_channels (int): Channel number of the output.
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels, out_channels, mode='down'):
|
||||
super(ResBlock, self).__init__()
|
||||
|
||||
self.conv1 = nn.Conv2d(in_channels, in_channels, 3, 1, 1)
|
||||
self.conv2 = nn.Conv2d(in_channels, out_channels, 3, 1, 1)
|
||||
self.skip = nn.Conv2d(in_channels, out_channels, 1, bias=False)
|
||||
if mode == 'down':
|
||||
self.scale_factor = 0.5
|
||||
elif mode == 'up':
|
||||
self.scale_factor = 2
|
||||
|
||||
def forward(self, x):
|
||||
out = F.leaky_relu_(self.conv1(x), negative_slope=0.2)
|
||||
# upsample/downsample
|
||||
out = F.interpolate(out, scale_factor=self.scale_factor, mode='bilinear', align_corners=False)
|
||||
out = F.leaky_relu_(self.conv2(out), negative_slope=0.2)
|
||||
# skip
|
||||
x = F.interpolate(x, scale_factor=self.scale_factor, mode='bilinear', align_corners=False)
|
||||
skip = self.skip(x)
|
||||
out = out + skip
|
||||
return out
|
||||
|
||||
|
||||
class GFPGANv1Clean(nn.Module):
|
||||
"""GFPGANv1 Clean version."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
out_size,
|
||||
num_style_feat=512,
|
||||
channel_multiplier=1,
|
||||
decoder_load_path=None,
|
||||
fix_decoder=True,
|
||||
# for stylegan decoder
|
||||
num_mlp=8,
|
||||
input_is_latent=False,
|
||||
different_w=False,
|
||||
narrow=1,
|
||||
sft_half=False):
|
||||
|
||||
super(GFPGANv1Clean, self).__init__()
|
||||
self.input_is_latent = input_is_latent
|
||||
self.different_w = different_w
|
||||
self.num_style_feat = num_style_feat
|
||||
|
||||
unet_narrow = narrow * 0.5
|
||||
channels = {
|
||||
'4': int(512 * unet_narrow),
|
||||
'8': int(512 * unet_narrow),
|
||||
'16': int(512 * unet_narrow),
|
||||
'32': int(512 * unet_narrow),
|
||||
'64': int(256 * channel_multiplier * unet_narrow),
|
||||
'128': int(128 * channel_multiplier * unet_narrow),
|
||||
'256': int(64 * channel_multiplier * unet_narrow),
|
||||
'512': int(32 * channel_multiplier * unet_narrow),
|
||||
'1024': int(16 * channel_multiplier * unet_narrow)
|
||||
}
|
||||
|
||||
self.log_size = int(math.log(out_size, 2))
|
||||
first_out_size = 2**(int(math.log(out_size, 2)))
|
||||
|
||||
self.conv_body_first = nn.Conv2d(3, channels[f'{first_out_size}'], 1)
|
||||
|
||||
# downsample
|
||||
in_channels = channels[f'{first_out_size}']
|
||||
self.conv_body_down = nn.ModuleList()
|
||||
for i in range(self.log_size, 2, -1):
|
||||
out_channels = channels[f'{2**(i - 1)}']
|
||||
self.conv_body_down.append(ResBlock(in_channels, out_channels, mode='down'))
|
||||
in_channels = out_channels
|
||||
|
||||
self.final_conv = nn.Conv2d(in_channels, channels['4'], 3, 1, 1)
|
||||
|
||||
# upsample
|
||||
in_channels = channels['4']
|
||||
self.conv_body_up = nn.ModuleList()
|
||||
for i in range(3, self.log_size + 1):
|
||||
out_channels = channels[f'{2**i}']
|
||||
self.conv_body_up.append(ResBlock(in_channels, out_channels, mode='up'))
|
||||
in_channels = out_channels
|
||||
|
||||
# to RGB
|
||||
self.toRGB = nn.ModuleList()
|
||||
for i in range(3, self.log_size + 1):
|
||||
self.toRGB.append(nn.Conv2d(channels[f'{2**i}'], 3, 1))
|
||||
|
||||
if different_w:
|
||||
linear_out_channel = (int(math.log(out_size, 2)) * 2 - 2) * num_style_feat
|
||||
else:
|
||||
linear_out_channel = num_style_feat
|
||||
|
||||
self.final_linear = nn.Linear(channels['4'] * 4 * 4, linear_out_channel)
|
||||
|
||||
self.stylegan_decoder = StyleGAN2GeneratorCSFT(
|
||||
out_size=out_size,
|
||||
num_style_feat=num_style_feat,
|
||||
num_mlp=num_mlp,
|
||||
channel_multiplier=channel_multiplier,
|
||||
narrow=narrow,
|
||||
sft_half=sft_half)
|
||||
|
||||
if decoder_load_path:
|
||||
self.stylegan_decoder.load_state_dict(
|
||||
torch.load(decoder_load_path, map_location=lambda storage, loc: storage)['params_ema'])
|
||||
if fix_decoder:
|
||||
for name, param in self.stylegan_decoder.named_parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
# for SFT
|
||||
self.condition_scale = nn.ModuleList()
|
||||
self.condition_shift = nn.ModuleList()
|
||||
for i in range(3, self.log_size + 1):
|
||||
out_channels = channels[f'{2**i}']
|
||||
if sft_half:
|
||||
sft_out_channels = out_channels
|
||||
else:
|
||||
sft_out_channels = out_channels * 2
|
||||
self.condition_scale.append(
|
||||
nn.Sequential(
|
||||
nn.Conv2d(out_channels, out_channels, 3, 1, 1), nn.LeakyReLU(0.2, True),
|
||||
nn.Conv2d(out_channels, sft_out_channels, 3, 1, 1)))
|
||||
self.condition_shift.append(
|
||||
nn.Sequential(
|
||||
nn.Conv2d(out_channels, out_channels, 3, 1, 1), nn.LeakyReLU(0.2, True),
|
||||
nn.Conv2d(out_channels, sft_out_channels, 3, 1, 1)))
|
||||
|
||||
def forward(self,
|
||||
x,
|
||||
return_latents=False,
|
||||
save_feat_path=None,
|
||||
load_feat_path=None,
|
||||
return_rgb=True,
|
||||
randomize_noise=True):
|
||||
conditions = []
|
||||
unet_skips = []
|
||||
out_rgbs = []
|
||||
|
||||
# encoder
|
||||
feat = F.leaky_relu_(self.conv_body_first(x), negative_slope=0.2)
|
||||
for i in range(self.log_size - 2):
|
||||
feat = self.conv_body_down[i](feat)
|
||||
unet_skips.insert(0, feat)
|
||||
feat = F.leaky_relu_(self.final_conv(feat), negative_slope=0.2)
|
||||
|
||||
# style code
|
||||
style_code = self.final_linear(feat.view(feat.size(0), -1))
|
||||
if self.different_w:
|
||||
style_code = style_code.view(style_code.size(0), -1, self.num_style_feat)
|
||||
# decode
|
||||
for i in range(self.log_size - 2):
|
||||
# add unet skip
|
||||
feat = feat + unet_skips[i]
|
||||
# ResUpLayer
|
||||
feat = self.conv_body_up[i](feat)
|
||||
# generate scale and shift for SFT layer
|
||||
scale = self.condition_scale[i](feat)
|
||||
conditions.append(scale.clone())
|
||||
shift = self.condition_shift[i](feat)
|
||||
conditions.append(shift.clone())
|
||||
# generate rgb images
|
||||
if return_rgb:
|
||||
out_rgbs.append(self.toRGB[i](feat))
|
||||
|
||||
if save_feat_path is not None:
|
||||
torch.save(conditions, save_feat_path)
|
||||
if load_feat_path is not None:
|
||||
conditions = torch.load(load_feat_path)
|
||||
conditions = [v.cuda() for v in conditions]
|
||||
|
||||
# decoder
|
||||
image, _ = self.stylegan_decoder([style_code],
|
||||
conditions,
|
||||
return_latents=return_latents,
|
||||
input_is_latent=self.input_is_latent,
|
||||
randomize_noise=randomize_noise)
|
||||
|
||||
return image, out_rgbs
|
||||
377
gfpgan/archs/stylegan2_clean_arch.py
Normal file
377
gfpgan/archs/stylegan2_clean_arch.py
Normal file
@@ -0,0 +1,377 @@
|
||||
import math
|
||||
import random
|
||||
import torch
|
||||
from basicsr.archs.arch_util import default_init_weights
|
||||
from basicsr.utils.registry import ARCH_REGISTRY
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
|
||||
class NormStyleCode(nn.Module):
|
||||
|
||||
def forward(self, x):
|
||||
"""Normalize the style codes.
|
||||
|
||||
Args:
|
||||
x (Tensor): Style codes with shape (b, c).
|
||||
|
||||
Returns:
|
||||
Tensor: Normalized tensor.
|
||||
"""
|
||||
return x * torch.rsqrt(torch.mean(x**2, dim=1, keepdim=True) + 1e-8)
|
||||
|
||||
|
||||
class ModulatedConv2d(nn.Module):
|
||||
"""Modulated Conv2d used in StyleGAN2.
|
||||
|
||||
There is no bias in ModulatedConv2d.
|
||||
|
||||
Args:
|
||||
in_channels (int): Channel number of the input.
|
||||
out_channels (int): Channel number of the output.
|
||||
kernel_size (int): Size of the convolving kernel.
|
||||
num_style_feat (int): Channel number of style features.
|
||||
demodulate (bool): Whether to demodulate in the conv layer.
|
||||
Default: True.
|
||||
sample_mode (str | None): Indicating 'upsample', 'downsample' or None.
|
||||
Default: None.
|
||||
eps (float): A value added to the denominator for numerical stability.
|
||||
Default: 1e-8.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
num_style_feat,
|
||||
demodulate=True,
|
||||
sample_mode=None,
|
||||
eps=1e-8):
|
||||
super(ModulatedConv2d, self).__init__()
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.kernel_size = kernel_size
|
||||
self.demodulate = demodulate
|
||||
self.sample_mode = sample_mode
|
||||
self.eps = eps
|
||||
|
||||
# modulation inside each modulated conv
|
||||
self.modulation = nn.Linear(num_style_feat, in_channels, bias=True)
|
||||
# initialization
|
||||
default_init_weights(self.modulation, scale=1, bias_fill=1, a=0, mode='fan_in', nonlinearity='linear')
|
||||
|
||||
self.weight = nn.Parameter(
|
||||
torch.randn(1, out_channels, in_channels, kernel_size, kernel_size) /
|
||||
math.sqrt(in_channels * kernel_size**2))
|
||||
self.padding = kernel_size // 2
|
||||
|
||||
def forward(self, x, style):
|
||||
"""Forward function.
|
||||
|
||||
Args:
|
||||
x (Tensor): Tensor with shape (b, c, h, w).
|
||||
style (Tensor): Tensor with shape (b, num_style_feat).
|
||||
|
||||
Returns:
|
||||
Tensor: Modulated tensor after convolution.
|
||||
"""
|
||||
b, c, h, w = x.shape # c = c_in
|
||||
# weight modulation
|
||||
style = self.modulation(style).view(b, 1, c, 1, 1)
|
||||
# self.weight: (1, c_out, c_in, k, k); style: (b, 1, c, 1, 1)
|
||||
weight = self.weight * style # (b, c_out, c_in, k, k)
|
||||
|
||||
if self.demodulate:
|
||||
demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + self.eps)
|
||||
weight = weight * demod.view(b, self.out_channels, 1, 1, 1)
|
||||
|
||||
weight = weight.view(b * self.out_channels, c, self.kernel_size, self.kernel_size)
|
||||
|
||||
if self.sample_mode == 'upsample':
|
||||
x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)
|
||||
elif self.sample_mode == 'downsample':
|
||||
x = F.interpolate(x, scale_factor=0.5, mode='bilinear', align_corners=False)
|
||||
|
||||
b, c, h, w = x.shape
|
||||
x = x.view(1, b * c, h, w)
|
||||
# weight: (b*c_out, c_in, k, k), groups=b
|
||||
out = F.conv2d(x, weight, padding=self.padding, groups=b)
|
||||
out = out.view(b, self.out_channels, *out.shape[2:4])
|
||||
|
||||
return out
|
||||
|
||||
def __repr__(self):
|
||||
return (f'{self.__class__.__name__}(in_channels={self.in_channels}, '
|
||||
f'out_channels={self.out_channels}, '
|
||||
f'kernel_size={self.kernel_size}, '
|
||||
f'demodulate={self.demodulate}, sample_mode={self.sample_mode})')
|
||||
|
||||
|
||||
class StyleConv(nn.Module):
|
||||
"""Style conv.
|
||||
|
||||
Args:
|
||||
in_channels (int): Channel number of the input.
|
||||
out_channels (int): Channel number of the output.
|
||||
kernel_size (int): Size of the convolving kernel.
|
||||
num_style_feat (int): Channel number of style features.
|
||||
demodulate (bool): Whether demodulate in the conv layer. Default: True.
|
||||
sample_mode (str | None): Indicating 'upsample', 'downsample' or None.
|
||||
Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels, out_channels, kernel_size, num_style_feat, demodulate=True, sample_mode=None):
|
||||
super(StyleConv, self).__init__()
|
||||
self.modulated_conv = ModulatedConv2d(
|
||||
in_channels, out_channels, kernel_size, num_style_feat, demodulate=demodulate, sample_mode=sample_mode)
|
||||
self.weight = nn.Parameter(torch.zeros(1)) # for noise injection
|
||||
self.bias = nn.Parameter(torch.zeros(1, out_channels, 1, 1))
|
||||
self.activate = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
||||
|
||||
def forward(self, x, style, noise=None):
|
||||
# modulate
|
||||
out = self.modulated_conv(x, style) * 2**0.5 # for conversion
|
||||
# noise injection
|
||||
if noise is None:
|
||||
b, _, h, w = out.shape
|
||||
noise = out.new_empty(b, 1, h, w).normal_()
|
||||
out = out + self.weight * noise
|
||||
# add bias
|
||||
out = out + self.bias
|
||||
# activation
|
||||
out = self.activate(out)
|
||||
return out
|
||||
|
||||
|
||||
class ToRGB(nn.Module):
|
||||
"""To RGB from features.
|
||||
|
||||
Args:
|
||||
in_channels (int): Channel number of input.
|
||||
num_style_feat (int): Channel number of style features.
|
||||
upsample (bool): Whether to upsample. Default: True.
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels, num_style_feat, upsample=True):
|
||||
super(ToRGB, self).__init__()
|
||||
self.upsample = upsample
|
||||
self.modulated_conv = ModulatedConv2d(
|
||||
in_channels, 3, kernel_size=1, num_style_feat=num_style_feat, demodulate=False, sample_mode=None)
|
||||
self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1))
|
||||
|
||||
def forward(self, x, style, skip=None):
|
||||
"""Forward function.
|
||||
|
||||
Args:
|
||||
x (Tensor): Feature tensor with shape (b, c, h, w).
|
||||
style (Tensor): Tensor with shape (b, num_style_feat).
|
||||
skip (Tensor): Base/skip tensor. Default: None.
|
||||
|
||||
Returns:
|
||||
Tensor: RGB images.
|
||||
"""
|
||||
out = self.modulated_conv(x, style)
|
||||
out = out + self.bias
|
||||
if skip is not None:
|
||||
if self.upsample:
|
||||
skip = F.interpolate(skip, scale_factor=2, mode='bilinear', align_corners=False)
|
||||
out = out + skip
|
||||
return out
|
||||
|
||||
|
||||
class ConstantInput(nn.Module):
|
||||
"""Constant input.
|
||||
|
||||
Args:
|
||||
num_channel (int): Channel number of constant input.
|
||||
size (int): Spatial size of constant input.
|
||||
"""
|
||||
|
||||
def __init__(self, num_channel, size):
|
||||
super(ConstantInput, self).__init__()
|
||||
self.weight = nn.Parameter(torch.randn(1, num_channel, size, size))
|
||||
|
||||
def forward(self, batch):
|
||||
out = self.weight.repeat(batch, 1, 1, 1)
|
||||
return out
|
||||
|
||||
|
||||
@ARCH_REGISTRY.register()
|
||||
class StyleGAN2GeneratorClean(nn.Module):
|
||||
"""Clean version of StyleGAN2 Generator.
|
||||
|
||||
Args:
|
||||
out_size (int): The spatial size of outputs.
|
||||
num_style_feat (int): Channel number of style features. Default: 512.
|
||||
num_mlp (int): Layer number of MLP style layers. Default: 8.
|
||||
channel_multiplier (int): Channel multiplier for large networks of
|
||||
StyleGAN2. Default: 2.
|
||||
narrow (float): Narrow ratio for channels. Default: 1.0.
|
||||
"""
|
||||
|
||||
def __init__(self, out_size, num_style_feat=512, num_mlp=8, channel_multiplier=2, narrow=1):
|
||||
super(StyleGAN2GeneratorClean, self).__init__()
|
||||
# Style MLP layers
|
||||
self.num_style_feat = num_style_feat
|
||||
style_mlp_layers = [NormStyleCode()]
|
||||
for i in range(num_mlp):
|
||||
style_mlp_layers.extend(
|
||||
[nn.Linear(num_style_feat, num_style_feat, bias=True),
|
||||
nn.LeakyReLU(negative_slope=0.2, inplace=True)])
|
||||
self.style_mlp = nn.Sequential(*style_mlp_layers)
|
||||
# initialization
|
||||
default_init_weights(self.style_mlp, scale=1, bias_fill=0, a=0.2, mode='fan_in', nonlinearity='leaky_relu')
|
||||
|
||||
channels = {
|
||||
'4': int(512 * narrow),
|
||||
'8': int(512 * narrow),
|
||||
'16': int(512 * narrow),
|
||||
'32': int(512 * narrow),
|
||||
'64': int(256 * channel_multiplier * narrow),
|
||||
'128': int(128 * channel_multiplier * narrow),
|
||||
'256': int(64 * channel_multiplier * narrow),
|
||||
'512': int(32 * channel_multiplier * narrow),
|
||||
'1024': int(16 * channel_multiplier * narrow)
|
||||
}
|
||||
self.channels = channels
|
||||
|
||||
self.constant_input = ConstantInput(channels['4'], size=4)
|
||||
self.style_conv1 = StyleConv(
|
||||
channels['4'],
|
||||
channels['4'],
|
||||
kernel_size=3,
|
||||
num_style_feat=num_style_feat,
|
||||
demodulate=True,
|
||||
sample_mode=None)
|
||||
self.to_rgb1 = ToRGB(channels['4'], num_style_feat, upsample=False)
|
||||
|
||||
self.log_size = int(math.log(out_size, 2))
|
||||
self.num_layers = (self.log_size - 2) * 2 + 1
|
||||
self.num_latent = self.log_size * 2 - 2
|
||||
|
||||
self.style_convs = nn.ModuleList()
|
||||
self.to_rgbs = nn.ModuleList()
|
||||
self.noises = nn.Module()
|
||||
|
||||
in_channels = channels['4']
|
||||
# noise
|
||||
for layer_idx in range(self.num_layers):
|
||||
resolution = 2**((layer_idx + 5) // 2)
|
||||
shape = [1, 1, resolution, resolution]
|
||||
self.noises.register_buffer(f'noise{layer_idx}', torch.randn(*shape))
|
||||
# style convs and to_rgbs
|
||||
for i in range(3, self.log_size + 1):
|
||||
out_channels = channels[f'{2**i}']
|
||||
self.style_convs.append(
|
||||
StyleConv(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
num_style_feat=num_style_feat,
|
||||
demodulate=True,
|
||||
sample_mode='upsample'))
|
||||
self.style_convs.append(
|
||||
StyleConv(
|
||||
out_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
num_style_feat=num_style_feat,
|
||||
demodulate=True,
|
||||
sample_mode=None))
|
||||
self.to_rgbs.append(ToRGB(out_channels, num_style_feat, upsample=True))
|
||||
in_channels = out_channels
|
||||
|
||||
def make_noise(self):
|
||||
"""Make noise for noise injection."""
|
||||
device = self.constant_input.weight.device
|
||||
noises = [torch.randn(1, 1, 4, 4, device=device)]
|
||||
|
||||
for i in range(3, self.log_size + 1):
|
||||
for _ in range(2):
|
||||
noises.append(torch.randn(1, 1, 2**i, 2**i, device=device))
|
||||
|
||||
return noises
|
||||
|
||||
def get_latent(self, x):
|
||||
return self.style_mlp(x)
|
||||
|
||||
def mean_latent(self, num_latent):
|
||||
latent_in = torch.randn(num_latent, self.num_style_feat, device=self.constant_input.weight.device)
|
||||
latent = self.style_mlp(latent_in).mean(0, keepdim=True)
|
||||
return latent
|
||||
|
||||
def forward(self,
|
||||
styles,
|
||||
input_is_latent=False,
|
||||
noise=None,
|
||||
randomize_noise=True,
|
||||
truncation=1,
|
||||
truncation_latent=None,
|
||||
inject_index=None,
|
||||
return_latents=False):
|
||||
"""Forward function for StyleGAN2Generator.
|
||||
|
||||
Args:
|
||||
styles (list[Tensor]): Sample codes of styles.
|
||||
input_is_latent (bool): Whether input is latent style.
|
||||
Default: False.
|
||||
noise (Tensor | None): Input noise or None. Default: None.
|
||||
randomize_noise (bool): Randomize noise, used when 'noise' is
|
||||
False. Default: True.
|
||||
truncation (float): TODO. Default: 1.
|
||||
truncation_latent (Tensor | None): TODO. Default: None.
|
||||
inject_index (int | None): The injection index for mixing noise.
|
||||
Default: None.
|
||||
return_latents (bool): Whether to return style latents.
|
||||
Default: False.
|
||||
"""
|
||||
# style codes -> latents with Style MLP layer
|
||||
if not input_is_latent:
|
||||
styles = [self.style_mlp(s) for s in styles]
|
||||
# noises
|
||||
if noise is None:
|
||||
if randomize_noise:
|
||||
noise = [None] * self.num_layers # for each style conv layer
|
||||
else: # use the stored noise
|
||||
noise = [getattr(self.noises, f'noise{i}') for i in range(self.num_layers)]
|
||||
# style truncation
|
||||
if truncation < 1:
|
||||
style_truncation = []
|
||||
for style in styles:
|
||||
style_truncation.append(truncation_latent + truncation * (style - truncation_latent))
|
||||
styles = style_truncation
|
||||
# get style latent with injection
|
||||
if len(styles) == 1:
|
||||
inject_index = self.num_latent
|
||||
|
||||
if styles[0].ndim < 3:
|
||||
# repeat latent code for all the layers
|
||||
latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
|
||||
else: # used for encoder with different latent code for each layer
|
||||
latent = styles[0]
|
||||
elif len(styles) == 2: # mixing noises
|
||||
if inject_index is None:
|
||||
inject_index = random.randint(1, self.num_latent - 1)
|
||||
latent1 = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
|
||||
latent2 = styles[1].unsqueeze(1).repeat(1, self.num_latent - inject_index, 1)
|
||||
latent = torch.cat([latent1, latent2], 1)
|
||||
|
||||
# main generation
|
||||
out = self.constant_input(latent.shape[0])
|
||||
out = self.style_conv1(out, latent[:, 0], noise=noise[0])
|
||||
skip = self.to_rgb1(out, latent[:, 1])
|
||||
|
||||
i = 1
|
||||
for conv1, conv2, noise1, noise2, to_rgb in zip(self.style_convs[::2], self.style_convs[1::2], noise[1::2],
|
||||
noise[2::2], self.to_rgbs):
|
||||
out = conv1(out, latent[:, i], noise=noise1)
|
||||
out = conv2(out, latent[:, i + 1], noise=noise2)
|
||||
skip = to_rgb(out, latent[:, i + 2], skip)
|
||||
i += 2
|
||||
|
||||
image = skip
|
||||
|
||||
if return_latents:
|
||||
return image, latent
|
||||
else:
|
||||
return image, None
|
||||
@@ -1,11 +1,10 @@
|
||||
import importlib
|
||||
from basicsr.utils import scandir
|
||||
from os import path as osp
|
||||
|
||||
from basicsr.utils import scandir
|
||||
|
||||
# automatically scan and import dataset modules for registry
|
||||
# scan all the files under the data folder with '_dataset' in file names
|
||||
# scan all the files that end with '_dataset.py' under the data folder
|
||||
data_folder = osp.dirname(osp.abspath(__file__))
|
||||
dataset_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(data_folder) if v.endswith('_dataset.py')]
|
||||
# import all the dataset modules
|
||||
_dataset_modules = [importlib.import_module(f'data.{file_name}') for file_name in dataset_filenames]
|
||||
_dataset_modules = [importlib.import_module(f'gfpgan.data.{file_name}') for file_name in dataset_filenames]
|
||||
@@ -4,14 +4,13 @@ import numpy as np
|
||||
import os.path as osp
|
||||
import torch
|
||||
import torch.utils.data as data
|
||||
from torchvision.transforms.functional import (adjust_brightness, adjust_contrast, adjust_hue, adjust_saturation,
|
||||
normalize)
|
||||
|
||||
from basicsr.data import degradations as degradations
|
||||
from basicsr.data.data_util import paths_from_folder
|
||||
from basicsr.data.transforms import augment
|
||||
from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor
|
||||
from basicsr.utils.registry import DATASET_REGISTRY
|
||||
from torchvision.transforms.functional import (adjust_brightness, adjust_contrast, adjust_hue, adjust_saturation,
|
||||
normalize)
|
||||
|
||||
|
||||
@DATASET_REGISTRY.register()
|
||||
@@ -1,12 +1,10 @@
|
||||
import importlib
|
||||
from basicsr.utils import scandir
|
||||
from os import path as osp
|
||||
|
||||
from basicsr.utils import scandir
|
||||
|
||||
# automatically scan and import model modules for registry
|
||||
# scan all the files under the 'models' folder and collect files ending with
|
||||
# '_model.py'
|
||||
# scan all the files that end with '_model.py' under the model folder
|
||||
model_folder = osp.dirname(osp.abspath(__file__))
|
||||
model_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(model_folder) if v.endswith('_model.py')]
|
||||
# import all the model modules
|
||||
_model_modules = [importlib.import_module(f'models.{file_name}') for file_name in model_filenames]
|
||||
_model_modules = [importlib.import_module(f'gfpgan.models.{file_name}') for file_name in model_filenames]
|
||||
@@ -1,11 +1,6 @@
|
||||
import math
|
||||
import os.path as osp
|
||||
import torch
|
||||
from collections import OrderedDict
|
||||
from torch.nn import functional as F
|
||||
from torchvision.ops import roi_align
|
||||
from tqdm import tqdm
|
||||
|
||||
from basicsr.archs import build_network
|
||||
from basicsr.losses import build_loss
|
||||
from basicsr.losses.losses import r1_penalty
|
||||
@@ -13,6 +8,10 @@ from basicsr.metrics import calculate_metric
|
||||
from basicsr.models.base_model import BaseModel
|
||||
from basicsr.utils import get_root_logger, imwrite, tensor2img
|
||||
from basicsr.utils.registry import MODEL_REGISTRY
|
||||
from collections import OrderedDict
|
||||
from torch.nn import functional as F
|
||||
from torchvision.ops import roi_align
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
@MODEL_REGISTRY.register()
|
||||
11
gfpgan/train.py
Normal file
11
gfpgan/train.py
Normal file
@@ -0,0 +1,11 @@
|
||||
# flake8: noqa
|
||||
import os.path as osp
|
||||
from basicsr.train import train_pipeline
|
||||
|
||||
import gfpgan.archs
|
||||
import gfpgan.data
|
||||
import gfpgan.models
|
||||
|
||||
if __name__ == '__main__':
|
||||
root_path = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir))
|
||||
train_pipeline(root_path)
|
||||
134
gfpgan/utils.py
Normal file
134
gfpgan/utils.py
Normal file
@@ -0,0 +1,134 @@
|
||||
import cv2
|
||||
import os
|
||||
import torch
|
||||
from basicsr.utils import img2tensor, tensor2img
|
||||
from facexlib.utils.face_restoration_helper import FaceRestoreHelper
|
||||
from torch.hub import download_url_to_file, get_dir
|
||||
from torchvision.transforms.functional import normalize
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from gfpgan.archs.gfpganv1_arch import GFPGANv1
|
||||
from gfpgan.archs.gfpganv1_clean_arch import GFPGANv1Clean
|
||||
|
||||
ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
|
||||
class GFPGANer():
|
||||
|
||||
def __init__(self, model_path, upscale=2, arch='clean', channel_multiplier=2, bg_upsampler=None):
|
||||
self.upscale = upscale
|
||||
self.bg_upsampler = bg_upsampler
|
||||
|
||||
# initialize model
|
||||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
# initialize the GFP-GAN
|
||||
if arch == 'clean':
|
||||
self.gfpgan = GFPGANv1Clean(
|
||||
out_size=512,
|
||||
num_style_feat=512,
|
||||
channel_multiplier=channel_multiplier,
|
||||
decoder_load_path=None,
|
||||
fix_decoder=False,
|
||||
num_mlp=8,
|
||||
input_is_latent=True,
|
||||
different_w=True,
|
||||
narrow=1,
|
||||
sft_half=True)
|
||||
else:
|
||||
self.gfpgan = GFPGANv1(
|
||||
out_size=512,
|
||||
num_style_feat=512,
|
||||
channel_multiplier=channel_multiplier,
|
||||
decoder_load_path=None,
|
||||
fix_decoder=True,
|
||||
num_mlp=8,
|
||||
input_is_latent=True,
|
||||
different_w=True,
|
||||
narrow=1,
|
||||
sft_half=True)
|
||||
# initialize face helper
|
||||
self.face_helper = FaceRestoreHelper(
|
||||
upscale,
|
||||
face_size=512,
|
||||
crop_ratio=(1, 1),
|
||||
det_model='retinaface_resnet50',
|
||||
save_ext='png',
|
||||
device=self.device)
|
||||
|
||||
if model_path.startswith('https://'):
|
||||
model_path = load_file_from_url(url=model_path, model_dir='gfpgan/weights', progress=True, file_name=None)
|
||||
loadnet = torch.load(model_path)
|
||||
if 'params_ema' in loadnet:
|
||||
keyname = 'params_ema'
|
||||
else:
|
||||
keyname = 'params'
|
||||
self.gfpgan.load_state_dict(loadnet[keyname], strict=True)
|
||||
self.gfpgan.eval()
|
||||
self.gfpgan = self.gfpgan.to(self.device)
|
||||
|
||||
@torch.no_grad()
|
||||
def enhance(self, img, has_aligned=False, only_center_face=False, paste_back=True):
|
||||
self.face_helper.clean_all()
|
||||
|
||||
if has_aligned:
|
||||
img = cv2.resize(img, (512, 512))
|
||||
self.face_helper.cropped_faces = [img]
|
||||
else:
|
||||
self.face_helper.read_image(img)
|
||||
# get face landmarks for each face
|
||||
self.face_helper.get_face_landmarks_5(only_center_face=only_center_face)
|
||||
# align and warp each face
|
||||
self.face_helper.align_warp_face()
|
||||
|
||||
# face restoration
|
||||
for cropped_face in self.face_helper.cropped_faces:
|
||||
# prepare data
|
||||
cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True)
|
||||
normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
|
||||
cropped_face_t = cropped_face_t.unsqueeze(0).to(self.device)
|
||||
|
||||
try:
|
||||
output = self.gfpgan(cropped_face_t, return_rgb=False)[0]
|
||||
# convert to image
|
||||
restored_face = tensor2img(output.squeeze(0), rgb2bgr=True, min_max=(-1, 1))
|
||||
except RuntimeError as error:
|
||||
print(f'\tFailed inference for GFPGAN: {error}.')
|
||||
restored_face = cropped_face
|
||||
|
||||
restored_face = restored_face.astype('uint8')
|
||||
self.face_helper.add_restored_face(restored_face)
|
||||
|
||||
if not has_aligned and paste_back:
|
||||
|
||||
if self.bg_upsampler is not None:
|
||||
# Now only support RealESRGAN
|
||||
bg_img = self.bg_upsampler.enhance(img, outscale=self.upscale)[0]
|
||||
else:
|
||||
bg_img = None
|
||||
|
||||
self.face_helper.get_inverse_affine(None)
|
||||
# paste each restored face to the input image
|
||||
restored_img = self.face_helper.paste_faces_to_input_image(upsample_img=bg_img)
|
||||
return self.face_helper.cropped_faces, self.face_helper.restored_faces, restored_img
|
||||
else:
|
||||
return self.face_helper.cropped_faces, self.face_helper.restored_faces, None
|
||||
|
||||
|
||||
def load_file_from_url(url, model_dir=None, progress=True, file_name=None):
|
||||
"""Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py
|
||||
"""
|
||||
if model_dir is None:
|
||||
hub_dir = get_dir()
|
||||
model_dir = os.path.join(hub_dir, 'checkpoints')
|
||||
|
||||
os.makedirs(os.path.join(ROOT_DIR, model_dir), exist_ok=True)
|
||||
|
||||
parts = urlparse(url)
|
||||
filename = os.path.basename(parts.path)
|
||||
if file_name is not None:
|
||||
filename = file_name
|
||||
cached_file = os.path.abspath(os.path.join(ROOT_DIR, model_dir, filename))
|
||||
if not os.path.exists(cached_file):
|
||||
print(f'Downloading: "{url}" to {cached_file}\n')
|
||||
download_url_to_file(url, cached_file, hash_prefix=None, progress=progress)
|
||||
return cached_file
|
||||
3
gfpgan/weights/README.md
Normal file
3
gfpgan/weights/README.md
Normal file
@@ -0,0 +1,3 @@
|
||||
# Weights
|
||||
|
||||
Put the downloaded weights to this folder.
|
||||
96
inference_gfpgan.py
Normal file
96
inference_gfpgan.py
Normal file
@@ -0,0 +1,96 @@
|
||||
import argparse
|
||||
import cv2
|
||||
import glob
|
||||
import numpy as np
|
||||
import os
|
||||
import torch
|
||||
from basicsr.utils import imwrite
|
||||
|
||||
from gfpgan import GFPGANer
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument('--upscale', type=int, default=2)
|
||||
parser.add_argument('--arch', type=str, default='clean')
|
||||
parser.add_argument('--channel', type=int, default=2)
|
||||
parser.add_argument('--model_path', type=str, default='experiments/pretrained_models/GFPGANCleanv1-NoCE-C2.pth')
|
||||
parser.add_argument('--bg_upsampler', type=str, default='realesrgan')
|
||||
parser.add_argument('--bg_tile', type=int, default=0)
|
||||
parser.add_argument('--test_path', type=str, default='inputs/whole_imgs')
|
||||
parser.add_argument('--suffix', type=str, default=None, help='Suffix of the restored faces')
|
||||
parser.add_argument('--only_center_face', action='store_true')
|
||||
parser.add_argument('--aligned', action='store_true')
|
||||
parser.add_argument('--paste_back', action='store_false')
|
||||
parser.add_argument('--save_root', type=str, default='results')
|
||||
|
||||
args = parser.parse_args()
|
||||
if args.test_path.endswith('/'):
|
||||
args.test_path = args.test_path[:-1]
|
||||
os.makedirs(args.save_root, exist_ok=True)
|
||||
|
||||
# background upsampler
|
||||
if args.bg_upsampler == 'realesrgan':
|
||||
if not torch.cuda.is_available(): # CPU
|
||||
import warnings
|
||||
warnings.warn('The unoptimized RealESRGAN is very slow on CPU. We do not use it. '
|
||||
'If you really want to use it, please modify the corresponding codes.')
|
||||
bg_upsampler = None
|
||||
else:
|
||||
from realesrgan import RealESRGANer
|
||||
bg_upsampler = RealESRGANer(
|
||||
scale=2,
|
||||
model_path='https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth',
|
||||
tile=args.bg_tile,
|
||||
tile_pad=10,
|
||||
pre_pad=0,
|
||||
half=True) # need to set False in CPU mode
|
||||
else:
|
||||
bg_upsampler = None
|
||||
# set up GFPGAN restorer
|
||||
restorer = GFPGANer(
|
||||
model_path=args.model_path,
|
||||
upscale=args.upscale,
|
||||
arch=args.arch,
|
||||
channel_multiplier=args.channel,
|
||||
bg_upsampler=bg_upsampler)
|
||||
|
||||
img_list = sorted(glob.glob(os.path.join(args.test_path, '*')))
|
||||
for img_path in img_list:
|
||||
# read image
|
||||
img_name = os.path.basename(img_path)
|
||||
print(f'Processing {img_name} ...')
|
||||
basename, ext = os.path.splitext(img_name)
|
||||
input_img = cv2.imread(img_path, cv2.IMREAD_COLOR)
|
||||
|
||||
cropped_faces, restored_faces, restored_img = restorer.enhance(
|
||||
input_img, has_aligned=args.aligned, only_center_face=args.only_center_face, paste_back=args.paste_back)
|
||||
|
||||
# save faces
|
||||
for idx, (cropped_face, restored_face) in enumerate(zip(cropped_faces, restored_faces)):
|
||||
# save cropped face
|
||||
save_crop_path = os.path.join(args.save_root, 'cropped_faces', f'{basename}_{idx:02d}.png')
|
||||
imwrite(restored_face, save_crop_path)
|
||||
# save restored face
|
||||
if args.suffix is not None:
|
||||
save_face_name = f'{basename}_{idx:02d}_{args.suffix}.png'
|
||||
else:
|
||||
save_face_name = f'{basename}_{idx:02d}.png'
|
||||
save_restore_path = os.path.join(args.save_root, 'restored_faces', save_face_name)
|
||||
imwrite(restored_face, save_restore_path)
|
||||
# save cmp image
|
||||
cmp_img = np.concatenate((cropped_face, restored_face), axis=1)
|
||||
imwrite(cmp_img, os.path.join(args.save_root, 'cmp', f'{basename}_{idx:02d}.png'))
|
||||
# save restored img
|
||||
if args.suffix is not None:
|
||||
save_restore_path = os.path.join(args.save_root, 'restored_imgs', f'{basename}_{args.suffix}{ext}')
|
||||
else:
|
||||
save_restore_path = os.path.join(args.save_root, 'restored_imgs', img_name)
|
||||
imwrite(restored_img, save_restore_path)
|
||||
|
||||
print(f'Results are in the [{args.save_root}] folder.')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@@ -1,115 +0,0 @@
|
||||
import argparse
|
||||
import cv2
|
||||
import glob
|
||||
import numpy as np
|
||||
import os
|
||||
import torch
|
||||
from facexlib.utils.face_restoration_helper import FaceRestoreHelper
|
||||
from torchvision.transforms.functional import normalize
|
||||
|
||||
from archs.gfpganv1_arch import GFPGANv1
|
||||
from basicsr.utils import img2tensor, imwrite, tensor2img
|
||||
|
||||
|
||||
def restoration(gfpgan, face_helper, img_path, save_root, has_aligned=False, only_center_face=True, suffix=None):
|
||||
# read image
|
||||
img_name = os.path.basename(img_path)
|
||||
print(f'Processing {img_name} ...')
|
||||
basename, _ = os.path.splitext(img_name)
|
||||
input_img = cv2.imread(img_path, cv2.IMREAD_COLOR)
|
||||
face_helper.clean_all()
|
||||
|
||||
if has_aligned:
|
||||
input_img = cv2.resize(input_img, (512, 512))
|
||||
face_helper.cropped_faces = [input_img]
|
||||
else:
|
||||
face_helper.read_image(input_img)
|
||||
# get face landmarks for each face
|
||||
face_helper.get_face_landmarks_5(only_center_face=only_center_face, pad_blur=False)
|
||||
# align and warp each face
|
||||
save_crop_path = os.path.join(save_root, 'cropped_faces', img_name)
|
||||
face_helper.align_warp_face(save_crop_path)
|
||||
|
||||
# face restoration
|
||||
for idx, cropped_face in enumerate(face_helper.cropped_faces):
|
||||
# prepare data
|
||||
cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True)
|
||||
normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
|
||||
cropped_face_t = cropped_face_t.unsqueeze(0).to('cuda')
|
||||
|
||||
try:
|
||||
with torch.no_grad():
|
||||
output = gfpgan(cropped_face_t, return_rgb=False)[0]
|
||||
# convert to image
|
||||
restored_face = tensor2img(output.squeeze(0), rgb2bgr=True, min_max=(-1, 1))
|
||||
except RuntimeError as error:
|
||||
print(f'\tFailed inference for GFPGAN: {error}.')
|
||||
restored_face = cropped_face
|
||||
|
||||
restored_face = restored_face.astype('uint8')
|
||||
face_helper.add_restored_face(restored_face)
|
||||
|
||||
if suffix is not None:
|
||||
save_face_name = f'{basename}_{idx:02d}_{suffix}.png'
|
||||
else:
|
||||
save_face_name = f'{basename}_{idx:02d}.png'
|
||||
save_restore_path = os.path.join(save_root, 'restored_faces', save_face_name)
|
||||
imwrite(restored_face, save_restore_path)
|
||||
|
||||
# save cmp image
|
||||
cmp_img = np.concatenate((cropped_face, restored_face), axis=1)
|
||||
imwrite(cmp_img, os.path.join(save_root, 'cmp', f'{basename}_{idx:02d}.png'))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument('--upscale_factor', type=int, default=1)
|
||||
parser.add_argument('--model_path', type=str, default='experiments/pretrained_models/GFPGANv1.pth')
|
||||
parser.add_argument('--test_path', type=str, default='inputs/whole_imgs')
|
||||
parser.add_argument('--suffix', type=str, default=None, help='Suffix of the restored faces')
|
||||
parser.add_argument('--only_center_face', action='store_true')
|
||||
parser.add_argument('--aligned', action='store_true')
|
||||
|
||||
args = parser.parse_args()
|
||||
if args.test_path.endswith('/'):
|
||||
args.test_path = args.test_path[:-1]
|
||||
save_root = 'results/'
|
||||
os.makedirs(save_root, exist_ok=True)
|
||||
|
||||
# initialize the GFP-GAN
|
||||
gfpgan = GFPGANv1(
|
||||
out_size=512,
|
||||
num_style_feat=512,
|
||||
channel_multiplier=1,
|
||||
decoder_load_path=None,
|
||||
fix_decoder=True,
|
||||
# for stylegan decoder
|
||||
num_mlp=8,
|
||||
input_is_latent=True,
|
||||
different_w=True,
|
||||
narrow=1,
|
||||
sft_half=True)
|
||||
|
||||
gfpgan.to(device)
|
||||
checkpoint = torch.load(args.model_path, map_location=lambda storage, loc: storage)
|
||||
gfpgan.load_state_dict(checkpoint['params_ema'])
|
||||
gfpgan.eval()
|
||||
|
||||
# initialize face helper
|
||||
face_helper = FaceRestoreHelper(
|
||||
upscale_factor=1, face_size=512, crop_ratio=(1, 1), det_model='retinaface_resnet50', save_ext='png')
|
||||
|
||||
img_list = sorted(glob.glob(os.path.join(args.test_path, '*')))
|
||||
for img_path in img_list:
|
||||
restoration(
|
||||
gfpgan,
|
||||
face_helper,
|
||||
img_path,
|
||||
save_root,
|
||||
has_aligned=args.aligned,
|
||||
only_center_face=args.only_center_face,
|
||||
suffix=args.suffix)
|
||||
|
||||
print('Results are in the <results> folder.')
|
||||
BIN
inputs/whole_imgs/Blake_Lively.jpg
Normal file
BIN
inputs/whole_imgs/Blake_Lively.jpg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 91 KiB |
@@ -9,9 +9,11 @@ datasets:
|
||||
train:
|
||||
name: FFHQ
|
||||
type: FFHQDegradationDataset
|
||||
dataroot_gt: datasets/ffhq/ffhq_512.lmdb
|
||||
# dataroot_gt: datasets/ffhq/ffhq_512.lmdb
|
||||
dataroot_gt: datasets/ffhq/ffhq_512
|
||||
io_backend:
|
||||
type: lmdb
|
||||
# type: lmdb
|
||||
type: disk
|
||||
|
||||
use_hflip: true
|
||||
mean: [0.5, 0.5, 0.5]
|
||||
@@ -32,6 +34,12 @@ datasets:
|
||||
color_jitter_pt_prob: 0.3
|
||||
gray_prob: 0.01
|
||||
|
||||
# If you do not want colorization, please set
|
||||
# color_jitter_prob: ~
|
||||
# color_jitter_pt_prob: ~
|
||||
# gray_prob: 0.01
|
||||
# gt_gray: True
|
||||
|
||||
crop_components: true
|
||||
component_path: experiments/pretrained_models/FFHQ_eye_mouth_landmarks_512.pth
|
||||
eye_enlarge_ratio: 1.4
|
||||
@@ -40,14 +48,16 @@ datasets:
|
||||
use_shuffle: true
|
||||
num_worker_per_gpu: 6
|
||||
batch_size_per_gpu: 3
|
||||
dataset_enlarge_ratio: 100
|
||||
dataset_enlarge_ratio: 1
|
||||
prefetch_mode: ~
|
||||
|
||||
val:
|
||||
name: validation1020_512
|
||||
# Please modify accordingly to use your own validation
|
||||
# Or comment the val block if do not need validation during training
|
||||
name: validation
|
||||
type: PairedImageDataset
|
||||
dataroot_lq: datasets/faces/validation1020_512/input # TODO: modify before release
|
||||
dataroot_gt: datasets/faces/validation1020_512/input
|
||||
dataroot_lq: datasets/faces/validation/input
|
||||
dataroot_gt: datasets/faces/validation/reference
|
||||
io_backend:
|
||||
type: disk
|
||||
mean: [0.5, 0.5, 0.5]
|
||||
@@ -97,12 +107,13 @@ path:
|
||||
param_key_g: params_ema
|
||||
strict_load_g: ~
|
||||
pretrain_network_d: ~
|
||||
|
||||
resume_state: ~
|
||||
pretrain_network_d_left_eye: ~
|
||||
pretrain_network_d_right_eye: ~
|
||||
pretrain_network_d_mouth: ~
|
||||
pretrain_network_arcface: experiments/pretrained_models/arcface_resnet18.pth
|
||||
pretrain_network_identity: experiments/pretrained_models/arcface_resnet18.pth
|
||||
# resume
|
||||
resume_state: ~
|
||||
ignore_resume_networks: ['network_identity']
|
||||
|
||||
# training settings
|
||||
train:
|
||||
@@ -137,7 +148,7 @@ train:
|
||||
reduction: mean
|
||||
|
||||
# image pyramid loss
|
||||
pyramid_loss_weight: 0
|
||||
pyramid_loss_weight: 1
|
||||
remove_pyramid_loss: 50000
|
||||
# perceptual loss (content and style losses)
|
||||
perceptual_opt:
|
||||
216
options/train_gfpgan_v1_simple.yml
Normal file
216
options/train_gfpgan_v1_simple.yml
Normal file
@@ -0,0 +1,216 @@
|
||||
# general settings
|
||||
name: train_GFPGANv1_512_simple
|
||||
model_type: GFPGANModel
|
||||
num_gpu: 4
|
||||
manual_seed: 0
|
||||
|
||||
# dataset and data loader settings
|
||||
datasets:
|
||||
train:
|
||||
name: FFHQ
|
||||
type: FFHQDegradationDataset
|
||||
# dataroot_gt: datasets/ffhq/ffhq_512.lmdb
|
||||
dataroot_gt: datasets/ffhq/ffhq_512
|
||||
io_backend:
|
||||
# type: lmdb
|
||||
type: disk
|
||||
|
||||
use_hflip: true
|
||||
mean: [0.5, 0.5, 0.5]
|
||||
std: [0.5, 0.5, 0.5]
|
||||
out_size: 512
|
||||
|
||||
blur_kernel_size: 41
|
||||
kernel_list: ['iso', 'aniso']
|
||||
kernel_prob: [0.5, 0.5]
|
||||
blur_sigma: [0.1, 10]
|
||||
downsample_range: [0.8, 8]
|
||||
noise_range: [0, 20]
|
||||
jpeg_range: [60, 100]
|
||||
|
||||
# color jitter and gray
|
||||
color_jitter_prob: 0.3
|
||||
color_jitter_shift: 20
|
||||
color_jitter_pt_prob: 0.3
|
||||
gray_prob: 0.01
|
||||
|
||||
# If you do not want colorization, please set
|
||||
# color_jitter_prob: ~
|
||||
# color_jitter_pt_prob: ~
|
||||
# gray_prob: 0.01
|
||||
# gt_gray: True
|
||||
|
||||
# crop_components: false
|
||||
# component_path: experiments/pretrained_models/FFHQ_eye_mouth_landmarks_512.pth
|
||||
# eye_enlarge_ratio: 1.4
|
||||
|
||||
# data loader
|
||||
use_shuffle: true
|
||||
num_worker_per_gpu: 6
|
||||
batch_size_per_gpu: 3
|
||||
dataset_enlarge_ratio: 1
|
||||
prefetch_mode: ~
|
||||
|
||||
val:
|
||||
# Please modify accordingly to use your own validation
|
||||
# Or comment the val block if do not need validation during training
|
||||
name: validation
|
||||
type: PairedImageDataset
|
||||
dataroot_lq: datasets/faces/validation/input
|
||||
dataroot_gt: datasets/faces/validation/reference
|
||||
io_backend:
|
||||
type: disk
|
||||
mean: [0.5, 0.5, 0.5]
|
||||
std: [0.5, 0.5, 0.5]
|
||||
scale: 1
|
||||
|
||||
# network structures
|
||||
network_g:
|
||||
type: GFPGANv1
|
||||
out_size: 512
|
||||
num_style_feat: 512
|
||||
channel_multiplier: 1
|
||||
resample_kernel: [1, 3, 3, 1]
|
||||
decoder_load_path: experiments/pretrained_models/StyleGAN2_512_Cmul1_FFHQ_B12G4_scratch_800k.pth
|
||||
fix_decoder: true
|
||||
num_mlp: 8
|
||||
lr_mlp: 0.01
|
||||
input_is_latent: true
|
||||
different_w: true
|
||||
narrow: 1
|
||||
sft_half: true
|
||||
|
||||
network_d:
|
||||
type: StyleGAN2Discriminator
|
||||
out_size: 512
|
||||
channel_multiplier: 1
|
||||
resample_kernel: [1, 3, 3, 1]
|
||||
|
||||
# network_d_left_eye:
|
||||
# type: FacialComponentDiscriminator
|
||||
|
||||
# network_d_right_eye:
|
||||
# type: FacialComponentDiscriminator
|
||||
|
||||
# network_d_mouth:
|
||||
# type: FacialComponentDiscriminator
|
||||
|
||||
network_identity:
|
||||
type: ResNetArcFace
|
||||
block: IRBlock
|
||||
layers: [2, 2, 2, 2]
|
||||
use_se: False
|
||||
|
||||
# path
|
||||
path:
|
||||
pretrain_network_g: ~
|
||||
param_key_g: params_ema
|
||||
strict_load_g: ~
|
||||
pretrain_network_d: ~
|
||||
# pretrain_network_d_left_eye: ~
|
||||
# pretrain_network_d_right_eye: ~
|
||||
# pretrain_network_d_mouth: ~
|
||||
pretrain_network_identity: experiments/pretrained_models/arcface_resnet18.pth
|
||||
# resume
|
||||
resume_state: ~
|
||||
ignore_resume_networks: ['network_identity']
|
||||
|
||||
# training settings
|
||||
train:
|
||||
optim_g:
|
||||
type: Adam
|
||||
lr: !!float 2e-3
|
||||
optim_d:
|
||||
type: Adam
|
||||
lr: !!float 2e-3
|
||||
optim_component:
|
||||
type: Adam
|
||||
lr: !!float 2e-3
|
||||
|
||||
scheduler:
|
||||
type: MultiStepLR
|
||||
milestones: [600000, 700000]
|
||||
gamma: 0.5
|
||||
|
||||
total_iter: 800000
|
||||
warmup_iter: -1 # no warm up
|
||||
|
||||
# losses
|
||||
# pixel loss
|
||||
pixel_opt:
|
||||
type: L1Loss
|
||||
loss_weight: !!float 1e-1
|
||||
reduction: mean
|
||||
# L1 loss used in pyramid loss, component style loss and identity loss
|
||||
L1_opt:
|
||||
type: L1Loss
|
||||
loss_weight: 1
|
||||
reduction: mean
|
||||
|
||||
# image pyramid loss
|
||||
pyramid_loss_weight: 1
|
||||
remove_pyramid_loss: 50000
|
||||
# perceptual loss (content and style losses)
|
||||
perceptual_opt:
|
||||
type: PerceptualLoss
|
||||
layer_weights:
|
||||
# before relu
|
||||
'conv1_2': 0.1
|
||||
'conv2_2': 0.1
|
||||
'conv3_4': 1
|
||||
'conv4_4': 1
|
||||
'conv5_4': 1
|
||||
vgg_type: vgg19
|
||||
use_input_norm: true
|
||||
perceptual_weight: !!float 1
|
||||
style_weight: 50
|
||||
range_norm: true
|
||||
criterion: l1
|
||||
# gan loss
|
||||
gan_opt:
|
||||
type: GANLoss
|
||||
gan_type: wgan_softplus
|
||||
loss_weight: !!float 1e-1
|
||||
# r1 regularization for discriminator
|
||||
r1_reg_weight: 10
|
||||
# facial component loss
|
||||
# gan_component_opt:
|
||||
# type: GANLoss
|
||||
# gan_type: vanilla
|
||||
# real_label_val: 1.0
|
||||
# fake_label_val: 0.0
|
||||
# loss_weight: !!float 1
|
||||
# comp_style_weight: 200
|
||||
# identity loss
|
||||
identity_weight: 10
|
||||
|
||||
net_d_iters: 1
|
||||
net_d_init_iters: 0
|
||||
net_d_reg_every: 16
|
||||
|
||||
# validation settings
|
||||
val:
|
||||
val_freq: !!float 5e3
|
||||
save_img: true
|
||||
|
||||
metrics:
|
||||
psnr: # metric name, can be arbitrary
|
||||
type: calculate_psnr
|
||||
crop_border: 0
|
||||
test_y_channel: false
|
||||
|
||||
# logging settings
|
||||
logger:
|
||||
print_freq: 100
|
||||
save_checkpoint_freq: !!float 5e3
|
||||
use_tb_logger: true
|
||||
wandb:
|
||||
project: ~
|
||||
resume_id: ~
|
||||
|
||||
# dist training settings
|
||||
dist_params:
|
||||
backend: nccl
|
||||
port: 29500
|
||||
|
||||
find_unused_parameters: true
|
||||
77
scripts/parse_landmark.py
Normal file
77
scripts/parse_landmark.py
Normal file
@@ -0,0 +1,77 @@
|
||||
import cv2
|
||||
import json
|
||||
import numpy as np
|
||||
import torch
|
||||
from basicsr.utils import FileClient, imfrombytes
|
||||
from collections import OrderedDict
|
||||
|
||||
print('Load JSON metadata...')
|
||||
# use the json file in FFHQ dataset
|
||||
with open('ffhq-dataset-v2.json', 'rb') as f:
|
||||
json_data = json.load(f, object_pairs_hook=OrderedDict)
|
||||
|
||||
print('Open LMDB file...')
|
||||
# read ffhq images
|
||||
file_client = FileClient('lmdb', db_paths='datasets/ffhq/ffhq_512.lmdb')
|
||||
with open('datasets/ffhq/ffhq_512.lmdb/meta_info.txt') as fin:
|
||||
paths = [line.split('.')[0] for line in fin]
|
||||
|
||||
save_img = False
|
||||
scale = 0.5 # 0.5 for official FFHQ (512x512), 1 for others
|
||||
enlarge_ratio = 1.4 # only for eyes
|
||||
save_dict = {}
|
||||
|
||||
for item_idx, item in enumerate(json_data.values()):
|
||||
print(f'\r{item_idx} / {len(json_data)}, {item["image"]["file_path"]} ', end='', flush=True)
|
||||
|
||||
# parse landmarks
|
||||
lm = np.array(item['image']['face_landmarks'])
|
||||
lm = lm * scale
|
||||
|
||||
item_dict = {}
|
||||
# get image
|
||||
if save_img:
|
||||
img_bytes = file_client.get(paths[item_idx])
|
||||
img = imfrombytes(img_bytes, float32=True)
|
||||
|
||||
map_left_eye = list(range(36, 42))
|
||||
map_right_eye = list(range(42, 48))
|
||||
map_mouth = list(range(48, 68))
|
||||
|
||||
# eye_left
|
||||
mean_left_eye = np.mean(lm[map_left_eye], 0) # (x, y)
|
||||
half_len_left_eye = np.max((np.max(np.max(lm[map_left_eye], 0) - np.min(lm[map_left_eye], 0)) / 2, 16))
|
||||
item_dict['left_eye'] = [mean_left_eye[0], mean_left_eye[1], half_len_left_eye]
|
||||
# mean_left_eye[0] = 512 - mean_left_eye[0] # for testing flip
|
||||
half_len_left_eye *= enlarge_ratio
|
||||
loc_left_eye = np.hstack((mean_left_eye - half_len_left_eye + 1, mean_left_eye + half_len_left_eye)).astype(int)
|
||||
if save_img:
|
||||
eye_left_img = img[loc_left_eye[1]:loc_left_eye[3], loc_left_eye[0]:loc_left_eye[2], :]
|
||||
cv2.imwrite(f'tmp/{item_idx:08d}_eye_left.png', eye_left_img * 255)
|
||||
|
||||
# eye_right
|
||||
mean_right_eye = np.mean(lm[map_right_eye], 0)
|
||||
half_len_right_eye = np.max((np.max(np.max(lm[map_right_eye], 0) - np.min(lm[map_right_eye], 0)) / 2, 16))
|
||||
item_dict['right_eye'] = [mean_right_eye[0], mean_right_eye[1], half_len_right_eye]
|
||||
# mean_right_eye[0] = 512 - mean_right_eye[0] # # for testing flip
|
||||
half_len_right_eye *= enlarge_ratio
|
||||
loc_right_eye = np.hstack(
|
||||
(mean_right_eye - half_len_right_eye + 1, mean_right_eye + half_len_right_eye)).astype(int)
|
||||
if save_img:
|
||||
eye_right_img = img[loc_right_eye[1]:loc_right_eye[3], loc_right_eye[0]:loc_right_eye[2], :]
|
||||
cv2.imwrite(f'tmp/{item_idx:08d}_eye_right.png', eye_right_img * 255)
|
||||
|
||||
# mouth
|
||||
mean_mouth = np.mean(lm[map_mouth], 0)
|
||||
half_len_mouth = np.max((np.max(np.max(lm[map_mouth], 0) - np.min(lm[map_mouth], 0)) / 2, 16))
|
||||
item_dict['mouth'] = [mean_mouth[0], mean_mouth[1], half_len_mouth]
|
||||
# mean_mouth[0] = 512 - mean_mouth[0] # for testing flip
|
||||
loc_mouth = np.hstack((mean_mouth - half_len_mouth + 1, mean_mouth + half_len_mouth)).astype(int)
|
||||
if save_img:
|
||||
mouth_img = img[loc_mouth[1]:loc_mouth[3], loc_mouth[0]:loc_mouth[2], :]
|
||||
cv2.imwrite(f'tmp/{item_idx:08d}_mouth.png', mouth_img * 255)
|
||||
|
||||
save_dict[f'{item_idx:08d}'] = item_dict
|
||||
|
||||
print('Save...')
|
||||
torch.save(save_dict, './FFHQ_eye_mouth_landmarks_512.pth')
|
||||
@@ -16,7 +16,7 @@ split_before_expression_after_opening_paren = true
|
||||
line_length = 120
|
||||
multi_line_output = 0
|
||||
known_standard_library = pkg_resources,setuptools
|
||||
known_first_party = basicsr
|
||||
known_third_party = cv2,facexlib,numpy,torch,torchvision,tqdm
|
||||
known_first_party = gfpgan
|
||||
known_third_party = basicsr,cv2,facexlib,numpy,torch,torchvision,tqdm
|
||||
no_lines_before = STDLIB,LOCALFOLDER
|
||||
default_section = THIRDPARTY
|
||||
|
||||
113
setup.py
Normal file
113
setup.py
Normal file
@@ -0,0 +1,113 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
from setuptools import find_packages, setup
|
||||
|
||||
import os
|
||||
import subprocess
|
||||
import time
|
||||
|
||||
version_file = 'gfpgan/version.py'
|
||||
|
||||
|
||||
def readme():
|
||||
with open('README.md', encoding='utf-8') as f:
|
||||
content = f.read()
|
||||
return content
|
||||
|
||||
|
||||
def get_git_hash():
|
||||
|
||||
def _minimal_ext_cmd(cmd):
|
||||
# construct minimal environment
|
||||
env = {}
|
||||
for k in ['SYSTEMROOT', 'PATH', 'HOME']:
|
||||
v = os.environ.get(k)
|
||||
if v is not None:
|
||||
env[k] = v
|
||||
# LANGUAGE is used on win32
|
||||
env['LANGUAGE'] = 'C'
|
||||
env['LANG'] = 'C'
|
||||
env['LC_ALL'] = 'C'
|
||||
out = subprocess.Popen(cmd, stdout=subprocess.PIPE, env=env).communicate()[0]
|
||||
return out
|
||||
|
||||
try:
|
||||
out = _minimal_ext_cmd(['git', 'rev-parse', 'HEAD'])
|
||||
sha = out.strip().decode('ascii')
|
||||
except OSError:
|
||||
sha = 'unknown'
|
||||
|
||||
return sha
|
||||
|
||||
|
||||
def get_hash():
|
||||
if os.path.exists('.git'):
|
||||
sha = get_git_hash()[:7]
|
||||
elif os.path.exists(version_file):
|
||||
try:
|
||||
from facexlib.version import __version__
|
||||
sha = __version__.split('+')[-1]
|
||||
except ImportError:
|
||||
raise ImportError('Unable to get git version')
|
||||
else:
|
||||
sha = 'unknown'
|
||||
|
||||
return sha
|
||||
|
||||
|
||||
def write_version_py():
|
||||
content = """# GENERATED VERSION FILE
|
||||
# TIME: {}
|
||||
__version__ = '{}'
|
||||
__gitsha__ = '{}'
|
||||
version_info = ({})
|
||||
"""
|
||||
sha = get_hash()
|
||||
with open('VERSION', 'r') as f:
|
||||
SHORT_VERSION = f.read().strip()
|
||||
VERSION_INFO = ', '.join([x if x.isdigit() else f'"{x}"' for x in SHORT_VERSION.split('.')])
|
||||
|
||||
version_file_str = content.format(time.asctime(), SHORT_VERSION, sha, VERSION_INFO)
|
||||
with open(version_file, 'w') as f:
|
||||
f.write(version_file_str)
|
||||
|
||||
|
||||
def get_version():
|
||||
with open(version_file, 'r') as f:
|
||||
exec(compile(f.read(), version_file, 'exec'))
|
||||
return locals()['__version__']
|
||||
|
||||
|
||||
def get_requirements(filename='requirements.txt'):
|
||||
here = os.path.dirname(os.path.realpath(__file__))
|
||||
with open(os.path.join(here, filename), 'r') as f:
|
||||
requires = [line.replace('\n', '') for line in f.readlines()]
|
||||
return requires
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
write_version_py()
|
||||
setup(
|
||||
name='gfpgan',
|
||||
version=get_version(),
|
||||
description='GFPGAN aims at developing Practical Algorithms for Real-world Face Restoration',
|
||||
long_description=readme(),
|
||||
long_description_content_type='text/markdown',
|
||||
author='Xintao Wang',
|
||||
author_email='xintao.wang@outlook.com',
|
||||
keywords='computer vision, pytorch, image restoration, super-resolution, face restoration, gan, gfpgan',
|
||||
url='https://github.com/TencentARC/GFPGAN',
|
||||
include_package_data=True,
|
||||
packages=find_packages(exclude=('options', 'datasets', 'experiments', 'results', 'tb_logger', 'wandb')),
|
||||
classifiers=[
|
||||
'Development Status :: 4 - Beta',
|
||||
'License :: OSI Approved :: Apache Software License',
|
||||
'Operating System :: OS Independent',
|
||||
'Programming Language :: Python :: 3',
|
||||
'Programming Language :: Python :: 3.7',
|
||||
'Programming Language :: Python :: 3.8',
|
||||
],
|
||||
license='Apache License Version 2.0',
|
||||
setup_requires=['cython', 'numpy'],
|
||||
install_requires=get_requirements(),
|
||||
zip_safe=False)
|
||||
Reference in New Issue
Block a user