1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19/*!
20 * \file tvm/node/attr_registry_map.h
21 * \brief Attribute map used in registry.
22 */
23#ifndef TVM_NODE_ATTR_REGISTRY_MAP_H_
24#define TVM_NODE_ATTR_REGISTRY_MAP_H_
25
26#include <tvm/runtime/container/string.h>
27
28#include <utility>
29#include <vector>
30
31namespace tvm {
32
33/*!
34 * \brief Generic attribute map.
35 * \tparam KeyType the type of the key.
36 */
37template <typename KeyType>
38class AttrRegistryMapContainerMap {
39 public:
40 /*!
41 * \brief Check if the map has key.
42 * \param key The key to the map
43 * \return 1 if key is contained in map, 0 otherwise.
44 */
45 int count(const KeyType& key) const {
46 if (key.defined()) {
47 const uint32_t idx = key->AttrRegistryIndex();
48 return idx < data_.size() ? (data_[idx].second != 0) : 0;
49 } else {
50 return 0;
51 }
52 }
53 /*!
54 * \brief get the corresponding value element at key.
55 * \param key The key to the map
56 * \return the const reference to the content value.
57 */
58 const runtime::TVMRetValue& operator[](const KeyType& key) const {
59 ICHECK(key.defined());
60 const uint32_t idx = key->AttrRegistryIndex();
61 ICHECK(idx < data_.size() && data_[idx].second != 0)
62 << "Attribute " << attr_name_ << " has not been registered for " << key->name;
63 return data_[idx].first;
64 }
65 /*!
66 * \brief get the corresponding value element at key with default value.
67 * \param key The key to the map
68 * \param def_value The default value when the key does not exist.
69 * \return the const reference to the content value.
70 * \tparam ValueType The content value type.
71 */
72 template <typename ValueType>
73 ValueType get(const KeyType& key, ValueType def_value) const {
74 ICHECK(key.defined());
75 const uint32_t idx = key->AttrRegistryIndex();
76 if (idx < data_.size() && data_[idx].second != 0) {
77 return data_[idx].first;
78 } else {
79 return def_value;
80 }
81 }
82
83 private:
84 /*! \brief The name of the attr field */
85 String attr_name_;
86 /*! \brief The internal data. */
87 std::vector<std::pair<runtime::TVMRetValue, int>> data_;
88 /*! \brief The constructor */
89 AttrRegistryMapContainerMap() = default;
90 template <typename, typename>
91 friend class AttrRegistry;
92 friend class OpRegEntry;
93};
94
95/*!
96 * \brief Map<Key, ValueType> used to store meta-data.
97 * \tparam KeyType The type of the key
98 * \tparam ValueType The type of the value stored in map.
99 */
100template <typename KeyType, typename ValueType>
101class AttrRegistryMap {
102 public:
103 /*!
104 * \brief constructor
105 * \param map The internal map.
106 */
107 explicit AttrRegistryMap(const AttrRegistryMapContainerMap<KeyType>& map) : map_(map) {}
108 /*!
109 * \brief Check if the map has op as key.
110 * \param key The key to the map
111 * \return 1 if op is contained in map, 0 otherwise.
112 */
113 int count(const KeyType& key) const { return map_.count(key); }
114 /*!
115 * \brief get the corresponding value element at key.
116 * \param key The key to the map
117 * \return the const reference to the content value.
118 */
119 ValueType operator[](const KeyType& key) const { return map_[key]; }
120 /*!
121 * \brief get the corresponding value element at key with default value.
122 * \param key The key to the map
123 * \param def_value The default value when the key does not exist.
124 * \return the const reference to the content value.
125 */
126 ValueType get(const KeyType& key, ValueType def_value) const { return map_.get(key, def_value); }
127
128 protected:
129 /*! \brief The internal map field */
130 const AttrRegistryMapContainerMap<KeyType>& map_;
131};
132
133} // namespace tvm
134#endif // TVM_NODE_ATTR_REGISTRY_MAP_H_
135