1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_CORE_FRAMEWORK_ATTR_VALUE_UTIL_H_ |
17 | #define TENSORFLOW_CORE_FRAMEWORK_ATTR_VALUE_UTIL_H_ |
18 | |
19 | #include <functional> |
20 | #include <string> |
21 | #include <vector> |
22 | |
23 | #include "tensorflow/core/framework/partial_tensor_shape.h" |
24 | #include "tensorflow/core/framework/tensor.h" |
25 | #include "tensorflow/core/framework/tensor_shape.h" |
26 | #include "tensorflow/core/framework/types.h" |
27 | #include "tensorflow/core/lib/core/status.h" |
28 | #include "tensorflow/core/lib/core/stringpiece.h" |
29 | #include "tensorflow/core/lib/gtl/array_slice.h" |
30 | |
31 | namespace tensorflow { |
32 | |
33 | // Forward declare protos so their symbols can be removed from .so exports |
34 | class AttrValue; |
35 | class NameAttrList; |
36 | |
37 | // A human-readable rendering of attr_value, that is more concise than a |
38 | // text-format proto. |
39 | std::string SummarizeAttrValue(const AttrValue& attr_value); |
40 | |
41 | // Generates an error if attr_value doesn't have the indicated attr type. |
42 | Status AttrValueHasType(const AttrValue& attr_value, StringPiece type); |
43 | |
44 | // Converts a text proto value from "text" into the field of *out |
45 | // indicated by "type" (e.g. from the type field of an AttrDef). |
46 | // Examples: |
47 | // * If type:"int" and text:"-14", then *out is set to "i: -14" |
48 | // * If type:"list(string)" and text:"['foo', 'bar']", |
49 | // then *out is set to "list { s: ['foo', 'bar'] }" |
50 | // Returns true on success. |
51 | bool ParseAttrValue(StringPiece type, StringPiece text, AttrValue* out); |
52 | |
53 | // Sets *out based on the type of value. |
54 | void SetAttrValue(const std::string& value, AttrValue* out); |
55 | void SetAttrValue(const tstring& value, AttrValue* out); |
56 | void SetAttrValue(const char* value, AttrValue* out); |
57 | void SetAttrValue(StringPiece value, AttrValue* out); |
58 | void SetAttrValue(int64_t value, AttrValue* out); |
59 | void SetAttrValue(int32_t value, AttrValue* out); |
60 | void SetAttrValue(float value, AttrValue* out); |
61 | void SetAttrValue(double value, AttrValue* out); |
62 | void SetAttrValue(bool value, AttrValue* out); |
63 | void SetAttrValue(DataType value, AttrValue* out); |
64 | void SetAttrValue(const TensorShape& value, AttrValue* out); |
65 | void SetAttrValue(const TensorShapeProto& value, AttrValue* out); |
66 | void SetAttrValue(const PartialTensorShape& value, AttrValue* out); |
67 | void SetAttrValue(const Tensor& value, AttrValue* out); |
68 | void SetAttrValue(const TensorProto& value, AttrValue* out); |
69 | void SetAttrValue(const NameAttrList& value, AttrValue* out); |
70 | |
71 | void SetAttrValue(gtl::ArraySlice<string> value, AttrValue* out); |
72 | void SetAttrValue(gtl::ArraySlice<tstring> value, AttrValue* out); |
73 | void SetAttrValue(gtl::ArraySlice<const char*> value, AttrValue* out); |
74 | void SetAttrValue(gtl::ArraySlice<StringPiece> value, AttrValue* out); |
75 | void SetAttrValue(gtl::ArraySlice<int64_t> value, AttrValue* out); |
76 | void SetAttrValue(gtl::ArraySlice<int32> value, AttrValue* out); |
77 | void SetAttrValue(gtl::ArraySlice<float> value, AttrValue* out); |
78 | void SetAttrValue(gtl::ArraySlice<double> value, AttrValue* out); |
79 | void SetAttrValue(gtl::ArraySlice<bool> value, AttrValue* out); |
80 | void SetAttrValue(const std::vector<bool>& value, AttrValue* out); |
81 | void SetAttrValue(std::initializer_list<bool> value, AttrValue* out); |
82 | void SetAttrValue(DataTypeSlice value, AttrValue* out); |
83 | void SetAttrValue(gtl::ArraySlice<TensorShape> value, AttrValue* out); |
84 | void SetAttrValue(gtl::ArraySlice<TensorShapeProto> value, AttrValue* out); |
85 | void SetAttrValue(gtl::ArraySlice<PartialTensorShape> value, AttrValue* out); |
86 | void SetAttrValue(gtl::ArraySlice<Tensor> value, AttrValue* out); |
87 | void SetAttrValue(gtl::ArraySlice<TensorProto> value, AttrValue* out); |
88 | void SetAttrValue(gtl::ArraySlice<NameAttrList> value, AttrValue* out); |
89 | |
90 | void SetAttrValue(const AttrValue& value, AttrValue* out); |
91 | |
92 | void MoveAttrValue(std::vector<string>&& value, AttrValue* out); |
93 | |
94 | // Returns a hash of `a` that is consistent with AreAttrValuesEqual. In other |
95 | // words, if two AttrValues compare equal according to AreAttrValuesEqual, |
96 | // they will have the same hash value. |
97 | // Similarly to protobuf deterministic serialization, hash value is |
98 | // guaranteed to be stable only for a given binary. In particular, one should |
99 | // probably not persist the returned value. |
100 | uint64 AttrValueHash(const AttrValue& a); |
101 | |
102 | // WARNING: Equality check might return false-negative for large (> 32mb) |
103 | // tensors defined with different TensorProto representations. |
104 | // |
105 | // A pair of consistent hash and equals functions that are guaranteed to be fast |
106 | // with AttrValues that potentially can have very large Tensors (larger than |
107 | // 32mb) defined by TensorProto. If large identical Tensors are defined using |
108 | // different representations (e.g. one with tensor content, and second with |
109 | // bool_val), they will have different hash code and equals will return false. |
110 | // Small (less than 32mb) tensors with different TensorProto representations |
111 | // hashed/compared by their tensor content. |
112 | uint64 FastAttrValueHash(const AttrValue& a); |
113 | // Returns true if a and b have the same value. If false negatives are allowed, |
114 | // then compares proto representation to avoid construction of large (> 32mb) |
115 | // tensors. |
116 | bool AreAttrValuesEqual(const AttrValue& a, const AttrValue& b, |
117 | bool allow_false_negatives = false); |
118 | |
119 | // Returns true if "val" has a placeholder. |
120 | bool HasPlaceHolder(const AttrValue& val); |
121 | |
122 | // SubstitutePlaceholders recursively replaces placeholders in 'value' |
123 | // with an attr value by calling SubstituteFunc. Returns true iff all |
124 | // placeholders in "value" are replaced with a value. |
125 | // |
126 | // SubstituteFunc is given a placeholder string. If the placeholder is |
127 | // unknown, SubstituteFunc returns false. Otherwise, overwrites the |
128 | // attr value and returns true. |
129 | using SubstituteFunc = std::function<bool(const string&, AttrValue*)>; |
130 | bool SubstitutePlaceholders(const SubstituteFunc& substitute, AttrValue* value); |
131 | |
132 | } // namespace tensorflow |
133 | |
134 | #endif // TENSORFLOW_CORE_FRAMEWORK_ATTR_VALUE_UTIL_H_ |
135 | |