1#include <stdint.h>
2#include <stddef.h>
3#include <assert.h>
4
5#include <nnpack.h>
6#include <nnpack/macros.h>
7#include <nnpack/utils.h>
8
9#include <nnpack/hwinfo.h>
10#include <nnpack/activations.h>
11#include <nnpack/validation.h>
12
13
14struct NNP_CACHE_ALIGN relu_context {
15 nnp_relu_function relu_function;
16 const float* input;
17 float* output;
18 float negative_slope;
19};
20
21static void compute_relu_output(
22 const struct relu_context context[restrict static 1],
23 size_t block_start, size_t block_size)
24{
25 nnp_relu_function relu = context->relu_function;
26 const float* input = context->input;
27 float* output = context->output;
28 float negative_slope = context->negative_slope;
29
30 relu(input + block_start, output + block_start, block_size, negative_slope);
31}
32
33struct NNP_CACHE_ALIGN inplace_relu_context {
34 nnp_inplace_relu_function relu_function;
35 float* data;
36 float negative_slope;
37};
38
39static void compute_inplace_relu_output(
40 const struct inplace_relu_context context[restrict static 1],
41 size_t block_start, size_t block_size)
42{
43 nnp_inplace_relu_function relu = context->relu_function;
44 float* data = context->data;
45 float negative_slope = context->negative_slope;
46
47 relu(data + block_start, block_size, negative_slope);
48}
49
50enum nnp_status nnp_relu_output(
51 size_t batch_size,
52 size_t channels,
53 const float input[],
54 float output[],
55 float negative_slope,
56 pthreadpool_t threadpool)
57{
58 enum nnp_status status = validate_relu_arguments(batch_size, channels);
59 if (status != nnp_status_success) {
60 return status;
61 }
62
63 size_t elements = batch_size * channels;
64 const size_t simd_width = nnp_hwinfo.simd_width;
65
66 assert(((uintptr_t) input) % sizeof(float) == 0);
67 assert(((uintptr_t) output) % sizeof(float) == 0);
68
69 const size_t prologue_elements = min((size_t) (-(((uintptr_t) output) / sizeof(float)) % simd_width), elements);
70 for (size_t i = 0; i < prologue_elements; i++) {
71 output[i] = relu(input[i], negative_slope);
72 }
73 elements -= prologue_elements;
74 input += prologue_elements;
75 output += prologue_elements;
76
77 const size_t epilogue_elements = elements % simd_width;
78 for (size_t i = 0; i < epilogue_elements; i++) {
79 output[elements - epilogue_elements + i] =
80 relu(input[elements - epilogue_elements + i], negative_slope);
81 }
82 elements -= epilogue_elements;
83
84 if (input != output) {
85 /* Out-of-place transformation */
86 struct relu_context relu_context = {
87 .relu_function = nnp_hwinfo.activations.relu,
88 .input = input,
89 .output = output,
90 .negative_slope = negative_slope,
91 };
92 pthreadpool_parallelize_1d_tile_1d(threadpool,
93 (pthreadpool_function_1d_tiled_t) compute_relu_output,
94 &relu_context,
95 elements, round_down(nnp_hwinfo.blocking.l1 / sizeof(float), simd_width),
96 PTHREADPOOL_FLAG_DISABLE_DENORMALS);
97 } else {
98 /* In-place transformation */
99 struct inplace_relu_context inplace_relu_context = {
100 .relu_function = nnp_hwinfo.activations.inplace_relu,
101 .data = output,
102 .negative_slope = negative_slope,
103 };
104 pthreadpool_parallelize_1d_tile_1d(threadpool,
105 (pthreadpool_function_1d_tiled_t) compute_inplace_relu_output,
106 &inplace_relu_context,
107 elements, round_down(nnp_hwinfo.blocking.l1 / sizeof(float), simd_width),
108 PTHREADPOOL_FLAG_DISABLE_DENORMALS);
109 }
110
111 return nnp_status_success;
112}
113