1 | /* Copyright 2022 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/dtensor/cc/dtensor_utils.h" |
17 | |
18 | #include <cstdlib> |
19 | |
20 | #include "absl/strings/numbers.h" |
21 | #include "tensorflow/core/platform/logging.h" |
22 | |
23 | namespace tensorflow { |
24 | namespace dtensor { |
25 | |
26 | // LINT.IfChange |
27 | int ClientId() { |
28 | char* client_id_str = std::getenv("DTENSOR_CLIENT_ID" ); |
29 | if (client_id_str == nullptr) return 0; |
30 | int client_id; |
31 | if (absl::SimpleAtoi(client_id_str, &client_id)) return client_id; |
32 | LOG(WARNING) << "Invalid DTENSOR_CLIENT_ID, using the default value 0." ; |
33 | return 0; |
34 | } |
35 | // LINT.ThenChange(//tensorflow/dtensor/python/dtensor_device.py) |
36 | |
37 | // LINT.IfChange |
38 | int NumClients() { |
39 | char* num_clients_str = std::getenv("DTENSOR_NUM_CLIENTS" ); |
40 | if (num_clients_str == nullptr) return 1; |
41 | int num_clients; |
42 | if (absl::SimpleAtoi(num_clients_str, &num_clients)) return num_clients; |
43 | LOG(WARNING) << "Invalid DTENSOR_NUM_CLIENTS, using the default value 1." ; |
44 | return 1; |
45 | } |
46 | // LINT.ThenChange(//tensorflow/dtensor/python/dtensor_device.py) |
47 | |
48 | bool LogOnAllTasks() { |
49 | char* dtensor_log_on_all_tasks_str = std::getenv("DTENSOR_LOG_ON_ALL_TASKS" ); |
50 | if (dtensor_log_on_all_tasks_str == nullptr) return false; |
51 | return true; |
52 | } |
53 | |
54 | bool LogOpByOp() { |
55 | char* dtensor_log_op_by_op_str = std::getenv("DTENSOR_LOG_OP_BY_OP" ); |
56 | if (dtensor_log_op_by_op_str == nullptr) return false; |
57 | return true; |
58 | } |
59 | |
60 | int LayoutPropagationMaxSteps() { |
61 | char* dtensor_layout_propagation_max_steps_str = |
62 | std::getenv("DTENSOR_LAYOUT_PROPAGATION_MAX_STEPS" ); |
63 | if (dtensor_layout_propagation_max_steps_str == nullptr) return 500; |
64 | int dtensor_layout_propagation_max_steps; |
65 | if (absl::SimpleAtoi(dtensor_layout_propagation_max_steps_str, |
66 | &dtensor_layout_propagation_max_steps)) |
67 | return dtensor_layout_propagation_max_steps; |
68 | LOG(WARNING) << "Invalid DTENSOR_LAYOUT_PROPAGATION_MAX_STEPS, using " |
69 | "the default value 500." ; |
70 | return 500; |
71 | } |
72 | |
73 | bool EnableMixedPrecisionReduce() { |
74 | char* dtensor_enable_mixed_precision_reduce_str = |
75 | std::getenv("DTENSOR_ENABLE_MIXED_PRECISION_REDUCE" ); |
76 | if (dtensor_enable_mixed_precision_reduce_str == nullptr) return false; |
77 | return true; |
78 | } |
79 | |
80 | bool DoNotFuseReduceScatter() { |
81 | char* dtensor_do_not_fuse_reduce_scatter_str = |
82 | std::getenv("DTENSOR_DO_NOT_FUSE_REDUCE_SCATTER" ); |
83 | if (dtensor_do_not_fuse_reduce_scatter_str == nullptr) return false; |
84 | return true; |
85 | } |
86 | |
87 | int ReduceInBfloat16MaxGroupSize() { |
88 | char* dtensor_reduce_in_bfloat16_max_group_size_str = |
89 | std::getenv("DTENSOR_REDUCE_IN_BFLOAT16_MAX_GROUP_SIZE" ); |
90 | if (dtensor_reduce_in_bfloat16_max_group_size_str == nullptr) return 8; |
91 | int dtensor_reduce_in_bfloat16_max_group_size; |
92 | if (absl::SimpleAtoi(dtensor_reduce_in_bfloat16_max_group_size_str, |
93 | &dtensor_reduce_in_bfloat16_max_group_size)) |
94 | return dtensor_reduce_in_bfloat16_max_group_size; |
95 | LOG(WARNING) << "Invalid DTENSOR_REDUCE_IN_BFLOAT16_MAX_GROUP_SIZE, using " |
96 | "the default value 8." ; |
97 | return 8; |
98 | } |
99 | |
100 | } // namespace dtensor |
101 | } // namespace tensorflow |
102 | |