1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | /*! |
20 | * \file tvm/node/attr_registry_map.h |
21 | * \brief Attribute map used in registry. |
22 | */ |
23 | #ifndef TVM_NODE_ATTR_REGISTRY_MAP_H_ |
24 | #define TVM_NODE_ATTR_REGISTRY_MAP_H_ |
25 | |
26 | #include <tvm/runtime/container/string.h> |
27 | |
28 | #include <utility> |
29 | #include <vector> |
30 | |
31 | namespace tvm { |
32 | |
33 | /*! |
34 | * \brief Generic attribute map. |
35 | * \tparam KeyType the type of the key. |
36 | */ |
37 | template <typename KeyType> |
38 | class AttrRegistryMapContainerMap { |
39 | public: |
40 | /*! |
41 | * \brief Check if the map has key. |
42 | * \param key The key to the map |
43 | * \return 1 if key is contained in map, 0 otherwise. |
44 | */ |
45 | int count(const KeyType& key) const { |
46 | if (key.defined()) { |
47 | const uint32_t idx = key->AttrRegistryIndex(); |
48 | return idx < data_.size() ? (data_[idx].second != 0) : 0; |
49 | } else { |
50 | return 0; |
51 | } |
52 | } |
53 | /*! |
54 | * \brief get the corresponding value element at key. |
55 | * \param key The key to the map |
56 | * \return the const reference to the content value. |
57 | */ |
58 | const runtime::TVMRetValue& operator[](const KeyType& key) const { |
59 | ICHECK(key.defined()); |
60 | const uint32_t idx = key->AttrRegistryIndex(); |
61 | ICHECK(idx < data_.size() && data_[idx].second != 0) |
62 | << "Attribute " << attr_name_ << " has not been registered for " << key->name; |
63 | return data_[idx].first; |
64 | } |
65 | /*! |
66 | * \brief get the corresponding value element at key with default value. |
67 | * \param key The key to the map |
68 | * \param def_value The default value when the key does not exist. |
69 | * \return the const reference to the content value. |
70 | * \tparam ValueType The content value type. |
71 | */ |
72 | template <typename ValueType> |
73 | ValueType get(const KeyType& key, ValueType def_value) const { |
74 | ICHECK(key.defined()); |
75 | const uint32_t idx = key->AttrRegistryIndex(); |
76 | if (idx < data_.size() && data_[idx].second != 0) { |
77 | return data_[idx].first; |
78 | } else { |
79 | return def_value; |
80 | } |
81 | } |
82 | |
83 | private: |
84 | /*! \brief The name of the attr field */ |
85 | String attr_name_; |
86 | /*! \brief The internal data. */ |
87 | std::vector<std::pair<runtime::TVMRetValue, int>> data_; |
88 | /*! \brief The constructor */ |
89 | AttrRegistryMapContainerMap() = default; |
90 | template <typename, typename> |
91 | friend class AttrRegistry; |
92 | friend class OpRegEntry; |
93 | }; |
94 | |
95 | /*! |
96 | * \brief Map<Key, ValueType> used to store meta-data. |
97 | * \tparam KeyType The type of the key |
98 | * \tparam ValueType The type of the value stored in map. |
99 | */ |
100 | template <typename KeyType, typename ValueType> |
101 | class AttrRegistryMap { |
102 | public: |
103 | /*! |
104 | * \brief constructor |
105 | * \param map The internal map. |
106 | */ |
107 | explicit AttrRegistryMap(const AttrRegistryMapContainerMap<KeyType>& map) : map_(map) {} |
108 | /*! |
109 | * \brief Check if the map has op as key. |
110 | * \param key The key to the map |
111 | * \return 1 if op is contained in map, 0 otherwise. |
112 | */ |
113 | int count(const KeyType& key) const { return map_.count(key); } |
114 | /*! |
115 | * \brief get the corresponding value element at key. |
116 | * \param key The key to the map |
117 | * \return the const reference to the content value. |
118 | */ |
119 | ValueType operator[](const KeyType& key) const { return map_[key]; } |
120 | /*! |
121 | * \brief get the corresponding value element at key with default value. |
122 | * \param key The key to the map |
123 | * \param def_value The default value when the key does not exist. |
124 | * \return the const reference to the content value. |
125 | */ |
126 | ValueType get(const KeyType& key, ValueType def_value) const { return map_.get(key, def_value); } |
127 | |
128 | protected: |
129 | /*! \brief The internal map field */ |
130 | const AttrRegistryMapContainerMap<KeyType>& map_; |
131 | }; |
132 | |
133 | } // namespace tvm |
134 | #endif // TVM_NODE_ATTR_REGISTRY_MAP_H_ |
135 | |