1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/framework/common_shape_fns.h" |
17 | #include "tensorflow/core/framework/op.h" |
18 | #include "tensorflow/core/framework/shape_inference.h" |
19 | |
20 | namespace tensorflow { |
21 | using shape_inference::DimensionHandle; |
22 | using shape_inference::InferenceContext; |
23 | using shape_inference::ShapeHandle; |
24 | |
25 | REGISTER_OP("AllToAll" ) |
26 | .Input("input: T" ) |
27 | .Input("group_assignment: int32" ) |
28 | .Output("output: T" ) |
29 | .Attr("T: {numbertype, bool}" ) |
30 | .Attr("concat_dimension: int" ) |
31 | .Attr("split_dimension: int" ) |
32 | .Attr("split_count: int" ) |
33 | .SetIsStateful() |
34 | .SetShapeFn([](InferenceContext* c) { |
35 | ShapeHandle input = c->input(0); |
36 | ShapeHandle group_assignment = c->input(1); |
37 | if (!c->RankKnown(input)) { |
38 | c->set_output(0, c->UnknownShape()); |
39 | return OkStatus(); |
40 | } |
41 | |
42 | int64_t rank = c->Rank(input); |
43 | int concat_dimension; |
44 | int split_dimension; |
45 | int split_count; |
46 | TF_RETURN_IF_ERROR(c->GetAttr("split_count" , &split_count)); |
47 | if (split_count < 1) { |
48 | return errors::InvalidArgument("split_count " , split_count, |
49 | " must at least be one." ); |
50 | } |
51 | if (c->RankKnown(group_assignment) && c->Rank(group_assignment) != 2) { |
52 | return errors::InvalidArgument("group_assignment must have rank 2." ); |
53 | } |
54 | DimensionHandle num_replicas_per_group = c->Dim(group_assignment, 1); |
55 | if (c->ValueKnown(num_replicas_per_group) && |
56 | (c->Value(num_replicas_per_group) != split_count)) { |
57 | return errors::InvalidArgument( |
58 | "split_count " , split_count, |
59 | " must equal the size of the second dimension of group_assignment " , |
60 | c->Value(num_replicas_per_group)); |
61 | } |
62 | |
63 | TF_RETURN_IF_ERROR(c->GetAttr("concat_dimension" , &concat_dimension)); |
64 | |
65 | if (concat_dimension < 0 || concat_dimension >= rank) { |
66 | return errors::InvalidArgument("concat_dimension " , concat_dimension, |
67 | " is out of range of input rank " , rank); |
68 | } |
69 | |
70 | TF_RETURN_IF_ERROR(c->GetAttr("split_dimension" , &split_dimension)); |
71 | if (split_dimension < 0 || split_dimension >= rank) { |
72 | return errors::InvalidArgument("split_dimension " , split_dimension, |
73 | " is out of range of input rank " , rank); |
74 | } |
75 | |
76 | if (!c->ValueKnown(c->Dim(input, concat_dimension)) || |
77 | !c->ValueKnown(c->Dim(input, split_dimension))) { |
78 | c->set_output(0, c->UnknownShape()); |
79 | return OkStatus(); |
80 | } |
81 | |
82 | std::vector<DimensionHandle> dims; |
83 | dims.resize(rank); |
84 | |
85 | for (int32_t i = 0; i < rank; ++i) { |
86 | dims[i] = c->Dim(input, i); |
87 | if (i == concat_dimension) { |
88 | dims[i] = c->MakeDim(c->Value(dims[i]) * split_count); |
89 | } |
90 | if (i == split_dimension) { |
91 | if (c->ValueKnown(dims[i]) && |
92 | (c->Value(dims[i]) % split_count != 0)) { |
93 | return errors::InvalidArgument( |
94 | "input dimension " , c->Value(dims[i]), |
95 | " not divisible by split_count " , split_count); |
96 | } |
97 | dims[i] = c->MakeDim(c->Value(dims[i]) / split_count); |
98 | } |
99 | } |
100 | |
101 | c->set_output(0, c->MakeShape(dims)); |
102 | return OkStatus(); |
103 | }); |
104 | |
105 | REGISTER_OP("CrossReplicaSum" ) |
106 | .Input("input: T" ) |
107 | .Input("group_assignment: int32" ) |
108 | .Output("output: T" ) |
109 | .Attr("T: {half, bfloat16, float, float64, int32, uint32}" ) |
110 | .SetIsStateful() |
111 | .SetShapeFn(shape_inference::UnchangedShape); |
112 | |
113 | REGISTER_OP("CollectivePermute" ) |
114 | .Input("input: T" ) |
115 | .Input("source_target_pairs: int32" ) |
116 | .Output("output: T" ) |
117 | .Attr("T: numbertype" ) |
118 | .SetIsStateful() |
119 | .SetShapeFn(shape_inference::UnchangedShape); |
120 | } // namespace tensorflow |
121 | |