1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/framework/common_shape_fns.h"
17#include "tensorflow/core/framework/op.h"
18#include "tensorflow/core/framework/shape_inference.h"
19
20namespace tensorflow {
21using shape_inference::DimensionHandle;
22using shape_inference::InferenceContext;
23using shape_inference::ShapeHandle;
24
25REGISTER_OP("AllToAll")
26 .Input("input: T")
27 .Input("group_assignment: int32")
28 .Output("output: T")
29 .Attr("T: {numbertype, bool}")
30 .Attr("concat_dimension: int")
31 .Attr("split_dimension: int")
32 .Attr("split_count: int")
33 .SetIsStateful()
34 .SetShapeFn([](InferenceContext* c) {
35 ShapeHandle input = c->input(0);
36 ShapeHandle group_assignment = c->input(1);
37 if (!c->RankKnown(input)) {
38 c->set_output(0, c->UnknownShape());
39 return OkStatus();
40 }
41
42 int64_t rank = c->Rank(input);
43 int concat_dimension;
44 int split_dimension;
45 int split_count;
46 TF_RETURN_IF_ERROR(c->GetAttr("split_count", &split_count));
47 if (split_count < 1) {
48 return errors::InvalidArgument("split_count ", split_count,
49 " must at least be one.");
50 }
51 if (c->RankKnown(group_assignment) && c->Rank(group_assignment) != 2) {
52 return errors::InvalidArgument("group_assignment must have rank 2.");
53 }
54 DimensionHandle num_replicas_per_group = c->Dim(group_assignment, 1);
55 if (c->ValueKnown(num_replicas_per_group) &&
56 (c->Value(num_replicas_per_group) != split_count)) {
57 return errors::InvalidArgument(
58 "split_count ", split_count,
59 " must equal the size of the second dimension of group_assignment ",
60 c->Value(num_replicas_per_group));
61 }
62
63 TF_RETURN_IF_ERROR(c->GetAttr("concat_dimension", &concat_dimension));
64
65 if (concat_dimension < 0 || concat_dimension >= rank) {
66 return errors::InvalidArgument("concat_dimension ", concat_dimension,
67 " is out of range of input rank ", rank);
68 }
69
70 TF_RETURN_IF_ERROR(c->GetAttr("split_dimension", &split_dimension));
71 if (split_dimension < 0 || split_dimension >= rank) {
72 return errors::InvalidArgument("split_dimension ", split_dimension,
73 " is out of range of input rank ", rank);
74 }
75
76 if (!c->ValueKnown(c->Dim(input, concat_dimension)) ||
77 !c->ValueKnown(c->Dim(input, split_dimension))) {
78 c->set_output(0, c->UnknownShape());
79 return OkStatus();
80 }
81
82 std::vector<DimensionHandle> dims;
83 dims.resize(rank);
84
85 for (int32_t i = 0; i < rank; ++i) {
86 dims[i] = c->Dim(input, i);
87 if (i == concat_dimension) {
88 dims[i] = c->MakeDim(c->Value(dims[i]) * split_count);
89 }
90 if (i == split_dimension) {
91 if (c->ValueKnown(dims[i]) &&
92 (c->Value(dims[i]) % split_count != 0)) {
93 return errors::InvalidArgument(
94 "input dimension ", c->Value(dims[i]),
95 " not divisible by split_count ", split_count);
96 }
97 dims[i] = c->MakeDim(c->Value(dims[i]) / split_count);
98 }
99 }
100
101 c->set_output(0, c->MakeShape(dims));
102 return OkStatus();
103 });
104
105REGISTER_OP("CrossReplicaSum")
106 .Input("input: T")
107 .Input("group_assignment: int32")
108 .Output("output: T")
109 .Attr("T: {half, bfloat16, float, float64, int32, uint32}")
110 .SetIsStateful()
111 .SetShapeFn(shape_inference::UnchangedShape);
112
113REGISTER_OP("CollectivePermute")
114 .Input("input: T")
115 .Input("source_target_pairs: int32")
116 .Output("output: T")
117 .Attr("T: numbertype")
118 .SetIsStateful()
119 .SetShapeFn(shape_inference::UnchangedShape);
120} // namespace tensorflow
121