1 | #include <stdint.h> |
2 | #include <stddef.h> |
3 | #include <assert.h> |
4 | |
5 | #include <nnpack.h> |
6 | #include <nnpack/macros.h> |
7 | #include <nnpack/utils.h> |
8 | |
9 | #include <nnpack/hwinfo.h> |
10 | #include <nnpack/activations.h> |
11 | #include <nnpack/validation.h> |
12 | |
13 | |
14 | struct NNP_CACHE_ALIGN relu_context { |
15 | nnp_relu_function relu_function; |
16 | const float* input; |
17 | float* output; |
18 | float negative_slope; |
19 | }; |
20 | |
21 | static void compute_relu_output( |
22 | const struct relu_context context[restrict static 1], |
23 | size_t block_start, size_t block_size) |
24 | { |
25 | nnp_relu_function relu = context->relu_function; |
26 | const float* input = context->input; |
27 | float* output = context->output; |
28 | float negative_slope = context->negative_slope; |
29 | |
30 | relu(input + block_start, output + block_start, block_size, negative_slope); |
31 | } |
32 | |
33 | struct NNP_CACHE_ALIGN inplace_relu_context { |
34 | nnp_inplace_relu_function relu_function; |
35 | float* data; |
36 | float negative_slope; |
37 | }; |
38 | |
39 | static void compute_inplace_relu_output( |
40 | const struct inplace_relu_context context[restrict static 1], |
41 | size_t block_start, size_t block_size) |
42 | { |
43 | nnp_inplace_relu_function relu = context->relu_function; |
44 | float* data = context->data; |
45 | float negative_slope = context->negative_slope; |
46 | |
47 | relu(data + block_start, block_size, negative_slope); |
48 | } |
49 | |
50 | enum nnp_status nnp_relu_output( |
51 | size_t batch_size, |
52 | size_t channels, |
53 | const float input[], |
54 | float output[], |
55 | float negative_slope, |
56 | pthreadpool_t threadpool) |
57 | { |
58 | enum nnp_status status = validate_relu_arguments(batch_size, channels); |
59 | if (status != nnp_status_success) { |
60 | return status; |
61 | } |
62 | |
63 | size_t elements = batch_size * channels; |
64 | const size_t simd_width = nnp_hwinfo.simd_width; |
65 | |
66 | assert(((uintptr_t) input) % sizeof(float) == 0); |
67 | assert(((uintptr_t) output) % sizeof(float) == 0); |
68 | |
69 | const size_t prologue_elements = min((size_t) (-(((uintptr_t) output) / sizeof(float)) % simd_width), elements); |
70 | for (size_t i = 0; i < prologue_elements; i++) { |
71 | output[i] = relu(input[i], negative_slope); |
72 | } |
73 | elements -= prologue_elements; |
74 | input += prologue_elements; |
75 | output += prologue_elements; |
76 | |
77 | const size_t epilogue_elements = elements % simd_width; |
78 | for (size_t i = 0; i < epilogue_elements; i++) { |
79 | output[elements - epilogue_elements + i] = |
80 | relu(input[elements - epilogue_elements + i], negative_slope); |
81 | } |
82 | elements -= epilogue_elements; |
83 | |
84 | if (input != output) { |
85 | /* Out-of-place transformation */ |
86 | struct relu_context relu_context = { |
87 | .relu_function = nnp_hwinfo.activations.relu, |
88 | .input = input, |
89 | .output = output, |
90 | .negative_slope = negative_slope, |
91 | }; |
92 | pthreadpool_parallelize_1d_tile_1d(threadpool, |
93 | (pthreadpool_function_1d_tiled_t) compute_relu_output, |
94 | &relu_context, |
95 | elements, round_down(nnp_hwinfo.blocking.l1 / sizeof(float), simd_width), |
96 | PTHREADPOOL_FLAG_DISABLE_DENORMALS); |
97 | } else { |
98 | /* In-place transformation */ |
99 | struct inplace_relu_context inplace_relu_context = { |
100 | .relu_function = nnp_hwinfo.activations.inplace_relu, |
101 | .data = output, |
102 | .negative_slope = negative_slope, |
103 | }; |
104 | pthreadpool_parallelize_1d_tile_1d(threadpool, |
105 | (pthreadpool_function_1d_tiled_t) compute_inplace_relu_output, |
106 | &inplace_relu_context, |
107 | elements, round_down(nnp_hwinfo.blocking.l1 / sizeof(float), simd_width), |
108 | PTHREADPOOL_FLAG_DISABLE_DENORMALS); |
109 | } |
110 | |
111 | return nnp_status_success; |
112 | } |
113 | |