1/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_DTENSOR_CC_DTENSOR_UTILS_H_
17#define TENSORFLOW_DTENSOR_CC_DTENSOR_UTILS_H_
18
19namespace tensorflow {
20namespace dtensor {
21
22// Returns the DTensor client ID of this process, usually equal to the TF task
23// ID on this host.
24int ClientId();
25
26// Returns the total number of DTensor clients, usually equal to the total
27// number of TF tasks.
28int NumClients();
29
30// Returns whether to enable logging for passes and layouts on all passes.
31bool LogOnAllTasks();
32
33// Returns whether to log op-by-op execution in addition to function execution
34// when logging is enabled.
35bool LogOpByOp();
36
37// Returns the maximum number of steps to run layout propagation. If the number
38// of steps exceeds this amount, layout propagation will fail.
39int LayoutPropagationMaxSteps();
40
41// Returns whether to upcast bfloat16 reduction inputs to float32 for
42// sufficient reduction group size.
43bool EnableMixedPrecisionReduce();
44
45// Returns whether *not* to fuse AllReduce + AllScatter into ReduceScatter op,
46// which can be more efficiently implemented.
47bool DoNotFuseReduceScatter();
48
49// Returns the maximum reduction group size for bfloat16 reduction. If the
50// group size exceeds this, then tensors are upcasted to float32 before the
51// reduce op.
52int ReduceInBfloat16MaxGroupSize();
53
54} // namespace dtensor
55} // namespace tensorflow
56
57#endif // TENSORFLOW_DTENSOR_CC_DTENSOR_UTILS_H_
58