1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_ATTR_BUILDER_H_
17#define TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_ATTR_BUILDER_H_
18
19// Support for eager execution of TensorFlow kernels.
20
21#include <memory>
22#include <unordered_map>
23
24#include "tensorflow/c/eager/abstract_op_attrs.h"
25#include "tensorflow/c/tf_attrtype.h"
26#include "tensorflow/core/common_runtime/device.h"
27#include "tensorflow/core/framework/node_def.pb.h"
28#include "tensorflow/core/framework/op_kernel.h"
29#include "tensorflow/core/framework/types.h"
30#include "tensorflow/core/lib/core/errors.h"
31#include "tensorflow/core/lib/core/status.h"
32#include "tensorflow/core/lib/gtl/inlined_vector.h"
33#include "tensorflow/core/lib/gtl/optional.h"
34#include "tensorflow/core/platform/fingerprint.h"
35#include "tensorflow/core/util/tensor_slice_reader_cache.h"
36
37namespace tensorflow {
38
39// Maps attribute name to an encoding of the type of the attribute value.
40// If the type is not a list type, the value is the same as the TF_AttrType type
41// of the value. Else, the highest order bit is on, and the rest of the bits
42// represent the TF_AttrType type of the values in the list.
43typedef std::unordered_map<string, uint32> AttrTypeMap;
44
45// Look up OpDef for `op_name`.
46Status OpDefForOp(const string& op_name, const OpDef** op_def);
47
48// Returns the AttrTypeMap for the TensorFlow operation named op_name.
49// If op_name is not registered in global op registry, AttrTypeMapForOp assumes
50// the op to be a function and returns the default attributes for a function.
51// `is_function` is set to true in this case.
52Status AttrTypeMapForOp(const char* op_name, const AttrTypeMap** out,
53 bool* is_function);
54
55// Looks for 'attr_name' in 'm' and sets 'out' and 'is_list'.
56Status AttrTypeByName(const AttrTypeMap& m, const string& attr_name,
57 TF_AttrType* out, unsigned char* is_list);
58
59// KernelAndDevice::Init needs a NodeDef only to pass the attribute map through.
60// An AttrBuilder is a convenience class to help with that - providing a smaller
61// interface than NodeDefBuilder and avoiding expensive (unnecessary?) sanity
62// checks (like number of inputs matching the OpDef - we only care about
63// attributes here).
64//
65// TODO(ashankar): Take a closer look at checks in NodeDefBuilder and see which
66// ones make sense to replicate.
67
68// This is a helper class for creating a NodeDef. Additionally, this class
69// allows computing a cache key based on fingerprinting the attributes of this
70// NodeDef.
71//
72// Example usage:
73// AttrBuilder a;
74// a.NumInputs(2);
75// a.Set("T", TF_FLOAT);
76// tensorflow::Fprint128 cache_key = a.CacheKey("cpu:0");
77// const NodeDef& n = a.BuildNodeDef();
78//
79// Note that all calls to Set and NumInputs should happen before calling
80// BuildNodeDef. Also, calls to NumInputs or Set between multiple invocations
81// to CacheKey may cause different values to be returned by CacheKey.
82//
83// For performance reasons, the class internally delays the actual construction
84// of the NodeDef till BuildNodeDef is called, or Set is called with certain
85// uncommon types (see template specializations of Set to see which types
86// trigger a NodeDef creation).
87//
88// Setting attributes via `Set` may cause arena-allocated protocol buffer
89// messages to be destructed, which is not thread safe. This means that it is
90// currently not safe to set attributes on *different* AttrBuilder objects from
91// multiple threads. This does not apply to `CopyAttributes`.
92class AttrBuilder : public AbstractOpAttrs {
93 public:
94 AttrBuilder()
95 : AbstractOpAttrs(AbstractOpAttrs::AbstractOpAttrsKind::kEager) {}
96
97 ~AttrBuilder() override {}
98 explicit AttrBuilder(const char* op)
99 : AbstractOpAttrs(AbstractOpAttrs::AbstractOpAttrsKind::kEager) {
100 Reset(op);
101 }
102
103 void Reset(const char* op) {
104 op_name_ = op;
105 num_inputs_ = 0;
106 encoded_attrs_.clear();
107 node_def_initialized_ = false;
108 node_def_finalized_ = false;
109 cached_cache_key_ = absl::nullopt;
110 device_for_cached_cache_key_.clear();
111 }
112
113 const string& op_name() const { return op_name_; }
114
115 // Needed to work around call to ValidateNodeDef in CreateOpKernel.
116 AttrBuilder& NumInputs(int n);
117
118 template <class T>
119 AttrBuilder& Set(StringPiece attr_name, T&& value) {
120 SetAttrValue(value, &attr_tmp_);
121 AddAttrIfNotPresent(attr_name, attr_tmp_);
122 cached_cache_key_ = absl::nullopt;
123 return *this;
124 }
125
126 size_t NumAttributes() const { return encoded_attrs_.size(); }
127
128 AttrBuilder& Set(StringPiece attr_name, const AttrValue& value) {
129 AddAttrIfNotPresent(attr_name, value);
130 cached_cache_key_ = absl::nullopt;
131 return *this;
132 }
133
134 // Retrieves the attribute value.
135 // Note that Get() can involve a linear scan of all attributes with the same
136 // value type in this Node. This is not an issue, because Get is used rarely
137 // and nodes have a small number of attributes.
138 template <class T>
139 Status Get(StringPiece attr_name, T* value) const {
140 // Common attributes are stored in AttrVecs. This Get() template
141 // is specialized for them below. If we end up here, the type must be
142 // among those that we store in the node_def_.
143 if (!node_def_initialized_) {
144 return errors::NotFound("No attr named'", attr_name,
145 "' found in AttrBuilder for ", op_name_);
146 }
147 return GetNodeAttr(AttrSlice(node_def_), attr_name, value);
148 }
149
150 tensorflow::Fprint128 CacheKey(const StringPiece device);
151
152 // Fill `m` with the attr-value pairs set via AttrBuilder::Set() so far, as
153 // well as any default attr-value pairs from the associated op_def, if there
154 // is one.
155 void FillAttrValueMap(AttrValueMap* m) const;
156
157 // Fill `m` with the attr-value pairs set via AttrBuilder::Set() so far except
158 // when the value matches the default for this attr.
159 // More precisely, if the global op registry contains an OpDef for this op
160 // and if an attribute value is the same as the default (according to the
161 // OpDef), this attr-value pair is not added to `m`.
162 void FillAttrValueMapWithoutDefaults(AttrValueMap* m) const;
163 const NodeDef& BuildNodeDef();
164
165 // Transfers the attributes from `other` to this AttrBuilder. Does not
166 // overwrite existing attributes. Since it does not require deserializing and
167 // re-serializing attributes, it is much more efficient than going through an
168 // AttrValueMap.
169 void CopyAttributes(const AttrBuilder& other);
170
171 void GetNameAttrList(tensorflow::NameAttrList* name_and_attrs) const override;
172
173 bool GetInt(absl::string_view attr_name, int64_t* result) const override;
174 bool GetFloat(absl::string_view attr_name, float* result) const override;
175 bool GetBool(absl::string_view attr_name, bool* result) const override;
176 bool GetType(absl::string_view attr_name,
177 tensorflow::DataType* result) const override;
178 Status GetTypeList(
179 absl::string_view attr_name,
180 absl::InlinedVector<DataType, 4>* type_list) const override;
181
182 private:
183 tensorflow::Fprint128 BuildCacheKeyForDevice(const StringPiece device) const;
184
185 // Initialize the node_def_ object.
186 // REQUIRES: node_def_initialized_ = false
187 void InitializeNodeDef();
188
189 template <class T>
190 void SetInAttrValueMap(AttrValueMap* m, const string& attr_name,
191 T&& value) const {
192 DCHECK(!node_def_finalized_)
193 << "Calling SetInAttrValueMap after BuildNodeDef.";
194 // If attribute is set more than once, its first value prevails
195 m->insert({attr_name, value});
196 }
197
198 void AddAttrIfNotPresent(StringPiece attr_name, const AttrValue& value);
199
200 gtl::FlatMap<string, string> encoded_attrs_;
201 mutable AttrValue attr_tmp_; // For encoding
202
203 string op_name_; // Conceptually const, but can't be because of Reset(...)
204 int num_inputs_;
205 NodeDef node_def_;
206 bool node_def_initialized_;
207 bool node_def_finalized_;
208
209 absl::optional<tensorflow::Fprint128> cached_cache_key_;
210 string device_for_cached_cache_key_;
211};
212
213template <>
214Status AttrBuilder::Get(StringPiece attr_name, int* value) const;
215template <>
216Status AttrBuilder::Get(StringPiece attr_name, float* value) const;
217template <>
218Status AttrBuilder::Get(StringPiece attr_name, bool* value) const;
219template <>
220Status AttrBuilder::Get(StringPiece attr_name,
221 tensorflow::DataType* value) const;
222} // namespace tensorflow
223
224#endif // TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_ATTR_BUILDER_H_
225