1 | #include <stdint.h> |
2 | #include <stddef.h> |
3 | #include <assert.h> |
4 | |
5 | #include <nnpack.h> |
6 | #include <nnpack/macros.h> |
7 | #include <nnpack/utils.h> |
8 | |
9 | #include <nnpack/hwinfo.h> |
10 | #include <nnpack/activations.h> |
11 | #include <nnpack/validation.h> |
12 | |
13 | |
14 | struct NNP_CACHE_ALIGN relu_context { |
15 | nnp_grad_relu_function grad_relu_function; |
16 | const float* grad_output; |
17 | const float* input; |
18 | float* grad_input; |
19 | float negative_slope; |
20 | }; |
21 | |
22 | static void compute_grad_relu( |
23 | const struct relu_context context[restrict static 1], |
24 | size_t block_start, size_t block_size) |
25 | { |
26 | nnp_grad_relu_function grad_relu = context->grad_relu_function; |
27 | const float* grad_output = context->grad_output; |
28 | const float* input = context->input; |
29 | float* grad_input = context->grad_input; |
30 | float negative_slope = context->negative_slope; |
31 | |
32 | grad_relu(grad_output + block_start, input + block_start, grad_input + block_start, block_size, negative_slope); |
33 | } |
34 | |
35 | enum nnp_status nnp_relu_input_gradient( |
36 | size_t batch_size, |
37 | size_t channels, |
38 | const float grad_output[], |
39 | const float input[], |
40 | float grad_input[], |
41 | float negative_slope, |
42 | pthreadpool_t threadpool) |
43 | { |
44 | enum nnp_status status = validate_relu_arguments(batch_size, channels); |
45 | if (status != nnp_status_success) { |
46 | return status; |
47 | } |
48 | |
49 | size_t elements = batch_size * channels; |
50 | const size_t simd_width = nnp_hwinfo.simd_width; |
51 | |
52 | assert(((uintptr_t) grad_output) % sizeof(float) == 0); |
53 | assert(((uintptr_t) input) % sizeof(float) == 0); |
54 | assert(((uintptr_t) grad_input) % sizeof(float) == 0); |
55 | |
56 | const size_t prologue_elements = min((size_t) (-(((uintptr_t) grad_input) / sizeof(float)) % simd_width), elements); |
57 | for (size_t i = 0; i < prologue_elements; i++) { |
58 | grad_input[i] = grad_relu(grad_output[i], input[i], negative_slope); |
59 | } |
60 | elements -= prologue_elements; |
61 | grad_output += prologue_elements; |
62 | input += prologue_elements; |
63 | grad_input += prologue_elements; |
64 | |
65 | const size_t epilogue_elements = elements % simd_width; |
66 | for (size_t i = 0; i < epilogue_elements; i++) { |
67 | grad_input[elements - epilogue_elements + i] = grad_relu( |
68 | grad_output[elements - epilogue_elements + i], |
69 | input[elements - epilogue_elements + i], |
70 | negative_slope); |
71 | } |
72 | elements -= epilogue_elements; |
73 | |
74 | struct relu_context relu_context = { |
75 | .grad_relu_function = nnp_hwinfo.activations.grad_relu, |
76 | .grad_output = grad_output, |
77 | .input = input, |
78 | .grad_input = grad_input, |
79 | .negative_slope = negative_slope, |
80 | }; |
81 | pthreadpool_parallelize_1d_tile_1d(threadpool, |
82 | (pthreadpool_function_1d_tiled_t) compute_grad_relu, |
83 | &relu_context, |
84 | elements, round_down(nnp_hwinfo.blocking.l1 / sizeof(float), simd_width), |
85 | PTHREADPOOL_FLAG_DISABLE_DENORMALS); |
86 | |
87 | return nnp_status_success; |
88 | } |
89 | |