Skip to content

Latest commit

 

History

History
122 lines (91 loc) · 3.63 KB

README.md

File metadata and controls

122 lines (91 loc) · 3.63 KB

PyTorch PaQ-2-PiQ: Patch Quality 2 Picture Quality Prediction

PyTorch implementation of PaQ2PiQ

Demo

Open In Colab

Get Started

Check out demo.ipynb

get predicts from a file:

model = InferenceModel(RoIPoolModel(), 'models/RoIPoolModel.pth')
output = model.predict_from_file("images/Picture1.jpg")

predict from a PIL image:

model = InferenceModel(RoIPoolModel(), 'models/RoIPoolModel.pth')
image = Image.open("images/Picture1.jpg")
output = model.predict_from_pil_image(image)

The output would be a dictionary:

output['global_score'] # a float scale number indicating the predicted global quality
output['local_scores']  # a 20x20 numpy array indicating the predicted  local quality scores
output['category']  # From low to high quality: 'Bad', 'Poor', 'Fair', 'Good', 'Excellent'

Installing

git clone https://github.com/baidut/paq2piq
cd paq2piq
virtualenv -p python3.6 env
source ./env/bin/activate
pip install -r requirements.txt

Dataset

The model was trained on FLIVE. You can get it from here. (Feel free to create an issue here if you encountered any problem) For each image, we cropped three different-sized patches. The image data and patch location is taken as input while their scores as output. Here is an example: data

Model

Used ResNet18 pretrained on ImageNet as backbone

Pre-trained model

Download

Train it with Pytorch-lightning

from pytorch_lightning_module import *
module = RoIPoolLightningModule()
trainer = pl.Trainer(gpus=[0])    
trainer.fit(module)

Train it with Pure-Pytorch

Change the settings here:

export PYTHONPATH=.
export PATH_TO_MODEL=models/RoIPoolModel.pth
export PATH_TO_IMAGES=/storage/DATA/images/
export PATH_TO_CSV=/storage/DATA/FLIVE/
export BATCH_SIZE=16
export NUM_WORKERS=2
export NUM_EPOCH=50
export INIT_LR=0.0001
export EXPERIMENT_DIR_NAME=/storage/experiment_n0001

Train model

python cli.py train_model --path_to_save_csv $PATH_TO_CSV \
                                --path_to_images $PATH_TO_IMAGES \
                                --batch_size $BATCH_SIZE \
                                --num_workers $NUM_WORKERS \
                                --num_epoch $NUM_EPOCH \
                                --init_lr $INIT_LR \
                                --experiment_dir_name $EXPERIMENT_DIR_NAME

Use tensorboard to tracking training progress

tensorboard --logdir .

Validate model on val and test datasets

python cli.py validate_model --path_to_model_state $PATH_TO_MODEL \
                                    --path_to_save_csv $PATH_TO_CSV \
                                    --path_to_images $PATH_TO_IMAGES \
                                    --batch_size $BATCH_SIZE \
                                    --num_workers $NUM_EPOCH

Get scores for one image

python cli.py get-image-score --path_to_model_state $PATH_TO_MODEL \
--path_to_image test_image.jpg

Contributing

Contributing are welcome

Acknowledgments