1#include <stddef.h>
2
3#include <nnpack.h>
4#include <nnpack/macros.h>
5#include <nnpack/utils.h>
6
7#include <nnpack/hwinfo.h>
8#include <nnpack/softmax.h>
9#include <nnpack/validation.h>
10
11
12struct NNP_CACHE_ALIGN softmax_context {
13 nnp_softmax_function softmax_function;
14 size_t channels;
15 const float* input;
16 float* output;
17};
18
19static void compute_softmax_output(
20 const struct softmax_context context[restrict static 1],
21 size_t sample)
22{
23 const nnp_softmax_function softmax = context->softmax_function;
24 const size_t channels = context->channels;
25
26 const float (*input)[channels] = (const float(*)[channels]) context->input;
27 float (*output)[channels] = (float(*)[channels]) context->output;
28
29 softmax(channels, input[sample], output[sample]);
30}
31
32struct NNP_CACHE_ALIGN inplace_softmax_context {
33 nnp_inplace_softmax_function softmax_function;
34 size_t channels;
35 float* data;
36};
37
38static void compute_inplace_softmax_output(
39 const struct inplace_softmax_context context[restrict static 1],
40 size_t sample)
41{
42 const nnp_inplace_softmax_function softmax = context->softmax_function;
43 const size_t channels = context->channels;
44
45 float (*data)[channels] = (float(*)[channels]) context->data;
46
47 softmax(channels, data[sample]);
48}
49
50enum nnp_status nnp_softmax_output(
51 size_t batch_size,
52 size_t channels,
53 const float* input,
54 float* output,
55 pthreadpool_t threadpool)
56{
57 enum nnp_status status = validate_softmax_arguments(batch_size, channels);
58 if (status != nnp_status_success) {
59 return status;
60 }
61
62 if (input != output) {
63 /* Out-of-place softmax */
64 struct softmax_context softmax_context = {
65 .softmax_function = nnp_hwinfo.activations.softmax,
66 .channels = channels,
67 .input = input,
68 .output = output,
69 };
70 pthreadpool_parallelize_1d(threadpool,
71 (pthreadpool_function_1d_t) compute_softmax_output,
72 &softmax_context,
73 batch_size,
74 PTHREADPOOL_FLAG_DISABLE_DENORMALS);
75 } else {
76 /* In-place softmax */
77 struct inplace_softmax_context inplace_softmax_context = {
78 .softmax_function = nnp_hwinfo.activations.inplace_softmax,
79 .channels = channels,
80 .data = output,
81 };
82 pthreadpool_parallelize_1d(threadpool,
83 (pthreadpool_function_1d_t) compute_inplace_softmax_output,
84 &inplace_softmax_context,
85 batch_size,
86 PTHREADPOOL_FLAG_DISABLE_DENORMALS);
87 }
88
89 return nnp_status_success;
90}
91