1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/framework/common_shape_fns.h" |
17 | #include "tensorflow/core/framework/op.h" |
18 | #include "tensorflow/core/framework/shape_inference.h" |
19 | |
20 | namespace tensorflow { |
21 | |
22 | using shape_inference::InferenceContext; |
23 | using shape_inference::ShapeHandle; |
24 | |
25 | REGISTER_OP("TPUReplicateMetadata" ) |
26 | .Attr("num_replicas: int >= 0" ) |
27 | .Attr("num_cores_per_replica: int = 1" ) |
28 | .Attr("topology: string = \"\"" ) |
29 | .Attr("use_tpu: bool = true" ) |
30 | .Attr("device_assignment: list(int) = []" ) |
31 | // Deprecated. Use num_cores_per_replica instead. |
32 | .Attr("computation_shape: list(int) = []" ) |
33 | .Attr("host_compute_core: list(string) = []" ) |
34 | .Attr("padding_map: list(string) = []" ) // Deprecated. |
35 | .Attr("step_marker_location: string = \"STEP_MARK_AT_ENTRY\"" ) |
36 | .Attr("allow_soft_placement: bool = false" ) |
37 | .Attr("use_spmd_for_xla_partitioning: bool = false" ) |
38 | .Attr("tpu_compile_options_proto: string = \"\"" ) |
39 | .SetShapeFn(shape_inference::UnknownShape); |
40 | |
41 | REGISTER_OP("TPUReplicatedInput" ) |
42 | .Input("inputs: N * T" ) |
43 | .Output("output: T" ) |
44 | .Attr("N: int >= 1" ) |
45 | .Attr("T: type" ) |
46 | .Attr("is_mirrored_variable: bool = false" ) |
47 | // `index` attribute is unused |
48 | .Attr("index: int = -1" ) |
49 | // All inputs are packed into one input |
50 | .Attr("is_packed: bool = false" ) |
51 | .SetShapeFn([](InferenceContext* c) { |
52 | ShapeHandle cur = c->input(c->num_inputs() - 1); |
53 | for (int i = c->num_inputs() - 2; i >= 0; --i) { |
54 | TF_RETURN_WITH_CONTEXT_IF_ERROR(c->Merge(c->input(i), cur, &cur), |
55 | "From merging shape " , i, |
56 | " with other shapes." ); |
57 | } |
58 | c->set_output(0, cur); |
59 | |
60 | // If this is a resource, unify the resource shapes. |
61 | DataType dtype; |
62 | TF_RETURN_IF_ERROR(c->GetAttr("T" , &dtype)); |
63 | if (dtype == DT_RESOURCE) { |
64 | const std::vector<shape_inference::ShapeAndType>* shapes_and_types = |
65 | nullptr; |
66 | for (int i = c->num_inputs() - 1; i >= 0; --i) { |
67 | if (shapes_and_types) { |
68 | // The return value of MergeInputHandleShapesAndTypes indicates |
69 | // the shape was refined, not that there was an error. |
70 | // TODO(phawkins): there seems to be no way to discover errors. |
71 | (void)!c->MergeInputHandleShapesAndTypes(i, *shapes_and_types); |
72 | } else { |
73 | shapes_and_types = c->input_handle_shapes_and_types(i); |
74 | } |
75 | } |
76 | if (shapes_and_types) { |
77 | c->set_output_handle_shapes_and_types(0, *shapes_and_types); |
78 | } |
79 | } |
80 | return OkStatus(); |
81 | }); |
82 | |
83 | REGISTER_OP("TPUReplicatedOutput" ) |
84 | .Input("input: T" ) |
85 | .Output("outputs: num_replicas * T" ) |
86 | .Attr("num_replicas: int >= 1" ) |
87 | .Attr("T: type" ) |
88 | .SetShapeFn([](InferenceContext* c) { |
89 | for (int i = 0; i < c->num_outputs(); ++i) { |
90 | c->set_output(i, c->input(0)); |
91 | } |
92 | return OkStatus(); |
93 | }); |
94 | |
95 | REGISTER_OP("TPUCompilationResult" ) |
96 | .Output("output: string" ) |
97 | .SetShapeFn(shape_inference::ScalarShape); |
98 | |
99 | REGISTER_OP("_TPUReplicate" ) |
100 | .Attr("computation: func" ) |
101 | .Attr("num_replicas: int >= 1" ) |
102 | .Attr("num_cores_per_replica: int = 1" ) |
103 | .Attr("topology: string = \"\"" ) |
104 | .Attr("use_tpu: bool = true" ) |
105 | .Attr("device_assignment: list(int) = []" ) |
106 | .Attr("host_compute_core: list(string) = []" ) |
107 | .Attr("Tinputs: list(type) >= 0" ) |
108 | .Attr("Tbroadcast_inputs: list(type) >= 0" ) |
109 | .Attr("NumVariables: int >= 0" ) |
110 | .Attr("Tguaranteed_constants: list(type) >= 0" ) |
111 | .Attr("output_types: list(type) >= 0" ) |
112 | .Attr("padding_map: list(string) = []" ) // Deprecated. |
113 | .Attr("step_marker_location: string = \"STEP_MARK_AT_ENTRY\"" ) |
114 | .Attr("allow_soft_placement: bool = false" ) |
115 | .Attr("num_distributed_variables: int = 0" ) |
116 | .Attr("use_spmd_for_xla_partitioning: bool = false" ) |
117 | .Attr("tpu_compile_options_proto: string = \"\"" ) |
118 | .Input("inputs: Tinputs" ) |
119 | .Input("broadcast_inputs: Tbroadcast_inputs" ) |
120 | .Input("variables: NumVariables * resource" ) |
121 | .Input("guaranteed_constants: Tguaranteed_constants" ) |
122 | .Output("outputs: output_types" ) |
123 | .SetShapeFn(shape_inference::UnknownShape); |
124 | |
125 | } // namespace tensorflow |
126 | |