1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_COMPILER_JIT_FLAGS_H_ |
17 | #define TENSORFLOW_COMPILER_JIT_FLAGS_H_ |
18 | |
19 | #include <string> |
20 | #include <vector> |
21 | |
22 | #include "absl/types/optional.h" |
23 | #include "tensorflow/core/platform/types.h" |
24 | #include "tensorflow/core/protobuf/config.pb.h" |
25 | #include "tensorflow/core/util/command_line_flags.h" |
26 | |
27 | namespace tensorflow { |
28 | |
29 | struct XlaAutoJitFlag { |
30 | // Control compilation of operators into XLA computations on CPU and GPU |
31 | // devices. 0 = use ConfigProto setting; -1 = off; 1 = on for things very |
32 | // likely to be improved; 2 = on for everything. |
33 | // |
34 | // If all non-CPU ops in the graph being optimized are placed on a single GPU |
35 | // and there is at least one node placed on that GPU then |
36 | // `optimization_level_single_gpu` applies. Otherwise |
37 | // `optimization_level_general` applies. |
38 | // |
39 | // Experimental. |
40 | int32 optimization_level_single_gpu; |
41 | int32 optimization_level_general; |
42 | }; |
43 | |
44 | // Sets the xla_auto_jit_flag based on the given flag string. Supported syntax |
45 | // is: |
46 | // <number>: sets general and single_gpu setting to the provided number. |
47 | // single-gpu(<number>): sets the single_gpu setting to the provided number. |
48 | bool SetXlaAutoJitFlagFromFlagString(const string& value); |
49 | |
50 | // Flags associated with the XLA bridge's mark_for_compilation_pass module. |
51 | struct MarkForCompilationPassFlags { |
52 | XlaAutoJitFlag xla_auto_jit_flag; |
53 | |
54 | // Minimum number of operators in an XLA compilation. Ignored for operators |
55 | // placed on an XLA device or operators explicitly marked for compilation. |
56 | int32 tf_xla_min_cluster_size; |
57 | |
58 | // Maximum number of operators in an XLA compilation. |
59 | int32 tf_xla_max_cluster_size; |
60 | |
61 | // If non-empty, limit XLA clustering to the following TF operations. |
62 | string tf_xla_ops_to_cluster; |
63 | |
64 | // If non-empty, remove following operations from XLA clustering excludelist. |
65 | string tf_xla_cluster_exclude_ops; |
66 | |
67 | // Dump graphs during XLA compilation. |
68 | bool tf_xla_clustering_debug; |
69 | |
70 | // Enables global JIT compilation for CPU via SessionOptions. |
71 | bool tf_xla_cpu_global_jit; |
72 | |
73 | // "Compiler fuel" for clustering. Only this many ops will be marked as |
74 | // eligible for clustering. |
75 | int64_t tf_xla_clustering_fuel; |
76 | |
77 | // If tf_xla_disable_deadness_safety_checks_for_debugging is set to true then |
78 | // we do not do deadness related safety checks. This is unsound in general, |
79 | // but can be used as a debugging aid. |
80 | bool tf_xla_disable_deadness_safety_checks_for_debugging; |
81 | |
82 | // If tf_xla_disable_resource_variable_safety_checks_for_debugging is set to |
83 | // true then we do not do safety checks to preserve TensorFlow's resource |
84 | // variable concurrency semantics. This is unsound in general, but can be |
85 | // used as a debugging aid. |
86 | bool tf_xla_disable_resource_variable_safety_checks_for_debugging; |
87 | |
88 | // If true names of clustered operations will be computed deterministically |
89 | // so that they remain stable from run to run of auto clusteing. |
90 | bool tf_xla_deterministic_cluster_names; |
91 | |
92 | // If non-empty, JIT-compiled executables are saved to and loaded from the |
93 | // specified file system directory path. |
94 | std::string tf_xla_persistent_cache_directory; |
95 | |
96 | // If true, entries loaded into the XLA compile cache will not have their |
97 | // signatures checked strictly. This should generally not be disabled except |
98 | // for debugging. Defaults to false. |
99 | bool tf_xla_disable_strict_signature_checks; |
100 | |
101 | // Specifies the persistance cache prefix. Default is "xla_compile_cache" |
102 | string tf_xla_persistent_cache_prefix; |
103 | }; |
104 | |
105 | // Flags associated with the XLA bridge's xla_device module. |
106 | struct XlaDeviceFlags { |
107 | // Switch the CPU device into "on-demand" mode, where instead of |
108 | // autoclustering ops are compiled one by one just-in-time. |
109 | // Enabling this mode by a legacy flag is a temporary mechanism. When this |
110 | // feature is battle-tested, we will switch this to be a session option. |
111 | bool tf_xla_compile_on_demand; |
112 | |
113 | // Enables "XLA" devices if this flag is set. |
114 | bool tf_xla_enable_xla_devices; |
115 | }; |
116 | |
117 | // Flags common to the _Xla* ops and their kernels. |
118 | struct XlaOpsCommonFlags { |
119 | // If true, _XlaCompile always refuses to compile the cluster, which means the |
120 | // XLA clusters always run in the TF executor. Defaults to false. |
121 | bool tf_xla_always_defer_compilation; |
122 | // If true, _XlaCompile compiles the cluster asynchronously with respect to |
123 | // the main execution. The fallback path is taken while compilation happens. |
124 | bool tf_xla_async_compilation; |
125 | }; |
126 | |
127 | // Flags for the build_xla_ops pass. |
128 | struct BuildXlaOpsPassFlags { |
129 | // Enables lazy compilation for TF/XLA (only when auto-clustering) if true. |
130 | // Defaults to true. |
131 | bool tf_xla_enable_lazy_compilation; |
132 | |
133 | // If true then insert Print nodes to print out values produced by XLA |
134 | // clusters. Useful for debugging. |
135 | bool tf_xla_print_cluster_outputs; |
136 | |
137 | // If true, insert CheckNumerics nodes for every floating point typed input to |
138 | // an XLA cluster. |
139 | bool tf_xla_check_cluster_input_numerics; |
140 | |
141 | // If true, insert CheckNumerics nodes for every floating point typed output |
142 | // from an XLA cluster. |
143 | bool tf_xla_check_cluster_output_numerics; |
144 | |
145 | // Disables all constant folding. The primary use for this is for testing to |
146 | // guarantee that tests are run on XLA and not on TF's CPU implementation. |
147 | bool tf_xla_disable_constant_folding; |
148 | }; |
149 | |
150 | // Flags for the IntroduceFloatingPointJitter pass. |
151 | struct IntroduceFloatingPointJitterPassFlags { |
152 | // The amount of jitter to introduce. This amount is added to each element in |
153 | // the tensors named in `tensor_names. |
154 | float jitter_amount; |
155 | |
156 | // The Tensors to add the jitter to. The tensors are named in the TensorId |
157 | // format of <node name>:<output idx>. |
158 | std::vector<string> tensor_names; |
159 | }; |
160 | |
161 | // Flags for common MLIR configurations. |
162 | struct MlirCommonFlags { |
163 | ConfigProto::Experimental::MlirBridgeRollout tf_mlir_enable_mlir_bridge; |
164 | |
165 | bool tf_mlir_enable_merge_control_flow_pass; |
166 | bool tf_mlir_enable_convert_control_to_data_outputs_pass; |
167 | }; |
168 | |
169 | // Flags for the JitRt pipeline -- see tf_jitrt_pipeline.h for details. |
170 | struct JitRtFlags { |
171 | bool always_specialize; |
172 | bool cost_driven_async_parallel_for; |
173 | |
174 | // Enables tracking of the "live" JitRt queries to, on a crash, identify the |
175 | // "query of death". See TfJitRtQueryOfDeathLogger. |
176 | bool log_query_of_death; |
177 | |
178 | bool vectorize; |
179 | |
180 | // Enables crash reproducer for JitRt MLIR pass manager. |
181 | bool enable_crash_reproducer; |
182 | }; |
183 | |
184 | // Return a pointer to the DumpGraphFlags struct; |
185 | // repeated calls return the same pointer. |
186 | // This should be called only after Flags::Parse() has returned. |
187 | |
188 | // Getters for flags structs defined above. The first call to any of these |
189 | // parses TF_XLA_FLAGS for all of them. Those functions which return a pointer |
190 | // always return the same pointer. |
191 | MarkForCompilationPassFlags* GetMarkForCompilationPassFlags(); |
192 | BuildXlaOpsPassFlags* GetBuildXlaOpsPassFlags(); |
193 | XlaDeviceFlags* GetXlaDeviceFlags(); |
194 | const XlaOpsCommonFlags& GetXlaOpsCommonFlags(); |
195 | |
196 | const IntroduceFloatingPointJitterPassFlags& |
197 | GetIntroduceFloatingPointJitterPassFlags(); |
198 | |
199 | MlirCommonFlags* GetMlirCommonFlags(); |
200 | |
201 | void ResetJitCompilerFlags(); |
202 | |
203 | const JitRtFlags& GetJitRtFlags(); |
204 | |
205 | // Returns the effective MLIR bridge rollout state based on the flags and the |
206 | // optional configuration. |
207 | ConfigProto::Experimental::MlirBridgeRollout GetMlirBridgeRolloutState( |
208 | std::optional<const ConfigProto> config_proto); |
209 | |
210 | // Appends the flag definitions associated with |
211 | // MarkForCompilationPassFlags/DumpGraphFlags to `flag_list`. |
212 | // |
213 | // Has the side-effect of parsing TF_XLA_FLAGS if that hasn't happened yet. |
214 | void AppendMarkForCompilationPassFlags( |
215 | std::vector<tensorflow::Flag>* flag_list); |
216 | |
217 | // Disables XLA compilation, forces it to return an error message instead. Can |
218 | // be used by a server to ensure that JIT compilation is opt-in. |
219 | void DisableXlaCompilation(); |
220 | |
221 | // Returns `false` unless `DisableXlaCompilation` was called. |
222 | bool FailOnXlaCompilation(); |
223 | |
224 | } // namespace tensorflow |
225 | |
226 | #endif // TENSORFLOW_COMPILER_JIT_FLAGS_H_ |
227 | |