1 | /* Copyright 2022 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/kernels/variant_ops_util.h" |
17 | |
18 | #include <functional> |
19 | #include <utility> |
20 | |
21 | #include "tensorflow/core/framework/op_kernel.h" |
22 | #include "tensorflow/core/framework/tensor.h" |
23 | #include "tensorflow/core/framework/variant.h" |
24 | #include "tensorflow/core/lib/core/status.h" |
25 | |
26 | namespace tensorflow { |
27 | // AddVariantTo efficiently performs: |
28 | // temp[lhs_ix] <- array(lhs_ix) + array(rhs_ix) |
29 | // where array(ix) := (temp_filled[ix] |
30 | // ? temp[ix] |
31 | // : ctx->input(ix).scalar<Variant>()()) |
32 | // This reduces (possibly expensive) copying of Variants from |
33 | // the inputs into temp at the lowest levels of the summation tree. |
34 | static inline Status AddVariantTo( |
35 | OpKernelContext* ctx, const int lhs_ix, const int rhs_ix, |
36 | gtl::InlinedVector<Variant, 4>* temp, |
37 | gtl::InlinedVector<bool, 4>* temp_filled, |
38 | std::function<Status(OpKernelContext*, const Variant&, const Variant&, |
39 | Variant*)> |
40 | binary_add_variant) { |
41 | Variant tmp; |
42 | if (temp_filled->at(lhs_ix)) tmp = std::move(temp->at(lhs_ix)); |
43 | const Variant& a = temp_filled->at(lhs_ix) |
44 | ? tmp |
45 | : ctx->input(lhs_ix).template scalar<Variant>()(); |
46 | const Variant& b = temp_filled->at(rhs_ix) |
47 | ? temp->at(rhs_ix) |
48 | : ctx->input(rhs_ix).template scalar<Variant>()(); |
49 | Variant* c = &temp->at(lhs_ix); |
50 | TF_RETURN_IF_ERROR(binary_add_variant(ctx, a, b, c)); |
51 | temp_filled->at(lhs_ix) = true; |
52 | return OkStatus(); |
53 | } |
54 | |
55 | void AddNVariant(OpKernelContext* ctx, |
56 | std::function<Status(OpKernelContext*, const Variant&, |
57 | const Variant&, Variant*)> |
58 | binary_add_variant) { |
59 | const Tensor& input0 = ctx->input(0); |
60 | const int num = ctx->num_inputs(); |
61 | |
62 | if (num == 1) { |
63 | ctx->set_output(0, input0); |
64 | return; |
65 | } |
66 | |
67 | for (int i = 0; i < num; ++i) { |
68 | // Step 1: ensure unary variants. |
69 | OP_REQUIRES( |
70 | ctx, ctx->input(i).dims() == 0, |
71 | errors::InvalidArgument( |
72 | "AddN of non-scalar Tensor with dtype=DT_VARIANT is not " |
73 | "supported; inputs[" , |
74 | i, " has shape: " , ctx->input(i).shape().DebugString(), "." )); |
75 | } |
76 | |
77 | // Step 2: Sum input variants in a tree-like structure using |
78 | // BinaryOpVariants(ADD_VARIANT_BINARY_OP, ...) |
79 | // For the output create a default-constructed variant object. |
80 | // |
81 | // Pairwise summation provides better numerical precision by |
82 | // reducing round-off error: |
83 | // |
84 | // https://en.wikipedia.org/wiki/Pairwise_summation |
85 | // |
86 | // These two vectors are used to store and mark intermediate sums. |
87 | gtl::InlinedVector<bool, 4> temp_filled(num, false); |
88 | gtl::InlinedVector<Variant, 4> temp(num); |
89 | |
90 | // Tree-based summation. |
91 | int skip = 1; |
92 | int n = num; |
93 | while (skip < n) { |
94 | int i = skip; |
95 | while (i < n) { |
96 | // TODO(ebrevdo, rmlarsen): Parallelize the pairwise summations in the |
97 | // inner loop if the variants are "large". |
98 | |
99 | // x[i - skip] += x[i] |
100 | OP_REQUIRES_OK(ctx, AddVariantTo(ctx, i - skip, i, &temp, &temp_filled, |
101 | binary_add_variant)); |
102 | // We won't use this index again, recover its memory. |
103 | temp[i].clear(); |
104 | i += 2 * skip; |
105 | } |
106 | if (i == n) { |
107 | // x[0] += x[i - skip] |
108 | OP_REQUIRES_OK(ctx, AddVariantTo(ctx, 0, i - skip, &temp, &temp_filled, |
109 | binary_add_variant)); |
110 | // We won't use this index again, recover its memory. |
111 | temp[i - skip].clear(); |
112 | n -= skip; |
113 | } |
114 | skip *= 2; |
115 | } |
116 | |
117 | Tensor out(cpu_allocator(), DT_VARIANT, TensorShape({})); |
118 | out.scalar<Variant>()() = std::move(temp[0]); |
119 | ctx->set_output(0, out); |
120 | } |
121 | } // end namespace tensorflow |
122 | |