1/*******************************************************************************
2* Copyright 2021-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "common/serialization.hpp"
18#include "common/type_helpers.hpp"
19#include "common/utils.hpp"
20
21namespace dnnl {
22namespace impl {
23namespace serialization {
24
25status_t serialize_desc(
26 serialization_stream_t &sstream, const op_desc_t *op_desc) {
27#define CASE(pkind) \
28 case primitive_kind::pkind: \
29 serialize_desc(sstream, *(const pkind##_desc_t *)op_desc); \
30 break;
31
32 switch ((int)op_desc->kind) {
33 CASE(batch_normalization)
34 CASE(binary)
35 CASE(concat)
36 CASE(convolution)
37 CASE(deconvolution)
38 CASE(eltwise)
39 CASE(inner_product)
40 CASE(gemm)
41 CASE(layer_normalization)
42 CASE(lrn)
43 CASE(matmul)
44 CASE(pooling)
45 CASE(prelu)
46 CASE(reduction)
47 CASE(reorder)
48 CASE(resampling)
49 CASE(rnn)
50 CASE(shuffle)
51 CASE(softmax)
52 CASE(sum)
53 default: return status::invalid_arguments;
54 }
55#undef CASE
56 return status::success;
57}
58
59void serialize_md(serialization_stream_t &sstream, const memory_desc_t &md) {
60 sstream.write(&md.ndims);
61 sstream.write(md.dims, md.ndims);
62 sstream.write(&md.data_type);
63 sstream.write(md.padded_dims, md.ndims);
64 sstream.write(md.padded_offsets, md.ndims);
65 sstream.write(&md.offset0);
66 sstream.write(&md.format_kind);
67 // format desc
68 switch ((int)md.format_kind) {
69 case format_kind::undef:
70 case format_kind::any: break;
71 case format_kind::blocked:
72 sstream.write(md.format_desc.blocking.strides, md.ndims);
73 sstream.write(&md.format_desc.blocking.inner_nblks);
74 sstream.write(md.format_desc.blocking.inner_blks,
75 md.format_desc.blocking.inner_nblks);
76 sstream.write(md.format_desc.blocking.inner_idxs,
77 md.format_desc.blocking.inner_nblks);
78 break;
79 case format_kind::wino:
80 sstream.write(&md.format_desc.wino_desc.wino_format);
81 sstream.write(&md.format_desc.wino_desc.r);
82 sstream.write(&md.format_desc.wino_desc.alpha);
83 sstream.write(&md.format_desc.wino_desc.ic);
84 sstream.write(&md.format_desc.wino_desc.oc);
85 sstream.write(&md.format_desc.wino_desc.ic_block);
86 sstream.write(&md.format_desc.wino_desc.oc_block);
87 sstream.write(&md.format_desc.wino_desc.ic2_block);
88 sstream.write(&md.format_desc.wino_desc.oc2_block);
89 sstream.write(&md.format_desc.wino_desc.adj_scale);
90 sstream.write(&md.format_desc.wino_desc.size);
91 break;
92 case format_kind::rnn_packed:
93 sstream.write(&md.format_desc.rnn_packed_desc.format);
94 sstream.write(&md.format_desc.rnn_packed_desc.n_parts);
95 sstream.write(&md.format_desc.rnn_packed_desc.n);
96 sstream.write(&md.format_desc.rnn_packed_desc.ldb);
97 {
98 int n_parts = md.format_desc.rnn_packed_desc.n_parts;
99 sstream.write(md.format_desc.rnn_packed_desc.parts, n_parts);
100 sstream.write(
101 md.format_desc.rnn_packed_desc.part_pack_size, n_parts);
102 sstream.write(
103 md.format_desc.rnn_packed_desc.pack_part, n_parts);
104 }
105 sstream.write(&md.format_desc.rnn_packed_desc.offset_compensation);
106 sstream.write(&md.format_desc.rnn_packed_desc.size);
107 break;
108 default: assert(!"unknown format_kind");
109 }
110
111 if (md.extra.flags != dnnl_memory_extra_flag_none) {
112 sstream.write(&md.extra.flags);
113 if ((md.extra.flags
114 & (dnnl_memory_extra_flag_compensation_conv_s8s8
115 | dnnl_memory_extra_flag_rnn_u8s8_compensation))
116 && !types::extra_flag_rnn_s8s8_compensation_is_set(
117 md.extra.flags)) {
118 sstream.write(&md.extra.compensation_mask);
119 }
120
121 if (md.extra.flags & dnnl_memory_extra_flag_scale_adjust) {
122 sstream.write(&md.extra.scale_adjust);
123 }
124
125 if (md.extra.flags
126 & dnnl_memory_extra_flag_compensation_conv_asymmetric_src) {
127 sstream.write(&md.extra.asymm_compensation_mask);
128 }
129 }
130}
131
132void serialize_attr(
133 serialization_stream_t &sstream, const primitive_attr_t &attr) {
134 // scratchpad_mode
135 sstream.write(&attr.scratchpad_mode_);
136 // fpmath_mode
137 sstream.write(&attr.fpmath_mode_);
138
139 if (!attr.output_scales_.has_default_values()) {
140 // output_scales: mask
141 sstream.write(&attr.output_scales_.mask_);
142 } else if (!attr.scales_.has_default_values()) {
143 // go through scales for all arguments
144 for (const auto &p : attr.scales_.scales_) {
145 sstream.write(&p.first);
146 sstream.write(&p.second.mask_);
147 }
148 }
149 // zero_points
150 for (int arg : {DNNL_ARG_SRC, DNNL_ARG_WEIGHTS, DNNL_ARG_DST})
151 if (!attr.zero_points_.has_default_values(arg)) {
152 // zero_points: arg
153 sstream.write(&arg);
154 int mask = 0;
155 attr.zero_points_.get(arg, &mask);
156 // zero_points: mask
157 sstream.write(&mask);
158 }
159 // post_ops: entry[:]
160 for (int i = 0; i < attr.post_ops_.len(); i++) {
161 const auto &entry = attr.post_ops_.entry_[i];
162 switch (entry.kind) {
163 case primitive_kind::eltwise:
164 sstream.write(&entry.eltwise.alg);
165 sstream.write(&entry.eltwise.scale);
166 sstream.write(&entry.eltwise.alpha);
167 sstream.write(&entry.eltwise.beta);
168 break;
169 case primitive_kind::sum:
170 sstream.write(&entry.sum.scale);
171 sstream.write(&entry.sum.zero_point);
172 sstream.write(&entry.sum.dt);
173 break;
174 case primitive_kind::convolution:
175 sstream.write(&entry.depthwise_conv.kernel);
176 sstream.write(&entry.depthwise_conv.stride);
177 sstream.write(&entry.depthwise_conv.padding);
178 sstream.write(&entry.depthwise_conv.wei_dt);
179 sstream.write(&entry.depthwise_conv.bias_dt);
180 sstream.write(&entry.depthwise_conv.dst_dt);
181 break;
182 case primitive_kind::binary:
183 sstream.write(&entry.binary.alg);
184 serialize_md(sstream, entry.binary.user_src1_desc);
185 break;
186 case primitive_kind::prelu: sstream.write(&entry.prelu.mask); break;
187 default: assert(!"unknown post_op");
188 }
189 }
190 // rnn_data_qparams: scale, shift
191 sstream.write(&attr.rnn_data_qparams_.scale_);
192 sstream.write(&attr.rnn_data_qparams_.shift_);
193 if (!attr.rnn_weights_qparams_.has_default_values()) {
194 // rnn_weights_qparams: mask
195 sstream.write(&attr.rnn_weights_qparams_.mask_);
196 // rnn_weights_qparams: count
197 sstream.write(&attr.rnn_weights_qparams_.count_);
198 // rnn_weights_qparams: scales[:]
199 sstream.write(attr.rnn_weights_qparams_.scales_,
200 attr.rnn_weights_qparams_.count_);
201 }
202 if (attr.gpu_attr_) {
203 attr.gpu_attr_->serialize(sstream);
204 } else {
205 int zero = 0;
206 sstream.write(&zero);
207 }
208}
209
210void serialize_desc(
211 serialization_stream_t &sstream, const concat_desc_t &desc) {
212 // Kinds
213 sstream.write(&desc.primitive_kind);
214 // Memory descriptors
215 serialize_md(sstream, *desc.dst_md);
216 // N
217 sstream.write(&desc.n);
218 // Concat dimension
219 sstream.write(&desc.concat_dimension);
220 // Array of mds
221 for (int i = 0; i < desc.n; i++)
222 serialize_md(sstream, *desc.src_mds[i]);
223}
224
225void serialize_desc(serialization_stream_t &sstream,
226 const batch_normalization_desc_t &desc) {
227 // Kinds
228 sstream.write(&desc.primitive_kind);
229 sstream.write(&desc.prop_kind);
230 // Memory descriptors
231 serialize_md(sstream, desc.src_desc);
232 serialize_md(sstream, desc.dst_desc);
233 serialize_md(sstream, desc.diff_src_desc);
234 serialize_md(sstream, desc.diff_dst_desc);
235 serialize_md(sstream, desc.scaleshift_desc);
236 serialize_md(sstream, desc.diff_scaleshift_desc);
237 serialize_md(sstream, desc.stat_desc);
238 // Epsilon
239 sstream.write(&desc.batch_norm_epsilon);
240 // Flags
241 sstream.write(&desc.flags);
242}
243
244void serialize_desc(
245 serialization_stream_t &sstream, const binary_desc_t &desc) {
246 // Kinds
247 sstream.write(&desc.primitive_kind);
248 sstream.write(&desc.alg_kind);
249 // Memory descriptors
250 serialize_md(sstream, desc.src_desc[0]);
251 serialize_md(sstream, desc.src_desc[1]);
252 serialize_md(sstream, desc.dst_desc);
253}
254
255// (De-)Convolution
256void serialize_desc(
257 serialization_stream_t &sstream, const convolution_desc_t &desc) {
258 // Kinds
259 sstream.write(&desc.primitive_kind);
260 sstream.write(&desc.prop_kind);
261 sstream.write(&desc.alg_kind);
262 // Memory descriptors
263 serialize_md(sstream, desc.src_desc);
264 serialize_md(sstream, desc.diff_src_desc);
265 serialize_md(sstream, desc.weights_desc);
266 serialize_md(sstream, desc.diff_weights_desc);
267 serialize_md(sstream, desc.bias_desc);
268 serialize_md(sstream, desc.diff_bias_desc);
269 serialize_md(sstream, desc.dst_desc);
270 serialize_md(sstream, desc.diff_dst_desc);
271 // Strides, dilates, padding
272 sstream.write(desc.strides, DNNL_MAX_NDIMS);
273 sstream.write(desc.dilates, DNNL_MAX_NDIMS);
274 sstream.write(desc.padding[0], DNNL_MAX_NDIMS);
275 sstream.write(desc.padding[1], DNNL_MAX_NDIMS);
276 // Accumulator type
277 sstream.write(&desc.accum_data_type);
278}
279
280// Eltwise
281void serialize_desc(
282 serialization_stream_t &sstream, const eltwise_desc_t &desc) {
283 // Kinds
284 sstream.write(&desc.primitive_kind);
285 sstream.write(&desc.prop_kind);
286 sstream.write(&desc.alg_kind);
287 // Memory descriptors
288 serialize_md(sstream, desc.src_desc);
289 serialize_md(sstream, desc.dst_desc);
290 serialize_md(sstream, desc.diff_src_desc);
291 serialize_md(sstream, desc.diff_dst_desc);
292 // Alpha, beta
293 sstream.write(&desc.alpha);
294 sstream.write(&desc.beta);
295}
296
297void serialize_desc(serialization_stream_t &sstream, const gemm_desc_t &desc) {
298 // Kind
299 sstream.write(&desc.primitive_kind);
300 serialize_md(sstream, desc.a_desc);
301 serialize_md(sstream, desc.b_desc);
302 serialize_md(sstream, desc.c_desc);
303 serialize_md(sstream, desc.bias_desc);
304 // Accumulator type
305 sstream.write(&desc.acc_type);
306 sstream.write(&desc.sum_ab);
307 sstream.write(&desc.sum_ab_type);
308}
309
310void serialize_desc(
311 serialization_stream_t &sstream, const inner_product_desc_t &desc) {
312 // Kinds
313 sstream.write(&desc.primitive_kind);
314 sstream.write(&desc.prop_kind);
315 // Memory descriptors
316 serialize_md(sstream, desc.src_desc);
317 serialize_md(sstream, desc.diff_src_desc);
318 serialize_md(sstream, desc.weights_desc);
319 serialize_md(sstream, desc.diff_weights_desc);
320 serialize_md(sstream, desc.bias_desc);
321 serialize_md(sstream, desc.diff_bias_desc);
322 serialize_md(sstream, desc.dst_desc);
323 serialize_md(sstream, desc.diff_dst_desc);
324 // Accumulator type
325 sstream.write(&desc.accum_data_type);
326}
327
328void serialize_desc(serialization_stream_t &sstream,
329 const layer_normalization_desc_t &desc) {
330 // Kinds
331 sstream.write(&desc.primitive_kind);
332 sstream.write(&desc.prop_kind);
333 // Memory descriptors
334 serialize_md(sstream, desc.src_desc);
335 serialize_md(sstream, desc.diff_src_desc);
336 serialize_md(sstream, desc.data_scaleshift_desc);
337 serialize_md(sstream, desc.diff_data_scaleshift_desc);
338 serialize_md(sstream, desc.dst_desc);
339 serialize_md(sstream, desc.diff_dst_desc);
340 serialize_md(sstream, desc.stat_desc);
341 // Epsilon
342 sstream.write(&desc.layer_norm_epsilon);
343 // Flags
344 sstream.write(&desc.flags);
345}
346
347void serialize_desc(serialization_stream_t &sstream, const lrn_desc_t &desc) {
348 // Kinds
349 sstream.write(&desc.primitive_kind);
350 sstream.write(&desc.prop_kind);
351 sstream.write(&desc.alg_kind);
352 // Memory descriptors
353 serialize_md(sstream, desc.src_desc);
354 serialize_md(sstream, desc.dst_desc);
355 serialize_md(sstream, desc.diff_src_desc);
356 serialize_md(sstream, desc.diff_dst_desc);
357 // Local size
358 sstream.write(&desc.local_size);
359 // Alpha, beta
360 sstream.write(&desc.lrn_alpha);
361 sstream.write(&desc.lrn_beta);
362 // k
363 sstream.write(&desc.lrn_k);
364}
365
366void serialize_desc(
367 serialization_stream_t &sstream, const matmul_desc_t &desc) {
368 // Kinds
369 sstream.write(&desc.primitive_kind);
370 // Memory descriptors
371 serialize_md(sstream, desc.src_desc);
372 serialize_md(sstream, desc.weights_desc);
373 serialize_md(sstream, desc.bias_desc);
374 serialize_md(sstream, desc.dst_desc);
375 // Accumulator type
376 sstream.write(&desc.accum_data_type);
377}
378
379void serialize_desc(
380 serialization_stream_t &sstream, const pooling_desc_t &desc) {
381 // Kinds
382 sstream.write(&desc.primitive_kind);
383 sstream.write(&desc.prop_kind);
384 sstream.write(&desc.alg_kind);
385 // Memory descriptors
386 serialize_md(sstream, desc.src_desc);
387 serialize_md(sstream, desc.diff_src_desc);
388 serialize_md(sstream, desc.dst_desc);
389 serialize_md(sstream, desc.diff_dst_desc);
390 // Strides, dilates, padding
391 sstream.write(desc.strides, DNNL_MAX_NDIMS);
392 sstream.write(desc.kernel, DNNL_MAX_NDIMS);
393 sstream.write(desc.padding[0], DNNL_MAX_NDIMS);
394 sstream.write(desc.padding[1], DNNL_MAX_NDIMS);
395 sstream.write(desc.dilation, DNNL_MAX_NDIMS);
396 // Accumulator type
397 sstream.write(&desc.accum_data_type);
398}
399
400void serialize_desc(serialization_stream_t &sstream, const prelu_desc_t &desc) {
401 // Kinds
402 sstream.write(&desc.primitive_kind);
403 sstream.write(&desc.prop_kind);
404 // Memory descriptors
405 serialize_md(sstream, desc.src_desc);
406 serialize_md(sstream, desc.weights_desc);
407 serialize_md(sstream, desc.dst_desc);
408 serialize_md(sstream, desc.diff_src_desc);
409 serialize_md(sstream, desc.diff_weights_desc);
410 serialize_md(sstream, desc.diff_dst_desc);
411}
412
413void serialize_desc(
414 serialization_stream_t &sstream, const reduction_desc_t &desc) {
415 // Kinds
416 sstream.write(&desc.primitive_kind);
417 sstream.write(&desc.alg_kind);
418 // Memory descriptors
419 serialize_md(sstream, desc.src_desc);
420 serialize_md(sstream, desc.dst_desc);
421 // P, eps
422 sstream.write(&desc.p);
423 sstream.write(&desc.eps);
424}
425
426void serialize_desc(
427 serialization_stream_t &sstream, const reorder_desc_t &desc) {
428 // Kinds
429 sstream.write(&desc.primitive_kind);
430 // Memory descriptors
431 serialize_md(sstream, *desc.src_md);
432 serialize_md(sstream, *desc.dst_md);
433 // Kinds of source and destination engines
434 sstream.write(&desc.src_engine_kind);
435 sstream.write(&desc.dst_engine_kind);
436 sstream.write(&desc.is_cross_engine);
437}
438
439void serialize_desc(
440 serialization_stream_t &sstream, const resampling_desc_t &desc) {
441 // Kinds
442 sstream.write(&desc.primitive_kind);
443 sstream.write(&desc.alg_kind);
444 // Memory descriptors
445 serialize_md(sstream, desc.src_desc);
446 serialize_md(sstream, desc.diff_src_desc);
447 serialize_md(sstream, desc.dst_desc);
448 serialize_md(sstream, desc.diff_dst_desc);
449 // Factors
450 sstream.write(desc.factors, DNNL_MAX_NDIMS);
451}
452
453void serialize_desc(serialization_stream_t &sstream, const rnn_desc_t &desc) {
454 // Kinds
455 sstream.write(&desc.primitive_kind);
456 sstream.write(&desc.prop_kind);
457 sstream.write(&desc.cell_kind);
458 sstream.write(&desc.direction);
459 // Memory descriptors
460 serialize_md(sstream, desc.src_layer_desc);
461 serialize_md(sstream, desc.src_iter_desc);
462 serialize_md(sstream, desc.src_iter_c_desc);
463 serialize_md(sstream, desc.weights_layer_desc);
464 serialize_md(sstream, desc.weights_iter_desc);
465 serialize_md(sstream, desc.bias_desc);
466 serialize_md(sstream, desc.dst_layer_desc);
467 serialize_md(sstream, desc.dst_iter_desc);
468 serialize_md(sstream, desc.dst_iter_c_desc);
469 serialize_md(sstream, desc.weights_peephole_desc);
470 serialize_md(sstream, desc.weights_projection_desc);
471 serialize_md(sstream, desc.diff_src_layer_desc);
472 serialize_md(sstream, desc.diff_src_iter_desc);
473 serialize_md(sstream, desc.diff_src_iter_c_desc);
474 serialize_md(sstream, desc.diff_weights_layer_desc);
475 serialize_md(sstream, desc.diff_weights_iter_desc);
476 serialize_md(sstream, desc.diff_bias_desc);
477 serialize_md(sstream, desc.diff_dst_layer_desc);
478 serialize_md(sstream, desc.diff_dst_iter_desc);
479 serialize_md(sstream, desc.diff_dst_iter_c_desc);
480 serialize_md(sstream, desc.diff_weights_peephole_desc);
481 serialize_md(sstream, desc.diff_weights_projection_desc);
482 // Flags
483 sstream.write(&desc.flags);
484 // Activation kind
485 sstream.write(&desc.activation_kind);
486 // Alpha, beta
487 sstream.write(&desc.alpha);
488 sstream.write(&desc.beta);
489}
490
491// Shuffle
492void serialize_desc(
493 serialization_stream_t &sstream, const shuffle_desc_t &desc) {
494 // Kinds
495 sstream.write(&desc.primitive_kind);
496 sstream.write(&desc.prop_kind);
497 // Memory descriptors
498 serialize_md(sstream, desc.src_desc);
499 serialize_md(sstream, desc.dst_desc);
500 // Axis
501 sstream.write(&desc.axis);
502 // Groupe size
503 sstream.write(&desc.group_size);
504}
505
506void serialize_desc(
507 serialization_stream_t &sstream, const softmax_desc_t &desc) {
508 // Kinds
509 sstream.write(&desc.primitive_kind);
510 sstream.write(&desc.prop_kind);
511 sstream.write(&desc.alg_kind);
512 // Memory descriptors
513 serialize_md(sstream, desc.src_desc);
514 serialize_md(sstream, desc.diff_src_desc);
515 serialize_md(sstream, desc.dst_desc);
516 serialize_md(sstream, desc.diff_dst_desc);
517 // Axis
518 sstream.write(&desc.softmax_axis);
519}
520
521void serialize_desc(serialization_stream_t &sstream, const sum_desc_t &desc) {
522 // Kinds
523 sstream.write(&desc.primitive_kind);
524 // Memory descriptors
525 serialize_md(sstream, *desc.dst_md);
526 // N
527 sstream.write(&desc.n);
528 // Scales
529 sstream.write(desc.scales, desc.n);
530 // Array of mds
531 for (int i = 0; i < desc.n; i++)
532 serialize_md(sstream, *desc.src_mds[i]);
533}
534
535} // namespace serialization
536} // namespace impl
537} // namespace dnnl
538