helper.cpp 2.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465
  1. #include "helper.hpp"
  2. namespace tf_lib {
  3. using namespace tensorflow;
  4. using namespace tensorflow::shape_inference;
  5. Status DimensionsFromShape(ShapeHandle shape, TensorFormat format,
  6. DimensionHandle* batch_dim,
  7. gtl::MutableArraySlice<DimensionHandle> spatial_dims,
  8. DimensionHandle* filter_dim,
  9. InferenceContext* context) {
  10. const int32 rank = GetTensorDimsFromSpatialDims(spatial_dims.size(), format);
  11. // Batch.
  12. *batch_dim = context->Dim(shape, GetTensorBatchDimIndex(rank, format));
  13. // Spatial.
  14. for (uint spatial_dim_index = 0; spatial_dim_index < spatial_dims.size();
  15. ++spatial_dim_index) {
  16. spatial_dims[spatial_dim_index] = context->Dim(
  17. shape, GetTensorSpatialDimIndex(rank, format, spatial_dim_index));
  18. }
  19. // Channel.
  20. *filter_dim = context->Dim(shape, GetTensorFeatureDimIndex(rank, format));
  21. if (format == FORMAT_NCHW_VECT_C) {
  22. TF_RETURN_IF_ERROR(context->Multiply(
  23. *filter_dim,
  24. context->Dim(shape, GetTensorInnerFeatureDimIndex(rank, format)),
  25. filter_dim));
  26. }
  27. return Status::OK();
  28. }
  29. Status ShapeFromDimensions(DimensionHandle batch_dim,
  30. gtl::ArraySlice<DimensionHandle> spatial_dims,
  31. DimensionHandle filter_dim, TensorFormat format,
  32. InferenceContext* context, ShapeHandle* shape) {
  33. const int32 rank = GetTensorDimsFromSpatialDims(spatial_dims.size(), format);
  34. std::vector<DimensionHandle> out_dims(rank);
  35. // Batch.
  36. out_dims[tensorflow::GetTensorBatchDimIndex(rank, format)] = batch_dim;
  37. // Spatial.
  38. for (uint spatial_dim_index = 0; spatial_dim_index < spatial_dims.size();
  39. ++spatial_dim_index) {
  40. out_dims[tensorflow::GetTensorSpatialDimIndex(
  41. rank, format, spatial_dim_index)] = spatial_dims[spatial_dim_index];
  42. }
  43. // Channel.
  44. if (format == tensorflow::FORMAT_NCHW_VECT_C) {
  45. // When format is NCHW_VECT_C, factor the feature map count
  46. // into the outer feature count and the inner feature count (=4).
  47. TF_RETURN_IF_ERROR(context->Divide(
  48. filter_dim, 4, /*evenly_divisible=*/true,
  49. &out_dims[tensorflow::GetTensorFeatureDimIndex(rank, format)]));
  50. out_dims[GetTensorInnerFeatureDimIndex(rank, format)] = context->MakeDim(4);
  51. } else {
  52. out_dims[tensorflow::GetTensorFeatureDimIndex(rank, format)] = filter_dim;
  53. }
  54. *shape = context->MakeShape(out_dims);
  55. return tensorflow::Status::OK();
  56. }
  57. }