zero_out.cc 1.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849
  1. #include "tensorflow/core/framework/op.h"
  2. #include "tensorflow/core/framework/shape_inference.h"
  3. #include "tensorflow/core/framework/function.h"
  4. using namespace tensorflow;
  5. REGISTER_OP("ZeroOut")
  6. .Input("to_zero: int32")
  7. .Output("zeroed: int32")
  8. .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {
  9. c->set_output(0, c->input(0));
  10. return Status::OK();
  11. });
  12. #include "tensorflow/core/framework/op_kernel.h"
  13. using namespace tensorflow;
  14. class ZeroOutOp : public OpKernel {
  15. public:
  16. explicit ZeroOutOp(OpKernelConstruction* context) : OpKernel(context) {}
  17. void Compute(OpKernelContext* context) override {
  18. // Grab the input tensor
  19. const Tensor& input_tensor = context->input(0);
  20. auto input = input_tensor.flat<int32>();
  21. printf("call n: %d\n", n++);
  22. // Create an output tensor
  23. Tensor* output_tensor = NULL;
  24. OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(),
  25. &output_tensor));
  26. auto output_flat = output_tensor->flat<int32>();
  27. // Set all but the first element of the output tensor to 0.
  28. const int N = input.size();
  29. for (int i = 1; i < N; i++) {
  30. output_flat(i) = 0;
  31. }
  32. // Preserve the first input value if possible.
  33. if (N > 0) output_flat(0) = input(0);
  34. }
  35. int n = 0;
  36. };
  37. REGISTER_KERNEL_BUILDER(Name("ZeroOut").Device(DEVICE_CPU), ZeroOutOp);