1/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/kernels/variant_ops_util.h"
17
18#include <functional>
19#include <utility>
20
21#include "tensorflow/core/framework/op_kernel.h"
22#include "tensorflow/core/framework/tensor.h"
23#include "tensorflow/core/framework/variant.h"
24#include "tensorflow/core/lib/core/status.h"
25
26namespace tensorflow {
27// AddVariantTo efficiently performs:
28// temp[lhs_ix] <- array(lhs_ix) + array(rhs_ix)
29// where array(ix) := (temp_filled[ix]
30// ? temp[ix]
31// : ctx->input(ix).scalar<Variant>()())
32// This reduces (possibly expensive) copying of Variants from
33// the inputs into temp at the lowest levels of the summation tree.
34static inline Status AddVariantTo(
35 OpKernelContext* ctx, const int lhs_ix, const int rhs_ix,
36 gtl::InlinedVector<Variant, 4>* temp,
37 gtl::InlinedVector<bool, 4>* temp_filled,
38 std::function<Status(OpKernelContext*, const Variant&, const Variant&,
39 Variant*)>
40 binary_add_variant) {
41 Variant tmp;
42 if (temp_filled->at(lhs_ix)) tmp = std::move(temp->at(lhs_ix));
43 const Variant& a = temp_filled->at(lhs_ix)
44 ? tmp
45 : ctx->input(lhs_ix).template scalar<Variant>()();
46 const Variant& b = temp_filled->at(rhs_ix)
47 ? temp->at(rhs_ix)
48 : ctx->input(rhs_ix).template scalar<Variant>()();
49 Variant* c = &temp->at(lhs_ix);
50 TF_RETURN_IF_ERROR(binary_add_variant(ctx, a, b, c));
51 temp_filled->at(lhs_ix) = true;
52 return OkStatus();
53}
54
55void AddNVariant(OpKernelContext* ctx,
56 std::function<Status(OpKernelContext*, const Variant&,
57 const Variant&, Variant*)>
58 binary_add_variant) {
59 const Tensor& input0 = ctx->input(0);
60 const int num = ctx->num_inputs();
61
62 if (num == 1) {
63 ctx->set_output(0, input0);
64 return;
65 }
66
67 for (int i = 0; i < num; ++i) {
68 // Step 1: ensure unary variants.
69 OP_REQUIRES(
70 ctx, ctx->input(i).dims() == 0,
71 errors::InvalidArgument(
72 "AddN of non-scalar Tensor with dtype=DT_VARIANT is not "
73 "supported; inputs[",
74 i, " has shape: ", ctx->input(i).shape().DebugString(), "."));
75 }
76
77 // Step 2: Sum input variants in a tree-like structure using
78 // BinaryOpVariants(ADD_VARIANT_BINARY_OP, ...)
79 // For the output create a default-constructed variant object.
80 //
81 // Pairwise summation provides better numerical precision by
82 // reducing round-off error:
83 //
84 // https://en.wikipedia.org/wiki/Pairwise_summation
85 //
86 // These two vectors are used to store and mark intermediate sums.
87 gtl::InlinedVector<bool, 4> temp_filled(num, false);
88 gtl::InlinedVector<Variant, 4> temp(num);
89
90 // Tree-based summation.
91 int skip = 1;
92 int n = num;
93 while (skip < n) {
94 int i = skip;
95 while (i < n) {
96 // TODO(ebrevdo, rmlarsen): Parallelize the pairwise summations in the
97 // inner loop if the variants are "large".
98
99 // x[i - skip] += x[i]
100 OP_REQUIRES_OK(ctx, AddVariantTo(ctx, i - skip, i, &temp, &temp_filled,
101 binary_add_variant));
102 // We won't use this index again, recover its memory.
103 temp[i].clear();
104 i += 2 * skip;
105 }
106 if (i == n) {
107 // x[0] += x[i - skip]
108 OP_REQUIRES_OK(ctx, AddVariantTo(ctx, 0, i - skip, &temp, &temp_filled,
109 binary_add_variant));
110 // We won't use this index again, recover its memory.
111 temp[i - skip].clear();
112 n -= skip;
113 }
114 skip *= 2;
115 }
116
117 Tensor out(cpu_allocator(), DT_VARIANT, TensorShape({}));
118 out.scalar<Variant>()() = std::move(temp[0]);
119 ctx->set_output(0, out);
120}
121} // end namespace tensorflow
122