1/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include <string>
17
18#include "tensorflow/core/framework/common_shape_fns.h"
19#include "tensorflow/core/framework/op.h"
20#include "tensorflow/core/framework/shape_inference.h"
21#include "tensorflow/dtensor/cc/tensor_layout.h"
22
23namespace tensorflow {
24namespace dtensor {
25
26REGISTER_OP("DTensorAllReduce")
27 .Input("input: T")
28 .Input("group_assignment: int32")
29 .Output("output: T")
30 .Attr("T: {half, bfloat16, float, int32, uint32, int64, bool}")
31 .Attr("reduce_op: {'Min', 'Max', 'Mul', 'Add', 'Mean', 'Any', 'All'}")
32 .Attr("device_type: string") // e.g. "/device:TPU"
33 .SetShapeFn(shape_inference::UnchangedShape);
34
35REGISTER_OP("DTensorReduceScatter")
36 .Input("input: T")
37 .Input("group_assignment: int32")
38 .Input("scatter_dimension: int32")
39 .Output("output: T")
40 .Attr("T: {half, bfloat16, float, int32, uint32, int64, bool}")
41 .Attr("reduce_op: {'Min', 'Max', 'Mul', 'Add', 'Mean', 'Any', 'All'}")
42 .Attr("device_type: string") // e.g. "/device:TPU"
43 .SetShapeFn(shape_inference::ReduceScatterShape);
44
45REGISTER_OP("DTensorAllScatter")
46 .Input("input: T")
47 .Output("output: T")
48 .Attr("T: {half, bfloat16, float, int32, uint32, int64, bool}")
49 .Attr("input_layout: string")
50 .Attr("output_layout: string")
51 .SetShapeFn([](shape_inference::InferenceContext* c) -> Status {
52 shape_inference::ShapeHandle in = c->input(0);
53 if (!c->RankKnown(in)) {
54 // Input shape unknown, so set unknown output shape.
55 c->set_output(0, in);
56 return OkStatus();
57 }
58
59 std::string input_layout_string;
60 std::string output_layout_string;
61 TF_RETURN_IF_ERROR(c->GetAttr("input_layout", &input_layout_string));
62 TF_RETURN_IF_ERROR(c->GetAttr("output_layout", &output_layout_string));
63 TF_ASSIGN_OR_RETURN(Layout input_layout,
64 Layout::FromString(input_layout_string));
65 TF_ASSIGN_OR_RETURN(Layout output_layout,
66 Layout::FromString(output_layout_string));
67 if (c->Rank(in) != input_layout.rank() ||
68 c->Rank(in) != output_layout.rank()) {
69 return errors::InvalidArgument(
70 "Input tensor rank and layout ranks do not agree: input rank ",
71 c->Rank(in), " input layout rank ", input_layout.rank(),
72 " output "
73 "layout rank ",
74 output_layout.rank());
75 }
76 const std::vector<int32> output_sharding = output_layout.num_shards();
77 std::vector<shape_inference::DimensionHandle> out_dims;
78 out_dims.reserve(c->Rank(in));
79 for (int i = 0; i < c->Rank(in); ++i) {
80 shape_inference::DimensionHandle dim = c->Dim(in, i);
81 if (!c->ValueKnown(dim) ||
82 input_layout.sharding_spec(i) == output_layout.sharding_spec(i)) {
83 out_dims.emplace_back(dim);
84 } else if (Layout::IsUnshardedDimension(
85 input_layout.sharding_spec(i))) {
86 shape_inference::DimensionHandle out_dim;
87 TF_RETURN_IF_ERROR(c->Divide(dim, output_sharding[i],
88 /*evenly_divisible=*/true, &out_dim));
89 out_dims.push_back(out_dim);
90 } else {
91 return errors::InvalidArgument(
92 "DTensorAllScatter only supports output layouts which are more "
93 "sharded than input layouts. Received input sharding spec ",
94 input_layout.sharding_spec(i), " and output sharding spec ",
95 output_layout.sharding_spec(i), " for dimension ", i, ".");
96 }
97 }
98 c->set_output(0, c->MakeShape(out_dims));
99 return OkStatus();
100 });
101
102REGISTER_OP("DTensorAllGather")
103 .Input("input: T")
104 .Output("output: T")
105 .Attr("T: {half, bfloat16, float, int32, uint32, int64, bool}")
106 .Attr("input_layout: string")
107 .Attr("output_layout: string")
108 .SetShapeFn([](shape_inference::InferenceContext* c) -> Status {
109 shape_inference::ShapeHandle in = c->input(0);
110 if (!c->RankKnown(in)) {
111 // Input shape unknown, so set unknown output shape.
112 c->set_output(0, in);
113 return OkStatus();
114 }
115
116 std::string input_layout_string;
117 std::string output_layout_string;
118 TF_RETURN_IF_ERROR(c->GetAttr("input_layout", &input_layout_string));
119 TF_RETURN_IF_ERROR(c->GetAttr("output_layout", &output_layout_string));
120 TF_ASSIGN_OR_RETURN(Layout input_layout,
121 Layout::FromString(input_layout_string));
122 TF_ASSIGN_OR_RETURN(Layout output_layout,
123 Layout::FromString(output_layout_string));
124 if (c->Rank(in) != input_layout.rank() ||
125 c->Rank(in) != output_layout.rank()) {
126 return errors::InvalidArgument(
127 "Input tensor rank and layout ranks do not agree: input rank ",
128 c->Rank(in), " input layout rank ", input_layout.rank(),
129 " output "
130 "layout rank ",
131 output_layout.rank());
132 }
133 const std::vector<int32> input_sharding = input_layout.num_shards();
134 std::vector<shape_inference::DimensionHandle> out_dims;
135 out_dims.reserve(c->Rank(in));
136 for (int32 i = 0; i < c->Rank(in); ++i) {
137 shape_inference::DimensionHandle dim = c->Dim(in, i);
138 if (!c->ValueKnown(dim) ||
139 input_layout.sharding_spec(i) == output_layout.sharding_spec(i)) {
140 out_dims.emplace_back(dim);
141 } else if (Layout::IsUnshardedDimension(
142 output_layout.sharding_spec(i))) {
143 shape_inference::DimensionHandle out_dim;
144 TF_RETURN_IF_ERROR(c->Multiply(dim, input_sharding[i], &out_dim));
145 out_dims.push_back(out_dim);
146 } else {
147 return errors::InvalidArgument(
148 "DTensorAllGatherr only supports input layouts which are more "
149 "sharded than output layouts. Received input sharding spec ",
150 input_layout.sharding_spec(i), " and output sharding spec ",
151 output_layout.sharding_spec(i), " for dimension ", i, ".");
152 }
153 }
154 c->set_output(0, c->MakeShape(out_dims));
155 return OkStatus();
156 });
157
158} // namespace dtensor
159} // namespace tensorflow
160