1/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include <sstream>
17
18#include "tensorflow/c/kernels.h"
19#include "tensorflow/c/ops.h"
20#include "tensorflow/c/tf_tensor.h"
21#include "tensorflow/core/framework/common_shape_fns.h"
22#include "tensorflow/core/framework/op.h"
23#include "tensorflow/core/framework/registration/registration.h"
24#include "tensorflow/core/framework/shape_inference.h"
25#include "tensorflow/core/platform/macros.h"
26
27// BitcastOp implements a bitcast kernel, creating an output tensor that shares
28// the same data buffer as the input but with a different shape and/or data
29// type. Its inputs are:
30//
31// * the input tensor
32// * an attribute named "T" containing the TF_DataType of the input tensor
33// * an attribute named "type" containing the TF_DataType of the output tensor
34//
35// Given an input tensor of shape [...], if the input DataType "T" is larger
36// than the output DataType "type", then the shape changes from [...]
37// to [..., sizeof(T)/sizeof(type)].
38//
39// If "T" is smaller than "type", the operator requires that the rightmost
40// dimension be equal to sizeof(type)/sizeof(T). The shape then goes from
41// [..., sizeof(type)/sizeof(T)] to [...].
42//
43// Bitcast is implemented as a low-level cast, so machines with different endian
44// orderings will give different results.
45typedef struct BitcastOp {
46 TF_DataType input_data_type;
47 TF_DataType output_data_type;
48 size_t in_size;
49 size_t out_size;
50} BitcastOp;
51
52static void* BitcastOp_Create(TF_OpKernelConstruction* ctx) {
53 auto* kernel = new BitcastOp;
54
55 TF_Status* s = TF_NewStatus();
56 TF_OpKernelConstruction_GetAttrType(ctx, "T", &kernel->input_data_type, s);
57
58 if (TF_GetCode(s) == TF_OK) {
59 TF_OpKernelConstruction_GetAttrType(ctx, "type", &kernel->output_data_type,
60 s);
61 }
62
63 if (TF_GetCode(s) == TF_OK) {
64 kernel->in_size = TF_DataTypeSize(kernel->input_data_type);
65 kernel->out_size = TF_DataTypeSize(kernel->output_data_type);
66
67 size_t check_size = std::max(kernel->in_size, kernel->out_size) %
68 std::min(kernel->in_size, kernel->out_size);
69 if (check_size != 0) {
70 std::ostringstream err;
71 err << "cannot convert between datatype " << kernel->input_data_type
72 << " and " << kernel->output_data_type;
73 TF_SetStatus(s, TF_INVALID_ARGUMENT, err.str().c_str());
74 }
75 }
76
77 if (TF_GetCode(s) != TF_OK) {
78 TF_OpKernelConstruction_Failure(ctx, s);
79 delete kernel;
80 kernel = nullptr;
81 }
82
83 TF_DeleteStatus(s);
84 return kernel;
85}
86
87static void BitcastOp_Delete(void* kernel) {
88 delete static_cast<BitcastOp*>(kernel);
89}
90
91static void BitcastOp_Compute(void* kernel, TF_OpKernelContext* ctx) {
92 auto* k = static_cast<BitcastOp*>(kernel);
93 int dim_count = 0;
94
95 TF_Tensor* tensor;
96 TF_Status* status = TF_NewStatus();
97 TF_GetInput(ctx, 0, &tensor, status);
98 if (TF_GetCode(status) == TF_OK) {
99 dim_count = TF_NumDims(tensor);
100 if (!(k->in_size >= k->out_size ||
101 (dim_count > 0 &&
102 TF_Dim(tensor, dim_count - 1) == k->out_size / k->in_size))) {
103 std::ostringstream err;
104 err << "Cannot bitcast from " << k->input_data_type << " to "
105 << k->output_data_type;
106 TF_SetStatus(status, TF_INVALID_ARGUMENT, err.str().c_str());
107 }
108 }
109
110 if (TF_GetCode(status) == TF_OK) {
111 auto* dims = new int64_t[dim_count + 1];
112 int new_dim_count = dim_count;
113 for (int dim = 0; dim < dim_count; ++dim) {
114 dims[dim] = TF_Dim(tensor, dim);
115 }
116 if (k->out_size < k->in_size) {
117 dims[new_dim_count++] = static_cast<int64_t>(k->in_size / k->out_size);
118 } else if (k->out_size > k->in_size) {
119 --new_dim_count;
120 }
121
122 TF_Tensor* output = TF_AllocateTensor(k->output_data_type, dims, 0,
123 TF_DataTypeSize(k->output_data_type));
124 TF_TensorBitcastFrom(tensor, k->output_data_type, output, dims,
125 new_dim_count, status);
126 if (TF_GetCode(status) == TF_OK) {
127 TF_SetOutput(ctx, 0, output, status);
128 }
129 delete[] dims;
130 TF_DeleteTensor(output);
131 }
132
133 if (TF_GetCode(status) != TF_OK) {
134 TF_OpKernelContext_Failure(ctx, status);
135 }
136 TF_DeleteStatus(status);
137 TF_DeleteTensor(tensor);
138}
139
140void RegisterBitcastOpKernel() {
141 TF_Status* status = TF_NewStatus();
142 {
143 auto* builder = TF_NewKernelBuilder("Bitcast", tensorflow::DEVICE_CPU,
144 &BitcastOp_Create, &BitcastOp_Compute,
145 &BitcastOp_Delete);
146 TF_RegisterKernelBuilder("BitcastOp", builder, status);
147 CHECK_EQ(TF_OK, TF_GetCode(status))
148 << "Error while registering bitcast kernel";
149 }
150
151#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
152 {
153 auto* builder = TF_NewKernelBuilder("Bitcast", tensorflow::DEVICE_GPU,
154 &BitcastOp_Create, &BitcastOp_Compute,
155 &BitcastOp_Delete);
156 TF_RegisterKernelBuilder("BitcastOp", builder, status);
157 CHECK_EQ(TF_OK, TF_GetCode(status))
158 << "Error while registering CUDA bitcast kernel";
159 }
160#endif
161
162 TF_DeleteStatus(status);
163}
164
165// A dummy static variable initialized by a lambda whose side-effect is to
166// register the bitcast kernel.
167TF_ATTRIBUTE_UNUSED static bool IsBitcastOpKernelRegistered = []() {
168 if (SHOULD_REGISTER_OP_KERNEL("BitcastOp")) {
169 RegisterBitcastOpKernel();
170 }
171 return true;
172}();
173