1/* Autogenerated by mlir-tblgen; don't manually edit */
2
3//===----------------------------------------------------------------------===//
4// DTensorAllGatherLowering
5//===----------------------------------------------------------------------===//
6#ifdef GEN_PASS_DECL_DTENSORALLGATHERLOWERING
7#undef GEN_PASS_DECL_DTENSORALLGATHERLOWERING
8#endif // GEN_PASS_DECL_DTENSORALLGATHERLOWERING
9#ifdef GEN_PASS_DEF_DTENSORALLGATHERLOWERING
10namespace impl {
11
12template <typename DerivedT>
13class DTensorAllGatherLoweringBase : public ::mlir::OperationPass<mlir::ModuleOp> {
14public:
15 using Base = DTensorAllGatherLoweringBase;
16
17 DTensorAllGatherLoweringBase() : ::mlir::OperationPass<mlir::ModuleOp>(::mlir::TypeID::get<DerivedT>()) {}
18 DTensorAllGatherLoweringBase(const DTensorAllGatherLoweringBase &other) : ::mlir::OperationPass<mlir::ModuleOp>(other) {}
19
20 /// Returns the command-line argument attached to this pass.
21 static constexpr ::llvm::StringLiteral getArgumentName() {
22 return ::llvm::StringLiteral("dtensor-all-gather-lowering");
23 }
24 ::llvm::StringRef getArgument() const override { return "dtensor-all-gather-lowering"; }
25
26 ::llvm::StringRef getDescription() const override { return "Converts logical AllGather ops into physical AllGather ops."; }
27
28 /// Returns the derived pass name.
29 static constexpr ::llvm::StringLiteral getPassName() {
30 return ::llvm::StringLiteral("DTensorAllGatherLowering");
31 }
32 ::llvm::StringRef getName() const override { return "DTensorAllGatherLowering"; }
33
34 /// Support isa/dyn_cast functionality for the derived pass class.
35 static bool classof(const ::mlir::Pass *pass) {
36 return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>();
37 }
38
39 /// A clone method to create a copy of this pass.
40 std::unique_ptr<::mlir::Pass> clonePass() const override {
41 return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this));
42 }
43
44 /// Return the dialect that must be loaded in the context before this pass.
45 void getDependentDialects(::mlir::DialectRegistry &registry) const override {
46
47 }
48
49 /// Explicitly declare the TypeID for this class. We declare an explicit private
50 /// instantiation because Pass classes should only be visible by the current
51 /// library.
52 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DTensorAllGatherLoweringBase<DerivedT>)
53
54protected:
55private:
56};
57} // namespace impl
58#undef GEN_PASS_DEF_DTENSORALLGATHERLOWERING
59#endif // GEN_PASS_DEF_DTENSORALLGATHERLOWERING
60
61//===----------------------------------------------------------------------===//
62// DTensorAllReduceCombineOptimization
63//===----------------------------------------------------------------------===//
64#ifdef GEN_PASS_DECL_DTENSORALLREDUCECOMBINEOPTIMIZATION
65#undef GEN_PASS_DECL_DTENSORALLREDUCECOMBINEOPTIMIZATION
66#endif // GEN_PASS_DECL_DTENSORALLREDUCECOMBINEOPTIMIZATION
67#ifdef GEN_PASS_DEF_DTENSORALLREDUCECOMBINEOPTIMIZATION
68namespace impl {
69
70template <typename DerivedT>
71class DTensorAllReduceCombineOptimizationBase : public ::mlir::OperationPass<mlir::func::FuncOp> {
72public:
73 using Base = DTensorAllReduceCombineOptimizationBase;
74
75 DTensorAllReduceCombineOptimizationBase() : ::mlir::OperationPass<mlir::func::FuncOp>(::mlir::TypeID::get<DerivedT>()) {}
76 DTensorAllReduceCombineOptimizationBase(const DTensorAllReduceCombineOptimizationBase &other) : ::mlir::OperationPass<mlir::func::FuncOp>(other) {}
77
78 /// Returns the command-line argument attached to this pass.
79 static constexpr ::llvm::StringLiteral getArgumentName() {
80 return ::llvm::StringLiteral("dtensor-allreduce-combine-optimization");
81 }
82 ::llvm::StringRef getArgument() const override { return "dtensor-allreduce-combine-optimization"; }
83
84 ::llvm::StringRef getDescription() const override { return "Combine independent all reduce operations."; }
85
86 /// Returns the derived pass name.
87 static constexpr ::llvm::StringLiteral getPassName() {
88 return ::llvm::StringLiteral("DTensorAllReduceCombineOptimization");
89 }
90 ::llvm::StringRef getName() const override { return "DTensorAllReduceCombineOptimization"; }
91
92 /// Support isa/dyn_cast functionality for the derived pass class.
93 static bool classof(const ::mlir::Pass *pass) {
94 return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>();
95 }
96
97 /// A clone method to create a copy of this pass.
98 std::unique_ptr<::mlir::Pass> clonePass() const override {
99 return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this));
100 }
101
102 /// Return the dialect that must be loaded in the context before this pass.
103 void getDependentDialects(::mlir::DialectRegistry &registry) const override {
104
105 }
106
107 /// Explicitly declare the TypeID for this class. We declare an explicit private
108 /// instantiation because Pass classes should only be visible by the current
109 /// library.
110 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DTensorAllReduceCombineOptimizationBase<DerivedT>)
111
112protected:
113private:
114};
115} // namespace impl
116#undef GEN_PASS_DEF_DTENSORALLREDUCECOMBINEOPTIMIZATION
117#endif // GEN_PASS_DEF_DTENSORALLREDUCECOMBINEOPTIMIZATION
118
119//===----------------------------------------------------------------------===//
120// DTensorAllReduceLowering
121//===----------------------------------------------------------------------===//
122#ifdef GEN_PASS_DECL_DTENSORALLREDUCELOWERING
123#undef GEN_PASS_DECL_DTENSORALLREDUCELOWERING
124#endif // GEN_PASS_DECL_DTENSORALLREDUCELOWERING
125#ifdef GEN_PASS_DEF_DTENSORALLREDUCELOWERING
126namespace impl {
127
128template <typename DerivedT>
129class DTensorAllReduceLoweringBase : public ::mlir::OperationPass<mlir::ModuleOp> {
130public:
131 using Base = DTensorAllReduceLoweringBase;
132
133 DTensorAllReduceLoweringBase() : ::mlir::OperationPass<mlir::ModuleOp>(::mlir::TypeID::get<DerivedT>()) {}
134 DTensorAllReduceLoweringBase(const DTensorAllReduceLoweringBase &other) : ::mlir::OperationPass<mlir::ModuleOp>(other) {}
135
136 /// Returns the command-line argument attached to this pass.
137 static constexpr ::llvm::StringLiteral getArgumentName() {
138 return ::llvm::StringLiteral("dtensor-all-reduce-lowering");
139 }
140 ::llvm::StringRef getArgument() const override { return "dtensor-all-reduce-lowering"; }
141
142 ::llvm::StringRef getDescription() const override { return "Converts logical AllReduce ops into physical AllReduce ops."; }
143
144 /// Returns the derived pass name.
145 static constexpr ::llvm::StringLiteral getPassName() {
146 return ::llvm::StringLiteral("DTensorAllReduceLowering");
147 }
148 ::llvm::StringRef getName() const override { return "DTensorAllReduceLowering"; }
149
150 /// Support isa/dyn_cast functionality for the derived pass class.
151 static bool classof(const ::mlir::Pass *pass) {
152 return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>();
153 }
154
155 /// A clone method to create a copy of this pass.
156 std::unique_ptr<::mlir::Pass> clonePass() const override {
157 return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this));
158 }
159
160 /// Return the dialect that must be loaded in the context before this pass.
161 void getDependentDialects(::mlir::DialectRegistry &registry) const override {
162
163 }
164
165 /// Explicitly declare the TypeID for this class. We declare an explicit private
166 /// instantiation because Pass classes should only be visible by the current
167 /// library.
168 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DTensorAllReduceLoweringBase<DerivedT>)
169
170protected:
171private:
172};
173} // namespace impl
174#undef GEN_PASS_DEF_DTENSORALLREDUCELOWERING
175#endif // GEN_PASS_DEF_DTENSORALLREDUCELOWERING
176
177//===----------------------------------------------------------------------===//
178// DTensorAllReduceScatterOptimization
179//===----------------------------------------------------------------------===//
180#ifdef GEN_PASS_DECL_DTENSORALLREDUCESCATTEROPTIMIZATION
181#undef GEN_PASS_DECL_DTENSORALLREDUCESCATTEROPTIMIZATION
182#endif // GEN_PASS_DECL_DTENSORALLREDUCESCATTEROPTIMIZATION
183#ifdef GEN_PASS_DEF_DTENSORALLREDUCESCATTEROPTIMIZATION
184namespace impl {
185
186template <typename DerivedT>
187class DTensorAllReduceScatterOptimizationBase : public ::mlir::OperationPass<mlir::func::FuncOp> {
188public:
189 using Base = DTensorAllReduceScatterOptimizationBase;
190
191 DTensorAllReduceScatterOptimizationBase() : ::mlir::OperationPass<mlir::func::FuncOp>(::mlir::TypeID::get<DerivedT>()) {}
192 DTensorAllReduceScatterOptimizationBase(const DTensorAllReduceScatterOptimizationBase &other) : ::mlir::OperationPass<mlir::func::FuncOp>(other) {}
193
194 /// Returns the command-line argument attached to this pass.
195 static constexpr ::llvm::StringLiteral getArgumentName() {
196 return ::llvm::StringLiteral("dtensor-allreduce-scatter-optimization");
197 }
198 ::llvm::StringRef getArgument() const override { return "dtensor-allreduce-scatter-optimization"; }
199
200 ::llvm::StringRef getDescription() const override { return "Combines allreduce and scatter to reducescatter."; }
201
202 /// Returns the derived pass name.
203 static constexpr ::llvm::StringLiteral getPassName() {
204 return ::llvm::StringLiteral("DTensorAllReduceScatterOptimization");
205 }
206 ::llvm::StringRef getName() const override { return "DTensorAllReduceScatterOptimization"; }
207
208 /// Support isa/dyn_cast functionality for the derived pass class.
209 static bool classof(const ::mlir::Pass *pass) {
210 return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>();
211 }
212
213 /// A clone method to create a copy of this pass.
214 std::unique_ptr<::mlir::Pass> clonePass() const override {
215 return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this));
216 }
217
218 /// Return the dialect that must be loaded in the context before this pass.
219 void getDependentDialects(::mlir::DialectRegistry &registry) const override {
220
221 }
222
223 /// Explicitly declare the TypeID for this class. We declare an explicit private
224 /// instantiation because Pass classes should only be visible by the current
225 /// library.
226 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DTensorAllReduceScatterOptimizationBase<DerivedT>)
227
228protected:
229private:
230};
231} // namespace impl
232#undef GEN_PASS_DEF_DTENSORALLREDUCESCATTEROPTIMIZATION
233#endif // GEN_PASS_DEF_DTENSORALLREDUCESCATTEROPTIMIZATION
234
235//===----------------------------------------------------------------------===//
236// DTensorAllReduceSumOptimization
237//===----------------------------------------------------------------------===//
238#ifdef GEN_PASS_DECL_DTENSORALLREDUCESUMOPTIMIZATION
239#undef GEN_PASS_DECL_DTENSORALLREDUCESUMOPTIMIZATION
240#endif // GEN_PASS_DECL_DTENSORALLREDUCESUMOPTIMIZATION
241#ifdef GEN_PASS_DEF_DTENSORALLREDUCESUMOPTIMIZATION
242namespace impl {
243
244template <typename DerivedT>
245class DTensorAllReduceSumOptimizationBase : public ::mlir::OperationPass<mlir::func::FuncOp> {
246public:
247 using Base = DTensorAllReduceSumOptimizationBase;
248
249 DTensorAllReduceSumOptimizationBase() : ::mlir::OperationPass<mlir::func::FuncOp>(::mlir::TypeID::get<DerivedT>()) {}
250 DTensorAllReduceSumOptimizationBase(const DTensorAllReduceSumOptimizationBase &other) : ::mlir::OperationPass<mlir::func::FuncOp>(other) {}
251
252 /// Returns the command-line argument attached to this pass.
253 static constexpr ::llvm::StringLiteral getArgumentName() {
254 return ::llvm::StringLiteral("dtensor-allreduce-sum-optimization");
255 }
256 ::llvm::StringRef getArgument() const override { return "dtensor-allreduce-sum-optimization"; }
257
258 ::llvm::StringRef getDescription() const override { return "Changes order of add/allreduce to minimize all reduce operations."; }
259
260 /// Returns the derived pass name.
261 static constexpr ::llvm::StringLiteral getPassName() {
262 return ::llvm::StringLiteral("DTensorAllReduceSumOptimization");
263 }
264 ::llvm::StringRef getName() const override { return "DTensorAllReduceSumOptimization"; }
265
266 /// Support isa/dyn_cast functionality for the derived pass class.
267 static bool classof(const ::mlir::Pass *pass) {
268 return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>();
269 }
270
271 /// A clone method to create a copy of this pass.
272 std::unique_ptr<::mlir::Pass> clonePass() const override {
273 return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this));
274 }
275
276 /// Return the dialect that must be loaded in the context before this pass.
277 void getDependentDialects(::mlir::DialectRegistry &registry) const override {
278
279 }
280
281 /// Explicitly declare the TypeID for this class. We declare an explicit private
282 /// instantiation because Pass classes should only be visible by the current
283 /// library.
284 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DTensorAllReduceSumOptimizationBase<DerivedT>)
285
286protected:
287private:
288};
289} // namespace impl
290#undef GEN_PASS_DEF_DTENSORALLREDUCESUMOPTIMIZATION
291#endif // GEN_PASS_DEF_DTENSORALLREDUCESUMOPTIMIZATION
292
293//===----------------------------------------------------------------------===//
294// DTensorAllScatterLowering
295//===----------------------------------------------------------------------===//
296#ifdef GEN_PASS_DECL_DTENSORALLSCATTERLOWERING
297#undef GEN_PASS_DECL_DTENSORALLSCATTERLOWERING
298#endif // GEN_PASS_DECL_DTENSORALLSCATTERLOWERING
299#ifdef GEN_PASS_DEF_DTENSORALLSCATTERLOWERING
300namespace impl {
301
302template <typename DerivedT>
303class DTensorAllScatterLoweringBase : public ::mlir::OperationPass<mlir::ModuleOp> {
304public:
305 using Base = DTensorAllScatterLoweringBase;
306
307 DTensorAllScatterLoweringBase() : ::mlir::OperationPass<mlir::ModuleOp>(::mlir::TypeID::get<DerivedT>()) {}
308 DTensorAllScatterLoweringBase(const DTensorAllScatterLoweringBase &other) : ::mlir::OperationPass<mlir::ModuleOp>(other) {}
309
310 /// Returns the command-line argument attached to this pass.
311 static constexpr ::llvm::StringLiteral getArgumentName() {
312 return ::llvm::StringLiteral("dtensor-all-scatter-lowering");
313 }
314 ::llvm::StringRef getArgument() const override { return "dtensor-all-scatter-lowering"; }
315
316 ::llvm::StringRef getDescription() const override { return "Converts logical AllScatter ops into physical Split ops."; }
317
318 /// Returns the derived pass name.
319 static constexpr ::llvm::StringLiteral getPassName() {
320 return ::llvm::StringLiteral("DTensorAllScatterLowering");
321 }
322 ::llvm::StringRef getName() const override { return "DTensorAllScatterLowering"; }
323
324 /// Support isa/dyn_cast functionality for the derived pass class.
325 static bool classof(const ::mlir::Pass *pass) {
326 return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>();
327 }
328
329 /// A clone method to create a copy of this pass.
330 std::unique_ptr<::mlir::Pass> clonePass() const override {
331 return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this));
332 }
333
334 /// Return the dialect that must be loaded in the context before this pass.
335 void getDependentDialects(::mlir::DialectRegistry &registry) const override {
336
337 }
338
339 /// Explicitly declare the TypeID for this class. We declare an explicit private
340 /// instantiation because Pass classes should only be visible by the current
341 /// library.
342 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DTensorAllScatterLoweringBase<DerivedT>)
343
344protected:
345private:
346};
347} // namespace impl
348#undef GEN_PASS_DEF_DTENSORALLSCATTERLOWERING
349#endif // GEN_PASS_DEF_DTENSORALLSCATTERLOWERING
350
351//===----------------------------------------------------------------------===//
352// DTensorAnnotateGlobalShape
353//===----------------------------------------------------------------------===//
354#ifdef GEN_PASS_DECL_DTENSORANNOTATEGLOBALSHAPE
355#undef GEN_PASS_DECL_DTENSORANNOTATEGLOBALSHAPE
356#endif // GEN_PASS_DECL_DTENSORANNOTATEGLOBALSHAPE
357#ifdef GEN_PASS_DEF_DTENSORANNOTATEGLOBALSHAPE
358namespace impl {
359
360template <typename DerivedT>
361class DTensorAnnotateGlobalShapeBase : public ::mlir::OperationPass<mlir::ModuleOp> {
362public:
363 using Base = DTensorAnnotateGlobalShapeBase;
364
365 DTensorAnnotateGlobalShapeBase() : ::mlir::OperationPass<mlir::ModuleOp>(::mlir::TypeID::get<DerivedT>()) {}
366 DTensorAnnotateGlobalShapeBase(const DTensorAnnotateGlobalShapeBase &other) : ::mlir::OperationPass<mlir::ModuleOp>(other) {}
367
368 /// Returns the command-line argument attached to this pass.
369 static constexpr ::llvm::StringLiteral getArgumentName() {
370 return ::llvm::StringLiteral("dtensor-annotate-global-shape");
371 }
372 ::llvm::StringRef getArgument() const override { return "dtensor-annotate-global-shape"; }
373
374 ::llvm::StringRef getDescription() const override { return "Mark all operations and function arguments with `_global_shape` attribute to be used during SPMD expansion."; }
375
376 /// Returns the derived pass name.
377 static constexpr ::llvm::StringLiteral getPassName() {
378 return ::llvm::StringLiteral("DTensorAnnotateGlobalShape");
379 }
380 ::llvm::StringRef getName() const override { return "DTensorAnnotateGlobalShape"; }
381
382 /// Support isa/dyn_cast functionality for the derived pass class.
383 static bool classof(const ::mlir::Pass *pass) {
384 return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>();
385 }
386
387 /// A clone method to create a copy of this pass.
388 std::unique_ptr<::mlir::Pass> clonePass() const override {
389 return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this));
390 }
391
392 /// Return the dialect that must be loaded in the context before this pass.
393 void getDependentDialects(::mlir::DialectRegistry &registry) const override {
394
395 }
396
397 /// Explicitly declare the TypeID for this class. We declare an explicit private
398 /// instantiation because Pass classes should only be visible by the current
399 /// library.
400 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DTensorAnnotateGlobalShapeBase<DerivedT>)
401
402protected:
403private:
404};
405} // namespace impl
406#undef GEN_PASS_DEF_DTENSORANNOTATEGLOBALSHAPE
407#endif // GEN_PASS_DEF_DTENSORANNOTATEGLOBALSHAPE
408
409//===----------------------------------------------------------------------===//
410// DTensorClusterFunctionConversion
411//===----------------------------------------------------------------------===//
412#ifdef GEN_PASS_DECL_DTENSORCLUSTERFUNCTIONCONVERSION
413#undef GEN_PASS_DECL_DTENSORCLUSTERFUNCTIONCONVERSION
414#endif // GEN_PASS_DECL_DTENSORCLUSTERFUNCTIONCONVERSION
415#ifdef GEN_PASS_DEF_DTENSORCLUSTERFUNCTIONCONVERSION
416namespace impl {
417
418template <typename DerivedT>
419class DTensorClusterFunctionConversionBase : public ::mlir::OperationPass<mlir::ModuleOp> {
420public:
421 using Base = DTensorClusterFunctionConversionBase;
422
423 DTensorClusterFunctionConversionBase() : ::mlir::OperationPass<mlir::ModuleOp>(::mlir::TypeID::get<DerivedT>()) {}
424 DTensorClusterFunctionConversionBase(const DTensorClusterFunctionConversionBase &other) : ::mlir::OperationPass<mlir::ModuleOp>(other) {}
425
426 /// Returns the command-line argument attached to this pass.
427 static constexpr ::llvm::StringLiteral getArgumentName() {
428 return ::llvm::StringLiteral("dtensor-cluster-function-conversion");
429 }
430 ::llvm::StringRef getArgument() const override { return "dtensor-cluster-function-conversion"; }
431
432 ::llvm::StringRef getDescription() const override { return "Converts tf_device.cluster_func ops into TF StatefulPartitioned call op with mesh attribute."; }
433
434 /// Returns the derived pass name.
435 static constexpr ::llvm::StringLiteral getPassName() {
436 return ::llvm::StringLiteral("DTensorClusterFunctionConversion");
437 }
438 ::llvm::StringRef getName() const override { return "DTensorClusterFunctionConversion"; }
439
440 /// Support isa/dyn_cast functionality for the derived pass class.
441 static bool classof(const ::mlir::Pass *pass) {
442 return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>();
443 }
444
445 /// A clone method to create a copy of this pass.
446 std::unique_ptr<::mlir::Pass> clonePass() const override {
447 return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this));
448 }
449
450 /// Return the dialect that must be loaded in the context before this pass.
451 void getDependentDialects(::mlir::DialectRegistry &registry) const override {
452
453 }
454
455 /// Explicitly declare the TypeID for this class. We declare an explicit private
456 /// instantiation because Pass classes should only be visible by the current
457 /// library.
458 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DTensorClusterFunctionConversionBase<DerivedT>)
459
460protected:
461private:
462};
463} // namespace impl
464#undef GEN_PASS_DEF_DTENSORCLUSTERFUNCTIONCONVERSION
465#endif // GEN_PASS_DEF_DTENSORCLUSTERFUNCTIONCONVERSION
466
467//===----------------------------------------------------------------------===//
468// DTensorConstantFolding
469//===----------------------------------------------------------------------===//
470#ifdef GEN_PASS_DECL_DTENSORCONSTANTFOLDING
471#undef GEN_PASS_DECL_DTENSORCONSTANTFOLDING
472#endif // GEN_PASS_DECL_DTENSORCONSTANTFOLDING
473#ifdef GEN_PASS_DEF_DTENSORCONSTANTFOLDING
474namespace impl {
475
476template <typename DerivedT>
477class DTensorConstantFoldingBase : public ::mlir::OperationPass<mlir::func::FuncOp> {
478public:
479 using Base = DTensorConstantFoldingBase;
480
481 DTensorConstantFoldingBase() : ::mlir::OperationPass<mlir::func::FuncOp>(::mlir::TypeID::get<DerivedT>()) {}
482 DTensorConstantFoldingBase(const DTensorConstantFoldingBase &other) : ::mlir::OperationPass<mlir::func::FuncOp>(other) {}
483
484 /// Returns the command-line argument attached to this pass.
485 static constexpr ::llvm::StringLiteral getArgumentName() {
486 return ::llvm::StringLiteral("dtensor-constant-folding");
487 }
488 ::llvm::StringRef getArgument() const override { return "dtensor-constant-folding"; }
489
490 ::llvm::StringRef getDescription() const override { return "Folds constants operations."; }
491
492 /// Returns the derived pass name.
493 static constexpr ::llvm::StringLiteral getPassName() {
494 return ::llvm::StringLiteral("DTensorConstantFolding");
495 }
496 ::llvm::StringRef getName() const override { return "DTensorConstantFolding"; }
497
498 /// Support isa/dyn_cast functionality for the derived pass class.
499 static bool classof(const ::mlir::Pass *pass) {
500 return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>();
501 }
502
503 /// A clone method to create a copy of this pass.
504 std::unique_ptr<::mlir::Pass> clonePass() const override {
505 return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this));
506 }
507
508 /// Return the dialect that must be loaded in the context before this pass.
509 void getDependentDialects(::mlir::DialectRegistry &registry) const override {
510
511 }
512
513 /// Explicitly declare the TypeID for this class. We declare an explicit private
514 /// instantiation because Pass classes should only be visible by the current
515 /// library.
516 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DTensorConstantFoldingBase<DerivedT>)
517
518protected:
519private:
520};
521} // namespace impl
522#undef GEN_PASS_DEF_DTENSORCONSTANTFOLDING
523#endif // GEN_PASS_DEF_DTENSORCONSTANTFOLDING
524
525//===----------------------------------------------------------------------===//
526// DTensorDCE
527//===----------------------------------------------------------------------===//
528#ifdef GEN_PASS_DECL_DTENSORDCE
529#undef GEN_PASS_DECL_DTENSORDCE
530#endif // GEN_PASS_DECL_DTENSORDCE
531#ifdef GEN_PASS_DEF_DTENSORDCE
532namespace impl {
533
534template <typename DerivedT>
535class DTensorDCEBase : public ::mlir::OperationPass<mlir::func::FuncOp> {
536public:
537 using Base = DTensorDCEBase;
538
539 DTensorDCEBase() : ::mlir::OperationPass<mlir::func::FuncOp>(::mlir::TypeID::get<DerivedT>()) {}
540 DTensorDCEBase(const DTensorDCEBase &other) : ::mlir::OperationPass<mlir::func::FuncOp>(other) {}
541
542 /// Returns the command-line argument attached to this pass.
543 static constexpr ::llvm::StringLiteral getArgumentName() {
544 return ::llvm::StringLiteral("dtensor-dce");
545 }
546 ::llvm::StringRef getArgument() const override { return "dtensor-dce"; }
547
548 ::llvm::StringRef getDescription() const override { return "Removes unused ops from graph."; }
549
550 /// Returns the derived pass name.
551 static constexpr ::llvm::StringLiteral getPassName() {
552 return ::llvm::StringLiteral("DTensorDCE");
553 }
554 ::llvm::StringRef getName() const override { return "DTensorDCE"; }
555
556 /// Support isa/dyn_cast functionality for the derived pass class.
557 static bool classof(const ::mlir::Pass *pass) {
558 return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>();
559 }
560
561 /// A clone method to create a copy of this pass.
562 std::unique_ptr<::mlir::Pass> clonePass() const override {
563 return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this));
564 }
565
566 /// Return the dialect that must be loaded in the context before this pass.
567 void getDependentDialects(::mlir::DialectRegistry &registry) const override {
568
569 }
570
571 /// Explicitly declare the TypeID for this class. We declare an explicit private
572 /// instantiation because Pass classes should only be visible by the current
573 /// library.
574 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DTensorDCEBase<DerivedT>)
575
576protected:
577private:
578};
579} // namespace impl
580#undef GEN_PASS_DEF_DTENSORDCE
581#endif // GEN_PASS_DEF_DTENSORDCE
582
583//===----------------------------------------------------------------------===//
584// DTensorDesignateResourceHandleMesh
585//===----------------------------------------------------------------------===//
586#ifdef GEN_PASS_DECL_DTENSORDESIGNATERESOURCEHANDLEMESH
587#undef GEN_PASS_DECL_DTENSORDESIGNATERESOURCEHANDLEMESH
588#endif // GEN_PASS_DECL_DTENSORDESIGNATERESOURCEHANDLEMESH
589#ifdef GEN_PASS_DEF_DTENSORDESIGNATERESOURCEHANDLEMESH
590namespace impl {
591
592template <typename DerivedT>
593class DTensorDesignateResourceHandleMeshBase : public ::mlir::OperationPass<mlir::func::FuncOp> {
594public:
595 using Base = DTensorDesignateResourceHandleMeshBase;
596
597 DTensorDesignateResourceHandleMeshBase() : ::mlir::OperationPass<mlir::func::FuncOp>(::mlir::TypeID::get<DerivedT>()) {}
598 DTensorDesignateResourceHandleMeshBase(const DTensorDesignateResourceHandleMeshBase &other) : ::mlir::OperationPass<mlir::func::FuncOp>(other) {}
599
600 /// Returns the command-line argument attached to this pass.
601 static constexpr ::llvm::StringLiteral getArgumentName() {
602 return ::llvm::StringLiteral("dtensor-designate-resource-handle-mesh");
603 }
604 ::llvm::StringRef getArgument() const override { return "dtensor-designate-resource-handle-mesh"; }
605
606 ::llvm::StringRef getDescription() const override { return "Sets empty mesh attributes for device cluster that creates or destroys resource handles."; }
607
608 /// Returns the derived pass name.
609 static constexpr ::llvm::StringLiteral getPassName() {
610 return ::llvm::StringLiteral("DTensorDesignateResourceHandleMesh");
611 }
612 ::llvm::StringRef getName() const override { return "DTensorDesignateResourceHandleMesh"; }
613
614 /// Support isa/dyn_cast functionality for the derived pass class.
615 static bool classof(const ::mlir::Pass *pass) {
616 return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>();
617 }
618
619 /// A clone method to create a copy of this pass.
620 std::unique_ptr<::mlir::Pass> clonePass() const override {
621 return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this));
622 }
623
624 /// Return the dialect that must be loaded in the context before this pass.
625 void getDependentDialects(::mlir::DialectRegistry &registry) const override {
626
627 }
628
629 /// Explicitly declare the TypeID for this class. We declare an explicit private
630 /// instantiation because Pass classes should only be visible by the current
631 /// library.
632 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DTensorDesignateResourceHandleMeshBase<DerivedT>)
633
634protected:
635private:
636};
637} // namespace impl
638#undef GEN_PASS_DEF_DTENSORDESIGNATERESOURCEHANDLEMESH
639#endif // GEN_PASS_DEF_DTENSORDESIGNATERESOURCEHANDLEMESH
640
641//===----------------------------------------------------------------------===//
642// DTensorDeviceMeshClusterCoarsening
643//===----------------------------------------------------------------------===//
644#ifdef GEN_PASS_DECL_DTENSORDEVICEMESHCLUSTERCOARSENING
645#undef GEN_PASS_DECL_DTENSORDEVICEMESHCLUSTERCOARSENING
646#endif // GEN_PASS_DECL_DTENSORDEVICEMESHCLUSTERCOARSENING
647#ifdef GEN_PASS_DEF_DTENSORDEVICEMESHCLUSTERCOARSENING
648namespace impl {
649
650template <typename DerivedT>
651class DTensorDeviceMeshClusterCoarseningBase : public ::mlir::OperationPass<mlir::func::FuncOp> {
652public:
653 using Base = DTensorDeviceMeshClusterCoarseningBase;
654
655 DTensorDeviceMeshClusterCoarseningBase() : ::mlir::OperationPass<mlir::func::FuncOp>(::mlir::TypeID::get<DerivedT>()) {}
656 DTensorDeviceMeshClusterCoarseningBase(const DTensorDeviceMeshClusterCoarseningBase &other) : ::mlir::OperationPass<mlir::func::FuncOp>(other) {}
657
658 /// Returns the command-line argument attached to this pass.
659 static constexpr ::llvm::StringLiteral getArgumentName() {
660 return ::llvm::StringLiteral("dtensor-device-mesh-cluster-coarsening");
661 }
662 ::llvm::StringRef getArgument() const override { return "dtensor-device-mesh-cluster-coarsening"; }
663
664 ::llvm::StringRef getDescription() const override { return "Merges tf_device.cluster op with same mesh attribute."; }
665
666 /// Returns the derived pass name.
667 static constexpr ::llvm::StringLiteral getPassName() {
668 return ::llvm::StringLiteral("DTensorDeviceMeshClusterCoarsening");
669 }
670 ::llvm::StringRef getName() const override { return "DTensorDeviceMeshClusterCoarsening"; }
671
672 /// Support isa/dyn_cast functionality for the derived pass class.
673 static bool classof(const ::mlir::Pass *pass) {
674 return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>();
675 }
676
677 /// A clone method to create a copy of this pass.
678 std::unique_ptr<::mlir::Pass> clonePass() const override {
679 return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this));
680 }
681
682 /// Return the dialect that must be loaded in the context before this pass.
683 void getDependentDialects(::mlir::DialectRegistry &registry) const override {
684
685 }
686
687 /// Explicitly declare the TypeID for this class. We declare an explicit private
688 /// instantiation because Pass classes should only be visible by the current
689 /// library.
690 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DTensorDeviceMeshClusterCoarseningBase<DerivedT>)
691
692protected:
693private:
694};
695} // namespace impl
696#undef GEN_PASS_DEF_DTENSORDEVICEMESHCLUSTERCOARSENING
697#endif // GEN_PASS_DEF_DTENSORDEVICEMESHCLUSTERCOARSENING
698
699//===----------------------------------------------------------------------===//
700// DTensorEmbedding
701//===----------------------------------------------------------------------===//
702#ifdef GEN_PASS_DECL_DTENSOREMBEDDING
703#undef GEN_PASS_DECL_DTENSOREMBEDDING
704#endif // GEN_PASS_DECL_DTENSOREMBEDDING
705#ifdef GEN_PASS_DEF_DTENSOREMBEDDING
706namespace impl {
707
708template <typename DerivedT>
709class DTensorEmbeddingBase : public ::mlir::OperationPass<mlir::ModuleOp> {
710public:
711 using Base = DTensorEmbeddingBase;
712
713 DTensorEmbeddingBase() : ::mlir::OperationPass<mlir::ModuleOp>(::mlir::TypeID::get<DerivedT>()) {}
714 DTensorEmbeddingBase(const DTensorEmbeddingBase &other) : ::mlir::OperationPass<mlir::ModuleOp>(other) {}
715
716 /// Returns the command-line argument attached to this pass.
717 static constexpr ::llvm::StringLiteral getArgumentName() {
718 return ::llvm::StringLiteral("dtensor-embedding");
719 }
720 ::llvm::StringRef getArgument() const override { return "dtensor-embedding"; }
721
722 ::llvm::StringRef getDescription() const override { return "Extracts the ApplyEmbeddingOptimizer and checks if the attached optimizer is usable for TPU embeddings."; }
723
724 /// Returns the derived pass name.
725 static constexpr ::llvm::StringLiteral getPassName() {
726 return ::llvm::StringLiteral("DTensorEmbedding");
727 }
728 ::llvm::StringRef getName() const override { return "DTensorEmbedding"; }
729
730 /// Support isa/dyn_cast functionality for the derived pass class.
731 static bool classof(const ::mlir::Pass *pass) {
732 return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>();
733 }
734
735 /// A clone method to create a copy of this pass.
736 std::unique_ptr<::mlir::Pass> clonePass() const override {
737 return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this));
738 }
739
740 /// Return the dialect that must be loaded in the context before this pass.
741 void getDependentDialects(::mlir::DialectRegistry &registry) const override {
742
743 registry.insert<::mlir::chlo::ChloDialect>();
744
745 registry.insert<::mlir::memref::MemRefDialect>();
746
747 registry.insert<::mlir::mhlo::MhloDialect>();
748
749 registry.insert<::mlir::shape::ShapeDialect>();
750
751 registry.insert<::mlir::scf::SCFDialect>();
752
753 registry.insert<::mlir::tensor::TensorDialect>();
754
755 }
756
757 /// Explicitly declare the TypeID for this class. We declare an explicit private
758 /// instantiation because Pass classes should only be visible by the current
759 /// library.
760 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DTensorEmbeddingBase<DerivedT>)
761
762protected:
763private:
764};
765} // namespace impl
766#undef GEN_PASS_DEF_DTENSOREMBEDDING
767#endif // GEN_PASS_DEF_DTENSOREMBEDDING
768
769//===----------------------------------------------------------------------===//
770// DTensorEmbeddingCheckpoint
771//===----------------------------------------------------------------------===//
772#ifdef GEN_PASS_DECL_DTENSOREMBEDDINGCHECKPOINT
773#undef GEN_PASS_DECL_DTENSOREMBEDDINGCHECKPOINT
774#endif // GEN_PASS_DECL_DTENSOREMBEDDINGCHECKPOINT
775#ifdef GEN_PASS_DEF_DTENSOREMBEDDINGCHECKPOINT
776namespace impl {
777
778template <typename DerivedT>
779class DTensorEmbeddingCheckpointBase : public ::mlir::OperationPass<mlir::ModuleOp> {
780public:
781 using Base = DTensorEmbeddingCheckpointBase;
782
783 DTensorEmbeddingCheckpointBase() : ::mlir::OperationPass<mlir::ModuleOp>(::mlir::TypeID::get<DerivedT>()) {}
784 DTensorEmbeddingCheckpointBase(const DTensorEmbeddingCheckpointBase &other) : ::mlir::OperationPass<mlir::ModuleOp>(other) {}
785
786 /// Returns the command-line argument attached to this pass.
787 static constexpr ::llvm::StringLiteral getArgumentName() {
788 return ::llvm::StringLiteral("dtensor-embedding-checkpoint");
789 }
790 ::llvm::StringRef getArgument() const override { return "dtensor-embedding-checkpoint"; }
791
792 ::llvm::StringRef getDescription() const override { return "Checkpoint pass to save or load TPU embeddings variables."; }
793
794 /// Returns the derived pass name.
795 static constexpr ::llvm::StringLiteral getPassName() {
796 return ::llvm::StringLiteral("DTensorEmbeddingCheckpoint");
797 }
798 ::llvm::StringRef getName() const override { return "DTensorEmbeddingCheckpoint"; }
799
800 /// Support isa/dyn_cast functionality for the derived pass class.
801 static bool classof(const ::mlir::Pass *pass) {
802 return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>();
803 }
804
805 /// A clone method to create a copy of this pass.
806 std::unique_ptr<::mlir::Pass> clonePass() const override {
807 return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this));
808 }
809
810 /// Return the dialect that must be loaded in the context before this pass.
811 void getDependentDialects(::mlir::DialectRegistry &registry) const override {
812
813 }
814
815 /// Explicitly declare the TypeID for this class. We declare an explicit private
816 /// instantiation because Pass classes should only be visible by the current
817 /// library.
818 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DTensorEmbeddingCheckpointBase<DerivedT>)
819
820protected:
821private:
822};
823} // namespace impl
824#undef GEN_PASS_DEF_DTENSOREMBEDDINGCHECKPOINT
825#endif // GEN_PASS_DEF_DTENSOREMBEDDINGCHECKPOINT
826
827//===----------------------------------------------------------------------===//
828// DTensorEmbeddingV2
829//===----------------------------------------------------------------------===//
830#ifdef GEN_PASS_DECL_DTENSOREMBEDDINGV2
831#undef GEN_PASS_DECL_DTENSOREMBEDDINGV2
832#endif // GEN_PASS_DECL_DTENSOREMBEDDINGV2
833#ifdef GEN_PASS_DEF_DTENSOREMBEDDINGV2
834namespace impl {
835
836template <typename DerivedT>
837class DTensorEmbeddingV2Base : public ::mlir::OperationPass<mlir::ModuleOp> {
838public:
839 using Base = DTensorEmbeddingV2Base;
840
841 DTensorEmbeddingV2Base() : ::mlir::OperationPass<mlir::ModuleOp>(::mlir::TypeID::get<DerivedT>()) {}
842 DTensorEmbeddingV2Base(const DTensorEmbeddingV2Base &other) : ::mlir::OperationPass<mlir::ModuleOp>(other) {}
843
844 /// Returns the command-line argument attached to this pass.
845 static constexpr ::llvm::StringLiteral getArgumentName() {
846 return ::llvm::StringLiteral("dtensor-embedding-v2");
847 }
848 ::llvm::StringRef getArgument() const override { return "dtensor-embedding-v2"; }
849
850 ::llvm::StringRef getDescription() const override { return "Embedding Pass"; }
851
852 /// Returns the derived pass name.
853 static constexpr ::llvm::StringLiteral getPassName() {
854 return ::llvm::StringLiteral("DTensorEmbeddingV2");
855 }
856 ::llvm::StringRef getName() const override { return "DTensorEmbeddingV2"; }
857
858 /// Support isa/dyn_cast functionality for the derived pass class.
859 static bool classof(const ::mlir::Pass *pass) {
860 return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>();
861 }
862
863 /// A clone method to create a copy of this pass.
864 std::unique_ptr<::mlir::Pass> clonePass() const override {
865 return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this));
866 }
867
868 /// Return the dialect that must be loaded in the context before this pass.
869 void getDependentDialects(::mlir::DialectRegistry &registry) const override {
870
871 }
872
873 /// Explicitly declare the TypeID for this class. We declare an explicit private
874 /// instantiation because Pass classes should only be visible by the current
875 /// library.
876 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DTensorEmbeddingV2Base<DerivedT>)
877
878protected:
879private:
880};
881} // namespace impl
882#undef GEN_PASS_DEF_DTENSOREMBEDDINGV2
883#endif // GEN_PASS_DEF_DTENSOREMBEDDINGV2
884
885//===----------------------------------------------------------------------===//
886// DTensorFunctionRenaming
887//===----------------------------------------------------------------------===//
888#ifdef GEN_PASS_DECL_DTENSORFUNCTIONRENAMING
889#undef GEN_PASS_DECL_DTENSORFUNCTIONRENAMING
890#endif // GEN_PASS_DECL_DTENSORFUNCTIONRENAMING
891#ifdef GEN_PASS_DEF_DTENSORFUNCTIONRENAMING
892namespace impl {
893
894template <typename DerivedT>
895class DTensorFunctionRenamingBase : public ::mlir::OperationPass<mlir::ModuleOp> {
896public:
897 using Base = DTensorFunctionRenamingBase;
898
899 DTensorFunctionRenamingBase() : ::mlir::OperationPass<mlir::ModuleOp>(::mlir::TypeID::get<DerivedT>()) {}
900 DTensorFunctionRenamingBase(const DTensorFunctionRenamingBase &other) : ::mlir::OperationPass<mlir::ModuleOp>(other) {}
901
902 /// Returns the command-line argument attached to this pass.
903 static constexpr ::llvm::StringLiteral getArgumentName() {
904 return ::llvm::StringLiteral("dtensor-function-renaming");
905 }
906 ::llvm::StringRef getArgument() const override { return "dtensor-function-renaming"; }
907
908 ::llvm::StringRef getDescription() const override { return "Renames private functions by appending an id to each name. This is used to make private function names unique across modules."; }
909
910 /// Returns the derived pass name.
911 static constexpr ::llvm::StringLiteral getPassName() {
912 return ::llvm::StringLiteral("DTensorFunctionRenaming");
913 }
914 ::llvm::StringRef getName() const override { return "DTensorFunctionRenaming"; }
915
916 /// Support isa/dyn_cast functionality for the derived pass class.
917 static bool classof(const ::mlir::Pass *pass) {
918 return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>();
919 }
920
921 /// A clone method to create a copy of this pass.
922 std::unique_ptr<::mlir::Pass> clonePass() const override {
923 return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this));
924 }
925
926 /// Return the dialect that must be loaded in the context before this pass.
927 void getDependentDialects(::mlir::DialectRegistry &registry) const override {
928
929 }
930
931 /// Explicitly declare the TypeID for this class. We declare an explicit private
932 /// instantiation because Pass classes should only be visible by the current
933 /// library.
934 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DTensorFunctionRenamingBase<DerivedT>)
935
936protected:
937private:
938};
939} // namespace impl
940#undef GEN_PASS_DEF_DTENSORFUNCTIONRENAMING
941#endif // GEN_PASS_DEF_DTENSORFUNCTIONRENAMING
942
943//===----------------------------------------------------------------------===//
944// DTensorHandleCrossClusterDependencies
945//===----------------------------------------------------------------------===//
946#ifdef GEN_PASS_DECL_DTENSORHANDLECROSSCLUSTERDEPENDENCIES
947#undef GEN_PASS_DECL_DTENSORHANDLECROSSCLUSTERDEPENDENCIES
948#endif // GEN_PASS_DECL_DTENSORHANDLECROSSCLUSTERDEPENDENCIES
949#ifdef GEN_PASS_DEF_DTENSORHANDLECROSSCLUSTERDEPENDENCIES
950namespace impl {
951
952template <typename DerivedT>
953class DTensorHandleCrossClusterDependenciesBase : public ::mlir::OperationPass<mlir::ModuleOp> {
954public:
955 using Base = DTensorHandleCrossClusterDependenciesBase;
956
957 DTensorHandleCrossClusterDependenciesBase() : ::mlir::OperationPass<mlir::ModuleOp>(::mlir::TypeID::get<DerivedT>()) {}
958 DTensorHandleCrossClusterDependenciesBase(const DTensorHandleCrossClusterDependenciesBase &other) : ::mlir::OperationPass<mlir::ModuleOp>(other) {}
959
960 /// Returns the command-line argument attached to this pass.
961 static constexpr ::llvm::StringLiteral getArgumentName() {
962 return ::llvm::StringLiteral("dtensor-handle_cross_cluster_dependences");
963 }
964 ::llvm::StringRef getArgument() const override { return "dtensor-handle_cross_cluster_dependences"; }
965
966 ::llvm::StringRef getDescription() const override { return "Lowers cross mesh cluster data depedences as send/recvs."; }
967
968 /// Returns the derived pass name.
969 static constexpr ::llvm::StringLiteral getPassName() {
970 return ::llvm::StringLiteral("DTensorHandleCrossClusterDependencies");
971 }
972 ::llvm::StringRef getName() const override { return "DTensorHandleCrossClusterDependencies"; }
973
974 /// Support isa/dyn_cast functionality for the derived pass class.
975 static bool classof(const ::mlir::Pass *pass) {
976 return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>();
977 }
978
979 /// A clone method to create a copy of this pass.
980 std::unique_ptr<::mlir::Pass> clonePass() const override {
981 return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this));
982 }
983
984 /// Return the dialect that must be loaded in the context before this pass.
985 void getDependentDialects(::mlir::DialectRegistry &registry) const override {
986
987 }
988
989 /// Explicitly declare the TypeID for this class. We declare an explicit private
990 /// instantiation because Pass classes should only be visible by the current
991 /// library.
992 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DTensorHandleCrossClusterDependenciesBase<DerivedT>)
993
994protected:
995private:
996};
997} // namespace impl
998#undef GEN_PASS_DEF_DTENSORHANDLECROSSCLUSTERDEPENDENCIES
999#endif // GEN_PASS_DEF_DTENSORHANDLECROSSCLUSTERDEPENDENCIES
1000
1001//===----------------------------------------------------------------------===//
1002// DTensorInferShapesForRestoreV2Op
1003//===----------------------------------------------------------------------===//
1004#ifdef GEN_PASS_DECL_DTENSORINFERSHAPESFORRESTOREV2OP
1005#undef GEN_PASS_DECL_DTENSORINFERSHAPESFORRESTOREV2OP
1006#endif // GEN_PASS_DECL_DTENSORINFERSHAPESFORRESTOREV2OP
1007#ifdef GEN_PASS_DEF_DTENSORINFERSHAPESFORRESTOREV2OP
1008namespace impl {
1009
1010template <typename DerivedT>
1011class DTensorInferShapesForRestoreV2OpBase : public ::mlir::OperationPass<mlir::ModuleOp> {
1012public:
1013 using Base = DTensorInferShapesForRestoreV2OpBase;
1014
1015 DTensorInferShapesForRestoreV2OpBase() : ::mlir::OperationPass<mlir::ModuleOp>(::mlir::TypeID::get<DerivedT>()) {}
1016 DTensorInferShapesForRestoreV2OpBase(const DTensorInferShapesForRestoreV2OpBase &other) : ::mlir::OperationPass<mlir::ModuleOp>(other) {}
1017
1018 /// Returns the command-line argument attached to this pass.
1019 static constexpr ::llvm::StringLiteral getArgumentName() {
1020 return ::llvm::StringLiteral("dtensor-infer-shapes-for-restorev2-op");
1021 }
1022 ::llvm::StringRef getArgument() const override { return "dtensor-infer-shapes-for-restorev2-op"; }
1023
1024 ::llvm::StringRef getDescription() const override { return "Infer shapes of the outputs of tf.RestoreV2Op from the AssignVariableOps that consume those outputs. This is used for DTensor integration with TF Checkpoint."; }
1025
1026 /// Returns the derived pass name.
1027 static constexpr ::llvm::StringLiteral getPassName() {
1028 return ::llvm::StringLiteral("DTensorInferShapesForRestoreV2Op");
1029 }
1030 ::llvm::StringRef getName() const override { return "DTensorInferShapesForRestoreV2Op"; }
1031
1032 /// Support isa/dyn_cast functionality for the derived pass class.
1033 static bool classof(const ::mlir::Pass *pass) {
1034 return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>();
1035 }
1036
1037 /// A clone method to create a copy of this pass.
1038 std::unique_ptr<::mlir::Pass> clonePass() const override {
1039 return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this));
1040 }
1041
1042 /// Return the dialect that must be loaded in the context before this pass.
1043 void getDependentDialects(::mlir::DialectRegistry &registry) const override {
1044
1045 }
1046
1047 /// Explicitly declare the TypeID for this class. We declare an explicit private
1048 /// instantiation because Pass classes should only be visible by the current
1049 /// library.
1050 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DTensorInferShapesForRestoreV2OpBase<DerivedT>)
1051
1052protected:
1053private:
1054};
1055} // namespace impl
1056#undef GEN_PASS_DEF_DTENSORINFERSHAPESFORRESTOREV2OP
1057#endif // GEN_PASS_DEF_DTENSORINFERSHAPESFORRESTOREV2OP
1058
1059//===----------------------------------------------------------------------===//
1060// DTensorLayoutPropagationV2
1061//===----------------------------------------------------------------------===//
1062#ifdef GEN_PASS_DECL_DTENSORLAYOUTPROPAGATIONV2
1063#undef GEN_PASS_DECL_DTENSORLAYOUTPROPAGATIONV2
1064#endif // GEN_PASS_DECL_DTENSORLAYOUTPROPAGATIONV2
1065#ifdef GEN_PASS_DEF_DTENSORLAYOUTPROPAGATIONV2
1066namespace impl {
1067
1068template <typename DerivedT>
1069class DTensorLayoutPropagationV2Base : public ::mlir::OperationPass<mlir::ModuleOp> {
1070public:
1071 using Base = DTensorLayoutPropagationV2Base;
1072
1073 DTensorLayoutPropagationV2Base() : ::mlir::OperationPass<mlir::ModuleOp>(::mlir::TypeID::get<DerivedT>()) {}
1074 DTensorLayoutPropagationV2Base(const DTensorLayoutPropagationV2Base &other) : ::mlir::OperationPass<mlir::ModuleOp>(other) {}
1075
1076 /// Returns the command-line argument attached to this pass.
1077 static constexpr ::llvm::StringLiteral getArgumentName() {
1078 return ::llvm::StringLiteral("dtensor-layout-propagation-v2");
1079 }
1080 ::llvm::StringRef getArgument() const override { return "dtensor-layout-propagation-v2"; }
1081
1082 ::llvm::StringRef getDescription() const override { return "Propagates layout information for all the TF Ops."; }
1083
1084 /// Returns the derived pass name.
1085 static constexpr ::llvm::StringLiteral getPassName() {
1086 return ::llvm::StringLiteral("DTensorLayoutPropagationV2");
1087 }
1088 ::llvm::StringRef getName() const override { return "DTensorLayoutPropagationV2"; }
1089
1090 /// Support isa/dyn_cast functionality for the derived pass class.
1091 static bool classof(const ::mlir::Pass *pass) {
1092 return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>();
1093 }
1094
1095 /// A clone method to create a copy of this pass.
1096 std::unique_ptr<::mlir::Pass> clonePass() const override {
1097 return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this));
1098 }
1099
1100 /// Return the dialect that must be loaded in the context before this pass.
1101 void getDependentDialects(::mlir::DialectRegistry &registry) const override {
1102
1103 }
1104
1105 /// Explicitly declare the TypeID for this class. We declare an explicit private
1106 /// instantiation because Pass classes should only be visible by the current
1107 /// library.
1108 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DTensorLayoutPropagationV2Base<DerivedT>)
1109
1110protected:
1111private:
1112};
1113} // namespace impl
1114#undef GEN_PASS_DEF_DTENSORLAYOUTPROPAGATIONV2
1115#endif // GEN_PASS_DEF_DTENSORLAYOUTPROPAGATIONV2
1116
1117//===----------------------------------------------------------------------===//
1118// DTensorLowerSendRecv
1119//===----------------------------------------------------------------------===//
1120#ifdef GEN_PASS_DECL_DTENSORLOWERSENDRECV
1121#undef GEN_PASS_DECL_DTENSORLOWERSENDRECV
1122#endif // GEN_PASS_DECL_DTENSORLOWERSENDRECV
1123#ifdef GEN_PASS_DEF_DTENSORLOWERSENDRECV
1124namespace impl {
1125
1126template <typename DerivedT>
1127class DTensorLowerSendRecvBase : public ::mlir::OperationPass<mlir::ModuleOp> {
1128public:
1129 using Base = DTensorLowerSendRecvBase;
1130
1131 DTensorLowerSendRecvBase() : ::mlir::OperationPass<mlir::ModuleOp>(::mlir::TypeID::get<DerivedT>()) {}
1132 DTensorLowerSendRecvBase(const DTensorLowerSendRecvBase &other) : ::mlir::OperationPass<mlir::ModuleOp>(other) {}
1133
1134 /// Returns the command-line argument attached to this pass.
1135 static constexpr ::llvm::StringLiteral getArgumentName() {
1136 return ::llvm::StringLiteral("dtensor-lower-send-recv");
1137 }
1138 ::llvm::StringRef getArgument() const override { return "dtensor-lower-send-recv"; }
1139
1140 ::llvm::StringRef getDescription() const override { return "Lowers DTensorSend/DTensorRecv ops to send/recv ops."; }
1141
1142 /// Returns the derived pass name.
1143 static constexpr ::llvm::StringLiteral getPassName() {
1144 return ::llvm::StringLiteral("DTensorLowerSendRecv");
1145 }
1146 ::llvm::StringRef getName() const override { return "DTensorLowerSendRecv"; }
1147
1148 /// Support isa/dyn_cast functionality for the derived pass class.
1149 static bool classof(const ::mlir::Pass *pass) {
1150 return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>();
1151 }
1152
1153 /// A clone method to create a copy of this pass.
1154 std::unique_ptr<::mlir::Pass> clonePass() const override {
1155 return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this));
1156 }
1157
1158 /// Return the dialect that must be loaded in the context before this pass.
1159 void getDependentDialects(::mlir::DialectRegistry &registry) const override {
1160
1161 }
1162
1163 /// Explicitly declare the TypeID for this class. We declare an explicit private
1164 /// instantiation because Pass classes should only be visible by the current
1165 /// library.
1166 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DTensorLowerSendRecvBase<DerivedT>)
1167
1168protected:
1169private:
1170};
1171} // namespace impl
1172#undef GEN_PASS_DEF_DTENSORLOWERSENDRECV
1173#endif // GEN_PASS_DEF_DTENSORLOWERSENDRECV
1174
1175//===----------------------------------------------------------------------===//
1176// DTensorMergeClusters
1177//===----------------------------------------------------------------------===//
1178#ifdef GEN_PASS_DECL_DTENSORMERGECLUSTERS
1179#undef GEN_PASS_DECL_DTENSORMERGECLUSTERS
1180#endif // GEN_PASS_DECL_DTENSORMERGECLUSTERS
1181#ifdef GEN_PASS_DEF_DTENSORMERGECLUSTERS
1182namespace impl {
1183
1184template <typename DerivedT>
1185class DTensorMergeClustersBase : public ::mlir::OperationPass<mlir::ModuleOp> {
1186public:
1187 using Base = DTensorMergeClustersBase;
1188
1189 DTensorMergeClustersBase() : ::mlir::OperationPass<mlir::ModuleOp>(::mlir::TypeID::get<DerivedT>()) {}
1190 DTensorMergeClustersBase(const DTensorMergeClustersBase &other) : ::mlir::OperationPass<mlir::ModuleOp>(other) {}
1191
1192 /// Returns the command-line argument attached to this pass.
1193 static constexpr ::llvm::StringLiteral getArgumentName() {
1194 return ::llvm::StringLiteral("dtensor-merge-clusters");
1195 }
1196 ::llvm::StringRef getArgument() const override { return "dtensor-merge-clusters"; }
1197
1198 ::llvm::StringRef getDescription() const override { return "Merges tf_device.Clusters ops with same mesh specification."; }
1199
1200 /// Returns the derived pass name.
1201 static constexpr ::llvm::StringLiteral getPassName() {
1202 return ::llvm::StringLiteral("DTensorMergeClusters");
1203 }
1204 ::llvm::StringRef getName() const override { return "DTensorMergeClusters"; }
1205
1206 /// Support isa/dyn_cast functionality for the derived pass class.
1207 static bool classof(const ::mlir::Pass *pass) {
1208 return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>();
1209 }
1210
1211 /// A clone method to create a copy of this pass.
1212 std::unique_ptr<::mlir::Pass> clonePass() const override {
1213 return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this));
1214 }
1215
1216 /// Return the dialect that must be loaded in the context before this pass.
1217 void getDependentDialects(::mlir::DialectRegistry &registry) const override {
1218
1219 }
1220
1221 /// Explicitly declare the TypeID for this class. We declare an explicit private
1222 /// instantiation because Pass classes should only be visible by the current
1223 /// library.
1224 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DTensorMergeClustersBase<DerivedT>)
1225
1226protected:
1227private:
1228};
1229} // namespace impl
1230#undef GEN_PASS_DEF_DTENSORMERGECLUSTERS
1231#endif // GEN_PASS_DEF_DTENSORMERGECLUSTERS
1232
1233//===----------------------------------------------------------------------===//
1234// DTensorMeshPropagation
1235//===----------------------------------------------------------------------===//
1236#ifdef GEN_PASS_DECL_DTENSORMESHPROPAGATION
1237#undef GEN_PASS_DECL_DTENSORMESHPROPAGATION
1238#endif // GEN_PASS_DECL_DTENSORMESHPROPAGATION
1239#ifdef GEN_PASS_DEF_DTENSORMESHPROPAGATION
1240namespace impl {
1241
1242template <typename DerivedT>
1243class DTensorMeshPropagationBase : public ::mlir::OperationPass<mlir::ModuleOp> {
1244public:
1245 using Base = DTensorMeshPropagationBase;
1246
1247 DTensorMeshPropagationBase() : ::mlir::OperationPass<mlir::ModuleOp>(::mlir::TypeID::get<DerivedT>()) {}
1248 DTensorMeshPropagationBase(const DTensorMeshPropagationBase &other) : ::mlir::OperationPass<mlir::ModuleOp>(other) {}
1249
1250 /// Returns the command-line argument attached to this pass.
1251 static constexpr ::llvm::StringLiteral getArgumentName() {
1252 return ::llvm::StringLiteral("dtensor-mesh-propagation");
1253 }
1254 ::llvm::StringRef getArgument() const override { return "dtensor-mesh-propagation"; }
1255
1256 ::llvm::StringRef getDescription() const override { return "Propagates mesh information to all clusters."; }
1257
1258 /// Returns the derived pass name.
1259 static constexpr ::llvm::StringLiteral getPassName() {
1260 return ::llvm::StringLiteral("DTensorMeshPropagation");
1261 }
1262 ::llvm::StringRef getName() const override { return "DTensorMeshPropagation"; }
1263
1264 /// Support isa/dyn_cast functionality for the derived pass class.
1265 static bool classof(const ::mlir::Pass *pass) {
1266 return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>();
1267 }
1268
1269 /// A clone method to create a copy of this pass.
1270 std::unique_ptr<::mlir::Pass> clonePass() const override {
1271 return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this));
1272 }
1273
1274 /// Return the dialect that must be loaded in the context before this pass.
1275 void getDependentDialects(::mlir::DialectRegistry &registry) const override {
1276
1277 }
1278
1279 /// Explicitly declare the TypeID for this class. We declare an explicit private
1280 /// instantiation because Pass classes should only be visible by the current
1281 /// library.
1282 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DTensorMeshPropagationBase<DerivedT>)
1283
1284protected:
1285private:
1286};
1287} // namespace impl
1288#undef GEN_PASS_DEF_DTENSORMESHPROPAGATION
1289#endif // GEN_PASS_DEF_DTENSORMESHPROPAGATION
1290
1291//===----------------------------------------------------------------------===//
1292// DTensorMixedPrecisionReduce
1293//===----------------------------------------------------------------------===//
1294#ifdef GEN_PASS_DECL_DTENSORMIXEDPRECISIONREDUCE
1295#undef GEN_PASS_DECL_DTENSORMIXEDPRECISIONREDUCE
1296#endif // GEN_PASS_DECL_DTENSORMIXEDPRECISIONREDUCE
1297#ifdef GEN_PASS_DEF_DTENSORMIXEDPRECISIONREDUCE
1298namespace impl {
1299
1300template <typename DerivedT>
1301class DTensorMixedPrecisionReduceBase : public ::mlir::OperationPass<mlir::func::FuncOp> {
1302public:
1303 using Base = DTensorMixedPrecisionReduceBase;
1304
1305 DTensorMixedPrecisionReduceBase() : ::mlir::OperationPass<mlir::func::FuncOp>(::mlir::TypeID::get<DerivedT>()) {}
1306 DTensorMixedPrecisionReduceBase(const DTensorMixedPrecisionReduceBase &other) : ::mlir::OperationPass<mlir::func::FuncOp>(other) {}
1307
1308 /// Returns the command-line argument attached to this pass.
1309 static constexpr ::llvm::StringLiteral getArgumentName() {
1310 return ::llvm::StringLiteral("dtensor-mixed-precision-reduce");
1311 }
1312 ::llvm::StringRef getArgument() const override { return "dtensor-mixed-precision-reduce"; }
1313
1314 ::llvm::StringRef getDescription() const override { return "Upcast tensors to higher precision type for reduction ops."; }
1315
1316 /// Returns the derived pass name.
1317 static constexpr ::llvm::StringLiteral getPassName() {
1318 return ::llvm::StringLiteral("DTensorMixedPrecisionReduce");
1319 }
1320 ::llvm::StringRef getName() const override { return "DTensorMixedPrecisionReduce"; }
1321
1322 /// Support isa/dyn_cast functionality for the derived pass class.
1323 static bool classof(const ::mlir::Pass *pass) {
1324 return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>();
1325 }
1326
1327 /// A clone method to create a copy of this pass.
1328 std::unique_ptr<::mlir::Pass> clonePass() const override {
1329 return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this));
1330 }
1331
1332 /// Return the dialect that must be loaded in the context before this pass.
1333 void getDependentDialects(::mlir::DialectRegistry &registry) const override {
1334
1335 }
1336
1337 /// Explicitly declare the TypeID for this class. We declare an explicit private
1338 /// instantiation because Pass classes should only be visible by the current
1339 /// library.
1340 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DTensorMixedPrecisionReduceBase<DerivedT>)
1341
1342protected:
1343private:
1344};
1345} // namespace impl
1346#undef GEN_PASS_DEF_DTENSORMIXEDPRECISIONREDUCE
1347#endif // GEN_PASS_DEF_DTENSORMIXEDPRECISIONREDUCE
1348
1349//===----------------------------------------------------------------------===//
1350// DTensorMoveCompilationToHost
1351//===----------------------------------------------------------------------===//
1352#ifdef GEN_PASS_DECL_DTENSORMOVECOMPILATIONTOHOST
1353#undef GEN_PASS_DECL_DTENSORMOVECOMPILATIONTOHOST
1354#endif // GEN_PASS_DECL_DTENSORMOVECOMPILATIONTOHOST
1355#ifdef GEN_PASS_DEF_DTENSORMOVECOMPILATIONTOHOST
1356namespace impl {
1357
1358template <typename DerivedT>
1359class DTensorMoveCompilationToHostBase : public ::mlir::OperationPass<mlir::ModuleOp> {
1360public:
1361 using Base = DTensorMoveCompilationToHostBase;
1362
1363 DTensorMoveCompilationToHostBase() : ::mlir::OperationPass<mlir::ModuleOp>(::mlir::TypeID::get<DerivedT>()) {}
1364 DTensorMoveCompilationToHostBase(const DTensorMoveCompilationToHostBase &other) : ::mlir::OperationPass<mlir::ModuleOp>(other) {}
1365
1366 /// Returns the command-line argument attached to this pass.
1367 static constexpr ::llvm::StringLiteral getArgumentName() {
1368 return ::llvm::StringLiteral("dtensor-move-compilation-to-host");
1369 }
1370 ::llvm::StringRef getArgument() const override { return "dtensor-move-compilation-to-host"; }
1371
1372 ::llvm::StringRef getDescription() const override { return "Moves XLA compilation ops to host computation."; }
1373
1374 /// Returns the derived pass name.
1375 static constexpr ::llvm::StringLiteral getPassName() {
1376 return ::llvm::StringLiteral("DTensorMoveCompilationToHost");
1377 }
1378 ::llvm::StringRef getName() const override { return "DTensorMoveCompilationToHost"; }
1379
1380 /// Support isa/dyn_cast functionality for the derived pass class.
1381 static bool classof(const ::mlir::Pass *pass) {
1382 return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>();
1383 }
1384
1385 /// A clone method to create a copy of this pass.
1386 std::unique_ptr<::mlir::Pass> clonePass() const override {
1387 return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this));
1388 }
1389
1390 /// Return the dialect that must be loaded in the context before this pass.
1391 void getDependentDialects(::mlir::DialectRegistry &registry) const override {
1392
1393 }
1394
1395 /// Explicitly declare the TypeID for this class. We declare an explicit private
1396 /// instantiation because Pass classes should only be visible by the current
1397 /// library.
1398 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DTensorMoveCompilationToHostBase<DerivedT>)
1399
1400protected:
1401private:
1402};
1403} // namespace impl
1404#undef GEN_PASS_DEF_DTENSORMOVECOMPILATIONTOHOST
1405#endif // GEN_PASS_DEF_DTENSORMOVECOMPILATIONTOHOST
1406
1407//===----------------------------------------------------------------------===//
1408// DTensorOpToDeviceCluster
1409//===----------------------------------------------------------------------===//
1410#ifdef GEN_PASS_DECL_DTENSOROPTODEVICECLUSTER
1411#undef GEN_PASS_DECL_DTENSOROPTODEVICECLUSTER
1412#endif // GEN_PASS_DECL_DTENSOROPTODEVICECLUSTER
1413#ifdef GEN_PASS_DEF_DTENSOROPTODEVICECLUSTER
1414namespace impl {
1415
1416template <typename DerivedT>
1417class DTensorOpToDeviceClusterBase : public ::mlir::OperationPass<mlir::func::FuncOp> {
1418public:
1419 using Base = DTensorOpToDeviceClusterBase;
1420
1421 DTensorOpToDeviceClusterBase() : ::mlir::OperationPass<mlir::func::FuncOp>(::mlir::TypeID::get<DerivedT>()) {}
1422 DTensorOpToDeviceClusterBase(const DTensorOpToDeviceClusterBase &other) : ::mlir::OperationPass<mlir::func::FuncOp>(other) {}
1423
1424 /// Returns the command-line argument attached to this pass.
1425 static constexpr ::llvm::StringLiteral getArgumentName() {
1426 return ::llvm::StringLiteral("dtensor-op-to-device-cluster");
1427 }
1428 ::llvm::StringRef getArgument() const override { return "dtensor-op-to-device-cluster"; }
1429
1430 ::llvm::StringRef getDescription() const override { return "Creates and wraps tf_device.cluster op for all TF ops"; }
1431
1432 /// Returns the derived pass name.
1433 static constexpr ::llvm::StringLiteral getPassName() {
1434 return ::llvm::StringLiteral("DTensorOpToDeviceCluster");
1435 }
1436 ::llvm::StringRef getName() const override { return "DTensorOpToDeviceCluster"; }
1437
1438 /// Support isa/dyn_cast functionality for the derived pass class.
1439 static bool classof(const ::mlir::Pass *pass) {
1440 return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>();
1441 }
1442
1443 /// A clone method to create a copy of this pass.
1444 std::unique_ptr<::mlir::Pass> clonePass() const override {
1445 return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this));
1446 }
1447
1448 /// Return the dialect that must be loaded in the context before this pass.
1449 void getDependentDialects(::mlir::DialectRegistry &registry) const override {
1450
1451 }
1452
1453 /// Explicitly declare the TypeID for this class. We declare an explicit private
1454 /// instantiation because Pass classes should only be visible by the current
1455 /// library.
1456 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DTensorOpToDeviceClusterBase<DerivedT>)
1457
1458protected:
1459private:
1460};
1461} // namespace impl
1462#undef GEN_PASS_DEF_DTENSOROPTODEVICECLUSTER
1463#endif // GEN_PASS_DEF_DTENSOROPTODEVICECLUSTER
1464
1465//===----------------------------------------------------------------------===//
1466// DTensorPropagateDefaultLayout
1467//===----------------------------------------------------------------------===//
1468#ifdef GEN_PASS_DECL_DTENSORPROPAGATEDEFAULTLAYOUT
1469#undef GEN_PASS_DECL_DTENSORPROPAGATEDEFAULTLAYOUT
1470#endif // GEN_PASS_DECL_DTENSORPROPAGATEDEFAULTLAYOUT
1471#ifdef GEN_PASS_DEF_DTENSORPROPAGATEDEFAULTLAYOUT
1472namespace impl {
1473
1474template <typename DerivedT>
1475class DTensorPropagateDefaultLayoutBase : public ::mlir::OperationPass<mlir::func::FuncOp> {
1476public:
1477 using Base = DTensorPropagateDefaultLayoutBase;
1478
1479 DTensorPropagateDefaultLayoutBase() : ::mlir::OperationPass<mlir::func::FuncOp>(::mlir::TypeID::get<DerivedT>()) {}
1480 DTensorPropagateDefaultLayoutBase(const DTensorPropagateDefaultLayoutBase &other) : ::mlir::OperationPass<mlir::func::FuncOp>(other) {}
1481
1482 /// Returns the command-line argument attached to this pass.
1483 static constexpr ::llvm::StringLiteral getArgumentName() {
1484 return ::llvm::StringLiteral("dtensor-propagate-default-layout");
1485 }
1486 ::llvm::StringRef getArgument() const override { return "dtensor-propagate-default-layout"; }
1487
1488 ::llvm::StringRef getDescription() const override { return "Converts layout attributes added by end users to DTensorLayout op."; }
1489
1490 /// Returns the derived pass name.
1491 static constexpr ::llvm::StringLiteral getPassName() {
1492 return ::llvm::StringLiteral("DTensorPropagateDefaultLayout");
1493 }
1494 ::llvm::StringRef getName() const override { return "DTensorPropagateDefaultLayout"; }
1495
1496 /// Support isa/dyn_cast functionality for the derived pass class.
1497 static bool classof(const ::mlir::Pass *pass) {
1498 return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>();
1499 }
1500
1501 /// A clone method to create a copy of this pass.
1502 std::unique_ptr<::mlir::Pass> clonePass() const override {
1503 return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this));
1504 }
1505
1506 /// Return the dialect that must be loaded in the context before this pass.
1507 void getDependentDialects(::mlir::DialectRegistry &registry) const override {
1508
1509 }
1510
1511 /// Explicitly declare the TypeID for this class. We declare an explicit private
1512 /// instantiation because Pass classes should only be visible by the current
1513 /// library.
1514 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DTensorPropagateDefaultLayoutBase<DerivedT>)
1515
1516protected:
1517private:
1518};
1519} // namespace impl
1520#undef GEN_PASS_DEF_DTENSORPROPAGATEDEFAULTLAYOUT
1521#endif // GEN_PASS_DEF_DTENSORPROPAGATEDEFAULTLAYOUT
1522
1523//===----------------------------------------------------------------------===//
1524// DTensorPropagateDeviceIdToFunctionArgs
1525//===----------------------------------------------------------------------===//
1526#ifdef GEN_PASS_DECL_DTENSORPROPAGATEDEVICEIDTOFUNCTIONARGS
1527#undef GEN_PASS_DECL_DTENSORPROPAGATEDEVICEIDTOFUNCTIONARGS
1528#endif // GEN_PASS_DECL_DTENSORPROPAGATEDEVICEIDTOFUNCTIONARGS
1529#ifdef GEN_PASS_DEF_DTENSORPROPAGATEDEVICEIDTOFUNCTIONARGS
1530namespace impl {
1531
1532template <typename DerivedT>
1533class DTensorPropagateDeviceIdToFunctionArgsBase : public ::mlir::OperationPass<mlir::ModuleOp> {
1534public:
1535 using Base = DTensorPropagateDeviceIdToFunctionArgsBase;
1536
1537 DTensorPropagateDeviceIdToFunctionArgsBase() : ::mlir::OperationPass<mlir::ModuleOp>(::mlir::TypeID::get<DerivedT>()) {}
1538 DTensorPropagateDeviceIdToFunctionArgsBase(const DTensorPropagateDeviceIdToFunctionArgsBase &other) : ::mlir::OperationPass<mlir::ModuleOp>(other) {}
1539
1540 /// Returns the command-line argument attached to this pass.
1541 static constexpr ::llvm::StringLiteral getArgumentName() {
1542 return ::llvm::StringLiteral("dtensor-propagate-device-id-to-function-args");
1543 }
1544 ::llvm::StringRef getArgument() const override { return "dtensor-propagate-device-id-to-function-args"; }
1545
1546 ::llvm::StringRef getDescription() const override { return "Adds device id as arguments to all private function in graph."; }
1547
1548 /// Returns the derived pass name.
1549 static constexpr ::llvm::StringLiteral getPassName() {
1550 return ::llvm::StringLiteral("DTensorPropagateDeviceIdToFunctionArgs");
1551 }
1552 ::llvm::StringRef getName() const override { return "DTensorPropagateDeviceIdToFunctionArgs"; }
1553
1554 /// Support isa/dyn_cast functionality for the derived pass class.
1555 static bool classof(const ::mlir::Pass *pass) {
1556 return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>();
1557 }
1558
1559 /// A clone method to create a copy of this pass.
1560 std::unique_ptr<::mlir::Pass> clonePass() const override {
1561 return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this));
1562 }
1563
1564 /// Return the dialect that must be loaded in the context before this pass.
1565 void getDependentDialects(::mlir::DialectRegistry &registry) const override {
1566
1567 }
1568
1569 /// Explicitly declare the TypeID for this class. We declare an explicit private
1570 /// instantiation because Pass classes should only be visible by the current
1571 /// library.
1572 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DTensorPropagateDeviceIdToFunctionArgsBase<DerivedT>)
1573
1574protected:
1575private:
1576};
1577} // namespace impl
1578#undef GEN_PASS_DEF_DTENSORPROPAGATEDEVICEIDTOFUNCTIONARGS
1579#endif // GEN_PASS_DEF_DTENSORPROPAGATEDEVICEIDTOFUNCTIONARGS
1580
1581//===----------------------------------------------------------------------===//
1582// DTensorReduceScatterLowering
1583//===----------------------------------------------------------------------===//
1584#ifdef GEN_PASS_DECL_DTENSORREDUCESCATTERLOWERING
1585#undef GEN_PASS_DECL_DTENSORREDUCESCATTERLOWERING
1586#endif // GEN_PASS_DECL_DTENSORREDUCESCATTERLOWERING
1587#ifdef GEN_PASS_DEF_DTENSORREDUCESCATTERLOWERING
1588namespace impl {
1589
1590template <typename DerivedT>
1591class DTensorReduceScatterLoweringBase : public ::mlir::OperationPass<mlir::ModuleOp> {
1592public:
1593 using Base = DTensorReduceScatterLoweringBase;
1594
1595 DTensorReduceScatterLoweringBase() : ::mlir::OperationPass<mlir::ModuleOp>(::mlir::TypeID::get<DerivedT>()) {}
1596 DTensorReduceScatterLoweringBase(const DTensorReduceScatterLoweringBase &other) : ::mlir::OperationPass<mlir::ModuleOp>(other) {}
1597
1598 /// Returns the command-line argument attached to this pass.
1599 static constexpr ::llvm::StringLiteral getArgumentName() {
1600 return ::llvm::StringLiteral("dtensor-reduce-scatter-lowering");
1601 }
1602 ::llvm::StringRef getArgument() const override { return "dtensor-reduce-scatter-lowering"; }
1603
1604 ::llvm::StringRef getDescription() const override { return "Converts logical ReduceScatter ops into physical ReduceScatter ops."; }
1605
1606 /// Returns the derived pass name.
1607 static constexpr ::llvm::StringLiteral getPassName() {
1608 return ::llvm::StringLiteral("DTensorReduceScatterLowering");
1609 }
1610 ::llvm::StringRef getName() const override { return "DTensorReduceScatterLowering"; }
1611
1612 /// Support isa/dyn_cast functionality for the derived pass class.
1613 static bool classof(const ::mlir::Pass *pass) {
1614 return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>();
1615 }
1616
1617 /// A clone method to create a copy of this pass.
1618 std::unique_ptr<::mlir::Pass> clonePass() const override {
1619 return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this));
1620 }
1621
1622 /// Return the dialect that must be loaded in the context before this pass.
1623 void getDependentDialects(::mlir::DialectRegistry &registry) const override {
1624
1625 }
1626
1627 /// Explicitly declare the TypeID for this class. We declare an explicit private
1628 /// instantiation because Pass classes should only be visible by the current
1629 /// library.
1630 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DTensorReduceScatterLoweringBase<DerivedT>)
1631
1632protected:
1633private:
1634};
1635} // namespace impl
1636#undef GEN_PASS_DEF_DTENSORREDUCESCATTERLOWERING
1637#endif // GEN_PASS_DEF_DTENSORREDUCESCATTERLOWERING
1638
1639//===----------------------------------------------------------------------===//
1640// DTensorSPMDExpansion
1641//===----------------------------------------------------------------------===//
1642#ifdef GEN_PASS_DECL_DTENSORSPMDEXPANSION
1643#undef GEN_PASS_DECL_DTENSORSPMDEXPANSION
1644#endif // GEN_PASS_DECL_DTENSORSPMDEXPANSION
1645#ifdef GEN_PASS_DEF_DTENSORSPMDEXPANSION
1646namespace impl {
1647
1648template <typename DerivedT>
1649class DTensorSPMDExpansionBase : public ::mlir::OperationPass<mlir::ModuleOp> {
1650public:
1651 using Base = DTensorSPMDExpansionBase;
1652
1653 DTensorSPMDExpansionBase() : ::mlir::OperationPass<mlir::ModuleOp>(::mlir::TypeID::get<DerivedT>()) {}
1654 DTensorSPMDExpansionBase(const DTensorSPMDExpansionBase &other) : ::mlir::OperationPass<mlir::ModuleOp>(other) {}
1655
1656 /// Returns the command-line argument attached to this pass.
1657 static constexpr ::llvm::StringLiteral getArgumentName() {
1658 return ::llvm::StringLiteral("dtensor-spmd-expansion");
1659 }
1660 ::llvm::StringRef getArgument() const override { return "dtensor-spmd-expansion"; }
1661
1662 ::llvm::StringRef getDescription() const override { return "Converts ops into SPMD expanded form."; }
1663
1664 /// Returns the derived pass name.
1665 static constexpr ::llvm::StringLiteral getPassName() {
1666 return ::llvm::StringLiteral("DTensorSPMDExpansion");
1667 }
1668 ::llvm::StringRef getName() const override { return "DTensorSPMDExpansion"; }
1669
1670 /// Support isa/dyn_cast functionality for the derived pass class.
1671 static bool classof(const ::mlir::Pass *pass) {
1672 return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>();
1673 }
1674
1675 /// A clone method to create a copy of this pass.
1676 std::unique_ptr<::mlir::Pass> clonePass() const override {
1677 return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this));
1678 }
1679
1680 /// Return the dialect that must be loaded in the context before this pass.
1681 void getDependentDialects(::mlir::DialectRegistry &registry) const override {
1682
1683 }
1684
1685 /// Explicitly declare the TypeID for this class. We declare an explicit private
1686 /// instantiation because Pass classes should only be visible by the current
1687 /// library.
1688 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DTensorSPMDExpansionBase<DerivedT>)
1689
1690protected:
1691private:
1692};
1693} // namespace impl
1694#undef GEN_PASS_DEF_DTENSORSPMDEXPANSION
1695#endif // GEN_PASS_DEF_DTENSORSPMDEXPANSION
1696
1697//===----------------------------------------------------------------------===//
1698// DTensorSetDefaultSharding
1699//===----------------------------------------------------------------------===//
1700#ifdef GEN_PASS_DECL_DTENSORSETDEFAULTSHARDING
1701#undef GEN_PASS_DECL_DTENSORSETDEFAULTSHARDING
1702#endif // GEN_PASS_DECL_DTENSORSETDEFAULTSHARDING
1703#ifdef GEN_PASS_DEF_DTENSORSETDEFAULTSHARDING
1704namespace impl {
1705
1706template <typename DerivedT>
1707class DTensorSetDefaultShardingBase : public ::mlir::OperationPass<mlir::func::FuncOp> {
1708public:
1709 using Base = DTensorSetDefaultShardingBase;
1710
1711 DTensorSetDefaultShardingBase() : ::mlir::OperationPass<mlir::func::FuncOp>(::mlir::TypeID::get<DerivedT>()) {}
1712 DTensorSetDefaultShardingBase(const DTensorSetDefaultShardingBase &other) : ::mlir::OperationPass<mlir::func::FuncOp>(other) {}
1713
1714 /// Returns the command-line argument attached to this pass.
1715 static constexpr ::llvm::StringLiteral getArgumentName() {
1716 return ::llvm::StringLiteral("dtensor-set-default-sharding");
1717 }
1718 ::llvm::StringRef getArgument() const override { return "dtensor-set-default-sharding"; }
1719
1720 ::llvm::StringRef getDescription() const override { return "Sets default sharding of TPU computation inputs/outputs to maximal."; }
1721
1722 /// Returns the derived pass name.
1723 static constexpr ::llvm::StringLiteral getPassName() {
1724 return ::llvm::StringLiteral("DTensorSetDefaultSharding");
1725 }
1726 ::llvm::StringRef getName() const override { return "DTensorSetDefaultSharding"; }
1727
1728 /// Support isa/dyn_cast functionality for the derived pass class.
1729 static bool classof(const ::mlir::Pass *pass) {
1730 return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>();
1731 }
1732
1733 /// A clone method to create a copy of this pass.
1734 std::unique_ptr<::mlir::Pass> clonePass() const override {
1735 return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this));
1736 }
1737
1738 /// Return the dialect that must be loaded in the context before this pass.
1739 void getDependentDialects(::mlir::DialectRegistry &registry) const override {
1740
1741 }
1742
1743 /// Explicitly declare the TypeID for this class. We declare an explicit private
1744 /// instantiation because Pass classes should only be visible by the current
1745 /// library.
1746 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DTensorSetDefaultShardingBase<DerivedT>)
1747
1748protected:
1749private:
1750};
1751} // namespace impl
1752#undef GEN_PASS_DEF_DTENSORSETDEFAULTSHARDING
1753#endif // GEN_PASS_DEF_DTENSORSETDEFAULTSHARDING
1754
1755//===----------------------------------------------------------------------===//
1756// DTensorSparseExpansion
1757//===----------------------------------------------------------------------===//
1758#ifdef GEN_PASS_DECL_DTENSORSPARSEEXPANSION
1759#undef GEN_PASS_DECL_DTENSORSPARSEEXPANSION
1760#endif // GEN_PASS_DECL_DTENSORSPARSEEXPANSION
1761#ifdef GEN_PASS_DEF_DTENSORSPARSEEXPANSION
1762namespace impl {
1763
1764template <typename DerivedT>
1765class DTensorSparseExpansionBase : public ::mlir::OperationPass<mlir::ModuleOp> {
1766public:
1767 using Base = DTensorSparseExpansionBase;
1768
1769 DTensorSparseExpansionBase() : ::mlir::OperationPass<mlir::ModuleOp>(::mlir::TypeID::get<DerivedT>()) {}
1770 DTensorSparseExpansionBase(const DTensorSparseExpansionBase &other) : ::mlir::OperationPass<mlir::ModuleOp>(other) {}
1771
1772 /// Returns the command-line argument attached to this pass.
1773 static constexpr ::llvm::StringLiteral getArgumentName() {
1774 return ::llvm::StringLiteral("dtensor-sparse-expansion");
1775 }
1776 ::llvm::StringRef getArgument() const override { return "dtensor-sparse-expansion"; }
1777
1778 ::llvm::StringRef getDescription() const override { return "Convert ops that take in SparseTensor input to its corresponding Sparse or Dense ops."; }
1779
1780 /// Returns the derived pass name.
1781 static constexpr ::llvm::StringLiteral getPassName() {
1782 return ::llvm::StringLiteral("DTensorSparseExpansion");
1783 }
1784 ::llvm::StringRef getName() const override { return "DTensorSparseExpansion"; }
1785
1786 /// Support isa/dyn_cast functionality for the derived pass class.
1787 static bool classof(const ::mlir::Pass *pass) {
1788 return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>();
1789 }
1790
1791 /// A clone method to create a copy of this pass.
1792 std::unique_ptr<::mlir::Pass> clonePass() const override {
1793 return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this));
1794 }
1795
1796 /// Return the dialect that must be loaded in the context before this pass.
1797 void getDependentDialects(::mlir::DialectRegistry &registry) const override {
1798
1799 }
1800
1801 /// Explicitly declare the TypeID for this class. We declare an explicit private
1802 /// instantiation because Pass classes should only be visible by the current
1803 /// library.
1804 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DTensorSparseExpansionBase<DerivedT>)
1805
1806protected:
1807private:
1808};
1809} // namespace impl
1810#undef GEN_PASS_DEF_DTENSORSPARSEEXPANSION
1811#endif // GEN_PASS_DEF_DTENSORSPARSEEXPANSION
1812
1813//===----------------------------------------------------------------------===//
1814// DTensorSparseTensorToDenseTensor
1815//===----------------------------------------------------------------------===//
1816#ifdef GEN_PASS_DECL_DTENSORSPARSETENSORTODENSETENSOR
1817#undef GEN_PASS_DECL_DTENSORSPARSETENSORTODENSETENSOR
1818#endif // GEN_PASS_DECL_DTENSORSPARSETENSORTODENSETENSOR
1819#ifdef GEN_PASS_DEF_DTENSORSPARSETENSORTODENSETENSOR
1820namespace impl {
1821
1822template <typename DerivedT>
1823class DTensorSparseTensorToDenseTensorBase : public ::mlir::OperationPass<mlir::ModuleOp> {
1824public:
1825 using Base = DTensorSparseTensorToDenseTensorBase;
1826
1827 DTensorSparseTensorToDenseTensorBase() : ::mlir::OperationPass<mlir::ModuleOp>(::mlir::TypeID::get<DerivedT>()) {}
1828 DTensorSparseTensorToDenseTensorBase(const DTensorSparseTensorToDenseTensorBase &other) : ::mlir::OperationPass<mlir::ModuleOp>(other) {}
1829
1830 /// Returns the command-line argument attached to this pass.
1831 static constexpr ::llvm::StringLiteral getArgumentName() {
1832 return ::llvm::StringLiteral("dtensor-sparse-tensor-to-dense-tensor");
1833 }
1834 ::llvm::StringRef getArgument() const override { return "dtensor-sparse-tensor-to-dense-tensor"; }
1835
1836 ::llvm::StringRef getDescription() const override { return "Converts SparseTensor inputs to its component tensors inputs and emits a SparseToDenseOp for every op that consumes a SparseTensor."; }
1837
1838 /// Returns the derived pass name.
1839 static constexpr ::llvm::StringLiteral getPassName() {
1840 return ::llvm::StringLiteral("DTensorSparseTensorToDenseTensor");
1841 }
1842 ::llvm::StringRef getName() const override { return "DTensorSparseTensorToDenseTensor"; }
1843
1844 /// Support isa/dyn_cast functionality for the derived pass class.
1845 static bool classof(const ::mlir::Pass *pass) {
1846 return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>();
1847 }
1848
1849 /// A clone method to create a copy of this pass.
1850 std::unique_ptr<::mlir::Pass> clonePass() const override {
1851 return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this));
1852 }
1853
1854 /// Return the dialect that must be loaded in the context before this pass.
1855 void getDependentDialects(::mlir::DialectRegistry &registry) const override {
1856
1857 }
1858
1859 /// Explicitly declare the TypeID for this class. We declare an explicit private
1860 /// instantiation because Pass classes should only be visible by the current
1861 /// library.
1862 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DTensorSparseTensorToDenseTensorBase<DerivedT>)
1863
1864protected:
1865private:
1866};
1867} // namespace impl
1868#undef GEN_PASS_DEF_DTENSORSPARSETENSORTODENSETENSOR
1869#endif // GEN_PASS_DEF_DTENSORSPARSETENSORTODENSETENSOR
1870
1871//===----------------------------------------------------------------------===//
1872// DTensorTPUIntegration
1873//===----------------------------------------------------------------------===//
1874#ifdef GEN_PASS_DECL_DTENSORTPUINTEGRATION
1875#undef GEN_PASS_DECL_DTENSORTPUINTEGRATION
1876#endif // GEN_PASS_DECL_DTENSORTPUINTEGRATION
1877#ifdef GEN_PASS_DEF_DTENSORTPUINTEGRATION
1878namespace impl {
1879
1880template <typename DerivedT>
1881class DTensorTPUIntegrationBase : public ::mlir::OperationPass<mlir::ModuleOp> {
1882public:
1883 using Base = DTensorTPUIntegrationBase;
1884
1885 DTensorTPUIntegrationBase() : ::mlir::OperationPass<mlir::ModuleOp>(::mlir::TypeID::get<DerivedT>()) {}
1886 DTensorTPUIntegrationBase(const DTensorTPUIntegrationBase &other) : ::mlir::OperationPass<mlir::ModuleOp>(other) {}
1887
1888 /// Returns the command-line argument attached to this pass.
1889 static constexpr ::llvm::StringLiteral getArgumentName() {
1890 return ::llvm::StringLiteral("dtensor-tpu-integration");
1891 }
1892 ::llvm::StringRef getArgument() const override { return "dtensor-tpu-integration"; }
1893
1894 ::llvm::StringRef getDescription() const override { return "Adds TPUReplicateMetadata and converts ops that run on TPU's to a single tf_device.cluster to be compatible with following TF2XLA MLIR passes."; }
1895
1896 /// Returns the derived pass name.
1897 static constexpr ::llvm::StringLiteral getPassName() {
1898 return ::llvm::StringLiteral("DTensorTPUIntegration");
1899 }
1900 ::llvm::StringRef getName() const override { return "DTensorTPUIntegration"; }
1901
1902 /// Support isa/dyn_cast functionality for the derived pass class.
1903 static bool classof(const ::mlir::Pass *pass) {
1904 return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>();
1905 }
1906
1907 /// A clone method to create a copy of this pass.
1908 std::unique_ptr<::mlir::Pass> clonePass() const override {
1909 return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this));
1910 }
1911
1912 /// Return the dialect that must be loaded in the context before this pass.
1913 void getDependentDialects(::mlir::DialectRegistry &registry) const override {
1914
1915 }
1916
1917 /// Explicitly declare the TypeID for this class. We declare an explicit private
1918 /// instantiation because Pass classes should only be visible by the current
1919 /// library.
1920 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DTensorTPUIntegrationBase<DerivedT>)
1921
1922protected:
1923private:
1924};
1925} // namespace impl
1926#undef GEN_PASS_DEF_DTENSORTPUINTEGRATION
1927#endif // GEN_PASS_DEF_DTENSORTPUINTEGRATION
1928
1929//===----------------------------------------------------------------------===//
1930// DTensorTpuAddResourceDeviceAttribute
1931//===----------------------------------------------------------------------===//
1932#ifdef GEN_PASS_DECL_DTENSORTPUADDRESOURCEDEVICEATTRIBUTE
1933#undef GEN_PASS_DECL_DTENSORTPUADDRESOURCEDEVICEATTRIBUTE
1934#endif // GEN_PASS_DECL_DTENSORTPUADDRESOURCEDEVICEATTRIBUTE
1935#ifdef GEN_PASS_DEF_DTENSORTPUADDRESOURCEDEVICEATTRIBUTE
1936namespace impl {
1937
1938template <typename DerivedT>
1939class DTensorTpuAddResourceDeviceAttributeBase : public ::mlir::OperationPass<mlir::ModuleOp> {
1940public:
1941 using Base = DTensorTpuAddResourceDeviceAttributeBase;
1942
1943 DTensorTpuAddResourceDeviceAttributeBase() : ::mlir::OperationPass<mlir::ModuleOp>(::mlir::TypeID::get<DerivedT>()) {}
1944 DTensorTpuAddResourceDeviceAttributeBase(const DTensorTpuAddResourceDeviceAttributeBase &other) : ::mlir::OperationPass<mlir::ModuleOp>(other) {}
1945
1946 /// Returns the command-line argument attached to this pass.
1947 static constexpr ::llvm::StringLiteral getArgumentName() {
1948 return ::llvm::StringLiteral("dtensor-tpu-add-resource-device-attribute");
1949 }
1950 ::llvm::StringRef getArgument() const override { return "dtensor-tpu-add-resource-device-attribute"; }
1951
1952 ::llvm::StringRef getDescription() const override { return "Adds placeholder device attributes to resources accessed by TPU computation to enable buffer aliasing."; }
1953
1954 /// Returns the derived pass name.
1955 static constexpr ::llvm::StringLiteral getPassName() {
1956 return ::llvm::StringLiteral("DTensorTpuAddResourceDeviceAttribute");
1957 }
1958 ::llvm::StringRef getName() const override { return "DTensorTpuAddResourceDeviceAttribute"; }
1959
1960 /// Support isa/dyn_cast functionality for the derived pass class.
1961 static bool classof(const ::mlir::Pass *pass) {
1962 return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>();
1963 }
1964
1965 /// A clone method to create a copy of this pass.
1966 std::unique_ptr<::mlir::Pass> clonePass() const override {
1967 return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this));
1968 }
1969
1970 /// Return the dialect that must be loaded in the context before this pass.
1971 void getDependentDialects(::mlir::DialectRegistry &registry) const override {
1972
1973 }
1974
1975 /// Explicitly declare the TypeID for this class. We declare an explicit private
1976 /// instantiation because Pass classes should only be visible by the current
1977 /// library.
1978 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DTensorTpuAddResourceDeviceAttributeBase<DerivedT>)
1979
1980protected:
1981private:
1982};
1983} // namespace impl
1984#undef GEN_PASS_DEF_DTENSORTPUADDRESOURCEDEVICEATTRIBUTE
1985#endif // GEN_PASS_DEF_DTENSORTPUADDRESOURCEDEVICEATTRIBUTE
1986
1987//===----------------------------------------------------------------------===//
1988// DTensorUndoMergeConstAcrossMesh
1989//===----------------------------------------------------------------------===//
1990#ifdef GEN_PASS_DECL_DTENSORUNDOMERGECONSTACROSSMESH
1991#undef GEN_PASS_DECL_DTENSORUNDOMERGECONSTACROSSMESH
1992#endif // GEN_PASS_DECL_DTENSORUNDOMERGECONSTACROSSMESH
1993#ifdef GEN_PASS_DEF_DTENSORUNDOMERGECONSTACROSSMESH
1994namespace impl {
1995
1996template <typename DerivedT>
1997class DTensorUndoMergeConstAcrossMeshBase : public ::mlir::OperationPass<mlir::func::FuncOp> {
1998public:
1999 using Base = DTensorUndoMergeConstAcrossMeshBase;
2000
2001 DTensorUndoMergeConstAcrossMeshBase() : ::mlir::OperationPass<mlir::func::FuncOp>(::mlir::TypeID::get<DerivedT>()) {}
2002 DTensorUndoMergeConstAcrossMeshBase(const DTensorUndoMergeConstAcrossMeshBase &other) : ::mlir::OperationPass<mlir::func::FuncOp>(other) {}
2003
2004 /// Returns the command-line argument attached to this pass.
2005 static constexpr ::llvm::StringLiteral getArgumentName() {
2006 return ::llvm::StringLiteral("dtensor-undo-merge-const-across-mesh");
2007 }
2008 ::llvm::StringRef getArgument() const override { return "dtensor-undo-merge-const-across-mesh"; }
2009
2010 ::llvm::StringRef getDescription() const override { return "Undo constant merging across meshes"; }
2011
2012 /// Returns the derived pass name.
2013 static constexpr ::llvm::StringLiteral getPassName() {
2014 return ::llvm::StringLiteral("DTensorUndoMergeConstAcrossMesh");
2015 }
2016 ::llvm::StringRef getName() const override { return "DTensorUndoMergeConstAcrossMesh"; }
2017
2018 /// Support isa/dyn_cast functionality for the derived pass class.
2019 static bool classof(const ::mlir::Pass *pass) {
2020 return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>();
2021 }
2022
2023 /// A clone method to create a copy of this pass.
2024 std::unique_ptr<::mlir::Pass> clonePass() const override {
2025 return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this));
2026 }
2027
2028 /// Return the dialect that must be loaded in the context before this pass.
2029 void getDependentDialects(::mlir::DialectRegistry &registry) const override {
2030
2031 }
2032
2033 /// Explicitly declare the TypeID for this class. We declare an explicit private
2034 /// instantiation because Pass classes should only be visible by the current
2035 /// library.
2036 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DTensorUndoMergeConstAcrossMeshBase<DerivedT>)
2037
2038protected:
2039private:
2040};
2041} // namespace impl
2042#undef GEN_PASS_DEF_DTENSORUNDOMERGECONSTACROSSMESH
2043#endif // GEN_PASS_DEF_DTENSORUNDOMERGECONSTACROSSMESH
2044
2045//===----------------------------------------------------------------------===//
2046// DTensorUpdateTPUMetadata
2047//===----------------------------------------------------------------------===//
2048#ifdef GEN_PASS_DECL_DTENSORUPDATETPUMETADATA
2049#undef GEN_PASS_DECL_DTENSORUPDATETPUMETADATA
2050#endif // GEN_PASS_DECL_DTENSORUPDATETPUMETADATA
2051#ifdef GEN_PASS_DEF_DTENSORUPDATETPUMETADATA
2052namespace impl {
2053
2054template <typename DerivedT>
2055class DTensorUpdateTPUMetadataBase : public ::mlir::OperationPass<mlir::ModuleOp> {
2056public:
2057 using Base = DTensorUpdateTPUMetadataBase;
2058
2059 DTensorUpdateTPUMetadataBase() : ::mlir::OperationPass<mlir::ModuleOp>(::mlir::TypeID::get<DerivedT>()) {}
2060 DTensorUpdateTPUMetadataBase(const DTensorUpdateTPUMetadataBase &other) : ::mlir::OperationPass<mlir::ModuleOp>(other) {}
2061
2062 /// Returns the command-line argument attached to this pass.
2063 static constexpr ::llvm::StringLiteral getArgumentName() {
2064 return ::llvm::StringLiteral("dtensor-update-tpu-metadata");
2065 }
2066 ::llvm::StringRef getArgument() const override { return "dtensor-update-tpu-metadata"; }
2067
2068 ::llvm::StringRef getDescription() const override { return "Changes metadata on TPU specific ops such as device placement."; }
2069
2070 /// Returns the derived pass name.
2071 static constexpr ::llvm::StringLiteral getPassName() {
2072 return ::llvm::StringLiteral("DTensorUpdateTPUMetadata");
2073 }
2074 ::llvm::StringRef getName() const override { return "DTensorUpdateTPUMetadata"; }
2075
2076 /// Support isa/dyn_cast functionality for the derived pass class.
2077 static bool classof(const ::mlir::Pass *pass) {
2078 return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>();
2079 }
2080
2081 /// A clone method to create a copy of this pass.
2082 std::unique_ptr<::mlir::Pass> clonePass() const override {
2083 return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this));
2084 }
2085
2086 /// Return the dialect that must be loaded in the context before this pass.
2087 void getDependentDialects(::mlir::DialectRegistry &registry) const override {
2088
2089 }
2090
2091 /// Explicitly declare the TypeID for this class. We declare an explicit private
2092 /// instantiation because Pass classes should only be visible by the current
2093 /// library.
2094 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DTensorUpdateTPUMetadataBase<DerivedT>)
2095
2096protected:
2097private:
2098};
2099} // namespace impl
2100#undef GEN_PASS_DEF_DTENSORUPDATETPUMETADATA
2101#endif // GEN_PASS_DEF_DTENSORUPDATETPUMETADATA
2102#ifdef GEN_PASS_REGISTRATION
2103
2104//===----------------------------------------------------------------------===//
2105// DTensorAllGatherLowering Registration
2106//===----------------------------------------------------------------------===//
2107
2108inline void registerDTensorAllGatherLowering() {
2109 ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {
2110 return CreateDTensorAllGatherLoweringPass();
2111 });
2112}
2113
2114// Old registration code, kept for temporary backwards compatibility.
2115inline void registerDTensorAllGatherLoweringPass() {
2116 ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {
2117 return CreateDTensorAllGatherLoweringPass();
2118 });
2119}
2120
2121//===----------------------------------------------------------------------===//
2122// DTensorAllReduceCombineOptimization Registration
2123//===----------------------------------------------------------------------===//
2124
2125inline void registerDTensorAllReduceCombineOptimization() {
2126 ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {
2127 return CreateDTensorAllReduceCombineOptimization();
2128 });
2129}
2130
2131// Old registration code, kept for temporary backwards compatibility.
2132inline void registerDTensorAllReduceCombineOptimizationPass() {
2133 ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {
2134 return CreateDTensorAllReduceCombineOptimization();
2135 });
2136}
2137
2138//===----------------------------------------------------------------------===//
2139// DTensorAllReduceLowering Registration
2140//===----------------------------------------------------------------------===//
2141
2142inline void registerDTensorAllReduceLowering() {
2143 ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {
2144 return CreateDTensorAllReduceLoweringPass();
2145 });
2146}
2147
2148// Old registration code, kept for temporary backwards compatibility.
2149inline void registerDTensorAllReduceLoweringPass() {
2150 ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {
2151 return CreateDTensorAllReduceLoweringPass();
2152 });
2153}
2154
2155//===----------------------------------------------------------------------===//
2156// DTensorAllReduceScatterOptimization Registration
2157//===----------------------------------------------------------------------===//
2158
2159inline void registerDTensorAllReduceScatterOptimization() {
2160 ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {
2161 return CreateDTensorAllReduceScatterOptimization();
2162 });
2163}
2164
2165// Old registration code, kept for temporary backwards compatibility.
2166inline void registerDTensorAllReduceScatterOptimizationPass() {
2167 ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {
2168 return CreateDTensorAllReduceScatterOptimization();
2169 });
2170}
2171
2172//===----------------------------------------------------------------------===//
2173// DTensorAllReduceSumOptimization Registration
2174//===----------------------------------------------------------------------===//
2175
2176inline void registerDTensorAllReduceSumOptimization() {
2177 ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {
2178 return CreateDTensorAllReduceSumOptimization();
2179 });
2180}
2181
2182// Old registration code, kept for temporary backwards compatibility.
2183inline void registerDTensorAllReduceSumOptimizationPass() {
2184 ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {
2185 return CreateDTensorAllReduceSumOptimization();
2186 });
2187}
2188
2189//===----------------------------------------------------------------------===//
2190// DTensorAllScatterLowering Registration
2191//===----------------------------------------------------------------------===//
2192
2193inline void registerDTensorAllScatterLowering() {
2194 ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {
2195 return CreateDTensorAllScatterLoweringPass();
2196 });
2197}
2198
2199// Old registration code, kept for temporary backwards compatibility.
2200inline void registerDTensorAllScatterLoweringPass() {
2201 ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {
2202 return CreateDTensorAllScatterLoweringPass();
2203 });
2204}
2205
2206//===----------------------------------------------------------------------===//
2207// DTensorAnnotateGlobalShape Registration
2208//===----------------------------------------------------------------------===//
2209
2210inline void registerDTensorAnnotateGlobalShape() {
2211 ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {
2212 return CreateDTensorAnnotateGlobalShape();
2213 });
2214}
2215
2216// Old registration code, kept for temporary backwards compatibility.
2217inline void registerDTensorAnnotateGlobalShapePass() {
2218 ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {
2219 return CreateDTensorAnnotateGlobalShape();
2220 });
2221}
2222
2223//===----------------------------------------------------------------------===//
2224// DTensorClusterFunctionConversion Registration
2225//===----------------------------------------------------------------------===//
2226
2227inline void registerDTensorClusterFunctionConversion() {
2228 ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {
2229 return CreateDTensorClusterFunctionConversion();
2230 });
2231}
2232
2233// Old registration code, kept for temporary backwards compatibility.
2234inline void registerDTensorClusterFunctionConversionPass() {
2235 ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {
2236 return CreateDTensorClusterFunctionConversion();
2237 });
2238}
2239
2240//===----------------------------------------------------------------------===//
2241// DTensorConstantFolding Registration
2242//===----------------------------------------------------------------------===//
2243
2244inline void registerDTensorConstantFolding() {
2245 ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {
2246 return CreateDTensorConstantFolding();
2247 });
2248}
2249
2250// Old registration code, kept for temporary backwards compatibility.
2251inline void registerDTensorConstantFoldingPass() {
2252 ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {
2253 return CreateDTensorConstantFolding();
2254 });
2255}
2256
2257//===----------------------------------------------------------------------===//
2258// DTensorDCE Registration
2259//===----------------------------------------------------------------------===//
2260
2261inline void registerDTensorDCE() {
2262 ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {
2263 return CreateDTensorDCE();
2264 });
2265}
2266
2267// Old registration code, kept for temporary backwards compatibility.
2268inline void registerDTensorDCEPass() {
2269 ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {
2270 return CreateDTensorDCE();
2271 });
2272}
2273
2274//===----------------------------------------------------------------------===//
2275// DTensorDesignateResourceHandleMesh Registration
2276//===----------------------------------------------------------------------===//
2277
2278inline void registerDTensorDesignateResourceHandleMesh() {
2279 ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {
2280 return CreateDTensorDesignateResourceHandleMesh();
2281 });
2282}
2283
2284// Old registration code, kept for temporary backwards compatibility.
2285inline void registerDTensorDesignateResourceHandleMeshPass() {
2286 ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {
2287 return CreateDTensorDesignateResourceHandleMesh();
2288 });
2289}
2290
2291//===----------------------------------------------------------------------===//
2292// DTensorDeviceMeshClusterCoarsening Registration
2293//===----------------------------------------------------------------------===//
2294
2295inline void registerDTensorDeviceMeshClusterCoarsening() {
2296 ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {
2297 return CreateDTensorDeviceMeshClusterCoarsening();
2298 });
2299}
2300
2301// Old registration code, kept for temporary backwards compatibility.
2302inline void registerDTensorDeviceMeshClusterCoarseningPass() {
2303 ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {
2304 return CreateDTensorDeviceMeshClusterCoarsening();
2305 });
2306}
2307
2308//===----------------------------------------------------------------------===//
2309// DTensorEmbedding Registration
2310//===----------------------------------------------------------------------===//
2311
2312inline void registerDTensorEmbedding() {
2313 ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {
2314 return CreateDTensorEmbeddingPass();
2315 });
2316}
2317
2318// Old registration code, kept for temporary backwards compatibility.
2319inline void registerDTensorEmbeddingPass() {
2320 ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {
2321 return CreateDTensorEmbeddingPass();
2322 });
2323}
2324
2325//===----------------------------------------------------------------------===//
2326// DTensorEmbeddingCheckpoint Registration
2327//===----------------------------------------------------------------------===//
2328
2329inline void registerDTensorEmbeddingCheckpoint() {
2330 ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {
2331 return CreateDTensorEmbeddingCheckpointPass();
2332 });
2333}
2334
2335// Old registration code, kept for temporary backwards compatibility.
2336inline void registerDTensorEmbeddingCheckpointPass() {
2337 ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {
2338 return CreateDTensorEmbeddingCheckpointPass();
2339 });
2340}
2341
2342//===----------------------------------------------------------------------===//
2343// DTensorEmbeddingV2 Registration
2344//===----------------------------------------------------------------------===//
2345
2346inline void registerDTensorEmbeddingV2() {
2347 ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {
2348 return CreateDTensorEmbeddingPassV2();
2349 });
2350}
2351
2352// Old registration code, kept for temporary backwards compatibility.
2353inline void registerDTensorEmbeddingV2Pass() {
2354 ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {
2355 return CreateDTensorEmbeddingPassV2();
2356 });
2357}
2358
2359//===----------------------------------------------------------------------===//
2360// DTensorFunctionRenaming Registration
2361//===----------------------------------------------------------------------===//
2362
2363inline void registerDTensorFunctionRenaming() {
2364 ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {
2365 return CreateFunctionRenamingPass();
2366 });
2367}
2368
2369// Old registration code, kept for temporary backwards compatibility.
2370inline void registerDTensorFunctionRenamingPass() {
2371 ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {
2372 return CreateFunctionRenamingPass();
2373 });
2374}
2375
2376//===----------------------------------------------------------------------===//
2377// DTensorHandleCrossClusterDependencies Registration
2378//===----------------------------------------------------------------------===//
2379
2380inline void registerDTensorHandleCrossClusterDependencies() {
2381 ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {
2382 return CreateDTensorHandleCrossClusterDependencies();
2383 });
2384}
2385
2386// Old registration code, kept for temporary backwards compatibility.
2387inline void registerDTensorHandleCrossClusterDependenciesPass() {
2388 ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {
2389 return CreateDTensorHandleCrossClusterDependencies();
2390 });
2391}
2392
2393//===----------------------------------------------------------------------===//
2394// DTensorInferShapesForRestoreV2Op Registration
2395//===----------------------------------------------------------------------===//
2396
2397inline void registerDTensorInferShapesForRestoreV2Op() {
2398 ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {
2399 return CreateDTensorInferShapesForRestoreV2Op();
2400 });
2401}
2402
2403// Old registration code, kept for temporary backwards compatibility.
2404inline void registerDTensorInferShapesForRestoreV2OpPass() {
2405 ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {
2406 return CreateDTensorInferShapesForRestoreV2Op();
2407 });
2408}
2409
2410//===----------------------------------------------------------------------===//
2411// DTensorLayoutPropagationV2 Registration
2412//===----------------------------------------------------------------------===//
2413
2414inline void registerDTensorLayoutPropagationV2() {
2415 ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {
2416 return CreateDTensorLayoutPropagationPassV2();
2417 });
2418}
2419
2420// Old registration code, kept for temporary backwards compatibility.
2421inline void registerDTensorLayoutPropagationV2Pass() {
2422 ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {
2423 return CreateDTensorLayoutPropagationPassV2();
2424 });
2425}
2426
2427//===----------------------------------------------------------------------===//
2428// DTensorLowerSendRecv Registration
2429//===----------------------------------------------------------------------===//
2430
2431inline void registerDTensorLowerSendRecv() {
2432 ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {
2433 return CreateDTensorLowerSendRecv();
2434 });
2435}
2436
2437// Old registration code, kept for temporary backwards compatibility.
2438inline void registerDTensorLowerSendRecvPass() {
2439 ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {
2440 return CreateDTensorLowerSendRecv();
2441 });
2442}
2443
2444//===----------------------------------------------------------------------===//
2445// DTensorMergeClusters Registration
2446//===----------------------------------------------------------------------===//
2447
2448inline void registerDTensorMergeClusters() {
2449 ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {
2450 return CreateDTensorMergeClustersPass();
2451 });
2452}
2453
2454// Old registration code, kept for temporary backwards compatibility.
2455inline void registerDTensorMergeClustersPass() {
2456 ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {
2457 return CreateDTensorMergeClustersPass();
2458 });
2459}
2460
2461//===----------------------------------------------------------------------===//
2462// DTensorMeshPropagation Registration
2463//===----------------------------------------------------------------------===//
2464
2465inline void registerDTensorMeshPropagation() {
2466 ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {
2467 return CreateDTensorMeshPropagationPass();
2468 });
2469}
2470
2471// Old registration code, kept for temporary backwards compatibility.
2472inline void registerDTensorMeshPropagationPass() {
2473 ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {
2474 return CreateDTensorMeshPropagationPass();
2475 });
2476}
2477
2478//===----------------------------------------------------------------------===//
2479// DTensorMixedPrecisionReduce Registration
2480//===----------------------------------------------------------------------===//
2481
2482inline void registerDTensorMixedPrecisionReduce() {
2483 ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {
2484 return CreateDTensorMixedPrecisionReducePass();
2485 });
2486}
2487
2488// Old registration code, kept for temporary backwards compatibility.
2489inline void registerDTensorMixedPrecisionReducePass() {
2490 ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {
2491 return CreateDTensorMixedPrecisionReducePass();
2492 });
2493}
2494
2495//===----------------------------------------------------------------------===//
2496// DTensorMoveCompilationToHost Registration
2497//===----------------------------------------------------------------------===//
2498
2499inline void registerDTensorMoveCompilationToHost() {
2500 ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {
2501 return CreateDTensorMoveCompilationToHost();
2502 });
2503}
2504
2505// Old registration code, kept for temporary backwards compatibility.
2506inline void registerDTensorMoveCompilationToHostPass() {
2507 ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {
2508 return CreateDTensorMoveCompilationToHost();
2509 });
2510}
2511
2512//===----------------------------------------------------------------------===//
2513// DTensorOpToDeviceCluster Registration
2514//===----------------------------------------------------------------------===//
2515
2516inline void registerDTensorOpToDeviceCluster() {
2517 ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {
2518 return CreateDTensorOpToDeviceClusterPass();
2519 });
2520}
2521
2522// Old registration code, kept for temporary backwards compatibility.
2523inline void registerDTensorOpToDeviceClusterPass() {
2524 ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {
2525 return CreateDTensorOpToDeviceClusterPass();
2526 });
2527}
2528
2529//===----------------------------------------------------------------------===//
2530// DTensorPropagateDefaultLayout Registration
2531//===----------------------------------------------------------------------===//
2532
2533inline void registerDTensorPropagateDefaultLayout() {
2534 ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {
2535 return CreateDTensorPropagateDefaultLayout();
2536 });
2537}
2538
2539// Old registration code, kept for temporary backwards compatibility.
2540inline void registerDTensorPropagateDefaultLayoutPass() {
2541 ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {
2542 return CreateDTensorPropagateDefaultLayout();
2543 });
2544}
2545
2546//===----------------------------------------------------------------------===//
2547// DTensorPropagateDeviceIdToFunctionArgs Registration
2548//===----------------------------------------------------------------------===//
2549
2550inline void registerDTensorPropagateDeviceIdToFunctionArgs() {
2551 ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {
2552 return CreateDTensorPropagateDeviceIdToFunctionArgs();
2553 });
2554}
2555
2556// Old registration code, kept for temporary backwards compatibility.
2557inline void registerDTensorPropagateDeviceIdToFunctionArgsPass() {
2558 ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {
2559 return CreateDTensorPropagateDeviceIdToFunctionArgs();
2560 });
2561}
2562
2563//===----------------------------------------------------------------------===//
2564// DTensorReduceScatterLowering Registration
2565//===----------------------------------------------------------------------===//
2566
2567inline void registerDTensorReduceScatterLowering() {
2568 ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {
2569 return CreateDTensorReduceScatterLoweringPass();
2570 });
2571}
2572
2573// Old registration code, kept for temporary backwards compatibility.
2574inline void registerDTensorReduceScatterLoweringPass() {
2575 ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {
2576 return CreateDTensorReduceScatterLoweringPass();
2577 });
2578}
2579
2580//===----------------------------------------------------------------------===//
2581// DTensorSPMDExpansion Registration
2582//===----------------------------------------------------------------------===//
2583
2584inline void registerDTensorSPMDExpansion() {
2585 ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {
2586 return CreateDTensorSPMDExpansion();
2587 });
2588}
2589
2590// Old registration code, kept for temporary backwards compatibility.
2591inline void registerDTensorSPMDExpansionPass() {
2592 ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {
2593 return CreateDTensorSPMDExpansion();
2594 });
2595}
2596
2597//===----------------------------------------------------------------------===//
2598// DTensorSetDefaultSharding Registration
2599//===----------------------------------------------------------------------===//
2600
2601inline void registerDTensorSetDefaultSharding() {
2602 ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {
2603 return CreateDTensorSetDefaultSharding();
2604 });
2605}
2606
2607// Old registration code, kept for temporary backwards compatibility.
2608inline void registerDTensorSetDefaultShardingPass() {
2609 ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {
2610 return CreateDTensorSetDefaultSharding();
2611 });
2612}
2613
2614//===----------------------------------------------------------------------===//
2615// DTensorSparseExpansion Registration
2616//===----------------------------------------------------------------------===//
2617
2618inline void registerDTensorSparseExpansion() {
2619 ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {
2620 return CreateDTensorSparseExpansion();
2621 });
2622}
2623
2624// Old registration code, kept for temporary backwards compatibility.
2625inline void registerDTensorSparseExpansionPass() {
2626 ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {
2627 return CreateDTensorSparseExpansion();
2628 });
2629}
2630
2631//===----------------------------------------------------------------------===//
2632// DTensorSparseTensorToDenseTensor Registration
2633//===----------------------------------------------------------------------===//
2634
2635inline void registerDTensorSparseTensorToDenseTensor() {
2636 ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {
2637 return CreateDTensorSparseTensorToDenseTensor();
2638 });
2639}
2640
2641// Old registration code, kept for temporary backwards compatibility.
2642inline void registerDTensorSparseTensorToDenseTensorPass() {
2643 ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {
2644 return CreateDTensorSparseTensorToDenseTensor();
2645 });
2646}
2647
2648//===----------------------------------------------------------------------===//
2649// DTensorTPUIntegration Registration
2650//===----------------------------------------------------------------------===//
2651
2652inline void registerDTensorTPUIntegration() {
2653 ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {
2654 return CreateDTensorTPUIntegration();
2655 });
2656}
2657
2658// Old registration code, kept for temporary backwards compatibility.
2659inline void registerDTensorTPUIntegrationPass() {
2660 ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {
2661 return CreateDTensorTPUIntegration();
2662 });
2663}
2664
2665//===----------------------------------------------------------------------===//
2666// DTensorTpuAddResourceDeviceAttribute Registration
2667//===----------------------------------------------------------------------===//
2668
2669inline void registerDTensorTpuAddResourceDeviceAttribute() {
2670 ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {
2671 return CreateDTensorTpuAddResourceDeviceAttribute();
2672 });
2673}
2674
2675// Old registration code, kept for temporary backwards compatibility.
2676inline void registerDTensorTpuAddResourceDeviceAttributePass() {
2677 ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {
2678 return CreateDTensorTpuAddResourceDeviceAttribute();
2679 });
2680}
2681
2682//===----------------------------------------------------------------------===//
2683// DTensorUndoMergeConstAcrossMesh Registration
2684//===----------------------------------------------------------------------===//
2685
2686inline void registerDTensorUndoMergeConstAcrossMesh() {
2687 ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {
2688 return CreateDTensorUndoMergeConstAcrossMesh();
2689 });
2690}
2691
2692// Old registration code, kept for temporary backwards compatibility.
2693inline void registerDTensorUndoMergeConstAcrossMeshPass() {
2694 ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {
2695 return CreateDTensorUndoMergeConstAcrossMesh();
2696 });
2697}
2698
2699//===----------------------------------------------------------------------===//
2700// DTensorUpdateTPUMetadata Registration
2701//===----------------------------------------------------------------------===//
2702
2703inline void registerDTensorUpdateTPUMetadata() {
2704 ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {
2705 return CreateDTensorUpdateTPUMetadata();
2706 });
2707}
2708
2709// Old registration code, kept for temporary backwards compatibility.
2710inline void registerDTensorUpdateTPUMetadataPass() {
2711 ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {
2712 return CreateDTensorUpdateTPUMetadata();
2713 });
2714}
2715
2716//===----------------------------------------------------------------------===//
2717// DTensor Registration
2718//===----------------------------------------------------------------------===//
2719
2720inline void registerDTensorPasses() {
2721 registerDTensorAllGatherLowering();
2722 registerDTensorAllReduceCombineOptimization();
2723 registerDTensorAllReduceLowering();
2724 registerDTensorAllReduceScatterOptimization();
2725 registerDTensorAllReduceSumOptimization();
2726 registerDTensorAllScatterLowering();
2727 registerDTensorAnnotateGlobalShape();
2728 registerDTensorClusterFunctionConversion();
2729 registerDTensorConstantFolding();
2730 registerDTensorDCE();
2731 registerDTensorDesignateResourceHandleMesh();
2732 registerDTensorDeviceMeshClusterCoarsening();
2733 registerDTensorEmbedding();
2734 registerDTensorEmbeddingCheckpoint();
2735 registerDTensorEmbeddingV2();
2736 registerDTensorFunctionRenaming();
2737 registerDTensorHandleCrossClusterDependencies();
2738 registerDTensorInferShapesForRestoreV2Op();
2739 registerDTensorLayoutPropagationV2();
2740 registerDTensorLowerSendRecv();
2741 registerDTensorMergeClusters();
2742 registerDTensorMeshPropagation();
2743 registerDTensorMixedPrecisionReduce();
2744 registerDTensorMoveCompilationToHost();
2745 registerDTensorOpToDeviceCluster();
2746 registerDTensorPropagateDefaultLayout();
2747 registerDTensorPropagateDeviceIdToFunctionArgs();
2748 registerDTensorReduceScatterLowering();
2749 registerDTensorSPMDExpansion();
2750 registerDTensorSetDefaultSharding();
2751 registerDTensorSparseExpansion();
2752 registerDTensorSparseTensorToDenseTensor();
2753 registerDTensorTPUIntegration();
2754 registerDTensorTpuAddResourceDeviceAttribute();
2755 registerDTensorUndoMergeConstAcrossMesh();
2756 registerDTensorUpdateTPUMetadata();
2757}
2758#undef GEN_PASS_REGISTRATION
2759#endif // GEN_PASS_REGISTRATION
2760// Deprecated. Please use the new per-pass macros.
2761#ifdef GEN_PASS_CLASSES
2762
2763template <typename DerivedT>
2764class DTensorAllGatherLoweringBase : public ::mlir::OperationPass<mlir::ModuleOp> {
2765public:
2766 using Base = DTensorAllGatherLoweringBase;
2767
2768 DTensorAllGatherLoweringBase() : ::mlir::OperationPass<mlir::ModuleOp>(::mlir::TypeID::get<DerivedT>()) {}
2769 DTensorAllGatherLoweringBase(const DTensorAllGatherLoweringBase &other) : ::mlir::OperationPass<mlir::ModuleOp>(other) {}
2770
2771 /// Returns the command-line argument attached to this pass.
2772 static constexpr ::llvm::StringLiteral getArgumentName() {
2773 return ::llvm::StringLiteral("dtensor-all-gather-lowering");
2774 }
2775 ::llvm::StringRef getArgument() const override { return "dtensor-all-gather-lowering"; }
2776
2777 ::llvm::StringRef getDescription() const override { return "Converts logical AllGather ops into physical AllGather ops."; }
2778
2779 /// Returns the derived pass name.
2780 static constexpr ::llvm::StringLiteral getPassName() {
2781 return ::llvm::StringLiteral("DTensorAllGatherLowering");
2782 }
2783 ::llvm::StringRef getName() const override { return "DTensorAllGatherLowering"; }
2784
2785 /// Support isa/dyn_cast functionality for the derived pass class.
2786 static bool classof(const ::mlir::Pass *pass) {
2787 return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>();
2788 }
2789
2790 /// A clone method to create a copy of this pass.
2791 std::unique_ptr<::mlir::Pass> clonePass() const override {
2792 return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this));
2793 }
2794
2795 /// Return the dialect that must be loaded in the context before this pass.
2796 void getDependentDialects(::mlir::DialectRegistry &registry) const override {
2797
2798 }
2799
2800 /// Explicitly declare the TypeID for this class. We declare an explicit private
2801 /// instantiation because Pass classes should only be visible by the current
2802 /// library.
2803 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DTensorAllGatherLoweringBase<DerivedT>)
2804
2805protected:
2806};
2807
2808template <typename DerivedT>
2809class DTensorAllReduceCombineOptimizationBase : public ::mlir::OperationPass<mlir::func::FuncOp> {
2810public:
2811 using Base = DTensorAllReduceCombineOptimizationBase;
2812
2813 DTensorAllReduceCombineOptimizationBase() : ::mlir::OperationPass<mlir::func::FuncOp>(::mlir::TypeID::get<DerivedT>()) {}
2814 DTensorAllReduceCombineOptimizationBase(const DTensorAllReduceCombineOptimizationBase &other) : ::mlir::OperationPass<mlir::func::FuncOp>(other) {}
2815
2816 /// Returns the command-line argument attached to this pass.
2817 static constexpr ::llvm::StringLiteral getArgumentName() {
2818 return ::llvm::StringLiteral("dtensor-allreduce-combine-optimization");
2819 }
2820 ::llvm::StringRef getArgument() const override { return "dtensor-allreduce-combine-optimization"; }
2821
2822 ::llvm::StringRef getDescription() const override { return "Combine independent all reduce operations."; }
2823
2824 /// Returns the derived pass name.
2825 static constexpr ::llvm::StringLiteral getPassName() {
2826 return ::llvm::StringLiteral("DTensorAllReduceCombineOptimization");
2827 }
2828 ::llvm::StringRef getName() const override { return "DTensorAllReduceCombineOptimization"; }
2829
2830 /// Support isa/dyn_cast functionality for the derived pass class.
2831 static bool classof(const ::mlir::Pass *pass) {
2832 return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>();
2833 }
2834
2835 /// A clone method to create a copy of this pass.
2836 std::unique_ptr<::mlir::Pass> clonePass() const override {
2837 return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this));
2838 }
2839
2840 /// Return the dialect that must be loaded in the context before this pass.
2841 void getDependentDialects(::mlir::DialectRegistry &registry) const override {
2842
2843 }
2844
2845 /// Explicitly declare the TypeID for this class. We declare an explicit private
2846 /// instantiation because Pass classes should only be visible by the current
2847 /// library.
2848 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DTensorAllReduceCombineOptimizationBase<DerivedT>)
2849
2850protected:
2851};
2852
2853template <typename DerivedT>
2854class DTensorAllReduceLoweringBase : public ::mlir::OperationPass<mlir::ModuleOp> {
2855public:
2856 using Base = DTensorAllReduceLoweringBase;
2857
2858 DTensorAllReduceLoweringBase() : ::mlir::OperationPass<mlir::ModuleOp>(::mlir::TypeID::get<DerivedT>()) {}
2859 DTensorAllReduceLoweringBase(const DTensorAllReduceLoweringBase &other) : ::mlir::OperationPass<mlir::ModuleOp>(other) {}
2860
2861 /// Returns the command-line argument attached to this pass.
2862 static constexpr ::llvm::StringLiteral getArgumentName() {
2863 return ::llvm::StringLiteral("dtensor-all-reduce-lowering");
2864 }
2865 ::llvm::StringRef getArgument() const override { return "dtensor-all-reduce-lowering"; }
2866
2867 ::llvm::StringRef getDescription() const override { return "Converts logical AllReduce ops into physical AllReduce ops."; }
2868
2869 /// Returns the derived pass name.
2870 static constexpr ::llvm::StringLiteral getPassName() {
2871 return ::llvm::StringLiteral("DTensorAllReduceLowering");
2872 }
2873 ::llvm::StringRef getName() const override { return "DTensorAllReduceLowering"; }
2874
2875 /// Support isa/dyn_cast functionality for the derived pass class.
2876 static bool classof(const ::mlir::Pass *pass) {
2877 return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>();
2878 }
2879
2880 /// A clone method to create a copy of this pass.
2881 std::unique_ptr<::mlir::Pass> clonePass() const override {
2882 return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this));
2883 }
2884
2885 /// Return the dialect that must be loaded in the context before this pass.
2886 void getDependentDialects(::mlir::DialectRegistry &registry) const override {
2887
2888 }
2889
2890 /// Explicitly declare the TypeID for this class. We declare an explicit private
2891 /// instantiation because Pass classes should only be visible by the current
2892 /// library.
2893 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DTensorAllReduceLoweringBase<DerivedT>)
2894
2895protected:
2896};
2897
2898template <typename DerivedT>
2899class DTensorAllReduceScatterOptimizationBase : public ::mlir::OperationPass<mlir::func::FuncOp> {
2900public:
2901 using Base = DTensorAllReduceScatterOptimizationBase;
2902
2903 DTensorAllReduceScatterOptimizationBase() : ::mlir::OperationPass<mlir::func::FuncOp>(::mlir::TypeID::get<DerivedT>()) {}
2904 DTensorAllReduceScatterOptimizationBase(const DTensorAllReduceScatterOptimizationBase &other) : ::mlir::OperationPass<mlir::func::FuncOp>(other) {}
2905
2906 /// Returns the command-line argument attached to this pass.
2907 static constexpr ::llvm::StringLiteral getArgumentName() {
2908 return ::llvm::StringLiteral("dtensor-allreduce-scatter-optimization");
2909 }
2910 ::llvm::StringRef getArgument() const override { return "dtensor-allreduce-scatter-optimization"; }
2911
2912 ::llvm::StringRef getDescription() const override { return "Combines allreduce and scatter to reducescatter."; }
2913
2914 /// Returns the derived pass name.
2915 static constexpr ::llvm::StringLiteral getPassName() {
2916 return ::llvm::StringLiteral("DTensorAllReduceScatterOptimization");
2917 }
2918 ::llvm::StringRef getName() const override { return "DTensorAllReduceScatterOptimization"; }
2919
2920 /// Support isa/dyn_cast functionality for the derived pass class.
2921 static bool classof(const ::mlir::Pass *pass) {
2922 return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>();
2923 }
2924
2925 /// A clone method to create a copy of this pass.
2926 std::unique_ptr<::mlir::Pass> clonePass() const override {
2927 return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this));
2928 }
2929
2930 /// Return the dialect that must be loaded in the context before this pass.
2931 void getDependentDialects(::mlir::DialectRegistry &registry) const override {
2932
2933 }
2934
2935 /// Explicitly declare the TypeID for this class. We declare an explicit private
2936 /// instantiation because Pass classes should only be visible by the current
2937 /// library.
2938 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DTensorAllReduceScatterOptimizationBase<DerivedT>)
2939
2940protected:
2941};
2942
2943template <typename DerivedT>
2944class DTensorAllReduceSumOptimizationBase : public ::mlir::OperationPass<mlir::func::FuncOp> {
2945public:
2946 using Base = DTensorAllReduceSumOptimizationBase;
2947
2948 DTensorAllReduceSumOptimizationBase() : ::mlir::OperationPass<mlir::func::FuncOp>(::mlir::TypeID::get<DerivedT>()) {}
2949 DTensorAllReduceSumOptimizationBase(const DTensorAllReduceSumOptimizationBase &other) : ::mlir::OperationPass<mlir::func::FuncOp>(other) {}
2950
2951 /// Returns the command-line argument attached to this pass.
2952 static constexpr ::llvm::StringLiteral getArgumentName() {
2953 return ::llvm::StringLiteral("dtensor-allreduce-sum-optimization");
2954 }
2955 ::llvm::StringRef getArgument() const override { return "dtensor-allreduce-sum-optimization"; }
2956
2957 ::llvm::StringRef getDescription() const override { return "Changes order of add/allreduce to minimize all reduce operations."; }
2958
2959 /// Returns the derived pass name.
2960 static constexpr ::llvm::StringLiteral getPassName() {
2961 return ::llvm::StringLiteral("DTensorAllReduceSumOptimization");
2962 }
2963 ::llvm::StringRef getName() const override { return "DTensorAllReduceSumOptimization"; }
2964
2965 /// Support isa/dyn_cast functionality for the derived pass class.
2966 static bool classof(const ::mlir::Pass *pass) {
2967 return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>();
2968 }
2969
2970 /// A clone method to create a copy of this pass.
2971 std::unique_ptr<::mlir::Pass> clonePass() const override {
2972 return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this));
2973 }
2974
2975 /// Return the dialect that must be loaded in the context before this pass.
2976 void getDependentDialects(::mlir::DialectRegistry &registry) const override {
2977
2978 }
2979
2980 /// Explicitly declare the TypeID for this class. We declare an explicit private
2981 /// instantiation because Pass classes should only be visible by the current
2982 /// library.
2983 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DTensorAllReduceSumOptimizationBase<DerivedT>)
2984
2985protected:
2986};
2987
2988template <typename DerivedT>
2989class DTensorAllScatterLoweringBase : public ::mlir::OperationPass<mlir::ModuleOp> {
2990public:
2991 using Base = DTensorAllScatterLoweringBase;
2992
2993 DTensorAllScatterLoweringBase() : ::mlir::OperationPass<mlir::ModuleOp>(::mlir::TypeID::get<DerivedT>()) {}
2994 DTensorAllScatterLoweringBase(const DTensorAllScatterLoweringBase &other) : ::mlir::OperationPass<mlir::ModuleOp>(other) {}
2995
2996 /// Returns the command-line argument attached to this pass.
2997 static constexpr ::llvm::StringLiteral getArgumentName() {
2998 return ::llvm::StringLiteral("dtensor-all-scatter-lowering");
2999 }
3000 ::llvm::StringRef getArgument() const override { return "dtensor-all-scatter-lowering"; }
3001
3002 ::llvm::StringRef getDescription() const override { return "Converts logical AllScatter ops into physical Split ops."; }
3003
3004 /// Returns the derived pass name.
3005 static constexpr ::llvm::StringLiteral getPassName() {
3006 return ::llvm::StringLiteral("DTensorAllScatterLowering");
3007 }
3008 ::llvm::StringRef getName() const override { return "DTensorAllScatterLowering"; }
3009
3010 /// Support isa/dyn_cast functionality for the derived pass class.
3011 static bool classof(const ::mlir::Pass *pass) {
3012 return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>();
3013 }
3014
3015 /// A clone method to create a copy of this pass.
3016 std::unique_ptr<::mlir::Pass> clonePass() const override {
3017 return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this));
3018 }
3019
3020 /// Return the dialect that must be loaded in the context before this pass.
3021 void getDependentDialects(::mlir::DialectRegistry &registry) const override {
3022
3023 }
3024
3025 /// Explicitly declare the TypeID for this class. We declare an explicit private
3026 /// instantiation because Pass classes should only be visible by the current
3027 /// library.
3028 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DTensorAllScatterLoweringBase<DerivedT>)
3029
3030protected:
3031};
3032
3033template <typename DerivedT>
3034class DTensorAnnotateGlobalShapeBase : public ::mlir::OperationPass<mlir::ModuleOp> {
3035public:
3036 using Base = DTensorAnnotateGlobalShapeBase;
3037
3038 DTensorAnnotateGlobalShapeBase() : ::mlir::OperationPass<mlir::ModuleOp>(::mlir::TypeID::get<DerivedT>()) {}
3039 DTensorAnnotateGlobalShapeBase(const DTensorAnnotateGlobalShapeBase &other) : ::mlir::OperationPass<mlir::ModuleOp>(other) {}
3040
3041 /// Returns the command-line argument attached to this pass.
3042 static constexpr ::llvm::StringLiteral getArgumentName() {
3043 return ::llvm::StringLiteral("dtensor-annotate-global-shape");
3044 }
3045 ::llvm::StringRef getArgument() const override { return "dtensor-annotate-global-shape"; }
3046
3047 ::llvm::StringRef getDescription() const override { return "Mark all operations and function arguments with `_global_shape` attribute to be used during SPMD expansion."; }
3048
3049 /// Returns the derived pass name.
3050 static constexpr ::llvm::StringLiteral getPassName() {
3051 return ::llvm::StringLiteral("DTensorAnnotateGlobalShape");
3052 }
3053 ::llvm::StringRef getName() const override { return "DTensorAnnotateGlobalShape"; }
3054
3055 /// Support isa/dyn_cast functionality for the derived pass class.
3056 static bool classof(const ::mlir::Pass *pass) {
3057 return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>();
3058 }
3059
3060 /// A clone method to create a copy of this pass.
3061 std::unique_ptr<::mlir::Pass> clonePass() const override {
3062 return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this));
3063 }
3064
3065 /// Return the dialect that must be loaded in the context before this pass.
3066 void getDependentDialects(::mlir::DialectRegistry &registry) const override {
3067
3068 }
3069
3070 /// Explicitly declare the TypeID for this class. We declare an explicit private
3071 /// instantiation because Pass classes should only be visible by the current
3072 /// library.
3073 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DTensorAnnotateGlobalShapeBase<DerivedT>)
3074
3075protected:
3076};
3077
3078template <typename DerivedT>
3079class DTensorClusterFunctionConversionBase : public ::mlir::OperationPass<mlir::ModuleOp> {
3080public:
3081 using Base = DTensorClusterFunctionConversionBase;
3082
3083 DTensorClusterFunctionConversionBase() : ::mlir::OperationPass<mlir::ModuleOp>(::mlir::TypeID::get<DerivedT>()) {}
3084 DTensorClusterFunctionConversionBase(const DTensorClusterFunctionConversionBase &other) : ::mlir::OperationPass<mlir::ModuleOp>(other) {}
3085
3086 /// Returns the command-line argument attached to this pass.
3087 static constexpr ::llvm::StringLiteral getArgumentName() {
3088 return ::llvm::StringLiteral("dtensor-cluster-function-conversion");
3089 }
3090 ::llvm::StringRef getArgument() const override { return "dtensor-cluster-function-conversion"; }
3091
3092 ::llvm::StringRef getDescription() const override { return "Converts tf_device.cluster_func ops into TF StatefulPartitioned call op with mesh attribute."; }
3093
3094 /// Returns the derived pass name.
3095 static constexpr ::llvm::StringLiteral getPassName() {
3096 return ::llvm::StringLiteral("DTensorClusterFunctionConversion");
3097 }
3098 ::llvm::StringRef getName() const override { return "DTensorClusterFunctionConversion"; }
3099
3100 /// Support isa/dyn_cast functionality for the derived pass class.
3101 static bool classof(const ::mlir::Pass *pass) {
3102 return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>();
3103 }
3104
3105 /// A clone method to create a copy of this pass.
3106 std::unique_ptr<::mlir::Pass> clonePass() const override {
3107 return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this));
3108 }
3109
3110 /// Return the dialect that must be loaded in the context before this pass.
3111 void getDependentDialects(::mlir::DialectRegistry &registry) const override {
3112
3113 }
3114
3115 /// Explicitly declare the TypeID for this class. We declare an explicit private
3116 /// instantiation because Pass classes should only be visible by the current
3117 /// library.
3118 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DTensorClusterFunctionConversionBase<DerivedT>)
3119
3120protected:
3121};
3122
3123template <typename DerivedT>
3124class DTensorConstantFoldingBase : public ::mlir::OperationPass<mlir::func::FuncOp> {
3125public:
3126 using Base = DTensorConstantFoldingBase;
3127
3128 DTensorConstantFoldingBase() : ::mlir::OperationPass<mlir::func::FuncOp>(::mlir::TypeID::get<DerivedT>()) {}
3129 DTensorConstantFoldingBase(const DTensorConstantFoldingBase &other) : ::mlir::OperationPass<mlir::func::FuncOp>(other) {}
3130
3131 /// Returns the command-line argument attached to this pass.
3132 static constexpr ::llvm::StringLiteral getArgumentName() {
3133 return ::llvm::StringLiteral("dtensor-constant-folding");
3134 }
3135 ::llvm::StringRef getArgument() const override { return "dtensor-constant-folding"; }
3136
3137 ::llvm::StringRef getDescription() const override { return "Folds constants operations."; }
3138
3139 /// Returns the derived pass name.
3140 static constexpr ::llvm::StringLiteral getPassName() {
3141 return ::llvm::StringLiteral("DTensorConstantFolding");
3142 }
3143 ::llvm::StringRef getName() const override { return "DTensorConstantFolding"; }
3144
3145 /// Support isa/dyn_cast functionality for the derived pass class.
3146 static bool classof(const ::mlir::Pass *pass) {
3147 return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>();
3148 }
3149
3150 /// A clone method to create a copy of this pass.
3151 std::unique_ptr<::mlir::Pass> clonePass() const override {
3152 return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this));
3153 }
3154
3155 /// Return the dialect that must be loaded in the context before this pass.
3156 void getDependentDialects(::mlir::DialectRegistry &registry) const override {
3157
3158 }
3159
3160 /// Explicitly declare the TypeID for this class. We declare an explicit private
3161 /// instantiation because Pass classes should only be visible by the current
3162 /// library.
3163 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DTensorConstantFoldingBase<DerivedT>)
3164
3165protected:
3166};
3167
3168template <typename DerivedT>
3169class DTensorDCEBase : public ::mlir::OperationPass<mlir::func::FuncOp> {
3170public:
3171 using Base = DTensorDCEBase;
3172
3173 DTensorDCEBase() : ::mlir::OperationPass<mlir::func::FuncOp>(::mlir::TypeID::get<DerivedT>()) {}
3174 DTensorDCEBase(const DTensorDCEBase &other) : ::mlir::OperationPass<mlir::func::FuncOp>(other) {}
3175
3176 /// Returns the command-line argument attached to this pass.
3177 static constexpr ::llvm::StringLiteral getArgumentName() {
3178 return ::llvm::StringLiteral("dtensor-dce");
3179 }
3180 ::llvm::StringRef getArgument() const override { return "dtensor-dce"; }
3181
3182 ::llvm::StringRef getDescription() const override { return "Removes unused ops from graph."; }
3183
3184 /// Returns the derived pass name.
3185 static constexpr ::llvm::StringLiteral getPassName() {
3186 return ::llvm::StringLiteral("DTensorDCE");
3187 }
3188 ::llvm::StringRef getName() const override { return "DTensorDCE"; }
3189
3190 /// Support isa/dyn_cast functionality for the derived pass class.
3191 static bool classof(const ::mlir::Pass *pass) {
3192 return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>();
3193 }
3194
3195 /// A clone method to create a copy of this pass.
3196 std::unique_ptr<::mlir::Pass> clonePass() const override {
3197 return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this));
3198 }
3199
3200 /// Return the dialect that must be loaded in the context before this pass.
3201 void getDependentDialects(::mlir::DialectRegistry &registry) const override {
3202
3203 }
3204
3205 /// Explicitly declare the TypeID for this class. We declare an explicit private
3206 /// instantiation because Pass classes should only be visible by the current
3207 /// library.
3208 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DTensorDCEBase<DerivedT>)
3209
3210protected:
3211};
3212
3213template <typename DerivedT>
3214class DTensorDesignateResourceHandleMeshBase : public ::mlir::OperationPass<mlir::func::FuncOp> {
3215public:
3216 using Base = DTensorDesignateResourceHandleMeshBase;
3217
3218 DTensorDesignateResourceHandleMeshBase() : ::mlir::OperationPass<mlir::func::FuncOp>(::mlir::TypeID::get<DerivedT>()) {}
3219 DTensorDesignateResourceHandleMeshBase(const DTensorDesignateResourceHandleMeshBase &other) : ::mlir::OperationPass<mlir::func::FuncOp>(other) {}
3220
3221 /// Returns the command-line argument attached to this pass.
3222 static constexpr ::llvm::StringLiteral getArgumentName() {
3223 return ::llvm::StringLiteral("dtensor-designate-resource-handle-mesh");
3224 }
3225 ::llvm::StringRef getArgument() const override { return "dtensor-designate-resource-handle-mesh"; }
3226
3227 ::llvm::StringRef getDescription() const override { return "Sets empty mesh attributes for device cluster that creates or destroys resource handles."; }
3228
3229 /// Returns the derived pass name.
3230 static constexpr ::llvm::StringLiteral getPassName() {
3231 return ::llvm::StringLiteral("DTensorDesignateResourceHandleMesh");
3232 }
3233 ::llvm::StringRef getName() const override { return "DTensorDesignateResourceHandleMesh"; }
3234
3235 /// Support isa/dyn_cast functionality for the derived pass class.
3236 static bool classof(const ::mlir::Pass *pass) {
3237 return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>();
3238 }
3239
3240 /// A clone method to create a copy of this pass.
3241 std::unique_ptr<::mlir::Pass> clonePass() const override {
3242 return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this));
3243 }
3244
3245 /// Return the dialect that must be loaded in the context before this pass.
3246 void getDependentDialects(::mlir::DialectRegistry &registry) const override {
3247
3248 }
3249
3250 /// Explicitly declare the TypeID for this class. We declare an explicit private
3251 /// instantiation because Pass classes should only be visible by the current
3252 /// library.
3253 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DTensorDesignateResourceHandleMeshBase<DerivedT>)
3254
3255protected:
3256};
3257
3258template <typename DerivedT>
3259class DTensorDeviceMeshClusterCoarseningBase : public ::mlir::OperationPass<mlir::func::FuncOp> {
3260public:
3261 using Base = DTensorDeviceMeshClusterCoarseningBase;
3262
3263 DTensorDeviceMeshClusterCoarseningBase() : ::mlir::OperationPass<mlir::func::FuncOp>(::mlir::TypeID::get<DerivedT>()) {}
3264 DTensorDeviceMeshClusterCoarseningBase(const DTensorDeviceMeshClusterCoarseningBase &other) : ::mlir::OperationPass<mlir::func::FuncOp>(other) {}
3265
3266 /// Returns the command-line argument attached to this pass.
3267 static constexpr ::llvm::StringLiteral getArgumentName() {
3268 return ::llvm::StringLiteral("dtensor-device-mesh-cluster-coarsening");
3269 }
3270 ::llvm::StringRef getArgument() const override { return "dtensor-device-mesh-cluster-coarsening"; }
3271
3272 ::llvm::StringRef getDescription() const override { return "Merges tf_device.cluster op with same mesh attribute."; }
3273
3274 /// Returns the derived pass name.
3275 static constexpr ::llvm::StringLiteral getPassName() {
3276 return ::llvm::StringLiteral("DTensorDeviceMeshClusterCoarsening");
3277 }
3278 ::llvm::StringRef getName() const override { return "DTensorDeviceMeshClusterCoarsening"; }
3279
3280 /// Support isa/dyn_cast functionality for the derived pass class.
3281 static bool classof(const ::mlir::Pass *pass) {
3282 return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>();
3283 }
3284
3285 /// A clone method to create a copy of this pass.
3286 std::unique_ptr<::mlir::Pass> clonePass() const override {
3287 return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this));
3288 }
3289
3290 /// Return the dialect that must be loaded in the context before this pass.
3291 void getDependentDialects(::mlir::DialectRegistry &registry) const override {
3292
3293 }
3294
3295 /// Explicitly declare the TypeID for this class. We declare an explicit private
3296 /// instantiation because Pass classes should only be visible by the current
3297 /// library.
3298 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DTensorDeviceMeshClusterCoarseningBase<DerivedT>)
3299
3300protected:
3301};
3302
3303template <typename DerivedT>
3304class DTensorEmbeddingBase : public ::mlir::OperationPass<mlir::ModuleOp> {
3305public:
3306 using Base = DTensorEmbeddingBase;
3307
3308 DTensorEmbeddingBase() : ::mlir::OperationPass<mlir::ModuleOp>(::mlir::TypeID::get<DerivedT>()) {}
3309 DTensorEmbeddingBase(const DTensorEmbeddingBase &other) : ::mlir::OperationPass<mlir::ModuleOp>(other) {}
3310
3311 /// Returns the command-line argument attached to this pass.
3312 static constexpr ::llvm::StringLiteral getArgumentName() {
3313 return ::llvm::StringLiteral("dtensor-embedding");
3314 }
3315 ::llvm::StringRef getArgument() const override { return "dtensor-embedding"; }
3316
3317 ::llvm::StringRef getDescription() const override { return "Extracts the ApplyEmbeddingOptimizer and checks if the attached optimizer is usable for TPU embeddings."; }
3318
3319 /// Returns the derived pass name.
3320 static constexpr ::llvm::StringLiteral getPassName() {
3321 return ::llvm::StringLiteral("DTensorEmbedding");
3322 }
3323 ::llvm::StringRef getName() const override { return "DTensorEmbedding"; }
3324
3325 /// Support isa/dyn_cast functionality for the derived pass class.
3326 static bool classof(const ::mlir::Pass *pass) {
3327 return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>();
3328 }
3329
3330 /// A clone method to create a copy of this pass.
3331 std::unique_ptr<::mlir::Pass> clonePass() const override {
3332 return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this));
3333 }
3334
3335 /// Return the dialect that must be loaded in the context before this pass.
3336 void getDependentDialects(::mlir::DialectRegistry &registry) const override {
3337
3338 registry.insert<::mlir::chlo::ChloDialect>();
3339
3340 registry.insert<::mlir::memref::MemRefDialect>();
3341
3342 registry.insert<::mlir::mhlo::MhloDialect>();
3343
3344 registry.insert<::mlir::shape::ShapeDialect>();
3345
3346 registry.insert<::mlir::scf::SCFDialect>();
3347
3348 registry.insert<::mlir::tensor::TensorDialect>();
3349
3350 }
3351
3352 /// Explicitly declare the TypeID for this class. We declare an explicit private
3353 /// instantiation because Pass classes should only be visible by the current
3354 /// library.
3355 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DTensorEmbeddingBase<DerivedT>)
3356
3357protected:
3358};
3359
3360template <typename DerivedT>
3361class DTensorEmbeddingCheckpointBase : public ::mlir::OperationPass<mlir::ModuleOp> {
3362public:
3363 using Base = DTensorEmbeddingCheckpointBase;
3364
3365 DTensorEmbeddingCheckpointBase() : ::mlir::OperationPass<mlir::ModuleOp>(::mlir::TypeID::get<DerivedT>()) {}
3366 DTensorEmbeddingCheckpointBase(const DTensorEmbeddingCheckpointBase &other) : ::mlir::OperationPass<mlir::ModuleOp>(other) {}
3367
3368 /// Returns the command-line argument attached to this pass.
3369 static constexpr ::llvm::StringLiteral getArgumentName() {
3370 return ::llvm::StringLiteral("dtensor-embedding-checkpoint");
3371 }
3372 ::llvm::StringRef getArgument() const override { return "dtensor-embedding-checkpoint"; }
3373
3374 ::llvm::StringRef getDescription() const override { return "Checkpoint pass to save or load TPU embeddings variables."; }
3375
3376 /// Returns the derived pass name.
3377 static constexpr ::llvm::StringLiteral getPassName() {
3378 return ::llvm::StringLiteral("DTensorEmbeddingCheckpoint");
3379 }
3380 ::llvm::StringRef getName() const override { return "DTensorEmbeddingCheckpoint"; }
3381
3382 /// Support isa/dyn_cast functionality for the derived pass class.
3383 static bool classof(const ::mlir::Pass *pass) {
3384 return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>();
3385 }
3386
3387 /// A clone method to create a copy of this pass.
3388 std::unique_ptr<::mlir::Pass> clonePass() const override {
3389 return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this));
3390 }
3391
3392 /// Return the dialect that must be loaded in the context before this pass.
3393 void getDependentDialects(::mlir::DialectRegistry &registry) const override {
3394
3395 }
3396
3397 /// Explicitly declare the TypeID for this class. We declare an explicit private
3398 /// instantiation because Pass classes should only be visible by the current
3399 /// library.
3400 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DTensorEmbeddingCheckpointBase<DerivedT>)
3401
3402protected:
3403};
3404
3405template <typename DerivedT>
3406class DTensorEmbeddingV2Base : public ::mlir::OperationPass<mlir::ModuleOp> {
3407public:
3408 using Base = DTensorEmbeddingV2Base;
3409
3410 DTensorEmbeddingV2Base() : ::mlir::OperationPass<mlir::ModuleOp>(::mlir::TypeID::get<DerivedT>()) {}
3411 DTensorEmbeddingV2Base(const DTensorEmbeddingV2Base &other) : ::mlir::OperationPass<mlir::ModuleOp>(other) {}
3412
3413 /// Returns the command-line argument attached to this pass.
3414 static constexpr ::llvm::StringLiteral getArgumentName() {
3415 return ::llvm::StringLiteral("dtensor-embedding-v2");
3416 }
3417 ::llvm::StringRef getArgument() const override { return "dtensor-embedding-v2"; }
3418
3419 ::llvm::StringRef getDescription() const override { return "Embedding Pass"; }
3420
3421 /// Returns the derived pass name.
3422 static constexpr ::llvm::StringLiteral getPassName() {
3423 return ::llvm::StringLiteral("DTensorEmbeddingV2");
3424 }
3425 ::llvm::StringRef getName() const override { return "DTensorEmbeddingV2"; }
3426
3427 /// Support isa/dyn_cast functionality for the derived pass class.
3428 static bool classof(const ::mlir::Pass *pass) {
3429 return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>();
3430 }
3431
3432 /// A clone method to create a copy of this pass.
3433 std::unique_ptr<::mlir::Pass> clonePass() const override {
3434 return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this));
3435 }
3436
3437 /// Return the dialect that must be loaded in the context before this pass.
3438 void getDependentDialects(::mlir::DialectRegistry &registry) const override {
3439
3440 }
3441
3442 /// Explicitly declare the TypeID for this class. We declare an explicit private
3443 /// instantiation because Pass classes should only be visible by the current
3444 /// library.
3445 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DTensorEmbeddingV2Base<DerivedT>)
3446
3447protected:
3448};
3449
3450template <typename DerivedT>
3451class DTensorFunctionRenamingBase : public ::mlir::OperationPass<mlir::ModuleOp> {
3452public:
3453 using Base = DTensorFunctionRenamingBase;
3454
3455 DTensorFunctionRenamingBase() : ::mlir::OperationPass<mlir::ModuleOp>(::mlir::TypeID::get<DerivedT>()) {}
3456 DTensorFunctionRenamingBase(const DTensorFunctionRenamingBase &other) : ::mlir::OperationPass<mlir::ModuleOp>(other) {}
3457
3458 /// Returns the command-line argument attached to this pass.
3459 static constexpr ::llvm::StringLiteral getArgumentName() {
3460 return ::llvm::StringLiteral("dtensor-function-renaming");
3461 }
3462 ::llvm::StringRef getArgument() const override { return "dtensor-function-renaming"; }
3463
3464 ::llvm::StringRef getDescription() const override { return "Renames private functions by appending an id to each name. This is used to make private function names unique across modules."; }
3465
3466 /// Returns the derived pass name.
3467 static constexpr ::llvm::StringLiteral getPassName() {
3468 return ::llvm::StringLiteral("DTensorFunctionRenaming");
3469 }
3470 ::llvm::StringRef getName() const override { return "DTensorFunctionRenaming"; }
3471
3472 /// Support isa/dyn_cast functionality for the derived pass class.
3473 static bool classof(const ::mlir::Pass *pass) {
3474 return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>();
3475 }
3476
3477 /// A clone method to create a copy of this pass.
3478 std::unique_ptr<::mlir::Pass> clonePass() const override {
3479 return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this));
3480 }
3481
3482 /// Return the dialect that must be loaded in the context before this pass.
3483 void getDependentDialects(::mlir::DialectRegistry &registry) const override {
3484
3485 }
3486
3487 /// Explicitly declare the TypeID for this class. We declare an explicit private
3488 /// instantiation because Pass classes should only be visible by the current
3489 /// library.
3490 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DTensorFunctionRenamingBase<DerivedT>)
3491
3492protected:
3493};
3494
3495template <typename DerivedT>
3496class DTensorHandleCrossClusterDependenciesBase : public ::mlir::OperationPass<mlir::ModuleOp> {
3497public:
3498 using Base = DTensorHandleCrossClusterDependenciesBase;
3499
3500 DTensorHandleCrossClusterDependenciesBase() : ::mlir::OperationPass<mlir::ModuleOp>(::mlir::TypeID::get<DerivedT>()) {}
3501 DTensorHandleCrossClusterDependenciesBase(const DTensorHandleCrossClusterDependenciesBase &other) : ::mlir::OperationPass<mlir::ModuleOp>(other) {}
3502
3503 /// Returns the command-line argument attached to this pass.
3504 static constexpr ::llvm::StringLiteral getArgumentName() {
3505 return ::llvm::StringLiteral("dtensor-handle_cross_cluster_dependences");
3506 }
3507 ::llvm::StringRef getArgument() const override { return "dtensor-handle_cross_cluster_dependences"; }
3508
3509 ::llvm::StringRef getDescription() const override { return "Lowers cross mesh cluster data depedences as send/recvs."; }
3510
3511 /// Returns the derived pass name.
3512 static constexpr ::llvm::StringLiteral getPassName() {
3513 return ::llvm::StringLiteral("DTensorHandleCrossClusterDependencies");
3514 }
3515 ::llvm::StringRef getName() const override { return "DTensorHandleCrossClusterDependencies"; }
3516
3517 /// Support isa/dyn_cast functionality for the derived pass class.
3518 static bool classof(const ::mlir::Pass *pass) {
3519 return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>();
3520 }
3521
3522 /// A clone method to create a copy of this pass.
3523 std::unique_ptr<::mlir::Pass> clonePass() const override {
3524 return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this));
3525 }
3526
3527 /// Return the dialect that must be loaded in the context before this pass.
3528 void getDependentDialects(::mlir::DialectRegistry &registry) const override {
3529
3530 }
3531
3532 /// Explicitly declare the TypeID for this class. We declare an explicit private
3533 /// instantiation because Pass classes should only be visible by the current
3534 /// library.
3535 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DTensorHandleCrossClusterDependenciesBase<DerivedT>)
3536
3537protected:
3538};
3539
3540template <typename DerivedT>
3541class DTensorInferShapesForRestoreV2OpBase : public ::mlir::OperationPass<mlir::ModuleOp> {
3542public:
3543 using Base = DTensorInferShapesForRestoreV2OpBase;
3544
3545 DTensorInferShapesForRestoreV2OpBase() : ::mlir::OperationPass<mlir::ModuleOp>(::mlir::TypeID::get<DerivedT>()) {}
3546 DTensorInferShapesForRestoreV2OpBase(const DTensorInferShapesForRestoreV2OpBase &other) : ::mlir::OperationPass<mlir::ModuleOp>(other) {}
3547
3548 /// Returns the command-line argument attached to this pass.
3549 static constexpr ::llvm::StringLiteral getArgumentName() {
3550 return ::llvm::StringLiteral("dtensor-infer-shapes-for-restorev2-op");
3551 }
3552 ::llvm::StringRef getArgument() const override { return "dtensor-infer-shapes-for-restorev2-op"; }
3553
3554 ::llvm::StringRef getDescription() const override { return "Infer shapes of the outputs of tf.RestoreV2Op from the AssignVariableOps that consume those outputs. This is used for DTensor integration with TF Checkpoint."; }
3555
3556 /// Returns the derived pass name.
3557 static constexpr ::llvm::StringLiteral getPassName() {
3558 return ::llvm::StringLiteral("DTensorInferShapesForRestoreV2Op");
3559 }
3560 ::llvm::StringRef getName() const override { return "DTensorInferShapesForRestoreV2Op"; }
3561
3562 /// Support isa/dyn_cast functionality for the derived pass class.
3563 static bool classof(const ::mlir::Pass *pass) {
3564 return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>();
3565 }
3566
3567 /// A clone method to create a copy of this pass.
3568 std::unique_ptr<::mlir::Pass> clonePass() const override {
3569 return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this));
3570 }
3571
3572 /// Return the dialect that must be loaded in the context before this pass.
3573 void getDependentDialects(::mlir::DialectRegistry &registry) const override {
3574
3575 }
3576
3577 /// Explicitly declare the TypeID for this class. We declare an explicit private
3578 /// instantiation because Pass classes should only be visible by the current
3579 /// library.
3580 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DTensorInferShapesForRestoreV2OpBase<DerivedT>)
3581
3582protected:
3583};
3584
3585template <typename DerivedT>
3586class DTensorLayoutPropagationV2Base : public ::mlir::OperationPass<mlir::ModuleOp> {
3587public:
3588 using Base = DTensorLayoutPropagationV2Base;
3589
3590 DTensorLayoutPropagationV2Base() : ::mlir::OperationPass<mlir::ModuleOp>(::mlir::TypeID::get<DerivedT>()) {}
3591 DTensorLayoutPropagationV2Base(const DTensorLayoutPropagationV2Base &other) : ::mlir::OperationPass<mlir::ModuleOp>(other) {}
3592
3593 /// Returns the command-line argument attached to this pass.
3594 static constexpr ::llvm::StringLiteral getArgumentName() {
3595 return ::llvm::StringLiteral("dtensor-layout-propagation-v2");
3596 }
3597 ::llvm::StringRef getArgument() const override { return "dtensor-layout-propagation-v2"; }
3598
3599 ::llvm::StringRef getDescription() const override { return "Propagates layout information for all the TF Ops."; }
3600
3601 /// Returns the derived pass name.
3602 static constexpr ::llvm::StringLiteral getPassName() {
3603 return ::llvm::StringLiteral("DTensorLayoutPropagationV2");
3604 }
3605 ::llvm::StringRef getName() const override { return "DTensorLayoutPropagationV2"; }
3606
3607 /// Support isa/dyn_cast functionality for the derived pass class.
3608 static bool classof(const ::mlir::Pass *pass) {
3609 return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>();
3610 }
3611
3612 /// A clone method to create a copy of this pass.
3613 std::unique_ptr<::mlir::Pass> clonePass() const override {
3614 return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this));
3615 }
3616
3617 /// Return the dialect that must be loaded in the context before this pass.
3618 void getDependentDialects(::mlir::DialectRegistry &registry) const override {
3619
3620 }
3621
3622 /// Explicitly declare the TypeID for this class. We declare an explicit private
3623 /// instantiation because Pass classes should only be visible by the current
3624 /// library.
3625 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DTensorLayoutPropagationV2Base<DerivedT>)
3626
3627protected:
3628};
3629
3630template <typename DerivedT>
3631class DTensorLowerSendRecvBase : public ::mlir::OperationPass<mlir::ModuleOp> {
3632public:
3633 using Base = DTensorLowerSendRecvBase;
3634
3635 DTensorLowerSendRecvBase() : ::mlir::OperationPass<mlir::ModuleOp>(::mlir::TypeID::get<DerivedT>()) {}
3636 DTensorLowerSendRecvBase(const DTensorLowerSendRecvBase &other) : ::mlir::OperationPass<mlir::ModuleOp>(other) {}
3637
3638 /// Returns the command-line argument attached to this pass.
3639 static constexpr ::llvm::StringLiteral getArgumentName() {
3640 return ::llvm::StringLiteral("dtensor-lower-send-recv");
3641 }
3642 ::llvm::StringRef getArgument() const override { return "dtensor-lower-send-recv"; }
3643
3644 ::llvm::StringRef getDescription() const override { return "Lowers DTensorSend/DTensorRecv ops to send/recv ops."; }
3645
3646 /// Returns the derived pass name.
3647 static constexpr ::llvm::StringLiteral getPassName() {
3648 return ::llvm::StringLiteral("DTensorLowerSendRecv");
3649 }
3650 ::llvm::StringRef getName() const override { return "DTensorLowerSendRecv"; }
3651
3652 /// Support isa/dyn_cast functionality for the derived pass class.
3653 static bool classof(const ::mlir::Pass *pass) {
3654 return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>();
3655 }
3656
3657 /// A clone method to create a copy of this pass.
3658 std::unique_ptr<::mlir::Pass> clonePass() const override {
3659 return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this));
3660 }
3661
3662 /// Return the dialect that must be loaded in the context before this pass.
3663 void getDependentDialects(::mlir::DialectRegistry &registry) const override {
3664
3665 }
3666
3667 /// Explicitly declare the TypeID for this class. We declare an explicit private
3668 /// instantiation because Pass classes should only be visible by the current
3669 /// library.
3670 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DTensorLowerSendRecvBase<DerivedT>)
3671
3672protected:
3673};
3674
3675template <typename DerivedT>
3676class DTensorMergeClustersBase : public ::mlir::OperationPass<mlir::ModuleOp> {
3677public:
3678 using Base = DTensorMergeClustersBase;
3679
3680 DTensorMergeClustersBase() : ::mlir::OperationPass<mlir::ModuleOp>(::mlir::TypeID::get<DerivedT>()) {}
3681 DTensorMergeClustersBase(const DTensorMergeClustersBase &other) : ::mlir::OperationPass<mlir::ModuleOp>(other) {}
3682
3683 /// Returns the command-line argument attached to this pass.
3684 static constexpr ::llvm::StringLiteral getArgumentName() {
3685 return ::llvm::StringLiteral("dtensor-merge-clusters");
3686 }
3687 ::llvm::StringRef getArgument() const override { return "dtensor-merge-clusters"; }
3688
3689 ::llvm::StringRef getDescription() const override { return "Merges tf_device.Clusters ops with same mesh specification."; }
3690
3691 /// Returns the derived pass name.
3692 static constexpr ::llvm::StringLiteral getPassName() {
3693 return ::llvm::StringLiteral("DTensorMergeClusters");
3694 }
3695 ::llvm::StringRef getName() const override { return "DTensorMergeClusters"; }
3696
3697 /// Support isa/dyn_cast functionality for the derived pass class.
3698 static bool classof(const ::mlir::Pass *pass) {
3699 return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>();
3700 }
3701
3702 /// A clone method to create a copy of this pass.
3703 std::unique_ptr<::mlir::Pass> clonePass() const override {
3704 return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this));
3705 }
3706
3707 /// Return the dialect that must be loaded in the context before this pass.
3708 void getDependentDialects(::mlir::DialectRegistry &registry) const override {
3709
3710 }
3711
3712 /// Explicitly declare the TypeID for this class. We declare an explicit private
3713 /// instantiation because Pass classes should only be visible by the current
3714 /// library.
3715 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DTensorMergeClustersBase<DerivedT>)
3716
3717protected:
3718};
3719
3720template <typename DerivedT>
3721class DTensorMeshPropagationBase : public ::mlir::OperationPass<mlir::ModuleOp> {
3722public:
3723 using Base = DTensorMeshPropagationBase;
3724
3725 DTensorMeshPropagationBase() : ::mlir::OperationPass<mlir::ModuleOp>(::mlir::TypeID::get<DerivedT>()) {}
3726 DTensorMeshPropagationBase(const DTensorMeshPropagationBase &other) : ::mlir::OperationPass<mlir::ModuleOp>(other) {}
3727
3728 /// Returns the command-line argument attached to this pass.
3729 static constexpr ::llvm::StringLiteral getArgumentName() {
3730 return ::llvm::StringLiteral("dtensor-mesh-propagation");
3731 }
3732 ::llvm::StringRef getArgument() const override { return "dtensor-mesh-propagation"; }
3733
3734 ::llvm::StringRef getDescription() const override { return "Propagates mesh information to all clusters."; }
3735
3736 /// Returns the derived pass name.
3737 static constexpr ::llvm::StringLiteral getPassName() {
3738 return ::llvm::StringLiteral("DTensorMeshPropagation");
3739 }
3740 ::llvm::StringRef getName() const override { return "DTensorMeshPropagation"; }
3741
3742 /// Support isa/dyn_cast functionality for the derived pass class.
3743 static bool classof(const ::mlir::Pass *pass) {
3744 return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>();
3745 }
3746
3747 /// A clone method to create a copy of this pass.
3748 std::unique_ptr<::mlir::Pass> clonePass() const override {
3749 return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this));
3750 }
3751
3752 /// Return the dialect that must be loaded in the context before this pass.
3753 void getDependentDialects(::mlir::DialectRegistry &registry) const override {
3754
3755 }
3756
3757 /// Explicitly declare the TypeID for this class. We declare an explicit private
3758 /// instantiation because Pass classes should only be visible by the current
3759 /// library.
3760 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DTensorMeshPropagationBase<DerivedT>)
3761
3762protected:
3763};
3764
3765template <typename DerivedT>
3766class DTensorMixedPrecisionReduceBase : public ::mlir::OperationPass<mlir::func::FuncOp> {
3767public:
3768 using Base = DTensorMixedPrecisionReduceBase;
3769
3770 DTensorMixedPrecisionReduceBase() : ::mlir::OperationPass<mlir::func::FuncOp>(::mlir::TypeID::get<DerivedT>()) {}
3771 DTensorMixedPrecisionReduceBase(const DTensorMixedPrecisionReduceBase &other) : ::mlir::OperationPass<mlir::func::FuncOp>(other) {}
3772
3773 /// Returns the command-line argument attached to this pass.
3774 static constexpr ::llvm::StringLiteral getArgumentName() {
3775 return ::llvm::StringLiteral("dtensor-mixed-precision-reduce");
3776 }
3777 ::llvm::StringRef getArgument() const override { return "dtensor-mixed-precision-reduce"; }
3778
3779 ::llvm::StringRef getDescription() const override { return "Upcast tensors to higher precision type for reduction ops."; }
3780
3781 /// Returns the derived pass name.
3782 static constexpr ::llvm::StringLiteral getPassName() {
3783 return ::llvm::StringLiteral("DTensorMixedPrecisionReduce");
3784 }
3785 ::llvm::StringRef getName() const override { return "DTensorMixedPrecisionReduce"; }
3786
3787 /// Support isa/dyn_cast functionality for the derived pass class.
3788 static bool classof(const ::mlir::Pass *pass) {
3789 return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>();
3790 }
3791
3792 /// A clone method to create a copy of this pass.
3793 std::unique_ptr<::mlir::Pass> clonePass() const override {
3794 return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this));
3795 }
3796
3797 /// Return the dialect that must be loaded in the context before this pass.
3798 void getDependentDialects(::mlir::DialectRegistry &registry) const override {
3799
3800 }
3801
3802 /// Explicitly declare the TypeID for this class. We declare an explicit private
3803 /// instantiation because Pass classes should only be visible by the current
3804 /// library.
3805 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DTensorMixedPrecisionReduceBase<DerivedT>)
3806
3807protected:
3808};
3809
3810template <typename DerivedT>
3811class DTensorMoveCompilationToHostBase : public ::mlir::OperationPass<mlir::ModuleOp> {
3812public:
3813 using Base = DTensorMoveCompilationToHostBase;
3814
3815 DTensorMoveCompilationToHostBase() : ::mlir::OperationPass<mlir::ModuleOp>(::mlir::TypeID::get<DerivedT>()) {}
3816 DTensorMoveCompilationToHostBase(const DTensorMoveCompilationToHostBase &other) : ::mlir::OperationPass<mlir::ModuleOp>(other) {}
3817
3818 /// Returns the command-line argument attached to this pass.
3819 static constexpr ::llvm::StringLiteral getArgumentName() {
3820 return ::llvm::StringLiteral("dtensor-move-compilation-to-host");
3821 }
3822 ::llvm::StringRef getArgument() const override { return "dtensor-move-compilation-to-host"; }
3823
3824 ::llvm::StringRef getDescription() const override { return "Moves XLA compilation ops to host computation."; }
3825
3826 /// Returns the derived pass name.
3827 static constexpr ::llvm::StringLiteral getPassName() {
3828 return ::llvm::StringLiteral("DTensorMoveCompilationToHost");
3829 }
3830 ::llvm::StringRef getName() const override { return "DTensorMoveCompilationToHost"; }
3831
3832 /// Support isa/dyn_cast functionality for the derived pass class.
3833 static bool classof(const ::mlir::Pass *pass) {
3834 return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>();
3835 }
3836
3837 /// A clone method to create a copy of this pass.
3838 std::unique_ptr<::mlir::Pass> clonePass() const override {
3839 return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this));
3840 }
3841
3842 /// Return the dialect that must be loaded in the context before this pass.
3843 void getDependentDialects(::mlir::DialectRegistry &registry) const override {
3844
3845 }
3846
3847 /// Explicitly declare the TypeID for this class. We declare an explicit private
3848 /// instantiation because Pass classes should only be visible by the current
3849 /// library.
3850 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DTensorMoveCompilationToHostBase<DerivedT>)
3851
3852protected:
3853};
3854
3855template <typename DerivedT>
3856class DTensorOpToDeviceClusterBase : public ::mlir::OperationPass<mlir::func::FuncOp> {
3857public:
3858 using Base = DTensorOpToDeviceClusterBase;
3859
3860 DTensorOpToDeviceClusterBase() : ::mlir::OperationPass<mlir::func::FuncOp>(::mlir::TypeID::get<DerivedT>()) {}
3861 DTensorOpToDeviceClusterBase(const DTensorOpToDeviceClusterBase &other) : ::mlir::OperationPass<mlir::func::FuncOp>(other) {}
3862
3863 /// Returns the command-line argument attached to this pass.
3864 static constexpr ::llvm::StringLiteral getArgumentName() {
3865 return ::llvm::StringLiteral("dtensor-op-to-device-cluster");
3866 }
3867 ::llvm::StringRef getArgument() const override { return "dtensor-op-to-device-cluster"; }
3868
3869 ::llvm::StringRef getDescription() const override { return "Creates and wraps tf_device.cluster op for all TF ops"; }
3870
3871 /// Returns the derived pass name.
3872 static constexpr ::llvm::StringLiteral getPassName() {
3873 return ::llvm::StringLiteral("DTensorOpToDeviceCluster");
3874 }
3875 ::llvm::StringRef getName() const override { return "DTensorOpToDeviceCluster"; }
3876
3877 /// Support isa/dyn_cast functionality for the derived pass class.
3878 static bool classof(const ::mlir::Pass *pass) {
3879 return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>();
3880 }
3881
3882 /// A clone method to create a copy of this pass.
3883 std::unique_ptr<::mlir::Pass> clonePass() const override {
3884 return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this));
3885 }
3886
3887 /// Return the dialect that must be loaded in the context before this pass.
3888 void getDependentDialects(::mlir::DialectRegistry &registry) const override {
3889
3890 }
3891
3892 /// Explicitly declare the TypeID for this class. We declare an explicit private
3893 /// instantiation because Pass classes should only be visible by the current
3894 /// library.
3895 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DTensorOpToDeviceClusterBase<DerivedT>)
3896
3897protected:
3898};
3899
3900template <typename DerivedT>
3901class DTensorPropagateDefaultLayoutBase : public ::mlir::OperationPass<mlir::func::FuncOp> {
3902public:
3903 using Base = DTensorPropagateDefaultLayoutBase;
3904
3905 DTensorPropagateDefaultLayoutBase() : ::mlir::OperationPass<mlir::func::FuncOp>(::mlir::TypeID::get<DerivedT>()) {}
3906 DTensorPropagateDefaultLayoutBase(const DTensorPropagateDefaultLayoutBase &other) : ::mlir::OperationPass<mlir::func::FuncOp>(other) {}
3907
3908 /// Returns the command-line argument attached to this pass.
3909 static constexpr ::llvm::StringLiteral getArgumentName() {
3910 return ::llvm::StringLiteral("dtensor-propagate-default-layout");
3911 }
3912 ::llvm::StringRef getArgument() const override { return "dtensor-propagate-default-layout"; }
3913
3914 ::llvm::StringRef getDescription() const override { return "Converts layout attributes added by end users to DTensorLayout op."; }
3915
3916 /// Returns the derived pass name.
3917 static constexpr ::llvm::StringLiteral getPassName() {
3918 return ::llvm::StringLiteral("DTensorPropagateDefaultLayout");
3919 }
3920 ::llvm::StringRef getName() const override { return "DTensorPropagateDefaultLayout"; }
3921
3922 /// Support isa/dyn_cast functionality for the derived pass class.
3923 static bool classof(const ::mlir::Pass *pass) {
3924 return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>();
3925 }
3926
3927 /// A clone method to create a copy of this pass.
3928 std::unique_ptr<::mlir::Pass> clonePass() const override {
3929 return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this));
3930 }
3931
3932 /// Return the dialect that must be loaded in the context before this pass.
3933 void getDependentDialects(::mlir::DialectRegistry &registry) const override {
3934
3935 }
3936
3937 /// Explicitly declare the TypeID for this class. We declare an explicit private
3938 /// instantiation because Pass classes should only be visible by the current
3939 /// library.
3940 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DTensorPropagateDefaultLayoutBase<DerivedT>)
3941
3942protected:
3943};
3944
3945template <typename DerivedT>
3946class DTensorPropagateDeviceIdToFunctionArgsBase : public ::mlir::OperationPass<mlir::ModuleOp> {
3947public:
3948 using Base = DTensorPropagateDeviceIdToFunctionArgsBase;
3949
3950 DTensorPropagateDeviceIdToFunctionArgsBase() : ::mlir::OperationPass<mlir::ModuleOp>(::mlir::TypeID::get<DerivedT>()) {}
3951 DTensorPropagateDeviceIdToFunctionArgsBase(const DTensorPropagateDeviceIdToFunctionArgsBase &other) : ::mlir::OperationPass<mlir::ModuleOp>(other) {}
3952
3953 /// Returns the command-line argument attached to this pass.
3954 static constexpr ::llvm::StringLiteral getArgumentName() {
3955 return ::llvm::StringLiteral("dtensor-propagate-device-id-to-function-args");
3956 }
3957 ::llvm::StringRef getArgument() const override { return "dtensor-propagate-device-id-to-function-args"; }
3958
3959 ::llvm::StringRef getDescription() const override { return "Adds device id as arguments to all private function in graph."; }
3960
3961 /// Returns the derived pass name.
3962 static constexpr ::llvm::StringLiteral getPassName() {
3963 return ::llvm::StringLiteral("DTensorPropagateDeviceIdToFunctionArgs");
3964 }
3965 ::llvm::StringRef getName() const override { return "DTensorPropagateDeviceIdToFunctionArgs"; }
3966
3967 /// Support isa/dyn_cast functionality for the derived pass class.
3968 static bool classof(const ::mlir::Pass *pass) {
3969 return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>();
3970 }
3971
3972 /// A clone method to create a copy of this pass.
3973 std::unique_ptr<::mlir::Pass> clonePass() const override {
3974 return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this));
3975 }
3976
3977 /// Return the dialect that must be loaded in the context before this pass.
3978 void getDependentDialects(::mlir::DialectRegistry &registry) const override {
3979
3980 }
3981
3982 /// Explicitly declare the TypeID for this class. We declare an explicit private
3983 /// instantiation because Pass classes should only be visible by the current
3984 /// library.
3985 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DTensorPropagateDeviceIdToFunctionArgsBase<DerivedT>)
3986
3987protected:
3988};
3989
3990template <typename DerivedT>
3991class DTensorReduceScatterLoweringBase : public ::mlir::OperationPass<mlir::ModuleOp> {
3992public:
3993 using Base = DTensorReduceScatterLoweringBase;
3994
3995 DTensorReduceScatterLoweringBase() : ::mlir::OperationPass<mlir::ModuleOp>(::mlir::TypeID::get<DerivedT>()) {}
3996 DTensorReduceScatterLoweringBase(const DTensorReduceScatterLoweringBase &other) : ::mlir::OperationPass<mlir::ModuleOp>(other) {}
3997
3998 /// Returns the command-line argument attached to this pass.
3999 static constexpr ::llvm::StringLiteral getArgumentName() {
4000 return ::llvm::StringLiteral("dtensor-reduce-scatter-lowering");
4001 }
4002 ::llvm::StringRef getArgument() const override { return "dtensor-reduce-scatter-lowering"; }
4003
4004 ::llvm::StringRef getDescription() const override { return "Converts logical ReduceScatter ops into physical ReduceScatter ops."; }
4005
4006 /// Returns the derived pass name.
4007 static constexpr ::llvm::StringLiteral getPassName() {
4008 return ::llvm::StringLiteral("DTensorReduceScatterLowering");
4009 }
4010 ::llvm::StringRef getName() const override { return "DTensorReduceScatterLowering"; }
4011
4012 /// Support isa/dyn_cast functionality for the derived pass class.
4013 static bool classof(const ::mlir::Pass *pass) {
4014 return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>();
4015 }
4016
4017 /// A clone method to create a copy of this pass.
4018 std::unique_ptr<::mlir::Pass> clonePass() const override {
4019 return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this));
4020 }
4021
4022 /// Return the dialect that must be loaded in the context before this pass.
4023 void getDependentDialects(::mlir::DialectRegistry &registry) const override {
4024
4025 }
4026
4027 /// Explicitly declare the TypeID for this class. We declare an explicit private
4028 /// instantiation because Pass classes should only be visible by the current
4029 /// library.
4030 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DTensorReduceScatterLoweringBase<DerivedT>)
4031
4032protected:
4033};
4034
4035template <typename DerivedT>
4036class DTensorSPMDExpansionBase : public ::mlir::OperationPass<mlir::ModuleOp> {
4037public:
4038 using Base = DTensorSPMDExpansionBase;
4039
4040 DTensorSPMDExpansionBase() : ::mlir::OperationPass<mlir::ModuleOp>(::mlir::TypeID::get<DerivedT>()) {}
4041 DTensorSPMDExpansionBase(const DTensorSPMDExpansionBase &other) : ::mlir::OperationPass<mlir::ModuleOp>(other) {}
4042
4043 /// Returns the command-line argument attached to this pass.
4044 static constexpr ::llvm::StringLiteral getArgumentName() {
4045 return ::llvm::StringLiteral("dtensor-spmd-expansion");
4046 }
4047 ::llvm::StringRef getArgument() const override { return "dtensor-spmd-expansion"; }
4048
4049 ::llvm::StringRef getDescription() const override { return "Converts ops into SPMD expanded form."; }
4050
4051 /// Returns the derived pass name.
4052 static constexpr ::llvm::StringLiteral getPassName() {
4053 return ::llvm::StringLiteral("DTensorSPMDExpansion");
4054 }
4055 ::llvm::StringRef getName() const override { return "DTensorSPMDExpansion"; }
4056
4057 /// Support isa/dyn_cast functionality for the derived pass class.
4058 static bool classof(const ::mlir::Pass *pass) {
4059 return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>();
4060 }
4061
4062 /// A clone method to create a copy of this pass.
4063 std::unique_ptr<::mlir::Pass> clonePass() const override {
4064 return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this));
4065 }
4066
4067 /// Return the dialect that must be loaded in the context before this pass.
4068 void getDependentDialects(::mlir::DialectRegistry &registry) const override {
4069
4070 }
4071
4072 /// Explicitly declare the TypeID for this class. We declare an explicit private
4073 /// instantiation because Pass classes should only be visible by the current
4074 /// library.
4075 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DTensorSPMDExpansionBase<DerivedT>)
4076
4077protected:
4078};
4079
4080template <typename DerivedT>
4081class DTensorSetDefaultShardingBase : public ::mlir::OperationPass<mlir::func::FuncOp> {
4082public:
4083 using Base = DTensorSetDefaultShardingBase;
4084
4085 DTensorSetDefaultShardingBase() : ::mlir::OperationPass<mlir::func::FuncOp>(::mlir::TypeID::get<DerivedT>()) {}
4086 DTensorSetDefaultShardingBase(const DTensorSetDefaultShardingBase &other) : ::mlir::OperationPass<mlir::func::FuncOp>(other) {}
4087
4088 /// Returns the command-line argument attached to this pass.
4089 static constexpr ::llvm::StringLiteral getArgumentName() {
4090 return ::llvm::StringLiteral("dtensor-set-default-sharding");
4091 }
4092 ::llvm::StringRef getArgument() const override { return "dtensor-set-default-sharding"; }
4093
4094 ::llvm::StringRef getDescription() const override { return "Sets default sharding of TPU computation inputs/outputs to maximal."; }
4095
4096 /// Returns the derived pass name.
4097 static constexpr ::llvm::StringLiteral getPassName() {
4098 return ::llvm::StringLiteral("DTensorSetDefaultSharding");
4099 }
4100 ::llvm::StringRef getName() const override { return "DTensorSetDefaultSharding"; }
4101
4102 /// Support isa/dyn_cast functionality for the derived pass class.
4103 static bool classof(const ::mlir::Pass *pass) {
4104 return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>();
4105 }
4106
4107 /// A clone method to create a copy of this pass.
4108 std::unique_ptr<::mlir::Pass> clonePass() const override {
4109 return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this));
4110 }
4111
4112 /// Return the dialect that must be loaded in the context before this pass.
4113 void getDependentDialects(::mlir::DialectRegistry &registry) const override {
4114
4115 }
4116
4117 /// Explicitly declare the TypeID for this class. We declare an explicit private
4118 /// instantiation because Pass classes should only be visible by the current
4119 /// library.
4120 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DTensorSetDefaultShardingBase<DerivedT>)
4121
4122protected:
4123};
4124
4125template <typename DerivedT>
4126class DTensorSparseExpansionBase : public ::mlir::OperationPass<mlir::ModuleOp> {
4127public:
4128 using Base = DTensorSparseExpansionBase;
4129
4130 DTensorSparseExpansionBase() : ::mlir::OperationPass<mlir::ModuleOp>(::mlir::TypeID::get<DerivedT>()) {}
4131 DTensorSparseExpansionBase(const DTensorSparseExpansionBase &other) : ::mlir::OperationPass<mlir::ModuleOp>(other) {}
4132
4133 /// Returns the command-line argument attached to this pass.
4134 static constexpr ::llvm::StringLiteral getArgumentName() {
4135 return ::llvm::StringLiteral("dtensor-sparse-expansion");
4136 }
4137 ::llvm::StringRef getArgument() const override { return "dtensor-sparse-expansion"; }
4138
4139 ::llvm::StringRef getDescription() const override { return "Convert ops that take in SparseTensor input to its corresponding Sparse or Dense ops."; }
4140
4141 /// Returns the derived pass name.
4142 static constexpr ::llvm::StringLiteral getPassName() {
4143 return ::llvm::StringLiteral("DTensorSparseExpansion");
4144 }
4145 ::llvm::StringRef getName() const override { return "DTensorSparseExpansion"; }
4146
4147 /// Support isa/dyn_cast functionality for the derived pass class.
4148 static bool classof(const ::mlir::Pass *pass) {
4149 return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>();
4150 }
4151
4152 /// A clone method to create a copy of this pass.
4153 std::unique_ptr<::mlir::Pass> clonePass() const override {
4154 return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this));
4155 }
4156
4157 /// Return the dialect that must be loaded in the context before this pass.
4158 void getDependentDialects(::mlir::DialectRegistry &registry) const override {
4159
4160 }
4161
4162 /// Explicitly declare the TypeID for this class. We declare an explicit private
4163 /// instantiation because Pass classes should only be visible by the current
4164 /// library.
4165 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DTensorSparseExpansionBase<DerivedT>)
4166
4167protected:
4168};
4169
4170template <typename DerivedT>
4171class DTensorSparseTensorToDenseTensorBase : public ::mlir::OperationPass<mlir::ModuleOp> {
4172public:
4173 using Base = DTensorSparseTensorToDenseTensorBase;
4174
4175 DTensorSparseTensorToDenseTensorBase() : ::mlir::OperationPass<mlir::ModuleOp>(::mlir::TypeID::get<DerivedT>()) {}
4176 DTensorSparseTensorToDenseTensorBase(const DTensorSparseTensorToDenseTensorBase &other) : ::mlir::OperationPass<mlir::ModuleOp>(other) {}
4177
4178 /// Returns the command-line argument attached to this pass.
4179 static constexpr ::llvm::StringLiteral getArgumentName() {
4180 return ::llvm::StringLiteral("dtensor-sparse-tensor-to-dense-tensor");
4181 }
4182 ::llvm::StringRef getArgument() const override { return "dtensor-sparse-tensor-to-dense-tensor"; }
4183
4184 ::llvm::StringRef getDescription() const override { return "Converts SparseTensor inputs to its component tensors inputs and emits a SparseToDenseOp for every op that consumes a SparseTensor."; }
4185
4186 /// Returns the derived pass name.
4187 static constexpr ::llvm::StringLiteral getPassName() {
4188 return ::llvm::StringLiteral("DTensorSparseTensorToDenseTensor");
4189 }
4190 ::llvm::StringRef getName() const override { return "DTensorSparseTensorToDenseTensor"; }
4191
4192 /// Support isa/dyn_cast functionality for the derived pass class.
4193 static bool classof(const ::mlir::Pass *pass) {
4194 return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>();
4195 }
4196
4197 /// A clone method to create a copy of this pass.
4198 std::unique_ptr<::mlir::Pass> clonePass() const override {
4199 return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this));
4200 }
4201
4202 /// Return the dialect that must be loaded in the context before this pass.
4203 void getDependentDialects(::mlir::DialectRegistry &registry) const override {
4204
4205 }
4206
4207 /// Explicitly declare the TypeID for this class. We declare an explicit private
4208 /// instantiation because Pass classes should only be visible by the current
4209 /// library.
4210 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DTensorSparseTensorToDenseTensorBase<DerivedT>)
4211
4212protected:
4213};
4214
4215template <typename DerivedT>
4216class DTensorTPUIntegrationBase : public ::mlir::OperationPass<mlir::ModuleOp> {
4217public:
4218 using Base = DTensorTPUIntegrationBase;
4219
4220 DTensorTPUIntegrationBase() : ::mlir::OperationPass<mlir::ModuleOp>(::mlir::TypeID::get<DerivedT>()) {}
4221 DTensorTPUIntegrationBase(const DTensorTPUIntegrationBase &other) : ::mlir::OperationPass<mlir::ModuleOp>(other) {}
4222
4223 /// Returns the command-line argument attached to this pass.
4224 static constexpr ::llvm::StringLiteral getArgumentName() {
4225 return ::llvm::StringLiteral("dtensor-tpu-integration");
4226 }
4227 ::llvm::StringRef getArgument() const override { return "dtensor-tpu-integration"; }
4228
4229 ::llvm::StringRef getDescription() const override { return "Adds TPUReplicateMetadata and converts ops that run on TPU's to a single tf_device.cluster to be compatible with following TF2XLA MLIR passes."; }
4230
4231 /// Returns the derived pass name.
4232 static constexpr ::llvm::StringLiteral getPassName() {
4233 return ::llvm::StringLiteral("DTensorTPUIntegration");
4234 }
4235 ::llvm::StringRef getName() const override { return "DTensorTPUIntegration"; }
4236
4237 /// Support isa/dyn_cast functionality for the derived pass class.
4238 static bool classof(const ::mlir::Pass *pass) {
4239 return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>();
4240 }
4241
4242 /// A clone method to create a copy of this pass.
4243 std::unique_ptr<::mlir::Pass> clonePass() const override {
4244 return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this));
4245 }
4246
4247 /// Return the dialect that must be loaded in the context before this pass.
4248 void getDependentDialects(::mlir::DialectRegistry &registry) const override {
4249
4250 }
4251
4252 /// Explicitly declare the TypeID for this class. We declare an explicit private
4253 /// instantiation because Pass classes should only be visible by the current
4254 /// library.
4255 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DTensorTPUIntegrationBase<DerivedT>)
4256
4257protected:
4258};
4259
4260template <typename DerivedT>
4261class DTensorTpuAddResourceDeviceAttributeBase : public ::mlir::OperationPass<mlir::ModuleOp> {
4262public:
4263 using Base = DTensorTpuAddResourceDeviceAttributeBase;
4264
4265 DTensorTpuAddResourceDeviceAttributeBase() : ::mlir::OperationPass<mlir::ModuleOp>(::mlir::TypeID::get<DerivedT>()) {}
4266 DTensorTpuAddResourceDeviceAttributeBase(const DTensorTpuAddResourceDeviceAttributeBase &other) : ::mlir::OperationPass<mlir::ModuleOp>(other) {}
4267
4268 /// Returns the command-line argument attached to this pass.
4269 static constexpr ::llvm::StringLiteral getArgumentName() {
4270 return ::llvm::StringLiteral("dtensor-tpu-add-resource-device-attribute");
4271 }
4272 ::llvm::StringRef getArgument() const override { return "dtensor-tpu-add-resource-device-attribute"; }
4273
4274 ::llvm::StringRef getDescription() const override { return "Adds placeholder device attributes to resources accessed by TPU computation to enable buffer aliasing."; }
4275
4276 /// Returns the derived pass name.
4277 static constexpr ::llvm::StringLiteral getPassName() {
4278 return ::llvm::StringLiteral("DTensorTpuAddResourceDeviceAttribute");
4279 }
4280 ::llvm::StringRef getName() const override { return "DTensorTpuAddResourceDeviceAttribute"; }
4281
4282 /// Support isa/dyn_cast functionality for the derived pass class.
4283 static bool classof(const ::mlir::Pass *pass) {
4284 return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>();
4285 }
4286
4287 /// A clone method to create a copy of this pass.
4288 std::unique_ptr<::mlir::Pass> clonePass() const override {
4289 return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this));
4290 }
4291
4292 /// Return the dialect that must be loaded in the context before this pass.
4293 void getDependentDialects(::mlir::DialectRegistry &registry) const override {
4294
4295 }
4296
4297 /// Explicitly declare the TypeID for this class. We declare an explicit private
4298 /// instantiation because Pass classes should only be visible by the current
4299 /// library.
4300 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DTensorTpuAddResourceDeviceAttributeBase<DerivedT>)
4301
4302protected:
4303};
4304
4305template <typename DerivedT>
4306class DTensorUndoMergeConstAcrossMeshBase : public ::mlir::OperationPass<mlir::func::FuncOp> {
4307public:
4308 using Base = DTensorUndoMergeConstAcrossMeshBase;
4309
4310 DTensorUndoMergeConstAcrossMeshBase() : ::mlir::OperationPass<mlir::func::FuncOp>(::mlir::TypeID::get<DerivedT>()) {}
4311 DTensorUndoMergeConstAcrossMeshBase(const DTensorUndoMergeConstAcrossMeshBase &other) : ::mlir::OperationPass<mlir::func::FuncOp>(other) {}
4312
4313 /// Returns the command-line argument attached to this pass.
4314 static constexpr ::llvm::StringLiteral getArgumentName() {
4315 return ::llvm::StringLiteral("dtensor-undo-merge-const-across-mesh");
4316 }
4317 ::llvm::StringRef getArgument() const override { return "dtensor-undo-merge-const-across-mesh"; }
4318
4319 ::llvm::StringRef getDescription() const override { return "Undo constant merging across meshes"; }
4320
4321 /// Returns the derived pass name.
4322 static constexpr ::llvm::StringLiteral getPassName() {
4323 return ::llvm::StringLiteral("DTensorUndoMergeConstAcrossMesh");
4324 }
4325 ::llvm::StringRef getName() const override { return "DTensorUndoMergeConstAcrossMesh"; }
4326
4327 /// Support isa/dyn_cast functionality for the derived pass class.
4328 static bool classof(const ::mlir::Pass *pass) {
4329 return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>();
4330 }
4331
4332 /// A clone method to create a copy of this pass.
4333 std::unique_ptr<::mlir::Pass> clonePass() const override {
4334 return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this));
4335 }
4336
4337 /// Return the dialect that must be loaded in the context before this pass.
4338 void getDependentDialects(::mlir::DialectRegistry &registry) const override {
4339
4340 }
4341
4342 /// Explicitly declare the TypeID for this class. We declare an explicit private
4343 /// instantiation because Pass classes should only be visible by the current
4344 /// library.
4345 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DTensorUndoMergeConstAcrossMeshBase<DerivedT>)
4346
4347protected:
4348};
4349
4350template <typename DerivedT>
4351class DTensorUpdateTPUMetadataBase : public ::mlir::OperationPass<mlir::ModuleOp> {
4352public:
4353 using Base = DTensorUpdateTPUMetadataBase;
4354
4355 DTensorUpdateTPUMetadataBase() : ::mlir::OperationPass<mlir::ModuleOp>(::mlir::TypeID::get<DerivedT>()) {}
4356 DTensorUpdateTPUMetadataBase(const DTensorUpdateTPUMetadataBase &other) : ::mlir::OperationPass<mlir::ModuleOp>(other) {}
4357
4358 /// Returns the command-line argument attached to this pass.
4359 static constexpr ::llvm::StringLiteral getArgumentName() {
4360 return ::llvm::StringLiteral("dtensor-update-tpu-metadata");
4361 }
4362 ::llvm::StringRef getArgument() const override { return "dtensor-update-tpu-metadata"; }
4363
4364 ::llvm::StringRef getDescription() const override { return "Changes metadata on TPU specific ops such as device placement."; }
4365
4366 /// Returns the derived pass name.
4367 static constexpr ::llvm::StringLiteral getPassName() {
4368 return ::llvm::StringLiteral("DTensorUpdateTPUMetadata");
4369 }
4370 ::llvm::StringRef getName() const override { return "DTensorUpdateTPUMetadata"; }
4371
4372 /// Support isa/dyn_cast functionality for the derived pass class.
4373 static bool classof(const ::mlir::Pass *pass) {
4374 return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>();
4375 }
4376
4377 /// A clone method to create a copy of this pass.
4378 std::unique_ptr<::mlir::Pass> clonePass() const override {
4379 return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this));
4380 }
4381
4382 /// Return the dialect that must be loaded in the context before this pass.
4383 void getDependentDialects(::mlir::DialectRegistry &registry) const override {
4384
4385 }
4386
4387 /// Explicitly declare the TypeID for this class. We declare an explicit private
4388 /// instantiation because Pass classes should only be visible by the current
4389 /// library.
4390 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DTensorUpdateTPUMetadataBase<DerivedT>)
4391
4392protected:
4393};
4394#undef GEN_PASS_CLASSES
4395#endif // GEN_PASS_CLASSES
4396