diff --git a/include/dmlc/any.h b/include/dmlc/any.h index 5707a363db99..d32e44e48c4d 100644 --- a/include/dmlc/any.h +++ b/include/dmlc/any.h @@ -259,7 +259,8 @@ inline const std::type_info& any::type() const { template inline void any::check_type() const { CHECK(type_ != nullptr) - << "The any container is empty"; + << "The any container is empty" + << " requested=" << typeid(T).name(); CHECK(type_->ptype_info == &typeid(T)) << "The stored type mismatch" << " stored=" << type_->ptype_info->name() diff --git a/include/dmlc/parameter.h b/include/dmlc/parameter.h index 2fbab2a44e32..0222df60664d 100644 --- a/include/dmlc/parameter.h +++ b/include/dmlc/parameter.h @@ -57,6 +57,14 @@ class FieldEntry; // forward declare ParamManagerSingleton template struct ParamManagerSingleton; + +/*! \brief option in parameter initialization */ +enum ParamInitOption { + /*! \brief allow unknown parameters */ + kAllowUnknown, + /*! \brief need to match exact parameters */ + kAllMatch +}; } // namespace parameter /*! * \brief Information about a parameter field in string representations. @@ -108,13 +116,17 @@ struct Parameter { * and throw error if something wrong happens. * * \param kwargs map of keyword arguments, or vector of pairs + * \parma option The option on initialization. * \tparam Container container type * \throw ParamError when something go wrong. */ template - inline void Init(const Container &kwargs) { + inline void Init(const Container &kwargs, + parameter::ParamInitOption option = parameter::kAllowUnknown) { PType::__MANAGER__()->RunInit(static_cast(this), - kwargs.begin(), kwargs.end(), NULL); + kwargs.begin(), kwargs.end(), + NULL, + option == parameter::kAllowUnknown); } /*! * \brief initialize the parameter by keyword arguments. @@ -130,7 +142,8 @@ struct Parameter { InitAllowUnknown(const Container &kwargs) { std::vector > unknown; PType::__MANAGER__()->RunInit(static_cast(this), - kwargs.begin(), kwargs.end(), &unknown); + kwargs.begin(), kwargs.end(), + &unknown, true); return unknown; } /*! @@ -355,7 +368,8 @@ class ParamManager { inline void RunInit(void *head, RandomAccessIterator begin, RandomAccessIterator end, - std::vector > *unknown_args) const { + std::vector > *unknown_args, + bool allow_unknown) const { std::set selected_args; for (RandomAccessIterator it = begin; it != end; ++it) { FieldAccessEntry *e = Find(it->first); @@ -367,11 +381,13 @@ class ParamManager { if (unknown_args != NULL) { unknown_args->push_back(*it); } else { - std::ostringstream os; - os << "Cannot find argument \'" << it->first << "\', Possible Arguments:\n"; - os << "----------------\n"; - PrintDocString(os); - throw dmlc::ParamError(os.str()); + if (!allow_unknown) { + std::ostringstream os; + os << "Cannot find argument \'" << it->first << "\', Possible Arguments:\n"; + os << "----------------\n"; + PrintDocString(os); + throw dmlc::ParamError(os.str()); + } } } } diff --git a/test/parameter_test.cc b/test/parameter_test.cc index a3411adcf7f8..c49b8f5288ba 100644 --- a/test/parameter_test.cc +++ b/test/parameter_test.cc @@ -46,16 +46,9 @@ int main(int argc, char *argv[]) { printf("Parameters\n-----------\n%s", Param::__DOC__().c_str()); std::vector > unknown; unknown = param.InitAllowUnknown(kwargs); - unknown = param2.InitAllowUnknown(unknown); + param2.Init(unknown); + - if (unknown.size() != 0) { - std::ostringstream os; - os << "Cannot find argument \'" << unknown[0].first << "\', Possible Arguments:\n"; - os << "----------------\n"; - os << param.__DOC__(); - os << param2.__DOC__(); - throw dmlc::ParamError(os.str()); - } printf("-----\n"); printf("param.num_hidden=%d\n", param.num_hidden); printf("param.learning_rate=%f\n", param.learning_rate);