Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Why Result is less than the one in the paper ? #12

Closed
MOAboAli opened this issue Mar 22, 2021 · 8 comments
Closed

Why Result is less than the one in the paper ? #12

MOAboAli opened this issue Mar 22, 2021 · 8 comments

Comments

@MOAboAli
Copy link

Why Result is less than the one in the paper ?

@huynhtruc0309
Copy link

Hello, @mohamedaboalimaa have you resolved this problem?

@MOAboAli
Copy link
Author

No I couldn't get same results of the paper from the trained Model , Do you have an idea how ?

@lugiavn
Copy link
Collaborator

lugiavn commented Mar 31, 2021

Thanks for raising the concern, are they the models you train yourself or the model we provided?
There might be issues with the model we provide for download, so we are investigating,
If you train it yourself, I think the performance should be comparable to reported the paper

@MOAboAli
Copy link
Author

MOAboAli commented Apr 1, 2021

the model I used is the one I you provide "checkpoint_fashion200k.pth" ,the result that come out of this model (in test datset) is :

['1 ---> 2.7,
'5 ---> 7.8',
'10 ---> 11.1',
'50 ---> 26.4',
'100 ---> 36.4']

@huynhtruc0309
Copy link

huynhtruc0309 commented Apr 3, 2021

@mohamedaboalimaa The provided checkpoint on Fashion200k is not well trained. I can get the same recall as the paper by do it myself.
This is the result I get

     train recall_top1_correct_composition 0.3332
     train recall_top5_correct_composition 0.6061
     train recall_top10_correct_composition 0.7332
     train recall_top50_correct_composition 0.9294
     train recall_top100_correct_composition 0.9686
     test recall_top1_correct_composition 0.1378
     test recall_top5_correct_composition 0.3446
     test recall_top10_correct_composition 0.4262
     test recall_top50_correct_composition 0.6308
     test recall_top100_correct_composition 0.7184

Something you should add to your code to get the best result. You should set deterministic, save the checkpoint at the iteration with the lowest loss and remember to set train shuffle True.
This is my modified main.py

# Copyright 2019 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Main method to train the model."""


#!/usr/bin/python

import argparse
import sys
import time
import datasets
import img_text_composition_models
import numpy as np
import random
from torch.utils.tensorboard import SummaryWriter
import test_retrieval
import torch
import torch.utils.data
import torchvision
from tqdm import tqdm as tqdm

torch.set_num_threads(3)

def set_deterministic():
    torch.manual_seed(0)
    random.seed(0)
    np.random.seed(0)    


def parse_opt():
    """Parses the input arguments."""
    parser = argparse.ArgumentParser()
    parser.add_argument('-f', type=str, default='')
    parser.add_argument('--comment', type=str, default='test_notebook')
    parser.add_argument('--dataset', type=str, default='fashion200k')
    parser.add_argument(
        '--dataset_path', type=str, default='./Fashion200k')
    parser.add_argument('--model', type=str, default='tirg')
    parser.add_argument('--embed_dim', type=int, default=512)
    parser.add_argument('--learning_rate', type=float, default=1e-2)
    parser.add_argument(
        '--learning_rate_decay_frequency', type=int, default=9999999)
    parser.add_argument('--batch_size', type=int, default=32)
    parser.add_argument('--weight_decay', type=float, default=1e-6)
    parser.add_argument('--num_iters', type=int, default=210000)
    parser.add_argument('--loss', type=str, default='soft_triplet')
    parser.add_argument('--loader_num_workers', type=int, default=4)
    args = parser.parse_args()
    return args


def load_dataset(opt):
    """Loads the input datasets."""
    print('Reading dataset ', opt.dataset)
    if opt.dataset == 'css3d':
        trainset = datasets.CSSDataset(
            path=opt.dataset_path,
            split='train',
            transform=torchvision.transforms.Compose([
                torchvision.transforms.ToTensor(),
                torchvision.transforms.Normalize([0.485, 0.456, 0.406],
                                                 [0.229, 0.224, 0.225])
            ]))
        testset = datasets.CSSDataset(
            path=opt.dataset_path,
            split='test',
            transform=torchvision.transforms.Compose([
                torchvision.transforms.ToTensor(),
                torchvision.transforms.Normalize([0.485, 0.456, 0.406],
                                                 [0.229, 0.224, 0.225])
            ]))
    elif opt.dataset == 'fashion200k':
        trainset = datasets.Fashion200k(
            path=opt.dataset_path,
            split='train',
            transform=torchvision.transforms.Compose([
                torchvision.transforms.Resize(224),
                torchvision.transforms.CenterCrop(224),
                torchvision.transforms.ToTensor(),
                torchvision.transforms.Normalize([0.485, 0.456, 0.406],
                                                 [0.229, 0.224, 0.225])
            ]))
        testset = datasets.Fashion200k(
            path=opt.dataset_path,
            split='test',
            transform=torchvision.transforms.Compose([
                torchvision.transforms.Resize(224),
                torchvision.transforms.CenterCrop(224),
                torchvision.transforms.ToTensor(),
                torchvision.transforms.Normalize([0.485, 0.456, 0.406],
                                                 [0.229, 0.224, 0.225])
            ]))
    elif opt.dataset == 'mitstates':
        trainset = datasets.MITStates(
            path=opt.dataset_path,
            split='train',
            transform=torchvision.transforms.Compose([
                torchvision.transforms.Resize(224),
                torchvision.transforms.CenterCrop(224),
                torchvision.transforms.ToTensor(),
                torchvision.transforms.Normalize([0.485, 0.456, 0.406],
                                                 [0.229, 0.224, 0.225])
            ]))
        testset = datasets.MITStates(
            path=opt.dataset_path,
            split='test',
            transform=torchvision.transforms.Compose([
                torchvision.transforms.Resize(224),
                torchvision.transforms.CenterCrop(224),
                torchvision.transforms.ToTensor(),
                torchvision.transforms.Normalize([0.485, 0.456, 0.406],
                                                 [0.229, 0.224, 0.225])
            ]))
    else:
        print('Invalid dataset', opt.dataset)
        sys.exit()

    print('trainset size:', len(trainset))
    print('testset size:', len(testset))
    return trainset, testset


def create_model_and_optimizer(opt, texts):
    """Builds the model and related optimizer."""
    print('Creating model and optimizer for', opt.model)
    set_deterministic()
    if opt.model == 'imgonly':
        model = img_text_composition_models.SimpleModelImageOnly(
            texts, embed_dim=opt.embed_dim)
    elif opt.model == 'textonly':
        model = img_text_composition_models.SimpleModelTextOnly(
            texts, embed_dim=opt.embed_dim)
    elif opt.model == 'concat':
        model = img_text_composition_models.Concat(
            texts, embed_dim=opt.embed_dim)
    elif opt.model == 'tirg':
        model = img_text_composition_models.TIRG(
            texts, embed_dim=opt.embed_dim)
    elif opt.model == 'tirg_lastconv':
        model = img_text_composition_models.TIRGLastConv(
            texts, embed_dim=opt.embed_dim)
    else:
        print('Invalid model', opt.model)
        print('available: imgonly, textonly, concat, tirg or tirg_lastconv')
        sys.exit()
    model = model.cuda()
    # import pdb; pdb.set_trace()

    # create optimizer
    params = []
    # low learning rate for pretrained layers on real image datasets
    if opt.dataset != 'css3d':
        params.append({
            'params': [p for p in model.img_model.fc.parameters()],
            'lr': opt.learning_rate
        })
        params.append({
            'params': [p for p in model.img_model.parameters()],
            'lr': 0.1 * opt.learning_rate
        })
    params.append({'params': [p for p in model.parameters()]})
    for _, p1 in enumerate(params):  # remove duplicated params
        for _, p2 in enumerate(params):
            if p1 is not p2:
                for p11 in p1['params']:
                    for j, p22 in enumerate(p2['params']):
                        if p11 is p22:
                            p2['params'][j] = torch.tensor(
                                0.0, requires_grad=True)
    optimizer = torch.optim.SGD(
        params, lr=opt.learning_rate, momentum=0.9, weight_decay=opt.weight_decay)
    return model, optimizer


def train_loop(opt, logger, trainset, testset, model, optimizer):
    """Function for train loop"""
    print('Begin training')
    lowest_loss = float('inf')
    it = 0
    epoch = -1
    tic = time.time()
    while it < opt.num_iters:
        epoch += 1

        # show/log stats
        print('It', it, 'epoch', epoch, 'Elapsed time', round(time.time() - tic,
                                                              4), opt.comment)
        tic = time.time()

        # test
        if epoch % 3 == 1:
            tests = []
            for name, dataset in [('train', trainset), ('test', testset)]:
                t = test_retrieval.test(opt, model, dataset)
                tests += [(name + '/' + metric_name, metric_value)
                          for metric_name, metric_value in t]
            for metric_name, metric_value in tests:
                logger.add_scalar(metric_name, metric_value, it)
                print('    ', metric_name, round(metric_value, 4))

        # save checkpoint
        torch.save({
            'it': it,
            'opt': opt,
            'model_state_dict': model.state_dict(),
        },
            logger.file_writer.get_logdir() + '/latest_checkpoint.pth')

        # run trainning for 1 epoch
        model.train()
        trainloader = trainset.get_loader(
            batch_size=opt.batch_size,
            shuffle=True,
            drop_last=True,
            num_workers=opt.loader_num_workers)

        def training_1_iter(data, lowest_loss):
            assert type(data) is list
            img1 = np.stack([d['source_img_data'] for d in data])
            img1 = torch.from_numpy(img1).float()
            img1 = torch.autograd.Variable(img1).cuda()
            img2 = np.stack([d['target_img_data'] for d in data])
            img2 = torch.from_numpy(img2).float()
            img2 = torch.autograd.Variable(img2).cuda()
            mods = [str(d['mod']['str']) for d in data]
            mods = [t for t in mods]
            # print(mods)

            # compute loss
            losses = []
            if opt.loss == 'soft_triplet':
                loss_value = model.compute_loss(
                    img1, mods, img2, soft_triplet_loss=True)
            elif opt.loss == 'batch_based_classification':
                loss_value = model.compute_loss(
                    img1, mods, img2, soft_triplet_loss=False)
            else:
                print('Invalid loss function', opt.loss)
                sys.exit()
            loss_name = opt.loss
            loss_weight = 1.0
            losses += [(loss_name, loss_weight, loss_value)]
            total_loss = sum([
                loss_weight * loss_value
                for loss_name, loss_weight, loss_value in losses
            ])
            assert not torch.isnan(total_loss)
            losses += [('total training loss', None, total_loss)]

            # track losses
            for loss_name, loss_weight, loss_value in losses:
                logger.add_scalar(loss_name, loss_value, it)
            logger.add_scalar('learning_rate', optimizer.param_groups[0]['lr'], it)
        
            if total_loss < lowest_loss:
                torch.save({
                    'it': it,
                    'opt': opt,
                    'model_state_dict': model.state_dict(),
                }, logger.file_writer.get_logdir() + '/best_checkpoint.pth')
                lowest_loss = total_loss
                
            # gradient descend
            optimizer.zero_grad()
            total_loss.backward()
            optimizer.step()

        for data in tqdm(trainloader, desc='Training for epoch ' + str(epoch)):
            it += 1
            training_1_iter(data, lowest_loss)

            # decay learing rate
            if it >= opt.learning_rate_decay_frequency and it % opt.learning_rate_decay_frequency == 0:
                for g in optimizer.param_groups:
                    g['lr'] *= 0.1

    print('Finished training')


def main():
    opt = parse_opt()
    print('Arguments:')
    for k in opt.__dict__.keys():
        print('    ', k, ':', str(opt.__dict__[k]))

    logger = SummaryWriter(comment=opt.comment)
    print('Log files saved to', logger.file_writer.get_logdir())
    for k in opt.__dict__.keys():
        logger.add_text(k, str(opt.__dict__[k]))

    trainset, testset = load_dataset(opt)
    print(len([t for t in trainset.get_all_texts()]))
    model, optimizer = create_model_and_optimizer(
        opt, [t for t in trainset.get_all_texts()])

    train_loop(opt, logger, trainset, testset, model, optimizer)
    logger.close()


if __name__ == '__main__':
    main()

I finally get it after 160k iterations.

@MOAboAli
Copy link
Author

MOAboAli commented Apr 3, 2021

@huynhtruc0309 thank you for your response, how much time it takes to run 160K iterations? , because I tried once on my PC, 1 iteration almost takes one day.

@lugiavn
Copy link
Collaborator

lugiavn commented Apr 5, 2021

We updated HEAD so that the text model vocabulary is more deterministic now, please use it so that your model is saved correctly. Note that you still need the original training dataset to construct the model.

I trained a new model for fashion200k
https://drive.google.com/file/d/1U4TdV3T22ZSB-07BR_BppwPailfWnU6u/view?usp=sharing

test recall_top1_correct_composition 0.1481
test recall_top5_correct_composition 0.3402
test recall_top10_correct_composition 0.4245
test recall_top50_correct_composition 0.6548
test recall_top100_correct_composition 0.7436    

@MOAboAli
Copy link
Author

MOAboAli commented Apr 5, 2021

@lugiavn I would like to thank you very much, it works perfect now .

@MOAboAli MOAboAli closed this as completed Apr 5, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants