1/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_DTENSOR_CC_SMALL_CONSTANT_OPTIMIZATION_H_
17#define TENSORFLOW_DTENSOR_CC_SMALL_CONSTANT_OPTIMIZATION_H_
18
19#include "absl/types/optional.h"
20#include "tensorflow/c/eager/c_api.h"
21#include "tensorflow/core/framework/node_def_builder.h"
22#include "tensorflow/dtensor/cc/tensor_layout.h"
23
24namespace tensorflow {
25namespace dtensor {
26
27// Attempt to convert small constant tensors into a constant NodeDef operation.
28// This constant value will be available for constant propagation in DTensor and
29// MLIR.
30
31// This conversion is currently required for some DTensor operations. In
32// particular, reductions require access to the axis argument at compilation
33// time. While this is not strictly necessary, it greatly simplifies SPMD code
34// generation and is generally available.
35absl::optional<NodeDef> ExtractSmallTensorValue(TFE_Context* context,
36 TFE_TensorHandle* tensor,
37 const Layout& layout,
38 TF_Status* status);
39
40// Returns true if the given input argument should be eligible for extracting
41// into a graph constant.
42bool ShouldFoldInputArgument(absl::string_view operation_name, int input_index);
43
44// Returns true if the tensor proto of a and b are different.
45bool NodeDefsHaveDifferentTensorProto(const NodeDef& a, const NodeDef& b);
46} // namespace dtensor
47} // namespace tensorflow
48
49#endif // TENSORFLOW_DTENSOR_CC_SMALL_CONSTANT_OPTIMIZATION_H_
50