1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_CORE_FRAMEWORK_ATTR_VALUE_UTIL_H_
17#define TENSORFLOW_CORE_FRAMEWORK_ATTR_VALUE_UTIL_H_
18
19#include <functional>
20#include <string>
21#include <vector>
22
23#include "tensorflow/core/framework/partial_tensor_shape.h"
24#include "tensorflow/core/framework/tensor.h"
25#include "tensorflow/core/framework/tensor_shape.h"
26#include "tensorflow/core/framework/types.h"
27#include "tensorflow/core/lib/core/status.h"
28#include "tensorflow/core/lib/core/stringpiece.h"
29#include "tensorflow/core/lib/gtl/array_slice.h"
30
31namespace tensorflow {
32
33// Forward declare protos so their symbols can be removed from .so exports
34class AttrValue;
35class NameAttrList;
36
37// A human-readable rendering of attr_value, that is more concise than a
38// text-format proto.
39std::string SummarizeAttrValue(const AttrValue& attr_value);
40
41// Generates an error if attr_value doesn't have the indicated attr type.
42Status AttrValueHasType(const AttrValue& attr_value, StringPiece type);
43
44// Converts a text proto value from "text" into the field of *out
45// indicated by "type" (e.g. from the type field of an AttrDef).
46// Examples:
47// * If type:"int" and text:"-14", then *out is set to "i: -14"
48// * If type:"list(string)" and text:"['foo', 'bar']",
49// then *out is set to "list { s: ['foo', 'bar'] }"
50// Returns true on success.
51bool ParseAttrValue(StringPiece type, StringPiece text, AttrValue* out);
52
53// Sets *out based on the type of value.
54void SetAttrValue(const std::string& value, AttrValue* out);
55void SetAttrValue(const tstring& value, AttrValue* out);
56void SetAttrValue(const char* value, AttrValue* out);
57void SetAttrValue(StringPiece value, AttrValue* out);
58void SetAttrValue(int64_t value, AttrValue* out);
59void SetAttrValue(int32_t value, AttrValue* out);
60void SetAttrValue(float value, AttrValue* out);
61void SetAttrValue(double value, AttrValue* out);
62void SetAttrValue(bool value, AttrValue* out);
63void SetAttrValue(DataType value, AttrValue* out);
64void SetAttrValue(const TensorShape& value, AttrValue* out);
65void SetAttrValue(const TensorShapeProto& value, AttrValue* out);
66void SetAttrValue(const PartialTensorShape& value, AttrValue* out);
67void SetAttrValue(const Tensor& value, AttrValue* out);
68void SetAttrValue(const TensorProto& value, AttrValue* out);
69void SetAttrValue(const NameAttrList& value, AttrValue* out);
70
71void SetAttrValue(gtl::ArraySlice<string> value, AttrValue* out);
72void SetAttrValue(gtl::ArraySlice<tstring> value, AttrValue* out);
73void SetAttrValue(gtl::ArraySlice<const char*> value, AttrValue* out);
74void SetAttrValue(gtl::ArraySlice<StringPiece> value, AttrValue* out);
75void SetAttrValue(gtl::ArraySlice<int64_t> value, AttrValue* out);
76void SetAttrValue(gtl::ArraySlice<int32> value, AttrValue* out);
77void SetAttrValue(gtl::ArraySlice<float> value, AttrValue* out);
78void SetAttrValue(gtl::ArraySlice<double> value, AttrValue* out);
79void SetAttrValue(gtl::ArraySlice<bool> value, AttrValue* out);
80void SetAttrValue(const std::vector<bool>& value, AttrValue* out);
81void SetAttrValue(std::initializer_list<bool> value, AttrValue* out);
82void SetAttrValue(DataTypeSlice value, AttrValue* out);
83void SetAttrValue(gtl::ArraySlice<TensorShape> value, AttrValue* out);
84void SetAttrValue(gtl::ArraySlice<TensorShapeProto> value, AttrValue* out);
85void SetAttrValue(gtl::ArraySlice<PartialTensorShape> value, AttrValue* out);
86void SetAttrValue(gtl::ArraySlice<Tensor> value, AttrValue* out);
87void SetAttrValue(gtl::ArraySlice<TensorProto> value, AttrValue* out);
88void SetAttrValue(gtl::ArraySlice<NameAttrList> value, AttrValue* out);
89
90void SetAttrValue(const AttrValue& value, AttrValue* out);
91
92void MoveAttrValue(std::vector<string>&& value, AttrValue* out);
93
94// Returns a hash of `a` that is consistent with AreAttrValuesEqual. In other
95// words, if two AttrValues compare equal according to AreAttrValuesEqual,
96// they will have the same hash value.
97// Similarly to protobuf deterministic serialization, hash value is
98// guaranteed to be stable only for a given binary. In particular, one should
99// probably not persist the returned value.
100uint64 AttrValueHash(const AttrValue& a);
101
102// WARNING: Equality check might return false-negative for large (> 32mb)
103// tensors defined with different TensorProto representations.
104//
105// A pair of consistent hash and equals functions that are guaranteed to be fast
106// with AttrValues that potentially can have very large Tensors (larger than
107// 32mb) defined by TensorProto. If large identical Tensors are defined using
108// different representations (e.g. one with tensor content, and second with
109// bool_val), they will have different hash code and equals will return false.
110// Small (less than 32mb) tensors with different TensorProto representations
111// hashed/compared by their tensor content.
112uint64 FastAttrValueHash(const AttrValue& a);
113// Returns true if a and b have the same value. If false negatives are allowed,
114// then compares proto representation to avoid construction of large (> 32mb)
115// tensors.
116bool AreAttrValuesEqual(const AttrValue& a, const AttrValue& b,
117 bool allow_false_negatives = false);
118
119// Returns true if "val" has a placeholder.
120bool HasPlaceHolder(const AttrValue& val);
121
122// SubstitutePlaceholders recursively replaces placeholders in 'value'
123// with an attr value by calling SubstituteFunc. Returns true iff all
124// placeholders in "value" are replaced with a value.
125//
126// SubstituteFunc is given a placeholder string. If the placeholder is
127// unknown, SubstituteFunc returns false. Otherwise, overwrites the
128// attr value and returns true.
129using SubstituteFunc = std::function<bool(const string&, AttrValue*)>;
130bool SubstitutePlaceholders(const SubstituteFunc& substitute, AttrValue* value);
131
132} // namespace tensorflow
133
134#endif // TENSORFLOW_CORE_FRAMEWORK_ATTR_VALUE_UTIL_H_
135