1/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/dtensor/cc/dtensor_utils.h"
17
18#include <cstdlib>
19
20#include "absl/strings/numbers.h"
21#include "tensorflow/core/platform/logging.h"
22
23namespace tensorflow {
24namespace dtensor {
25
26// LINT.IfChange
27int ClientId() {
28 char* client_id_str = std::getenv("DTENSOR_CLIENT_ID");
29 if (client_id_str == nullptr) return 0;
30 int client_id;
31 if (absl::SimpleAtoi(client_id_str, &client_id)) return client_id;
32 LOG(WARNING) << "Invalid DTENSOR_CLIENT_ID, using the default value 0.";
33 return 0;
34}
35// LINT.ThenChange(//tensorflow/dtensor/python/dtensor_device.py)
36
37// LINT.IfChange
38int NumClients() {
39 char* num_clients_str = std::getenv("DTENSOR_NUM_CLIENTS");
40 if (num_clients_str == nullptr) return 1;
41 int num_clients;
42 if (absl::SimpleAtoi(num_clients_str, &num_clients)) return num_clients;
43 LOG(WARNING) << "Invalid DTENSOR_NUM_CLIENTS, using the default value 1.";
44 return 1;
45}
46// LINT.ThenChange(//tensorflow/dtensor/python/dtensor_device.py)
47
48bool LogOnAllTasks() {
49 char* dtensor_log_on_all_tasks_str = std::getenv("DTENSOR_LOG_ON_ALL_TASKS");
50 if (dtensor_log_on_all_tasks_str == nullptr) return false;
51 return true;
52}
53
54bool LogOpByOp() {
55 char* dtensor_log_op_by_op_str = std::getenv("DTENSOR_LOG_OP_BY_OP");
56 if (dtensor_log_op_by_op_str == nullptr) return false;
57 return true;
58}
59
60int LayoutPropagationMaxSteps() {
61 char* dtensor_layout_propagation_max_steps_str =
62 std::getenv("DTENSOR_LAYOUT_PROPAGATION_MAX_STEPS");
63 if (dtensor_layout_propagation_max_steps_str == nullptr) return 500;
64 int dtensor_layout_propagation_max_steps;
65 if (absl::SimpleAtoi(dtensor_layout_propagation_max_steps_str,
66 &dtensor_layout_propagation_max_steps))
67 return dtensor_layout_propagation_max_steps;
68 LOG(WARNING) << "Invalid DTENSOR_LAYOUT_PROPAGATION_MAX_STEPS, using "
69 "the default value 500.";
70 return 500;
71}
72
73bool EnableMixedPrecisionReduce() {
74 char* dtensor_enable_mixed_precision_reduce_str =
75 std::getenv("DTENSOR_ENABLE_MIXED_PRECISION_REDUCE");
76 if (dtensor_enable_mixed_precision_reduce_str == nullptr) return false;
77 return true;
78}
79
80bool DoNotFuseReduceScatter() {
81 char* dtensor_do_not_fuse_reduce_scatter_str =
82 std::getenv("DTENSOR_DO_NOT_FUSE_REDUCE_SCATTER");
83 if (dtensor_do_not_fuse_reduce_scatter_str == nullptr) return false;
84 return true;
85}
86
87int ReduceInBfloat16MaxGroupSize() {
88 char* dtensor_reduce_in_bfloat16_max_group_size_str =
89 std::getenv("DTENSOR_REDUCE_IN_BFLOAT16_MAX_GROUP_SIZE");
90 if (dtensor_reduce_in_bfloat16_max_group_size_str == nullptr) return 8;
91 int dtensor_reduce_in_bfloat16_max_group_size;
92 if (absl::SimpleAtoi(dtensor_reduce_in_bfloat16_max_group_size_str,
93 &dtensor_reduce_in_bfloat16_max_group_size))
94 return dtensor_reduce_in_bfloat16_max_group_size;
95 LOG(WARNING) << "Invalid DTENSOR_REDUCE_IN_BFLOAT16_MAX_GROUP_SIZE, using "
96 "the default value 8.";
97 return 8;
98}
99
100} // namespace dtensor
101} // namespace tensorflow
102