matMul.cc 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165
  1. #include "tensorflow/core/framework/op.h"
  2. #include "tensorflow/core/framework/shape_inference.h"
  3. #include "tensorflow/core/framework/function.h"
  4. #include "tensorflow/core/lib/math/math_util.h"
  5. using namespace tensorflow;
  6. typedef FunctionDefHelper FDH;
  7. REGISTER_OP("MyMatMul")
  8. .Input("to_zero: int32")
  9. .Output("zeroed: int32")
  10. .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {
  11. c->set_output(0, c->input(0));
  12. return Status::OK();
  13. });
  14. REGISTER_OP("MyConv2D")
  15. .Input("input: int32")
  16. .Input("filter: int32")
  17. .Output("output: int32")
  18. .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {
  19. c->set_output(0, c->input(0));
  20. return Status::OK();
  21. });
  22. #include "tensorflow/core/framework/op_kernel.h"
  23. using namespace tensorflow;
  24. /*
  25. class Conv2DOp : public OpKernel {
  26. public:
  27. explicit Conv2DOp(OpKernelConstruction* context) : OpKernel(context) {}
  28. void Compute(OpKernelContext* context) override {
  29. // Grab the input tensor
  30. const Tensor& input_tensor = context->input(0);
  31. auto input = input_tensor.flat<int32>();
  32. printf("call n: %d\n", n++);
  33. // Create an output tensor
  34. Tensor* output_tensor = NULL;
  35. OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(),
  36. &output_tensor));
  37. auto output_flat = output_tensor->flat<int32>();
  38. // Set all but the first element of the output tensor to 0.
  39. const int N = input.size();
  40. for (int i = 1; i < N; i++) {
  41. output_flat(i) = 0;
  42. }
  43. // Preserve the first input value if possible.
  44. if (N > 0) output_flat(0) = input(0);
  45. }
  46. int n = 0;
  47. };
  48. */
  49. class Conv2DOp : public OpKernel {
  50. public:
  51. explicit Conv2DOp(OpKernelConstruction* context) : OpKernel(context) {
  52. }
  53. void Compute(OpKernelContext* context) override {
  54. // Input tensor is of the following dimensions:
  55. // [ batch, in_rows, in_cols, in_depth ]
  56. const Tensor& input = context->input(0);
  57. // Input filter is of the following dimensions:
  58. // [ filter_rows, filter_cols, in_depth, out_depth]
  59. const Tensor& filter = context->input(1);
  60. TensorShape out_shape = input.shape();
  61. // Output tensor is of the following dimensions:
  62. // [ in_batch, out_rows, out_cols, out_depth ]
  63. Tensor* output = nullptr;
  64. OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &output));
  65. std::cout << "Conv2D" << std::endl;
  66. // If there is nothing to compute, return.
  67. if (out_shape.num_elements() == 0) {
  68. return;
  69. }
  70. }
  71. private:
  72. //LaunchConv2DOp<Device, T> launcher_;
  73. TF_DISALLOW_COPY_AND_ASSIGN(Conv2DOp);
  74. };
  75. REGISTER_KERNEL_BUILDER(Name("MyConv2D").Device(DEVICE_CPU), Conv2DOp);
  76. static Status MatMulGradHelper(FunctionDef* g, const string& opname,
  77. const string& attr_adj_x,
  78. const string& attr_adj_y, const string& x0,
  79. bool ax0, const string& x1, bool ax1,
  80. const string& y0, bool ay0, const string& y1,
  81. bool ay1) {
  82. // The final outputs are "dx" and "dy". If we're broadcasting compute
  83. // intermediate nodes for now.
  84. std::vector<FDH::Node> nodes = {
  85. {{("dx")},
  86. opname,
  87. {x0, x1},
  88. {{"T", "$T"}, {attr_adj_x, ax0}, {attr_adj_y, ax1}}},
  89. {{("dy")},
  90. opname,
  91. {y0, y1},
  92. {{"T", "$T"}, {attr_adj_x, ay0}, {attr_adj_y, ay1}}},
  93. };
  94. *g = FDH::Define(
  95. // Arg defs
  96. {"x: T", "y: T", "dz: T"},
  97. // Ret val defs
  98. {"dx: T", "dy: T"},
  99. // Attr defs
  100. {{"T: {half, float, double}"}},
  101. // Nodes
  102. nodes);
  103. return Status::OK();
  104. }
  105. Status MatMulGrad(const AttrSlice& attrs, FunctionDef* g) {
  106. const string opname = "MyMatMul";
  107. const string attr_adj_x = "transpose_a";
  108. const string attr_adj_y = "transpose_b";
  109. DataType T;
  110. TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "T", &T));
  111. if (T == DT_COMPLEX64 || T == DT_COMPLEX128) {
  112. return errors::Unimplemented(
  113. "MatMul gradient for complex is not supported yet.");
  114. }
  115. bool ta;
  116. bool tb;
  117. TF_RETURN_IF_ERROR(GetNodeAttr(attrs, attr_adj_x, &ta));
  118. TF_RETURN_IF_ERROR(GetNodeAttr(attrs, attr_adj_y, &tb));
  119. if (!ta && !tb) {
  120. return MatMulGradHelper(g, opname, attr_adj_x, attr_adj_y, "dz", false, "y",
  121. true, "x", true, "dz", false);
  122. }
  123. if (!ta && tb) {
  124. return MatMulGradHelper(g, opname, attr_adj_x, attr_adj_y, "dz", false, "y",
  125. false, "dz", true, "x", false);
  126. }
  127. if (ta && !tb) {
  128. return MatMulGradHelper(g, opname, attr_adj_x, attr_adj_y, "y", false, "dz",
  129. true, "x", false, "dz", false);
  130. }
  131. CHECK(ta && tb);
  132. return MatMulGradHelper(g, opname, attr_adj_x, attr_adj_y, "y", true, "dz",
  133. true, "dz", true, "x", true);
  134. }
  135. REGISTER_OP_GRADIENT("MyConv2D", MatMulGrad);