1 | /* Autogenerated by mlir-tblgen; don't manually edit */ |
2 | |
3 | //===----------------------------------------------------------------------===// |
4 | // DTensorAllGatherLowering |
5 | //===----------------------------------------------------------------------===// |
6 | #ifdef GEN_PASS_DECL_DTENSORALLGATHERLOWERING |
7 | #undef GEN_PASS_DECL_DTENSORALLGATHERLOWERING |
8 | #endif // GEN_PASS_DECL_DTENSORALLGATHERLOWERING |
9 | #ifdef GEN_PASS_DEF_DTENSORALLGATHERLOWERING |
10 | namespace impl { |
11 | |
12 | template <typename DerivedT> |
13 | class DTensorAllGatherLoweringBase : public ::mlir::OperationPass<mlir::ModuleOp> { |
14 | public: |
15 | using Base = DTensorAllGatherLoweringBase; |
16 | |
17 | DTensorAllGatherLoweringBase() : ::mlir::OperationPass<mlir::ModuleOp>(::mlir::TypeID::get<DerivedT>()) {} |
18 | DTensorAllGatherLoweringBase(const DTensorAllGatherLoweringBase &other) : ::mlir::OperationPass<mlir::ModuleOp>(other) {} |
19 | |
20 | /// Returns the command-line argument attached to this pass. |
21 | static constexpr ::llvm::StringLiteral getArgumentName() { |
22 | return ::llvm::StringLiteral("dtensor-all-gather-lowering" ); |
23 | } |
24 | ::llvm::StringRef getArgument() const override { return "dtensor-all-gather-lowering" ; } |
25 | |
26 | ::llvm::StringRef getDescription() const override { return "Converts logical AllGather ops into physical AllGather ops." ; } |
27 | |
28 | /// Returns the derived pass name. |
29 | static constexpr ::llvm::StringLiteral getPassName() { |
30 | return ::llvm::StringLiteral("DTensorAllGatherLowering" ); |
31 | } |
32 | ::llvm::StringRef getName() const override { return "DTensorAllGatherLowering" ; } |
33 | |
34 | /// Support isa/dyn_cast functionality for the derived pass class. |
35 | static bool classof(const ::mlir::Pass *pass) { |
36 | return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>(); |
37 | } |
38 | |
39 | /// A clone method to create a copy of this pass. |
40 | std::unique_ptr<::mlir::Pass> clonePass() const override { |
41 | return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this)); |
42 | } |
43 | |
44 | /// Return the dialect that must be loaded in the context before this pass. |
45 | void getDependentDialects(::mlir::DialectRegistry ®istry) const override { |
46 | |
47 | } |
48 | |
49 | /// Explicitly declare the TypeID for this class. We declare an explicit private |
50 | /// instantiation because Pass classes should only be visible by the current |
51 | /// library. |
52 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DTensorAllGatherLoweringBase<DerivedT>) |
53 | |
54 | protected: |
55 | private: |
56 | }; |
57 | } // namespace impl |
58 | #undef GEN_PASS_DEF_DTENSORALLGATHERLOWERING |
59 | #endif // GEN_PASS_DEF_DTENSORALLGATHERLOWERING |
60 | |
61 | //===----------------------------------------------------------------------===// |
62 | // DTensorAllReduceCombineOptimization |
63 | //===----------------------------------------------------------------------===// |
64 | #ifdef GEN_PASS_DECL_DTENSORALLREDUCECOMBINEOPTIMIZATION |
65 | #undef GEN_PASS_DECL_DTENSORALLREDUCECOMBINEOPTIMIZATION |
66 | #endif // GEN_PASS_DECL_DTENSORALLREDUCECOMBINEOPTIMIZATION |
67 | #ifdef GEN_PASS_DEF_DTENSORALLREDUCECOMBINEOPTIMIZATION |
68 | namespace impl { |
69 | |
70 | template <typename DerivedT> |
71 | class DTensorAllReduceCombineOptimizationBase : public ::mlir::OperationPass<mlir::func::FuncOp> { |
72 | public: |
73 | using Base = DTensorAllReduceCombineOptimizationBase; |
74 | |
75 | DTensorAllReduceCombineOptimizationBase() : ::mlir::OperationPass<mlir::func::FuncOp>(::mlir::TypeID::get<DerivedT>()) {} |
76 | DTensorAllReduceCombineOptimizationBase(const DTensorAllReduceCombineOptimizationBase &other) : ::mlir::OperationPass<mlir::func::FuncOp>(other) {} |
77 | |
78 | /// Returns the command-line argument attached to this pass. |
79 | static constexpr ::llvm::StringLiteral getArgumentName() { |
80 | return ::llvm::StringLiteral("dtensor-allreduce-combine-optimization" ); |
81 | } |
82 | ::llvm::StringRef getArgument() const override { return "dtensor-allreduce-combine-optimization" ; } |
83 | |
84 | ::llvm::StringRef getDescription() const override { return "Combine independent all reduce operations." ; } |
85 | |
86 | /// Returns the derived pass name. |
87 | static constexpr ::llvm::StringLiteral getPassName() { |
88 | return ::llvm::StringLiteral("DTensorAllReduceCombineOptimization" ); |
89 | } |
90 | ::llvm::StringRef getName() const override { return "DTensorAllReduceCombineOptimization" ; } |
91 | |
92 | /// Support isa/dyn_cast functionality for the derived pass class. |
93 | static bool classof(const ::mlir::Pass *pass) { |
94 | return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>(); |
95 | } |
96 | |
97 | /// A clone method to create a copy of this pass. |
98 | std::unique_ptr<::mlir::Pass> clonePass() const override { |
99 | return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this)); |
100 | } |
101 | |
102 | /// Return the dialect that must be loaded in the context before this pass. |
103 | void getDependentDialects(::mlir::DialectRegistry ®istry) const override { |
104 | |
105 | } |
106 | |
107 | /// Explicitly declare the TypeID for this class. We declare an explicit private |
108 | /// instantiation because Pass classes should only be visible by the current |
109 | /// library. |
110 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DTensorAllReduceCombineOptimizationBase<DerivedT>) |
111 | |
112 | protected: |
113 | private: |
114 | }; |
115 | } // namespace impl |
116 | #undef GEN_PASS_DEF_DTENSORALLREDUCECOMBINEOPTIMIZATION |
117 | #endif // GEN_PASS_DEF_DTENSORALLREDUCECOMBINEOPTIMIZATION |
118 | |
119 | //===----------------------------------------------------------------------===// |
120 | // DTensorAllReduceLowering |
121 | //===----------------------------------------------------------------------===// |
122 | #ifdef GEN_PASS_DECL_DTENSORALLREDUCELOWERING |
123 | #undef GEN_PASS_DECL_DTENSORALLREDUCELOWERING |
124 | #endif // GEN_PASS_DECL_DTENSORALLREDUCELOWERING |
125 | #ifdef GEN_PASS_DEF_DTENSORALLREDUCELOWERING |
126 | namespace impl { |
127 | |
128 | template <typename DerivedT> |
129 | class DTensorAllReduceLoweringBase : public ::mlir::OperationPass<mlir::ModuleOp> { |
130 | public: |
131 | using Base = DTensorAllReduceLoweringBase; |
132 | |
133 | DTensorAllReduceLoweringBase() : ::mlir::OperationPass<mlir::ModuleOp>(::mlir::TypeID::get<DerivedT>()) {} |
134 | DTensorAllReduceLoweringBase(const DTensorAllReduceLoweringBase &other) : ::mlir::OperationPass<mlir::ModuleOp>(other) {} |
135 | |
136 | /// Returns the command-line argument attached to this pass. |
137 | static constexpr ::llvm::StringLiteral getArgumentName() { |
138 | return ::llvm::StringLiteral("dtensor-all-reduce-lowering" ); |
139 | } |
140 | ::llvm::StringRef getArgument() const override { return "dtensor-all-reduce-lowering" ; } |
141 | |
142 | ::llvm::StringRef getDescription() const override { return "Converts logical AllReduce ops into physical AllReduce ops." ; } |
143 | |
144 | /// Returns the derived pass name. |
145 | static constexpr ::llvm::StringLiteral getPassName() { |
146 | return ::llvm::StringLiteral("DTensorAllReduceLowering" ); |
147 | } |
148 | ::llvm::StringRef getName() const override { return "DTensorAllReduceLowering" ; } |
149 | |
150 | /// Support isa/dyn_cast functionality for the derived pass class. |
151 | static bool classof(const ::mlir::Pass *pass) { |
152 | return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>(); |
153 | } |
154 | |
155 | /// A clone method to create a copy of this pass. |
156 | std::unique_ptr<::mlir::Pass> clonePass() const override { |
157 | return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this)); |
158 | } |
159 | |
160 | /// Return the dialect that must be loaded in the context before this pass. |
161 | void getDependentDialects(::mlir::DialectRegistry ®istry) const override { |
162 | |
163 | } |
164 | |
165 | /// Explicitly declare the TypeID for this class. We declare an explicit private |
166 | /// instantiation because Pass classes should only be visible by the current |
167 | /// library. |
168 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DTensorAllReduceLoweringBase<DerivedT>) |
169 | |
170 | protected: |
171 | private: |
172 | }; |
173 | } // namespace impl |
174 | #undef GEN_PASS_DEF_DTENSORALLREDUCELOWERING |
175 | #endif // GEN_PASS_DEF_DTENSORALLREDUCELOWERING |
176 | |
177 | //===----------------------------------------------------------------------===// |
178 | // DTensorAllReduceScatterOptimization |
179 | //===----------------------------------------------------------------------===// |
180 | #ifdef GEN_PASS_DECL_DTENSORALLREDUCESCATTEROPTIMIZATION |
181 | #undef GEN_PASS_DECL_DTENSORALLREDUCESCATTEROPTIMIZATION |
182 | #endif // GEN_PASS_DECL_DTENSORALLREDUCESCATTEROPTIMIZATION |
183 | #ifdef GEN_PASS_DEF_DTENSORALLREDUCESCATTEROPTIMIZATION |
184 | namespace impl { |
185 | |
186 | template <typename DerivedT> |
187 | class DTensorAllReduceScatterOptimizationBase : public ::mlir::OperationPass<mlir::func::FuncOp> { |
188 | public: |
189 | using Base = DTensorAllReduceScatterOptimizationBase; |
190 | |
191 | DTensorAllReduceScatterOptimizationBase() : ::mlir::OperationPass<mlir::func::FuncOp>(::mlir::TypeID::get<DerivedT>()) {} |
192 | DTensorAllReduceScatterOptimizationBase(const DTensorAllReduceScatterOptimizationBase &other) : ::mlir::OperationPass<mlir::func::FuncOp>(other) {} |
193 | |
194 | /// Returns the command-line argument attached to this pass. |
195 | static constexpr ::llvm::StringLiteral getArgumentName() { |
196 | return ::llvm::StringLiteral("dtensor-allreduce-scatter-optimization" ); |
197 | } |
198 | ::llvm::StringRef getArgument() const override { return "dtensor-allreduce-scatter-optimization" ; } |
199 | |
200 | ::llvm::StringRef getDescription() const override { return "Combines allreduce and scatter to reducescatter." ; } |
201 | |
202 | /// Returns the derived pass name. |
203 | static constexpr ::llvm::StringLiteral getPassName() { |
204 | return ::llvm::StringLiteral("DTensorAllReduceScatterOptimization" ); |
205 | } |
206 | ::llvm::StringRef getName() const override { return "DTensorAllReduceScatterOptimization" ; } |
207 | |
208 | /// Support isa/dyn_cast functionality for the derived pass class. |
209 | static bool classof(const ::mlir::Pass *pass) { |
210 | return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>(); |
211 | } |
212 | |
213 | /// A clone method to create a copy of this pass. |
214 | std::unique_ptr<::mlir::Pass> clonePass() const override { |
215 | return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this)); |
216 | } |
217 | |
218 | /// Return the dialect that must be loaded in the context before this pass. |
219 | void getDependentDialects(::mlir::DialectRegistry ®istry) const override { |
220 | |
221 | } |
222 | |
223 | /// Explicitly declare the TypeID for this class. We declare an explicit private |
224 | /// instantiation because Pass classes should only be visible by the current |
225 | /// library. |
226 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DTensorAllReduceScatterOptimizationBase<DerivedT>) |
227 | |
228 | protected: |
229 | private: |
230 | }; |
231 | } // namespace impl |
232 | #undef GEN_PASS_DEF_DTENSORALLREDUCESCATTEROPTIMIZATION |
233 | #endif // GEN_PASS_DEF_DTENSORALLREDUCESCATTEROPTIMIZATION |
234 | |
235 | //===----------------------------------------------------------------------===// |
236 | // DTensorAllReduceSumOptimization |
237 | //===----------------------------------------------------------------------===// |
238 | #ifdef GEN_PASS_DECL_DTENSORALLREDUCESUMOPTIMIZATION |
239 | #undef GEN_PASS_DECL_DTENSORALLREDUCESUMOPTIMIZATION |
240 | #endif // GEN_PASS_DECL_DTENSORALLREDUCESUMOPTIMIZATION |
241 | #ifdef GEN_PASS_DEF_DTENSORALLREDUCESUMOPTIMIZATION |
242 | namespace impl { |
243 | |
244 | template <typename DerivedT> |
245 | class DTensorAllReduceSumOptimizationBase : public ::mlir::OperationPass<mlir::func::FuncOp> { |
246 | public: |
247 | using Base = DTensorAllReduceSumOptimizationBase; |
248 | |
249 | DTensorAllReduceSumOptimizationBase() : ::mlir::OperationPass<mlir::func::FuncOp>(::mlir::TypeID::get<DerivedT>()) {} |
250 | DTensorAllReduceSumOptimizationBase(const DTensorAllReduceSumOptimizationBase &other) : ::mlir::OperationPass<mlir::func::FuncOp>(other) {} |
251 | |
252 | /// Returns the command-line argument attached to this pass. |
253 | static constexpr ::llvm::StringLiteral getArgumentName() { |
254 | return ::llvm::StringLiteral("dtensor-allreduce-sum-optimization" ); |
255 | } |
256 | ::llvm::StringRef getArgument() const override { return "dtensor-allreduce-sum-optimization" ; } |
257 | |
258 | ::llvm::StringRef getDescription() const override { return "Changes order of add/allreduce to minimize all reduce operations." ; } |
259 | |
260 | /// Returns the derived pass name. |
261 | static constexpr ::llvm::StringLiteral getPassName() { |
262 | return ::llvm::StringLiteral("DTensorAllReduceSumOptimization" ); |
263 | } |
264 | ::llvm::StringRef getName() const override { return "DTensorAllReduceSumOptimization" ; } |
265 | |
266 | /// Support isa/dyn_cast functionality for the derived pass class. |
267 | static bool classof(const ::mlir::Pass *pass) { |
268 | return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>(); |
269 | } |
270 | |
271 | /// A clone method to create a copy of this pass. |
272 | std::unique_ptr<::mlir::Pass> clonePass() const override { |
273 | return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this)); |
274 | } |
275 | |
276 | /// Return the dialect that must be loaded in the context before this pass. |
277 | void getDependentDialects(::mlir::DialectRegistry ®istry) const override { |
278 | |
279 | } |
280 | |
281 | /// Explicitly declare the TypeID for this class. We declare an explicit private |
282 | /// instantiation because Pass classes should only be visible by the current |
283 | /// library. |
284 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DTensorAllReduceSumOptimizationBase<DerivedT>) |
285 | |
286 | protected: |
287 | private: |
288 | }; |
289 | } // namespace impl |
290 | #undef GEN_PASS_DEF_DTENSORALLREDUCESUMOPTIMIZATION |
291 | #endif // GEN_PASS_DEF_DTENSORALLREDUCESUMOPTIMIZATION |
292 | |
293 | //===----------------------------------------------------------------------===// |
294 | // DTensorAllScatterLowering |
295 | //===----------------------------------------------------------------------===// |
296 | #ifdef GEN_PASS_DECL_DTENSORALLSCATTERLOWERING |
297 | #undef GEN_PASS_DECL_DTENSORALLSCATTERLOWERING |
298 | #endif // GEN_PASS_DECL_DTENSORALLSCATTERLOWERING |
299 | #ifdef GEN_PASS_DEF_DTENSORALLSCATTERLOWERING |
300 | namespace impl { |
301 | |
302 | template <typename DerivedT> |
303 | class DTensorAllScatterLoweringBase : public ::mlir::OperationPass<mlir::ModuleOp> { |
304 | public: |
305 | using Base = DTensorAllScatterLoweringBase; |
306 | |
307 | DTensorAllScatterLoweringBase() : ::mlir::OperationPass<mlir::ModuleOp>(::mlir::TypeID::get<DerivedT>()) {} |
308 | DTensorAllScatterLoweringBase(const DTensorAllScatterLoweringBase &other) : ::mlir::OperationPass<mlir::ModuleOp>(other) {} |
309 | |
310 | /// Returns the command-line argument attached to this pass. |
311 | static constexpr ::llvm::StringLiteral getArgumentName() { |
312 | return ::llvm::StringLiteral("dtensor-all-scatter-lowering" ); |
313 | } |
314 | ::llvm::StringRef getArgument() const override { return "dtensor-all-scatter-lowering" ; } |
315 | |
316 | ::llvm::StringRef getDescription() const override { return "Converts logical AllScatter ops into physical Split ops." ; } |
317 | |
318 | /// Returns the derived pass name. |
319 | static constexpr ::llvm::StringLiteral getPassName() { |
320 | return ::llvm::StringLiteral("DTensorAllScatterLowering" ); |
321 | } |
322 | ::llvm::StringRef getName() const override { return "DTensorAllScatterLowering" ; } |
323 | |
324 | /// Support isa/dyn_cast functionality for the derived pass class. |
325 | static bool classof(const ::mlir::Pass *pass) { |
326 | return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>(); |
327 | } |
328 | |
329 | /// A clone method to create a copy of this pass. |
330 | std::unique_ptr<::mlir::Pass> clonePass() const override { |
331 | return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this)); |
332 | } |
333 | |
334 | /// Return the dialect that must be loaded in the context before this pass. |
335 | void getDependentDialects(::mlir::DialectRegistry ®istry) const override { |
336 | |
337 | } |
338 | |
339 | /// Explicitly declare the TypeID for this class. We declare an explicit private |
340 | /// instantiation because Pass classes should only be visible by the current |
341 | /// library. |
342 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DTensorAllScatterLoweringBase<DerivedT>) |
343 | |
344 | protected: |
345 | private: |
346 | }; |
347 | } // namespace impl |
348 | #undef GEN_PASS_DEF_DTENSORALLSCATTERLOWERING |
349 | #endif // GEN_PASS_DEF_DTENSORALLSCATTERLOWERING |
350 | |
351 | //===----------------------------------------------------------------------===// |
352 | // DTensorAnnotateGlobalShape |
353 | //===----------------------------------------------------------------------===// |
354 | #ifdef GEN_PASS_DECL_DTENSORANNOTATEGLOBALSHAPE |
355 | #undef GEN_PASS_DECL_DTENSORANNOTATEGLOBALSHAPE |
356 | #endif // GEN_PASS_DECL_DTENSORANNOTATEGLOBALSHAPE |
357 | #ifdef GEN_PASS_DEF_DTENSORANNOTATEGLOBALSHAPE |
358 | namespace impl { |
359 | |
360 | template <typename DerivedT> |
361 | class DTensorAnnotateGlobalShapeBase : public ::mlir::OperationPass<mlir::ModuleOp> { |
362 | public: |
363 | using Base = DTensorAnnotateGlobalShapeBase; |
364 | |
365 | DTensorAnnotateGlobalShapeBase() : ::mlir::OperationPass<mlir::ModuleOp>(::mlir::TypeID::get<DerivedT>()) {} |
366 | DTensorAnnotateGlobalShapeBase(const DTensorAnnotateGlobalShapeBase &other) : ::mlir::OperationPass<mlir::ModuleOp>(other) {} |
367 | |
368 | /// Returns the command-line argument attached to this pass. |
369 | static constexpr ::llvm::StringLiteral getArgumentName() { |
370 | return ::llvm::StringLiteral("dtensor-annotate-global-shape" ); |
371 | } |
372 | ::llvm::StringRef getArgument() const override { return "dtensor-annotate-global-shape" ; } |
373 | |
374 | ::llvm::StringRef getDescription() const override { return "Mark all operations and function arguments with `_global_shape` attribute to be used during SPMD expansion." ; } |
375 | |
376 | /// Returns the derived pass name. |
377 | static constexpr ::llvm::StringLiteral getPassName() { |
378 | return ::llvm::StringLiteral("DTensorAnnotateGlobalShape" ); |
379 | } |
380 | ::llvm::StringRef getName() const override { return "DTensorAnnotateGlobalShape" ; } |
381 | |
382 | /// Support isa/dyn_cast functionality for the derived pass class. |
383 | static bool classof(const ::mlir::Pass *pass) { |
384 | return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>(); |
385 | } |
386 | |
387 | /// A clone method to create a copy of this pass. |
388 | std::unique_ptr<::mlir::Pass> clonePass() const override { |
389 | return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this)); |
390 | } |
391 | |
392 | /// Return the dialect that must be loaded in the context before this pass. |
393 | void getDependentDialects(::mlir::DialectRegistry ®istry) const override { |
394 | |
395 | } |
396 | |
397 | /// Explicitly declare the TypeID for this class. We declare an explicit private |
398 | /// instantiation because Pass classes should only be visible by the current |
399 | /// library. |
400 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DTensorAnnotateGlobalShapeBase<DerivedT>) |
401 | |
402 | protected: |
403 | private: |
404 | }; |
405 | } // namespace impl |
406 | #undef GEN_PASS_DEF_DTENSORANNOTATEGLOBALSHAPE |
407 | #endif // GEN_PASS_DEF_DTENSORANNOTATEGLOBALSHAPE |
408 | |
409 | //===----------------------------------------------------------------------===// |
410 | // DTensorClusterFunctionConversion |
411 | //===----------------------------------------------------------------------===// |
412 | #ifdef GEN_PASS_DECL_DTENSORCLUSTERFUNCTIONCONVERSION |
413 | #undef GEN_PASS_DECL_DTENSORCLUSTERFUNCTIONCONVERSION |
414 | #endif // GEN_PASS_DECL_DTENSORCLUSTERFUNCTIONCONVERSION |
415 | #ifdef GEN_PASS_DEF_DTENSORCLUSTERFUNCTIONCONVERSION |
416 | namespace impl { |
417 | |
418 | template <typename DerivedT> |
419 | class DTensorClusterFunctionConversionBase : public ::mlir::OperationPass<mlir::ModuleOp> { |
420 | public: |
421 | using Base = DTensorClusterFunctionConversionBase; |
422 | |
423 | DTensorClusterFunctionConversionBase() : ::mlir::OperationPass<mlir::ModuleOp>(::mlir::TypeID::get<DerivedT>()) {} |
424 | DTensorClusterFunctionConversionBase(const DTensorClusterFunctionConversionBase &other) : ::mlir::OperationPass<mlir::ModuleOp>(other) {} |
425 | |
426 | /// Returns the command-line argument attached to this pass. |
427 | static constexpr ::llvm::StringLiteral getArgumentName() { |
428 | return ::llvm::StringLiteral("dtensor-cluster-function-conversion" ); |
429 | } |
430 | ::llvm::StringRef getArgument() const override { return "dtensor-cluster-function-conversion" ; } |
431 | |
432 | ::llvm::StringRef getDescription() const override { return "Converts tf_device.cluster_func ops into TF StatefulPartitioned call op with mesh attribute." ; } |
433 | |
434 | /// Returns the derived pass name. |
435 | static constexpr ::llvm::StringLiteral getPassName() { |
436 | return ::llvm::StringLiteral("DTensorClusterFunctionConversion" ); |
437 | } |
438 | ::llvm::StringRef getName() const override { return "DTensorClusterFunctionConversion" ; } |
439 | |
440 | /// Support isa/dyn_cast functionality for the derived pass class. |
441 | static bool classof(const ::mlir::Pass *pass) { |
442 | return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>(); |
443 | } |
444 | |
445 | /// A clone method to create a copy of this pass. |
446 | std::unique_ptr<::mlir::Pass> clonePass() const override { |
447 | return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this)); |
448 | } |
449 | |
450 | /// Return the dialect that must be loaded in the context before this pass. |
451 | void getDependentDialects(::mlir::DialectRegistry ®istry) const override { |
452 | |
453 | } |
454 | |
455 | /// Explicitly declare the TypeID for this class. We declare an explicit private |
456 | /// instantiation because Pass classes should only be visible by the current |
457 | /// library. |
458 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DTensorClusterFunctionConversionBase<DerivedT>) |
459 | |
460 | protected: |
461 | private: |
462 | }; |
463 | } // namespace impl |
464 | #undef GEN_PASS_DEF_DTENSORCLUSTERFUNCTIONCONVERSION |
465 | #endif // GEN_PASS_DEF_DTENSORCLUSTERFUNCTIONCONVERSION |
466 | |
467 | //===----------------------------------------------------------------------===// |
468 | // DTensorConstantFolding |
469 | //===----------------------------------------------------------------------===// |
470 | #ifdef GEN_PASS_DECL_DTENSORCONSTANTFOLDING |
471 | #undef GEN_PASS_DECL_DTENSORCONSTANTFOLDING |
472 | #endif // GEN_PASS_DECL_DTENSORCONSTANTFOLDING |
473 | #ifdef GEN_PASS_DEF_DTENSORCONSTANTFOLDING |
474 | namespace impl { |
475 | |
476 | template <typename DerivedT> |
477 | class DTensorConstantFoldingBase : public ::mlir::OperationPass<mlir::func::FuncOp> { |
478 | public: |
479 | using Base = DTensorConstantFoldingBase; |
480 | |
481 | DTensorConstantFoldingBase() : ::mlir::OperationPass<mlir::func::FuncOp>(::mlir::TypeID::get<DerivedT>()) {} |
482 | DTensorConstantFoldingBase(const DTensorConstantFoldingBase &other) : ::mlir::OperationPass<mlir::func::FuncOp>(other) {} |
483 | |
484 | /// Returns the command-line argument attached to this pass. |
485 | static constexpr ::llvm::StringLiteral getArgumentName() { |
486 | return ::llvm::StringLiteral("dtensor-constant-folding" ); |
487 | } |
488 | ::llvm::StringRef getArgument() const override { return "dtensor-constant-folding" ; } |
489 | |
490 | ::llvm::StringRef getDescription() const override { return "Folds constants operations." ; } |
491 | |
492 | /// Returns the derived pass name. |
493 | static constexpr ::llvm::StringLiteral getPassName() { |
494 | return ::llvm::StringLiteral("DTensorConstantFolding" ); |
495 | } |
496 | ::llvm::StringRef getName() const override { return "DTensorConstantFolding" ; } |
497 | |
498 | /// Support isa/dyn_cast functionality for the derived pass class. |
499 | static bool classof(const ::mlir::Pass *pass) { |
500 | return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>(); |
501 | } |
502 | |
503 | /// A clone method to create a copy of this pass. |
504 | std::unique_ptr<::mlir::Pass> clonePass() const override { |
505 | return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this)); |
506 | } |
507 | |
508 | /// Return the dialect that must be loaded in the context before this pass. |
509 | void getDependentDialects(::mlir::DialectRegistry ®istry) const override { |
510 | |
511 | } |
512 | |
513 | /// Explicitly declare the TypeID for this class. We declare an explicit private |
514 | /// instantiation because Pass classes should only be visible by the current |
515 | /// library. |
516 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DTensorConstantFoldingBase<DerivedT>) |
517 | |
518 | protected: |
519 | private: |
520 | }; |
521 | } // namespace impl |
522 | #undef GEN_PASS_DEF_DTENSORCONSTANTFOLDING |
523 | #endif // GEN_PASS_DEF_DTENSORCONSTANTFOLDING |
524 | |
525 | //===----------------------------------------------------------------------===// |
526 | // DTensorDCE |
527 | //===----------------------------------------------------------------------===// |
528 | #ifdef GEN_PASS_DECL_DTENSORDCE |
529 | #undef GEN_PASS_DECL_DTENSORDCE |
530 | #endif // GEN_PASS_DECL_DTENSORDCE |
531 | #ifdef GEN_PASS_DEF_DTENSORDCE |
532 | namespace impl { |
533 | |
534 | template <typename DerivedT> |
535 | class DTensorDCEBase : public ::mlir::OperationPass<mlir::func::FuncOp> { |
536 | public: |
537 | using Base = DTensorDCEBase; |
538 | |
539 | DTensorDCEBase() : ::mlir::OperationPass<mlir::func::FuncOp>(::mlir::TypeID::get<DerivedT>()) {} |
540 | DTensorDCEBase(const DTensorDCEBase &other) : ::mlir::OperationPass<mlir::func::FuncOp>(other) {} |
541 | |
542 | /// Returns the command-line argument attached to this pass. |
543 | static constexpr ::llvm::StringLiteral getArgumentName() { |
544 | return ::llvm::StringLiteral("dtensor-dce" ); |
545 | } |
546 | ::llvm::StringRef getArgument() const override { return "dtensor-dce" ; } |
547 | |
548 | ::llvm::StringRef getDescription() const override { return "Removes unused ops from graph." ; } |
549 | |
550 | /// Returns the derived pass name. |
551 | static constexpr ::llvm::StringLiteral getPassName() { |
552 | return ::llvm::StringLiteral("DTensorDCE" ); |
553 | } |
554 | ::llvm::StringRef getName() const override { return "DTensorDCE" ; } |
555 | |
556 | /// Support isa/dyn_cast functionality for the derived pass class. |
557 | static bool classof(const ::mlir::Pass *pass) { |
558 | return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>(); |
559 | } |
560 | |
561 | /// A clone method to create a copy of this pass. |
562 | std::unique_ptr<::mlir::Pass> clonePass() const override { |
563 | return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this)); |
564 | } |
565 | |
566 | /// Return the dialect that must be loaded in the context before this pass. |
567 | void getDependentDialects(::mlir::DialectRegistry ®istry) const override { |
568 | |
569 | } |
570 | |
571 | /// Explicitly declare the TypeID for this class. We declare an explicit private |
572 | /// instantiation because Pass classes should only be visible by the current |
573 | /// library. |
574 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DTensorDCEBase<DerivedT>) |
575 | |
576 | protected: |
577 | private: |
578 | }; |
579 | } // namespace impl |
580 | #undef GEN_PASS_DEF_DTENSORDCE |
581 | #endif // GEN_PASS_DEF_DTENSORDCE |
582 | |
583 | //===----------------------------------------------------------------------===// |
584 | // DTensorDesignateResourceHandleMesh |
585 | //===----------------------------------------------------------------------===// |
586 | #ifdef GEN_PASS_DECL_DTENSORDESIGNATERESOURCEHANDLEMESH |
587 | #undef GEN_PASS_DECL_DTENSORDESIGNATERESOURCEHANDLEMESH |
588 | #endif // GEN_PASS_DECL_DTENSORDESIGNATERESOURCEHANDLEMESH |
589 | #ifdef GEN_PASS_DEF_DTENSORDESIGNATERESOURCEHANDLEMESH |
590 | namespace impl { |
591 | |
592 | template <typename DerivedT> |
593 | class DTensorDesignateResourceHandleMeshBase : public ::mlir::OperationPass<mlir::func::FuncOp> { |
594 | public: |
595 | using Base = DTensorDesignateResourceHandleMeshBase; |
596 | |
597 | DTensorDesignateResourceHandleMeshBase() : ::mlir::OperationPass<mlir::func::FuncOp>(::mlir::TypeID::get<DerivedT>()) {} |
598 | DTensorDesignateResourceHandleMeshBase(const DTensorDesignateResourceHandleMeshBase &other) : ::mlir::OperationPass<mlir::func::FuncOp>(other) {} |
599 | |
600 | /// Returns the command-line argument attached to this pass. |
601 | static constexpr ::llvm::StringLiteral getArgumentName() { |
602 | return ::llvm::StringLiteral("dtensor-designate-resource-handle-mesh" ); |
603 | } |
604 | ::llvm::StringRef getArgument() const override { return "dtensor-designate-resource-handle-mesh" ; } |
605 | |
606 | ::llvm::StringRef getDescription() const override { return "Sets empty mesh attributes for device cluster that creates or destroys resource handles." ; } |
607 | |
608 | /// Returns the derived pass name. |
609 | static constexpr ::llvm::StringLiteral getPassName() { |
610 | return ::llvm::StringLiteral("DTensorDesignateResourceHandleMesh" ); |
611 | } |
612 | ::llvm::StringRef getName() const override { return "DTensorDesignateResourceHandleMesh" ; } |
613 | |
614 | /// Support isa/dyn_cast functionality for the derived pass class. |
615 | static bool classof(const ::mlir::Pass *pass) { |
616 | return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>(); |
617 | } |
618 | |
619 | /// A clone method to create a copy of this pass. |
620 | std::unique_ptr<::mlir::Pass> clonePass() const override { |
621 | return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this)); |
622 | } |
623 | |
624 | /// Return the dialect that must be loaded in the context before this pass. |
625 | void getDependentDialects(::mlir::DialectRegistry ®istry) const override { |
626 | |
627 | } |
628 | |
629 | /// Explicitly declare the TypeID for this class. We declare an explicit private |
630 | /// instantiation because Pass classes should only be visible by the current |
631 | /// library. |
632 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DTensorDesignateResourceHandleMeshBase<DerivedT>) |
633 | |
634 | protected: |
635 | private: |
636 | }; |
637 | } // namespace impl |
638 | #undef GEN_PASS_DEF_DTENSORDESIGNATERESOURCEHANDLEMESH |
639 | #endif // GEN_PASS_DEF_DTENSORDESIGNATERESOURCEHANDLEMESH |
640 | |
641 | //===----------------------------------------------------------------------===// |
642 | // DTensorDeviceMeshClusterCoarsening |
643 | //===----------------------------------------------------------------------===// |
644 | #ifdef GEN_PASS_DECL_DTENSORDEVICEMESHCLUSTERCOARSENING |
645 | #undef GEN_PASS_DECL_DTENSORDEVICEMESHCLUSTERCOARSENING |
646 | #endif // GEN_PASS_DECL_DTENSORDEVICEMESHCLUSTERCOARSENING |
647 | #ifdef GEN_PASS_DEF_DTENSORDEVICEMESHCLUSTERCOARSENING |
648 | namespace impl { |
649 | |
650 | template <typename DerivedT> |
651 | class DTensorDeviceMeshClusterCoarseningBase : public ::mlir::OperationPass<mlir::func::FuncOp> { |
652 | public: |
653 | using Base = DTensorDeviceMeshClusterCoarseningBase; |
654 | |
655 | DTensorDeviceMeshClusterCoarseningBase() : ::mlir::OperationPass<mlir::func::FuncOp>(::mlir::TypeID::get<DerivedT>()) {} |
656 | DTensorDeviceMeshClusterCoarseningBase(const DTensorDeviceMeshClusterCoarseningBase &other) : ::mlir::OperationPass<mlir::func::FuncOp>(other) {} |
657 | |
658 | /// Returns the command-line argument attached to this pass. |
659 | static constexpr ::llvm::StringLiteral getArgumentName() { |
660 | return ::llvm::StringLiteral("dtensor-device-mesh-cluster-coarsening" ); |
661 | } |
662 | ::llvm::StringRef getArgument() const override { return "dtensor-device-mesh-cluster-coarsening" ; } |
663 | |
664 | ::llvm::StringRef getDescription() const override { return "Merges tf_device.cluster op with same mesh attribute." ; } |
665 | |
666 | /// Returns the derived pass name. |
667 | static constexpr ::llvm::StringLiteral getPassName() { |
668 | return ::llvm::StringLiteral("DTensorDeviceMeshClusterCoarsening" ); |
669 | } |
670 | ::llvm::StringRef getName() const override { return "DTensorDeviceMeshClusterCoarsening" ; } |
671 | |
672 | /// Support isa/dyn_cast functionality for the derived pass class. |
673 | static bool classof(const ::mlir::Pass *pass) { |
674 | return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>(); |
675 | } |
676 | |
677 | /// A clone method to create a copy of this pass. |
678 | std::unique_ptr<::mlir::Pass> clonePass() const override { |
679 | return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this)); |
680 | } |
681 | |
682 | /// Return the dialect that must be loaded in the context before this pass. |
683 | void getDependentDialects(::mlir::DialectRegistry ®istry) const override { |
684 | |
685 | } |
686 | |
687 | /// Explicitly declare the TypeID for this class. We declare an explicit private |
688 | /// instantiation because Pass classes should only be visible by the current |
689 | /// library. |
690 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DTensorDeviceMeshClusterCoarseningBase<DerivedT>) |
691 | |
692 | protected: |
693 | private: |
694 | }; |
695 | } // namespace impl |
696 | #undef GEN_PASS_DEF_DTENSORDEVICEMESHCLUSTERCOARSENING |
697 | #endif // GEN_PASS_DEF_DTENSORDEVICEMESHCLUSTERCOARSENING |
698 | |
699 | //===----------------------------------------------------------------------===// |
700 | // DTensorEmbedding |
701 | //===----------------------------------------------------------------------===// |
702 | #ifdef GEN_PASS_DECL_DTENSOREMBEDDING |
703 | #undef GEN_PASS_DECL_DTENSOREMBEDDING |
704 | #endif // GEN_PASS_DECL_DTENSOREMBEDDING |
705 | #ifdef GEN_PASS_DEF_DTENSOREMBEDDING |
706 | namespace impl { |
707 | |
708 | template <typename DerivedT> |
709 | class DTensorEmbeddingBase : public ::mlir::OperationPass<mlir::ModuleOp> { |
710 | public: |
711 | using Base = DTensorEmbeddingBase; |
712 | |
713 | DTensorEmbeddingBase() : ::mlir::OperationPass<mlir::ModuleOp>(::mlir::TypeID::get<DerivedT>()) {} |
714 | DTensorEmbeddingBase(const DTensorEmbeddingBase &other) : ::mlir::OperationPass<mlir::ModuleOp>(other) {} |
715 | |
716 | /// Returns the command-line argument attached to this pass. |
717 | static constexpr ::llvm::StringLiteral getArgumentName() { |
718 | return ::llvm::StringLiteral("dtensor-embedding" ); |
719 | } |
720 | ::llvm::StringRef getArgument() const override { return "dtensor-embedding" ; } |
721 | |
722 | ::llvm::StringRef getDescription() const override { return "Extracts the ApplyEmbeddingOptimizer and checks if the attached optimizer is usable for TPU embeddings." ; } |
723 | |
724 | /// Returns the derived pass name. |
725 | static constexpr ::llvm::StringLiteral getPassName() { |
726 | return ::llvm::StringLiteral("DTensorEmbedding" ); |
727 | } |
728 | ::llvm::StringRef getName() const override { return "DTensorEmbedding" ; } |
729 | |
730 | /// Support isa/dyn_cast functionality for the derived pass class. |
731 | static bool classof(const ::mlir::Pass *pass) { |
732 | return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>(); |
733 | } |
734 | |
735 | /// A clone method to create a copy of this pass. |
736 | std::unique_ptr<::mlir::Pass> clonePass() const override { |
737 | return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this)); |
738 | } |
739 | |
740 | /// Return the dialect that must be loaded in the context before this pass. |
741 | void getDependentDialects(::mlir::DialectRegistry ®istry) const override { |
742 | |
743 | registry.insert<::mlir::chlo::ChloDialect>(); |
744 | |
745 | registry.insert<::mlir::memref::MemRefDialect>(); |
746 | |
747 | registry.insert<::mlir::mhlo::MhloDialect>(); |
748 | |
749 | registry.insert<::mlir::shape::ShapeDialect>(); |
750 | |
751 | registry.insert<::mlir::scf::SCFDialect>(); |
752 | |
753 | registry.insert<::mlir::tensor::TensorDialect>(); |
754 | |
755 | } |
756 | |
757 | /// Explicitly declare the TypeID for this class. We declare an explicit private |
758 | /// instantiation because Pass classes should only be visible by the current |
759 | /// library. |
760 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DTensorEmbeddingBase<DerivedT>) |
761 | |
762 | protected: |
763 | private: |
764 | }; |
765 | } // namespace impl |
766 | #undef GEN_PASS_DEF_DTENSOREMBEDDING |
767 | #endif // GEN_PASS_DEF_DTENSOREMBEDDING |
768 | |
769 | //===----------------------------------------------------------------------===// |
770 | // DTensorEmbeddingCheckpoint |
771 | //===----------------------------------------------------------------------===// |
772 | #ifdef GEN_PASS_DECL_DTENSOREMBEDDINGCHECKPOINT |
773 | #undef GEN_PASS_DECL_DTENSOREMBEDDINGCHECKPOINT |
774 | #endif // GEN_PASS_DECL_DTENSOREMBEDDINGCHECKPOINT |
775 | #ifdef GEN_PASS_DEF_DTENSOREMBEDDINGCHECKPOINT |
776 | namespace impl { |
777 | |
778 | template <typename DerivedT> |
779 | class DTensorEmbeddingCheckpointBase : public ::mlir::OperationPass<mlir::ModuleOp> { |
780 | public: |
781 | using Base = DTensorEmbeddingCheckpointBase; |
782 | |
783 | DTensorEmbeddingCheckpointBase() : ::mlir::OperationPass<mlir::ModuleOp>(::mlir::TypeID::get<DerivedT>()) {} |
784 | DTensorEmbeddingCheckpointBase(const DTensorEmbeddingCheckpointBase &other) : ::mlir::OperationPass<mlir::ModuleOp>(other) {} |
785 | |
786 | /// Returns the command-line argument attached to this pass. |
787 | static constexpr ::llvm::StringLiteral getArgumentName() { |
788 | return ::llvm::StringLiteral("dtensor-embedding-checkpoint" ); |
789 | } |
790 | ::llvm::StringRef getArgument() const override { return "dtensor-embedding-checkpoint" ; } |
791 | |
792 | ::llvm::StringRef getDescription() const override { return "Checkpoint pass to save or load TPU embeddings variables." ; } |
793 | |
794 | /// Returns the derived pass name. |
795 | static constexpr ::llvm::StringLiteral getPassName() { |
796 | return ::llvm::StringLiteral("DTensorEmbeddingCheckpoint" ); |
797 | } |
798 | ::llvm::StringRef getName() const override { return "DTensorEmbeddingCheckpoint" ; } |
799 | |
800 | /// Support isa/dyn_cast functionality for the derived pass class. |
801 | static bool classof(const ::mlir::Pass *pass) { |
802 | return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>(); |
803 | } |
804 | |
805 | /// A clone method to create a copy of this pass. |
806 | std::unique_ptr<::mlir::Pass> clonePass() const override { |
807 | return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this)); |
808 | } |
809 | |
810 | /// Return the dialect that must be loaded in the context before this pass. |
811 | void getDependentDialects(::mlir::DialectRegistry ®istry) const override { |
812 | |
813 | } |
814 | |
815 | /// Explicitly declare the TypeID for this class. We declare an explicit private |
816 | /// instantiation because Pass classes should only be visible by the current |
817 | /// library. |
818 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DTensorEmbeddingCheckpointBase<DerivedT>) |
819 | |
820 | protected: |
821 | private: |
822 | }; |
823 | } // namespace impl |
824 | #undef GEN_PASS_DEF_DTENSOREMBEDDINGCHECKPOINT |
825 | #endif // GEN_PASS_DEF_DTENSOREMBEDDINGCHECKPOINT |
826 | |
827 | //===----------------------------------------------------------------------===// |
828 | // DTensorEmbeddingV2 |
829 | //===----------------------------------------------------------------------===// |
830 | #ifdef GEN_PASS_DECL_DTENSOREMBEDDINGV2 |
831 | #undef GEN_PASS_DECL_DTENSOREMBEDDINGV2 |
832 | #endif // GEN_PASS_DECL_DTENSOREMBEDDINGV2 |
833 | #ifdef GEN_PASS_DEF_DTENSOREMBEDDINGV2 |
834 | namespace impl { |
835 | |
836 | template <typename DerivedT> |
837 | class DTensorEmbeddingV2Base : public ::mlir::OperationPass<mlir::ModuleOp> { |
838 | public: |
839 | using Base = DTensorEmbeddingV2Base; |
840 | |
841 | DTensorEmbeddingV2Base() : ::mlir::OperationPass<mlir::ModuleOp>(::mlir::TypeID::get<DerivedT>()) {} |
842 | DTensorEmbeddingV2Base(const DTensorEmbeddingV2Base &other) : ::mlir::OperationPass<mlir::ModuleOp>(other) {} |
843 | |
844 | /// Returns the command-line argument attached to this pass. |
845 | static constexpr ::llvm::StringLiteral getArgumentName() { |
846 | return ::llvm::StringLiteral("dtensor-embedding-v2" ); |
847 | } |
848 | ::llvm::StringRef getArgument() const override { return "dtensor-embedding-v2" ; } |
849 | |
850 | ::llvm::StringRef getDescription() const override { return "Embedding Pass" ; } |
851 | |
852 | /// Returns the derived pass name. |
853 | static constexpr ::llvm::StringLiteral getPassName() { |
854 | return ::llvm::StringLiteral("DTensorEmbeddingV2" ); |
855 | } |
856 | ::llvm::StringRef getName() const override { return "DTensorEmbeddingV2" ; } |
857 | |
858 | /// Support isa/dyn_cast functionality for the derived pass class. |
859 | static bool classof(const ::mlir::Pass *pass) { |
860 | return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>(); |
861 | } |
862 | |
863 | /// A clone method to create a copy of this pass. |
864 | std::unique_ptr<::mlir::Pass> clonePass() const override { |
865 | return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this)); |
866 | } |
867 | |
868 | /// Return the dialect that must be loaded in the context before this pass. |
869 | void getDependentDialects(::mlir::DialectRegistry ®istry) const override { |
870 | |
871 | } |
872 | |
873 | /// Explicitly declare the TypeID for this class. We declare an explicit private |
874 | /// instantiation because Pass classes should only be visible by the current |
875 | /// library. |
876 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DTensorEmbeddingV2Base<DerivedT>) |
877 | |
878 | protected: |
879 | private: |
880 | }; |
881 | } // namespace impl |
882 | #undef GEN_PASS_DEF_DTENSOREMBEDDINGV2 |
883 | #endif // GEN_PASS_DEF_DTENSOREMBEDDINGV2 |
884 | |
885 | //===----------------------------------------------------------------------===// |
886 | // DTensorFunctionRenaming |
887 | //===----------------------------------------------------------------------===// |
888 | #ifdef GEN_PASS_DECL_DTENSORFUNCTIONRENAMING |
889 | #undef GEN_PASS_DECL_DTENSORFUNCTIONRENAMING |
890 | #endif // GEN_PASS_DECL_DTENSORFUNCTIONRENAMING |
891 | #ifdef GEN_PASS_DEF_DTENSORFUNCTIONRENAMING |
892 | namespace impl { |
893 | |
894 | template <typename DerivedT> |
895 | class DTensorFunctionRenamingBase : public ::mlir::OperationPass<mlir::ModuleOp> { |
896 | public: |
897 | using Base = DTensorFunctionRenamingBase; |
898 | |
899 | DTensorFunctionRenamingBase() : ::mlir::OperationPass<mlir::ModuleOp>(::mlir::TypeID::get<DerivedT>()) {} |
900 | DTensorFunctionRenamingBase(const DTensorFunctionRenamingBase &other) : ::mlir::OperationPass<mlir::ModuleOp>(other) {} |
901 | |
902 | /// Returns the command-line argument attached to this pass. |
903 | static constexpr ::llvm::StringLiteral getArgumentName() { |
904 | return ::llvm::StringLiteral("dtensor-function-renaming" ); |
905 | } |
906 | ::llvm::StringRef getArgument() const override { return "dtensor-function-renaming" ; } |
907 | |
908 | ::llvm::StringRef getDescription() const override { return "Renames private functions by appending an id to each name. This is used to make private function names unique across modules." ; } |
909 | |
910 | /// Returns the derived pass name. |
911 | static constexpr ::llvm::StringLiteral getPassName() { |
912 | return ::llvm::StringLiteral("DTensorFunctionRenaming" ); |
913 | } |
914 | ::llvm::StringRef getName() const override { return "DTensorFunctionRenaming" ; } |
915 | |
916 | /// Support isa/dyn_cast functionality for the derived pass class. |
917 | static bool classof(const ::mlir::Pass *pass) { |
918 | return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>(); |
919 | } |
920 | |
921 | /// A clone method to create a copy of this pass. |
922 | std::unique_ptr<::mlir::Pass> clonePass() const override { |
923 | return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this)); |
924 | } |
925 | |
926 | /// Return the dialect that must be loaded in the context before this pass. |
927 | void getDependentDialects(::mlir::DialectRegistry ®istry) const override { |
928 | |
929 | } |
930 | |
931 | /// Explicitly declare the TypeID for this class. We declare an explicit private |
932 | /// instantiation because Pass classes should only be visible by the current |
933 | /// library. |
934 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DTensorFunctionRenamingBase<DerivedT>) |
935 | |
936 | protected: |
937 | private: |
938 | }; |
939 | } // namespace impl |
940 | #undef GEN_PASS_DEF_DTENSORFUNCTIONRENAMING |
941 | #endif // GEN_PASS_DEF_DTENSORFUNCTIONRENAMING |
942 | |
943 | //===----------------------------------------------------------------------===// |
944 | // DTensorHandleCrossClusterDependencies |
945 | //===----------------------------------------------------------------------===// |
946 | #ifdef GEN_PASS_DECL_DTENSORHANDLECROSSCLUSTERDEPENDENCIES |
947 | #undef GEN_PASS_DECL_DTENSORHANDLECROSSCLUSTERDEPENDENCIES |
948 | #endif // GEN_PASS_DECL_DTENSORHANDLECROSSCLUSTERDEPENDENCIES |
949 | #ifdef GEN_PASS_DEF_DTENSORHANDLECROSSCLUSTERDEPENDENCIES |
950 | namespace impl { |
951 | |
952 | template <typename DerivedT> |
953 | class DTensorHandleCrossClusterDependenciesBase : public ::mlir::OperationPass<mlir::ModuleOp> { |
954 | public: |
955 | using Base = DTensorHandleCrossClusterDependenciesBase; |
956 | |
957 | DTensorHandleCrossClusterDependenciesBase() : ::mlir::OperationPass<mlir::ModuleOp>(::mlir::TypeID::get<DerivedT>()) {} |
958 | DTensorHandleCrossClusterDependenciesBase(const DTensorHandleCrossClusterDependenciesBase &other) : ::mlir::OperationPass<mlir::ModuleOp>(other) {} |
959 | |
960 | /// Returns the command-line argument attached to this pass. |
961 | static constexpr ::llvm::StringLiteral getArgumentName() { |
962 | return ::llvm::StringLiteral("dtensor-handle_cross_cluster_dependences" ); |
963 | } |
964 | ::llvm::StringRef getArgument() const override { return "dtensor-handle_cross_cluster_dependences" ; } |
965 | |
966 | ::llvm::StringRef getDescription() const override { return "Lowers cross mesh cluster data depedences as send/recvs." ; } |
967 | |
968 | /// Returns the derived pass name. |
969 | static constexpr ::llvm::StringLiteral getPassName() { |
970 | return ::llvm::StringLiteral("DTensorHandleCrossClusterDependencies" ); |
971 | } |
972 | ::llvm::StringRef getName() const override { return "DTensorHandleCrossClusterDependencies" ; } |
973 | |
974 | /// Support isa/dyn_cast functionality for the derived pass class. |
975 | static bool classof(const ::mlir::Pass *pass) { |
976 | return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>(); |
977 | } |
978 | |
979 | /// A clone method to create a copy of this pass. |
980 | std::unique_ptr<::mlir::Pass> clonePass() const override { |
981 | return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this)); |
982 | } |
983 | |
984 | /// Return the dialect that must be loaded in the context before this pass. |
985 | void getDependentDialects(::mlir::DialectRegistry ®istry) const override { |
986 | |
987 | } |
988 | |
989 | /// Explicitly declare the TypeID for this class. We declare an explicit private |
990 | /// instantiation because Pass classes should only be visible by the current |
991 | /// library. |
992 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DTensorHandleCrossClusterDependenciesBase<DerivedT>) |
993 | |
994 | protected: |
995 | private: |
996 | }; |
997 | } // namespace impl |
998 | #undef GEN_PASS_DEF_DTENSORHANDLECROSSCLUSTERDEPENDENCIES |
999 | #endif // GEN_PASS_DEF_DTENSORHANDLECROSSCLUSTERDEPENDENCIES |
1000 | |
1001 | //===----------------------------------------------------------------------===// |
1002 | // DTensorInferShapesForRestoreV2Op |
1003 | //===----------------------------------------------------------------------===// |
1004 | #ifdef GEN_PASS_DECL_DTENSORINFERSHAPESFORRESTOREV2OP |
1005 | #undef GEN_PASS_DECL_DTENSORINFERSHAPESFORRESTOREV2OP |
1006 | #endif // GEN_PASS_DECL_DTENSORINFERSHAPESFORRESTOREV2OP |
1007 | #ifdef GEN_PASS_DEF_DTENSORINFERSHAPESFORRESTOREV2OP |
1008 | namespace impl { |
1009 | |
1010 | template <typename DerivedT> |
1011 | class DTensorInferShapesForRestoreV2OpBase : public ::mlir::OperationPass<mlir::ModuleOp> { |
1012 | public: |
1013 | using Base = DTensorInferShapesForRestoreV2OpBase; |
1014 | |
1015 | DTensorInferShapesForRestoreV2OpBase() : ::mlir::OperationPass<mlir::ModuleOp>(::mlir::TypeID::get<DerivedT>()) {} |
1016 | DTensorInferShapesForRestoreV2OpBase(const DTensorInferShapesForRestoreV2OpBase &other) : ::mlir::OperationPass<mlir::ModuleOp>(other) {} |
1017 | |
1018 | /// Returns the command-line argument attached to this pass. |
1019 | static constexpr ::llvm::StringLiteral getArgumentName() { |
1020 | return ::llvm::StringLiteral("dtensor-infer-shapes-for-restorev2-op" ); |
1021 | } |
1022 | ::llvm::StringRef getArgument() const override { return "dtensor-infer-shapes-for-restorev2-op" ; } |
1023 | |
1024 | ::llvm::StringRef getDescription() const override { return "Infer shapes of the outputs of tf.RestoreV2Op from the AssignVariableOps that consume those outputs. This is used for DTensor integration with TF Checkpoint." ; } |
1025 | |
1026 | /// Returns the derived pass name. |
1027 | static constexpr ::llvm::StringLiteral getPassName() { |
1028 | return ::llvm::StringLiteral("DTensorInferShapesForRestoreV2Op" ); |
1029 | } |
1030 | ::llvm::StringRef getName() const override { return "DTensorInferShapesForRestoreV2Op" ; } |
1031 | |
1032 | /// Support isa/dyn_cast functionality for the derived pass class. |
1033 | static bool classof(const ::mlir::Pass *pass) { |
1034 | return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>(); |
1035 | } |
1036 | |
1037 | /// A clone method to create a copy of this pass. |
1038 | std::unique_ptr<::mlir::Pass> clonePass() const override { |
1039 | return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this)); |
1040 | } |
1041 | |
1042 | /// Return the dialect that must be loaded in the context before this pass. |
1043 | void getDependentDialects(::mlir::DialectRegistry ®istry) const override { |
1044 | |
1045 | } |
1046 | |
1047 | /// Explicitly declare the TypeID for this class. We declare an explicit private |
1048 | /// instantiation because Pass classes should only be visible by the current |
1049 | /// library. |
1050 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DTensorInferShapesForRestoreV2OpBase<DerivedT>) |
1051 | |
1052 | protected: |
1053 | private: |
1054 | }; |
1055 | } // namespace impl |
1056 | #undef GEN_PASS_DEF_DTENSORINFERSHAPESFORRESTOREV2OP |
1057 | #endif // GEN_PASS_DEF_DTENSORINFERSHAPESFORRESTOREV2OP |
1058 | |
1059 | //===----------------------------------------------------------------------===// |
1060 | // DTensorLayoutPropagationV2 |
1061 | //===----------------------------------------------------------------------===// |
1062 | #ifdef GEN_PASS_DECL_DTENSORLAYOUTPROPAGATIONV2 |
1063 | #undef GEN_PASS_DECL_DTENSORLAYOUTPROPAGATIONV2 |
1064 | #endif // GEN_PASS_DECL_DTENSORLAYOUTPROPAGATIONV2 |
1065 | #ifdef GEN_PASS_DEF_DTENSORLAYOUTPROPAGATIONV2 |
1066 | namespace impl { |
1067 | |
1068 | template <typename DerivedT> |
1069 | class DTensorLayoutPropagationV2Base : public ::mlir::OperationPass<mlir::ModuleOp> { |
1070 | public: |
1071 | using Base = DTensorLayoutPropagationV2Base; |
1072 | |
1073 | DTensorLayoutPropagationV2Base() : ::mlir::OperationPass<mlir::ModuleOp>(::mlir::TypeID::get<DerivedT>()) {} |
1074 | DTensorLayoutPropagationV2Base(const DTensorLayoutPropagationV2Base &other) : ::mlir::OperationPass<mlir::ModuleOp>(other) {} |
1075 | |
1076 | /// Returns the command-line argument attached to this pass. |
1077 | static constexpr ::llvm::StringLiteral getArgumentName() { |
1078 | return ::llvm::StringLiteral("dtensor-layout-propagation-v2" ); |
1079 | } |
1080 | ::llvm::StringRef getArgument() const override { return "dtensor-layout-propagation-v2" ; } |
1081 | |
1082 | ::llvm::StringRef getDescription() const override { return "Propagates layout information for all the TF Ops." ; } |
1083 | |
1084 | /// Returns the derived pass name. |
1085 | static constexpr ::llvm::StringLiteral getPassName() { |
1086 | return ::llvm::StringLiteral("DTensorLayoutPropagationV2" ); |
1087 | } |
1088 | ::llvm::StringRef getName() const override { return "DTensorLayoutPropagationV2" ; } |
1089 | |
1090 | /// Support isa/dyn_cast functionality for the derived pass class. |
1091 | static bool classof(const ::mlir::Pass *pass) { |
1092 | return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>(); |
1093 | } |
1094 | |
1095 | /// A clone method to create a copy of this pass. |
1096 | std::unique_ptr<::mlir::Pass> clonePass() const override { |
1097 | return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this)); |
1098 | } |
1099 | |
1100 | /// Return the dialect that must be loaded in the context before this pass. |
1101 | void getDependentDialects(::mlir::DialectRegistry ®istry) const override { |
1102 | |
1103 | } |
1104 | |
1105 | /// Explicitly declare the TypeID for this class. We declare an explicit private |
1106 | /// instantiation because Pass classes should only be visible by the current |
1107 | /// library. |
1108 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DTensorLayoutPropagationV2Base<DerivedT>) |
1109 | |
1110 | protected: |
1111 | private: |
1112 | }; |
1113 | } // namespace impl |
1114 | #undef GEN_PASS_DEF_DTENSORLAYOUTPROPAGATIONV2 |
1115 | #endif // GEN_PASS_DEF_DTENSORLAYOUTPROPAGATIONV2 |
1116 | |
1117 | //===----------------------------------------------------------------------===// |
1118 | // DTensorLowerSendRecv |
1119 | //===----------------------------------------------------------------------===// |
1120 | #ifdef GEN_PASS_DECL_DTENSORLOWERSENDRECV |
1121 | #undef GEN_PASS_DECL_DTENSORLOWERSENDRECV |
1122 | #endif // GEN_PASS_DECL_DTENSORLOWERSENDRECV |
1123 | #ifdef GEN_PASS_DEF_DTENSORLOWERSENDRECV |
1124 | namespace impl { |
1125 | |
1126 | template <typename DerivedT> |
1127 | class DTensorLowerSendRecvBase : public ::mlir::OperationPass<mlir::ModuleOp> { |
1128 | public: |
1129 | using Base = DTensorLowerSendRecvBase; |
1130 | |
1131 | DTensorLowerSendRecvBase() : ::mlir::OperationPass<mlir::ModuleOp>(::mlir::TypeID::get<DerivedT>()) {} |
1132 | DTensorLowerSendRecvBase(const DTensorLowerSendRecvBase &other) : ::mlir::OperationPass<mlir::ModuleOp>(other) {} |
1133 | |
1134 | /// Returns the command-line argument attached to this pass. |
1135 | static constexpr ::llvm::StringLiteral getArgumentName() { |
1136 | return ::llvm::StringLiteral("dtensor-lower-send-recv" ); |
1137 | } |
1138 | ::llvm::StringRef getArgument() const override { return "dtensor-lower-send-recv" ; } |
1139 | |
1140 | ::llvm::StringRef getDescription() const override { return "Lowers DTensorSend/DTensorRecv ops to send/recv ops." ; } |
1141 | |
1142 | /// Returns the derived pass name. |
1143 | static constexpr ::llvm::StringLiteral getPassName() { |
1144 | return ::llvm::StringLiteral("DTensorLowerSendRecv" ); |
1145 | } |
1146 | ::llvm::StringRef getName() const override { return "DTensorLowerSendRecv" ; } |
1147 | |
1148 | /// Support isa/dyn_cast functionality for the derived pass class. |
1149 | static bool classof(const ::mlir::Pass *pass) { |
1150 | return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>(); |
1151 | } |
1152 | |
1153 | /// A clone method to create a copy of this pass. |
1154 | std::unique_ptr<::mlir::Pass> clonePass() const override { |
1155 | return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this)); |
1156 | } |
1157 | |
1158 | /// Return the dialect that must be loaded in the context before this pass. |
1159 | void getDependentDialects(::mlir::DialectRegistry ®istry) const override { |
1160 | |
1161 | } |
1162 | |
1163 | /// Explicitly declare the TypeID for this class. We declare an explicit private |
1164 | /// instantiation because Pass classes should only be visible by the current |
1165 | /// library. |
1166 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DTensorLowerSendRecvBase<DerivedT>) |
1167 | |
1168 | protected: |
1169 | private: |
1170 | }; |
1171 | } // namespace impl |
1172 | #undef GEN_PASS_DEF_DTENSORLOWERSENDRECV |
1173 | #endif // GEN_PASS_DEF_DTENSORLOWERSENDRECV |
1174 | |
1175 | //===----------------------------------------------------------------------===// |
1176 | // DTensorMergeClusters |
1177 | //===----------------------------------------------------------------------===// |
1178 | #ifdef GEN_PASS_DECL_DTENSORMERGECLUSTERS |
1179 | #undef GEN_PASS_DECL_DTENSORMERGECLUSTERS |
1180 | #endif // GEN_PASS_DECL_DTENSORMERGECLUSTERS |
1181 | #ifdef GEN_PASS_DEF_DTENSORMERGECLUSTERS |
1182 | namespace impl { |
1183 | |
1184 | template <typename DerivedT> |
1185 | class DTensorMergeClustersBase : public ::mlir::OperationPass<mlir::ModuleOp> { |
1186 | public: |
1187 | using Base = DTensorMergeClustersBase; |
1188 | |
1189 | DTensorMergeClustersBase() : ::mlir::OperationPass<mlir::ModuleOp>(::mlir::TypeID::get<DerivedT>()) {} |
1190 | DTensorMergeClustersBase(const DTensorMergeClustersBase &other) : ::mlir::OperationPass<mlir::ModuleOp>(other) {} |
1191 | |
1192 | /// Returns the command-line argument attached to this pass. |
1193 | static constexpr ::llvm::StringLiteral getArgumentName() { |
1194 | return ::llvm::StringLiteral("dtensor-merge-clusters" ); |
1195 | } |
1196 | ::llvm::StringRef getArgument() const override { return "dtensor-merge-clusters" ; } |
1197 | |
1198 | ::llvm::StringRef getDescription() const override { return "Merges tf_device.Clusters ops with same mesh specification." ; } |
1199 | |
1200 | /// Returns the derived pass name. |
1201 | static constexpr ::llvm::StringLiteral getPassName() { |
1202 | return ::llvm::StringLiteral("DTensorMergeClusters" ); |
1203 | } |
1204 | ::llvm::StringRef getName() const override { return "DTensorMergeClusters" ; } |
1205 | |
1206 | /// Support isa/dyn_cast functionality for the derived pass class. |
1207 | static bool classof(const ::mlir::Pass *pass) { |
1208 | return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>(); |
1209 | } |
1210 | |
1211 | /// A clone method to create a copy of this pass. |
1212 | std::unique_ptr<::mlir::Pass> clonePass() const override { |
1213 | return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this)); |
1214 | } |
1215 | |
1216 | /// Return the dialect that must be loaded in the context before this pass. |
1217 | void getDependentDialects(::mlir::DialectRegistry ®istry) const override { |
1218 | |
1219 | } |
1220 | |
1221 | /// Explicitly declare the TypeID for this class. We declare an explicit private |
1222 | /// instantiation because Pass classes should only be visible by the current |
1223 | /// library. |
1224 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DTensorMergeClustersBase<DerivedT>) |
1225 | |
1226 | protected: |
1227 | private: |
1228 | }; |
1229 | } // namespace impl |
1230 | #undef GEN_PASS_DEF_DTENSORMERGECLUSTERS |
1231 | #endif // GEN_PASS_DEF_DTENSORMERGECLUSTERS |
1232 | |
1233 | //===----------------------------------------------------------------------===// |
1234 | // DTensorMeshPropagation |
1235 | //===----------------------------------------------------------------------===// |
1236 | #ifdef GEN_PASS_DECL_DTENSORMESHPROPAGATION |
1237 | #undef GEN_PASS_DECL_DTENSORMESHPROPAGATION |
1238 | #endif // GEN_PASS_DECL_DTENSORMESHPROPAGATION |
1239 | #ifdef GEN_PASS_DEF_DTENSORMESHPROPAGATION |
1240 | namespace impl { |
1241 | |
1242 | template <typename DerivedT> |
1243 | class DTensorMeshPropagationBase : public ::mlir::OperationPass<mlir::ModuleOp> { |
1244 | public: |
1245 | using Base = DTensorMeshPropagationBase; |
1246 | |
1247 | DTensorMeshPropagationBase() : ::mlir::OperationPass<mlir::ModuleOp>(::mlir::TypeID::get<DerivedT>()) {} |
1248 | DTensorMeshPropagationBase(const DTensorMeshPropagationBase &other) : ::mlir::OperationPass<mlir::ModuleOp>(other) {} |
1249 | |
1250 | /// Returns the command-line argument attached to this pass. |
1251 | static constexpr ::llvm::StringLiteral getArgumentName() { |
1252 | return ::llvm::StringLiteral("dtensor-mesh-propagation" ); |
1253 | } |
1254 | ::llvm::StringRef getArgument() const override { return "dtensor-mesh-propagation" ; } |
1255 | |
1256 | ::llvm::StringRef getDescription() const override { return "Propagates mesh information to all clusters." ; } |
1257 | |
1258 | /// Returns the derived pass name. |
1259 | static constexpr ::llvm::StringLiteral getPassName() { |
1260 | return ::llvm::StringLiteral("DTensorMeshPropagation" ); |
1261 | } |
1262 | ::llvm::StringRef getName() const override { return "DTensorMeshPropagation" ; } |
1263 | |
1264 | /// Support isa/dyn_cast functionality for the derived pass class. |
1265 | static bool classof(const ::mlir::Pass *pass) { |
1266 | return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>(); |
1267 | } |
1268 | |
1269 | /// A clone method to create a copy of this pass. |
1270 | std::unique_ptr<::mlir::Pass> clonePass() const override { |
1271 | return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this)); |
1272 | } |
1273 | |
1274 | /// Return the dialect that must be loaded in the context before this pass. |
1275 | void getDependentDialects(::mlir::DialectRegistry ®istry) const override { |
1276 | |
1277 | } |
1278 | |
1279 | /// Explicitly declare the TypeID for this class. We declare an explicit private |
1280 | /// instantiation because Pass classes should only be visible by the current |
1281 | /// library. |
1282 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DTensorMeshPropagationBase<DerivedT>) |
1283 | |
1284 | protected: |
1285 | private: |
1286 | }; |
1287 | } // namespace impl |
1288 | #undef GEN_PASS_DEF_DTENSORMESHPROPAGATION |
1289 | #endif // GEN_PASS_DEF_DTENSORMESHPROPAGATION |
1290 | |
1291 | //===----------------------------------------------------------------------===// |
1292 | // DTensorMixedPrecisionReduce |
1293 | //===----------------------------------------------------------------------===// |
1294 | #ifdef GEN_PASS_DECL_DTENSORMIXEDPRECISIONREDUCE |
1295 | #undef GEN_PASS_DECL_DTENSORMIXEDPRECISIONREDUCE |
1296 | #endif // GEN_PASS_DECL_DTENSORMIXEDPRECISIONREDUCE |
1297 | #ifdef GEN_PASS_DEF_DTENSORMIXEDPRECISIONREDUCE |
1298 | namespace impl { |
1299 | |
1300 | template <typename DerivedT> |
1301 | class DTensorMixedPrecisionReduceBase : public ::mlir::OperationPass<mlir::func::FuncOp> { |
1302 | public: |
1303 | using Base = DTensorMixedPrecisionReduceBase; |
1304 | |
1305 | DTensorMixedPrecisionReduceBase() : ::mlir::OperationPass<mlir::func::FuncOp>(::mlir::TypeID::get<DerivedT>()) {} |
1306 | DTensorMixedPrecisionReduceBase(const DTensorMixedPrecisionReduceBase &other) : ::mlir::OperationPass<mlir::func::FuncOp>(other) {} |
1307 | |
1308 | /// Returns the command-line argument attached to this pass. |
1309 | static constexpr ::llvm::StringLiteral getArgumentName() { |
1310 | return ::llvm::StringLiteral("dtensor-mixed-precision-reduce" ); |
1311 | } |
1312 | ::llvm::StringRef getArgument() const override { return "dtensor-mixed-precision-reduce" ; } |
1313 | |
1314 | ::llvm::StringRef getDescription() const override { return "Upcast tensors to higher precision type for reduction ops." ; } |
1315 | |
1316 | /// Returns the derived pass name. |
1317 | static constexpr ::llvm::StringLiteral getPassName() { |
1318 | return ::llvm::StringLiteral("DTensorMixedPrecisionReduce" ); |
1319 | } |
1320 | ::llvm::StringRef getName() const override { return "DTensorMixedPrecisionReduce" ; } |
1321 | |
1322 | /// Support isa/dyn_cast functionality for the derived pass class. |
1323 | static bool classof(const ::mlir::Pass *pass) { |
1324 | return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>(); |
1325 | } |
1326 | |
1327 | /// A clone method to create a copy of this pass. |
1328 | std::unique_ptr<::mlir::Pass> clonePass() const override { |
1329 | return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this)); |
1330 | } |
1331 | |
1332 | /// Return the dialect that must be loaded in the context before this pass. |
1333 | void getDependentDialects(::mlir::DialectRegistry ®istry) const override { |
1334 | |
1335 | } |
1336 | |
1337 | /// Explicitly declare the TypeID for this class. We declare an explicit private |
1338 | /// instantiation because Pass classes should only be visible by the current |
1339 | /// library. |
1340 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DTensorMixedPrecisionReduceBase<DerivedT>) |
1341 | |
1342 | protected: |
1343 | private: |
1344 | }; |
1345 | } // namespace impl |
1346 | #undef GEN_PASS_DEF_DTENSORMIXEDPRECISIONREDUCE |
1347 | #endif // GEN_PASS_DEF_DTENSORMIXEDPRECISIONREDUCE |
1348 | |
1349 | //===----------------------------------------------------------------------===// |
1350 | // DTensorMoveCompilationToHost |
1351 | //===----------------------------------------------------------------------===// |
1352 | #ifdef GEN_PASS_DECL_DTENSORMOVECOMPILATIONTOHOST |
1353 | #undef GEN_PASS_DECL_DTENSORMOVECOMPILATIONTOHOST |
1354 | #endif // GEN_PASS_DECL_DTENSORMOVECOMPILATIONTOHOST |
1355 | #ifdef GEN_PASS_DEF_DTENSORMOVECOMPILATIONTOHOST |
1356 | namespace impl { |
1357 | |
1358 | template <typename DerivedT> |
1359 | class DTensorMoveCompilationToHostBase : public ::mlir::OperationPass<mlir::ModuleOp> { |
1360 | public: |
1361 | using Base = DTensorMoveCompilationToHostBase; |
1362 | |
1363 | DTensorMoveCompilationToHostBase() : ::mlir::OperationPass<mlir::ModuleOp>(::mlir::TypeID::get<DerivedT>()) {} |
1364 | DTensorMoveCompilationToHostBase(const DTensorMoveCompilationToHostBase &other) : ::mlir::OperationPass<mlir::ModuleOp>(other) {} |
1365 | |
1366 | /// Returns the command-line argument attached to this pass. |
1367 | static constexpr ::llvm::StringLiteral getArgumentName() { |
1368 | return ::llvm::StringLiteral("dtensor-move-compilation-to-host" ); |
1369 | } |
1370 | ::llvm::StringRef getArgument() const override { return "dtensor-move-compilation-to-host" ; } |
1371 | |
1372 | ::llvm::StringRef getDescription() const override { return "Moves XLA compilation ops to host computation." ; } |
1373 | |
1374 | /// Returns the derived pass name. |
1375 | static constexpr ::llvm::StringLiteral getPassName() { |
1376 | return ::llvm::StringLiteral("DTensorMoveCompilationToHost" ); |
1377 | } |
1378 | ::llvm::StringRef getName() const override { return "DTensorMoveCompilationToHost" ; } |
1379 | |
1380 | /// Support isa/dyn_cast functionality for the derived pass class. |
1381 | static bool classof(const ::mlir::Pass *pass) { |
1382 | return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>(); |
1383 | } |
1384 | |
1385 | /// A clone method to create a copy of this pass. |
1386 | std::unique_ptr<::mlir::Pass> clonePass() const override { |
1387 | return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this)); |
1388 | } |
1389 | |
1390 | /// Return the dialect that must be loaded in the context before this pass. |
1391 | void getDependentDialects(::mlir::DialectRegistry ®istry) const override { |
1392 | |
1393 | } |
1394 | |
1395 | /// Explicitly declare the TypeID for this class. We declare an explicit private |
1396 | /// instantiation because Pass classes should only be visible by the current |
1397 | /// library. |
1398 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DTensorMoveCompilationToHostBase<DerivedT>) |
1399 | |
1400 | protected: |
1401 | private: |
1402 | }; |
1403 | } // namespace impl |
1404 | #undef GEN_PASS_DEF_DTENSORMOVECOMPILATIONTOHOST |
1405 | #endif // GEN_PASS_DEF_DTENSORMOVECOMPILATIONTOHOST |
1406 | |
1407 | //===----------------------------------------------------------------------===// |
1408 | // DTensorOpToDeviceCluster |
1409 | //===----------------------------------------------------------------------===// |
1410 | #ifdef GEN_PASS_DECL_DTENSOROPTODEVICECLUSTER |
1411 | #undef GEN_PASS_DECL_DTENSOROPTODEVICECLUSTER |
1412 | #endif // GEN_PASS_DECL_DTENSOROPTODEVICECLUSTER |
1413 | #ifdef GEN_PASS_DEF_DTENSOROPTODEVICECLUSTER |
1414 | namespace impl { |
1415 | |
1416 | template <typename DerivedT> |
1417 | class DTensorOpToDeviceClusterBase : public ::mlir::OperationPass<mlir::func::FuncOp> { |
1418 | public: |
1419 | using Base = DTensorOpToDeviceClusterBase; |
1420 | |
1421 | DTensorOpToDeviceClusterBase() : ::mlir::OperationPass<mlir::func::FuncOp>(::mlir::TypeID::get<DerivedT>()) {} |
1422 | DTensorOpToDeviceClusterBase(const DTensorOpToDeviceClusterBase &other) : ::mlir::OperationPass<mlir::func::FuncOp>(other) {} |
1423 | |
1424 | /// Returns the command-line argument attached to this pass. |
1425 | static constexpr ::llvm::StringLiteral getArgumentName() { |
1426 | return ::llvm::StringLiteral("dtensor-op-to-device-cluster" ); |
1427 | } |
1428 | ::llvm::StringRef getArgument() const override { return "dtensor-op-to-device-cluster" ; } |
1429 | |
1430 | ::llvm::StringRef getDescription() const override { return "Creates and wraps tf_device.cluster op for all TF ops" ; } |
1431 | |
1432 | /// Returns the derived pass name. |
1433 | static constexpr ::llvm::StringLiteral getPassName() { |
1434 | return ::llvm::StringLiteral("DTensorOpToDeviceCluster" ); |
1435 | } |
1436 | ::llvm::StringRef getName() const override { return "DTensorOpToDeviceCluster" ; } |
1437 | |
1438 | /// Support isa/dyn_cast functionality for the derived pass class. |
1439 | static bool classof(const ::mlir::Pass *pass) { |
1440 | return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>(); |
1441 | } |
1442 | |
1443 | /// A clone method to create a copy of this pass. |
1444 | std::unique_ptr<::mlir::Pass> clonePass() const override { |
1445 | return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this)); |
1446 | } |
1447 | |
1448 | /// Return the dialect that must be loaded in the context before this pass. |
1449 | void getDependentDialects(::mlir::DialectRegistry ®istry) const override { |
1450 | |
1451 | } |
1452 | |
1453 | /// Explicitly declare the TypeID for this class. We declare an explicit private |
1454 | /// instantiation because Pass classes should only be visible by the current |
1455 | /// library. |
1456 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DTensorOpToDeviceClusterBase<DerivedT>) |
1457 | |
1458 | protected: |
1459 | private: |
1460 | }; |
1461 | } // namespace impl |
1462 | #undef GEN_PASS_DEF_DTENSOROPTODEVICECLUSTER |
1463 | #endif // GEN_PASS_DEF_DTENSOROPTODEVICECLUSTER |
1464 | |
1465 | //===----------------------------------------------------------------------===// |
1466 | // DTensorPropagateDefaultLayout |
1467 | //===----------------------------------------------------------------------===// |
1468 | #ifdef GEN_PASS_DECL_DTENSORPROPAGATEDEFAULTLAYOUT |
1469 | #undef GEN_PASS_DECL_DTENSORPROPAGATEDEFAULTLAYOUT |
1470 | #endif // GEN_PASS_DECL_DTENSORPROPAGATEDEFAULTLAYOUT |
1471 | #ifdef GEN_PASS_DEF_DTENSORPROPAGATEDEFAULTLAYOUT |
1472 | namespace impl { |
1473 | |
1474 | template <typename DerivedT> |
1475 | class DTensorPropagateDefaultLayoutBase : public ::mlir::OperationPass<mlir::func::FuncOp> { |
1476 | public: |
1477 | using Base = DTensorPropagateDefaultLayoutBase; |
1478 | |
1479 | DTensorPropagateDefaultLayoutBase() : ::mlir::OperationPass<mlir::func::FuncOp>(::mlir::TypeID::get<DerivedT>()) {} |
1480 | DTensorPropagateDefaultLayoutBase(const DTensorPropagateDefaultLayoutBase &other) : ::mlir::OperationPass<mlir::func::FuncOp>(other) {} |
1481 | |
1482 | /// Returns the command-line argument attached to this pass. |
1483 | static constexpr ::llvm::StringLiteral getArgumentName() { |
1484 | return ::llvm::StringLiteral("dtensor-propagate-default-layout" ); |
1485 | } |
1486 | ::llvm::StringRef getArgument() const override { return "dtensor-propagate-default-layout" ; } |
1487 | |
1488 | ::llvm::StringRef getDescription() const override { return "Converts layout attributes added by end users to DTensorLayout op." ; } |
1489 | |
1490 | /// Returns the derived pass name. |
1491 | static constexpr ::llvm::StringLiteral getPassName() { |
1492 | return ::llvm::StringLiteral("DTensorPropagateDefaultLayout" ); |
1493 | } |
1494 | ::llvm::StringRef getName() const override { return "DTensorPropagateDefaultLayout" ; } |
1495 | |
1496 | /// Support isa/dyn_cast functionality for the derived pass class. |
1497 | static bool classof(const ::mlir::Pass *pass) { |
1498 | return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>(); |
1499 | } |
1500 | |
1501 | /// A clone method to create a copy of this pass. |
1502 | std::unique_ptr<::mlir::Pass> clonePass() const override { |
1503 | return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this)); |
1504 | } |
1505 | |
1506 | /// Return the dialect that must be loaded in the context before this pass. |
1507 | void getDependentDialects(::mlir::DialectRegistry ®istry) const override { |
1508 | |
1509 | } |
1510 | |
1511 | /// Explicitly declare the TypeID for this class. We declare an explicit private |
1512 | /// instantiation because Pass classes should only be visible by the current |
1513 | /// library. |
1514 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DTensorPropagateDefaultLayoutBase<DerivedT>) |
1515 | |
1516 | protected: |
1517 | private: |
1518 | }; |
1519 | } // namespace impl |
1520 | #undef GEN_PASS_DEF_DTENSORPROPAGATEDEFAULTLAYOUT |
1521 | #endif // GEN_PASS_DEF_DTENSORPROPAGATEDEFAULTLAYOUT |
1522 | |
1523 | //===----------------------------------------------------------------------===// |
1524 | // DTensorPropagateDeviceIdToFunctionArgs |
1525 | //===----------------------------------------------------------------------===// |
1526 | #ifdef GEN_PASS_DECL_DTENSORPROPAGATEDEVICEIDTOFUNCTIONARGS |
1527 | #undef GEN_PASS_DECL_DTENSORPROPAGATEDEVICEIDTOFUNCTIONARGS |
1528 | #endif // GEN_PASS_DECL_DTENSORPROPAGATEDEVICEIDTOFUNCTIONARGS |
1529 | #ifdef GEN_PASS_DEF_DTENSORPROPAGATEDEVICEIDTOFUNCTIONARGS |
1530 | namespace impl { |
1531 | |
1532 | template <typename DerivedT> |
1533 | class DTensorPropagateDeviceIdToFunctionArgsBase : public ::mlir::OperationPass<mlir::ModuleOp> { |
1534 | public: |
1535 | using Base = DTensorPropagateDeviceIdToFunctionArgsBase; |
1536 | |
1537 | DTensorPropagateDeviceIdToFunctionArgsBase() : ::mlir::OperationPass<mlir::ModuleOp>(::mlir::TypeID::get<DerivedT>()) {} |
1538 | DTensorPropagateDeviceIdToFunctionArgsBase(const DTensorPropagateDeviceIdToFunctionArgsBase &other) : ::mlir::OperationPass<mlir::ModuleOp>(other) {} |
1539 | |
1540 | /// Returns the command-line argument attached to this pass. |
1541 | static constexpr ::llvm::StringLiteral getArgumentName() { |
1542 | return ::llvm::StringLiteral("dtensor-propagate-device-id-to-function-args" ); |
1543 | } |
1544 | ::llvm::StringRef getArgument() const override { return "dtensor-propagate-device-id-to-function-args" ; } |
1545 | |
1546 | ::llvm::StringRef getDescription() const override { return "Adds device id as arguments to all private function in graph." ; } |
1547 | |
1548 | /// Returns the derived pass name. |
1549 | static constexpr ::llvm::StringLiteral getPassName() { |
1550 | return ::llvm::StringLiteral("DTensorPropagateDeviceIdToFunctionArgs" ); |
1551 | } |
1552 | ::llvm::StringRef getName() const override { return "DTensorPropagateDeviceIdToFunctionArgs" ; } |
1553 | |
1554 | /// Support isa/dyn_cast functionality for the derived pass class. |
1555 | static bool classof(const ::mlir::Pass *pass) { |
1556 | return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>(); |
1557 | } |
1558 | |
1559 | /// A clone method to create a copy of this pass. |
1560 | std::unique_ptr<::mlir::Pass> clonePass() const override { |
1561 | return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this)); |
1562 | } |
1563 | |
1564 | /// Return the dialect that must be loaded in the context before this pass. |
1565 | void getDependentDialects(::mlir::DialectRegistry ®istry) const override { |
1566 | |
1567 | } |
1568 | |
1569 | /// Explicitly declare the TypeID for this class. We declare an explicit private |
1570 | /// instantiation because Pass classes should only be visible by the current |
1571 | /// library. |
1572 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DTensorPropagateDeviceIdToFunctionArgsBase<DerivedT>) |
1573 | |
1574 | protected: |
1575 | private: |
1576 | }; |
1577 | } // namespace impl |
1578 | #undef GEN_PASS_DEF_DTENSORPROPAGATEDEVICEIDTOFUNCTIONARGS |
1579 | #endif // GEN_PASS_DEF_DTENSORPROPAGATEDEVICEIDTOFUNCTIONARGS |
1580 | |
1581 | //===----------------------------------------------------------------------===// |
1582 | // DTensorReduceScatterLowering |
1583 | //===----------------------------------------------------------------------===// |
1584 | #ifdef GEN_PASS_DECL_DTENSORREDUCESCATTERLOWERING |
1585 | #undef GEN_PASS_DECL_DTENSORREDUCESCATTERLOWERING |
1586 | #endif // GEN_PASS_DECL_DTENSORREDUCESCATTERLOWERING |
1587 | #ifdef GEN_PASS_DEF_DTENSORREDUCESCATTERLOWERING |
1588 | namespace impl { |
1589 | |
1590 | template <typename DerivedT> |
1591 | class DTensorReduceScatterLoweringBase : public ::mlir::OperationPass<mlir::ModuleOp> { |
1592 | public: |
1593 | using Base = DTensorReduceScatterLoweringBase; |
1594 | |
1595 | DTensorReduceScatterLoweringBase() : ::mlir::OperationPass<mlir::ModuleOp>(::mlir::TypeID::get<DerivedT>()) {} |
1596 | DTensorReduceScatterLoweringBase(const DTensorReduceScatterLoweringBase &other) : ::mlir::OperationPass<mlir::ModuleOp>(other) {} |
1597 | |
1598 | /// Returns the command-line argument attached to this pass. |
1599 | static constexpr ::llvm::StringLiteral getArgumentName() { |
1600 | return ::llvm::StringLiteral("dtensor-reduce-scatter-lowering" ); |
1601 | } |
1602 | ::llvm::StringRef getArgument() const override { return "dtensor-reduce-scatter-lowering" ; } |
1603 | |
1604 | ::llvm::StringRef getDescription() const override { return "Converts logical ReduceScatter ops into physical ReduceScatter ops." ; } |
1605 | |
1606 | /// Returns the derived pass name. |
1607 | static constexpr ::llvm::StringLiteral getPassName() { |
1608 | return ::llvm::StringLiteral("DTensorReduceScatterLowering" ); |
1609 | } |
1610 | ::llvm::StringRef getName() const override { return "DTensorReduceScatterLowering" ; } |
1611 | |
1612 | /// Support isa/dyn_cast functionality for the derived pass class. |
1613 | static bool classof(const ::mlir::Pass *pass) { |
1614 | return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>(); |
1615 | } |
1616 | |
1617 | /// A clone method to create a copy of this pass. |
1618 | std::unique_ptr<::mlir::Pass> clonePass() const override { |
1619 | return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this)); |
1620 | } |
1621 | |
1622 | /// Return the dialect that must be loaded in the context before this pass. |
1623 | void getDependentDialects(::mlir::DialectRegistry ®istry) const override { |
1624 | |
1625 | } |
1626 | |
1627 | /// Explicitly declare the TypeID for this class. We declare an explicit private |
1628 | /// instantiation because Pass classes should only be visible by the current |
1629 | /// library. |
1630 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DTensorReduceScatterLoweringBase<DerivedT>) |
1631 | |
1632 | protected: |
1633 | private: |
1634 | }; |
1635 | } // namespace impl |
1636 | #undef GEN_PASS_DEF_DTENSORREDUCESCATTERLOWERING |
1637 | #endif // GEN_PASS_DEF_DTENSORREDUCESCATTERLOWERING |
1638 | |
1639 | //===----------------------------------------------------------------------===// |
1640 | // DTensorSPMDExpansion |
1641 | //===----------------------------------------------------------------------===// |
1642 | #ifdef GEN_PASS_DECL_DTENSORSPMDEXPANSION |
1643 | #undef GEN_PASS_DECL_DTENSORSPMDEXPANSION |
1644 | #endif // GEN_PASS_DECL_DTENSORSPMDEXPANSION |
1645 | #ifdef GEN_PASS_DEF_DTENSORSPMDEXPANSION |
1646 | namespace impl { |
1647 | |
1648 | template <typename DerivedT> |
1649 | class DTensorSPMDExpansionBase : public ::mlir::OperationPass<mlir::ModuleOp> { |
1650 | public: |
1651 | using Base = DTensorSPMDExpansionBase; |
1652 | |
1653 | DTensorSPMDExpansionBase() : ::mlir::OperationPass<mlir::ModuleOp>(::mlir::TypeID::get<DerivedT>()) {} |
1654 | DTensorSPMDExpansionBase(const DTensorSPMDExpansionBase &other) : ::mlir::OperationPass<mlir::ModuleOp>(other) {} |
1655 | |
1656 | /// Returns the command-line argument attached to this pass. |
1657 | static constexpr ::llvm::StringLiteral getArgumentName() { |
1658 | return ::llvm::StringLiteral("dtensor-spmd-expansion" ); |
1659 | } |
1660 | ::llvm::StringRef getArgument() const override { return "dtensor-spmd-expansion" ; } |
1661 | |
1662 | ::llvm::StringRef getDescription() const override { return "Converts ops into SPMD expanded form." ; } |
1663 | |
1664 | /// Returns the derived pass name. |
1665 | static constexpr ::llvm::StringLiteral getPassName() { |
1666 | return ::llvm::StringLiteral("DTensorSPMDExpansion" ); |
1667 | } |
1668 | ::llvm::StringRef getName() const override { return "DTensorSPMDExpansion" ; } |
1669 | |
1670 | /// Support isa/dyn_cast functionality for the derived pass class. |
1671 | static bool classof(const ::mlir::Pass *pass) { |
1672 | return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>(); |
1673 | } |
1674 | |
1675 | /// A clone method to create a copy of this pass. |
1676 | std::unique_ptr<::mlir::Pass> clonePass() const override { |
1677 | return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this)); |
1678 | } |
1679 | |
1680 | /// Return the dialect that must be loaded in the context before this pass. |
1681 | void getDependentDialects(::mlir::DialectRegistry ®istry) const override { |
1682 | |
1683 | } |
1684 | |
1685 | /// Explicitly declare the TypeID for this class. We declare an explicit private |
1686 | /// instantiation because Pass classes should only be visible by the current |
1687 | /// library. |
1688 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DTensorSPMDExpansionBase<DerivedT>) |
1689 | |
1690 | protected: |
1691 | private: |
1692 | }; |
1693 | } // namespace impl |
1694 | #undef GEN_PASS_DEF_DTENSORSPMDEXPANSION |
1695 | #endif // GEN_PASS_DEF_DTENSORSPMDEXPANSION |
1696 | |
1697 | //===----------------------------------------------------------------------===// |
1698 | // DTensorSetDefaultSharding |
1699 | //===----------------------------------------------------------------------===// |
1700 | #ifdef GEN_PASS_DECL_DTENSORSETDEFAULTSHARDING |
1701 | #undef GEN_PASS_DECL_DTENSORSETDEFAULTSHARDING |
1702 | #endif // GEN_PASS_DECL_DTENSORSETDEFAULTSHARDING |
1703 | #ifdef GEN_PASS_DEF_DTENSORSETDEFAULTSHARDING |
1704 | namespace impl { |
1705 | |
1706 | template <typename DerivedT> |
1707 | class DTensorSetDefaultShardingBase : public ::mlir::OperationPass<mlir::func::FuncOp> { |
1708 | public: |
1709 | using Base = DTensorSetDefaultShardingBase; |
1710 | |
1711 | DTensorSetDefaultShardingBase() : ::mlir::OperationPass<mlir::func::FuncOp>(::mlir::TypeID::get<DerivedT>()) {} |
1712 | DTensorSetDefaultShardingBase(const DTensorSetDefaultShardingBase &other) : ::mlir::OperationPass<mlir::func::FuncOp>(other) {} |
1713 | |
1714 | /// Returns the command-line argument attached to this pass. |
1715 | static constexpr ::llvm::StringLiteral getArgumentName() { |
1716 | return ::llvm::StringLiteral("dtensor-set-default-sharding" ); |
1717 | } |
1718 | ::llvm::StringRef getArgument() const override { return "dtensor-set-default-sharding" ; } |
1719 | |
1720 | ::llvm::StringRef getDescription() const override { return "Sets default sharding of TPU computation inputs/outputs to maximal." ; } |
1721 | |
1722 | /// Returns the derived pass name. |
1723 | static constexpr ::llvm::StringLiteral getPassName() { |
1724 | return ::llvm::StringLiteral("DTensorSetDefaultSharding" ); |
1725 | } |
1726 | ::llvm::StringRef getName() const override { return "DTensorSetDefaultSharding" ; } |
1727 | |
1728 | /// Support isa/dyn_cast functionality for the derived pass class. |
1729 | static bool classof(const ::mlir::Pass *pass) { |
1730 | return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>(); |
1731 | } |
1732 | |
1733 | /// A clone method to create a copy of this pass. |
1734 | std::unique_ptr<::mlir::Pass> clonePass() const override { |
1735 | return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this)); |
1736 | } |
1737 | |
1738 | /// Return the dialect that must be loaded in the context before this pass. |
1739 | void getDependentDialects(::mlir::DialectRegistry ®istry) const override { |
1740 | |
1741 | } |
1742 | |
1743 | /// Explicitly declare the TypeID for this class. We declare an explicit private |
1744 | /// instantiation because Pass classes should only be visible by the current |
1745 | /// library. |
1746 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DTensorSetDefaultShardingBase<DerivedT>) |
1747 | |
1748 | protected: |
1749 | private: |
1750 | }; |
1751 | } // namespace impl |
1752 | #undef GEN_PASS_DEF_DTENSORSETDEFAULTSHARDING |
1753 | #endif // GEN_PASS_DEF_DTENSORSETDEFAULTSHARDING |
1754 | |
1755 | //===----------------------------------------------------------------------===// |
1756 | // DTensorSparseExpansion |
1757 | //===----------------------------------------------------------------------===// |
1758 | #ifdef GEN_PASS_DECL_DTENSORSPARSEEXPANSION |
1759 | #undef GEN_PASS_DECL_DTENSORSPARSEEXPANSION |
1760 | #endif // GEN_PASS_DECL_DTENSORSPARSEEXPANSION |
1761 | #ifdef GEN_PASS_DEF_DTENSORSPARSEEXPANSION |
1762 | namespace impl { |
1763 | |
1764 | template <typename DerivedT> |
1765 | class DTensorSparseExpansionBase : public ::mlir::OperationPass<mlir::ModuleOp> { |
1766 | public: |
1767 | using Base = DTensorSparseExpansionBase; |
1768 | |
1769 | DTensorSparseExpansionBase() : ::mlir::OperationPass<mlir::ModuleOp>(::mlir::TypeID::get<DerivedT>()) {} |
1770 | DTensorSparseExpansionBase(const DTensorSparseExpansionBase &other) : ::mlir::OperationPass<mlir::ModuleOp>(other) {} |
1771 | |
1772 | /// Returns the command-line argument attached to this pass. |
1773 | static constexpr ::llvm::StringLiteral getArgumentName() { |
1774 | return ::llvm::StringLiteral("dtensor-sparse-expansion" ); |
1775 | } |
1776 | ::llvm::StringRef getArgument() const override { return "dtensor-sparse-expansion" ; } |
1777 | |
1778 | ::llvm::StringRef getDescription() const override { return "Convert ops that take in SparseTensor input to its corresponding Sparse or Dense ops." ; } |
1779 | |
1780 | /// Returns the derived pass name. |
1781 | static constexpr ::llvm::StringLiteral getPassName() { |
1782 | return ::llvm::StringLiteral("DTensorSparseExpansion" ); |
1783 | } |
1784 | ::llvm::StringRef getName() const override { return "DTensorSparseExpansion" ; } |
1785 | |
1786 | /// Support isa/dyn_cast functionality for the derived pass class. |
1787 | static bool classof(const ::mlir::Pass *pass) { |
1788 | return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>(); |
1789 | } |
1790 | |
1791 | /// A clone method to create a copy of this pass. |
1792 | std::unique_ptr<::mlir::Pass> clonePass() const override { |
1793 | return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this)); |
1794 | } |
1795 | |
1796 | /// Return the dialect that must be loaded in the context before this pass. |
1797 | void getDependentDialects(::mlir::DialectRegistry ®istry) const override { |
1798 | |
1799 | } |
1800 | |
1801 | /// Explicitly declare the TypeID for this class. We declare an explicit private |
1802 | /// instantiation because Pass classes should only be visible by the current |
1803 | /// library. |
1804 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DTensorSparseExpansionBase<DerivedT>) |
1805 | |
1806 | protected: |
1807 | private: |
1808 | }; |
1809 | } // namespace impl |
1810 | #undef GEN_PASS_DEF_DTENSORSPARSEEXPANSION |
1811 | #endif // GEN_PASS_DEF_DTENSORSPARSEEXPANSION |
1812 | |
1813 | //===----------------------------------------------------------------------===// |
1814 | // DTensorSparseTensorToDenseTensor |
1815 | //===----------------------------------------------------------------------===// |
1816 | #ifdef GEN_PASS_DECL_DTENSORSPARSETENSORTODENSETENSOR |
1817 | #undef GEN_PASS_DECL_DTENSORSPARSETENSORTODENSETENSOR |
1818 | #endif // GEN_PASS_DECL_DTENSORSPARSETENSORTODENSETENSOR |
1819 | #ifdef GEN_PASS_DEF_DTENSORSPARSETENSORTODENSETENSOR |
1820 | namespace impl { |
1821 | |
1822 | template <typename DerivedT> |
1823 | class DTensorSparseTensorToDenseTensorBase : public ::mlir::OperationPass<mlir::ModuleOp> { |
1824 | public: |
1825 | using Base = DTensorSparseTensorToDenseTensorBase; |
1826 | |
1827 | DTensorSparseTensorToDenseTensorBase() : ::mlir::OperationPass<mlir::ModuleOp>(::mlir::TypeID::get<DerivedT>()) {} |
1828 | DTensorSparseTensorToDenseTensorBase(const DTensorSparseTensorToDenseTensorBase &other) : ::mlir::OperationPass<mlir::ModuleOp>(other) {} |
1829 | |
1830 | /// Returns the command-line argument attached to this pass. |
1831 | static constexpr ::llvm::StringLiteral getArgumentName() { |
1832 | return ::llvm::StringLiteral("dtensor-sparse-tensor-to-dense-tensor" ); |
1833 | } |
1834 | ::llvm::StringRef getArgument() const override { return "dtensor-sparse-tensor-to-dense-tensor" ; } |
1835 | |
1836 | ::llvm::StringRef getDescription() const override { return "Converts SparseTensor inputs to its component tensors inputs and emits a SparseToDenseOp for every op that consumes a SparseTensor." ; } |
1837 | |
1838 | /// Returns the derived pass name. |
1839 | static constexpr ::llvm::StringLiteral getPassName() { |
1840 | return ::llvm::StringLiteral("DTensorSparseTensorToDenseTensor" ); |
1841 | } |
1842 | ::llvm::StringRef getName() const override { return "DTensorSparseTensorToDenseTensor" ; } |
1843 | |
1844 | /// Support isa/dyn_cast functionality for the derived pass class. |
1845 | static bool classof(const ::mlir::Pass *pass) { |
1846 | return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>(); |
1847 | } |
1848 | |
1849 | /// A clone method to create a copy of this pass. |
1850 | std::unique_ptr<::mlir::Pass> clonePass() const override { |
1851 | return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this)); |
1852 | } |
1853 | |
1854 | /// Return the dialect that must be loaded in the context before this pass. |
1855 | void getDependentDialects(::mlir::DialectRegistry ®istry) const override { |
1856 | |
1857 | } |
1858 | |
1859 | /// Explicitly declare the TypeID for this class. We declare an explicit private |
1860 | /// instantiation because Pass classes should only be visible by the current |
1861 | /// library. |
1862 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DTensorSparseTensorToDenseTensorBase<DerivedT>) |
1863 | |
1864 | protected: |
1865 | private: |
1866 | }; |
1867 | } // namespace impl |
1868 | #undef GEN_PASS_DEF_DTENSORSPARSETENSORTODENSETENSOR |
1869 | #endif // GEN_PASS_DEF_DTENSORSPARSETENSORTODENSETENSOR |
1870 | |
1871 | //===----------------------------------------------------------------------===// |
1872 | // DTensorTPUIntegration |
1873 | //===----------------------------------------------------------------------===// |
1874 | #ifdef GEN_PASS_DECL_DTENSORTPUINTEGRATION |
1875 | #undef GEN_PASS_DECL_DTENSORTPUINTEGRATION |
1876 | #endif // GEN_PASS_DECL_DTENSORTPUINTEGRATION |
1877 | #ifdef GEN_PASS_DEF_DTENSORTPUINTEGRATION |
1878 | namespace impl { |
1879 | |
1880 | template <typename DerivedT> |
1881 | class DTensorTPUIntegrationBase : public ::mlir::OperationPass<mlir::ModuleOp> { |
1882 | public: |
1883 | using Base = DTensorTPUIntegrationBase; |
1884 | |
1885 | DTensorTPUIntegrationBase() : ::mlir::OperationPass<mlir::ModuleOp>(::mlir::TypeID::get<DerivedT>()) {} |
1886 | DTensorTPUIntegrationBase(const DTensorTPUIntegrationBase &other) : ::mlir::OperationPass<mlir::ModuleOp>(other) {} |
1887 | |
1888 | /// Returns the command-line argument attached to this pass. |
1889 | static constexpr ::llvm::StringLiteral getArgumentName() { |
1890 | return ::llvm::StringLiteral("dtensor-tpu-integration" ); |
1891 | } |
1892 | ::llvm::StringRef getArgument() const override { return "dtensor-tpu-integration" ; } |
1893 | |
1894 | ::llvm::StringRef getDescription() const override { return "Adds TPUReplicateMetadata and converts ops that run on TPU's to a single tf_device.cluster to be compatible with following TF2XLA MLIR passes." ; } |
1895 | |
1896 | /// Returns the derived pass name. |
1897 | static constexpr ::llvm::StringLiteral getPassName() { |
1898 | return ::llvm::StringLiteral("DTensorTPUIntegration" ); |
1899 | } |
1900 | ::llvm::StringRef getName() const override { return "DTensorTPUIntegration" ; } |
1901 | |
1902 | /// Support isa/dyn_cast functionality for the derived pass class. |
1903 | static bool classof(const ::mlir::Pass *pass) { |
1904 | return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>(); |
1905 | } |
1906 | |
1907 | /// A clone method to create a copy of this pass. |
1908 | std::unique_ptr<::mlir::Pass> clonePass() const override { |
1909 | return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this)); |
1910 | } |
1911 | |
1912 | /// Return the dialect that must be loaded in the context before this pass. |
1913 | void getDependentDialects(::mlir::DialectRegistry ®istry) const override { |
1914 | |
1915 | } |
1916 | |
1917 | /// Explicitly declare the TypeID for this class. We declare an explicit private |
1918 | /// instantiation because Pass classes should only be visible by the current |
1919 | /// library. |
1920 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DTensorTPUIntegrationBase<DerivedT>) |
1921 | |
1922 | protected: |
1923 | private: |
1924 | }; |
1925 | } // namespace impl |
1926 | #undef GEN_PASS_DEF_DTENSORTPUINTEGRATION |
1927 | #endif // GEN_PASS_DEF_DTENSORTPUINTEGRATION |
1928 | |
1929 | //===----------------------------------------------------------------------===// |
1930 | // DTensorTpuAddResourceDeviceAttribute |
1931 | //===----------------------------------------------------------------------===// |
1932 | #ifdef GEN_PASS_DECL_DTENSORTPUADDRESOURCEDEVICEATTRIBUTE |
1933 | #undef GEN_PASS_DECL_DTENSORTPUADDRESOURCEDEVICEATTRIBUTE |
1934 | #endif // GEN_PASS_DECL_DTENSORTPUADDRESOURCEDEVICEATTRIBUTE |
1935 | #ifdef GEN_PASS_DEF_DTENSORTPUADDRESOURCEDEVICEATTRIBUTE |
1936 | namespace impl { |
1937 | |
1938 | template <typename DerivedT> |
1939 | class DTensorTpuAddResourceDeviceAttributeBase : public ::mlir::OperationPass<mlir::ModuleOp> { |
1940 | public: |
1941 | using Base = DTensorTpuAddResourceDeviceAttributeBase; |
1942 | |
1943 | DTensorTpuAddResourceDeviceAttributeBase() : ::mlir::OperationPass<mlir::ModuleOp>(::mlir::TypeID::get<DerivedT>()) {} |
1944 | DTensorTpuAddResourceDeviceAttributeBase(const DTensorTpuAddResourceDeviceAttributeBase &other) : ::mlir::OperationPass<mlir::ModuleOp>(other) {} |
1945 | |
1946 | /// Returns the command-line argument attached to this pass. |
1947 | static constexpr ::llvm::StringLiteral getArgumentName() { |
1948 | return ::llvm::StringLiteral("dtensor-tpu-add-resource-device-attribute" ); |
1949 | } |
1950 | ::llvm::StringRef getArgument() const override { return "dtensor-tpu-add-resource-device-attribute" ; } |
1951 | |
1952 | ::llvm::StringRef getDescription() const override { return "Adds placeholder device attributes to resources accessed by TPU computation to enable buffer aliasing." ; } |
1953 | |
1954 | /// Returns the derived pass name. |
1955 | static constexpr ::llvm::StringLiteral getPassName() { |
1956 | return ::llvm::StringLiteral("DTensorTpuAddResourceDeviceAttribute" ); |
1957 | } |
1958 | ::llvm::StringRef getName() const override { return "DTensorTpuAddResourceDeviceAttribute" ; } |
1959 | |
1960 | /// Support isa/dyn_cast functionality for the derived pass class. |
1961 | static bool classof(const ::mlir::Pass *pass) { |
1962 | return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>(); |
1963 | } |
1964 | |
1965 | /// A clone method to create a copy of this pass. |
1966 | std::unique_ptr<::mlir::Pass> clonePass() const override { |
1967 | return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this)); |
1968 | } |
1969 | |
1970 | /// Return the dialect that must be loaded in the context before this pass. |
1971 | void getDependentDialects(::mlir::DialectRegistry ®istry) const override { |
1972 | |
1973 | } |
1974 | |
1975 | /// Explicitly declare the TypeID for this class. We declare an explicit private |
1976 | /// instantiation because Pass classes should only be visible by the current |
1977 | /// library. |
1978 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DTensorTpuAddResourceDeviceAttributeBase<DerivedT>) |
1979 | |
1980 | protected: |
1981 | private: |
1982 | }; |
1983 | } // namespace impl |
1984 | #undef GEN_PASS_DEF_DTENSORTPUADDRESOURCEDEVICEATTRIBUTE |
1985 | #endif // GEN_PASS_DEF_DTENSORTPUADDRESOURCEDEVICEATTRIBUTE |
1986 | |
1987 | //===----------------------------------------------------------------------===// |
1988 | // DTensorUndoMergeConstAcrossMesh |
1989 | //===----------------------------------------------------------------------===// |
1990 | #ifdef GEN_PASS_DECL_DTENSORUNDOMERGECONSTACROSSMESH |
1991 | #undef GEN_PASS_DECL_DTENSORUNDOMERGECONSTACROSSMESH |
1992 | #endif // GEN_PASS_DECL_DTENSORUNDOMERGECONSTACROSSMESH |
1993 | #ifdef GEN_PASS_DEF_DTENSORUNDOMERGECONSTACROSSMESH |
1994 | namespace impl { |
1995 | |
1996 | template <typename DerivedT> |
1997 | class DTensorUndoMergeConstAcrossMeshBase : public ::mlir::OperationPass<mlir::func::FuncOp> { |
1998 | public: |
1999 | using Base = DTensorUndoMergeConstAcrossMeshBase; |
2000 | |
2001 | DTensorUndoMergeConstAcrossMeshBase() : ::mlir::OperationPass<mlir::func::FuncOp>(::mlir::TypeID::get<DerivedT>()) {} |
2002 | DTensorUndoMergeConstAcrossMeshBase(const DTensorUndoMergeConstAcrossMeshBase &other) : ::mlir::OperationPass<mlir::func::FuncOp>(other) {} |
2003 | |
2004 | /// Returns the command-line argument attached to this pass. |
2005 | static constexpr ::llvm::StringLiteral getArgumentName() { |
2006 | return ::llvm::StringLiteral("dtensor-undo-merge-const-across-mesh" ); |
2007 | } |
2008 | ::llvm::StringRef getArgument() const override { return "dtensor-undo-merge-const-across-mesh" ; } |
2009 | |
2010 | ::llvm::StringRef getDescription() const override { return "Undo constant merging across meshes" ; } |
2011 | |
2012 | /// Returns the derived pass name. |
2013 | static constexpr ::llvm::StringLiteral getPassName() { |
2014 | return ::llvm::StringLiteral("DTensorUndoMergeConstAcrossMesh" ); |
2015 | } |
2016 | ::llvm::StringRef getName() const override { return "DTensorUndoMergeConstAcrossMesh" ; } |
2017 | |
2018 | /// Support isa/dyn_cast functionality for the derived pass class. |
2019 | static bool classof(const ::mlir::Pass *pass) { |
2020 | return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>(); |
2021 | } |
2022 | |
2023 | /// A clone method to create a copy of this pass. |
2024 | std::unique_ptr<::mlir::Pass> clonePass() const override { |
2025 | return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this)); |
2026 | } |
2027 | |
2028 | /// Return the dialect that must be loaded in the context before this pass. |
2029 | void getDependentDialects(::mlir::DialectRegistry ®istry) const override { |
2030 | |
2031 | } |
2032 | |
2033 | /// Explicitly declare the TypeID for this class. We declare an explicit private |
2034 | /// instantiation because Pass classes should only be visible by the current |
2035 | /// library. |
2036 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DTensorUndoMergeConstAcrossMeshBase<DerivedT>) |
2037 | |
2038 | protected: |
2039 | private: |
2040 | }; |
2041 | } // namespace impl |
2042 | #undef GEN_PASS_DEF_DTENSORUNDOMERGECONSTACROSSMESH |
2043 | #endif // GEN_PASS_DEF_DTENSORUNDOMERGECONSTACROSSMESH |
2044 | |
2045 | //===----------------------------------------------------------------------===// |
2046 | // DTensorUpdateTPUMetadata |
2047 | //===----------------------------------------------------------------------===// |
2048 | #ifdef GEN_PASS_DECL_DTENSORUPDATETPUMETADATA |
2049 | #undef GEN_PASS_DECL_DTENSORUPDATETPUMETADATA |
2050 | #endif // GEN_PASS_DECL_DTENSORUPDATETPUMETADATA |
2051 | #ifdef GEN_PASS_DEF_DTENSORUPDATETPUMETADATA |
2052 | namespace impl { |
2053 | |
2054 | template <typename DerivedT> |
2055 | class DTensorUpdateTPUMetadataBase : public ::mlir::OperationPass<mlir::ModuleOp> { |
2056 | public: |
2057 | using Base = DTensorUpdateTPUMetadataBase; |
2058 | |
2059 | DTensorUpdateTPUMetadataBase() : ::mlir::OperationPass<mlir::ModuleOp>(::mlir::TypeID::get<DerivedT>()) {} |
2060 | DTensorUpdateTPUMetadataBase(const DTensorUpdateTPUMetadataBase &other) : ::mlir::OperationPass<mlir::ModuleOp>(other) {} |
2061 | |
2062 | /// Returns the command-line argument attached to this pass. |
2063 | static constexpr ::llvm::StringLiteral getArgumentName() { |
2064 | return ::llvm::StringLiteral("dtensor-update-tpu-metadata" ); |
2065 | } |
2066 | ::llvm::StringRef getArgument() const override { return "dtensor-update-tpu-metadata" ; } |
2067 | |
2068 | ::llvm::StringRef getDescription() const override { return "Changes metadata on TPU specific ops such as device placement." ; } |
2069 | |
2070 | /// Returns the derived pass name. |
2071 | static constexpr ::llvm::StringLiteral getPassName() { |
2072 | return ::llvm::StringLiteral("DTensorUpdateTPUMetadata" ); |
2073 | } |
2074 | ::llvm::StringRef getName() const override { return "DTensorUpdateTPUMetadata" ; } |
2075 | |
2076 | /// Support isa/dyn_cast functionality for the derived pass class. |
2077 | static bool classof(const ::mlir::Pass *pass) { |
2078 | return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>(); |
2079 | } |
2080 | |
2081 | /// A clone method to create a copy of this pass. |
2082 | std::unique_ptr<::mlir::Pass> clonePass() const override { |
2083 | return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this)); |
2084 | } |
2085 | |
2086 | /// Return the dialect that must be loaded in the context before this pass. |
2087 | void getDependentDialects(::mlir::DialectRegistry ®istry) const override { |
2088 | |
2089 | } |
2090 | |
2091 | /// Explicitly declare the TypeID for this class. We declare an explicit private |
2092 | /// instantiation because Pass classes should only be visible by the current |
2093 | /// library. |
2094 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DTensorUpdateTPUMetadataBase<DerivedT>) |
2095 | |
2096 | protected: |
2097 | private: |
2098 | }; |
2099 | } // namespace impl |
2100 | #undef GEN_PASS_DEF_DTENSORUPDATETPUMETADATA |
2101 | #endif // GEN_PASS_DEF_DTENSORUPDATETPUMETADATA |
2102 | #ifdef GEN_PASS_REGISTRATION |
2103 | |
2104 | //===----------------------------------------------------------------------===// |
2105 | // DTensorAllGatherLowering Registration |
2106 | //===----------------------------------------------------------------------===// |
2107 | |
2108 | inline void registerDTensorAllGatherLowering() { |
2109 | ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> { |
2110 | return CreateDTensorAllGatherLoweringPass(); |
2111 | }); |
2112 | } |
2113 | |
2114 | // Old registration code, kept for temporary backwards compatibility. |
2115 | inline void registerDTensorAllGatherLoweringPass() { |
2116 | ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> { |
2117 | return CreateDTensorAllGatherLoweringPass(); |
2118 | }); |
2119 | } |
2120 | |
2121 | //===----------------------------------------------------------------------===// |
2122 | // DTensorAllReduceCombineOptimization Registration |
2123 | //===----------------------------------------------------------------------===// |
2124 | |
2125 | inline void registerDTensorAllReduceCombineOptimization() { |
2126 | ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> { |
2127 | return CreateDTensorAllReduceCombineOptimization(); |
2128 | }); |
2129 | } |
2130 | |
2131 | // Old registration code, kept for temporary backwards compatibility. |
2132 | inline void registerDTensorAllReduceCombineOptimizationPass() { |
2133 | ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> { |
2134 | return CreateDTensorAllReduceCombineOptimization(); |
2135 | }); |
2136 | } |
2137 | |
2138 | //===----------------------------------------------------------------------===// |
2139 | // DTensorAllReduceLowering Registration |
2140 | //===----------------------------------------------------------------------===// |
2141 | |
2142 | inline void registerDTensorAllReduceLowering() { |
2143 | ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> { |
2144 | return CreateDTensorAllReduceLoweringPass(); |
2145 | }); |
2146 | } |
2147 | |
2148 | // Old registration code, kept for temporary backwards compatibility. |
2149 | inline void registerDTensorAllReduceLoweringPass() { |
2150 | ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> { |
2151 | return CreateDTensorAllReduceLoweringPass(); |
2152 | }); |
2153 | } |
2154 | |
2155 | //===----------------------------------------------------------------------===// |
2156 | // DTensorAllReduceScatterOptimization Registration |
2157 | //===----------------------------------------------------------------------===// |
2158 | |
2159 | inline void registerDTensorAllReduceScatterOptimization() { |
2160 | ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> { |
2161 | return CreateDTensorAllReduceScatterOptimization(); |
2162 | }); |
2163 | } |
2164 | |
2165 | // Old registration code, kept for temporary backwards compatibility. |
2166 | inline void registerDTensorAllReduceScatterOptimizationPass() { |
2167 | ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> { |
2168 | return CreateDTensorAllReduceScatterOptimization(); |
2169 | }); |
2170 | } |
2171 | |
2172 | //===----------------------------------------------------------------------===// |
2173 | // DTensorAllReduceSumOptimization Registration |
2174 | //===----------------------------------------------------------------------===// |
2175 | |
2176 | inline void registerDTensorAllReduceSumOptimization() { |
2177 | ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> { |
2178 | return CreateDTensorAllReduceSumOptimization(); |
2179 | }); |
2180 | } |
2181 | |
2182 | // Old registration code, kept for temporary backwards compatibility. |
2183 | inline void registerDTensorAllReduceSumOptimizationPass() { |
2184 | ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> { |
2185 | return CreateDTensorAllReduceSumOptimization(); |
2186 | }); |
2187 | } |
2188 | |
2189 | //===----------------------------------------------------------------------===// |
2190 | // DTensorAllScatterLowering Registration |
2191 | //===----------------------------------------------------------------------===// |
2192 | |
2193 | inline void registerDTensorAllScatterLowering() { |
2194 | ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> { |
2195 | return CreateDTensorAllScatterLoweringPass(); |
2196 | }); |
2197 | } |
2198 | |
2199 | // Old registration code, kept for temporary backwards compatibility. |
2200 | inline void registerDTensorAllScatterLoweringPass() { |
2201 | ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> { |
2202 | return CreateDTensorAllScatterLoweringPass(); |
2203 | }); |
2204 | } |
2205 | |
2206 | //===----------------------------------------------------------------------===// |
2207 | // DTensorAnnotateGlobalShape Registration |
2208 | //===----------------------------------------------------------------------===// |
2209 | |
2210 | inline void registerDTensorAnnotateGlobalShape() { |
2211 | ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> { |
2212 | return CreateDTensorAnnotateGlobalShape(); |
2213 | }); |
2214 | } |
2215 | |
2216 | // Old registration code, kept for temporary backwards compatibility. |
2217 | inline void registerDTensorAnnotateGlobalShapePass() { |
2218 | ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> { |
2219 | return CreateDTensorAnnotateGlobalShape(); |
2220 | }); |
2221 | } |
2222 | |
2223 | //===----------------------------------------------------------------------===// |
2224 | // DTensorClusterFunctionConversion Registration |
2225 | //===----------------------------------------------------------------------===// |
2226 | |
2227 | inline void registerDTensorClusterFunctionConversion() { |
2228 | ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> { |
2229 | return CreateDTensorClusterFunctionConversion(); |
2230 | }); |
2231 | } |
2232 | |
2233 | // Old registration code, kept for temporary backwards compatibility. |
2234 | inline void registerDTensorClusterFunctionConversionPass() { |
2235 | ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> { |
2236 | return CreateDTensorClusterFunctionConversion(); |
2237 | }); |
2238 | } |
2239 | |
2240 | //===----------------------------------------------------------------------===// |
2241 | // DTensorConstantFolding Registration |
2242 | //===----------------------------------------------------------------------===// |
2243 | |
2244 | inline void registerDTensorConstantFolding() { |
2245 | ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> { |
2246 | return CreateDTensorConstantFolding(); |
2247 | }); |
2248 | } |
2249 | |
2250 | // Old registration code, kept for temporary backwards compatibility. |
2251 | inline void registerDTensorConstantFoldingPass() { |
2252 | ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> { |
2253 | return CreateDTensorConstantFolding(); |
2254 | }); |
2255 | } |
2256 | |
2257 | //===----------------------------------------------------------------------===// |
2258 | // DTensorDCE Registration |
2259 | //===----------------------------------------------------------------------===// |
2260 | |
2261 | inline void registerDTensorDCE() { |
2262 | ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> { |
2263 | return CreateDTensorDCE(); |
2264 | }); |
2265 | } |
2266 | |
2267 | // Old registration code, kept for temporary backwards compatibility. |
2268 | inline void registerDTensorDCEPass() { |
2269 | ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> { |
2270 | return CreateDTensorDCE(); |
2271 | }); |
2272 | } |
2273 | |
2274 | //===----------------------------------------------------------------------===// |
2275 | // DTensorDesignateResourceHandleMesh Registration |
2276 | //===----------------------------------------------------------------------===// |
2277 | |
2278 | inline void registerDTensorDesignateResourceHandleMesh() { |
2279 | ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> { |
2280 | return CreateDTensorDesignateResourceHandleMesh(); |
2281 | }); |
2282 | } |
2283 | |
2284 | // Old registration code, kept for temporary backwards compatibility. |
2285 | inline void registerDTensorDesignateResourceHandleMeshPass() { |
2286 | ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> { |
2287 | return CreateDTensorDesignateResourceHandleMesh(); |
2288 | }); |
2289 | } |
2290 | |
2291 | //===----------------------------------------------------------------------===// |
2292 | // DTensorDeviceMeshClusterCoarsening Registration |
2293 | //===----------------------------------------------------------------------===// |
2294 | |
2295 | inline void registerDTensorDeviceMeshClusterCoarsening() { |
2296 | ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> { |
2297 | return CreateDTensorDeviceMeshClusterCoarsening(); |
2298 | }); |
2299 | } |
2300 | |
2301 | // Old registration code, kept for temporary backwards compatibility. |
2302 | inline void registerDTensorDeviceMeshClusterCoarseningPass() { |
2303 | ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> { |
2304 | return CreateDTensorDeviceMeshClusterCoarsening(); |
2305 | }); |
2306 | } |
2307 | |
2308 | //===----------------------------------------------------------------------===// |
2309 | // DTensorEmbedding Registration |
2310 | //===----------------------------------------------------------------------===// |
2311 | |
2312 | inline void registerDTensorEmbedding() { |
2313 | ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> { |
2314 | return CreateDTensorEmbeddingPass(); |
2315 | }); |
2316 | } |
2317 | |
2318 | // Old registration code, kept for temporary backwards compatibility. |
2319 | inline void registerDTensorEmbeddingPass() { |
2320 | ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> { |
2321 | return CreateDTensorEmbeddingPass(); |
2322 | }); |
2323 | } |
2324 | |
2325 | //===----------------------------------------------------------------------===// |
2326 | // DTensorEmbeddingCheckpoint Registration |
2327 | //===----------------------------------------------------------------------===// |
2328 | |
2329 | inline void registerDTensorEmbeddingCheckpoint() { |
2330 | ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> { |
2331 | return CreateDTensorEmbeddingCheckpointPass(); |
2332 | }); |
2333 | } |
2334 | |
2335 | // Old registration code, kept for temporary backwards compatibility. |
2336 | inline void registerDTensorEmbeddingCheckpointPass() { |
2337 | ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> { |
2338 | return CreateDTensorEmbeddingCheckpointPass(); |
2339 | }); |
2340 | } |
2341 | |
2342 | //===----------------------------------------------------------------------===// |
2343 | // DTensorEmbeddingV2 Registration |
2344 | //===----------------------------------------------------------------------===// |
2345 | |
2346 | inline void registerDTensorEmbeddingV2() { |
2347 | ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> { |
2348 | return CreateDTensorEmbeddingPassV2(); |
2349 | }); |
2350 | } |
2351 | |
2352 | // Old registration code, kept for temporary backwards compatibility. |
2353 | inline void registerDTensorEmbeddingV2Pass() { |
2354 | ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> { |
2355 | return CreateDTensorEmbeddingPassV2(); |
2356 | }); |
2357 | } |
2358 | |
2359 | //===----------------------------------------------------------------------===// |
2360 | // DTensorFunctionRenaming Registration |
2361 | //===----------------------------------------------------------------------===// |
2362 | |
2363 | inline void registerDTensorFunctionRenaming() { |
2364 | ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> { |
2365 | return CreateFunctionRenamingPass(); |
2366 | }); |
2367 | } |
2368 | |
2369 | // Old registration code, kept for temporary backwards compatibility. |
2370 | inline void registerDTensorFunctionRenamingPass() { |
2371 | ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> { |
2372 | return CreateFunctionRenamingPass(); |
2373 | }); |
2374 | } |
2375 | |
2376 | //===----------------------------------------------------------------------===// |
2377 | // DTensorHandleCrossClusterDependencies Registration |
2378 | //===----------------------------------------------------------------------===// |
2379 | |
2380 | inline void registerDTensorHandleCrossClusterDependencies() { |
2381 | ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> { |
2382 | return CreateDTensorHandleCrossClusterDependencies(); |
2383 | }); |
2384 | } |
2385 | |
2386 | // Old registration code, kept for temporary backwards compatibility. |
2387 | inline void registerDTensorHandleCrossClusterDependenciesPass() { |
2388 | ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> { |
2389 | return CreateDTensorHandleCrossClusterDependencies(); |
2390 | }); |
2391 | } |
2392 | |
2393 | //===----------------------------------------------------------------------===// |
2394 | // DTensorInferShapesForRestoreV2Op Registration |
2395 | //===----------------------------------------------------------------------===// |
2396 | |
2397 | inline void registerDTensorInferShapesForRestoreV2Op() { |
2398 | ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> { |
2399 | return CreateDTensorInferShapesForRestoreV2Op(); |
2400 | }); |
2401 | } |
2402 | |
2403 | // Old registration code, kept for temporary backwards compatibility. |
2404 | inline void registerDTensorInferShapesForRestoreV2OpPass() { |
2405 | ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> { |
2406 | return CreateDTensorInferShapesForRestoreV2Op(); |
2407 | }); |
2408 | } |
2409 | |
2410 | //===----------------------------------------------------------------------===// |
2411 | // DTensorLayoutPropagationV2 Registration |
2412 | //===----------------------------------------------------------------------===// |
2413 | |
2414 | inline void registerDTensorLayoutPropagationV2() { |
2415 | ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> { |
2416 | return CreateDTensorLayoutPropagationPassV2(); |
2417 | }); |
2418 | } |
2419 | |
2420 | // Old registration code, kept for temporary backwards compatibility. |
2421 | inline void registerDTensorLayoutPropagationV2Pass() { |
2422 | ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> { |
2423 | return CreateDTensorLayoutPropagationPassV2(); |
2424 | }); |
2425 | } |
2426 | |
2427 | //===----------------------------------------------------------------------===// |
2428 | // DTensorLowerSendRecv Registration |
2429 | //===----------------------------------------------------------------------===// |
2430 | |
2431 | inline void registerDTensorLowerSendRecv() { |
2432 | ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> { |
2433 | return CreateDTensorLowerSendRecv(); |
2434 | }); |
2435 | } |
2436 | |
2437 | // Old registration code, kept for temporary backwards compatibility. |
2438 | inline void registerDTensorLowerSendRecvPass() { |
2439 | ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> { |
2440 | return CreateDTensorLowerSendRecv(); |
2441 | }); |
2442 | } |
2443 | |
2444 | //===----------------------------------------------------------------------===// |
2445 | // DTensorMergeClusters Registration |
2446 | //===----------------------------------------------------------------------===// |
2447 | |
2448 | inline void registerDTensorMergeClusters() { |
2449 | ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> { |
2450 | return CreateDTensorMergeClustersPass(); |
2451 | }); |
2452 | } |
2453 | |
2454 | // Old registration code, kept for temporary backwards compatibility. |
2455 | inline void registerDTensorMergeClustersPass() { |
2456 | ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> { |
2457 | return CreateDTensorMergeClustersPass(); |
2458 | }); |
2459 | } |
2460 | |
2461 | //===----------------------------------------------------------------------===// |
2462 | // DTensorMeshPropagation Registration |
2463 | //===----------------------------------------------------------------------===// |
2464 | |
2465 | inline void registerDTensorMeshPropagation() { |
2466 | ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> { |
2467 | return CreateDTensorMeshPropagationPass(); |
2468 | }); |
2469 | } |
2470 | |
2471 | // Old registration code, kept for temporary backwards compatibility. |
2472 | inline void registerDTensorMeshPropagationPass() { |
2473 | ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> { |
2474 | return CreateDTensorMeshPropagationPass(); |
2475 | }); |
2476 | } |
2477 | |
2478 | //===----------------------------------------------------------------------===// |
2479 | // DTensorMixedPrecisionReduce Registration |
2480 | //===----------------------------------------------------------------------===// |
2481 | |
2482 | inline void registerDTensorMixedPrecisionReduce() { |
2483 | ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> { |
2484 | return CreateDTensorMixedPrecisionReducePass(); |
2485 | }); |
2486 | } |
2487 | |
2488 | // Old registration code, kept for temporary backwards compatibility. |
2489 | inline void registerDTensorMixedPrecisionReducePass() { |
2490 | ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> { |
2491 | return CreateDTensorMixedPrecisionReducePass(); |
2492 | }); |
2493 | } |
2494 | |
2495 | //===----------------------------------------------------------------------===// |
2496 | // DTensorMoveCompilationToHost Registration |
2497 | //===----------------------------------------------------------------------===// |
2498 | |
2499 | inline void registerDTensorMoveCompilationToHost() { |
2500 | ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> { |
2501 | return CreateDTensorMoveCompilationToHost(); |
2502 | }); |
2503 | } |
2504 | |
2505 | // Old registration code, kept for temporary backwards compatibility. |
2506 | inline void registerDTensorMoveCompilationToHostPass() { |
2507 | ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> { |
2508 | return CreateDTensorMoveCompilationToHost(); |
2509 | }); |
2510 | } |
2511 | |
2512 | //===----------------------------------------------------------------------===// |
2513 | // DTensorOpToDeviceCluster Registration |
2514 | //===----------------------------------------------------------------------===// |
2515 | |
2516 | inline void registerDTensorOpToDeviceCluster() { |
2517 | ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> { |
2518 | return CreateDTensorOpToDeviceClusterPass(); |
2519 | }); |
2520 | } |
2521 | |
2522 | // Old registration code, kept for temporary backwards compatibility. |
2523 | inline void registerDTensorOpToDeviceClusterPass() { |
2524 | ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> { |
2525 | return CreateDTensorOpToDeviceClusterPass(); |
2526 | }); |
2527 | } |
2528 | |
2529 | //===----------------------------------------------------------------------===// |
2530 | // DTensorPropagateDefaultLayout Registration |
2531 | //===----------------------------------------------------------------------===// |
2532 | |
2533 | inline void registerDTensorPropagateDefaultLayout() { |
2534 | ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> { |
2535 | return CreateDTensorPropagateDefaultLayout(); |
2536 | }); |
2537 | } |
2538 | |
2539 | // Old registration code, kept for temporary backwards compatibility. |
2540 | inline void registerDTensorPropagateDefaultLayoutPass() { |
2541 | ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> { |
2542 | return CreateDTensorPropagateDefaultLayout(); |
2543 | }); |
2544 | } |
2545 | |
2546 | //===----------------------------------------------------------------------===// |
2547 | // DTensorPropagateDeviceIdToFunctionArgs Registration |
2548 | //===----------------------------------------------------------------------===// |
2549 | |
2550 | inline void registerDTensorPropagateDeviceIdToFunctionArgs() { |
2551 | ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> { |
2552 | return CreateDTensorPropagateDeviceIdToFunctionArgs(); |
2553 | }); |
2554 | } |
2555 | |
2556 | // Old registration code, kept for temporary backwards compatibility. |
2557 | inline void registerDTensorPropagateDeviceIdToFunctionArgsPass() { |
2558 | ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> { |
2559 | return CreateDTensorPropagateDeviceIdToFunctionArgs(); |
2560 | }); |
2561 | } |
2562 | |
2563 | //===----------------------------------------------------------------------===// |
2564 | // DTensorReduceScatterLowering Registration |
2565 | //===----------------------------------------------------------------------===// |
2566 | |
2567 | inline void registerDTensorReduceScatterLowering() { |
2568 | ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> { |
2569 | return CreateDTensorReduceScatterLoweringPass(); |
2570 | }); |
2571 | } |
2572 | |
2573 | // Old registration code, kept for temporary backwards compatibility. |
2574 | inline void registerDTensorReduceScatterLoweringPass() { |
2575 | ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> { |
2576 | return CreateDTensorReduceScatterLoweringPass(); |
2577 | }); |
2578 | } |
2579 | |
2580 | //===----------------------------------------------------------------------===// |
2581 | // DTensorSPMDExpansion Registration |
2582 | //===----------------------------------------------------------------------===// |
2583 | |
2584 | inline void registerDTensorSPMDExpansion() { |
2585 | ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> { |
2586 | return CreateDTensorSPMDExpansion(); |
2587 | }); |
2588 | } |
2589 | |
2590 | // Old registration code, kept for temporary backwards compatibility. |
2591 | inline void registerDTensorSPMDExpansionPass() { |
2592 | ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> { |
2593 | return CreateDTensorSPMDExpansion(); |
2594 | }); |
2595 | } |
2596 | |
2597 | //===----------------------------------------------------------------------===// |
2598 | // DTensorSetDefaultSharding Registration |
2599 | //===----------------------------------------------------------------------===// |
2600 | |
2601 | inline void registerDTensorSetDefaultSharding() { |
2602 | ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> { |
2603 | return CreateDTensorSetDefaultSharding(); |
2604 | }); |
2605 | } |
2606 | |
2607 | // Old registration code, kept for temporary backwards compatibility. |
2608 | inline void registerDTensorSetDefaultShardingPass() { |
2609 | ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> { |
2610 | return CreateDTensorSetDefaultSharding(); |
2611 | }); |
2612 | } |
2613 | |
2614 | //===----------------------------------------------------------------------===// |
2615 | // DTensorSparseExpansion Registration |
2616 | //===----------------------------------------------------------------------===// |
2617 | |
2618 | inline void registerDTensorSparseExpansion() { |
2619 | ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> { |
2620 | return CreateDTensorSparseExpansion(); |
2621 | }); |
2622 | } |
2623 | |
2624 | // Old registration code, kept for temporary backwards compatibility. |
2625 | inline void registerDTensorSparseExpansionPass() { |
2626 | ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> { |
2627 | return CreateDTensorSparseExpansion(); |
2628 | }); |
2629 | } |
2630 | |
2631 | //===----------------------------------------------------------------------===// |
2632 | // DTensorSparseTensorToDenseTensor Registration |
2633 | //===----------------------------------------------------------------------===// |
2634 | |
2635 | inline void registerDTensorSparseTensorToDenseTensor() { |
2636 | ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> { |
2637 | return CreateDTensorSparseTensorToDenseTensor(); |
2638 | }); |
2639 | } |
2640 | |
2641 | // Old registration code, kept for temporary backwards compatibility. |
2642 | inline void registerDTensorSparseTensorToDenseTensorPass() { |
2643 | ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> { |
2644 | return CreateDTensorSparseTensorToDenseTensor(); |
2645 | }); |
2646 | } |
2647 | |
2648 | //===----------------------------------------------------------------------===// |
2649 | // DTensorTPUIntegration Registration |
2650 | //===----------------------------------------------------------------------===// |
2651 | |
2652 | inline void registerDTensorTPUIntegration() { |
2653 | ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> { |
2654 | return CreateDTensorTPUIntegration(); |
2655 | }); |
2656 | } |
2657 | |
2658 | // Old registration code, kept for temporary backwards compatibility. |
2659 | inline void registerDTensorTPUIntegrationPass() { |
2660 | ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> { |
2661 | return CreateDTensorTPUIntegration(); |
2662 | }); |
2663 | } |
2664 | |
2665 | //===----------------------------------------------------------------------===// |
2666 | // DTensorTpuAddResourceDeviceAttribute Registration |
2667 | //===----------------------------------------------------------------------===// |
2668 | |
2669 | inline void registerDTensorTpuAddResourceDeviceAttribute() { |
2670 | ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> { |
2671 | return CreateDTensorTpuAddResourceDeviceAttribute(); |
2672 | }); |
2673 | } |
2674 | |
2675 | // Old registration code, kept for temporary backwards compatibility. |
2676 | inline void registerDTensorTpuAddResourceDeviceAttributePass() { |
2677 | ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> { |
2678 | return CreateDTensorTpuAddResourceDeviceAttribute(); |
2679 | }); |
2680 | } |
2681 | |
2682 | //===----------------------------------------------------------------------===// |
2683 | // DTensorUndoMergeConstAcrossMesh Registration |
2684 | //===----------------------------------------------------------------------===// |
2685 | |
2686 | inline void registerDTensorUndoMergeConstAcrossMesh() { |
2687 | ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> { |
2688 | return CreateDTensorUndoMergeConstAcrossMesh(); |
2689 | }); |
2690 | } |
2691 | |
2692 | // Old registration code, kept for temporary backwards compatibility. |
2693 | inline void registerDTensorUndoMergeConstAcrossMeshPass() { |
2694 | ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> { |
2695 | return CreateDTensorUndoMergeConstAcrossMesh(); |
2696 | }); |
2697 | } |
2698 | |
2699 | //===----------------------------------------------------------------------===// |
2700 | // DTensorUpdateTPUMetadata Registration |
2701 | //===----------------------------------------------------------------------===// |
2702 | |
2703 | inline void registerDTensorUpdateTPUMetadata() { |
2704 | ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> { |
2705 | return CreateDTensorUpdateTPUMetadata(); |
2706 | }); |
2707 | } |
2708 | |
2709 | // Old registration code, kept for temporary backwards compatibility. |
2710 | inline void registerDTensorUpdateTPUMetadataPass() { |
2711 | ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> { |
2712 | return CreateDTensorUpdateTPUMetadata(); |
2713 | }); |
2714 | } |
2715 | |
2716 | //===----------------------------------------------------------------------===// |
2717 | // DTensor Registration |
2718 | //===----------------------------------------------------------------------===// |
2719 | |
2720 | inline void registerDTensorPasses() { |
2721 | registerDTensorAllGatherLowering(); |
2722 | registerDTensorAllReduceCombineOptimization(); |
2723 | registerDTensorAllReduceLowering(); |
2724 | registerDTensorAllReduceScatterOptimization(); |
2725 | registerDTensorAllReduceSumOptimization(); |
2726 | registerDTensorAllScatterLowering(); |
2727 | registerDTensorAnnotateGlobalShape(); |
2728 | registerDTensorClusterFunctionConversion(); |
2729 | registerDTensorConstantFolding(); |
2730 | registerDTensorDCE(); |
2731 | registerDTensorDesignateResourceHandleMesh(); |
2732 | registerDTensorDeviceMeshClusterCoarsening(); |
2733 | registerDTensorEmbedding(); |
2734 | registerDTensorEmbeddingCheckpoint(); |
2735 | registerDTensorEmbeddingV2(); |
2736 | registerDTensorFunctionRenaming(); |
2737 | registerDTensorHandleCrossClusterDependencies(); |
2738 | registerDTensorInferShapesForRestoreV2Op(); |
2739 | registerDTensorLayoutPropagationV2(); |
2740 | registerDTensorLowerSendRecv(); |
2741 | registerDTensorMergeClusters(); |
2742 | registerDTensorMeshPropagation(); |
2743 | registerDTensorMixedPrecisionReduce(); |
2744 | registerDTensorMoveCompilationToHost(); |
2745 | registerDTensorOpToDeviceCluster(); |
2746 | registerDTensorPropagateDefaultLayout(); |
2747 | registerDTensorPropagateDeviceIdToFunctionArgs(); |
2748 | registerDTensorReduceScatterLowering(); |
2749 | registerDTensorSPMDExpansion(); |
2750 | registerDTensorSetDefaultSharding(); |
2751 | registerDTensorSparseExpansion(); |
2752 | registerDTensorSparseTensorToDenseTensor(); |
2753 | registerDTensorTPUIntegration(); |
2754 | registerDTensorTpuAddResourceDeviceAttribute(); |
2755 | registerDTensorUndoMergeConstAcrossMesh(); |
2756 | registerDTensorUpdateTPUMetadata(); |
2757 | } |
2758 | #undef GEN_PASS_REGISTRATION |
2759 | #endif // GEN_PASS_REGISTRATION |
2760 | // Deprecated. Please use the new per-pass macros. |
2761 | #ifdef GEN_PASS_CLASSES |
2762 | |
2763 | template <typename DerivedT> |
2764 | class DTensorAllGatherLoweringBase : public ::mlir::OperationPass<mlir::ModuleOp> { |
2765 | public: |
2766 | using Base = DTensorAllGatherLoweringBase; |
2767 | |
2768 | DTensorAllGatherLoweringBase() : ::mlir::OperationPass<mlir::ModuleOp>(::mlir::TypeID::get<DerivedT>()) {} |
2769 | DTensorAllGatherLoweringBase(const DTensorAllGatherLoweringBase &other) : ::mlir::OperationPass<mlir::ModuleOp>(other) {} |
2770 | |
2771 | /// Returns the command-line argument attached to this pass. |
2772 | static constexpr ::llvm::StringLiteral getArgumentName() { |
2773 | return ::llvm::StringLiteral("dtensor-all-gather-lowering" ); |
2774 | } |
2775 | ::llvm::StringRef getArgument() const override { return "dtensor-all-gather-lowering" ; } |
2776 | |
2777 | ::llvm::StringRef getDescription() const override { return "Converts logical AllGather ops into physical AllGather ops." ; } |
2778 | |
2779 | /// Returns the derived pass name. |
2780 | static constexpr ::llvm::StringLiteral getPassName() { |
2781 | return ::llvm::StringLiteral("DTensorAllGatherLowering" ); |
2782 | } |
2783 | ::llvm::StringRef getName() const override { return "DTensorAllGatherLowering" ; } |
2784 | |
2785 | /// Support isa/dyn_cast functionality for the derived pass class. |
2786 | static bool classof(const ::mlir::Pass *pass) { |
2787 | return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>(); |
2788 | } |
2789 | |
2790 | /// A clone method to create a copy of this pass. |
2791 | std::unique_ptr<::mlir::Pass> clonePass() const override { |
2792 | return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this)); |
2793 | } |
2794 | |
2795 | /// Return the dialect that must be loaded in the context before this pass. |
2796 | void getDependentDialects(::mlir::DialectRegistry ®istry) const override { |
2797 | |
2798 | } |
2799 | |
2800 | /// Explicitly declare the TypeID for this class. We declare an explicit private |
2801 | /// instantiation because Pass classes should only be visible by the current |
2802 | /// library. |
2803 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DTensorAllGatherLoweringBase<DerivedT>) |
2804 | |
2805 | protected: |
2806 | }; |
2807 | |
2808 | template <typename DerivedT> |
2809 | class DTensorAllReduceCombineOptimizationBase : public ::mlir::OperationPass<mlir::func::FuncOp> { |
2810 | public: |
2811 | using Base = DTensorAllReduceCombineOptimizationBase; |
2812 | |
2813 | DTensorAllReduceCombineOptimizationBase() : ::mlir::OperationPass<mlir::func::FuncOp>(::mlir::TypeID::get<DerivedT>()) {} |
2814 | DTensorAllReduceCombineOptimizationBase(const DTensorAllReduceCombineOptimizationBase &other) : ::mlir::OperationPass<mlir::func::FuncOp>(other) {} |
2815 | |
2816 | /// Returns the command-line argument attached to this pass. |
2817 | static constexpr ::llvm::StringLiteral getArgumentName() { |
2818 | return ::llvm::StringLiteral("dtensor-allreduce-combine-optimization" ); |
2819 | } |
2820 | ::llvm::StringRef getArgument() const override { return "dtensor-allreduce-combine-optimization" ; } |
2821 | |
2822 | ::llvm::StringRef getDescription() const override { return "Combine independent all reduce operations." ; } |
2823 | |
2824 | /// Returns the derived pass name. |
2825 | static constexpr ::llvm::StringLiteral getPassName() { |
2826 | return ::llvm::StringLiteral("DTensorAllReduceCombineOptimization" ); |
2827 | } |
2828 | ::llvm::StringRef getName() const override { return "DTensorAllReduceCombineOptimization" ; } |
2829 | |
2830 | /// Support isa/dyn_cast functionality for the derived pass class. |
2831 | static bool classof(const ::mlir::Pass *pass) { |
2832 | return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>(); |
2833 | } |
2834 | |
2835 | /// A clone method to create a copy of this pass. |
2836 | std::unique_ptr<::mlir::Pass> clonePass() const override { |
2837 | return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this)); |
2838 | } |
2839 | |
2840 | /// Return the dialect that must be loaded in the context before this pass. |
2841 | void getDependentDialects(::mlir::DialectRegistry ®istry) const override { |
2842 | |
2843 | } |
2844 | |
2845 | /// Explicitly declare the TypeID for this class. We declare an explicit private |
2846 | /// instantiation because Pass classes should only be visible by the current |
2847 | /// library. |
2848 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DTensorAllReduceCombineOptimizationBase<DerivedT>) |
2849 | |
2850 | protected: |
2851 | }; |
2852 | |
2853 | template <typename DerivedT> |
2854 | class DTensorAllReduceLoweringBase : public ::mlir::OperationPass<mlir::ModuleOp> { |
2855 | public: |
2856 | using Base = DTensorAllReduceLoweringBase; |
2857 | |
2858 | DTensorAllReduceLoweringBase() : ::mlir::OperationPass<mlir::ModuleOp>(::mlir::TypeID::get<DerivedT>()) {} |
2859 | DTensorAllReduceLoweringBase(const DTensorAllReduceLoweringBase &other) : ::mlir::OperationPass<mlir::ModuleOp>(other) {} |
2860 | |
2861 | /// Returns the command-line argument attached to this pass. |
2862 | static constexpr ::llvm::StringLiteral getArgumentName() { |
2863 | return ::llvm::StringLiteral("dtensor-all-reduce-lowering" ); |
2864 | } |
2865 | ::llvm::StringRef getArgument() const override { return "dtensor-all-reduce-lowering" ; } |
2866 | |
2867 | ::llvm::StringRef getDescription() const override { return "Converts logical AllReduce ops into physical AllReduce ops." ; } |
2868 | |
2869 | /// Returns the derived pass name. |
2870 | static constexpr ::llvm::StringLiteral getPassName() { |
2871 | return ::llvm::StringLiteral("DTensorAllReduceLowering" ); |
2872 | } |
2873 | ::llvm::StringRef getName() const override { return "DTensorAllReduceLowering" ; } |
2874 | |
2875 | /// Support isa/dyn_cast functionality for the derived pass class. |
2876 | static bool classof(const ::mlir::Pass *pass) { |
2877 | return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>(); |
2878 | } |
2879 | |
2880 | /// A clone method to create a copy of this pass. |
2881 | std::unique_ptr<::mlir::Pass> clonePass() const override { |
2882 | return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this)); |
2883 | } |
2884 | |
2885 | /// Return the dialect that must be loaded in the context before this pass. |
2886 | void getDependentDialects(::mlir::DialectRegistry ®istry) const override { |
2887 | |
2888 | } |
2889 | |
2890 | /// Explicitly declare the TypeID for this class. We declare an explicit private |
2891 | /// instantiation because Pass classes should only be visible by the current |
2892 | /// library. |
2893 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DTensorAllReduceLoweringBase<DerivedT>) |
2894 | |
2895 | protected: |
2896 | }; |
2897 | |
2898 | template <typename DerivedT> |
2899 | class DTensorAllReduceScatterOptimizationBase : public ::mlir::OperationPass<mlir::func::FuncOp> { |
2900 | public: |
2901 | using Base = DTensorAllReduceScatterOptimizationBase; |
2902 | |
2903 | DTensorAllReduceScatterOptimizationBase() : ::mlir::OperationPass<mlir::func::FuncOp>(::mlir::TypeID::get<DerivedT>()) {} |
2904 | DTensorAllReduceScatterOptimizationBase(const DTensorAllReduceScatterOptimizationBase &other) : ::mlir::OperationPass<mlir::func::FuncOp>(other) {} |
2905 | |
2906 | /// Returns the command-line argument attached to this pass. |
2907 | static constexpr ::llvm::StringLiteral getArgumentName() { |
2908 | return ::llvm::StringLiteral("dtensor-allreduce-scatter-optimization" ); |
2909 | } |
2910 | ::llvm::StringRef getArgument() const override { return "dtensor-allreduce-scatter-optimization" ; } |
2911 | |
2912 | ::llvm::StringRef getDescription() const override { return "Combines allreduce and scatter to reducescatter." ; } |
2913 | |
2914 | /// Returns the derived pass name. |
2915 | static constexpr ::llvm::StringLiteral getPassName() { |
2916 | return ::llvm::StringLiteral("DTensorAllReduceScatterOptimization" ); |
2917 | } |
2918 | ::llvm::StringRef getName() const override { return "DTensorAllReduceScatterOptimization" ; } |
2919 | |
2920 | /// Support isa/dyn_cast functionality for the derived pass class. |
2921 | static bool classof(const ::mlir::Pass *pass) { |
2922 | return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>(); |
2923 | } |
2924 | |
2925 | /// A clone method to create a copy of this pass. |
2926 | std::unique_ptr<::mlir::Pass> clonePass() const override { |
2927 | return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this)); |
2928 | } |
2929 | |
2930 | /// Return the dialect that must be loaded in the context before this pass. |
2931 | void getDependentDialects(::mlir::DialectRegistry ®istry) const override { |
2932 | |
2933 | } |
2934 | |
2935 | /// Explicitly declare the TypeID for this class. We declare an explicit private |
2936 | /// instantiation because Pass classes should only be visible by the current |
2937 | /// library. |
2938 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DTensorAllReduceScatterOptimizationBase<DerivedT>) |
2939 | |
2940 | protected: |
2941 | }; |
2942 | |
2943 | template <typename DerivedT> |
2944 | class DTensorAllReduceSumOptimizationBase : public ::mlir::OperationPass<mlir::func::FuncOp> { |
2945 | public: |
2946 | using Base = DTensorAllReduceSumOptimizationBase; |
2947 | |
2948 | DTensorAllReduceSumOptimizationBase() : ::mlir::OperationPass<mlir::func::FuncOp>(::mlir::TypeID::get<DerivedT>()) {} |
2949 | DTensorAllReduceSumOptimizationBase(const DTensorAllReduceSumOptimizationBase &other) : ::mlir::OperationPass<mlir::func::FuncOp>(other) {} |
2950 | |
2951 | /// Returns the command-line argument attached to this pass. |
2952 | static constexpr ::llvm::StringLiteral getArgumentName() { |
2953 | return ::llvm::StringLiteral("dtensor-allreduce-sum-optimization" ); |
2954 | } |
2955 | ::llvm::StringRef getArgument() const override { return "dtensor-allreduce-sum-optimization" ; } |
2956 | |
2957 | ::llvm::StringRef getDescription() const override { return "Changes order of add/allreduce to minimize all reduce operations." ; } |
2958 | |
2959 | /// Returns the derived pass name. |
2960 | static constexpr ::llvm::StringLiteral getPassName() { |
2961 | return ::llvm::StringLiteral("DTensorAllReduceSumOptimization" ); |
2962 | } |
2963 | ::llvm::StringRef getName() const override { return "DTensorAllReduceSumOptimization" ; } |
2964 | |
2965 | /// Support isa/dyn_cast functionality for the derived pass class. |
2966 | static bool classof(const ::mlir::Pass *pass) { |
2967 | return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>(); |
2968 | } |
2969 | |
2970 | /// A clone method to create a copy of this pass. |
2971 | std::unique_ptr<::mlir::Pass> clonePass() const override { |
2972 | return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this)); |
2973 | } |
2974 | |
2975 | /// Return the dialect that must be loaded in the context before this pass. |
2976 | void getDependentDialects(::mlir::DialectRegistry ®istry) const override { |
2977 | |
2978 | } |
2979 | |
2980 | /// Explicitly declare the TypeID for this class. We declare an explicit private |
2981 | /// instantiation because Pass classes should only be visible by the current |
2982 | /// library. |
2983 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DTensorAllReduceSumOptimizationBase<DerivedT>) |
2984 | |
2985 | protected: |
2986 | }; |
2987 | |
2988 | template <typename DerivedT> |
2989 | class DTensorAllScatterLoweringBase : public ::mlir::OperationPass<mlir::ModuleOp> { |
2990 | public: |
2991 | using Base = DTensorAllScatterLoweringBase; |
2992 | |
2993 | DTensorAllScatterLoweringBase() : ::mlir::OperationPass<mlir::ModuleOp>(::mlir::TypeID::get<DerivedT>()) {} |
2994 | DTensorAllScatterLoweringBase(const DTensorAllScatterLoweringBase &other) : ::mlir::OperationPass<mlir::ModuleOp>(other) {} |
2995 | |
2996 | /// Returns the command-line argument attached to this pass. |
2997 | static constexpr ::llvm::StringLiteral getArgumentName() { |
2998 | return ::llvm::StringLiteral("dtensor-all-scatter-lowering" ); |
2999 | } |
3000 | ::llvm::StringRef getArgument() const override { return "dtensor-all-scatter-lowering" ; } |
3001 | |
3002 | ::llvm::StringRef getDescription() const override { return "Converts logical AllScatter ops into physical Split ops." ; } |
3003 | |
3004 | /// Returns the derived pass name. |
3005 | static constexpr ::llvm::StringLiteral getPassName() { |
3006 | return ::llvm::StringLiteral("DTensorAllScatterLowering" ); |
3007 | } |
3008 | ::llvm::StringRef getName() const override { return "DTensorAllScatterLowering" ; } |
3009 | |
3010 | /// Support isa/dyn_cast functionality for the derived pass class. |
3011 | static bool classof(const ::mlir::Pass *pass) { |
3012 | return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>(); |
3013 | } |
3014 | |
3015 | /// A clone method to create a copy of this pass. |
3016 | std::unique_ptr<::mlir::Pass> clonePass() const override { |
3017 | return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this)); |
3018 | } |
3019 | |
3020 | /// Return the dialect that must be loaded in the context before this pass. |
3021 | void getDependentDialects(::mlir::DialectRegistry ®istry) const override { |
3022 | |
3023 | } |
3024 | |
3025 | /// Explicitly declare the TypeID for this class. We declare an explicit private |
3026 | /// instantiation because Pass classes should only be visible by the current |
3027 | /// library. |
3028 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DTensorAllScatterLoweringBase<DerivedT>) |
3029 | |
3030 | protected: |
3031 | }; |
3032 | |
3033 | template <typename DerivedT> |
3034 | class DTensorAnnotateGlobalShapeBase : public ::mlir::OperationPass<mlir::ModuleOp> { |
3035 | public: |
3036 | using Base = DTensorAnnotateGlobalShapeBase; |
3037 | |
3038 | DTensorAnnotateGlobalShapeBase() : ::mlir::OperationPass<mlir::ModuleOp>(::mlir::TypeID::get<DerivedT>()) {} |
3039 | DTensorAnnotateGlobalShapeBase(const DTensorAnnotateGlobalShapeBase &other) : ::mlir::OperationPass<mlir::ModuleOp>(other) {} |
3040 | |
3041 | /// Returns the command-line argument attached to this pass. |
3042 | static constexpr ::llvm::StringLiteral getArgumentName() { |
3043 | return ::llvm::StringLiteral("dtensor-annotate-global-shape" ); |
3044 | } |
3045 | ::llvm::StringRef getArgument() const override { return "dtensor-annotate-global-shape" ; } |
3046 | |
3047 | ::llvm::StringRef getDescription() const override { return "Mark all operations and function arguments with `_global_shape` attribute to be used during SPMD expansion." ; } |
3048 | |
3049 | /// Returns the derived pass name. |
3050 | static constexpr ::llvm::StringLiteral getPassName() { |
3051 | return ::llvm::StringLiteral("DTensorAnnotateGlobalShape" ); |
3052 | } |
3053 | ::llvm::StringRef getName() const override { return "DTensorAnnotateGlobalShape" ; } |
3054 | |
3055 | /// Support isa/dyn_cast functionality for the derived pass class. |
3056 | static bool classof(const ::mlir::Pass *pass) { |
3057 | return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>(); |
3058 | } |
3059 | |
3060 | /// A clone method to create a copy of this pass. |
3061 | std::unique_ptr<::mlir::Pass> clonePass() const override { |
3062 | return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this)); |
3063 | } |
3064 | |
3065 | /// Return the dialect that must be loaded in the context before this pass. |
3066 | void getDependentDialects(::mlir::DialectRegistry ®istry) const override { |
3067 | |
3068 | } |
3069 | |
3070 | /// Explicitly declare the TypeID for this class. We declare an explicit private |
3071 | /// instantiation because Pass classes should only be visible by the current |
3072 | /// library. |
3073 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DTensorAnnotateGlobalShapeBase<DerivedT>) |
3074 | |
3075 | protected: |
3076 | }; |
3077 | |
3078 | template <typename DerivedT> |
3079 | class DTensorClusterFunctionConversionBase : public ::mlir::OperationPass<mlir::ModuleOp> { |
3080 | public: |
3081 | using Base = DTensorClusterFunctionConversionBase; |
3082 | |
3083 | DTensorClusterFunctionConversionBase() : ::mlir::OperationPass<mlir::ModuleOp>(::mlir::TypeID::get<DerivedT>()) {} |
3084 | DTensorClusterFunctionConversionBase(const DTensorClusterFunctionConversionBase &other) : ::mlir::OperationPass<mlir::ModuleOp>(other) {} |
3085 | |
3086 | /// Returns the command-line argument attached to this pass. |
3087 | static constexpr ::llvm::StringLiteral getArgumentName() { |
3088 | return ::llvm::StringLiteral("dtensor-cluster-function-conversion" ); |
3089 | } |
3090 | ::llvm::StringRef getArgument() const override { return "dtensor-cluster-function-conversion" ; } |
3091 | |
3092 | ::llvm::StringRef getDescription() const override { return "Converts tf_device.cluster_func ops into TF StatefulPartitioned call op with mesh attribute." ; } |
3093 | |
3094 | /// Returns the derived pass name. |
3095 | static constexpr ::llvm::StringLiteral getPassName() { |
3096 | return ::llvm::StringLiteral("DTensorClusterFunctionConversion" ); |
3097 | } |
3098 | ::llvm::StringRef getName() const override { return "DTensorClusterFunctionConversion" ; } |
3099 | |
3100 | /// Support isa/dyn_cast functionality for the derived pass class. |
3101 | static bool classof(const ::mlir::Pass *pass) { |
3102 | return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>(); |
3103 | } |
3104 | |
3105 | /// A clone method to create a copy of this pass. |
3106 | std::unique_ptr<::mlir::Pass> clonePass() const override { |
3107 | return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this)); |
3108 | } |
3109 | |
3110 | /// Return the dialect that must be loaded in the context before this pass. |
3111 | void getDependentDialects(::mlir::DialectRegistry ®istry) const override { |
3112 | |
3113 | } |
3114 | |
3115 | /// Explicitly declare the TypeID for this class. We declare an explicit private |
3116 | /// instantiation because Pass classes should only be visible by the current |
3117 | /// library. |
3118 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DTensorClusterFunctionConversionBase<DerivedT>) |
3119 | |
3120 | protected: |
3121 | }; |
3122 | |
3123 | template <typename DerivedT> |
3124 | class DTensorConstantFoldingBase : public ::mlir::OperationPass<mlir::func::FuncOp> { |
3125 | public: |
3126 | using Base = DTensorConstantFoldingBase; |
3127 | |
3128 | DTensorConstantFoldingBase() : ::mlir::OperationPass<mlir::func::FuncOp>(::mlir::TypeID::get<DerivedT>()) {} |
3129 | DTensorConstantFoldingBase(const DTensorConstantFoldingBase &other) : ::mlir::OperationPass<mlir::func::FuncOp>(other) {} |
3130 | |
3131 | /// Returns the command-line argument attached to this pass. |
3132 | static constexpr ::llvm::StringLiteral getArgumentName() { |
3133 | return ::llvm::StringLiteral("dtensor-constant-folding" ); |
3134 | } |
3135 | ::llvm::StringRef getArgument() const override { return "dtensor-constant-folding" ; } |
3136 | |
3137 | ::llvm::StringRef getDescription() const override { return "Folds constants operations." ; } |
3138 | |
3139 | /// Returns the derived pass name. |
3140 | static constexpr ::llvm::StringLiteral getPassName() { |
3141 | return ::llvm::StringLiteral("DTensorConstantFolding" ); |
3142 | } |
3143 | ::llvm::StringRef getName() const override { return "DTensorConstantFolding" ; } |
3144 | |
3145 | /// Support isa/dyn_cast functionality for the derived pass class. |
3146 | static bool classof(const ::mlir::Pass *pass) { |
3147 | return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>(); |
3148 | } |
3149 | |
3150 | /// A clone method to create a copy of this pass. |
3151 | std::unique_ptr<::mlir::Pass> clonePass() const override { |
3152 | return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this)); |
3153 | } |
3154 | |
3155 | /// Return the dialect that must be loaded in the context before this pass. |
3156 | void getDependentDialects(::mlir::DialectRegistry ®istry) const override { |
3157 | |
3158 | } |
3159 | |
3160 | /// Explicitly declare the TypeID for this class. We declare an explicit private |
3161 | /// instantiation because Pass classes should only be visible by the current |
3162 | /// library. |
3163 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DTensorConstantFoldingBase<DerivedT>) |
3164 | |
3165 | protected: |
3166 | }; |
3167 | |
3168 | template <typename DerivedT> |
3169 | class DTensorDCEBase : public ::mlir::OperationPass<mlir::func::FuncOp> { |
3170 | public: |
3171 | using Base = DTensorDCEBase; |
3172 | |
3173 | DTensorDCEBase() : ::mlir::OperationPass<mlir::func::FuncOp>(::mlir::TypeID::get<DerivedT>()) {} |
3174 | DTensorDCEBase(const DTensorDCEBase &other) : ::mlir::OperationPass<mlir::func::FuncOp>(other) {} |
3175 | |
3176 | /// Returns the command-line argument attached to this pass. |
3177 | static constexpr ::llvm::StringLiteral getArgumentName() { |
3178 | return ::llvm::StringLiteral("dtensor-dce" ); |
3179 | } |
3180 | ::llvm::StringRef getArgument() const override { return "dtensor-dce" ; } |
3181 | |
3182 | ::llvm::StringRef getDescription() const override { return "Removes unused ops from graph." ; } |
3183 | |
3184 | /// Returns the derived pass name. |
3185 | static constexpr ::llvm::StringLiteral getPassName() { |
3186 | return ::llvm::StringLiteral("DTensorDCE" ); |
3187 | } |
3188 | ::llvm::StringRef getName() const override { return "DTensorDCE" ; } |
3189 | |
3190 | /// Support isa/dyn_cast functionality for the derived pass class. |
3191 | static bool classof(const ::mlir::Pass *pass) { |
3192 | return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>(); |
3193 | } |
3194 | |
3195 | /// A clone method to create a copy of this pass. |
3196 | std::unique_ptr<::mlir::Pass> clonePass() const override { |
3197 | return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this)); |
3198 | } |
3199 | |
3200 | /// Return the dialect that must be loaded in the context before this pass. |
3201 | void getDependentDialects(::mlir::DialectRegistry ®istry) const override { |
3202 | |
3203 | } |
3204 | |
3205 | /// Explicitly declare the TypeID for this class. We declare an explicit private |
3206 | /// instantiation because Pass classes should only be visible by the current |
3207 | /// library. |
3208 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DTensorDCEBase<DerivedT>) |
3209 | |
3210 | protected: |
3211 | }; |
3212 | |
3213 | template <typename DerivedT> |
3214 | class DTensorDesignateResourceHandleMeshBase : public ::mlir::OperationPass<mlir::func::FuncOp> { |
3215 | public: |
3216 | using Base = DTensorDesignateResourceHandleMeshBase; |
3217 | |
3218 | DTensorDesignateResourceHandleMeshBase() : ::mlir::OperationPass<mlir::func::FuncOp>(::mlir::TypeID::get<DerivedT>()) {} |
3219 | DTensorDesignateResourceHandleMeshBase(const DTensorDesignateResourceHandleMeshBase &other) : ::mlir::OperationPass<mlir::func::FuncOp>(other) {} |
3220 | |
3221 | /// Returns the command-line argument attached to this pass. |
3222 | static constexpr ::llvm::StringLiteral getArgumentName() { |
3223 | return ::llvm::StringLiteral("dtensor-designate-resource-handle-mesh" ); |
3224 | } |
3225 | ::llvm::StringRef getArgument() const override { return "dtensor-designate-resource-handle-mesh" ; } |
3226 | |
3227 | ::llvm::StringRef getDescription() const override { return "Sets empty mesh attributes for device cluster that creates or destroys resource handles." ; } |
3228 | |
3229 | /// Returns the derived pass name. |
3230 | static constexpr ::llvm::StringLiteral getPassName() { |
3231 | return ::llvm::StringLiteral("DTensorDesignateResourceHandleMesh" ); |
3232 | } |
3233 | ::llvm::StringRef getName() const override { return "DTensorDesignateResourceHandleMesh" ; } |
3234 | |
3235 | /// Support isa/dyn_cast functionality for the derived pass class. |
3236 | static bool classof(const ::mlir::Pass *pass) { |
3237 | return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>(); |
3238 | } |
3239 | |
3240 | /// A clone method to create a copy of this pass. |
3241 | std::unique_ptr<::mlir::Pass> clonePass() const override { |
3242 | return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this)); |
3243 | } |
3244 | |
3245 | /// Return the dialect that must be loaded in the context before this pass. |
3246 | void getDependentDialects(::mlir::DialectRegistry ®istry) const override { |
3247 | |
3248 | } |
3249 | |
3250 | /// Explicitly declare the TypeID for this class. We declare an explicit private |
3251 | /// instantiation because Pass classes should only be visible by the current |
3252 | /// library. |
3253 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DTensorDesignateResourceHandleMeshBase<DerivedT>) |
3254 | |
3255 | protected: |
3256 | }; |
3257 | |
3258 | template <typename DerivedT> |
3259 | class DTensorDeviceMeshClusterCoarseningBase : public ::mlir::OperationPass<mlir::func::FuncOp> { |
3260 | public: |
3261 | using Base = DTensorDeviceMeshClusterCoarseningBase; |
3262 | |
3263 | DTensorDeviceMeshClusterCoarseningBase() : ::mlir::OperationPass<mlir::func::FuncOp>(::mlir::TypeID::get<DerivedT>()) {} |
3264 | DTensorDeviceMeshClusterCoarseningBase(const DTensorDeviceMeshClusterCoarseningBase &other) : ::mlir::OperationPass<mlir::func::FuncOp>(other) {} |
3265 | |
3266 | /// Returns the command-line argument attached to this pass. |
3267 | static constexpr ::llvm::StringLiteral getArgumentName() { |
3268 | return ::llvm::StringLiteral("dtensor-device-mesh-cluster-coarsening" ); |
3269 | } |
3270 | ::llvm::StringRef getArgument() const override { return "dtensor-device-mesh-cluster-coarsening" ; } |
3271 | |
3272 | ::llvm::StringRef getDescription() const override { return "Merges tf_device.cluster op with same mesh attribute." ; } |
3273 | |
3274 | /// Returns the derived pass name. |
3275 | static constexpr ::llvm::StringLiteral getPassName() { |
3276 | return ::llvm::StringLiteral("DTensorDeviceMeshClusterCoarsening" ); |
3277 | } |
3278 | ::llvm::StringRef getName() const override { return "DTensorDeviceMeshClusterCoarsening" ; } |
3279 | |
3280 | /// Support isa/dyn_cast functionality for the derived pass class. |
3281 | static bool classof(const ::mlir::Pass *pass) { |
3282 | return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>(); |
3283 | } |
3284 | |
3285 | /// A clone method to create a copy of this pass. |
3286 | std::unique_ptr<::mlir::Pass> clonePass() const override { |
3287 | return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this)); |
3288 | } |
3289 | |
3290 | /// Return the dialect that must be loaded in the context before this pass. |
3291 | void getDependentDialects(::mlir::DialectRegistry ®istry) const override { |
3292 | |
3293 | } |
3294 | |
3295 | /// Explicitly declare the TypeID for this class. We declare an explicit private |
3296 | /// instantiation because Pass classes should only be visible by the current |
3297 | /// library. |
3298 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DTensorDeviceMeshClusterCoarseningBase<DerivedT>) |
3299 | |
3300 | protected: |
3301 | }; |
3302 | |
3303 | template <typename DerivedT> |
3304 | class DTensorEmbeddingBase : public ::mlir::OperationPass<mlir::ModuleOp> { |
3305 | public: |
3306 | using Base = DTensorEmbeddingBase; |
3307 | |
3308 | DTensorEmbeddingBase() : ::mlir::OperationPass<mlir::ModuleOp>(::mlir::TypeID::get<DerivedT>()) {} |
3309 | DTensorEmbeddingBase(const DTensorEmbeddingBase &other) : ::mlir::OperationPass<mlir::ModuleOp>(other) {} |
3310 | |
3311 | /// Returns the command-line argument attached to this pass. |
3312 | static constexpr ::llvm::StringLiteral getArgumentName() { |
3313 | return ::llvm::StringLiteral("dtensor-embedding" ); |
3314 | } |
3315 | ::llvm::StringRef getArgument() const override { return "dtensor-embedding" ; } |
3316 | |
3317 | ::llvm::StringRef getDescription() const override { return "Extracts the ApplyEmbeddingOptimizer and checks if the attached optimizer is usable for TPU embeddings." ; } |
3318 | |
3319 | /// Returns the derived pass name. |
3320 | static constexpr ::llvm::StringLiteral getPassName() { |
3321 | return ::llvm::StringLiteral("DTensorEmbedding" ); |
3322 | } |
3323 | ::llvm::StringRef getName() const override { return "DTensorEmbedding" ; } |
3324 | |
3325 | /// Support isa/dyn_cast functionality for the derived pass class. |
3326 | static bool classof(const ::mlir::Pass *pass) { |
3327 | return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>(); |
3328 | } |
3329 | |
3330 | /// A clone method to create a copy of this pass. |
3331 | std::unique_ptr<::mlir::Pass> clonePass() const override { |
3332 | return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this)); |
3333 | } |
3334 | |
3335 | /// Return the dialect that must be loaded in the context before this pass. |
3336 | void getDependentDialects(::mlir::DialectRegistry ®istry) const override { |
3337 | |
3338 | registry.insert<::mlir::chlo::ChloDialect>(); |
3339 | |
3340 | registry.insert<::mlir::memref::MemRefDialect>(); |
3341 | |
3342 | registry.insert<::mlir::mhlo::MhloDialect>(); |
3343 | |
3344 | registry.insert<::mlir::shape::ShapeDialect>(); |
3345 | |
3346 | registry.insert<::mlir::scf::SCFDialect>(); |
3347 | |
3348 | registry.insert<::mlir::tensor::TensorDialect>(); |
3349 | |
3350 | } |
3351 | |
3352 | /// Explicitly declare the TypeID for this class. We declare an explicit private |
3353 | /// instantiation because Pass classes should only be visible by the current |
3354 | /// library. |
3355 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DTensorEmbeddingBase<DerivedT>) |
3356 | |
3357 | protected: |
3358 | }; |
3359 | |
3360 | template <typename DerivedT> |
3361 | class DTensorEmbeddingCheckpointBase : public ::mlir::OperationPass<mlir::ModuleOp> { |
3362 | public: |
3363 | using Base = DTensorEmbeddingCheckpointBase; |
3364 | |
3365 | DTensorEmbeddingCheckpointBase() : ::mlir::OperationPass<mlir::ModuleOp>(::mlir::TypeID::get<DerivedT>()) {} |
3366 | DTensorEmbeddingCheckpointBase(const DTensorEmbeddingCheckpointBase &other) : ::mlir::OperationPass<mlir::ModuleOp>(other) {} |
3367 | |
3368 | /// Returns the command-line argument attached to this pass. |
3369 | static constexpr ::llvm::StringLiteral getArgumentName() { |
3370 | return ::llvm::StringLiteral("dtensor-embedding-checkpoint" ); |
3371 | } |
3372 | ::llvm::StringRef getArgument() const override { return "dtensor-embedding-checkpoint" ; } |
3373 | |
3374 | ::llvm::StringRef getDescription() const override { return "Checkpoint pass to save or load TPU embeddings variables." ; } |
3375 | |
3376 | /// Returns the derived pass name. |
3377 | static constexpr ::llvm::StringLiteral getPassName() { |
3378 | return ::llvm::StringLiteral("DTensorEmbeddingCheckpoint" ); |
3379 | } |
3380 | ::llvm::StringRef getName() const override { return "DTensorEmbeddingCheckpoint" ; } |
3381 | |
3382 | /// Support isa/dyn_cast functionality for the derived pass class. |
3383 | static bool classof(const ::mlir::Pass *pass) { |
3384 | return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>(); |
3385 | } |
3386 | |
3387 | /// A clone method to create a copy of this pass. |
3388 | std::unique_ptr<::mlir::Pass> clonePass() const override { |
3389 | return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this)); |
3390 | } |
3391 | |
3392 | /// Return the dialect that must be loaded in the context before this pass. |
3393 | void getDependentDialects(::mlir::DialectRegistry ®istry) const override { |
3394 | |
3395 | } |
3396 | |
3397 | /// Explicitly declare the TypeID for this class. We declare an explicit private |
3398 | /// instantiation because Pass classes should only be visible by the current |
3399 | /// library. |
3400 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DTensorEmbeddingCheckpointBase<DerivedT>) |
3401 | |
3402 | protected: |
3403 | }; |
3404 | |
3405 | template <typename DerivedT> |
3406 | class DTensorEmbeddingV2Base : public ::mlir::OperationPass<mlir::ModuleOp> { |
3407 | public: |
3408 | using Base = DTensorEmbeddingV2Base; |
3409 | |
3410 | DTensorEmbeddingV2Base() : ::mlir::OperationPass<mlir::ModuleOp>(::mlir::TypeID::get<DerivedT>()) {} |
3411 | DTensorEmbeddingV2Base(const DTensorEmbeddingV2Base &other) : ::mlir::OperationPass<mlir::ModuleOp>(other) {} |
3412 | |
3413 | /// Returns the command-line argument attached to this pass. |
3414 | static constexpr ::llvm::StringLiteral getArgumentName() { |
3415 | return ::llvm::StringLiteral("dtensor-embedding-v2" ); |
3416 | } |
3417 | ::llvm::StringRef getArgument() const override { return "dtensor-embedding-v2" ; } |
3418 | |
3419 | ::llvm::StringRef getDescription() const override { return "Embedding Pass" ; } |
3420 | |
3421 | /// Returns the derived pass name. |
3422 | static constexpr ::llvm::StringLiteral getPassName() { |
3423 | return ::llvm::StringLiteral("DTensorEmbeddingV2" ); |
3424 | } |
3425 | ::llvm::StringRef getName() const override { return "DTensorEmbeddingV2" ; } |
3426 | |
3427 | /// Support isa/dyn_cast functionality for the derived pass class. |
3428 | static bool classof(const ::mlir::Pass *pass) { |
3429 | return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>(); |
3430 | } |
3431 | |
3432 | /// A clone method to create a copy of this pass. |
3433 | std::unique_ptr<::mlir::Pass> clonePass() const override { |
3434 | return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this)); |
3435 | } |
3436 | |
3437 | /// Return the dialect that must be loaded in the context before this pass. |
3438 | void getDependentDialects(::mlir::DialectRegistry ®istry) const override { |
3439 | |
3440 | } |
3441 | |
3442 | /// Explicitly declare the TypeID for this class. We declare an explicit private |
3443 | /// instantiation because Pass classes should only be visible by the current |
3444 | /// library. |
3445 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DTensorEmbeddingV2Base<DerivedT>) |
3446 | |
3447 | protected: |
3448 | }; |
3449 | |
3450 | template <typename DerivedT> |
3451 | class DTensorFunctionRenamingBase : public ::mlir::OperationPass<mlir::ModuleOp> { |
3452 | public: |
3453 | using Base = DTensorFunctionRenamingBase; |
3454 | |
3455 | DTensorFunctionRenamingBase() : ::mlir::OperationPass<mlir::ModuleOp>(::mlir::TypeID::get<DerivedT>()) {} |
3456 | DTensorFunctionRenamingBase(const DTensorFunctionRenamingBase &other) : ::mlir::OperationPass<mlir::ModuleOp>(other) {} |
3457 | |
3458 | /// Returns the command-line argument attached to this pass. |
3459 | static constexpr ::llvm::StringLiteral getArgumentName() { |
3460 | return ::llvm::StringLiteral("dtensor-function-renaming" ); |
3461 | } |
3462 | ::llvm::StringRef getArgument() const override { return "dtensor-function-renaming" ; } |
3463 | |
3464 | ::llvm::StringRef getDescription() const override { return "Renames private functions by appending an id to each name. This is used to make private function names unique across modules." ; } |
3465 | |
3466 | /// Returns the derived pass name. |
3467 | static constexpr ::llvm::StringLiteral getPassName() { |
3468 | return ::llvm::StringLiteral("DTensorFunctionRenaming" ); |
3469 | } |
3470 | ::llvm::StringRef getName() const override { return "DTensorFunctionRenaming" ; } |
3471 | |
3472 | /// Support isa/dyn_cast functionality for the derived pass class. |
3473 | static bool classof(const ::mlir::Pass *pass) { |
3474 | return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>(); |
3475 | } |
3476 | |
3477 | /// A clone method to create a copy of this pass. |
3478 | std::unique_ptr<::mlir::Pass> clonePass() const override { |
3479 | return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this)); |
3480 | } |
3481 | |
3482 | /// Return the dialect that must be loaded in the context before this pass. |
3483 | void getDependentDialects(::mlir::DialectRegistry ®istry) const override { |
3484 | |
3485 | } |
3486 | |
3487 | /// Explicitly declare the TypeID for this class. We declare an explicit private |
3488 | /// instantiation because Pass classes should only be visible by the current |
3489 | /// library. |
3490 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DTensorFunctionRenamingBase<DerivedT>) |
3491 | |
3492 | protected: |
3493 | }; |
3494 | |
3495 | template <typename DerivedT> |
3496 | class DTensorHandleCrossClusterDependenciesBase : public ::mlir::OperationPass<mlir::ModuleOp> { |
3497 | public: |
3498 | using Base = DTensorHandleCrossClusterDependenciesBase; |
3499 | |
3500 | DTensorHandleCrossClusterDependenciesBase() : ::mlir::OperationPass<mlir::ModuleOp>(::mlir::TypeID::get<DerivedT>()) {} |
3501 | DTensorHandleCrossClusterDependenciesBase(const DTensorHandleCrossClusterDependenciesBase &other) : ::mlir::OperationPass<mlir::ModuleOp>(other) {} |
3502 | |
3503 | /// Returns the command-line argument attached to this pass. |
3504 | static constexpr ::llvm::StringLiteral getArgumentName() { |
3505 | return ::llvm::StringLiteral("dtensor-handle_cross_cluster_dependences" ); |
3506 | } |
3507 | ::llvm::StringRef getArgument() const override { return "dtensor-handle_cross_cluster_dependences" ; } |
3508 | |
3509 | ::llvm::StringRef getDescription() const override { return "Lowers cross mesh cluster data depedences as send/recvs." ; } |
3510 | |
3511 | /// Returns the derived pass name. |
3512 | static constexpr ::llvm::StringLiteral getPassName() { |
3513 | return ::llvm::StringLiteral("DTensorHandleCrossClusterDependencies" ); |
3514 | } |
3515 | ::llvm::StringRef getName() const override { return "DTensorHandleCrossClusterDependencies" ; } |
3516 | |
3517 | /// Support isa/dyn_cast functionality for the derived pass class. |
3518 | static bool classof(const ::mlir::Pass *pass) { |
3519 | return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>(); |
3520 | } |
3521 | |
3522 | /// A clone method to create a copy of this pass. |
3523 | std::unique_ptr<::mlir::Pass> clonePass() const override { |
3524 | return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this)); |
3525 | } |
3526 | |
3527 | /// Return the dialect that must be loaded in the context before this pass. |
3528 | void getDependentDialects(::mlir::DialectRegistry ®istry) const override { |
3529 | |
3530 | } |
3531 | |
3532 | /// Explicitly declare the TypeID for this class. We declare an explicit private |
3533 | /// instantiation because Pass classes should only be visible by the current |
3534 | /// library. |
3535 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DTensorHandleCrossClusterDependenciesBase<DerivedT>) |
3536 | |
3537 | protected: |
3538 | }; |
3539 | |
3540 | template <typename DerivedT> |
3541 | class DTensorInferShapesForRestoreV2OpBase : public ::mlir::OperationPass<mlir::ModuleOp> { |
3542 | public: |
3543 | using Base = DTensorInferShapesForRestoreV2OpBase; |
3544 | |
3545 | DTensorInferShapesForRestoreV2OpBase() : ::mlir::OperationPass<mlir::ModuleOp>(::mlir::TypeID::get<DerivedT>()) {} |
3546 | DTensorInferShapesForRestoreV2OpBase(const DTensorInferShapesForRestoreV2OpBase &other) : ::mlir::OperationPass<mlir::ModuleOp>(other) {} |
3547 | |
3548 | /// Returns the command-line argument attached to this pass. |
3549 | static constexpr ::llvm::StringLiteral getArgumentName() { |
3550 | return ::llvm::StringLiteral("dtensor-infer-shapes-for-restorev2-op" ); |
3551 | } |
3552 | ::llvm::StringRef getArgument() const override { return "dtensor-infer-shapes-for-restorev2-op" ; } |
3553 | |
3554 | ::llvm::StringRef getDescription() const override { return "Infer shapes of the outputs of tf.RestoreV2Op from the AssignVariableOps that consume those outputs. This is used for DTensor integration with TF Checkpoint." ; } |
3555 | |
3556 | /// Returns the derived pass name. |
3557 | static constexpr ::llvm::StringLiteral getPassName() { |
3558 | return ::llvm::StringLiteral("DTensorInferShapesForRestoreV2Op" ); |
3559 | } |
3560 | ::llvm::StringRef getName() const override { return "DTensorInferShapesForRestoreV2Op" ; } |
3561 | |
3562 | /// Support isa/dyn_cast functionality for the derived pass class. |
3563 | static bool classof(const ::mlir::Pass *pass) { |
3564 | return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>(); |
3565 | } |
3566 | |
3567 | /// A clone method to create a copy of this pass. |
3568 | std::unique_ptr<::mlir::Pass> clonePass() const override { |
3569 | return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this)); |
3570 | } |
3571 | |
3572 | /// Return the dialect that must be loaded in the context before this pass. |
3573 | void getDependentDialects(::mlir::DialectRegistry ®istry) const override { |
3574 | |
3575 | } |
3576 | |
3577 | /// Explicitly declare the TypeID for this class. We declare an explicit private |
3578 | /// instantiation because Pass classes should only be visible by the current |
3579 | /// library. |
3580 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DTensorInferShapesForRestoreV2OpBase<DerivedT>) |
3581 | |
3582 | protected: |
3583 | }; |
3584 | |
3585 | template <typename DerivedT> |
3586 | class DTensorLayoutPropagationV2Base : public ::mlir::OperationPass<mlir::ModuleOp> { |
3587 | public: |
3588 | using Base = DTensorLayoutPropagationV2Base; |
3589 | |
3590 | DTensorLayoutPropagationV2Base() : ::mlir::OperationPass<mlir::ModuleOp>(::mlir::TypeID::get<DerivedT>()) {} |
3591 | DTensorLayoutPropagationV2Base(const DTensorLayoutPropagationV2Base &other) : ::mlir::OperationPass<mlir::ModuleOp>(other) {} |
3592 | |
3593 | /// Returns the command-line argument attached to this pass. |
3594 | static constexpr ::llvm::StringLiteral getArgumentName() { |
3595 | return ::llvm::StringLiteral("dtensor-layout-propagation-v2" ); |
3596 | } |
3597 | ::llvm::StringRef getArgument() const override { return "dtensor-layout-propagation-v2" ; } |
3598 | |
3599 | ::llvm::StringRef getDescription() const override { return "Propagates layout information for all the TF Ops." ; } |
3600 | |
3601 | /// Returns the derived pass name. |
3602 | static constexpr ::llvm::StringLiteral getPassName() { |
3603 | return ::llvm::StringLiteral("DTensorLayoutPropagationV2" ); |
3604 | } |
3605 | ::llvm::StringRef getName() const override { return "DTensorLayoutPropagationV2" ; } |
3606 | |
3607 | /// Support isa/dyn_cast functionality for the derived pass class. |
3608 | static bool classof(const ::mlir::Pass *pass) { |
3609 | return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>(); |
3610 | } |
3611 | |
3612 | /// A clone method to create a copy of this pass. |
3613 | std::unique_ptr<::mlir::Pass> clonePass() const override { |
3614 | return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this)); |
3615 | } |
3616 | |
3617 | /// Return the dialect that must be loaded in the context before this pass. |
3618 | void getDependentDialects(::mlir::DialectRegistry ®istry) const override { |
3619 | |
3620 | } |
3621 | |
3622 | /// Explicitly declare the TypeID for this class. We declare an explicit private |
3623 | /// instantiation because Pass classes should only be visible by the current |
3624 | /// library. |
3625 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DTensorLayoutPropagationV2Base<DerivedT>) |
3626 | |
3627 | protected: |
3628 | }; |
3629 | |
3630 | template <typename DerivedT> |
3631 | class DTensorLowerSendRecvBase : public ::mlir::OperationPass<mlir::ModuleOp> { |
3632 | public: |
3633 | using Base = DTensorLowerSendRecvBase; |
3634 | |
3635 | DTensorLowerSendRecvBase() : ::mlir::OperationPass<mlir::ModuleOp>(::mlir::TypeID::get<DerivedT>()) {} |
3636 | DTensorLowerSendRecvBase(const DTensorLowerSendRecvBase &other) : ::mlir::OperationPass<mlir::ModuleOp>(other) {} |
3637 | |
3638 | /// Returns the command-line argument attached to this pass. |
3639 | static constexpr ::llvm::StringLiteral getArgumentName() { |
3640 | return ::llvm::StringLiteral("dtensor-lower-send-recv" ); |
3641 | } |
3642 | ::llvm::StringRef getArgument() const override { return "dtensor-lower-send-recv" ; } |
3643 | |
3644 | ::llvm::StringRef getDescription() const override { return "Lowers DTensorSend/DTensorRecv ops to send/recv ops." ; } |
3645 | |
3646 | /// Returns the derived pass name. |
3647 | static constexpr ::llvm::StringLiteral getPassName() { |
3648 | return ::llvm::StringLiteral("DTensorLowerSendRecv" ); |
3649 | } |
3650 | ::llvm::StringRef getName() const override { return "DTensorLowerSendRecv" ; } |
3651 | |
3652 | /// Support isa/dyn_cast functionality for the derived pass class. |
3653 | static bool classof(const ::mlir::Pass *pass) { |
3654 | return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>(); |
3655 | } |
3656 | |
3657 | /// A clone method to create a copy of this pass. |
3658 | std::unique_ptr<::mlir::Pass> clonePass() const override { |
3659 | return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this)); |
3660 | } |
3661 | |
3662 | /// Return the dialect that must be loaded in the context before this pass. |
3663 | void getDependentDialects(::mlir::DialectRegistry ®istry) const override { |
3664 | |
3665 | } |
3666 | |
3667 | /// Explicitly declare the TypeID for this class. We declare an explicit private |
3668 | /// instantiation because Pass classes should only be visible by the current |
3669 | /// library. |
3670 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DTensorLowerSendRecvBase<DerivedT>) |
3671 | |
3672 | protected: |
3673 | }; |
3674 | |
3675 | template <typename DerivedT> |
3676 | class DTensorMergeClustersBase : public ::mlir::OperationPass<mlir::ModuleOp> { |
3677 | public: |
3678 | using Base = DTensorMergeClustersBase; |
3679 | |
3680 | DTensorMergeClustersBase() : ::mlir::OperationPass<mlir::ModuleOp>(::mlir::TypeID::get<DerivedT>()) {} |
3681 | DTensorMergeClustersBase(const DTensorMergeClustersBase &other) : ::mlir::OperationPass<mlir::ModuleOp>(other) {} |
3682 | |
3683 | /// Returns the command-line argument attached to this pass. |
3684 | static constexpr ::llvm::StringLiteral getArgumentName() { |
3685 | return ::llvm::StringLiteral("dtensor-merge-clusters" ); |
3686 | } |
3687 | ::llvm::StringRef getArgument() const override { return "dtensor-merge-clusters" ; } |
3688 | |
3689 | ::llvm::StringRef getDescription() const override { return "Merges tf_device.Clusters ops with same mesh specification." ; } |
3690 | |
3691 | /// Returns the derived pass name. |
3692 | static constexpr ::llvm::StringLiteral getPassName() { |
3693 | return ::llvm::StringLiteral("DTensorMergeClusters" ); |
3694 | } |
3695 | ::llvm::StringRef getName() const override { return "DTensorMergeClusters" ; } |
3696 | |
3697 | /// Support isa/dyn_cast functionality for the derived pass class. |
3698 | static bool classof(const ::mlir::Pass *pass) { |
3699 | return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>(); |
3700 | } |
3701 | |
3702 | /// A clone method to create a copy of this pass. |
3703 | std::unique_ptr<::mlir::Pass> clonePass() const override { |
3704 | return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this)); |
3705 | } |
3706 | |
3707 | /// Return the dialect that must be loaded in the context before this pass. |
3708 | void getDependentDialects(::mlir::DialectRegistry ®istry) const override { |
3709 | |
3710 | } |
3711 | |
3712 | /// Explicitly declare the TypeID for this class. We declare an explicit private |
3713 | /// instantiation because Pass classes should only be visible by the current |
3714 | /// library. |
3715 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DTensorMergeClustersBase<DerivedT>) |
3716 | |
3717 | protected: |
3718 | }; |
3719 | |
3720 | template <typename DerivedT> |
3721 | class DTensorMeshPropagationBase : public ::mlir::OperationPass<mlir::ModuleOp> { |
3722 | public: |
3723 | using Base = DTensorMeshPropagationBase; |
3724 | |
3725 | DTensorMeshPropagationBase() : ::mlir::OperationPass<mlir::ModuleOp>(::mlir::TypeID::get<DerivedT>()) {} |
3726 | DTensorMeshPropagationBase(const DTensorMeshPropagationBase &other) : ::mlir::OperationPass<mlir::ModuleOp>(other) {} |
3727 | |
3728 | /// Returns the command-line argument attached to this pass. |
3729 | static constexpr ::llvm::StringLiteral getArgumentName() { |
3730 | return ::llvm::StringLiteral("dtensor-mesh-propagation" ); |
3731 | } |
3732 | ::llvm::StringRef getArgument() const override { return "dtensor-mesh-propagation" ; } |
3733 | |
3734 | ::llvm::StringRef getDescription() const override { return "Propagates mesh information to all clusters." ; } |
3735 | |
3736 | /// Returns the derived pass name. |
3737 | static constexpr ::llvm::StringLiteral getPassName() { |
3738 | return ::llvm::StringLiteral("DTensorMeshPropagation" ); |
3739 | } |
3740 | ::llvm::StringRef getName() const override { return "DTensorMeshPropagation" ; } |
3741 | |
3742 | /// Support isa/dyn_cast functionality for the derived pass class. |
3743 | static bool classof(const ::mlir::Pass *pass) { |
3744 | return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>(); |
3745 | } |
3746 | |
3747 | /// A clone method to create a copy of this pass. |
3748 | std::unique_ptr<::mlir::Pass> clonePass() const override { |
3749 | return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this)); |
3750 | } |
3751 | |
3752 | /// Return the dialect that must be loaded in the context before this pass. |
3753 | void getDependentDialects(::mlir::DialectRegistry ®istry) const override { |
3754 | |
3755 | } |
3756 | |
3757 | /// Explicitly declare the TypeID for this class. We declare an explicit private |
3758 | /// instantiation because Pass classes should only be visible by the current |
3759 | /// library. |
3760 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DTensorMeshPropagationBase<DerivedT>) |
3761 | |
3762 | protected: |
3763 | }; |
3764 | |
3765 | template <typename DerivedT> |
3766 | class DTensorMixedPrecisionReduceBase : public ::mlir::OperationPass<mlir::func::FuncOp> { |
3767 | public: |
3768 | using Base = DTensorMixedPrecisionReduceBase; |
3769 | |
3770 | DTensorMixedPrecisionReduceBase() : ::mlir::OperationPass<mlir::func::FuncOp>(::mlir::TypeID::get<DerivedT>()) {} |
3771 | DTensorMixedPrecisionReduceBase(const DTensorMixedPrecisionReduceBase &other) : ::mlir::OperationPass<mlir::func::FuncOp>(other) {} |
3772 | |
3773 | /// Returns the command-line argument attached to this pass. |
3774 | static constexpr ::llvm::StringLiteral getArgumentName() { |
3775 | return ::llvm::StringLiteral("dtensor-mixed-precision-reduce" ); |
3776 | } |
3777 | ::llvm::StringRef getArgument() const override { return "dtensor-mixed-precision-reduce" ; } |
3778 | |
3779 | ::llvm::StringRef getDescription() const override { return "Upcast tensors to higher precision type for reduction ops." ; } |
3780 | |
3781 | /// Returns the derived pass name. |
3782 | static constexpr ::llvm::StringLiteral getPassName() { |
3783 | return ::llvm::StringLiteral("DTensorMixedPrecisionReduce" ); |
3784 | } |
3785 | ::llvm::StringRef getName() const override { return "DTensorMixedPrecisionReduce" ; } |
3786 | |
3787 | /// Support isa/dyn_cast functionality for the derived pass class. |
3788 | static bool classof(const ::mlir::Pass *pass) { |
3789 | return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>(); |
3790 | } |
3791 | |
3792 | /// A clone method to create a copy of this pass. |
3793 | std::unique_ptr<::mlir::Pass> clonePass() const override { |
3794 | return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this)); |
3795 | } |
3796 | |
3797 | /// Return the dialect that must be loaded in the context before this pass. |
3798 | void getDependentDialects(::mlir::DialectRegistry ®istry) const override { |
3799 | |
3800 | } |
3801 | |
3802 | /// Explicitly declare the TypeID for this class. We declare an explicit private |
3803 | /// instantiation because Pass classes should only be visible by the current |
3804 | /// library. |
3805 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DTensorMixedPrecisionReduceBase<DerivedT>) |
3806 | |
3807 | protected: |
3808 | }; |
3809 | |
3810 | template <typename DerivedT> |
3811 | class DTensorMoveCompilationToHostBase : public ::mlir::OperationPass<mlir::ModuleOp> { |
3812 | public: |
3813 | using Base = DTensorMoveCompilationToHostBase; |
3814 | |
3815 | DTensorMoveCompilationToHostBase() : ::mlir::OperationPass<mlir::ModuleOp>(::mlir::TypeID::get<DerivedT>()) {} |
3816 | DTensorMoveCompilationToHostBase(const DTensorMoveCompilationToHostBase &other) : ::mlir::OperationPass<mlir::ModuleOp>(other) {} |
3817 | |
3818 | /// Returns the command-line argument attached to this pass. |
3819 | static constexpr ::llvm::StringLiteral getArgumentName() { |
3820 | return ::llvm::StringLiteral("dtensor-move-compilation-to-host" ); |
3821 | } |
3822 | ::llvm::StringRef getArgument() const override { return "dtensor-move-compilation-to-host" ; } |
3823 | |
3824 | ::llvm::StringRef getDescription() const override { return "Moves XLA compilation ops to host computation." ; } |
3825 | |
3826 | /// Returns the derived pass name. |
3827 | static constexpr ::llvm::StringLiteral getPassName() { |
3828 | return ::llvm::StringLiteral("DTensorMoveCompilationToHost" ); |
3829 | } |
3830 | ::llvm::StringRef getName() const override { return "DTensorMoveCompilationToHost" ; } |
3831 | |
3832 | /// Support isa/dyn_cast functionality for the derived pass class. |
3833 | static bool classof(const ::mlir::Pass *pass) { |
3834 | return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>(); |
3835 | } |
3836 | |
3837 | /// A clone method to create a copy of this pass. |
3838 | std::unique_ptr<::mlir::Pass> clonePass() const override { |
3839 | return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this)); |
3840 | } |
3841 | |
3842 | /// Return the dialect that must be loaded in the context before this pass. |
3843 | void getDependentDialects(::mlir::DialectRegistry ®istry) const override { |
3844 | |
3845 | } |
3846 | |
3847 | /// Explicitly declare the TypeID for this class. We declare an explicit private |
3848 | /// instantiation because Pass classes should only be visible by the current |
3849 | /// library. |
3850 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DTensorMoveCompilationToHostBase<DerivedT>) |
3851 | |
3852 | protected: |
3853 | }; |
3854 | |
3855 | template <typename DerivedT> |
3856 | class DTensorOpToDeviceClusterBase : public ::mlir::OperationPass<mlir::func::FuncOp> { |
3857 | public: |
3858 | using Base = DTensorOpToDeviceClusterBase; |
3859 | |
3860 | DTensorOpToDeviceClusterBase() : ::mlir::OperationPass<mlir::func::FuncOp>(::mlir::TypeID::get<DerivedT>()) {} |
3861 | DTensorOpToDeviceClusterBase(const DTensorOpToDeviceClusterBase &other) : ::mlir::OperationPass<mlir::func::FuncOp>(other) {} |
3862 | |
3863 | /// Returns the command-line argument attached to this pass. |
3864 | static constexpr ::llvm::StringLiteral getArgumentName() { |
3865 | return ::llvm::StringLiteral("dtensor-op-to-device-cluster" ); |
3866 | } |
3867 | ::llvm::StringRef getArgument() const override { return "dtensor-op-to-device-cluster" ; } |
3868 | |
3869 | ::llvm::StringRef getDescription() const override { return "Creates and wraps tf_device.cluster op for all TF ops" ; } |
3870 | |
3871 | /// Returns the derived pass name. |
3872 | static constexpr ::llvm::StringLiteral getPassName() { |
3873 | return ::llvm::StringLiteral("DTensorOpToDeviceCluster" ); |
3874 | } |
3875 | ::llvm::StringRef getName() const override { return "DTensorOpToDeviceCluster" ; } |
3876 | |
3877 | /// Support isa/dyn_cast functionality for the derived pass class. |
3878 | static bool classof(const ::mlir::Pass *pass) { |
3879 | return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>(); |
3880 | } |
3881 | |
3882 | /// A clone method to create a copy of this pass. |
3883 | std::unique_ptr<::mlir::Pass> clonePass() const override { |
3884 | return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this)); |
3885 | } |
3886 | |
3887 | /// Return the dialect that must be loaded in the context before this pass. |
3888 | void getDependentDialects(::mlir::DialectRegistry ®istry) const override { |
3889 | |
3890 | } |
3891 | |
3892 | /// Explicitly declare the TypeID for this class. We declare an explicit private |
3893 | /// instantiation because Pass classes should only be visible by the current |
3894 | /// library. |
3895 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DTensorOpToDeviceClusterBase<DerivedT>) |
3896 | |
3897 | protected: |
3898 | }; |
3899 | |
3900 | template <typename DerivedT> |
3901 | class DTensorPropagateDefaultLayoutBase : public ::mlir::OperationPass<mlir::func::FuncOp> { |
3902 | public: |
3903 | using Base = DTensorPropagateDefaultLayoutBase; |
3904 | |
3905 | DTensorPropagateDefaultLayoutBase() : ::mlir::OperationPass<mlir::func::FuncOp>(::mlir::TypeID::get<DerivedT>()) {} |
3906 | DTensorPropagateDefaultLayoutBase(const DTensorPropagateDefaultLayoutBase &other) : ::mlir::OperationPass<mlir::func::FuncOp>(other) {} |
3907 | |
3908 | /// Returns the command-line argument attached to this pass. |
3909 | static constexpr ::llvm::StringLiteral getArgumentName() { |
3910 | return ::llvm::StringLiteral("dtensor-propagate-default-layout" ); |
3911 | } |
3912 | ::llvm::StringRef getArgument() const override { return "dtensor-propagate-default-layout" ; } |
3913 | |
3914 | ::llvm::StringRef getDescription() const override { return "Converts layout attributes added by end users to DTensorLayout op." ; } |
3915 | |
3916 | /// Returns the derived pass name. |
3917 | static constexpr ::llvm::StringLiteral getPassName() { |
3918 | return ::llvm::StringLiteral("DTensorPropagateDefaultLayout" ); |
3919 | } |
3920 | ::llvm::StringRef getName() const override { return "DTensorPropagateDefaultLayout" ; } |
3921 | |
3922 | /// Support isa/dyn_cast functionality for the derived pass class. |
3923 | static bool classof(const ::mlir::Pass *pass) { |
3924 | return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>(); |
3925 | } |
3926 | |
3927 | /// A clone method to create a copy of this pass. |
3928 | std::unique_ptr<::mlir::Pass> clonePass() const override { |
3929 | return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this)); |
3930 | } |
3931 | |
3932 | /// Return the dialect that must be loaded in the context before this pass. |
3933 | void getDependentDialects(::mlir::DialectRegistry ®istry) const override { |
3934 | |
3935 | } |
3936 | |
3937 | /// Explicitly declare the TypeID for this class. We declare an explicit private |
3938 | /// instantiation because Pass classes should only be visible by the current |
3939 | /// library. |
3940 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DTensorPropagateDefaultLayoutBase<DerivedT>) |
3941 | |
3942 | protected: |
3943 | }; |
3944 | |
3945 | template <typename DerivedT> |
3946 | class DTensorPropagateDeviceIdToFunctionArgsBase : public ::mlir::OperationPass<mlir::ModuleOp> { |
3947 | public: |
3948 | using Base = DTensorPropagateDeviceIdToFunctionArgsBase; |
3949 | |
3950 | DTensorPropagateDeviceIdToFunctionArgsBase() : ::mlir::OperationPass<mlir::ModuleOp>(::mlir::TypeID::get<DerivedT>()) {} |
3951 | DTensorPropagateDeviceIdToFunctionArgsBase(const DTensorPropagateDeviceIdToFunctionArgsBase &other) : ::mlir::OperationPass<mlir::ModuleOp>(other) {} |
3952 | |
3953 | /// Returns the command-line argument attached to this pass. |
3954 | static constexpr ::llvm::StringLiteral getArgumentName() { |
3955 | return ::llvm::StringLiteral("dtensor-propagate-device-id-to-function-args" ); |
3956 | } |
3957 | ::llvm::StringRef getArgument() const override { return "dtensor-propagate-device-id-to-function-args" ; } |
3958 | |
3959 | ::llvm::StringRef getDescription() const override { return "Adds device id as arguments to all private function in graph." ; } |
3960 | |
3961 | /// Returns the derived pass name. |
3962 | static constexpr ::llvm::StringLiteral getPassName() { |
3963 | return ::llvm::StringLiteral("DTensorPropagateDeviceIdToFunctionArgs" ); |
3964 | } |
3965 | ::llvm::StringRef getName() const override { return "DTensorPropagateDeviceIdToFunctionArgs" ; } |
3966 | |
3967 | /// Support isa/dyn_cast functionality for the derived pass class. |
3968 | static bool classof(const ::mlir::Pass *pass) { |
3969 | return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>(); |
3970 | } |
3971 | |
3972 | /// A clone method to create a copy of this pass. |
3973 | std::unique_ptr<::mlir::Pass> clonePass() const override { |
3974 | return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this)); |
3975 | } |
3976 | |
3977 | /// Return the dialect that must be loaded in the context before this pass. |
3978 | void getDependentDialects(::mlir::DialectRegistry ®istry) const override { |
3979 | |
3980 | } |
3981 | |
3982 | /// Explicitly declare the TypeID for this class. We declare an explicit private |
3983 | /// instantiation because Pass classes should only be visible by the current |
3984 | /// library. |
3985 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DTensorPropagateDeviceIdToFunctionArgsBase<DerivedT>) |
3986 | |
3987 | protected: |
3988 | }; |
3989 | |
3990 | template <typename DerivedT> |
3991 | class DTensorReduceScatterLoweringBase : public ::mlir::OperationPass<mlir::ModuleOp> { |
3992 | public: |
3993 | using Base = DTensorReduceScatterLoweringBase; |
3994 | |
3995 | DTensorReduceScatterLoweringBase() : ::mlir::OperationPass<mlir::ModuleOp>(::mlir::TypeID::get<DerivedT>()) {} |
3996 | DTensorReduceScatterLoweringBase(const DTensorReduceScatterLoweringBase &other) : ::mlir::OperationPass<mlir::ModuleOp>(other) {} |
3997 | |
3998 | /// Returns the command-line argument attached to this pass. |
3999 | static constexpr ::llvm::StringLiteral getArgumentName() { |
4000 | return ::llvm::StringLiteral("dtensor-reduce-scatter-lowering" ); |
4001 | } |
4002 | ::llvm::StringRef getArgument() const override { return "dtensor-reduce-scatter-lowering" ; } |
4003 | |
4004 | ::llvm::StringRef getDescription() const override { return "Converts logical ReduceScatter ops into physical ReduceScatter ops." ; } |
4005 | |
4006 | /// Returns the derived pass name. |
4007 | static constexpr ::llvm::StringLiteral getPassName() { |
4008 | return ::llvm::StringLiteral("DTensorReduceScatterLowering" ); |
4009 | } |
4010 | ::llvm::StringRef getName() const override { return "DTensorReduceScatterLowering" ; } |
4011 | |
4012 | /// Support isa/dyn_cast functionality for the derived pass class. |
4013 | static bool classof(const ::mlir::Pass *pass) { |
4014 | return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>(); |
4015 | } |
4016 | |
4017 | /// A clone method to create a copy of this pass. |
4018 | std::unique_ptr<::mlir::Pass> clonePass() const override { |
4019 | return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this)); |
4020 | } |
4021 | |
4022 | /// Return the dialect that must be loaded in the context before this pass. |
4023 | void getDependentDialects(::mlir::DialectRegistry ®istry) const override { |
4024 | |
4025 | } |
4026 | |
4027 | /// Explicitly declare the TypeID for this class. We declare an explicit private |
4028 | /// instantiation because Pass classes should only be visible by the current |
4029 | /// library. |
4030 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DTensorReduceScatterLoweringBase<DerivedT>) |
4031 | |
4032 | protected: |
4033 | }; |
4034 | |
4035 | template <typename DerivedT> |
4036 | class DTensorSPMDExpansionBase : public ::mlir::OperationPass<mlir::ModuleOp> { |
4037 | public: |
4038 | using Base = DTensorSPMDExpansionBase; |
4039 | |
4040 | DTensorSPMDExpansionBase() : ::mlir::OperationPass<mlir::ModuleOp>(::mlir::TypeID::get<DerivedT>()) {} |
4041 | DTensorSPMDExpansionBase(const DTensorSPMDExpansionBase &other) : ::mlir::OperationPass<mlir::ModuleOp>(other) {} |
4042 | |
4043 | /// Returns the command-line argument attached to this pass. |
4044 | static constexpr ::llvm::StringLiteral getArgumentName() { |
4045 | return ::llvm::StringLiteral("dtensor-spmd-expansion" ); |
4046 | } |
4047 | ::llvm::StringRef getArgument() const override { return "dtensor-spmd-expansion" ; } |
4048 | |
4049 | ::llvm::StringRef getDescription() const override { return "Converts ops into SPMD expanded form." ; } |
4050 | |
4051 | /// Returns the derived pass name. |
4052 | static constexpr ::llvm::StringLiteral getPassName() { |
4053 | return ::llvm::StringLiteral("DTensorSPMDExpansion" ); |
4054 | } |
4055 | ::llvm::StringRef getName() const override { return "DTensorSPMDExpansion" ; } |
4056 | |
4057 | /// Support isa/dyn_cast functionality for the derived pass class. |
4058 | static bool classof(const ::mlir::Pass *pass) { |
4059 | return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>(); |
4060 | } |
4061 | |
4062 | /// A clone method to create a copy of this pass. |
4063 | std::unique_ptr<::mlir::Pass> clonePass() const override { |
4064 | return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this)); |
4065 | } |
4066 | |
4067 | /// Return the dialect that must be loaded in the context before this pass. |
4068 | void getDependentDialects(::mlir::DialectRegistry ®istry) const override { |
4069 | |
4070 | } |
4071 | |
4072 | /// Explicitly declare the TypeID for this class. We declare an explicit private |
4073 | /// instantiation because Pass classes should only be visible by the current |
4074 | /// library. |
4075 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DTensorSPMDExpansionBase<DerivedT>) |
4076 | |
4077 | protected: |
4078 | }; |
4079 | |
4080 | template <typename DerivedT> |
4081 | class DTensorSetDefaultShardingBase : public ::mlir::OperationPass<mlir::func::FuncOp> { |
4082 | public: |
4083 | using Base = DTensorSetDefaultShardingBase; |
4084 | |
4085 | DTensorSetDefaultShardingBase() : ::mlir::OperationPass<mlir::func::FuncOp>(::mlir::TypeID::get<DerivedT>()) {} |
4086 | DTensorSetDefaultShardingBase(const DTensorSetDefaultShardingBase &other) : ::mlir::OperationPass<mlir::func::FuncOp>(other) {} |
4087 | |
4088 | /// Returns the command-line argument attached to this pass. |
4089 | static constexpr ::llvm::StringLiteral getArgumentName() { |
4090 | return ::llvm::StringLiteral("dtensor-set-default-sharding" ); |
4091 | } |
4092 | ::llvm::StringRef getArgument() const override { return "dtensor-set-default-sharding" ; } |
4093 | |
4094 | ::llvm::StringRef getDescription() const override { return "Sets default sharding of TPU computation inputs/outputs to maximal." ; } |
4095 | |
4096 | /// Returns the derived pass name. |
4097 | static constexpr ::llvm::StringLiteral getPassName() { |
4098 | return ::llvm::StringLiteral("DTensorSetDefaultSharding" ); |
4099 | } |
4100 | ::llvm::StringRef getName() const override { return "DTensorSetDefaultSharding" ; } |
4101 | |
4102 | /// Support isa/dyn_cast functionality for the derived pass class. |
4103 | static bool classof(const ::mlir::Pass *pass) { |
4104 | return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>(); |
4105 | } |
4106 | |
4107 | /// A clone method to create a copy of this pass. |
4108 | std::unique_ptr<::mlir::Pass> clonePass() const override { |
4109 | return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this)); |
4110 | } |
4111 | |
4112 | /// Return the dialect that must be loaded in the context before this pass. |
4113 | void getDependentDialects(::mlir::DialectRegistry ®istry) const override { |
4114 | |
4115 | } |
4116 | |
4117 | /// Explicitly declare the TypeID for this class. We declare an explicit private |
4118 | /// instantiation because Pass classes should only be visible by the current |
4119 | /// library. |
4120 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DTensorSetDefaultShardingBase<DerivedT>) |
4121 | |
4122 | protected: |
4123 | }; |
4124 | |
4125 | template <typename DerivedT> |
4126 | class DTensorSparseExpansionBase : public ::mlir::OperationPass<mlir::ModuleOp> { |
4127 | public: |
4128 | using Base = DTensorSparseExpansionBase; |
4129 | |
4130 | DTensorSparseExpansionBase() : ::mlir::OperationPass<mlir::ModuleOp>(::mlir::TypeID::get<DerivedT>()) {} |
4131 | DTensorSparseExpansionBase(const DTensorSparseExpansionBase &other) : ::mlir::OperationPass<mlir::ModuleOp>(other) {} |
4132 | |
4133 | /// Returns the command-line argument attached to this pass. |
4134 | static constexpr ::llvm::StringLiteral getArgumentName() { |
4135 | return ::llvm::StringLiteral("dtensor-sparse-expansion" ); |
4136 | } |
4137 | ::llvm::StringRef getArgument() const override { return "dtensor-sparse-expansion" ; } |
4138 | |
4139 | ::llvm::StringRef getDescription() const override { return "Convert ops that take in SparseTensor input to its corresponding Sparse or Dense ops." ; } |
4140 | |
4141 | /// Returns the derived pass name. |
4142 | static constexpr ::llvm::StringLiteral getPassName() { |
4143 | return ::llvm::StringLiteral("DTensorSparseExpansion" ); |
4144 | } |
4145 | ::llvm::StringRef getName() const override { return "DTensorSparseExpansion" ; } |
4146 | |
4147 | /// Support isa/dyn_cast functionality for the derived pass class. |
4148 | static bool classof(const ::mlir::Pass *pass) { |
4149 | return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>(); |
4150 | } |
4151 | |
4152 | /// A clone method to create a copy of this pass. |
4153 | std::unique_ptr<::mlir::Pass> clonePass() const override { |
4154 | return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this)); |
4155 | } |
4156 | |
4157 | /// Return the dialect that must be loaded in the context before this pass. |
4158 | void getDependentDialects(::mlir::DialectRegistry ®istry) const override { |
4159 | |
4160 | } |
4161 | |
4162 | /// Explicitly declare the TypeID for this class. We declare an explicit private |
4163 | /// instantiation because Pass classes should only be visible by the current |
4164 | /// library. |
4165 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DTensorSparseExpansionBase<DerivedT>) |
4166 | |
4167 | protected: |
4168 | }; |
4169 | |
4170 | template <typename DerivedT> |
4171 | class DTensorSparseTensorToDenseTensorBase : public ::mlir::OperationPass<mlir::ModuleOp> { |
4172 | public: |
4173 | using Base = DTensorSparseTensorToDenseTensorBase; |
4174 | |
4175 | DTensorSparseTensorToDenseTensorBase() : ::mlir::OperationPass<mlir::ModuleOp>(::mlir::TypeID::get<DerivedT>()) {} |
4176 | DTensorSparseTensorToDenseTensorBase(const DTensorSparseTensorToDenseTensorBase &other) : ::mlir::OperationPass<mlir::ModuleOp>(other) {} |
4177 | |
4178 | /// Returns the command-line argument attached to this pass. |
4179 | static constexpr ::llvm::StringLiteral getArgumentName() { |
4180 | return ::llvm::StringLiteral("dtensor-sparse-tensor-to-dense-tensor" ); |
4181 | } |
4182 | ::llvm::StringRef getArgument() const override { return "dtensor-sparse-tensor-to-dense-tensor" ; } |
4183 | |
4184 | ::llvm::StringRef getDescription() const override { return "Converts SparseTensor inputs to its component tensors inputs and emits a SparseToDenseOp for every op that consumes a SparseTensor." ; } |
4185 | |
4186 | /// Returns the derived pass name. |
4187 | static constexpr ::llvm::StringLiteral getPassName() { |
4188 | return ::llvm::StringLiteral("DTensorSparseTensorToDenseTensor" ); |
4189 | } |
4190 | ::llvm::StringRef getName() const override { return "DTensorSparseTensorToDenseTensor" ; } |
4191 | |
4192 | /// Support isa/dyn_cast functionality for the derived pass class. |
4193 | static bool classof(const ::mlir::Pass *pass) { |
4194 | return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>(); |
4195 | } |
4196 | |
4197 | /// A clone method to create a copy of this pass. |
4198 | std::unique_ptr<::mlir::Pass> clonePass() const override { |
4199 | return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this)); |
4200 | } |
4201 | |
4202 | /// Return the dialect that must be loaded in the context before this pass. |
4203 | void getDependentDialects(::mlir::DialectRegistry ®istry) const override { |
4204 | |
4205 | } |
4206 | |
4207 | /// Explicitly declare the TypeID for this class. We declare an explicit private |
4208 | /// instantiation because Pass classes should only be visible by the current |
4209 | /// library. |
4210 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DTensorSparseTensorToDenseTensorBase<DerivedT>) |
4211 | |
4212 | protected: |
4213 | }; |
4214 | |
4215 | template <typename DerivedT> |
4216 | class DTensorTPUIntegrationBase : public ::mlir::OperationPass<mlir::ModuleOp> { |
4217 | public: |
4218 | using Base = DTensorTPUIntegrationBase; |
4219 | |
4220 | DTensorTPUIntegrationBase() : ::mlir::OperationPass<mlir::ModuleOp>(::mlir::TypeID::get<DerivedT>()) {} |
4221 | DTensorTPUIntegrationBase(const DTensorTPUIntegrationBase &other) : ::mlir::OperationPass<mlir::ModuleOp>(other) {} |
4222 | |
4223 | /// Returns the command-line argument attached to this pass. |
4224 | static constexpr ::llvm::StringLiteral getArgumentName() { |
4225 | return ::llvm::StringLiteral("dtensor-tpu-integration" ); |
4226 | } |
4227 | ::llvm::StringRef getArgument() const override { return "dtensor-tpu-integration" ; } |
4228 | |
4229 | ::llvm::StringRef getDescription() const override { return "Adds TPUReplicateMetadata and converts ops that run on TPU's to a single tf_device.cluster to be compatible with following TF2XLA MLIR passes." ; } |
4230 | |
4231 | /// Returns the derived pass name. |
4232 | static constexpr ::llvm::StringLiteral getPassName() { |
4233 | return ::llvm::StringLiteral("DTensorTPUIntegration" ); |
4234 | } |
4235 | ::llvm::StringRef getName() const override { return "DTensorTPUIntegration" ; } |
4236 | |
4237 | /// Support isa/dyn_cast functionality for the derived pass class. |
4238 | static bool classof(const ::mlir::Pass *pass) { |
4239 | return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>(); |
4240 | } |
4241 | |
4242 | /// A clone method to create a copy of this pass. |
4243 | std::unique_ptr<::mlir::Pass> clonePass() const override { |
4244 | return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this)); |
4245 | } |
4246 | |
4247 | /// Return the dialect that must be loaded in the context before this pass. |
4248 | void getDependentDialects(::mlir::DialectRegistry ®istry) const override { |
4249 | |
4250 | } |
4251 | |
4252 | /// Explicitly declare the TypeID for this class. We declare an explicit private |
4253 | /// instantiation because Pass classes should only be visible by the current |
4254 | /// library. |
4255 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DTensorTPUIntegrationBase<DerivedT>) |
4256 | |
4257 | protected: |
4258 | }; |
4259 | |
4260 | template <typename DerivedT> |
4261 | class DTensorTpuAddResourceDeviceAttributeBase : public ::mlir::OperationPass<mlir::ModuleOp> { |
4262 | public: |
4263 | using Base = DTensorTpuAddResourceDeviceAttributeBase; |
4264 | |
4265 | DTensorTpuAddResourceDeviceAttributeBase() : ::mlir::OperationPass<mlir::ModuleOp>(::mlir::TypeID::get<DerivedT>()) {} |
4266 | DTensorTpuAddResourceDeviceAttributeBase(const DTensorTpuAddResourceDeviceAttributeBase &other) : ::mlir::OperationPass<mlir::ModuleOp>(other) {} |
4267 | |
4268 | /// Returns the command-line argument attached to this pass. |
4269 | static constexpr ::llvm::StringLiteral getArgumentName() { |
4270 | return ::llvm::StringLiteral("dtensor-tpu-add-resource-device-attribute" ); |
4271 | } |
4272 | ::llvm::StringRef getArgument() const override { return "dtensor-tpu-add-resource-device-attribute" ; } |
4273 | |
4274 | ::llvm::StringRef getDescription() const override { return "Adds placeholder device attributes to resources accessed by TPU computation to enable buffer aliasing." ; } |
4275 | |
4276 | /// Returns the derived pass name. |
4277 | static constexpr ::llvm::StringLiteral getPassName() { |
4278 | return ::llvm::StringLiteral("DTensorTpuAddResourceDeviceAttribute" ); |
4279 | } |
4280 | ::llvm::StringRef getName() const override { return "DTensorTpuAddResourceDeviceAttribute" ; } |
4281 | |
4282 | /// Support isa/dyn_cast functionality for the derived pass class. |
4283 | static bool classof(const ::mlir::Pass *pass) { |
4284 | return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>(); |
4285 | } |
4286 | |
4287 | /// A clone method to create a copy of this pass. |
4288 | std::unique_ptr<::mlir::Pass> clonePass() const override { |
4289 | return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this)); |
4290 | } |
4291 | |
4292 | /// Return the dialect that must be loaded in the context before this pass. |
4293 | void getDependentDialects(::mlir::DialectRegistry ®istry) const override { |
4294 | |
4295 | } |
4296 | |
4297 | /// Explicitly declare the TypeID for this class. We declare an explicit private |
4298 | /// instantiation because Pass classes should only be visible by the current |
4299 | /// library. |
4300 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DTensorTpuAddResourceDeviceAttributeBase<DerivedT>) |
4301 | |
4302 | protected: |
4303 | }; |
4304 | |
4305 | template <typename DerivedT> |
4306 | class DTensorUndoMergeConstAcrossMeshBase : public ::mlir::OperationPass<mlir::func::FuncOp> { |
4307 | public: |
4308 | using Base = DTensorUndoMergeConstAcrossMeshBase; |
4309 | |
4310 | DTensorUndoMergeConstAcrossMeshBase() : ::mlir::OperationPass<mlir::func::FuncOp>(::mlir::TypeID::get<DerivedT>()) {} |
4311 | DTensorUndoMergeConstAcrossMeshBase(const DTensorUndoMergeConstAcrossMeshBase &other) : ::mlir::OperationPass<mlir::func::FuncOp>(other) {} |
4312 | |
4313 | /// Returns the command-line argument attached to this pass. |
4314 | static constexpr ::llvm::StringLiteral getArgumentName() { |
4315 | return ::llvm::StringLiteral("dtensor-undo-merge-const-across-mesh" ); |
4316 | } |
4317 | ::llvm::StringRef getArgument() const override { return "dtensor-undo-merge-const-across-mesh" ; } |
4318 | |
4319 | ::llvm::StringRef getDescription() const override { return "Undo constant merging across meshes" ; } |
4320 | |
4321 | /// Returns the derived pass name. |
4322 | static constexpr ::llvm::StringLiteral getPassName() { |
4323 | return ::llvm::StringLiteral("DTensorUndoMergeConstAcrossMesh" ); |
4324 | } |
4325 | ::llvm::StringRef getName() const override { return "DTensorUndoMergeConstAcrossMesh" ; } |
4326 | |
4327 | /// Support isa/dyn_cast functionality for the derived pass class. |
4328 | static bool classof(const ::mlir::Pass *pass) { |
4329 | return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>(); |
4330 | } |
4331 | |
4332 | /// A clone method to create a copy of this pass. |
4333 | std::unique_ptr<::mlir::Pass> clonePass() const override { |
4334 | return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this)); |
4335 | } |
4336 | |
4337 | /// Return the dialect that must be loaded in the context before this pass. |
4338 | void getDependentDialects(::mlir::DialectRegistry ®istry) const override { |
4339 | |
4340 | } |
4341 | |
4342 | /// Explicitly declare the TypeID for this class. We declare an explicit private |
4343 | /// instantiation because Pass classes should only be visible by the current |
4344 | /// library. |
4345 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DTensorUndoMergeConstAcrossMeshBase<DerivedT>) |
4346 | |
4347 | protected: |
4348 | }; |
4349 | |
4350 | template <typename DerivedT> |
4351 | class DTensorUpdateTPUMetadataBase : public ::mlir::OperationPass<mlir::ModuleOp> { |
4352 | public: |
4353 | using Base = DTensorUpdateTPUMetadataBase; |
4354 | |
4355 | DTensorUpdateTPUMetadataBase() : ::mlir::OperationPass<mlir::ModuleOp>(::mlir::TypeID::get<DerivedT>()) {} |
4356 | DTensorUpdateTPUMetadataBase(const DTensorUpdateTPUMetadataBase &other) : ::mlir::OperationPass<mlir::ModuleOp>(other) {} |
4357 | |
4358 | /// Returns the command-line argument attached to this pass. |
4359 | static constexpr ::llvm::StringLiteral getArgumentName() { |
4360 | return ::llvm::StringLiteral("dtensor-update-tpu-metadata" ); |
4361 | } |
4362 | ::llvm::StringRef getArgument() const override { return "dtensor-update-tpu-metadata" ; } |
4363 | |
4364 | ::llvm::StringRef getDescription() const override { return "Changes metadata on TPU specific ops such as device placement." ; } |
4365 | |
4366 | /// Returns the derived pass name. |
4367 | static constexpr ::llvm::StringLiteral getPassName() { |
4368 | return ::llvm::StringLiteral("DTensorUpdateTPUMetadata" ); |
4369 | } |
4370 | ::llvm::StringRef getName() const override { return "DTensorUpdateTPUMetadata" ; } |
4371 | |
4372 | /// Support isa/dyn_cast functionality for the derived pass class. |
4373 | static bool classof(const ::mlir::Pass *pass) { |
4374 | return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>(); |
4375 | } |
4376 | |
4377 | /// A clone method to create a copy of this pass. |
4378 | std::unique_ptr<::mlir::Pass> clonePass() const override { |
4379 | return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this)); |
4380 | } |
4381 | |
4382 | /// Return the dialect that must be loaded in the context before this pass. |
4383 | void getDependentDialects(::mlir::DialectRegistry ®istry) const override { |
4384 | |
4385 | } |
4386 | |
4387 | /// Explicitly declare the TypeID for this class. We declare an explicit private |
4388 | /// instantiation because Pass classes should only be visible by the current |
4389 | /// library. |
4390 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DTensorUpdateTPUMetadataBase<DerivedT>) |
4391 | |
4392 | protected: |
4393 | }; |
4394 | #undef GEN_PASS_CLASSES |
4395 | #endif // GEN_PASS_CLASSES |
4396 | |