dummyOp.cpp 1.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253
  1. #include "dummyOp.hpp"
  2. namespace tf_lib {
  3. DummyOp::DummyOp(OpKernelConstruction* context) : AsyncOpKernel(context) {
  4. };
  5. void DummyOp::ComputeAsync(OpKernelContext* context, DoneCallback done) {
  6. init();
  7. // Input tensor is of the following dimensions:
  8. // [ batch, in_rows, in_cols, in_depth ]
  9. const Tensor& input = context->input(0);
  10. ///const int32 *p = input.flat<int32>().data();
  11. TensorShape input_shape = input.shape();
  12. TensorShape output_shape;
  13. const int32 dims[] = {dataLength};
  14. TensorShapeUtils::MakeShape(dims, 1, &output_shape);
  15. output_shape.set_dim(0, dims[0]);
  16. Tensor* output = nullptr;
  17. OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output));
  18. auto input_tensor = input.tensor<int32, 1>();
  19. auto output_tensor = output->tensor<int32, 1>();
  20. auto worker = connectionManager.createWorker(Module::dummyModule);
  21. worker->setJobTimeout(milliseconds(100));
  22. worker->setRetryCount(10);
  23. {
  24. auto job = worker->getJobList()->getJob(0);
  25. for(size_t i=0; i<job->getPayloadSize(); i++) {
  26. job->setPayload(i, input_tensor(i));
  27. job->setReady();
  28. }
  29. }
  30. worker->setDoneCallback([output_tensor, worker, done]{
  31. auto job = worker->getJobList()->getJob(0);
  32. for(size_t i=0; i<job->getResponsePayloadSize(); i++) {
  33. output_tensor(i) = job->getResponsePayload(i);
  34. }
  35. done();
  36. connectionManager.removeFinishedWorkers();
  37. });
  38. worker->startAsync();
  39. }
  40. }