1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_CORE_UTIL_MKL_UTIL_H_
17#define TENSORFLOW_CORE_UTIL_MKL_UTIL_H_
18#ifdef INTEL_MKL
19
20#include <list>
21#include <memory>
22#include <string>
23#include <unordered_map>
24#include <utility>
25#include <vector>
26
27#include "dnnl.hpp"
28#include "tensorflow/core/framework/op_kernel.h"
29#include "tensorflow/core/framework/tensor.h"
30#include "tensorflow/core/framework/tensor_shape.h"
31#include "tensorflow/core/graph/mkl_graph_util.h"
32#include "tensorflow/core/lib/core/errors.h"
33#include "tensorflow/core/lib/core/stringpiece.h"
34#include "tensorflow/core/lib/gtl/array_slice.h"
35#include "tensorflow/core/platform/cpu_info.h"
36#include "tensorflow/core/platform/logging.h"
37#include "tensorflow/core/platform/macros.h"
38#include "tensorflow/core/util/env_var.h"
39#include "tensorflow/core/util/mkl_threadpool.h"
40#include "tensorflow/core/util/padding.h"
41#include "tensorflow/core/util/tensor_format.h"
42#ifdef DNNL_AARCH64_USE_ACL
43#include "tensorflow/core/platform/mutex.h"
44#endif
45
46using dnnl::engine;
47using dnnl::memory;
48using dnnl::primitive;
49using dnnl::reorder;
50using dnnl::stream;
51using CPUDevice = Eigen::ThreadPoolDevice;
52using MemoryArgsMap = std::unordered_map<int, memory>;
53using ReorderPd = dnnl::reorder::primitive_desc;
54
55#ifdef _WIN32
56typedef unsigned int uint;
57#endif
58
59namespace tensorflow {
60
61// The file contains a number of utility classes and functions used by MKL
62// enabled kernels
63
64// This class encapsulates all the meta data that is associated with an MKL
65// tensor. A tensor is an MKL tensor if it was created as the result of an
66// MKL operation, and did not go through a conversion to a standard
67// Tensorflow tensor.
68
69// The dimensions order that oneDNN internally uses for 2D activations
70// [Batch, Channel, Height, Width] and
71// for 2D filters [Out_Channel, In_Channel, Height, Width].
72typedef enum {
73 Dim_N = 0,
74 Dim_C = 1,
75 Dim_H = 2,
76 Dim_W = 3,
77 Dim_O = 0,
78 Dim_I = 1
79} MklDnnDims;
80
81// The dimensions order that oneDNN internally uses for 3D activations
82// [Batch, Channel, Depth, Height, Width] and
83// for 3D filters [Out_Channel, In_Channel, Depth, Height, Width].
84typedef enum {
85 Dim3d_N = 0,
86 Dim3d_C = 1,
87 Dim3d_D = 2,
88 Dim3d_H = 3,
89 Dim3d_W = 4,
90 Dim3d_O = 0,
91 Dim3d_I = 1
92} MklDnnDims3D;
93
94// Enum for the order of dimensions of a TF 2D filter with shape [filter_height,
95// filter_width, in_channels, out_channels]
96typedef enum {
97 TF_2DFILTER_DIM_H = 0,
98 TF_2DFILTER_DIM_W = 1,
99 TF_2DFILTER_DIM_I = 2,
100 TF_2DFILTER_DIM_O = 3
101} TFFilterDims2d;
102
103// Enum for the order of dimensions of a TF 3D filter with shape [filter_depth,
104// filter_height, filter_width, in_channels, out_channels]
105typedef enum {
106 TF_3DFILTER_DIM_P = 0,
107 TF_3DFILTER_DIM_H = 1,
108 TF_3DFILTER_DIM_W = 2,
109 TF_3DFILTER_DIM_I = 3,
110 TF_3DFILTER_DIM_O = 4
111} TFFilterDims3d;
112
113// The dimensions order that oneDNN requires for the filter in a grouped
114// convolution (2D only)
115typedef enum {
116 MKL_GROUP_FILTER_DIM_G = 0,
117 MKL_GROUP_FILTER_DIM_O = 1,
118 MKL_GROUP_FILTER_DIM_I = 2,
119 MKL_GROUP_FILTER_DIM_H = 3,
120 MKL_GROUP_FILTER_DIM_W = 4
121} MklDnnFilterGroupDims;
122
123// Enum used to templatize MklOp kernel implementation
124// that support both fp32 and int8 versions.
125enum class MklQuantization {
126 QUANTIZED_VERSION,
127 FP_VERSION,
128};
129
130static const int kSmallBatchSize = 32;
131
132inline void execute_primitives(
133 std::vector<dnnl::primitive>& primitives, std::shared_ptr<stream> stream,
134 std::vector<std::unordered_map<int, memory>>& net_args) {
135 DCHECK_EQ(primitives.size(), net_args.size());
136 for (size_t i = 0; i < primitives.size(); ++i) {
137 primitives.at(i).execute(*stream, net_args.at(i));
138 }
139}
140
141// In oneDNN v1.x, the format (ex. NCHW) used to initialize a memory descriptor
142// (md) structure will no longer be recorded in its `format` field. Instead, it
143// will be set to a canonical `blocked` format for every fully described md.
144//
145// Currently, we query this `format` field while mapping oneDNN's data format
146// to TF's data format. Due to the above restriction, we will now get this data
147// format information from TF's `data_format` attribute (i.e. via
148// `TensorFormat`) for oneDNN v1.x.
149//
150// Some oneDNN operators such as ReLU do not have a `data_format` attribute
151// since they are usually in `blocked` format. Therefore, in order to
152// distinguish between blocked and non-blocked formats, we have defined a new
153// enum called `MklTensorFormat` that is semantically similar to `TensorFormat`
154// but with the following additional fields namely:
155// 1) FORMAT_BLOCKED: as described above, this is needed for element-wise
156// operators such as ReLU.
157// 2) FORMAT_INVALID: for error-checking (ex. unsupported format)
158// 3) FORMAT_X, FORMAT_NC, FORMAT_TNC: to distinguish between MKL tensors based
159// on their dimensions in operators such as Softmax, i.e.:
160// FORMAT_X - 1D tensor
161// FORMAT_NC - 2D tensor
162// FORMAT_TNC - 3D tensor
163enum class MklTensorFormat {
164 FORMAT_NHWC = 0,
165 FORMAT_NCHW = 1,
166 FORMAT_NDHWC = 2,
167 FORMAT_NCDHW = 3,
168 FORMAT_X = 4,
169 FORMAT_NC = 5,
170 FORMAT_TNC = 6,
171 FORMAT_BLOCKED = 7,
172 FORMAT_INVALID = 8,
173};
174
175// Forward declarations
176memory::format_tag MklTensorFormatToMklDnnDataFormat(MklTensorFormat format);
177
178TensorFormat MklDnn3DDataFormatToTFDataFormat(MklTensorFormat format);
179TensorFormat MklDnnDataFormatToTFDataFormat(MklTensorFormat format);
180
181memory::dims CalculateTFStrides(const memory::dims& dims_tf_order);
182Status CreateBlockedMemDescHelper(const memory::dims& dim,
183 const memory::dims& strides,
184 memory::data_type dtype,
185 dnnl_memory_desc_t* blocked_md);
186
187inline std::ostream& operator<<(std::ostream& os,
188 const memory::format_tag& tag) {
189 if (tag == memory::format_tag::undef) {
190 os << "undef";
191 } else if (tag == memory::format_tag::any) {
192 os << "any";
193 } else {
194 os << "invalid";
195 }
196 return os;
197}
198
199inline void operator<<(std::ostream& os, const MklTensorFormat& format) {
200 if (format == MklTensorFormat::FORMAT_NHWC) {
201 os << "FORMAT_NHWC";
202 } else if (format == MklTensorFormat::FORMAT_NCHW) {
203 os << "FORMAT_NCHW";
204 } else if (format == MklTensorFormat::FORMAT_NDHWC) {
205 os << "FORMAT_NDHWC";
206 } else if (format == MklTensorFormat::FORMAT_NCDHW) {
207 os << "FORMAT_NCDHW";
208 } else if (format == MklTensorFormat::FORMAT_X) {
209 os << "FORMAT_X";
210 } else if (format == MklTensorFormat::FORMAT_NC) {
211 os << "FORMAT_NC";
212 } else if (format == MklTensorFormat::FORMAT_TNC) {
213 os << "FORMAT_TNC";
214 } else if (format == MklTensorFormat::FORMAT_BLOCKED) {
215 os << "FORMAT_BLOCKED";
216 } else {
217 os << "INVALID FORMAT";
218 }
219}
220
221template <typename T>
222inline bool array_cmp(const T* a1, const T* a2, size_t size) {
223 for (size_t i = 0; i < size; ++i)
224 if (a1[i] != a2[i]) return false;
225 return true;
226}
227
228inline dnnl::stream* CreateStream(MklDnnThreadPool* eigen_tp,
229 const engine& engine) {
230#ifndef ENABLE_ONEDNN_OPENMP
231 if (eigen_tp != nullptr) {
232 stream* tp_stream =
233 new stream(dnnl::threadpool_interop::make_stream(engine, eigen_tp));
234 return tp_stream;
235 } else {
236 stream* tp_stream = new stream(engine);
237 return tp_stream;
238 }
239#else
240 stream* tp_stream = new stream(engine);
241 return tp_stream;
242#endif // !ENABLE_ONEDNN_OPENMP
243}
244
245class MklDnnShape {
246 private:
247 struct MklShapeData {
248 // Flag to indicate if the tensor is an MKL tensor or not
249 bool is_mkl_tensor_ = false;
250 // Number of dimensions in Tensorflow format
251 size_t dimension_ = 0;
252 dnnl_dims_t sizes_; // Required by MKL for conversions
253 MklTensorFormat tf_data_format_ = MklTensorFormat::FORMAT_BLOCKED;
254 memory::data_type T_ = memory::data_type::undef;
255 // MKL layout
256 dnnl_memory_desc_t mkl_md_;
257 /// TF dimension corresponding to this MKL dimension
258 dnnl_dims_t map_;
259 };
260 MklShapeData data_;
261
262 typedef std::remove_extent<dnnl_dims_t>::type dnnl_dim_t;
263
264#define INVALID_DIM_SIZE -1
265
266 public:
267 MklDnnShape() : data_{} {
268 for (size_t i = 0; i < sizeof(data_.sizes_) / sizeof(data_.sizes_[0]);
269 ++i) {
270 data_.sizes_[i] = -1;
271 }
272 for (size_t i = 0; i < sizeof(data_.map_) / sizeof(data_.map_[0]); ++i) {
273 data_.map_[i] = -1;
274 }
275 }
276
277 ~MklDnnShape() {}
278 TF_DISALLOW_COPY_AND_ASSIGN(MklDnnShape); // Cannot copy
279
280 /// Equality function for MklDnnShape objects
281 /// @return true if both are equal; false otherwise.
282 inline bool operator==(const MklDnnShape& input_shape) const {
283 if (this->IsMklTensor() != input_shape.IsMklTensor()) {
284 return false;
285 }
286
287 // If input tensors are in MKL layout, then we check for dimensions and
288 // sizes.
289 if (this->IsMklTensor()) {
290 const dnnl_memory_desc_t& cur_md = (this->GetMklLayout()).data;
291 const dnnl_memory_desc_t& input_shape_md =
292 input_shape.GetMklLayout().data;
293 return this->GetTfShape() == input_shape.GetTfShape() &&
294 dnnl_memory_desc_equal(&cur_md, &input_shape_md);
295 }
296
297 // Both inputs are not MKL tensors.
298 return true;
299 }
300
301 /// Equality operator for MklDnnShape and TFShape.
302 /// Returns: true if TF shapes for both are the same, false otherwise
303 inline bool operator==(const TensorShape& input_shape) const {
304 if (!this->IsMklTensor()) {
305 return false;
306 }
307
308 return this->GetTfShape() == input_shape;
309 }
310
311 inline const bool IsMklTensor() const { return data_.is_mkl_tensor_; }
312 inline void SetMklTensor(bool is_mkl_tensor) {
313 data_.is_mkl_tensor_ = is_mkl_tensor;
314 }
315
316 inline void SetDimensions(const size_t dimension) {
317 data_.dimension_ = dimension;
318 }
319 inline size_t GetDimension(char dimension) const {
320 int index = GetMklDnnTensorDimIndex(dimension);
321 CHECK(index >= 0 && index < this->GetDimension())
322 << "Invalid index from the dimension: " << index << ", " << dimension;
323 return this->DimSize(index);
324 }
325
326 inline size_t GetDimension3D(char dimension) const {
327 int index = GetMklDnnTensor3DDimIndex(dimension);
328 CHECK(index >= 0 && index < this->GetDimension())
329 << "Invalid index from the dimension: " << index << ", " << dimension;
330 return this->DimSize(index);
331 }
332
333 inline int32 GetMklDnnTensorDimIndex(char dimension) const {
334 switch (dimension) {
335 case 'N':
336 return MklDnnDims::Dim_N;
337 case 'C':
338 return MklDnnDims::Dim_C;
339 case 'H':
340 return MklDnnDims::Dim_H;
341 case 'W':
342 return MklDnnDims::Dim_W;
343 default:
344 LOG(FATAL) << "Invalid dimension: " << dimension;
345 return -1; // Avoid compiler warning about missing return value
346 }
347 }
348
349 inline int32 GetMklDnnTensor3DDimIndex(char dimension) const {
350 switch (dimension) {
351 case 'N':
352 return MklDnnDims3D::Dim3d_N;
353 case 'C':
354 return MklDnnDims3D::Dim3d_C;
355 case 'D':
356 return MklDnnDims3D::Dim3d_D;
357 case 'H':
358 return MklDnnDims3D::Dim3d_H;
359 case 'W':
360 return MklDnnDims3D::Dim3d_W;
361 default:
362 LOG(FATAL) << "Invalid dimension: " << dimension;
363 return -1; // Avoid compiler warning about missing return value
364 }
365 }
366
367 inline size_t GetDimension() const { return data_.dimension_; }
368 inline const int* GetSizes() const {
369 return reinterpret_cast<const int*>(&data_.sizes_[0]);
370 }
371
372 // Returns an dnnl::memory::dims object that contains the sizes of this
373 // MklDnnShape object.
374 inline memory::dims GetSizesAsMklDnnDims() const {
375 memory::dims retVal;
376 if (data_.is_mkl_tensor_) {
377 size_t dimensions = sizeof(data_.sizes_) / sizeof(data_.sizes_[0]);
378 for (size_t i = 0; i < dimensions; i++) {
379 if (data_.sizes_[i] != INVALID_DIM_SIZE)
380 retVal.push_back(data_.sizes_[i]);
381 }
382 } else {
383 CHECK_EQ(data_.is_mkl_tensor_, true);
384 }
385 return retVal;
386 }
387
388 inline int64 DimSize(int index) const {
389 CHECK_LT(index, sizeof(data_.sizes_) / sizeof(data_.sizes_[0]));
390 return data_.sizes_[index];
391 }
392
393 /// Return TensorShape that describes the Tensorflow shape of the tensor
394 /// represented by this MklShape.
395 inline TensorShape GetTfShape() const {
396 CHECK_EQ(data_.is_mkl_tensor_, true);
397
398 std::vector<int32> shape(data_.dimension_, -1);
399 // As mentioned in the comment above, we now rely on TF's `data_format`
400 // attribute to determine if TF shape is in blocked format or not.
401 if (data_.tf_data_format_ != MklTensorFormat::FORMAT_BLOCKED) {
402 for (size_t idx = 0; idx < data_.dimension_; ++idx) {
403 shape[idx] = data_.sizes_[TfDimIdx(idx)];
404 }
405 } else {
406 // If Tensorflow shape is in Blocked format, then we don't have dimension
407 // map for it. So we just create Tensorflow shape from sizes in the
408 // specified order.
409 for (size_t idx = 0; idx < data_.dimension_; ++idx) {
410 shape[idx] = data_.sizes_[idx];
411 }
412 }
413
414 TensorShape ts;
415 bool ret = TensorShapeUtils::MakeShape(shape, &ts).ok();
416 CHECK_EQ(ret, true);
417 return ts;
418 }
419
420 inline void SetElemType(memory::data_type dt) { data_.T_ = dt; }
421 inline const memory::data_type GetElemType() { return data_.T_; }
422
423 inline void SetMklLayout(memory::desc* md) {
424 CHECK_NOTNULL(md);
425 data_.mkl_md_ = md->data;
426 }
427
428 inline const memory::desc GetMklLayout() const {
429 return memory::desc(data_.mkl_md_);
430 }
431
432 inline MklTensorFormat GetTfDataFormat() const {
433 return data_.tf_data_format_;
434 }
435
436 /// We don't create primitive_descriptor for TensorFlow layout now.
437 /// We use lazy evaluation and create it only when needed. Input format can
438 /// also be Blocked format.
439 inline void SetTfLayout(size_t dims, const memory::dims& sizes,
440 MklTensorFormat format) {
441 DCHECK_EQ(dims, sizes.size())
442 << "SetTfLayout: Number of dimensions does not"
443 "match with dimension array";
444 data_.dimension_ = dims;
445 for (size_t ii = 0; ii < dims; ++ii) {
446 data_.sizes_[ii] = sizes[ii];
447 }
448 data_.tf_data_format_ = format;
449 if (format != MklTensorFormat::FORMAT_BLOCKED) {
450 if (dims == 2) {
451 data_.map_[0] = MklDnnDims::Dim_N;
452 data_.map_[1] = MklDnnDims::Dim_C;
453 } else {
454 SetTfDimOrder(dims, format);
455 }
456 }
457 }
458
459 inline const memory::desc GetTfLayout() const {
460 memory::dims dims;
461 for (size_t ii = 0; ii < data_.dimension_; ++ii) {
462 dims.push_back(data_.sizes_[ii]);
463 }
464
465 // Create Blocked memory desc if input TF format was set like that.
466 if (data_.tf_data_format_ == MklTensorFormat::FORMAT_BLOCKED) {
467 auto strides = CalculateTFStrides(dims);
468 dnnl_memory_desc_t blocked_md;
469 TF_CHECK_OK(
470 CreateBlockedMemDescHelper(dims, strides, data_.T_, &blocked_md));
471 return memory::desc(blocked_md);
472 } else {
473 auto format_tag =
474 MklTensorFormatToMklDnnDataFormat(data_.tf_data_format_);
475 return memory::desc(dims, data_.T_, format_tag);
476 }
477 }
478
479 inline const memory::desc GetCurLayout() const {
480 return IsMklTensor() ? GetMklLayout() : GetTfLayout();
481 }
482
483 // We don't need a case of default dimension order because
484 // when an operator that does not get data_format attribute gets all inputs
485 // in Tensorflow format, it will produce output in Tensorflow format.
486 inline void SetTfDimOrder(const size_t dimension, const dnnl_dims_t map) {
487 CHECK(dimension == data_.dimension_);
488 for (size_t ii = 0; ii < dimension; ii++) {
489 data_.map_[ii] = map[ii];
490 }
491 }
492
493 inline void SetTfDimOrder(const size_t dimension, TensorFormat data_format) {
494 if (dimension == 5) {
495 CHECK(dimension == data_.dimension_);
496 data_.map_[GetTensorDimIndex<3>(data_format, '0')] =
497 MklDnnDims3D::Dim3d_D;
498 data_.map_[GetTensorDimIndex<3>(data_format, '1')] =
499 MklDnnDims3D::Dim3d_H;
500 data_.map_[GetTensorDimIndex<3>(data_format, '2')] =
501 MklDnnDims3D::Dim3d_W;
502 data_.map_[GetTensorDimIndex<3>(data_format, 'C')] =
503 MklDnnDims3D::Dim3d_C;
504 data_.map_[GetTensorDimIndex<3>(data_format, 'N')] =
505 MklDnnDims3D::Dim3d_N;
506 } else {
507 CHECK_EQ(dimension, 4);
508 CHECK(dimension == data_.dimension_);
509 data_.map_[GetTensorDimIndex<2>(data_format, 'W')] = MklDnnDims::Dim_W;
510 data_.map_[GetTensorDimIndex<2>(data_format, 'H')] = MklDnnDims::Dim_H;
511 data_.map_[GetTensorDimIndex<2>(data_format, 'C')] = MklDnnDims::Dim_C;
512 data_.map_[GetTensorDimIndex<2>(data_format, 'N')] = MklDnnDims::Dim_N;
513 }
514 }
515
516 inline void SetTfDimOrder(const size_t dimension, MklTensorFormat format) {
517 TensorFormat data_format = MklDnnDataFormatToTFDataFormat(format);
518 SetTfDimOrder(dimension, data_format);
519 }
520
521 inline const dnnl_dim_t* GetTfToMklDimMap() const { return &data_.map_[0]; }
522 inline size_t TfDimIdx(int index) const { return data_.map_[index]; }
523 inline int64 TfDimSize(int index) const {
524 return data_.sizes_[TfDimIdx(index)];
525 }
526
527 /// Query TF-MKL dimension ordering map and check if Tensorflow dimension 'd'
528 /// corresponds to MKL's Channel dimension.
529 inline bool IsMklChannelDim(int d) const {
530 return TfDimIdx(d) == MklDnnDims::Dim_C;
531 }
532
533 /// Query TF-MKL dimension ordering map and check if Tensorflow dimension 'd'
534 /// corresponds to MKL's Batch dimension.
535 inline bool IsMklBatchDim(int d) const {
536 return TfDimIdx(d) == MklDnnDims::Dim_N;
537 }
538
539 /// Query TF-MKL dimension ordering map and check if Tensorflow dimension 'd'
540 /// corresponds to MKL's Width dimension.
541 inline bool IsMklWidthDim(int d) const {
542 return TfDimIdx(d) == MklDnnDims::Dim_W;
543 }
544 /// Query TF-MKL dimension ordering map and check if Tensorflow dimension 'd'
545 /// corresponds to MKL's Height dimension.
546 inline bool IsMklHeightDim(int d) const {
547 return TfDimIdx(d) == MklDnnDims::Dim_H;
548 }
549
550 /// Check if the TF-MKL dimension ordering map specifies if the input
551 /// tensor is in NCHW format.
552 inline bool IsTensorInNCHWFormat() const {
553 TensorFormat data_format = FORMAT_NCHW;
554 return (IsMklBatchDim(GetTensorDimIndex<2>(data_format, 'N')) &&
555 IsMklChannelDim(GetTensorDimIndex<2>(data_format, 'C')) &&
556 IsMklHeightDim(GetTensorDimIndex<2>(data_format, 'H')) &&
557 IsMklWidthDim(GetTensorDimIndex<2>(data_format, 'W')));
558 }
559
560 /// Check if the TF-MKL dimension ordering map specifies if the input
561 /// tensor is in NHWC format.
562 inline bool IsTensorInNHWCFormat() const {
563 TensorFormat data_format = FORMAT_NHWC;
564 return (IsMklBatchDim(GetTensorDimIndex<2>(data_format, 'N')) &&
565 IsMklChannelDim(GetTensorDimIndex<2>(data_format, 'C')) &&
566 IsMklHeightDim(GetTensorDimIndex<2>(data_format, 'H')) &&
567 IsMklWidthDim(GetTensorDimIndex<2>(data_format, 'W')));
568 }
569
570 /// The following methods are used for serializing and de-serializing the
571 /// contents of the mklshape object.
572 /// The data is serialized in this order
573 /// is_mkl_tensor_ : dimension_ : sizes_ : map_: format_ : T_ : mkl_pd_;
574
575 /// Size of buffer to hold the serialized object, the size is computed by
576 /// following above mentioned order
577 inline size_t GetSerializeBufferSize() const { return sizeof(MklShapeData); }
578
579 void SerializeMklDnnShape(unsigned char* buf, size_t buf_size) const {
580 CHECK(buf_size >= GetSerializeBufferSize())
581 << "Buffer size is too small to SerializeMklDnnShape";
582 *reinterpret_cast<MklShapeData*>(buf) = data_;
583 }
584
585 void DeSerializeMklDnnShape(const unsigned char* buf, size_t buf_size) {
586 // Make sure buffer holds at least is_mkl_tensor_.
587 CHECK(buf_size >= sizeof(data_.is_mkl_tensor_))
588 << "Buffer size is too small in DeSerializeMklDnnShape";
589
590 const bool is_mkl_tensor = *reinterpret_cast<const bool*>(buf);
591 if (is_mkl_tensor) { // If it is an MKL Tensor then read the rest
592 CHECK(buf_size >= GetSerializeBufferSize())
593 << "Buffer size is too small in DeSerializeMklDnnShape";
594 data_ = *reinterpret_cast<const MklShapeData*>(buf);
595 }
596 }
597};
598
599// List of MklShape objects. Used in Concat/Split layers.
600typedef std::vector<MklDnnShape> MklDnnShapeList;
601
602template <typename T>
603class MklDnnData;
604
605// TODO(intel-tf): Merge with the execute_primitives.
606inline void ExecutePrimitive(const std::vector<primitive>& net,
607 const std::vector<MemoryArgsMap>* net_args,
608 const engine& cpu_engine,
609 OpKernelContext* context = nullptr) {
610 DCHECK(net_args);
611 DCHECK_EQ(net.size(), net_args->size());
612 std::unique_ptr<stream> cpu_stream;
613 MklDnnThreadPool eigen_tp;
614 if (context != nullptr) {
615 eigen_tp = MklDnnThreadPool(context);
616 cpu_stream.reset(CreateStream(&eigen_tp, cpu_engine));
617 } else {
618 cpu_stream.reset(CreateStream(nullptr, cpu_engine));
619 }
620 for (size_t i = 0; i < net.size(); ++i) {
621 net.at(i).execute(*cpu_stream, net_args->at(i));
622 }
623 cpu_stream->wait();
624}
625template <typename T>
626inline Status ConvertMklToTF(OpKernelContext* context,
627 const Tensor& input_mkl_tensor,
628 const MklDnnShape& input_mkl_shape,
629 Tensor* output_tf_tensor) {
630 try {
631 if (!input_mkl_shape.IsMklTensor()) {
632 // Return input as is since it is already a TF tensor
633 *output_tf_tensor = input_mkl_tensor;
634 return Status::OK();
635 }
636
637 // Allocate output tensor.
638 TensorShape output_tf_shape = input_mkl_shape.GetTfShape();
639 TF_CHECK_OK(context->allocate_temp(DataTypeToEnum<T>::v(), output_tf_shape,
640 output_tf_tensor));
641
642 engine cpu_engine(engine::kind::cpu, 0);
643 MklDnnData<T> input(&cpu_engine);
644
645 // Get MKL layout of input tensor.
646 auto input_mkl_md = input_mkl_shape.GetMklLayout();
647 auto output_tf_md = input_mkl_shape.GetTfLayout();
648 input.SetUsrMem(input_mkl_md, &input_mkl_tensor);
649
650 if (input.IsReorderNeeded(output_tf_md)) {
651 std::vector<primitive> net;
652 std::vector<MemoryArgsMap> net_args;
653 bool status = input.CheckReorderToOpMem(output_tf_md, output_tf_tensor,
654 net, net_args, cpu_engine);
655 if (!status) {
656 return Status(error::Code::INTERNAL,
657 "ConvertMklToTF(): Failed to create reorder for input");
658 }
659 ExecutePrimitive(net, &net_args, cpu_engine, context);
660 } else {
661 // If not, just forward input tensor to output tensor.
662 bool status =
663 output_tf_tensor->CopyFrom(input_mkl_tensor, output_tf_shape);
664 if (!status) {
665 return Status(
666 error::Code::INTERNAL,
667 "ConvertMklToTF(): Failed to forward input tensor to output");
668 }
669 }
670 return Status::OK();
671 } catch (dnnl::error& e) {
672 string error_msg = "Status: " + std::to_string(e.status) +
673 ", message: " + string(e.message) + ", in file " +
674 string(__FILE__) + ":" + std::to_string(__LINE__);
675 LOG(FATAL) << "Operation received an exception: " << error_msg;
676 }
677}
678
679// Get the MKL shape from the second string tensor
680inline void GetMklShape(OpKernelContext* ctext, int n, MklDnnShape* mklshape,
681 bool eager_mode) {
682 if (!eager_mode) {
683 mklshape->DeSerializeMklDnnShape(
684 ctext->input(GetTensorMetaDataIndex(n, ctext->num_inputs()))
685 .flat<uint8>()
686 .data(),
687 ctext->input(GetTensorMetaDataIndex(n, ctext->num_inputs()))
688 .flat<uint8>()
689 .size() *
690 sizeof(uint8));
691 } else {
692 mklshape->SetMklTensor(false);
693 }
694}
695
696inline void GetMklShape(OpKernelContext* ctext, int n, MklDnnShape* mklshape) {
697 GetMklShape(ctext, n, mklshape, false);
698}
699
700// Gets the actual input
701inline const Tensor& MklGetInput(OpKernelContext* ctext, int n) {
702 return ctext->input(GetTensorDataIndex(n, ctext->num_inputs()));
703}
704
705inline void GetMklInputList(OpKernelContext* ctext, StringPiece name,
706 OpInputList* input_tensors) {
707 CHECK_NOTNULL(input_tensors);
708 TF_CHECK_OK(ctext->input_list(name, input_tensors));
709}
710
711inline void GetMklShapeList(OpKernelContext* ctext, StringPiece name,
712 MklDnnShapeList* mkl_shapes,
713 bool native_format = false) {
714 if (!native_format) {
715 OpInputList input_mkl_tensors;
716 GetMklInputList(ctext, strings::StrCat("mkl_", name), &input_mkl_tensors);
717
718 for (int i = 0; i < input_mkl_tensors.size(); i++) {
719 (*mkl_shapes)[i].DeSerializeMklDnnShape(
720 input_mkl_tensors[i].flat<uint8>().data(),
721 input_mkl_tensors[i].flat<uint8>().size() * sizeof(uint8));
722 }
723 } else {
724 for (int i = 0; i < mkl_shapes->size(); ++i) {
725 (*mkl_shapes)[i].SetMklTensor(false);
726 }
727 }
728}
729
730/// Get shape of input tensor pointed by 'input_idx' in TensorShape format.
731/// If the input tensor is in MKL layout, then obtains TensorShape from
732/// MklShape.
733inline TensorShape GetTfShape(OpKernelContext* context, size_t input_idx,
734 bool eager_mode = false) {
735 // Sanity check.
736 CHECK_NOTNULL(context);
737 CHECK_LT(input_idx, context->num_inputs());
738
739 MklDnnShape input_mkl_shape;
740 GetMklShape(context, input_idx, &input_mkl_shape, eager_mode);
741 if (input_mkl_shape.IsMklTensor() && !eager_mode) {
742 return input_mkl_shape.GetTfShape();
743 } else {
744 const Tensor& t = MklGetInput(context, input_idx);
745 return t.shape();
746 }
747}
748
749// Allocate the second output tensor that will contain
750// the MKL shape serialized
751inline void AllocateOutputSetMklShape(OpKernelContext* ctext, int n,
752 const MklDnnShape& mkl_shape) {
753 Tensor* second_tensor = nullptr;
754 TensorShape second_shape;
755 second_shape.AddDim(mkl_shape.GetSerializeBufferSize());
756 OP_REQUIRES_OK(ctext, ctext->allocate_output(
757 GetTensorMetaDataIndex(n, ctext->num_outputs()),
758 second_shape, &second_tensor));
759 mkl_shape.SerializeMklDnnShape(
760 second_tensor->flat<uint8>().data(),
761 second_tensor->flat<uint8>().size() * sizeof(uint8));
762}
763
764// Allocate the output tensor, create a second output tensor that will contain
765// the MKL shape serialized
766inline void AllocateOutputSetMklShape(OpKernelContext* ctext, int n,
767 Tensor** output,
768 const TensorShape& tf_shape,
769 const MklDnnShape& mkl_shape,
770 bool eager_mode = false) {
771 OP_REQUIRES_OK(
772 ctext, ctext->allocate_output(GetTensorDataIndex(n, ctext->num_outputs()),
773 tf_shape, output));
774 if (!eager_mode) {
775 Tensor* second_tensor = nullptr;
776 TensorShape second_shape;
777 second_shape.AddDim(mkl_shape.GetSerializeBufferSize());
778 OP_REQUIRES_OK(ctext, ctext->allocate_output(
779 GetTensorMetaDataIndex(n, ctext->num_outputs()),
780 second_shape, &second_tensor));
781 mkl_shape.SerializeMklDnnShape(
782 second_tensor->flat<uint8>().data(),
783 second_tensor->flat<uint8>().size() * sizeof(uint8));
784 }
785}
786
787// Allocates a temp tensor and returns the data buffer for temporary storage.
788template <typename T>
789inline void AllocTmpBuffer(OpKernelContext* context, Tensor* tensor_out,
790 const memory::desc& pd, void** buf_out) {
791 TensorShape tf_shape;
792
793 tf_shape.AddDim(pd.get_size() / sizeof(T) + 1);
794 OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum<T>::v(),
795 tf_shape, tensor_out));
796 *buf_out = static_cast<void*>(tensor_out->flat<T>().data());
797}
798
799template <typename T>
800inline void AllocTmpBuffer(OpKernelContext* context, Tensor* tensor_out,
801 TensorShape tf_shape) {
802 OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum<T>::v(),
803 tf_shape, tensor_out));
804}
805
806template <typename T>
807struct UserScratchPad {
808 template <typename MklPrim>
809 // NOTE: if scratchpad is not required for a particular primitive the
810 // spad_md.get_size() will return 0. It is fine to return
811 // nullptr in this case
812 inline void AllocateSPTensor(MklPrim* mkl_prim, OpKernelContext* context) {
813 allocated_ = false;
814 auto spad_md = mkl_prim->GetScratchPadDesc();
815 size_t spad_size = spad_md.get_size();
816 if (spad_size == 0) return;
817
818 size_t allocate_size = (spad_size + sizeof(T) - 1) / sizeof(T);
819 TensorShape tf_shape;
820 tf_shape.AddDim(allocate_size);
821 AllocTmpBuffer<T>(context, &scratch_pad_, tf_shape);
822 allocated_ = true;
823 }
824 inline void* Get() {
825 if (allocated_) {
826 return static_cast<void*>(scratch_pad_.flat<T>().data());
827 } else {
828 return nullptr;
829 }
830 }
831
832 private:
833 Tensor scratch_pad_;
834 bool allocated_ = false;
835};
836
837inline void GetStridesFromSizes(MklTensorFormat data_format, size_t* strides,
838 const size_t* sizes) {
839 DCHECK_NE(data_format, MklTensorFormat::FORMAT_INVALID);
840 // MKL requires strides in NCHW
841 if (data_format == MklTensorFormat::FORMAT_NHWC) {
842 strides[0] = sizes[2];
843 strides[1] = sizes[0] * sizes[2];
844 strides[2] = 1;
845 strides[3] = sizes[0] * sizes[1] * sizes[2];
846 } else {
847 strides[0] = 1;
848 strides[1] = sizes[0];
849 strides[2] = sizes[0] * sizes[1];
850 strides[3] = sizes[0] * sizes[1] * sizes[2];
851 }
852}
853
854inline void CopyMklTensorInToOut(OpKernelContext* context, int idx_in,
855 int idx_out) {
856 int num_inputs = context->num_inputs();
857 int num_outputs = context->num_outputs();
858 int idx_data_in = GetTensorDataIndex(idx_in, num_inputs);
859 int idx_meta_in = GetTensorMetaDataIndex(idx_in, num_inputs);
860 int idx_data_out = GetTensorDataIndex(idx_out, num_outputs);
861 int idx_meta_out = GetTensorMetaDataIndex(idx_out, num_outputs);
862
863 const Tensor& data = context->input(idx_data_in);
864 const Tensor& meta = context->input(idx_meta_in);
865 Tensor output(data.dtype());
866 Tensor meta_output(meta.dtype());
867
868 // TODO(intel-tf): alternatively, call forward_input_to_output_with_shape(...)
869 CHECK(output.CopyFrom(data, data.shape()));
870 CHECK(meta_output.CopyFrom(meta, meta.shape()));
871 context->set_output(idx_data_out, output);
872 context->set_output(idx_meta_out, meta_output);
873}
874
875inline void CopyTfTensorInToOutWithShape(OpKernelContext* context, int idx_in,
876 int idx_out,
877 const TensorShape& shape) {
878 int num_inputs = context->num_inputs();
879 int num_outputs = context->num_outputs();
880 int idx_data_in = GetTensorDataIndex(idx_in, num_inputs);
881 int idx_data_out = GetTensorDataIndex(idx_out, num_outputs);
882
883 const Tensor& data = context->input(idx_data_in);
884 MklDnnShape mkl_shape_output;
885 mkl_shape_output.SetMklTensor(false);
886 AllocateOutputSetMklShape(context, idx_out, mkl_shape_output);
887 Tensor output(data.dtype());
888 // TODO(intel-tf): alternatively, call forward_input_to_output_with_shape(...)
889 CHECK(output.CopyFrom(data, shape));
890 context->set_output(idx_data_out, output);
891}
892
893inline void ForwardTfTensorInToOut(OpKernelContext* context, int idx_in,
894 int idx_out) {
895 int num_inputs = context->num_inputs();
896 int num_outputs = context->num_outputs();
897 int idx_data_in = GetTensorDataIndex(idx_in, num_inputs);
898 int idx_data_out = GetTensorDataIndex(idx_out, num_outputs);
899
900 MklDnnShape dnn_shape_output;
901 dnn_shape_output.SetMklTensor(false);
902 AllocateOutputSetMklShape(context, idx_out, dnn_shape_output);
903 if (IsRefType(context->input_dtype(idx_data_in))) {
904 context->forward_ref_input_to_ref_output(idx_data_in, idx_data_out);
905 } else {
906 context->set_output(idx_data_out, context->input(idx_data_in));
907 }
908}
909
910inline void ForwardMklTensorInToOut(OpKernelContext* context, int idx_in,
911 int idx_out) {
912 int num_inputs = context->num_inputs();
913 int num_outputs = context->num_outputs();
914 int idx_data_in = GetTensorDataIndex(idx_in, num_inputs);
915 int idx_meta_in = GetTensorMetaDataIndex(idx_in, num_inputs);
916 int idx_data_out = GetTensorDataIndex(idx_out, num_outputs);
917 int idx_meta_out = GetTensorMetaDataIndex(idx_out, num_outputs);
918
919 if (IsRefType(context->input_dtype(idx_data_in))) {
920 context->forward_ref_input_to_ref_output(idx_data_in, idx_data_out);
921 context->forward_ref_input_to_ref_output(idx_meta_in, idx_meta_out);
922 } else {
923 context->set_output(idx_data_out, context->input(idx_data_in));
924 context->set_output(idx_meta_out, context->input(idx_meta_in));
925 }
926}
927
928// Set a dummy oneDNN shape (called when the output is in TF format)
929inline void SetDummyMklDnnShapeOutput(OpKernelContext* context,
930 uint32 idx_data_out) {
931 MklDnnShape mkl_shape_output;
932 mkl_shape_output.SetMklTensor(false);
933 AllocateOutputSetMklShape(context, idx_data_out, mkl_shape_output);
934}
935
936// If the input tensor has ref count as 1, it is forwarded to the desired
937// output port and the function returns true. In that case, it also allocates
938// the serialized MklDnnShape object. Otherwise, the function returns false.
939inline bool ForwardMklTensorInToOutWithMklShape(OpKernelContext* context,
940 int idx_in, int idx_out,
941 Tensor** output,
942 const MklDnnShape& mkl_shape,
943 bool always_forward = true) {
944 int num_inputs = context->num_inputs();
945 int num_outputs = context->num_outputs();
946 int idx_data_in = GetTensorDataIndex(idx_in, num_inputs);
947 int idx_data_out = GetTensorDataIndex(idx_out, num_outputs);
948 bool is_forwarded = false;
949 const Tensor& input_tensor = context->input(idx_data_in);
950 const auto output_shape = input_tensor.shape();
951 if (always_forward) {
952 if (IsRefType(context->input_dtype(idx_data_in))) {
953 context->forward_ref_input_to_ref_output(idx_data_in, idx_data_out);
954 } else {
955 context->set_output(idx_data_out, input_tensor);
956 }
957 } else {
958 is_forwarded = context->forward_input_to_output_with_shape(
959 idx_data_in, idx_data_out, output_shape, output);
960 }
961 if (is_forwarded || always_forward) {
962 AllocateOutputSetMklShape(context, idx_out, mkl_shape);
963 return true;
964 }
965 return false;
966}
967
968// Forward the MKL shape ONLY (used in elementwise and other ops where
969// we call the eigen implementation and MKL shape is not used)
970inline void ForwardMklMetaDataInToOut(OpKernelContext* context,
971 uint32 idx_data_in,
972 uint32_t idx_data_out) {
973 uint32 idx_meta_in =
974 GetTensorMetaDataIndex(idx_data_in, context->num_inputs());
975 uint32 idx_meta_out =
976 GetTensorMetaDataIndex(idx_data_out, context->num_outputs());
977
978 if (IsRefType(context->input_dtype(idx_data_in))) {
979 context->forward_ref_input_to_ref_output(idx_meta_in, idx_meta_out);
980 } else {
981 context->set_output(idx_meta_out, context->input(idx_meta_in));
982 }
983}
984
985// -------------------------------------------------------------------
986// Common utility functions used by MKL unit tests
987
988inline Tensor GetMklMetaTensor() {
989 MklDnnShape non_mkl_shape;
990 non_mkl_shape.SetMklTensor(false);
991
992 auto size = static_cast<int64_t>(non_mkl_shape.GetSerializeBufferSize());
993 Tensor tensor(DT_UINT8, {size});
994
995 non_mkl_shape.SerializeMklDnnShape(tensor.flat<uint8>().data(),
996 size * sizeof(uint8));
997 return tensor;
998}
999
1000// -------------------------------------------------------------------
1001
1002/// Return oneDNN data type (memory::data_type) for input type T
1003///
1004/// @input None
1005/// @return memory::data_type corresponding to type T
1006template <typename T>
1007static memory::data_type MklDnnType();
1008
1009/// Instantiation for float type. Add similar instantiations for other
1010/// type if needed.
1011template <>
1012memory::data_type MklDnnType<float>() {
1013 return memory::data_type::f32;
1014}
1015
1016template <>
1017memory::data_type MklDnnType<quint8>() {
1018 return memory::data_type::u8;
1019}
1020
1021template <>
1022memory::data_type MklDnnType<uint8>() {
1023 return memory::data_type::u8;
1024}
1025
1026template <>
1027memory::data_type MklDnnType<qint8>() {
1028 return memory::data_type::s8;
1029}
1030
1031template <>
1032memory::data_type MklDnnType<qint32>() {
1033 return memory::data_type::s32;
1034}
1035template <>
1036memory::data_type MklDnnType<bfloat16>() {
1037 return memory::data_type::bf16;
1038}
1039
1040// Map MklTensorFormat to oneDNN format tag
1041//
1042// @input: MklTensorFormat i.e. TensorFlow data format
1043// @return: oneDNN's memory format tag corresponding to MklTensorFormat.
1044// Fails with an error if invalid data format.
1045inline memory::format_tag MklTensorFormatToMklDnnDataFormat(
1046 MklTensorFormat format) {
1047 if (format == MklTensorFormat::FORMAT_NHWC) return memory::format_tag::nhwc;
1048 if (format == MklTensorFormat::FORMAT_NCHW) return memory::format_tag::nchw;
1049 if (format == MklTensorFormat::FORMAT_NDHWC) return memory::format_tag::ndhwc;
1050 if (format == MklTensorFormat::FORMAT_NCDHW) return memory::format_tag::ncdhw;
1051 if (format == MklTensorFormat::FORMAT_X) return memory::format_tag::x;
1052 if (format == MklTensorFormat::FORMAT_NC) return memory::format_tag::nc;
1053 if (format == MklTensorFormat::FORMAT_TNC) return memory::format_tag::tnc;
1054 return memory::format_tag::undef;
1055}
1056
1057/// Map TensorFlow data format into oneDNN 3D data format
1058/// @input: TensorFlow data format
1059/// @return: oneDNN 3D data format corresponding to TensorFlow data format;
1060/// Fails with an error if invalid data format.
1061inline MklTensorFormat TFDataFormatToMklDnn3DDataFormat(TensorFormat format) {
1062 if (format == FORMAT_NHWC) return MklTensorFormat::FORMAT_NDHWC;
1063 if (format == FORMAT_NCHW) return MklTensorFormat::FORMAT_NCDHW;
1064 TF_CHECK_OK(Status(error::Code::INVALID_ARGUMENT, "Unsupported data format"));
1065 return MklTensorFormat::FORMAT_INVALID;
1066}
1067
1068/// Map TensorFlow data format into oneDNN data format
1069///
1070/// @input: TensorFlow data format
1071/// @return: oneDNN data format corresponding to TensorFlow data format;
1072/// Fails with an error if invalid data format.
1073inline MklTensorFormat TFDataFormatToMklDnnDataFormat(TensorFormat format) {
1074 if (format == FORMAT_NHWC) return MklTensorFormat::FORMAT_NHWC;
1075 if (format == FORMAT_NCHW) return MklTensorFormat::FORMAT_NCHW;
1076 TF_CHECK_OK(Status(error::Code::INVALID_ARGUMENT, "Unsupported data format"));
1077 return MklTensorFormat::FORMAT_INVALID;
1078}
1079
1080/// Map oneDNN data format into TensorFlow data format
1081///
1082/// @input: oneDNN data format
1083/// @return: Tensorflow data format corresponding to oneDNN data format;
1084/// Fails with an error if invalid data format.
1085inline TensorFormat MklDnnDataFormatToTFDataFormat(MklTensorFormat format) {
1086 if (format == MklTensorFormat::FORMAT_NHWC ||
1087 format == MklTensorFormat::FORMAT_NDHWC)
1088 return FORMAT_NHWC;
1089 if (format == MklTensorFormat::FORMAT_NCHW ||
1090 format == MklTensorFormat::FORMAT_NCDHW)
1091 return FORMAT_NCHW;
1092 TF_CHECK_OK(Status(error::Code::INVALID_ARGUMENT, "Unsupported data format"));
1093
1094 // Return to prevent compiler warnings, otherwise TF_CHECK_OK will ensure
1095 // that we don't come here.
1096 return FORMAT_NHWC;
1097}
1098
1099/// Map TensorShape object into memory::dims required by oneDNN
1100///
1101/// This function will simply map input TensorShape into oneDNN dims
1102/// naively. So it will preserve the order of dimensions. E.g., if
1103/// input tensor is in NHWC format, then dims will be in NHWC format also.
1104///
1105/// @input TensorShape object in shape
1106/// @return memory::dims corresponding to TensorShape
1107inline memory::dims TFShapeToMklDnnDims(const TensorShape& shape) {
1108 memory::dims dims(shape.dims());
1109 for (int d = 0; d < shape.dims(); ++d) {
1110 dims[d] = shape.dim_size(d);
1111 }
1112 return dims;
1113}
1114
1115/// Map TensorShape object into memory::dims in NCHW format required by oneDNN
1116///
1117/// This function is a specific one than above function. It will map input
1118/// TensorShape into oneDNN dims in NCHW format. So it may not preserve the
1119/// order of dimensions. E.g., if input tensor is in NHWC format, then dims
1120/// will be in NCHW format, and not in NHWC format.
1121///
1122/// @input TensorShape object in shape
1123/// @return memory::dims in oneDNN required NCHW format
1124inline memory::dims TFShapeToMklDnnDimsInNCHW(const TensorShape& shape,
1125 TensorFormat format) {
1126 // Check validity of format.
1127 DCHECK_NE(TFDataFormatToMklDnnDataFormat(format),
1128 MklTensorFormat::FORMAT_INVALID);
1129
1130 int n = shape.dim_size(GetTensorDimIndex(format, 'N'));
1131 int c = shape.dim_size(GetTensorDimIndex(format, 'C'));
1132 int h = shape.dim_size(GetTensorDimIndex(format, 'H'));
1133 int w = shape.dim_size(GetTensorDimIndex(format, 'W'));
1134
1135 // oneDNN requires dimensions in NCHW format.
1136 return memory::dims({n, c, h, w});
1137}
1138
1139inline memory::dims TFShapeToMklDnnDimsInNCDHW(const TensorShape& shape,
1140 TensorFormat format) {
1141 // Validate format.
1142 DCHECK_NE(TFDataFormatToMklDnn3DDataFormat(format),
1143 MklTensorFormat::FORMAT_INVALID);
1144
1145 int n = shape.dim_size(GetTensorDimIndex<3>(format, 'N'));
1146 int c = shape.dim_size(GetTensorDimIndex<3>(format, 'C'));
1147 int d = shape.dim_size(GetTensorDimIndex<3>(format, '0'));
1148 int h = shape.dim_size(GetTensorDimIndex<3>(format, '1'));
1149 int w = shape.dim_size(GetTensorDimIndex<3>(format, '2'));
1150
1151 // oneDNN requires dimensions in NCDHW format.
1152 return memory::dims({n, c, d, h, w});
1153}
1154
1155/// Overloaded version of function TFShapeToMklDnnDimsInNCHW above.
1156/// Input parameters are self-explanatory.
1157inline memory::dims MklDnnDimsInNCHW(const memory::dims& in_dims,
1158 TensorFormat format) {
1159 // Validate format.
1160 DCHECK_NE(TFDataFormatToMklDnnDataFormat(format),
1161 MklTensorFormat::FORMAT_INVALID);
1162
1163 int n = in_dims[GetTensorDimIndex(format, 'N')];
1164 int c = in_dims[GetTensorDimIndex(format, 'C')];
1165 int h = in_dims[GetTensorDimIndex(format, 'H')];
1166 int w = in_dims[GetTensorDimIndex(format, 'W')];
1167
1168 // oneDNN requires dimensions in NCHW format.
1169 return memory::dims({n, c, h, w});
1170}
1171
1172/// Overloaded version of function TFShapeToMklDnnDimsInNCDHW above.
1173/// Input parameters are self-explanatory.
1174inline memory::dims MklDnnDimsInNCDHW(const memory::dims& in_dims,
1175 TensorFormat format) {
1176 // Validate format.
1177 DCHECK_NE(TFDataFormatToMklDnnDataFormat(format),
1178 MklTensorFormat::FORMAT_INVALID);
1179
1180 int n = in_dims[GetTensorDimIndex<3>(format, 'N')];
1181 int c = in_dims[GetTensorDimIndex<3>(format, 'C')];
1182 int d = in_dims[GetTensorDimIndex<3>(format, '0')];
1183 int h = in_dims[GetTensorDimIndex<3>(format, '1')];
1184 int w = in_dims[GetTensorDimIndex<3>(format, '2')];
1185
1186 // MKL DNN requires dimensions in NCDHW format.
1187 return memory::dims({n, c, d, h, w});
1188}
1189
1190/// Map MklDnn memory::dims object into TensorShape object.
1191///
1192/// This function will simply map input shape in oneDNN memory::dims format
1193/// in Tensorflow's TensorShape object by preserving dimension order.
1194///
1195/// @input oneDNN memory::dims object
1196/// @output TensorShape corresponding to memory::dims
1197inline TensorShape MklDnnDimsToTFShape(const memory::dims& dims) {
1198 std::vector<int32> shape(dims.size(), -1);
1199 for (int d = 0; d < dims.size(); d++) {
1200 shape[d] = dims[d];
1201 }
1202
1203 TensorShape ret;
1204 CHECK_EQ(TensorShapeUtils::MakeShape(shape, &ret).ok(), true);
1205 return ret;
1206}
1207
1208/// Function to calculate strides given tensor shape in Tensorflow order
1209/// E.g., if dims_tf_order is {1, 2, 3, 4}, then as per Tensorflow convention,
1210/// dimension with size 1 is outermost dimension; while dimension with size 4 is
1211/// innermost dimension. So strides for this tensor would be {4 * 3 * 2,
1212/// 4 * 3, 4, 1}, i.e., {24, 12, 4, 1}.
1213///
1214/// @input Tensorflow shape in memory::dims type
1215/// @return memory::dims containing strides for the tensor.
1216inline memory::dims CalculateTFStrides(const memory::dims& dims_tf_order) {
1217 CHECK_GT(dims_tf_order.size(), 0);
1218 memory::dims strides(dims_tf_order.size());
1219 int last_dim_idx = dims_tf_order.size() - 1;
1220 strides[last_dim_idx] = 1;
1221 for (int d = last_dim_idx - 1; d >= 0; d--) {
1222 strides[d] = strides[d + 1] * dims_tf_order[d + 1];
1223 }
1224 return strides;
1225}
1226
1227/// Helper function to create memory descriptor in Blocked format
1228///
1229/// @input: Tensor dimensions
1230/// @input: strides corresponding to dimensions. One can use utility
1231/// function such as CalculateTFStrides to compute strides
1232/// for given dimensions.
1233/// @output: dnnl_memory_desc_t object corresponding to blocked memory
1234/// format for given dimensions and strides.
1235/// @return: Status indicating whether the blocked memory descriptor
1236/// was successfully created.
1237inline Status CreateBlockedMemDescHelper(const memory::dims& dim,
1238 const memory::dims& strides,
1239 memory::data_type dtype,
1240 dnnl_memory_desc_t* blocked_md) {
1241 DCHECK_EQ(dim.size(), strides.size());
1242 const int kNumDims = dim.size();
1243 dnnl_dim_t* input_dims = new dnnl_dim_t[kNumDims];
1244 dnnl_dim_t* input_strides = new dnnl_dim_t[kNumDims];
1245 for (int i = 0; i < kNumDims; ++i) {
1246 input_dims[i] = dim[i];
1247 input_strides[i] = strides[i];
1248 }
1249 try {
1250 dnnl_memory_desc_init_by_strides(blocked_md, kNumDims, input_dims,
1251 memory::convert_to_c(dtype),
1252 input_strides);
1253 delete[] input_dims;
1254 delete[] input_strides;
1255 } catch (dnnl::error& e) {
1256 delete[] input_dims;
1257 delete[] input_strides;
1258 return Status(error::Code::INTERNAL,
1259 tensorflow::strings::StrCat(
1260 "Failed to create blocked memory descriptor.",
1261 "Status: ", e.status, ", message: ", e.message));
1262 }
1263 return Status::OK();
1264}
1265
1266inline void CreateAndExecuteReorder(const ReorderPd& reorder_desc,
1267 const memory& src_mem,
1268 const memory& dst_mem, const engine& engine,
1269 OpKernelContext* ctx = nullptr) {
1270 std::vector<primitive> net;
1271 net.push_back(dnnl::reorder(reorder_desc));
1272 std::vector<MemoryArgsMap> net_args;
1273 net_args.push_back({{DNNL_ARG_FROM, src_mem}, {DNNL_ARG_TO, dst_mem}});
1274 ExecutePrimitive(net, &net_args, engine, ctx);
1275}
1276
1277class MklReorderPrimitive;
1278
1279template <typename T>
1280inline MklReorderPrimitive* FindOrCreateReorder(const memory* from,
1281 const memory* to);
1282
1283// Class to represent all the resources corresponding to a tensor in TensorFlow
1284// that are required to execute an operation (such as Convolution).
1285template <typename T>
1286class MklDnnData {
1287 private:
1288 /// oneDNN memory primitive for input user memory
1289 memory* user_memory_;
1290
1291 /// oneDNN memory primitive in case input or output reorder is needed.
1292 memory* reorder_memory_;
1293
1294 /// Operations memory descriptor
1295 memory::desc* op_md_;
1296 // flat to indicate if data is 3D or not.
1297 bool bIs3D;
1298 /// Operations temp buffer
1299 void* allocated_buffer_;
1300 /// CPU engine on which operation will be executed
1301 const engine* cpu_engine_;
1302
1303 public:
1304 explicit MklDnnData(const engine* e)
1305 : user_memory_(nullptr),
1306 reorder_memory_(nullptr),
1307 op_md_(nullptr),
1308 bIs3D(false),
1309 allocated_buffer_(nullptr),
1310 cpu_engine_(e) {}
1311
1312 // MklDnnData does not use any smart pointers,
1313 // hence default operator= will result in memory leak if user_memory was
1314 // already initialized. See
1315 // https://github.com/tensorflow/tensorflow/pull/45593 as an example of such
1316 // leak.
1317 MklDnnData(const MklDnnData&) = default;
1318 MklDnnData& operator=(const MklDnnData&) = delete;
1319
1320 ~MklDnnData() {
1321 if (allocated_buffer_ != nullptr) {
1322 cpu_allocator()->DeallocateRaw(allocated_buffer_);
1323 }
1324 cpu_engine_ = nullptr; // We don't own this.
1325 delete (user_memory_);
1326 delete (reorder_memory_);
1327 delete (op_md_);
1328 }
1329
1330 inline void* GetTensorBuffer(const Tensor* tensor) const {
1331 CHECK_NOTNULL(tensor);
1332 return const_cast<void*>(
1333 static_cast<const void*>(tensor->flat<T>().data()));
1334 }
1335
1336 void SetIs3DData(bool bIs3D_) { bIs3D = bIs3D_; }
1337 bool GetIs3D() { return bIs3D; }
1338
1339 /// Set user memory primitive using specified dimensions, memory format tag
1340 /// and data_buffer. Function automatically uses element data type by using
1341 /// input type T used for creating call object.
1342 ///
1343 /// In a nutshell, function allows user to describe the input tensor to
1344 /// an operation. E.g., filter of Conv2D is of shape {1, 2, 3, 4}, and
1345 /// memory format tag HWIO, and the buffer that contains actual values is
1346 /// pointed by data_buffer.
1347 inline void SetUsrMem(const memory::dims& dim, memory::format_tag fm,
1348 void* data_buffer = nullptr) {
1349 auto md = memory::desc(dim, MklDnnType<T>(), fm);
1350 SetUsrMem(md, data_buffer);
1351 }
1352
1353 inline void SetUsrMem(const memory::dims& dim, memory::format_tag fm,
1354 const Tensor* tensor) {
1355 DCHECK(tensor);
1356 SetUsrMem(dim, fm, GetTensorBuffer(tensor));
1357 }
1358
1359 /// Helper function to create memory descriptor in Blocked format
1360 ///
1361 /// @input: Tensor dimensions
1362 /// @input: strides corresponding to dimensions. One can use utility
1363 /// function such as CalculateTFStrides to compute strides
1364 /// for given dimensions.
1365 /// @return: memory::desc object corresponding to blocked memory format
1366 /// for given dimensions and strides.
1367 static inline memory::desc CreateBlockedMemDesc(const memory::dims& dim,
1368 const memory::dims& strides) {
1369 dnnl_memory_desc_t blocked_md;
1370 TF_CHECK_OK(
1371 CreateBlockedMemDescHelper(dim, strides, MklDnnType<T>(), &blocked_md));
1372 return memory::desc(blocked_md);
1373 }
1374
1375 /// A version of SetUsrMem call that allows user to create memory in blocked
1376 /// format. So in addition to accepting dimensions, it also accepts strides.
1377 /// This allows user to create memory for tensor in a format that is not
1378 /// supported by oneDNN. E.g., oneDNN does not support tensor format for 6
1379 /// dimensional tensor as a native format. But by using blocked format, a user
1380 /// can create memory for 6D tensor.
1381 inline void SetUsrMem(const memory::dims& dim, const memory::dims& strides,
1382 void* data_buffer = nullptr) {
1383 CHECK_EQ(dim.size(), strides.size());
1384 auto blocked_md = MklDnnData<T>::CreateBlockedMemDesc(dim, strides);
1385 SetUsrMem(blocked_md, data_buffer);
1386 }
1387
1388 inline void SetUsrMem(const memory::dims& dim, const memory::dims& strides,
1389 const Tensor* tensor) {
1390 CHECK_NOTNULL(tensor);
1391 SetUsrMem(dim, strides, GetTensorBuffer(tensor));
1392 }
1393
1394 /// A version of SetUsrMem with memory descriptor and tensor
1395 inline void SetUsrMem(const memory::desc& md, const Tensor* tensor) {
1396 CHECK_NOTNULL(tensor);
1397 SetUsrMem(md, GetTensorBuffer(tensor));
1398 }
1399
1400 /// A version of function to set user memory type that accepts memory
1401 /// descriptor directly, instead of accepting dimensions and format. This
1402 /// function is more generic than the one above, but the function above is
1403 /// sufficient in most cases.
1404 inline void SetUsrMem(const memory::desc& pd, void* data_buffer = nullptr) {
1405 DCHECK(cpu_engine_);
1406 if (user_memory_) delete user_memory_;
1407 // TODO(intel-tf): can we remove dynamic memory allocation?
1408 if (data_buffer) {
1409 user_memory_ = new memory(pd, *cpu_engine_, data_buffer);
1410 } else {
1411 user_memory_ = new memory(pd, *cpu_engine_);
1412 }
1413 }
1414
1415 /// Get function for user memory primitive.
1416 inline const memory* GetUsrMem() const { return user_memory_; }
1417
1418 /// Get function for descriptor of user memory.
1419 inline memory::desc GetUsrMemDesc() const {
1420 DCHECK(user_memory_);
1421 return user_memory_->get_desc();
1422 }
1423
1424 /// Get function for data buffer of user memory primitive.
1425 inline void* GetUsrMemDataHandle() const {
1426 CHECK_NOTNULL(user_memory_);
1427 return user_memory_->get_data_handle();
1428 }
1429
1430 /// Set function for data buffer of user memory primitive.
1431 inline void SetUsrMemDataHandle(void* data_buffer,
1432 std::shared_ptr<stream> t_stream = nullptr) {
1433 CHECK_NOTNULL(user_memory_);
1434 CHECK_NOTNULL(data_buffer);
1435#ifndef ENABLE_ONEDNN_OPENMP
1436 user_memory_->set_data_handle(data_buffer, *t_stream);
1437#else
1438 user_memory_->set_data_handle(data_buffer);
1439#endif // !ENABLE_ONEDNN_OPENMP
1440 }
1441
1442 /// Set function for data buffer of user memory primitive.
1443 inline void SetUsrMemDataHandle(const Tensor* tensor,
1444 std::shared_ptr<stream> t_stream = nullptr) {
1445 SetUsrMemDataHandle(GetTensorBuffer(tensor), t_stream);
1446 }
1447
1448 /// allocate function for data buffer
1449 inline void AllocateBuffer(size_t size) {
1450 const int64 kMemoryAlignment = 64; // For AVX512 memory alignment.
1451 allocated_buffer_ = cpu_allocator()->AllocateRaw(kMemoryAlignment, size);
1452 }
1453
1454 inline void* GetAllocatedBuffer() { return allocated_buffer_; }
1455
1456 /// Get the memory primitive for input and output of an op. If inputs
1457 /// to an op require reorders, then this function returns memory primitive
1458 /// for reorder. Otherwise, it will return memory primitive for user memory.
1459 ///
1460 /// E.g., Conv2D(I, F) is a primitive with I and F being inputs. Then to
1461 /// execute Conv2D, we need memory primitive for I and F. But if reorder is
1462 /// required for I and F (say I_r is reorder primitive for I; F_r is reorder
1463 /// primitive for F), then we need I_r and F_r to perform Conv2D.
1464 inline const memory& GetOpMem() const {
1465 return reorder_memory_ ? *reorder_memory_ : *user_memory_;
1466 }
1467
1468 /// Set memory descriptor of an operation in terms of dimensions and memory
1469 /// format. E.g., For Conv2D, the dimensions would be same as user dimensions
1470 /// but memory::format_tag would be dnnl::any because we want oneDNN to
1471 /// choose the best layout/format for given input dimensions.
1472 inline void SetOpMemDesc(const memory::dims& dim, memory::format_tag fm) {
1473 // TODO(intel-tf): can we remove dynamic memory allocation?
1474 op_md_ = new memory::desc(dim, MklDnnType<T>(), fm);
1475 }
1476
1477 /// Get function for memory descriptor for an operation
1478 inline const memory::desc& GetOpMemDesc() const { return *op_md_; }
1479
1480 /// Predicate that checks if we need to reorder user's memory into memory
1481 /// pointed by op_md.
1482 ///
1483 /// @input: op_md - memory descriptor of the given input of an operation.
1484 /// @return: true in case reorder of input is needed; false, otherwise.
1485 inline bool IsReorderNeeded(const memory::desc& op_pd) const {
1486 DCHECK(user_memory_);
1487 return op_pd != user_memory_->get_desc();
1488 }
1489
1490 /// Function to create a reorder from memory pointed by from to memory pointed
1491 /// by to. Returns created primitive.
1492 inline primitive CreateReorder(const memory* from, const memory* to) const {
1493 CHECK_NOTNULL(from);
1494 CHECK_NOTNULL(to);
1495 return reorder(*from, *to);
1496 }
1497
1498 /// Function to handle input reordering
1499 ///
1500 /// Check if we need to reorder this input of an operation.
1501 /// Return true and allocate reorder memory primitive if reorder is needed.
1502 /// Otherwise, return false and do not allocate reorder memory primitive.
1503 ///
1504 /// To check if reorder is needed, this function compares memory primitive
1505 /// descriptor (memory descriptor for v1.x) of an operation (op_pd) for
1506 /// the given input with the user-specified memory descriptor.
1507 ///
1508 /// @input: op_pd - memory primitive descriptor of the given input of an
1509 /// operation
1510 /// @input: net - net to which to add reorder primitive in case it is needed.
1511 /// @input: net_args - net to which user and reorder memories are added if
1512 /// needed. Each entry is a key-value pair of the form
1513 /// <argument-type, dnnl::memory>.
1514 /// @return: true in case reorder of input is needed; false, otherwise.
1515 inline bool CheckReorderToOpMem(const memory::desc& op_md,
1516 std::vector<primitive>& net,
1517 std::vector<MemoryArgsMap>& net_args,
1518 const engine& engine) {
1519 DCHECK(user_memory_);
1520 DCHECK_EQ(net.size(), net_args.size());
1521 if (IsReorderNeeded(op_md)) {
1522 // TODO(intel-tf): can we remove dynamic memory allocation?
1523 reorder_memory_ = new memory(op_md, engine);
1524 net.push_back(CreateReorder(user_memory_, reorder_memory_));
1525 net_args.push_back(MemoryArgsMap{{DNNL_ARG_FROM, *user_memory_},
1526 {DNNL_ARG_TO, *reorder_memory_}});
1527 return true;
1528 }
1529 return false;
1530 }
1531
1532 inline bool CheckReorderToOpMem(const memory::desc& op_md,
1533 const engine& engine,
1534 OpKernelContext* context = nullptr) {
1535 DCHECK(user_memory_);
1536 if (IsReorderNeeded(op_md)) {
1537 // TODO(intel-tf): can we remove dynamic memory allocation?
1538 // primitive reuse don't allow two same reorder prim in
1539 // one stream, so submit it immediately
1540 reorder_memory_ = new memory(op_md, engine);
1541 auto* prim = FindOrCreateReorder<T>(user_memory_, reorder_memory_);
1542 std::shared_ptr<stream> cpu_stream;
1543 MklDnnThreadPool eigen_tp;
1544 if (context != nullptr) {
1545 eigen_tp = MklDnnThreadPool(context);
1546 cpu_stream.reset(CreateStream(&eigen_tp, prim->GetEngine()));
1547 } else {
1548 cpu_stream.reset(CreateStream(nullptr, prim->GetEngine()));
1549 }
1550 std::vector<primitive> net;
1551 net.push_back(*(prim->GetPrimitive()));
1552 std::vector<MemoryArgsMap> net_args;
1553 net_args.push_back(
1554 {{DNNL_ARG_FROM, *user_memory_}, {DNNL_ARG_TO, *reorder_memory_}});
1555 execute_primitives(net, cpu_stream, net_args);
1556 return true;
1557 }
1558 return false;
1559 }
1560
1561 /// Overloaded version of above function that accepts memory buffer
1562 /// where output of reorder needs to be stored.
1563 ///
1564 /// @input: op_pd - memory primitive descriptor (memory descriptor for v1.x)
1565 /// of the given input of an operation
1566 /// @reorder_data_handle - memory buffer where output of reorder needs to be
1567 /// stored. Primitive does not check if buffer has
1568 /// enough size to write.
1569 /// @input: net - net to which to add reorder primitive in case it is needed.
1570 /// @input: net_args - net to which user and reorder memories are added if
1571 /// needed. Each entry is a key-value pair of the form
1572 /// <argument-type, dnnl::memory>.
1573 /// @input: engine - oneDNN's abstraction of a computational device
1574 /// @return: true in case reorder of input is needed; false, otherwise.
1575 inline bool CheckReorderToOpMem(const memory::desc& op_md,
1576 void* reorder_data_handle,
1577 std::vector<primitive>& net,
1578 std::vector<MemoryArgsMap>& net_args,
1579 const engine& engine) {
1580 DCHECK(reorder_data_handle);
1581 DCHECK(user_memory_);
1582 if (IsReorderNeeded(op_md)) {
1583 // TODO(intel-tf): can we remove dynamic memory allocation?
1584 reorder_memory_ = new memory(op_md, engine, reorder_data_handle);
1585 net.push_back(CreateReorder(user_memory_, reorder_memory_));
1586 net_args.push_back(MemoryArgsMap{{DNNL_ARG_FROM, *user_memory_},
1587 {DNNL_ARG_TO, *reorder_memory_}});
1588 return true;
1589 }
1590 return false;
1591 }
1592
1593 /// This is a faster path with reorder primitive cache compared with
1594 /// CheckReorderToOpMem(..., std::vector<primitive>* net).
1595 /// The slower path will be removed in the future
1596 /// TODO(intel-tf): Need to use reorder cache here for better performance.
1597 inline bool CheckReorderToOpMem(const memory::desc& op_md,
1598 void* reorder_data_handle,
1599 const engine& engine,
1600 OpKernelContext* context = nullptr) {
1601 DCHECK(reorder_data_handle);
1602 DCHECK(user_memory_);
1603 if (IsReorderNeeded(op_md)) {
1604 // TODO(intel-tf): can we remove dynamic memory allocation?
1605 // primitive reuse don't allow two same reorder prim in
1606 // one stream, so submit it immediately
1607 reorder_memory_ = new memory(op_md, engine, reorder_data_handle);
1608 auto* prim = FindOrCreateReorder<T>(user_memory_, reorder_memory_);
1609 std::shared_ptr<stream> cpu_stream;
1610 MklDnnThreadPool eigen_tp;
1611 if (context != nullptr) {
1612 eigen_tp = MklDnnThreadPool(context);
1613 cpu_stream.reset(CreateStream(&eigen_tp, prim->GetEngine()));
1614 } else {
1615 cpu_stream.reset(CreateStream(nullptr, prim->GetEngine()));
1616 }
1617 std::vector<primitive> net;
1618 net.push_back(*(prim->GetPrimitive()));
1619 std::vector<MemoryArgsMap> net_args;
1620 net_args.push_back(
1621 {{DNNL_ARG_FROM, *user_memory_}, {DNNL_ARG_TO, *reorder_memory_}});
1622 execute_primitives(net, cpu_stream, net_args);
1623 return true;
1624 }
1625 return false;
1626 }
1627
1628 /// Another overloaded version of CheckReorderToOpMem that accepts Tensor
1629 /// where output of reorder needs to be stored.
1630 ///
1631 /// @input: op_md - memory primitive descriptor (memory descriptor for v1.x)
1632 /// of the given input of an operation
1633 /// @reorder_tensor - Tensor whose buffer is to be used to store output of
1634 /// reorder. Primitive does not check if buffer is
1635 /// enough size to write.
1636 /// @input: net - net to which to add reorder primitive in case it is needed.
1637 /// @input: net_args - net to which user and reorder memories are added if
1638 /// needed. Each entry is a key-value pair of the form
1639 /// <argument-type, dnnl::memory>.
1640 /// @input: engine - MKL-DNN's abstraction of a computational device
1641 /// @return: true in case reorder of input is needed; false, otherwise.
1642 inline bool CheckReorderToOpMem(const memory::desc& op_md,
1643 Tensor* reorder_tensor,
1644 std::vector<primitive>& net,
1645 std::vector<MemoryArgsMap>& net_args,
1646 const engine& engine) {
1647 DCHECK(reorder_tensor);
1648 return CheckReorderToOpMem(op_md, GetTensorBuffer(reorder_tensor), net,
1649 net_args, engine);
1650 }
1651
1652 /// TODO(intel-tf): this is a faster path with reorder primitive cache
1653 /// compared with CheckReorderToOpMem(op_md, reorder_tensor, net, net_args,
1654 /// engine), will remove slow path in the future.
1655 inline bool CheckReorderToOpMem(const memory::desc& op_pd,
1656 Tensor* reorder_tensor,
1657 OpKernelContext* ctx = nullptr) {
1658 DCHECK(reorder_tensor);
1659 return CheckReorderToOpMem(op_pd, GetTensorBuffer(reorder_tensor),
1660 *cpu_engine_, ctx);
1661 }
1662
1663 /// Function to handle output reorder
1664 ///
1665 /// This function performs very similar functionality as input reordering
1666 /// function above. The only difference is that this function does not add
1667 /// reorder primitive to the net. The reason for this is: the reorder
1668 /// primitive for output needs to be added to the list only after operation
1669 /// has executed. But we need to prepare a temporary buffer in case output
1670 /// reorder is needed. And this temporary buffer will hold the output of
1671 /// an operation before it is fed to reorder primitive.
1672 ///
1673 /// @input - memory primitive descriptor (memory descriptor for v1.x) for the
1674 /// given output of an operation
1675 /// @return: true in case reorder of output is needed; false, otherwise.
1676 inline bool PrepareReorderToUserMemIfReq(const memory::desc& op_pd) {
1677 DCHECK(user_memory_);
1678 if (IsReorderNeeded(op_pd)) {
1679 // TODO(intel-tf): can we remove dynamic memory allocation?
1680 reorder_memory_ = new memory(op_pd, *cpu_engine_);
1681 return true;
1682 }
1683 return false;
1684 }
1685
1686 /// Function to actually insert reorder primitive in the net
1687 ///
1688 /// This function completes remaining part of output reordering. It inserts
1689 /// a reordering primitive from the temporary buffer that holds the output
1690 /// to the user-specified output buffer.
1691 ///
1692 /// @input: net - net to which to add reorder primitive
1693 /// @input: net_args - net to which user and reorder memories are added if
1694 /// needed. Each entry is a key-value pair of the form
1695 /// <argument-type, dnnl::memory>.
1696 inline void InsertReorderToUserMem(std::vector<primitive>& net,
1697 std::vector<MemoryArgsMap>& net_args) {
1698 DCHECK(user_memory_);
1699 DCHECK(reorder_memory_);
1700 net.push_back(CreateReorder(reorder_memory_, user_memory_));
1701 net_args.push_back(MemoryArgsMap{{DNNL_ARG_FROM, *reorder_memory_},
1702 {DNNL_ARG_TO, *user_memory_}});
1703 }
1704
1705 /// TODO(intel-tf): this is a faster path with reorder primitive cache
1706 /// compared with InsertReorderToUserMem(net, net_args), will remove
1707 /// slow path in the future
1708 inline void InsertReorderToUserMem(OpKernelContext* ctx = nullptr) {
1709 DCHECK(user_memory_);
1710 DCHECK(reorder_memory_);
1711 DCHECK(cpu_engine_);
1712 // primitive reuse don't allow two same reorder prim in
1713 // one stream, so submit it immediately
1714 std::vector<primitive> net;
1715 auto* prim = FindOrCreateReorder<T>(reorder_memory_, user_memory_);
1716 net.push_back(*(prim->GetPrimitive()));
1717 std::vector<MemoryArgsMap> net_args;
1718 net_args.push_back(
1719 {{DNNL_ARG_FROM, *reorder_memory_}, {DNNL_ARG_TO, *user_memory_}});
1720 std::shared_ptr<stream> cpu_stream;
1721 MklDnnThreadPool eigen_tp;
1722 if (ctx != nullptr) {
1723 eigen_tp = MklDnnThreadPool(ctx);
1724 cpu_stream.reset(CreateStream(&eigen_tp, prim->GetEngine()));
1725 } else {
1726 cpu_stream.reset(CreateStream(nullptr, prim->GetEngine()));
1727 }
1728 execute_primitives(net, cpu_stream, net_args);
1729 }
1730};
1731
1732/// Base class for operations with reuse of primitives
1733class MklPrimitive {
1734 public:
1735 virtual ~MklPrimitive() {}
1736 MklPrimitive() {}
1737 MklPrimitive(const engine& cpu_engine) { cpu_engine_ = cpu_engine; }
1738 // Dummy data which MKL DNN never operates on
1739 unsigned char* DummyData = nullptr;
1740 engine cpu_engine_ = engine(engine::kind::cpu, 0);
1741 const engine& GetEngine() { return cpu_engine_; }
1742};
1743
1744const dnnl::memory::dims NONE_DIMS = {};
1745
1746//
1747// LRUCache is a class which implements LRU (Least Recently Used) cache.
1748// The implementation is similar to that of
1749// tensorflow/core/platform/cloud/expiring_lru_cache.h
1750// without its thread-safe part because the cache is supposed to be
1751// used as thread local (for instance, MklPrimitive caching).
1752//
1753// The LRU list maintains objects in chronological order based on
1754// creation time, with the least recently accessed object at the
1755// tail of LRU list, while the most recently accessed object
1756// at the head of LRU list.
1757//
1758// This class is used to maintain an upper bound on the total number of
1759// cached items. When the cache reaches its capacity, the LRU item will
1760// be removed and replaced by a new one from SetOp call.
1761//
1762template <typename T>
1763class LRUCache {
1764 public:
1765 explicit LRUCache(size_t capacity) {
1766 capacity_ = capacity;
1767 Clear();
1768 }
1769
1770 T* GetOp(const string& key) {
1771#ifdef DNNL_AARCH64_USE_ACL
1772 mutex_lock lock(lru_mu_);
1773#endif
1774 auto it = cache_.find(key);
1775 if (it == cache_.end()) {
1776 return nullptr;
1777 }
1778
1779 // Move to the front of LRU list as the most recently accessed.
1780 lru_list_.erase(it->second.lru_iterator);
1781 lru_list_.push_front(it->first);
1782 it->second.lru_iterator = lru_list_.begin();
1783 return it->second.op;
1784 }
1785
1786 void SetOp(const string& key, T* op) {
1787#ifdef DNNL_AARCH64_USE_ACL
1788 mutex_lock lock(lru_mu_);
1789#endif
1790 if (lru_list_.size() >= capacity_) {
1791 Delete();
1792 }
1793
1794 // Insert an entry to the front of the LRU list
1795 lru_list_.push_front(key);
1796 Entry entry(op, lru_list_.begin());
1797 cache_.emplace(std::make_pair(key, std::move(entry)));
1798#ifdef DNNL_AARCH64_USE_ACL
1799 FinishedAllocation(key);
1800#endif
1801 }
1802
1803 void Clear() {
1804 if (lru_list_.empty()) return;
1805
1806 // Clean up the cache
1807 cache_.clear();
1808 lru_list_.clear();
1809 }
1810
1811#ifdef DNNL_AARCH64_USE_ACL
1812 bool IsAllocating(const string& key) {
1813 mutex_lock lock(in_flight_mu_);
1814 return in_flight_.find(key) != in_flight_.end();
1815 }
1816
1817 void Allocate(const string& key) {
1818 mutex_lock lock(in_flight_mu_);
1819 in_flight_.insert(key);
1820 }
1821
1822 void FinishedAllocation(const string& key) {
1823 mutex_lock lock(in_flight_mu_);
1824 in_flight_.erase(key);
1825 }
1826#endif
1827
1828 private:
1829 struct Entry {
1830 // The entry's value.
1831 T* op;
1832
1833 // A list iterator pointing to the entry's position in the LRU list.
1834 std::list<string>::iterator lru_iterator;
1835
1836 // Constructor
1837 Entry(T* op, std::list<string>::iterator it) {
1838 this->op = op;
1839 this->lru_iterator = it;
1840 }
1841
1842 // Move constructor
1843 Entry(Entry&& source) noexcept
1844 : lru_iterator(std::move(source.lru_iterator)) {
1845 op = std::move(source.op);
1846 source.op = std::forward<T*>(nullptr);
1847 }
1848
1849 // Destructor
1850 ~Entry() {
1851 if (op != nullptr) delete op;
1852 }
1853 };
1854
1855 // Remove the least recently accessed entry from LRU list, which
1856 // is the tail of lru_list_. Update cache_ correspondingly.
1857 bool Delete() {
1858 if (lru_list_.empty()) return false;
1859 string key = lru_list_.back();
1860 lru_list_.pop_back();
1861 cache_.erase(key);
1862 return true;
1863 }
1864
1865 // Cache capacity
1866 size_t capacity_;
1867
1868 // The cache, a map from string key to a LRU entry.
1869 std::unordered_map<string, Entry> cache_;
1870
1871 // The LRU list of entries.
1872 // The front of the list contains the key of the most recently accessed
1873 // entry, while the back of the list is the least recently accessed entry.
1874 std::list<string> lru_list_;
1875
1876#ifdef DNNL_AARCH64_USE_ACL
1877 // Guards access to the cache and LRU list
1878 mutex lru_mu_;
1879
1880 // The keys that are currently under creation
1881 std::set<string> in_flight_;
1882 TF_GUARDED_BY(in_flight_mu_)
1883 mutex in_flight_mu_;
1884#endif
1885};
1886
1887template <typename T>
1888class MklPrimitiveFactory {
1889 public:
1890 MklPrimitiveFactory() {}
1891
1892 ~MklPrimitiveFactory() {}
1893
1894 MklPrimitive* GetOp(const string& key) {
1895#ifndef DNNL_AARCH64_USE_ACL
1896 auto& lru_cache = MklPrimitiveFactory<T>::GetLRUCache();
1897 return lru_cache.GetOp(key);
1898#else
1899 while (true) {
1900 // TODO(milpuz01): Consider if it is possible to narrow scope to be
1901 // only around checks for allocations and conditional wait.
1902 mutex_lock lock(primitive_creation_mu_);
1903 auto& lru_cache = MklPrimitiveFactory<T>::GetLRUCache();
1904
1905 // Check to see whether primitive already exists.
1906 MklPrimitive* primitive = lru_cache.GetOp(key);
1907 if (primitive != nullptr) {
1908 return primitive;
1909 }
1910
1911 // Now check whether some other thread is creating this primitive.
1912 if (!lru_cache.IsAllocating(key)) {
1913 // This thread is going to pick it up and create the primitive.
1914 lru_cache.Allocate(key);
1915 return nullptr;
1916 // Now we release lock as primitive creation might take long time.
1917 }
1918
1919 // At this point we cannot create primitive as other thread is creating
1920 // it. We should wait for primitive to get created.
1921 primitive_creation_cv_.wait(lock);
1922
1923 // The primitive is created and is in the cache so we are going to try
1924 // retrieve it again after getting a lock on it as multiple threads might
1925 // be waiting for the primitive.
1926 }
1927#endif
1928 }
1929
1930 void SetOp(const string& key, MklPrimitive* op) {
1931#ifndef DNNL_AARCH64_USE_ACL
1932 auto& lru_cache = MklPrimitiveFactory<T>::GetLRUCache();
1933 lru_cache.SetOp(key, op);
1934#else
1935 {
1936 mutex_lock lock(primitive_creation_mu_);
1937 auto& lru_cache = MklPrimitiveFactory<T>::GetLRUCache();
1938 lru_cache.SetOp(key, op);
1939 }
1940
1941 // Now we can inform all waiting threads that primitive is created.
1942 primitive_creation_cv_.notify_all();
1943#endif
1944 }
1945
1946 /// Function to decide whether HW has AVX512 or AVX2
1947 /// For those legacy device(w/o AVX512 and AVX2),
1948 /// MKL-DNN GEMM will be used.
1949 static inline bool IsLegacyPlatform() {
1950 static const bool is_legacy_platform =
1951 (!port::TestCPUFeature(port::CPUFeature::AVX512F) &&
1952 !port::TestCPUFeature(port::CPUFeature::AVX2));
1953 return is_legacy_platform;
1954 }
1955
1956 /// Function to check whether primitive memory optimization is enabled
1957 static inline bool IsPrimitiveMemOptEnabled() {
1958 static const bool is_primitive_mem_opt_enabled = [] {
1959 bool value = true;
1960 TF_CHECK_OK(
1961 ReadBoolFromEnvVar("TF_MKL_OPTIMIZE_PRIMITIVE_MEMUSE", true, &value));
1962 return value;
1963 }();
1964 return is_primitive_mem_opt_enabled;
1965 }
1966
1967#ifdef DNNL_AARCH64_USE_ACL
1968 static int IncrementCounter() {
1969 static std::atomic_int counter{1};
1970 return counter.fetch_add(1);
1971 }
1972#endif
1973
1974 private:
1975 static inline LRUCache<MklPrimitive>& GetLRUCache() {
1976 static const int kCapacity = 1024; // cache capacity
1977#ifndef DNNL_AARCH64_USE_ACL
1978 static thread_local LRUCache<MklPrimitive> lru_cache_(kCapacity);
1979#else
1980 static LRUCache<MklPrimitive> lru_cache_(kCapacity);
1981 TF_GUARDED_BY(lru_mu_)
1982#endif
1983 return lru_cache_;
1984 }
1985
1986#ifdef DNNL_AARCH64_USE_ACL
1987 mutex primitive_creation_mu_;
1988 condition_variable primitive_creation_cv_;
1989#endif
1990};
1991
1992// utility class for creating keys of MKL primitive pool.
1993class FactoryKeyCreator {
1994 public:
1995 FactoryKeyCreator() { key_.reserve(kMaxKeyLength); }
1996
1997 ~FactoryKeyCreator() {}
1998
1999 void AddAsKey(const string& str) { Append(str); }
2000
2001 void AddAsKey(const dnnl::memory::dims& dims) {
2002 for (unsigned int i = 0; i < dims.size(); i++) {
2003 AddAsKey<int>(dims[i]);
2004 }
2005 }
2006
2007 template <typename T>
2008 void AddAsKey(const T data) {
2009 auto buffer = reinterpret_cast<const char*>(&data);
2010 Append(StringPiece(buffer, sizeof(T)));
2011 }
2012
2013 // generalisation to handle pointers
2014 void AddAsKey(const void* data) {
2015 auto buffer = reinterpret_cast<const char*>(&data);
2016 Append(StringPiece(buffer, sizeof(data)));
2017 }
2018
2019 string GetKey() { return key_; }
2020
2021 private:
2022 string key_;
2023 const char delimiter = 'x';
2024 const int kMaxKeyLength = 256;
2025 void Append(StringPiece s) {
2026 key_.append(string(s));
2027 key_.append(1, delimiter);
2028 }
2029};
2030
2031class MklReorderPrimitive : public MklPrimitive {
2032 public:
2033 explicit MklReorderPrimitive(const memory* from, const memory* to)
2034 : MklPrimitive(engine(engine::kind::cpu, 0)) {
2035 Setup(from, to);
2036 }
2037 ~MklReorderPrimitive() {}
2038
2039 std::shared_ptr<primitive> GetPrimitive() { return context_.reorder_prim; }
2040
2041 void SetMemory(const memory* from, const memory* to) {
2042 context_.src_mem->set_data_handle(from->get_data_handle());
2043 context_.dst_mem->set_data_handle(to->get_data_handle());
2044 }
2045
2046 std::shared_ptr<dnnl::stream> GetStream() { return stream_; }
2047
2048 private:
2049 struct ReorderContext {
2050 std::shared_ptr<dnnl::memory> src_mem;
2051 std::shared_ptr<dnnl::memory> dst_mem;
2052 std::shared_ptr<primitive> reorder_prim;
2053 ReorderContext()
2054 : src_mem(nullptr), dst_mem(nullptr), reorder_prim(nullptr) {}
2055 } context_;
2056
2057 std::shared_ptr<dnnl::stream> stream_;
2058
2059 void Setup(const memory* from, const memory* to) {
2060 context_.src_mem.reset(
2061 new memory(from->get_desc(), cpu_engine_, DummyData));
2062 context_.dst_mem.reset(new memory(to->get_desc(), cpu_engine_, DummyData));
2063 context_.reorder_prim = std::make_shared<dnnl::reorder>(
2064 reorder(*context_.src_mem, *context_.dst_mem));
2065 stream_.reset(new stream(cpu_engine_));
2066 }
2067};
2068
2069template <typename T>
2070class MklReorderPrimitiveFactory : public MklPrimitiveFactory<T> {
2071 public:
2072 static MklReorderPrimitive* Get(const memory* from, const memory* to) {
2073 auto reorderPrim = static_cast<MklReorderPrimitive*>(
2074 MklReorderPrimitiveFactory<T>::GetInstance().GetReorder(from, to));
2075 if (reorderPrim == nullptr) {
2076 reorderPrim = new MklReorderPrimitive(from, to);
2077 MklReorderPrimitiveFactory<T>::GetInstance().SetReorder(from, to,
2078 reorderPrim);
2079 }
2080 reorderPrim->SetMemory(from, to);
2081 return reorderPrim;
2082 }
2083
2084 static MklReorderPrimitiveFactory& GetInstance() {
2085 static MklReorderPrimitiveFactory instance_;
2086 return instance_;
2087 }
2088
2089 static string CreateKey(const memory* from, const memory* to) {
2090 string prefix = "reorder";
2091 FactoryKeyCreator key_creator;
2092 auto const& from_desc = from->get_desc().data;
2093 auto const& to_desc = to->get_desc().data;
2094 memory::dims from_dims(from_desc.dims, &from_desc.dims[from_desc.ndims]);
2095 memory::dims to_dims(to_desc.dims, &to_desc.dims[to_desc.ndims]);
2096 auto from_strides = from_desc.format_desc.blocking.strides;
2097
2098 // As DNNL memory desc has C style array and only init the used
2099 // part, so need use the valid part as key.
2100 auto from_inner_nblks = from_desc.format_desc.blocking.inner_nblks;
2101 auto from_inner_blks = from_desc.format_desc.blocking.inner_blks;
2102 auto from_inner_idxs = from_desc.format_desc.blocking.inner_idxs;
2103 memory::dims from_inner_blks_1(from_inner_blks,
2104 &from_inner_blks[from_inner_nblks]);
2105 memory::dims from_inner_idxs_1(from_inner_idxs,
2106 &from_inner_idxs[from_inner_nblks]);
2107 auto to_inner_nblks = to_desc.format_desc.blocking.inner_nblks;
2108 auto to_inner_blks = to_desc.format_desc.blocking.inner_blks;
2109 auto to_inner_idxs = to_desc.format_desc.blocking.inner_idxs;
2110 memory::dims to_inner_blks_1(to_inner_blks, &to_inner_blks[to_inner_nblks]);
2111 memory::dims to_inner_idxs_1(to_inner_idxs, &to_inner_idxs[to_inner_nblks]);
2112
2113 auto to_strides = to_desc.format_desc.blocking.strides;
2114 memory::dims from_strides_outer_blocks(from_strides,
2115 &from_strides[from_desc.ndims]);
2116 memory::dims to_strides_outer_blocks(to_strides,
2117 &to_strides[to_desc.ndims]);
2118
2119 key_creator.AddAsKey(prefix);
2120#ifdef DNNL_AARCH64_USE_ACL
2121 // The reorder primitives have local memory (calls to SetMemory) so we
2122 // need to make sure that memory for those primitives is cached per thread.
2123 key_creator.AddAsKey(std::this_thread::get_id());
2124#endif
2125 key_creator.AddAsKey(static_cast<int>(from_desc.extra.flags));
2126 key_creator.AddAsKey(static_cast<int>(from_inner_nblks));
2127 key_creator.AddAsKey(from_inner_blks_1);
2128 key_creator.AddAsKey(from_inner_idxs_1);
2129 key_creator.AddAsKey(static_cast<int>(from_desc.data_type));
2130 key_creator.AddAsKey(from_dims);
2131 key_creator.AddAsKey(from_strides_outer_blocks);
2132 key_creator.AddAsKey(static_cast<int>(to_desc.extra.flags));
2133 key_creator.AddAsKey(static_cast<int>(to_inner_nblks));
2134 key_creator.AddAsKey(to_inner_blks_1);
2135 key_creator.AddAsKey(to_inner_idxs_1);
2136 key_creator.AddAsKey(static_cast<int>(to_desc.data_type));
2137 key_creator.AddAsKey(to_dims);
2138 key_creator.AddAsKey(to_strides_outer_blocks);
2139 return key_creator.GetKey();
2140 }
2141
2142 private:
2143 MklReorderPrimitiveFactory() {}
2144 ~MklReorderPrimitiveFactory() {}
2145
2146 MklPrimitive* GetReorder(const memory* from, const memory* to) {
2147 string key = CreateKey(from, to);
2148 return this->GetOp(key);
2149 }
2150
2151 void SetReorder(const memory* from, const memory* to, MklPrimitive* op) {
2152 string key = CreateKey(from, to);
2153 this->SetOp(key, op);
2154 }
2155};
2156
2157/// Function to find(or create) a reorder from memory pointed by
2158/// from to memory pointed by to, it will created primitive or
2159/// get primitive from pool if it is cached.
2160/// Returns the primitive.
2161template <typename T>
2162inline MklReorderPrimitive* FindOrCreateReorder(const memory* from,
2163 const memory* to) {
2164 CHECK_NOTNULL(from);
2165 CHECK_NOTNULL(to);
2166 MklReorderPrimitive* reorder_prim =
2167 MklReorderPrimitiveFactory<T>::Get(from, to);
2168 return reorder_prim;
2169}
2170
2171// utility function to determine if it is conv 1x1 and stride != 1
2172// for purpose of temporarily disabling primitive reuse
2173inline bool IsConv1x1StrideNot1(memory::dims filter_dims,
2174 memory::dims strides) {
2175 if (filter_dims.size() != 4 || strides.size() != 2) return false;
2176
2177 return ((filter_dims[2] == 1) && (filter_dims[3] == 1) &&
2178 ((strides[0] != 1) || (strides[1] != 1)));
2179}
2180
2181} // namespace tensorflow
2182
2183/////////////////////////////////////////////////////////////////////
2184// Macros for handling registration for various types
2185/////////////////////////////////////////////////////////////////////
2186
2187#define REGISTER_TEST_FLOAT32(TEST) REGISTER_TEST(TEST, DT_FLOAT, Float32Input);
2188
2189#define REGISTER_TEST_BFLOAT16(TEST) \
2190 REGISTER_TEST(TEST, DT_BFLOAT16, BFloat16Input);
2191
2192#define REGISTER_TEST_ALL_TYPES(TEST) \
2193 REGISTER_TEST_FLOAT32(TEST); \
2194 REGISTER_TEST_BFLOAT16(TEST);
2195#else
2196#define REGISTER_TEST_ALL_TYPES(TEST) REGISTER_TEST_FLOAT32(TEST);
2197
2198#endif // INTEL_MKL
2199#endif // TENSORFLOW_CORE_UTIL_MKL_UTIL_H_
2200