1/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/dtensor/cc/small_constant_optimization.h"
17
18#include <cstdint>
19#include <memory>
20#include <utility>
21#include <vector>
22
23#include "absl/algorithm/container.h"
24#include "absl/strings/string_view.h"
25#include "tensorflow/c/eager/c_api.h"
26#include "tensorflow/c/eager/c_api_experimental.h"
27#include "tensorflow/c/tf_status.h"
28#include "tensorflow/c/tf_tensor_internal.h"
29#include "tensorflow/core/framework/tensor.pb.h"
30#include "tensorflow/core/platform/ctstring_internal.h"
31#include "tensorflow/core/platform/protobuf.h"
32#include "tensorflow/dtensor/cc/constants.h"
33#include "tensorflow/dtensor/proto/layout.pb.h"
34
35namespace tensorflow {
36namespace dtensor {
37
38namespace {
39
40constexpr TF_DataType kAllowedDataType[] = {TF_INT32, TF_INT64, TF_FLOAT,
41 TF_STRING};
42
43void AppendIntValues(const int num_of_elements, const int* int_values,
44 TensorProto* proto) {
45 for (int i = 0; i < num_of_elements; ++i) {
46 proto->add_int_val(int_values[i]);
47 }
48}
49
50void AppendInt64Values(const int num_of_elements, const int64_t* int64_values,
51 TensorProto* proto) {
52 for (int i = 0; i < num_of_elements; ++i) {
53 proto->add_int64_val(int64_values[i]);
54 }
55}
56
57void AppendStringValues(const int num_of_elements,
58 const TF_TString* string_values, TensorProto* proto) {
59 for (int i = 0; i < num_of_elements; ++i) {
60 proto->add_string_val(
61 std::string(TF_TString_GetDataPointer(&string_values[i]),
62 TF_TString_GetSize(&string_values[i])));
63 }
64}
65void AppendFloatValues(const int num_of_elements, const float* float_values,
66 TensorProto* proto) {
67 for (int i = 0; i < num_of_elements; ++i) {
68 proto->add_float_val(float_values[i]);
69 }
70}
71
72} // namespace
73
74absl::optional<NodeDef> ExtractSmallTensorValue(TFE_Context* context,
75 TFE_TensorHandle* tensor,
76 const Layout& layout,
77 TF_Status* status) {
78 if (!layout.IsFullyReplicated()) return std::nullopt;
79 auto num_elements = TFE_TensorHandleNumElements(tensor, status);
80 if (TF_GetCode(status) != TF_OK) return absl::nullopt;
81
82 if (num_elements >= kSmallTensorThreshold) return absl::nullopt;
83
84 // Check the DType before attempting to resolve the tensor so we don't try to
85 // copy resource-dtype tensors off the DTensor device. Currently we only
86 // extract small int32/int64_t tensors, primarily to catch shapes and axes,
87 // and tf_string tensors that are mostly used in save/restore ops.
88 const auto& dtype = TFE_TensorHandleDataType(tensor);
89 if (absl::c_find(kAllowedDataType, dtype) == std::end(kAllowedDataType)) {
90 return absl::nullopt;
91 }
92
93 // This is the enum from protobuf, or the following AddNodeAttr will always
94 // set the integer field.
95 const auto& datatype = static_cast<DataType>(dtype);
96 std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> value_tensor(
97 TFE_TensorHandleResolve(tensor, status), TF_DeleteTensor);
98 if (TF_GetCode(status) != TF_OK) return absl::nullopt;
99
100 NodeDef node_def;
101 node_def.set_op("Const");
102 AddNodeAttr("dtype", datatype, &node_def);
103
104 TensorProto tensor_proto;
105 tensor_proto.set_dtype(datatype);
106 switch (dtype) {
107 case TF_INT32:
108 AppendIntValues(num_elements,
109 static_cast<int*>(TF_TensorData(value_tensor.get())),
110 &tensor_proto);
111 break;
112 case TF_INT64:
113 AppendInt64Values(
114 num_elements,
115 static_cast<const int64_t*>(TF_TensorData(value_tensor.get())),
116 &tensor_proto);
117 break;
118 case TF_STRING:
119 AppendStringValues(
120 num_elements,
121 static_cast<const TF_TString*>(TF_TensorData(value_tensor.get())),
122 &tensor_proto);
123 break;
124 case TF_FLOAT:
125 AppendFloatValues(
126 num_elements,
127 static_cast<const float*>(TF_TensorData(value_tensor.get())),
128 &tensor_proto);
129 break;
130 default:
131 TF_SetStatus(status, TF_INTERNAL,
132 absl::StrCat("dtype: ", dtype,
133 " fell through the supported extraction list. "
134 "This should not happen.")
135 .c_str());
136 return absl::nullopt;
137 }
138
139 std::vector<int64_t> dim_list;
140 int num_dims = value_tensor->tensor->NumDims();
141 dim_list.reserve(num_dims);
142 for (int i = 0; i < num_dims; ++i) {
143 dim_list.push_back(value_tensor->tensor->Dim(i));
144 }
145
146 TensorShape shape(std::move(dim_list));
147 shape.AsProto(tensor_proto.mutable_tensor_shape());
148 AddNodeAttr("value", tensor_proto, &node_def);
149
150 AddNodeAttr(kLayoutAttr, {layout.ToString()}, &node_def);
151 AddNodeAttr(kMeshAttr, layout.mesh().ToString(), &node_def);
152 return node_def;
153}
154
155bool ShouldFoldInputArgument(absl::string_view operation_name,
156 int input_index) {
157 // Fold if we are in a function or if a special eager op.
158 // TODO(xiejw,power): Think about how to generalize this so it does not depend
159 // on operation_name. For example, we can check the max abs value of the
160 // tensor value.
161 if (operation_name == absl::string_view("StatelessRandomUniform") ||
162 operation_name == absl::string_view("StatelessRandomUniformFullInt") ||
163 operation_name == absl::string_view("StatelessRandomNormal") ||
164 operation_name == absl::string_view("StatelessTruncatedNormal")) {
165 // For all stateless rng ops, we avoid fold seed (input_index==1) in graph.
166 // This is an important optimization to avoid unnecessary MLIR SPMD lowering
167 // and TPU compilation during model parameters initialization process.
168 // which typically have the same shape for rng ops but different seeds.
169 return input_index != 1;
170 }
171
172 return true;
173}
174
175bool NodeDefsHaveDifferentTensorProto(const NodeDef& a, const NodeDef& b) {
176 const TensorProto* tensor_proto_a;
177 bool read_a_tensor_proto = TryGetNodeAttr(a, "value", &tensor_proto_a);
178 if (!read_a_tensor_proto) return true;
179
180 const TensorProto* tensor_proto_b;
181 bool read_b_tensor_proto = TryGetNodeAttr(b, "value", &tensor_proto_b);
182 if (!read_b_tensor_proto) return true;
183 return !protobuf::util::MessageDifferencer::Equals(*tensor_proto_a,
184 *tensor_proto_b);
185}
186
187} // namespace dtensor
188} // namespace tensorflow
189