1 | /* Copyright 2022 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" |
17 | #include "tensorflow/dtensor/mlir/expansions/argmax_spmd_expander.h" |
18 | #include "tensorflow/dtensor/mlir/expansions/bias_add_spmd_expander.h" |
19 | #include "tensorflow/dtensor/mlir/expansions/broadcast_to_spmd_expander.h" |
20 | #include "tensorflow/dtensor/mlir/expansions/concat_spmd_expander.h" |
21 | #include "tensorflow/dtensor/mlir/expansions/control_flow_spmd_expander.h" |
22 | #include "tensorflow/dtensor/mlir/expansions/conv_spmd_expander.h" |
23 | #include "tensorflow/dtensor/mlir/expansions/cumsum_spmd_expander.h" |
24 | #include "tensorflow/dtensor/mlir/expansions/dataparallel_spmd_expander.h" |
25 | #include "tensorflow/dtensor/mlir/expansions/disable_copy_on_read_spmd_expander.h" |
26 | #include "tensorflow/dtensor/mlir/expansions/dtensor_op_spmd_expander.h" |
27 | #include "tensorflow/dtensor/mlir/expansions/einsum_spmd_expander.h" |
28 | #include "tensorflow/dtensor/mlir/expansions/elementwise_spmd_expander.h" |
29 | #include "tensorflow/dtensor/mlir/expansions/expanddims_spmd_expander.h" |
30 | #include "tensorflow/dtensor/mlir/expansions/fill_spmd_expander.h" |
31 | #include "tensorflow/dtensor/mlir/expansions/gather_spmd_expander.h" |
32 | #include "tensorflow/dtensor/mlir/expansions/identity_n_spmd_expander.h" |
33 | #include "tensorflow/dtensor/mlir/expansions/in_top_k_spmd_expander.h" |
34 | #include "tensorflow/dtensor/mlir/expansions/io_op_spmd_expander.h" |
35 | #include "tensorflow/dtensor/mlir/expansions/matmul_spmd_expander.h" |
36 | #include "tensorflow/dtensor/mlir/expansions/meta_spmd_expander.h" |
37 | #include "tensorflow/dtensor/mlir/expansions/nullary_spmd_expander.h" |
38 | #include "tensorflow/dtensor/mlir/expansions/qr_spmd_expander.h" |
39 | #include "tensorflow/dtensor/mlir/expansions/random_op_spmd_expander.h" |
40 | #include "tensorflow/dtensor/mlir/expansions/range_spmd_expander.h" |
41 | #include "tensorflow/dtensor/mlir/expansions/reduce_spmd_expander.h" |
42 | #include "tensorflow/dtensor/mlir/expansions/replicated_spmd_expander.h" |
43 | #include "tensorflow/dtensor/mlir/expansions/resource_spmd_expander.h" |
44 | #include "tensorflow/dtensor/mlir/expansions/save_restore_spmd_expander.h" |
45 | #include "tensorflow/dtensor/mlir/expansions/scatter_spmd_expander.h" |
46 | #include "tensorflow/dtensor/mlir/expansions/segmentation_spmd_expander.h" |
47 | #include "tensorflow/dtensor/mlir/expansions/slice_spmd_expander.h" |
48 | #include "tensorflow/dtensor/mlir/expansions/softmax_spmd_expander.h" |
49 | #include "tensorflow/dtensor/mlir/expansions/sparse_to_dense_spmd_expander.h" |
50 | #include "tensorflow/dtensor/mlir/expansions/split_spmd_expander.h" |
51 | #include "tensorflow/dtensor/mlir/expansions/squeeze_spmd_expander.h" |
52 | #include "tensorflow/dtensor/mlir/expansions/tensorlist_getitem_spmd_expander.h" |
53 | #include "tensorflow/dtensor/mlir/expansions/tensorlist_reserve_spmd_expander.h" |
54 | #include "tensorflow/dtensor/mlir/expansions/tensorlist_setitem_spmd_expander.h" |
55 | #include "tensorflow/dtensor/mlir/expansions/top_k_spmd_expander.h" |
56 | #include "tensorflow/dtensor/mlir/expansions/trivial_spmd_expander.h" |
57 | #include "tensorflow/dtensor/mlir/spmd_expander.h" |
58 | |
59 | namespace tensorflow { |
60 | namespace dtensor { |
61 | |
62 | // Nullary |
63 | REGISTER_SPMD(Const, TF::ConstOp, NullarySPMDExpander); |
64 | |
65 | // Unary |
66 | REGISTER_SPMD(Abs, TF::AbsOp, ElementwiseSPMDExpander); |
67 | REGISTER_SPMD(Cast, TF::CastOp, ElementwiseSPMDExpander); |
68 | REGISTER_SPMD(Identity, TF::IdentityOp, ElementwiseSPMDExpander); |
69 | REGISTER_SPMD(Neg, TF::NegOp, ElementwiseSPMDExpander); |
70 | REGISTER_SPMD(ZerosLike, TF::ZerosLikeOp, ElementwiseSPMDExpander); |
71 | REGISTER_SPMD(Exp, TF::ExpOp, ElementwiseSPMDExpander); |
72 | REGISTER_SPMD(Sqrt, TF::SqrtOp, ElementwiseSPMDExpander); |
73 | REGISTER_SPMD(Rsqrt, TF::RsqrtOp, ElementwiseSPMDExpander); |
74 | REGISTER_SPMD(Log, TF::LogOp, ElementwiseSPMDExpander); |
75 | REGISTER_SPMD(StopGradient, TF::StopGradientOp, ElementwiseSPMDExpander); |
76 | REGISTER_SPMD(Reciprocal, TF::ReciprocalOp, ElementwiseSPMDExpander); |
77 | REGISTER_SPMD(Square, TF::SquareOp, ElementwiseSPMDExpander); |
78 | REGISTER_SPMD(Erf, TF::ErfOp, ElementwiseSPMDExpander); |
79 | REGISTER_SPMD(Tanh, TF::TanhOp, ElementwiseSPMDExpander); |
80 | REGISTER_SPMD(TanhGrad, TF::TanhGradOp, ElementwiseSPMDExpander); |
81 | REGISTER_SPMD(Relu, TF::ReluOp, ElementwiseSPMDExpander); |
82 | REGISTER_SPMD(ReluGrad, TF::ReluGradOp, ElementwiseSPMDExpander); |
83 | REGISTER_SPMD(Sigmoid, TF::SigmoidOp, ElementwiseSPMDExpander); |
84 | REGISTER_SPMD(SigmoidGrad, TF::SigmoidGradOp, ElementwiseSPMDExpander); |
85 | REGISTER_SPMD(IsFinite, TF::IsFiniteOp, ElementwiseSPMDExpander); |
86 | |
87 | // Elementwise |
88 | REGISTER_SPMD(Add, TF::AddOp, ElementwiseSPMDExpander); |
89 | REGISTER_SPMD(AddV2, TF::AddV2Op, ElementwiseSPMDExpander); |
90 | REGISTER_SPMD(AddN, TF::AddNOp, ElementwiseSPMDExpander); |
91 | REGISTER_SPMD(RealDiv, TF::RealDivOp, ElementwiseSPMDExpander); |
92 | REGISTER_SPMD(Div, TF::DivOp, ElementwiseSPMDExpander); |
93 | REGISTER_SPMD(DivNoNan, TF::DivNoNanOp, ElementwiseSPMDExpander); |
94 | REGISTER_SPMD(Equal, TF::EqualOp, ElementwiseSPMDExpander); |
95 | REGISTER_SPMD(FloorDiv, TF::FloorDivOp, ElementwiseSPMDExpander); |
96 | REGISTER_SPMD(FloorMod, TF::FloorModOp, ElementwiseSPMDExpander); |
97 | REGISTER_SPMD(NotEqual, TF::NotEqualOp, ElementwiseSPMDExpander); |
98 | REGISTER_SPMD(Less, TF::LessOp, ElementwiseSPMDExpander); |
99 | REGISTER_SPMD(LessEqual, TF::LessEqualOp, ElementwiseSPMDExpander); |
100 | REGISTER_SPMD(LogicalAnd, TF::LogicalAndOp, ElementwiseSPMDExpander); |
101 | REGISTER_SPMD(LogicalNot, TF::LogicalNotOp, ElementwiseSPMDExpander); |
102 | REGISTER_SPMD(Maximum, TF::MaximumOp, ElementwiseSPMDExpander); |
103 | REGISTER_SPMD(Minimum, TF::MinimumOp, ElementwiseSPMDExpander); |
104 | REGISTER_SPMD(Mul, TF::MulOp, ElementwiseSPMDExpander); |
105 | REGISTER_SPMD(Select, TF::SelectOp, ElementwiseSPMDExpander); |
106 | REGISTER_SPMD(SelectV2, TF::SelectV2Op, ElementwiseSPMDExpander); |
107 | REGISTER_SPMD(Sub, TF::SubOp, ElementwiseSPMDExpander); |
108 | REGISTER_SPMD(SquaredDifference, TF::SquaredDifferenceOp, |
109 | ElementwiseSPMDExpander); |
110 | REGISTER_SPMD(Greater, TF::GreaterOp, ElementwiseSPMDExpander); |
111 | REGISTER_SPMD(GreaterEqual, TF::GreaterEqualOp, ElementwiseSPMDExpander); |
112 | REGISTER_SPMD(RsqrtGrad, TF::RsqrtGradOp, ElementwiseSPMDExpander); |
113 | REGISTER_SPMD(SqrtGrad, TF::SqrtGradOp, ElementwiseSPMDExpander); |
114 | REGISTER_SPMD(Pow, TF::PowOp, ElementwiseSPMDExpander); |
115 | REGISTER_SPMD(BitwiseAnd, TF::BitwiseAndOp, ElementwiseSPMDExpander); |
116 | REGISTER_SPMD(BitwiseOr, TF::BitwiseOrOp, ElementwiseSPMDExpander); |
117 | REGISTER_SPMD(BitwiseXor, TF::BitwiseXorOp, ElementwiseSPMDExpander); |
118 | REGISTER_SPMD(LeftShift, TF::LeftShiftOp, ElementwiseSPMDExpander); |
119 | REGISTER_SPMD(RightShift, TF::RightShiftOp, ElementwiseSPMDExpander); |
120 | REGISTER_SPMD(LogicalOr, TF::LogicalOrOp, ElementwiseSPMDExpander); |
121 | REGISTER_SPMD(Cos, TF::CosOp, ElementwiseSPMDExpander); |
122 | REGISTER_SPMD(Acos, TF::AcosOp, ElementwiseSPMDExpander); |
123 | REGISTER_SPMD(Acosh, TF::AcoshOp, ElementwiseSPMDExpander); |
124 | REGISTER_SPMD(Angle, TF::AngleOp, ElementwiseSPMDExpander); |
125 | REGISTER_SPMD(Asin, TF::AsinOp, ElementwiseSPMDExpander); |
126 | REGISTER_SPMD(Asinh, TF::AsinhOp, ElementwiseSPMDExpander); |
127 | REGISTER_SPMD(Atan, TF::AtanOp, ElementwiseSPMDExpander); |
128 | REGISTER_SPMD(Atan2, TF::Atan2Op, ElementwiseSPMDExpander); |
129 | REGISTER_SPMD(Atanh, TF::AtanhOp, ElementwiseSPMDExpander); |
130 | REGISTER_SPMD(BesselI0e, TF::BesselI0eOp, ElementwiseSPMDExpander); |
131 | REGISTER_SPMD(BesselI1e, TF::BesselI1eOp, ElementwiseSPMDExpander); |
132 | REGISTER_SPMD(Betainc, TF::BetaincOp, ElementwiseSPMDExpander); |
133 | REGISTER_SPMD(Bitcast, TF::BitcastOp, ElementwiseSPMDExpander); |
134 | REGISTER_SPMD(Ceil, TF::CeilOp, ElementwiseSPMDExpander); |
135 | REGISTER_SPMD(CheckNumerics, TF::CheckNumericsOp, ElementwiseSPMDExpander); |
136 | REGISTER_SPMD(ClipByValue, TF::ClipByValueOp, ElementwiseSPMDExpander); |
137 | REGISTER_SPMD(Conj, TF::ConjOp, ElementwiseSPMDExpander); |
138 | REGISTER_SPMD(Cosh, TF::CoshOp, ElementwiseSPMDExpander); |
139 | REGISTER_SPMD(Complex, TF::ComplexOp, ElementwiseSPMDExpander); |
140 | REGISTER_SPMD(ComplexAbs, TF::ComplexAbsOp, ElementwiseSPMDExpander); |
141 | REGISTER_SPMD(Digamma, TF::DigammaOp, ElementwiseSPMDExpander); |
142 | |
143 | // TODO(b/193924452): Add the following Ops once unit tests are there. |
144 | // |
145 | REGISTER_SPMD(Elu, TF::EluOp, ElementwiseSPMDExpander); |
146 | REGISTER_SPMD(EluGrad, TF::EluGradOp, ElementwiseSPMDExpander); |
147 | REGISTER_SPMD(Erfc, TF::ErfcOp, ElementwiseSPMDExpander); |
148 | REGISTER_SPMD(Erfinv, TF::ErfinvOp, ElementwiseSPMDExpander); |
149 | REGISTER_SPMD(Expm1, TF::Expm1Op, ElementwiseSPMDExpander); |
150 | REGISTER_SPMD(Floor, TF::FloorOp, ElementwiseSPMDExpander); |
151 | // REGISTER_SPMD(HSVToRGB, TF::HSVToRGBOp, ElementwiseSPMDExpander); |
152 | REGISTER_SPMD(Igamma, TF::IgammaOp, ElementwiseSPMDExpander); |
153 | REGISTER_SPMD(Igammac, TF::IgammacOp, ElementwiseSPMDExpander); |
154 | REGISTER_SPMD(IgammaGradA, TF::IgammaGradAOp, ElementwiseSPMDExpander); |
155 | REGISTER_SPMD(Imag, TF::ImagOp, ElementwiseSPMDExpander); |
156 | // REGISTER_SPMD(InplaceAdd, TF::InplaceAddOp, ElementwiseSPMDExpander); |
157 | // REGISTER_SPMD(InplaceUpdate, TF::InplaceUpdateOp, ElementwiseSPMDExpander); |
158 | REGISTER_SPMD(Inv, TF::InvOp, ElementwiseSPMDExpander); |
159 | REGISTER_SPMD(Invert, TF::InvertOp, ElementwiseSPMDExpander); |
160 | REGISTER_SPMD(IsInf, TF::IsInfOp, ElementwiseSPMDExpander); |
161 | REGISTER_SPMD(IsNan, TF::IsNanOp, ElementwiseSPMDExpander); |
162 | REGISTER_SPMD(LeakyRelu, TF::LeakyReluOp, ElementwiseSPMDExpander); |
163 | REGISTER_SPMD(LeakyReluGrad, TF::LeakyReluGradOp, ElementwiseSPMDExpander); |
164 | REGISTER_SPMD(Lgamma, TF::LgammaOp, ElementwiseSPMDExpander); |
165 | REGISTER_SPMD(Log1p, TF::Log1pOp, ElementwiseSPMDExpander); |
166 | REGISTER_SPMD(MulNoNan, TF::MulNoNanOp, ElementwiseSPMDExpander); |
167 | REGISTER_SPMD(Ndtri, TF::NdtriOp, ElementwiseSPMDExpander); |
168 | REGISTER_SPMD(NextAfter, TF::NextAfterOp, ElementwiseSPMDExpander); |
169 | REGISTER_SPMD(Polygamma, TF::PolygammaOp, ElementwiseSPMDExpander); |
170 | REGISTER_SPMD(PopulationCount, TF::PopulationCountOp, ElementwiseSPMDExpander); |
171 | REGISTER_SPMD(PreventGradient, TF::PreventGradientOp, ElementwiseSPMDExpander); |
172 | REGISTER_SPMD(Real, TF::RealOp, ElementwiseSPMDExpander); |
173 | REGISTER_SPMD(ReciprocalGrad, TF::ReciprocalGradOp, ElementwiseSPMDExpander); |
174 | REGISTER_SPMD(Relu6, TF::Relu6Op, ElementwiseSPMDExpander); |
175 | REGISTER_SPMD(Relu6Grad, TF::Relu6GradOp, ElementwiseSPMDExpander); |
176 | REGISTER_SPMD(Rint, TF::RintOp, ElementwiseSPMDExpander); |
177 | REGISTER_SPMD(Round, TF::RoundOp, ElementwiseSPMDExpander); |
178 | REGISTER_SPMD(Selu, TF::SeluOp, ElementwiseSPMDExpander); |
179 | REGISTER_SPMD(SeluGrad, TF::SeluGradOp, ElementwiseSPMDExpander); |
180 | REGISTER_SPMD(Sign, TF::SignOp, ElementwiseSPMDExpander); |
181 | REGISTER_SPMD(Sin, TF::SinOp, ElementwiseSPMDExpander); |
182 | REGISTER_SPMD(Sinh, TF::SinhOp, ElementwiseSPMDExpander); |
183 | // REGISTER_SPMD(Snapshot, TF::SnapshotOp, ElementwiseSPMDExpander); |
184 | REGISTER_SPMD(Softplus, TF::SoftplusOp, ElementwiseSPMDExpander); |
185 | // REGISTER_SPMD(SoftplusGrad, TF::SoftplusGradOp, ElementwiseSPMDExpander); |
186 | REGISTER_SPMD(Softsign, TF::SoftsignOp, ElementwiseSPMDExpander); |
187 | // REGISTER_SPMD(SoftsignGrad, TF::SoftsignGradOp, ElementwiseSPMDExpander); |
188 | REGISTER_SPMD(Tan, TF::TanOp, ElementwiseSPMDExpander); |
189 | // REGISTER_SPMD(TridiagonalSolve, TF::TridiagonalSolveOp, |
190 | // ElementwiseSPMDExpander); |
191 | REGISTER_SPMD(TruncateDiv, TF::TruncateDivOp, ElementwiseSPMDExpander); |
192 | REGISTER_SPMD(TruncateMod, TF::TruncateModOp, ElementwiseSPMDExpander); |
193 | REGISTER_SPMD(Xdivy, TF::XdivyOp, ElementwiseSPMDExpander); |
194 | REGISTER_SPMD(Xlog1py, TF::Xlog1pyOp, ElementwiseSPMDExpander); |
195 | REGISTER_SPMD(Xlogy, TF::XlogyOp, ElementwiseSPMDExpander); |
196 | REGISTER_SPMD(Zeta, TF::ZetaOp, ElementwiseSPMDExpander); |
197 | |
198 | // IdentityN |
199 | // TODO(hongjunchoi): Make ElementwiseSPMDExpander support IdentityN. |
200 | REGISTER_SPMD(IdentityN, TF::IdentityNOp, IdentityNSPMDExpander); |
201 | |
202 | // Range |
203 | REGISTER_SPMD(Range, TF::RangeOp, RangeSPMDExpander); |
204 | |
205 | // Reductions |
206 | REGISTER_SPMD(All, TF::AllOp, ReduceSPMDExpander); |
207 | REGISTER_SPMD(Any, TF::AnyOp, ReduceSPMDExpander); |
208 | REGISTER_SPMD(Mean, TF::MeanOp, ReduceSPMDExpander); |
209 | REGISTER_SPMD(Max, TF::MaxOp, ReduceSPMDExpander); |
210 | REGISTER_SPMD(Min, TF::MinOp, ReduceSPMDExpander); |
211 | REGISTER_SPMD(Prod, TF::ProdOp, ReduceSPMDExpander); |
212 | REGISTER_SPMD(Sum, TF::SumOp, ReduceSPMDExpander); |
213 | REGISTER_SPMD(L2Loss, TF::L2LossOp, ReduceSPMDExpander); |
214 | |
215 | // Convolution |
216 | REGISTER_SPMD(Conv2D, TF::Conv2DOp, ConvSPMDExpander); |
217 | REGISTER_SPMD(Conv2DBackpropFilter, TF::Conv2DBackpropFilterOp, |
218 | ConvSPMDExpander); |
219 | REGISTER_SPMD(Conv2DBackpropInput, TF::Conv2DBackpropInputOp, ConvSPMDExpander); |
220 | REGISTER_SPMD(Conv3D, TF::Conv3DOp, ConvSPMDExpander); |
221 | REGISTER_SPMD(Conv3DBackpropFilterV2, TF::Conv3DBackpropFilterV2Op, |
222 | ConvSPMDExpander); |
223 | REGISTER_SPMD(Conv3DBackpropInputV2, TF::Conv3DBackpropInputV2Op, |
224 | ConvSPMDExpander); |
225 | REGISTER_SPMD(MaxPool, TF::MaxPoolOp, ConvSPMDExpander); |
226 | REGISTER_SPMD(MaxPoolGrad, TF::MaxPoolGradOp, ConvSPMDExpander); |
227 | |
228 | // Metadata |
229 | REGISTER_SPMD(Rank, TF::RankOp, ShapeSPMDExpander); |
230 | REGISTER_SPMD(Shape, TF::ShapeOp, ShapeSPMDExpander); |
231 | REGISTER_SPMD(ShapeN, TF::ShapeNOp, ShapeSPMDExpander); |
232 | |
233 | REGISTER_SPMD(BroadcastGradientArgs, TF::BroadcastGradientArgsOp, |
234 | MetadataSPMDExpander); |
235 | |
236 | // Resource ops |
237 | REGISTER_SPMD(AssignVariable, TF::AssignVariableOp, ResourceSPMDExpander); |
238 | REGISTER_SPMD(AssignAddVariable, TF::AssignAddVariableOp, ResourceSPMDExpander); |
239 | REGISTER_SPMD(AssignSubVariable, TF::AssignSubVariableOp, ResourceSPMDExpander); |
240 | REGISTER_SPMD(ReadVariable, TF::ReadVariableOp, ResourceSPMDExpander); |
241 | REGISTER_SPMD(VarHandle, TF::VarHandleOp, ResourceSPMDExpander); |
242 | REGISTER_SPMD(VarIsInitialized, TF::VarIsInitializedOp, ResourceSPMDExpander); |
243 | REGISTER_SPMD(DestroyResource, TF::DestroyResourceOp, ResourceSPMDExpander); |
244 | |
245 | // Einsum |
246 | REGISTER_SPMD(Einsum, TF::EinsumOp, EinsumSPMDExpander); |
247 | |
248 | // Matrix multiplication |
249 | REGISTER_SPMD(BatchMatMulV2, TF::BatchMatMulV2Op, MatMulSPMDExpander); |
250 | REGISTER_SPMD(MatMul, TF::MatMulOp, MatMulSPMDExpander); |
251 | |
252 | // Stack/unstack (pack/unpack) |
253 | REGISTER_SPMD(Pack, TF::PackOp, PackSPMDExpander); |
254 | REGISTER_SPMD(Unpack, TF::UnpackOp, UnpackSPMDExpander); |
255 | |
256 | // Reshape |
257 | REGISTER_SPMD(Reshape, TF::ReshapeOp, ReshapeSPMDExpander); |
258 | REGISTER_SPMD(Transpose, TF::TransposeOp, TransposeSPMDExpander); |
259 | REGISTER_SPMD(InvertPermutation, TF::InvertPermutationOp, |
260 | ReplicatedOpSPMDExpander, |
261 | /*relayout_when_sharded=*/true); |
262 | |
263 | // Pad |
264 | REGISTER_SPMD(Pad, TF::PadOp, PadSPMDExpander); |
265 | REGISTER_SPMD(PadV2, TF::PadV2Op, PadSPMDExpander); |
266 | |
267 | // Scatter/Gather |
268 | REGISTER_SPMD(GatherV2, TF::GatherV2Op, GatherV2SPMDExpander); |
269 | REGISTER_SPMD(GatherNd, TF::GatherNdOp, GatherNdSPMDExpander); |
270 | REGISTER_SPMD(TensorScatterUpdate, TF::TensorScatterUpdateOp, |
271 | TensorScatterOpSPMDExpander); |
272 | REGISTER_SPMD(TensorScatterAdd, TF::TensorScatterAddOp, |
273 | TensorScatterOpSPMDExpander); |
274 | |
275 | // ArgMax/ArgMin |
276 | REGISTER_SPMD(ArgMax, TF::ArgMaxOp, ArgMaxSPMDExpander); |
277 | |
278 | // Slice |
279 | REGISTER_SPMD(Slice, TF::SliceOp, SliceSPMDExpander); |
280 | REGISTER_SPMD(StridedSlice, TF::StridedSliceOp, StridedSliceSPMDExpander); |
281 | REGISTER_SPMD(TensorStridedSliceUpdate, TF::TensorStridedSliceUpdateOp, |
282 | TensorStridedSliceUpdateSPMDExpander); |
283 | REGISTER_SPMD(StridedSliceGrad, TF::StridedSliceGradOp, |
284 | StridedSliceGradSPMDExpander); |
285 | |
286 | // Split |
287 | REGISTER_SPMD(Split, TF::SplitOp, SplitSPMDExpander); |
288 | REGISTER_SPMD(SplitV, TF::SplitVOp, SplitVSPMDExpander); |
289 | |
290 | // Squeeze |
291 | REGISTER_SPMD(Squeeze, TF::SqueezeOp, SqueezeSPMDExpander); |
292 | |
293 | // Concat |
294 | REGISTER_SPMD(Concat, TF::ConcatOp, ConcatSPMDExpander); |
295 | REGISTER_SPMD(ConcatV2, TF::ConcatV2Op, ConcatSPMDExpander); |
296 | |
297 | // Softmax Loss ops |
298 | REGISTER_SPMD(SoftmaxCrossEntropyWithLogits, |
299 | TF::SoftmaxCrossEntropyWithLogitsOp, SoftmaxLossOpSPMDExpander); |
300 | REGISTER_SPMD(SparseSoftmaxCrossEntropyWithLogits, |
301 | TF::SparseSoftmaxCrossEntropyWithLogitsOp, |
302 | SoftmaxLossOpSPMDExpander); |
303 | |
304 | // Softmax ops |
305 | REGISTER_SPMD(Softmax, TF::SoftmaxOp, SoftmaxOpSPMDExpander); |
306 | REGISTER_SPMD(LogSoftmax, TF::LogSoftmaxOp, SoftmaxOpSPMDExpander); |
307 | |
308 | // Random ops |
309 | // LINT.IfChange |
310 | REGISTER_SPMD(StatelessRandomUniform, TF::StatelessRandomUniformOp, |
311 | RandomOpSPMDExpander); |
312 | REGISTER_SPMD(StatelessRandomUniformFullInt, |
313 | TF::StatelessRandomUniformFullIntOp, RandomOpSPMDExpander); |
314 | REGISTER_SPMD(StatelessRandomNormal, TF::StatelessRandomNormalOp, |
315 | RandomOpSPMDExpander); |
316 | REGISTER_SPMD(StatelessTruncatedNormal, TF::StatelessTruncatedNormalOp, |
317 | RandomOpSPMDExpander); |
318 | // LINT.ThenChange(//tensorflow/dtensor/cc/small_constant_optimization.cc) |
319 | // Random V2 ops |
320 | REGISTER_SPMD(StatelessRandomGetKeyCounter, TF::StatelessRandomGetKeyCounterOp, |
321 | ReplicatedOpSPMDExpander); |
322 | REGISTER_SPMD(RngReadAndSkip, TF::RngReadAndSkipOp, ReplicatedOpSPMDExpander); |
323 | REGISTER_SPMD(StatelessRandomNormalV2, TF::StatelessRandomNormalV2Op, |
324 | RandomOpSPMDExpander); |
325 | REGISTER_SPMD(StatelessRandomUniformV2, TF::StatelessRandomUniformV2Op, |
326 | RandomOpSPMDExpander); |
327 | REGISTER_SPMD(StatelessRandomUniformFullIntV2, |
328 | TF::StatelessRandomUniformFullIntV2Op, RandomOpSPMDExpander); |
329 | REGISTER_SPMD(StatelessRandomUniformIntV2, TF::StatelessRandomUniformIntV2Op, |
330 | RandomOpSPMDExpander); |
331 | REGISTER_SPMD(StatelessTruncatedNormalV2, TF::StatelessTruncatedNormalV2Op, |
332 | RandomOpSPMDExpander); |
333 | |
334 | // Input agnotics ops |
335 | REGISTER_SPMD(Fill, TF::FillOp, FillSPMDExpander); |
336 | |
337 | // Tile |
338 | REGISTER_SPMD(Tile, TF::TileOp, TileSPMDExpander); |
339 | |
340 | // Expansion of ResourceApply ops are no-ops as they are always element-wise. |
341 | // Also, ResourceApply ops do not have output values. As so, inferring layout |
342 | // from operand and consumers are trivial(no-op). |
343 | // Resource apply ops |
344 | REGISTER_SPMD(ResourceApplyAdagradV2, TF::ResourceApplyAdagradV2Op, |
345 | NoOpSPMDExpander); |
346 | REGISTER_SPMD(ResourceApplyAdam, TF::ResourceApplyAdamOp, NoOpSPMDExpander); |
347 | REGISTER_SPMD(ResourceApplyGradientDescent, TF::ResourceApplyGradientDescentOp, |
348 | NoOpSPMDExpander); |
349 | REGISTER_SPMD(ResourceApplyCenteredRMSProp, TF::ResourceApplyCenteredRMSPropOp, |
350 | NoOpSPMDExpander); |
351 | REGISTER_SPMD(ResourceApplyKerasMomentum, TF::ResourceApplyKerasMomentumOp, |
352 | NoOpSPMDExpander); |
353 | REGISTER_SPMD(ResourceApplyMomentum, TF::ResourceApplyMomentumOp, |
354 | NoOpSPMDExpander); |
355 | |
356 | // AssertOp |
357 | REGISTER_SPMD(Assert, TF::AssertOp, NoOpSPMDExpander); |
358 | |
359 | // Terminator ops |
360 | REGISTER_SPMD(Return, tf_device::ReturnOp, TerminatorSPMDExpander); |
361 | |
362 | // Onehot |
363 | REGISTER_SPMD(OneHot, TF::OneHotOp, OneHotSPMDExpander); |
364 | // ExpandDimsOp |
365 | REGISTER_SPMD(ExpandDims, TF::ExpandDimsOp, ExpandDimsExpander); |
366 | // UnsortedSegmentSumOp |
367 | REGISTER_SPMD(UnsortedSegmentSum, TF::UnsortedSegmentSumOp, |
368 | UnsortedSegmentSumSPMDExpander); |
369 | // BroadcastToOp |
370 | REGISTER_SPMD(BroadcastTo, TF::BroadcastToOp, BroadcastToSPMDExpander); |
371 | |
372 | // Save/Restore ops. |
373 | REGISTER_SPMD(SaveV2, TF::SaveV2Op, SaveRestoreSPMDExpander); |
374 | REGISTER_SPMD(MergeV2Checkpoints, TF::MergeV2CheckpointsOp, |
375 | SaveRestoreSPMDExpander); |
376 | REGISTER_SPMD(RestoreV2, TF::RestoreV2Op, SaveRestoreSPMDExpander); |
377 | REGISTER_SPMD(DTensorRestoreV2, TF::DTensorRestoreV2Op, |
378 | SaveRestoreSPMDExpander); |
379 | REGISTER_SPMD(DTensorShardedPrefix, TF::DTensorShardedPrefixOp, |
380 | DTensorShardPrefixSPMDExpander); |
381 | |
382 | // DTensor Virtual ops |
383 | REGISTER_SPMD(Relayout, TF::RelayoutOp, RelayoutSPMDExpander); |
384 | REGISTER_SPMD(DTensorSend, TF::DTensorSend, DTensorSendSPMDExpander); |
385 | REGISTER_SPMD(DTensorRecv, TF::DTensorRecv, DTensorRecvSPMDExpander); |
386 | |
387 | // TopKV2 |
388 | REGISTER_SPMD(TopKV2, TF::TopKV2Op, TopKSPMDExpander); |
389 | // InTopKV2 |
390 | REGISTER_SPMD(InTopKV2, TF::InTopKV2Op, InTopKSPMDExpander); |
391 | |
392 | // Control flow |
393 | REGISTER_SPMD(WhileRegion, TF::WhileRegionOp, WhileRegionSPMDExpander); |
394 | REGISTER_SPMD(IfRegion, TF::IfRegionOp, IfRegionSPMDExpander); |
395 | |
396 | // BiasAdd |
397 | REGISTER_SPMD(BiasAdd, TF::BiasAddOp, BiasAddExpander); |
398 | REGISTER_SPMD(BiasAddGrad, TF::BiasAddGradOp, ReduceSPMDExpander); |
399 | |
400 | // QR |
401 | REGISTER_SPMD(Qr, TF::QrOp, QRSPMDExpander); |
402 | |
403 | // Data Parallel |
404 | REGISTER_SPMD(AvgPool, TF::AvgPoolOp, DataparallelSPMDExpander, |
405 | llvm::DenseMap<int, int>{{0, 3}}, |
406 | llvm::DenseMap<int, int>{{0, 3}}); |
407 | REGISTER_SPMD(AvgPool3D, TF::AvgPool3DOp, DataparallelSPMDExpander, |
408 | llvm::DenseMap<int, int>{{0, 4}}, |
409 | llvm::DenseMap<int, int>{{0, 4}}); |
410 | REGISTER_SPMD(MaxPool3D, TF::MaxPool3DOp, DataparallelSPMDExpander, |
411 | llvm::DenseMap<int, int>{{0, 4}}, |
412 | llvm::DenseMap<int, int>{{0, 4}}); |
413 | REGISTER_SPMD(DepthwiseConv2dNative, TF::DepthwiseConv2dNativeOp, |
414 | DataparallelSPMDExpander, llvm::DenseMap<int, int>{{0, 3}}, |
415 | llvm::DenseMap<int, int>{{0, 3}}); |
416 | REGISTER_SPMD(ResizeBilinear, TF::ResizeBilinearOp, DataparallelSPMDExpander, |
417 | llvm::DenseMap<int, int>{{0, 3}}, |
418 | llvm::DenseMap<int, int>{{0, 3}}); |
419 | REGISTER_SPMD(ResizeNearestNeighbor, TF::ResizeNearestNeighborOp, |
420 | DataparallelSPMDExpander, llvm::DenseMap<int, int>{{0, 3}}, |
421 | llvm::DenseMap<int, int>{{0, 3}}); |
422 | REGISTER_SPMD(AdjustContrastv2, TF::AdjustContrastv2Op, |
423 | DataparallelSPMDExpander, llvm::DenseMap<int, int>{{0, 3}}, |
424 | llvm::DenseMap<int, int>{{0, 3}}); |
425 | REGISTER_SPMD(AdjustSaturation, TF::AdjustSaturationOp, |
426 | DataparallelSPMDExpander, llvm::DenseMap<int, int>{{0, 3}}, |
427 | llvm::DenseMap<int, int>{{0, 3}}); |
428 | REGISTER_SPMD(FFT, TF::FFTOp, DataparallelSPMDExpander, |
429 | llvm::DenseMap<int, int>{{0, 1}}, |
430 | llvm::DenseMap<int, int>{{0, 1}}); |
431 | REGISTER_SPMD(FFT2D, TF::FFT2DOp, DataparallelSPMDExpander, |
432 | llvm::DenseMap<int, int>{{0, 1}}, |
433 | llvm::DenseMap<int, int>{{0, 1}}); |
434 | REGISTER_SPMD(FFT3D, TF::FFT3DOp, DataparallelSPMDExpander, |
435 | llvm::DenseMap<int, int>{{0, 1}}, |
436 | llvm::DenseMap<int, int>{{0, 1}}); |
437 | REGISTER_SPMD(IFFT, TF::IFFTOp, DataparallelSPMDExpander, |
438 | llvm::DenseMap<int, int>{{0, 1}}, |
439 | llvm::DenseMap<int, int>{{0, 1}}); |
440 | REGISTER_SPMD(IFFT2D, TF::IFFT2DOp, DataparallelSPMDExpander, |
441 | llvm::DenseMap<int, int>{{0, 1}}, |
442 | llvm::DenseMap<int, int>{{0, 1}}); |
443 | REGISTER_SPMD(IFFT3D, TF::IFFT3DOp, DataparallelSPMDExpander, |
444 | llvm::DenseMap<int, int>{{0, 1}}, |
445 | llvm::DenseMap<int, int>{{0, 1}}); |
446 | REGISTER_SPMD(IRFFT, TF::IRFFTOp, DataparallelSPMDExpander, |
447 | llvm::DenseMap<int, int>{{0, 1}}, |
448 | llvm::DenseMap<int, int>{{0, 1}}); |
449 | REGISTER_SPMD(IRFFT2D, TF::IRFFT2DOp, DataparallelSPMDExpander, |
450 | llvm::DenseMap<int, int>{{0, 1}}, |
451 | llvm::DenseMap<int, int>{{0, 1}}); |
452 | REGISTER_SPMD(IRFFT3D, TF::IRFFT3DOp, DataparallelSPMDExpander, |
453 | llvm::DenseMap<int, int>{{0, 1}}, |
454 | llvm::DenseMap<int, int>{{0, 1}}); |
455 | REGISTER_SPMD(RFFT, TF::RFFTOp, DataparallelSPMDExpander, |
456 | llvm::DenseMap<int, int>{{0, 1}}, |
457 | llvm::DenseMap<int, int>{{0, 1}}); |
458 | REGISTER_SPMD(RFFT2D, TF::RFFT2DOp, DataparallelSPMDExpander, |
459 | llvm::DenseMap<int, int>{{0, 1}}, |
460 | llvm::DenseMap<int, int>{{0, 1}}); |
461 | REGISTER_SPMD(RFFT3D, TF::RFFT3DOp, DataparallelSPMDExpander, |
462 | llvm::DenseMap<int, int>{{0, 1}}, |
463 | llvm::DenseMap<int, int>{{0, 1}}); |
464 | REGISTER_SPMD(Cholesky, TF::CholeskyOp, DataparallelSPMDExpander, |
465 | llvm::DenseMap<int, int>{{0, 2}}, |
466 | llvm::DenseMap<int, int>{{0, 2}}); |
467 | // Data Parallel Grad Ops |
468 | REGISTER_SPMD(MaxPool3DGrad, TF::MaxPool3DGradOp, DataparallelSPMDExpander, |
469 | llvm::DenseMap<int, int>{{0, 4}, {1, 4}, {2, 4}}, |
470 | llvm::DenseMap<int, int>{{0, 4}}); |
471 | REGISTER_SPMD(MaxPool3DGradGrad, TF::MaxPool3DGradGradOp, |
472 | DataparallelSPMDExpander, |
473 | llvm::DenseMap<int, int>{{0, 4}, {1, 4}, {2, 4}}, |
474 | llvm::DenseMap<int, int>{{0, 4}}); |
475 | REGISTER_SPMD(MaxPoolGradGrad, TF::MaxPoolGradGradOp, DataparallelSPMDExpander, |
476 | llvm::DenseMap<int, int>{{0, 3}, {1, 3}, {2, 3}}, |
477 | llvm::DenseMap<int, int>{{0, 3}}); |
478 | REGISTER_SPMD(ResizeBilinearGrad, TF::ResizeBilinearGradOp, |
479 | DataparallelSPMDExpander, |
480 | llvm::DenseMap<int, int>{{0, 3}, {1, 3}}, |
481 | llvm::DenseMap<int, int>{{0, 3}}); |
482 | REGISTER_SPMD(ResizeNearestNeighborGrad, TF::ResizeNearestNeighborGradOp, |
483 | DataparallelSPMDExpander, llvm::DenseMap<int, int>{{0, 3}}, |
484 | llvm::DenseMap<int, int>{{0, 3}}); |
485 | |
486 | // DiagPart |
487 | REGISTER_SPMD(DiagPart, TF::DiagPartOp, ReplicatedOpSPMDExpander, |
488 | /*relayout_when_sharded=*/true); |
489 | |
490 | // Cumsum |
491 | REGISTER_SPMD(Cumsum, TF::CumsumOp, CumsumSPMDExpander); |
492 | |
493 | // SparseToDenseOp |
494 | REGISTER_SPMD(SparseToDense, TF::SparseToDenseOp, SparseToDenseSPMDExpander); |
495 | |
496 | // StringFormat |
497 | REGISTER_SPMD(StringFormat, TF::StringFormatOp, ReplicatedOpSPMDExpander, |
498 | /*relayout_when_sharded=*/true); |
499 | |
500 | // TensorList ops |
501 | REGISTER_SPMD(TensorListReserve, TF::TensorListReserveOp, |
502 | TensorListReserveSPMDExpander); |
503 | REGISTER_SPMD(TensorListGetItem, TF::TensorListGetItemOp, |
504 | TensorListGetItemSPMDExpander); |
505 | REGISTER_SPMD(TensorListSetItem, TF::TensorListSetItemOp, |
506 | TensorListSetItemSPMDExpander); |
507 | |
508 | // IO ops |
509 | REGISTER_SPMD(WriteSummary, TF::WriteSummaryOp, IOOpSPMDExpander); |
510 | REGISTER_SPMD(DisableCopyOnRead, TF::DisableCopyOnReadOp, |
511 | DisableCopyOnReadSPMDExpander); |
512 | REGISTER_SPMD(ShardedFilename, TF::ShardedFilenameOp, ReplicatedOpSPMDExpander); |
513 | } // namespace dtensor |
514 | } // namespace tensorflow |
515 | |