1 | /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include <sstream> |
17 | |
18 | #include "tensorflow/c/kernels.h" |
19 | #include "tensorflow/c/ops.h" |
20 | #include "tensorflow/c/tf_tensor.h" |
21 | #include "tensorflow/core/framework/common_shape_fns.h" |
22 | #include "tensorflow/core/framework/op.h" |
23 | #include "tensorflow/core/framework/registration/registration.h" |
24 | #include "tensorflow/core/framework/shape_inference.h" |
25 | #include "tensorflow/core/platform/macros.h" |
26 | |
27 | // BitcastOp implements a bitcast kernel, creating an output tensor that shares |
28 | // the same data buffer as the input but with a different shape and/or data |
29 | // type. Its inputs are: |
30 | // |
31 | // * the input tensor |
32 | // * an attribute named "T" containing the TF_DataType of the input tensor |
33 | // * an attribute named "type" containing the TF_DataType of the output tensor |
34 | // |
35 | // Given an input tensor of shape [...], if the input DataType "T" is larger |
36 | // than the output DataType "type", then the shape changes from [...] |
37 | // to [..., sizeof(T)/sizeof(type)]. |
38 | // |
39 | // If "T" is smaller than "type", the operator requires that the rightmost |
40 | // dimension be equal to sizeof(type)/sizeof(T). The shape then goes from |
41 | // [..., sizeof(type)/sizeof(T)] to [...]. |
42 | // |
43 | // Bitcast is implemented as a low-level cast, so machines with different endian |
44 | // orderings will give different results. |
45 | typedef struct BitcastOp { |
46 | TF_DataType input_data_type; |
47 | TF_DataType output_data_type; |
48 | size_t in_size; |
49 | size_t out_size; |
50 | } BitcastOp; |
51 | |
52 | static void* BitcastOp_Create(TF_OpKernelConstruction* ctx) { |
53 | auto* kernel = new BitcastOp; |
54 | |
55 | TF_Status* s = TF_NewStatus(); |
56 | TF_OpKernelConstruction_GetAttrType(ctx, "T" , &kernel->input_data_type, s); |
57 | |
58 | if (TF_GetCode(s) == TF_OK) { |
59 | TF_OpKernelConstruction_GetAttrType(ctx, "type" , &kernel->output_data_type, |
60 | s); |
61 | } |
62 | |
63 | if (TF_GetCode(s) == TF_OK) { |
64 | kernel->in_size = TF_DataTypeSize(kernel->input_data_type); |
65 | kernel->out_size = TF_DataTypeSize(kernel->output_data_type); |
66 | |
67 | size_t check_size = std::max(kernel->in_size, kernel->out_size) % |
68 | std::min(kernel->in_size, kernel->out_size); |
69 | if (check_size != 0) { |
70 | std::ostringstream err; |
71 | err << "cannot convert between datatype " << kernel->input_data_type |
72 | << " and " << kernel->output_data_type; |
73 | TF_SetStatus(s, TF_INVALID_ARGUMENT, err.str().c_str()); |
74 | } |
75 | } |
76 | |
77 | if (TF_GetCode(s) != TF_OK) { |
78 | TF_OpKernelConstruction_Failure(ctx, s); |
79 | delete kernel; |
80 | kernel = nullptr; |
81 | } |
82 | |
83 | TF_DeleteStatus(s); |
84 | return kernel; |
85 | } |
86 | |
87 | static void BitcastOp_Delete(void* kernel) { |
88 | delete static_cast<BitcastOp*>(kernel); |
89 | } |
90 | |
91 | static void BitcastOp_Compute(void* kernel, TF_OpKernelContext* ctx) { |
92 | auto* k = static_cast<BitcastOp*>(kernel); |
93 | int dim_count = 0; |
94 | |
95 | TF_Tensor* tensor; |
96 | TF_Status* status = TF_NewStatus(); |
97 | TF_GetInput(ctx, 0, &tensor, status); |
98 | if (TF_GetCode(status) == TF_OK) { |
99 | dim_count = TF_NumDims(tensor); |
100 | if (!(k->in_size >= k->out_size || |
101 | (dim_count > 0 && |
102 | TF_Dim(tensor, dim_count - 1) == k->out_size / k->in_size))) { |
103 | std::ostringstream err; |
104 | err << "Cannot bitcast from " << k->input_data_type << " to " |
105 | << k->output_data_type; |
106 | TF_SetStatus(status, TF_INVALID_ARGUMENT, err.str().c_str()); |
107 | } |
108 | } |
109 | |
110 | if (TF_GetCode(status) == TF_OK) { |
111 | auto* dims = new int64_t[dim_count + 1]; |
112 | int new_dim_count = dim_count; |
113 | for (int dim = 0; dim < dim_count; ++dim) { |
114 | dims[dim] = TF_Dim(tensor, dim); |
115 | } |
116 | if (k->out_size < k->in_size) { |
117 | dims[new_dim_count++] = static_cast<int64_t>(k->in_size / k->out_size); |
118 | } else if (k->out_size > k->in_size) { |
119 | --new_dim_count; |
120 | } |
121 | |
122 | TF_Tensor* output = TF_AllocateTensor(k->output_data_type, dims, 0, |
123 | TF_DataTypeSize(k->output_data_type)); |
124 | TF_TensorBitcastFrom(tensor, k->output_data_type, output, dims, |
125 | new_dim_count, status); |
126 | if (TF_GetCode(status) == TF_OK) { |
127 | TF_SetOutput(ctx, 0, output, status); |
128 | } |
129 | delete[] dims; |
130 | TF_DeleteTensor(output); |
131 | } |
132 | |
133 | if (TF_GetCode(status) != TF_OK) { |
134 | TF_OpKernelContext_Failure(ctx, status); |
135 | } |
136 | TF_DeleteStatus(status); |
137 | TF_DeleteTensor(tensor); |
138 | } |
139 | |
140 | void RegisterBitcastOpKernel() { |
141 | TF_Status* status = TF_NewStatus(); |
142 | { |
143 | auto* builder = TF_NewKernelBuilder("Bitcast" , tensorflow::DEVICE_CPU, |
144 | &BitcastOp_Create, &BitcastOp_Compute, |
145 | &BitcastOp_Delete); |
146 | TF_RegisterKernelBuilder("BitcastOp" , builder, status); |
147 | CHECK_EQ(TF_OK, TF_GetCode(status)) |
148 | << "Error while registering bitcast kernel" ; |
149 | } |
150 | |
151 | #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
152 | { |
153 | auto* builder = TF_NewKernelBuilder("Bitcast" , tensorflow::DEVICE_GPU, |
154 | &BitcastOp_Create, &BitcastOp_Compute, |
155 | &BitcastOp_Delete); |
156 | TF_RegisterKernelBuilder("BitcastOp" , builder, status); |
157 | CHECK_EQ(TF_OK, TF_GetCode(status)) |
158 | << "Error while registering CUDA bitcast kernel" ; |
159 | } |
160 | #endif |
161 | |
162 | TF_DeleteStatus(status); |
163 | } |
164 | |
165 | // A dummy static variable initialized by a lambda whose side-effect is to |
166 | // register the bitcast kernel. |
167 | TF_ATTRIBUTE_UNUSED static bool IsBitcastOpKernelRegistered = []() { |
168 | if (SHOULD_REGISTER_OP_KERNEL("BitcastOp" )) { |
169 | RegisterBitcastOpKernel(); |
170 | } |
171 | return true; |
172 | }(); |
173 | |