New upstream version 0.0~git20220418.02b17c5
Mo Zhou
1 year, 10 months ago
34 | 34 | using batch_normalization_flag = dnnl::normalization_flags; |
35 | 35 | using query = dnnl::query; |
36 | 36 | using scale_t = std::vector<float>; |
37 | using zero_point_t = std::vector<int32_t>; | |
37 | 38 | using exec_args = std::unordered_map<int, memory>; |
38 | 39 | |
39 | 40 | // for computation cache |
40 | 40 | return attr; |
41 | 41 | } |
42 | 42 | |
43 | static attr_t fuse_gelu(float scale = 1.0, float alpha = 0.f, | |
44 | float beta = 0.f) { | |
45 | attr_t attr; | |
46 | post_ops po; | |
47 | po.append_eltwise(scale, algorithm::eltwise_gelu_erf, alpha, beta); | |
48 | attr.set_post_ops(po); | |
49 | return attr; | |
50 | } | |
51 | ||
52 | static attr_t fuse_tanh(float scale = 1.0, float alpha = 0.f, | |
53 | float beta = 0.f) { | |
54 | attr_t attr; | |
55 | post_ops po; | |
56 | po.append_eltwise(scale, algorithm::eltwise_tanh, alpha, beta); | |
57 | attr.set_post_ops(po); | |
58 | return attr; | |
59 | } | |
60 | ||
43 | 61 | static attr_t residual(float sum_scale = 1.0, float relu_scale = 1.0, |
44 | 62 | float alpha = 0.f, float beta = 0.f) { |
45 | 63 | attr_t attr; |
46 | 64 | post_ops po; |
47 | 65 | po.append_sum(sum_scale); |
48 | 66 | po.append_eltwise(relu_scale, algorithm::eltwise_relu, alpha, beta); |
67 | attr.set_post_ops(po); | |
68 | return attr; | |
69 | } | |
70 | ||
71 | static attr_t fuse_clamp(float lower_bound = -1.0, float upper_bound = 1.0) { | |
72 | attr_t attr; | |
73 | post_ops po; | |
74 | po.append_eltwise(1.0, algorithm::eltwise_clip, lower_bound, upper_bound); | |
49 | 75 | attr.set_post_ops(po); |
50 | 76 | return attr; |
51 | 77 | } |
146 | 172 | |
147 | 173 | } // namespace ideep |
148 | 174 | |
149 | #endif⏎ | |
175 | #endif |
18 | 18 | #include "operators/lstm.hpp" |
19 | 19 | #include "operators/matmul.hpp" |
20 | 20 | #include "operators/pool.hpp" |
21 | #include "operators/prelu.hpp" | |
21 | 22 | #include "operators/softmax.hpp" |
22 | 23 | #include "operators/spliter.hpp" |
23 | 24 | #include "operators/sum.hpp" |
53 | 53 | |
54 | 54 | bool fuse_norm_relu = (bool) (flags & batch_normalization_flag::fuse_norm_relu); |
55 | 55 | attr_t attr = fuse_norm_relu ? attr_t::fuse_relu() : attr_t(); |
56 | attr.set_scratchpad_mode(dnnl::scratchpad_mode::user); | |
56 | 57 | auto pd = primitive_desc( |
57 | 58 | {prop_kind::forward_inference, src_desc, epsilon, pd_flags}, attr, aengine); |
58 | 59 | |
59 | 60 | tensor scale_shift {pd.weights_desc()}; |
61 | tensor scratchpad(pd.scratchpad_desc()); | |
60 | 62 | auto* scale_shift_buf = static_cast<char *>(scale_shift.get_data_handle()); |
61 | 63 | std::memcpy(scale_shift_buf, scale.get_data_handle(), scale.get_size()); |
62 | 64 | std::memcpy(scale_shift_buf + scale.get_size(), |
72 | 74 | {DNNL_ARG_SCALE_SHIFT, scale_shift}, |
73 | 75 | {DNNL_ARG_VARIANCE, expected_var}, |
74 | 76 | {DNNL_ARG_MEAN, expected_mean}, |
75 | {DNNL_ARG_DST, dst}}); | |
77 | {DNNL_ARG_DST, dst}, | |
78 | {DNNL_ARG_SCRATCHPAD, scratchpad}}); | |
76 | 79 | } else { |
77 | 80 | super(pd).execute(stream::default_stream(), |
78 | 81 | {{DNNL_ARG_SRC, expected_src}, |
79 | 82 | {DNNL_ARG_SCALE_SHIFT, scale_shift}, |
80 | {DNNL_ARG_DST, dst}}); | |
83 | {DNNL_ARG_DST, dst}, | |
84 | {DNNL_ARG_SCRATCHPAD, scratchpad}}); | |
81 | 85 | } |
82 | 86 | } |
83 | 87 | }; |
104 | 108 | auto src_desc = src._get_unblocked_desc_if_4c_blocked(); |
105 | 109 | // auto src_desc = src.get_desc(); |
106 | 110 | |
111 | auto op_attr = dnnl::primitive_attr(); | |
112 | op_attr.set_scratchpad_mode(dnnl::scratchpad_mode::user); | |
113 | ||
107 | 114 | auto pd = primitive_desc( |
108 | {prop_kind::forward_training, src_desc, epsilon, pd_flags}, aengine); | |
115 | {prop_kind::forward_training, src_desc, epsilon, pd_flags}, | |
116 | op_attr, | |
117 | aengine); | |
109 | 118 | |
110 | 119 | tensor scale_shift {pd.weights_desc()}; |
120 | tensor scratchpad(pd.scratchpad_desc()); | |
111 | 121 | auto* scale_shift_buf = static_cast<char *>(scale_shift.get_data_handle()); |
112 | 122 | std::memcpy(scale_shift_buf, scale.get_data_handle(), scale.get_size()); |
113 | 123 | std::memcpy(scale_shift_buf + scale.get_size(), |
121 | 131 | {DNNL_ARG_SCALE_SHIFT, scale_shift}, |
122 | 132 | {DNNL_ARG_MEAN, mean}, |
123 | 133 | {DNNL_ARG_VARIANCE, variance}, |
124 | {DNNL_ARG_DST, dst}}; | |
134 | {DNNL_ARG_DST, dst}, | |
135 | {DNNL_ARG_SCRATCHPAD, scratchpad}}; | |
125 | 136 | if (with_workspace) { |
126 | 137 | dst.init_workspace(pd.workspace_desc()); |
127 | 138 | args.insert({DNNL_ARG_WORKSPACE, dst.get_workspace()}); |
173 | 184 | auto forward_hints = dnnl::batch_normalization_forward::primitive_desc( |
174 | 185 | {prop_kind::forward_training, src_desc, epsilon, pd_flags}, aengine); |
175 | 186 | |
187 | auto op_attr = dnnl::primitive_attr(); | |
188 | op_attr.set_scratchpad_mode(dnnl::scratchpad_mode::user); | |
189 | ||
176 | 190 | auto pd = primitive_desc( |
177 | 191 | {prop_kind::backward, forward_hints.dst_desc(), src_desc, epsilon, pd_flags}, |
178 | aengine, forward_hints); | |
192 | op_attr, aengine, forward_hints); | |
179 | 193 | |
180 | 194 | auto expected_diff_dst = diff_dst.reorder_if_differ_in(pd.diff_dst_desc()); |
181 | 195 | auto expected_src = src.reorder_if_differ_in(pd.src_desc()); |
184 | 198 | diff_src.reinit_if_possible(pd.diff_src_desc()); |
185 | 199 | diff_scale_shift.reinit_if_possible(pd.diff_weights_desc()); |
186 | 200 | |
201 | tensor scratchpad(pd.scratchpad_desc()); | |
202 | ||
187 | 203 | exec_args args {{DNNL_ARG_SRC, expected_src}, |
188 | 204 | {DNNL_ARG_DIFF_DST, expected_diff_dst}, |
189 | 205 | {DNNL_ARG_SCALE_SHIFT, scale}, // only need scale |
190 | 206 | {DNNL_ARG_MEAN, expected_mean}, |
191 | 207 | {DNNL_ARG_VARIANCE, expected_variance}, |
192 | 208 | {DNNL_ARG_DIFF_SRC, diff_src}, |
193 | {DNNL_ARG_DIFF_SCALE_SHIFT, diff_scale_shift}}; | |
209 | {DNNL_ARG_DIFF_SCALE_SHIFT, diff_scale_shift}, | |
210 | {DNNL_ARG_SCRATCHPAD, scratchpad}}; | |
194 | 211 | if (with_workspace) { |
195 | 212 | args.insert({DNNL_ARG_WORKSPACE, dst.get_workspace()}); |
196 | 213 | } |
15 | 15 | auto src1_desc = src1.get_desc(); |
16 | 16 | auto dst_desc = src0_desc.to_format_any(); |
17 | 17 | |
18 | auto op_attr = dnnl::primitive_attr(); | |
19 | op_attr.set_scratchpad_mode(dnnl::scratchpad_mode::user); | |
20 | ||
18 | 21 | auto pd = primitive_desc( |
19 | {aalgorithm, src0_desc, src1_desc, dst_desc}, aengine); | |
20 | ||
22 | {aalgorithm, src0_desc, src1_desc, dst_desc}, op_attr, aengine); | |
23 | ||
24 | tensor scratchpad(pd.scratchpad_desc()); | |
25 | ||
21 | 26 | auto expected_src0 = src0.reorder_if_differ_in(pd.src0_desc()); |
22 | 27 | auto expected_src1 = src1.reorder_if_differ_in(pd.src1_desc()); |
23 | 28 | dst.reinit_if_possible(pd.dst_desc()); |
25 | 30 | super(pd).execute(stream::default_stream(), |
26 | 31 | {{DNNL_ARG_SRC_0, expected_src0}, |
27 | 32 | {DNNL_ARG_SRC_1, expected_src1}, |
28 | {DNNL_ARG_DST, dst}}); | |
33 | {DNNL_ARG_DST, dst}, | |
34 | {DNNL_ARG_SCRATCHPAD, scratchpad}}); | |
29 | 35 | } |
30 | 36 | }; |
31 | 37 |
16 | 16 | IDEEP_ENFORCE(src.get_data_type() == data_type::f32, "invalid data type"); |
17 | 17 | |
18 | 18 | auto group_size = static_cast<int>(src.get_dim(axis) / group); |
19 | ||
20 | auto op_attr = dnnl::primitive_attr(); | |
21 | op_attr.set_scratchpad_mode(dnnl::scratchpad_mode::user); | |
22 | ||
19 | 23 | auto pd = |
20 | primitive_desc({aprop_kind, src.get_desc(), axis, group_size}, aengine); | |
24 | primitive_desc({aprop_kind, src.get_desc(), axis, group_size}, aengine, op_attr); | |
21 | 25 | |
22 | 26 | auto expected_src = src.reorder_if_differ_in(pd.src_desc()); |
23 | 27 | dst.reinit_if_possible(pd.dst_desc()); |
24 | 28 | |
29 | tensor scratchpad(pd.scratchpad_desc()); | |
30 | ||
25 | 31 | super(pd).execute(stream::default_stream(), |
26 | {{DNNL_ARG_SRC, expected_src}, {DNNL_ARG_DST, dst}}); | |
32 | {{DNNL_ARG_SRC, expected_src}, | |
33 | {DNNL_ARG_DST, dst}, | |
34 | {DNNL_ARG_SCRATCHPAD, scratchpad}}); | |
27 | 35 | } |
28 | 36 | }; |
29 | 37 | |
41 | 49 | |
42 | 50 | auto forward_hints = dnnl::shuffle_forward::primitive_desc( |
43 | 51 | {prop_kind::forward, data_desc, group_size, axis}, aengine); |
44 | auto pd = | |
45 | primitive_desc({data_desc, axis, group_size}, aengine, forward_hints); | |
52 | ||
53 | auto op_attr = dnnl::primitive_attr(); | |
54 | op_attr.set_scratchpad_mode(dnnl::scratchpad_mode::user); | |
55 | ||
56 | auto pd = primitive_desc( | |
57 | {data_desc, axis, group_size}, aengine, forward_hints, op_attr); | |
46 | 58 | |
47 | 59 | auto expected_diff_dst = diff_dst.reorder_if_differ_in(pd.diff_dst_desc()); |
48 | 60 | diff_src.reinit_if_possible(pd.diff_src_desc()); |
49 | 61 | |
62 | tensor scratchpad(pd.scratchpad_desc()); | |
63 | ||
50 | 64 | super(pd).execute(stream::default_stream(), |
51 | 65 | {{DNNL_ARG_DIFF_DST, expected_diff_dst}, |
52 | {DNNL_ARG_DIFF_SRC, diff_src}}); | |
66 | {DNNL_ARG_DIFF_SRC, diff_src}, | |
67 | {DNNL_ARG_SCRATCHPAD, scratchpad}}); | |
53 | 68 | } |
54 | 69 | }; |
55 | 70 |
15 | 15 | return static_cast<memory::desc>(t.get_desc()); |
16 | 16 | }); |
17 | 17 | |
18 | auto op_attr = dnnl::primitive_attr(); | |
19 | op_attr.set_scratchpad_mode(dnnl::scratchpad_mode::user); | |
20 | ||
18 | 21 | // create a pd to query the optimimal format for src and dst |
19 | auto pd = primitive_desc(axis, input_descs, aengine); | |
22 | auto pd = primitive_desc(axis, input_descs, aengine, op_attr); | |
20 | 23 | auto expected_desc = tensor::desc(pd.dst_desc()); |
21 | 24 | |
22 | 25 | output.reinit_if_possible(expected_desc); |
23 | ||
24 | exec_args args {{DNNL_ARG_DST, output}}; | |
25 | 26 | |
26 | 27 | // DNNL currently supports two types of implementations in the concat: |
27 | 28 | // (Very fast) Works only when all memories are in the same format |
44 | 45 | return static_cast<memory::desc>(t.get_desc()); |
45 | 46 | }); |
46 | 47 | // recreate the pd on new inputs with same formats |
47 | pd = primitive_desc(axis, input_descs, aengine); | |
48 | pd = primitive_desc(axis, input_descs, aengine, op_attr); | |
48 | 49 | } |
50 | ||
51 | tensor scratchpad(pd.scratchpad_desc()); | |
52 | exec_args args {{DNNL_ARG_DST, output}, {DNNL_ARG_SCRATCHPAD, scratchpad}}; | |
49 | 53 | |
50 | 54 | for (int i = 0; i < opt_inputs.size(); ++i) { |
51 | 55 | args.insert({DNNL_ARG_MULTIPLE_SRC + i, opt_inputs[i]}); |
7 | 7 | // bias_attr contains requantization scales for bias |
8 | 8 | attr_t bias_attr; |
9 | 9 | scale_t dst_scales; |
10 | zero_point_t src_zero_point; | |
10 | 11 | int groups; |
11 | 12 | tensor scratchpad; |
12 | 13 | }; |
13 | 14 | |
15 | struct conv_deconv_utils { | |
16 | /// Common logic to prepare parameters for conv/deconv. | |
17 | static void prepare_parameters(const tensor& src, | |
18 | const tensor& weights, | |
19 | const tensor& bias, | |
20 | const dims& dst_dims, | |
21 | const tensor& dst, | |
22 | const dims& dilates, | |
23 | int groups, | |
24 | const scale_t& src_scales, | |
25 | const scale_t& weights_scales, | |
26 | const scale_t& dst_scales, | |
27 | const zero_point_t& src_zero_points, | |
28 | const zero_point_t& dst_zero_points, | |
29 | const attr_t& attr, | |
30 | const lowp_kind alowp_kind, | |
31 | bool with_bias, | |
32 | bool is_deconv, | |
33 | tensor& weight_grouped, /* Output */ | |
34 | dims& dil_compatible, /* Output */ | |
35 | attr_t& op_attr, /* Output */ | |
36 | attr_t& src_attr, /* Output */ | |
37 | attr_t& weights_attr, /* Output */ | |
38 | attr_t& bias_attr, /* Output */ | |
39 | tensor::desc& src_desc, /* Output */ | |
40 | tensor::desc& weights_desc, /* Output */ | |
41 | tensor::desc& bias_desc, /* Output */ | |
42 | tensor::desc& dst_desc /* Output */) { | |
43 | scale_t dst_scales_in; | |
44 | data_type dst_data_type; | |
45 | op_attr = attr; | |
46 | ||
47 | // make weights and dilates compatible with DNNL | |
48 | weight_grouped = weights.make_grouped_weights(groups, is_deconv); | |
49 | dil_compatible = utils::get_compatible_dilates(dilates); | |
50 | ||
51 | auto& weights_scales_in = | |
52 | weight_grouped.has_scale() ? weight_grouped.get_scale() : weights_scales; | |
53 | if (!weights_scales_in.empty()) { | |
54 | IDEEP_ENFORCE(alowp_kind == u8s8 || alowp_kind == s8s8, | |
55 | "Unsupported lowp kind"); | |
56 | int scale_size = (weights_scales_in.size() > 1) ? dst_dims[1] : 1; | |
57 | auto& src_scales_in = | |
58 | src.has_scale() ? src.get_scale() | |
59 | : (src_scales.empty() ? IDEEP_DEF_SCALE : src_scales); | |
60 | ||
61 | // determine dst data type | |
62 | if (dst.get_data_type() != data_type::undef) { | |
63 | dst_data_type = dst.get_data_type(); | |
64 | } else if (dst_scales.empty() || dst_scales == IDEEP_DEF_SCALE) { | |
65 | dst_data_type = data_type::f32; | |
66 | } else if (attr.non_negitive_output()) { | |
67 | dst_data_type = data_type::u8; | |
68 | } else { | |
69 | dst_data_type = data_type::s8; | |
70 | } | |
71 | ||
72 | // fill primitive attr | |
73 | dst_scales_in = dst_scales.empty() || dst_data_type == data_type::f32 | |
74 | ? IDEEP_DEF_SCALE | |
75 | : dst_scales; | |
76 | const auto default_zero_point = zero_point_t(1); | |
77 | const auto& src_zero_point = src.has_zero_point() ? src.get_zero_point() : | |
78 | src_zero_points.empty() ? default_zero_point : src_zero_points; | |
79 | const auto& weights_zero_point = weight_grouped.has_zero_point() ? weight_grouped.get_zero_point() : default_zero_point; | |
80 | const auto& dst_zero_point = dst.has_zero_point() ? dst.get_zero_point() : | |
81 | dst_zero_points.empty() ? default_zero_point : dst_zero_points; | |
82 | const auto src_zero_point_size = static_cast<dim>(src_zero_point.size()); | |
83 | const auto weights_zero_point_size = 1; | |
84 | const auto dst_zero_point_size = static_cast<dim>(dst_zero_point.size()); | |
85 | IDEEP_ENFORCE(src_zero_point_size == 1 && dst_zero_point_size == 1, | |
86 | "DNNL only support 1-dim zero_point"); | |
87 | ||
88 | scale_t bias_scales, op_scales; | |
89 | std::tie(bias_scales, op_scales) = utils::compute_scales( | |
90 | src_scales_in[0], dst_scales_in[0], weights_scales_in); | |
91 | ||
92 | if (attr.has_op_kind(kind::sum)) { | |
93 | float sum_scale = | |
94 | dst_scales_in[0] / (dst.has_scale() ? dst.get_scale()[0] : 1.0f); | |
95 | if (attr.has_op_kind(kind::eltwise)) { | |
96 | op_attr = attr_t::residual(sum_scale); | |
97 | } else { | |
98 | op_attr = attr_t::fuse_sum(sum_scale); | |
99 | } | |
100 | } | |
101 | op_attr.set_output_scales(utils::op_scale_mask(scale_size), op_scales); | |
102 | zero_point_t src_zero_point_in_attr; | |
103 | int zp_mask = utils::tensor_zp_mask(1); | |
104 | attr.get_zero_points(DNNL_ARG_SRC, zp_mask, src_zero_point_in_attr); | |
105 | if (src_zero_point_in_attr == zero_point_t({DNNL_RUNTIME_S32_VAL})) { // runtime src zero point | |
106 | op_attr.set_zero_points(DNNL_ARG_SRC, | |
107 | zp_mask, | |
108 | src_zero_point_in_attr); | |
109 | } else { | |
110 | op_attr.set_zero_points(DNNL_ARG_SRC, | |
111 | ideep::utils::tensor_zp_mask(src_zero_point_size), | |
112 | src_zero_point); | |
113 | } | |
114 | op_attr.set_zero_points(DNNL_ARG_WEIGHTS, | |
115 | ideep::utils::tensor_zp_mask(weights_zero_point_size), | |
116 | zero_point_t(1, weights_zero_point[0])); | |
117 | if (dst_data_type != data_type::f32) { | |
118 | op_attr.set_zero_points(DNNL_ARG_DST, | |
119 | ideep::utils::tensor_zp_mask(dst_zero_point_size), | |
120 | dst_zero_point); | |
121 | } | |
122 | ||
123 | src_desc = {src.get_dims(), | |
124 | alowp_kind == u8s8 ? data_type::u8 : data_type::s8, tag::any}; | |
125 | if (src.get_data_type() == data_type::f32) { | |
126 | src_attr = {0, src_scales_in}; | |
127 | } | |
128 | ||
129 | weights_desc = weight_grouped.get_desc().to_type(data_type::s8); | |
130 | if (weight_grouped.get_data_type() == data_type::f32) { | |
131 | weights_attr = {utils::tensor_scale_mask(scale_size, groups > 1), | |
132 | weights_scales_in}; | |
133 | } | |
134 | ||
135 | if (with_bias) { | |
136 | bias_desc = {bias.get_dims(), data_type::f32, tag::any}; // Use f32 instead of s32 to improve accuracy | |
137 | if (bias.get_data_type() == data_type::f32) { | |
138 | bias_attr = {utils::tensor_scale_mask(scale_size, false), | |
139 | bias_scales}; | |
140 | } | |
141 | } | |
142 | } else { | |
143 | if (src.has_scale()) { | |
144 | auto src_scale = src.get_scale(); | |
145 | src_scale[0] = 1.0f / src_scale[0]; | |
146 | src_attr = {0, src_scale}; | |
147 | } | |
148 | ||
149 | IDEEP_ENFORCE(utils::one_of(weight_grouped.get_data_type(), | |
150 | data_type::f32, data_type::bf16), | |
151 | "Incorrect data type in weights"); | |
152 | ||
153 | // align weights data type with src | |
154 | dst_data_type = src.get_data_type() == data_type::bf16 ? data_type::bf16 | |
155 | : data_type::f32; | |
156 | src_desc = src.get_desc().to_type(dst_data_type); | |
157 | weights_desc = weight_grouped.get_desc().to_type(dst_data_type); | |
158 | ||
159 | if (with_bias) { | |
160 | IDEEP_ENFORCE(utils::one_of(bias.get_data_type(), | |
161 | data_type::f32, data_type::bf16), | |
162 | "Incorrect data type in bias"); | |
163 | bias_desc = bias.get_desc(); | |
164 | } | |
165 | } | |
166 | ||
167 | op_attr.set_scratchpad_mode(dnnl::scratchpad_mode::user); | |
168 | ||
169 | dst_desc = attr.has_op_kind(kind::sum) | |
170 | ? dst.get_desc() | |
171 | : tensor::desc(dst_dims, dst_data_type); | |
172 | } | |
173 | ||
174 | /// Get true zero point from input tensor, specified zero point or op attr | |
175 | /// Priority: input.get_zero_point() > input_zero_point > op_attr > default | |
176 | /// | |
177 | /// @param input Get the true zero point from this tensor. | |
178 | /// @param arg_idx Parameter argument index as passed to the | |
179 | /// primitive::execute() call. Such as DNNL_ARG_SRC. | |
180 | /// @param op_attr Attr of the conv/deconv operation. | |
181 | /// @param aengine Cpu execution engine. | |
182 | /// @param zero_point Output tensor of zero points. | |
183 | static void obtain_runtime_zero_point(const tensor& input, | |
184 | const zero_point_t& input_zero_point, | |
185 | const int& arg_idx, | |
186 | const dnnl::primitive_attr& op_attr, | |
187 | const engine& aengine, | |
188 | tensor& zero_point /* Output */) { | |
189 | zero_point_t src_zero_point_in_attr; | |
190 | int zp_mask = utils::tensor_zp_mask(1); | |
191 | op_attr.get_zero_points(arg_idx, zp_mask, src_zero_point_in_attr); | |
192 | dim src_zero_point_size = 1; | |
193 | const zero_point_t* zero_point_data = NULL; | |
194 | const zero_point_t default_zero_point = {0}; | |
195 | if (input.has_zero_point()) { | |
196 | src_zero_point_size = static_cast<dim>(input.get_zero_point().size()); | |
197 | zero_point_data = &input.get_zero_point(); | |
198 | } else if (!input_zero_point.empty()) { | |
199 | src_zero_point_size = static_cast<dim>(input_zero_point.size()); | |
200 | zero_point_data = &input_zero_point; | |
201 | } else if (src_zero_point_in_attr == zero_point_t({DNNL_RUNTIME_S32_VAL}) || | |
202 | src_zero_point_in_attr.empty()) { // runtime zero point of input | |
203 | src_zero_point_size = static_cast<dim>(default_zero_point.size()); | |
204 | zero_point_data = &default_zero_point; | |
205 | } else { | |
206 | src_zero_point_size = static_cast<dim>(src_zero_point_in_attr.size()); | |
207 | zero_point_data = &src_zero_point_in_attr; | |
208 | } | |
209 | tensor::desc src_zero_point_desc = {{src_zero_point_size}, data_type::s32, {1}}; | |
210 | zero_point.init(src_zero_point_desc, aengine); | |
211 | auto src_z = reinterpret_cast<int32_t *>(zero_point.get_data_handle()); | |
212 | for (memory::dim i = 0; i < src_zero_point_size; ++i) // fill in zero point data | |
213 | src_z[i] = (*zero_point_data)[i]; | |
214 | ||
215 | } | |
216 | }; | |
217 | ||
14 | 218 | struct convolution_forward |
15 | 219 | : public dnnl::convolution_forward, |
16 | 220 | utils::computation_cache<dnnl::convolution_forward::primitive_desc> { |
17 | 221 | |
18 | 222 | using super = dnnl::convolution_forward; |
19 | 223 | |
20 | // prepare with bias | |
224 | // 2-in-1 compute (prepare & compute) with bias | |
225 | // Bias is not used if it is empty. | |
226 | // Zero points are passed explicitly as arguments for quantization | |
227 | template <bool plain_format = false> | |
228 | static void compute_v2(const tensor& src, | |
229 | const tensor& weights, | |
230 | const tensor& bias, | |
231 | const dims& dst_dims, | |
232 | tensor& dst, | |
233 | const dims& strides, | |
234 | const dims& dilates, | |
235 | const dims& padding_l, | |
236 | const dims& padding_r, | |
237 | int groups, | |
238 | const scale_t& src_scales = scale_t(), | |
239 | const scale_t& weights_scales = scale_t(), | |
240 | const scale_t& dst_scales = scale_t(), | |
241 | const zero_point_t& src_zero_point = zero_point_t(), | |
242 | const zero_point_t& dst_zero_point = zero_point_t(), | |
243 | const attr_t& attr = attr_t(), | |
244 | algorithm aalgorithm = algorithm::convolution_direct, | |
245 | prop_kind aprop_kind = prop_kind::forward, | |
246 | const lowp_kind alowp_kind = u8s8, | |
247 | const engine& aengine = engine::cpu_engine()) { | |
248 | if (bias.is_empty()) { | |
249 | compute_dispatch</*with_bias=*/false, plain_format>( | |
250 | src, weights, bias, dst_dims, dst, strides, dilates, | |
251 | padding_l, padding_r, groups, src_scales, weights_scales, dst_scales, | |
252 | src_zero_point, dst_zero_point, attr, aalgorithm, aprop_kind, alowp_kind, aengine); | |
253 | } else { | |
254 | compute_dispatch</*with_bias=*/true, plain_format>( | |
255 | src, weights, bias, dst_dims, dst, strides, dilates, | |
256 | padding_l, padding_r, groups, src_scales, weights_scales, dst_scales, | |
257 | src_zero_point, dst_zero_point, attr, aalgorithm, aprop_kind, alowp_kind, aengine); | |
258 | } | |
259 | } | |
260 | ||
261 | // 2-in-1 compute (prepare & compute) without bias | |
262 | // Zero points are passed explicitly as arguments for quantization | |
263 | template <bool plain_format = false> | |
264 | static void compute_v2(const tensor& src, | |
265 | const tensor& weights, | |
266 | const dims& dst_dims, | |
267 | tensor& dst, | |
268 | const dims& strides, | |
269 | const dims& dilates, | |
270 | const dims& padding_l, | |
271 | const dims& padding_r, | |
272 | int groups, | |
273 | const scale_t& src_scales = scale_t(), | |
274 | const scale_t& weights_scales = scale_t(), | |
275 | const scale_t& dst_scales = scale_t(), | |
276 | const zero_point_t& src_zero_point = zero_point_t(), | |
277 | const zero_point_t& dst_zero_point = zero_point_t(), | |
278 | const attr_t& attr = attr_t(), | |
279 | algorithm aalgorithm = algorithm::convolution_direct, | |
280 | prop_kind aprop_kind = prop_kind::forward, | |
281 | const lowp_kind alowp_kind = u8s8, | |
282 | const engine& aengine = engine::cpu_engine()) { | |
283 | static tensor dummy_bias; | |
284 | compute_dispatch</*with_bias=*/false, plain_format>( | |
285 | src, weights, dummy_bias, dst_dims, dst, strides, dilates, | |
286 | padding_l, padding_r, groups, src_scales, weights_scales, dst_scales, | |
287 | src_zero_point, dst_zero_point, attr, aalgorithm, aprop_kind, alowp_kind, aengine); | |
288 | } | |
289 | ||
290 | // Prepare with bias. | |
291 | // Bias is not used if it is empty. | |
292 | // Zero points are set to tensor for quantization | |
21 | 293 | static void prepare( |
22 | 294 | convolution_forward_params& param, |
23 | 295 | const tensor& src, |
38 | 310 | prop_kind aprop_kind = prop_kind::forward, |
39 | 311 | const lowp_kind alowp_kind = u8s8, |
40 | 312 | const engine& aengine = engine::cpu_engine()) { |
41 | do_prepare</*with_bias=*/true, /*keep_format=*/false>( | |
42 | param, src, weights, bias, dst_dims, dst, strides, dilates, | |
43 | padding_l, padding_r, groups, src_scales, weights_scales, dst_scales, | |
44 | attr, aalgorithm, aprop_kind, alowp_kind, aengine); | |
45 | } | |
46 | ||
47 | // prepare without bias | |
313 | if (bias.is_empty()) { | |
314 | do_prepare</*with_bias=*/false, /*keep_format=*/false>( | |
315 | param, src, weights, bias, dst_dims, dst, strides, dilates, | |
316 | padding_l, padding_r, groups, src_scales, weights_scales, dst_scales, | |
317 | zero_point_t(), zero_point_t(), attr, aalgorithm, aprop_kind, alowp_kind, aengine); | |
318 | } else { | |
319 | do_prepare</*with_bias=*/true, /*keep_format=*/false>( | |
320 | param, src, weights, bias, dst_dims, dst, strides, dilates, | |
321 | padding_l, padding_r, groups, src_scales, weights_scales, dst_scales, | |
322 | zero_point_t(), zero_point_t(), attr, aalgorithm, aprop_kind, alowp_kind, aengine); | |
323 | } | |
324 | } | |
325 | ||
326 | // Prepare without bias. | |
327 | // Zero points are set to tensor for quantization | |
48 | 328 | static void prepare( |
49 | 329 | convolution_forward_params& param, |
50 | 330 | const tensor& src, |
68 | 348 | do_prepare</*with_bias=*/false, /*keep_format=*/false>( |
69 | 349 | param, src, weights, dummy_bias, dst_dims, dst, strides, dilates, |
70 | 350 | padding_l, padding_r, groups, src_scales, weights_scales, dst_scales, |
71 | attr, aalgorithm, aprop_kind, alowp_kind, aengine); | |
72 | } | |
73 | ||
74 | // compute with bias | |
351 | zero_point_t(), zero_point_t(), attr, aalgorithm, aprop_kind, alowp_kind, aengine); | |
352 | } | |
353 | ||
354 | // Compute with bias | |
355 | // Bias is not used if it is empty. | |
75 | 356 | static void compute(const convolution_forward_params& param, |
76 | 357 | const tensor& src, |
77 | 358 | const tensor& weights, |
78 | 359 | const tensor& bias, |
79 | 360 | tensor& dst) { |
80 | do_compute</*with_bias=*/true>(param, src, weights, bias, dst); | |
81 | } | |
82 | ||
83 | // compute without bias | |
361 | if (bias.is_empty()) { | |
362 | do_compute</*with_bias=*/false>(param, src, weights, bias, dst); | |
363 | } else { | |
364 | do_compute</*with_bias=*/true>(param, src, weights, bias, dst); | |
365 | } | |
366 | } | |
367 | ||
368 | // Compute without bias | |
84 | 369 | static void compute(const convolution_forward_params& param, |
85 | 370 | const tensor& src, |
86 | 371 | const tensor& weights, |
89 | 374 | do_compute</*with_bias=*/false>(param, src, weights, dummy_bias, dst); |
90 | 375 | } |
91 | 376 | |
92 | // 2-in-1 compute (prepare & compute) with bias | |
377 | // Compute with given primitive & src zero point with or without bias | |
378 | static void compute(const super::primitive_desc pd, | |
379 | const super& primitive, | |
380 | const tensor& src, | |
381 | const tensor& weights, | |
382 | const tensor& expected_bias, | |
383 | tensor& dst, | |
384 | const tensor& src_zero_point, | |
385 | int groups) { | |
386 | if (expected_bias.is_empty()) { | |
387 | do_compute</*with_bias=*/false>( | |
388 | pd, primitive, src, weights, expected_bias, dst, src_zero_point, groups); | |
389 | } else { | |
390 | do_compute</*with_bias=*/true>( | |
391 | pd, primitive, src, weights, expected_bias, dst, src_zero_point, groups); | |
392 | } | |
393 | } | |
394 | ||
395 | // Deprecated. 2-in-1 compute (prepare & compute) with bias | |
396 | // Zero points are set to tensor for quantization | |
93 | 397 | template <bool plain_format = false> |
94 | 398 | static void compute(const tensor& src, |
95 | 399 | const tensor& weights, |
112 | 416 | compute_dispatch</*with_bias=*/true, plain_format>( |
113 | 417 | src, weights, bias, dst_dims, dst, strides, dilates, |
114 | 418 | padding_l, padding_r, groups, src_scales, weights_scales, dst_scales, |
115 | attr, aalgorithm, aprop_kind, alowp_kind, aengine); | |
116 | } | |
117 | ||
118 | // 2-in-1 compute (prepare & compute) without bias | |
419 | zero_point_t(), zero_point_t(), attr, aalgorithm, aprop_kind, alowp_kind, aengine); | |
420 | } | |
421 | ||
422 | // Deprecated. 2-in-1 compute (prepare & compute) without bias | |
423 | // Zero points are set to tensor for quantization | |
119 | 424 | template <bool plain_format = false> |
120 | 425 | static void compute(const tensor& src, |
121 | 426 | const tensor& weights, |
138 | 443 | compute_dispatch</*with_bias=*/false, plain_format>( |
139 | 444 | src, weights, dummy_bias, dst_dims, dst, strides, dilates, |
140 | 445 | padding_l, padding_r, groups, src_scales, weights_scales, dst_scales, |
141 | attr, aalgorithm, aprop_kind, alowp_kind, aengine); | |
446 | zero_point_t(), zero_point_t(), attr, aalgorithm, aprop_kind, alowp_kind, aengine); | |
142 | 447 | } |
143 | 448 | |
144 | 449 | static tensor::desc expected_weights_desc( |
184 | 489 | y_dims.push_back(1); |
185 | 490 | y_dims.push_back(oc); |
186 | 491 | if (4 == src_size) { |
187 | x_dims.push_back(2 * kernel_size[0]); | |
492 | x_dims.push_back(4 * kernel_size[0]); | |
188 | 493 | x_dims.push_back(4 * kernel_size[1]); |
189 | 494 | } else { |
190 | x_dims.push_back(2 * kernel_size[0]); | |
191 | x_dims.push_back(4 * kernel_size[1]); | |
495 | x_dims.push_back(8 * kernel_size[0]); | |
496 | x_dims.push_back(8 * kernel_size[1]); | |
192 | 497 | x_dims.push_back(8 * kernel_size[2]); |
193 | 498 | } |
194 | 499 | } else { |
224 | 529 | |
225 | 530 | auto pd = get_primitive_desc</*with_bias=*/false>( |
226 | 531 | src_desc, weights_desc, tensor::desc(), dst_desc, strides, dilates_, |
227 | padding_l, padding_r, attr_t(), aalgorithm, apkind); | |
532 | padding_l, padding_r, attr, aalgorithm, apkind); | |
228 | 533 | |
229 | 534 | // embed group info into weights_desc |
230 | 535 | return tensor::desc(pd.weights_desc(), groups); |
264 | 569 | |
265 | 570 | // For nhwc path, weight uses format_tag::any, |
266 | 571 | // while activation uses format_tag::nhwc. |
267 | bool is_nhwc = src_desc.is_nhwc() || weights_desc.is_nhwc(); | |
268 | if (is_nhwc) { | |
269 | src_desc_query = src_desc.to_format(tag::nhwc); | |
270 | weights_desc_query = weights_desc.to_format_any(); | |
271 | bias_desc_query = with_bias ? bias_desc.to_format_any() : tensor::desc(); | |
272 | dst_desc_query = dst_desc.to_format(tag::nhwc); | |
273 | } | |
274 | ||
275 | auto key = utils::create_key(aprop_kind, aalgorithm, src_desc_query, | |
276 | weights_desc_query, with_bias, strides, | |
277 | dilates, padding_l, padding_r, attr); | |
572 | auto ndims = src_desc.get_dims().size(); | |
573 | if (ndims == 4) { | |
574 | bool is_channels_last = src_desc.is_nhwc() || weights_desc.is_nhwc(); | |
575 | if (is_channels_last) { | |
576 | src_desc_query = src_desc.to_format(tag::nhwc); | |
577 | weights_desc_query = weights_desc.to_format_any(); | |
578 | bias_desc_query = with_bias ? bias_desc.to_format_any() : tensor::desc(); | |
579 | dst_desc_query = dst_desc.to_format(tag::nhwc); | |
580 | } | |
581 | } else if (ndims == 5) { | |
582 | bool is_channels_last = src_desc.is_ndhwc() || weights_desc.is_ndhwc(); | |
583 | if (is_channels_last) { | |
584 | src_desc_query = src_desc.to_format(tag::ndhwc); | |
585 | weights_desc_query = weights_desc.to_format_any(); | |
586 | bias_desc_query = with_bias ? bias_desc.to_format_any() : tensor::desc(); | |
587 | dst_desc_query = dst_desc.to_format(tag::ndhwc); | |
588 | } | |
589 | } | |
590 | ||
591 | auto key = utils::create_key( | |
592 | aprop_kind, | |
593 | aalgorithm, | |
594 | src_desc_query, | |
595 | weights_desc_query, | |
596 | with_bias, | |
597 | strides, | |
598 | dilates, | |
599 | padding_l, | |
600 | padding_r, | |
601 | attr, | |
602 | omp_get_max_threads()); | |
278 | 603 | return fetch_or_create(key, [&]() { |
279 | if (with_bias) { | |
280 | return primitive_desc({aprop_kind, aalgorithm, src_desc_query, | |
281 | weights_desc_query, bias_desc_query, dst_desc_query, | |
282 | strides, dilates, padding_l, padding_r}, | |
283 | attr, aengine); | |
284 | } else { | |
285 | return primitive_desc({aprop_kind, aalgorithm, src_desc_query, | |
286 | weights_desc_query, dst_desc_query, | |
287 | strides, dilates, padding_l, padding_r}, | |
288 | attr, aengine); | |
289 | } | |
604 | if (with_bias) { | |
605 | return primitive_desc( | |
606 | {aprop_kind, | |
607 | aalgorithm, | |
608 | src_desc_query, | |
609 | weights_desc_query, | |
610 | bias_desc_query, | |
611 | dst_desc_query, | |
612 | strides, | |
613 | dilates, | |
614 | padding_l, | |
615 | padding_r}, | |
616 | attr, | |
617 | aengine); | |
618 | } else { | |
619 | return primitive_desc( | |
620 | {aprop_kind, | |
621 | aalgorithm, | |
622 | src_desc_query, | |
623 | weights_desc_query, | |
624 | dst_desc_query, | |
625 | strides, | |
626 | dilates, | |
627 | padding_l, | |
628 | padding_r}, | |
629 | attr, | |
630 | aengine); | |
631 | } | |
290 | 632 | }); |
291 | 633 | } |
292 | 634 | |
326 | 668 | const scale_t& src_scales = scale_t(), |
327 | 669 | const scale_t& weights_scales = scale_t(), |
328 | 670 | const scale_t& dst_scales = scale_t(), |
671 | const zero_point_t& src_zero_point = zero_point_t(), | |
672 | const zero_point_t& dst_zero_point = zero_point_t(), | |
329 | 673 | const attr_t& attr = attr_t(), |
330 | 674 | algorithm aalgorithm = algorithm::convolution_direct, |
331 | 675 | prop_kind aprop_kind = prop_kind::forward, |
342 | 686 | do_prepare<with_bias, /*keep_format=*/true>( |
343 | 687 | params, src, weights, bias, dst_dims, dst, strides, dilates, |
344 | 688 | padding_l, padding_r, groups, src_scales, weights_scales, dst_scales, |
345 | attr, aalgorithm, aprop_kind, alowp_kind, aengine); | |
689 | src_zero_point, dst_zero_point, attr, aalgorithm, aprop_kind, alowp_kind, aengine); | |
346 | 690 | do_compute<with_bias>(params, src, weights, bias, dst); |
347 | 691 | } else { |
348 | 692 | tensor dst_blocked; |
349 | 693 | do_prepare<with_bias, /*keep_format=*/false>( |
350 | 694 | params, src, weights, bias, dst_dims, dst_blocked, strides, dilates, |
351 | 695 | padding_l, padding_r, groups, src_scales, weights_scales, dst_scales, |
352 | attr, aalgorithm, aprop_kind, alowp_kind, aengine); | |
696 | src_zero_point, dst_zero_point, attr, aalgorithm, aprop_kind, alowp_kind, aengine); | |
353 | 697 | do_compute<with_bias>(params, src, weights, bias, dst_blocked); |
354 | 698 | dst.feed_from(dst_blocked); |
355 | 699 | } |
358 | 702 | do_prepare<with_bias, /*keep_format=*/false>( |
359 | 703 | params, src, weights, bias, dst_dims, dst, strides, dilates, |
360 | 704 | padding_l, padding_r, groups, src_scales, weights_scales, dst_scales, |
361 | attr, aalgorithm, aprop_kind, alowp_kind, aengine); | |
705 | src_zero_point, dst_zero_point, attr, aalgorithm, aprop_kind, alowp_kind, aengine); | |
362 | 706 | do_compute<with_bias>(params, src, weights, bias, dst); |
363 | 707 | } |
364 | 708 | } |
379 | 723 | const scale_t& src_scales, |
380 | 724 | const scale_t& weights_scales, |
381 | 725 | const scale_t& dst_scales, |
726 | const zero_point_t& src_zero_point, | |
727 | const zero_point_t& dst_zero_point, | |
382 | 728 | const attr_t& attr, |
383 | 729 | algorithm aalgorithm, |
384 | 730 | prop_kind aprop_kind, |
387 | 733 | |
388 | 734 | scale_t dst_scales_in; |
389 | 735 | data_type dst_data_type; |
390 | tensor::desc src_desc, weights_desc, bias_desc; | |
736 | tensor::desc src_desc, weights_desc, bias_desc, dst_desc; | |
391 | 737 | attr_t op_attr, src_attr, weights_attr, bias_attr; |
392 | ||
393 | // make weights and dilates compatible with DNNL | |
394 | auto weights_ = weights.make_grouped_weights(groups); | |
395 | auto dilates_ = utils::get_compatible_dilates(dilates); | |
396 | ||
397 | auto& weights_scales_in = | |
398 | weights_.has_scale() ? weights_.get_scale() : weights_scales; | |
399 | if (!weights_scales_in.empty()) { | |
400 | IDEEP_ENFORCE(alowp_kind == u8s8 || alowp_kind == s8s8, | |
401 | "Unsupported lowp kind"); | |
402 | int scale_size = (weights_scales_in.size() > 1) ? dst_dims[1] : 1; | |
403 | auto src_scales_in = | |
404 | src.has_scale() ? src.get_scale() | |
405 | : (src_scales.empty() ? IDEEP_DEF_SCALE : src_scales); | |
406 | ||
407 | // determine dst data type | |
408 | if (attr.has_op_kind(kind::sum)) { | |
409 | dst_data_type = dst.get_data_type(); | |
410 | } else if (dst_scales.empty() || dst_scales == IDEEP_DEF_SCALE) { | |
411 | dst_data_type = data_type::f32; | |
412 | } else if (attr.non_negitive_output()) { | |
413 | dst_data_type = data_type::u8; | |
414 | } else { | |
415 | dst_data_type = data_type::s8; | |
416 | } | |
417 | ||
418 | // fill primitive attr | |
419 | dst_scales_in = dst_scales.empty() || dst_data_type == data_type::f32 | |
420 | ? IDEEP_DEF_SCALE | |
421 | : dst_scales; | |
422 | ||
423 | scale_t bias_scales, op_scales; | |
424 | std::tie(bias_scales, op_scales) = utils::compute_scales( | |
425 | src_scales_in[0], dst_scales_in[0], weights_scales_in); | |
426 | ||
427 | if (attr.has_op_kind(kind::sum)) { | |
428 | float sum_scale = | |
429 | dst_scales_in[0] / (dst.has_scale() ? dst.get_scale()[0] : 1.0f); | |
430 | if (attr.has_op_kind(kind::eltwise)) { | |
431 | op_attr = attr_t::residual(sum_scale); | |
432 | } else { | |
433 | op_attr = attr_t::fuse_sum(sum_scale); | |
434 | } | |
435 | } else if (attr.has_op_kind(kind::eltwise)) { | |
436 | op_attr = attr_t::fuse_relu(); | |
437 | } | |
438 | op_attr.set_output_scales(utils::op_scale_mask(scale_size), op_scales); | |
439 | ||
440 | src_desc = {src.get_dims(), | |
441 | alowp_kind == u8s8 ? data_type::u8 : data_type::s8, tag::any}; | |
442 | if (src.get_data_type() == data_type::f32) { | |
443 | src_attr = {0, src_scales_in}; | |
444 | } | |
445 | ||
446 | weights_desc = weights_.get_desc().to_type(data_type::s8); | |
447 | if (weights_.get_data_type() == data_type::f32) { | |
448 | weights_attr = {utils::tensor_scale_mask(scale_size, groups > 1), | |
449 | weights_scales_in}; | |
450 | } | |
451 | ||
452 | if (with_bias) { | |
453 | bias_desc = {bias.get_dims(), data_type::s32, tag::any}; | |
454 | if (bias.get_data_type() == data_type::f32) { | |
455 | bias_attr = {utils::tensor_scale_mask(scale_size, false), | |
456 | bias_scales}; | |
457 | } | |
458 | } | |
459 | } else { | |
460 | op_attr = attr; | |
461 | ||
462 | if (src.has_scale()) { | |
463 | auto src_scale = src.get_scale(); | |
464 | src_scale[0] = 1.0f / src_scale[0]; | |
465 | src_attr = {0, src_scale}; | |
466 | } | |
467 | ||
468 | IDEEP_ENFORCE(utils::one_of(weights_.get_data_type(), | |
469 | data_type::f32, data_type::bf16), | |
470 | "Incorrect data type in weights"); | |
471 | ||
472 | // align weights data type with src | |
473 | dst_data_type = src.get_data_type() == data_type::bf16 ? data_type::bf16 | |
474 | : data_type::f32; | |
475 | src_desc = src.get_desc().to_type(dst_data_type); | |
476 | weights_desc = weights_.get_desc().to_type(dst_data_type); | |
477 | ||
478 | if (with_bias) { | |
479 | IDEEP_ENFORCE(utils::one_of(bias.get_data_type(), | |
480 | data_type::f32, data_type::bf16), | |
481 | "Incorrect data type in bias"); | |
482 | bias_desc = bias.get_desc(); | |
483 | } | |
484 | } | |
485 | ||
486 | op_attr.set_scratchpad_mode(dnnl::scratchpad_mode::user); | |
487 | ||
488 | auto dst_desc = attr.has_op_kind(kind::sum) | |
489 | ? dst.get_desc() | |
490 | : tensor::desc(dst_dims, dst_data_type); | |
491 | ||
738 | tensor weights_grouped; | |
739 | dims dil_compatible; | |
740 | ||
741 | conv_deconv_utils::prepare_parameters( | |
742 | src, weights, bias, dst_dims, dst, dilates, groups, | |
743 | src_scales, weights_scales, dst_scales, src_zero_point, dst_zero_point, | |
744 | attr, alowp_kind, with_bias, false, | |
745 | weights_grouped, dil_compatible, op_attr, src_attr, weights_attr, bias_attr, | |
746 | src_desc, weights_desc, bias_desc, dst_desc); | |
492 | 747 | auto pd = get_primitive_desc<with_bias, keep_format>( |
493 | src_desc, weights_desc, bias_desc, dst_desc, strides, dilates_, | |
748 | src_desc, weights_desc, bias_desc, dst_desc, strides, dil_compatible, | |
494 | 749 | padding_l, padding_r, op_attr, aalgorithm, aprop_kind, aengine); |
495 | 750 | |
496 | 751 | // allocate scratchpad |
497 | 752 | tensor scratchpad(pd.scratchpad_desc()); |
498 | 753 | |
499 | param = {pd, bias_attr, dst_scales, groups, scratchpad}; | |
754 | param = {std::move(pd), bias_attr, std::move(dst_scales), std::move(src_zero_point), | |
755 | groups, std::move(scratchpad)}; | |
500 | 756 | } |
501 | 757 | |
502 | 758 | template <bool with_bias> |
504 | 760 | const tensor& src, const tensor& weights, |
505 | 761 | const tensor& bias, tensor& dst) { |
506 | 762 | auto& pd = param.pd; |
507 | auto scratchpad = param.scratchpad; | |
763 | auto& scratchpad = param.scratchpad; | |
508 | 764 | auto expected_src = src.reorder_if_differ_in(pd.src_desc()); |
509 | 765 | auto expected_weights = weights.make_grouped_weights(param.groups) |
510 | 766 | .reorder_if_differ_in(pd.weights_desc()); |
514 | 770 | dst.set_scale(param.dst_scales); |
515 | 771 | } |
516 | 772 | |
773 | tensor src_zero_point_m; | |
774 | conv_deconv_utils::obtain_runtime_zero_point( | |
775 | src, param.src_zero_point, DNNL_ARG_SRC, pd.get_primitive_attr(), | |
776 | ideep::engine(pd.get_engine().get_kind()), src_zero_point_m); | |
517 | 777 | if (with_bias) { |
518 | auto expected_bias = | |
519 | bias.reorder_if_differ_in(pd.bias_desc(), param.bias_attr); | |
778 | auto expected_bias = bias.reorder_if_differ_in(pd.bias_desc(), param.bias_attr); | |
520 | 779 | super(pd).execute(stream::default_stream(), |
521 | 780 | {{DNNL_ARG_SRC, expected_src}, |
522 | 781 | {DNNL_ARG_WEIGHTS, expected_weights}, |
523 | 782 | {DNNL_ARG_BIAS, expected_bias}, |
524 | 783 | {DNNL_ARG_DST, dst}, |
525 | {DNNL_ARG_SCRATCHPAD, scratchpad}}); | |
784 | {DNNL_ARG_SCRATCHPAD, scratchpad}, | |
785 | {DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_SRC, src_zero_point_m}}); | |
526 | 786 | } else { |
527 | 787 | super(pd).execute(stream::default_stream(), |
528 | 788 | {{DNNL_ARG_SRC, expected_src}, |
529 | 789 | {DNNL_ARG_WEIGHTS, expected_weights}, |
530 | 790 | {DNNL_ARG_DST, dst}, |
531 | {DNNL_ARG_SCRATCHPAD, scratchpad}}); | |
791 | {DNNL_ARG_SCRATCHPAD, scratchpad}, | |
792 | {DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_SRC, src_zero_point_m}}); | |
793 | } | |
794 | } | |
795 | ||
796 | // Do_compute with given primitive & src zero point | |
797 | // Bias scale has been applied before passed in. | |
798 | template <bool with_bias> | |
799 | static void do_compute(const super::primitive_desc& pd, | |
800 | const super& primitive, | |
801 | const tensor& src, | |
802 | const tensor& weights, | |
803 | const tensor& expected_bias, | |
804 | tensor& dst, | |
805 | const tensor& src_zero_point, | |
806 | int groups) { | |
807 | auto scratchpad = tensor(pd.scratchpad_desc()); | |
808 | auto weights_grouped = weights.make_grouped_weights(groups); | |
809 | if (with_bias) { | |
810 | primitive.execute(stream::default_stream(), | |
811 | {{DNNL_ARG_SRC, src}, | |
812 | {DNNL_ARG_WEIGHTS, weights_grouped}, | |
813 | {DNNL_ARG_BIAS, expected_bias}, | |
814 | {DNNL_ARG_DST, dst}, | |
815 | {DNNL_ARG_SCRATCHPAD, scratchpad}, | |
816 | {DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_SRC, src_zero_point}}); | |
817 | } else { | |
818 | primitive.execute(stream::default_stream(), | |
819 | {{DNNL_ARG_SRC, src}, | |
820 | {DNNL_ARG_WEIGHTS, weights_grouped}, | |
821 | {DNNL_ARG_DST, dst}, | |
822 | {DNNL_ARG_SCRATCHPAD, scratchpad}, | |
823 | {DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_SRC, src_zero_point}}); | |
532 | 824 | } |
533 | 825 | } |
534 | 826 | }; |
554 | 846 | auto dilates_ = utils::get_compatible_dilates(dilates); |
555 | 847 | |
556 | 848 | bool is_nhwc = diff_dst.get_desc().is_nhwc(); |
557 | auto format_tag = is_nhwc ? tag::nhwc : tag::any; | |
849 | bool is_ndhwc = diff_dst.get_desc().is_ndhwc(); | |
850 | auto format_tag = is_nhwc ? tag::nhwc : (is_ndhwc ? tag::ndhwc : tag::any); | |
558 | 851 | auto diff_dst_desc = diff_dst.get_desc().to_format(format_tag); |
559 | 852 | // align weight data type with diff_dst for bf16 |
560 | 853 | auto weights_desc = |
568 | 861 | diff_src_desc, weights_desc, tensor::desc(), diff_dst_desc, strides, |
569 | 862 | dilates_, padding_l, padding_r); |
570 | 863 | |
864 | auto op_attr = dnnl::primitive_attr(); | |
865 | op_attr.set_scratchpad_mode(dnnl::scratchpad_mode::user); | |
866 | ||
571 | 867 | auto pd = primitive_desc( |
572 | 868 | {aalgorithm, diff_src_desc, weights_desc, diff_dst_desc, strides, |
573 | dilates_, padding_l, padding_r}, aengine, forward_hints); | |
869 | dilates_, padding_l, padding_r}, op_attr, aengine, forward_hints); | |
574 | 870 | |
575 | 871 | auto expected_diff_dst = diff_dst.reorder_if_differ_in(pd.diff_dst_desc()); |
576 | 872 | auto expected_weights = weights_.reorder_if_differ_in(pd.weights_desc()); |
577 | 873 | diff_src.reinit_if_possible(pd.diff_src_desc()); |
578 | 874 | |
875 | tensor scratchpad(pd.scratchpad_desc()); | |
876 | ||
579 | 877 | super(pd).execute(stream::default_stream(), |
580 | 878 | {{DNNL_ARG_DIFF_DST, expected_diff_dst}, |
581 | 879 | {DNNL_ARG_WEIGHTS, expected_weights}, |
582 | {DNNL_ARG_DIFF_SRC, diff_src}}); | |
880 | {DNNL_ARG_DIFF_SRC, diff_src}, | |
881 | {DNNL_ARG_SCRATCHPAD, scratchpad}}); | |
583 | 882 | } |
584 | 883 | }; |
585 | 884 | |
653 | 952 | } |
654 | 953 | |
655 | 954 | bool is_nhwc = diff_dst.get_desc().is_nhwc(); |
656 | auto format_tag = is_nhwc ? tag::nhwc : tag::any; | |
955 | bool is_ndhwc = diff_dst.get_desc().is_ndhwc(); | |
956 | auto format_tag = is_nhwc ? tag::nhwc : (is_ndhwc ? tag::ndhwc : tag::any); | |
657 | 957 | auto diff_dst_desc = diff_dst.get_desc().to_format(format_tag); |
658 | 958 | auto src_desc = src.get_desc().to_format(format_tag); |
659 | 959 | |
672 | 972 | dilates_, padding_l, padding_r, attr_t(), aalgorithm, |
673 | 973 | prop_kind::forward, aengine); |
674 | 974 | |
975 | auto op_attr = dnnl::primitive_attr(); | |
976 | op_attr.set_scratchpad_mode(dnnl::scratchpad_mode::user); | |
977 | ||
675 | 978 | auto pd = with_diff_bias |
676 | 979 | ? primitive_desc({aalgorithm, src_desc, diff_weights_desc, |
677 | 980 | diff_bias_desc, diff_dst_desc, strides, dilates_, |
678 | padding_l, padding_r}, aengine, forward_hints) | |
981 | padding_l, padding_r}, op_attr, aengine, forward_hints) | |
679 | 982 | : primitive_desc({aalgorithm, src_desc, diff_weights_desc, |
680 | 983 | diff_dst_desc, strides, dilates_, |
681 | padding_l, padding_r}, aengine, forward_hints); | |
984 | padding_l, padding_r}, op_attr, aengine, forward_hints); | |
682 | 985 | |
683 | 986 | auto expected_diff_dst = diff_dst.reorder_if_differ_in(pd.diff_dst_desc()); |
684 | 987 | auto expected_src = src.reorder_if_differ_in(pd.src_desc()); |
687 | 990 | tensor::desc(pd.diff_weights_desc(), groups); |
688 | 991 | diff_weights.reinit_if_possible(expected_diff_weights_desc); |
689 | 992 | |
993 | tensor scratchpad(pd.scratchpad_desc()); | |
994 | ||
690 | 995 | if (with_diff_bias) { |
691 | 996 | diff_bias.reinit_if_possible(pd.diff_bias_desc()); |
692 | 997 | super(pd).execute(stream::default_stream(), |
693 | 998 | {{DNNL_ARG_DIFF_DST, expected_diff_dst}, |
694 | 999 | {DNNL_ARG_SRC, expected_src}, |
695 | 1000 | {DNNL_ARG_DIFF_WEIGHTS, diff_weights}, |
696 | {DNNL_ARG_DIFF_BIAS, diff_bias}}); | |
1001 | {DNNL_ARG_DIFF_BIAS, diff_bias}, | |
1002 | {DNNL_ARG_SCRATCHPAD, scratchpad}}); | |
697 | 1003 | } else { |
698 | 1004 | super(pd).execute(stream::default_stream(), |
699 | 1005 | {{DNNL_ARG_DIFF_DST, expected_diff_dst}, |
700 | 1006 | {DNNL_ARG_SRC, expected_src}, |
701 | {DNNL_ARG_DIFF_WEIGHTS, diff_weights}}); | |
1007 | {DNNL_ARG_DIFF_WEIGHTS, diff_weights}, | |
1008 | {DNNL_ARG_SCRATCHPAD, scratchpad}}); | |
702 | 1009 | } |
703 | 1010 | } |
704 | 1011 | }; |
2 | 2 | |
3 | 3 | namespace ideep { |
4 | 4 | |
5 | struct deconv_forward_params { | |
6 | dnnl::deconvolution_forward::primitive_desc pd; | |
7 | attr_t bias_attr; | |
8 | tensor input_zero_point; | |
9 | int groups; | |
10 | }; | |
11 | ||
5 | 12 | struct convolution_transpose_forward : public dnnl::deconvolution_forward { |
6 | 13 | |
7 | 14 | using super = dnnl::deconvolution_forward; |
8 | 15 | |
16 | // With bias. Zero points are passed explicitly as arguments for quantization | |
17 | // Bias is not used if it is empty. | |
18 | static void compute_v2(const tensor& src, | |
19 | const tensor& weights, // dim: {o, i[, d], h, w} | |
20 | const tensor& bias, | |
21 | const dims& dst_dims, | |
22 | tensor& dst, | |
23 | const dims& strides, | |
24 | const dims& padding_l, | |
25 | const dims& padding_r, | |
26 | const dims& dilates = {1, 1}, | |
27 | int groups = 1, | |
28 | const scale_t& src_scales = scale_t(), | |
29 | const scale_t& weights_scales = scale_t(), | |
30 | const scale_t& dst_scales = scale_t(), | |
31 | const zero_point_t& src_zero_point = zero_point_t(), | |
32 | const zero_point_t& dst_zero_point = zero_point_t(), | |
33 | const attr_t& attr = attr_t(), | |
34 | algorithm aalgorithm = algorithm::deconvolution_direct, | |
35 | prop_kind aprop_kind = prop_kind::forward, | |
36 | const lowp_kind alowp_kind = u8s8, | |
37 | const engine& aengine = engine::cpu_engine()) { | |
38 | if (bias.is_empty()) { | |
39 | compute_impl</*with_bias=*/false>( | |
40 | src, weights, bias, dst_dims, dst, strides, dilates, | |
41 | padding_l, padding_r, groups, src_scales, weights_scales, dst_scales, | |
42 | src_zero_point, dst_zero_point, attr, aalgorithm, aprop_kind, alowp_kind, aengine); | |
43 | } else { | |
44 | compute_impl</*with_bias=*/true>( | |
45 | src, weights, bias, dst_dims, dst, strides, dilates, | |
46 | padding_l, padding_r, groups, src_scales, weights_scales, dst_scales, | |
47 | src_zero_point, dst_zero_point, attr, aalgorithm, aprop_kind, alowp_kind, aengine); | |
48 | } | |
49 | } | |
50 | ||
51 | // Without bias. Zero points are passed explicitly as arguments for quantization | |
52 | static void compute_v2(const tensor& src, | |
53 | const tensor& weights, // dim: {o, i[, d], h, w} | |
54 | const dims& dst_dims, | |
55 | tensor& dst, | |
56 | const dims& strides, | |
57 | const dims& padding_l, | |
58 | const dims& padding_r, | |
59 | const dims& dilates = {1, 1}, | |
60 | int groups = 1, | |
61 | const scale_t& src_scales = scale_t(), | |
62 | const scale_t& weights_scales = scale_t(), | |
63 | const scale_t& dst_scales = scale_t(), | |
64 | const zero_point_t& src_zero_point = zero_point_t(), | |
65 | const zero_point_t& dst_zero_point = zero_point_t(), | |
66 | const attr_t& attr = attr_t(), | |
67 | algorithm aalgorithm = algorithm::deconvolution_direct, | |
68 | prop_kind aprop_kind = prop_kind::forward, | |
69 | const lowp_kind alowp_kind = u8s8, | |
70 | const engine& aengine = engine::cpu_engine()) { | |
71 | static tensor dummy_bias; | |
72 | compute_impl</*with_bias=*/false>( | |
73 | src, weights, dummy_bias, dst_dims, dst, strides, dilates, | |
74 | padding_l, padding_r, groups, src_scales, weights_scales, dst_scales, | |
75 | src_zero_point, dst_zero_point, attr, aalgorithm, aprop_kind, alowp_kind, aengine); | |
76 | } | |
77 | ||
78 | // Deprecated. With bias. Zero points are set to tensor for quantization. | |
9 | 79 | static void compute(const tensor& src, |
10 | const tensor& weights, // dim: {i, o[, d], h, w} | |
80 | const tensor& weights, // dim: {o, i[, d], h, w} | |
11 | 81 | const tensor& bias, |
12 | 82 | const dims& dst_dims, |
13 | 83 | tensor& dst, |
16 | 86 | const dims& padding_r, |
17 | 87 | const dims& dilates = {1, 1}, |
18 | 88 | int groups = 1, |
89 | const scale_t& src_scales = scale_t(), | |
90 | const scale_t& weights_scales = scale_t(), | |
91 | const scale_t& dst_scales = scale_t(), | |
19 | 92 | const attr_t& attr = attr_t(), |
20 | 93 | algorithm aalgorithm = algorithm::deconvolution_direct, |
21 | 94 | prop_kind aprop_kind = prop_kind::forward, |
95 | const lowp_kind alowp_kind = u8s8, | |
22 | 96 | const engine& aengine = engine::cpu_engine()) { |
23 | 97 | compute_impl</*with_bias=*/true>( |
24 | 98 | src, weights, bias, dst_dims, dst, strides, dilates, |
25 | padding_l, padding_r, groups, attr, aalgorithm, aprop_kind, aengine); | |
26 | } | |
27 | ||
99 | padding_l, padding_r, groups, src_scales, weights_scales, dst_scales, | |
100 | zero_point_t(), zero_point_t(), attr, aalgorithm, aprop_kind, alowp_kind, aengine); | |
101 | } | |
102 | ||
103 | // Deprecated. Without bias. Zero points are set to tensor for quantization. | |
28 | 104 | static void compute(const tensor& src, |
29 | const tensor& weights, // dim: {i, o[, d], h, w} | |
105 | const tensor& weights, // dim: {o, i[, d], h, w} | |
30 | 106 | const dims& dst_dims, |
31 | 107 | tensor& dst, |
32 | 108 | const dims& strides, |
34 | 110 | const dims& padding_r, |
35 | 111 | const dims& dilates = {1, 1}, |
36 | 112 | int groups = 1, |
113 | const scale_t& src_scales = scale_t(), | |
114 | const scale_t& weights_scales = scale_t(), | |
115 | const scale_t& dst_scales = scale_t(), | |
37 | 116 | const attr_t& attr = attr_t(), |
38 | 117 | algorithm aalgorithm = algorithm::deconvolution_direct, |
39 | 118 | prop_kind aprop_kind = prop_kind::forward, |
119 | const lowp_kind alowp_kind = u8s8, | |
40 | 120 | const engine& aengine = engine::cpu_engine()) { |
41 | 121 | static tensor dummy_bias; |
42 | 122 | compute_impl</*with_bias=*/false>( |
43 | 123 | src, weights, dummy_bias, dst_dims, dst, strides, dilates, |
44 | padding_l, padding_r, groups, attr, aalgorithm, aprop_kind, aengine); | |
124 | padding_l, padding_r, groups, src_scales, weights_scales, dst_scales, | |
125 | zero_point_t(), zero_point_t(), attr, aalgorithm, aprop_kind, alowp_kind, aengine); | |
126 | } | |
127 | ||
128 | // Bias is not used if it is empty. | |
129 | static void prepare(deconv_forward_params& param, | |
130 | const tensor& src, | |
131 | const tensor& weights, // dim: {o, i[, d], h, w} | |
132 | const tensor& bias, | |
133 | const dims& dst_dims, | |
134 | tensor& dst, | |
135 | const dims& strides, | |
136 | const dims& padding_l, | |
137 | const dims& padding_r, | |
138 | const dims& dilates = {1, 1}, | |
139 | int groups = 1, | |
140 | const scale_t& src_scales = scale_t(), | |
141 | const scale_t& weights_scales = scale_t(), | |
142 | const scale_t& dst_scales = scale_t(), | |
143 | const zero_point_t& src_zero_point = zero_point_t(), | |
144 | const zero_point_t& dst_zero_point = zero_point_t(), | |
145 | const attr_t& attr = attr_t(), | |
146 | algorithm aalgorithm = algorithm::deconvolution_direct, | |
147 | prop_kind aprop_kind = prop_kind::forward, | |
148 | const lowp_kind alowp_kind = u8s8, | |
149 | const engine& aengine = engine::cpu_engine()) { | |
150 | bool with_bias = (!bias.is_empty()); | |
151 | do_prepare(param, src, weights, bias, with_bias, dst_dims, dst, | |
152 | strides, dilates, padding_l, padding_r, groups, | |
153 | src_scales, weights_scales, dst_scales, | |
154 | src_zero_point, dst_zero_point, attr, | |
155 | aalgorithm, aprop_kind, alowp_kind, aengine); | |
156 | } | |
157 | ||
158 | // Bias is not used if it is empty. | |
159 | static void compute(const super::primitive_desc& pd, | |
160 | const super& primitive, | |
161 | const tensor& src, | |
162 | const tensor& weights, | |
163 | const tensor& expected_bias, | |
164 | tensor& dst, | |
165 | const tensor& src_zero_point, | |
166 | int groups) { | |
167 | bool with_bias = (!expected_bias.is_empty()); | |
168 | do_compute(pd, primitive, src, weights, expected_bias, | |
169 | with_bias, dst, src_zero_point, groups); | |
45 | 170 | } |
46 | 171 | |
47 | 172 | static tensor::desc expected_weights_desc( |
85 | 210 | x_dims.push_back(ic); |
86 | 211 | y_dims.push_back(1); |
87 | 212 | y_dims.push_back(oc); |
213 | auto valid_x_dim = [=](int idx) { | |
214 | return std::max((padding_l[idx] + padding_r[idx] - (1 + (kernel_size[idx] - 1) * dilates[idx])) / strides[idx] + 2, | |
215 | 2 * kernel_size[idx]); | |
216 | }; | |
88 | 217 | if (4 == src_size) { |
89 | x_dims.push_back(2 * kernel_size[0]); | |
90 | x_dims.push_back(4 * kernel_size[1]); | |
218 | x_dims.push_back(valid_x_dim(0)); | |
219 | x_dims.push_back(valid_x_dim(1)); | |
91 | 220 | } else { |
92 | x_dims.push_back(2 * kernel_size[0]); | |
93 | x_dims.push_back(4 * kernel_size[1]); | |
94 | x_dims.push_back(8 * kernel_size[2]); | |
221 | x_dims.push_back(valid_x_dim(0)); | |
222 | x_dims.push_back(valid_x_dim(1)); | |
223 | x_dims.push_back(valid_x_dim(2)); | |
95 | 224 | } |
96 | 225 | } else { |
97 | 226 | // Use the real data |
113 | 242 | |
114 | 243 | auto pd = get_primitive_desc</*with_bias=*/false>( |
115 | 244 | src_desc, weights_desc, tensor::desc(), dst_desc, strides, dilates_, |
116 | padding_l, padding_r, attr_t(), aalgorithm, aprop_kind); | |
245 | padding_l, padding_r, attr, aalgorithm, aprop_kind); | |
117 | 246 | |
118 | 247 | // embed group info into weights_desc |
119 | 248 | if (grouped) { |
142 | 271 | // For nhwc path, weight uses format_tag::any, |
143 | 272 | // while activation uses format_tag::nhwc |
144 | 273 | bool is_nhwc = src_desc.is_nhwc() || weights_desc.is_nhwc(); |
145 | auto format_tag = is_nhwc ? tag::nhwc : tag::any; | |
274 | bool is_ndhwc = src_desc.is_ndhwc() || weights_desc.is_ndhwc(); | |
275 | auto format_tag = is_nhwc ? tag::nhwc : (is_ndhwc ? tag::ndhwc : tag::any); | |
146 | 276 | auto src_desc_query = src_desc.to_format(format_tag); |
147 | 277 | auto weights_desc_query = weights_desc.to_format_any(); |
148 | 278 | auto bias_desc_query = with_bias ? bias_desc.to_format_any() : tensor::desc(); |
173 | 303 | const dims& padding_l, |
174 | 304 | const dims& padding_r, |
175 | 305 | int groups, |
306 | const scale_t& src_scales, | |
307 | const scale_t& weights_scales, | |
308 | const scale_t& dst_scales, | |
309 | const zero_point_t& src_zero_point, | |
310 | const zero_point_t& dst_zero_point, | |
176 | 311 | const attr_t& attr, |
177 | 312 | algorithm aalgorithm, |
178 | 313 | prop_kind aprop_kind, |
314 | const lowp_kind alowp_kind, | |
179 | 315 | const engine& aengine) { |
180 | ||
181 | // make weights and dilates compatible with DNNL | |
182 | auto weights_ = weights.make_grouped_weights(groups, true); | |
183 | auto dilates_ = utils::get_compatible_dilates(dilates); | |
184 | ||
185 | tensor::desc dst_desc(dst_dims, src.get_data_type()); | |
316 | scale_t dst_scales_in; | |
317 | data_type dst_data_type; | |
318 | tensor::desc src_desc, weights_desc, bias_desc, dst_desc; | |
319 | attr_t op_attr, src_attr, weights_attr, bias_attr; | |
320 | tensor weights_grouped; | |
321 | dims dil_compatible; | |
322 | ||
323 | conv_deconv_utils::prepare_parameters( | |
324 | src, weights, bias, dst_dims, dst, dilates, groups, | |
325 | src_scales, weights_scales, dst_scales, src_zero_point, dst_zero_point, | |
326 | attr, alowp_kind, with_bias, true, | |
327 | weights_grouped, dil_compatible, op_attr, src_attr, weights_attr, bias_attr, | |
328 | src_desc, weights_desc, bias_desc, dst_desc); | |
186 | 329 | |
187 | 330 | auto pd = get_primitive_desc<with_bias>( |
188 | src.get_desc(), weights_.get_desc(), bias.get_desc(), dst_desc, | |
189 | strides, dilates_, padding_l, padding_r, attr, aalgorithm, | |
331 | src_desc, weights_desc, bias_desc, dst_desc, | |
332 | strides, dil_compatible, padding_l, padding_r, op_attr, aalgorithm, | |
190 | 333 | aprop_kind, aengine); |
191 | 334 | |
335 | tensor scratchpad(pd.scratchpad_desc()); | |
192 | 336 | auto expected_src = src.reorder_if_differ_in(pd.src_desc()); |
193 | auto expected_weights = weights_.reorder_if_differ_in(pd.weights_desc()); | |
337 | auto expected_weights = weights_grouped.reorder_if_differ_in(pd.weights_desc()); | |
194 | 338 | dst.reinit_if_possible(pd.dst_desc()); |
195 | 339 | |
340 | tensor src_zero_point_m; | |
341 | conv_deconv_utils::obtain_runtime_zero_point(src, src_zero_point, DNNL_ARG_SRC, | |
342 | op_attr, aengine, src_zero_point_m); | |
343 | ||
196 | 344 | if (with_bias) { |
197 | auto expected_bias = bias.reorder_if_differ_in(pd.bias_desc()); | |
345 | auto expected_bias = bias.reorder_if_differ_in(pd.bias_desc(), bias_attr); | |
198 | 346 | super(pd).execute(stream::default_stream(), |
199 | 347 | {{DNNL_ARG_SRC, expected_src}, |
200 | 348 | {DNNL_ARG_WEIGHTS, expected_weights}, |
201 | 349 | {DNNL_ARG_BIAS, expected_bias}, |
202 | {DNNL_ARG_DST, dst}}); | |
350 | {DNNL_ARG_DST, dst}, | |
351 | {DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_SRC, src_zero_point_m}, | |
352 | {DNNL_ARG_SCRATCHPAD, scratchpad}}); | |
203 | 353 | } else { |
204 | 354 | super(pd).execute(stream::default_stream(), |
205 | 355 | {{DNNL_ARG_SRC, expected_src}, |
206 | 356 | {DNNL_ARG_WEIGHTS, expected_weights}, |
207 | {DNNL_ARG_DST, dst}}); | |
208 | } | |
209 | } | |
357 | {DNNL_ARG_DST, dst}, | |
358 | {DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_SRC, src_zero_point_m}, | |
359 | {DNNL_ARG_SCRATCHPAD, scratchpad}}); | |
360 | } | |
361 | } | |
362 | ||
363 | static void do_prepare(deconv_forward_params& param, | |
364 | const tensor& src, | |
365 | const tensor& weights, | |
366 | const tensor& bias, | |
367 | bool with_bias, | |
368 | const dims& dst_dims, | |
369 | tensor& dst, | |
370 | const dims& strides, | |
371 | const dims& dilates, | |
372 | const dims& padding_l, | |
373 | const dims& padding_r, | |
374 | int groups, | |
375 | const scale_t& src_scales, | |
376 | const scale_t& weights_scales, | |
377 | const scale_t& dst_scales, | |
378 | const zero_point_t& src_zero_point, | |
379 | const zero_point_t& dst_zero_point, | |
380 | const attr_t& attr, | |
381 | algorithm aalgorithm, | |
382 | prop_kind aprop_kind, | |
383 | const lowp_kind alowp_kind, | |
384 | const engine& aengine) { | |
385 | scale_t dst_scales_in; | |
386 | data_type dst_data_type; | |
387 | tensor::desc src_desc, weights_desc, bias_desc, dst_desc; | |
388 | attr_t op_attr, src_attr, weights_attr, bias_attr; | |
389 | tensor weights_grouped; | |
390 | dims dil_compatible; | |
391 | ||
392 | conv_deconv_utils::prepare_parameters( | |
393 | src, weights, bias, dst_dims, dst, dilates, groups, | |
394 | src_scales, weights_scales, dst_scales, src_zero_point, dst_zero_point, | |
395 | attr, alowp_kind, with_bias, true, | |
396 | weights_grouped, dil_compatible, op_attr, src_attr, weights_attr, bias_attr, | |
397 | src_desc, weights_desc, bias_desc, dst_desc); | |
398 | ||
399 | auto pd = with_bias ? | |
400 | get_primitive_desc</*with_bias=*/true>( | |
401 | src_desc, weights_desc, bias_desc, dst_desc, | |
402 | strides, dil_compatible, padding_l, padding_r, op_attr, aalgorithm, | |
403 | aprop_kind, aengine) : | |
404 | get_primitive_desc</*with_bias=*/false>( | |
405 | src_desc, weights_desc, bias_desc, dst_desc, | |
406 | strides, dil_compatible, padding_l, padding_r, op_attr, aalgorithm, | |
407 | aprop_kind, aengine); | |
408 | ||
409 | auto expected_src = src.reorder_if_differ_in(pd.src_desc()); | |
410 | auto expected_weights = weights_grouped.reorder_if_differ_in(pd.weights_desc()); | |
411 | dst.reinit_if_possible(pd.dst_desc()); | |
412 | ||
413 | tensor src_zero_point_m; | |
414 | conv_deconv_utils::obtain_runtime_zero_point(src, src_zero_point, DNNL_ARG_SRC, | |
415 | op_attr, aengine, src_zero_point_m); | |
416 | param.pd = std::move(pd); | |
417 | param.bias_attr = bias_attr; | |
418 | param.input_zero_point = std::move(src_zero_point_m); | |
419 | param.groups = groups; | |
420 | } | |
421 | ||
422 | static void do_compute(const super::primitive_desc& pd, | |
423 | const super& primitive, | |
424 | const tensor& src, | |
425 | const tensor& weights, | |
426 | const tensor& expected_bias, | |
427 | bool with_bias, | |
428 | tensor& dst, | |
429 | const tensor& src_zero_point, | |
430 | int groups) { | |
431 | tensor scratchpad(pd.scratchpad_desc()); | |
432 | auto expected_weights = weights.make_grouped_weights(groups); | |
433 | if (with_bias) { | |
434 | primitive.execute(stream::default_stream(), | |
435 | {{DNNL_ARG_SRC, src}, | |
436 | {DNNL_ARG_WEIGHTS, expected_weights}, | |
437 | {DNNL_ARG_BIAS, expected_bias}, | |
438 | {DNNL_ARG_DST, dst}, | |
439 | {DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_SRC, src_zero_point}, | |
440 | {DNNL_ARG_SCRATCHPAD, scratchpad}}); | |
441 | } else { | |
442 | primitive.execute(stream::default_stream(), | |
443 | {{DNNL_ARG_SRC, src}, | |
444 | {DNNL_ARG_WEIGHTS, expected_weights}, | |
445 | {DNNL_ARG_DST, dst}, | |
446 | {DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_SRC, src_zero_point}, | |
447 | {DNNL_ARG_SCRATCHPAD, scratchpad}}); | |
448 | } | |
449 | } | |
450 | ||
210 | 451 | }; |
211 | 452 | |
212 | 453 | struct convolution_transpose_backward_data |
230 | 471 | auto dilates_ = utils::get_compatible_dilates(dilates); |
231 | 472 | |
232 | 473 | bool is_nhwc = diff_dst.get_desc().is_nhwc(); |
233 | auto format_tag = is_nhwc ? tag::nhwc : tag::any; | |
474 | bool is_ndhwc = diff_dst.get_desc().is_ndhwc(); | |
475 | auto format_tag = is_nhwc ? tag::nhwc : (is_ndhwc ? tag::ndhwc : tag::any); | |
234 | 476 | auto diff_dst_desc = diff_dst.get_desc().to_format(format_tag); |
235 | 477 | auto weights_desc = weights_.get_desc().to_format_any(); |
236 | 478 | |
241 | 483 | diff_src_desc, weights_desc, tensor::desc(), diff_dst_desc, strides, |
242 | 484 | dilates_, padding_l, padding_r); |
243 | 485 | |
486 | auto op_attr = dnnl::primitive_attr(); | |
487 | op_attr.set_scratchpad_mode(dnnl::scratchpad_mode::user); | |
488 | ||
244 | 489 | auto pd = primitive_desc( |
245 | 490 | {aalgorithm, diff_src_desc, weights_desc, diff_dst_desc, strides, |
246 | dilates_, padding_l, padding_r}, aengine, forward_hints); | |
491 | dilates_, padding_l, padding_r}, op_attr, aengine, forward_hints); | |
247 | 492 | |
248 | 493 | auto expected_diff_dst = diff_dst.reorder_if_differ_in(pd.diff_dst_desc()); |
249 | 494 | auto expected_weights = weights_.reorder_if_differ_in(pd.weights_desc()); |
250 | 495 | diff_src.reinit_if_possible(pd.diff_src_desc()); |
496 | tensor scratchpad(pd.scratchpad_desc()); | |
251 | 497 | |
252 | 498 | super(pd).execute(stream::default_stream(), |
253 | 499 | {{DNNL_ARG_DIFF_DST, expected_diff_dst}, |
254 | 500 | {DNNL_ARG_WEIGHTS, expected_weights}, |
255 | {DNNL_ARG_DIFF_SRC, diff_src}}); | |
501 | {DNNL_ARG_DIFF_SRC, diff_src}, | |
502 | {DNNL_ARG_SCRATCHPAD, scratchpad}}); | |
256 | 503 | } |
257 | 504 | }; |
258 | 505 | |
325 | 572 | } |
326 | 573 | |
327 | 574 | bool is_nhwc = diff_dst.get_desc().is_nhwc(); |
328 | auto format_tag = is_nhwc ? tag::nhwc : tag::any; | |
575 | bool is_ndhwc = diff_dst.get_desc().is_ndhwc(); | |
576 | auto format_tag = is_nhwc ? tag::nhwc : (is_ndhwc ? tag::ndhwc : tag::any); | |
329 | 577 | auto diff_dst_desc = diff_dst.get_desc().to_format(format_tag); |
330 | 578 | auto src_desc = src.get_desc().to_format(format_tag); |
331 | 579 | |
340 | 588 | dilates_, padding_l, padding_r, attr_t(), aalgorithm, |
341 | 589 | prop_kind::forward, aengine); |
342 | 590 | |
591 | auto op_attr = dnnl::primitive_attr(); | |
592 | op_attr.set_scratchpad_mode(dnnl::scratchpad_mode::user); | |
593 | ||
343 | 594 | auto pd = with_diff_bias |
344 | 595 | ? primitive_desc({aalgorithm, src_desc, diff_weights_desc, |
345 | 596 | diff_bias_desc, diff_dst_desc, strides, dilates_, |
346 | padding_l, padding_r}, aengine, forward_hints) | |
597 | padding_l, padding_r}, op_attr, aengine, forward_hints) | |
347 | 598 | : primitive_desc({aalgorithm, src_desc, diff_weights_desc, |
348 | 599 | diff_dst_desc, strides, dilates_, |
349 | padding_l, padding_r}, aengine, forward_hints); | |
600 | padding_l, padding_r}, op_attr, aengine, forward_hints); | |
350 | 601 | |
351 | 602 | auto expected_diff_dst = diff_dst.reorder_if_differ_in(pd.diff_dst_desc()); |
352 | 603 | auto expected_src = src.reorder_if_differ_in(pd.src_desc()); |
354 | 605 | auto expected_diff_weights_desc = |
355 | 606 | tensor::desc(pd.diff_weights_desc(), groups); |
356 | 607 | diff_weights.reinit_if_possible(expected_diff_weights_desc); |
608 | tensor scratchpad(pd.scratchpad_desc()); | |
357 | 609 | |
358 | 610 | if (with_diff_bias) { |
359 | 611 | diff_bias.reinit_if_possible(pd.diff_bias_desc()); |
361 | 613 | {{DNNL_ARG_DIFF_DST, expected_diff_dst}, |
362 | 614 | {DNNL_ARG_SRC, expected_src}, |
363 | 615 | {DNNL_ARG_DIFF_WEIGHTS, diff_weights}, |
364 | {DNNL_ARG_DIFF_BIAS, diff_bias}}); | |
616 | {DNNL_ARG_DIFF_BIAS, diff_bias}, | |
617 | {DNNL_ARG_SCRATCHPAD, scratchpad}}); | |
365 | 618 | } else { |
366 | 619 | super(pd).execute(stream::default_stream(), |
367 | 620 | {{DNNL_ARG_DIFF_DST, expected_diff_dst}, |
368 | 621 | {DNNL_ARG_SRC, expected_src}, |
369 | {DNNL_ARG_DIFF_WEIGHTS, diff_weights}}); | |
622 | {DNNL_ARG_DIFF_WEIGHTS, diff_weights}, | |
623 | {DNNL_ARG_SCRATCHPAD, scratchpad}}); | |
370 | 624 | } |
371 | 625 | |
372 | 626 | // recover output dims to align with pytorch |
21 | 21 | } |
22 | 22 | auto src_desc = src_in.get_desc(); |
23 | 23 | |
24 | auto op_attr = dnnl::primitive_attr(); | |
25 | op_attr.set_scratchpad_mode(dnnl::scratchpad_mode::user); | |
26 | ||
24 | 27 | auto pd = primitive_desc( |
25 | {aprop_kind, aalgorithm, src_desc, alpha, beta}, aengine); | |
28 | {aprop_kind, aalgorithm, src_desc, alpha, beta}, op_attr, aengine); | |
26 | 29 | |
27 | 30 | dst.reinit_if_possible(pd.dst_desc()); |
28 | 31 | if (src_in.has_scale()) { |
29 | 32 | dst.set_scale(src_in.get_scale()); |
30 | 33 | } |
34 | tensor scratchpad(pd.scratchpad_desc()); | |
31 | 35 | |
32 | 36 | super(pd).execute(stream::default_stream(), |
33 | {{DNNL_ARG_SRC, src_in}, {DNNL_ARG_DST, dst}}); | |
37 | {{DNNL_ARG_SRC, src_in}, | |
38 | {DNNL_ARG_DST, dst}, | |
39 | {DNNL_ARG_SCRATCHPAD, scratchpad}}); | |
34 | 40 | |
35 | 41 | // xpz: ??? |
36 | 42 | if (dst.has_scale() && aalgorithm == algorithm::eltwise_relu && |
55 | 61 | |
56 | 62 | auto forward_hints = eltwise_forward::primitive_desc( |
57 | 63 | {prop_kind::forward, aalgorithm, src_desc, alpha, beta}, aengine); |
64 | ||
65 | auto op_attr = dnnl::primitive_attr(); | |
66 | op_attr.set_scratchpad_mode(dnnl::scratchpad_mode::user); | |
67 | ||
58 | 68 | auto pd = |
59 | 69 | primitive_desc({aalgorithm, forward_hints.dst_desc(), src_desc, alpha, beta}, |
60 | aengine, forward_hints); | |
70 | op_attr, aengine, forward_hints); | |
61 | 71 | |
62 | 72 | auto expected_diff_dst = diff_dst.reorder_if_differ_in(pd.diff_dst_desc()); |
63 | 73 | auto expected_src = src.reorder_if_differ_in(pd.src_desc()); |
72 | 82 | algorithm::eltwise_exp_use_dst_for_bwd); |
73 | 83 | auto src_dst_arg = use_dst ? DNNL_ARG_DST : DNNL_ARG_SRC; |
74 | 84 | |
85 | tensor scratchpad(pd.scratchpad_desc()); | |
75 | 86 | super(pd).execute(stream::default_stream(), |
76 | 87 | {{DNNL_ARG_DIFF_DST, expected_diff_dst}, |
77 | 88 | {src_dst_arg, expected_src}, |
78 | {DNNL_ARG_DIFF_SRC, diff_src}}); | |
89 | {DNNL_ARG_DIFF_SRC, diff_src}, | |
90 | {DNNL_ARG_SCRATCHPAD, scratchpad}}); | |
79 | 91 | } |
80 | 92 | }; |
81 | 93 | } // namespace ideep |
2 | 2 | |
3 | 3 | namespace ideep { |
4 | 4 | |
5 | struct inner_product_forward : public dnnl::inner_product_forward { | |
6 | ||
5 | struct inner_product_forward | |
6 | : public dnnl::inner_product_forward, | |
7 | utils::computation_cache<dnnl::inner_product_forward::primitive_desc> { | |
7 | 8 | using super = dnnl::inner_product_forward; |
8 | 9 | |
9 | 10 | static void compute(const tensor& src, |
61 | 62 | primitive_desc({aprop_kind, src_desc, weights_desc, dst_desc}, aengine); |
62 | 63 | return pd.weights_desc(); |
63 | 64 | } |
65 | ||
66 | static primitive_desc get_primitive_desc( | |
67 | const tensor::desc& src_desc, | |
68 | const tensor::desc& weights_desc, | |
69 | const tensor::desc& dst_desc, | |
70 | const tensor::desc& bias_desc = tensor::desc(), | |
71 | const bool with_bias = false, | |
72 | const attr_t& attr = attr_t(), | |
73 | const prop_kind aprop_kind = prop_kind::forward, | |
74 | const engine& aengine = engine::cpu_engine()) { | |
75 | auto key = utils::create_key( | |
76 | aprop_kind, | |
77 | src_desc, | |
78 | weights_desc, | |
79 | bias_desc, | |
80 | dst_desc, | |
81 | attr, | |
82 | with_bias, | |
83 | omp_get_max_threads()); | |
84 | return fetch_or_create(key, [&]() { | |
85 | if (with_bias) { | |
86 | return primitive_desc( | |
87 | {aprop_kind, src_desc, weights_desc, bias_desc, dst_desc}, | |
88 | attr, | |
89 | aengine); | |
90 | } else { | |
91 | return primitive_desc( | |
92 | {aprop_kind, src_desc, weights_desc, dst_desc}, attr, aengine); | |
93 | } | |
94 | }); | |
95 | }; | |
64 | 96 | |
65 | 97 | private: |
66 | 98 | template <bool with_bias> |
135 | 167 | } |
136 | 168 | |
137 | 169 | // determine dst data type |
138 | if (dst_scales.empty() || dst_scales == IDEEP_DEF_SCALE) { | |
170 | if (dst.get_data_type() != data_type::undef) { | |
171 | dst_data_type = dst.get_data_type(); | |
172 | } else if (dst_scales.empty() || dst_scales == IDEEP_DEF_SCALE) { | |
139 | 173 | dst_data_type = data_type::f32; |
140 | 174 | } else if (attr.non_negitive_output()) { |
141 | 175 | dst_data_type = data_type::u8; |
163 | 197 | } |
164 | 198 | } else { |
165 | 199 | op_attr = attr; |
166 | src_desc = {src.get_dims(), data_type::f32, format_tag::any}; | |
167 | 200 | if (src.has_scale()) { |
168 | 201 | auto src_scale = src.get_scale(); |
169 | 202 | src_scale[0] = 1.f / src_scale[0]; |
177 | 210 | // align weights data type with src |
178 | 211 | dst_data_type = src.get_data_type() == data_type::bf16 ? data_type::bf16 |
179 | 212 | : data_type::f32; |
180 | src_desc = src.get_desc().to_type(dst_data_type).to_format_any(); | |
181 | weights_desc = weights.get_desc().to_type(dst_data_type).to_format_any(); | |
213 | src_desc = {src.get_dims(), dst_data_type, format_tag::any}; | |
214 | weights_desc = {weights.get_dims(), dst_data_type, format_tag::any}; | |
182 | 215 | if (with_bias) { |
183 | 216 | IDEEP_ENFORCE(utils::one_of(bias.get_data_type(), |
184 | 217 | data_type::f32, data_type::bf16), |
188 | 221 | } |
189 | 222 | |
190 | 223 | tensor::desc dst_desc(dst_dims, dst_data_type, format_tag::any); |
191 | auto pd = with_bias | |
192 | ? primitive_desc({aprop_kind, src_desc, weights_desc, bias_desc, | |
193 | dst_desc}, op_attr, aengine) | |
194 | : primitive_desc({aprop_kind, src_desc, weights_desc, dst_desc}, | |
195 | op_attr, aengine); | |
196 | ||
197 | // reorder src, weight, dst if needed | |
224 | ||
225 | op_attr.set_scratchpad_mode(dnnl::scratchpad_mode::user); | |
226 | ||
227 | auto pd = get_primitive_desc( | |
228 | src_desc, | |
229 | weights_desc, | |
230 | dst_desc, | |
231 | bias_desc, | |
232 | with_bias, | |
233 | op_attr, | |
234 | aprop_kind); | |
235 | ||
198 | 236 | auto expected_src = src.reorder_if_differ_in(pd.src_desc(), src_attr); |
199 | 237 | auto expected_weights = weights.reorder_if_differ_in(pd.weights_desc(), weights_attr); |
200 | ||
201 | // [ Note output buffer] | |
202 | // In this case, dst is an empty ideep tensor, can be re-init | |
203 | // If dst is not empty, ideep must write result to dst's memory and it is caller's duty to | |
204 | // make sure dst is big enough to hold the result | |
205 | if (dst.is_empty()) | |
206 | dst.init(pd.dst_desc()); | |
207 | auto expected_dst = dst.reorder_if_differ_in(pd.dst_desc()); | |
208 | if (!dst_scales.empty() && utils::one_of(dst.get_data_type(), data_type::u8, data_type::s8)) { | |
238 | ||
239 | tensor expected_dst; | |
240 | if (dst.is_empty() || dst.get_desc() != pd.dst_desc()){ | |
241 | // If dst buffer are not given by user or user given dst buffer are not under expected format | |
242 | // We need init a new one. "dst.get_desc() != pd.dst_desc()" conditional is setting for | |
243 | // caffe2 caller, it might given a non-empty but uncorrect dst (maybe the size is uncorrect) | |
244 | expected_dst.init(pd.dst_desc()); | |
245 | if (!dst.is_empty() && op_attr.has_op_kind(kind::sum)) { | |
246 | // We need copy the content of given buffer if ip is fused with sum | |
247 | expected_dst.feed_from(dst); | |
248 | } | |
249 | } else { | |
250 | // The format of given dst buffer is expected | |
251 | expected_dst = dst; | |
252 | } | |
253 | ||
254 | if (!dst_scales.empty() && utils::one_of(dst.get_data_type(), data_type::u8, data_type::s8)) { | |
209 | 255 | expected_dst.set_scale(dst_scales_in); |
210 | 256 | } |
211 | 257 | |
258 | tensor scratchpad(pd.scratchpad_desc()); | |
259 | ||
212 | 260 | if (with_bias){ |
213 | // reorder bias if needed | |
214 | 261 | auto expected_bias = bias.reorder_if_differ_in(pd.bias_desc(), bias_attr); |
215 | 262 | super(pd).execute(stream::default_stream(), |
216 | 263 | {{DNNL_ARG_SRC, expected_src}, |
217 | 264 | {DNNL_ARG_WEIGHTS, expected_weights}, |
218 | 265 | {DNNL_ARG_BIAS, expected_bias}, |
219 | {DNNL_ARG_DST, expected_dst}}); | |
266 | {DNNL_ARG_DST, expected_dst}, | |
267 | {DNNL_ARG_SCRATCHPAD, scratchpad}}); | |
220 | 268 | } else { |
221 | 269 | super(pd).execute(stream::default_stream(), |
222 | 270 | {{DNNL_ARG_SRC, expected_src}, |
223 | 271 | {DNNL_ARG_WEIGHTS, expected_weights}, |
224 | {DNNL_ARG_DST, expected_dst}}); | |
272 | {DNNL_ARG_DST, expected_dst}, | |
273 | {DNNL_ARG_SCRATCHPAD, scratchpad}}); | |
225 | 274 | } |
226 | 275 | |
227 | 276 | if (attr.non_negitive_output() && expected_dst.get_data_type() == data_type::s8) { |
228 | 277 | expected_dst.to_type(data_type::u8); |
229 | 278 | } |
230 | 279 | // reorder back to dst's buffer if needed |
231 | expected_dst.reorder_to_if_differ_from(dst); | |
280 | if (dst.is_empty() || | |
281 | // when dst is empty, expect return buffer allocate by ideep | |
282 | dst.get_desc() == expected_dst.get_desc() || | |
283 | // dst and expected_dst is the same under this case | |
284 | !dst.get_desc().has_same_shape_as(expected_dst.get_desc())){ | |
285 | // for caffe2 caller, get an uncorrect size dst from caller, can return buffer allocate by ideep | |
286 | dst = expected_dst; | |
287 | } else { | |
288 | dst.feed_from(expected_dst); | |
289 | } | |
232 | 290 | } |
233 | 291 | }; |
234 | 292 | |
257 | 315 | } |
258 | 316 | |
259 | 317 | auto diff_dst_desc = diff_dst.get_desc().to_format_any(); |
260 | auto weights_desc = weights_.get_desc().to_format_any(); | |
318 | auto weights_desc = weights_.get_desc(); | |
261 | 319 | auto diff_src_desc = |
262 | 320 | tensor::desc(diff_src_dims, diff_dst.get_data_type(), tag::any); |
263 | 321 | |
264 | auto forward_hints = | |
265 | inner_product_forward::primitive_desc( | |
266 | {prop_kind::forward, diff_src_desc, weights_desc, diff_dst_desc}, | |
267 | aengine); | |
322 | auto forward_hints = inner_product_forward::get_primitive_desc( | |
323 | diff_src_desc, weights_desc, diff_dst_desc); | |
324 | ||
325 | auto op_attr = dnnl::primitive_attr(); | |
326 | op_attr.set_scratchpad_mode(dnnl::scratchpad_mode::user); | |
268 | 327 | |
269 | 328 | auto pd = primitive_desc( |
270 | {diff_src_desc, weights_desc, diff_dst_desc}, aengine, forward_hints); | |
271 | ||
272 | // reorder diff_dst(grad_y), weights, diff_src(grad_x) if needed | |
329 | {diff_src_desc, weights_desc, diff_dst_desc}, op_attr, aengine, forward_hints); | |
330 | ||
273 | 331 | auto expected_diff_dst = diff_dst.reorder_if_differ_in(pd.diff_dst_desc()); |
274 | 332 | auto expected_weights = weights_.reorder_if_differ_in(pd.weights_desc()); |
275 | // see [Notes output buffer] | |
276 | if (diff_src.is_empty()) | |
277 | diff_src.init(pd.diff_src_desc()); | |
278 | auto expected_diff_src = diff_src.reorder_if_differ_in(pd.diff_src_desc()); | |
333 | tensor expected_diff_src; | |
334 | if (diff_src.is_empty() || diff_src.get_desc() != pd.diff_src_desc()){ | |
335 | // If diff_src buffer are not given by user or user given diff_src buffer are not under expected format | |
336 | // We need init a new one | |
337 | expected_diff_src.init(pd.diff_src_desc()); | |
338 | } else { | |
339 | // The format of given diff_src buffer is expected | |
340 | expected_diff_src = diff_src; | |
341 | } | |
342 | ||
343 | tensor scratchpad(pd.scratchpad_desc()); | |
279 | 344 | |
280 | 345 | super(pd).execute(stream::default_stream(), |
281 | 346 | {{DNNL_ARG_DIFF_DST, expected_diff_dst}, |
282 | 347 | {DNNL_ARG_WEIGHTS, expected_weights}, |
283 | {DNNL_ARG_DIFF_SRC, expected_diff_src}}); | |
284 | expected_diff_src.reorder_to_if_differ_from(diff_src); | |
348 | {DNNL_ARG_DIFF_SRC, expected_diff_src}, | |
349 | {DNNL_ARG_SCRATCHPAD, scratchpad}}); | |
350 | // reorder back to diff_src's buffer if needed | |
351 | if (diff_src.is_empty() || | |
352 | diff_src.get_desc() == expected_diff_src.get_desc() || | |
353 | !diff_src.get_desc().has_same_shape_as(expected_diff_src.get_desc())){ | |
354 | diff_src = expected_diff_src; | |
355 | } else { | |
356 | diff_src.feed_from(expected_diff_src); | |
357 | } | |
285 | 358 | } |
286 | 359 | }; |
287 | 360 | |
337 | 410 | if (diff_weight_type_in != diff_dst_type) { |
338 | 411 | weights_desc = weights_desc.to_type(diff_dst_type); |
339 | 412 | } |
340 | auto forward_hints = with_diff_bias | |
341 | ? inner_product_forward::primitive_desc({prop_kind::forward, src_desc, | |
342 | weights_desc, diff_bias_desc, diff_dst_desc}, aengine) | |
343 | : inner_product_forward::primitive_desc({prop_kind::forward, src_desc, | |
344 | weights_desc, diff_dst_desc}, aengine); | |
413 | auto forward_hints = inner_product_forward::get_primitive_desc( | |
414 | src_desc, weights_desc, diff_dst_desc, diff_bias_desc, with_diff_bias); | |
415 | ||
416 | auto op_attr = dnnl::primitive_attr(); | |
417 | op_attr.set_scratchpad_mode(dnnl::scratchpad_mode::user); | |
418 | ||
345 | 419 | auto pd = with_diff_bias |
346 | 420 | ? primitive_desc({src_desc, diff_weights_desc, diff_bias_desc, |
347 | diff_dst_desc}, aengine, forward_hints) | |
421 | diff_dst_desc}, op_attr, aengine, forward_hints) | |
348 | 422 | : primitive_desc({src_desc, diff_weights_desc, diff_dst_desc}, |
349 | aengine, forward_hints); | |
350 | ||
351 | // reorder diff_dst(grad_y), diff_weights(grad_w), src if needed | |
423 | op_attr, aengine, forward_hints); | |
424 | ||
352 | 425 | auto expected_diff_dst = diff_dst.reorder_if_differ_in(pd.diff_dst_desc()); |
353 | 426 | auto expected_src = src.reorder_if_differ_in(pd.src_desc()); |
354 | if (diff_weights.is_empty()) | |
355 | diff_weights.init(pd.diff_weights_desc()); | |
356 | auto expected_diff_weights = diff_weights.reorder_if_differ_in(pd.diff_weights_desc()); | |
427 | tensor expected_diff_weights; | |
428 | if (diff_weights.is_empty() || diff_weights.get_desc() != pd.diff_weights_desc()){ | |
429 | // If diff_weights buffer are not given by user or user given diff_weights buffer are not under expected format | |
430 | // We need init a new one | |
431 | expected_diff_weights.init(pd.diff_weights_desc()); | |
432 | } else { | |
433 | // The format of given diff_weights buffer is expected | |
434 | expected_diff_weights = diff_weights; | |
435 | } | |
436 | ||
437 | tensor scratchpad(pd.scratchpad_desc()); | |
357 | 438 | |
358 | 439 | exec_args args {{DNNL_ARG_DIFF_DST, expected_diff_dst}, |
359 | 440 | {DNNL_ARG_SRC, expected_src}, |
360 | {DNNL_ARG_DIFF_WEIGHTS ,expected_diff_weights}}; | |
361 | ||
362 | ideep::tensor expected_diff_bias; | |
441 | {DNNL_ARG_DIFF_WEIGHTS ,expected_diff_weights}, | |
442 | {DNNL_ARG_SCRATCHPAD, scratchpad}}; | |
443 | ||
363 | 444 | if (with_diff_bias) { |
364 | // reorder diff_bias(grad_b) | |
365 | if (diff_bias.is_empty()) | |
366 | diff_bias.init(pd.diff_bias_desc()); | |
367 | expected_diff_bias = diff_bias.reorder_if_differ_in(pd.diff_bias_desc()); | |
368 | args.insert({DNNL_ARG_DIFF_BIAS, expected_diff_bias}); | |
445 | diff_bias.reinit_if_possible(pd.diff_bias_desc()); | |
446 | args.insert({DNNL_ARG_DIFF_BIAS, diff_bias}); | |
369 | 447 | } |
370 | 448 | |
371 | 449 | super(pd).execute(stream::default_stream(), args); |
372 | expected_diff_weights.reorder_to_if_differ_from(diff_weights); | |
373 | expected_diff_bias.reorder_to_if_differ_from(diff_bias); | |
450 | // reorder back to diff_weights's buffer if needed | |
451 | if (diff_weights.is_empty() || | |
452 | diff_weights.get_desc() == expected_diff_weights.get_desc() || | |
453 | !diff_weights.get_desc().has_same_shape_as(expected_diff_weights.get_desc())){ | |
454 | diff_weights = expected_diff_weights; | |
455 | } else { | |
456 | diff_weights.feed_from(expected_diff_weights); | |
457 | } | |
374 | 458 | } |
375 | 459 | }; |
376 | 460 |
16 | 16 | const engine& aengine = engine::cpu_engine()) { |
17 | 17 | auto flags = batch_normalization_flag::use_scale_shift; |
18 | 18 | auto src_desc = src.get_desc(); |
19 | auto op_attr = dnnl::primitive_attr(); | |
20 | op_attr.set_scratchpad_mode(dnnl::scratchpad_mode::user); | |
19 | 21 | auto pd = primitive_desc( |
20 | {prop_kind::forward_training, src_desc, epsilon, flags}, aengine); | |
22 | {prop_kind::forward_training, src_desc, epsilon, flags}, | |
23 | op_attr, | |
24 | aengine); | |
21 | 25 | |
22 | 26 | tensor scale_shift {pd.weights_desc()}; |
23 | 27 | auto* scale_shift_buf = static_cast<char *>(scale_shift.get_data_handle()); |
28 | 32 | mean.reinit_if_possible(pd.mean_desc()); |
29 | 33 | variance.reinit_if_possible(pd.variance_desc()); |
30 | 34 | dst.reinit_if_possible(pd.dst_desc()); |
35 | tensor scratchpad(pd.scratchpad_desc()); | |
31 | 36 | |
32 | 37 | super(pd).execute(stream::default_stream(), |
33 | 38 | {{DNNL_ARG_SRC, expected_src}, |
34 | 39 | {DNNL_ARG_SCALE_SHIFT, scale_shift}, |
35 | 40 | {DNNL_ARG_MEAN, mean}, |
36 | 41 | {DNNL_ARG_VARIANCE, variance}, |
37 | {DNNL_ARG_DST, dst}}); | |
42 | {DNNL_ARG_DST, dst}, | |
43 | {DNNL_ARG_SCRATCHPAD, scratchpad}}); | |
38 | 44 | } |
39 | 45 | }; |
40 | 46 |
18 | 18 | |
19 | 19 | // workaround: use src.get_desc() once issue intel/mkl-dnn#588 is resolved |
20 | 20 | auto src_desc = src._get_unblocked_desc_if_4c_blocked(); |
21 | ||
22 | auto op_attr = dnnl::primitive_attr(); | |
23 | op_attr.set_scratchpad_mode(dnnl::scratchpad_mode::user); | |
24 | ||
21 | 25 | // auto src_desc = src.get_desc(); |
22 | 26 | auto pd = primitive_desc( |
23 | 27 | {aprop_kind, aalgorithm, src_desc, local_size, alpha, beta, k}, |
28 | op_attr, | |
24 | 29 | aengine); |
25 | 30 | |
26 | 31 | auto expected_src = src.reorder_if_differ_in(pd.src_desc()); |
27 | 32 | dst.reinit_if_possible(pd.dst_desc()); |
33 | tensor scratchpad(pd.scratchpad_desc()); | |
28 | 34 | |
29 | exec_args args {{DNNL_ARG_SRC, expected_src}, {DNNL_ARG_DST, dst}}; | |
35 | exec_args args { | |
36 | {DNNL_ARG_SRC, expected_src}, | |
37 | {DNNL_ARG_DST, dst}, | |
38 | {DNNL_ARG_SCRATCHPAD, scratchpad}}; | |
30 | 39 | |
31 | 40 | bool with_workspace = aprop_kind == prop_kind::forward_training; |
32 | 41 | if (with_workspace) { |
61 | 70 | src_desc, local_size, alpha, beta, k}, |
62 | 71 | aengine); |
63 | 72 | |
73 | auto op_attr = dnnl::primitive_attr(); | |
74 | op_attr.set_scratchpad_mode(dnnl::scratchpad_mode::user); | |
75 | ||
64 | 76 | auto pd = primitive_desc( |
65 | 77 | {aalgorithm, src_desc, diff_dst.get_desc(), local_size, alpha, beta, k}, |
66 | aengine, forward_hints); | |
78 | op_attr, aengine, forward_hints); | |
67 | 79 | |
68 | 80 | auto expected_diff_dst = diff_dst.reorder_if_differ_in(pd.diff_dst_desc()); |
69 | 81 | diff_src.reinit_if_possible(pd.diff_src_desc()); |
82 | tensor scratchpad(pd.scratchpad_desc()); | |
70 | 83 | |
71 | 84 | exec_args args {{DNNL_ARG_SRC, src}, |
72 | 85 | {DNNL_ARG_DIFF_DST, expected_diff_dst}, |
73 | {DNNL_ARG_DIFF_SRC, diff_src}}; | |
86 | {DNNL_ARG_DIFF_SRC, diff_src}, | |
87 | {DNNL_ARG_SCRATCHPAD, scratchpad}}; | |
74 | 88 | |
75 | 89 | if (dst.has_workspace()) { |
76 | 90 | args.insert({DNNL_ARG_WORKSPACE, dst.get_workspace()}); |
2 | 2 | |
3 | 3 | namespace ideep { |
4 | 4 | |
5 | struct matmul_forward : public dnnl::matmul { | |
6 | ||
5 | struct matmul_forward_params { | |
6 | dnnl::matmul::primitive_desc pd; | |
7 | dnnl::matmul primitive; | |
8 | attr_t op_attr; | |
9 | // bias_attr contains requantization scales for bias | |
10 | attr_t bias_attr; | |
11 | scale_t dst_scales; | |
12 | attr_t src_attr; | |
13 | attr_t weights_attr; | |
14 | tensor::desc src_desc; | |
15 | tensor::desc weights_desc; | |
16 | tensor scales_m; | |
17 | tensor src_zero_point_m; | |
18 | tensor wei_zero_point_m; | |
19 | tensor dst_zero_point_m; | |
20 | }; | |
21 | ||
22 | struct matmul_forward : public dnnl::matmul, | |
23 | utils::computation_cache<dnnl::matmul::primitive_desc> { | |
7 | 24 | using super = dnnl::matmul; |
8 | 25 | |
26 | // With bias. Zero points are passed explicitly as arguments for quantization | |
27 | // Bias is not used if it is empty. | |
28 | static void compute_v2( | |
29 | const tensor& src, | |
30 | const tensor& weights, | |
31 | const tensor& bias, | |
32 | tensor& dst, | |
33 | const float dst_coeff = 1.0f, | |
34 | const float sum_coeff = 1.0f, | |
35 | const scale_t& src_scales = scale_t(), | |
36 | const scale_t& weights_scales = scale_t(), | |
37 | const scale_t& dst_scales = scale_t(), | |
38 | const zero_point_t& src_zero_points = zero_point_t(), | |
39 | const zero_point_t& dst_zero_points = zero_point_t(), | |
40 | const attr_t& attr = attr_t(), | |
41 | const data_type dst_type = data_type::undef, | |
42 | const lowp_kind alowp_kind = u8s8, | |
43 | const engine& aengine = engine::cpu_engine()) { | |
44 | if (bias.is_empty()) { | |
45 | compute_impl</*with_bias=*/false>(src, weights, bias, dst, dst_coeff, sum_coeff, | |
46 | src_scales, weights_scales, dst_scales, | |
47 | src_zero_points, dst_zero_points, | |
48 | attr, dst_type, alowp_kind, aengine); | |
49 | } else { | |
50 | compute_impl</*with_bias=*/true>(src, weights, bias, dst, dst_coeff, sum_coeff, | |
51 | src_scales, weights_scales, dst_scales, | |
52 | src_zero_points, dst_zero_points, | |
53 | attr, dst_type, alowp_kind, aengine); | |
54 | } | |
55 | } | |
56 | ||
57 | // Without bias. Zero points are passed explicitly as arguments for quantization | |
58 | static void compute_v2( | |
59 | const tensor& src, | |
60 | const tensor& weights, | |
61 | tensor& dst, | |
62 | const float dst_coeff = 1.0f, | |
63 | const float sum_coeff = 1.0f, | |
64 | const scale_t& src_scales = scale_t(), | |
65 | const scale_t& weights_scales = scale_t(), | |
66 | const scale_t& dst_scales = scale_t(), | |
67 | const zero_point_t& src_zero_points = zero_point_t(), | |
68 | const zero_point_t& dst_zero_points = zero_point_t(), | |
69 | const attr_t& attr = attr_t(), | |
70 | const data_type dst_type = data_type::undef, | |
71 | const lowp_kind alowp_kind = u8s8, | |
72 | const engine& aengine = engine::cpu_engine()) { | |
73 | static tensor dummy_bias; | |
74 | compute_impl</*with_bias=*/false>(src, weights, dummy_bias, dst, dst_coeff, | |
75 | sum_coeff, src_scales, weights_scales, dst_scales, | |
76 | src_zero_points, dst_zero_points, | |
77 | attr, dst_type, alowp_kind, aengine); | |
78 | } | |
79 | ||
80 | // Deprecated. With bias. Set zero points to tensors for quantization. | |
9 | 81 | static void compute( |
10 | 82 | const tensor& src, |
11 | 83 | const tensor& weights, |
22 | 94 | const engine& aengine = engine::cpu_engine()) { |
23 | 95 | compute_impl</*with_bias=*/true>(src, weights, bias, dst, dst_coeff, sum_coeff, |
24 | 96 | src_scales, weights_scales, dst_scales, |
97 | zero_point_t(), zero_point_t(), | |
25 | 98 | attr, dst_type, alowp_kind, aengine); |
26 | 99 | } |
27 | 100 | |
101 | // Deprecated. Without bias. Set zero points to tensors for quantization. | |
28 | 102 | static void compute( |
29 | 103 | const tensor& src, |
30 | 104 | const tensor& weights, |
40 | 114 | const engine& aengine = engine::cpu_engine()) { |
41 | 115 | static tensor dummy_bias; |
42 | 116 | compute_impl</*with_bias=*/false>(src, weights, dummy_bias, dst, dst_coeff, |
43 | sum_coeff, src_scales, weights_scales, | |
44 | dst_scales, attr, dst_type, alowp_kind, aengine); | |
117 | sum_coeff, src_scales, weights_scales, dst_scales, | |
118 | zero_point_t(), zero_point_t(), | |
119 | attr, dst_type, alowp_kind, aengine); | |
120 | } | |
121 | ||
122 | // Bias is not used if it is empty. | |
123 | template <bool is_dynamic> | |
124 | static void prepare(matmul_forward_params& param, | |
125 | const tensor& src, | |
126 | const tensor& weights, | |
127 | const tensor& bias, | |
128 | tensor& dst, | |
129 | const float dst_coeff = 1.0f, | |
130 | const float sum_coeff = 1.0f, | |
131 | const scale_t& src_scales = scale_t(), | |
132 | const scale_t& weights_scales = scale_t(), | |
133 | const scale_t& dst_scales = scale_t(), | |
134 | const zero_point_t& src_zero_points = zero_point_t(), | |
135 | const zero_point_t& dst_zero_points = zero_point_t(), | |
136 | const attr_t& attr = attr_t(), | |
137 | const data_type dst_type = data_type::undef, | |
138 | const lowp_kind alowp_kind = u8s8, | |
139 | const engine& aengine = engine::cpu_engine()) { | |
140 | auto prepare_type = is_dynamic ? type_prepare_dynamic : type_prepare_static; | |
141 | bool with_bias = (!bias.is_empty()); | |
142 | do_prepare(param, prepare_type, src, weights, bias, with_bias, dst, dst_coeff, sum_coeff, | |
143 | src_scales, weights_scales, dst_scales, src_zero_points, dst_zero_points, | |
144 | attr, dst_type, alowp_kind, aengine); | |
145 | } | |
146 | ||
147 | // Bias is not used if it is empty. | |
148 | static void compute(const matmul_forward_params& param, | |
149 | const tensor& src, | |
150 | const tensor& weights, | |
151 | const tensor& bias, | |
152 | tensor& dst) { | |
153 | bool with_bias = (!bias.is_empty()); | |
154 | do_compute(param, type_compute_static, src, weights, bias, with_bias, dst); | |
155 | } | |
156 | ||
157 | // Bias is not used if it is empty. | |
158 | static void compute_dynamic( | |
159 | const matmul_forward_params& param, | |
160 | const tensor& src, | |
161 | const tensor& weights, | |
162 | const tensor& bias, | |
163 | tensor& dst, | |
164 | const float dst_coeff = 1.0f, | |
165 | const float sum_coeff = 1.0f, | |
166 | const scale_t& src_scales = scale_t(), | |
167 | const scale_t& weights_scales = scale_t(), | |
168 | const scale_t& dst_scales = scale_t(), | |
169 | const zero_point_t& src_zero_points = zero_point_t(), | |
170 | const zero_point_t& dst_zero_points = zero_point_t(), | |
171 | const attr_t& attr = attr_t(), | |
172 | const data_type dst_type = data_type::undef, | |
173 | const lowp_kind alowp_kind = u8s8, | |
174 | const engine& aengine = engine::cpu_engine()) { | |
175 | bool with_bias = (!bias.is_empty()); | |
176 | // Call do_prepare here to calculate scales and zero points. | |
177 | matmul_forward_params param_for_compute; | |
178 | do_prepare(param_for_compute, type_compute_dynamic, src, weights, bias, with_bias, dst, dst_coeff, sum_coeff, | |
179 | src_scales, weights_scales, dst_scales, src_zero_points, dst_zero_points, | |
180 | attr, dst_type, alowp_kind, aengine); | |
181 | param_for_compute.pd = param.pd; | |
182 | param_for_compute.primitive = param.primitive; | |
183 | param_for_compute.op_attr = param.op_attr; | |
184 | do_compute(param_for_compute, type_compute_dynamic, src, weights, bias, with_bias, dst); | |
45 | 185 | } |
46 | 186 | |
47 | 187 | static tensor::desc expected_weights_desc( |
51 | 191 | const engine& aengine = engine::cpu_engine()) { |
52 | 192 | auto ndims = weights_dims.size(); |
53 | 193 | auto x_dims = weights_dims; |
54 | x_dims[ndims-2] = 1; | |
194 | x_dims[ndims-2] = DNNL_RUNTIME_DIM_VAL; | |
55 | 195 | x_dims[ndims-1] = weights_dims[ndims-2]; |
56 | 196 | auto y_dims = {x_dims[0], weights_dims[1]}; |
57 | 197 | if (ndims == 3) |
62 | 202 | "Invalid dims for data and weights"); |
63 | 203 | tensor::desc x_desc(x_dims, x_dtype, ndims == 2 ? tag::ab : tag::abc); |
64 | 204 | tensor::desc y_desc(y_dims, y_dtype, ndims == 2 ? tag::ab : tag::abc); |
65 | tensor::desc weights_desc(weights_dims , dtype, ndims == 2 ? tag::ab : tag::abc); | |
66 | auto pd = primitive_desc({x_desc, weights_desc, y_desc}, aengine); | |
205 | tensor::desc weights_desc(weights_dims , dtype, tag::any); | |
206 | attr_t attr; | |
207 | attr.set_output_scales(/* mask */ (1 << 1), {DNNL_RUNTIME_F32_VAL}); | |
208 | attr.set_zero_points(DNNL_ARG_SRC, /* mask */ 0, {DNNL_RUNTIME_S32_VAL}); | |
209 | attr.set_zero_points(DNNL_ARG_DST, /* mask */ 0, {DNNL_RUNTIME_S32_VAL}); | |
210 | auto pd = primitive_desc({x_desc, weights_desc, y_desc}, attr, aengine); | |
67 | 211 | return pd.weights_desc(); |
68 | 212 | } |
69 | 213 | |
70 | 214 | private: |
215 | enum task_type { | |
216 | type_prepare_static, | |
217 | type_prepare_dynamic, | |
218 | type_compute_static, | |
219 | type_compute_dynamic, | |
220 | }; | |
221 | ||
71 | 222 | template <bool with_bias> |
72 | 223 | static void compute_impl(const tensor& src, |
73 | 224 | const tensor& weights, |
78 | 229 | const scale_t& src_scales = scale_t(), |
79 | 230 | const scale_t& weights_scales = scale_t(), |
80 | 231 | const scale_t& dst_scales = scale_t(), |
232 | const zero_point_t& src_zero_points = zero_point_t(), | |
233 | const zero_point_t& dst_zero_points = zero_point_t(), | |
81 | 234 | const attr_t& attr = attr_t(), |
82 | 235 | const data_type dst_type = data_type::undef, |
83 | 236 | const lowp_kind alowp_kind = u8s8, |
84 | 237 | const engine& aengine = engine::cpu_engine()) { |
238 | matmul_forward_params param; | |
239 | do_prepare(param, type_prepare_static, src, weights, bias, with_bias, dst, dst_coeff, sum_coeff, | |
240 | src_scales, weights_scales, dst_scales, src_zero_points, dst_zero_points, | |
241 | attr, dst_type, alowp_kind, aengine); | |
242 | do_compute(param, type_compute_static, src, weights, bias, with_bias, dst); | |
243 | } | |
244 | ||
245 | static void do_prepare(matmul_forward_params& param, | |
246 | task_type task, | |
247 | const tensor& src, | |
248 | const tensor& weights, | |
249 | const tensor& bias, | |
250 | bool with_bias, | |
251 | tensor& dst, | |
252 | const float dst_coeff = 1.0f, | |
253 | const float sum_coeff = 1.0f, | |
254 | const scale_t& src_scales = scale_t(), | |
255 | const scale_t& weights_scales = scale_t(), | |
256 | const scale_t& dst_scales = scale_t(), | |
257 | const zero_point_t& src_zero_points = zero_point_t(), | |
258 | const zero_point_t& dst_zero_points = zero_point_t(), | |
259 | const attr_t& attr = attr_t(), | |
260 | const data_type dst_type = data_type::undef, | |
261 | const lowp_kind alowp_kind = u8s8, | |
262 | const engine& aengine = engine::cpu_engine()) { | |
85 | 263 | IDEEP_ENFORCE(src.ndims() == weights.ndims(), "Invalid dims in src or weights"); |
86 | 264 | |
87 | 265 | tensor::desc src_desc, weights_desc, bias_desc; |
88 | attr_t op_attr, src_attr, weights_attr, bias_attr; | |
266 | attr_t op_attr = attr, src_attr, weights_attr, bias_attr; | |
89 | 267 | scale_t dst_scales_in; |
90 | 268 | auto dst_data_type = data_type::f32; |
91 | 269 | |
92 | tensor::dims dst_dims = {src.get_dim(0), weights.get_dim(1)}; | |
270 | bool is_dynamic = (task == type_prepare_dynamic || task == type_compute_dynamic); | |
271 | const int64_t runtime_bs = DNNL_RUNTIME_DIM_VAL; | |
272 | tensor::dims src_dims = src.get_dims(); | |
273 | if (task == type_prepare_dynamic) { | |
274 | src_dims[0] = runtime_bs; | |
275 | } | |
276 | tensor::dims dst_dims = {src_dims[0], weights.get_dim(1)}; | |
93 | 277 | auto ndims = weights.ndims(); |
94 | if (ndims == 3) | |
95 | dst_dims = {src.get_dim(0), src.get_dim(1), weights.get_dim(2)}; | |
96 | ||
97 | auto weights_scales_in = | |
278 | if (ndims == 3) | |
279 | dst_dims = {src_dims[0], src.get_dim(1), weights.get_dim(2)}; | |
280 | ||
281 | auto& weights_scales_in = | |
98 | 282 | weights.has_scale() ? weights.get_scale() : weights_scales; |
99 | 283 | tensor scales_m, src_zero_point_m, wei_zero_point_m, dst_zero_point_m; |
100 | 284 | if (!weights_scales_in.empty()) { |
101 | 285 | IDEEP_ENFORCE(alowp_kind == u8s8 || alowp_kind == s8s8, |
102 | 286 | "Unsupported lowp kind"); |
103 | 287 | |
104 | auto src_scales_in = | |
288 | auto src_scales_in = | |
105 | 289 | src.has_scale() ? src.get_scale() |
106 | 290 | : (src_scales.empty() ? IDEEP_DEF_SCALE : src_scales); |
107 | src_desc = {src.get_dims(), | |
108 | alowp_kind == u8s8 ? data_type::u8 : data_type::s8, | |
109 | tag::any}; | |
291 | auto src_data_type = (alowp_kind == u8s8) ? data_type::u8 : data_type::s8; | |
292 | std::vector<int64_t> src_strides = (ndims == 3) ? | |
293 | std::vector<int64_t>({src_dims[1] * src_dims[2], src_dims[1], 1}) : | |
294 | std::vector<int64_t>({src_dims[1], 1}); | |
295 | src_desc = is_dynamic ? | |
296 | tensor::desc(src_dims, src_data_type, src_strides) : | |
297 | tensor::desc(src_dims, src_data_type, tag::any); | |
110 | 298 | if (src.get_data_type() == data_type::f32) { |
111 | 299 | src_attr = {0, src_scales_in}; |
112 | 300 | } |
114 | 302 | int scale_size = (weights_scales_in.size() > 1) ? weights.get_dim(1) : 1; |
115 | 303 | weights_desc = weights.get_desc(); |
116 | 304 | if (weights.get_data_type() == data_type::f32) { |
117 | weights_attr = {utils::tensor_scale_mask(scale_size, false), | |
305 | weights_attr = {utils::tensor_scale_mask(scale_size, false), | |
118 | 306 | weights_scales_in}; |
119 | 307 | } |
120 | ||
308 | ||
121 | 309 | // determine dst data type |
122 | if (dst_scales.empty() || dst_scales == IDEEP_DEF_SCALE) { | |
310 | if (dst.get_data_type() != data_type::undef) { | |
311 | dst_data_type = dst.get_data_type(); | |
312 | } else if (dst_scales.empty() || dst_scales == IDEEP_DEF_SCALE) { | |
123 | 313 | dst_data_type = data_type::f32; |
124 | 314 | } else { |
125 | 315 | dst_data_type = data_type::u8; |
127 | 317 | |
128 | 318 | // fill primitive attr |
129 | 319 | scale_t op_scales(scale_size), bias_scales(scale_size); |
130 | dst_scales_in = (dst_scales.empty() || dst_data_type == data_type::f32) | |
131 | ? IDEEP_DEF_SCALE | |
320 | dst_scales_in = (dst_scales.empty() || dst_data_type == data_type::f32) | |
321 | ? IDEEP_DEF_SCALE | |
132 | 322 | : dst_scales; |
133 | auto src_zero_point = src.has_zero_point() | |
134 | ? src.get_zero_point() : std::vector<int32_t>(1); | |
135 | auto src_zero_point_size = static_cast<dim>(src_zero_point.size()); | |
136 | auto dst_zero_point = dst.has_zero_point() | |
137 | ? dst.get_zero_point() : std::vector<int32_t>(1); | |
138 | auto dst_zero_point_size = static_cast<dim>(dst_zero_point.size()); | |
139 | IDEEP_ENFORCE(src_zero_point_size == 1 && dst_zero_point_size == 1, | |
323 | const zero_point_t default_zero_points = zero_point_t(1); | |
324 | const auto& src_zero_point = src.has_zero_point() ? src.get_zero_point() : | |
325 | src_zero_points.empty() ? default_zero_points : src_zero_points; | |
326 | const auto src_zero_point_size = static_cast<dim>(src_zero_point.size()); | |
327 | const auto& dst_zero_point = dst.has_zero_point() ? dst.get_zero_point() : | |
328 | dst_zero_points.empty() ? default_zero_points : dst_zero_points; | |
329 | const auto dst_zero_point_size = static_cast<dim>(dst_zero_point.size()); | |
330 | IDEEP_ENFORCE(src_zero_point_size == 1 && dst_zero_point_size == 1, | |
140 | 331 | "DNNL only support 1-dim zero_point"); |
141 | auto wei_zero_point = weights.has_zero_point() | |
142 | ? weights.get_zero_point() : std::vector<int32_t>(1); | |
143 | dim wei_zero_point_size = 1; | |
144 | ||
332 | const auto& wei_zero_point = weights.has_zero_point() ? | |
333 | weights.get_zero_point() : default_zero_points; | |
334 | const dim wei_zero_point_size = 1; | |
335 | ||
145 | 336 | if (attr.has_op_kind(kind::sum)) { |
146 | float sum_scale = | |
147 | sum_coeff * dst_scales_in[0] / (dst.has_scale() ? dst.get_scale()[0] : 1.0f); | |
337 | float sum_scale = | |
338 | sum_coeff * dst_scales_in[0] / (dst.has_scale() ? dst.get_scale()[0] : 1.0f); | |
148 | 339 | op_attr = attr_t::fuse_sum(sum_scale); |
149 | 340 | } |
150 | 341 | |
151 | 342 | auto bias_scales_in = |
152 | 343 | bias.has_scale() ? bias.get_scale() : IDEEP_DEF_SCALE; |
153 | bias_scales_in = bias_scales_in.size() == 1 ? | |
154 | std::vector<float>(scale_size, bias_scales_in[0]) : bias_scales_in; | |
155 | bool flag_runtime = false; | |
344 | bias_scales_in = bias_scales_in.size() == 1 ? | |
345 | std::vector<float>(scale_size, bias_scales_in[0]) : bias_scales_in; | |
346 | bool flag_runtime = is_dynamic; | |
156 | 347 | if (flag_runtime) { |
157 | op_attr.set_output_scales(utils::op_scale_mask(scale_size), {DNNL_RUNTIME_F32_VAL}); | |
158 | tensor::desc scales_desc = {{scale_size}, data_type::f32, {1}}; | |
159 | scales_m.init(scales_desc, aengine); | |
160 | auto s = reinterpret_cast<float *>(scales_m.get_data_handle()); | |
161 | for (memory::dim i = 0; i < scale_size; ++i) { | |
162 | bias_scales[i] = src_scales_in[0] * weights_scales_in[i] | |
163 | / (dst_coeff * bias_scales_in[i]); | |
164 | s[i] = dst_coeff * dst_scales_in[0] / (src_scales_in[0] * weights_scales_in[i]); | |
165 | } | |
166 | ||
167 | op_attr.set_zero_points(DNNL_ARG_SRC, utils::tensor_zp_mask(1), {DNNL_RUNTIME_S32_VAL}); | |
168 | tensor::desc src_zero_point_desc = {{src_zero_point_size}, data_type::s32, {1}}; | |
169 | src_zero_point_m.init(src_zero_point_desc, aengine); | |
170 | auto src_z = reinterpret_cast<int32_t *>(src_zero_point_m.get_data_handle()); | |
171 | for (memory::dim i = 0; i < src_zero_point_size; ++i) | |
172 | src_z[i] = src_zero_point[i]; | |
173 | ||
174 | op_attr.set_zero_points(DNNL_ARG_WEIGHTS, utils::tensor_zp_mask(1), {DNNL_RUNTIME_S32_VAL}); | |
175 | tensor::desc wei_zero_point_desc = {{wei_zero_point_size}, data_type::s32, {1}}; | |
176 | wei_zero_point_m.init(wei_zero_point_desc, aengine); | |
177 | auto wei_z = reinterpret_cast<int32_t *>(wei_zero_point_m.get_data_handle()); | |
178 | for (memory::dim i = 0; i < wei_zero_point_size; ++i) | |
179 | wei_z[i] = wei_zero_point[i]; | |
180 | ||
181 | if (dst_data_type != data_type::f32) { | |
182 | op_attr.set_zero_points(DNNL_ARG_DST, utils::tensor_zp_mask(1), {DNNL_RUNTIME_S32_VAL}); | |
183 | tensor::desc dst_zero_point_desc = {{dst_zero_point_size}, data_type::s32, {1}}; | |
184 | dst_zero_point_m.init(dst_zero_point_desc, aengine); | |
185 | auto dst_z = reinterpret_cast<int32_t *>(dst_zero_point_m.get_data_handle()); | |
186 | for (memory::dim i = 0; i < dst_zero_point_size; ++i) | |
187 | dst_z[i] = dst_zero_point[i]; | |
348 | if (task == type_prepare_dynamic) { | |
349 | op_attr.set_output_scales(utils::op_scale_mask(scale_size), {DNNL_RUNTIME_F32_VAL}); | |
350 | op_attr.set_zero_points(DNNL_ARG_SRC, utils::tensor_zp_mask(1), {DNNL_RUNTIME_S32_VAL}); | |
351 | op_attr.set_zero_points(DNNL_ARG_WEIGHTS, utils::tensor_zp_mask(1), {DNNL_RUNTIME_S32_VAL}); | |
352 | if (dst_data_type != data_type::f32) { | |
353 | op_attr.set_zero_points(DNNL_ARG_DST, utils::tensor_zp_mask(1), {DNNL_RUNTIME_S32_VAL}); | |
354 | } | |
355 | } else { // task == type_compute_dynamic | |
356 | tensor::desc scales_desc = {{scale_size}, data_type::f32, {1}}; | |
357 | scales_m.init(scales_desc, aengine); | |
358 | auto s = reinterpret_cast<float *>(scales_m.get_data_handle()); | |
359 | for (memory::dim i = 0; i < scale_size; ++i) { | |
360 | bias_scales[i] = src_scales_in[0] * weights_scales_in[i] / (dst_coeff * bias_scales_in[i]); | |
361 | s[i] = dst_coeff * dst_scales_in[0] / (src_scales_in[0] * weights_scales_in[i]); | |
362 | } | |
363 | if (src.get_data_type() == data_type::f32) { | |
364 | // Set zero point for reorder (quantization). 1st arg should be DNNL_ARG_DST rather than DNNL_ARG_SRC | |
365 | src_attr.set_zero_points(DNNL_ARG_DST, | |
366 | utils::tensor_zp_mask(src_zero_point.size()), src_zero_point); | |
367 | } | |
368 | ||
369 | tensor::desc src_zero_point_desc = {{src_zero_point_size}, data_type::s32, {1}}; | |
370 | src_zero_point_m.init(src_zero_point_desc, aengine); | |
371 | auto src_z = reinterpret_cast<int32_t *>(src_zero_point_m.get_data_handle()); | |
372 | for (memory::dim i = 0; i < src_zero_point_size; ++i) | |
373 | src_z[i] = src_zero_point[i]; | |
374 | ||
375 | tensor::desc wei_zero_point_desc = {{wei_zero_point_size}, data_type::s32, {1}}; | |
376 | wei_zero_point_m.init(wei_zero_point_desc, aengine); | |
377 | auto wei_z = reinterpret_cast<int32_t *>(wei_zero_point_m.get_data_handle()); | |
378 | for (memory::dim i = 0; i < wei_zero_point_size; ++i) | |
379 | wei_z[i] = wei_zero_point[i]; | |
380 | ||
381 | if (dst_data_type != data_type::f32) { | |
382 | tensor::desc dst_zero_point_desc = {{dst_zero_point_size}, data_type::s32, {1}}; | |
383 | dst_zero_point_m.init(dst_zero_point_desc, aengine); | |
384 | auto dst_z = reinterpret_cast<int32_t *>(dst_zero_point_m.get_data_handle()); | |
385 | for (memory::dim i = 0; i < dst_zero_point_size; ++i) | |
386 | dst_z[i] = dst_zero_point[i]; | |
387 | } | |
188 | 388 | } |
189 | 389 | } else { |
190 | 390 | for (int i = 0; i < scale_size; i++) { |
191 | bias_scales[i] = src_scales_in[0] * weights_scales_in[i] | |
192 | / (dst_coeff * bias_scales_in[i]); | |
391 | bias_scales[i] = src_scales_in[0] * weights_scales_in[i] / (dst_coeff * bias_scales_in[i]); | |
193 | 392 | op_scales[i] = dst_coeff * dst_scales_in[0] / (src_scales_in[0] * weights_scales_in[i]); |
194 | 393 | } |
195 | 394 | op_attr.set_output_scales(utils::op_scale_mask(scale_size), op_scales); |
196 | op_attr.set_zero_points(DNNL_ARG_SRC, | |
395 | op_attr.set_zero_points(DNNL_ARG_SRC, | |
197 | 396 | utils::tensor_zp_mask(src_zero_point.size()), src_zero_point); |
198 | op_attr.set_zero_points(DNNL_ARG_WEIGHTS, | |
199 | utils::tensor_zp_mask(1), std::vector<int32_t>(1,wei_zero_point[0])); | |
397 | if (src.get_data_type() == data_type::f32) { | |
398 | // Set zero point for reorder (quantization). 1st arg should be DNNL_ARG_DST rather than DNNL_ARG_SRC | |
399 | src_attr.set_zero_points(DNNL_ARG_DST, | |
400 | utils::tensor_zp_mask(src_zero_point.size()), src_zero_point); | |
401 | } | |
402 | op_attr.set_zero_points(DNNL_ARG_WEIGHTS, | |
403 | utils::tensor_zp_mask(1), zero_point_t(1,wei_zero_point[0])); | |
200 | 404 | if (dst_data_type != data_type::f32) { |
201 | op_attr.set_zero_points(DNNL_ARG_DST, | |
405 | op_attr.set_zero_points(DNNL_ARG_DST, | |
202 | 406 | utils::tensor_zp_mask(dst_zero_point.size()), dst_zero_point); |
203 | 407 | } |
204 | 408 | } |
205 | 409 | |
206 | 410 | if (with_bias) { |
207 | 411 | tag bia_tag = bias.get_dims().size() == 2 ? tag::ab : tag::abc; |
208 | bias_desc = {bias.get_dims(), data_type::s32, bia_tag}; | |
412 | bias_desc = {bias.get_dims(), data_type::f32, bia_tag}; // Use f32 instead of s32 to improve accuracy | |
209 | 413 | if (bias.get_data_type() != data_type::s32) { |
210 | auto ndims = bias.get_dims().size(); | |
414 | auto ndims = bias.get_dims().size(); | |
211 | 415 | int mask = scale_size > 1 ? 1 << (ndims - 1) : 0; |
212 | 416 | bias_attr = {mask, bias_scales}; |
213 | 417 | } |
214 | 418 | } |
215 | 419 | } else { |
216 | op_attr = attr; | |
217 | 420 | if (src.has_scale()) { |
218 | 421 | auto src_scale = src.get_scale(); |
219 | 422 | src_scale[0] = 1.0f / src_scale[0]; |
220 | 423 | src_attr = {0, src_scale}; |
221 | 424 | } |
222 | 425 | |
223 | // We always set tag "any" to all input desc, and assume it's DNNL's duty to find best solution | |
224 | // (which means minimize the execution times (include gemm computation and tensor format conversion)) | |
426 | // We intentionally didn't set weight desc to format `any` so DNNL wouldn't | |
427 | // have to determine weight format for us. Because the weight tensor from | |
428 | // pytorch may have a transposed format (say `ba`). However, DNNL would | |
429 | // choose plain format for it by default (`ab` in this case), which would | |
430 | // introduces *an extra reorder* afterwards. Here we keep the weight format | |
431 | // untouched thanks to optimizations for both plain and transposed formats | |
432 | // in DNNL. | |
225 | 433 | IDEEP_ENFORCE(weights.get_data_type() == data_type::f32 || |
226 | 434 | weights.get_data_type() == data_type::bf16, |
227 | 435 | "Incorrect data type in weights"); |
228 | dst_data_type = src.get_data_type() == data_type::bf16 ? | |
436 | dst_data_type = src.get_data_type() == data_type::bf16 ? | |
229 | 437 | data_type::bf16 : data_type::f32; |
230 | src_desc = src.get_desc().to_type(dst_data_type).to_format_any(); | |
231 | weights_desc = weights.get_desc().to_type(dst_data_type).to_format_any(); | |
438 | src_desc = src.get_desc().to_type(dst_data_type); | |
439 | weights_desc = weights.get_desc().to_type(dst_data_type); | |
232 | 440 | if (with_bias) { |
233 | 441 | IDEEP_ENFORCE(bias.get_data_type() == data_type::f32 || |
234 | 442 | bias.get_data_type() == data_type::bf16, |
242 | 450 | op_attr = attr_t::fuse_sum(sum_coeff); |
243 | 451 | } |
244 | 452 | int scale_size = 1; |
245 | op_attr.set_output_scales(utils::op_scale_mask(scale_size), | |
453 | op_attr.set_output_scales(utils::op_scale_mask(scale_size), | |
246 | 454 | std::vector<float>(1, dst_coeff)); |
247 | 455 | } |
248 | ||
249 | dst_data_type = dst_type == data_type::undef ? dst_data_type : dst_type; | |
250 | tensor::desc dst_desc(dst_dims, dst_data_type, tag::any); | |
251 | auto pd = with_bias | |
252 | ? primitive_desc({src_desc, weights_desc, bias_desc, dst_desc}, | |
253 | op_attr, aengine) | |
254 | : primitive_desc({src_desc, weights_desc, dst_desc}, | |
255 | op_attr, aengine); | |
256 | // reorder src, weight, dst if needed | |
257 | auto expected_src = src.reorder_if_differ_in(pd.src_desc(), src_attr); | |
258 | auto expected_weights = weights.reorder_if_differ_in(pd.weights_desc(), weights_attr); | |
259 | ||
260 | // [ Note output buffer] | |
261 | // In this case, dst is an empty ideep tensor, can be re-init | |
262 | // If dst is not empty, ideep must write result to dst's memory and it is caller's duty to | |
263 | // make sure dst is big enough to hold the result | |
264 | if (dst.is_empty()) | |
265 | dst.init(pd.dst_desc()); | |
266 | auto expected_dst = dst.reorder_if_differ_in(pd.dst_desc()); | |
267 | if (!dst_scales.empty() && utils::one_of(dst.get_data_type(), data_type::u8, data_type::s8)) { | |
268 | expected_dst.set_scale(dst_scales_in); | |
269 | } | |
270 | ||
456 | op_attr.set_scratchpad_mode(dnnl::scratchpad_mode::user); | |
457 | ||
458 | if (task == type_prepare_static || task == type_prepare_dynamic) { | |
459 | dst_data_type = dst_type == data_type::undef ? dst_data_type : dst_type; | |
460 | std::vector<int64_t> dst_strides = (ndims == 3) ? | |
461 | std::vector<int64_t>({dst_dims[2]* dst_dims[1], dst_dims[1], 1}) : | |
462 | std::vector<int64_t>({dst_dims[1], 1}); | |
463 | tensor::desc dst_desc = is_dynamic ? | |
464 | tensor::desc(dst_dims, dst_data_type, dst_strides) : | |
465 | tensor::desc(dst_dims, dst_data_type, tag::any); | |
466 | auto key = utils::create_key( | |
467 | src_desc, | |
468 | weights_desc, | |
469 | bias_desc, | |
470 | dst_desc, | |
471 | op_attr, | |
472 | with_bias, | |
473 | omp_get_max_threads()); | |
474 | auto pd = fetch_or_create(key, [&]() { | |
475 | if (with_bias) { | |
476 | return primitive_desc( | |
477 | {src_desc, weights_desc, bias_desc, dst_desc}, op_attr, aengine); | |
478 | } else { | |
479 | return primitive_desc( | |
480 | {src_desc, weights_desc, dst_desc}, op_attr, aengine); | |
481 | } | |
482 | }); | |
483 | param.primitive = std::move(super(pd)); | |
484 | param.pd = std::move(pd); | |
485 | param.op_attr = std::move(op_attr); | |
486 | } | |
487 | ||
488 | param.src_attr = std::move(src_attr); | |
489 | param.weights_attr = std::move(weights_attr); | |
490 | param.bias_attr = std::move(bias_attr); | |
491 | param.dst_scales = std::move(dst_scales_in); | |
492 | if (task == type_compute_dynamic) { | |
493 | param.src_desc = src_desc; | |
494 | param.weights_desc = weights_desc; | |
495 | param.scales_m = std::move(scales_m); | |
496 | param.src_zero_point_m = std::move(src_zero_point_m); | |
497 | param.wei_zero_point_m = std::move(wei_zero_point_m); | |
498 | param.dst_zero_point_m = std::move(dst_zero_point_m); | |
499 | } | |
500 | } | |
501 | ||
502 | static void do_compute(const matmul_forward_params& param, | |
503 | task_type task, | |
504 | const tensor& src, | |
505 | const tensor& weights, | |
506 | const tensor& bias, | |
507 | bool with_bias, | |
508 | tensor& dst) { | |
509 | auto& pd = param.pd; | |
510 | auto& primitive = param.primitive; | |
511 | auto& op_attr = param.op_attr; | |
512 | auto& src_attr = param.src_attr; | |
513 | auto& weights_attr = param.weights_attr; | |
514 | auto& bias_attr = param.bias_attr; | |
515 | auto& dst_scales_in = param.dst_scales; | |
516 | auto& src_desc = param.src_desc; | |
517 | auto& wei_desc = param.weights_desc; | |
518 | auto& scales_m = param.scales_m; | |
519 | auto& src_zero_point_m = param.src_zero_point_m; | |
520 | auto& wei_zero_point_m = param.wei_zero_point_m; | |
521 | auto& dst_zero_point_m = param.dst_zero_point_m; | |
522 | ||
523 | auto expected_src_desc = (task == type_compute_dynamic) ? src_desc : tensor::desc(pd.src_desc()); | |
524 | auto expected_wei_desc = (task == type_compute_dynamic) ? wei_desc : tensor::desc(pd.weights_desc()); | |
525 | auto expected_dst_desc = (task == type_compute_dynamic) ? dst.get_desc() : pd.dst_desc(); | |
526 | auto expected_src = src.reorder_if_differ_in(expected_src_desc, src_attr); | |
527 | auto expected_weights = weights.reorder_if_differ_in(expected_wei_desc, weights_attr); | |
528 | tensor expected_dst; | |
529 | if (dst.is_empty() || dst.get_desc() != expected_dst_desc){ | |
530 | // If dst buffer are not given by user or user given dst buffer are not under expected format | |
531 | // We need init a new one | |
532 | expected_dst.init(expected_dst_desc); | |
533 | if (!dst.is_empty() && op_attr.has_op_kind(kind::sum)) { | |
534 | // We need copy the content of given buffer if matmul is fused with sum | |
535 | expected_dst.feed_from(dst); | |
536 | } | |
537 | } else { | |
538 | // The format of given dst buffer is expected | |
539 | expected_dst = dst; | |
540 | } | |
541 | ||
542 | if (!dst_scales_in.empty() && utils::one_of(dst.get_data_type(), data_type::u8, data_type::s8)) { | |
543 | expected_dst.set_scale(dst_scales_in); | |
544 | } | |
545 | tensor scratchpad(pd.scratchpad_desc()); | |
271 | 546 | if (with_bias){ |
272 | // reorder bias if needed | |
273 | 547 | auto expected_bias = bias.reorder_if_differ_in(pd.bias_desc(), bias_attr); |
274 | super(pd).execute(stream::default_stream(), | |
548 | primitive.execute(stream::default_stream(), | |
275 | 549 | {{DNNL_ARG_SRC, expected_src}, |
276 | 550 | {DNNL_ARG_WEIGHTS, expected_weights}, |
277 | 551 | {DNNL_ARG_BIAS, expected_bias}, |
279 | 553 | {DNNL_ARG_ATTR_OUTPUT_SCALES, scales_m}, |
280 | 554 | {DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_SRC, src_zero_point_m}, |
281 | 555 | {DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_WEIGHTS, wei_zero_point_m}, |
282 | {DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_DST, dst_zero_point_m}}); | |
556 | {DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_DST, dst_zero_point_m}, | |
557 | {DNNL_ARG_SCRATCHPAD, scratchpad}}); | |
283 | 558 | } else { |
284 | super(pd).execute(stream::default_stream(), | |
559 | primitive.execute(stream::default_stream(), | |
285 | 560 | {{DNNL_ARG_SRC, expected_src}, |
286 | 561 | {DNNL_ARG_WEIGHTS, expected_weights}, |
287 | 562 | {DNNL_ARG_DST, expected_dst}, |
288 | 563 | {DNNL_ARG_ATTR_OUTPUT_SCALES, scales_m}, |
289 | 564 | {DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_SRC, src_zero_point_m}, |
290 | 565 | {DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_WEIGHTS, wei_zero_point_m}, |
291 | {DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_DST, dst_zero_point_m}}); | |
292 | } | |
293 | // reorder back to dst's buffer if needed | |
294 | expected_dst.reorder_to_if_differ_from(dst); | |
295 | } | |
566 | {DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_DST, dst_zero_point_m}, | |
567 | {DNNL_ARG_SCRATCHPAD, scratchpad}}); | |
568 | } | |
569 | // reorder back to dst's buffer if needed | |
570 | if (dst.is_empty() || | |
571 | dst.get_desc() == expected_dst.get_desc() || | |
572 | !dst.get_desc().has_same_shape_as(expected_dst.get_desc())){ | |
573 | dst = expected_dst; | |
574 | } else { | |
575 | dst.feed_from(expected_dst); | |
576 | } | |
577 | } | |
578 | ||
296 | 579 | }; |
297 | 580 | |
298 | 581 | } // namespace ideep |
1 | 1 | #define IDEEP_OPERATORS_POOL_HPP |
2 | 2 | |
3 | 3 | namespace ideep { |
4 | // pooling_v2_forward/backward supports dilation, | |
5 | // while pooling_forward/backward does not. | |
4 | 6 | |
5 | 7 | struct pooling_forward : public dnnl::pooling_forward { |
6 | 8 | |
25 | 27 | |
26 | 28 | tensor::desc dst_desc(output_sizes, src.get_data_type(), tag::any); |
27 | 29 | |
30 | auto op_attr = dnnl::primitive_attr(); | |
31 | op_attr.set_scratchpad_mode(dnnl::scratchpad_mode::user); | |
32 | ||
28 | 33 | auto pd = primitive_desc( |
29 | 34 | {aprop_kind, aalgorithm, src_desc, dst_desc, strides, kernel, padding_l, |
30 | padding_r}, aengine); | |
35 | padding_r}, op_attr, aengine); | |
31 | 36 | |
32 | 37 | auto expected_src = src.reorder_if_differ_in(pd.src_desc()); |
33 | 38 | dst.reinit_if_possible(pd.dst_desc()); |
35 | 40 | dst.set_scale(src.get_scale()); |
36 | 41 | } |
37 | 42 | |
38 | exec_args args {{DNNL_ARG_SRC, expected_src}, {DNNL_ARG_DST, dst}}; | |
43 | tensor scratchpad(pd.scratchpad_desc()); | |
44 | ||
45 | exec_args args { | |
46 | {DNNL_ARG_SRC, expected_src}, | |
47 | {DNNL_ARG_DST, dst}, | |
48 | {DNNL_ARG_SCRATCHPAD, scratchpad}}; | |
49 | if (with_workspace) { | |
50 | dst.init_workspace(pd.workspace_desc()); | |
51 | args.insert({DNNL_ARG_WORKSPACE, dst.get_workspace()}); | |
52 | } | |
53 | ||
54 | super(pd).execute(stream::default_stream(), args); | |
55 | } | |
56 | }; | |
57 | ||
58 | struct pooling_v2_forward : public dnnl::pooling_v2_forward { | |
59 | ||
60 | using super = dnnl::pooling_v2_forward; | |
61 | ||
62 | static void compute(const tensor& src, | |
63 | const dims& output_sizes, | |
64 | tensor& dst, | |
65 | const dims& strides, | |
66 | const dims& kernel, | |
67 | const dims& dilation, | |
68 | const dims& padding_l, | |
69 | const dims& padding_r, | |
70 | algorithm aalgorithm, | |
71 | prop_kind aprop_kind = prop_kind::forward, | |
72 | const engine& aengine = engine::cpu_engine()) { | |
73 | bool with_workspace = aprop_kind == prop_kind::forward_training && | |
74 | aalgorithm == dnnl::algorithm::pooling_max; | |
75 | ||
76 | // workaround: use src.get_desc() once issue intel/mkl-dnn#588 is resolved | |
77 | auto src_desc = src._get_unblocked_desc_if_4c_blocked(); | |
78 | // auto src_desc = src.get_desc(); | |
79 | ||
80 | tensor::desc dst_desc(output_sizes, src.get_data_type(), tag::any); | |
81 | ||
82 | auto dil_compatible = utils::get_compatible_dilates(dilation); | |
83 | ||
84 | auto op_attr = dnnl::primitive_attr(); | |
85 | op_attr.set_scratchpad_mode(dnnl::scratchpad_mode::user); | |
86 | ||
87 | auto pd = primitive_desc( | |
88 | {aprop_kind, aalgorithm, src_desc, dst_desc, strides, kernel, | |
89 | dil_compatible, padding_l, padding_r}, op_attr, aengine); | |
90 | ||
91 | auto expected_src = src.reorder_if_differ_in(pd.src_desc()); | |
92 | dst.reinit_if_possible(pd.dst_desc()); | |
93 | if (src.has_scale()) { | |
94 | dst.set_scale(src.get_scale()); | |
95 | } | |
96 | ||
97 | tensor scratchpad(pd.scratchpad_desc()); | |
98 | ||
99 | exec_args args { | |
100 | {DNNL_ARG_SRC, expected_src}, | |
101 | {DNNL_ARG_DST, dst}, | |
102 | {DNNL_ARG_SCRATCHPAD, scratchpad}}; | |
103 | ||
39 | 104 | if (with_workspace) { |
40 | 105 | dst.init_workspace(pd.workspace_desc()); |
41 | 106 | args.insert({DNNL_ARG_WORKSPACE, dst.get_workspace()}); |
59 | 124 | const dims& padding_r, |
60 | 125 | algorithm aalgorithm, |
61 | 126 | const engine& aengine = engine::cpu_engine()) { |
62 | auto src_desc = src.get_desc().to_format_any(); | |
127 | auto src_desc = src.get_desc(); | |
63 | 128 | auto dst_desc = dst.get_desc(); |
64 | 129 | |
65 | 130 | auto forward_hints = |
67 | 132 | {prop_kind::forward, aalgorithm, src_desc, dst_desc, strides, |
68 | 133 | kernel, padding_l, padding_r}, aengine); |
69 | 134 | |
135 | auto op_attr = dnnl::primitive_attr(); | |
136 | op_attr.set_scratchpad_mode(dnnl::scratchpad_mode::user); | |
137 | ||
70 | 138 | auto pd = primitive_desc( |
71 | 139 | {aalgorithm, src_desc, dst_desc, strides, kernel, padding_l, padding_r}, |
72 | aengine, forward_hints); | |
140 | op_attr, aengine, forward_hints); | |
73 | 141 | |
74 | 142 | auto expected_diff_dst = diff_dst.reorder_if_differ_in(pd.diff_dst_desc()); |
75 | 143 | diff_src.reinit_if_possible(pd.diff_src_desc()); |
76 | 144 | |
145 | tensor scratchpad(pd.scratchpad_desc()); | |
146 | ||
77 | 147 | exec_args args {{DNNL_ARG_DIFF_DST, expected_diff_dst}, |
78 | {DNNL_ARG_DIFF_SRC, diff_src}}; | |
148 | {DNNL_ARG_DIFF_SRC, diff_src}, | |
149 | {DNNL_ARG_SCRATCHPAD, scratchpad}}; | |
150 | ||
79 | 151 | if (dst.has_workspace()) { |
80 | 152 | auto expected_workspace = |
81 | 153 | dst.get_workspace().reorder_if_differ_in(pd.workspace_desc()); |
86 | 158 | } |
87 | 159 | }; |
88 | 160 | |
161 | struct pooling_v2_backward : public dnnl::pooling_v2_backward { | |
162 | ||
163 | using super = dnnl::pooling_v2_backward; | |
164 | ||
165 | static void compute(const tensor& diff_dst, | |
166 | const tensor& dst, | |
167 | const tensor& src, | |
168 | tensor& diff_src, | |
169 | const dims& strides, | |
170 | const dims& kernel, | |
171 | const dims& dilation, | |
172 | const dims& padding_l, | |
173 | const dims& padding_r, | |
174 | algorithm aalgorithm, | |
175 | const engine& aengine = engine::cpu_engine()) { | |
176 | auto src_desc = src.get_desc(); | |
177 | auto dst_desc = dst.get_desc(); | |
178 | auto dil_compatible = utils::get_compatible_dilates(dilation); | |
179 | ||
180 | auto forward_hints = | |
181 | pooling_v2_forward::primitive_desc( | |
182 | {prop_kind::forward, aalgorithm, src_desc, dst_desc, strides, | |
183 | kernel, dil_compatible, padding_l, padding_r}, aengine); | |
184 | ||
185 | auto op_attr = dnnl::primitive_attr(); | |
186 | op_attr.set_scratchpad_mode(dnnl::scratchpad_mode::user); | |
187 | ||
188 | auto pd = primitive_desc( | |
189 | {aalgorithm, src_desc, dst_desc, strides, kernel, dil_compatible, | |
190 | padding_l, padding_r}, op_attr, aengine, forward_hints); | |
191 | ||
192 | auto expected_diff_dst = diff_dst.reorder_if_differ_in(pd.diff_dst_desc()); | |
193 | diff_src.reinit_if_possible(pd.diff_src_desc()); | |
194 | tensor scratchpad(pd.scratchpad_desc()); | |
195 | ||
196 | exec_args args {{DNNL_ARG_DIFF_DST, expected_diff_dst}, | |
197 | {DNNL_ARG_DIFF_SRC, diff_src}, | |
198 | {DNNL_ARG_SCRATCHPAD, scratchpad}}; | |
199 | if (dst.has_workspace()) { | |
200 | auto expected_workspace = | |
201 | dst.get_workspace().reorder_if_differ_in(pd.workspace_desc()); | |
202 | args.insert({DNNL_ARG_WORKSPACE, expected_workspace}); | |
203 | } | |
204 | ||
205 | super(pd).execute(stream::default_stream(), args); | |
206 | } | |
207 | }; | |
208 | ||
89 | 209 | } // namespace ideep |
90 | 210 | |
91 | 211 | #endif |
0 | #ifndef IDEEP_OPERATORS_PRELU_HPP | |
1 | #define IDEEP_OPERATORS_PRELU_HPP | |
2 | ||
3 | namespace ideep { | |
4 | ||
5 | struct prelu_forward : public dnnl::prelu_forward { | |
6 | ||
7 | using super = dnnl::prelu_forward; | |
8 | ||
9 | static void compute(const tensor& src, | |
10 | const tensor& weight, | |
11 | tensor& dst, | |
12 | prop_kind aprop_kind = prop_kind::forward, | |
13 | const engine& aengine = engine::cpu_engine()) { | |
14 | auto src_in = src; | |
15 | auto weight_in = weight; | |
16 | ||
17 | // Reshape weight to src dimension | |
18 | auto new_dims = src.get_dims(); | |
19 | if (src.ndims() != weight.ndims()) { | |
20 | std::vector<dim> dim_w(src.ndims(), 1); | |
21 | dim_w[1] = weight.get_dim(0); | |
22 | weight_in.reshape(dim_w); | |
23 | } | |
24 | ||
25 | auto src_desc = src_in.get_desc(); | |
26 | auto weight_desc = weight_in.get_desc().to_format_any(); | |
27 | ||
28 | auto op_attr = dnnl::primitive_attr(); | |
29 | op_attr.set_scratchpad_mode(dnnl::scratchpad_mode::user); | |
30 | ||
31 | auto pd = primitive_desc({aprop_kind, src_desc, weight_desc}, op_attr, aengine); | |
32 | auto expected_weights = weight_in.reorder_if_differ_in(pd.weights_desc()); | |
33 | dst.reinit_if_possible(pd.dst_desc()); | |
34 | ||
35 | tensor scratchpad(pd.scratchpad_desc()); | |
36 | ||
37 | super(pd).execute(stream::default_stream(), | |
38 | {{DNNL_ARG_SRC, src_in}, | |
39 | {DNNL_ARG_WEIGHTS, expected_weights}, | |
40 | {DNNL_ARG_DST, dst}, | |
41 | {DNNL_ARG_SCRATCHPAD, scratchpad}}); | |
42 | } | |
43 | }; | |
44 | ||
45 | struct prelu_backward : public dnnl::prelu_backward { | |
46 | ||
47 | using super = dnnl::prelu_backward; | |
48 | ||
49 | static void compute(const tensor& src, | |
50 | const tensor& weight, | |
51 | const tensor& diff_dst, | |
52 | tensor& diff_src, | |
53 | tensor& diff_weight, | |
54 | prop_kind aprop_kind = prop_kind::backward, | |
55 | const engine& aengine = engine::cpu_engine()) { | |
56 | auto src_in = src; | |
57 | auto weight_in = weight; | |
58 | auto diff_dst_in = diff_dst; | |
59 | auto weight_dims = weight_in.get_dims(); | |
60 | ||
61 | // Reshape wieght to src dimension | |
62 | auto new_dims = src.get_dims(); | |
63 | if (src.ndims() != weight.ndims()) { | |
64 | std::vector<dim> dim_w(src.ndims(), 1); | |
65 | dim_w[1] = weight.get_dim(0); | |
66 | weight_in.reshape(dim_w); | |
67 | } | |
68 | ||
69 | auto src_desc = src_in.get_desc(); | |
70 | auto weight_desc = weight_in.get_desc().to_format_any(); | |
71 | auto diff_dst_desc = diff_dst_in.get_desc(); | |
72 | auto diff_weights_desc = | |
73 | tensor::desc(weight_in.get_dims(), diff_dst_in.get_data_type(), tag::any).to_format_any(); | |
74 | auto forward_hints = prelu_forward::primitive_desc( | |
75 | {prop_kind::forward, src_desc, weight_desc}, aengine); | |
76 | ||
77 | auto op_attr = dnnl::primitive_attr(); | |
78 | op_attr.set_scratchpad_mode(dnnl::scratchpad_mode::user); | |
79 | ||
80 | auto pd = primitive_desc( | |
81 | {src_desc, weight_desc, diff_dst_desc, diff_weights_desc}, op_attr, aengine, forward_hints); | |
82 | ||
83 | auto expected_diff_dst = diff_dst_in.reorder_if_differ_in(pd.diff_dst_desc()); | |
84 | auto expected_src = src_in.reorder_if_differ_in(pd.src_desc()); | |
85 | auto expected_weights = weight_in.reorder_if_differ_in(pd.weights_desc()); | |
86 | diff_src.reinit_if_possible(pd.diff_src_desc()); | |
87 | diff_weight.reinit_if_possible(pd.diff_weights_desc()); | |
88 | ||
89 | tensor scratchpad(pd.scratchpad_desc()); | |
90 | ||
91 | super(pd).execute(stream::default_stream(), | |
92 | {{DNNL_ARG_DIFF_DST, expected_diff_dst}, | |
93 | {DNNL_ARG_SRC, expected_src}, | |
94 | {DNNL_ARG_WEIGHTS, expected_weights}, | |
95 | {DNNL_ARG_DIFF_SRC, diff_src}, | |
96 | {DNNL_ARG_DIFF_WEIGHTS ,diff_weight}, | |
97 | {DNNL_ARG_SCRATCHPAD, scratchpad}}); | |
98 | ||
99 | // Reshape weight back to original dimension | |
100 | if (diff_weight.get_dims() != weight_dims) { | |
101 | diff_weight.reshape(weight_dims); | |
102 | } | |
103 | } | |
104 | }; | |
105 | ||
106 | } // namespace ideep | |
107 | ||
108 | #endif |
14 | 14 | auto src_desc = src.get_desc(); |
15 | 15 | dst.reinit_if_possible(src_desc); |
16 | 16 | |
17 | auto op_attr = dnnl::primitive_attr(); | |
18 | op_attr.set_scratchpad_mode(dnnl::scratchpad_mode::user); | |
19 | ||
17 | 20 | auto pd = primitive_desc( |
18 | {aprop_kind, src_desc, softmax_axis}, aengine); | |
19 | ||
20 | super(pd).execute(stream::default_stream(), | |
21 | {{DNNL_ARG_SRC, src}, {DNNL_ARG_DST, dst}}); | |
21 | {aprop_kind, src_desc, softmax_axis}, op_attr, aengine); | |
22 | tensor scratchpad(pd.scratchpad_desc()); | |
23 | super(pd).execute( | |
24 | stream::default_stream(), | |
25 | {{DNNL_ARG_SRC, src}, | |
26 | {DNNL_ARG_DST, dst}, | |
27 | {DNNL_ARG_SCRATCHPAD, scratchpad}}); | |
22 | 28 | } |
23 | 29 | }; |
24 | 30 | |
35 | 41 | auto forward_hints = softmax_forward::primitive_desc( |
36 | 42 | {prop_kind::forward_inference, dst.get_desc(), softmax_axis}, aengine); |
37 | 43 | |
44 | auto op_attr = dnnl::primitive_attr(); | |
45 | op_attr.set_scratchpad_mode(dnnl::scratchpad_mode::user); | |
46 | ||
38 | 47 | auto pd = |
39 | 48 | primitive_desc({diff_dst.get_desc(), dst.get_desc(), softmax_axis}, |
40 | aengine, forward_hints); | |
49 | op_attr, aengine, forward_hints); | |
41 | 50 | auto expected_dst = dst.reorder_if_differ_in(pd.dst_desc()); |
42 | 51 | auto expected_diff_dst = diff_dst.reorder_if_differ_in(pd.diff_dst_desc()); |
43 | 52 | diff_src.reinit_if_possible(pd.diff_src_desc()); |
44 | 53 | |
54 | tensor scratchpad(pd.scratchpad_desc()); | |
55 | ||
45 | 56 | super(pd).execute(stream::default_stream(), |
46 | 57 | {{DNNL_ARG_DST, expected_dst}, |
47 | 58 | {DNNL_ARG_DIFF_DST, expected_diff_dst}, |
48 | {DNNL_ARG_DIFF_SRC, diff_src}}); | |
59 | {DNNL_ARG_DIFF_SRC, diff_src}, | |
60 | {DNNL_ARG_SCRATCHPAD, scratchpad}}); | |
49 | 61 | |
50 | 62 | } |
51 | 63 | }; |
14 | 14 | // "upcast" vector<tensor::desc> to vector<memory::desc> |
15 | 15 | return static_cast<memory::desc>(t.get_desc()); |
16 | 16 | }); |
17 | auto pd = primitive_desc(scales, src_descs, aengine); | |
17 | ||
18 | auto op_attr = dnnl::primitive_attr(); | |
19 | op_attr.set_scratchpad_mode(dnnl::scratchpad_mode::user); | |
20 | ||
21 | auto pd = primitive_desc(scales, src_descs, aengine, op_attr); | |
18 | 22 | |
19 | 23 | dst.reinit_if_possible(pd.dst_desc()); |
20 | ||
21 | exec_args args {{DNNL_ARG_DST, dst}}; | |
24 | tensor scratchpad(pd.scratchpad_desc()); | |
25 | exec_args args {{DNNL_ARG_DST, dst}, {DNNL_ARG_SCRATCHPAD, scratchpad}}; | |
22 | 26 | for (int i = 0; i < srcs.size(); ++i) { |
23 | 27 | args.insert({DNNL_ARG_MULTIPLE_SRC + i, srcs[i]}); |
24 | 28 | } |
157 | 157 | && strides[w] == dims[c] |
158 | 158 | && strides[c] == 1; |
159 | 159 | }; |
160 | ||
161 | inline bool is_ndhwc() const { | |
162 | if (!is_plain() || data.ndims != 5) return false; | |
163 | const auto &dims = data.dims; | |
164 | const auto &strides = blocking_strides(); | |
165 | const auto n = 0, c = 1, d =2, h = 3, w = 4; | |
166 | return strides[n] == dims[d] * dims[h] * dims[w] * dims[c] | |
167 | && strides[d] == dims[h] * dims[w] * dims[c] | |
168 | && strides[h] == dims[w] * dims[c] | |
169 | && strides[w] == dims[c] | |
170 | && strides[c] == 1; | |
171 | } | |
160 | 172 | |
161 | 173 | inline bool is_nchw() const { |
162 | 174 | if (!is_plain() || data.ndims != 4) return false; |
304 | 316 | new_inner_idxs[i] = perms[old_inner_idxs[i]]; |
305 | 317 | new_inner_blks[i] = old_inner_blks[i]; |
306 | 318 | } |
319 | new_desc.data.extra = data.extra; | |
307 | 320 | |
308 | 321 | return new_desc; |
309 | 322 | } |
370 | 383 | return desc(md); |
371 | 384 | } |
372 | 385 | |
373 | private: | |
374 | ||
375 | /// Returns dimension vector | |
376 | inline dims get_internal_dims() const { | |
377 | return dims(data.dims, data.dims + data.ndims); | |
378 | } | |
379 | ||
380 | const dims_t &padded_dims() const { return data.padded_dims; } | |
381 | ||
382 | const dims_t &padded_offsets() const { return data.padded_offsets; } | |
383 | ||
384 | dim_t offset0() const { return data.offset0; } | |
385 | ||
386 | inline format_kind_t format_kind() const { return data.format_kind; } | |
387 | ||
388 | bool is_blocking_desc() const { return format_kind() == dnnl_blocked; } | |
389 | ||
390 | bool is_wino_desc() const { return format_kind() == dnnl_format_kind_wino; } | |
391 | ||
392 | bool is_rnn_packed_desc() const { | |
393 | return format_kind() == dnnl_format_kind_rnn_packed; | |
394 | } | |
395 | ||
396 | 386 | const blocking_desc_t &blocking_desc() const { |
397 | 387 | IDEEP_ENFORCE(is_blocking_desc(), |
398 | 388 | "Cannot get blocking desc on a non-blocking desc"); |
399 | 389 | return data.format_desc.blocking; |
390 | } | |
391 | ||
392 | private: | |
393 | ||
394 | /// Returns dimension vector | |
395 | inline dims get_internal_dims() const { | |
396 | return dims(data.dims, data.dims + data.ndims); | |
397 | } | |
398 | ||
399 | const dims_t &padded_dims() const { return data.padded_dims; } | |
400 | ||
401 | const dims_t &padded_offsets() const { return data.padded_offsets; } | |
402 | ||
403 | dim_t offset0() const { return data.offset0; } | |
404 | ||
405 | inline format_kind_t format_kind() const { return data.format_kind; } | |
406 | ||
407 | bool is_blocking_desc() const { return format_kind() == dnnl_blocked; } | |
408 | ||
409 | bool is_wino_desc() const { return format_kind() == dnnl_format_kind_wino; } | |
410 | ||
411 | bool is_rnn_packed_desc() const { | |
412 | return format_kind() == dnnl_format_kind_rnn_packed; | |
400 | 413 | } |
401 | 414 | |
402 | 415 | dims_t& blocking_strides() const { |
653 | 666 | } |
654 | 667 | |
655 | 668 | tensor reorder_if_differ_in(const desc &expected_desc, const attr_t &aattr = attr_t()) const { |
656 | if (expected_desc == get_desc()) { | |
669 | auto output_scales = std::get<0>(aattr.get_output_scales()); | |
670 | auto is_empty_or_ones = output_scales.empty() | |
671 | || std::all_of(output_scales.begin(), output_scales.end(), [](float i){return 1.0==i;}); | |
672 | if (expected_desc == get_desc() && is_empty_or_ones) { | |
657 | 673 | return *this; |
658 | 674 | } else { |
659 | 675 | tensor dst{expected_desc}; |
660 | 676 | this->reorder_to(dst, aattr); |
661 | 677 | return dst; |
662 | 678 | } |
663 | } | |
664 | ||
665 | // Reorder data from *this to dst if dst's memory desc(size, stride, format, etc) is different from *this; | |
666 | void reorder_to_if_differ_from(tensor &dst, const attr_t &aattr = attr_t()) const { | |
667 | if (dst.get_desc() != get_desc()) { | |
668 | this->reorder_to(dst, aattr); | |
669 | } | |
670 | return; | |
671 | 679 | } |
672 | 680 | |
673 | 681 | // workaround for issue intel/mkl-dnn#588 |
697 | 705 | // handle channels last with groups |
698 | 706 | if (is_deconv) { |
699 | 707 | // deconv: judge whether is channels last on iohw format |
700 | auto is_channels_last = old_desc.transpose(0, 1).is_nhwc(); | |
701 | if (is_channels_last) { | |
708 | auto old_desc_trans = old_desc.transpose(0, 1); | |
709 | if (old_desc_trans.is_nhwc()) { | |
702 | 710 | // giohw (acbde) => gihwo (acdeb) |
703 | 711 | grouped_desc = grouped_desc.to_format(format_tag::acdeb); |
712 | } else if (old_desc_trans.is_ndhwc()) { | |
713 | // giodhw (acbdef) => gidhwo (acdefb) | |
714 | // TODO: onednn doesn't have the tag of acdefb for now | |
715 | // grouped_desc = grouped_desc.to_format(format_tag::acdefb); | |
716 | // | |
717 | // work around by re-create desc based on dims and strides. | |
718 | auto ddims = grouped_desc.get_dims(); | |
719 | auto ddata_type = grouped_desc.get_data_type(); | |
720 | auto g = groups; | |
721 | auto o = ddims[0] / g; | |
722 | auto i = ddims[1]; | |
723 | auto d = ddims[2]; | |
724 | auto h = ddims[3]; | |
725 | auto w = ddims[4]; | |
726 | desc new_desc{{g, o, i, d, h, w}, ddata_type, { | |
727 | /*g*/i * d * h * w * o, | |
728 | /*o*/1, | |
729 | /*i*/d * h * w *o, | |
730 | /*d*/h * w * o, | |
731 | /*h*/w * o, | |
732 | /*w*/o}}; | |
733 | grouped_desc = new_desc; | |
704 | 734 | } |
705 | 735 | } else { |
706 | 736 | // conv: judge whether is channels last on oihw format |
707 | auto is_channels_last = old_desc.is_nhwc(); | |
708 | if (is_channels_last) { | |
737 | if (old_desc.is_nhwc()) { | |
709 | 738 | // goihw (abcde) => gohwi (abdec) |
710 | 739 | grouped_desc = grouped_desc.to_format(format_tag::abdec); |
740 | } else if (old_desc.is_ndhwc()) { | |
741 | // goidhw (abcdef) => godhwi (abdefc) | |
742 | grouped_desc = grouped_desc.to_format(format_tag::abdefc); | |
711 | 743 | } |
712 | 744 | } |
713 | 745 | |
727 | 759 | tensor &reshape(const dims &adims) { |
728 | 760 | IDEEP_ENFORCE(has_same_volume(adims), "reshape to incompatible shape"); |
729 | 761 | |
730 | // count the number of non-one dimensions | |
731 | // e.g. the actual rank of shape [1, 1, 35, 1] is one | |
732 | auto actual_rank = [](const dims &shape) { | |
733 | auto cnt = 0; | |
734 | for (auto d : shape) if (d > 1) cnt++; | |
735 | return cnt; | |
762 | auto need_convert_to_default_format = [](const desc &src_desc, | |
763 | const dims &shape) { | |
764 | // if src_desc is default format, do not need to conver format. | |
765 | if (src_desc.is_default()) { | |
766 | return false; | |
767 | } else { | |
768 | // count the number of non-one dimensions | |
769 | // e.g. the squeezed_ndims of shape [1, 1, 35, 1] is one. | |
770 | auto squeezed_ndims = 0; | |
771 | for (auto d : shape) | |
772 | if (d > 1) | |
773 | squeezed_ndims++; | |
774 | if (squeezed_ndims == 0) | |
775 | return false; // [1, 1, ...] | |
776 | // For squeezed_ndims is one, src_desc is plain format | |
777 | // or src_desc is block format, but the blocking dim's size is not one, | |
778 | // for example, aBcd16b, the shape is [1, 2048, 1, 1], the blocking dim | |
779 | // is the second dimension, the strid is one for the blockind dim, the | |
780 | // format does not matter for data idexing. But for aBc16b with shape | |
781 | // [1, 1, 7], we need do format change even the squeezed_ndims is one, | |
782 | // because the last dimension is not contiguous, the stride is 16. | |
783 | if (squeezed_ndims == 1) { | |
784 | if (src_desc.is_plain()) | |
785 | return false; | |
786 | // block format, only one dim is blocked, and the size of blocked dim > 1. | |
787 | auto block_desc = src_desc.blocking_desc(); | |
788 | if (block_desc.inner_nblks == 1 && shape[block_desc.inner_idxs[0]] > 1) { | |
789 | return false; | |
790 | } | |
791 | } | |
792 | return true; | |
793 | } | |
736 | 794 | }; |
737 | ||
738 | 795 | auto old_dims = get_dims(); |
739 | 796 | if (adims != old_dims) { |
740 | // Since we are going to set the desc to new dims with default format, | |
741 | // we have to make sure it's already in default format. In particular, | |
742 | // tensor format does not matter if actual rank <= 1 | |
743 | if (!get_desc().is_default() && actual_rank(old_dims) > 1) { | |
797 | if (need_convert_to_default_format(get_desc(), old_dims)) { | |
744 | 798 | to_default_format(); |
745 | 799 | } |
746 | 800 | // set desc with default format |
763 | 817 | } |
764 | 818 | |
765 | 819 | inline void reorder_from(const tensor &src) { |
766 | dnnl::reorder(src, *this) | |
767 | .execute(stream::default_stream(), const_cast<tensor &>(src), *this); | |
820 | auto op_attr = dnnl::primitive_attr(); | |
821 | op_attr.set_scratchpad_mode(dnnl::scratchpad_mode::user); | |
822 | auto pd = dnnl::reorder::primitive_desc(src, *this, op_attr); | |
823 | ||
824 | tensor scratchpad(pd.scratchpad_desc()); | |
825 | dnnl::reorder(pd).execute( | |
826 | stream::default_stream(), | |
827 | {{DNNL_ARG_FROM, const_cast<tensor&>(src)}, | |
828 | {DNNL_ARG_TO, *this}, | |
829 | {DNNL_ARG_SCRATCHPAD, scratchpad}}); | |
768 | 830 | } |
769 | 831 | |
770 | 832 | inline void reorder_to(tensor &dst, const attr_t &aattr = attr_t()) const { |
771 | auto pd = dnnl::reorder::primitive_desc(*this, dst, aattr); | |
772 | dnnl::reorder(pd) | |
773 | .execute(stream::default_stream(), const_cast<tensor &>(*this), dst); | |
833 | attr_t op_attr = aattr; | |
834 | op_attr.set_scratchpad_mode(dnnl::scratchpad_mode::user); | |
835 | auto pd = dnnl::reorder::primitive_desc(*this, dst, op_attr); | |
836 | ||
837 | tensor scratchpad(pd.scratchpad_desc()); | |
838 | dnnl::reorder(pd).execute( | |
839 | stream::default_stream(), | |
840 | {{DNNL_ARG_FROM, const_cast<tensor&>(*this)}, | |
841 | {DNNL_ARG_TO, dst}, | |
842 | {DNNL_ARG_SCRATCHPAD, scratchpad}}); | |
774 | 843 | } |
775 | 844 | |
776 | 845 | /// Convert the tensor to public format, and f32 data type by default |
864 | 933 | void insert_submemory(const tensor &src, const dims &adims, |
865 | 934 | const dims &offsets, const attr_t &attr = attr_t()) { |
866 | 935 | auto view = get_desc().submemory_desc(adims, offsets); |
867 | dnnl::reorder({src.get_engine(), src.get_desc(), get_engine(), view, attr}) | |
868 | .execute(stream::default_stream(), const_cast<tensor &>(src), *this); | |
936 | ||
937 | attr_t op_attr = attr; | |
938 | op_attr.set_scratchpad_mode(dnnl::scratchpad_mode::user); | |
939 | ||
940 | auto pd = dnnl::reorder::primitive_desc( | |
941 | src.get_engine(), src.get_desc(), get_engine(), view, op_attr); | |
942 | tensor scratchpad(pd.scratchpad_desc()); | |
943 | dnnl::reorder(pd).execute( | |
944 | stream::default_stream(), | |
945 | {{DNNL_ARG_FROM, const_cast<tensor&>(src)}, | |
946 | {DNNL_ARG_TO, *this}, | |
947 | {DNNL_ARG_SCRATCHPAD, scratchpad}}); | |
869 | 948 | } |
870 | 949 | |
871 | 950 | // reorder part of this tensor to dst |
872 | 951 | void extract_submemory(tensor &dst, const dims &adims, const dims &offsets, |
873 | 952 | const attr_t &attr = attr_t()) const { |
874 | 953 | auto view = get_desc().submemory_desc(adims, offsets); |
875 | dnnl::reorder({get_engine(), view, dst.get_engine(), dst.get_desc(), attr}) | |
876 | .execute(stream::default_stream(), const_cast<tensor &>(*this), dst); | |
954 | ||
955 | attr_t op_attr = attr; | |
956 | op_attr.set_scratchpad_mode(dnnl::scratchpad_mode::user); | |
957 | ||
958 | auto pd = dnnl::reorder::primitive_desc( | |
959 | get_engine(), view, dst.get_engine(), dst.get_desc(), op_attr); | |
960 | tensor scratchpad(pd.scratchpad_desc()); | |
961 | dnnl::reorder(pd).execute( | |
962 | stream::default_stream(), | |
963 | {{DNNL_ARG_FROM, const_cast<tensor&>(*this)}, | |
964 | {DNNL_ARG_TO, dst}, | |
965 | {DNNL_ARG_SCRATCHPAD, scratchpad}}); | |
877 | 966 | } |
878 | 967 | |
879 | 968 | // simple api for extract_submemory |
908 | 997 | bool has_zero_point() const { return zero_point_ != nullptr && !zero_point_->empty(); } |
909 | 998 | |
910 | 999 | /// Return the zero_point of this param. |
911 | const std::vector<int32_t> &get_zero_point() const { return *zero_point_.get(); } | |
1000 | const zero_point_t &get_zero_point() const { return *zero_point_.get(); } | |
912 | 1001 | |
913 | 1002 | /// Set new scale into param |
914 | void set_zero_point(const std::vector<int32_t> &zp) { zero_point_.reset(new std::vector<int32_t>(zp)); } | |
1003 | void set_zero_point(const zero_point_t &zp) { zero_point_.reset(new zero_point_t(zp)); } | |
915 | 1004 | |
916 | 1005 | /// Need reorder if current param used by non DNNL routines. |
917 | 1006 | // legacy API for caffe2 |
993 | 1082 | |
994 | 1083 | std::shared_ptr<tensor> workspace_; |
995 | 1084 | std::shared_ptr<scale_t> scale_; |
996 | std::shared_ptr<std::vector<int32_t>> zero_point_; | |
1085 | std::shared_ptr<zero_point_t> zero_point_; | |
997 | 1086 | std::shared_ptr<void> buffer_; |
998 | 1087 | engine eng_; |
999 | 1088 | }; |
297 | 297 | arr[i] = static_cast<T>(val); |
298 | 298 | } |
299 | 299 | |
300 | } | |
301 | } | |
302 | #endif | |
300 | inline int set_verbose(int level) { | |
301 | dnnl::status ret = dnnl::set_verbose(level); | |
302 | return ret == dnnl::status::success; | |
303 | } | |
304 | ||
305 | } // namespace utils | |
306 | } // namespace ideep | |
307 | #endif |
1 | 1 | #define _IDEEP_PIN_SINGLETONS_HPP_ |
2 | 2 | |
3 | 3 | #include "ideep.hpp" |
4 | #include "mkldnn_compat.hpp" | |
5 | ||
4 | 6 | |
5 | 7 | namespace ideep { |
6 | 8 | |
18 | 20 | RegisterEngineAllocator(engine& eng, |
19 | 21 | const std::function<void*(size_t)>& malloc, |
20 | 22 | const std::function<void(void*)>& free) { |
23 | // change runtime flag start with "MKLDNN_" to "DNNL_" | |
24 | EnvSetter env_setter; | |
21 | 25 | eng.set_allocator(malloc, free); |
22 | 26 | } |
23 | 27 | }; |
0 | #ifndef _MKLDNN_COMPAT_HPP_ | |
1 | #define _MKLDNN_COMPAT_HPP_ | |
2 | ||
3 | #include "ideep.hpp" | |
4 | ||
5 | #ifdef _WIN32 | |
6 | #include <windows.h> | |
7 | #endif | |
8 | ||
9 | namespace ideep { | |
10 | struct EnvSetter { | |
11 | // oneDNN will only accept runtime flags which start with "DNNL_/ONEDNN_" from version v2.5. | |
12 | // If user setting runtime flags start with MKLDNN_, we need to keep it works for a while before | |
13 | // we finally deprecated it. | |
14 | // This is a compatibility layer for runtime flags start with MKLDNN_ | |
15 | EnvSetter(){ | |
16 | for (auto name: mkldnn_runtime_flags){ | |
17 | query_and_set_env(name.c_str()); | |
18 | } | |
19 | } | |
20 | ||
21 | void query_and_set_env(std::string name){ | |
22 | std::string value; | |
23 | if (getenv_user(name, value)){ | |
24 | std::string dnnl_name = "DNNL_"; | |
25 | dnnl_name += std::string(name); | |
26 | #ifdef _WIN32 | |
27 | SetEnvironmentVariable(dnnl_name.c_str(), value.c_str()); | |
28 | #else | |
29 | setenv(dnnl_name.c_str(), value.c_str(), 1); | |
30 | #endif | |
31 | } | |
32 | } | |
33 | ||
34 | bool getenv_user(std::string name, std::string& value) { | |
35 | std::string name_str = "MKLDNN_" + std::string(name); | |
36 | size_t value_length = 0; | |
37 | const char* p = getenv(name_str.c_str()); | |
38 | value_length = p == nullptr ? 0 : strlen(p); | |
39 | if (value_length > 0) { | |
40 | value += std::string(p); | |
41 | return true; | |
42 | } | |
43 | return false; | |
44 | } | |
45 | ||
46 | // current runtime flags in mkldnn | |
47 | const std::vector<std::string> mkldnn_runtime_flags = { | |
48 | "VERBOSE", | |
49 | "ITT_TASK_LEVEL", | |
50 | "PRIMITIVE_CACHE_CAPACITY", | |
51 | "SC_STACK_SIZE", | |
52 | "SC_SOFT_STACK_LIMIT", | |
53 | "JIT_PROFILE", | |
54 | "VERBOSE_TIMESTAMP", | |
55 | "DEFAULT_FPMATH_MODE", | |
56 | "MAX_CPU_ISA", | |
57 | "CPU_ISA_HINTS" | |
58 | }; | |
59 | ||
60 | }; | |
61 | ||
62 | } | |
63 | ||
64 | #endif | |
65 |