Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FIX] Clear relay cache after every build & Clear warning message cache after autotvm task extraction #6131

Merged
merged 1 commit into from
Jul 31, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions python/tvm/autotvm/task/relay_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import logging

import tvm
from tvm.autotvm.task.dispatcher import DispatchContext, FallbackContext
from .task import create
from .topi_integration import TaskExtractEnv

Expand Down Expand Up @@ -140,6 +141,10 @@ def extract_from_multiple_program(mods, params, target, target_host=None, ops=No
build_thread.start()
build_thread.join()
relay.backend.compile_engine.get().clear()
# Clear the warning message cache in FallbackContext
if isinstance(DispatchContext.current, FallbackContext):
DispatchContext.current.memory = {}
DispatchContext.warning_messages = set()

logger.disabled = old_state

Expand Down
3 changes: 3 additions & 0 deletions src/relay/backend/build_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#include <memory>

#include "../../target/source/codegen_source_base.h"
#include "compile_engine.h"
#include "utils.h"

namespace tvm {
Expand Down Expand Up @@ -224,6 +225,8 @@ class RelayBuildModule : public runtime::ModuleNode {
targets_ = targets;
target_host_ = target_host;
BuildRelay(mod, params_);
// Clear compile engine so that tuning schedules can be changed between runs. See issue #6096.
CompileEngine::Global()->Clear();
}

protected:
Expand Down
2 changes: 1 addition & 1 deletion src/relay/backend/compile_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -770,7 +770,7 @@ class CompileEngineImpl : public CompileEngineNode {
};

/*! \brief The global compile engine */
const CompileEngine& CompileEngine::Global() {
CompileEngine& CompileEngine::Global() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the reason for dropping the const here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Clearing the cache mutates the CompileEngine, so in order to clear the cache after each build we need to modify the global state.

// intentionally allocate raw pointer to avoid
// free during destructuion.
static CompileEngine* inst = new CompileEngine(make_object<CompileEngineImpl>());
Expand Down
2 changes: 1 addition & 1 deletion src/relay/backend/compile_engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ class CompileEngine : public ObjectRef {
CompileEngineNode* operator->() { return static_cast<CompileEngineNode*>(get_mutable()); }
using ContainerType = CompileEngineNode;
/*! \brief The global compile engine. */
TVM_DLL static const CompileEngine& Global();
TVM_DLL static CompileEngine& Global();
};

/*!
Expand Down