1 | /*! |
2 | * Copyright (c) 2015 by Contributors |
3 | * \file parameter.h |
4 | * \brief Provide lightweight util to do parameter setup and checking. |
5 | */ |
6 | #ifndef DMLC_PARAMETER_H_ |
7 | #define DMLC_PARAMETER_H_ |
8 | |
9 | #include <cstddef> |
10 | #include <cstdlib> |
11 | #include <cmath> |
12 | #include <sstream> |
13 | #include <limits> |
14 | #include <map> |
15 | #include <set> |
16 | #include <typeinfo> |
17 | #include <string> |
18 | #include <vector> |
19 | #include <algorithm> |
20 | #include <utility> |
21 | #include <stdexcept> |
22 | #include <iostream> |
23 | #include <iomanip> |
24 | #include <cerrno> |
25 | #include "./base.h" |
26 | #include "./json.h" |
27 | #include "./logging.h" |
28 | #include "./type_traits.h" |
29 | #include "./optional.h" |
30 | #include "./strtonum.h" |
31 | |
32 | namespace dmlc { |
33 | // this file is backward compatible with non-c++11 |
34 | /*! \brief Error throwed by parameter checking */ |
35 | struct ParamError : public dmlc::Error { |
36 | /*! |
37 | * \brief constructor |
38 | * \param msg error message |
39 | */ |
40 | explicit ParamError(const std::string &msg) |
41 | : dmlc::Error(msg) {} |
42 | }; |
43 | |
44 | /*! |
45 | * \brief Get environment variable with default. |
46 | * \param key the name of environment variable. |
47 | * \param default_value the default value of environment vriable. |
48 | * \return The value received |
49 | */ |
50 | template<typename ValueType> |
51 | inline ValueType GetEnv(const char *key, |
52 | ValueType default_value); |
53 | /*! |
54 | * \brief Set environment variable. |
55 | * \param key the name of environment variable. |
56 | * \param value the new value for key. |
57 | * \return The value received |
58 | */ |
59 | template<typename ValueType> |
60 | inline void SetEnv(const char *key, |
61 | ValueType value); |
62 | |
63 | /*! \brief internal namespace for parameter manangement */ |
64 | namespace parameter { |
65 | // forward declare ParamManager |
66 | class ParamManager; |
67 | // forward declare FieldAccessEntry |
68 | class FieldAccessEntry; |
69 | // forward declare FieldEntry |
70 | template<typename DType> |
71 | class FieldEntry; |
72 | // forward declare ParamManagerSingleton |
73 | template<typename PType> |
74 | struct ParamManagerSingleton; |
75 | |
76 | /*! \brief option in parameter initialization */ |
77 | enum ParamInitOption { |
78 | /*! \brief allow unknown parameters */ |
79 | kAllowUnknown, |
80 | /*! \brief need to match exact parameters */ |
81 | kAllMatch, |
82 | /*! \brief allow unmatched hidden field with format __*__ */ |
83 | kAllowHidden |
84 | }; |
85 | } // namespace parameter |
86 | /*! |
87 | * \brief Information about a parameter field in string representations. |
88 | */ |
89 | struct ParamFieldInfo { |
90 | /*! \brief name of the field */ |
91 | std::string name; |
92 | /*! \brief type of the field in string format */ |
93 | std::string type; |
94 | /*! |
95 | * \brief detailed type information string |
96 | * This include the default value, enum constran and typename. |
97 | */ |
98 | std::string type_info_str; |
99 | /*! \brief detailed description of the type */ |
100 | std::string description; |
101 | }; |
102 | |
103 | /*! |
104 | * \brief Parameter is the base type every parameter struct should inherit from |
105 | * The following code is a complete example to setup parameters. |
106 | * \code |
107 | * struct Param : public dmlc::Parameter<Param> { |
108 | * float learning_rate; |
109 | * int num_hidden; |
110 | * std::string name; |
111 | * // declare parameters in header file |
112 | * DMLC_DECLARE_PARAMETER(Param) { |
113 | * DMLC_DECLARE_FIELD(num_hidden).set_range(0, 1000); |
114 | * DMLC_DECLARE_FIELD(learning_rate).set_default(0.01f); |
115 | * DMLC_DECLARE_FIELD(name).set_default("hello"); |
116 | * } |
117 | * }; |
118 | * // register it in cc file |
119 | * DMLC_REGISTER_PARAMETER(Param); |
120 | * \endcode |
121 | * |
122 | * After that, the Param struct will get all the functions defined in Parameter. |
123 | * \tparam PType the type of parameter struct |
124 | * |
125 | * \sa DMLC_DECLARE_FIELD, DMLC_REGISTER_PARAMETER, DMLC_DECLARE_PARAMETER |
126 | */ |
127 | template<typename PType> |
128 | struct Parameter { |
129 | public: |
130 | /*! |
131 | * \brief initialize the parameter by keyword arguments. |
132 | * This function will initialize the parameter struct, check consistency |
133 | * and throw error if something wrong happens. |
134 | * |
135 | * \param kwargs map of keyword arguments, or vector of pairs |
136 | * \parma option The option on initialization. |
137 | * \tparam Container container type |
138 | * \throw ParamError when something go wrong. |
139 | */ |
140 | template<typename Container> |
141 | inline void Init(const Container &kwargs, |
142 | parameter::ParamInitOption option = parameter::kAllowHidden) { |
143 | PType::__MANAGER__()->RunInit(static_cast<PType*>(this), |
144 | kwargs.begin(), kwargs.end(), |
145 | NULL, |
146 | option); |
147 | } |
148 | /*! |
149 | * \brief initialize the parameter by keyword arguments. |
150 | * This is same as Init, but allow unknown arguments. |
151 | * |
152 | * \param kwargs map of keyword arguments, or vector of pairs |
153 | * \tparam Container container type |
154 | * \throw ParamError when something go wrong. |
155 | * \return vector of pairs of unknown arguments. |
156 | */ |
157 | template<typename Container> |
158 | inline std::vector<std::pair<std::string, std::string> > |
159 | InitAllowUnknown(const Container &kwargs) { |
160 | std::vector<std::pair<std::string, std::string> > unknown; |
161 | PType::__MANAGER__()->RunInit(static_cast<PType*>(this), |
162 | kwargs.begin(), kwargs.end(), |
163 | &unknown, parameter::kAllowUnknown); |
164 | return unknown; |
165 | } |
166 | |
167 | /*! |
168 | * \brief Update the parameter by keyword arguments. This is same as |
169 | * `InitAllowUnknown', but without setting not provided parameters to their default. |
170 | * |
171 | * \tparam Container container type |
172 | * |
173 | * \param kwargs map of keyword arguments, or vector of pairs |
174 | * |
175 | * \throw ParamError when something go wrong. |
176 | * \return vector of pairs of unknown arguments. |
177 | */ |
178 | template <typename Container> |
179 | std::vector<std::pair<std::string, std::string> > |
180 | UpdateAllowUnknown(Container const& kwargs) { |
181 | std::vector<std::pair<std::string, std::string> > unknown; |
182 | PType::__MANAGER__()->RunUpdate(static_cast<PType *>(this), kwargs.begin(), |
183 | kwargs.end(), parameter::kAllowUnknown, |
184 | &unknown, nullptr); |
185 | return unknown; |
186 | } |
187 | |
188 | /*! |
189 | * \brief Update the dict with values stored in parameter. |
190 | * |
191 | * \param dict The dictionary to be updated. |
192 | * \tparam Container container type |
193 | */ |
194 | template<typename Container> |
195 | inline void UpdateDict(Container *dict) const { |
196 | PType::__MANAGER__()->UpdateDict(this->head(), dict); |
197 | } |
198 | /*! |
199 | * \brief Return a dictionary representation of the parameters |
200 | * \return A dictionary that maps key -> value |
201 | */ |
202 | inline std::map<std::string, std::string> __DICT__() const { |
203 | std::vector<std::pair<std::string, std::string> > vec |
204 | = PType::__MANAGER__()->GetDict(this->head()); |
205 | return std::map<std::string, std::string>(vec.begin(), vec.end()); |
206 | } |
207 | /*! |
208 | * \brief Write the parameters in JSON format. |
209 | * \param writer JSONWriter used for writing. |
210 | */ |
211 | inline void Save(dmlc::JSONWriter *writer) const { |
212 | writer->Write(this->__DICT__()); |
213 | } |
214 | /*! |
215 | * \brief Load the parameters from JSON. |
216 | * \param reader JSONReader used for loading. |
217 | * \throw ParamError when something go wrong. |
218 | */ |
219 | inline void Load(dmlc::JSONReader *reader) { |
220 | std::map<std::string, std::string> kwargs; |
221 | reader->Read(&kwargs); |
222 | this->Init(kwargs); |
223 | } |
224 | /*! |
225 | * \brief Get the fields of the parameters. |
226 | * \return List of ParamFieldInfo of each field. |
227 | */ |
228 | inline static std::vector<ParamFieldInfo> __FIELDS__() { |
229 | return PType::__MANAGER__()->GetFieldInfo(); |
230 | } |
231 | /*! |
232 | * \brief Print docstring of the parameter |
233 | * \return the printed docstring |
234 | */ |
235 | inline static std::string __DOC__() { |
236 | std::ostringstream os; |
237 | PType::__MANAGER__()->PrintDocString(os); |
238 | return os.str(); |
239 | } |
240 | |
241 | protected: |
242 | /*! |
243 | * \brief internal function to allow declare of a parameter memember |
244 | * \param manager the parameter manager |
245 | * \param key the key name of the parameter |
246 | * \param ref the reference to the parameter in the struct. |
247 | */ |
248 | template<typename DType> |
249 | inline parameter::FieldEntry<DType>& DECLARE( |
250 | parameter::ParamManagerSingleton<PType> *manager, |
251 | const std::string &key, DType &ref) { // NOLINT(*) |
252 | parameter::FieldEntry<DType> *e = |
253 | new parameter::FieldEntry<DType>(); |
254 | e->Init(key, this->head(), ref); |
255 | manager->manager.AddEntry(key, e); |
256 | return *e; |
257 | } |
258 | |
259 | private: |
260 | /*! \return Get head pointer of child structure */ |
261 | inline PType *head() const { |
262 | return static_cast<PType*>(const_cast<Parameter<PType>*>(this)); |
263 | } |
264 | }; |
265 | |
266 | //! \cond Doxygen_Suppress |
267 | /*! |
268 | * \brief macro used to declare parameter |
269 | * |
270 | * Example: |
271 | * \code |
272 | * struct Param : public dmlc::Parameter<Param> { |
273 | * // declare parameters in header file |
274 | * DMLC_DECLARE_PARAMETER(Param) { |
275 | * // details of declarations |
276 | * } |
277 | * }; |
278 | * \endcode |
279 | * |
280 | * This macro need to be put in a source file so that registration only happens once. |
281 | * Refer to example code in Parameter for details |
282 | * |
283 | * \param PType the name of parameter struct. |
284 | * \sa Parameter |
285 | */ |
286 | #define DMLC_DECLARE_PARAMETER(PType) \ |
287 | static ::dmlc::parameter::ParamManager *__MANAGER__(); \ |
288 | inline void __DECLARE__(::dmlc::parameter::ParamManagerSingleton<PType> *manager) \ |
289 | |
290 | /*! |
291 | * \brief macro to declare fields |
292 | * \param FieldName the name of the field. |
293 | */ |
294 | #define DMLC_DECLARE_FIELD(FieldName) this->DECLARE(manager, #FieldName, FieldName) |
295 | |
296 | /*! |
297 | * \brief macro to declare alias of a fields |
298 | * \param FieldName the name of the field. |
299 | * \param AliasName the name of the alias, must be declared after the field is declared. |
300 | */ |
301 | #define DMLC_DECLARE_ALIAS(FieldName, AliasName) manager->manager.AddAlias(#FieldName, #AliasName) |
302 | |
303 | /*! |
304 | * \brief Macro used to register parameter. |
305 | * |
306 | * This macro need to be put in a source file so that registeration only happens once. |
307 | * Refer to example code in Parameter for details |
308 | * \param PType the type of parameter struct. |
309 | * \sa Parameter |
310 | */ |
311 | #define DMLC_REGISTER_PARAMETER(PType) \ |
312 | ::dmlc::parameter::ParamManager *PType::__MANAGER__() { \ |
313 | static ::dmlc::parameter::ParamManagerSingleton<PType> inst(#PType); \ |
314 | return &inst.manager; \ |
315 | } \ |
316 | static DMLC_ATTRIBUTE_UNUSED ::dmlc::parameter::ParamManager& \ |
317 | __make__ ## PType ## ParamManager__ = \ |
318 | (*PType::__MANAGER__()) \ |
319 | |
320 | //! \endcond |
321 | /*! |
322 | * \brief internal namespace for parameter management |
323 | * There is no need to use it directly in normal case |
324 | */ |
325 | namespace parameter { |
326 | /*! |
327 | * \brief FieldAccessEntry interface to help manage the parameters |
328 | * Each entry can be used to access one parameter in the Parameter struct. |
329 | * |
330 | * This is an internal interface used that is used to manage parameters |
331 | */ |
332 | class FieldAccessEntry { |
333 | public: |
334 | FieldAccessEntry() |
335 | : has_default_(false), index_(0) {} |
336 | /*! \brief destructor */ |
337 | virtual ~FieldAccessEntry() {} |
338 | /*! |
339 | * \brief set the default value. |
340 | * \param head the pointer to the head of the struct |
341 | * \throw error if no default is presented |
342 | */ |
343 | virtual void SetDefault(void *head) const = 0; |
344 | /*! |
345 | * \brief set the parameter by string value |
346 | * \param head the pointer to the head of the struct |
347 | * \param value the value to be set |
348 | */ |
349 | virtual void Set(void *head, const std::string &value) const = 0; |
350 | // check if value is OK |
351 | virtual void Check(void *head) const {} |
352 | /*! |
353 | * \brief get the string representation of value. |
354 | * \param head the pointer to the head of the struct |
355 | */ |
356 | virtual std::string GetStringValue(void *head) const = 0; |
357 | /*! |
358 | * \brief Get field information |
359 | * \return the corresponding field information |
360 | */ |
361 | virtual ParamFieldInfo GetFieldInfo() const = 0; |
362 | |
363 | protected: |
364 | /*! \brief whether this parameter have default value */ |
365 | bool has_default_; |
366 | /*! \brief positional index of parameter in struct */ |
367 | size_t index_; |
368 | /*! \brief parameter key name */ |
369 | std::string key_; |
370 | /*! \brief parameter type */ |
371 | std::string type_; |
372 | /*! \brief description of the parameter */ |
373 | std::string description_; |
374 | // internal offset of the field |
375 | ptrdiff_t offset_; |
376 | /*! \brief get pointer to parameter */ |
377 | char* GetRawPtr(void* head) const { |
378 | return reinterpret_cast<char*>(head) + offset_; |
379 | } |
380 | /*! |
381 | * \brief print string representation of default value |
382 | * \parma os the stream to print the docstring to. |
383 | */ |
384 | virtual void PrintDefaultValueString(std::ostream &os) const = 0; // NOLINT(*) |
385 | // allow ParamManager to modify self |
386 | friend class ParamManager; |
387 | }; |
388 | |
389 | /*! |
390 | * \brief manager class to handle parameter structure for each type |
391 | * An manager will be created for each parameter structure. |
392 | */ |
393 | class ParamManager { |
394 | public: |
395 | /*! \brief destructor */ |
396 | ~ParamManager() { |
397 | for (size_t i = 0; i < entry_.size(); ++i) { |
398 | delete entry_[i]; |
399 | } |
400 | } |
401 | /*! |
402 | * \brief find the access entry by parameter key |
403 | * \param key the key of the parameter. |
404 | * \return pointer to FieldAccessEntry, NULL if nothing is found. |
405 | */ |
406 | inline FieldAccessEntry *Find(const std::string &key) const { |
407 | std::map<std::string, FieldAccessEntry*>::const_iterator it = |
408 | entry_map_.find(key); |
409 | if (it == entry_map_.end()) return NULL; |
410 | return it->second; |
411 | } |
412 | /*! |
413 | * \brief Set parameter by keyword arguments and default values. |
414 | * \param head head to the parameter field. |
415 | * \param begin begin iterator of original kwargs |
416 | * \param end end iterator of original kwargs |
417 | * \param unknown_args optional, used to hold unknown arguments |
418 | * When it is specified, unknown arguments will be stored into here, instead of raise an error |
419 | * \tparam RandomAccessIterator iterator type |
420 | * \throw ParamError when there is unknown argument and unknown_args == NULL, or required argument is missing. |
421 | */ |
422 | template<typename RandomAccessIterator> |
423 | inline void RunInit(void *head, |
424 | RandomAccessIterator begin, |
425 | RandomAccessIterator end, |
426 | std::vector<std::pair<std::string, std::string> > *unknown_args, |
427 | parameter::ParamInitOption option) const { |
428 | std::set<FieldAccessEntry*> selected_args; |
429 | RunUpdate(head, begin, end, option, unknown_args, &selected_args); |
430 | for (auto const& kv : entry_map_) { |
431 | if (selected_args.find(kv.second) == selected_args.cend()) { |
432 | kv.second->SetDefault(head); |
433 | } |
434 | } |
435 | for (std::map<std::string, FieldAccessEntry*>::const_iterator it = entry_map_.begin(); |
436 | it != entry_map_.end(); ++it) { |
437 | if (selected_args.count(it->second) == 0) { |
438 | it->second->SetDefault(head); |
439 | } |
440 | } |
441 | } |
442 | /*! |
443 | * \brief Update parameters by keyword arguments. |
444 | * |
445 | * \tparam RandomAccessIterator iterator type |
446 | * \param head head to the parameter field. |
447 | * \param begin begin iterator of original kwargs |
448 | * \param end end iterator of original kwargs |
449 | * \param unknown_args optional, used to hold unknown arguments |
450 | * When it is specified, unknown arguments will be stored into here, instead of raise an error |
451 | * \param selected_args The arguments used in update will be pushed into it, defaullt to nullptr. |
452 | * \throw ParamError when there is unknown argument and unknown_args == NULL, or required argument is missing. |
453 | */ |
454 | template <typename RandomAccessIterator> |
455 | void RunUpdate(void *head, |
456 | RandomAccessIterator begin, |
457 | RandomAccessIterator end, |
458 | parameter::ParamInitOption option, |
459 | std::vector<std::pair<std::string, std::string> > *unknown_args, |
460 | std::set<FieldAccessEntry*>* selected_args = nullptr) const { |
461 | for (RandomAccessIterator it = begin; it != end; ++it) { |
462 | if (FieldAccessEntry *e = Find(it->first)) { |
463 | e->Set(head, it->second); |
464 | e->Check(head); |
465 | if (selected_args) { |
466 | selected_args->insert(e); |
467 | } |
468 | } else { |
469 | if (unknown_args != NULL) { |
470 | unknown_args->push_back(*it); |
471 | } else { |
472 | if (option != parameter::kAllowUnknown) { |
473 | if (option == parameter::kAllowHidden && |
474 | it->first.length() > 4 && |
475 | it->first.find("__" ) == 0 && |
476 | it->first.rfind("__" ) == it->first.length()-2) { |
477 | continue; |
478 | } |
479 | std::ostringstream os; |
480 | os << "Cannot find argument \'" << it->first << "\', Possible Arguments:\n" ; |
481 | os << "----------------\n" ; |
482 | PrintDocString(os); |
483 | throw dmlc::ParamError(os.str()); |
484 | } |
485 | } |
486 | } |
487 | } |
488 | } |
489 | /*! |
490 | * \brief internal function to add entry to manager, |
491 | * The manager will take ownership of the entry. |
492 | * \param key the key to the parameters |
493 | * \param e the pointer to the new entry. |
494 | */ |
495 | inline void AddEntry(const std::string &key, FieldAccessEntry *e) { |
496 | e->index_ = entry_.size(); |
497 | // TODO(bing) better error message |
498 | if (entry_map_.count(key) != 0) { |
499 | LOG(FATAL) << "key " << key << " has already been registered in " << name_; |
500 | } |
501 | entry_.push_back(e); |
502 | entry_map_[key] = e; |
503 | } |
504 | /*! |
505 | * \brief internal function to add entry to manager, |
506 | * The manager will take ownership of the entry. |
507 | * \param key the key to the parameters |
508 | * \param e the pointer to the new entry. |
509 | */ |
510 | inline void AddAlias(const std::string& field, const std::string& alias) { |
511 | if (entry_map_.count(field) == 0) { |
512 | LOG(FATAL) << "key " << field << " has not been registered in " << name_; |
513 | } |
514 | if (entry_map_.count(alias) != 0) { |
515 | LOG(FATAL) << "Alias " << alias << " has already been registered in " << name_; |
516 | } |
517 | entry_map_[alias] = entry_map_[field]; |
518 | } |
519 | /*! |
520 | * \brief set the name of parameter manager |
521 | * \param name the name to set |
522 | */ |
523 | inline void set_name(const std::string &name) { |
524 | name_ = name; |
525 | } |
526 | /*! |
527 | * \brief get field information of each field. |
528 | * \return field information |
529 | */ |
530 | inline std::vector<ParamFieldInfo> GetFieldInfo() const { |
531 | std::vector<ParamFieldInfo> ret(entry_.size()); |
532 | for (size_t i = 0; i < entry_.size(); ++i) { |
533 | ret[i] = entry_[i]->GetFieldInfo(); |
534 | } |
535 | return ret; |
536 | } |
537 | /*! |
538 | * \brief Print readible docstring to ostream, add newline. |
539 | * \parma os the stream to print the docstring to. |
540 | */ |
541 | inline void PrintDocString(std::ostream &os) const { // NOLINT(*) |
542 | for (size_t i = 0; i < entry_.size(); ++i) { |
543 | ParamFieldInfo info = entry_[i]->GetFieldInfo(); |
544 | os << info.name << " : " << info.type_info_str << '\n'; |
545 | if (info.description.length() != 0) { |
546 | os << " " << info.description << '\n'; |
547 | } |
548 | } |
549 | } |
550 | /*! |
551 | * \brief Get internal parameters in vector of pairs. |
552 | * \param head the head of the struct. |
553 | * \param skip_default skip the values that equals default value. |
554 | * \return the parameter dictionary. |
555 | */ |
556 | inline std::vector<std::pair<std::string, std::string> > GetDict(void * head) const { |
557 | std::vector<std::pair<std::string, std::string> > ret; |
558 | for (std::map<std::string, FieldAccessEntry*>::const_iterator |
559 | it = entry_map_.begin(); it != entry_map_.end(); ++it) { |
560 | ret.push_back(std::make_pair(it->first, it->second->GetStringValue(head))); |
561 | } |
562 | return ret; |
563 | } |
564 | /*! |
565 | * \brief Update the dictionary with values in parameter. |
566 | * \param head the head of the struct. |
567 | * \tparam Container The container type |
568 | * \return the parameter dictionary. |
569 | */ |
570 | template<typename Container> |
571 | inline void UpdateDict(void * head, Container* dict) const { |
572 | for (std::map<std::string, FieldAccessEntry*>::const_iterator |
573 | it = entry_map_.begin(); it != entry_map_.end(); ++it) { |
574 | (*dict)[it->first] = it->second->GetStringValue(head); |
575 | } |
576 | } |
577 | |
578 | private: |
579 | /*! \brief parameter struct name */ |
580 | std::string name_; |
581 | /*! \brief positional list of entries */ |
582 | std::vector<FieldAccessEntry*> entry_; |
583 | /*! \brief map from key to entry */ |
584 | std::map<std::string, FieldAccessEntry*> entry_map_; |
585 | }; |
586 | |
587 | //! \cond Doxygen_Suppress |
588 | |
589 | // The following piece of code will be template heavy and less documented |
590 | // singleton parameter manager for certain type, used for initialization |
591 | template<typename PType> |
592 | struct ParamManagerSingleton { |
593 | ParamManager manager; |
594 | explicit ParamManagerSingleton(const std::string ¶m_name) { |
595 | PType param; |
596 | manager.set_name(param_name); |
597 | param.__DECLARE__(this); |
598 | } |
599 | }; |
600 | |
601 | // Base class of FieldEntry |
602 | // implement set_default |
603 | template<typename TEntry, typename DType> |
604 | class FieldEntryBase : public FieldAccessEntry { |
605 | public: |
606 | // entry type |
607 | typedef TEntry EntryType; |
608 | // implement set value |
609 | void Set(void *head, const std::string &value) const override { |
610 | std::istringstream is(value); |
611 | is >> this->Get(head); |
612 | if (!is.fail()) { |
613 | while (!is.eof()) { |
614 | int ch = is.get(); |
615 | if (ch == EOF) { |
616 | is.clear(); break; |
617 | } |
618 | if (!isspace(ch)) { |
619 | is.setstate(std::ios::failbit); break; |
620 | } |
621 | } |
622 | } |
623 | |
624 | if (is.fail()) { |
625 | std::ostringstream os; |
626 | os << "Invalid Parameter format for " << key_ |
627 | << " expect " << type_ << " but value=\'" << value<< '\''; |
628 | throw dmlc::ParamError(os.str()); |
629 | } |
630 | } |
631 | |
632 | std::string GetStringValue(void *head) const override { |
633 | std::ostringstream os; |
634 | PrintValue(os, this->Get(head)); |
635 | return os.str(); |
636 | } |
637 | ParamFieldInfo GetFieldInfo() const override { |
638 | ParamFieldInfo info; |
639 | std::ostringstream os; |
640 | info.name = key_; |
641 | info.type = type_; |
642 | os << type_; |
643 | if (has_default_) { |
644 | os << ',' << " optional, default=" ; |
645 | PrintDefaultValueString(os); |
646 | } else { |
647 | os << ", required" ; |
648 | } |
649 | info.type_info_str = os.str(); |
650 | info.description = description_; |
651 | return info; |
652 | } |
653 | // implement set head to default value |
654 | void SetDefault(void *head) const override { |
655 | if (!has_default_) { |
656 | std::ostringstream os; |
657 | os << "Required parameter " << key_ |
658 | << " of " << type_ << " is not presented" ; |
659 | throw dmlc::ParamError(os.str()); |
660 | } else { |
661 | this->Get(head) = default_value_; |
662 | } |
663 | } |
664 | // return reference of self as derived type |
665 | inline TEntry &self() { |
666 | return *(static_cast<TEntry*>(this)); |
667 | } |
668 | // implement set_default |
669 | inline TEntry &set_default(const DType &default_value) { |
670 | default_value_ = default_value; |
671 | has_default_ = true; |
672 | // return self to allow chaining |
673 | return this->self(); |
674 | } |
675 | // implement describe |
676 | inline TEntry &describe(const std::string &description) { |
677 | description_ = description; |
678 | // return self to allow chaining |
679 | return this->self(); |
680 | } |
681 | // initialization function |
682 | inline void Init(const std::string &key, |
683 | void *head, DType &ref) { // NOLINT(*) |
684 | this->key_ = key; |
685 | if (this->type_.length() == 0) { |
686 | this->type_ = dmlc::type_name<DType>(); |
687 | } |
688 | this->offset_ = ((char*)&ref) - ((char*)head); // NOLINT(*) |
689 | } |
690 | |
691 | protected: |
692 | // print the value |
693 | virtual void PrintValue(std::ostream &os, DType value) const { // NOLINT(*) |
694 | os << value; |
695 | } |
696 | void PrintDefaultValueString(std::ostream &os) const override { // NOLINT(*) |
697 | PrintValue(os, default_value_); |
698 | } |
699 | // get the internal representation of parameter |
700 | // for example if this entry corresponds field param.learning_rate |
701 | // then Get(¶m) will return reference to param.learning_rate |
702 | inline DType &Get(void *head) const { |
703 | return *(DType*)this->GetRawPtr(head); // NOLINT(*) |
704 | } |
705 | // default value of field |
706 | DType default_value_; |
707 | }; |
708 | |
709 | // parameter base for numeric types that have range |
710 | template<typename TEntry, typename DType> |
711 | class FieldEntryNumeric |
712 | : public FieldEntryBase<TEntry, DType> { |
713 | public: |
714 | FieldEntryNumeric() |
715 | : has_begin_(false), has_end_(false) {} |
716 | // implement set_range |
717 | virtual TEntry &set_range(DType begin, DType end) { |
718 | begin_ = begin; end_ = end; |
719 | has_begin_ = true; has_end_ = true; |
720 | return this->self(); |
721 | } |
722 | // implement set_range |
723 | virtual TEntry &set_lower_bound(DType begin) { |
724 | begin_ = begin; has_begin_ = true; |
725 | return this->self(); |
726 | } |
727 | // consistency check for numeric ranges |
728 | virtual void Check(void *head) const { |
729 | FieldEntryBase<TEntry, DType>::Check(head); |
730 | DType v = this->Get(head); |
731 | if (has_begin_ && has_end_) { |
732 | if (v < begin_ || v > end_) { |
733 | std::ostringstream os; |
734 | os << "value " << v << " for Parameter " << this->key_ |
735 | << " exceed bound [" << begin_ << ',' << end_ <<']' << '\n'; |
736 | os << this->key_ << ": " << this->description_; |
737 | throw dmlc::ParamError(os.str()); |
738 | } |
739 | } else if (has_begin_ && v < begin_) { |
740 | std::ostringstream os; |
741 | os << "value " << v << " for Parameter " << this->key_ |
742 | << " should be greater equal to " << begin_ << '\n'; |
743 | os << this->key_ << ": " << this->description_; |
744 | throw dmlc::ParamError(os.str()); |
745 | } else if (has_end_ && v > end_) { |
746 | std::ostringstream os; |
747 | os << "value " << v << " for Parameter " << this->key_ |
748 | << " should be smaller equal to " << end_ << '\n'; |
749 | os << this->key_ << ": " << this->description_; |
750 | throw dmlc::ParamError(os.str()); |
751 | } |
752 | } |
753 | |
754 | protected: |
755 | // whether it have begin and end range |
756 | bool has_begin_, has_end_; |
757 | // data bound |
758 | DType begin_, end_; |
759 | }; |
760 | |
761 | /*! |
762 | * \brief FieldEntry defines parsing and checking behavior of DType. |
763 | * This class can be specialized to implement specific behavior of more settings. |
764 | * \tparam DType the data type of the entry. |
765 | */ |
766 | template<typename DType> |
767 | class FieldEntry : |
768 | public IfThenElseType<dmlc::is_arithmetic<DType>::value, |
769 | FieldEntryNumeric<FieldEntry<DType>, DType>, |
770 | FieldEntryBase<FieldEntry<DType>, DType> >::Type { |
771 | }; |
772 | |
773 | // specialize define for int(enum) |
774 | template<> |
775 | class FieldEntry<int> |
776 | : public FieldEntryNumeric<FieldEntry<int>, int> { |
777 | public: |
778 | // construct |
779 | FieldEntry() : is_enum_(false) {} |
780 | // parent |
781 | typedef FieldEntryNumeric<FieldEntry<int>, int> Parent; |
782 | // override set |
783 | virtual void Set(void *head, const std::string &value) const { |
784 | if (is_enum_) { |
785 | std::map<std::string, int>::const_iterator it = enum_map_.find(value); |
786 | std::ostringstream os; |
787 | if (it == enum_map_.end()) { |
788 | os << "Invalid Input: \'" << value; |
789 | os << "\', valid values are: " ; |
790 | PrintEnums(os); |
791 | throw dmlc::ParamError(os.str()); |
792 | } else { |
793 | os << it->second; |
794 | Parent::Set(head, os.str()); |
795 | } |
796 | } else { |
797 | Parent::Set(head, value); |
798 | } |
799 | } |
800 | virtual ParamFieldInfo GetFieldInfo() const { |
801 | if (is_enum_) { |
802 | ParamFieldInfo info; |
803 | std::ostringstream os; |
804 | info.name = key_; |
805 | info.type = type_; |
806 | PrintEnums(os); |
807 | if (has_default_) { |
808 | os << ',' << "optional, default=" ; |
809 | PrintDefaultValueString(os); |
810 | } else { |
811 | os << ", required" ; |
812 | } |
813 | info.type_info_str = os.str(); |
814 | info.description = description_; |
815 | return info; |
816 | } else { |
817 | return Parent::GetFieldInfo(); |
818 | } |
819 | } |
820 | // add enum |
821 | inline FieldEntry<int> &add_enum(const std::string &key, int value) { |
822 | if ((enum_map_.size() != 0 && enum_map_.count(key) != 0) || \ |
823 | enum_back_map_.count(value) != 0) { |
824 | std::ostringstream os; |
825 | os << "Enum " << "(" << key << ": " << value << " exisit!" << ")\n" ; |
826 | os << "Enums: " ; |
827 | for (std::map<std::string, int>::const_iterator it = enum_map_.begin(); |
828 | it != enum_map_.end(); ++it) { |
829 | os << "(" << it->first << ": " << it->second << "), " ; |
830 | } |
831 | throw dmlc::ParamError(os.str()); |
832 | } |
833 | enum_map_[key] = value; |
834 | enum_back_map_[value] = key; |
835 | is_enum_ = true; |
836 | return this->self(); |
837 | } |
838 | |
839 | protected: |
840 | // enum flag |
841 | bool is_enum_; |
842 | // enum map |
843 | std::map<std::string, int> enum_map_; |
844 | // enum map |
845 | std::map<int, std::string> enum_back_map_; |
846 | // override print behavior |
847 | virtual void PrintDefaultValueString(std::ostream &os) const { // NOLINT(*) |
848 | os << '\''; |
849 | PrintValue(os, default_value_); |
850 | os << '\''; |
851 | } |
852 | // override print default |
853 | virtual void PrintValue(std::ostream &os, int value) const { // NOLINT(*) |
854 | if (is_enum_) { |
855 | CHECK_NE(enum_back_map_.count(value), 0U) |
856 | << "Value not found in enum declared" ; |
857 | os << enum_back_map_.at(value); |
858 | } else { |
859 | os << value; |
860 | } |
861 | } |
862 | |
863 | |
864 | private: |
865 | inline void PrintEnums(std::ostream &os) const { // NOLINT(*) |
866 | os << '{'; |
867 | for (std::map<std::string, int>::const_iterator |
868 | it = enum_map_.begin(); it != enum_map_.end(); ++it) { |
869 | if (it != enum_map_.begin()) { |
870 | os << ", " ; |
871 | } |
872 | os << "\'" << it->first << '\''; |
873 | } |
874 | os << '}'; |
875 | } |
876 | }; |
877 | |
878 | |
879 | // specialize define for optional<int>(enum) |
880 | template<> |
881 | class FieldEntry<optional<int> > |
882 | : public FieldEntryBase<FieldEntry<optional<int> >, optional<int> > { |
883 | public: |
884 | // construct |
885 | FieldEntry() : is_enum_(false) {} |
886 | // parent |
887 | typedef FieldEntryBase<FieldEntry<optional<int> >, optional<int> > Parent; |
888 | // override set |
889 | virtual void Set(void *head, const std::string &value) const { |
890 | if (is_enum_ && value != "None" ) { |
891 | std::map<std::string, int>::const_iterator it = enum_map_.find(value); |
892 | std::ostringstream os; |
893 | if (it == enum_map_.end()) { |
894 | os << "Invalid Input: \'" << value; |
895 | os << "\', valid values are: " ; |
896 | PrintEnums(os); |
897 | throw dmlc::ParamError(os.str()); |
898 | } else { |
899 | os << it->second; |
900 | Parent::Set(head, os.str()); |
901 | } |
902 | } else { |
903 | Parent::Set(head, value); |
904 | } |
905 | } |
906 | virtual ParamFieldInfo GetFieldInfo() const { |
907 | if (is_enum_) { |
908 | ParamFieldInfo info; |
909 | std::ostringstream os; |
910 | info.name = key_; |
911 | info.type = type_; |
912 | PrintEnums(os); |
913 | if (has_default_) { |
914 | os << ',' << "optional, default=" ; |
915 | PrintDefaultValueString(os); |
916 | } else { |
917 | os << ", required" ; |
918 | } |
919 | info.type_info_str = os.str(); |
920 | info.description = description_; |
921 | return info; |
922 | } else { |
923 | return Parent::GetFieldInfo(); |
924 | } |
925 | } |
926 | // add enum |
927 | inline FieldEntry<optional<int> > &add_enum(const std::string &key, int value) { |
928 | CHECK_NE(key, "None" ) << "None is reserved for empty optional<int>" ; |
929 | if ((enum_map_.size() != 0 && enum_map_.count(key) != 0) || \ |
930 | enum_back_map_.count(value) != 0) { |
931 | std::ostringstream os; |
932 | os << "Enum " << "(" << key << ": " << value << " exisit!" << ")\n" ; |
933 | os << "Enums: " ; |
934 | for (std::map<std::string, int>::const_iterator it = enum_map_.begin(); |
935 | it != enum_map_.end(); ++it) { |
936 | os << "(" << it->first << ": " << it->second << "), " ; |
937 | } |
938 | throw dmlc::ParamError(os.str()); |
939 | } |
940 | enum_map_[key] = value; |
941 | enum_back_map_[value] = key; |
942 | is_enum_ = true; |
943 | return this->self(); |
944 | } |
945 | |
946 | protected: |
947 | // enum flag |
948 | bool is_enum_; |
949 | // enum map |
950 | std::map<std::string, int> enum_map_; |
951 | // enum map |
952 | std::map<int, std::string> enum_back_map_; |
953 | // override print behavior |
954 | virtual void PrintDefaultValueString(std::ostream &os) const { // NOLINT(*) |
955 | os << '\''; |
956 | PrintValue(os, default_value_); |
957 | os << '\''; |
958 | } |
959 | // override print default |
960 | virtual void PrintValue(std::ostream &os, optional<int> value) const { // NOLINT(*) |
961 | if (is_enum_) { |
962 | if (!value) { |
963 | os << "None" ; |
964 | } else { |
965 | CHECK_NE(enum_back_map_.count(value.value()), 0U) |
966 | << "Value not found in enum declared" ; |
967 | os << enum_back_map_.at(value.value()); |
968 | } |
969 | } else { |
970 | os << value; |
971 | } |
972 | } |
973 | |
974 | |
975 | private: |
976 | inline void PrintEnums(std::ostream &os) const { // NOLINT(*) |
977 | os << "{None" ; |
978 | for (std::map<std::string, int>::const_iterator |
979 | it = enum_map_.begin(); it != enum_map_.end(); ++it) { |
980 | os << ", " ; |
981 | os << "\'" << it->first << '\''; |
982 | } |
983 | os << '}'; |
984 | } |
985 | }; |
986 | |
987 | // specialize define for string |
988 | template<> |
989 | class FieldEntry<std::string> |
990 | : public FieldEntryBase<FieldEntry<std::string>, std::string> { |
991 | public: |
992 | // parent class |
993 | typedef FieldEntryBase<FieldEntry<std::string>, std::string> Parent; |
994 | // override set |
995 | virtual void Set(void *head, const std::string &value) const { |
996 | this->Get(head) = value; |
997 | } |
998 | // override print default |
999 | virtual void PrintDefaultValueString(std::ostream &os) const { // NOLINT(*) |
1000 | os << '\'' << default_value_ << '\''; |
1001 | } |
1002 | }; |
1003 | |
1004 | // specialize define for bool |
1005 | template<> |
1006 | class FieldEntry<bool> |
1007 | : public FieldEntryBase<FieldEntry<bool>, bool> { |
1008 | public: |
1009 | // parent class |
1010 | typedef FieldEntryBase<FieldEntry<bool>, bool> Parent; |
1011 | // override set |
1012 | virtual void Set(void *head, const std::string &value) const { |
1013 | std::string lower_case; lower_case.resize(value.length()); |
1014 | std::transform(value.begin(), value.end(), lower_case.begin(), ::tolower); |
1015 | bool &ref = this->Get(head); |
1016 | if (lower_case == "true" ) { |
1017 | ref = true; |
1018 | } else if (lower_case == "false" ) { |
1019 | ref = false; |
1020 | } else if (lower_case == "1" ) { |
1021 | ref = true; |
1022 | } else if (lower_case == "0" ) { |
1023 | ref = false; |
1024 | } else { |
1025 | std::ostringstream os; |
1026 | os << "Invalid Parameter format for " << key_ |
1027 | << " expect " << type_ << " but value=\'" << value<< '\''; |
1028 | throw dmlc::ParamError(os.str()); |
1029 | } |
1030 | } |
1031 | |
1032 | protected: |
1033 | // print default string |
1034 | virtual void PrintValue(std::ostream &os, bool value) const { // NOLINT(*) |
1035 | os << static_cast<int>(value); |
1036 | } |
1037 | }; |
1038 | |
1039 | |
1040 | // specialize define for float. Uses stof for platform independent handling of |
1041 | // INF, -INF, NAN, etc. |
1042 | #if DMLC_USE_CXX11 |
1043 | template <> |
1044 | class FieldEntry<float> : public FieldEntryNumeric<FieldEntry<float>, float> { |
1045 | public: |
1046 | // parent |
1047 | typedef FieldEntryNumeric<FieldEntry<float>, float> Parent; |
1048 | // override set |
1049 | virtual void Set(void *head, const std::string &value) const { |
1050 | size_t pos = 0; // number of characters processed by dmlc::stof() |
1051 | try { |
1052 | this->Get(head) = dmlc::stof(value, &pos); |
1053 | } catch (const std::invalid_argument &) { |
1054 | std::ostringstream os; |
1055 | os << "Invalid Parameter format for " << key_ << " expect " << type_ |
1056 | << " but value=\'" << value << '\''; |
1057 | throw dmlc::ParamError(os.str()); |
1058 | } catch (const std::out_of_range&) { |
1059 | std::ostringstream os; |
1060 | os << "Out of range value for " << key_ << ", value=\'" << value << '\''; |
1061 | throw dmlc::ParamError(os.str()); |
1062 | } |
1063 | CHECK_LE(pos, value.length()); // just in case |
1064 | if (pos < value.length()) { |
1065 | std::ostringstream os; |
1066 | os << "Some trailing characters could not be parsed: \'" |
1067 | << value.substr(pos) << "\'" ; |
1068 | throw dmlc::ParamError(os.str()); |
1069 | } |
1070 | } |
1071 | |
1072 | protected: |
1073 | // print the value |
1074 | virtual void PrintValue(std::ostream &os, float value) const { // NOLINT(*) |
1075 | os << std::setprecision(std::numeric_limits<float>::max_digits10) << value; |
1076 | } |
1077 | }; |
1078 | |
1079 | // specialize define for double. Uses stod for platform independent handling of |
1080 | // INF, -INF, NAN, etc. |
1081 | template <> |
1082 | class FieldEntry<double> |
1083 | : public FieldEntryNumeric<FieldEntry<double>, double> { |
1084 | public: |
1085 | // parent |
1086 | typedef FieldEntryNumeric<FieldEntry<double>, double> Parent; |
1087 | // override set |
1088 | virtual void Set(void *head, const std::string &value) const { |
1089 | size_t pos = 0; // number of characters processed by dmlc::stod() |
1090 | try { |
1091 | this->Get(head) = dmlc::stod(value, &pos); |
1092 | } catch (const std::invalid_argument &) { |
1093 | std::ostringstream os; |
1094 | os << "Invalid Parameter format for " << key_ << " expect " << type_ |
1095 | << " but value=\'" << value << '\''; |
1096 | throw dmlc::ParamError(os.str()); |
1097 | } catch (const std::out_of_range&) { |
1098 | std::ostringstream os; |
1099 | os << "Out of range value for " << key_ << ", value=\'" << value << '\''; |
1100 | throw dmlc::ParamError(os.str()); |
1101 | } |
1102 | CHECK_LE(pos, value.length()); // just in case |
1103 | if (pos < value.length()) { |
1104 | std::ostringstream os; |
1105 | os << "Some trailing characters could not be parsed: \'" |
1106 | << value.substr(pos) << "\'" ; |
1107 | throw dmlc::ParamError(os.str()); |
1108 | } |
1109 | } |
1110 | |
1111 | protected: |
1112 | // print the value |
1113 | virtual void PrintValue(std::ostream &os, double value) const { // NOLINT(*) |
1114 | os << std::setprecision(std::numeric_limits<double>::max_digits10) << value; |
1115 | } |
1116 | }; |
1117 | #endif // DMLC_USE_CXX11 |
1118 | |
1119 | } // namespace parameter |
1120 | //! \endcond |
1121 | |
1122 | // implement GetEnv |
1123 | template<typename ValueType> |
1124 | inline ValueType GetEnv(const char *key, |
1125 | ValueType default_value) { |
1126 | const char *val = getenv(key); |
1127 | // On some implementations, if the var is set to a blank string (i.e. "FOO="), then |
1128 | // a blank string will be returned instead of NULL. In order to be consistent, if |
1129 | // the environment var is a blank string, then also behave as if a null was returned. |
1130 | if (val == nullptr || !*val) { |
1131 | return default_value; |
1132 | } |
1133 | ValueType ret; |
1134 | parameter::FieldEntry<ValueType> e; |
1135 | e.Init(key, &ret, ret); |
1136 | e.Set(&ret, val); |
1137 | return ret; |
1138 | } |
1139 | |
1140 | // implement SetEnv |
1141 | template<typename ValueType> |
1142 | inline void SetEnv(const char *key, |
1143 | ValueType value) { |
1144 | parameter::FieldEntry<ValueType> e; |
1145 | e.Init(key, &value, value); |
1146 | #ifdef _WIN32 |
1147 | _putenv_s(key, e.GetStringValue(&value).c_str()); |
1148 | #else |
1149 | setenv(key, e.GetStringValue(&value).c_str(), 1); |
1150 | #endif // _WIN32 |
1151 | } |
1152 | } // namespace dmlc |
1153 | #endif // DMLC_PARAMETER_H_ |
1154 | |