1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_CORE_UTIL_MKL_UTIL_H_ |
17 | #define TENSORFLOW_CORE_UTIL_MKL_UTIL_H_ |
18 | #ifdef INTEL_MKL |
19 | |
20 | #include <list> |
21 | #include <memory> |
22 | #include <string> |
23 | #include <unordered_map> |
24 | #include <utility> |
25 | #include <vector> |
26 | |
27 | #include "dnnl.hpp" |
28 | #include "tensorflow/core/framework/op_kernel.h" |
29 | #include "tensorflow/core/framework/tensor.h" |
30 | #include "tensorflow/core/framework/tensor_shape.h" |
31 | #include "tensorflow/core/graph/mkl_graph_util.h" |
32 | #include "tensorflow/core/lib/core/errors.h" |
33 | #include "tensorflow/core/lib/core/stringpiece.h" |
34 | #include "tensorflow/core/lib/gtl/array_slice.h" |
35 | #include "tensorflow/core/platform/cpu_info.h" |
36 | #include "tensorflow/core/platform/logging.h" |
37 | #include "tensorflow/core/platform/macros.h" |
38 | #include "tensorflow/core/util/env_var.h" |
39 | #include "tensorflow/core/util/mkl_threadpool.h" |
40 | #include "tensorflow/core/util/padding.h" |
41 | #include "tensorflow/core/util/tensor_format.h" |
42 | #ifdef DNNL_AARCH64_USE_ACL |
43 | #include "tensorflow/core/platform/mutex.h" |
44 | #endif |
45 | |
46 | using dnnl::engine; |
47 | using dnnl::memory; |
48 | using dnnl::primitive; |
49 | using dnnl::reorder; |
50 | using dnnl::stream; |
51 | using CPUDevice = Eigen::ThreadPoolDevice; |
52 | using MemoryArgsMap = std::unordered_map<int, memory>; |
53 | using ReorderPd = dnnl::reorder::primitive_desc; |
54 | |
55 | #ifdef _WIN32 |
56 | typedef unsigned int uint; |
57 | #endif |
58 | |
59 | namespace tensorflow { |
60 | |
61 | // The file contains a number of utility classes and functions used by MKL |
62 | // enabled kernels |
63 | |
64 | // This class encapsulates all the meta data that is associated with an MKL |
65 | // tensor. A tensor is an MKL tensor if it was created as the result of an |
66 | // MKL operation, and did not go through a conversion to a standard |
67 | // Tensorflow tensor. |
68 | |
69 | // The dimensions order that oneDNN internally uses for 2D activations |
70 | // [Batch, Channel, Height, Width] and |
71 | // for 2D filters [Out_Channel, In_Channel, Height, Width]. |
72 | typedef enum { |
73 | Dim_N = 0, |
74 | Dim_C = 1, |
75 | Dim_H = 2, |
76 | Dim_W = 3, |
77 | Dim_O = 0, |
78 | Dim_I = 1 |
79 | } MklDnnDims; |
80 | |
81 | // The dimensions order that oneDNN internally uses for 3D activations |
82 | // [Batch, Channel, Depth, Height, Width] and |
83 | // for 3D filters [Out_Channel, In_Channel, Depth, Height, Width]. |
84 | typedef enum { |
85 | Dim3d_N = 0, |
86 | Dim3d_C = 1, |
87 | Dim3d_D = 2, |
88 | Dim3d_H = 3, |
89 | Dim3d_W = 4, |
90 | Dim3d_O = 0, |
91 | Dim3d_I = 1 |
92 | } MklDnnDims3D; |
93 | |
94 | // Enum for the order of dimensions of a TF 2D filter with shape [filter_height, |
95 | // filter_width, in_channels, out_channels] |
96 | typedef enum { |
97 | TF_2DFILTER_DIM_H = 0, |
98 | TF_2DFILTER_DIM_W = 1, |
99 | TF_2DFILTER_DIM_I = 2, |
100 | TF_2DFILTER_DIM_O = 3 |
101 | } TFFilterDims2d; |
102 | |
103 | // Enum for the order of dimensions of a TF 3D filter with shape [filter_depth, |
104 | // filter_height, filter_width, in_channels, out_channels] |
105 | typedef enum { |
106 | TF_3DFILTER_DIM_P = 0, |
107 | TF_3DFILTER_DIM_H = 1, |
108 | TF_3DFILTER_DIM_W = 2, |
109 | TF_3DFILTER_DIM_I = 3, |
110 | TF_3DFILTER_DIM_O = 4 |
111 | } TFFilterDims3d; |
112 | |
113 | // The dimensions order that oneDNN requires for the filter in a grouped |
114 | // convolution (2D only) |
115 | typedef enum { |
116 | MKL_GROUP_FILTER_DIM_G = 0, |
117 | MKL_GROUP_FILTER_DIM_O = 1, |
118 | MKL_GROUP_FILTER_DIM_I = 2, |
119 | MKL_GROUP_FILTER_DIM_H = 3, |
120 | MKL_GROUP_FILTER_DIM_W = 4 |
121 | } MklDnnFilterGroupDims; |
122 | |
123 | // Enum used to templatize MklOp kernel implementation |
124 | // that support both fp32 and int8 versions. |
125 | enum class MklQuantization { |
126 | QUANTIZED_VERSION, |
127 | FP_VERSION, |
128 | }; |
129 | |
130 | static const int kSmallBatchSize = 32; |
131 | |
132 | inline void execute_primitives( |
133 | std::vector<dnnl::primitive>& primitives, std::shared_ptr<stream> stream, |
134 | std::vector<std::unordered_map<int, memory>>& net_args) { |
135 | DCHECK_EQ(primitives.size(), net_args.size()); |
136 | for (size_t i = 0; i < primitives.size(); ++i) { |
137 | primitives.at(i).execute(*stream, net_args.at(i)); |
138 | } |
139 | } |
140 | |
141 | // In oneDNN v1.x, the format (ex. NCHW) used to initialize a memory descriptor |
142 | // (md) structure will no longer be recorded in its `format` field. Instead, it |
143 | // will be set to a canonical `blocked` format for every fully described md. |
144 | // |
145 | // Currently, we query this `format` field while mapping oneDNN's data format |
146 | // to TF's data format. Due to the above restriction, we will now get this data |
147 | // format information from TF's `data_format` attribute (i.e. via |
148 | // `TensorFormat`) for oneDNN v1.x. |
149 | // |
150 | // Some oneDNN operators such as ReLU do not have a `data_format` attribute |
151 | // since they are usually in `blocked` format. Therefore, in order to |
152 | // distinguish between blocked and non-blocked formats, we have defined a new |
153 | // enum called `MklTensorFormat` that is semantically similar to `TensorFormat` |
154 | // but with the following additional fields namely: |
155 | // 1) FORMAT_BLOCKED: as described above, this is needed for element-wise |
156 | // operators such as ReLU. |
157 | // 2) FORMAT_INVALID: for error-checking (ex. unsupported format) |
158 | // 3) FORMAT_X, FORMAT_NC, FORMAT_TNC: to distinguish between MKL tensors based |
159 | // on their dimensions in operators such as Softmax, i.e.: |
160 | // FORMAT_X - 1D tensor |
161 | // FORMAT_NC - 2D tensor |
162 | // FORMAT_TNC - 3D tensor |
163 | enum class MklTensorFormat { |
164 | FORMAT_NHWC = 0, |
165 | FORMAT_NCHW = 1, |
166 | FORMAT_NDHWC = 2, |
167 | FORMAT_NCDHW = 3, |
168 | FORMAT_X = 4, |
169 | FORMAT_NC = 5, |
170 | FORMAT_TNC = 6, |
171 | FORMAT_BLOCKED = 7, |
172 | FORMAT_INVALID = 8, |
173 | }; |
174 | |
175 | // Forward declarations |
176 | memory::format_tag MklTensorFormatToMklDnnDataFormat(MklTensorFormat format); |
177 | |
178 | TensorFormat MklDnn3DDataFormatToTFDataFormat(MklTensorFormat format); |
179 | TensorFormat MklDnnDataFormatToTFDataFormat(MklTensorFormat format); |
180 | |
181 | memory::dims CalculateTFStrides(const memory::dims& dims_tf_order); |
182 | Status CreateBlockedMemDescHelper(const memory::dims& dim, |
183 | const memory::dims& strides, |
184 | memory::data_type dtype, |
185 | dnnl_memory_desc_t* blocked_md); |
186 | |
187 | inline std::ostream& operator<<(std::ostream& os, |
188 | const memory::format_tag& tag) { |
189 | if (tag == memory::format_tag::undef) { |
190 | os << "undef" ; |
191 | } else if (tag == memory::format_tag::any) { |
192 | os << "any" ; |
193 | } else { |
194 | os << "invalid" ; |
195 | } |
196 | return os; |
197 | } |
198 | |
199 | inline void operator<<(std::ostream& os, const MklTensorFormat& format) { |
200 | if (format == MklTensorFormat::FORMAT_NHWC) { |
201 | os << "FORMAT_NHWC" ; |
202 | } else if (format == MklTensorFormat::FORMAT_NCHW) { |
203 | os << "FORMAT_NCHW" ; |
204 | } else if (format == MklTensorFormat::FORMAT_NDHWC) { |
205 | os << "FORMAT_NDHWC" ; |
206 | } else if (format == MklTensorFormat::FORMAT_NCDHW) { |
207 | os << "FORMAT_NCDHW" ; |
208 | } else if (format == MklTensorFormat::FORMAT_X) { |
209 | os << "FORMAT_X" ; |
210 | } else if (format == MklTensorFormat::FORMAT_NC) { |
211 | os << "FORMAT_NC" ; |
212 | } else if (format == MklTensorFormat::FORMAT_TNC) { |
213 | os << "FORMAT_TNC" ; |
214 | } else if (format == MklTensorFormat::FORMAT_BLOCKED) { |
215 | os << "FORMAT_BLOCKED" ; |
216 | } else { |
217 | os << "INVALID FORMAT" ; |
218 | } |
219 | } |
220 | |
221 | template <typename T> |
222 | inline bool array_cmp(const T* a1, const T* a2, size_t size) { |
223 | for (size_t i = 0; i < size; ++i) |
224 | if (a1[i] != a2[i]) return false; |
225 | return true; |
226 | } |
227 | |
228 | inline dnnl::stream* CreateStream(MklDnnThreadPool* eigen_tp, |
229 | const engine& engine) { |
230 | #ifndef ENABLE_ONEDNN_OPENMP |
231 | if (eigen_tp != nullptr) { |
232 | stream* tp_stream = |
233 | new stream(dnnl::threadpool_interop::make_stream(engine, eigen_tp)); |
234 | return tp_stream; |
235 | } else { |
236 | stream* tp_stream = new stream(engine); |
237 | return tp_stream; |
238 | } |
239 | #else |
240 | stream* tp_stream = new stream(engine); |
241 | return tp_stream; |
242 | #endif // !ENABLE_ONEDNN_OPENMP |
243 | } |
244 | |
245 | class MklDnnShape { |
246 | private: |
247 | struct MklShapeData { |
248 | // Flag to indicate if the tensor is an MKL tensor or not |
249 | bool is_mkl_tensor_ = false; |
250 | // Number of dimensions in Tensorflow format |
251 | size_t dimension_ = 0; |
252 | dnnl_dims_t sizes_; // Required by MKL for conversions |
253 | MklTensorFormat tf_data_format_ = MklTensorFormat::FORMAT_BLOCKED; |
254 | memory::data_type T_ = memory::data_type::undef; |
255 | // MKL layout |
256 | dnnl_memory_desc_t mkl_md_; |
257 | /// TF dimension corresponding to this MKL dimension |
258 | dnnl_dims_t map_; |
259 | }; |
260 | MklShapeData data_; |
261 | |
262 | typedef std::remove_extent<dnnl_dims_t>::type dnnl_dim_t; |
263 | |
264 | #define INVALID_DIM_SIZE -1 |
265 | |
266 | public: |
267 | MklDnnShape() : data_{} { |
268 | for (size_t i = 0; i < sizeof(data_.sizes_) / sizeof(data_.sizes_[0]); |
269 | ++i) { |
270 | data_.sizes_[i] = -1; |
271 | } |
272 | for (size_t i = 0; i < sizeof(data_.map_) / sizeof(data_.map_[0]); ++i) { |
273 | data_.map_[i] = -1; |
274 | } |
275 | } |
276 | |
277 | ~MklDnnShape() {} |
278 | TF_DISALLOW_COPY_AND_ASSIGN(MklDnnShape); // Cannot copy |
279 | |
280 | /// Equality function for MklDnnShape objects |
281 | /// @return true if both are equal; false otherwise. |
282 | inline bool operator==(const MklDnnShape& input_shape) const { |
283 | if (this->IsMklTensor() != input_shape.IsMklTensor()) { |
284 | return false; |
285 | } |
286 | |
287 | // If input tensors are in MKL layout, then we check for dimensions and |
288 | // sizes. |
289 | if (this->IsMklTensor()) { |
290 | const dnnl_memory_desc_t& cur_md = (this->GetMklLayout()).data; |
291 | const dnnl_memory_desc_t& input_shape_md = |
292 | input_shape.GetMklLayout().data; |
293 | return this->GetTfShape() == input_shape.GetTfShape() && |
294 | dnnl_memory_desc_equal(&cur_md, &input_shape_md); |
295 | } |
296 | |
297 | // Both inputs are not MKL tensors. |
298 | return true; |
299 | } |
300 | |
301 | /// Equality operator for MklDnnShape and TFShape. |
302 | /// Returns: true if TF shapes for both are the same, false otherwise |
303 | inline bool operator==(const TensorShape& input_shape) const { |
304 | if (!this->IsMklTensor()) { |
305 | return false; |
306 | } |
307 | |
308 | return this->GetTfShape() == input_shape; |
309 | } |
310 | |
311 | inline const bool IsMklTensor() const { return data_.is_mkl_tensor_; } |
312 | inline void SetMklTensor(bool is_mkl_tensor) { |
313 | data_.is_mkl_tensor_ = is_mkl_tensor; |
314 | } |
315 | |
316 | inline void SetDimensions(const size_t dimension) { |
317 | data_.dimension_ = dimension; |
318 | } |
319 | inline size_t GetDimension(char dimension) const { |
320 | int index = GetMklDnnTensorDimIndex(dimension); |
321 | CHECK(index >= 0 && index < this->GetDimension()) |
322 | << "Invalid index from the dimension: " << index << ", " << dimension; |
323 | return this->DimSize(index); |
324 | } |
325 | |
326 | inline size_t GetDimension3D(char dimension) const { |
327 | int index = GetMklDnnTensor3DDimIndex(dimension); |
328 | CHECK(index >= 0 && index < this->GetDimension()) |
329 | << "Invalid index from the dimension: " << index << ", " << dimension; |
330 | return this->DimSize(index); |
331 | } |
332 | |
333 | inline int32 GetMklDnnTensorDimIndex(char dimension) const { |
334 | switch (dimension) { |
335 | case 'N': |
336 | return MklDnnDims::Dim_N; |
337 | case 'C': |
338 | return MklDnnDims::Dim_C; |
339 | case 'H': |
340 | return MklDnnDims::Dim_H; |
341 | case 'W': |
342 | return MklDnnDims::Dim_W; |
343 | default: |
344 | LOG(FATAL) << "Invalid dimension: " << dimension; |
345 | return -1; // Avoid compiler warning about missing return value |
346 | } |
347 | } |
348 | |
349 | inline int32 GetMklDnnTensor3DDimIndex(char dimension) const { |
350 | switch (dimension) { |
351 | case 'N': |
352 | return MklDnnDims3D::Dim3d_N; |
353 | case 'C': |
354 | return MklDnnDims3D::Dim3d_C; |
355 | case 'D': |
356 | return MklDnnDims3D::Dim3d_D; |
357 | case 'H': |
358 | return MklDnnDims3D::Dim3d_H; |
359 | case 'W': |
360 | return MklDnnDims3D::Dim3d_W; |
361 | default: |
362 | LOG(FATAL) << "Invalid dimension: " << dimension; |
363 | return -1; // Avoid compiler warning about missing return value |
364 | } |
365 | } |
366 | |
367 | inline size_t GetDimension() const { return data_.dimension_; } |
368 | inline const int* GetSizes() const { |
369 | return reinterpret_cast<const int*>(&data_.sizes_[0]); |
370 | } |
371 | |
372 | // Returns an dnnl::memory::dims object that contains the sizes of this |
373 | // MklDnnShape object. |
374 | inline memory::dims GetSizesAsMklDnnDims() const { |
375 | memory::dims retVal; |
376 | if (data_.is_mkl_tensor_) { |
377 | size_t dimensions = sizeof(data_.sizes_) / sizeof(data_.sizes_[0]); |
378 | for (size_t i = 0; i < dimensions; i++) { |
379 | if (data_.sizes_[i] != INVALID_DIM_SIZE) |
380 | retVal.push_back(data_.sizes_[i]); |
381 | } |
382 | } else { |
383 | CHECK_EQ(data_.is_mkl_tensor_, true); |
384 | } |
385 | return retVal; |
386 | } |
387 | |
388 | inline int64 DimSize(int index) const { |
389 | CHECK_LT(index, sizeof(data_.sizes_) / sizeof(data_.sizes_[0])); |
390 | return data_.sizes_[index]; |
391 | } |
392 | |
393 | /// Return TensorShape that describes the Tensorflow shape of the tensor |
394 | /// represented by this MklShape. |
395 | inline TensorShape GetTfShape() const { |
396 | CHECK_EQ(data_.is_mkl_tensor_, true); |
397 | |
398 | std::vector<int32> shape(data_.dimension_, -1); |
399 | // As mentioned in the comment above, we now rely on TF's `data_format` |
400 | // attribute to determine if TF shape is in blocked format or not. |
401 | if (data_.tf_data_format_ != MklTensorFormat::FORMAT_BLOCKED) { |
402 | for (size_t idx = 0; idx < data_.dimension_; ++idx) { |
403 | shape[idx] = data_.sizes_[TfDimIdx(idx)]; |
404 | } |
405 | } else { |
406 | // If Tensorflow shape is in Blocked format, then we don't have dimension |
407 | // map for it. So we just create Tensorflow shape from sizes in the |
408 | // specified order. |
409 | for (size_t idx = 0; idx < data_.dimension_; ++idx) { |
410 | shape[idx] = data_.sizes_[idx]; |
411 | } |
412 | } |
413 | |
414 | TensorShape ts; |
415 | bool ret = TensorShapeUtils::MakeShape(shape, &ts).ok(); |
416 | CHECK_EQ(ret, true); |
417 | return ts; |
418 | } |
419 | |
420 | inline void SetElemType(memory::data_type dt) { data_.T_ = dt; } |
421 | inline const memory::data_type GetElemType() { return data_.T_; } |
422 | |
423 | inline void SetMklLayout(memory::desc* md) { |
424 | CHECK_NOTNULL(md); |
425 | data_.mkl_md_ = md->data; |
426 | } |
427 | |
428 | inline const memory::desc GetMklLayout() const { |
429 | return memory::desc(data_.mkl_md_); |
430 | } |
431 | |
432 | inline MklTensorFormat GetTfDataFormat() const { |
433 | return data_.tf_data_format_; |
434 | } |
435 | |
436 | /// We don't create primitive_descriptor for TensorFlow layout now. |
437 | /// We use lazy evaluation and create it only when needed. Input format can |
438 | /// also be Blocked format. |
439 | inline void SetTfLayout(size_t dims, const memory::dims& sizes, |
440 | MklTensorFormat format) { |
441 | DCHECK_EQ(dims, sizes.size()) |
442 | << "SetTfLayout: Number of dimensions does not" |
443 | "match with dimension array" ; |
444 | data_.dimension_ = dims; |
445 | for (size_t ii = 0; ii < dims; ++ii) { |
446 | data_.sizes_[ii] = sizes[ii]; |
447 | } |
448 | data_.tf_data_format_ = format; |
449 | if (format != MklTensorFormat::FORMAT_BLOCKED) { |
450 | if (dims == 2) { |
451 | data_.map_[0] = MklDnnDims::Dim_N; |
452 | data_.map_[1] = MklDnnDims::Dim_C; |
453 | } else { |
454 | SetTfDimOrder(dims, format); |
455 | } |
456 | } |
457 | } |
458 | |
459 | inline const memory::desc GetTfLayout() const { |
460 | memory::dims dims; |
461 | for (size_t ii = 0; ii < data_.dimension_; ++ii) { |
462 | dims.push_back(data_.sizes_[ii]); |
463 | } |
464 | |
465 | // Create Blocked memory desc if input TF format was set like that. |
466 | if (data_.tf_data_format_ == MklTensorFormat::FORMAT_BLOCKED) { |
467 | auto strides = CalculateTFStrides(dims); |
468 | dnnl_memory_desc_t blocked_md; |
469 | TF_CHECK_OK( |
470 | CreateBlockedMemDescHelper(dims, strides, data_.T_, &blocked_md)); |
471 | return memory::desc(blocked_md); |
472 | } else { |
473 | auto format_tag = |
474 | MklTensorFormatToMklDnnDataFormat(data_.tf_data_format_); |
475 | return memory::desc(dims, data_.T_, format_tag); |
476 | } |
477 | } |
478 | |
479 | inline const memory::desc GetCurLayout() const { |
480 | return IsMklTensor() ? GetMklLayout() : GetTfLayout(); |
481 | } |
482 | |
483 | // We don't need a case of default dimension order because |
484 | // when an operator that does not get data_format attribute gets all inputs |
485 | // in Tensorflow format, it will produce output in Tensorflow format. |
486 | inline void SetTfDimOrder(const size_t dimension, const dnnl_dims_t map) { |
487 | CHECK(dimension == data_.dimension_); |
488 | for (size_t ii = 0; ii < dimension; ii++) { |
489 | data_.map_[ii] = map[ii]; |
490 | } |
491 | } |
492 | |
493 | inline void SetTfDimOrder(const size_t dimension, TensorFormat data_format) { |
494 | if (dimension == 5) { |
495 | CHECK(dimension == data_.dimension_); |
496 | data_.map_[GetTensorDimIndex<3>(data_format, '0')] = |
497 | MklDnnDims3D::Dim3d_D; |
498 | data_.map_[GetTensorDimIndex<3>(data_format, '1')] = |
499 | MklDnnDims3D::Dim3d_H; |
500 | data_.map_[GetTensorDimIndex<3>(data_format, '2')] = |
501 | MklDnnDims3D::Dim3d_W; |
502 | data_.map_[GetTensorDimIndex<3>(data_format, 'C')] = |
503 | MklDnnDims3D::Dim3d_C; |
504 | data_.map_[GetTensorDimIndex<3>(data_format, 'N')] = |
505 | MklDnnDims3D::Dim3d_N; |
506 | } else { |
507 | CHECK_EQ(dimension, 4); |
508 | CHECK(dimension == data_.dimension_); |
509 | data_.map_[GetTensorDimIndex<2>(data_format, 'W')] = MklDnnDims::Dim_W; |
510 | data_.map_[GetTensorDimIndex<2>(data_format, 'H')] = MklDnnDims::Dim_H; |
511 | data_.map_[GetTensorDimIndex<2>(data_format, 'C')] = MklDnnDims::Dim_C; |
512 | data_.map_[GetTensorDimIndex<2>(data_format, 'N')] = MklDnnDims::Dim_N; |
513 | } |
514 | } |
515 | |
516 | inline void SetTfDimOrder(const size_t dimension, MklTensorFormat format) { |
517 | TensorFormat data_format = MklDnnDataFormatToTFDataFormat(format); |
518 | SetTfDimOrder(dimension, data_format); |
519 | } |
520 | |
521 | inline const dnnl_dim_t* GetTfToMklDimMap() const { return &data_.map_[0]; } |
522 | inline size_t TfDimIdx(int index) const { return data_.map_[index]; } |
523 | inline int64 TfDimSize(int index) const { |
524 | return data_.sizes_[TfDimIdx(index)]; |
525 | } |
526 | |
527 | /// Query TF-MKL dimension ordering map and check if Tensorflow dimension 'd' |
528 | /// corresponds to MKL's Channel dimension. |
529 | inline bool IsMklChannelDim(int d) const { |
530 | return TfDimIdx(d) == MklDnnDims::Dim_C; |
531 | } |
532 | |
533 | /// Query TF-MKL dimension ordering map and check if Tensorflow dimension 'd' |
534 | /// corresponds to MKL's Batch dimension. |
535 | inline bool IsMklBatchDim(int d) const { |
536 | return TfDimIdx(d) == MklDnnDims::Dim_N; |
537 | } |
538 | |
539 | /// Query TF-MKL dimension ordering map and check if Tensorflow dimension 'd' |
540 | /// corresponds to MKL's Width dimension. |
541 | inline bool IsMklWidthDim(int d) const { |
542 | return TfDimIdx(d) == MklDnnDims::Dim_W; |
543 | } |
544 | /// Query TF-MKL dimension ordering map and check if Tensorflow dimension 'd' |
545 | /// corresponds to MKL's Height dimension. |
546 | inline bool IsMklHeightDim(int d) const { |
547 | return TfDimIdx(d) == MklDnnDims::Dim_H; |
548 | } |
549 | |
550 | /// Check if the TF-MKL dimension ordering map specifies if the input |
551 | /// tensor is in NCHW format. |
552 | inline bool IsTensorInNCHWFormat() const { |
553 | TensorFormat data_format = FORMAT_NCHW; |
554 | return (IsMklBatchDim(GetTensorDimIndex<2>(data_format, 'N')) && |
555 | IsMklChannelDim(GetTensorDimIndex<2>(data_format, 'C')) && |
556 | IsMklHeightDim(GetTensorDimIndex<2>(data_format, 'H')) && |
557 | IsMklWidthDim(GetTensorDimIndex<2>(data_format, 'W'))); |
558 | } |
559 | |
560 | /// Check if the TF-MKL dimension ordering map specifies if the input |
561 | /// tensor is in NHWC format. |
562 | inline bool IsTensorInNHWCFormat() const { |
563 | TensorFormat data_format = FORMAT_NHWC; |
564 | return (IsMklBatchDim(GetTensorDimIndex<2>(data_format, 'N')) && |
565 | IsMklChannelDim(GetTensorDimIndex<2>(data_format, 'C')) && |
566 | IsMklHeightDim(GetTensorDimIndex<2>(data_format, 'H')) && |
567 | IsMklWidthDim(GetTensorDimIndex<2>(data_format, 'W'))); |
568 | } |
569 | |
570 | /// The following methods are used for serializing and de-serializing the |
571 | /// contents of the mklshape object. |
572 | /// The data is serialized in this order |
573 | /// is_mkl_tensor_ : dimension_ : sizes_ : map_: format_ : T_ : mkl_pd_; |
574 | |
575 | /// Size of buffer to hold the serialized object, the size is computed by |
576 | /// following above mentioned order |
577 | inline size_t GetSerializeBufferSize() const { return sizeof(MklShapeData); } |
578 | |
579 | void SerializeMklDnnShape(unsigned char* buf, size_t buf_size) const { |
580 | CHECK(buf_size >= GetSerializeBufferSize()) |
581 | << "Buffer size is too small to SerializeMklDnnShape" ; |
582 | *reinterpret_cast<MklShapeData*>(buf) = data_; |
583 | } |
584 | |
585 | void DeSerializeMklDnnShape(const unsigned char* buf, size_t buf_size) { |
586 | // Make sure buffer holds at least is_mkl_tensor_. |
587 | CHECK(buf_size >= sizeof(data_.is_mkl_tensor_)) |
588 | << "Buffer size is too small in DeSerializeMklDnnShape" ; |
589 | |
590 | const bool is_mkl_tensor = *reinterpret_cast<const bool*>(buf); |
591 | if (is_mkl_tensor) { // If it is an MKL Tensor then read the rest |
592 | CHECK(buf_size >= GetSerializeBufferSize()) |
593 | << "Buffer size is too small in DeSerializeMklDnnShape" ; |
594 | data_ = *reinterpret_cast<const MklShapeData*>(buf); |
595 | } |
596 | } |
597 | }; |
598 | |
599 | // List of MklShape objects. Used in Concat/Split layers. |
600 | typedef std::vector<MklDnnShape> MklDnnShapeList; |
601 | |
602 | template <typename T> |
603 | class MklDnnData; |
604 | |
605 | // TODO(intel-tf): Merge with the execute_primitives. |
606 | inline void ExecutePrimitive(const std::vector<primitive>& net, |
607 | const std::vector<MemoryArgsMap>* net_args, |
608 | const engine& cpu_engine, |
609 | OpKernelContext* context = nullptr) { |
610 | DCHECK(net_args); |
611 | DCHECK_EQ(net.size(), net_args->size()); |
612 | std::unique_ptr<stream> cpu_stream; |
613 | MklDnnThreadPool eigen_tp; |
614 | if (context != nullptr) { |
615 | eigen_tp = MklDnnThreadPool(context); |
616 | cpu_stream.reset(CreateStream(&eigen_tp, cpu_engine)); |
617 | } else { |
618 | cpu_stream.reset(CreateStream(nullptr, cpu_engine)); |
619 | } |
620 | for (size_t i = 0; i < net.size(); ++i) { |
621 | net.at(i).execute(*cpu_stream, net_args->at(i)); |
622 | } |
623 | cpu_stream->wait(); |
624 | } |
625 | template <typename T> |
626 | inline Status ConvertMklToTF(OpKernelContext* context, |
627 | const Tensor& input_mkl_tensor, |
628 | const MklDnnShape& input_mkl_shape, |
629 | Tensor* output_tf_tensor) { |
630 | try { |
631 | if (!input_mkl_shape.IsMklTensor()) { |
632 | // Return input as is since it is already a TF tensor |
633 | *output_tf_tensor = input_mkl_tensor; |
634 | return Status::OK(); |
635 | } |
636 | |
637 | // Allocate output tensor. |
638 | TensorShape output_tf_shape = input_mkl_shape.GetTfShape(); |
639 | TF_CHECK_OK(context->allocate_temp(DataTypeToEnum<T>::v(), output_tf_shape, |
640 | output_tf_tensor)); |
641 | |
642 | engine cpu_engine(engine::kind::cpu, 0); |
643 | MklDnnData<T> input(&cpu_engine); |
644 | |
645 | // Get MKL layout of input tensor. |
646 | auto input_mkl_md = input_mkl_shape.GetMklLayout(); |
647 | auto output_tf_md = input_mkl_shape.GetTfLayout(); |
648 | input.SetUsrMem(input_mkl_md, &input_mkl_tensor); |
649 | |
650 | if (input.IsReorderNeeded(output_tf_md)) { |
651 | std::vector<primitive> net; |
652 | std::vector<MemoryArgsMap> net_args; |
653 | bool status = input.CheckReorderToOpMem(output_tf_md, output_tf_tensor, |
654 | net, net_args, cpu_engine); |
655 | if (!status) { |
656 | return Status(error::Code::INTERNAL, |
657 | "ConvertMklToTF(): Failed to create reorder for input" ); |
658 | } |
659 | ExecutePrimitive(net, &net_args, cpu_engine, context); |
660 | } else { |
661 | // If not, just forward input tensor to output tensor. |
662 | bool status = |
663 | output_tf_tensor->CopyFrom(input_mkl_tensor, output_tf_shape); |
664 | if (!status) { |
665 | return Status( |
666 | error::Code::INTERNAL, |
667 | "ConvertMklToTF(): Failed to forward input tensor to output" ); |
668 | } |
669 | } |
670 | return Status::OK(); |
671 | } catch (dnnl::error& e) { |
672 | string error_msg = "Status: " + std::to_string(e.status) + |
673 | ", message: " + string(e.message) + ", in file " + |
674 | string(__FILE__) + ":" + std::to_string(__LINE__); |
675 | LOG(FATAL) << "Operation received an exception: " << error_msg; |
676 | } |
677 | } |
678 | |
679 | // Get the MKL shape from the second string tensor |
680 | inline void GetMklShape(OpKernelContext* ctext, int n, MklDnnShape* mklshape, |
681 | bool eager_mode) { |
682 | if (!eager_mode) { |
683 | mklshape->DeSerializeMklDnnShape( |
684 | ctext->input(GetTensorMetaDataIndex(n, ctext->num_inputs())) |
685 | .flat<uint8>() |
686 | .data(), |
687 | ctext->input(GetTensorMetaDataIndex(n, ctext->num_inputs())) |
688 | .flat<uint8>() |
689 | .size() * |
690 | sizeof(uint8)); |
691 | } else { |
692 | mklshape->SetMklTensor(false); |
693 | } |
694 | } |
695 | |
696 | inline void GetMklShape(OpKernelContext* ctext, int n, MklDnnShape* mklshape) { |
697 | GetMklShape(ctext, n, mklshape, false); |
698 | } |
699 | |
700 | // Gets the actual input |
701 | inline const Tensor& MklGetInput(OpKernelContext* ctext, int n) { |
702 | return ctext->input(GetTensorDataIndex(n, ctext->num_inputs())); |
703 | } |
704 | |
705 | inline void GetMklInputList(OpKernelContext* ctext, StringPiece name, |
706 | OpInputList* input_tensors) { |
707 | CHECK_NOTNULL(input_tensors); |
708 | TF_CHECK_OK(ctext->input_list(name, input_tensors)); |
709 | } |
710 | |
711 | inline void GetMklShapeList(OpKernelContext* ctext, StringPiece name, |
712 | MklDnnShapeList* mkl_shapes, |
713 | bool native_format = false) { |
714 | if (!native_format) { |
715 | OpInputList input_mkl_tensors; |
716 | GetMklInputList(ctext, strings::StrCat("mkl_" , name), &input_mkl_tensors); |
717 | |
718 | for (int i = 0; i < input_mkl_tensors.size(); i++) { |
719 | (*mkl_shapes)[i].DeSerializeMklDnnShape( |
720 | input_mkl_tensors[i].flat<uint8>().data(), |
721 | input_mkl_tensors[i].flat<uint8>().size() * sizeof(uint8)); |
722 | } |
723 | } else { |
724 | for (int i = 0; i < mkl_shapes->size(); ++i) { |
725 | (*mkl_shapes)[i].SetMklTensor(false); |
726 | } |
727 | } |
728 | } |
729 | |
730 | /// Get shape of input tensor pointed by 'input_idx' in TensorShape format. |
731 | /// If the input tensor is in MKL layout, then obtains TensorShape from |
732 | /// MklShape. |
733 | inline TensorShape GetTfShape(OpKernelContext* context, size_t input_idx, |
734 | bool eager_mode = false) { |
735 | // Sanity check. |
736 | CHECK_NOTNULL(context); |
737 | CHECK_LT(input_idx, context->num_inputs()); |
738 | |
739 | MklDnnShape input_mkl_shape; |
740 | GetMklShape(context, input_idx, &input_mkl_shape, eager_mode); |
741 | if (input_mkl_shape.IsMklTensor() && !eager_mode) { |
742 | return input_mkl_shape.GetTfShape(); |
743 | } else { |
744 | const Tensor& t = MklGetInput(context, input_idx); |
745 | return t.shape(); |
746 | } |
747 | } |
748 | |
749 | // Allocate the second output tensor that will contain |
750 | // the MKL shape serialized |
751 | inline void AllocateOutputSetMklShape(OpKernelContext* ctext, int n, |
752 | const MklDnnShape& mkl_shape) { |
753 | Tensor* second_tensor = nullptr; |
754 | TensorShape second_shape; |
755 | second_shape.AddDim(mkl_shape.GetSerializeBufferSize()); |
756 | OP_REQUIRES_OK(ctext, ctext->allocate_output( |
757 | GetTensorMetaDataIndex(n, ctext->num_outputs()), |
758 | second_shape, &second_tensor)); |
759 | mkl_shape.SerializeMklDnnShape( |
760 | second_tensor->flat<uint8>().data(), |
761 | second_tensor->flat<uint8>().size() * sizeof(uint8)); |
762 | } |
763 | |
764 | // Allocate the output tensor, create a second output tensor that will contain |
765 | // the MKL shape serialized |
766 | inline void AllocateOutputSetMklShape(OpKernelContext* ctext, int n, |
767 | Tensor** output, |
768 | const TensorShape& tf_shape, |
769 | const MklDnnShape& mkl_shape, |
770 | bool eager_mode = false) { |
771 | OP_REQUIRES_OK( |
772 | ctext, ctext->allocate_output(GetTensorDataIndex(n, ctext->num_outputs()), |
773 | tf_shape, output)); |
774 | if (!eager_mode) { |
775 | Tensor* second_tensor = nullptr; |
776 | TensorShape second_shape; |
777 | second_shape.AddDim(mkl_shape.GetSerializeBufferSize()); |
778 | OP_REQUIRES_OK(ctext, ctext->allocate_output( |
779 | GetTensorMetaDataIndex(n, ctext->num_outputs()), |
780 | second_shape, &second_tensor)); |
781 | mkl_shape.SerializeMklDnnShape( |
782 | second_tensor->flat<uint8>().data(), |
783 | second_tensor->flat<uint8>().size() * sizeof(uint8)); |
784 | } |
785 | } |
786 | |
787 | // Allocates a temp tensor and returns the data buffer for temporary storage. |
788 | template <typename T> |
789 | inline void AllocTmpBuffer(OpKernelContext* context, Tensor* tensor_out, |
790 | const memory::desc& pd, void** buf_out) { |
791 | TensorShape tf_shape; |
792 | |
793 | tf_shape.AddDim(pd.get_size() / sizeof(T) + 1); |
794 | OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum<T>::v(), |
795 | tf_shape, tensor_out)); |
796 | *buf_out = static_cast<void*>(tensor_out->flat<T>().data()); |
797 | } |
798 | |
799 | template <typename T> |
800 | inline void AllocTmpBuffer(OpKernelContext* context, Tensor* tensor_out, |
801 | TensorShape tf_shape) { |
802 | OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum<T>::v(), |
803 | tf_shape, tensor_out)); |
804 | } |
805 | |
806 | template <typename T> |
807 | struct UserScratchPad { |
808 | template <typename MklPrim> |
809 | // NOTE: if scratchpad is not required for a particular primitive the |
810 | // spad_md.get_size() will return 0. It is fine to return |
811 | // nullptr in this case |
812 | inline void AllocateSPTensor(MklPrim* mkl_prim, OpKernelContext* context) { |
813 | allocated_ = false; |
814 | auto spad_md = mkl_prim->GetScratchPadDesc(); |
815 | size_t spad_size = spad_md.get_size(); |
816 | if (spad_size == 0) return; |
817 | |
818 | size_t allocate_size = (spad_size + sizeof(T) - 1) / sizeof(T); |
819 | TensorShape tf_shape; |
820 | tf_shape.AddDim(allocate_size); |
821 | AllocTmpBuffer<T>(context, &scratch_pad_, tf_shape); |
822 | allocated_ = true; |
823 | } |
824 | inline void* Get() { |
825 | if (allocated_) { |
826 | return static_cast<void*>(scratch_pad_.flat<T>().data()); |
827 | } else { |
828 | return nullptr; |
829 | } |
830 | } |
831 | |
832 | private: |
833 | Tensor scratch_pad_; |
834 | bool allocated_ = false; |
835 | }; |
836 | |
837 | inline void GetStridesFromSizes(MklTensorFormat data_format, size_t* strides, |
838 | const size_t* sizes) { |
839 | DCHECK_NE(data_format, MklTensorFormat::FORMAT_INVALID); |
840 | // MKL requires strides in NCHW |
841 | if (data_format == MklTensorFormat::FORMAT_NHWC) { |
842 | strides[0] = sizes[2]; |
843 | strides[1] = sizes[0] * sizes[2]; |
844 | strides[2] = 1; |
845 | strides[3] = sizes[0] * sizes[1] * sizes[2]; |
846 | } else { |
847 | strides[0] = 1; |
848 | strides[1] = sizes[0]; |
849 | strides[2] = sizes[0] * sizes[1]; |
850 | strides[3] = sizes[0] * sizes[1] * sizes[2]; |
851 | } |
852 | } |
853 | |
854 | inline void CopyMklTensorInToOut(OpKernelContext* context, int idx_in, |
855 | int idx_out) { |
856 | int num_inputs = context->num_inputs(); |
857 | int num_outputs = context->num_outputs(); |
858 | int idx_data_in = GetTensorDataIndex(idx_in, num_inputs); |
859 | int idx_meta_in = GetTensorMetaDataIndex(idx_in, num_inputs); |
860 | int idx_data_out = GetTensorDataIndex(idx_out, num_outputs); |
861 | int idx_meta_out = GetTensorMetaDataIndex(idx_out, num_outputs); |
862 | |
863 | const Tensor& data = context->input(idx_data_in); |
864 | const Tensor& meta = context->input(idx_meta_in); |
865 | Tensor output(data.dtype()); |
866 | Tensor meta_output(meta.dtype()); |
867 | |
868 | // TODO(intel-tf): alternatively, call forward_input_to_output_with_shape(...) |
869 | CHECK(output.CopyFrom(data, data.shape())); |
870 | CHECK(meta_output.CopyFrom(meta, meta.shape())); |
871 | context->set_output(idx_data_out, output); |
872 | context->set_output(idx_meta_out, meta_output); |
873 | } |
874 | |
875 | inline void CopyTfTensorInToOutWithShape(OpKernelContext* context, int idx_in, |
876 | int idx_out, |
877 | const TensorShape& shape) { |
878 | int num_inputs = context->num_inputs(); |
879 | int num_outputs = context->num_outputs(); |
880 | int idx_data_in = GetTensorDataIndex(idx_in, num_inputs); |
881 | int idx_data_out = GetTensorDataIndex(idx_out, num_outputs); |
882 | |
883 | const Tensor& data = context->input(idx_data_in); |
884 | MklDnnShape mkl_shape_output; |
885 | mkl_shape_output.SetMklTensor(false); |
886 | AllocateOutputSetMklShape(context, idx_out, mkl_shape_output); |
887 | Tensor output(data.dtype()); |
888 | // TODO(intel-tf): alternatively, call forward_input_to_output_with_shape(...) |
889 | CHECK(output.CopyFrom(data, shape)); |
890 | context->set_output(idx_data_out, output); |
891 | } |
892 | |
893 | inline void ForwardTfTensorInToOut(OpKernelContext* context, int idx_in, |
894 | int idx_out) { |
895 | int num_inputs = context->num_inputs(); |
896 | int num_outputs = context->num_outputs(); |
897 | int idx_data_in = GetTensorDataIndex(idx_in, num_inputs); |
898 | int idx_data_out = GetTensorDataIndex(idx_out, num_outputs); |
899 | |
900 | MklDnnShape dnn_shape_output; |
901 | dnn_shape_output.SetMklTensor(false); |
902 | AllocateOutputSetMklShape(context, idx_out, dnn_shape_output); |
903 | if (IsRefType(context->input_dtype(idx_data_in))) { |
904 | context->forward_ref_input_to_ref_output(idx_data_in, idx_data_out); |
905 | } else { |
906 | context->set_output(idx_data_out, context->input(idx_data_in)); |
907 | } |
908 | } |
909 | |
910 | inline void ForwardMklTensorInToOut(OpKernelContext* context, int idx_in, |
911 | int idx_out) { |
912 | int num_inputs = context->num_inputs(); |
913 | int num_outputs = context->num_outputs(); |
914 | int idx_data_in = GetTensorDataIndex(idx_in, num_inputs); |
915 | int idx_meta_in = GetTensorMetaDataIndex(idx_in, num_inputs); |
916 | int idx_data_out = GetTensorDataIndex(idx_out, num_outputs); |
917 | int idx_meta_out = GetTensorMetaDataIndex(idx_out, num_outputs); |
918 | |
919 | if (IsRefType(context->input_dtype(idx_data_in))) { |
920 | context->forward_ref_input_to_ref_output(idx_data_in, idx_data_out); |
921 | context->forward_ref_input_to_ref_output(idx_meta_in, idx_meta_out); |
922 | } else { |
923 | context->set_output(idx_data_out, context->input(idx_data_in)); |
924 | context->set_output(idx_meta_out, context->input(idx_meta_in)); |
925 | } |
926 | } |
927 | |
928 | // Set a dummy oneDNN shape (called when the output is in TF format) |
929 | inline void SetDummyMklDnnShapeOutput(OpKernelContext* context, |
930 | uint32 idx_data_out) { |
931 | MklDnnShape mkl_shape_output; |
932 | mkl_shape_output.SetMklTensor(false); |
933 | AllocateOutputSetMklShape(context, idx_data_out, mkl_shape_output); |
934 | } |
935 | |
936 | // If the input tensor has ref count as 1, it is forwarded to the desired |
937 | // output port and the function returns true. In that case, it also allocates |
938 | // the serialized MklDnnShape object. Otherwise, the function returns false. |
939 | inline bool ForwardMklTensorInToOutWithMklShape(OpKernelContext* context, |
940 | int idx_in, int idx_out, |
941 | Tensor** output, |
942 | const MklDnnShape& mkl_shape, |
943 | bool always_forward = true) { |
944 | int num_inputs = context->num_inputs(); |
945 | int num_outputs = context->num_outputs(); |
946 | int idx_data_in = GetTensorDataIndex(idx_in, num_inputs); |
947 | int idx_data_out = GetTensorDataIndex(idx_out, num_outputs); |
948 | bool is_forwarded = false; |
949 | const Tensor& input_tensor = context->input(idx_data_in); |
950 | const auto output_shape = input_tensor.shape(); |
951 | if (always_forward) { |
952 | if (IsRefType(context->input_dtype(idx_data_in))) { |
953 | context->forward_ref_input_to_ref_output(idx_data_in, idx_data_out); |
954 | } else { |
955 | context->set_output(idx_data_out, input_tensor); |
956 | } |
957 | } else { |
958 | is_forwarded = context->forward_input_to_output_with_shape( |
959 | idx_data_in, idx_data_out, output_shape, output); |
960 | } |
961 | if (is_forwarded || always_forward) { |
962 | AllocateOutputSetMklShape(context, idx_out, mkl_shape); |
963 | return true; |
964 | } |
965 | return false; |
966 | } |
967 | |
968 | // Forward the MKL shape ONLY (used in elementwise and other ops where |
969 | // we call the eigen implementation and MKL shape is not used) |
970 | inline void ForwardMklMetaDataInToOut(OpKernelContext* context, |
971 | uint32 idx_data_in, |
972 | uint32_t idx_data_out) { |
973 | uint32 idx_meta_in = |
974 | GetTensorMetaDataIndex(idx_data_in, context->num_inputs()); |
975 | uint32 idx_meta_out = |
976 | GetTensorMetaDataIndex(idx_data_out, context->num_outputs()); |
977 | |
978 | if (IsRefType(context->input_dtype(idx_data_in))) { |
979 | context->forward_ref_input_to_ref_output(idx_meta_in, idx_meta_out); |
980 | } else { |
981 | context->set_output(idx_meta_out, context->input(idx_meta_in)); |
982 | } |
983 | } |
984 | |
985 | // ------------------------------------------------------------------- |
986 | // Common utility functions used by MKL unit tests |
987 | |
988 | inline Tensor GetMklMetaTensor() { |
989 | MklDnnShape non_mkl_shape; |
990 | non_mkl_shape.SetMklTensor(false); |
991 | |
992 | auto size = static_cast<int64_t>(non_mkl_shape.GetSerializeBufferSize()); |
993 | Tensor tensor(DT_UINT8, {size}); |
994 | |
995 | non_mkl_shape.SerializeMklDnnShape(tensor.flat<uint8>().data(), |
996 | size * sizeof(uint8)); |
997 | return tensor; |
998 | } |
999 | |
1000 | // ------------------------------------------------------------------- |
1001 | |
1002 | /// Return oneDNN data type (memory::data_type) for input type T |
1003 | /// |
1004 | /// @input None |
1005 | /// @return memory::data_type corresponding to type T |
1006 | template <typename T> |
1007 | static memory::data_type MklDnnType(); |
1008 | |
1009 | /// Instantiation for float type. Add similar instantiations for other |
1010 | /// type if needed. |
1011 | template <> |
1012 | memory::data_type MklDnnType<float>() { |
1013 | return memory::data_type::f32; |
1014 | } |
1015 | |
1016 | template <> |
1017 | memory::data_type MklDnnType<quint8>() { |
1018 | return memory::data_type::u8; |
1019 | } |
1020 | |
1021 | template <> |
1022 | memory::data_type MklDnnType<uint8>() { |
1023 | return memory::data_type::u8; |
1024 | } |
1025 | |
1026 | template <> |
1027 | memory::data_type MklDnnType<qint8>() { |
1028 | return memory::data_type::s8; |
1029 | } |
1030 | |
1031 | template <> |
1032 | memory::data_type MklDnnType<qint32>() { |
1033 | return memory::data_type::s32; |
1034 | } |
1035 | template <> |
1036 | memory::data_type MklDnnType<bfloat16>() { |
1037 | return memory::data_type::bf16; |
1038 | } |
1039 | |
1040 | // Map MklTensorFormat to oneDNN format tag |
1041 | // |
1042 | // @input: MklTensorFormat i.e. TensorFlow data format |
1043 | // @return: oneDNN's memory format tag corresponding to MklTensorFormat. |
1044 | // Fails with an error if invalid data format. |
1045 | inline memory::format_tag MklTensorFormatToMklDnnDataFormat( |
1046 | MklTensorFormat format) { |
1047 | if (format == MklTensorFormat::FORMAT_NHWC) return memory::format_tag::nhwc; |
1048 | if (format == MklTensorFormat::FORMAT_NCHW) return memory::format_tag::nchw; |
1049 | if (format == MklTensorFormat::FORMAT_NDHWC) return memory::format_tag::ndhwc; |
1050 | if (format == MklTensorFormat::FORMAT_NCDHW) return memory::format_tag::ncdhw; |
1051 | if (format == MklTensorFormat::FORMAT_X) return memory::format_tag::x; |
1052 | if (format == MklTensorFormat::FORMAT_NC) return memory::format_tag::nc; |
1053 | if (format == MklTensorFormat::FORMAT_TNC) return memory::format_tag::tnc; |
1054 | return memory::format_tag::undef; |
1055 | } |
1056 | |
1057 | /// Map TensorFlow data format into oneDNN 3D data format |
1058 | /// @input: TensorFlow data format |
1059 | /// @return: oneDNN 3D data format corresponding to TensorFlow data format; |
1060 | /// Fails with an error if invalid data format. |
1061 | inline MklTensorFormat TFDataFormatToMklDnn3DDataFormat(TensorFormat format) { |
1062 | if (format == FORMAT_NHWC) return MklTensorFormat::FORMAT_NDHWC; |
1063 | if (format == FORMAT_NCHW) return MklTensorFormat::FORMAT_NCDHW; |
1064 | TF_CHECK_OK(Status(error::Code::INVALID_ARGUMENT, "Unsupported data format" )); |
1065 | return MklTensorFormat::FORMAT_INVALID; |
1066 | } |
1067 | |
1068 | /// Map TensorFlow data format into oneDNN data format |
1069 | /// |
1070 | /// @input: TensorFlow data format |
1071 | /// @return: oneDNN data format corresponding to TensorFlow data format; |
1072 | /// Fails with an error if invalid data format. |
1073 | inline MklTensorFormat TFDataFormatToMklDnnDataFormat(TensorFormat format) { |
1074 | if (format == FORMAT_NHWC) return MklTensorFormat::FORMAT_NHWC; |
1075 | if (format == FORMAT_NCHW) return MklTensorFormat::FORMAT_NCHW; |
1076 | TF_CHECK_OK(Status(error::Code::INVALID_ARGUMENT, "Unsupported data format" )); |
1077 | return MklTensorFormat::FORMAT_INVALID; |
1078 | } |
1079 | |
1080 | /// Map oneDNN data format into TensorFlow data format |
1081 | /// |
1082 | /// @input: oneDNN data format |
1083 | /// @return: Tensorflow data format corresponding to oneDNN data format; |
1084 | /// Fails with an error if invalid data format. |
1085 | inline TensorFormat MklDnnDataFormatToTFDataFormat(MklTensorFormat format) { |
1086 | if (format == MklTensorFormat::FORMAT_NHWC || |
1087 | format == MklTensorFormat::FORMAT_NDHWC) |
1088 | return FORMAT_NHWC; |
1089 | if (format == MklTensorFormat::FORMAT_NCHW || |
1090 | format == MklTensorFormat::FORMAT_NCDHW) |
1091 | return FORMAT_NCHW; |
1092 | TF_CHECK_OK(Status(error::Code::INVALID_ARGUMENT, "Unsupported data format" )); |
1093 | |
1094 | // Return to prevent compiler warnings, otherwise TF_CHECK_OK will ensure |
1095 | // that we don't come here. |
1096 | return FORMAT_NHWC; |
1097 | } |
1098 | |
1099 | /// Map TensorShape object into memory::dims required by oneDNN |
1100 | /// |
1101 | /// This function will simply map input TensorShape into oneDNN dims |
1102 | /// naively. So it will preserve the order of dimensions. E.g., if |
1103 | /// input tensor is in NHWC format, then dims will be in NHWC format also. |
1104 | /// |
1105 | /// @input TensorShape object in shape |
1106 | /// @return memory::dims corresponding to TensorShape |
1107 | inline memory::dims TFShapeToMklDnnDims(const TensorShape& shape) { |
1108 | memory::dims dims(shape.dims()); |
1109 | for (int d = 0; d < shape.dims(); ++d) { |
1110 | dims[d] = shape.dim_size(d); |
1111 | } |
1112 | return dims; |
1113 | } |
1114 | |
1115 | /// Map TensorShape object into memory::dims in NCHW format required by oneDNN |
1116 | /// |
1117 | /// This function is a specific one than above function. It will map input |
1118 | /// TensorShape into oneDNN dims in NCHW format. So it may not preserve the |
1119 | /// order of dimensions. E.g., if input tensor is in NHWC format, then dims |
1120 | /// will be in NCHW format, and not in NHWC format. |
1121 | /// |
1122 | /// @input TensorShape object in shape |
1123 | /// @return memory::dims in oneDNN required NCHW format |
1124 | inline memory::dims TFShapeToMklDnnDimsInNCHW(const TensorShape& shape, |
1125 | TensorFormat format) { |
1126 | // Check validity of format. |
1127 | DCHECK_NE(TFDataFormatToMklDnnDataFormat(format), |
1128 | MklTensorFormat::FORMAT_INVALID); |
1129 | |
1130 | int n = shape.dim_size(GetTensorDimIndex(format, 'N')); |
1131 | int c = shape.dim_size(GetTensorDimIndex(format, 'C')); |
1132 | int h = shape.dim_size(GetTensorDimIndex(format, 'H')); |
1133 | int w = shape.dim_size(GetTensorDimIndex(format, 'W')); |
1134 | |
1135 | // oneDNN requires dimensions in NCHW format. |
1136 | return memory::dims({n, c, h, w}); |
1137 | } |
1138 | |
1139 | inline memory::dims TFShapeToMklDnnDimsInNCDHW(const TensorShape& shape, |
1140 | TensorFormat format) { |
1141 | // Validate format. |
1142 | DCHECK_NE(TFDataFormatToMklDnn3DDataFormat(format), |
1143 | MklTensorFormat::FORMAT_INVALID); |
1144 | |
1145 | int n = shape.dim_size(GetTensorDimIndex<3>(format, 'N')); |
1146 | int c = shape.dim_size(GetTensorDimIndex<3>(format, 'C')); |
1147 | int d = shape.dim_size(GetTensorDimIndex<3>(format, '0')); |
1148 | int h = shape.dim_size(GetTensorDimIndex<3>(format, '1')); |
1149 | int w = shape.dim_size(GetTensorDimIndex<3>(format, '2')); |
1150 | |
1151 | // oneDNN requires dimensions in NCDHW format. |
1152 | return memory::dims({n, c, d, h, w}); |
1153 | } |
1154 | |
1155 | /// Overloaded version of function TFShapeToMklDnnDimsInNCHW above. |
1156 | /// Input parameters are self-explanatory. |
1157 | inline memory::dims MklDnnDimsInNCHW(const memory::dims& in_dims, |
1158 | TensorFormat format) { |
1159 | // Validate format. |
1160 | DCHECK_NE(TFDataFormatToMklDnnDataFormat(format), |
1161 | MklTensorFormat::FORMAT_INVALID); |
1162 | |
1163 | int n = in_dims[GetTensorDimIndex(format, 'N')]; |
1164 | int c = in_dims[GetTensorDimIndex(format, 'C')]; |
1165 | int h = in_dims[GetTensorDimIndex(format, 'H')]; |
1166 | int w = in_dims[GetTensorDimIndex(format, 'W')]; |
1167 | |
1168 | // oneDNN requires dimensions in NCHW format. |
1169 | return memory::dims({n, c, h, w}); |
1170 | } |
1171 | |
1172 | /// Overloaded version of function TFShapeToMklDnnDimsInNCDHW above. |
1173 | /// Input parameters are self-explanatory. |
1174 | inline memory::dims MklDnnDimsInNCDHW(const memory::dims& in_dims, |
1175 | TensorFormat format) { |
1176 | // Validate format. |
1177 | DCHECK_NE(TFDataFormatToMklDnnDataFormat(format), |
1178 | MklTensorFormat::FORMAT_INVALID); |
1179 | |
1180 | int n = in_dims[GetTensorDimIndex<3>(format, 'N')]; |
1181 | int c = in_dims[GetTensorDimIndex<3>(format, 'C')]; |
1182 | int d = in_dims[GetTensorDimIndex<3>(format, '0')]; |
1183 | int h = in_dims[GetTensorDimIndex<3>(format, '1')]; |
1184 | int w = in_dims[GetTensorDimIndex<3>(format, '2')]; |
1185 | |
1186 | // MKL DNN requires dimensions in NCDHW format. |
1187 | return memory::dims({n, c, d, h, w}); |
1188 | } |
1189 | |
1190 | /// Map MklDnn memory::dims object into TensorShape object. |
1191 | /// |
1192 | /// This function will simply map input shape in oneDNN memory::dims format |
1193 | /// in Tensorflow's TensorShape object by preserving dimension order. |
1194 | /// |
1195 | /// @input oneDNN memory::dims object |
1196 | /// @output TensorShape corresponding to memory::dims |
1197 | inline TensorShape MklDnnDimsToTFShape(const memory::dims& dims) { |
1198 | std::vector<int32> shape(dims.size(), -1); |
1199 | for (int d = 0; d < dims.size(); d++) { |
1200 | shape[d] = dims[d]; |
1201 | } |
1202 | |
1203 | TensorShape ret; |
1204 | CHECK_EQ(TensorShapeUtils::MakeShape(shape, &ret).ok(), true); |
1205 | return ret; |
1206 | } |
1207 | |
1208 | /// Function to calculate strides given tensor shape in Tensorflow order |
1209 | /// E.g., if dims_tf_order is {1, 2, 3, 4}, then as per Tensorflow convention, |
1210 | /// dimension with size 1 is outermost dimension; while dimension with size 4 is |
1211 | /// innermost dimension. So strides for this tensor would be {4 * 3 * 2, |
1212 | /// 4 * 3, 4, 1}, i.e., {24, 12, 4, 1}. |
1213 | /// |
1214 | /// @input Tensorflow shape in memory::dims type |
1215 | /// @return memory::dims containing strides for the tensor. |
1216 | inline memory::dims CalculateTFStrides(const memory::dims& dims_tf_order) { |
1217 | CHECK_GT(dims_tf_order.size(), 0); |
1218 | memory::dims strides(dims_tf_order.size()); |
1219 | int last_dim_idx = dims_tf_order.size() - 1; |
1220 | strides[last_dim_idx] = 1; |
1221 | for (int d = last_dim_idx - 1; d >= 0; d--) { |
1222 | strides[d] = strides[d + 1] * dims_tf_order[d + 1]; |
1223 | } |
1224 | return strides; |
1225 | } |
1226 | |
1227 | /// Helper function to create memory descriptor in Blocked format |
1228 | /// |
1229 | /// @input: Tensor dimensions |
1230 | /// @input: strides corresponding to dimensions. One can use utility |
1231 | /// function such as CalculateTFStrides to compute strides |
1232 | /// for given dimensions. |
1233 | /// @output: dnnl_memory_desc_t object corresponding to blocked memory |
1234 | /// format for given dimensions and strides. |
1235 | /// @return: Status indicating whether the blocked memory descriptor |
1236 | /// was successfully created. |
1237 | inline Status CreateBlockedMemDescHelper(const memory::dims& dim, |
1238 | const memory::dims& strides, |
1239 | memory::data_type dtype, |
1240 | dnnl_memory_desc_t* blocked_md) { |
1241 | DCHECK_EQ(dim.size(), strides.size()); |
1242 | const int kNumDims = dim.size(); |
1243 | dnnl_dim_t* input_dims = new dnnl_dim_t[kNumDims]; |
1244 | dnnl_dim_t* input_strides = new dnnl_dim_t[kNumDims]; |
1245 | for (int i = 0; i < kNumDims; ++i) { |
1246 | input_dims[i] = dim[i]; |
1247 | input_strides[i] = strides[i]; |
1248 | } |
1249 | try { |
1250 | dnnl_memory_desc_init_by_strides(blocked_md, kNumDims, input_dims, |
1251 | memory::convert_to_c(dtype), |
1252 | input_strides); |
1253 | delete[] input_dims; |
1254 | delete[] input_strides; |
1255 | } catch (dnnl::error& e) { |
1256 | delete[] input_dims; |
1257 | delete[] input_strides; |
1258 | return Status(error::Code::INTERNAL, |
1259 | tensorflow::strings::StrCat( |
1260 | "Failed to create blocked memory descriptor." , |
1261 | "Status: " , e.status, ", message: " , e.message)); |
1262 | } |
1263 | return Status::OK(); |
1264 | } |
1265 | |
1266 | inline void CreateAndExecuteReorder(const ReorderPd& reorder_desc, |
1267 | const memory& src_mem, |
1268 | const memory& dst_mem, const engine& engine, |
1269 | OpKernelContext* ctx = nullptr) { |
1270 | std::vector<primitive> net; |
1271 | net.push_back(dnnl::reorder(reorder_desc)); |
1272 | std::vector<MemoryArgsMap> net_args; |
1273 | net_args.push_back({{DNNL_ARG_FROM, src_mem}, {DNNL_ARG_TO, dst_mem}}); |
1274 | ExecutePrimitive(net, &net_args, engine, ctx); |
1275 | } |
1276 | |
1277 | class MklReorderPrimitive; |
1278 | |
1279 | template <typename T> |
1280 | inline MklReorderPrimitive* FindOrCreateReorder(const memory* from, |
1281 | const memory* to); |
1282 | |
1283 | // Class to represent all the resources corresponding to a tensor in TensorFlow |
1284 | // that are required to execute an operation (such as Convolution). |
1285 | template <typename T> |
1286 | class MklDnnData { |
1287 | private: |
1288 | /// oneDNN memory primitive for input user memory |
1289 | memory* user_memory_; |
1290 | |
1291 | /// oneDNN memory primitive in case input or output reorder is needed. |
1292 | memory* reorder_memory_; |
1293 | |
1294 | /// Operations memory descriptor |
1295 | memory::desc* op_md_; |
1296 | // flat to indicate if data is 3D or not. |
1297 | bool bIs3D; |
1298 | /// Operations temp buffer |
1299 | void* allocated_buffer_; |
1300 | /// CPU engine on which operation will be executed |
1301 | const engine* cpu_engine_; |
1302 | |
1303 | public: |
1304 | explicit MklDnnData(const engine* e) |
1305 | : user_memory_(nullptr), |
1306 | reorder_memory_(nullptr), |
1307 | op_md_(nullptr), |
1308 | bIs3D(false), |
1309 | allocated_buffer_(nullptr), |
1310 | cpu_engine_(e) {} |
1311 | |
1312 | // MklDnnData does not use any smart pointers, |
1313 | // hence default operator= will result in memory leak if user_memory was |
1314 | // already initialized. See |
1315 | // https://github.com/tensorflow/tensorflow/pull/45593 as an example of such |
1316 | // leak. |
1317 | MklDnnData(const MklDnnData&) = default; |
1318 | MklDnnData& operator=(const MklDnnData&) = delete; |
1319 | |
1320 | ~MklDnnData() { |
1321 | if (allocated_buffer_ != nullptr) { |
1322 | cpu_allocator()->DeallocateRaw(allocated_buffer_); |
1323 | } |
1324 | cpu_engine_ = nullptr; // We don't own this. |
1325 | delete (user_memory_); |
1326 | delete (reorder_memory_); |
1327 | delete (op_md_); |
1328 | } |
1329 | |
1330 | inline void* GetTensorBuffer(const Tensor* tensor) const { |
1331 | CHECK_NOTNULL(tensor); |
1332 | return const_cast<void*>( |
1333 | static_cast<const void*>(tensor->flat<T>().data())); |
1334 | } |
1335 | |
1336 | void SetIs3DData(bool bIs3D_) { bIs3D = bIs3D_; } |
1337 | bool GetIs3D() { return bIs3D; } |
1338 | |
1339 | /// Set user memory primitive using specified dimensions, memory format tag |
1340 | /// and data_buffer. Function automatically uses element data type by using |
1341 | /// input type T used for creating call object. |
1342 | /// |
1343 | /// In a nutshell, function allows user to describe the input tensor to |
1344 | /// an operation. E.g., filter of Conv2D is of shape {1, 2, 3, 4}, and |
1345 | /// memory format tag HWIO, and the buffer that contains actual values is |
1346 | /// pointed by data_buffer. |
1347 | inline void SetUsrMem(const memory::dims& dim, memory::format_tag fm, |
1348 | void* data_buffer = nullptr) { |
1349 | auto md = memory::desc(dim, MklDnnType<T>(), fm); |
1350 | SetUsrMem(md, data_buffer); |
1351 | } |
1352 | |
1353 | inline void SetUsrMem(const memory::dims& dim, memory::format_tag fm, |
1354 | const Tensor* tensor) { |
1355 | DCHECK(tensor); |
1356 | SetUsrMem(dim, fm, GetTensorBuffer(tensor)); |
1357 | } |
1358 | |
1359 | /// Helper function to create memory descriptor in Blocked format |
1360 | /// |
1361 | /// @input: Tensor dimensions |
1362 | /// @input: strides corresponding to dimensions. One can use utility |
1363 | /// function such as CalculateTFStrides to compute strides |
1364 | /// for given dimensions. |
1365 | /// @return: memory::desc object corresponding to blocked memory format |
1366 | /// for given dimensions and strides. |
1367 | static inline memory::desc CreateBlockedMemDesc(const memory::dims& dim, |
1368 | const memory::dims& strides) { |
1369 | dnnl_memory_desc_t blocked_md; |
1370 | TF_CHECK_OK( |
1371 | CreateBlockedMemDescHelper(dim, strides, MklDnnType<T>(), &blocked_md)); |
1372 | return memory::desc(blocked_md); |
1373 | } |
1374 | |
1375 | /// A version of SetUsrMem call that allows user to create memory in blocked |
1376 | /// format. So in addition to accepting dimensions, it also accepts strides. |
1377 | /// This allows user to create memory for tensor in a format that is not |
1378 | /// supported by oneDNN. E.g., oneDNN does not support tensor format for 6 |
1379 | /// dimensional tensor as a native format. But by using blocked format, a user |
1380 | /// can create memory for 6D tensor. |
1381 | inline void SetUsrMem(const memory::dims& dim, const memory::dims& strides, |
1382 | void* data_buffer = nullptr) { |
1383 | CHECK_EQ(dim.size(), strides.size()); |
1384 | auto blocked_md = MklDnnData<T>::CreateBlockedMemDesc(dim, strides); |
1385 | SetUsrMem(blocked_md, data_buffer); |
1386 | } |
1387 | |
1388 | inline void SetUsrMem(const memory::dims& dim, const memory::dims& strides, |
1389 | const Tensor* tensor) { |
1390 | CHECK_NOTNULL(tensor); |
1391 | SetUsrMem(dim, strides, GetTensorBuffer(tensor)); |
1392 | } |
1393 | |
1394 | /// A version of SetUsrMem with memory descriptor and tensor |
1395 | inline void SetUsrMem(const memory::desc& md, const Tensor* tensor) { |
1396 | CHECK_NOTNULL(tensor); |
1397 | SetUsrMem(md, GetTensorBuffer(tensor)); |
1398 | } |
1399 | |
1400 | /// A version of function to set user memory type that accepts memory |
1401 | /// descriptor directly, instead of accepting dimensions and format. This |
1402 | /// function is more generic than the one above, but the function above is |
1403 | /// sufficient in most cases. |
1404 | inline void SetUsrMem(const memory::desc& pd, void* data_buffer = nullptr) { |
1405 | DCHECK(cpu_engine_); |
1406 | if (user_memory_) delete user_memory_; |
1407 | // TODO(intel-tf): can we remove dynamic memory allocation? |
1408 | if (data_buffer) { |
1409 | user_memory_ = new memory(pd, *cpu_engine_, data_buffer); |
1410 | } else { |
1411 | user_memory_ = new memory(pd, *cpu_engine_); |
1412 | } |
1413 | } |
1414 | |
1415 | /// Get function for user memory primitive. |
1416 | inline const memory* GetUsrMem() const { return user_memory_; } |
1417 | |
1418 | /// Get function for descriptor of user memory. |
1419 | inline memory::desc GetUsrMemDesc() const { |
1420 | DCHECK(user_memory_); |
1421 | return user_memory_->get_desc(); |
1422 | } |
1423 | |
1424 | /// Get function for data buffer of user memory primitive. |
1425 | inline void* GetUsrMemDataHandle() const { |
1426 | CHECK_NOTNULL(user_memory_); |
1427 | return user_memory_->get_data_handle(); |
1428 | } |
1429 | |
1430 | /// Set function for data buffer of user memory primitive. |
1431 | inline void SetUsrMemDataHandle(void* data_buffer, |
1432 | std::shared_ptr<stream> t_stream = nullptr) { |
1433 | CHECK_NOTNULL(user_memory_); |
1434 | CHECK_NOTNULL(data_buffer); |
1435 | #ifndef ENABLE_ONEDNN_OPENMP |
1436 | user_memory_->set_data_handle(data_buffer, *t_stream); |
1437 | #else |
1438 | user_memory_->set_data_handle(data_buffer); |
1439 | #endif // !ENABLE_ONEDNN_OPENMP |
1440 | } |
1441 | |
1442 | /// Set function for data buffer of user memory primitive. |
1443 | inline void SetUsrMemDataHandle(const Tensor* tensor, |
1444 | std::shared_ptr<stream> t_stream = nullptr) { |
1445 | SetUsrMemDataHandle(GetTensorBuffer(tensor), t_stream); |
1446 | } |
1447 | |
1448 | /// allocate function for data buffer |
1449 | inline void AllocateBuffer(size_t size) { |
1450 | const int64 kMemoryAlignment = 64; // For AVX512 memory alignment. |
1451 | allocated_buffer_ = cpu_allocator()->AllocateRaw(kMemoryAlignment, size); |
1452 | } |
1453 | |
1454 | inline void* GetAllocatedBuffer() { return allocated_buffer_; } |
1455 | |
1456 | /// Get the memory primitive for input and output of an op. If inputs |
1457 | /// to an op require reorders, then this function returns memory primitive |
1458 | /// for reorder. Otherwise, it will return memory primitive for user memory. |
1459 | /// |
1460 | /// E.g., Conv2D(I, F) is a primitive with I and F being inputs. Then to |
1461 | /// execute Conv2D, we need memory primitive for I and F. But if reorder is |
1462 | /// required for I and F (say I_r is reorder primitive for I; F_r is reorder |
1463 | /// primitive for F), then we need I_r and F_r to perform Conv2D. |
1464 | inline const memory& GetOpMem() const { |
1465 | return reorder_memory_ ? *reorder_memory_ : *user_memory_; |
1466 | } |
1467 | |
1468 | /// Set memory descriptor of an operation in terms of dimensions and memory |
1469 | /// format. E.g., For Conv2D, the dimensions would be same as user dimensions |
1470 | /// but memory::format_tag would be dnnl::any because we want oneDNN to |
1471 | /// choose the best layout/format for given input dimensions. |
1472 | inline void SetOpMemDesc(const memory::dims& dim, memory::format_tag fm) { |
1473 | // TODO(intel-tf): can we remove dynamic memory allocation? |
1474 | op_md_ = new memory::desc(dim, MklDnnType<T>(), fm); |
1475 | } |
1476 | |
1477 | /// Get function for memory descriptor for an operation |
1478 | inline const memory::desc& GetOpMemDesc() const { return *op_md_; } |
1479 | |
1480 | /// Predicate that checks if we need to reorder user's memory into memory |
1481 | /// pointed by op_md. |
1482 | /// |
1483 | /// @input: op_md - memory descriptor of the given input of an operation. |
1484 | /// @return: true in case reorder of input is needed; false, otherwise. |
1485 | inline bool IsReorderNeeded(const memory::desc& op_pd) const { |
1486 | DCHECK(user_memory_); |
1487 | return op_pd != user_memory_->get_desc(); |
1488 | } |
1489 | |
1490 | /// Function to create a reorder from memory pointed by from to memory pointed |
1491 | /// by to. Returns created primitive. |
1492 | inline primitive CreateReorder(const memory* from, const memory* to) const { |
1493 | CHECK_NOTNULL(from); |
1494 | CHECK_NOTNULL(to); |
1495 | return reorder(*from, *to); |
1496 | } |
1497 | |
1498 | /// Function to handle input reordering |
1499 | /// |
1500 | /// Check if we need to reorder this input of an operation. |
1501 | /// Return true and allocate reorder memory primitive if reorder is needed. |
1502 | /// Otherwise, return false and do not allocate reorder memory primitive. |
1503 | /// |
1504 | /// To check if reorder is needed, this function compares memory primitive |
1505 | /// descriptor (memory descriptor for v1.x) of an operation (op_pd) for |
1506 | /// the given input with the user-specified memory descriptor. |
1507 | /// |
1508 | /// @input: op_pd - memory primitive descriptor of the given input of an |
1509 | /// operation |
1510 | /// @input: net - net to which to add reorder primitive in case it is needed. |
1511 | /// @input: net_args - net to which user and reorder memories are added if |
1512 | /// needed. Each entry is a key-value pair of the form |
1513 | /// <argument-type, dnnl::memory>. |
1514 | /// @return: true in case reorder of input is needed; false, otherwise. |
1515 | inline bool CheckReorderToOpMem(const memory::desc& op_md, |
1516 | std::vector<primitive>& net, |
1517 | std::vector<MemoryArgsMap>& net_args, |
1518 | const engine& engine) { |
1519 | DCHECK(user_memory_); |
1520 | DCHECK_EQ(net.size(), net_args.size()); |
1521 | if (IsReorderNeeded(op_md)) { |
1522 | // TODO(intel-tf): can we remove dynamic memory allocation? |
1523 | reorder_memory_ = new memory(op_md, engine); |
1524 | net.push_back(CreateReorder(user_memory_, reorder_memory_)); |
1525 | net_args.push_back(MemoryArgsMap{{DNNL_ARG_FROM, *user_memory_}, |
1526 | {DNNL_ARG_TO, *reorder_memory_}}); |
1527 | return true; |
1528 | } |
1529 | return false; |
1530 | } |
1531 | |
1532 | inline bool CheckReorderToOpMem(const memory::desc& op_md, |
1533 | const engine& engine, |
1534 | OpKernelContext* context = nullptr) { |
1535 | DCHECK(user_memory_); |
1536 | if (IsReorderNeeded(op_md)) { |
1537 | // TODO(intel-tf): can we remove dynamic memory allocation? |
1538 | // primitive reuse don't allow two same reorder prim in |
1539 | // one stream, so submit it immediately |
1540 | reorder_memory_ = new memory(op_md, engine); |
1541 | auto* prim = FindOrCreateReorder<T>(user_memory_, reorder_memory_); |
1542 | std::shared_ptr<stream> cpu_stream; |
1543 | MklDnnThreadPool eigen_tp; |
1544 | if (context != nullptr) { |
1545 | eigen_tp = MklDnnThreadPool(context); |
1546 | cpu_stream.reset(CreateStream(&eigen_tp, prim->GetEngine())); |
1547 | } else { |
1548 | cpu_stream.reset(CreateStream(nullptr, prim->GetEngine())); |
1549 | } |
1550 | std::vector<primitive> net; |
1551 | net.push_back(*(prim->GetPrimitive())); |
1552 | std::vector<MemoryArgsMap> net_args; |
1553 | net_args.push_back( |
1554 | {{DNNL_ARG_FROM, *user_memory_}, {DNNL_ARG_TO, *reorder_memory_}}); |
1555 | execute_primitives(net, cpu_stream, net_args); |
1556 | return true; |
1557 | } |
1558 | return false; |
1559 | } |
1560 | |
1561 | /// Overloaded version of above function that accepts memory buffer |
1562 | /// where output of reorder needs to be stored. |
1563 | /// |
1564 | /// @input: op_pd - memory primitive descriptor (memory descriptor for v1.x) |
1565 | /// of the given input of an operation |
1566 | /// @reorder_data_handle - memory buffer where output of reorder needs to be |
1567 | /// stored. Primitive does not check if buffer has |
1568 | /// enough size to write. |
1569 | /// @input: net - net to which to add reorder primitive in case it is needed. |
1570 | /// @input: net_args - net to which user and reorder memories are added if |
1571 | /// needed. Each entry is a key-value pair of the form |
1572 | /// <argument-type, dnnl::memory>. |
1573 | /// @input: engine - oneDNN's abstraction of a computational device |
1574 | /// @return: true in case reorder of input is needed; false, otherwise. |
1575 | inline bool CheckReorderToOpMem(const memory::desc& op_md, |
1576 | void* reorder_data_handle, |
1577 | std::vector<primitive>& net, |
1578 | std::vector<MemoryArgsMap>& net_args, |
1579 | const engine& engine) { |
1580 | DCHECK(reorder_data_handle); |
1581 | DCHECK(user_memory_); |
1582 | if (IsReorderNeeded(op_md)) { |
1583 | // TODO(intel-tf): can we remove dynamic memory allocation? |
1584 | reorder_memory_ = new memory(op_md, engine, reorder_data_handle); |
1585 | net.push_back(CreateReorder(user_memory_, reorder_memory_)); |
1586 | net_args.push_back(MemoryArgsMap{{DNNL_ARG_FROM, *user_memory_}, |
1587 | {DNNL_ARG_TO, *reorder_memory_}}); |
1588 | return true; |
1589 | } |
1590 | return false; |
1591 | } |
1592 | |
1593 | /// This is a faster path with reorder primitive cache compared with |
1594 | /// CheckReorderToOpMem(..., std::vector<primitive>* net). |
1595 | /// The slower path will be removed in the future |
1596 | /// TODO(intel-tf): Need to use reorder cache here for better performance. |
1597 | inline bool CheckReorderToOpMem(const memory::desc& op_md, |
1598 | void* reorder_data_handle, |
1599 | const engine& engine, |
1600 | OpKernelContext* context = nullptr) { |
1601 | DCHECK(reorder_data_handle); |
1602 | DCHECK(user_memory_); |
1603 | if (IsReorderNeeded(op_md)) { |
1604 | // TODO(intel-tf): can we remove dynamic memory allocation? |
1605 | // primitive reuse don't allow two same reorder prim in |
1606 | // one stream, so submit it immediately |
1607 | reorder_memory_ = new memory(op_md, engine, reorder_data_handle); |
1608 | auto* prim = FindOrCreateReorder<T>(user_memory_, reorder_memory_); |
1609 | std::shared_ptr<stream> cpu_stream; |
1610 | MklDnnThreadPool eigen_tp; |
1611 | if (context != nullptr) { |
1612 | eigen_tp = MklDnnThreadPool(context); |
1613 | cpu_stream.reset(CreateStream(&eigen_tp, prim->GetEngine())); |
1614 | } else { |
1615 | cpu_stream.reset(CreateStream(nullptr, prim->GetEngine())); |
1616 | } |
1617 | std::vector<primitive> net; |
1618 | net.push_back(*(prim->GetPrimitive())); |
1619 | std::vector<MemoryArgsMap> net_args; |
1620 | net_args.push_back( |
1621 | {{DNNL_ARG_FROM, *user_memory_}, {DNNL_ARG_TO, *reorder_memory_}}); |
1622 | execute_primitives(net, cpu_stream, net_args); |
1623 | return true; |
1624 | } |
1625 | return false; |
1626 | } |
1627 | |
1628 | /// Another overloaded version of CheckReorderToOpMem that accepts Tensor |
1629 | /// where output of reorder needs to be stored. |
1630 | /// |
1631 | /// @input: op_md - memory primitive descriptor (memory descriptor for v1.x) |
1632 | /// of the given input of an operation |
1633 | /// @reorder_tensor - Tensor whose buffer is to be used to store output of |
1634 | /// reorder. Primitive does not check if buffer is |
1635 | /// enough size to write. |
1636 | /// @input: net - net to which to add reorder primitive in case it is needed. |
1637 | /// @input: net_args - net to which user and reorder memories are added if |
1638 | /// needed. Each entry is a key-value pair of the form |
1639 | /// <argument-type, dnnl::memory>. |
1640 | /// @input: engine - MKL-DNN's abstraction of a computational device |
1641 | /// @return: true in case reorder of input is needed; false, otherwise. |
1642 | inline bool CheckReorderToOpMem(const memory::desc& op_md, |
1643 | Tensor* reorder_tensor, |
1644 | std::vector<primitive>& net, |
1645 | std::vector<MemoryArgsMap>& net_args, |
1646 | const engine& engine) { |
1647 | DCHECK(reorder_tensor); |
1648 | return CheckReorderToOpMem(op_md, GetTensorBuffer(reorder_tensor), net, |
1649 | net_args, engine); |
1650 | } |
1651 | |
1652 | /// TODO(intel-tf): this is a faster path with reorder primitive cache |
1653 | /// compared with CheckReorderToOpMem(op_md, reorder_tensor, net, net_args, |
1654 | /// engine), will remove slow path in the future. |
1655 | inline bool CheckReorderToOpMem(const memory::desc& op_pd, |
1656 | Tensor* reorder_tensor, |
1657 | OpKernelContext* ctx = nullptr) { |
1658 | DCHECK(reorder_tensor); |
1659 | return CheckReorderToOpMem(op_pd, GetTensorBuffer(reorder_tensor), |
1660 | *cpu_engine_, ctx); |
1661 | } |
1662 | |
1663 | /// Function to handle output reorder |
1664 | /// |
1665 | /// This function performs very similar functionality as input reordering |
1666 | /// function above. The only difference is that this function does not add |
1667 | /// reorder primitive to the net. The reason for this is: the reorder |
1668 | /// primitive for output needs to be added to the list only after operation |
1669 | /// has executed. But we need to prepare a temporary buffer in case output |
1670 | /// reorder is needed. And this temporary buffer will hold the output of |
1671 | /// an operation before it is fed to reorder primitive. |
1672 | /// |
1673 | /// @input - memory primitive descriptor (memory descriptor for v1.x) for the |
1674 | /// given output of an operation |
1675 | /// @return: true in case reorder of output is needed; false, otherwise. |
1676 | inline bool PrepareReorderToUserMemIfReq(const memory::desc& op_pd) { |
1677 | DCHECK(user_memory_); |
1678 | if (IsReorderNeeded(op_pd)) { |
1679 | // TODO(intel-tf): can we remove dynamic memory allocation? |
1680 | reorder_memory_ = new memory(op_pd, *cpu_engine_); |
1681 | return true; |
1682 | } |
1683 | return false; |
1684 | } |
1685 | |
1686 | /// Function to actually insert reorder primitive in the net |
1687 | /// |
1688 | /// This function completes remaining part of output reordering. It inserts |
1689 | /// a reordering primitive from the temporary buffer that holds the output |
1690 | /// to the user-specified output buffer. |
1691 | /// |
1692 | /// @input: net - net to which to add reorder primitive |
1693 | /// @input: net_args - net to which user and reorder memories are added if |
1694 | /// needed. Each entry is a key-value pair of the form |
1695 | /// <argument-type, dnnl::memory>. |
1696 | inline void InsertReorderToUserMem(std::vector<primitive>& net, |
1697 | std::vector<MemoryArgsMap>& net_args) { |
1698 | DCHECK(user_memory_); |
1699 | DCHECK(reorder_memory_); |
1700 | net.push_back(CreateReorder(reorder_memory_, user_memory_)); |
1701 | net_args.push_back(MemoryArgsMap{{DNNL_ARG_FROM, *reorder_memory_}, |
1702 | {DNNL_ARG_TO, *user_memory_}}); |
1703 | } |
1704 | |
1705 | /// TODO(intel-tf): this is a faster path with reorder primitive cache |
1706 | /// compared with InsertReorderToUserMem(net, net_args), will remove |
1707 | /// slow path in the future |
1708 | inline void InsertReorderToUserMem(OpKernelContext* ctx = nullptr) { |
1709 | DCHECK(user_memory_); |
1710 | DCHECK(reorder_memory_); |
1711 | DCHECK(cpu_engine_); |
1712 | // primitive reuse don't allow two same reorder prim in |
1713 | // one stream, so submit it immediately |
1714 | std::vector<primitive> net; |
1715 | auto* prim = FindOrCreateReorder<T>(reorder_memory_, user_memory_); |
1716 | net.push_back(*(prim->GetPrimitive())); |
1717 | std::vector<MemoryArgsMap> net_args; |
1718 | net_args.push_back( |
1719 | {{DNNL_ARG_FROM, *reorder_memory_}, {DNNL_ARG_TO, *user_memory_}}); |
1720 | std::shared_ptr<stream> cpu_stream; |
1721 | MklDnnThreadPool eigen_tp; |
1722 | if (ctx != nullptr) { |
1723 | eigen_tp = MklDnnThreadPool(ctx); |
1724 | cpu_stream.reset(CreateStream(&eigen_tp, prim->GetEngine())); |
1725 | } else { |
1726 | cpu_stream.reset(CreateStream(nullptr, prim->GetEngine())); |
1727 | } |
1728 | execute_primitives(net, cpu_stream, net_args); |
1729 | } |
1730 | }; |
1731 | |
1732 | /// Base class for operations with reuse of primitives |
1733 | class MklPrimitive { |
1734 | public: |
1735 | virtual ~MklPrimitive() {} |
1736 | MklPrimitive() {} |
1737 | MklPrimitive(const engine& cpu_engine) { cpu_engine_ = cpu_engine; } |
1738 | // Dummy data which MKL DNN never operates on |
1739 | unsigned char* DummyData = nullptr; |
1740 | engine cpu_engine_ = engine(engine::kind::cpu, 0); |
1741 | const engine& GetEngine() { return cpu_engine_; } |
1742 | }; |
1743 | |
1744 | const dnnl::memory::dims NONE_DIMS = {}; |
1745 | |
1746 | // |
1747 | // LRUCache is a class which implements LRU (Least Recently Used) cache. |
1748 | // The implementation is similar to that of |
1749 | // tensorflow/core/platform/cloud/expiring_lru_cache.h |
1750 | // without its thread-safe part because the cache is supposed to be |
1751 | // used as thread local (for instance, MklPrimitive caching). |
1752 | // |
1753 | // The LRU list maintains objects in chronological order based on |
1754 | // creation time, with the least recently accessed object at the |
1755 | // tail of LRU list, while the most recently accessed object |
1756 | // at the head of LRU list. |
1757 | // |
1758 | // This class is used to maintain an upper bound on the total number of |
1759 | // cached items. When the cache reaches its capacity, the LRU item will |
1760 | // be removed and replaced by a new one from SetOp call. |
1761 | // |
1762 | template <typename T> |
1763 | class LRUCache { |
1764 | public: |
1765 | explicit LRUCache(size_t capacity) { |
1766 | capacity_ = capacity; |
1767 | Clear(); |
1768 | } |
1769 | |
1770 | T* GetOp(const string& key) { |
1771 | #ifdef DNNL_AARCH64_USE_ACL |
1772 | mutex_lock lock(lru_mu_); |
1773 | #endif |
1774 | auto it = cache_.find(key); |
1775 | if (it == cache_.end()) { |
1776 | return nullptr; |
1777 | } |
1778 | |
1779 | // Move to the front of LRU list as the most recently accessed. |
1780 | lru_list_.erase(it->second.lru_iterator); |
1781 | lru_list_.push_front(it->first); |
1782 | it->second.lru_iterator = lru_list_.begin(); |
1783 | return it->second.op; |
1784 | } |
1785 | |
1786 | void SetOp(const string& key, T* op) { |
1787 | #ifdef DNNL_AARCH64_USE_ACL |
1788 | mutex_lock lock(lru_mu_); |
1789 | #endif |
1790 | if (lru_list_.size() >= capacity_) { |
1791 | Delete(); |
1792 | } |
1793 | |
1794 | // Insert an entry to the front of the LRU list |
1795 | lru_list_.push_front(key); |
1796 | Entry entry(op, lru_list_.begin()); |
1797 | cache_.emplace(std::make_pair(key, std::move(entry))); |
1798 | #ifdef DNNL_AARCH64_USE_ACL |
1799 | FinishedAllocation(key); |
1800 | #endif |
1801 | } |
1802 | |
1803 | void Clear() { |
1804 | if (lru_list_.empty()) return; |
1805 | |
1806 | // Clean up the cache |
1807 | cache_.clear(); |
1808 | lru_list_.clear(); |
1809 | } |
1810 | |
1811 | #ifdef DNNL_AARCH64_USE_ACL |
1812 | bool IsAllocating(const string& key) { |
1813 | mutex_lock lock(in_flight_mu_); |
1814 | return in_flight_.find(key) != in_flight_.end(); |
1815 | } |
1816 | |
1817 | void Allocate(const string& key) { |
1818 | mutex_lock lock(in_flight_mu_); |
1819 | in_flight_.insert(key); |
1820 | } |
1821 | |
1822 | void FinishedAllocation(const string& key) { |
1823 | mutex_lock lock(in_flight_mu_); |
1824 | in_flight_.erase(key); |
1825 | } |
1826 | #endif |
1827 | |
1828 | private: |
1829 | struct Entry { |
1830 | // The entry's value. |
1831 | T* op; |
1832 | |
1833 | // A list iterator pointing to the entry's position in the LRU list. |
1834 | std::list<string>::iterator lru_iterator; |
1835 | |
1836 | // Constructor |
1837 | Entry(T* op, std::list<string>::iterator it) { |
1838 | this->op = op; |
1839 | this->lru_iterator = it; |
1840 | } |
1841 | |
1842 | // Move constructor |
1843 | Entry(Entry&& source) noexcept |
1844 | : lru_iterator(std::move(source.lru_iterator)) { |
1845 | op = std::move(source.op); |
1846 | source.op = std::forward<T*>(nullptr); |
1847 | } |
1848 | |
1849 | // Destructor |
1850 | ~Entry() { |
1851 | if (op != nullptr) delete op; |
1852 | } |
1853 | }; |
1854 | |
1855 | // Remove the least recently accessed entry from LRU list, which |
1856 | // is the tail of lru_list_. Update cache_ correspondingly. |
1857 | bool Delete() { |
1858 | if (lru_list_.empty()) return false; |
1859 | string key = lru_list_.back(); |
1860 | lru_list_.pop_back(); |
1861 | cache_.erase(key); |
1862 | return true; |
1863 | } |
1864 | |
1865 | // Cache capacity |
1866 | size_t capacity_; |
1867 | |
1868 | // The cache, a map from string key to a LRU entry. |
1869 | std::unordered_map<string, Entry> cache_; |
1870 | |
1871 | // The LRU list of entries. |
1872 | // The front of the list contains the key of the most recently accessed |
1873 | // entry, while the back of the list is the least recently accessed entry. |
1874 | std::list<string> lru_list_; |
1875 | |
1876 | #ifdef DNNL_AARCH64_USE_ACL |
1877 | // Guards access to the cache and LRU list |
1878 | mutex lru_mu_; |
1879 | |
1880 | // The keys that are currently under creation |
1881 | std::set<string> in_flight_; |
1882 | TF_GUARDED_BY(in_flight_mu_) |
1883 | mutex in_flight_mu_; |
1884 | #endif |
1885 | }; |
1886 | |
1887 | template <typename T> |
1888 | class MklPrimitiveFactory { |
1889 | public: |
1890 | MklPrimitiveFactory() {} |
1891 | |
1892 | ~MklPrimitiveFactory() {} |
1893 | |
1894 | MklPrimitive* GetOp(const string& key) { |
1895 | #ifndef DNNL_AARCH64_USE_ACL |
1896 | auto& lru_cache = MklPrimitiveFactory<T>::GetLRUCache(); |
1897 | return lru_cache.GetOp(key); |
1898 | #else |
1899 | while (true) { |
1900 | // TODO(milpuz01): Consider if it is possible to narrow scope to be |
1901 | // only around checks for allocations and conditional wait. |
1902 | mutex_lock lock(primitive_creation_mu_); |
1903 | auto& lru_cache = MklPrimitiveFactory<T>::GetLRUCache(); |
1904 | |
1905 | // Check to see whether primitive already exists. |
1906 | MklPrimitive* primitive = lru_cache.GetOp(key); |
1907 | if (primitive != nullptr) { |
1908 | return primitive; |
1909 | } |
1910 | |
1911 | // Now check whether some other thread is creating this primitive. |
1912 | if (!lru_cache.IsAllocating(key)) { |
1913 | // This thread is going to pick it up and create the primitive. |
1914 | lru_cache.Allocate(key); |
1915 | return nullptr; |
1916 | // Now we release lock as primitive creation might take long time. |
1917 | } |
1918 | |
1919 | // At this point we cannot create primitive as other thread is creating |
1920 | // it. We should wait for primitive to get created. |
1921 | primitive_creation_cv_.wait(lock); |
1922 | |
1923 | // The primitive is created and is in the cache so we are going to try |
1924 | // retrieve it again after getting a lock on it as multiple threads might |
1925 | // be waiting for the primitive. |
1926 | } |
1927 | #endif |
1928 | } |
1929 | |
1930 | void SetOp(const string& key, MklPrimitive* op) { |
1931 | #ifndef DNNL_AARCH64_USE_ACL |
1932 | auto& lru_cache = MklPrimitiveFactory<T>::GetLRUCache(); |
1933 | lru_cache.SetOp(key, op); |
1934 | #else |
1935 | { |
1936 | mutex_lock lock(primitive_creation_mu_); |
1937 | auto& lru_cache = MklPrimitiveFactory<T>::GetLRUCache(); |
1938 | lru_cache.SetOp(key, op); |
1939 | } |
1940 | |
1941 | // Now we can inform all waiting threads that primitive is created. |
1942 | primitive_creation_cv_.notify_all(); |
1943 | #endif |
1944 | } |
1945 | |
1946 | /// Function to decide whether HW has AVX512 or AVX2 |
1947 | /// For those legacy device(w/o AVX512 and AVX2), |
1948 | /// MKL-DNN GEMM will be used. |
1949 | static inline bool IsLegacyPlatform() { |
1950 | static const bool is_legacy_platform = |
1951 | (!port::TestCPUFeature(port::CPUFeature::AVX512F) && |
1952 | !port::TestCPUFeature(port::CPUFeature::AVX2)); |
1953 | return is_legacy_platform; |
1954 | } |
1955 | |
1956 | /// Function to check whether primitive memory optimization is enabled |
1957 | static inline bool IsPrimitiveMemOptEnabled() { |
1958 | static const bool is_primitive_mem_opt_enabled = [] { |
1959 | bool value = true; |
1960 | TF_CHECK_OK( |
1961 | ReadBoolFromEnvVar("TF_MKL_OPTIMIZE_PRIMITIVE_MEMUSE" , true, &value)); |
1962 | return value; |
1963 | }(); |
1964 | return is_primitive_mem_opt_enabled; |
1965 | } |
1966 | |
1967 | #ifdef DNNL_AARCH64_USE_ACL |
1968 | static int IncrementCounter() { |
1969 | static std::atomic_int counter{1}; |
1970 | return counter.fetch_add(1); |
1971 | } |
1972 | #endif |
1973 | |
1974 | private: |
1975 | static inline LRUCache<MklPrimitive>& GetLRUCache() { |
1976 | static const int kCapacity = 1024; // cache capacity |
1977 | #ifndef DNNL_AARCH64_USE_ACL |
1978 | static thread_local LRUCache<MklPrimitive> lru_cache_(kCapacity); |
1979 | #else |
1980 | static LRUCache<MklPrimitive> lru_cache_(kCapacity); |
1981 | TF_GUARDED_BY(lru_mu_) |
1982 | #endif |
1983 | return lru_cache_; |
1984 | } |
1985 | |
1986 | #ifdef DNNL_AARCH64_USE_ACL |
1987 | mutex primitive_creation_mu_; |
1988 | condition_variable primitive_creation_cv_; |
1989 | #endif |
1990 | }; |
1991 | |
1992 | // utility class for creating keys of MKL primitive pool. |
1993 | class FactoryKeyCreator { |
1994 | public: |
1995 | FactoryKeyCreator() { key_.reserve(kMaxKeyLength); } |
1996 | |
1997 | ~FactoryKeyCreator() {} |
1998 | |
1999 | void AddAsKey(const string& str) { Append(str); } |
2000 | |
2001 | void AddAsKey(const dnnl::memory::dims& dims) { |
2002 | for (unsigned int i = 0; i < dims.size(); i++) { |
2003 | AddAsKey<int>(dims[i]); |
2004 | } |
2005 | } |
2006 | |
2007 | template <typename T> |
2008 | void AddAsKey(const T data) { |
2009 | auto buffer = reinterpret_cast<const char*>(&data); |
2010 | Append(StringPiece(buffer, sizeof(T))); |
2011 | } |
2012 | |
2013 | // generalisation to handle pointers |
2014 | void AddAsKey(const void* data) { |
2015 | auto buffer = reinterpret_cast<const char*>(&data); |
2016 | Append(StringPiece(buffer, sizeof(data))); |
2017 | } |
2018 | |
2019 | string GetKey() { return key_; } |
2020 | |
2021 | private: |
2022 | string key_; |
2023 | const char delimiter = 'x'; |
2024 | const int kMaxKeyLength = 256; |
2025 | void Append(StringPiece s) { |
2026 | key_.append(string(s)); |
2027 | key_.append(1, delimiter); |
2028 | } |
2029 | }; |
2030 | |
2031 | class MklReorderPrimitive : public MklPrimitive { |
2032 | public: |
2033 | explicit MklReorderPrimitive(const memory* from, const memory* to) |
2034 | : MklPrimitive(engine(engine::kind::cpu, 0)) { |
2035 | Setup(from, to); |
2036 | } |
2037 | ~MklReorderPrimitive() {} |
2038 | |
2039 | std::shared_ptr<primitive> GetPrimitive() { return context_.reorder_prim; } |
2040 | |
2041 | void SetMemory(const memory* from, const memory* to) { |
2042 | context_.src_mem->set_data_handle(from->get_data_handle()); |
2043 | context_.dst_mem->set_data_handle(to->get_data_handle()); |
2044 | } |
2045 | |
2046 | std::shared_ptr<dnnl::stream> GetStream() { return stream_; } |
2047 | |
2048 | private: |
2049 | struct ReorderContext { |
2050 | std::shared_ptr<dnnl::memory> src_mem; |
2051 | std::shared_ptr<dnnl::memory> dst_mem; |
2052 | std::shared_ptr<primitive> reorder_prim; |
2053 | ReorderContext() |
2054 | : src_mem(nullptr), dst_mem(nullptr), reorder_prim(nullptr) {} |
2055 | } context_; |
2056 | |
2057 | std::shared_ptr<dnnl::stream> stream_; |
2058 | |
2059 | void Setup(const memory* from, const memory* to) { |
2060 | context_.src_mem.reset( |
2061 | new memory(from->get_desc(), cpu_engine_, DummyData)); |
2062 | context_.dst_mem.reset(new memory(to->get_desc(), cpu_engine_, DummyData)); |
2063 | context_.reorder_prim = std::make_shared<dnnl::reorder>( |
2064 | reorder(*context_.src_mem, *context_.dst_mem)); |
2065 | stream_.reset(new stream(cpu_engine_)); |
2066 | } |
2067 | }; |
2068 | |
2069 | template <typename T> |
2070 | class MklReorderPrimitiveFactory : public MklPrimitiveFactory<T> { |
2071 | public: |
2072 | static MklReorderPrimitive* Get(const memory* from, const memory* to) { |
2073 | auto reorderPrim = static_cast<MklReorderPrimitive*>( |
2074 | MklReorderPrimitiveFactory<T>::GetInstance().GetReorder(from, to)); |
2075 | if (reorderPrim == nullptr) { |
2076 | reorderPrim = new MklReorderPrimitive(from, to); |
2077 | MklReorderPrimitiveFactory<T>::GetInstance().SetReorder(from, to, |
2078 | reorderPrim); |
2079 | } |
2080 | reorderPrim->SetMemory(from, to); |
2081 | return reorderPrim; |
2082 | } |
2083 | |
2084 | static MklReorderPrimitiveFactory& GetInstance() { |
2085 | static MklReorderPrimitiveFactory instance_; |
2086 | return instance_; |
2087 | } |
2088 | |
2089 | static string CreateKey(const memory* from, const memory* to) { |
2090 | string prefix = "reorder" ; |
2091 | FactoryKeyCreator key_creator; |
2092 | auto const& from_desc = from->get_desc().data; |
2093 | auto const& to_desc = to->get_desc().data; |
2094 | memory::dims from_dims(from_desc.dims, &from_desc.dims[from_desc.ndims]); |
2095 | memory::dims to_dims(to_desc.dims, &to_desc.dims[to_desc.ndims]); |
2096 | auto from_strides = from_desc.format_desc.blocking.strides; |
2097 | |
2098 | // As DNNL memory desc has C style array and only init the used |
2099 | // part, so need use the valid part as key. |
2100 | auto from_inner_nblks = from_desc.format_desc.blocking.inner_nblks; |
2101 | auto from_inner_blks = from_desc.format_desc.blocking.inner_blks; |
2102 | auto from_inner_idxs = from_desc.format_desc.blocking.inner_idxs; |
2103 | memory::dims from_inner_blks_1(from_inner_blks, |
2104 | &from_inner_blks[from_inner_nblks]); |
2105 | memory::dims from_inner_idxs_1(from_inner_idxs, |
2106 | &from_inner_idxs[from_inner_nblks]); |
2107 | auto to_inner_nblks = to_desc.format_desc.blocking.inner_nblks; |
2108 | auto to_inner_blks = to_desc.format_desc.blocking.inner_blks; |
2109 | auto to_inner_idxs = to_desc.format_desc.blocking.inner_idxs; |
2110 | memory::dims to_inner_blks_1(to_inner_blks, &to_inner_blks[to_inner_nblks]); |
2111 | memory::dims to_inner_idxs_1(to_inner_idxs, &to_inner_idxs[to_inner_nblks]); |
2112 | |
2113 | auto to_strides = to_desc.format_desc.blocking.strides; |
2114 | memory::dims from_strides_outer_blocks(from_strides, |
2115 | &from_strides[from_desc.ndims]); |
2116 | memory::dims to_strides_outer_blocks(to_strides, |
2117 | &to_strides[to_desc.ndims]); |
2118 | |
2119 | key_creator.AddAsKey(prefix); |
2120 | #ifdef DNNL_AARCH64_USE_ACL |
2121 | // The reorder primitives have local memory (calls to SetMemory) so we |
2122 | // need to make sure that memory for those primitives is cached per thread. |
2123 | key_creator.AddAsKey(std::this_thread::get_id()); |
2124 | #endif |
2125 | key_creator.AddAsKey(static_cast<int>(from_desc.extra.flags)); |
2126 | key_creator.AddAsKey(static_cast<int>(from_inner_nblks)); |
2127 | key_creator.AddAsKey(from_inner_blks_1); |
2128 | key_creator.AddAsKey(from_inner_idxs_1); |
2129 | key_creator.AddAsKey(static_cast<int>(from_desc.data_type)); |
2130 | key_creator.AddAsKey(from_dims); |
2131 | key_creator.AddAsKey(from_strides_outer_blocks); |
2132 | key_creator.AddAsKey(static_cast<int>(to_desc.extra.flags)); |
2133 | key_creator.AddAsKey(static_cast<int>(to_inner_nblks)); |
2134 | key_creator.AddAsKey(to_inner_blks_1); |
2135 | key_creator.AddAsKey(to_inner_idxs_1); |
2136 | key_creator.AddAsKey(static_cast<int>(to_desc.data_type)); |
2137 | key_creator.AddAsKey(to_dims); |
2138 | key_creator.AddAsKey(to_strides_outer_blocks); |
2139 | return key_creator.GetKey(); |
2140 | } |
2141 | |
2142 | private: |
2143 | MklReorderPrimitiveFactory() {} |
2144 | ~MklReorderPrimitiveFactory() {} |
2145 | |
2146 | MklPrimitive* GetReorder(const memory* from, const memory* to) { |
2147 | string key = CreateKey(from, to); |
2148 | return this->GetOp(key); |
2149 | } |
2150 | |
2151 | void SetReorder(const memory* from, const memory* to, MklPrimitive* op) { |
2152 | string key = CreateKey(from, to); |
2153 | this->SetOp(key, op); |
2154 | } |
2155 | }; |
2156 | |
2157 | /// Function to find(or create) a reorder from memory pointed by |
2158 | /// from to memory pointed by to, it will created primitive or |
2159 | /// get primitive from pool if it is cached. |
2160 | /// Returns the primitive. |
2161 | template <typename T> |
2162 | inline MklReorderPrimitive* FindOrCreateReorder(const memory* from, |
2163 | const memory* to) { |
2164 | CHECK_NOTNULL(from); |
2165 | CHECK_NOTNULL(to); |
2166 | MklReorderPrimitive* reorder_prim = |
2167 | MklReorderPrimitiveFactory<T>::Get(from, to); |
2168 | return reorder_prim; |
2169 | } |
2170 | |
2171 | // utility function to determine if it is conv 1x1 and stride != 1 |
2172 | // for purpose of temporarily disabling primitive reuse |
2173 | inline bool IsConv1x1StrideNot1(memory::dims filter_dims, |
2174 | memory::dims strides) { |
2175 | if (filter_dims.size() != 4 || strides.size() != 2) return false; |
2176 | |
2177 | return ((filter_dims[2] == 1) && (filter_dims[3] == 1) && |
2178 | ((strides[0] != 1) || (strides[1] != 1))); |
2179 | } |
2180 | |
2181 | } // namespace tensorflow |
2182 | |
2183 | ///////////////////////////////////////////////////////////////////// |
2184 | // Macros for handling registration for various types |
2185 | ///////////////////////////////////////////////////////////////////// |
2186 | |
2187 | #define REGISTER_TEST_FLOAT32(TEST) REGISTER_TEST(TEST, DT_FLOAT, Float32Input); |
2188 | |
2189 | #define REGISTER_TEST_BFLOAT16(TEST) \ |
2190 | REGISTER_TEST(TEST, DT_BFLOAT16, BFloat16Input); |
2191 | |
2192 | #define REGISTER_TEST_ALL_TYPES(TEST) \ |
2193 | REGISTER_TEST_FLOAT32(TEST); \ |
2194 | REGISTER_TEST_BFLOAT16(TEST); |
2195 | #else |
2196 | #define REGISTER_TEST_ALL_TYPES(TEST) REGISTER_TEST_FLOAT32(TEST); |
2197 | |
2198 | #endif // INTEL_MKL |
2199 | #endif // TENSORFLOW_CORE_UTIL_MKL_UTIL_H_ |
2200 | |