Skip to content

Commit

Permalink
[FIX] Fixes apache#6096 (apache#6131)
Browse files Browse the repository at this point in the history
Clear the compile cache between module builds so that schedule changes
will have an effect. Also, clear the warning cache so that schedule
changes properly list untuned ops.
  • Loading branch information
tkonolige authored and Trevor Morris committed Sep 2, 2020
1 parent 996d6f9 commit 6c95307
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 2 deletions.
5 changes: 5 additions & 0 deletions python/tvm/autotvm/task/relay_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import logging

import tvm
from tvm.autotvm.task.dispatcher import DispatchContext, FallbackContext
from .task import create
from .topi_integration import TaskExtractEnv

Expand Down Expand Up @@ -140,6 +141,10 @@ def extract_from_multiple_program(mods, params, target, target_host=None, ops=No
build_thread.start()
build_thread.join()
relay.backend.compile_engine.get().clear()
# Clear the warning message cache in FallbackContext
if isinstance(DispatchContext.current, FallbackContext):
DispatchContext.current.memory = {}
DispatchContext.warning_messages = set()

logger.disabled = old_state

Expand Down
3 changes: 3 additions & 0 deletions src/relay/backend/build_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#include <memory>

#include "../../target/source/codegen_source_base.h"
#include "compile_engine.h"
#include "utils.h"

namespace tvm {
Expand Down Expand Up @@ -224,6 +225,8 @@ class RelayBuildModule : public runtime::ModuleNode {
targets_ = targets;
target_host_ = target_host;
BuildRelay(mod, params_);
// Clear compile engine so that tuning schedules can be changed between runs. See issue #6096.
CompileEngine::Global()->Clear();
}

protected:
Expand Down
2 changes: 1 addition & 1 deletion src/relay/backend/compile_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -770,7 +770,7 @@ class CompileEngineImpl : public CompileEngineNode {
};

/*! \brief The global compile engine */
const CompileEngine& CompileEngine::Global() {
CompileEngine& CompileEngine::Global() {
// intentionally allocate raw pointer to avoid
// free during destructuion.
static CompileEngine* inst = new CompileEngine(make_object<CompileEngineImpl>());
Expand Down
2 changes: 1 addition & 1 deletion src/relay/backend/compile_engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ class CompileEngine : public ObjectRef {
CompileEngineNode* operator->() { return static_cast<CompileEngineNode*>(get_mutable()); }
using ContainerType = CompileEngineNode;
/*! \brief The global compile engine. */
TVM_DLL static const CompileEngine& Global();
TVM_DLL static CompileEngine& Global();
};

/*!
Expand Down

0 comments on commit 6c95307

Please sign in to comment.