1/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/framework/op.h"
17#include "tensorflow/core/framework/op_kernel.h"
18#include "tensorflow/core/framework/op_requires.h"
19#include "tensorflow/core/framework/variant.h"
20#include "tensorflow/core/framework/variant_encode_decode.h"
21#include "tensorflow/core/kernels/composite_tensor_variant.h"
22#include "tensorflow/core/platform/errors.h"
23#include "tensorflow/core/protobuf/composite_tensor_variant.pb.h"
24#include "tensorflow/core/protobuf/struct.pb.h"
25
26namespace tensorflow {
27
28class CompositeTensorVariantFromComponents : public OpKernel {
29 public:
30 explicit CompositeTensorVariantFromComponents(OpKernelConstruction* context)
31 : OpKernel(context) {
32 string type_spec_string;
33 OP_REQUIRES_OK(context, context->GetAttr("metadata", &type_spec_string));
34 OP_REQUIRES(context, metadata_.ParseFromString(type_spec_string),
35 errors::InvalidArgument("Error parsing metadata"));
36 }
37
38 void Compute(OpKernelContext* context) override {
39 OpInputList components_in;
40 OP_REQUIRES_OK(context, context->input_list("components", &components_in));
41
42 Tensor* encoded;
43 OP_REQUIRES_OK(context,
44 context->allocate_output(0, TensorShape({}), &encoded));
45
46 std::vector<Tensor> components{components_in.begin(), components_in.end()};
47 encoded->flat<Variant>()(0) =
48 CompositeTensorVariant(metadata_, absl::MakeSpan(components));
49 }
50
51 private:
52 CompositeTensorVariantMetadata metadata_;
53};
54
55class CompositeTensorVariantToComponents : public OpKernel {
56 public:
57 explicit CompositeTensorVariantToComponents(OpKernelConstruction* context)
58 : OpKernel(context) {
59 string type_spec_string;
60 OP_REQUIRES_OK(context, context->GetAttr("metadata", &type_spec_string));
61 OP_REQUIRES(context, metadata_.ParseFromString(type_spec_string),
62 errors::InvalidArgument("Error parsing `metadata`"));
63
64 OP_REQUIRES_OK(context,
65 context->GetAttr("Tcomponents", &component_dtypes_));
66 }
67
68 void Compute(OpKernelContext* context) override {
69 Tensor encoded_t = context->input(0);
70 OP_REQUIRES(
71 context, encoded_t.flat<Variant>().size() > 0,
72 errors::InvalidArgument("Input `encoded` must not be an empty variant "
73 "tensor, but got ",
74 encoded_t.DebugString()));
75 auto* encoded = encoded_t.flat<Variant>()(0).get<CompositeTensorVariant>();
76 OP_REQUIRES(context, encoded != nullptr,
77 errors::InvalidArgument("The input `encoded` is not a valid "
78 "CompositeTensorVariant tensor, got ",
79 encoded_t.DebugString()));
80
81 // Check that the encoded TypeSpec is compatible with the expected TypeSpec.
82 // For now, we just check that the class matches.
83 //
84 // TODO(b/173744905): Update this to do a generic compatibility check. This
85 // would require replacing the current design, where Python subclasses of
86 // TypeSpec can override is_compatible, with a design where compatibility
87 // can be deterministically determined from the metadata.
88 auto expected_class = metadata_.type_spec_proto().type_spec_class();
89 auto actual_class = encoded->metadata().type_spec_proto().type_spec_class();
90 OP_REQUIRES(
91 context, expected_class == actual_class,
92 errors::InvalidArgument(
93 "Expected a ", TypeSpecProto::TypeSpecClass_Name(expected_class),
94 " (based on `type_spec`), but `encoded` contains a ",
95 TypeSpecProto::TypeSpecClass_Name(actual_class)));
96
97 // Extract the component tensors.
98 OpOutputList components;
99 OP_REQUIRES_OK(context, context->output_list("components", &components));
100 int num_components = encoded->flat_components().size();
101
102 OP_REQUIRES(context, component_dtypes_.size() == num_components,
103 errors::InvalidArgument("Encoded value has ", num_components,
104 " tensor components; expected ",
105 component_dtypes_.size(),
106 " components based on type_spec"));
107
108 for (int i = 0; i < component_dtypes_.size(); i++) {
109 const Tensor& component = encoded->flat_components()[i];
110 OP_REQUIRES(context, component_dtypes_[i] == component.dtype(),
111 errors::InvalidArgument("Tensor component ", i, " had dtype ",
112 DataType_Name(component.dtype()),
113 "; expected dtype ",
114 DataType_Name(component_dtypes_[i])));
115 components.set(i, component);
116 }
117 }
118
119 private:
120 CompositeTensorVariantMetadata metadata_;
121 std::vector<DataType> component_dtypes_;
122};
123
124REGISTER_KERNEL_BUILDER(
125 Name("CompositeTensorVariantToComponents").Device(DEVICE_CPU),
126 CompositeTensorVariantToComponents);
127REGISTER_KERNEL_BUILDER(
128 Name("CompositeTensorVariantFromComponents").Device(DEVICE_CPU),
129 CompositeTensorVariantFromComponents);
130
131} // namespace tensorflow
132