1 | #include <stddef.h> |
2 | |
3 | #include <nnpack.h> |
4 | #include <nnpack/macros.h> |
5 | #include <nnpack/utils.h> |
6 | |
7 | #include <nnpack/hwinfo.h> |
8 | #include <nnpack/softmax.h> |
9 | #include <nnpack/validation.h> |
10 | |
11 | |
12 | struct NNP_CACHE_ALIGN softmax_context { |
13 | nnp_softmax_function softmax_function; |
14 | size_t channels; |
15 | const float* input; |
16 | float* output; |
17 | }; |
18 | |
19 | static void compute_softmax_output( |
20 | const struct softmax_context context[restrict static 1], |
21 | size_t sample) |
22 | { |
23 | const nnp_softmax_function softmax = context->softmax_function; |
24 | const size_t channels = context->channels; |
25 | |
26 | const float (*input)[channels] = (const float(*)[channels]) context->input; |
27 | float (*output)[channels] = (float(*)[channels]) context->output; |
28 | |
29 | softmax(channels, input[sample], output[sample]); |
30 | } |
31 | |
32 | struct NNP_CACHE_ALIGN inplace_softmax_context { |
33 | nnp_inplace_softmax_function softmax_function; |
34 | size_t channels; |
35 | float* data; |
36 | }; |
37 | |
38 | static void compute_inplace_softmax_output( |
39 | const struct inplace_softmax_context context[restrict static 1], |
40 | size_t sample) |
41 | { |
42 | const nnp_inplace_softmax_function softmax = context->softmax_function; |
43 | const size_t channels = context->channels; |
44 | |
45 | float (*data)[channels] = (float(*)[channels]) context->data; |
46 | |
47 | softmax(channels, data[sample]); |
48 | } |
49 | |
50 | enum nnp_status nnp_softmax_output( |
51 | size_t batch_size, |
52 | size_t channels, |
53 | const float* input, |
54 | float* output, |
55 | pthreadpool_t threadpool) |
56 | { |
57 | enum nnp_status status = validate_softmax_arguments(batch_size, channels); |
58 | if (status != nnp_status_success) { |
59 | return status; |
60 | } |
61 | |
62 | if (input != output) { |
63 | /* Out-of-place softmax */ |
64 | struct softmax_context softmax_context = { |
65 | .softmax_function = nnp_hwinfo.activations.softmax, |
66 | .channels = channels, |
67 | .input = input, |
68 | .output = output, |
69 | }; |
70 | pthreadpool_parallelize_1d(threadpool, |
71 | (pthreadpool_function_1d_t) compute_softmax_output, |
72 | &softmax_context, |
73 | batch_size, |
74 | PTHREADPOOL_FLAG_DISABLE_DENORMALS); |
75 | } else { |
76 | /* In-place softmax */ |
77 | struct inplace_softmax_context inplace_softmax_context = { |
78 | .softmax_function = nnp_hwinfo.activations.inplace_softmax, |
79 | .channels = channels, |
80 | .data = output, |
81 | }; |
82 | pthreadpool_parallelize_1d(threadpool, |
83 | (pthreadpool_function_1d_t) compute_inplace_softmax_output, |
84 | &inplace_softmax_context, |
85 | batch_size, |
86 | PTHREADPOOL_FLAG_DISABLE_DENORMALS); |
87 | } |
88 | |
89 | return nnp_status_success; |
90 | } |
91 | |