1/*
2 * Copyright (c) Facebook, Inc. and its affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9#include <assert.h>
10#include <math.h>
11#include <stddef.h>
12#include <stdint.h>
13#include <stdlib.h>
14
15#include <qnnpack.h>
16#include <qnnpack/operator.h>
17#include <qnnpack/log.h>
18
19
20enum qnnp_status qnnp_create_softargmax_nc_q8(
21 size_t channels,
22 float input_scale,
23 uint8_t output_zero_point,
24 float output_scale,
25 uint32_t flags,
26 qnnp_operator_t* softargmax_out)
27{
28 qnnp_operator_t softargmax_op = NULL;
29 enum qnnp_status status = qnnp_status_uninitialized;
30
31 if (!qnnp_params.initialized) {
32 qnnp_log_error("qnnp_create_softargmax_nc_q8 failed because QNNPACK is not properly initialized");
33 goto error;
34 }
35
36 status = qnnp_status_invalid_parameter;
37
38 if (channels == 0) {
39 qnnp_log_error(
40 "failed to create Soft ArgMax operator with %zu channels: number of channels must be non-zero", channels);
41 goto error;
42 }
43
44 if (input_scale <= 0.0f || !isnormal(input_scale)) {
45 qnnp_log_error(
46 "failed to create Soft ArgMax operator with %.7g input scale: scale must be finite and positive", input_scale);
47 goto error;
48 }
49
50 if (output_scale <= 0.0f || !isnormal(output_scale)) {
51 qnnp_log_error(
52 "failed to create Soft ArgMax operator with %.7g output scale: scale must be finite and positive", output_scale);
53 goto error;
54 }
55
56 status = qnnp_status_unsupported_parameter;
57
58 if (output_scale != 0x1.0p-8f) {
59 qnnp_log_error(
60 "failed to create Soft ArgMax operator with %.7g output scale: only output scale of 1/256 is supported",
61 output_scale);
62 goto error;
63 }
64
65 if (output_zero_point != 0) {
66 qnnp_log_error(
67 "failed to create Soft ArgMax operator with %" PRIu8 " output zero point: only output zero point of 0 is supported",
68 output_zero_point);
69 goto error;
70 }
71
72 status = qnnp_status_out_of_memory;
73
74 softargmax_op = calloc(1, sizeof(struct qnnp_operator));
75 if (softargmax_op == NULL) {
76 qnnp_log_error("failed to allocate %zu bytes for qnnp_operator structure", sizeof(struct qnnp_operator));
77 goto error;
78 }
79
80 softargmax_op->lookup_table = malloc(256 * sizeof(uint32_t));
81 if (softargmax_op->lookup_table == NULL) {
82 qnnp_log_error("failed to allocate 256 bytes for Soft ArgMax lookup table");
83 goto error;
84 }
85
86 uint32_t* lookup_table = softargmax_op->lookup_table;
87 const double qscale = fmin(((double) UINT32_MAX) / (double) channels, 8388607.0);
88 for (int32_t i = 0; i < 256; i++) {
89 const double scaled_exp_xi = qscale * exp((double) (i - 255) * (double) input_scale);
90 lookup_table[(uint32_t) i] = (uint32_t) lrint(scaled_exp_xi);
91 }
92
93 softargmax_op->channels = channels;
94
95 softargmax_op->ukernel_type = qnnp_ukernel_type_softargmax;
96 softargmax_op->format = qnnp_format_quint8;
97
98 *softargmax_out = softargmax_op;
99 return qnnp_status_success;
100
101error:
102 qnnp_delete_operator(softargmax_op);
103 return status;
104}
105
106enum qnnp_status qnnp_setup_softargmax_nc_q8(
107 qnnp_operator_t softargmax,
108 size_t batch_size,
109 const uint8_t* input,
110 size_t input_stride,
111 uint8_t* output,
112 size_t output_stride)
113{
114 if (!qnnp_params.initialized) {
115 qnnp_log_error("qnnp_setup_softargmax_nc_q8 failed because QNNPACK is not properly initialized");
116 return qnnp_status_uninitialized;
117 }
118
119 if (batch_size == 0) {
120 softargmax->batch_size = 0;
121 return qnnp_status_success;
122 }
123
124 softargmax->batch_size = batch_size;
125 softargmax->input = input;
126 softargmax->input_pixel_stride = input_stride;
127 softargmax->output = output;
128 softargmax->output_pixel_stride = output_stride;
129
130 return qnnp_status_success;
131}
132