diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..b09cd78 --- /dev/null +++ b/LICENSE @@ -0,0 +1,201 @@ +Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/README.md b/README.md new file mode 100644 index 0000000..7927a92 --- /dev/null +++ b/README.md @@ -0,0 +1,167 @@ +# AutoColor +![pytorch](https://img.shields.io/badge/pytorch-v1.9.0-green.svg?style=plastic) + +## 简介 +该项目主要用于漫画的自动上色。 项目主要基于 [Clip](https://github.com/Lednik7/CLIP-ONNX), [MAE](https://github.com/facebookresearch/mae) 和 [timm](https://github.com/rwightman/pytorch-image-models). + +## 使用说明 +### 方法1、直接下载打包好的软件(适用win10系统) +下载地址: + +(通过onnxruntime部署,通过pyinstaller打包的python程序) +### 方法2、通过执行python程序使用 +a. clone 或下载项目代码到本地(需要pytorch、onnxruntime等环境): +```bash +git clone https://github.com/danczs/AutoColor.git +``` +b. 安装相关包: +```bash +pip install -r requirements.txt +``` +c. 分别下载 [pytorch模型]() 和 [onnx模型]() 到文件夹 ```./deployment/pytorch_models```和```./deployment/onnx_models``` + +d. 启动GUI界面程序 +```bash +cd deployment +python auto_color_gui.py +``` +其中可以通过 'auto_color_gui.py' line14中的代码选择使用ONNX模型还是pytorch模型 +```bash +# using pytorch or onnx deployment +from autocolor_pytorch_deployment import AutoColorDeployment # pytorch deployment +#from autocolor_onnx_deployment import AutoColorDeployment # onnx deployment +``` +## 模型训练 +适用于模型的修改与重新训练 +### 数据准备 +需要下载Danbooru数据集的 [kaggle子集](https://www.kaggle.com/datasets/mylesoneill/tagged-anime-illustrations/code),约40G。 +数据集目录: +```bash +/path/to/archive/ + danbooru-images/ + 0000 + 1.jpg + 2.jpg + ... + 0001 + ... + 0149 + danbooru-metadata/ + moeimouto-faces/ +``` +我们只需使用danbooru-images中的前10个文件夹(0000-0009,共约2万张图片)进行训练,将该10个文件夹单独拷贝到一个新的文件夹(e.g.,archive_subset10) + +### 模型准备 +下载相应的 [预训练模型]() 到```./models```文件夹,以便后续的特征提取和模型训练 + +### 特征准备 +a. 提取Clip特征并保存 +```bash +python get_clip_features.py --data_path = /path/to/archive_subset10 +``` +提取的clip 特征会被保存到 ```./features/clip_features_subset.npy``` + +b. 提取MAE特征并保存 +```bash +python get_mae_features.py --data_path = /path/to/archive_subset10 --output_path=/path/to/mae_features +``` +mae特征会包保存到```--output_path```文件夹,每张图片特征被保存文一个.npy文件,一共约12G。 +这些特征文件的索引被保存在```./feature/mae_feature_names.txt``` + +### 训练模型 +#### 训练 color deocder模型 +```bash +python auto_color_main.py --grad_state = '010' + --output_dir = /path/to/output + --mae_model_path=./models/mae_visualize_vit_base.pth \ + --mae_feature_path=./features/mae_feature_names.txt \ + --clip_feature_path=./features/features/clip_features_subset.npy \ + --colordecoder_model_path=./models/color_decoder_pretrained.pth +``` +其中```--colordecoder_model_path```可以不设置,不设置时使用```mae_visualize_vit_base.pth```的decoder进行初始化。 +这里的```./models/color_decoder_pretrained.pth```是在数据集0000-0049文件(约10万张图片)上的预训练模型,会略微提升最终性能。 +模型保存在 +#### 训练 super color 模型 +super color 模型可以单独进行训练(较快),也可以结合训练好的 color decoder 一起训练,二者整体效果差别不大。 + +a. 单独训练 + +其输入输出为: 低分辨率彩图 + 高分辨率灰度图 --> 高分辨率彩图 +```bash +python auto_color_main.py --grad_state = '001' + --mae_feature_path=./features/mae_feature_names.txt \ + --clip_feature_path=./features/features/clip_features_subset.npy \ + --supercolor_only +``` +b. 基于训练好的color decoder的输出进行训练 + +其输入输出为: color_decoder输出的低分辨率彩图 + 高分辨率灰度图 --> 高分辨率彩图 +```bash +python auto_color_main.py --grad_state = '001' + --mae_feature_path=./features/mae_feature_names.txt \ + --clip_feature_path=./features/features/clip_features_subset.npy \ + --colordecoder_model_path=/path/to/trained_colordeocder_model.pth +``` +## 模型部署 +### pytorch部署 +a. 将训练好的color decoder模型和super color模型拷贝到 ```deployment/pytorch_models```中 + 分别命名为```color_decoder.pt``` 和```super_color.pth``` + +b. 确保```auto_color_gui.py```中line14使用的是pytorch deployment模块 +```bash +# using pytorch or onnx deployment +from autocolor_pytorch_deployment import AutoColorDeployment # pytorch deployment +#from autocolor_onnx_deployment import AutoColorDeployment # onnx deployment +``` +c. 启动界面 +```bash +cd deployment +python auto_color_gui.py +``` + +### onnx部署 +a. 使用```conert_to_onnx_models.py```将 pytorch 模型转换为 onnx 模型 +```bash +python conert_to_onnx_models.py +``` +b. 确保```auto_color_gui.py```中line14使用的是onnx deployment模块 +```bash +# using pytorch or onnx deployment +#from autocolor_pytorch_deployment import AutoColorDeployment # pytorch deployment +from autocolor_onnx_deployment import AutoColorDeployment # onnx deployment +``` +c. 启动界面 +```bash +cd deployment +python auto_color_gui.py +``` +### onnx打包 +```bash +pip install pyinstaller #安装pyinstaller +cd deployment +python -m PyInstaller -F -w -i feather_icon.ico auto_color_gui.py --add-data "\\path\\to\\.conda\\envs\\your_env_name\\Lib\\site-packages\\onnxruntime\\capi\\*.dll;onnxruntime\\capi" +``` +将```onnx_models,example_white.jpg, feather_icon.ico```拷贝到```deployment\dist\```文件夹下: +```bash +deployment\dist\ + onnx_models + color_decoder_onnx.onnx + super_color_onnx.onnx + ... + auto_color_gui.exe + example_white.jpg + feather_icon.ico +``` +运行```auto_color_gui.exe``` + +## citing +```bash +@misc{chen2022autocolor, + author = {Chen, Zhengsu}, + title = {Auto Color}, + year = {2022}, + publisher = {GitHub}, + journal = {GitHub repository}, + howpublished = {\url{https://github.com/danczs/AutoColor}} +} +``` \ No newline at end of file diff --git a/auto_color_main.py b/auto_color_main.py new file mode 100644 index 0000000..0672038 --- /dev/null +++ b/auto_color_main.py @@ -0,0 +1,269 @@ +''' +By danczs (https://github.com/danczs) +References: + https://github.com/facebookresearch/mae + https://github.com/openai/CLIP + https://github.com/Lednik7/CLIP-ONNX + https://github.com/rwightman/pytorch-image-models +''' + +import argparse +import os +import numpy as np +import torch +import torch.backends.cudnn as cudnn + +import util.misc as misc +from util.dataset_autocolor import build_dataset +import timm.optim.optim_factory as optim_factory + +from util.misc import NativeScalerWithGradNormCount as NativeScaler +import util.lr_sched as lr_sched + +from color_decoder import mae_color_decoder_base +from super_color import SuperColor +import torch.nn.functional as F +# +def get_args_parser(): + parser = argparse.ArgumentParser('MAE fine-tuning for image classification', add_help=False) + + parser.add_argument('--batch_size', default=32, type=int, + help='Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus') + parser.add_argument('--epochs', default=10, type=int) + + parser.add_argument('--input_size', default=224, type=int, + help='images input size') + parser.add_argument('--input_size_supercolor', default=448, type=int, + help='images input size') + parser.add_argument('--data_path', default='E://data//carton_subset//train', type=str, + help='dataset path') + parser.add_argument('--nb_classes', default=1000, type=int, + help='number of the classification types') + parser.add_argument('--drop_path', type=float, default=0.1, metavar='PCT', + help='Drop path rate (default: 0.1)') + parser.add_argument('--output_dir', default=None, type=str, + help='the output dir of models and logs') + parser.add_argument('--eval', action='store_true', help='evaluete the model') + + parser.add_argument('--colormask_prob', type=float, default=0.1, metavar='PCT', + help='the hyper-paramter of generating a colormask') + parser.add_argument('--mae_feature_path', default=None, type=str, + help='the mae feature path') + parser.add_argument('--clip_feature_path', default=None, type=str, + help='the clip feature path') + parser.add_argument('--mae_model_path', default=None, type=str, + help='the clip model') + parser.add_argument('--clip_model_path', default=None, type=str, + help='the clip feature path') + parser.add_argument('--colordecoder_model_path', default=None, type=str, + help='the initialization of color decoder weights. \ + If not specified, it will use the pre-trained mae decoder weights.') + parser.add_argument('--supercolor_model_path', default=None, type=str, + help='the initialization of the super color weights') + parser.add_argument('--grad_state', default='010', type=str, + help=' whether or not to train mae encoder, mae color decoder and super color. \ + e.g. 010 indicates only training the color decoder') + parser.add_argument('--supercolor_only',action='store_true', + help='only train or eval the supercolor model') + + # Optimizer parameters + parser.add_argument('--clip_grad', type=float, default=None, metavar='NORM', + help='Clip gradient norm (default: None, no clipping)') + parser.add_argument('--weight_decay', type=float, default=0.05, + help='weight decay (default: 0.05)') + + parser.add_argument('--blr_cd', type=float, default=5e-3, metavar='LR', + help='base learning rate color decoder: absolute_lr = base_lr * total_batch_size / 256') + parser.add_argument('--blr_sc', type=float, default=1e-1, metavar='LR', + help='base learning rate of super color: absolute_lr = base_lr * total_batch_size / 256') + parser.add_argument('--layer_decay', type=float, default=0.75, + help='layer-wise lr decay from ELECTRA/BEiT') + + parser.add_argument('--min_lr', type=float, default=1e-6, metavar='LR', + help='lower lr bound for cyclic schedulers that hit 0') + + parser.add_argument('--warmup_epochs', type=int, default=1, metavar='N', + help='epochs to warmup LR') + parser.add_argument('--alpha', type=float, default=0.5, + help='hyper-parameter to balance L1 loss and L2 loss') + parser.add_argument('--accum_iter', default=1, type=int, + help='Accumulate gradient iterations (for increasing the effective batch size under memory constraints)') + parser.add_argument('--seed', default=0, type=int) + return parser + +def main(args): + # fix the seed for reproducibility + seed = args.seed + misc.get_rank() + torch.manual_seed(seed) + np.random.seed(seed) + cudnn.benchmark = True + + args.mae_model_path = './models/mae_visualize_vit_base.pth' + args.colordecoder_model_path = 'models/color_decoder_pretrained.pth' #'./checkpoints_github/colordecoder_alpha0.5_lr0.005.pth'# + #args.colordecoder_model_path = './checkpoints_github/colordecoder_alpha0.5_lr0.005_p0.05.pth' + #args.supercolor_model_path = './checkpoints_github/supercolor_alpha0.5_lr0.1_p0.0.pth' + args.mae_feature_path = './features/mae_feature_names.txt' + args.clip_feature_path = './features/subset10_clip_feature.npy' + args.grad_state = '010' + args.batch_size = 64 + args.colormask_prob = 0.1 + #args.eval = True + #args.supercolor_only = True + + #dataset anda data loader + dataset = build_dataset(args=args) + if args.eval: + sampler = torch.utils.data.SequentialSampler(dataset) + else: + sampler = torch.utils.data.RandomSampler(dataset) + + data_loader = torch.utils.data.DataLoader( + dataset, sampler=sampler, + batch_size=args.batch_size, + num_workers=2, + pin_memory=True, + drop_last= not args.eval + ) + + device = "cuda" if torch.cuda.is_available() else "cpu" + # mae model + assert len(args.grad_state) == 3 + mae_eval, color_decoder_eval, super_color_eval = [i=='0' for i in args.grad_state] + + #mae encoder model + if args.mae_feature_path is None: + from mae_encoder import mae_vit_base_patch16_dec512d8b + mae_model = mae_vit_base_patch16_dec512d8b() + mae_weights = torch.load(args.mae_model_path, map_location='cpu')['model'] + msg = mae_model.load_state_dict(mae_weights, strict=False) + print(msg) + if mae_eval or args.eval: + mae_model.eval() + mae_model = mae_model.to(device) + + # build color decoder model + if not args.supercolor_only: + color_decoder = mae_color_decoder_base() + if args.colordecoder_model_path is None: + mae_weights = torch.load(args.mae_model_path, map_location='cpu')['model'] + del mae_weights['decoder_pos_embed'] + msg = color_decoder.load_state_dict(mae_weights, strict=False) + print(msg) + else: + color_decoder_weight = torch.load(args.colordecoder_model_path, map_location='cpu') + msg = color_decoder.load_state_dict(color_decoder_weight, strict=False) + print(msg) + if color_decoder_eval or args.eval: + color_decoder.eval() + color_decoder = color_decoder.to(device) + + #build supercolor model + if not super_color_eval or args.eval: + super_color = SuperColor(kernel_size=5, group=4) + if args.supercolor_model_path: + super_color_checkpoint = torch.load(args.supercolor_model_path, map_location='cpu') + msg = super_color.load_state_dict(super_color_checkpoint, strict=False) + print(msg) + if args.eval: + super_color.eval() + super_color = super_color.to(device) + + if color_decoder_eval is False: + lr_cd = args.blr_cd * args.batch_size / 256 + param_groups_cd = optim_factory.param_groups_weight_decay(color_decoder,weight_decay=args.weight_decay) + optimizer_cd = torch.optim.AdamW(param_groups_cd, lr=lr_cd, betas=(0.9, 0.95)) + else: + optimizer_cd = None + if super_color_eval is False: + lr_sc = args.blr_sc * args.batch_size / 256 + param_groups_sc = optim_factory.param_groups_weight_decay(super_color, weight_decay=args.weight_decay) + optimizer_sc = torch.optim.AdamW(param_groups_sc, lr=lr_sc, betas=(0.9, 0.95)) + else: + optimizer_sc = None + + loss_scaler = NativeScaler() + print(f"Start training for {args.epochs} epochs") + + for epoch in range(args.epochs): + avg_loss =0 + for iter_step, (mae_feature, clip_feature, color_mask, img_l, img_l_gray, img_h, img_h_gray, target) in enumerate(data_loader): + color_mask = color_mask.to(device, non_blocking=True) + img_l = img_l.to(device, non_blocking=True) + img_h = img_h.to(device, non_blocking=True) + img_l_gray = img_l_gray.to(device, non_blocking=True) + img_h_gray = img_h_gray.to(device, non_blocking=True) + clip_feature = clip_feature.to(device, non_blocking=True) + + if iter_step % args.accum_iter == 0: + if optimizer_cd is not None: + lr_sched.adjust_learning_rate(optimizer_cd, iter_step / len(data_loader) + epoch, lr_cd, args) + if optimizer_sc is not None: + lr_sched.adjust_learning_rate(optimizer_sc, iter_step / len(data_loader) + epoch, lr_sc, args) + + if args.mae_feature_path is not None: + mae_feature = mae_feature.to(device,non_blocking=True) + else: + with torch.cuda.amp.autocast(): + mae_feature = mae_model(img_l_gray) + + with torch.cuda.amp.autocast(): + if not args.supercolor_only: + pred = color_decoder(mae_feature, clip_feature, color_mask=color_mask) + pred_upsampling = F.interpolate(pred + img_l_gray, size=(img_h.size()[2:])) + if not color_decoder_eval or args.eval: + loss_decoder = color_decoder.forward_loss(pred, img_l_gray, img_l,alpha=args.alpha) + else: + pred_upsampling = F.interpolate(img_l, size=(img_h.size()[2:])) + + if not super_color_eval or args.eval: + color_mask_sc = F.interpolate(color_mask, size=(img_h.size()[2:])) + sc_pred = super_color(pred_upsampling.detach(), img_h_gray, color_mask_sc) #detach cd and sc + loss_sc = super_color.forward_loss(sc_pred, img_h_gray, img_h,alpha=args.alpha) + + if args.eval: + if loss_decoder is not None: + avg_loss += loss_decoder + if loss_sc is not None: + avg_loss += loss_sc + continue + + if optimizer_cd is not None: + + loss_decoder /= args.accum_iter + loss_scaler(loss_decoder, optimizer_cd, parameters=color_decoder.parameters(), + update_grad=(iter_step + 1) % args.accum_iter == 0) + if (iter_step + 1) % args.accum_iter == 0: + optimizer_cd.zero_grad() + avg_loss += loss_decoder.detach().item() + lr = optimizer_cd.param_groups[0]["lr"] + if iter_step % 100 == 0: + print('epoch:{} iter:{} color deocder loss:{} lr:{}'.format(epoch, iter_step, loss_decoder, lr)) + + if optimizer_sc is not None: + + loss_sc /= args.accum_iter + loss_scaler(loss_sc, optimizer_sc, parameters=super_color.parameters(), + update_grad=(iter_step + 1) % args.accum_iter == 0) + if (iter_step + 1) % args.accum_iter == 0: + optimizer_sc.zero_grad() + avg_loss += loss_sc.detach().item() + lr = optimizer_sc.param_groups[0]["lr"] + if iter_step % 100 == 0: + print('epoch:{} iter:{} super color loss:{} lr:{}'.format(epoch, iter_step, loss_sc, lr)) + torch.cuda.synchronize() + + print('epoch:{} avg loss:{}'.format(epoch,avg_loss/len(data_loader))) + if args.eval: + break + + if optimizer_cd: + torch.save(color_decoder.state_dict(), + os.path.join(args.output_dir,'colordecoder_alpha{}_lr{}_p{}.pth'.format(args.alpha,args.blr_cd,args.colormask_prob))) + if optimizer_sc: + torch.save(super_color.state_dict(), + os.path.join(args.output_dir,'supercolor_alpha{}_lr{}_p{}.pth'.format(args.alpha, args.blr_sc,args.colormask_prob))) + +if __name__=='__main__': + args = get_args_parser() + args = args.parse_args() + main(args) \ No newline at end of file diff --git a/color_decoder.py b/color_decoder.py new file mode 100644 index 0000000..97ca20e --- /dev/null +++ b/color_decoder.py @@ -0,0 +1,170 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# References: +# timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm +# DeiT: https://github.com/facebookresearch/deit +# -------------------------------------------------------- + +import torch +import torch.nn as nn +from functools import partial +from timm.models.vision_transformer import Block + +from util.pos_embed import get_2d_sincos_pos_embed +import numpy as np + +class LayerNorm(nn.Module): + def __init__(self, dim): + super().__init__() + #self.ln = nn.LayerNorm((198,dim),eps=1e-6) + self.ln = nn.LayerNorm(dim,eps=1e-6) + + def forward(self, x): + return self.ln(x) + +class ColorDecoder(nn.Module): + """ color decoder with VisionTransformer + """ + + def __init__(self, img_size=224, patch_size=16, in_chans=3, + embed_dim=1024, clip_feature_dim=512, decoder_embed_dim=512, + decoder_depth=8, decoder_num_heads=16, mlp_ratio=4., norm_layer=nn.LayerNorm): + super().__init__() + + self.num_patches = (img_size // patch_size)**2 + self.patch_size = patch_size + self.decoder_embed_dim = decoder_embed_dim + + self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True) + self.clip_embed = nn.Linear(clip_feature_dim, decoder_embed_dim, bias=True) + + self.token_num = self.num_patches + 2 + self.decoder_pos_embed = nn.Parameter(torch.zeros(1, self.token_num , decoder_embed_dim), + requires_grad=False) # fixed sin-cos embedding + + self.decoder_blocks = nn.ModuleList([ + Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer)#remove qk_scale=None, + for i in range(decoder_depth)]) + + self.decoder_pred = nn.Linear(decoder_embed_dim, patch_size ** 2 * in_chans, bias=True) # decoder to patch + self.decoder_conv1 = nn.Conv2d(patch_size ** 2 * in_chans, patch_size ** 2 * in_chans, 3, stride=1, padding=(3-1)//2, bias=True) + + self.color_embdding = nn.Linear(patch_size ** 2 * in_chans, decoder_embed_dim, bias=True) + self.initialize_weights() + + def initialize_weights(self): + decoder_pos_embed = get_2d_sincos_pos_embed(self.decoder_pos_embed.shape[-1], + int(self.num_patches ** .5), cls_token=True) + extra_token = 1 + decoder_pos_embed = np.concatenate([decoder_pos_embed, np.zeros([extra_token, self.decoder_embed_dim]) + 0.5], axis=0) + + self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0)) + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + # we use xavier_uniform following official JAX ViT: + torch.nn.init.xavier_uniform_(m.weight) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + if m.bias is not None: + nn.init.constant_(m.bias, 0.) + + def patchify(self, imgs): + """ + imgs: (N, 3, H, W) + x: (N, L, patch_size**2 *3) + """ + p = self.patch_size + assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0 + + h = w = imgs.shape[2] // p + x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p)) + x = torch.einsum('nchpwq->nhwpqc', x) + x = x.reshape(shape=(imgs.shape[0], h * w, p ** 2 * 3)) + return x + + def unpatchify(self, x): + """ + x: (N, L, patch_size**2 *3) + imgs: (N, 3, H, W) + """ + p = self.patch_size + h = w = int(x.shape[1] ** .5) + assert h * w == x.shape[1] + + x = x.reshape(shape=(x.shape[0], h, w, p, p, 3)) + x = torch.einsum('nhwpqc->nchpwq', x) + imgs = x.reshape(shape=(x.shape[0], 3, h * p, h * p)) + return imgs + + def feature_unpatchify(self,x): + p = self.patch_size + h = w = x.shape[2] + + x = x.reshape(shape=(x.shape[0],p,p,3,h,w)) + x = torch.einsum('npqchw->nchpwq',x) + x = x.reshape(shape=(x.shape[0],3,h*p,w*p)) + return x + + def forward_loss(self, pred, gray_target, target, alpha=0.): + """ + loss of pred + gray_target and traget, which have the same shape + """ + loss_l2 = (pred + gray_target - target) ** 2 + loss_l2 = loss_l2.mean() # [N, L], mean loss per patch + loss_l1 = torch.abs((pred + gray_target - target)) + loss_l1 = loss_l1.mean() + + return loss_l2 * (1.0 - alpha) + loss_l1 * alpha + + def forward(self, x, clip_x, color_mask): + # embed tokens + x = self.decoder_embed(x) + + color_mask = self.patchify(color_mask) + x_color = self.color_embdding(color_mask) + x_color = torch.cat([x[:, 0, :].unsqueeze(1), x_color], dim=1) + x = x + x_color + + clip_x = clip_x.unsqueeze(1) + clip_x = self.clip_embed(clip_x) + + x = torch.cat([x, clip_x], dim=1) + # add pos embed + x = x + self.decoder_pos_embed + + # apply Transformer blocks + for blk in self.decoder_blocks: + x = blk(x) + + x = self.decoder_pred(x) + x = x[:, 1:-1, :] + + h = w = int(x.shape[1] ** .5) + dim = x.shape[2] + x = x.reshape(shape=(x.shape[0], h, w, dim)) + x = x.permute(0, 3, 1, 2) + x = self.decoder_conv1(x) + x = self.feature_unpatchify(x) + return x + +def mae_color_decoder_base(**kwargs): + model = ColorDecoder( + patch_size=16, embed_dim=768, decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, + mlp_ratio=4, norm_layer=LayerNorm, **kwargs)#partial(nn.LayerNorm, eps=1e-6) + return model + + diff --git a/deployment/auto_color_gui.py b/deployment/auto_color_gui.py new file mode 100644 index 0000000..ceb8def --- /dev/null +++ b/deployment/auto_color_gui.py @@ -0,0 +1,747 @@ +import sys +sys.path.append("..") +import tkinter as tk +from PIL import Image, ImageTk +from tkinter import ttk +from tkinter import filedialog + +from tkinter.colorchooser import askcolor +import numpy as np +import os +import time +from color_pick import ColorPick + +# using pytorch or onnx deployment +#from autocolor_pytorch_deployment import AutoColorDeployment # pytorch deployment +from autocolor_onnx_deployment import AutoColorDeployment # onnx deployment + +IMAGE_SIZE = 448 +BASE_SIZE = 224 + +class CanvasState(): + NORM = 1 + PEN = 2 + ERASER = 3 + def __init__(self): + super().__init__() + self.state = CanvasState.NORM + self.pen_color = '#FF0000' + self.base_pen_size = IMAGE_SIZE // BASE_SIZE + self.pen_size = self.base_pen_size * 4 + self.pen_color_array = np.array([255, 0, 0]) + self.drawing_rect = False + self.rect_scale = 1 + self.rect_x = 0 + self.rect_y = 0 + self.rect_end_x = 0 + self.rect_end_y = 0 + self.canvas_scale = 1 + + def get_state(self): + return self.status + + def set_state(self, state): + assert state in [CanvasState.NORM, CanvasState.PEN, CanvasState.ERASER] + self.state = state + return self.state + + def is_norm(self): + return self.state == CanvasState.NORM + + def using_pen(self): + return self.state == CanvasState.PEN + + def using_eraser(self): + return self.state == CanvasState.ERASER + +if __name__ == '__main__': + #init + root = tk.Tk() + root.title('AutoColor') + root.geometry('1400x780+100+10') + root.iconbitmap('feather_icon.ico') + image_path = './example_white.jpg' + pil_image = Image.open(image_path) + + # input image and input info frame + input_image_info_frame = tk.Frame(root) + input_image_info_frame.pack(side='left', anchor=tk.N) + + # input_image_info_frame -> input_image_frame + input_image_frame = tk.Frame(input_image_info_frame) + input_image_frame.pack(side='left', anchor=tk.N, padx=10, pady=10) + image_main_label = tk.Label(input_image_frame, text='待上色图片') + image_main_label.pack(anchor=tk.W) + image_main_frame = tk.Frame(input_image_frame) + image_main_frame.pack() + canvas_state = CanvasState() + color_grids = np.zeros((BASE_SIZE, BASE_SIZE, 3), dtype=np.uint8) + auto_color_model = AutoColorDeployment() + np.random.seed(0) + pil_image = pil_image.resize((IMAGE_SIZE, IMAGE_SIZE)) + image_main = ImageTk.PhotoImage(pil_image) + input_image = [[image_main, pil_image, pil_image]] + output_images = [[pil_image, pil_image]] + input_image_config = [] + + # resize and pad an image + def resize_and_pad(image, size, pad_color='black',resample=Image.Resampling.BICUBIC):#Resampling. + w, h = image.size + if w > h: + new_w, new_h = size, round(size / w * h) + else: + new_w, new_h = round(size / h * w), size + image_resize = image.resize((new_w, new_h), resample=resample)#Resampling. + pad_image = Image.new(mode='RGB', size=(size, size), color=pad_color) + pad_t, pad_l = (size - new_h) // 2, (size - new_w) // 2 + pad_b, pad_r = size - new_h - pad_t, size - new_w - pad_l + pad_image.paste(image_resize, (pad_l, pad_t)) + return pad_image, (pad_l, pad_r, pad_t, pad_b) + + # adjust the input gray image by add value to pixels + def scale_gray_image(value): + tk_img, img, img_ori = input_image.pop() + image_ori_np = np.array(img_ori,dtype=np.int32) + image_ori_np = image_ori_np + int(value) + image_ori_np = np.clip(image_ori_np,0,255) + image_ori_np = image_ori_np.astype(np.uint8) + image_ori = Image.fromarray(image_ori_np) + img,_ = resize_and_pad(image_ori,IMAGE_SIZE * canvas_state.canvas_scale,pad_color='black') + tk_img = ImageTk.PhotoImage(img) + input_image.append([tk_img, img, img_ori]) + cv_image = image_canvas.find_withtag('image') + image_canvas.itemconfig(cv_image[0],image=input_image[0][0]) + + #select the input image and convert it to grayscale + def select_input_img(): + imagedir = filedialog.askopenfilenames() + if len(imagedir) == 0: + return + img_ori = Image.open(imagedir[0]) + img_ori = img_ori.convert('L').convert('RGB') + + while len(input_image_config) > 0: + input_image_config.pop() + input_image_config.append([imagedir[0]]) + img, _ = resize_and_pad(img_ori, size=IMAGE_SIZE) + tk_img = ImageTk.PhotoImage(img) + while len(input_image) > 0: + input_image.pop() + input_image.append([tk_img, img, img_ori]) + image_canvas.delete("all") + image_canvas.create_image(IMAGE_SIZE // 2, IMAGE_SIZE // 2, image=input_image[0][0], tags=['image']) + color_grids[:, :, :] = 0 + gray_scale.set(0) + + #scale bar for ajusting the input gray image + gray_scale = tk.Scale(image_main_frame, + from_=-255, + to=255, + resolution=3, + length=IMAGE_SIZE, + sliderlength=20, + #showvalue=False, + command=scale_gray_image) + gray_scale.pack(side='left',anchor=tk.W) + gray_scale.set(0) + image_canvas = tk.Canvas(image_main_frame, bg="white", width=IMAGE_SIZE, height=IMAGE_SIZE) + image_canvas.pack(side='left') + image_canvas.config(scrollregion=(0,0,IMAGE_SIZE,IMAGE_SIZE)) + scroll_x_frame = tk.Frame(input_image_frame) + scroll_x_frame.pack(side=tk.TOP, fill=tk.X) + + #scale the canvas size + def canvas_scaling(): + canvas_scale = canvas_state.canvas_scale + #print(hbar.get(),vbar.get()) + if canvas_state.canvas_scale == 1: + canvas_state.canvas_scale = 2 + canvas_scaling_button.config(text = '缩小图片') + else: + canvas_state.canvas_scale = 1 + canvas_scaling_button.config(text = '放大图片') + real_size = IMAGE_SIZE * canvas_state.canvas_scale + image_canvas.config( + scrollregion=(0, 0,real_size , real_size) + ) + tk,_,img_ori =input_image[0] + img, _ = resize_and_pad(img_ori, size=real_size) + tk_img = ImageTk.PhotoImage(img) + input_image[0] = [tk_img,img,img_ori] + image_canvas.create_image(real_size // 2, real_size // 2, image=input_image[0][0], tags=['image']) + + #adjust the color info accordingly + all_rect_items = image_canvas.find_withtag('rect') + for i in range(len(all_rect_items)): + x0, y0, x1, y1 = image_canvas.coords(all_rect_items[i]) + factor = canvas_state.canvas_scale if canvas_state.canvas_scale==2 else 0.5 + x0, y0 = int((x0 - 1) * factor) , int((y0 - 1) * factor) + x1, y1 = int((x1 + 1) * factor), int((y1 + 1) * factor) + out_line = image_canvas.itemcget(all_rect_items[i], 'outline') + image_canvas.create_rectangle(x0 + 1, y0 + 1, x1 - 1, y1 - 1, + outline= out_line, + fill=out_line, + tags=['rect']) + image_canvas.delete(all_rect_items[i]) + + # adjust the selected local input accordingly + local_input_items = image_canvas.find_withtag('local_input') + for i in range(len(local_input_items)): + x0, y0, x1, y1 = image_canvas.coords(local_input_items[i]) + factor = canvas_state.canvas_scale if canvas_state.canvas_scale == 2 else 0.5 + x0, y0 = int((x0 - 1) * factor), int((y0 - 1) * factor) + x1, y1 = int((x1 + 1) * factor), int((y1 + 1) * factor) + out_line = image_canvas.itemcget(local_input_items[i], 'outline') + image_canvas.create_rectangle(x0, y0, x1, y1, + outline=out_line, + #fill=out_line, + tags=['local_input']) + image_canvas.delete(local_input_items[i]) + + canvas_scaling_button = tk.Button(scroll_x_frame, text='放大图片', command=canvas_scaling) + canvas_scaling_button.pack(side=tk.LEFT) + + #config the scrollbars + hbar = tk.Scrollbar(scroll_x_frame, orient=tk.HORIZONTAL) + hbar.pack(side=tk.TOP, fill=tk.X, anchor=tk.E) + hbar.config(command=image_canvas.xview) + vbar = tk.Scrollbar(image_main_frame, orient=tk.VERTICAL) + vbar.pack(side=tk.RIGHT, fill=tk.Y) + vbar.config(command=image_canvas.yview) + + image_canvas.config(xscrollcommand=hbar.set, yscrollcommand=vbar.set) + input_image_button = tk.Button(input_image_frame, text='选择输入图片') + input_image_button.pack(pady=3) + input_image_button.config(command=select_input_img) + + #input color info frame + color_info_frame = tk.LabelFrame(input_image_frame, text='参考颜色', padx=5, pady=5) + color_info_frame.pack(fill=tk.X) + + #pen and eraser + selet_frame = tk.Frame(color_info_frame) + selet_frame.pack(pady=5, fill=tk.X) + pen_button = tk.Button(selet_frame, text='使用画笔') + pen_button.pack(side='left') + color_canvas = tk.Canvas(selet_frame, + bg='red', + height=20, + width=20) + color_canvas.pack(side='left', padx=10) + color_button = tk.Button(selet_frame, text='选择画笔颜色') + color_button.pack(side='left',padx=2) + + color_pick_button = tk.Button(selet_frame, text='屏幕取色') + color_pick_button.pack(side='left') + + pen_size_frame = tk.Frame(selet_frame) + pen_size_frame.pack(pady=5, fill=tk.X) + pen_size = ['4', '8', '16']#'1', '2', + combobox = ttk.Combobox(pen_size_frame, width=11) + combobox.pack(side='right',padx=5) + combobox['value'] = pen_size + combobox.current(0) + + pen_size_label = tk.Label(pen_size_frame, text='画笔尺寸:') + pen_size_label.pack(side='right') + + eraser_and_clear_frame = tk.Frame(color_info_frame) + eraser_and_clear_frame.pack(fill=tk.X,side='top',pady=2) + eraser_button = tk.Button(eraser_and_clear_frame, text='使用橡皮') + eraser_button.pack(side='left') + clear_button = tk.Button(eraser_and_clear_frame, text='清空颜色') + clear_button.pack(side='right') + + def set_color_canvas(i): + i = int(i) + color_canvas.config(bg = color_rgb_list[i][1]) + canvas_state.pen_color = color_rgb_list[i][1] + canvas_state.pen_color_array = color_rgb_list[i][0] + color_use_time[i] = time.time() + + def set_color_btn(color): + select_btn_index = 0 + min_time = color_use_time[0] + for i, use_time in enumerate(color_use_time): + if use_time < min_time: + select_btn_index = i + min_time = use_time + color_use_time[select_btn_index] = time.time() + + color_rgb_list[select_btn_index] = color + color_button_list[select_btn_index].config(bg=color[1]) + set_color_canvas(select_btn_index) + + def rgb2hexstr(color): + rgb = [int(v) for v in color] + rgb_hex = [hex(v)[2:] for v in rgb] + rgb_hex = [ ('0'+v if len(v)<2 else v ) for v in rgb_hex] + return '#' + ''.join(rgb_hex) + + def hexstr2rgb(color): + assert len(color) == 7 and color[0]=='#' + color = color.lower() + r,g,b = color[1:3],color[3:5],color[5:7] + return [int(r,16),int(g,16),int(b,16)] + + def pick_color_from_screen(): + w = ColorPick(root) + color_pick_button.wait_window(w.top) + root.state('normal') + color = w.get_color() + color_str = rgb2hexstr(color) + set_color_btn([color, color_str]) + + #color pick from screen + color_pick_button.config(command=pick_color_from_screen) + + #config the alternate color list + #using the time stamps to keep the recently used color + color_button_frame = tk.Frame(color_info_frame) + color_button_frame.pack(fill=tk.X,side='top') + color_num = 13 + color_button_list = [] + color_rgb_list = [] + color_use_time=[0]*color_num + for _ in range(color_num): + color_int = np.random.randint(255,size=3) + color_rgb_list.append([color_int,rgb2hexstr(color_int)]) + for i in list(range(color_num)): + fn = lambda a=i: set_color_canvas(a) + color_btn = tk.Button(color_info_frame, + height=1, + width=3, + bd=0, + bg=color_rgb_list[i][1], + command=fn, + ) + color_btn.pack(side='left',padx=3) + color_button_list.append(color_btn) + + #config the local input + rect_input_button = tk.Button(input_image_frame, text='选择上色区域') + rect_input_button.pack(pady=10) + + def select_rect_input(): + drawing_rect = canvas_state.drawing_rect + if drawing_rect: + canvas_state.drawing_rect = False + rect_input_button.config(text='选择上色区域') + rect_item = image_canvas.find_withtag('local_input') + image_canvas.delete(rect_item) + else: + canvas_state.drawing_rect = True + rect_input_button.config(text='清除选择框') + if canvas_state.using_pen(): + pen_button.config(text='使用画笔') + if canvas_state.using_eraser(): + eraser_button.config(text='使用橡皮') + canvas_state.set_state(CanvasState.NORM) + image_canvas.config(cursor='arrow') + + rect_input_button.config(command=select_rect_input) + + # input_image_info_frame -> middle_info_frame + middle_info_frame = tk.Frame(input_image_info_frame, padx=20) + middle_info_frame.pack(side='right', anchor=tk.N, pady=10) + image_info_frame = tk.LabelFrame(middle_info_frame, text='参考图片', padx=5, pady=5) + image_info_frame.pack() + image_choose_button = tk.Button(image_info_frame, text='选择参考图片') + image_choose_button.pack(anchor=tk.W,pady=10) + info_image = pil_image.resize((IMAGE_SIZE // 2, IMAGE_SIZE // 2)) + info_image = ImageTk.PhotoImage(info_image) + image_info_label = tk.Label(image_info_frame, image=info_image) + image_info_label.pack(side='left', fill=tk.X,pady=10) + info_images = [] + tk_info_images = [] + + #selet the input info images + def selet_info_image(): + MAX_INFO_NUM = 16 + pad_pixel = 4 + imagedir = filedialog.askopenfilenames() + if len(imagedir) == 0: + return + image_num = len(imagedir) + if image_num > MAX_INFO_NUM: + image_num = MAX_INFO_NUM + while len(info_images) > 0: + info_images.pop() + info_grid = int(np.sqrt(image_num)) + info_grid = info_grid if info_grid*info_grid == image_num else info_grid + 1 + if info_grid == 1: + pad_pixel = 0 + image_show_size = (BASE_SIZE - pad_pixel* (info_grid-1) ) // info_grid + combine_show_image = Image.new(mode='RGB',size=(BASE_SIZE,BASE_SIZE),color="#ffffff") + for i in range(info_grid): + for j in range(info_grid): + image_index = i*info_grid + j + if image_index <= image_num - 1: + pad_i = i * (pad_pixel + image_show_size ) + pad_j = j * (pad_pixel + image_show_size ) + img_ori = Image.open(imagedir[image_index]) + img_ori = img_ori.convert('RGB') + img_info, _ = resize_and_pad(img_ori, size=BASE_SIZE) + info_images.append(img_info) + img_show, _ = resize_and_pad(img_info,size=image_show_size) + combine_show_image.paste(img_show,(pad_j,pad_i)) + + tk_img = ImageTk.PhotoImage(combine_show_image) + while len(tk_info_images)>0: + tk_info_images.pop() + #keep the pointer of tk_img to avoid recycling + tk_info_images.append(tk_img) + image_info_label.config(image=tk_img) + + image_choose_button.config(command=selet_info_image) + + # input_image_info_frame -> middle_info_frame -> text info frame + text_info_frame = tk.LabelFrame(middle_info_frame, text='参考文本', padx=15,pady=10) + text_info_frame.pack(fill=tk.X,pady=10) + text = tk.Text(text_info_frame, width=10, height=10, undo=True, autoseparators=False) + text.pack(fill=tk.X) + + # clear the canvas + def clear_canvas(): + ori_img = input_image[0][1] + tk_image = ImageTk.PhotoImage(ori_img) + input_image[0][0] = tk_image + image_canvas.delete("all") + real_size = IMAGE_SIZE * canvas_state.canvas_scale + image_canvas.create_image(real_size // 2, real_size // 2, image=input_image[0][0],tags=['image']) + color_grids[:, :, :] = 0 + + # select or unselect the pen + def click_pen(): + if canvas_state.using_pen(): + canvas_state.set_state(CanvasState.NORM) + pen_button.config(text='使用画笔') + image_canvas.config(cursor='arrow') + else: + if canvas_state.using_eraser(): + eraser_button.config(text='使用橡皮') + canvas_state.set_state(CanvasState.PEN) + pen_button.config(text='取消画笔') + image_canvas.config(cursor='pencil') + + #select or unselect the eraser + def click_eraser(): + if canvas_state.using_eraser(): + canvas_state.set_state(CanvasState.NORM) + eraser_button.config(text='使用橡皮') + image_canvas.config(cursor='arrow') + + else: + if canvas_state.using_pen(): + pen_button.config(text='使用画笔') + + canvas_state.set_state(CanvasState.ERASER) + eraser_button.config(text='取消橡皮') + image_canvas.config(cursor='tcross') + + def selet_pen_size(event): + pen_size = combobox.get() + canvas_state.pen_size = canvas_state.base_pen_size * int(pen_size) + + #select pen color + def select_color(): + color = askcolor(title="颜色选择框", color="red") + if color[1] is None: + return + color_canvas.config(bg=color[1]) + canvas_state.pen_color = color[1] + canvas_state.pen_color_array = np.array(color[0]) + set_color_btn(color) + + color_button.config(command=select_color) + clear_button.config(command=clear_canvas) + combobox.bind('<>', selet_pen_size) + pen_button.config(command=click_pen) + eraser_button.config(command=click_eraser) + last_x = tk.IntVar(value=0) + last_y = tk.IntVar(value=0) + + # canvas event: LeftButtonDown + def onLeftButtonDown(event): + scroll_x = hbar.get() + scroll_y = vbar.get() + scroll_x = round(scroll_x[0] * IMAGE_SIZE * canvas_state.canvas_scale) + scroll_y = round(scroll_y[0] * IMAGE_SIZE * canvas_state.canvas_scale) + event.x = event.x + scroll_x + event.y = event.y + scroll_y + pen_size = canvas_state.pen_size + if canvas_state.is_norm(): + if canvas_state.is_norm(): + if canvas_state.drawing_rect: + canvas_state.rect_x = event.x + canvas_state.rect_y = event.y + canvas_state.rect_scale = canvas_state.canvas_scale + return + elif canvas_state.using_pen(): + x = event.x + y = event.y + x = x // pen_size * pen_size + y = y // pen_size * pen_size + + image_canvas.create_rectangle(x + 1, y + 1, x + pen_size - 1, y + pen_size - 1, + outline=canvas_state.pen_color, + fill=canvas_state.pen_color,tags=['rect']) + + x1 = x // canvas_state.base_pen_size // canvas_state.canvas_scale + y1 = y //canvas_state.base_pen_size // canvas_state.canvas_scale + step = pen_size // canvas_state.base_pen_size // canvas_state.canvas_scale + for i in range(3): + color_grids[x1:x1 + step, y1:y1 + step, i] = int(canvas_state.pen_color_array[i]) + + last_x.set(x) + last_y.set(y) + elif canvas_state.using_eraser(): + pen_size = min(pen_size,4) + x = event.x + y = event.y + x = x // pen_size * pen_size + y = y // pen_size * pen_size + + items = image_canvas.find_overlapping(x + 1, y + 1, x + pen_size - 1, y + pen_size - 1) + for i in range(len(items)): + tags = image_canvas.itemcget(items[i], 'tags') + + if 'rect' in tags: + x0,y0,x1,y1 = image_canvas.coords(items[i]) + x0, y0 = int(x0 - 1) // canvas_state.base_pen_size, int(y0 - 1)//canvas_state.base_pen_size + x1, y1 = int(x1 + 1) // canvas_state.base_pen_size, int(y1 + 1)//canvas_state.base_pen_size + color_grids[x0:x1, y0:y1, :] = 0 + image_canvas.delete(items[i]) + + #canvas event: LeftButtonMove + def onLeftButtonMove(event): + scroll_x = hbar.get() + scroll_y = vbar.get() + scroll_x = round(scroll_x[0] * IMAGE_SIZE * canvas_state.canvas_scale) + scroll_y = round(scroll_y[0] * IMAGE_SIZE * canvas_state.canvas_scale) + event.x = event.x + scroll_x + event.y = event.y + scroll_y + pen_size = canvas_state.pen_size + + if canvas_state.is_norm(): + if canvas_state.drawing_rect: + rect_item = image_canvas.find_withtag('local_input') + #print(rect_item) + image_canvas.delete(rect_item) + image_canvas.create_rectangle(canvas_state.rect_x,canvas_state.rect_y,event.x,event.y, + tag='local_input',outline='red') + canvas_state.rect_end_x = event.x + canvas_state.rect_end_y = event.y + return + elif canvas_state.using_pen(): + x = event.x + y = event.y + x = x // pen_size * pen_size + y = y // pen_size * pen_size + if x != last_x.get() or y != last_y.get(): + #print('moving create',x,y,last_x.get(),last_y.get(), abs(x - last_x.get()) + abs(y - last_x.get() )) + image_canvas.create_rectangle(x + 1, y + 1, x + pen_size - 1, y + pen_size - 1, + outline=canvas_state.pen_color, + fill=canvas_state.pen_color, + tags=['rect']) + last_x.set(x) + last_y.set(y) + + x1 = x // canvas_state.base_pen_size // canvas_state.canvas_scale + y1 = y // canvas_state.base_pen_size // canvas_state.canvas_scale + step = pen_size // canvas_state.base_pen_size // canvas_state.canvas_scale + for i in range(3): + color_grids[x1:x1 + step, y1:y1 + step, i] = int(canvas_state.pen_color_array[i]) + + + elif canvas_state.using_eraser(): + pen_size = max(pen_size,4) + x = event.x + y = event.y + x = x // pen_size * pen_size + y = y // pen_size * pen_size + if x != last_x.get() or y != last_x.get(): + items = image_canvas.find_overlapping(x + 1, y + 1, x + pen_size - 1, y + pen_size - 1) + # print(items) + for i in range(len(items)): + tags = image_canvas.itemcget(items[i], 'tags') + if 'rect' in tags: + x0, y0, x1, y1 = image_canvas.coords(items[i]) + x0, y0 = int(x0 - 1) // canvas_state.base_pen_size, int(y0 - 1) // canvas_state.base_pen_size + x1, y1 = int(x1 + 1) // canvas_state.base_pen_size, int(y1 + 1) // canvas_state.base_pen_size + color_grids[x0:x1, y0:y1, :] = 0 + image_canvas.delete(items[i]) + + last_x.set(x) + last_y.set(y) + + image_canvas.bind('', onLeftButtonDown) + image_canvas.bind('', onLeftButtonMove) + + #utilize gpu or cpu + use_gpu = tk.IntVar() + def set_auto_color_device(): + auto_color_model.set_device(use_gpu.get()) + gpu_device = tk.Checkbutton(middle_info_frame, text="使用gpu", variable=use_gpu, command=set_auto_color_device) + gpu_device.pack(pady=15) + + # generate colorful image + generate_button = tk.Button(middle_info_frame, text='执行自动上色', pady=15) + generate_button.pack(fill=tk.X, padx=5, pady=0) + generate_button.config(fg='red') + + # get different resolution inputs for the model + def get_diff_res_inputs(image_size, target_size,image): + input_image_model = [] + + while image_size < target_size: + temp, pad = resize_and_pad(image, image_size, pad_color='black') + input_image_model.append(temp) + image_size = image_size * 2 + temp, pad = resize_and_pad(image, target_size, pad_color='black') + input_image_model.append(temp) + return input_image_model, pad + + # get scaled gray image + def get_scaled_gray_image(value,image): + if value == 0: + return image + image = np.array(image,dtype=np.int32) + image = image + int(value) + image = np.clip(image, 0, 255) + image = image.astype(np.uint8) + image = Image.fromarray(image) + return image + + #colorize the input image with the AI model + def auto_color(): + scale_value = gray_scale.get() + if canvas_state.drawing_rect: + #local rect input + input_image_ori = input_image[0][2] + input_image_ori = get_scaled_gray_image(scale_value, input_image_ori) + w, h = input_image_ori.size + new_size = max(w,h) + temp, pad_ori = resize_and_pad(input_image_ori, new_size, pad_color='black') + crop_x = round(canvas_state.rect_x // canvas_state.rect_scale * new_size / IMAGE_SIZE) + crop_y = round(canvas_state.rect_y // canvas_state.rect_scale * new_size / IMAGE_SIZE) + crop_end_x = round(canvas_state.rect_end_x // canvas_state.rect_scale * new_size / IMAGE_SIZE) + crop_end_y = round(canvas_state.rect_end_y // canvas_state.rect_scale* new_size / IMAGE_SIZE) + crop_image = temp.crop((crop_x,crop_y,crop_end_x,crop_end_y)) + target_size = max([crop_end_y - crop_y, crop_end_x - crop_x]) + target_size = max([target_size,BASE_SIZE]) + diff_res_inputs,pad = get_diff_res_inputs(BASE_SIZE,target_size,crop_image) + else: + # full image input + input_image_ori = input_image[0][2] + input_image_ori = get_scaled_gray_image(scale_value, input_image_ori) + w, h = input_image_ori.size + target_size = max([w, h]) + image_size = BASE_SIZE + + diff_res_inputs,pad_ori = get_diff_res_inputs(image_size,target_size,input_image_ori) + if len(input_image_config[0])<2: + input_image_config[0].append(pad_ori) + + image_info = info_images + + input_text = text.get('0.0', 'end') + input_text = input_text.strip() + input_text = None if len(input_text) == 0 else input_text + + #local rect input + if canvas_state.drawing_rect: + color_info = color_grids.repeat(canvas_state.base_pen_size, axis=0).repeat(canvas_state.base_pen_size, axis=1) + + x0 = canvas_state.rect_x // canvas_state.rect_scale + x1 = canvas_state.rect_end_x // canvas_state.rect_scale + y0 = canvas_state.rect_y // canvas_state.rect_scale + y1 = canvas_state.rect_end_y // canvas_state.rect_scale + color_info = color_info[x0:x1, y0:y1,:] + color_info = np.transpose(color_info, (1, 0, 2)) + color_info = Image.fromarray(color_info) + color_info, _ = resize_and_pad(color_info, BASE_SIZE, pad_color='black', resample=Image.Resampling.NEAREST) + output_image = auto_color_model.autocolor_forward(diff_res_inputs, image_info, input_text, color_info) + + if len(output_images)>0: + _,output_image_ori = output_images.pop() + else: + output_image_ori,pad_ori = resize_and_pad(input_image_ori, max(input_image_ori.size), pad_color='black') + pad_l, pad_r, pad_t, pad_b = pad + output_image = output_image.crop( (pad_l,pad_t,target_size - pad_r, target_size - pad_b)) + output_image = output_image.resize((crop_end_x-crop_x,crop_end_y - crop_y)) + output_image_ori.paste(output_image,(crop_x,crop_y)) + output_image_show = output_image_ori.resize((IMAGE_SIZE, IMAGE_SIZE), resample=Image.Resampling.BICUBIC) + tk_output_image = ImageTk.PhotoImage(output_image_show) + while len(output_images) > 0: + output_images.pop() + output_images.append([tk_output_image, output_image_ori]) + right_image_label.config(image=tk_output_image) + else: + # full image input + color_info = color_grids + color_info = np.transpose(color_info, (1, 0, 2)) + output_image = auto_color_model.autocolor_forward(diff_res_inputs, image_info, input_text, color_info) + output_image_resize = output_image.resize((IMAGE_SIZE, IMAGE_SIZE), resample=Image.Resampling.BICUBIC) + # output_image.show() + while len(output_images) > 0: + output_images.pop() + + tk_output_image = ImageTk.PhotoImage(output_image_resize) + output_images.append([tk_output_image, output_image]) + right_image_label.config(image=tk_output_image) + + generate_button.config(command=auto_color) + + # right_output_frame + right_output_frame = tk.Frame(root) + right_output_frame.pack(side='left', anchor=tk.N, pady=10) + right_output_label = tk.Label(right_output_frame, text='已上色图片') + right_output_label.pack(anchor=tk.W) + right_image_label = tk.Label(right_output_frame, image=image_main) + right_image_label.pack() + + fast_save_button = tk.Button(right_output_frame, text='快速保存') + fast_save_button.pack(side='right', pady=5) + + output_save_button = tk.Button(right_output_frame, text='保存输出图片') + output_save_button.pack(side='left',pady=5) + + + output_show_button = tk.Button(right_output_frame, text='展示图片') + output_show_button.pack(pady=5) + + def save_output_image(): + file_path = filedialog.asksaveasfilename(title=u'保存文件') + if len(file_path) == 0: + return + _, image = output_images[0] + path, pad = input_image_config[0] + pad_l, pad_r, pad_t, pad_b = pad + w, h = image.size + image = image.crop((pad_l, pad_t, w - pad_r, h - pad_b)) + image.save(file_path) + + + def fast_save(): + _, image = output_images[0] + path, pad = input_image_config[0] + pad_l, pad_r, pad_t, pad_b = pad + w, h = image.size + new_file_name = os.path.join(os.path.dirname(path), 'autocolor_' + os.path.basename(path)) + image = image.crop((pad_l, pad_t, w - pad_r, h - pad_b)) + image.save(new_file_name) + + def show_output_image(): + _, image = output_images[0] + image.show() + # color_info = np.transpose(color_grids, (1, 0, 2)) + # Image.fromarray(color_info).show() + + output_save_button.config(command=save_output_image) + output_show_button.config(command=show_output_image) + fast_save_button.config(command=fast_save) + + root.mainloop() + + diff --git a/deployment/auto_color_gui.spec b/deployment/auto_color_gui.spec new file mode 100644 index 0000000..039257a --- /dev/null +++ b/deployment/auto_color_gui.spec @@ -0,0 +1,45 @@ +# -*- mode: python ; coding: utf-8 -*- + + +block_cipher = None + + +a = Analysis( + ['auto_color_gui.py'], + pathex=[], + binaries=[], + datas=[('D:\\\\tools\\\\anaconda_file\\\\.conda\\\\envs\\\\cuda113\\\\Lib\\\\site-packages\\\\onnxruntime\\\\capi\\\\*.dll', 'onnxruntime\\\\capi')], + hiddenimports=[], + hookspath=[], + hooksconfig={}, + runtime_hooks=[], + excludes=[], + win_no_prefer_redirects=False, + win_private_assemblies=False, + cipher=block_cipher, + noarchive=False, +) +pyz = PYZ(a.pure, a.zipped_data, cipher=block_cipher) + +exe = EXE( + pyz, + a.scripts, + a.binaries, + a.zipfiles, + a.datas, + [], + name='auto_color_gui', + debug=False, + bootloader_ignore_signals=False, + strip=False, + upx=True, + upx_exclude=[], + runtime_tmpdir=None, + console=False, + disable_windowed_traceback=False, + argv_emulation=False, + target_arch=None, + codesign_identity=None, + entitlements_file=None, + icon='feather_icon.ico', +) diff --git a/deployment/autocolor_onnx_deployment.py b/deployment/autocolor_onnx_deployment.py new file mode 100644 index 0000000..efe2518 --- /dev/null +++ b/deployment/autocolor_onnx_deployment.py @@ -0,0 +1,181 @@ +from PIL import Image +from typing import Any, Union, List +import onnxruntime +import numpy as np +from simple_tokenizer import SimpleTokenizer as _Tokenizer +# + +class AutoColorDeployment: + def __init__(self): + self.device = 0 # 0 for cpu, 1 for gpu + providers = ['CPUExecutionProvider'] + + self.mae_encoder_file = './onnx_models/mae_encoder_vitb_onnx.onnx' + self.clip_text_file = './onnx_models/clip_textual.onnx' + self.clip_image_file = './onnx_models/clip_visual.onnx' + self.color_decoder_file = './onnx_models/color_decoder_onnx.onnx' + self.super_color_file = './onnx_models/super_color_onnx.onnx' + self.mae_encoder_session = onnxruntime.InferenceSession(self.mae_encoder_file, providers=providers) + self.clip_text_session = onnxruntime.InferenceSession(self.clip_text_file, providers=providers) + self.color_decoder_session = onnxruntime.InferenceSession(self.color_decoder_file, providers=providers) + self.clip_image_session = onnxruntime.InferenceSession(self.clip_image_file, providers=providers) + self.super_color_session = onnxruntime.InferenceSession(self.super_color_file, providers=providers) + self._tokenizer = _Tokenizer(bpe_path='./onnx_models/bpe_simple_vocab_16e6.txt.gz') + + # 0 for cpu, 1 for gpu + def set_device(self, device): + if device == self.device: + return + else: + self.device = device + providers = ['CPUExecutionProvider'] if device==0 else ['CUDAExecutionProvider'] + self.mae_encoder_session = onnxruntime.InferenceSession(self.mae_encoder_file, providers=providers) + self.clip_text_session = onnxruntime.InferenceSession(self.clip_text_file, providers=providers) + self.color_decoder_session = onnxruntime.InferenceSession(self.color_decoder_file, providers=providers) + self.clip_image_session = onnxruntime.InferenceSession(self.clip_image_file, providers=providers) + self.super_color_session = onnxruntime.InferenceSession(self.super_color_file, providers=providers) + + #interp sample with nearest pixels, which same with F.interpolate(x,nearest) in pytorch + def numpy_nearest_interp(self, sample, size): + + b, c, w, h = sample.shape + new_w, new_h = size + if new_w % w == 0 and new_h % h == 0: + return sample.repeat(new_w // w, axis=2).repeat(new_h // h, axis=3) + x = np.arange(new_w) * w / new_w + y = np.arange(new_h) * w / new_h + x = x[:, None].repeat(new_h, axis=1).reshape(-1).astype(np.int) + y = y[None, :].repeat(new_w, axis=0).reshape(-1).astype(np.int) + interp_sample = sample[:, :, x, y].reshape(b, c, new_w, new_h) + return interp_sample + + def get_image_from_numpy(self, sample,clip_norm=False): + image = np.transpose(sample, (1, 2, 0)) + if clip_norm: + mean = np.asarray([0.48145466, 0.4578275, 0.40821073], dtype=np.float32)*255 + std = np.asarray([0.26862954, 0.26130258, 0.27577711], dtype=np.float32)*255 + else: + mean = np.asarray([123.68, 116.28, 103.53], dtype=np.float32) + std = np.asarray([58.395, 57.120, 57.375], dtype=np.float32) + + image = image * std + mean + image = np.clip(image, 0, 255) + image = image.astype(np.uint8) + pil_image = Image.fromarray(image) + return pil_image + + def pre_processing(self, img, clip_norm=False): + img = np.array(img, dtype=np.float32) + if clip_norm: + mean = np.asarray([0.48145466, 0.4578275, 0.40821073], dtype=np.float32)*255 + std = np.asarray([0.26862954, 0.26130258, 0.27577711], dtype=np.float32)*255 + else: + mean = np.asarray([123.68, 116.28, 103.53], dtype=np.float32) + std = np.asarray([58.395, 57.120, 57.375], dtype=np.float32) + img = (img - mean) / std + img = np.transpose(img, (2, 0, 1)) + img = img[None,:,:,:] + return img + + def clip_tokenize(self,texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) : + """ + Returns the tokenized representation of given input string(s) + + Parameters + ---------- + texts : Union[str, List[str]] + An input string or a list of input strings to tokenize + + context_length : int + The context length to use; all CLIP models use 77 as the context length + + truncate: bool + Whether to truncate the text in case its encoding is longer than the context length + + Returns + ------- + A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]. + We return LongTensor when torch version is <1.8.0, since older index_select requires indices to be long. + """ + if isinstance(texts, str): + texts = [texts] + + sot_token = self._tokenizer.encoder["<|startoftext|>"] + eot_token = self._tokenizer.encoder["<|endoftext|>"] + all_tokens = [[sot_token] + self._tokenizer.encode(text) + [eot_token] for text in texts] + # if packaging.version.parse(torch.__version__) < packaging.version.parse("1.8.0"): + # result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) + # else: + result = np.zeros((len(all_tokens), context_length), dtype=np.int) + + for i, tokens in enumerate(all_tokens): + if len(tokens) > context_length: + if truncate: + tokens = tokens[:context_length] + tokens[-1] = eot_token + else: + raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}") + result[i, :len(tokens)] = tokens + + return result + + def autocolor_forward(self, input_images, image_info, input_text, color_mask): + """ + input_images : List[PIL.Image] + The input gray images with different resolutions. + e.g. for a 1080 x 1080 input image, the input image list is [224x224, 448x448, 896x896, 1080x1080] + + image_info : PIL.Image 224x224 + The info for image colorization, which will be fed to the clip image model + + input_text : str + The info for image colorization, which will be fed to the clip text model + + color_mask: numpy.array 224x224x3 + The color info for image colorization + + Returns + ------- + the colored image + """ + image_clip_feature_sum = 0 + if len(image_info) > 0: + for img in image_info: + clip_input = self.pre_processing(img,clip_norm=True) + clip_feature = self.clip_image_session.run(None, {"input":clip_input})[0] + image_clip_feature_sum += clip_feature + + for i in range(len(input_images)): + img = self.pre_processing(input_images[i]) + input_images[i] = img + + if input_text is not None: + text = self.clip_tokenize([input_text]) + text_features = self.clip_text_session.run(None,{"input":text})[0] + if len(image_info) > 0: + clip_feature = ( image_clip_feature_sum + text_features ) * 0.5 + else: + clip_feature = text_features + + mae_feature = self.mae_encoder_session.run(None,{"input":input_images[0]})[0] + color_mask = self.pre_processing(color_mask) + + #color decoder + pred = self.color_decoder_session.run(None,{"input_mae":mae_feature, "input_clip":clip_feature, "input_color":color_mask})[0] + + #super color + for i in range(1,len(input_images)): + img_pred = pred + input_images[i-1] + pred_upsampling = self.numpy_nearest_interp(img_pred,input_images[i].shape[2:]) + color_mask = self.numpy_nearest_interp(color_mask,input_images[i].shape[2:]) + sc_pred = self.super_color_session.run(None,{"input_interp":pred_upsampling, "input_gray":input_images[i], "input_color":color_mask})[0] + pred = sc_pred + img_pred = pred + input_images[-1] + output_img = self.get_image_from_numpy(img_pred[0]) + return output_img + +if __name__=='__main__': + auto_color_deploy = AutoColorDeployment() + + + diff --git a/deployment/autocolor_pytorch_deployment.py b/deployment/autocolor_pytorch_deployment.py new file mode 100644 index 0000000..3eab7ce --- /dev/null +++ b/deployment/autocolor_pytorch_deployment.py @@ -0,0 +1,136 @@ +from PIL import Image +import numpy as np +import torch +import torch.nn.functional as F + +import sys +sys.path.append("..") + +from mae_encoder import mae_vit_base_patch16_dec512d8b +from color_decoder import mae_color_decoder_base +from super_color import SuperColor +import clip + +class AutoColorDeployment: + def __init__(self): + self.device = "cuda" if torch.cuda.is_available() else "cpu" + mae_encoder_dir = './pytorch_models/mae_visualize_vit_base.pth' + mae_encoder_model = mae_vit_base_patch16_dec512d8b() + mae_cp = torch.load(mae_encoder_dir, map_location='cpu') + msg = mae_encoder_model.load_state_dict(mae_cp['model'],strict=False) + mae_encoder_model = mae_encoder_model.to(self.device) + mae_encoder_model.eval() + self.mae_encoder_model = mae_encoder_model + print(msg) + + clip_model,_ = clip.load("ViT-B/16", download_root='./pytorch_models',device=self.device) + clip_model.eval() + self.clip_model = clip_model + + color_decoder_dir = './pytorch_models/color_decoder.pth' + color_decoder_model = mae_color_decoder_base() + color_decoder_cp = torch.load(color_decoder_dir,map_location='cpu') + msg = color_decoder_model.load_state_dict(color_decoder_cp, strict=False) + color_decoder_model = color_decoder_model.to(self.device) + color_decoder_model.eval() + self.color_decoder_model = color_decoder_model + print(msg) + + super_color_dir = './pytorch_models/super_color.pth' + super_color_model = SuperColor(kernel_size=5, group=4) + super_color_checkpoint = torch.load(super_color_dir, map_location='cpu') + msg = super_color_model.load_state_dict(super_color_checkpoint, strict=False) + super_color_model = super_color_model.to(self.device) + super_color_model.eval() + self.super_color_model = super_color_model + print(msg) + + def set_device(self, device): + return + + def get_image_from_tensor(self, sample,clip_norm=False): + sample = sample.cpu().detach().numpy() + image = np.transpose(sample, (1, 2, 0)) + if clip_norm: + mean = np.asarray([0.48145466, 0.4578275, 0.40821073], dtype=np.float32)*255 + std = np.asarray([0.26862954, 0.26130258, 0.27577711], dtype=np.float32)*255 + else: + mean = np.asarray([123.68, 116.28, 103.53], dtype=np.float32) + std = np.asarray([58.395, 57.120, 57.375], dtype=np.float32) + + image = image * std + mean + image = np.clip(image, 0, 255) + image = image.astype(np.uint8) + pil_image = Image.fromarray(image) + return pil_image + + def pre_processing(self, img, clip_norm=False): + img = np.array(img, dtype=np.float32) + if clip_norm: + mean = np.asarray([0.48145466, 0.4578275, 0.40821073], dtype=np.float32)*255 + std = np.asarray([0.26862954, 0.26130258, 0.27577711], dtype=np.float32)*255 + else: + mean = np.asarray([123.68, 116.28, 103.53], dtype=np.float32) + std = np.asarray([58.395, 57.120, 57.375], dtype=np.float32) + img = (img - mean) / std + img = np.transpose(img, (2, 0, 1)) + img = torch.from_numpy(img[None,:,:,:]) + return img + + def autocolor_forward(self,input_images, image_info, input_text, color_mask): + """ + input_images : List[PIL.Image] + The input gray images with different resolutions. + e.g. for a 1080 x 1080 input image, the input image list is [224x224, 448x448, 896x896, 1080x1080] + + image_info : PIL.Image 224x224 + The image info for image colorization, which will be fed into the clip image model + + input_text : str + The text info for image colorization, which will be fed into the clip text model + + color_mask: numpy.array 224x224x3 + The color info for image colorization + + Returns + ------- + the colored image + """ + image_clip_feature_sum = 0 + if len(image_info) > 0: + for img in image_info: + clip_input = self.pre_processing(img,clip_norm=True) + clip_input = clip_input.to(self.device, non_blocking=True) + clip_feature = self.clip_model.encode_image(clip_input) + image_clip_feature_sum += clip_feature + + for i in range(len(input_images)): + img = self.pre_processing(input_images[i]) + input_images[i] = img.to(self.device, non_blocking=True) + + if input_text is not None: + text = clip.tokenize([input_text]).to(self.device) + text_features = self.clip_model.encode_text(text)#.repeat(BATCH_SIZE,1) + if len(image_info) > 0: + clip_feature = ( image_clip_feature_sum + text_features ) * 0.5 + else: + clip_feature = text_features + + mae_feature = self.mae_encoder_model(input_images[0]) + color_mask = self.pre_processing(color_mask) + color_mask = color_mask.to(self.device, non_blocking=True) + with torch.no_grad(): + with torch.cuda.amp.autocast(): + pred = self.color_decoder_model(mae_feature, clip_feature, color_mask=color_mask) + + for i in range(1,len(input_images)): + img_pred = pred + input_images[i-1] + pred_upsampling = F.interpolate(img_pred, size=(input_images[i].size()[2:])) + + color_mask = F.interpolate(color_mask, size=(input_images[i].size()[2:])) + sc_pred = self.super_color_model(pred_upsampling,input_images[i],color_mask) + + pred = sc_pred + img_pred = pred + input_images[-1] + output_img = self.get_image_from_tensor(img_pred[0]) + return output_img diff --git a/deployment/clip_onnx/__init__.py b/deployment/clip_onnx/__init__.py new file mode 100644 index 0000000..a4648a0 --- /dev/null +++ b/deployment/clip_onnx/__init__.py @@ -0,0 +1,7 @@ +''' +Reference: https://github.com/Lednik7/CLIP-ONNX +''' +from .clip_converter import clip_converter +from .clip_onnx import clip_onnx +from .utils import Textual, attention +from .benchmark import speed_test diff --git a/deployment/clip_onnx/benchmark.py b/deployment/clip_onnx/benchmark.py new file mode 100644 index 0000000..8fbf7aa --- /dev/null +++ b/deployment/clip_onnx/benchmark.py @@ -0,0 +1,16 @@ +import time +import torch + + +def speed_test(func, data_gen, n: int = 5, empty_cache: bool = True): + if empty_cache: + torch.cuda.empty_cache() + values = [] + for _ in range(n): + input_data = data_gen() + t = time.time() + func(input_data) + values.append(time.time() - t) + if empty_cache: + torch.cuda.empty_cache() + return sum(values) / n diff --git a/deployment/clip_onnx/clip_converter.py b/deployment/clip_onnx/clip_converter.py new file mode 100644 index 0000000..17d90a9 --- /dev/null +++ b/deployment/clip_onnx/clip_converter.py @@ -0,0 +1,94 @@ +import torch +import onnx +from torch import nn +from onnxruntime.quantization import quantize_dynamic, QuantType +from .utils import Textual, DEFAULT_EXPORT + + +class clip_converter(nn.Module): + def __init__(self, model, visual_path: str = "clip_visual.onnx", + textual_path: str = "clip_textual.onnx"): + super().__init__() + self.model = model + self.visual_path = visual_path + self.textual_path = textual_path + self.visual_flag = False + self.textual_flag = False + self.logit_scale = self.model.logit_scale.exp() + + self.model.eval() + for x in self.model.parameters(): + x.requires_grad = False + + def quantization(self, mode: str = "dynamic"): + assert mode in ["dynamic"] + if mode == "dynamic": + model_quant_visual = f"{self.visual_path}.quant" + quantize_dynamic(self.visual_path, + model_quant_visual, + weight_type=QuantType.QUInt8) + self.visual_path = model_quant_visual + + model_quant_textual = f"{self.textual_path}.quant" + quantize_dynamic(self.textual_path, + model_quant_textual, + weight_type=QuantType.QUInt8) + self.textual_path = model_quant_textual + + def torch_export(self, model, dummy_input, path: str, export_params=DEFAULT_EXPORT): + torch.onnx.export(model, dummy_input, path, **export_params) + + def onnx_checker(self, path: str): + model = onnx.load(path) + onnx.checker.check_model(model) + del model + + def convert_visual(self, dummy_input, wrapper=lambda x: x, + export_params=DEFAULT_EXPORT): + visual = wrapper(self.model.visual) + self.torch_export(visual, dummy_input, self.visual_path, + export_params=export_params) + self.onnx_checker(self.visual_path) + + def convert_textual(self, dummy_input, wrapper=Textual, + export_params=DEFAULT_EXPORT): + textual = wrapper(self.model) + self.torch_export(textual, dummy_input, self.textual_path, + export_params=export_params) + self.onnx_checker(self.textual_path) + + def convert2onnx(self, visual_input=None, textual_input=None, verbose=True, + visual_wrapper=lambda x: x, + textual_wrapper=Textual, + visual_export_params=DEFAULT_EXPORT, + textual_export_params=DEFAULT_EXPORT): + isinstance_visual_input = isinstance(visual_input, (torch.Tensor)) + isinstance_textual_input = isinstance(textual_input, (torch.Tensor)) + + if (not isinstance_visual_input) and (not isinstance_textual_input): + raise Exception("[CLIP ONNX] Please, choose a dummy input") + elif not isinstance_visual_input: + print("[CLIP ONNX] Convert only textual model") + elif not isinstance_textual_input: + print("[CLIP ONNX] Convert only visual model") + + if isinstance_visual_input: + self.visual_flag = True + if verbose: + print("[CLIP ONNX] Start convert visual model") + self.convert_visual(visual_input, visual_wrapper, visual_export_params) + if verbose: + print("[CLIP ONNX] Start check visual model") + self.onnx_checker(self.visual_path) + + if isinstance_textual_input: + self.textual_flag = True + if verbose: + print("[CLIP ONNX] Start convert textual model") + self.convert_textual(textual_input, textual_wrapper, textual_export_params) + if verbose: + print("[CLIP ONNX] Start check textual model") + self.onnx_checker(self.textual_path) + + if verbose: + print("[CLIP ONNX] Models converts successfully") diff --git a/deployment/clip_onnx/clip_onnx.py b/deployment/clip_onnx/clip_onnx.py new file mode 100644 index 0000000..7fad0f6 --- /dev/null +++ b/deployment/clip_onnx/clip_onnx.py @@ -0,0 +1,67 @@ +from .clip_converter import clip_converter +import torch +import onnxruntime + + +class clip_onnx(clip_converter): + def __init__(self, model=None, + visual_path: str = "clip_visual.onnx", + textual_path: str = "clip_textual.onnx"): + if not isinstance(model, (type(None))): + super().__init__(model, visual_path, textual_path) + else: + print("[CLIP ONNX] Load mode") + + def load_onnx(self, visual_path=None, textual_path=None, logit_scale=None): + if visual_path and textual_path: + if not logit_scale: + raise Exception("For this mode logit_scale must be specified. Example: model.logit_scale.exp()") + self.logit_scale = logit_scale + if visual_path: + self.visual_path = visual_path + self.visual_flag = True + if textual_path: + self.textual_path = textual_path + self.textual_flag = True + + def start_sessions(self, providers=['TensorrtExecutionProvider', + 'CUDAExecutionProvider', + 'CPUExecutionProvider']): + if self.visual_flag: + self.visual_session = onnxruntime.InferenceSession(self.visual_path, + providers=providers) + if self.textual_flag: + self.textual_session = onnxruntime.InferenceSession(self.textual_path, + providers=providers) + + def visual_run(self, onnx_image): + onnx_input_image = {self.visual_session.get_inputs()[0].name: onnx_image} + visual_output, = self.visual_session.run(None, onnx_input_image) + return visual_output + + def textual_run(self, onnx_text): + onnx_input_text = {self.textual_session.get_inputs()[0].name: onnx_text} + textual_output, = self.textual_session.run(None, onnx_input_text) + return textual_output + + def __call__(self, image, text, device: str = "cpu"): + assert self.visual_flag and self.textual_flag + image_features = torch.from_numpy(self.visual_run(image)).to(device) + text_features = torch.from_numpy(self.textual_run(text)).to(device) + + # normalized features + image_features = image_features / image_features.norm(dim=-1, keepdim=True) + text_features = text_features / text_features.norm(dim=-1, keepdim=True) + + # cosine similarity as logits + logits_per_image = self.logit_scale * image_features @ text_features.t() + logits_per_text = logits_per_image.t() + + # shape = [global_batch_size, global_batch_size] + return logits_per_image, logits_per_text + + def encode_image(self, image): + return self.visual_run(image) + + def encode_text(self, text): + return self.textual_run(text) diff --git a/deployment/clip_onnx/utils.py b/deployment/clip_onnx/utils.py new file mode 100644 index 0000000..19c3257 --- /dev/null +++ b/deployment/clip_onnx/utils.py @@ -0,0 +1,54 @@ +import torch.nn.functional as F +import torch +from torch import nn + + +class Textual(nn.Module): + def __init__(self, model): + super().__init__() + self.transformer = model.transformer + self.positional_embedding = model.positional_embedding + self.transformer = model.transformer + self.ln_final = model.ln_final + self.text_projection = model.text_projection + self.token_embedding = model.token_embedding + + def forward(self, text): + x = self.token_embedding(text) # [batch_size, n_ctx, d_model] + + x = x + self.positional_embedding + x = x.permute(1, 0, 2) # NLD -> LND + x = self.transformer(x) + x = x.permute(1, 0, 2) # LND -> NLD + x = self.ln_final(x) + + # x.shape = [batch_size, n_ctx, transformer.width] + # take features from the eot embedding (eot_token is the highest number in each sequence) + # needs .float() before .argmax( ) to work + x = x[torch.arange(x.shape[0]), text.float().argmax(dim=-1)] @ self.text_projection + + return x + + +def attention(self, x: torch.Tensor): + # onnx doesn't like multi_head_attention_forward so this is a reimplementation + q, k, v = (torch.einsum("tbh, oh -> tbo", x, self.attn.in_proj_weight) + self.attn.in_proj_bias).contiguous().chunk( + 3, dim=-1) + tgt_len = q.shape[0] + bsz = q.shape[1] + num_heads = self.attn.num_heads + head_dim = q.shape[2] // num_heads + attn_output, attn_output_weights = F._scaled_dot_product_attention( + q.reshape(tgt_len, bsz * num_heads, head_dim).transpose(0, 1), + k.reshape(tgt_len, bsz * num_heads, head_dim).transpose(0, 1), + v.reshape(tgt_len, bsz * num_heads, head_dim).transpose(0, 1), None, 0.0 + ) + attn_output = attn_output.transpose(0, 1).contiguous().view(q.shape) + attn_output = F.linear(attn_output, self.attn.out_proj.weight, self.attn.out_proj.bias) + return attn_output + + +DEFAULT_EXPORT = dict(input_names=['input'], output_names=['output'], + export_params=True, verbose=False, opset_version=12, + do_constant_folding=True, + dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}}) diff --git a/deployment/color_pick.py b/deployment/color_pick.py new file mode 100644 index 0000000..0d57fa6 --- /dev/null +++ b/deployment/color_pick.py @@ -0,0 +1,37 @@ +''' +Reference: https://blog.csdn.net/dongfuguo/article/details/118704759 +''' + +import tkinter +import tkinter.filedialog +import tkinter.messagebox +from PIL import ImageGrab, Image, ImageTk + +import ctypes +try: + ctypes.windll.shcore.SetProcessDpiAwareness(2) # if your windows version >= 8.1 +except: + ctypes.windll.user32.SetProcessDPIAware() # win 8.0 or less + +class ColorPick: + def __init__(self,root): + self.img = ImageGrab.grab() + screenWidth, screenHeight = self.img.size + + self.top = tkinter.Toplevel(root, width=screenWidth, height=screenHeight) + self.picked_color = (0,0,0) + self.top.overrideredirect(True) + self.tk_image = ImageTk.PhotoImage(self.img) + + self.canvas = tkinter.Canvas(self.top, bg='white', width=screenWidth, height=screenHeight,cursor='target') + + self.canvas.create_image(screenWidth // 2, screenHeight // 2, image=self.tk_image) + + def onLeftButtonDown(event): + color = self.img.getpixel((event.x, event.y)) + self.picked_color = color + self.top.destroy() + self.canvas.bind('', onLeftButtonDown) + self.canvas.pack(fill=tkinter.BOTH, expand=tkinter.YES) + def get_color(self): + return self.picked_color diff --git a/deployment/convert_to_onnx_models.py b/deployment/convert_to_onnx_models.py new file mode 100644 index 0000000..a227403 --- /dev/null +++ b/deployment/convert_to_onnx_models.py @@ -0,0 +1,205 @@ +import numpy as np +import torch +import sys +sys.path.append("..") +import onnx +import onnxruntime + +def get_mae_onnx_model(): + from mae_encoder import mae_vit_base_patch16_dec512d8b + device = "cuda" if torch.cuda.is_available() else "cpu" + mae_encoder_dir = './pytorch_models/mae_visualize_vit_base.pth' + mae_encoder_model = mae_vit_base_patch16_dec512d8b() + mae_cp = torch.load(mae_encoder_dir, map_location='cpu') + msg = mae_encoder_model.load_state_dict(mae_cp['model'],strict=False) + print(msg) + mae_encoder_model = mae_encoder_model.to(device) + mae_encoder_model.eval() + + export_onnx_file = './onnx_models/mae_encoder_vitb_onnx.onnx' + x = torch.onnx.export(mae_encoder_model, + torch.randn(1,3,224,224,device=device), + export_onnx_file, + verbose=False, + input_names=['input'], + output_names=['output'], + opset_version=12, + do_constant_folding=True, + #dynamic_axes={'input':{0:"batch_size",2:"h"},"output":{0:"batch_size"},} + ) + print(x) + net = onnx.load(export_onnx_file) + onnx.checker.check_model(net) + onnx.helper.printable_graph(net.graph) + + #test onnx model + session = onnxruntime.InferenceSession(export_onnx_file, providers=['CUDAExecutionProvider','CPUExecutionProvider']) + input = np.random.rand(1,3,224,224).astype('float32') + out_r = session.run(None, {"input":input}) + print(out_r[0].shape) + print(np.mean(out_r[0])) + torch.set_printoptions(precision=12) + input_tensor = torch.from_numpy(input).to(device) + out = mae_encoder_model(input_tensor) + print(out.size()) + print(torch.mean(out)) + +def get_clip_onnx_model(): + from clip_onnx import clip_onnx + import clip + device = "cpu" + model, preprocess = clip.load("ViT-B/16", download_root='./pytorch_models',device=device) + + + image = torch.randn(1,3,224,224,device=device) + text = clip.tokenize(["a diagram"]).cpu().to(device) # [3, 77] + + visual_path = "./onnx_models/clip_visual.onnx" + textual_path = "./onnx_models/clip_textual.onnx" + + onnx_model = clip_onnx(model, visual_path=visual_path, textual_path=textual_path) + onnx_model.convert2onnx(image, text, verbose=True) + + #image ecoder test + torch.set_printoptions(precision=10) + session = onnxruntime.InferenceSession(visual_path, providers=['CUDAExecutionProvider','CPUExecutionProvider']) + + input = np.random.rand(1,3,224,224).astype('float32') + out_image = session.run(None, {"input":input}) + print(out_image[0].shape, np.mean(out_image[0])) + + input_tensor = torch.from_numpy(input).to(device) + model = model.to(device) + out_feature = model.encode_image(input_tensor) + print(out_feature.size(), torch.mean(out_feature)) + + #text encoder test + session = onnxruntime.InferenceSession(textual_path, providers=['CUDAExecutionProvider','CPUExecutionProvider']) + input = text.numpy() + out_text = session.run(None, {"input": input}) + print(out_text[0].shape, np.mean(out_text[0])) + + input_tensor = text.to(device) + out_feature = model.encode_text(input_tensor) + print(out_feature.size(), torch.mean(out_feature)) + + +def get_color_decoder_onnx_model(): + from color_decoder import mae_color_decoder_base + export_onnx_file = './onnx_models/color_decoder_onnx.onnx' + device = "cuda" if torch.cuda.is_available() else "cpu" + color_decoder_dir = './pytorch_models/color_decoder.pth' + color_decoder_model = mae_color_decoder_base() + color_decoder_cp = torch.load(color_decoder_dir, map_location='cpu') + msg = color_decoder_model.load_state_dict(color_decoder_cp, strict=False) + print(msg) + color_decoder_model = color_decoder_model.to(device) + color_decoder_model.eval() + color_decoder_model = color_decoder_model + x = torch.onnx.export(color_decoder_model, + (torch.randn(1,197,768,device=device),torch.randn(1,512,device=device),torch.randn(1,3,224,224,device=device)), + export_onnx_file, + verbose=False, + input_names=['input_mae','input_clip','input_color'], + output_names=['output'], + opset_version=12, + do_constant_folding=True, + #dynamic_axes={'input':{0:"batch_size"},"output":{0:"batch_size"},} + ) + net = onnx.load(export_onnx_file) + onnx.checker.check_model(net) + onnx.helper.printable_graph(net.graph) + + #test onnx model + session = onnxruntime.InferenceSession(export_onnx_file, providers=['CUDAExecutionProvider','CPUExecutionProvider']) + input_mae = np.random.rand(1,197,768).astype('float32') + input_clip = np.random.rand(1, 512).astype('float32') + input_color = np.random.rand(1, 3, 224, 224).astype('float32') + out_r = session.run(None, {"input_mae":input_mae,"input_clip":input_clip,"input_color":input_color}) + print(out_r[0].shape) + print(np.mean(out_r[0])) + torch.set_printoptions(precision=10) + input_mae_tensor = torch.from_numpy(input_mae).cuda() + input_clip_tensor = torch.from_numpy(input_clip).cuda() + input_color_tensor = torch.from_numpy(input_color).cuda() + + out = color_decoder_model( input_mae_tensor,input_clip_tensor,input_color_tensor) + print(out.size()) + print(torch.mean(out)) + +def get_super_color_onnx_model(): + from super_color import SuperColor + export_onnx_file = './onnx_models/super_color_onnx.onnx' + device = "cuda" if torch.cuda.is_available() else "cpu" + + super_color_dir = './pytorch_models/super_color.pth' + super_color_model = SuperColor(kernel_size=5, group=4) + super_color_checkpoint = torch.load(super_color_dir, map_location='cpu') + msg = super_color_model.load_state_dict(super_color_checkpoint, strict=False) + super_color_model = super_color_model.to(device) + super_color_model.eval() + super_color_model = super_color_model + print(msg) + + x = torch.onnx.export(super_color_model, + (torch.randn(1, 3, 448, 448, device='cuda'), + torch.randn(1, 3, 448, 448, device='cuda'), + torch.randn(1, 3, 448, 448, device='cuda')), + export_onnx_file, + verbose=False, + input_names=['input_interp', 'input_gray','input_color'], + output_names=['output'], + opset_version=12, + do_constant_folding=True, + dynamic_axes={'input_interp':{2:"width",3:"height"},'input_gray':{2:"width",3:"height"},'input_color':{2:"width",3:"height"},"output":{2:"width",3:"height"},} + ) + print(x) + net = onnx.load(export_onnx_file) + onnx.checker.check_model(net) + onnx.helper.printable_graph(net.graph) + + #test inputs with different resolutions + session = onnxruntime.InferenceSession(export_onnx_file, providers=['CUDAExecutionProvider','CPUExecutionProvider']) + # test 448x448 inputs + input_interp = np.random.rand(1, 3,448, 448).astype('float32') + input_gray = np.random.rand(1,3, 448, 448).astype('float32') + input_color = np.random.rand(1,3, 448, 448).astype('float32') + + out_r = session.run(None, {"input_interp": input_interp, "input_gray": input_gray,"input_color":input_color}) + print(out_r[0].shape) + print(np.mean(out_r[0])) + torch.set_printoptions(precision=10) + input_interp = torch.from_numpy(input_interp).cuda() + input_gray = torch.from_numpy(input_gray).cuda() + input_color = torch.from_numpy(input_color).cuda() + + + out = super_color_model(input_interp, input_gray,input_color) + print(out.size()) + print(torch.mean(out)) + + #test 896x896 inputs + input_interp = np.random.rand(1, 3, 896, 896).astype('float32') + input_gray = np.random.rand(1, 3, 896, 896).astype('float32') + input_color = np.random.rand(1, 3, 896, 896).astype('float32') + + out_r = session.run(None, {"input_interp": input_interp, "input_gray": input_gray,"input_color":input_color}) + print(out_r[0].shape) + print(np.mean(out_r[0])) + torch.set_printoptions(precision=10) + input_interp = torch.from_numpy(input_interp).cuda() + input_gray = torch.from_numpy(input_gray).cuda() + input_color = torch.from_numpy(input_color).cuda() + + out = super_color_model(input_interp, input_gray,input_color) + print(out.size()) + print(torch.mean(out)) + + +if __name__=='__main__': + #get_mae_onnx_model() + get_color_decoder_onnx_model() + #get_clip_onnx_model() + #get_super_color_onnx_model() + + diff --git a/deployment/example_white.jpg b/deployment/example_white.jpg new file mode 100644 index 0000000..fb43580 Binary files /dev/null and b/deployment/example_white.jpg differ diff --git a/deployment/feather_icon.ico b/deployment/feather_icon.ico new file mode 100644 index 0000000..267436d Binary files /dev/null and b/deployment/feather_icon.ico differ diff --git a/deployment/simple_tokenizer.py b/deployment/simple_tokenizer.py new file mode 100644 index 0000000..0a66286 --- /dev/null +++ b/deployment/simple_tokenizer.py @@ -0,0 +1,132 @@ +import gzip +import html +import os +from functools import lru_cache + +import ftfy +import regex as re + + +@lru_cache() +def default_bpe(): + return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") + + +@lru_cache() +def bytes_to_unicode(): + """ + Returns list of utf-8 byte and a corresponding list of unicode strings. + The reversible bpe codes work on unicode strings. + This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. + When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. + This is a signficant percentage of your normal, say, 32K bpe vocab. + To avoid that, we want lookup tables between utf-8 bytes and unicode strings. + And avoids mapping to whitespace/control characters the bpe code barfs on. + """ + bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) + cs = bs[:] + n = 0 + for b in range(2**8): + if b not in bs: + bs.append(b) + cs.append(2**8+n) + n += 1 + cs = [chr(n) for n in cs] + return dict(zip(bs, cs)) + + +def get_pairs(word): + """Return set of symbol pairs in a word. + Word is represented as tuple of symbols (symbols being variable-length strings). + """ + pairs = set() + prev_char = word[0] + for char in word[1:]: + pairs.add((prev_char, char)) + prev_char = char + return pairs + + +def basic_clean(text): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +def whitespace_clean(text): + text = re.sub(r'\s+', ' ', text) + text = text.strip() + return text + + +class SimpleTokenizer(object): + def __init__(self, bpe_path: str = default_bpe()): + self.byte_encoder = bytes_to_unicode() + self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} + merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') + merges = merges[1:49152-256-2+1] + merges = [tuple(merge.split()) for merge in merges] + vocab = list(bytes_to_unicode().values()) + vocab = vocab + [v+'' for v in vocab] + for merge in merges: + vocab.append(''.join(merge)) + vocab.extend(['<|startoftext|>', '<|endoftext|>']) + self.encoder = dict(zip(vocab, range(len(vocab)))) + self.decoder = {v: k for k, v in self.encoder.items()} + self.bpe_ranks = dict(zip(merges, range(len(merges)))) + self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} + self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) + + def bpe(self, token): + if token in self.cache: + return self.cache[token] + word = tuple(token[:-1]) + ( token[-1] + '',) + pairs = get_pairs(word) + + if not pairs: + return token+'' + + while True: + bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) + if bigram not in self.bpe_ranks: + break + first, second = bigram + new_word = [] + i = 0 + while i < len(word): + try: + j = word.index(first, i) + new_word.extend(word[i:j]) + i = j + except: + new_word.extend(word[i:]) + break + + if word[i] == first and i < len(word)-1 and word[i+1] == second: + new_word.append(first+second) + i += 2 + else: + new_word.append(word[i]) + i += 1 + new_word = tuple(new_word) + word = new_word + if len(word) == 1: + break + else: + pairs = get_pairs(word) + word = ' '.join(word) + self.cache[token] = word + return word + + def encode(self, text): + bpe_tokens = [] + text = whitespace_clean(basic_clean(text)).lower() + for token in re.findall(self.pat, text): + token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) + bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) + return bpe_tokens + + def decode(self, tokens): + text = ''.join([self.decoder[token] for token in tokens]) + text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') + return text diff --git a/get_clip_features.py b/get_clip_features.py new file mode 100644 index 0000000..8a9d003 --- /dev/null +++ b/get_clip_features.py @@ -0,0 +1,53 @@ +import argparse +import clip + +#model, preprocess = clip.load("ViT-B/16", device=device) +from util.datasets_clip import build_dataset_clip +import torch +import numpy as np + +BATCH_SIZE = 20 +# +def get_args_parser(): + parser = argparse.ArgumentParser('extract clip features', add_help=False) + parser.add_argument('--input_size', default=224, type=int, + help='images input size') + parser.add_argument('--data_path', default='E://data//carton_subset//val', type=str, + help='dataset path') + return parser + + +def main(args): + dataset_test = build_dataset_clip(args=args) + sampler_test = torch.utils.data.SequentialSampler(dataset_test) + data_loader_test = torch.utils.data.DataLoader( + dataset_test, sampler=sampler_test, + batch_size=BATCH_SIZE, + num_workers=2, + pin_memory=True, + drop_last=False + ) + + device = "cuda" if torch.cuda.is_available() else "cpu" + model, preprocess = clip.load("ViT-B/16", download_root='./models', device=device) + all_features = None + all_path = [] + with torch.no_grad(): + for samples, targets, path in data_loader_test: + samples = samples.to(device, non_blocking=True) + image_features = model.encode_image(samples) + image_features = image_features.detach().cpu().numpy() + if all_features is None: + all_features = image_features + else: + all_features = np.concatenate([all_features, image_features], axis=0) + all_path += path + if all_features.shape[0] % 100 ==0: + print(all_features.shape[0]) + + np.save('clip_features_subset.npy', all_features) + +if __name__=='__main__': + args = get_args_parser() + args = args.parse_args() + main(args) \ No newline at end of file diff --git a/get_mae_features.py b/get_mae_features.py new file mode 100644 index 0000000..e2007cd --- /dev/null +++ b/get_mae_features.py @@ -0,0 +1,62 @@ +import argparse +import numpy as np +import torch +from util.datasets import build_dataset +from mae import mae_vit_base_patch16_dec512d8b +import os +BATCH_SIZE = 20 +# +def get_args_parser(): + parser = argparse.ArgumentParser('extract mae features', add_help=False) + parser.add_argument('--input_size', default=224, type=int, + help='images input size') + parser.add_argument('--data_path', default='', type=str, + help='path//to//dataset') + parser.add_argument('--output_path', default='', type=str, + help='path//to//output//mae_feature//') + return parser + +def main(args): + model_dir = './models/mae_visualize_vit_base.pth' + dataset_test = build_dataset(args=args) + sampler_test = torch.utils.data.SequentialSampler(dataset_test) + data_loader_test = torch.utils.data.DataLoader( + dataset_test, sampler=sampler_test, + batch_size=BATCH_SIZE, + num_workers=2, + pin_memory=True, + drop_last=False + ) + + checkpoint = torch.load(model_dir, map_location='cpu') + checkpoint_model = checkpoint['model'] + model = mae_vit_base_patch16_dec512d8b() + msg = model.load_state_dict(checkpoint_model, strict=False) + print(msg) + device = "cuda" if torch.cuda.is_available() else "cpu" + model = model.cuda() + model.eval() + + mae_feature_names = [] + for samples, targets, path in data_loader_test: + samples = samples.to(device, non_blocking=True) + x, mask, ids_restore = model.forward_encoder(samples,mask_ratio=0.0) + x = x.detach().cpu().numpy() + mae_features = x + for b in range(x.shape[0]): + img_name = path[b].split('\\')[-1].split('.')[-2] + feature_name = os.path.join(args.output_path, 'mae_feature_{}.npy'.format(img_name)) + mae_feature_names.append(feature_name) + np.save(feature_name,mae_features[b]) + + if len(mae_feature_names) % 100 ==0: + print(len(mae_feature_names)) + + with open('features/mae_feature_names_train.txt', 'w') as file: + for i, p in enumerate(mae_feature_names): + file.write(p + '\n') + +if __name__ == '__main__': + args = get_args_parser() + args = args.parse_args() + main(args) \ No newline at end of file diff --git a/mae.py b/mae.py new file mode 100644 index 0000000..74f79e0 --- /dev/null +++ b/mae.py @@ -0,0 +1,255 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# References: +# timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm +# DeiT: https://github.com/facebookresearch/deit +# -------------------------------------------------------- + +from functools import partial + +import torch +import torch.nn as nn + +from timm.models.vision_transformer import PatchEmbed, Block + +from util.pos_embed import get_2d_sincos_pos_embed + + +class MaskedAutoencoderViT(nn.Module): + """ Masked Autoencoder with VisionTransformer backbone + """ + + def __init__(self, img_size=224, patch_size=16, in_chans=3, + embed_dim=1024, depth=24, num_heads=16, + decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, + mlp_ratio=4., norm_layer=nn.LayerNorm, norm_pix_loss=False): + super().__init__() + + # -------------------------------------------------------------------------- + # MAE encoder specifics + self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim) + num_patches = self.patch_embed.num_patches + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim), + requires_grad=False) # fixed sin-cos embedding + + self.blocks = nn.ModuleList([ + Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer)#remove qk_scale=None, + for i in range(depth)]) + self.norm = norm_layer(embed_dim) + # -------------------------------------------------------------------------- + + # -------------------------------------------------------------------------- + # MAE decoder specifics + self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True) + + self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim)) + + self.decoder_pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, decoder_embed_dim), + requires_grad=False) # fixed sin-cos embedding + + self.decoder_blocks = nn.ModuleList([ + Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer)#remove qk_scale=None, + for i in range(decoder_depth)]) + + self.decoder_norm = norm_layer(decoder_embed_dim) + self.decoder_pred = nn.Linear(decoder_embed_dim, patch_size ** 2 * in_chans, bias=True) # decoder to patch + # -------------------------------------------------------------------------- + + self.norm_pix_loss = norm_pix_loss + + self.initialize_weights() + + def initialize_weights(self): + # initialization + # initialize (and freeze) pos_embed by sin-cos embedding + pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.patch_embed.num_patches ** .5), + cls_token=True) + self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) + + decoder_pos_embed = get_2d_sincos_pos_embed(self.decoder_pos_embed.shape[-1], + int(self.patch_embed.num_patches ** .5), cls_token=True) + self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0)) + + # initialize patch_embed like nn.Linear (instead of nn.Conv2d) + w = self.patch_embed.proj.weight.data + torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + + # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.) + torch.nn.init.normal_(self.cls_token, std=.02) + torch.nn.init.normal_(self.mask_token, std=.02) + + # initialize nn.Linear and nn.LayerNorm + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + # we use xavier_uniform following official JAX ViT: + torch.nn.init.xavier_uniform_(m.weight) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def patchify(self, imgs): + """ + imgs: (N, 3, H, W) + x: (N, L, patch_size**2 *3) + """ + p = self.patch_embed.patch_size[0] + assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0 + + h = w = imgs.shape[2] // p + x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p)) + x = torch.einsum('nchpwq->nhwpqc', x) + x = x.reshape(shape=(imgs.shape[0], h * w, p ** 2 * 3)) + return x + + def unpatchify(self, x): + """ + x: (N, L, patch_size**2 *3) + imgs: (N, 3, H, W) + """ + p = self.patch_embed.patch_size[0] + h = w = int(x.shape[1] ** .5) + assert h * w == x.shape[1] + + x = x.reshape(shape=(x.shape[0], h, w, p, p, 3)) + x = torch.einsum('nhwpqc->nchpwq', x) + imgs = x.reshape(shape=(x.shape[0], 3, h * p, h * p)) + return imgs + + def random_masking(self, x, mask_ratio): + """ + Perform per-sample random masking by per-sample shuffling. + Per-sample shuffling is done by argsort random noise. + x: [N, L, D], sequence + """ + N, L, D = x.shape # batch, length, dim + len_keep = int(L * (1 - mask_ratio)) + + noise = torch.rand(N, L, device=x.device) # noise in [0, 1] + + # sort noise for each sample + ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove + ids_restore = torch.argsort(ids_shuffle, dim=1) + + # keep the first subset + ids_keep = ids_shuffle[:, :len_keep] + x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D)) + + # generate the binary mask: 0 is keep, 1 is remove + mask = torch.ones([N, L], device=x.device) + mask[:, :len_keep] = 0 + # unshuffle to get the binary mask + mask = torch.gather(mask, dim=1, index=ids_restore) + + return x_masked, mask, ids_restore + + def forward_encoder(self, x, mask_ratio): + # embed patches + x = self.patch_embed(x) + + # add pos embed w/o cls token + x = x + self.pos_embed[:, 1:, :] + + # masking: length -> length * mask_ratio + x, mask, ids_restore = self.random_masking(x, mask_ratio) + + # append cls token + cls_token = self.cls_token + self.pos_embed[:, :1, :] + cls_tokens = cls_token.expand(x.shape[0], -1, -1) + x = torch.cat((cls_tokens, x), dim=1) + + # apply Transformer blocks + for blk in self.blocks: + x = blk(x) + x = self.norm(x) + + return x, mask, ids_restore + + def forward_decoder(self, x, ids_restore): + # embed tokens + x = self.decoder_embed(x) + + # append mask tokens to sequence + mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1) + x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) # no cls token + x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) # unshuffle + x = torch.cat([x[:, :1, :], x_], dim=1) # append cls token + + # add pos embed + x = x + self.decoder_pos_embed + + # apply Transformer blocks + for blk in self.decoder_blocks: + x = blk(x) + x = self.decoder_norm(x) + + # predictor projection + x = self.decoder_pred(x) + + # remove cls token + x = x[:, 1:, :] + + return x + + def forward_loss(self, imgs, pred, mask): + """ + imgs: [N, 3, H, W] + pred: [N, L, p*p*3] + mask: [N, L], 0 is keep, 1 is remove, + """ + target = self.patchify(imgs) + if self.norm_pix_loss: + mean = target.mean(dim=-1, keepdim=True) + var = target.var(dim=-1, keepdim=True) + target = (target - mean) / (var + 1.e-6) ** .5 + + loss = (pred - target) ** 2 + loss = loss.mean(dim=-1) # [N, L], mean loss per patch + + loss = (loss * mask).sum() / mask.sum() # mean loss on removed patches + return loss + + def forward(self, imgs, mask_ratio=0.75): + latent, mask, ids_restore = self.forward_encoder(imgs, mask_ratio) + pred = self.forward_decoder(latent, ids_restore) # [N, L, p*p*3] + loss = self.forward_loss(imgs, pred, mask) + return loss, pred, mask + + +def mae_vit_base_patch16_dec512d8b(**kwargs): + model = MaskedAutoencoderViT( + patch_size=16, embed_dim=768, depth=12, num_heads=12, + decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, + mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + return model + + +def mae_vit_large_patch16_dec512d8b(**kwargs): + model = MaskedAutoencoderViT( + patch_size=16, embed_dim=1024, depth=24, num_heads=16, + decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, + mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + return model + + +def mae_vit_huge_patch14_dec512d8b(**kwargs): + model = MaskedAutoencoderViT( + patch_size=14, embed_dim=1280, depth=32, num_heads=16, + decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, + mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + return model + + +# set recommended archs +mae_vit_base_patch16 = mae_vit_base_patch16_dec512d8b # decoder: 512 dim, 8 blocks +mae_vit_large_patch16 = mae_vit_large_patch16_dec512d8b # decoder: 512 dim, 8 blocks +mae_vit_huge_patch14 = mae_vit_huge_patch14_dec512d8b # decoder: 512 dim, 8 blocks diff --git a/mae_encoder.py b/mae_encoder.py new file mode 100644 index 0000000..2df151b --- /dev/null +++ b/mae_encoder.py @@ -0,0 +1,152 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# References: +# timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm +# DeiT: https://github.com/facebookresearch/deit +# -------------------------------------------------------- + +from functools import partial + +import torch +import torch.nn as nn + +from timm.models.vision_transformer import PatchEmbed, Block + +from util.pos_embed import get_2d_sincos_pos_embed + +class MaskedAutoencoderViT(nn.Module): + """ Masked Autoencoder with VisionTransformer backbone + """ + def __init__(self, img_size=224, patch_size=16, in_chans=3, + embed_dim=1024, depth=24, num_heads=16, + mlp_ratio=4., norm_layer=nn.LayerNorm): + super().__init__() + + # -------------------------------------------------------------------------- + # MAE encoder specifics + self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim) + num_patches = self.patch_embed.num_patches + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim), + requires_grad=False) # fixed sin-cos embedding + + self.blocks = nn.ModuleList([ + Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer)#remove qk_scale=None, + for i in range(depth)]) + self.norm = norm_layer(embed_dim) + # -------------------------------------------------------------------------- + self.initialize_weights() + + def initialize_weights(self): + # initialization + # initialize (and freeze) pos_embed by sin-cos embedding + pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.patch_embed.num_patches ** .5), + cls_token=True) + self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) + + w = self.patch_embed.proj.weight.data + torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + + torch.nn.init.normal_(self.cls_token, std=.02) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + # we use xavier_uniform following official JAX ViT: + torch.nn.init.xavier_uniform_(m.weight) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def patchify(self, imgs): + """ + imgs: (N, 3, H, W) + x: (N, L, patch_size**2 *3) + """ + p = self.patch_embed.patch_size[0] + assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0 + + h = w = imgs.shape[2] // p + x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p)) + x = torch.einsum('nchpwq->nhwpqc', x) + x = x.reshape(shape=(imgs.shape[0], h * w, p ** 2 * 3)) + return x + + def unpatchify(self, x): + """ + x: (N, L, patch_size**2 *3) + imgs: (N, 3, H, W) + """ + p = self.patch_embed.patch_size[0] + h = w = int(x.shape[1] ** .5) + assert h * w == x.shape[1] + + x = x.reshape(shape=(x.shape[0], h, w, p, p, 3)) + x = torch.einsum('nhwpqc->nchpwq', x) + imgs = x.reshape(shape=(x.shape[0], 3, h * p, h * p)) + return imgs + + def forward(self, x): + # embed patches + x = self.patch_embed(x) + + # add pos embed w/o cls token + x = x + self.pos_embed[:, 1:, :] + + # append cls token + cls_token = self.cls_token + self.pos_embed[:, :1, :] + cls_tokens = cls_token.expand(x.shape[0], -1, -1) + x = torch.cat((cls_tokens, x), dim=1) + + # apply Transformer blocks + for blk in self.blocks: + x = blk(x) + x = self.norm(x) + + return x + + def forward_loss(self, imgs, pred, mask): + """ + imgs: [N, 3, H, W] + pred: [N, L, p*p*3] + mask: [N, L], 0 is keep, 1 is remove, + """ + target = self.patchify(imgs) + loss = (pred - target) ** 2 + loss = loss.mean(dim=-1) # [N, L], mean loss per patch + + loss = (loss * mask).sum() / mask.sum() # mean loss on removed patches + return loss + +def mae_vit_base_patch16_dec512d8b(**kwargs): + model = MaskedAutoencoderViT( + patch_size=16, embed_dim=768, depth=12, num_heads=12, + mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + return model + + +def mae_vit_large_patch16_dec512d8b(**kwargs): + model = MaskedAutoencoderViT( + patch_size=16, embed_dim=1024, depth=24, num_heads=16, + mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + return model + + +def mae_vit_huge_patch14_dec512d8b(**kwargs): + model = MaskedAutoencoderViT( + patch_size=14, embed_dim=1280, depth=32, num_heads=16, + mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + return model + +# set recommended archs +mae_vit_base_patch16 = mae_vit_base_patch16_dec512d8b # decoder: 512 dim, 8 blocks +mae_vit_large_patch16 = mae_vit_large_patch16_dec512d8b # decoder: 512 dim, 8 blocks +mae_vit_huge_patch14 = mae_vit_huge_patch14_dec512d8b # decoder: 512 dim, 8 blocks diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..ba28737 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,2 @@ +timm==0.6.7 +clip diff --git a/super_color.py b/super_color.py new file mode 100644 index 0000000..d34437f --- /dev/null +++ b/super_color.py @@ -0,0 +1,81 @@ +import torch +import torch.nn as nn + +class ConvBlock(nn.Module): + def __init__(self,in_features, kernel_size=3, hidden_features=None, out_features=None, + norm_layer=nn.LayerNorm, act_layer=nn.ReLU, group=8): + super().__init__() + self.conv1 = nn.Conv2d(in_features, hidden_features, 1, stride=1, padding=0, bias=False) + self.act1 = act_layer() + self.norm1 = norm_layer(hidden_features) + + self.conv2 = nn.Conv2d(hidden_features, hidden_features, kernel_size, stride=1, padding=(kernel_size-1)//2, bias=False, groups=group) + self.act2 = act_layer() + self.norm2 = norm_layer(hidden_features) + + self.conv3 = nn.Conv2d(hidden_features, out_features, 1, stride=1, padding=0, bias=False) + self.act3 = act_layer() + self.norm3 = norm_layer(out_features) + + def forward(self, x): + short_cut = x + + x = self.conv1(x) + x = self.norm1(x) + x = self.act1(x) + + x = self.conv2(x) + x = self.norm2(x) + x = self.act2(x) + + x = self.conv3(x) + x = self.norm3(x) + x = self.act3(x + short_cut) + return x + +class SuperColor(nn.Module): + def __init__(self, block_num=2, input_dim=16, kernel_size=3, group=4): + super().__init__() + image_dim = 9 + self.embedding = nn.Conv2d(image_dim, input_dim, 3, stride=1, groups=1, padding=1, bias=True) + self.blocks = nn.ModuleList([ + ConvBlock(input_dim, kernel_size=kernel_size, hidden_features=input_dim//2, group=group, out_features=16, norm_layer=nn.BatchNorm2d) + for i in range(block_num)]) + self.embedding_decoder = nn.Conv2d(16, 3, 3, stride=1, groups=1, padding=1, bias=True) + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + torch.nn.init.xavier_uniform_(m.weight) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + if m.bias is not None: + nn.init.constant_(m.bias, 0.) + + def forward(self, img_l_interp,gray_h,color_mask): + x = torch.cat([img_l_interp, gray_h],dim=1) + x = torch.cat([x, color_mask], dim=1) + x = self.embedding(x) + for b in self.blocks: + x = b(x) + x = self.embedding_decoder(x) + return x + + def forward_loss(self, pred, gray_h, img_h, alpha=0.0): + loss_l2 = (pred + gray_h - img_h) ** 2 + loss_l2 = loss_l2.mean() + loss_l1 = torch.abs((pred + gray_h - img_h)) + loss_l1 = loss_l1.mean() + return loss_l2 * (1.0 - alpha) + loss_l1 * alpha + + + + diff --git a/util/crop.py b/util/crop.py new file mode 100644 index 0000000..fcb2612 --- /dev/null +++ b/util/crop.py @@ -0,0 +1,42 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import math + +import torch + +from torchvision import transforms +from torchvision.transforms import functional as F + + +class RandomResizedCrop(transforms.RandomResizedCrop): + """ + RandomResizedCrop for matching TF/TPU implementation: no for-loop is used. + This may lead to results different with torchvision's version. + Following BYOL's TF code: + https://github.com/deepmind/deepmind-research/blob/master/byol/utils/dataset.py#L206 + """ + @staticmethod + def get_params(img, scale, ratio): + width, height = F._get_image_size(img) + area = height * width + + target_area = area * torch.empty(1).uniform_(scale[0], scale[1]).item() + log_ratio = torch.log(torch.tensor(ratio)) + aspect_ratio = torch.exp( + torch.empty(1).uniform_(log_ratio[0], log_ratio[1]) + ).item() + + w = int(round(math.sqrt(target_area * aspect_ratio))) + h = int(round(math.sqrt(target_area / aspect_ratio))) + + w = min(w, width) + h = min(h, height) + + i = torch.randint(0, height - h + 1, size=(1,)).item() + j = torch.randint(0, width - w + 1, size=(1,)).item() + + return i, j, h, w \ No newline at end of file diff --git a/util/dataset_autocolor.py b/util/dataset_autocolor.py new file mode 100644 index 0000000..cc76f0e --- /dev/null +++ b/util/dataset_autocolor.py @@ -0,0 +1,189 @@ +import numpy as np +from PIL import Image +from torchvision import transforms +from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from torchvision.transforms import InterpolationMode + +from torchvision.datasets import ImageFolder +from typing import Any, Callable, Optional, Tuple +import torch + + +class AutoColorImageFolder(ImageFolder): + def __init__( + self, + root: str, + transform: Optional[Callable] = None, + transform_l_s1: Optional[Callable] = None, + transform_h_s1: Optional[Callable] = None, + transform_s2: Optional[Callable] = None, + transform_gray: Optional[Callable] = None, + mae_feature_path = None, + clip_feature_path = None + ): + super().__init__( + root, + transform = transform + ) + self.imgs = self.samples + self.last_data = None + + self.transform_l_s1 = transform_l_s1 + self.transform_h_s1 = transform_h_s1 + self.transform_s2 = transform_s2 + self.transform_gray = transform_gray + if clip_feature_path: + self.clip_features = torch.from_numpy(np.load(clip_feature_path)) + else: + self.clip_features = None + + if mae_feature_path: + self.mae_feature_files = [] + with open(mae_feature_path, 'r') as file: + for line in file: + self.mae_feature_files.append(line.strip()) + else: + self.mae_feature_files = None + + def __getitem__(self, index: int) -> Tuple[Any, Any, Any]: + """ + Args: + index (int): Index + + Returns: + tuple: (sample, target) where target is class_index of the target class. + """ + path, target = self.samples[index] + sample = self.loader(path) + + if self.transform is not None: + img_l_interp = self.transform(sample) + color_mask = self.transform_s2(img_l_interp) + img_l_s1 = self.transform_l_s1(sample) + img_h_s1 = self.transform_h_s1(sample) + img_l_gray = self.transform_gray(img_l_s1) + img_h_gray = self.transform_gray(img_h_s1) + img_h = self.transform_s2(img_h_s1) + img_l = self.transform_s2(img_l_s1) + img_l_gray = self.transform_s2(img_l_gray) + img_h_gray = self.transform_s2(img_h_gray) + if self.mae_feature_files is not None: + mae_feature = np.load(self.mae_feature_files[index]) + mae_feature = torch.from_numpy(mae_feature).float() + else: + mae_feature = [] + if self.clip_features is not None: + clip_feature = self.clip_features[index] + else: + clip_feature = [] + + if self.target_transform is not None: + target = self.target_transform(target) + + return mae_feature, clip_feature, color_mask, img_l, img_l_gray, img_h, img_h_gray, target + +def build_dataset(args): + transform_l_s1 = transform_img_l_s1(args) + transform_h_s1 = transform_img_h_s1(args) + transform_s2 = transform_img_s2() + gray_transform = transformer_gray() + color_mask_transformer = ColorMask(img_size=args.input_size, p=args.colormask_prob, grids=[2,4,8,16]) + + root = args.data_path + dataset = AutoColorImageFolder(root, transform=color_mask_transformer, + transform_l_s1=transform_l_s1, + transform_h_s1=transform_h_s1, + transform_s2=transform_s2, + transform_gray=gray_transform, + mae_feature_path = args.mae_feature_path, + clip_feature_path = args.clip_feature_path) + return dataset + +def transform_img_h_s1(args): + t = [] + t.append( + transforms.Resize(args.input_size_supercolor, interpolation=InterpolationMode.BICUBIC), # to maintain same ratio w.r.t. 224 images + ) + t.append(transforms.CenterCrop(args.input_size_supercolor)) + return transforms.Compose(t) + +def transform_img_l_s1(args): + t = [] + t.append( + transforms.Resize(args.input_size, interpolation=InterpolationMode.BICUBIC), # to maintain same ratio w.r.t. 224 images + ) + t.append(transforms.CenterCrop(args.input_size)) + return transforms.Compose(t) + +def transform_img_s2(): + mean = IMAGENET_DEFAULT_MEAN + std = IMAGENET_DEFAULT_STD + t = [] + t.append(transforms.ToTensor()) + t.append(transforms.Normalize(mean, std)) + return transforms.Compose(t) + +def transformer_gray(): + return transforms.Grayscale(num_output_channels=3) + +class ColorMask(torch.nn.Module): + def __init__(self, img_size=224, p=0.2, grids=[4,8,16]): + super().__init__() + self.p = p + self.grids = grids + self.img_size = img_size + + def forward(self,img): + return self.random_color_mask(img,self.img_size) + + def random_color_mask(self, img, size=224): + p = self.p + if p < 1e-6: + return np.zeros((size,size,3),dtype=np.float32) + grids = self.grids + all_size = [size // g for g in grids] + output_img = 0. + prev_mask = 1.0 + for i in range(len(all_size)): + tmp_size = all_size[i] + p_tmp = np.random.rand() * p + + p_size = (tmp_size, tmp_size) + p_img = img.resize(p_size, resample=Image.Resampling.BICUBIC) + + p_mask = np.random.rand(tmp_size, tmp_size) + p_mask = p_mask[:, :, None] + p_mask = (p_mask < p_tmp).astype(np.float32) + + p_mask = p_mask.repeat(grids[i], axis=0).repeat(grids[i], axis=1) + p_img = np.array(p_img).astype(np.float32) + p_img = p_img.repeat(grids[i], axis=0).repeat(grids[i], axis=1) + + p_img = prev_mask * p_mask * p_img + prev_mask = (1.0 - p_mask) * prev_mask + output_img += p_img + + #code to show color mask example + # image = output_img.astype(np.uint8) + # pil_image = Image.fromarray(image) + # pil_image.show() + return output_img/255. + +if __name__ == '__main__': + class Test: + def __init__(self): + self.input_size=224 + self.input_size_supercolor = 448 + self.colormask_prob = 0.1 + self.mae_feature_path = None + self.clip_feature_path = None + args = Test() + args.input_size = 224 + args.data_path = 'E://data//carton_subset//train' + + data_set = build_dataset(args) + for i in range(10): + mae_feature, clip_feature, color_mask, img_l, img_l_gray, img_h, img_h_gray, target = data_set.__getitem__(i) + #print(sample.shape,target,path,clip_feature.shape,mae_feature.shape) + #print(get_image_backend()) + diff --git a/util/datasets.py b/util/datasets.py new file mode 100644 index 0000000..35af927 --- /dev/null +++ b/util/datasets.py @@ -0,0 +1,57 @@ +from torchvision import datasets, transforms + +from timm.data import create_transform +from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from torchvision.transforms import InterpolationMode + +from torchvision.datasets import ImageFolder +from typing import Any, Callable, cast, Dict, List, Optional, Tuple +class MyImageFolder(ImageFolder): + + def __getitem__(self, index: int) -> Tuple[Any, Any, Any]: + """ + Args: + index (int): Index + + Returns: + tuple: (sample, target) where target is class_index of the target class. + """ + path, target = self.samples[index] + sample = self.loader(path) + if self.transform is not None: + sample = self.transform(sample) + if self.target_transform is not None: + target = self.target_transform(target) + + return sample, target, path + + +def build_dataset(args): + transform = build_transform(args) + + root = args.data_path + dataset = MyImageFolder(root, transform=transform) + + print(dataset) + + return dataset + + +def build_transform(args): + mean = IMAGENET_DEFAULT_MEAN + std = IMAGENET_DEFAULT_STD + + t = [] + size = args.input_size + t.append( + transforms.Resize(size, interpolation=InterpolationMode.BICUBIC), # to maintain same ratio w.r.t. 224 images + ) + + #new add + t.append(transforms.Grayscale(num_output_channels=3)) + + t.append(transforms.CenterCrop(args.input_size)) + + t.append(transforms.ToTensor()) + t.append(transforms.Normalize(mean, std)) + return transforms.Compose(t) diff --git a/util/datasets_clip.py b/util/datasets_clip.py new file mode 100644 index 0000000..3fed971 --- /dev/null +++ b/util/datasets_clip.py @@ -0,0 +1,40 @@ +from typing import Any, Tuple +from torchvision import transforms +from torchvision.datasets import ImageFolder +from torchvision.transforms import InterpolationMode + +class MyImageFolder(ImageFolder): + + def __getitem__(self, index: int) -> Tuple[Any, Any, Any]: + """ + Args: + index (int): Index + + Returns: + tuple: (sample, target) where target is class_index of the target class. + """ + path, target = self.samples[index] + sample = self.loader(path) + if self.transform is not None: + sample = self.transform(sample) + if self.target_transform is not None: + target = self.target_transform(target) + + return sample, target, path + +def build_dataset_clip(args): + transform = build_transform( args) + root = args.data_path + dataset = MyImageFolder(root, transform=transform) + + return dataset + +def build_transform(args): + t = [] + t.append( + transforms.Resize(args.input_size, interpolation=InterpolationMode.BICUBIC), # to maintain same ratio w.r.t. 224 images + ) + t.append(transforms.CenterCrop(args.input_size)) + t.append(transforms.ToTensor()) + t.append(transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))) + return transforms.Compose(t) diff --git a/util/lars.py b/util/lars.py new file mode 100644 index 0000000..509c5f6 --- /dev/null +++ b/util/lars.py @@ -0,0 +1,47 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# LARS optimizer, implementation from MoCo v3: +# https://github.com/facebookresearch/moco-v3 +# -------------------------------------------------------- + +import torch + + +class LARS(torch.optim.Optimizer): + """ + LARS optimizer, no rate scaling or weight decay for parameters <= 1D. + """ + def __init__(self, params, lr=0, weight_decay=0, momentum=0.9, trust_coefficient=0.001): + defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum, trust_coefficient=trust_coefficient) + super().__init__(params, defaults) + + @torch.no_grad() + def step(self): + for g in self.param_groups: + for p in g['params']: + dp = p.grad + + if dp is None: + continue + + if p.ndim > 1: # if not normalization gamma/beta or bias + dp = dp.add(p, alpha=g['weight_decay']) + param_norm = torch.norm(p) + update_norm = torch.norm(dp) + one = torch.ones_like(param_norm) + q = torch.where(param_norm > 0., + torch.where(update_norm > 0, + (g['trust_coefficient'] * param_norm / update_norm), one), + one) + dp = dp.mul(q) + + param_state = self.state[p] + if 'mu' not in param_state: + param_state['mu'] = torch.zeros_like(p) + mu = param_state['mu'] + mu.mul_(g['momentum']).add_(dp) + p.add_(mu, alpha=-g['lr']) \ No newline at end of file diff --git a/util/lr_decay.py b/util/lr_decay.py new file mode 100644 index 0000000..7fa11f1 --- /dev/null +++ b/util/lr_decay.py @@ -0,0 +1,76 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# References: +# ELECTRA https://github.com/google-research/electra +# BEiT: https://github.com/microsoft/unilm/tree/master/beit +# -------------------------------------------------------- + +import json + + +def param_groups_lrd(model, weight_decay=0.05, no_weight_decay_list=[], layer_decay=.75): + """ + Parameter groups for layer-wise lr decay + Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L58 + """ + param_group_names = {} + param_groups = {} + + num_layers = len(model.blocks) + 1 + + layer_scales = list(layer_decay ** (num_layers - i) for i in range(num_layers + 1)) + + for n, p in model.named_parameters(): + if not p.requires_grad: + continue + + # no decay: all 1D parameters and model specific ones + if p.ndim == 1 or n in no_weight_decay_list: + g_decay = "no_decay" + this_decay = 0. + else: + g_decay = "decay" + this_decay = weight_decay + + layer_id = get_layer_id_for_vit(n, num_layers) + group_name = "layer_%d_%s" % (layer_id, g_decay) + + if group_name not in param_group_names: + this_scale = layer_scales[layer_id] + + param_group_names[group_name] = { + "lr_scale": this_scale, + "weight_decay": this_decay, + "params": [], + } + param_groups[group_name] = { + "lr_scale": this_scale, + "weight_decay": this_decay, + "params": [], + } + + param_group_names[group_name]["params"].append(n) + param_groups[group_name]["params"].append(p) + + # print("parameter groups: \n%s" % json.dumps(param_group_names, indent=2)) + + return list(param_groups.values()) + + +def get_layer_id_for_vit(name, num_layers): + """ + Assign a parameter with its layer id + Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L33 + """ + if name in ['cls_token', 'pos_embed']: + return 0 + elif name.startswith('patch_embed'): + return 0 + elif name.startswith('blocks'): + return int(name.split('.')[1]) + 1 + else: + return num_layers \ No newline at end of file diff --git a/util/lr_sched.py b/util/lr_sched.py new file mode 100644 index 0000000..ddb76a1 --- /dev/null +++ b/util/lr_sched.py @@ -0,0 +1,21 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import math + +def adjust_learning_rate(optimizer, epoch, lr, args): + """Decay the learning rate with half-cycle cosine after warmup""" + if epoch < args.warmup_epochs: + lr = lr * epoch / args.warmup_epochs + else: + lr = args.min_lr + (lr - args.min_lr) * 0.5 * \ + (1. + math.cos(math.pi * (epoch - args.warmup_epochs) / (args.epochs - args.warmup_epochs))) + for param_group in optimizer.param_groups: + if "lr_scale" in param_group: + param_group["lr"] = lr * param_group["lr_scale"] + else: + param_group["lr"] = lr + return lr diff --git a/util/misc.py b/util/misc.py new file mode 100644 index 0000000..ad9a786 --- /dev/null +++ b/util/misc.py @@ -0,0 +1,340 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# References: +# DeiT: https://github.com/facebookresearch/deit +# BEiT: https://github.com/microsoft/unilm/tree/master/beit +# -------------------------------------------------------- + +import builtins +import datetime +import os +import time +from collections import defaultdict, deque +from pathlib import Path + +import torch +import torch.distributed as dist +from torch._six import inf + + +class SmoothedValue(object): + """Track a series of values and provide access to smoothed values over a + window or the global series average. + """ + + def __init__(self, window_size=20, fmt=None): + if fmt is None: + fmt = "{median:.4f} ({global_avg:.4f})" + self.deque = deque(maxlen=window_size) + self.total = 0.0 + self.count = 0 + self.fmt = fmt + + def update(self, value, n=1): + self.deque.append(value) + self.count += n + self.total += value * n + + def synchronize_between_processes(self): + """ + Warning: does not synchronize the deque! + """ + if not is_dist_avail_and_initialized(): + return + t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') + dist.barrier() + dist.all_reduce(t) + t = t.tolist() + self.count = int(t[0]) + self.total = t[1] + + @property + def median(self): + d = torch.tensor(list(self.deque)) + return d.median().item() + + @property + def avg(self): + d = torch.tensor(list(self.deque), dtype=torch.float32) + return d.mean().item() + + @property + def global_avg(self): + return self.total / self.count + + @property + def max(self): + return max(self.deque) + + @property + def value(self): + return self.deque[-1] + + def __str__(self): + return self.fmt.format( + median=self.median, + avg=self.avg, + global_avg=self.global_avg, + max=self.max, + value=self.value) + + +class MetricLogger(object): + def __init__(self, delimiter="\t"): + self.meters = defaultdict(SmoothedValue) + self.delimiter = delimiter + + def update(self, **kwargs): + for k, v in kwargs.items(): + if v is None: + continue + if isinstance(v, torch.Tensor): + v = v.item() + assert isinstance(v, (float, int)) + self.meters[k].update(v) + + def __getattr__(self, attr): + if attr in self.meters: + return self.meters[attr] + if attr in self.__dict__: + return self.__dict__[attr] + raise AttributeError("'{}' object has no attribute '{}'".format( + type(self).__name__, attr)) + + def __str__(self): + loss_str = [] + for name, meter in self.meters.items(): + loss_str.append( + "{}: {}".format(name, str(meter)) + ) + return self.delimiter.join(loss_str) + + def synchronize_between_processes(self): + for meter in self.meters.values(): + meter.synchronize_between_processes() + + def add_meter(self, name, meter): + self.meters[name] = meter + + def log_every(self, iterable, print_freq, header=None): + i = 0 + if not header: + header = '' + start_time = time.time() + end = time.time() + iter_time = SmoothedValue(fmt='{avg:.4f}') + data_time = SmoothedValue(fmt='{avg:.4f}') + space_fmt = ':' + str(len(str(len(iterable)))) + 'd' + log_msg = [ + header, + '[{0' + space_fmt + '}/{1}]', + 'eta: {eta}', + '{meters}', + 'time: {time}', + 'data: {data}' + ] + if torch.cuda.is_available(): + log_msg.append('max mem: {memory:.0f}') + log_msg = self.delimiter.join(log_msg) + MB = 1024.0 * 1024.0 + for obj in iterable: + data_time.update(time.time() - end) + yield obj + iter_time.update(time.time() - end) + if i % print_freq == 0 or i == len(iterable) - 1: + eta_seconds = iter_time.global_avg * (len(iterable) - i) + eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) + if torch.cuda.is_available(): + print(log_msg.format( + i, len(iterable), eta=eta_string, + meters=str(self), + time=str(iter_time), data=str(data_time), + memory=torch.cuda.max_memory_allocated() / MB)) + else: + print(log_msg.format( + i, len(iterable), eta=eta_string, + meters=str(self), + time=str(iter_time), data=str(data_time))) + i += 1 + end = time.time() + total_time = time.time() - start_time + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + print('{} Total time: {} ({:.4f} s / it)'.format( + header, total_time_str, total_time / len(iterable))) + + +def setup_for_distributed(is_master): + """ + This function disables printing when not in master process + """ + builtin_print = builtins.print + + def print(*args, **kwargs): + force = kwargs.pop('force', False) + force = force or (get_world_size() > 8) + if is_master or force: + now = datetime.datetime.now().time() + builtin_print('[{}] '.format(now), end='') # print with time stamp + builtin_print(*args, **kwargs) + + builtins.print = print + + +def is_dist_avail_and_initialized(): + if not dist.is_available(): + return False + if not dist.is_initialized(): + return False + return True + + +def get_world_size(): + if not is_dist_avail_and_initialized(): + return 1 + return dist.get_world_size() + + +def get_rank(): + if not is_dist_avail_and_initialized(): + return 0 + return dist.get_rank() + + +def is_main_process(): + return get_rank() == 0 + + +def save_on_master(*args, **kwargs): + if is_main_process(): + torch.save(*args, **kwargs) + + +def init_distributed_mode(args): + if args.dist_on_itp: + args.rank = int(os.environ['OMPI_COMM_WORLD_RANK']) + args.world_size = int(os.environ['OMPI_COMM_WORLD_SIZE']) + args.gpu = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK']) + args.dist_url = "tcp://%s:%s" % (os.environ['MASTER_ADDR'], os.environ['MASTER_PORT']) + os.environ['LOCAL_RANK'] = str(args.gpu) + os.environ['RANK'] = str(args.rank) + os.environ['WORLD_SIZE'] = str(args.world_size) + # ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"] + elif 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: + args.rank = int(os.environ["RANK"]) + args.world_size = int(os.environ['WORLD_SIZE']) + args.gpu = int(os.environ['LOCAL_RANK']) + elif 'SLURM_PROCID' in os.environ: + args.rank = int(os.environ['SLURM_PROCID']) + args.gpu = args.rank % torch.cuda.device_count() + else: + print('Not using distributed mode') + setup_for_distributed(is_master=True) # hack + args.distributed = False + return + + args.distributed = True + + torch.cuda.set_device(args.gpu) + args.dist_backend = 'nccl' + print('| distributed init (rank {}): {}, gpu {}'.format( + args.rank, args.dist_url, args.gpu), flush=True) + torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, + world_size=args.world_size, rank=args.rank) + torch.distributed.barrier() + setup_for_distributed(args.rank == 0) + + +class NativeScalerWithGradNormCount: + state_dict_key = "amp_scaler" + + def __init__(self): + self._scaler = torch.cuda.amp.GradScaler() + + def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True): + self._scaler.scale(loss).backward(create_graph=create_graph) + if update_grad: + if clip_grad is not None: + assert parameters is not None + self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place + norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad) + else: + self._scaler.unscale_(optimizer) + norm = get_grad_norm_(parameters) + self._scaler.step(optimizer) + self._scaler.update() + else: + norm = None + return norm + + def state_dict(self): + return self._scaler.state_dict() + + def load_state_dict(self, state_dict): + self._scaler.load_state_dict(state_dict) + + +def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor: + if isinstance(parameters, torch.Tensor): + parameters = [parameters] + parameters = [p for p in parameters if p.grad is not None] + norm_type = float(norm_type) + if len(parameters) == 0: + return torch.tensor(0.) + device = parameters[0].grad.device + if norm_type == inf: + total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters) + else: + total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type) + return total_norm + + +def save_model(args, epoch, model, model_without_ddp, optimizer, loss_scaler): + output_dir = Path(args.output_dir) + epoch_name = str(epoch) + if loss_scaler is not None: + checkpoint_paths = [output_dir / ('checkpoint-%s.pth' % epoch_name)] + for checkpoint_path in checkpoint_paths: + to_save = { + 'model': model_without_ddp.state_dict(), + 'optimizer': optimizer.state_dict(), + 'epoch': epoch, + 'scaler': loss_scaler.state_dict(), + 'args': args, + } + + save_on_master(to_save, checkpoint_path) + else: + client_state = {'epoch': epoch} + model.save_checkpoint(save_dir=args.output_dir, tag="checkpoint-%s" % epoch_name, client_state=client_state) + + +def load_model(args, model_without_ddp, optimizer, loss_scaler): + if args.resume: + if args.resume.startswith('https'): + checkpoint = torch.hub.load_state_dict_from_url( + args.resume, map_location='cpu', check_hash=True) + else: + checkpoint = torch.load(args.resume, map_location='cpu') + model_without_ddp.load_state_dict(checkpoint['model']) + print("Resume checkpoint %s" % args.resume) + if 'optimizer' in checkpoint and 'epoch' in checkpoint and not (hasattr(args, 'eval') and args.eval): + optimizer.load_state_dict(checkpoint['optimizer']) + args.start_epoch = checkpoint['epoch'] + 1 + if 'scaler' in checkpoint: + loss_scaler.load_state_dict(checkpoint['scaler']) + print("With optim & sched!") + + +def all_reduce_mean(x): + world_size = get_world_size() + if world_size > 1: + x_reduce = torch.tensor(x).cuda() + dist.all_reduce(x_reduce) + x_reduce /= world_size + return x_reduce.item() + else: + return x \ No newline at end of file diff --git a/util/pos_embed.py b/util/pos_embed.py new file mode 100644 index 0000000..6acf8bd --- /dev/null +++ b/util/pos_embed.py @@ -0,0 +1,96 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# Position embedding utils +# -------------------------------------------------------- + +import numpy as np + +import torch + +# -------------------------------------------------------- +# 2D sine-cosine position embedding +# References: +# Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py +# MoCo v3: https://github.com/facebookresearch/moco-v3 +# -------------------------------------------------------- +def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): + """ + grid_size: int of the grid height and width + return: + pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) + """ + grid_h = np.arange(grid_size, dtype=np.float32) + grid_w = np.arange(grid_size, dtype=np.float32) + grid = np.meshgrid(grid_w, grid_h) # here w goes first + grid = np.stack(grid, axis=0) + + grid = grid.reshape([2, 1, grid_size, grid_size]) + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) + if cls_token: + pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) + return pos_embed + + +def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): + assert embed_dim % 2 == 0 + + # use half of dimensions to encode grid_h + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) + + emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) + return emb + + +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): + """ + embed_dim: output dimension for each position + pos: a list of positions to be encoded: size (M,) + out: (M, D) + """ + assert embed_dim % 2 == 0 + omega = np.arange(embed_dim // 2, dtype=np.float) + omega /= embed_dim / 2. + omega = 1. / 10000**omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product + + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + + emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) + return emb + + +# -------------------------------------------------------- +# Interpolate position embeddings for high-resolution +# References: +# DeiT: https://github.com/facebookresearch/deit +# -------------------------------------------------------- +def interpolate_pos_embed(model, checkpoint_model): + if 'pos_embed' in checkpoint_model: + pos_embed_checkpoint = checkpoint_model['pos_embed'] + embedding_size = pos_embed_checkpoint.shape[-1] + num_patches = model.patch_embed.num_patches + num_extra_tokens = model.pos_embed.shape[-2] - num_patches + # height (== width) for the checkpoint position embedding + orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) + # height (== width) for the new position embedding + new_size = int(num_patches ** 0.5) + # class_token and dist_token are kept unchanged + if orig_size != new_size: + print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size)) + extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] + # only the position tokens are interpolated + pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] + pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) + pos_tokens = torch.nn.functional.interpolate( + pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) + pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) + new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) + checkpoint_model['pos_embed'] = new_pos_embed