1/*!
2 * Copyright (c) 2015 by Contributors
3 * \file parameter.h
4 * \brief Provide lightweight util to do parameter setup and checking.
5 */
6#ifndef DMLC_PARAMETER_H_
7#define DMLC_PARAMETER_H_
8
9#include <cstddef>
10#include <cstdlib>
11#include <cmath>
12#include <sstream>
13#include <limits>
14#include <map>
15#include <set>
16#include <typeinfo>
17#include <string>
18#include <vector>
19#include <algorithm>
20#include <utility>
21#include <stdexcept>
22#include <iostream>
23#include <iomanip>
24#include <cerrno>
25#include "./base.h"
26#include "./json.h"
27#include "./logging.h"
28#include "./type_traits.h"
29#include "./optional.h"
30#include "./strtonum.h"
31
32namespace dmlc {
33// this file is backward compatible with non-c++11
34/*! \brief Error throwed by parameter checking */
35struct ParamError : public dmlc::Error {
36 /*!
37 * \brief constructor
38 * \param msg error message
39 */
40 explicit ParamError(const std::string &msg)
41 : dmlc::Error(msg) {}
42};
43
44/*!
45 * \brief Get environment variable with default.
46 * \param key the name of environment variable.
47 * \param default_value the default value of environment vriable.
48 * \return The value received
49 */
50template<typename ValueType>
51inline ValueType GetEnv(const char *key,
52 ValueType default_value);
53/*!
54 * \brief Set environment variable.
55 * \param key the name of environment variable.
56 * \param value the new value for key.
57 * \return The value received
58 */
59template<typename ValueType>
60inline void SetEnv(const char *key,
61 ValueType value);
62
63/*! \brief internal namespace for parameter manangement */
64namespace parameter {
65// forward declare ParamManager
66class ParamManager;
67// forward declare FieldAccessEntry
68class FieldAccessEntry;
69// forward declare FieldEntry
70template<typename DType>
71class FieldEntry;
72// forward declare ParamManagerSingleton
73template<typename PType>
74struct ParamManagerSingleton;
75
76/*! \brief option in parameter initialization */
77enum ParamInitOption {
78 /*! \brief allow unknown parameters */
79 kAllowUnknown,
80 /*! \brief need to match exact parameters */
81 kAllMatch,
82 /*! \brief allow unmatched hidden field with format __*__ */
83 kAllowHidden
84};
85} // namespace parameter
86/*!
87 * \brief Information about a parameter field in string representations.
88 */
89struct ParamFieldInfo {
90 /*! \brief name of the field */
91 std::string name;
92 /*! \brief type of the field in string format */
93 std::string type;
94 /*!
95 * \brief detailed type information string
96 * This include the default value, enum constran and typename.
97 */
98 std::string type_info_str;
99 /*! \brief detailed description of the type */
100 std::string description;
101};
102
103/*!
104 * \brief Parameter is the base type every parameter struct should inherit from
105 * The following code is a complete example to setup parameters.
106 * \code
107 * struct Param : public dmlc::Parameter<Param> {
108 * float learning_rate;
109 * int num_hidden;
110 * std::string name;
111 * // declare parameters in header file
112 * DMLC_DECLARE_PARAMETER(Param) {
113 * DMLC_DECLARE_FIELD(num_hidden).set_range(0, 1000);
114 * DMLC_DECLARE_FIELD(learning_rate).set_default(0.01f);
115 * DMLC_DECLARE_FIELD(name).set_default("hello");
116 * }
117 * };
118 * // register it in cc file
119 * DMLC_REGISTER_PARAMETER(Param);
120 * \endcode
121 *
122 * After that, the Param struct will get all the functions defined in Parameter.
123 * \tparam PType the type of parameter struct
124 *
125 * \sa DMLC_DECLARE_FIELD, DMLC_REGISTER_PARAMETER, DMLC_DECLARE_PARAMETER
126 */
127template<typename PType>
128struct Parameter {
129 public:
130 /*!
131 * \brief initialize the parameter by keyword arguments.
132 * This function will initialize the parameter struct, check consistency
133 * and throw error if something wrong happens.
134 *
135 * \param kwargs map of keyword arguments, or vector of pairs
136 * \parma option The option on initialization.
137 * \tparam Container container type
138 * \throw ParamError when something go wrong.
139 */
140 template<typename Container>
141 inline void Init(const Container &kwargs,
142 parameter::ParamInitOption option = parameter::kAllowHidden) {
143 PType::__MANAGER__()->RunInit(static_cast<PType*>(this),
144 kwargs.begin(), kwargs.end(),
145 NULL,
146 option);
147 }
148 /*!
149 * \brief initialize the parameter by keyword arguments.
150 * This is same as Init, but allow unknown arguments.
151 *
152 * \param kwargs map of keyword arguments, or vector of pairs
153 * \tparam Container container type
154 * \throw ParamError when something go wrong.
155 * \return vector of pairs of unknown arguments.
156 */
157 template<typename Container>
158 inline std::vector<std::pair<std::string, std::string> >
159 InitAllowUnknown(const Container &kwargs) {
160 std::vector<std::pair<std::string, std::string> > unknown;
161 PType::__MANAGER__()->RunInit(static_cast<PType*>(this),
162 kwargs.begin(), kwargs.end(),
163 &unknown, parameter::kAllowUnknown);
164 return unknown;
165 }
166
167 /*!
168 * \brief Update the parameter by keyword arguments. This is same as
169 * `InitAllowUnknown', but without setting not provided parameters to their default.
170 *
171 * \tparam Container container type
172 *
173 * \param kwargs map of keyword arguments, or vector of pairs
174 *
175 * \throw ParamError when something go wrong.
176 * \return vector of pairs of unknown arguments.
177 */
178 template <typename Container>
179 std::vector<std::pair<std::string, std::string> >
180 UpdateAllowUnknown(Container const& kwargs) {
181 std::vector<std::pair<std::string, std::string> > unknown;
182 PType::__MANAGER__()->RunUpdate(static_cast<PType *>(this), kwargs.begin(),
183 kwargs.end(), parameter::kAllowUnknown,
184 &unknown, nullptr);
185 return unknown;
186 }
187
188 /*!
189 * \brief Update the dict with values stored in parameter.
190 *
191 * \param dict The dictionary to be updated.
192 * \tparam Container container type
193 */
194 template<typename Container>
195 inline void UpdateDict(Container *dict) const {
196 PType::__MANAGER__()->UpdateDict(this->head(), dict);
197 }
198 /*!
199 * \brief Return a dictionary representation of the parameters
200 * \return A dictionary that maps key -> value
201 */
202 inline std::map<std::string, std::string> __DICT__() const {
203 std::vector<std::pair<std::string, std::string> > vec
204 = PType::__MANAGER__()->GetDict(this->head());
205 return std::map<std::string, std::string>(vec.begin(), vec.end());
206 }
207 /*!
208 * \brief Write the parameters in JSON format.
209 * \param writer JSONWriter used for writing.
210 */
211 inline void Save(dmlc::JSONWriter *writer) const {
212 writer->Write(this->__DICT__());
213 }
214 /*!
215 * \brief Load the parameters from JSON.
216 * \param reader JSONReader used for loading.
217 * \throw ParamError when something go wrong.
218 */
219 inline void Load(dmlc::JSONReader *reader) {
220 std::map<std::string, std::string> kwargs;
221 reader->Read(&kwargs);
222 this->Init(kwargs);
223 }
224 /*!
225 * \brief Get the fields of the parameters.
226 * \return List of ParamFieldInfo of each field.
227 */
228 inline static std::vector<ParamFieldInfo> __FIELDS__() {
229 return PType::__MANAGER__()->GetFieldInfo();
230 }
231 /*!
232 * \brief Print docstring of the parameter
233 * \return the printed docstring
234 */
235 inline static std::string __DOC__() {
236 std::ostringstream os;
237 PType::__MANAGER__()->PrintDocString(os);
238 return os.str();
239 }
240
241 protected:
242 /*!
243 * \brief internal function to allow declare of a parameter memember
244 * \param manager the parameter manager
245 * \param key the key name of the parameter
246 * \param ref the reference to the parameter in the struct.
247 */
248 template<typename DType>
249 inline parameter::FieldEntry<DType>& DECLARE(
250 parameter::ParamManagerSingleton<PType> *manager,
251 const std::string &key, DType &ref) { // NOLINT(*)
252 parameter::FieldEntry<DType> *e =
253 new parameter::FieldEntry<DType>();
254 e->Init(key, this->head(), ref);
255 manager->manager.AddEntry(key, e);
256 return *e;
257 }
258
259 private:
260 /*! \return Get head pointer of child structure */
261 inline PType *head() const {
262 return static_cast<PType*>(const_cast<Parameter<PType>*>(this));
263 }
264};
265
266//! \cond Doxygen_Suppress
267/*!
268 * \brief macro used to declare parameter
269 *
270 * Example:
271 * \code
272 * struct Param : public dmlc::Parameter<Param> {
273 * // declare parameters in header file
274 * DMLC_DECLARE_PARAMETER(Param) {
275 * // details of declarations
276 * }
277 * };
278 * \endcode
279 *
280 * This macro need to be put in a source file so that registration only happens once.
281 * Refer to example code in Parameter for details
282 *
283 * \param PType the name of parameter struct.
284 * \sa Parameter
285 */
286#define DMLC_DECLARE_PARAMETER(PType) \
287 static ::dmlc::parameter::ParamManager *__MANAGER__(); \
288 inline void __DECLARE__(::dmlc::parameter::ParamManagerSingleton<PType> *manager) \
289
290/*!
291 * \brief macro to declare fields
292 * \param FieldName the name of the field.
293 */
294#define DMLC_DECLARE_FIELD(FieldName) this->DECLARE(manager, #FieldName, FieldName)
295
296/*!
297 * \brief macro to declare alias of a fields
298 * \param FieldName the name of the field.
299 * \param AliasName the name of the alias, must be declared after the field is declared.
300 */
301#define DMLC_DECLARE_ALIAS(FieldName, AliasName) manager->manager.AddAlias(#FieldName, #AliasName)
302
303/*!
304 * \brief Macro used to register parameter.
305 *
306 * This macro need to be put in a source file so that registeration only happens once.
307 * Refer to example code in Parameter for details
308 * \param PType the type of parameter struct.
309 * \sa Parameter
310 */
311#define DMLC_REGISTER_PARAMETER(PType) \
312 ::dmlc::parameter::ParamManager *PType::__MANAGER__() { \
313 static ::dmlc::parameter::ParamManagerSingleton<PType> inst(#PType); \
314 return &inst.manager; \
315 } \
316 static DMLC_ATTRIBUTE_UNUSED ::dmlc::parameter::ParamManager& \
317 __make__ ## PType ## ParamManager__ = \
318 (*PType::__MANAGER__()) \
319
320//! \endcond
321/*!
322 * \brief internal namespace for parameter management
323 * There is no need to use it directly in normal case
324 */
325namespace parameter {
326/*!
327 * \brief FieldAccessEntry interface to help manage the parameters
328 * Each entry can be used to access one parameter in the Parameter struct.
329 *
330 * This is an internal interface used that is used to manage parameters
331 */
332class FieldAccessEntry {
333 public:
334 FieldAccessEntry()
335 : has_default_(false), index_(0) {}
336 /*! \brief destructor */
337 virtual ~FieldAccessEntry() {}
338 /*!
339 * \brief set the default value.
340 * \param head the pointer to the head of the struct
341 * \throw error if no default is presented
342 */
343 virtual void SetDefault(void *head) const = 0;
344 /*!
345 * \brief set the parameter by string value
346 * \param head the pointer to the head of the struct
347 * \param value the value to be set
348 */
349 virtual void Set(void *head, const std::string &value) const = 0;
350 // check if value is OK
351 virtual void Check(void *head) const {}
352 /*!
353 * \brief get the string representation of value.
354 * \param head the pointer to the head of the struct
355 */
356 virtual std::string GetStringValue(void *head) const = 0;
357 /*!
358 * \brief Get field information
359 * \return the corresponding field information
360 */
361 virtual ParamFieldInfo GetFieldInfo() const = 0;
362
363 protected:
364 /*! \brief whether this parameter have default value */
365 bool has_default_;
366 /*! \brief positional index of parameter in struct */
367 size_t index_;
368 /*! \brief parameter key name */
369 std::string key_;
370 /*! \brief parameter type */
371 std::string type_;
372 /*! \brief description of the parameter */
373 std::string description_;
374 // internal offset of the field
375 ptrdiff_t offset_;
376 /*! \brief get pointer to parameter */
377 char* GetRawPtr(void* head) const {
378 return reinterpret_cast<char*>(head) + offset_;
379 }
380 /*!
381 * \brief print string representation of default value
382 * \parma os the stream to print the docstring to.
383 */
384 virtual void PrintDefaultValueString(std::ostream &os) const = 0; // NOLINT(*)
385 // allow ParamManager to modify self
386 friend class ParamManager;
387};
388
389/*!
390 * \brief manager class to handle parameter structure for each type
391 * An manager will be created for each parameter structure.
392 */
393class ParamManager {
394 public:
395 /*! \brief destructor */
396 ~ParamManager() {
397 for (size_t i = 0; i < entry_.size(); ++i) {
398 delete entry_[i];
399 }
400 }
401 /*!
402 * \brief find the access entry by parameter key
403 * \param key the key of the parameter.
404 * \return pointer to FieldAccessEntry, NULL if nothing is found.
405 */
406 inline FieldAccessEntry *Find(const std::string &key) const {
407 std::map<std::string, FieldAccessEntry*>::const_iterator it =
408 entry_map_.find(key);
409 if (it == entry_map_.end()) return NULL;
410 return it->second;
411 }
412 /*!
413 * \brief Set parameter by keyword arguments and default values.
414 * \param head head to the parameter field.
415 * \param begin begin iterator of original kwargs
416 * \param end end iterator of original kwargs
417 * \param unknown_args optional, used to hold unknown arguments
418 * When it is specified, unknown arguments will be stored into here, instead of raise an error
419 * \tparam RandomAccessIterator iterator type
420 * \throw ParamError when there is unknown argument and unknown_args == NULL, or required argument is missing.
421 */
422 template<typename RandomAccessIterator>
423 inline void RunInit(void *head,
424 RandomAccessIterator begin,
425 RandomAccessIterator end,
426 std::vector<std::pair<std::string, std::string> > *unknown_args,
427 parameter::ParamInitOption option) const {
428 std::set<FieldAccessEntry*> selected_args;
429 RunUpdate(head, begin, end, option, unknown_args, &selected_args);
430 for (auto const& kv : entry_map_) {
431 if (selected_args.find(kv.second) == selected_args.cend()) {
432 kv.second->SetDefault(head);
433 }
434 }
435 for (std::map<std::string, FieldAccessEntry*>::const_iterator it = entry_map_.begin();
436 it != entry_map_.end(); ++it) {
437 if (selected_args.count(it->second) == 0) {
438 it->second->SetDefault(head);
439 }
440 }
441 }
442 /*!
443 * \brief Update parameters by keyword arguments.
444 *
445 * \tparam RandomAccessIterator iterator type
446 * \param head head to the parameter field.
447 * \param begin begin iterator of original kwargs
448 * \param end end iterator of original kwargs
449 * \param unknown_args optional, used to hold unknown arguments
450 * When it is specified, unknown arguments will be stored into here, instead of raise an error
451 * \param selected_args The arguments used in update will be pushed into it, defaullt to nullptr.
452 * \throw ParamError when there is unknown argument and unknown_args == NULL, or required argument is missing.
453 */
454 template <typename RandomAccessIterator>
455 void RunUpdate(void *head,
456 RandomAccessIterator begin,
457 RandomAccessIterator end,
458 parameter::ParamInitOption option,
459 std::vector<std::pair<std::string, std::string> > *unknown_args,
460 std::set<FieldAccessEntry*>* selected_args = nullptr) const {
461 for (RandomAccessIterator it = begin; it != end; ++it) {
462 if (FieldAccessEntry *e = Find(it->first)) {
463 e->Set(head, it->second);
464 e->Check(head);
465 if (selected_args) {
466 selected_args->insert(e);
467 }
468 } else {
469 if (unknown_args != NULL) {
470 unknown_args->push_back(*it);
471 } else {
472 if (option != parameter::kAllowUnknown) {
473 if (option == parameter::kAllowHidden &&
474 it->first.length() > 4 &&
475 it->first.find("__") == 0 &&
476 it->first.rfind("__") == it->first.length()-2) {
477 continue;
478 }
479 std::ostringstream os;
480 os << "Cannot find argument \'" << it->first << "\', Possible Arguments:\n";
481 os << "----------------\n";
482 PrintDocString(os);
483 throw dmlc::ParamError(os.str());
484 }
485 }
486 }
487 }
488 }
489 /*!
490 * \brief internal function to add entry to manager,
491 * The manager will take ownership of the entry.
492 * \param key the key to the parameters
493 * \param e the pointer to the new entry.
494 */
495 inline void AddEntry(const std::string &key, FieldAccessEntry *e) {
496 e->index_ = entry_.size();
497 // TODO(bing) better error message
498 if (entry_map_.count(key) != 0) {
499 LOG(FATAL) << "key " << key << " has already been registered in " << name_;
500 }
501 entry_.push_back(e);
502 entry_map_[key] = e;
503 }
504 /*!
505 * \brief internal function to add entry to manager,
506 * The manager will take ownership of the entry.
507 * \param key the key to the parameters
508 * \param e the pointer to the new entry.
509 */
510 inline void AddAlias(const std::string& field, const std::string& alias) {
511 if (entry_map_.count(field) == 0) {
512 LOG(FATAL) << "key " << field << " has not been registered in " << name_;
513 }
514 if (entry_map_.count(alias) != 0) {
515 LOG(FATAL) << "Alias " << alias << " has already been registered in " << name_;
516 }
517 entry_map_[alias] = entry_map_[field];
518 }
519 /*!
520 * \brief set the name of parameter manager
521 * \param name the name to set
522 */
523 inline void set_name(const std::string &name) {
524 name_ = name;
525 }
526 /*!
527 * \brief get field information of each field.
528 * \return field information
529 */
530 inline std::vector<ParamFieldInfo> GetFieldInfo() const {
531 std::vector<ParamFieldInfo> ret(entry_.size());
532 for (size_t i = 0; i < entry_.size(); ++i) {
533 ret[i] = entry_[i]->GetFieldInfo();
534 }
535 return ret;
536 }
537 /*!
538 * \brief Print readible docstring to ostream, add newline.
539 * \parma os the stream to print the docstring to.
540 */
541 inline void PrintDocString(std::ostream &os) const { // NOLINT(*)
542 for (size_t i = 0; i < entry_.size(); ++i) {
543 ParamFieldInfo info = entry_[i]->GetFieldInfo();
544 os << info.name << " : " << info.type_info_str << '\n';
545 if (info.description.length() != 0) {
546 os << " " << info.description << '\n';
547 }
548 }
549 }
550 /*!
551 * \brief Get internal parameters in vector of pairs.
552 * \param head the head of the struct.
553 * \param skip_default skip the values that equals default value.
554 * \return the parameter dictionary.
555 */
556 inline std::vector<std::pair<std::string, std::string> > GetDict(void * head) const {
557 std::vector<std::pair<std::string, std::string> > ret;
558 for (std::map<std::string, FieldAccessEntry*>::const_iterator
559 it = entry_map_.begin(); it != entry_map_.end(); ++it) {
560 ret.push_back(std::make_pair(it->first, it->second->GetStringValue(head)));
561 }
562 return ret;
563 }
564 /*!
565 * \brief Update the dictionary with values in parameter.
566 * \param head the head of the struct.
567 * \tparam Container The container type
568 * \return the parameter dictionary.
569 */
570 template<typename Container>
571 inline void UpdateDict(void * head, Container* dict) const {
572 for (std::map<std::string, FieldAccessEntry*>::const_iterator
573 it = entry_map_.begin(); it != entry_map_.end(); ++it) {
574 (*dict)[it->first] = it->second->GetStringValue(head);
575 }
576 }
577
578 private:
579 /*! \brief parameter struct name */
580 std::string name_;
581 /*! \brief positional list of entries */
582 std::vector<FieldAccessEntry*> entry_;
583 /*! \brief map from key to entry */
584 std::map<std::string, FieldAccessEntry*> entry_map_;
585};
586
587//! \cond Doxygen_Suppress
588
589// The following piece of code will be template heavy and less documented
590// singleton parameter manager for certain type, used for initialization
591template<typename PType>
592struct ParamManagerSingleton {
593 ParamManager manager;
594 explicit ParamManagerSingleton(const std::string &param_name) {
595 PType param;
596 manager.set_name(param_name);
597 param.__DECLARE__(this);
598 }
599};
600
601// Base class of FieldEntry
602// implement set_default
603template<typename TEntry, typename DType>
604class FieldEntryBase : public FieldAccessEntry {
605 public:
606 // entry type
607 typedef TEntry EntryType;
608 // implement set value
609 void Set(void *head, const std::string &value) const override {
610 std::istringstream is(value);
611 is >> this->Get(head);
612 if (!is.fail()) {
613 while (!is.eof()) {
614 int ch = is.get();
615 if (ch == EOF) {
616 is.clear(); break;
617 }
618 if (!isspace(ch)) {
619 is.setstate(std::ios::failbit); break;
620 }
621 }
622 }
623
624 if (is.fail()) {
625 std::ostringstream os;
626 os << "Invalid Parameter format for " << key_
627 << " expect " << type_ << " but value=\'" << value<< '\'';
628 throw dmlc::ParamError(os.str());
629 }
630 }
631
632 std::string GetStringValue(void *head) const override {
633 std::ostringstream os;
634 PrintValue(os, this->Get(head));
635 return os.str();
636 }
637 ParamFieldInfo GetFieldInfo() const override {
638 ParamFieldInfo info;
639 std::ostringstream os;
640 info.name = key_;
641 info.type = type_;
642 os << type_;
643 if (has_default_) {
644 os << ',' << " optional, default=";
645 PrintDefaultValueString(os);
646 } else {
647 os << ", required";
648 }
649 info.type_info_str = os.str();
650 info.description = description_;
651 return info;
652 }
653 // implement set head to default value
654 void SetDefault(void *head) const override {
655 if (!has_default_) {
656 std::ostringstream os;
657 os << "Required parameter " << key_
658 << " of " << type_ << " is not presented";
659 throw dmlc::ParamError(os.str());
660 } else {
661 this->Get(head) = default_value_;
662 }
663 }
664 // return reference of self as derived type
665 inline TEntry &self() {
666 return *(static_cast<TEntry*>(this));
667 }
668 // implement set_default
669 inline TEntry &set_default(const DType &default_value) {
670 default_value_ = default_value;
671 has_default_ = true;
672 // return self to allow chaining
673 return this->self();
674 }
675 // implement describe
676 inline TEntry &describe(const std::string &description) {
677 description_ = description;
678 // return self to allow chaining
679 return this->self();
680 }
681 // initialization function
682 inline void Init(const std::string &key,
683 void *head, DType &ref) { // NOLINT(*)
684 this->key_ = key;
685 if (this->type_.length() == 0) {
686 this->type_ = dmlc::type_name<DType>();
687 }
688 this->offset_ = ((char*)&ref) - ((char*)head); // NOLINT(*)
689 }
690
691 protected:
692 // print the value
693 virtual void PrintValue(std::ostream &os, DType value) const { // NOLINT(*)
694 os << value;
695 }
696 void PrintDefaultValueString(std::ostream &os) const override { // NOLINT(*)
697 PrintValue(os, default_value_);
698 }
699 // get the internal representation of parameter
700 // for example if this entry corresponds field param.learning_rate
701 // then Get(&param) will return reference to param.learning_rate
702 inline DType &Get(void *head) const {
703 return *(DType*)this->GetRawPtr(head); // NOLINT(*)
704 }
705 // default value of field
706 DType default_value_;
707};
708
709// parameter base for numeric types that have range
710template<typename TEntry, typename DType>
711class FieldEntryNumeric
712 : public FieldEntryBase<TEntry, DType> {
713 public:
714 FieldEntryNumeric()
715 : has_begin_(false), has_end_(false) {}
716 // implement set_range
717 virtual TEntry &set_range(DType begin, DType end) {
718 begin_ = begin; end_ = end;
719 has_begin_ = true; has_end_ = true;
720 return this->self();
721 }
722 // implement set_range
723 virtual TEntry &set_lower_bound(DType begin) {
724 begin_ = begin; has_begin_ = true;
725 return this->self();
726 }
727 // consistency check for numeric ranges
728 virtual void Check(void *head) const {
729 FieldEntryBase<TEntry, DType>::Check(head);
730 DType v = this->Get(head);
731 if (has_begin_ && has_end_) {
732 if (v < begin_ || v > end_) {
733 std::ostringstream os;
734 os << "value " << v << " for Parameter " << this->key_
735 << " exceed bound [" << begin_ << ',' << end_ <<']' << '\n';
736 os << this->key_ << ": " << this->description_;
737 throw dmlc::ParamError(os.str());
738 }
739 } else if (has_begin_ && v < begin_) {
740 std::ostringstream os;
741 os << "value " << v << " for Parameter " << this->key_
742 << " should be greater equal to " << begin_ << '\n';
743 os << this->key_ << ": " << this->description_;
744 throw dmlc::ParamError(os.str());
745 } else if (has_end_ && v > end_) {
746 std::ostringstream os;
747 os << "value " << v << " for Parameter " << this->key_
748 << " should be smaller equal to " << end_ << '\n';
749 os << this->key_ << ": " << this->description_;
750 throw dmlc::ParamError(os.str());
751 }
752 }
753
754 protected:
755 // whether it have begin and end range
756 bool has_begin_, has_end_;
757 // data bound
758 DType begin_, end_;
759};
760
761/*!
762 * \brief FieldEntry defines parsing and checking behavior of DType.
763 * This class can be specialized to implement specific behavior of more settings.
764 * \tparam DType the data type of the entry.
765 */
766template<typename DType>
767class FieldEntry :
768 public IfThenElseType<dmlc::is_arithmetic<DType>::value,
769 FieldEntryNumeric<FieldEntry<DType>, DType>,
770 FieldEntryBase<FieldEntry<DType>, DType> >::Type {
771};
772
773// specialize define for int(enum)
774template<>
775class FieldEntry<int>
776 : public FieldEntryNumeric<FieldEntry<int>, int> {
777 public:
778 // construct
779 FieldEntry() : is_enum_(false) {}
780 // parent
781 typedef FieldEntryNumeric<FieldEntry<int>, int> Parent;
782 // override set
783 virtual void Set(void *head, const std::string &value) const {
784 if (is_enum_) {
785 std::map<std::string, int>::const_iterator it = enum_map_.find(value);
786 std::ostringstream os;
787 if (it == enum_map_.end()) {
788 os << "Invalid Input: \'" << value;
789 os << "\', valid values are: ";
790 PrintEnums(os);
791 throw dmlc::ParamError(os.str());
792 } else {
793 os << it->second;
794 Parent::Set(head, os.str());
795 }
796 } else {
797 Parent::Set(head, value);
798 }
799 }
800 virtual ParamFieldInfo GetFieldInfo() const {
801 if (is_enum_) {
802 ParamFieldInfo info;
803 std::ostringstream os;
804 info.name = key_;
805 info.type = type_;
806 PrintEnums(os);
807 if (has_default_) {
808 os << ',' << "optional, default=";
809 PrintDefaultValueString(os);
810 } else {
811 os << ", required";
812 }
813 info.type_info_str = os.str();
814 info.description = description_;
815 return info;
816 } else {
817 return Parent::GetFieldInfo();
818 }
819 }
820 // add enum
821 inline FieldEntry<int> &add_enum(const std::string &key, int value) {
822 if ((enum_map_.size() != 0 && enum_map_.count(key) != 0) || \
823 enum_back_map_.count(value) != 0) {
824 std::ostringstream os;
825 os << "Enum " << "(" << key << ": " << value << " exisit!" << ")\n";
826 os << "Enums: ";
827 for (std::map<std::string, int>::const_iterator it = enum_map_.begin();
828 it != enum_map_.end(); ++it) {
829 os << "(" << it->first << ": " << it->second << "), ";
830 }
831 throw dmlc::ParamError(os.str());
832 }
833 enum_map_[key] = value;
834 enum_back_map_[value] = key;
835 is_enum_ = true;
836 return this->self();
837 }
838
839 protected:
840 // enum flag
841 bool is_enum_;
842 // enum map
843 std::map<std::string, int> enum_map_;
844 // enum map
845 std::map<int, std::string> enum_back_map_;
846 // override print behavior
847 virtual void PrintDefaultValueString(std::ostream &os) const { // NOLINT(*)
848 os << '\'';
849 PrintValue(os, default_value_);
850 os << '\'';
851 }
852 // override print default
853 virtual void PrintValue(std::ostream &os, int value) const { // NOLINT(*)
854 if (is_enum_) {
855 CHECK_NE(enum_back_map_.count(value), 0U)
856 << "Value not found in enum declared";
857 os << enum_back_map_.at(value);
858 } else {
859 os << value;
860 }
861 }
862
863
864 private:
865 inline void PrintEnums(std::ostream &os) const { // NOLINT(*)
866 os << '{';
867 for (std::map<std::string, int>::const_iterator
868 it = enum_map_.begin(); it != enum_map_.end(); ++it) {
869 if (it != enum_map_.begin()) {
870 os << ", ";
871 }
872 os << "\'" << it->first << '\'';
873 }
874 os << '}';
875 }
876};
877
878
879// specialize define for optional<int>(enum)
880template<>
881class FieldEntry<optional<int> >
882 : public FieldEntryBase<FieldEntry<optional<int> >, optional<int> > {
883 public:
884 // construct
885 FieldEntry() : is_enum_(false) {}
886 // parent
887 typedef FieldEntryBase<FieldEntry<optional<int> >, optional<int> > Parent;
888 // override set
889 virtual void Set(void *head, const std::string &value) const {
890 if (is_enum_ && value != "None") {
891 std::map<std::string, int>::const_iterator it = enum_map_.find(value);
892 std::ostringstream os;
893 if (it == enum_map_.end()) {
894 os << "Invalid Input: \'" << value;
895 os << "\', valid values are: ";
896 PrintEnums(os);
897 throw dmlc::ParamError(os.str());
898 } else {
899 os << it->second;
900 Parent::Set(head, os.str());
901 }
902 } else {
903 Parent::Set(head, value);
904 }
905 }
906 virtual ParamFieldInfo GetFieldInfo() const {
907 if (is_enum_) {
908 ParamFieldInfo info;
909 std::ostringstream os;
910 info.name = key_;
911 info.type = type_;
912 PrintEnums(os);
913 if (has_default_) {
914 os << ',' << "optional, default=";
915 PrintDefaultValueString(os);
916 } else {
917 os << ", required";
918 }
919 info.type_info_str = os.str();
920 info.description = description_;
921 return info;
922 } else {
923 return Parent::GetFieldInfo();
924 }
925 }
926 // add enum
927 inline FieldEntry<optional<int> > &add_enum(const std::string &key, int value) {
928 CHECK_NE(key, "None") << "None is reserved for empty optional<int>";
929 if ((enum_map_.size() != 0 && enum_map_.count(key) != 0) || \
930 enum_back_map_.count(value) != 0) {
931 std::ostringstream os;
932 os << "Enum " << "(" << key << ": " << value << " exisit!" << ")\n";
933 os << "Enums: ";
934 for (std::map<std::string, int>::const_iterator it = enum_map_.begin();
935 it != enum_map_.end(); ++it) {
936 os << "(" << it->first << ": " << it->second << "), ";
937 }
938 throw dmlc::ParamError(os.str());
939 }
940 enum_map_[key] = value;
941 enum_back_map_[value] = key;
942 is_enum_ = true;
943 return this->self();
944 }
945
946 protected:
947 // enum flag
948 bool is_enum_;
949 // enum map
950 std::map<std::string, int> enum_map_;
951 // enum map
952 std::map<int, std::string> enum_back_map_;
953 // override print behavior
954 virtual void PrintDefaultValueString(std::ostream &os) const { // NOLINT(*)
955 os << '\'';
956 PrintValue(os, default_value_);
957 os << '\'';
958 }
959 // override print default
960 virtual void PrintValue(std::ostream &os, optional<int> value) const { // NOLINT(*)
961 if (is_enum_) {
962 if (!value) {
963 os << "None";
964 } else {
965 CHECK_NE(enum_back_map_.count(value.value()), 0U)
966 << "Value not found in enum declared";
967 os << enum_back_map_.at(value.value());
968 }
969 } else {
970 os << value;
971 }
972 }
973
974
975 private:
976 inline void PrintEnums(std::ostream &os) const { // NOLINT(*)
977 os << "{None";
978 for (std::map<std::string, int>::const_iterator
979 it = enum_map_.begin(); it != enum_map_.end(); ++it) {
980 os << ", ";
981 os << "\'" << it->first << '\'';
982 }
983 os << '}';
984 }
985};
986
987// specialize define for string
988template<>
989class FieldEntry<std::string>
990 : public FieldEntryBase<FieldEntry<std::string>, std::string> {
991 public:
992 // parent class
993 typedef FieldEntryBase<FieldEntry<std::string>, std::string> Parent;
994 // override set
995 virtual void Set(void *head, const std::string &value) const {
996 this->Get(head) = value;
997 }
998 // override print default
999 virtual void PrintDefaultValueString(std::ostream &os) const { // NOLINT(*)
1000 os << '\'' << default_value_ << '\'';
1001 }
1002};
1003
1004// specialize define for bool
1005template<>
1006class FieldEntry<bool>
1007 : public FieldEntryBase<FieldEntry<bool>, bool> {
1008 public:
1009 // parent class
1010 typedef FieldEntryBase<FieldEntry<bool>, bool> Parent;
1011 // override set
1012 virtual void Set(void *head, const std::string &value) const {
1013 std::string lower_case; lower_case.resize(value.length());
1014 std::transform(value.begin(), value.end(), lower_case.begin(), ::tolower);
1015 bool &ref = this->Get(head);
1016 if (lower_case == "true") {
1017 ref = true;
1018 } else if (lower_case == "false") {
1019 ref = false;
1020 } else if (lower_case == "1") {
1021 ref = true;
1022 } else if (lower_case == "0") {
1023 ref = false;
1024 } else {
1025 std::ostringstream os;
1026 os << "Invalid Parameter format for " << key_
1027 << " expect " << type_ << " but value=\'" << value<< '\'';
1028 throw dmlc::ParamError(os.str());
1029 }
1030 }
1031
1032 protected:
1033 // print default string
1034 virtual void PrintValue(std::ostream &os, bool value) const { // NOLINT(*)
1035 os << static_cast<int>(value);
1036 }
1037};
1038
1039
1040// specialize define for float. Uses stof for platform independent handling of
1041// INF, -INF, NAN, etc.
1042#if DMLC_USE_CXX11
1043template <>
1044class FieldEntry<float> : public FieldEntryNumeric<FieldEntry<float>, float> {
1045 public:
1046 // parent
1047 typedef FieldEntryNumeric<FieldEntry<float>, float> Parent;
1048 // override set
1049 virtual void Set(void *head, const std::string &value) const {
1050 size_t pos = 0; // number of characters processed by dmlc::stof()
1051 try {
1052 this->Get(head) = dmlc::stof(value, &pos);
1053 } catch (const std::invalid_argument &) {
1054 std::ostringstream os;
1055 os << "Invalid Parameter format for " << key_ << " expect " << type_
1056 << " but value=\'" << value << '\'';
1057 throw dmlc::ParamError(os.str());
1058 } catch (const std::out_of_range&) {
1059 std::ostringstream os;
1060 os << "Out of range value for " << key_ << ", value=\'" << value << '\'';
1061 throw dmlc::ParamError(os.str());
1062 }
1063 CHECK_LE(pos, value.length()); // just in case
1064 if (pos < value.length()) {
1065 std::ostringstream os;
1066 os << "Some trailing characters could not be parsed: \'"
1067 << value.substr(pos) << "\'";
1068 throw dmlc::ParamError(os.str());
1069 }
1070 }
1071
1072 protected:
1073 // print the value
1074 virtual void PrintValue(std::ostream &os, float value) const { // NOLINT(*)
1075 os << std::setprecision(std::numeric_limits<float>::max_digits10) << value;
1076 }
1077};
1078
1079// specialize define for double. Uses stod for platform independent handling of
1080// INF, -INF, NAN, etc.
1081template <>
1082class FieldEntry<double>
1083 : public FieldEntryNumeric<FieldEntry<double>, double> {
1084 public:
1085 // parent
1086 typedef FieldEntryNumeric<FieldEntry<double>, double> Parent;
1087 // override set
1088 virtual void Set(void *head, const std::string &value) const {
1089 size_t pos = 0; // number of characters processed by dmlc::stod()
1090 try {
1091 this->Get(head) = dmlc::stod(value, &pos);
1092 } catch (const std::invalid_argument &) {
1093 std::ostringstream os;
1094 os << "Invalid Parameter format for " << key_ << " expect " << type_
1095 << " but value=\'" << value << '\'';
1096 throw dmlc::ParamError(os.str());
1097 } catch (const std::out_of_range&) {
1098 std::ostringstream os;
1099 os << "Out of range value for " << key_ << ", value=\'" << value << '\'';
1100 throw dmlc::ParamError(os.str());
1101 }
1102 CHECK_LE(pos, value.length()); // just in case
1103 if (pos < value.length()) {
1104 std::ostringstream os;
1105 os << "Some trailing characters could not be parsed: \'"
1106 << value.substr(pos) << "\'";
1107 throw dmlc::ParamError(os.str());
1108 }
1109 }
1110
1111 protected:
1112 // print the value
1113 virtual void PrintValue(std::ostream &os, double value) const { // NOLINT(*)
1114 os << std::setprecision(std::numeric_limits<double>::max_digits10) << value;
1115 }
1116};
1117#endif // DMLC_USE_CXX11
1118
1119} // namespace parameter
1120//! \endcond
1121
1122// implement GetEnv
1123template<typename ValueType>
1124inline ValueType GetEnv(const char *key,
1125 ValueType default_value) {
1126 const char *val = getenv(key);
1127 // On some implementations, if the var is set to a blank string (i.e. "FOO="), then
1128 // a blank string will be returned instead of NULL. In order to be consistent, if
1129 // the environment var is a blank string, then also behave as if a null was returned.
1130 if (val == nullptr || !*val) {
1131 return default_value;
1132 }
1133 ValueType ret;
1134 parameter::FieldEntry<ValueType> e;
1135 e.Init(key, &ret, ret);
1136 e.Set(&ret, val);
1137 return ret;
1138}
1139
1140// implement SetEnv
1141template<typename ValueType>
1142inline void SetEnv(const char *key,
1143 ValueType value) {
1144 parameter::FieldEntry<ValueType> e;
1145 e.Init(key, &value, value);
1146#ifdef _WIN32
1147 _putenv_s(key, e.GetStringValue(&value).c_str());
1148#else
1149 setenv(key, e.GetStringValue(&value).c_str(), 1);
1150#endif // _WIN32
1151}
1152} // namespace dmlc
1153#endif // DMLC_PARAMETER_H_
1154