conv2D.cpp 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217
  1. #include "conv2D.hpp"
  2. namespace tf_lib {
  3. volatile int instances = 0;
  4. volatile int inParallel = 0;
  5. std::mutex printMu;
  6. ShapeFunction conv2d_shape_fn = [](InferenceContext* c) {
  7. //INPUT: NHWC
  8. //KERNEL: HWIO
  9. //OUTPUT: NHWC
  10. constexpr int num_spatial_dims = 2;
  11. TensorFormat data_format;
  12. FormatFromString("NHWC", &data_format);
  13. FilterTensorFormat filter_format;
  14. FilterFormatFromString("HWIO", &filter_format);
  15. ShapeHandle input_shape, filter_shape, output_shape;
  16. TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input_shape));
  17. TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 4, &filter_shape));
  18. DimensionHandle batch_size_dim;
  19. DimensionHandle input_depth_dim;
  20. gtl::InlinedVector<DimensionHandle, 2> input_spatial_dims(2);
  21. TF_RETURN_IF_ERROR(DimensionsFromShape(
  22. input_shape, data_format, &batch_size_dim,
  23. absl::MakeSpan(input_spatial_dims), &input_depth_dim, c));
  24. DimensionHandle output_depth_dim = c->Dim(
  25. filter_shape, GetFilterDimIndex<num_spatial_dims>(filter_format, 'O'));
  26. /*
  27. DimensionHandle filter_rows_dim = c->Dim(
  28. filter_shape, GetFilterDimIndex<num_spatial_dims>(filter_format, 'H'));
  29. DimensionHandle filter_cols_dim = c->Dim(
  30. filter_shape, GetFilterDimIndex<num_spatial_dims>(filter_format, 'W'));
  31. */
  32. DimensionHandle filter_input_depth_dim = c->Dim(
  33. filter_shape, GetFilterDimIndex<num_spatial_dims>(filter_format, 'I'));
  34. DimensionHandle output_rows, output_cols, output_channels;
  35. c->Add(input_spatial_dims[0], 0, &output_rows);
  36. c->Add(input_spatial_dims[1], 0, &output_cols);
  37. c->Multiply(filter_input_depth_dim, output_depth_dim, &output_channels);
  38. std::vector<DimensionHandle> out_dims(4);
  39. out_dims[0] = batch_size_dim;
  40. out_dims[1] = output_rows;
  41. out_dims[2] = output_cols;
  42. out_dims[3] = output_channels;
  43. output_shape = c->MakeShape(out_dims);
  44. c->set_output(0, output_shape);
  45. return Status::OK();
  46. };
  47. Conv2DOp::Conv2DOp(OpKernelConstruction* context) : AsyncOpKernel(context) {
  48. instance = instances++;
  49. };
  50. void Conv2DOp::ComputeAsync(OpKernelContext* context, DoneCallback done) {
  51. init();
  52. // Input tensor is of the following dimensions:
  53. // [ batch, in_rows, in_cols, in_depth ]
  54. const Tensor& input = context->input(0);
  55. ///const int32 *p = input.flat<int32>().data();
  56. // Input filter is of the following dimensions:
  57. // [ filter_rows, filter_cols, in_depth, out_depth]
  58. const Tensor& kernel = context->input(1);
  59. TensorShape kernel_shape = kernel.shape();
  60. TensorShape input_shape = input.shape();
  61. int batchSize = input_shape.dim_size(0);
  62. int channels = input_shape.dim_size(3);
  63. int filters = kernel_shape.dim_size(3);
  64. TensorShape output_shape;
  65. const int32 dims[] = {batchSize, outputSize, outputSize, channels * filters};
  66. TensorShapeUtils::MakeShape(dims, 4, &output_shape);
  67. output_shape.set_dim(0, batchSize);
  68. output_shape.set_dim(1, outputSize);
  69. output_shape.set_dim(2, outputSize);
  70. output_shape.set_dim(3, channels * filters);
  71. //printMu.lock();
  72. //std::cout << output_shape.DebugString() << std::endl;
  73. //printMu.unlock();
  74. // Output tensor is of the following dimensions:
  75. // [ in_batch, out_rows, out_cols, out_depth ]
  76. Tensor* output = nullptr;
  77. OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output));
  78. auto input_tensor = input.tensor<float, 4>();
  79. auto kernel_tensor = kernel.tensor<float, 4>();
  80. auto output_tensor = output->tensor<float, 4>();
  81. auto worker = connectionManager.createWorker(Module::conv2D_5x5_Module, batchSize * channels * filters);
  82. {
  83. worker->setJobTimeout(milliseconds(300));
  84. worker->setRetryCount(10);
  85. auto jobs = worker->getJobList();
  86. for(int sample=0; sample<batchSize; sample++) {
  87. for(int channel=0; channel<channels; channel++) {
  88. for(int filter=0; filter<filters; filter++) {
  89. auto job = jobs->getJob(sample * channels * filters + channel * filters + filter);
  90. for(int x=0; x<5; x++) {
  91. for(int y=0; y<5; y++) {
  92. job->setPayload(5*5 + x*outputSize + y, *((uint32_t*)&kernel_tensor(filter, y, x, channel)));
  93. }
  94. }
  95. for(int x=0; x<outputSize; x++) {
  96. for(int y=0; y<outputSize; y++) {
  97. job->setPayload(5*5 + x*outputSize + y, *((uint32_t*)&input_tensor(sample, y, x, channel)));
  98. }
  99. }
  100. job->setReady();
  101. }
  102. }
  103. }
  104. }
  105. worker->setDoneCallback([output_tensor, worker, done, batchSize, channels, filters, this]{
  106. auto jobs = worker->getJobList();
  107. for(int sample=0; sample<batchSize; sample++) {
  108. for(int channel=0; channel<channels; channel++) {
  109. for(int filter=0; filter<filters; filter++) {
  110. auto job = jobs->getJob(sample * channels * filters + channel * filters + filter);
  111. for(int x=0; x<outputSize; x++) {
  112. for(int y=0; y<outputSize; y++) {
  113. output_tensor(sample, y, x, channel) = job->getResponsePayload(x*outputSize + y);
  114. }
  115. }
  116. }
  117. }
  118. }
  119. done();
  120. connectionManager.removeFinishedWorkers();
  121. });
  122. worker->startAsync();
  123. }
  124. static Status MatMulGradHelper(FunctionDef* g, const string& opname,
  125. const string& attr_adj_x,
  126. const string& attr_adj_y, const string& x0,
  127. bool ax0, const string& x1, bool ax1,
  128. const string& y0, bool ay0, const string& y1,
  129. bool ay1) {
  130. // The final outputs are "dx" and "dy". If we're broadcasting compute
  131. // intermediate nodes for now.
  132. std::vector<FDH::Node> nodes = {
  133. {{("dx")},
  134. opname,
  135. {x0, x1},
  136. {{"T", "$T"}, {attr_adj_x, ax0}, {attr_adj_y, ax1}}},
  137. {{("dy")},
  138. opname,
  139. {y0, y1},
  140. {{"T", "$T"}, {attr_adj_x, ay0}, {attr_adj_y, ay1}}},
  141. };
  142. *g = FDH::Define(
  143. // Arg defs
  144. {"x: T", "y: T", "dz: T"},
  145. // Ret val defs
  146. {"dx: T", "dy: T"},
  147. // Attr defs
  148. {{"T: {half, float, double}"}},
  149. // Nodes
  150. nodes);
  151. return Status::OK();
  152. }
  153. Status MatMulGrad(const AttrSlice& attrs, FunctionDef* g) {
  154. const string opname = "MyMatMul";
  155. const string attr_adj_x = "transpose_a";
  156. const string attr_adj_y = "transpose_b";
  157. DataType T;
  158. TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "T", &T));
  159. if (T == DT_COMPLEX64 || T == DT_COMPLEX128) {
  160. return errors::Unimplemented(
  161. "MatMul gradient for complex is not supported yet.");
  162. }
  163. bool ta;
  164. bool tb;
  165. TF_RETURN_IF_ERROR(GetNodeAttr(attrs, attr_adj_x, &ta));
  166. TF_RETURN_IF_ERROR(GetNodeAttr(attrs, attr_adj_y, &tb));
  167. if (!ta && !tb) {
  168. return MatMulGradHelper(g, opname, attr_adj_x, attr_adj_y, "dz", false, "y",
  169. true, "x", true, "dz", false);
  170. }
  171. if (!ta && tb) {
  172. return MatMulGradHelper(g, opname, attr_adj_x, attr_adj_y, "dz", false, "y",
  173. false, "dz", true, "x", false);
  174. }
  175. if (ta && !tb) {
  176. return MatMulGradHelper(g, opname, attr_adj_x, attr_adj_y, "y", false, "dz",
  177. true, "x", false, "dz", false);
  178. }
  179. CHECK(ta && tb);
  180. return MatMulGradHelper(g, opname, attr_adj_x, attr_adj_y, "y", true, "dz",
  181. true, "dz", true, "x", true);
  182. }
  183. }