dummyBigOp.cpp 1.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051
  1. #include "dummyBigOp.hpp"
  2. namespace tf_lib {
  3. DummyBigOp::DummyBigOp(OpKernelConstruction* context) : AsyncOpKernel(context) {
  4. };
  5. void DummyBigOp::ComputeAsync(OpKernelContext* context, DoneCallback done) {
  6. init();
  7. // Input tensor is of the following dimensions:
  8. // [ batch, in_rows, in_cols, in_depth ]
  9. const Tensor& input = context->input(0);
  10. TensorShape input_shape = input.shape();
  11. TensorShape output_shape;
  12. const int32 dims[] = {dataLength};
  13. TensorShapeUtils::MakeShape(dims, 1, &output_shape);
  14. output_shape.set_dim(0, dims[0]);
  15. Tensor* output = nullptr;
  16. OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output));
  17. auto input_tensor = input.tensor<int32, 1>();
  18. auto output_tensor = output->tensor<int32, 1>();
  19. auto worker = connectionManager.createWorker(Module::dummyBigModule);
  20. worker->setJobTimeout(milliseconds(100));
  21. worker->setRetryCount(10);
  22. {
  23. auto job = worker->getJobList()->getJob(0);
  24. for(size_t i=0; i<job->getPayloadSize(); i++) {
  25. job->setPayload(i, input_tensor(i));
  26. job->setReady();
  27. }
  28. }
  29. worker->setDoneCallback([output_tensor, worker, done]{
  30. auto job = worker->getJobList()->getJob(0);
  31. for(size_t i=0; i<job->getResponsePayloadSize(); i++) {
  32. output_tensor(i) = job->getResponsePayload(i);
  33. }
  34. done();
  35. connectionManager.removeFinishedWorkers();
  36. });
  37. worker->startAsync();
  38. }
  39. }