Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

r1.15.5-deeprec2302 incr ev 在 restore过程中不能正确加载 #999

Open
HH-66 opened this issue Jun 20, 2024 · 4 comments
Open

r1.15.5-deeprec2302 incr ev 在 restore过程中不能正确加载 #999

HH-66 opened this issue Jun 20, 2024 · 4 comments
Labels
bug Something isn't working

Comments

@HH-66
Copy link

HH-66 commented Jun 20, 2024

System information

  • OS Platform and Distribution (e.g., Linux Ubuntu 20.04): Ubuntu 20.04
  • DeepRec version or commit id: 2325297
  • Python version: python3.6.9
  • Bazel version (if compiling from source):
  • GCC/Compiler version (if compiling from source): 0.26.1
  • CUDA/cuDNN version: None

Describe the current behavior
restore的时候加载 incremental_ckpt ev变量不能正确加载覆盖base里的ev变量

Describe the expected behavior
正确加载incr ev 覆盖对应的变量

Code to reproduce the issue

Provide a reproducible test case that is the bare minimum necessary to generate the problem.

Other info / logs

Include any logs or source code that would be helpful to diagnose the problem. If including tracebacks, please include the full traceback. Large logs and files should be attached.

@candyzone
Copy link
Collaborator

A temporary fix based deeprec2302
candyzone@da651f0

@candyzone candyzone added the bug Something isn't working label Jun 20, 2024
@candyzone
Copy link
Collaborator

This issue is already fixed in release deeprec2402.

@torshie
Copy link

torshie commented Jul 2, 2024

使用了partitioner后,问题仍然存在,用下面的代码可以复现 (2302版本)

测试代码说明:

测试模型经过设计使得它具有以下几个特点(具体实现方法参见model_fn函数)

  • embedding的初始值为 1
  • 一个key每训练一步,其权重值下降 0.01
  • train/eval loss跟embedding的权重相同

训练的key依次为: 20步0, 40步1, 30步2, 10步3

训练命令:
./bare_minimum.py train

评测命令:
./bare_minimum.py eval --value 1 # 评测 key 为 1 的 embedding权重

修改checkpoint_dir/checkpoint文件,可以分别评测增量checkpoint和全量checkpoint

代码中的其它内容主要是为了让DeepRec能够在使用estimator api时也能正确生成、加载增量checkpoint

#!/usr/bin/env python3

import argparse
import functools
import os.path
import time

import tensorflow as tf
import numpy

global _incr_ckpt_secs
global _incr_ckpt_steps


class DelayHook(tf.train.SessionRunHook):
    def after_run(self, run_context, run_values):
        time.sleep(0.02)


def get_ev_option():
    init_opt = tf.InitializerOption(initializer=tf.constant_initializer(1))
    return tf.EmbeddingVariableOption(
            init_option=init_opt, filter_option=None, evict_option=None)


def model_fn(features, labels, mode, params):
    id_ = features['x']
    weights = tf.get_embedding_variable(
            name='embedding_table', embedding_dim=1,
            value_dtype=tf.float32,
            ev_option=get_ev_option(), key_dtype=tf.int64,
            partitioner=tf.fixed_size_partitioner(num_shards=1))

    x = tf.nn.embedding_lookup(weights, id_)
    y = tf.reduce_mean(x, axis=1)
    loss = tf.reduce_mean(y - labels)

    saver = tf.train.Saver(
            sharded=True, incremental_save_restore=True,
            save_relative_paths=True)
    scaffold = tf.train.Scaffold(saver=saver, incremental_save_restore=True)
    if mode == tf.estimator.ModeKeys.EVAL:
        return tf.estimator.EstimatorSpec(mode, loss=loss, scaffold=scaffold)

    optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01)

    saver_hook = tf.train.CheckpointSaverHook(
            incremental_save_secs=1, checkpoint_dir='checkpoint_dir',
            save_steps=50, scaffold=scaffold, listeners=[])
    log_hook = tf.train.LoggingTensorHook(
            {'loss': loss, 'step': tf.train.get_or_create_global_step()},
            every_n_iter=1)

    minimize = optimizer.minimize(
            loss, global_step=tf.train.get_or_create_global_step())

    return tf.estimator.EstimatorSpec(
            mode, loss=loss, train_op=minimize,
            training_chief_hooks=[saver_hook], scaffold=scaffold,
            training_hooks=[log_hook, DelayHook()])


def train_input_fn():
    def generator():
        for i in range(20):
            features = {
                'x': numpy.array([0], dtype=numpy.int64)
            }
            labels = numpy.zeros([1], dtype=numpy.float32)
            yield features, labels

        for i in range(40):
            features = {
                'x': numpy.array([1], dtype=numpy.int64),
            }
            labels = numpy.zeros([1], dtype=numpy.float32)
            yield features, labels

        for i in range(30):
            features = {
                'x': numpy.array([2], dtype=numpy.int64)
            }
            labels = numpy.zeros([1], dtype=numpy.float32)
            yield features, labels

        for i in range(10):
            features = {
                'x': numpy.array([3], dtype=numpy.int64)
            }
            labels = numpy.zeros([1], dtype=numpy.float32)
            yield features, labels

    return tf.data.Dataset.from_generator(
            generator, output_types=({'x': tf.int64}, tf.float32),
            output_shapes=({'x': tf.TensorShape([None])}, tf.TensorShape([None])))


def eval_input_fn(value):
    def generator():
        for i in range(10):
            features = {
                'x': numpy.array([value, value], dtype=numpy.int64)
            }
            labels = numpy.zeros([2], dtype=numpy.float32)
            yield features, labels

    return tf.data.Dataset.from_generator(
            generator, output_types=({'x': tf.int64}, tf.float32),
            output_shapes=({'x': tf.TensorShape([None])}, tf.TensorShape([None])))


def _patch_session_creator(checkpoint_dir):
    tf.logging.info('Patching monitored_session.ChiefSessionCreator')
    from tensorflow.python.training import monitored_session, checkpoint_management
    monitored_session.ChiefSessionCreator__ = monitored_session.ChiefSessionCreator
    monitored_session.ChiefSessionCreator = functools.partial(
            _session_creator, checkpoint_dir=checkpoint_dir)
    _patch_evaluate_and_export()
    _patch_evaluate_recover_session()


def _create_session(*args, **kwargs):
    if _incr_ckpt_secs is not None:
        if 'save_incremental_checkpoint_secs' not in kwargs \
                or kwargs['save_incremental_checkpoint_secs'] is None:
            kwargs['save_incremental_checkpoint_secs'] = _incr_ckpt_secs
    if _incr_ckpt_steps is not None:
        if 'save_incremental_checkpoint_steps' not in kwargs \
                or kwargs['save_incremental_checkpoint_steps'] is None:
            kwargs['save_incremental_checkpoint_steps'] = _incr_ckpt_steps
    tf.logging.info("Creating MonitoredTrainingSession, %s, %s", args, kwargs)
    return tf.train.MonitoredTrainingSession__(*args, **kwargs)


def _session_creator(**kwargs):
    from tensorflow.python.training import monitored_session

    kwargs['checkpoint_filename_with_path'] = None

    tf.logging.info('Creating ChiefSessionCreator: %s', kwargs)
    return monitored_session.ChiefSessionCreator__(**kwargs)


def patch_incr_ckpt(secs=0, steps=0):
    global _incr_ckpt_secs
    global _incr_ckpt_steps
    _incr_ckpt_secs = secs if secs > 0 else None
    _incr_ckpt_steps = steps if steps > 0 else None

    tf.logging.info("Patching MonitoredTrainingSession.")
    from tensorflow.python.training import training
    training.MonitoredTrainingSession__ = training.MonitoredTrainingSession
    training.MonitoredTrainingSession = _create_session
    tf.train.MonitoredTrainingSession__ = tf.train.MonitoredTrainingSession
    tf.train.MonitoredTrainingSession = _create_session


def _patch_evaluate_and_export():
    from tensorflow.python.training import checkpoint_management
    from tensorflow_estimator.python.estimator.training import _EvalResult, _EvalStatus, _TrainingExecutor
    from tensorflow.python.framework import ops
    from tensorflow.python.eager import context

    def evaluate_and_export(self):
        tf.logging.info('custom evaluate_and_export')

        latest_ckpt_path = self._estimator.latest_checkpoint()
        if not latest_ckpt_path:
            self._log_err_msg('Estimator is not trained yet. Will start an '
                              'evaluation when a checkpoint is ready.')
            return _EvalResult(status=_EvalStatus.MISSING_CHECKPOINT), []

        # .incremental_checkpoint
        with context.graph_mode():
            incremental_dir = os.path.join(self._estimator.model_dir, '.incremental_checkpoint')
            incremental_ckpt = checkpoint_management.latest_checkpoint(incremental_dir)

        base_version = int(latest_ckpt_path.split('-')[-1])
        incremental_version = int(incremental_ckpt.split('-')[-1]) if incremental_ckpt else None
        previous_version = int(self._previous_ckpt_path.split('-')[-1]) if self._previous_ckpt_path else None
        tf.logging.info(f'now version: {base_version} {incremental_version} <- {previous_version}')

        if previous_version and incremental_version and incremental_version == previous_version:
            self._log_err_msg(
                'No new checkpoint ready for evaluation. Skip the current '
                'evaluation pass as evaluation results are expected to be same '
                'for the same checkpoint.')
            return _EvalResult(status=_EvalStatus.NO_NEW_CHECKPOINT), []

        metrics = self._estimator.evaluate(
            input_fn=self._eval_spec.input_fn,
            steps=self._eval_spec.steps,
            name=self._eval_spec.name,
            checkpoint_path=latest_ckpt_path,
            hooks=self._eval_spec.hooks)

        # _EvalResult validates the metrics.
        eval_result = _EvalResult(
            status=_EvalStatus.EVALUATED,
            metrics=metrics,
            checkpoint_path=latest_ckpt_path)

        is_the_final_export = (
            eval_result.metrics[ops.GraphKeys.GLOBAL_STEP] >=
            self._max_training_steps if self._max_training_steps else False)
        export_results = self._export_eval_result(eval_result,
                                                  is_the_final_export)

        if is_the_final_export:
            tf.logging.debug('Calling exporter with the `is_the_final_export=True`.')
            self._is_final_export_triggered = True

        self._last_warning_time = 0
        self._previous_ckpt_path = incremental_ckpt if incremental_ckpt else latest_ckpt_path
        return eval_result, export_results

    _TrainingExecutor._Evaluator.evaluate_and_export__ = _TrainingExecutor._Evaluator.evaluate_and_export
    _TrainingExecutor._Evaluator.evaluate_and_export = evaluate_and_export

def _patch_evaluate_recover_session():
    def recover_session(self,
                        master,
                        saver=None,
                        checkpoint_dir=None,
                        checkpoint_filename_with_path=None,
                        wait_for_checkpoint=False,
                        max_wait_secs=7200,
                        config=None):
        from tensorflow.python.training import incremental_saver
        incr_saver = incremental_saver._get_incremental_saver(self._incremental_save_restore, self._saver)

        tf.logging.info("custom recover_session")

        sess, is_loaded_from_checkpoint = self._restore_checkpoint(
            master,
            saver,
            incr_saver,
            checkpoint_dir=checkpoint_dir,
            checkpoint_filename_with_path=checkpoint_filename_with_path,
            wait_for_checkpoint=wait_for_checkpoint,
            max_wait_secs=max_wait_secs,
            config=config)

        # Always try to run local_init_op
        local_init_success, msg = self._try_run_local_init_op(sess)

        if not is_loaded_from_checkpoint:
            # Do not need to run checks for readiness
            return sess, False

        restoring_file = checkpoint_dir or checkpoint_filename_with_path
        if not local_init_success:
            tf.logging.info(
                "Restoring model from %s did not make model ready for local init:"
                " %s", restoring_file, msg)
            return sess, False

        is_ready, msg = self._model_ready(sess)
        if not is_ready:
            tf.logging.info("Restoring model from %s did not make model ready: %s",
                         restoring_file, msg)
            return sess, False

        tf.logging.info("Restored model from %s", restoring_file)
        return sess, is_loaded_from_checkpoint

    tf.train.SessionManager.recover_session__ = tf.train.SessionManager.recover_session
    tf.train.SessionManager.recover_session = recover_session


def parse_cmdline():
    p = argparse.ArgumentParser()
    p.add_argument('mode', choices=('train', 'eval'))
    p.add_argument('--value', type=int, default=0)
    return p.parse_args()


def main():
    cmdline = parse_cmdline()

    tf.logging.set_verbosity(tf.logging.INFO)

    patch_incr_ckpt(secs=1)
    _patch_session_creator('checkpoint_dir')

    eval_input = functools.partial(eval_input_fn, value=cmdline.value)

    train_spec = tf.estimator.TrainSpec(input_fn=train_input_fn)
    eval_spec = tf.estimator.EvalSpec(input_fn=eval_input)
    config = tf.estimator.RunConfig(
        model_dir='checkpoint_dir',
        tf_random_seed=2020,
        save_summary_steps=1,
        save_checkpoints_steps=50,
        keep_checkpoint_max=20,
        experimental_max_worker_delay_secs=2000)
    estimator = tf.estimator.Estimator(model_fn=model_fn, config=config)

    if cmdline.mode == 'eval':
        estimator.evaluate(eval_input)
    else:
        tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)


if __name__ == '__main__':
    main()

@torshie
Copy link

torshie commented Jul 2, 2024

经验证,candyzone@261ccfb 能解决这个问题

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants