conv2D_1.cpp 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242
  1. #include "conv2D_1.hpp"
  2. namespace tf_lib {
  3. volatile int instances = 0;
  4. volatile int inParallel = 0;
  5. std::mutex printMu;
  6. ShapeFunction conv2d_shape_fn = [](InferenceContext* c) {
  7. //INPUT: NHWC
  8. //KERNEL: HWIO
  9. //OUTPUT: NHWC
  10. constexpr int num_spatial_dims = 2;
  11. TensorFormat data_format;
  12. FormatFromString("NHWC", &data_format);
  13. FilterTensorFormat filter_format;
  14. FilterFormatFromString("HWIO", &filter_format);
  15. ShapeHandle input_shape, filter_shape, output_shape;
  16. TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input_shape));
  17. TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 4, &filter_shape));
  18. DimensionHandle batch_size_dim;
  19. DimensionHandle input_depth_dim;
  20. gtl::InlinedVector<DimensionHandle, 2> input_spatial_dims(2);
  21. TF_RETURN_IF_ERROR(DimensionsFromShape(
  22. input_shape, data_format, &batch_size_dim,
  23. absl::MakeSpan(input_spatial_dims), &input_depth_dim, c));
  24. DimensionHandle output_depth_dim = c->Dim(
  25. filter_shape, GetFilterDimIndex<num_spatial_dims>(filter_format, 'O'));
  26. /*
  27. DimensionHandle filter_rows_dim = c->Dim(
  28. filter_shape, GetFilterDimIndex<num_spatial_dims>(filter_format, 'H'));
  29. DimensionHandle filter_cols_dim = c->Dim(
  30. filter_shape, GetFilterDimIndex<num_spatial_dims>(filter_format, 'W'));
  31. DimensionHandle filter_input_depth_dim = c->Dim(
  32. filter_shape, GetFilterDimIndex<num_spatial_dims>(filter_format, 'I'));
  33. */
  34. DimensionHandle output_rows, output_cols, output_channels;
  35. c->Subtract(input_spatial_dims[0], 4, &output_rows);
  36. c->Subtract(input_spatial_dims[1], 4, &output_cols);
  37. c->Subtract(output_depth_dim, 0, &output_channels);
  38. std::vector<DimensionHandle> out_dims(4);
  39. out_dims[0] = batch_size_dim;
  40. out_dims[1] = output_rows;
  41. out_dims[2] = output_cols;
  42. out_dims[3] = output_channels;
  43. output_shape = c->MakeShape(out_dims);
  44. c->set_output(0, output_shape);
  45. return Status::OK();
  46. };
  47. Conv2DOp::Conv2DOp(OpKernelConstruction* context) : AsyncOpKernel(context) {
  48. instance = instances++;
  49. };
  50. void Conv2DOp::ComputeAsync(OpKernelContext* context, DoneCallback done) {
  51. init();
  52. // ############ TensorFlow namespace #############
  53. // Input tensor is of the following dimensions:
  54. // [ batch, in_rows, in_cols, in_depth ]
  55. const Tensor& input = context->input(0);
  56. // Input filter is of the following dimensions:
  57. // [ filter_rows, filter_cols, in_depth, out_depth]
  58. const Tensor& kernel = context->input(1);
  59. TensorShape kernel_shape = kernel.shape();
  60. TensorShape input_shape = input.shape();
  61. OP_REQUIRES_ASYNC(context, input_shape.dim_size(1) == 228, errors::InvalidArgument("Unsupported input height: ", input_shape.dim_size(1)), done);
  62. OP_REQUIRES_ASYNC(context, input_shape.dim_size(2) == 228, errors::InvalidArgument("Unsupported input width: ", input_shape.dim_size(2)), done);
  63. OP_REQUIRES_ASYNC(context, kernel_shape.dim_size(0) == 5, errors::InvalidArgument("Unsupported kernel height: ", kernel_shape.dim_size(0)), done);
  64. OP_REQUIRES_ASYNC(context, kernel_shape.dim_size(1) == 5, errors::InvalidArgument("Unsupported kernel width: ", kernel_shape.dim_size(1)), done);
  65. OP_REQUIRES_ASYNC(context, kernel_shape.dim_size(2) == input_shape.dim_size(3),
  66. errors::InvalidArgument("kernel channels != input channels: ", kernel_shape.dim_size(2), " != ", input_shape.dim_size(3)), done);
  67. int batchSize = input_shape.dim_size(0);
  68. int channels = input_shape.dim_size(3);
  69. int outputChannels = kernel_shape.dim_size(3);
  70. // create output tensor
  71. TensorShape output_shape;
  72. const int32 dims[] = {batchSize, outputSize, outputSize, outputChannels};
  73. TensorShapeUtils::MakeShape(dims, 4, &output_shape);
  74. output_shape.set_dim(0, batchSize);
  75. output_shape.set_dim(1, outputSize);
  76. output_shape.set_dim(2, outputSize);
  77. output_shape.set_dim(3, outputChannels);
  78. // Output tensor is of the following dimensions:
  79. // [ in_batch, out_rows, out_cols, out_depth ]
  80. Tensor* output = nullptr;
  81. OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output));
  82. // get data references
  83. auto input_tensor = input.tensor<float, 4>();
  84. auto kernel_tensor = kernel.tensor<float, 4>();
  85. auto output_tensor = output->tensor<float, 4>();
  86. // ############ FPGA communications library #############
  87. auto worker = connectionManager.createWorker(Module::conv2D_5x5_Module, batchSize * channels * outputChannels);
  88. {
  89. worker->setJobTimeout(milliseconds(300));
  90. worker->setRetryCount(10);
  91. auto jobs = worker->getJobList();
  92. for(int sample=0; sample<batchSize; sample++) {
  93. for(int outputChannel=0; outputChannel<outputChannels; outputChannel++) {
  94. for(int channel=0; channel<channels; channel++) {
  95. // get each job
  96. auto job = jobs->getJob(sample * outputChannels * channels + outputChannel * channels + channel);
  97. // write kernel to job
  98. for(int x=0; x<kernelSize; x++) {
  99. for(int y=0; y<kernelSize; y++) {
  100. job->setPayload(y*kernelSize + x, *((uint32_t*)&kernel_tensor(y, x, channel, outputChannel)));
  101. }
  102. }
  103. // write input pixels to job
  104. for(int x=0; x<sizeWithBorder; x++) {
  105. for(int y=0; y<sizeWithBorder; y++) {
  106. job->setPayload(kernelSize*kernelSize + y*sizeWithBorder + x, *((uint32_t*)&input_tensor(sample, y, x, channel)));
  107. }
  108. }
  109. job->setReady();
  110. }
  111. }
  112. }
  113. }
  114. worker->setDoneCallback([output_tensor, worker, done, batchSize, channels, outputChannels, this]{
  115. auto jobs = worker->getJobList();
  116. for(int sample=0; sample<batchSize; sample++) {
  117. for(int outputChannel=0; outputChannel<outputChannels; outputChannel++) {
  118. //set output matrix to zero
  119. for(int x=0; x<outputSize; x++) {
  120. for(int y=0; y<outputSize; y++) {
  121. output_tensor(sample, y, x, outputChannel) = 0;
  122. }
  123. }
  124. //accumulate the pixels of all output channels
  125. for(int channel=0; channel<channels; channel++) {
  126. auto job = jobs->getJob(sample * outputChannels * channels + outputChannel * channels + channel);
  127. for(int x=0; x<outputSize; x++) {
  128. for(int y=0; y<outputSize; y++) {
  129. uint32_t pixel = job->getResponsePayload((y+border*2)*sizeWithBorder + (x+border*2));
  130. output_tensor(sample, y, x, outputChannel) += *((float*)&pixel);
  131. }
  132. }
  133. }
  134. }
  135. }
  136. done();
  137. connectionManager.removeFinishedWorkers();
  138. });
  139. worker->startAsync();
  140. }
  141. static Status MatMulGradHelper(FunctionDef* g, const string& opname,
  142. const string& attr_adj_x,
  143. const string& attr_adj_y, const string& x0,
  144. bool ax0, const string& x1, bool ax1,
  145. const string& y0, bool ay0, const string& y1,
  146. bool ay1) {
  147. // The final outputs are "dx" and "dy". If we're broadcasting compute
  148. // intermediate nodes for now.
  149. std::vector<FDH::Node> nodes = {
  150. {{("dx")},
  151. opname,
  152. {x0, x1},
  153. {{"T", "$T"}, {attr_adj_x, ax0}, {attr_adj_y, ax1}}},
  154. {{("dy")},
  155. opname,
  156. {y0, y1},
  157. {{"T", "$T"}, {attr_adj_x, ay0}, {attr_adj_y, ay1}}},
  158. };
  159. *g = FDH::Define(
  160. // Arg defs
  161. {"x: T", "y: T", "dz: T"},
  162. // Ret val defs
  163. {"dx: T", "dy: T"},
  164. // Attr defs
  165. {{"T: {half, float, double}"}},
  166. // Nodes
  167. nodes);
  168. return Status::OK();
  169. }
  170. Status MatMulGrad(const AttrSlice& attrs, FunctionDef* g) {
  171. const string opname = "MyConv2D";
  172. const string attr_adj_x = "transpose_a";
  173. const string attr_adj_y = "transpose_b";
  174. DataType T;
  175. TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "T", &T));
  176. if (T == DT_COMPLEX64 || T == DT_COMPLEX128) {
  177. return errors::Unimplemented(
  178. "MatMul gradient for complex is not supported yet.");
  179. }
  180. bool ta;
  181. bool tb;
  182. TF_RETURN_IF_ERROR(GetNodeAttr(attrs, attr_adj_x, &ta));
  183. TF_RETURN_IF_ERROR(GetNodeAttr(attrs, attr_adj_y, &tb));
  184. if (!ta && !tb) {
  185. return MatMulGradHelper(g, opname, attr_adj_x, attr_adj_y, "dz", false, "y",
  186. true, "x", true, "dz", false);
  187. }
  188. if (!ta && tb) {
  189. return MatMulGradHelper(g, opname, attr_adj_x, attr_adj_y, "dz", false, "y",
  190. false, "dz", true, "x", false);
  191. }
  192. if (ta && !tb) {
  193. return MatMulGradHelper(g, opname, attr_adj_x, attr_adj_y, "y", false, "dz",
  194. true, "x", false, "dz", false);
  195. }
  196. CHECK(ta && tb);
  197. return MatMulGradHelper(g, opname, attr_adj_x, attr_adj_y, "y", true, "dz",
  198. true, "dz", true, "x", true);
  199. }
  200. REGISTER_OP("MyConv2D_1")
  201. .Input("input: float")
  202. .Input("filter: float")
  203. .Output("output: float")
  204. .SetShapeFn(conv2d_shape_fn);
  205. REGISTER_KERNEL_BUILDER(Name("MyConv2D_1").Device(DEVICE_CPU), Conv2DOp);
  206. REGISTER_OP_GRADIENT("MyConv2D_1", MatMulGrad);
  207. }