1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_ATTR_BUILDER_H_ |
17 | #define TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_ATTR_BUILDER_H_ |
18 | |
19 | // Support for eager execution of TensorFlow kernels. |
20 | |
21 | #include <memory> |
22 | #include <unordered_map> |
23 | |
24 | #include "tensorflow/c/eager/abstract_op_attrs.h" |
25 | #include "tensorflow/c/tf_attrtype.h" |
26 | #include "tensorflow/core/common_runtime/device.h" |
27 | #include "tensorflow/core/framework/node_def.pb.h" |
28 | #include "tensorflow/core/framework/op_kernel.h" |
29 | #include "tensorflow/core/framework/types.h" |
30 | #include "tensorflow/core/lib/core/errors.h" |
31 | #include "tensorflow/core/lib/core/status.h" |
32 | #include "tensorflow/core/lib/gtl/inlined_vector.h" |
33 | #include "tensorflow/core/lib/gtl/optional.h" |
34 | #include "tensorflow/core/platform/fingerprint.h" |
35 | #include "tensorflow/core/util/tensor_slice_reader_cache.h" |
36 | |
37 | namespace tensorflow { |
38 | |
39 | // Maps attribute name to an encoding of the type of the attribute value. |
40 | // If the type is not a list type, the value is the same as the TF_AttrType type |
41 | // of the value. Else, the highest order bit is on, and the rest of the bits |
42 | // represent the TF_AttrType type of the values in the list. |
43 | typedef std::unordered_map<string, uint32> AttrTypeMap; |
44 | |
45 | // Look up OpDef for `op_name`. |
46 | Status OpDefForOp(const string& op_name, const OpDef** op_def); |
47 | |
48 | // Returns the AttrTypeMap for the TensorFlow operation named op_name. |
49 | // If op_name is not registered in global op registry, AttrTypeMapForOp assumes |
50 | // the op to be a function and returns the default attributes for a function. |
51 | // `is_function` is set to true in this case. |
52 | Status AttrTypeMapForOp(const char* op_name, const AttrTypeMap** out, |
53 | bool* is_function); |
54 | |
55 | // Looks for 'attr_name' in 'm' and sets 'out' and 'is_list'. |
56 | Status AttrTypeByName(const AttrTypeMap& m, const string& attr_name, |
57 | TF_AttrType* out, unsigned char* is_list); |
58 | |
59 | // KernelAndDevice::Init needs a NodeDef only to pass the attribute map through. |
60 | // An AttrBuilder is a convenience class to help with that - providing a smaller |
61 | // interface than NodeDefBuilder and avoiding expensive (unnecessary?) sanity |
62 | // checks (like number of inputs matching the OpDef - we only care about |
63 | // attributes here). |
64 | // |
65 | // TODO(ashankar): Take a closer look at checks in NodeDefBuilder and see which |
66 | // ones make sense to replicate. |
67 | |
68 | // This is a helper class for creating a NodeDef. Additionally, this class |
69 | // allows computing a cache key based on fingerprinting the attributes of this |
70 | // NodeDef. |
71 | // |
72 | // Example usage: |
73 | // AttrBuilder a; |
74 | // a.NumInputs(2); |
75 | // a.Set("T", TF_FLOAT); |
76 | // tensorflow::Fprint128 cache_key = a.CacheKey("cpu:0"); |
77 | // const NodeDef& n = a.BuildNodeDef(); |
78 | // |
79 | // Note that all calls to Set and NumInputs should happen before calling |
80 | // BuildNodeDef. Also, calls to NumInputs or Set between multiple invocations |
81 | // to CacheKey may cause different values to be returned by CacheKey. |
82 | // |
83 | // For performance reasons, the class internally delays the actual construction |
84 | // of the NodeDef till BuildNodeDef is called, or Set is called with certain |
85 | // uncommon types (see template specializations of Set to see which types |
86 | // trigger a NodeDef creation). |
87 | // |
88 | // Setting attributes via `Set` may cause arena-allocated protocol buffer |
89 | // messages to be destructed, which is not thread safe. This means that it is |
90 | // currently not safe to set attributes on *different* AttrBuilder objects from |
91 | // multiple threads. This does not apply to `CopyAttributes`. |
92 | class AttrBuilder : public AbstractOpAttrs { |
93 | public: |
94 | AttrBuilder() |
95 | : AbstractOpAttrs(AbstractOpAttrs::AbstractOpAttrsKind::kEager) {} |
96 | |
97 | ~AttrBuilder() override {} |
98 | explicit AttrBuilder(const char* op) |
99 | : AbstractOpAttrs(AbstractOpAttrs::AbstractOpAttrsKind::kEager) { |
100 | Reset(op); |
101 | } |
102 | |
103 | void Reset(const char* op) { |
104 | op_name_ = op; |
105 | num_inputs_ = 0; |
106 | encoded_attrs_.clear(); |
107 | node_def_initialized_ = false; |
108 | node_def_finalized_ = false; |
109 | cached_cache_key_ = absl::nullopt; |
110 | device_for_cached_cache_key_.clear(); |
111 | } |
112 | |
113 | const string& op_name() const { return op_name_; } |
114 | |
115 | // Needed to work around call to ValidateNodeDef in CreateOpKernel. |
116 | AttrBuilder& NumInputs(int n); |
117 | |
118 | template <class T> |
119 | AttrBuilder& Set(StringPiece attr_name, T&& value) { |
120 | SetAttrValue(value, &attr_tmp_); |
121 | AddAttrIfNotPresent(attr_name, attr_tmp_); |
122 | cached_cache_key_ = absl::nullopt; |
123 | return *this; |
124 | } |
125 | |
126 | size_t NumAttributes() const { return encoded_attrs_.size(); } |
127 | |
128 | AttrBuilder& Set(StringPiece attr_name, const AttrValue& value) { |
129 | AddAttrIfNotPresent(attr_name, value); |
130 | cached_cache_key_ = absl::nullopt; |
131 | return *this; |
132 | } |
133 | |
134 | // Retrieves the attribute value. |
135 | // Note that Get() can involve a linear scan of all attributes with the same |
136 | // value type in this Node. This is not an issue, because Get is used rarely |
137 | // and nodes have a small number of attributes. |
138 | template <class T> |
139 | Status Get(StringPiece attr_name, T* value) const { |
140 | // Common attributes are stored in AttrVecs. This Get() template |
141 | // is specialized for them below. If we end up here, the type must be |
142 | // among those that we store in the node_def_. |
143 | if (!node_def_initialized_) { |
144 | return errors::NotFound("No attr named'" , attr_name, |
145 | "' found in AttrBuilder for " , op_name_); |
146 | } |
147 | return GetNodeAttr(AttrSlice(node_def_), attr_name, value); |
148 | } |
149 | |
150 | tensorflow::Fprint128 CacheKey(const StringPiece device); |
151 | |
152 | // Fill `m` with the attr-value pairs set via AttrBuilder::Set() so far, as |
153 | // well as any default attr-value pairs from the associated op_def, if there |
154 | // is one. |
155 | void FillAttrValueMap(AttrValueMap* m) const; |
156 | |
157 | // Fill `m` with the attr-value pairs set via AttrBuilder::Set() so far except |
158 | // when the value matches the default for this attr. |
159 | // More precisely, if the global op registry contains an OpDef for this op |
160 | // and if an attribute value is the same as the default (according to the |
161 | // OpDef), this attr-value pair is not added to `m`. |
162 | void FillAttrValueMapWithoutDefaults(AttrValueMap* m) const; |
163 | const NodeDef& BuildNodeDef(); |
164 | |
165 | // Transfers the attributes from `other` to this AttrBuilder. Does not |
166 | // overwrite existing attributes. Since it does not require deserializing and |
167 | // re-serializing attributes, it is much more efficient than going through an |
168 | // AttrValueMap. |
169 | void CopyAttributes(const AttrBuilder& other); |
170 | |
171 | void GetNameAttrList(tensorflow::NameAttrList* name_and_attrs) const override; |
172 | |
173 | bool GetInt(absl::string_view attr_name, int64_t* result) const override; |
174 | bool GetFloat(absl::string_view attr_name, float* result) const override; |
175 | bool GetBool(absl::string_view attr_name, bool* result) const override; |
176 | bool GetType(absl::string_view attr_name, |
177 | tensorflow::DataType* result) const override; |
178 | Status GetTypeList( |
179 | absl::string_view attr_name, |
180 | absl::InlinedVector<DataType, 4>* type_list) const override; |
181 | |
182 | private: |
183 | tensorflow::Fprint128 BuildCacheKeyForDevice(const StringPiece device) const; |
184 | |
185 | // Initialize the node_def_ object. |
186 | // REQUIRES: node_def_initialized_ = false |
187 | void InitializeNodeDef(); |
188 | |
189 | template <class T> |
190 | void SetInAttrValueMap(AttrValueMap* m, const string& attr_name, |
191 | T&& value) const { |
192 | DCHECK(!node_def_finalized_) |
193 | << "Calling SetInAttrValueMap after BuildNodeDef." ; |
194 | // If attribute is set more than once, its first value prevails |
195 | m->insert({attr_name, value}); |
196 | } |
197 | |
198 | void AddAttrIfNotPresent(StringPiece attr_name, const AttrValue& value); |
199 | |
200 | gtl::FlatMap<string, string> encoded_attrs_; |
201 | mutable AttrValue attr_tmp_; // For encoding |
202 | |
203 | string op_name_; // Conceptually const, but can't be because of Reset(...) |
204 | int num_inputs_; |
205 | NodeDef node_def_; |
206 | bool node_def_initialized_; |
207 | bool node_def_finalized_; |
208 | |
209 | absl::optional<tensorflow::Fprint128> cached_cache_key_; |
210 | string device_for_cached_cache_key_; |
211 | }; |
212 | |
213 | template <> |
214 | Status AttrBuilder::Get(StringPiece attr_name, int* value) const; |
215 | template <> |
216 | Status AttrBuilder::Get(StringPiece attr_name, float* value) const; |
217 | template <> |
218 | Status AttrBuilder::Get(StringPiece attr_name, bool* value) const; |
219 | template <> |
220 | Status AttrBuilder::Get(StringPiece attr_name, |
221 | tensorflow::DataType* value) const; |
222 | } // namespace tensorflow |
223 | |
224 | #endif // TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_ATTR_BUILDER_H_ |
225 | |