1/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
17#include "tensorflow/dtensor/mlir/expansions/argmax_spmd_expander.h"
18#include "tensorflow/dtensor/mlir/expansions/bias_add_spmd_expander.h"
19#include "tensorflow/dtensor/mlir/expansions/broadcast_to_spmd_expander.h"
20#include "tensorflow/dtensor/mlir/expansions/concat_spmd_expander.h"
21#include "tensorflow/dtensor/mlir/expansions/control_flow_spmd_expander.h"
22#include "tensorflow/dtensor/mlir/expansions/conv_spmd_expander.h"
23#include "tensorflow/dtensor/mlir/expansions/cumsum_spmd_expander.h"
24#include "tensorflow/dtensor/mlir/expansions/dataparallel_spmd_expander.h"
25#include "tensorflow/dtensor/mlir/expansions/disable_copy_on_read_spmd_expander.h"
26#include "tensorflow/dtensor/mlir/expansions/dtensor_op_spmd_expander.h"
27#include "tensorflow/dtensor/mlir/expansions/einsum_spmd_expander.h"
28#include "tensorflow/dtensor/mlir/expansions/elementwise_spmd_expander.h"
29#include "tensorflow/dtensor/mlir/expansions/expanddims_spmd_expander.h"
30#include "tensorflow/dtensor/mlir/expansions/fill_spmd_expander.h"
31#include "tensorflow/dtensor/mlir/expansions/gather_spmd_expander.h"
32#include "tensorflow/dtensor/mlir/expansions/identity_n_spmd_expander.h"
33#include "tensorflow/dtensor/mlir/expansions/in_top_k_spmd_expander.h"
34#include "tensorflow/dtensor/mlir/expansions/io_op_spmd_expander.h"
35#include "tensorflow/dtensor/mlir/expansions/matmul_spmd_expander.h"
36#include "tensorflow/dtensor/mlir/expansions/meta_spmd_expander.h"
37#include "tensorflow/dtensor/mlir/expansions/nullary_spmd_expander.h"
38#include "tensorflow/dtensor/mlir/expansions/qr_spmd_expander.h"
39#include "tensorflow/dtensor/mlir/expansions/random_op_spmd_expander.h"
40#include "tensorflow/dtensor/mlir/expansions/range_spmd_expander.h"
41#include "tensorflow/dtensor/mlir/expansions/reduce_spmd_expander.h"
42#include "tensorflow/dtensor/mlir/expansions/replicated_spmd_expander.h"
43#include "tensorflow/dtensor/mlir/expansions/resource_spmd_expander.h"
44#include "tensorflow/dtensor/mlir/expansions/save_restore_spmd_expander.h"
45#include "tensorflow/dtensor/mlir/expansions/scatter_spmd_expander.h"
46#include "tensorflow/dtensor/mlir/expansions/segmentation_spmd_expander.h"
47#include "tensorflow/dtensor/mlir/expansions/slice_spmd_expander.h"
48#include "tensorflow/dtensor/mlir/expansions/softmax_spmd_expander.h"
49#include "tensorflow/dtensor/mlir/expansions/sparse_to_dense_spmd_expander.h"
50#include "tensorflow/dtensor/mlir/expansions/split_spmd_expander.h"
51#include "tensorflow/dtensor/mlir/expansions/squeeze_spmd_expander.h"
52#include "tensorflow/dtensor/mlir/expansions/tensorlist_getitem_spmd_expander.h"
53#include "tensorflow/dtensor/mlir/expansions/tensorlist_reserve_spmd_expander.h"
54#include "tensorflow/dtensor/mlir/expansions/tensorlist_setitem_spmd_expander.h"
55#include "tensorflow/dtensor/mlir/expansions/top_k_spmd_expander.h"
56#include "tensorflow/dtensor/mlir/expansions/trivial_spmd_expander.h"
57#include "tensorflow/dtensor/mlir/spmd_expander.h"
58
59namespace tensorflow {
60namespace dtensor {
61
62// Nullary
63REGISTER_SPMD(Const, TF::ConstOp, NullarySPMDExpander);
64
65// Unary
66REGISTER_SPMD(Abs, TF::AbsOp, ElementwiseSPMDExpander);
67REGISTER_SPMD(Cast, TF::CastOp, ElementwiseSPMDExpander);
68REGISTER_SPMD(Identity, TF::IdentityOp, ElementwiseSPMDExpander);
69REGISTER_SPMD(Neg, TF::NegOp, ElementwiseSPMDExpander);
70REGISTER_SPMD(ZerosLike, TF::ZerosLikeOp, ElementwiseSPMDExpander);
71REGISTER_SPMD(Exp, TF::ExpOp, ElementwiseSPMDExpander);
72REGISTER_SPMD(Sqrt, TF::SqrtOp, ElementwiseSPMDExpander);
73REGISTER_SPMD(Rsqrt, TF::RsqrtOp, ElementwiseSPMDExpander);
74REGISTER_SPMD(Log, TF::LogOp, ElementwiseSPMDExpander);
75REGISTER_SPMD(StopGradient, TF::StopGradientOp, ElementwiseSPMDExpander);
76REGISTER_SPMD(Reciprocal, TF::ReciprocalOp, ElementwiseSPMDExpander);
77REGISTER_SPMD(Square, TF::SquareOp, ElementwiseSPMDExpander);
78REGISTER_SPMD(Erf, TF::ErfOp, ElementwiseSPMDExpander);
79REGISTER_SPMD(Tanh, TF::TanhOp, ElementwiseSPMDExpander);
80REGISTER_SPMD(TanhGrad, TF::TanhGradOp, ElementwiseSPMDExpander);
81REGISTER_SPMD(Relu, TF::ReluOp, ElementwiseSPMDExpander);
82REGISTER_SPMD(ReluGrad, TF::ReluGradOp, ElementwiseSPMDExpander);
83REGISTER_SPMD(Sigmoid, TF::SigmoidOp, ElementwiseSPMDExpander);
84REGISTER_SPMD(SigmoidGrad, TF::SigmoidGradOp, ElementwiseSPMDExpander);
85REGISTER_SPMD(IsFinite, TF::IsFiniteOp, ElementwiseSPMDExpander);
86
87// Elementwise
88REGISTER_SPMD(Add, TF::AddOp, ElementwiseSPMDExpander);
89REGISTER_SPMD(AddV2, TF::AddV2Op, ElementwiseSPMDExpander);
90REGISTER_SPMD(AddN, TF::AddNOp, ElementwiseSPMDExpander);
91REGISTER_SPMD(RealDiv, TF::RealDivOp, ElementwiseSPMDExpander);
92REGISTER_SPMD(Div, TF::DivOp, ElementwiseSPMDExpander);
93REGISTER_SPMD(DivNoNan, TF::DivNoNanOp, ElementwiseSPMDExpander);
94REGISTER_SPMD(Equal, TF::EqualOp, ElementwiseSPMDExpander);
95REGISTER_SPMD(FloorDiv, TF::FloorDivOp, ElementwiseSPMDExpander);
96REGISTER_SPMD(FloorMod, TF::FloorModOp, ElementwiseSPMDExpander);
97REGISTER_SPMD(NotEqual, TF::NotEqualOp, ElementwiseSPMDExpander);
98REGISTER_SPMD(Less, TF::LessOp, ElementwiseSPMDExpander);
99REGISTER_SPMD(LessEqual, TF::LessEqualOp, ElementwiseSPMDExpander);
100REGISTER_SPMD(LogicalAnd, TF::LogicalAndOp, ElementwiseSPMDExpander);
101REGISTER_SPMD(LogicalNot, TF::LogicalNotOp, ElementwiseSPMDExpander);
102REGISTER_SPMD(Maximum, TF::MaximumOp, ElementwiseSPMDExpander);
103REGISTER_SPMD(Minimum, TF::MinimumOp, ElementwiseSPMDExpander);
104REGISTER_SPMD(Mul, TF::MulOp, ElementwiseSPMDExpander);
105REGISTER_SPMD(Select, TF::SelectOp, ElementwiseSPMDExpander);
106REGISTER_SPMD(SelectV2, TF::SelectV2Op, ElementwiseSPMDExpander);
107REGISTER_SPMD(Sub, TF::SubOp, ElementwiseSPMDExpander);
108REGISTER_SPMD(SquaredDifference, TF::SquaredDifferenceOp,
109 ElementwiseSPMDExpander);
110REGISTER_SPMD(Greater, TF::GreaterOp, ElementwiseSPMDExpander);
111REGISTER_SPMD(GreaterEqual, TF::GreaterEqualOp, ElementwiseSPMDExpander);
112REGISTER_SPMD(RsqrtGrad, TF::RsqrtGradOp, ElementwiseSPMDExpander);
113REGISTER_SPMD(SqrtGrad, TF::SqrtGradOp, ElementwiseSPMDExpander);
114REGISTER_SPMD(Pow, TF::PowOp, ElementwiseSPMDExpander);
115REGISTER_SPMD(BitwiseAnd, TF::BitwiseAndOp, ElementwiseSPMDExpander);
116REGISTER_SPMD(BitwiseOr, TF::BitwiseOrOp, ElementwiseSPMDExpander);
117REGISTER_SPMD(BitwiseXor, TF::BitwiseXorOp, ElementwiseSPMDExpander);
118REGISTER_SPMD(LeftShift, TF::LeftShiftOp, ElementwiseSPMDExpander);
119REGISTER_SPMD(RightShift, TF::RightShiftOp, ElementwiseSPMDExpander);
120REGISTER_SPMD(LogicalOr, TF::LogicalOrOp, ElementwiseSPMDExpander);
121REGISTER_SPMD(Cos, TF::CosOp, ElementwiseSPMDExpander);
122REGISTER_SPMD(Acos, TF::AcosOp, ElementwiseSPMDExpander);
123REGISTER_SPMD(Acosh, TF::AcoshOp, ElementwiseSPMDExpander);
124REGISTER_SPMD(Angle, TF::AngleOp, ElementwiseSPMDExpander);
125REGISTER_SPMD(Asin, TF::AsinOp, ElementwiseSPMDExpander);
126REGISTER_SPMD(Asinh, TF::AsinhOp, ElementwiseSPMDExpander);
127REGISTER_SPMD(Atan, TF::AtanOp, ElementwiseSPMDExpander);
128REGISTER_SPMD(Atan2, TF::Atan2Op, ElementwiseSPMDExpander);
129REGISTER_SPMD(Atanh, TF::AtanhOp, ElementwiseSPMDExpander);
130REGISTER_SPMD(BesselI0e, TF::BesselI0eOp, ElementwiseSPMDExpander);
131REGISTER_SPMD(BesselI1e, TF::BesselI1eOp, ElementwiseSPMDExpander);
132REGISTER_SPMD(Betainc, TF::BetaincOp, ElementwiseSPMDExpander);
133REGISTER_SPMD(Bitcast, TF::BitcastOp, ElementwiseSPMDExpander);
134REGISTER_SPMD(Ceil, TF::CeilOp, ElementwiseSPMDExpander);
135REGISTER_SPMD(CheckNumerics, TF::CheckNumericsOp, ElementwiseSPMDExpander);
136REGISTER_SPMD(ClipByValue, TF::ClipByValueOp, ElementwiseSPMDExpander);
137REGISTER_SPMD(Conj, TF::ConjOp, ElementwiseSPMDExpander);
138REGISTER_SPMD(Cosh, TF::CoshOp, ElementwiseSPMDExpander);
139REGISTER_SPMD(Complex, TF::ComplexOp, ElementwiseSPMDExpander);
140REGISTER_SPMD(ComplexAbs, TF::ComplexAbsOp, ElementwiseSPMDExpander);
141REGISTER_SPMD(Digamma, TF::DigammaOp, ElementwiseSPMDExpander);
142
143// TODO(b/193924452): Add the following Ops once unit tests are there.
144//
145REGISTER_SPMD(Elu, TF::EluOp, ElementwiseSPMDExpander);
146REGISTER_SPMD(EluGrad, TF::EluGradOp, ElementwiseSPMDExpander);
147REGISTER_SPMD(Erfc, TF::ErfcOp, ElementwiseSPMDExpander);
148REGISTER_SPMD(Erfinv, TF::ErfinvOp, ElementwiseSPMDExpander);
149REGISTER_SPMD(Expm1, TF::Expm1Op, ElementwiseSPMDExpander);
150REGISTER_SPMD(Floor, TF::FloorOp, ElementwiseSPMDExpander);
151// REGISTER_SPMD(HSVToRGB, TF::HSVToRGBOp, ElementwiseSPMDExpander);
152REGISTER_SPMD(Igamma, TF::IgammaOp, ElementwiseSPMDExpander);
153REGISTER_SPMD(Igammac, TF::IgammacOp, ElementwiseSPMDExpander);
154REGISTER_SPMD(IgammaGradA, TF::IgammaGradAOp, ElementwiseSPMDExpander);
155REGISTER_SPMD(Imag, TF::ImagOp, ElementwiseSPMDExpander);
156// REGISTER_SPMD(InplaceAdd, TF::InplaceAddOp, ElementwiseSPMDExpander);
157// REGISTER_SPMD(InplaceUpdate, TF::InplaceUpdateOp, ElementwiseSPMDExpander);
158REGISTER_SPMD(Inv, TF::InvOp, ElementwiseSPMDExpander);
159REGISTER_SPMD(Invert, TF::InvertOp, ElementwiseSPMDExpander);
160REGISTER_SPMD(IsInf, TF::IsInfOp, ElementwiseSPMDExpander);
161REGISTER_SPMD(IsNan, TF::IsNanOp, ElementwiseSPMDExpander);
162REGISTER_SPMD(LeakyRelu, TF::LeakyReluOp, ElementwiseSPMDExpander);
163REGISTER_SPMD(LeakyReluGrad, TF::LeakyReluGradOp, ElementwiseSPMDExpander);
164REGISTER_SPMD(Lgamma, TF::LgammaOp, ElementwiseSPMDExpander);
165REGISTER_SPMD(Log1p, TF::Log1pOp, ElementwiseSPMDExpander);
166REGISTER_SPMD(MulNoNan, TF::MulNoNanOp, ElementwiseSPMDExpander);
167REGISTER_SPMD(Ndtri, TF::NdtriOp, ElementwiseSPMDExpander);
168REGISTER_SPMD(NextAfter, TF::NextAfterOp, ElementwiseSPMDExpander);
169REGISTER_SPMD(Polygamma, TF::PolygammaOp, ElementwiseSPMDExpander);
170REGISTER_SPMD(PopulationCount, TF::PopulationCountOp, ElementwiseSPMDExpander);
171REGISTER_SPMD(PreventGradient, TF::PreventGradientOp, ElementwiseSPMDExpander);
172REGISTER_SPMD(Real, TF::RealOp, ElementwiseSPMDExpander);
173REGISTER_SPMD(ReciprocalGrad, TF::ReciprocalGradOp, ElementwiseSPMDExpander);
174REGISTER_SPMD(Relu6, TF::Relu6Op, ElementwiseSPMDExpander);
175REGISTER_SPMD(Relu6Grad, TF::Relu6GradOp, ElementwiseSPMDExpander);
176REGISTER_SPMD(Rint, TF::RintOp, ElementwiseSPMDExpander);
177REGISTER_SPMD(Round, TF::RoundOp, ElementwiseSPMDExpander);
178REGISTER_SPMD(Selu, TF::SeluOp, ElementwiseSPMDExpander);
179REGISTER_SPMD(SeluGrad, TF::SeluGradOp, ElementwiseSPMDExpander);
180REGISTER_SPMD(Sign, TF::SignOp, ElementwiseSPMDExpander);
181REGISTER_SPMD(Sin, TF::SinOp, ElementwiseSPMDExpander);
182REGISTER_SPMD(Sinh, TF::SinhOp, ElementwiseSPMDExpander);
183// REGISTER_SPMD(Snapshot, TF::SnapshotOp, ElementwiseSPMDExpander);
184REGISTER_SPMD(Softplus, TF::SoftplusOp, ElementwiseSPMDExpander);
185// REGISTER_SPMD(SoftplusGrad, TF::SoftplusGradOp, ElementwiseSPMDExpander);
186REGISTER_SPMD(Softsign, TF::SoftsignOp, ElementwiseSPMDExpander);
187// REGISTER_SPMD(SoftsignGrad, TF::SoftsignGradOp, ElementwiseSPMDExpander);
188REGISTER_SPMD(Tan, TF::TanOp, ElementwiseSPMDExpander);
189// REGISTER_SPMD(TridiagonalSolve, TF::TridiagonalSolveOp,
190// ElementwiseSPMDExpander);
191REGISTER_SPMD(TruncateDiv, TF::TruncateDivOp, ElementwiseSPMDExpander);
192REGISTER_SPMD(TruncateMod, TF::TruncateModOp, ElementwiseSPMDExpander);
193REGISTER_SPMD(Xdivy, TF::XdivyOp, ElementwiseSPMDExpander);
194REGISTER_SPMD(Xlog1py, TF::Xlog1pyOp, ElementwiseSPMDExpander);
195REGISTER_SPMD(Xlogy, TF::XlogyOp, ElementwiseSPMDExpander);
196REGISTER_SPMD(Zeta, TF::ZetaOp, ElementwiseSPMDExpander);
197
198// IdentityN
199// TODO(hongjunchoi): Make ElementwiseSPMDExpander support IdentityN.
200REGISTER_SPMD(IdentityN, TF::IdentityNOp, IdentityNSPMDExpander);
201
202// Range
203REGISTER_SPMD(Range, TF::RangeOp, RangeSPMDExpander);
204
205// Reductions
206REGISTER_SPMD(All, TF::AllOp, ReduceSPMDExpander);
207REGISTER_SPMD(Any, TF::AnyOp, ReduceSPMDExpander);
208REGISTER_SPMD(Mean, TF::MeanOp, ReduceSPMDExpander);
209REGISTER_SPMD(Max, TF::MaxOp, ReduceSPMDExpander);
210REGISTER_SPMD(Min, TF::MinOp, ReduceSPMDExpander);
211REGISTER_SPMD(Prod, TF::ProdOp, ReduceSPMDExpander);
212REGISTER_SPMD(Sum, TF::SumOp, ReduceSPMDExpander);
213REGISTER_SPMD(L2Loss, TF::L2LossOp, ReduceSPMDExpander);
214
215// Convolution
216REGISTER_SPMD(Conv2D, TF::Conv2DOp, ConvSPMDExpander);
217REGISTER_SPMD(Conv2DBackpropFilter, TF::Conv2DBackpropFilterOp,
218 ConvSPMDExpander);
219REGISTER_SPMD(Conv2DBackpropInput, TF::Conv2DBackpropInputOp, ConvSPMDExpander);
220REGISTER_SPMD(Conv3D, TF::Conv3DOp, ConvSPMDExpander);
221REGISTER_SPMD(Conv3DBackpropFilterV2, TF::Conv3DBackpropFilterV2Op,
222 ConvSPMDExpander);
223REGISTER_SPMD(Conv3DBackpropInputV2, TF::Conv3DBackpropInputV2Op,
224 ConvSPMDExpander);
225REGISTER_SPMD(MaxPool, TF::MaxPoolOp, ConvSPMDExpander);
226REGISTER_SPMD(MaxPoolGrad, TF::MaxPoolGradOp, ConvSPMDExpander);
227
228// Metadata
229REGISTER_SPMD(Rank, TF::RankOp, ShapeSPMDExpander);
230REGISTER_SPMD(Shape, TF::ShapeOp, ShapeSPMDExpander);
231REGISTER_SPMD(ShapeN, TF::ShapeNOp, ShapeSPMDExpander);
232
233REGISTER_SPMD(BroadcastGradientArgs, TF::BroadcastGradientArgsOp,
234 MetadataSPMDExpander);
235
236// Resource ops
237REGISTER_SPMD(AssignVariable, TF::AssignVariableOp, ResourceSPMDExpander);
238REGISTER_SPMD(AssignAddVariable, TF::AssignAddVariableOp, ResourceSPMDExpander);
239REGISTER_SPMD(AssignSubVariable, TF::AssignSubVariableOp, ResourceSPMDExpander);
240REGISTER_SPMD(ReadVariable, TF::ReadVariableOp, ResourceSPMDExpander);
241REGISTER_SPMD(VarHandle, TF::VarHandleOp, ResourceSPMDExpander);
242REGISTER_SPMD(VarIsInitialized, TF::VarIsInitializedOp, ResourceSPMDExpander);
243REGISTER_SPMD(DestroyResource, TF::DestroyResourceOp, ResourceSPMDExpander);
244
245// Einsum
246REGISTER_SPMD(Einsum, TF::EinsumOp, EinsumSPMDExpander);
247
248// Matrix multiplication
249REGISTER_SPMD(BatchMatMulV2, TF::BatchMatMulV2Op, MatMulSPMDExpander);
250REGISTER_SPMD(MatMul, TF::MatMulOp, MatMulSPMDExpander);
251
252// Stack/unstack (pack/unpack)
253REGISTER_SPMD(Pack, TF::PackOp, PackSPMDExpander);
254REGISTER_SPMD(Unpack, TF::UnpackOp, UnpackSPMDExpander);
255
256// Reshape
257REGISTER_SPMD(Reshape, TF::ReshapeOp, ReshapeSPMDExpander);
258REGISTER_SPMD(Transpose, TF::TransposeOp, TransposeSPMDExpander);
259REGISTER_SPMD(InvertPermutation, TF::InvertPermutationOp,
260 ReplicatedOpSPMDExpander,
261 /*relayout_when_sharded=*/true);
262
263// Pad
264REGISTER_SPMD(Pad, TF::PadOp, PadSPMDExpander);
265REGISTER_SPMD(PadV2, TF::PadV2Op, PadSPMDExpander);
266
267// Scatter/Gather
268REGISTER_SPMD(GatherV2, TF::GatherV2Op, GatherV2SPMDExpander);
269REGISTER_SPMD(GatherNd, TF::GatherNdOp, GatherNdSPMDExpander);
270REGISTER_SPMD(TensorScatterUpdate, TF::TensorScatterUpdateOp,
271 TensorScatterOpSPMDExpander);
272REGISTER_SPMD(TensorScatterAdd, TF::TensorScatterAddOp,
273 TensorScatterOpSPMDExpander);
274
275// ArgMax/ArgMin
276REGISTER_SPMD(ArgMax, TF::ArgMaxOp, ArgMaxSPMDExpander);
277
278// Slice
279REGISTER_SPMD(Slice, TF::SliceOp, SliceSPMDExpander);
280REGISTER_SPMD(StridedSlice, TF::StridedSliceOp, StridedSliceSPMDExpander);
281REGISTER_SPMD(TensorStridedSliceUpdate, TF::TensorStridedSliceUpdateOp,
282 TensorStridedSliceUpdateSPMDExpander);
283REGISTER_SPMD(StridedSliceGrad, TF::StridedSliceGradOp,
284 StridedSliceGradSPMDExpander);
285
286// Split
287REGISTER_SPMD(Split, TF::SplitOp, SplitSPMDExpander);
288REGISTER_SPMD(SplitV, TF::SplitVOp, SplitVSPMDExpander);
289
290// Squeeze
291REGISTER_SPMD(Squeeze, TF::SqueezeOp, SqueezeSPMDExpander);
292
293// Concat
294REGISTER_SPMD(Concat, TF::ConcatOp, ConcatSPMDExpander);
295REGISTER_SPMD(ConcatV2, TF::ConcatV2Op, ConcatSPMDExpander);
296
297// Softmax Loss ops
298REGISTER_SPMD(SoftmaxCrossEntropyWithLogits,
299 TF::SoftmaxCrossEntropyWithLogitsOp, SoftmaxLossOpSPMDExpander);
300REGISTER_SPMD(SparseSoftmaxCrossEntropyWithLogits,
301 TF::SparseSoftmaxCrossEntropyWithLogitsOp,
302 SoftmaxLossOpSPMDExpander);
303
304// Softmax ops
305REGISTER_SPMD(Softmax, TF::SoftmaxOp, SoftmaxOpSPMDExpander);
306REGISTER_SPMD(LogSoftmax, TF::LogSoftmaxOp, SoftmaxOpSPMDExpander);
307
308// Random ops
309// LINT.IfChange
310REGISTER_SPMD(StatelessRandomUniform, TF::StatelessRandomUniformOp,
311 RandomOpSPMDExpander);
312REGISTER_SPMD(StatelessRandomUniformFullInt,
313 TF::StatelessRandomUniformFullIntOp, RandomOpSPMDExpander);
314REGISTER_SPMD(StatelessRandomNormal, TF::StatelessRandomNormalOp,
315 RandomOpSPMDExpander);
316REGISTER_SPMD(StatelessTruncatedNormal, TF::StatelessTruncatedNormalOp,
317 RandomOpSPMDExpander);
318// LINT.ThenChange(//tensorflow/dtensor/cc/small_constant_optimization.cc)
319// Random V2 ops
320REGISTER_SPMD(StatelessRandomGetKeyCounter, TF::StatelessRandomGetKeyCounterOp,
321 ReplicatedOpSPMDExpander);
322REGISTER_SPMD(RngReadAndSkip, TF::RngReadAndSkipOp, ReplicatedOpSPMDExpander);
323REGISTER_SPMD(StatelessRandomNormalV2, TF::StatelessRandomNormalV2Op,
324 RandomOpSPMDExpander);
325REGISTER_SPMD(StatelessRandomUniformV2, TF::StatelessRandomUniformV2Op,
326 RandomOpSPMDExpander);
327REGISTER_SPMD(StatelessRandomUniformFullIntV2,
328 TF::StatelessRandomUniformFullIntV2Op, RandomOpSPMDExpander);
329REGISTER_SPMD(StatelessRandomUniformIntV2, TF::StatelessRandomUniformIntV2Op,
330 RandomOpSPMDExpander);
331REGISTER_SPMD(StatelessTruncatedNormalV2, TF::StatelessTruncatedNormalV2Op,
332 RandomOpSPMDExpander);
333
334// Input agnotics ops
335REGISTER_SPMD(Fill, TF::FillOp, FillSPMDExpander);
336
337// Tile
338REGISTER_SPMD(Tile, TF::TileOp, TileSPMDExpander);
339
340// Expansion of ResourceApply ops are no-ops as they are always element-wise.
341// Also, ResourceApply ops do not have output values. As so, inferring layout
342// from operand and consumers are trivial(no-op).
343// Resource apply ops
344REGISTER_SPMD(ResourceApplyAdagradV2, TF::ResourceApplyAdagradV2Op,
345 NoOpSPMDExpander);
346REGISTER_SPMD(ResourceApplyAdam, TF::ResourceApplyAdamOp, NoOpSPMDExpander);
347REGISTER_SPMD(ResourceApplyGradientDescent, TF::ResourceApplyGradientDescentOp,
348 NoOpSPMDExpander);
349REGISTER_SPMD(ResourceApplyCenteredRMSProp, TF::ResourceApplyCenteredRMSPropOp,
350 NoOpSPMDExpander);
351REGISTER_SPMD(ResourceApplyKerasMomentum, TF::ResourceApplyKerasMomentumOp,
352 NoOpSPMDExpander);
353REGISTER_SPMD(ResourceApplyMomentum, TF::ResourceApplyMomentumOp,
354 NoOpSPMDExpander);
355
356// AssertOp
357REGISTER_SPMD(Assert, TF::AssertOp, NoOpSPMDExpander);
358
359// Terminator ops
360REGISTER_SPMD(Return, tf_device::ReturnOp, TerminatorSPMDExpander);
361
362// Onehot
363REGISTER_SPMD(OneHot, TF::OneHotOp, OneHotSPMDExpander);
364// ExpandDimsOp
365REGISTER_SPMD(ExpandDims, TF::ExpandDimsOp, ExpandDimsExpander);
366// UnsortedSegmentSumOp
367REGISTER_SPMD(UnsortedSegmentSum, TF::UnsortedSegmentSumOp,
368 UnsortedSegmentSumSPMDExpander);
369// BroadcastToOp
370REGISTER_SPMD(BroadcastTo, TF::BroadcastToOp, BroadcastToSPMDExpander);
371
372// Save/Restore ops.
373REGISTER_SPMD(SaveV2, TF::SaveV2Op, SaveRestoreSPMDExpander);
374REGISTER_SPMD(MergeV2Checkpoints, TF::MergeV2CheckpointsOp,
375 SaveRestoreSPMDExpander);
376REGISTER_SPMD(RestoreV2, TF::RestoreV2Op, SaveRestoreSPMDExpander);
377REGISTER_SPMD(DTensorRestoreV2, TF::DTensorRestoreV2Op,
378 SaveRestoreSPMDExpander);
379REGISTER_SPMD(DTensorShardedPrefix, TF::DTensorShardedPrefixOp,
380 DTensorShardPrefixSPMDExpander);
381
382// DTensor Virtual ops
383REGISTER_SPMD(Relayout, TF::RelayoutOp, RelayoutSPMDExpander);
384REGISTER_SPMD(DTensorSend, TF::DTensorSend, DTensorSendSPMDExpander);
385REGISTER_SPMD(DTensorRecv, TF::DTensorRecv, DTensorRecvSPMDExpander);
386
387// TopKV2
388REGISTER_SPMD(TopKV2, TF::TopKV2Op, TopKSPMDExpander);
389// InTopKV2
390REGISTER_SPMD(InTopKV2, TF::InTopKV2Op, InTopKSPMDExpander);
391
392// Control flow
393REGISTER_SPMD(WhileRegion, TF::WhileRegionOp, WhileRegionSPMDExpander);
394REGISTER_SPMD(IfRegion, TF::IfRegionOp, IfRegionSPMDExpander);
395
396// BiasAdd
397REGISTER_SPMD(BiasAdd, TF::BiasAddOp, BiasAddExpander);
398REGISTER_SPMD(BiasAddGrad, TF::BiasAddGradOp, ReduceSPMDExpander);
399
400// QR
401REGISTER_SPMD(Qr, TF::QrOp, QRSPMDExpander);
402
403// Data Parallel
404REGISTER_SPMD(AvgPool, TF::AvgPoolOp, DataparallelSPMDExpander,
405 llvm::DenseMap<int, int>{{0, 3}},
406 llvm::DenseMap<int, int>{{0, 3}});
407REGISTER_SPMD(AvgPool3D, TF::AvgPool3DOp, DataparallelSPMDExpander,
408 llvm::DenseMap<int, int>{{0, 4}},
409 llvm::DenseMap<int, int>{{0, 4}});
410REGISTER_SPMD(MaxPool3D, TF::MaxPool3DOp, DataparallelSPMDExpander,
411 llvm::DenseMap<int, int>{{0, 4}},
412 llvm::DenseMap<int, int>{{0, 4}});
413REGISTER_SPMD(DepthwiseConv2dNative, TF::DepthwiseConv2dNativeOp,
414 DataparallelSPMDExpander, llvm::DenseMap<int, int>{{0, 3}},
415 llvm::DenseMap<int, int>{{0, 3}});
416REGISTER_SPMD(ResizeBilinear, TF::ResizeBilinearOp, DataparallelSPMDExpander,
417 llvm::DenseMap<int, int>{{0, 3}},
418 llvm::DenseMap<int, int>{{0, 3}});
419REGISTER_SPMD(ResizeNearestNeighbor, TF::ResizeNearestNeighborOp,
420 DataparallelSPMDExpander, llvm::DenseMap<int, int>{{0, 3}},
421 llvm::DenseMap<int, int>{{0, 3}});
422REGISTER_SPMD(AdjustContrastv2, TF::AdjustContrastv2Op,
423 DataparallelSPMDExpander, llvm::DenseMap<int, int>{{0, 3}},
424 llvm::DenseMap<int, int>{{0, 3}});
425REGISTER_SPMD(AdjustSaturation, TF::AdjustSaturationOp,
426 DataparallelSPMDExpander, llvm::DenseMap<int, int>{{0, 3}},
427 llvm::DenseMap<int, int>{{0, 3}});
428REGISTER_SPMD(FFT, TF::FFTOp, DataparallelSPMDExpander,
429 llvm::DenseMap<int, int>{{0, 1}},
430 llvm::DenseMap<int, int>{{0, 1}});
431REGISTER_SPMD(FFT2D, TF::FFT2DOp, DataparallelSPMDExpander,
432 llvm::DenseMap<int, int>{{0, 1}},
433 llvm::DenseMap<int, int>{{0, 1}});
434REGISTER_SPMD(FFT3D, TF::FFT3DOp, DataparallelSPMDExpander,
435 llvm::DenseMap<int, int>{{0, 1}},
436 llvm::DenseMap<int, int>{{0, 1}});
437REGISTER_SPMD(IFFT, TF::IFFTOp, DataparallelSPMDExpander,
438 llvm::DenseMap<int, int>{{0, 1}},
439 llvm::DenseMap<int, int>{{0, 1}});
440REGISTER_SPMD(IFFT2D, TF::IFFT2DOp, DataparallelSPMDExpander,
441 llvm::DenseMap<int, int>{{0, 1}},
442 llvm::DenseMap<int, int>{{0, 1}});
443REGISTER_SPMD(IFFT3D, TF::IFFT3DOp, DataparallelSPMDExpander,
444 llvm::DenseMap<int, int>{{0, 1}},
445 llvm::DenseMap<int, int>{{0, 1}});
446REGISTER_SPMD(IRFFT, TF::IRFFTOp, DataparallelSPMDExpander,
447 llvm::DenseMap<int, int>{{0, 1}},
448 llvm::DenseMap<int, int>{{0, 1}});
449REGISTER_SPMD(IRFFT2D, TF::IRFFT2DOp, DataparallelSPMDExpander,
450 llvm::DenseMap<int, int>{{0, 1}},
451 llvm::DenseMap<int, int>{{0, 1}});
452REGISTER_SPMD(IRFFT3D, TF::IRFFT3DOp, DataparallelSPMDExpander,
453 llvm::DenseMap<int, int>{{0, 1}},
454 llvm::DenseMap<int, int>{{0, 1}});
455REGISTER_SPMD(RFFT, TF::RFFTOp, DataparallelSPMDExpander,
456 llvm::DenseMap<int, int>{{0, 1}},
457 llvm::DenseMap<int, int>{{0, 1}});
458REGISTER_SPMD(RFFT2D, TF::RFFT2DOp, DataparallelSPMDExpander,
459 llvm::DenseMap<int, int>{{0, 1}},
460 llvm::DenseMap<int, int>{{0, 1}});
461REGISTER_SPMD(RFFT3D, TF::RFFT3DOp, DataparallelSPMDExpander,
462 llvm::DenseMap<int, int>{{0, 1}},
463 llvm::DenseMap<int, int>{{0, 1}});
464REGISTER_SPMD(Cholesky, TF::CholeskyOp, DataparallelSPMDExpander,
465 llvm::DenseMap<int, int>{{0, 2}},
466 llvm::DenseMap<int, int>{{0, 2}});
467// Data Parallel Grad Ops
468REGISTER_SPMD(MaxPool3DGrad, TF::MaxPool3DGradOp, DataparallelSPMDExpander,
469 llvm::DenseMap<int, int>{{0, 4}, {1, 4}, {2, 4}},
470 llvm::DenseMap<int, int>{{0, 4}});
471REGISTER_SPMD(MaxPool3DGradGrad, TF::MaxPool3DGradGradOp,
472 DataparallelSPMDExpander,
473 llvm::DenseMap<int, int>{{0, 4}, {1, 4}, {2, 4}},
474 llvm::DenseMap<int, int>{{0, 4}});
475REGISTER_SPMD(MaxPoolGradGrad, TF::MaxPoolGradGradOp, DataparallelSPMDExpander,
476 llvm::DenseMap<int, int>{{0, 3}, {1, 3}, {2, 3}},
477 llvm::DenseMap<int, int>{{0, 3}});
478REGISTER_SPMD(ResizeBilinearGrad, TF::ResizeBilinearGradOp,
479 DataparallelSPMDExpander,
480 llvm::DenseMap<int, int>{{0, 3}, {1, 3}},
481 llvm::DenseMap<int, int>{{0, 3}});
482REGISTER_SPMD(ResizeNearestNeighborGrad, TF::ResizeNearestNeighborGradOp,
483 DataparallelSPMDExpander, llvm::DenseMap<int, int>{{0, 3}},
484 llvm::DenseMap<int, int>{{0, 3}});
485
486// DiagPart
487REGISTER_SPMD(DiagPart, TF::DiagPartOp, ReplicatedOpSPMDExpander,
488 /*relayout_when_sharded=*/true);
489
490// Cumsum
491REGISTER_SPMD(Cumsum, TF::CumsumOp, CumsumSPMDExpander);
492
493// SparseToDenseOp
494REGISTER_SPMD(SparseToDense, TF::SparseToDenseOp, SparseToDenseSPMDExpander);
495
496// StringFormat
497REGISTER_SPMD(StringFormat, TF::StringFormatOp, ReplicatedOpSPMDExpander,
498 /*relayout_when_sharded=*/true);
499
500// TensorList ops
501REGISTER_SPMD(TensorListReserve, TF::TensorListReserveOp,
502 TensorListReserveSPMDExpander);
503REGISTER_SPMD(TensorListGetItem, TF::TensorListGetItemOp,
504 TensorListGetItemSPMDExpander);
505REGISTER_SPMD(TensorListSetItem, TF::TensorListSetItemOp,
506 TensorListSetItemSPMDExpander);
507
508// IO ops
509REGISTER_SPMD(WriteSummary, TF::WriteSummaryOp, IOOpSPMDExpander);
510REGISTER_SPMD(DisableCopyOnRead, TF::DisableCopyOnReadOp,
511 DisableCopyOnReadSPMDExpander);
512REGISTER_SPMD(ShardedFilename, TF::ShardedFilenameOp, ReplicatedOpSPMDExpander);
513} // namespace dtensor
514} // namespace tensorflow
515