Skip to content

Commit

Permalink
JIT CSE Optimization - Add a gymnasium environment for reinforcement …
Browse files Browse the repository at this point in the history
…learning (#101856)

* Initial code

* Add notes

* Add CSE_HeuristicRLHook

* Move metric print location, double -> int

* Produce non-viable entries, fix output issue

* Shuffle features by type

* Initial JitEnv - not yet working

* Change to snake_case

* Initial RL implementation with stable-baselines3

* Enable parallel processing, fix some errors

* Clean up train.py, allow algorithm selection

* Fix paths

* Fix issue with null result

* Save method indexes

* Check if process is still running

* Up argument count before warning

* Track more statistics on tensorboard

* Fix an issue where we didn't let the model know it shouldn't pick something

* Reward improvements

- Scale up rewards.
- Clamp rewards to [-1, 1]
- Reward/penalize when complete if there are better/worse CSEs (this is very slow)
- Reward when complete based on whether we beat the heuristic or not

* Update jitenv.py to remove unused import

* Fix inverted graph

* Split data into test/train

* Refactor for clarity

* Use numpy for randomness

* Add open questions

* Fix a couple of model saving issues

* Refactor and cleanup

* Add evaluate.py

* Fix inverted test/train

* Add a way to get the probabilities of actions

* Rename file

* Clean up imports

* Changed action space

- 0 to action_space.n-2 are now the CSEs to apply instead of adding and subtracting 1 to the action.
- 0 no longer means terminate, instead the action from the model of n-1 is the terminate signal.  This is not passed to the JIT.

* Add field validator for perf_score

This shouldn't happen but it's important enough to validate

* Update applicability to ensure we have at least enough viable candidates and not more than total

* Fix a few bugs with evaluate

* Fix test/train split, some extra output

* Remove dead code, simplify format

* Rename JitEnv -> JitCseEnv

* More renames

* Try to factor the observaiton space

* Fix test/train split

* Reward cleanup

- Split reward function into shallow and deep.

* Remove 0 perfscore check

* Enable deep rewards

* Fix issue where jit failed and produced None method

* Simplify deeper rewards

* Update todo

* Add reward customization

* Clean up __all__

* Fix issue where we would JIT the first CSE candidate in reset

This was leftover code from the previous design of RLHook.

* Add two new features, emit selected sequence

* Jit one less method per cse chosen in deep rewards

* Use info dictionary instead of a specific state

* Fix segfault due to null variable

* Add superpmi_context

Getting the code well-factored so it's easy to modify.

* Add tensorboard entry for invalid choices, clear results

* Close the environment

* Add documentation for JIT changes

* Rename method

* Normalize observation

* Set return type hint for clarity

* Add RemoveFeaturesWrapper

* Update docstring

* Rename function

* Move feature normalization to a wrapper

- Also better wrapper factoring.

* Remove import

* Fix warning

* Fix Windows issue

* Properly log when using A2C

* Add readme

* Change argument name

* Remove whitespace change

* Format fixes

* Fix formatting

* Update readme:  Fix grammar, add a note about evaluation

* Fixed incorrect filename in readme

* Save more data to .json in preparation of other model kinds
  • Loading branch information
leculver authored May 9, 2024
1 parent ce2364c commit 279dbe1
Show file tree
Hide file tree
Showing 20 changed files with 1,915 additions and 1 deletion.
1 change: 1 addition & 0 deletions src/coreclr/jit/compiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -2541,6 +2541,7 @@ class Compiler
friend class CSE_HeuristicReplay;
friend class CSE_HeuristicRL;
friend class CSE_HeuristicParameterized;
friend class CSE_HeuristicRLHook;
friend class CSE_Heuristic;
friend class CodeGenInterface;
friend class CodeGen;
Expand Down
9 changes: 9 additions & 0 deletions src/coreclr/jit/jitconfigvalues.h
Original file line number Diff line number Diff line change
Expand Up @@ -498,6 +498,15 @@ CONFIG_STRING(JitRLCSEAlpha, W("JitRLCSEAlpha"))
// If nonzero, dump candidate feature values
CONFIG_INTEGER(JitRLCSECandidateFeatures, W("JitRLCSECandidateFeatures"), 0)

// Enable CSE_HeuristicRLHook
CONFIG_INTEGER(JitRLHook, W("JitRLHook"), 0) // If 1, emit RL callbacks

// If 1, emit feature column names
CONFIG_INTEGER(JitRLHookEmitFeatureNames, W("JitRLHookEmitFeatureNames"), 0)

// A list of CSEs to choose, in the order they should be applied.
CONFIG_STRING(JitRLHookCSEDecisions, W("JitRLHookCSEDecisions"))

#if !defined(DEBUG) && !defined(_DEBUG)
RELEASE_CONFIG_INTEGER(JitEnableNoWayAssert, W("JitEnableNoWayAssert"), 0)
#else // defined(DEBUG) || defined(_DEBUG)
Expand Down
305 changes: 304 additions & 1 deletion src/coreclr/jit/optcse.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2973,6 +2973,298 @@ void CSE_HeuristicParameterized::DumpChoices(ArrayStack<Choice>& choices, CSEdsc

#ifdef DEBUG

//------------------------------------------------------------------------
// CSE_HeuristicRLHook: a generic 'hook' for driving CSE decisions out of
// process using reinforcement learning
//
// Arguments;
// pCompiler - compiler instance
//
// Notes:
// This creates a hook to control CSE decisions from an external process
// when JitRLHook=1 is set. This will cause the JIT to emit a series of
// feature building blocks for each CSE in the method. Feature names for
// these values can be found by setting JitRLHookEmitFeatureNames=1. To
// control the CSE decisions, set JitRLHookCSEDecisions with a sequence
// of CSE indices to apply.
//
// This hook is only available in debug/checked builds, and does not
// contain any machine learning code.
//
CSE_HeuristicRLHook::CSE_HeuristicRLHook(Compiler* pCompiler)
: CSE_HeuristicCommon(pCompiler)
{
}

//------------------------------------------------------------------------
// ConsiderTree: check if this tree can be a CSE candidate
//
// Arguments:
// tree - tree in question
// isReturn - true if tree is part of a return statement
//
// Returns:
// true if this tree can be a CSE
bool CSE_HeuristicRLHook::ConsiderTree(GenTree* tree, bool isReturn)
{
return CanConsiderTree(tree, isReturn);
}

//------------------------------------------------------------------------
// ConsiderCandidates: examine candidates and perform CSEs.
// This simply defers to the JitRLHookCSEDecisions config value.
//
void CSE_HeuristicRLHook::ConsiderCandidates()
{
if (JitConfig.JitRLHookCSEDecisions() != nullptr)
{
ConfigIntArray JitRLHookCSEDecisions;
JitRLHookCSEDecisions.EnsureInit(JitConfig.JitRLHookCSEDecisions());

unsigned cnt = m_pCompiler->optCSECandidateCount;
for (unsigned i = 0; i < JitRLHookCSEDecisions.GetLength(); i++)
{
const int index = JitRLHookCSEDecisions.GetData()[i];
if ((index < 0) || (index >= (int)cnt))
{
JITDUMP("Invalid candidate number %d\n", index + 1);
continue;
}

CSEdsc* const dsc = m_pCompiler->optCSEtab[index];
if (!dsc->IsViable())
{
JITDUMP("Abandoned " FMT_CSE " -- not viable\n", dsc->csdIndex);
continue;
}

const int attempt = m_pCompiler->optCSEattempt++;
CSE_Candidate candidate(this, dsc);

JITDUMP("\nRLHook attempting " FMT_CSE "\n", candidate.CseIndex());
JITDUMP("CSE Expression : \n");
JITDUMPEXEC(m_pCompiler->gtDispTree(candidate.Expr()));
JITDUMP("\n");

PerformCSE(&candidate);
madeChanges = true;
}
}
}

//------------------------------------------------------------------------
// DumpMetrics: write out features for each CSE candidate
// Format:
// featureNames <comma separated list of feature names>
// features #<CSE index>,<comma separated list of feature values>
// seq <comma separated list of CSE indices>
//
// Notes:
// featureNames are emitted only if JitRLHookEmitFeatureNames is set.
// features are 0 indexed, and the index is the first value, following #.
// seq is a comma separated list of CSE indices that were applied, or
// omitted if none were selected
//
void CSE_HeuristicRLHook::DumpMetrics()
{
// Feature names, if requested
if (JitConfig.JitRLHookEmitFeatureNames() > 0)
{
printf(" featureNames ");
for (int i = 0; i < maxFeatures; i++)
{
printf("%s%s", (i == 0) ? "" : ",", s_featureNameAndType[i]);
}
}

// features
for (unsigned i = 0; i < m_pCompiler->optCSECandidateCount; i++)
{
CSEdsc* const cse = m_pCompiler->optCSEtab[i];

int features[maxFeatures];
GetFeatures(cse, features);

printf(" features #%i", cse->csdIndex);
for (int j = 0; j < maxFeatures; j++)
{
printf(",%d", features[j]);
}
}

// The selected sequence of CSEs that were applied
if (JitConfig.JitRLHookCSEDecisions() != nullptr)
{
ConfigIntArray JitRLHookCSEDecisions;
JitRLHookCSEDecisions.EnsureInit(JitConfig.JitRLHookCSEDecisions());

if (JitRLHookCSEDecisions.GetLength() > 0)
{
printf(" seq ");
for (unsigned i = 0; i < JitRLHookCSEDecisions.GetLength(); i++)
{
printf("%s%d", (i == 0) ? "" : ",", JitRLHookCSEDecisions.GetData()[i]);
}
}
}
}

//------------------------------------------------------------------------
// GetFeatures: extract features for this CSE
// Arguments:
// cse - cse descriptor
// features - array to fill in with feature values, this must be of length
// maxFeatures or greater
//
// Notes:
// Features are intended to be building blocks of "real" features that
// are further defined and refined in the machine learning model. That
// means that each "feature" here is a simple value and not a composite
// of multiple values.
//
// Features do not need to be stable across builds, they can be changed,
// added, or removed. However, the corresponding code needs to be updated
// to match: src/coreclr/scripts/cse_ml/jitml/method_context.py
// See src/coreclr/scripts/cse_ml/README.md for more information.
//
void CSE_HeuristicRLHook::GetFeatures(CSEdsc* cse, int* features)
{
assert(cse != nullptr);
assert(features != nullptr);
CSE_Candidate candidate(this, cse);

int enregCount = 0;
for (unsigned trackedIndex = 0; trackedIndex < m_pCompiler->lvaTrackedCount; trackedIndex++)
{
LclVarDsc* varDsc = m_pCompiler->lvaGetDescByTrackedIndex(trackedIndex);
var_types varTyp = varDsc->TypeGet();

// Locals with no references aren't enregistered
if (varDsc->lvRefCnt() == 0)
{
continue;
}

// Some LclVars always have stack homes
if (varDsc->lvDoNotEnregister)
{
continue;
}

if (!varTypeIsFloating(varTyp))
{
enregCount++; // The primitive types, including TYP_SIMD types use one register

#ifndef TARGET_64BIT
if (varTyp == TYP_LONG)
{
enregCount++; // on 32-bit targets longs use two registers
}
#endif
}
}

const unsigned numBBs = m_pCompiler->fgBBcount;
bool isMakeCse = false;
unsigned minPostorderNum = numBBs;
unsigned maxPostorderNum = 0;
BasicBlock* minPostorderBlock = nullptr;
BasicBlock* maxPostorderBlock = nullptr;
for (treeStmtLst* treeList = cse->csdTreeList; treeList != nullptr; treeList = treeList->tslNext)
{
BasicBlock* const treeBlock = treeList->tslBlock;
unsigned postorderNum = treeBlock->bbPostorderNum;
if (postorderNum < minPostorderNum)
{
minPostorderNum = postorderNum;
minPostorderBlock = treeBlock;
}

if (postorderNum > maxPostorderNum)
{
maxPostorderNum = postorderNum;
maxPostorderBlock = treeBlock;
}

isMakeCse |= ((treeList->tslTree->gtFlags & GTF_MAKE_CSE) != 0);
}

const unsigned blockSpread = maxPostorderNum - minPostorderNum;

int type = rlHookTypeOther;
if (candidate.Expr()->TypeIs(TYP_INT))
{
type = rlHookTypeInt;
}
else if (candidate.Expr()->TypeIs(TYP_LONG))
{
type = rlHookTypeLong;
}
else if (candidate.Expr()->TypeIs(TYP_FLOAT))
{
type = rlHookTypeFloat;
}
else if (candidate.Expr()->TypeIs(TYP_DOUBLE))
{
type = rlHookTypeDouble;
}
else if (candidate.Expr()->TypeIs(TYP_STRUCT))
{
type = rlHookTypeStruct;
}

#ifdef FEATURE_SIMD
else if (varTypeIsSIMD(candidate.Expr()->TypeGet()))
{
type = rlHookTypeSimd;
}
#ifdef TARGET_XARCH
else if (candidate.Expr()->TypeIs(TYP_SIMD32, TYP_SIMD64))
{
type = rlHookTypeSimd;
}
#endif
#endif

int i = 0;
features[i++] = type;
features[i++] = cse->IsViable() ? 1 : 0;
features[i++] = cse->csdLiveAcrossCall ? 1 : 0;
features[i++] = cse->csdTree->OperIsConst() ? 1 : 0;
features[i++] = cse->csdIsSharedConst ? 1 : 0;
features[i++] = isMakeCse ? 1 : 0;
features[i++] = ((cse->csdTree->gtFlags & GTF_CALL) != 0) ? 1 : 0;
features[i++] = cse->csdTree->OperIs(GT_ADD, GT_NOT, GT_MUL, GT_LSH) ? 1 : 0;
features[i++] = cse->csdTree->GetCostEx();
features[i++] = cse->csdTree->GetCostSz();
features[i++] = cse->csdUseCount;
features[i++] = cse->csdDefCount;
features[i++] = (int)cse->csdUseWtCnt;
features[i++] = (int)cse->csdDefWtCnt;
features[i++] = cse->numDistinctLocals;
features[i++] = cse->numLocalOccurrences;
features[i++] = numBBs;
features[i++] = blockSpread;
features[i++] = enregCount;

assert(i <= maxFeatures);

for (; i < maxFeatures; i++)
{
features[i] = 0;
}
}

// These need to match the features above, and match the field name of MethodContext
// in src/coreclr/scripts/cse_ml/jitml/method_context.py
const char* const CSE_HeuristicRLHook::s_featureNameAndType[] = {
"type", "viable", "live_across_call", "const",
"shared_const", "make_cse", "has_call", "containable",
"cost_ex", "cost_sz", "use_count", "def_count",
"use_wt_cnt", "def_wt_cnt", "distinct_locals", "local_occurrences",
"bb_count", "block_spread", "enreg_count",
};

//------------------------------------------------------------------------
// CSE_HeuristicRL: construct RL CSE heuristic
//
Expand Down Expand Up @@ -5165,9 +5457,20 @@ CSE_HeuristicCommon* Compiler::optGetCSEheuristic()

// Enable optional policies
//
// RL takes precedence
// RL hook takes precedence
//
if (optCSEheuristic == nullptr)
{
bool useRLHook = (JitConfig.JitRLHook() > 0);

if (useRLHook)
{
optCSEheuristic = new (this, CMK_CSE) CSE_HeuristicRLHook(this);
}
}

// then RL
if (optCSEheuristic == nullptr)
{
bool useRLHeuristic = (JitConfig.JitRLCSE() != nullptr);

Expand Down
43 changes: 43 additions & 0 deletions src/coreclr/jit/optcse.h
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,49 @@ class CSE_HeuristicParameterized : public CSE_HeuristicCommon

#ifdef DEBUG

// General Reinforcement Learning CSE heuristic hook.
//
// Produces a wide set of data to train a RL model.
// Consumes the decisions made by a model to perform CSEs.
//
class CSE_HeuristicRLHook : public CSE_HeuristicCommon
{
private:
static const char* const s_featureNameAndType[];

void GetFeatures(CSEdsc* cse, int* features);

enum
{
maxFeatures = 19,
};

enum
{
rlHookTypeOther = 0,
rlHookTypeInt = 1,
rlHookTypeLong = 2,
rlHookTypeFloat = 3,
rlHookTypeDouble = 4,
rlHookTypeStruct = 5,
rlHookTypeSimd = 6,
};

public:
CSE_HeuristicRLHook(Compiler*);
void ConsiderCandidates();
bool ConsiderTree(GenTree* tree, bool isReturn);

const char* Name() const
{
return "RL Hook CSE Heuristic";
}

#ifdef DEBUG
virtual void DumpMetrics();
#endif
};

// Reinforcement Learning CSE heuristic
//
// Uses a "linear" feature model with
Expand Down
2 changes: 2 additions & 0 deletions src/coreclr/scripts/cse_ml/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# The root .gitignore doens't mark this as ignored:
__pycache__
Loading

0 comments on commit 279dbe1

Please sign in to comment.