1#include <stdint.h>
2#include <stddef.h>
3#include <assert.h>
4
5#include <nnpack.h>
6#include <nnpack/macros.h>
7#include <nnpack/utils.h>
8
9#include <nnpack/hwinfo.h>
10#include <nnpack/activations.h>
11#include <nnpack/validation.h>
12
13
14struct NNP_CACHE_ALIGN relu_context {
15 nnp_grad_relu_function grad_relu_function;
16 const float* grad_output;
17 const float* input;
18 float* grad_input;
19 float negative_slope;
20};
21
22static void compute_grad_relu(
23 const struct relu_context context[restrict static 1],
24 size_t block_start, size_t block_size)
25{
26 nnp_grad_relu_function grad_relu = context->grad_relu_function;
27 const float* grad_output = context->grad_output;
28 const float* input = context->input;
29 float* grad_input = context->grad_input;
30 float negative_slope = context->negative_slope;
31
32 grad_relu(grad_output + block_start, input + block_start, grad_input + block_start, block_size, negative_slope);
33}
34
35enum nnp_status nnp_relu_input_gradient(
36 size_t batch_size,
37 size_t channels,
38 const float grad_output[],
39 const float input[],
40 float grad_input[],
41 float negative_slope,
42 pthreadpool_t threadpool)
43{
44 enum nnp_status status = validate_relu_arguments(batch_size, channels);
45 if (status != nnp_status_success) {
46 return status;
47 }
48
49 size_t elements = batch_size * channels;
50 const size_t simd_width = nnp_hwinfo.simd_width;
51
52 assert(((uintptr_t) grad_output) % sizeof(float) == 0);
53 assert(((uintptr_t) input) % sizeof(float) == 0);
54 assert(((uintptr_t) grad_input) % sizeof(float) == 0);
55
56 const size_t prologue_elements = min((size_t) (-(((uintptr_t) grad_input) / sizeof(float)) % simd_width), elements);
57 for (size_t i = 0; i < prologue_elements; i++) {
58 grad_input[i] = grad_relu(grad_output[i], input[i], negative_slope);
59 }
60 elements -= prologue_elements;
61 grad_output += prologue_elements;
62 input += prologue_elements;
63 grad_input += prologue_elements;
64
65 const size_t epilogue_elements = elements % simd_width;
66 for (size_t i = 0; i < epilogue_elements; i++) {
67 grad_input[elements - epilogue_elements + i] = grad_relu(
68 grad_output[elements - epilogue_elements + i],
69 input[elements - epilogue_elements + i],
70 negative_slope);
71 }
72 elements -= epilogue_elements;
73
74 struct relu_context relu_context = {
75 .grad_relu_function = nnp_hwinfo.activations.grad_relu,
76 .grad_output = grad_output,
77 .input = input,
78 .grad_input = grad_input,
79 .negative_slope = negative_slope,
80 };
81 pthreadpool_parallelize_1d_tile_1d(threadpool,
82 (pthreadpool_function_1d_tiled_t) compute_grad_relu,
83 &relu_context,
84 elements, round_down(nnp_hwinfo.blocking.l1 / sizeof(float), simd_width),
85 PTHREADPOOL_FLAG_DISABLE_DENORMALS);
86
87 return nnp_status_success;
88}
89