1 | /* |
2 | * Copyright (c) Facebook, Inc. and its affiliates. |
3 | * All rights reserved. |
4 | * |
5 | * This source code is licensed under the BSD-style license found in the |
6 | * LICENSE file in the root directory of this source tree. |
7 | */ |
8 | |
9 | #include <assert.h> |
10 | #include <math.h> |
11 | #include <stddef.h> |
12 | #include <stdint.h> |
13 | #include <stdlib.h> |
14 | |
15 | #include <qnnpack.h> |
16 | #include <qnnpack/operator.h> |
17 | #include <qnnpack/log.h> |
18 | |
19 | |
20 | enum qnnp_status qnnp_create_softargmax_nc_q8( |
21 | size_t channels, |
22 | float input_scale, |
23 | uint8_t output_zero_point, |
24 | float output_scale, |
25 | uint32_t flags, |
26 | qnnp_operator_t* softargmax_out) |
27 | { |
28 | qnnp_operator_t softargmax_op = NULL; |
29 | enum qnnp_status status = qnnp_status_uninitialized; |
30 | |
31 | if (!qnnp_params.initialized) { |
32 | qnnp_log_error("qnnp_create_softargmax_nc_q8 failed because QNNPACK is not properly initialized" ); |
33 | goto error; |
34 | } |
35 | |
36 | status = qnnp_status_invalid_parameter; |
37 | |
38 | if (channels == 0) { |
39 | qnnp_log_error( |
40 | "failed to create Soft ArgMax operator with %zu channels: number of channels must be non-zero" , channels); |
41 | goto error; |
42 | } |
43 | |
44 | if (input_scale <= 0.0f || !isnormal(input_scale)) { |
45 | qnnp_log_error( |
46 | "failed to create Soft ArgMax operator with %.7g input scale: scale must be finite and positive" , input_scale); |
47 | goto error; |
48 | } |
49 | |
50 | if (output_scale <= 0.0f || !isnormal(output_scale)) { |
51 | qnnp_log_error( |
52 | "failed to create Soft ArgMax operator with %.7g output scale: scale must be finite and positive" , output_scale); |
53 | goto error; |
54 | } |
55 | |
56 | status = qnnp_status_unsupported_parameter; |
57 | |
58 | if (output_scale != 0x1.0p-8f) { |
59 | qnnp_log_error( |
60 | "failed to create Soft ArgMax operator with %.7g output scale: only output scale of 1/256 is supported" , |
61 | output_scale); |
62 | goto error; |
63 | } |
64 | |
65 | if (output_zero_point != 0) { |
66 | qnnp_log_error( |
67 | "failed to create Soft ArgMax operator with %" PRIu8 " output zero point: only output zero point of 0 is supported" , |
68 | output_zero_point); |
69 | goto error; |
70 | } |
71 | |
72 | status = qnnp_status_out_of_memory; |
73 | |
74 | softargmax_op = calloc(1, sizeof(struct qnnp_operator)); |
75 | if (softargmax_op == NULL) { |
76 | qnnp_log_error("failed to allocate %zu bytes for qnnp_operator structure" , sizeof(struct qnnp_operator)); |
77 | goto error; |
78 | } |
79 | |
80 | softargmax_op->lookup_table = malloc(256 * sizeof(uint32_t)); |
81 | if (softargmax_op->lookup_table == NULL) { |
82 | qnnp_log_error("failed to allocate 256 bytes for Soft ArgMax lookup table" ); |
83 | goto error; |
84 | } |
85 | |
86 | uint32_t* lookup_table = softargmax_op->lookup_table; |
87 | const double qscale = fmin(((double) UINT32_MAX) / (double) channels, 8388607.0); |
88 | for (int32_t i = 0; i < 256; i++) { |
89 | const double scaled_exp_xi = qscale * exp((double) (i - 255) * (double) input_scale); |
90 | lookup_table[(uint32_t) i] = (uint32_t) lrint(scaled_exp_xi); |
91 | } |
92 | |
93 | softargmax_op->channels = channels; |
94 | |
95 | softargmax_op->ukernel_type = qnnp_ukernel_type_softargmax; |
96 | softargmax_op->format = qnnp_format_quint8; |
97 | |
98 | *softargmax_out = softargmax_op; |
99 | return qnnp_status_success; |
100 | |
101 | error: |
102 | qnnp_delete_operator(softargmax_op); |
103 | return status; |
104 | } |
105 | |
106 | enum qnnp_status qnnp_setup_softargmax_nc_q8( |
107 | qnnp_operator_t softargmax, |
108 | size_t batch_size, |
109 | const uint8_t* input, |
110 | size_t input_stride, |
111 | uint8_t* output, |
112 | size_t output_stride) |
113 | { |
114 | if (!qnnp_params.initialized) { |
115 | qnnp_log_error("qnnp_setup_softargmax_nc_q8 failed because QNNPACK is not properly initialized" ); |
116 | return qnnp_status_uninitialized; |
117 | } |
118 | |
119 | if (batch_size == 0) { |
120 | softargmax->batch_size = 0; |
121 | return qnnp_status_success; |
122 | } |
123 | |
124 | softargmax->batch_size = batch_size; |
125 | softargmax->input = input; |
126 | softargmax->input_pixel_stride = input_stride; |
127 | softargmax->output = output; |
128 | softargmax->output_pixel_stride = output_stride; |
129 | |
130 | return qnnp_status_success; |
131 | } |
132 | |