diff --git a/include/tvm/support/random_engine.h b/include/tvm/support/random_engine.h index e73c1193f4c3..6b733d074f6a 100644 --- a/include/tvm/support/random_engine.h +++ b/include/tvm/support/random_engine.h @@ -19,7 +19,8 @@ /*! * \file random_engine.h - * \brief Random number generator, for Sampler and Sampling functions. + * \brief Random number generator. It provides a generic interface consistent with + * `std::uniform_random_bit_generator` */ #ifndef TVM_SUPPORT_RANDOM_ENGINE_H_ @@ -41,10 +42,11 @@ namespace support { * included for simplification. For full member functions of std::minstd_rand, please check out the * following link: https://en.cppreference.com/w/cpp/numeric/random/linear_congruential_engine */ + class LinearCongruentialEngine { public: /*! - * \brief The result type is defined as int64_t here for meta_schedule sampler usage. + * \brief The result type is defined as uint64_t here to avoid overflow. * \note The type name is not in Google style because it is used in STL's distribution inferface. */ using result_type = uint64_t; @@ -63,13 +65,13 @@ class LinearCongruentialEngine { * \brief The minimum possible value of random state here. * \note The function name is uncapilized because it is used in STL's distribution inferface. */ - result_type min() { return 0; } + static constexpr result_type min() { return 0; } /*! * \brief The maximum possible value of random state here. * \note The function name is uncapilized because it is used in STL's distribution inferface. */ - result_type max() { return modulus - 1; } + static constexpr result_type max() { return modulus - 1; } /*! * \brief Operator to move the random state to the next and return the new random state. According diff --git a/include/tvm/tir/schedule/schedule.h b/include/tvm/tir/schedule/schedule.h index 5e223c98d74d..79fed09c3e36 100644 --- a/include/tvm/tir/schedule/schedule.h +++ b/include/tvm/tir/schedule/schedule.h @@ -19,6 +19,7 @@ #ifndef TVM_TIR_SCHEDULE_SCHEDULE_H_ #define TVM_TIR_SCHEDULE_SCHEDULE_H_ +#include #include #include @@ -118,9 +119,9 @@ class ScheduleNode : public runtime::Object { * \brief Seed the randomness * \param seed The new random seed, -1 if use device random, otherwise non-negative */ - virtual void Seed(int64_t seed = -1) { - LOG(FATAL) << "ValueError: The schedule cannot be seeded because no randomness is allowed"; - } + virtual void Seed(support::LinearCongruentialEngine::TRandState seed = -1) = 0; + /*! \brief Fork the random state */ + virtual support::LinearCongruentialEngine::TRandState ForkSeed() = 0; public: /******** Lookup/Remove random variables ********/ @@ -184,6 +185,16 @@ class ScheduleNode : public runtime::Object { public: /******** Schedule: Sampling ********/ + /*! + * \brief Sample an integer given the probability distribution + * \param candidates The candidates + * \param probs The probability distribution of the candidates + * \param decision The sampling decision + * \return The random variable sampled from candidates + */ + virtual ExprRV SampleCategorical(const Array& candidates, const Array& probs, + Optional decision = NullOpt) = 0; + /******** Schedule: Get blocks & loops ********/ /*! * \brief Retrieve a block in a specific function with its name @@ -356,6 +367,7 @@ class Schedule : public runtime::ObjectRef { /*! * \brief Construct a concrete TensorIR schedule from an IRModule * \param mod The IRModule to be scheduled + * \param seed The seed value for schedule's random state * \param debug_mask Do extra correctness checking after the class creation * and each time after calling the Replace method. * \param error_render_level The level of error rendering @@ -365,11 +377,12 @@ class Schedule : public runtime::ObjectRef { * 1) VerifySRefTree * 2) VerifyCachedFlags */ - TVM_DLL static Schedule Concrete(IRModule mod, int debug_mask, - ScheduleErrorRenderLevel error_render_level); + TVM_DLL static Schedule Concrete(IRModule mod, support::LinearCongruentialEngine::TRandState seed, + int debug_mask, ScheduleErrorRenderLevel error_render_level); /*! * \brief Construct a traced concrete TensorIR schedule from an IRModule * \param mod The IRModule to be scheduled + * \param seed The seed value for schedule's random state * \param debug_mask Do extra correctness checking after the class creation * and each time after calling the Replace method. * \param error_render_level The level of error rendering @@ -379,8 +392,8 @@ class Schedule : public runtime::ObjectRef { * 1) VerifySRefTree * 2) VerifyCachedFlags */ - TVM_DLL static Schedule Traced(IRModule mod, int debug_mask, - ScheduleErrorRenderLevel error_render_level); + TVM_DLL static Schedule Traced(IRModule mod, support::LinearCongruentialEngine::TRandState seed, + int debug_mask, ScheduleErrorRenderLevel error_render_level); TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(Schedule, runtime::ObjectRef, ScheduleNode); }; diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py index c9cbf45b9055..9433d019f9a5 100644 --- a/python/tvm/tir/schedule/schedule.py +++ b/python/tvm/tir/schedule/schedule.py @@ -79,6 +79,16 @@ def _parse_error_render_level(error_render_level: str) -> int: return _ERROR_RENDER_LEVEL.get(error_render_level) +def _parse_seed(seed: Optional[int]) -> int: + if seed is None: + return -1 + if not isinstance(seed, int): + raise TypeError(f"Expected `seed` to be int or None, but gets: {seed}") + if seed < 1 or seed > 2147483647: + raise ValueError(f"seed must be in the range [1, 2147483647], but gets: {seed}") + return seed + + @_register_object("tir.Schedule") class Schedule(Object): """The user-facing schedule class @@ -98,6 +108,7 @@ def __init__( self, mod: Union[PrimFunc, IRModule], *, + seed: Optional[int] = None, debug_mask: Union[str, int] = "none", error_render_level: str = "detail", ) -> None: @@ -107,6 +118,10 @@ def __init__( ---------- mod : Union[PrimFunc, IRModule] The IRModule or PrimFunc to be scheduled + seed: Optional[int] + The seed value for schedule's random state + Note that None and -1 means use device random, otherwise only integer between 1 and + 2147483647 is allowed. debug_mask : Union[str, int] Do extra correctness checking after the class creation and each time after calling the Replace method. @@ -130,6 +145,7 @@ def __init__( self.__init_handle_by_constructor__( _ffi_api.TracedSchedule, # type: ignore # pylint: disable=no-member _parse_mod(mod), + _parse_seed(seed), _parse_debug_mask(debug_mask), _parse_error_render_level(error_render_level), ) @@ -138,12 +154,14 @@ def __init__( def _create_non_traced( mod: Union[PrimFunc, IRModule], *, + seed: Optional[int] = None, debug_mask: Union[str, int] = "none", error_render_level: str = "detail", ) -> "Schedule": """Construct a non-traced TensorIR schedule class from an IRModule.""" return _ffi_api.ConcreteSchedule( # type: ignore # pylint: disable=no-member _parse_mod(mod), + _parse_seed(seed), _parse_debug_mask(debug_mask), _parse_error_render_level(error_render_level), ) @@ -190,6 +208,16 @@ def seed(self, seed: int) -> None: """ return _ffi_api.ScheduleSeed(self, seed) # type: ignore # pylint: disable=no-member + def fork_seed(self) -> int: + """Returns a forked random state as seed for new schedules + + Returns + ------- + seed : int + The forked random state, not the same as the current random state + """ + return _ffi_api.ScheduleForkSeed(self) # type: ignore # pylint: disable=no-member + def show(self, rand_var: RAND_VAR_TYPE) -> str: """Returns a string representation of the value that the random variable evaluates to @@ -268,6 +296,35 @@ def remove_rv(self, rand_var: RAND_VAR_TYPE) -> None: ########## Schedule: Sampling ########## + def sample_categorical( + self, + candidates: List[int], + probs: List[float], + decision: Optional[int] = None, + ) -> ExprRV: + """Sample an integer given the probability distribution + + Parameters + ---------- + candidates : List[int] + The candidates to be sampled from + probs : List[float] + The probability of each candidate + decision : Optional[int] + The sampling decision, if any + + Returns + ------- + result : ExprRV + The random variable sampled from candidates + """ + return _ffi_api.ScheduleSampleCategorical( # type: ignore # pylint: disable=no-member + self, + candidates, + probs, + decision, + ) + ########## Schedule: Get blocks & loops ########## def get_block( self, diff --git a/src/support/array.h b/src/support/array.h index 2cf416c471ec..89e17433344b 100644 --- a/src/support/array.h +++ b/src/support/array.h @@ -18,6 +18,7 @@ */ #ifndef TVM_SUPPORT_ARRAY_H_ #define TVM_SUPPORT_ARRAY_H_ +#include #include #include @@ -67,6 +68,73 @@ inline bool ArrayWithSameContent(const std::vector& a, const std::vector return true; } +/*! + * \brief Convert a tvm::runtime::Array to std::vector + * \tparam TSrc The type of elements in the source Array + * \tparam TDst The type of elements in the result vector + * \return The result vector + */ +template +std::vector AsVector(const Array& vec); + +/********** Implementation details of AsVector **********/ +namespace details { + +template +struct AsVectorImpl {}; + +template +struct AsVectorImpl { + inline std::vector operator()(const Array& vec) const { + return std::vector(vec.begin(), vec.end()); + } +}; + +template +struct AsVectorImpl { + inline std::vector operator()(const Array& vec) const { + std::vector results; + for (const TSrcObjectRef& x : vec) { + const auto* n = x.template as(); + ICHECK(n) << "TypeError: Expects IntImm, but gets: " << x->GetTypeKey(); + results.push_back(n->value); + } + return results; + } +}; + +template +struct AsVectorImpl { + inline std::vector operator()(const Array& vec) const { + std::vector results; + for (const TSrcObjectRef& x : vec) { + const auto* n = x.template as(); + ICHECK(n) << "TypeError: Expects IntImm, but gets: " << x->GetTypeKey(); + results.push_back(n->value); + } + return results; + } +}; + +template +struct AsVectorImpl { + inline std::vector operator()(const Array& array) const { + std::vector results; + for (const TSrcObjectRef& x : array) { + const auto* n = x.template as(); + ICHECK(n) << "TypeError: Expects FloatImm, but gets: " << x->GetTypeKey(); + results.push_back(n->value); + } + return results; + } +}; +} // namespace details + +template +inline std::vector AsVector(const Array& vec) { + return details::AsVectorImpl()(vec); +} + } // namespace support } // namespace tvm #endif // TVM_SUPPORT_ARRAY_H_ diff --git a/src/tir/schedule/concrete_schedule.cc b/src/tir/schedule/concrete_schedule.cc index 084d0b0eec6a..cd9aad8ae512 100644 --- a/src/tir/schedule/concrete_schedule.cc +++ b/src/tir/schedule/concrete_schedule.cc @@ -18,16 +18,19 @@ */ #include "./concrete_schedule.h" +#include + namespace tvm { namespace tir { -Schedule Schedule::Concrete(IRModule mod, int debug_mask, - ScheduleErrorRenderLevel error_render_level) { +Schedule Schedule::Concrete(IRModule mod, support::LinearCongruentialEngine::TRandState seed, + int debug_mask, ScheduleErrorRenderLevel error_render_level) { ObjectPtr n = make_object(); n->state_ = ScheduleState(mod, debug_mask); n->error_render_level_ = error_render_level; n->symbol_table_ = {}; n->analyzer_ = std::make_unique(); + support::LinearCongruentialEngine(&n->rand_state_).Seed(seed); return Schedule(std::move(n)); } @@ -208,6 +211,29 @@ Schedule ConcreteScheduleNode::Copy() const { } /******** Schedule: Schedule: Sampling ********/ + +void ConcreteScheduleNode::Seed(support::LinearCongruentialEngine::TRandState seed) { + if (seed == -1) { + seed = std::random_device()(); + } + support::LinearCongruentialEngine(&rand_state_).Seed(seed); +} + +support::LinearCongruentialEngine::TRandState ConcreteScheduleNode::ForkSeed() { + // In order for reproducibility, we computer the new seed using RNG's random state and a different + // set of parameters. Note that both 32767 and 1999999973 are prime numbers. + return (support::LinearCongruentialEngine(&rand_state_)() * 32767) % 1999999973; +} + +ExprRV ConcreteScheduleNode::SampleCategorical(const Array& candidates, + const Array& probs, + Optional decision) { + TVM_TIR_SCHEDULE_BEGIN(); + return CreateRV(tir::SampleCategorical(&this->rand_state_, candidates, probs, &decision)); + TVM_TIR_SCHEDULE_END("sample-categorical", this->error_render_level_); + throw; +} + /******** Schedule: Get blocks & loops ********/ BlockRV ConcreteScheduleNode::GetBlock(const String& name, const String& func_name) { diff --git a/src/tir/schedule/concrete_schedule.h b/src/tir/schedule/concrete_schedule.h index 97819d63edb6..0bd902d183bf 100644 --- a/src/tir/schedule/concrete_schedule.h +++ b/src/tir/schedule/concrete_schedule.h @@ -43,6 +43,8 @@ class ConcreteScheduleNode : public ScheduleNode { TSymbolTable symbol_table_; /*! \brief A persistent stateless arithmetic analyzer. */ std::unique_ptr analyzer_; + /*! \brief The value of random state for sampling. */ + support::LinearCongruentialEngine::TRandState rand_state_; public: void VisitAttrs(tvm::AttrVisitor* v) { @@ -50,6 +52,7 @@ class ConcreteScheduleNode : public ScheduleNode { // `error_render_level_` is not visited // `symbol_table_` is not visited // `analyzer_` is not visited + // `rand_state_` is not visited } virtual ~ConcreteScheduleNode() = default; @@ -58,6 +61,8 @@ class ConcreteScheduleNode : public ScheduleNode { ScheduleState state() const final { return state_; } Optional trace() const override { return NullOpt; } Schedule Copy() const override; + void Seed(support::LinearCongruentialEngine::TRandState seed = -1) final; + support::LinearCongruentialEngine::TRandState ForkSeed() final; public: /******** Lookup random variables ********/ @@ -75,6 +80,16 @@ class ConcreteScheduleNode : public ScheduleNode { public: /******** Schedule: Sampling ********/ + /*! + * \brief Sample an integer given the probability distribution + * \param candidates The candidates + * \param probs The probability distribution of the candidates + * \param decision The sampling decision, if it's given we would validate the decision, otherwise + * we would sample a decision from the distribution and set the decision accordingly. + * \return The random variable sampled from candidates + */ + ExprRV SampleCategorical(const Array& candidates, const Array& probs, + Optional decision = NullOpt) override; /******** Schedule: Get blocks & loops ********/ BlockRV GetBlock(const String& name, const String& func_name = "main") override; Array GetLoops(const BlockRV& block_rv) override; @@ -126,17 +141,11 @@ class ConcreteScheduleNode : public ScheduleNode { template inline T CreateRV(const StmtSRef& sref); /*! - * \brief Add an expr as a random variable into the symbol table - * \param expr The expr to be added to the symbol table + * \brief Add an integer as a random variable into the symbol table + * \param value The integer to be added to the symbol table * \return The new random variable created */ - inline ExprRV CreateRV(const PrimExpr& expr); - /*! - * \brief Add expr as random variables into the symbol table - * \param exprs The expr to be added to the symbol table - * \return The new random variables created - */ - inline Array CreateRV(const Array& exprs); + inline ExprRV CreateRV(int64_t value); /*! \brief Remove a random variable from the symbol table */ inline void RemoveFromSymbolTable(const ObjectRef& rv); }; @@ -251,23 +260,12 @@ inline T ConcreteScheduleNode::CreateRV(const StmtSRef& sref) { return std::move(rv); } -inline ExprRV ConcreteScheduleNode::CreateRV(const PrimExpr& expr) { - ExprRV rv; - this->symbol_table_.Set(rv, expr); +inline ExprRV ConcreteScheduleNode::CreateRV(int64_t value) { + Var rv("v" + std::to_string(this->symbol_table_.size() + 1), DataType::Int(32)); + this->symbol_table_.Set(rv, Integer(static_cast(value))); return std::move(rv); } -inline Array ConcreteScheduleNode::CreateRV(const Array& exprs) { - Array result; - result.reserve(exprs.size()); - for (const PrimExpr& expr : exprs) { - ExprRV rv; - this->symbol_table_.Set(rv, expr); - result.push_back(rv); - } - return result; -} - inline void ConcreteScheduleNode::RemoveFromSymbolTable(const ObjectRef& obj) { auto it = this->symbol_table_.find(obj); if (it != this->symbol_table_.end()) { diff --git a/src/tir/schedule/primitive.h b/src/tir/schedule/primitive.h index 2cf59f0b27c0..be33c2acca10 100644 --- a/src/tir/schedule/primitive.h +++ b/src/tir/schedule/primitive.h @@ -19,12 +19,26 @@ #ifndef TVM_TIR_SCHEDULE_PRIMITIVE_H_ #define TVM_TIR_SCHEDULE_PRIMITIVE_H_ +#include #include namespace tvm { namespace tir { /******** Schedule: Sampling ********/ +/*! + * \brief Sample once category from candidates according to the probability weights. + * \param self The schedule to update + * \param rand_state The pointer to schedule's random state + * \param candidates The candidates + * \param probs The probability distribution of the candidates + * \param decision The sampling decision, if any + * \return The random variable sampled from candidates + */ +TVM_DLL int64_t SampleCategorical(support::LinearCongruentialEngine::TRandState* rand_state, + const Array& candidates, const Array& probs, + Optional* decision); + /******** Schedule: Get blocks & loops ********/ /*! * \brief Retrieves blocks in a specific function with its name diff --git a/src/tir/schedule/primitive/sampling.cc b/src/tir/schedule/primitive/sampling.cc new file mode 100644 index 000000000000..ac40d27c4bf3 --- /dev/null +++ b/src/tir/schedule/primitive/sampling.cc @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include + +#include "../primitive.h" +#include "../utils.h" + +namespace tvm { +namespace tir { + +int64_t SampleCategorical(support::LinearCongruentialEngine::TRandState* rand_state, + const Array& candidates, const Array& probs, + Optional* decision) { + CHECK(candidates.size() == probs.size()) + << "ValueError: number of candidates does not match number of probabilities."; + int i = -1; + int n = candidates.size(); + + if (decision->defined()) { + const auto* int_imm = decision->as(); + i = int_imm->value; + CHECK(0 <= i && i < n) << "ValueError: Wrong decision value, where n = " << n + << ", but decision is: " << i; + } else { + std::vector weights = support::AsVector(probs); + std::discrete_distribution dist(weights.begin(), weights.end()); + support::LinearCongruentialEngine rand_(rand_state); + i = dist(rand_); + ICHECK(0 <= i && i < n) << "ValueError: Unexpected decision generated, where n = " << n + << ", but decision is: " << i; + } + + *decision = Integer(i); // decision is guaranteed not to be nullptr. + return candidates[i]; +} + +struct SampleCategoricalTraits : public UnpackedInstTraits { + static constexpr const char* kName = "SampleCategorical"; + static constexpr bool kIsPure = true; + + private: + static constexpr size_t kNumInputs = 0; + static constexpr size_t kNumAttrs = 2; + static constexpr size_t kNumDecisions = 1; + + static ExprRV UnpackedApplyToSchedule(Schedule sch, // + Array candidates, // + Array probs, // + Optional decision) { + return sch->SampleCategorical(candidates, probs, decision); + } + + static String UnpackedAsPython(Array outputs, // + Array candidates, // + Array probs, // + Optional decision) { + PythonAPICall py("sample_categorical"); + py.Input("candidates", candidates); + py.Input("probs", probs); + py.Decision(decision); + py.SingleOutput(outputs); + return py.Str(); + } + + friend struct UnpackedInstTraits; +}; + +TVM_REGISTER_INST_KIND_TRAITS(SampleCategoricalTraits); + +} // namespace tir +} // namespace tvm diff --git a/src/tir/schedule/schedule.cc b/src/tir/schedule/schedule.cc index 29681fdf0926..d24cdc625912 100644 --- a/src/tir/schedule/schedule.cc +++ b/src/tir/schedule/schedule.cc @@ -50,23 +50,27 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetState") // .set_body_method(&ScheduleNode::state); TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetTrace") // .set_body_method(&ScheduleNode::trace); -TVM_REGISTER_GLOBAL("tir.schedule.ScheduleSeed") // - .set_body_method(&ScheduleNode::Seed); TVM_REGISTER_GLOBAL("tir.schedule.ScheduleCopy") // .set_body_method(&ScheduleNode::Copy); +TVM_REGISTER_GLOBAL("tir.schedule.ScheduleSeed") // + .set_body_method(&ScheduleNode::Seed); +TVM_REGISTER_GLOBAL("tir.schedule.ScheduleForkSeed") // + .set_body_method(&ScheduleNode::ForkSeed); /**************** (FFI) Constructor ****************/ TVM_REGISTER_GLOBAL("tir.schedule.BlockRV").set_body_typed([]() { return BlockRV(); }); TVM_REGISTER_GLOBAL("tir.schedule.LoopRV").set_body_typed([]() { return LoopRV(); }); TVM_REGISTER_GLOBAL("tir.schedule.ConcreteSchedule") - .set_body_typed([](IRModule mod, int debug_mask, int error_render_level) -> Schedule { - return Schedule::Concrete(mod, debug_mask, + .set_body_typed([](IRModule mod, support::LinearCongruentialEngine::TRandState seed, + int debug_mask, int error_render_level) -> Schedule { + return Schedule::Concrete(mod, debug_mask, seed, static_cast(error_render_level)); }); TVM_REGISTER_GLOBAL("tir.schedule.TracedSchedule") - .set_body_typed([](IRModule mod, int debug_mask, int error_render_level) -> Schedule { - return Schedule::Traced(mod, debug_mask, + .set_body_typed([](IRModule mod, support::LinearCongruentialEngine::TRandState seed, + int debug_mask, int error_render_level) -> Schedule { + return Schedule::Traced(mod, seed, debug_mask, static_cast(error_render_level)); }); @@ -117,6 +121,8 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleRemoveRV") }); /******** (FFI) Sampling ********/ +TVM_REGISTER_GLOBAL("tir.schedule.ScheduleSampleCategorical") + .set_body_method(&ScheduleNode::SampleCategorical); /******** (FFI) Get blocks & loops ********/ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetBlock") .set_body_method(&ScheduleNode::GetBlock); diff --git a/src/tir/schedule/traced_schedule.cc b/src/tir/schedule/traced_schedule.cc index ae6a194b9888..af4a6588f064 100644 --- a/src/tir/schedule/traced_schedule.cc +++ b/src/tir/schedule/traced_schedule.cc @@ -21,14 +21,15 @@ namespace tvm { namespace tir { -Schedule Schedule::Traced(IRModule mod, int debug_mask, - ScheduleErrorRenderLevel error_render_level) { +Schedule Schedule::Traced(IRModule mod, support::LinearCongruentialEngine::TRandState seed, + int debug_mask, ScheduleErrorRenderLevel error_render_level) { ObjectPtr n = make_object(); n->state_ = ScheduleState(mod, debug_mask); n->error_render_level_ = error_render_level; n->symbol_table_ = {}; n->analyzer_ = std::make_unique(); n->trace_ = Trace(); + support::LinearCongruentialEngine(&n->rand_state_).Seed(seed); return Schedule(std::move(n)); } @@ -42,6 +43,19 @@ Schedule TracedScheduleNode::Copy() const { } /******** Schedule: Sampling ********/ +ExprRV TracedScheduleNode::SampleCategorical(const Array& candidates, + const Array& probs, + Optional decision) { + ExprRV result = + CreateRV(tir::SampleCategorical(&this->rand_state_, candidates, probs, &decision)); + static const InstructionKind& kind = InstructionKind::Get("SampleCategorical"); + trace_->Append(/*inst=*/Instruction(/*kind=*/kind, // + /*inputs=*/{}, + /*attrs=*/{candidates, probs}, + /*outputs=*/{result}), + /*decision=*/decision); + return result; +} /******** Schedule: Get blocks & loops ********/ diff --git a/src/tir/schedule/traced_schedule.h b/src/tir/schedule/traced_schedule.h index 11128ba32fad..48dadbc03b3b 100644 --- a/src/tir/schedule/traced_schedule.h +++ b/src/tir/schedule/traced_schedule.h @@ -47,6 +47,16 @@ class TracedScheduleNode : public ConcreteScheduleNode { public: /******** Schedule: Sampling ********/ + /*! + * \brief Sample an integer given the probability distribution + * \param candidates The candidates + * \param probs The probability distribution of the candidates + * \param decision The sampling decision, if it's given we would validate the decision, otherwise + * we would sample a decision from the distribution and set the decision accordingly. + * \return The random variable sampled from candidates + */ + ExprRV SampleCategorical(const Array& candidates, const Array& probs, + Optional decision = NullOpt) final; /******** Schedule: Get blocks & loops ********/ BlockRV GetBlock(const String& name, const String& func_name = "main") final; diff --git a/tests/python/unittest/test_tir_schedule_sampling.py b/tests/python/unittest/test_tir_schedule_sampling.py new file mode 100644 index 000000000000..2bfd68663c99 --- /dev/null +++ b/tests/python/unittest/test_tir_schedule_sampling.py @@ -0,0 +1,89 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import sys +from collections import defaultdict + +import pytest +import tvm +from tvm import tir +from tvm.script import ty +from tvm.tir.schedule.testing import verify_trace_roundtrip +from tvm.tir.schedule import Trace + + +# pylint: disable=no-member,invalid-name,unused-variable + + +@tvm.script.tir +def elementwise(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128, 128)) + B = tir.match_buffer(b, (128, 128, 128)) + with tir.block([128, 128, 128], "B") as [vi, vj, vk]: + B[vi, vj, vk] = A[vi, vj, vk] * 2.0 + + +# pylint: enable=no-member,invalid-name,unused-variable + + +def test_sample_categorical(): + """Test sample categprical sampling function""" + n = 1000 + sch = tir.Schedule(elementwise, seed=42, debug_mask="all") + counter = defaultdict(int) + candidates = [5, 2, 7, 1] + probs = [0.15, 0.55, 0.05, 0.25] + for _ in range(n): + v = sch.get(sch.sample_categorical(candidates, probs)) + counter[v] += 1 + for i, prob in enumerate(probs): + assert (prob - 0.07) * n <= counter[candidates[i]] <= (prob + 0.07) * n + verify_trace_roundtrip(sch, mod=elementwise) + + +def test_sample_categorical_copy(): + """Check the random variable sampling results after schedule copy""" + n = 100 + sch = tir.Schedule(elementwise, seed=42, debug_mask="all") + candidates = [1, 2, 3, 4] + probs = [0.1, 0.2, 0.3, 0.4] + rv_decisions = [] + for _ in range(n): + rv = sch.sample_categorical(candidates, probs) # pylint: disable=invalid-name + rv_decisions.append((rv, sch.get(rv))) + sch_copy = sch.copy() + for rv, decision in rv_decisions: # pylint: disable=invalid-name + decision_copy = sch_copy.get(rv) + assert int(decision) == int(decision_copy) + + +def test_sample_categorical_serialize(): + """Check the random variable sampling results after schedule serialization""" + n = 100 + sch = tir.Schedule(elementwise, seed=42, debug_mask="all") + candidates = [5, 6, 7, 8] + probs = [0.23, 0.19, 0.37, 0.21] + decisions = [] + for _ in range(n): + rv = sch.get(sch.sample_categorical(candidates, probs)) # pylint: disable=invalid-name + decisions.append(rv) + new_sch = verify_trace_roundtrip(sch, mod=elementwise) + for i, new_inst in enumerate(new_sch.trace.insts): + assert decisions[i] == candidates[new_sch.trace.decisions[new_inst].value] + + +if __name__ == "__main__": + sys.exit(pytest.main([__file__] + sys.argv[1:]))