1 | /* Copyright 2022 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include <string> |
17 | |
18 | #include "tensorflow/core/framework/common_shape_fns.h" |
19 | #include "tensorflow/core/framework/op.h" |
20 | #include "tensorflow/core/framework/shape_inference.h" |
21 | #include "tensorflow/dtensor/cc/tensor_layout.h" |
22 | |
23 | namespace tensorflow { |
24 | namespace dtensor { |
25 | |
26 | REGISTER_OP("DTensorAllReduce" ) |
27 | .Input("input: T" ) |
28 | .Input("group_assignment: int32" ) |
29 | .Output("output: T" ) |
30 | .Attr("T: {half, bfloat16, float, int32, uint32, int64, bool}" ) |
31 | .Attr("reduce_op: {'Min', 'Max', 'Mul', 'Add', 'Mean', 'Any', 'All'}" ) |
32 | .Attr("device_type: string" ) // e.g. "/device:TPU" |
33 | .SetShapeFn(shape_inference::UnchangedShape); |
34 | |
35 | REGISTER_OP("DTensorReduceScatter" ) |
36 | .Input("input: T" ) |
37 | .Input("group_assignment: int32" ) |
38 | .Input("scatter_dimension: int32" ) |
39 | .Output("output: T" ) |
40 | .Attr("T: {half, bfloat16, float, int32, uint32, int64, bool}" ) |
41 | .Attr("reduce_op: {'Min', 'Max', 'Mul', 'Add', 'Mean', 'Any', 'All'}" ) |
42 | .Attr("device_type: string" ) // e.g. "/device:TPU" |
43 | .SetShapeFn(shape_inference::ReduceScatterShape); |
44 | |
45 | REGISTER_OP("DTensorAllScatter" ) |
46 | .Input("input: T" ) |
47 | .Output("output: T" ) |
48 | .Attr("T: {half, bfloat16, float, int32, uint32, int64, bool}" ) |
49 | .Attr("input_layout: string" ) |
50 | .Attr("output_layout: string" ) |
51 | .SetShapeFn([](shape_inference::InferenceContext* c) -> Status { |
52 | shape_inference::ShapeHandle in = c->input(0); |
53 | if (!c->RankKnown(in)) { |
54 | // Input shape unknown, so set unknown output shape. |
55 | c->set_output(0, in); |
56 | return OkStatus(); |
57 | } |
58 | |
59 | std::string input_layout_string; |
60 | std::string output_layout_string; |
61 | TF_RETURN_IF_ERROR(c->GetAttr("input_layout" , &input_layout_string)); |
62 | TF_RETURN_IF_ERROR(c->GetAttr("output_layout" , &output_layout_string)); |
63 | TF_ASSIGN_OR_RETURN(Layout input_layout, |
64 | Layout::FromString(input_layout_string)); |
65 | TF_ASSIGN_OR_RETURN(Layout output_layout, |
66 | Layout::FromString(output_layout_string)); |
67 | if (c->Rank(in) != input_layout.rank() || |
68 | c->Rank(in) != output_layout.rank()) { |
69 | return errors::InvalidArgument( |
70 | "Input tensor rank and layout ranks do not agree: input rank " , |
71 | c->Rank(in), " input layout rank " , input_layout.rank(), |
72 | " output " |
73 | "layout rank " , |
74 | output_layout.rank()); |
75 | } |
76 | const std::vector<int32> output_sharding = output_layout.num_shards(); |
77 | std::vector<shape_inference::DimensionHandle> out_dims; |
78 | out_dims.reserve(c->Rank(in)); |
79 | for (int i = 0; i < c->Rank(in); ++i) { |
80 | shape_inference::DimensionHandle dim = c->Dim(in, i); |
81 | if (!c->ValueKnown(dim) || |
82 | input_layout.sharding_spec(i) == output_layout.sharding_spec(i)) { |
83 | out_dims.emplace_back(dim); |
84 | } else if (Layout::IsUnshardedDimension( |
85 | input_layout.sharding_spec(i))) { |
86 | shape_inference::DimensionHandle out_dim; |
87 | TF_RETURN_IF_ERROR(c->Divide(dim, output_sharding[i], |
88 | /*evenly_divisible=*/true, &out_dim)); |
89 | out_dims.push_back(out_dim); |
90 | } else { |
91 | return errors::InvalidArgument( |
92 | "DTensorAllScatter only supports output layouts which are more " |
93 | "sharded than input layouts. Received input sharding spec " , |
94 | input_layout.sharding_spec(i), " and output sharding spec " , |
95 | output_layout.sharding_spec(i), " for dimension " , i, "." ); |
96 | } |
97 | } |
98 | c->set_output(0, c->MakeShape(out_dims)); |
99 | return OkStatus(); |
100 | }); |
101 | |
102 | REGISTER_OP("DTensorAllGather" ) |
103 | .Input("input: T" ) |
104 | .Output("output: T" ) |
105 | .Attr("T: {half, bfloat16, float, int32, uint32, int64, bool}" ) |
106 | .Attr("input_layout: string" ) |
107 | .Attr("output_layout: string" ) |
108 | .SetShapeFn([](shape_inference::InferenceContext* c) -> Status { |
109 | shape_inference::ShapeHandle in = c->input(0); |
110 | if (!c->RankKnown(in)) { |
111 | // Input shape unknown, so set unknown output shape. |
112 | c->set_output(0, in); |
113 | return OkStatus(); |
114 | } |
115 | |
116 | std::string input_layout_string; |
117 | std::string output_layout_string; |
118 | TF_RETURN_IF_ERROR(c->GetAttr("input_layout" , &input_layout_string)); |
119 | TF_RETURN_IF_ERROR(c->GetAttr("output_layout" , &output_layout_string)); |
120 | TF_ASSIGN_OR_RETURN(Layout input_layout, |
121 | Layout::FromString(input_layout_string)); |
122 | TF_ASSIGN_OR_RETURN(Layout output_layout, |
123 | Layout::FromString(output_layout_string)); |
124 | if (c->Rank(in) != input_layout.rank() || |
125 | c->Rank(in) != output_layout.rank()) { |
126 | return errors::InvalidArgument( |
127 | "Input tensor rank and layout ranks do not agree: input rank " , |
128 | c->Rank(in), " input layout rank " , input_layout.rank(), |
129 | " output " |
130 | "layout rank " , |
131 | output_layout.rank()); |
132 | } |
133 | const std::vector<int32> input_sharding = input_layout.num_shards(); |
134 | std::vector<shape_inference::DimensionHandle> out_dims; |
135 | out_dims.reserve(c->Rank(in)); |
136 | for (int32 i = 0; i < c->Rank(in); ++i) { |
137 | shape_inference::DimensionHandle dim = c->Dim(in, i); |
138 | if (!c->ValueKnown(dim) || |
139 | input_layout.sharding_spec(i) == output_layout.sharding_spec(i)) { |
140 | out_dims.emplace_back(dim); |
141 | } else if (Layout::IsUnshardedDimension( |
142 | output_layout.sharding_spec(i))) { |
143 | shape_inference::DimensionHandle out_dim; |
144 | TF_RETURN_IF_ERROR(c->Multiply(dim, input_sharding[i], &out_dim)); |
145 | out_dims.push_back(out_dim); |
146 | } else { |
147 | return errors::InvalidArgument( |
148 | "DTensorAllGatherr only supports input layouts which are more " |
149 | "sharded than output layouts. Received input sharding spec " , |
150 | input_layout.sharding_spec(i), " and output sharding spec " , |
151 | output_layout.sharding_spec(i), " for dimension " , i, "." ); |
152 | } |
153 | } |
154 | c->set_output(0, c->MakeShape(out_dims)); |
155 | return OkStatus(); |
156 | }); |
157 | |
158 | } // namespace dtensor |
159 | } // namespace tensorflow |
160 | |