瀏覽代碼

fixed output channel count

subDesTagesMitExtraKaese 3 年之前
父節點
當前提交
0ba017ca56
共有 4 個文件被更改,包括 66 次插入27 次删除
  1. 1 1
      c++/lib/json
  2. 24 17
      c++/src/conv2D.cpp
  3. 8 9
      doku/layer/conv2d.md
  4. 33 0
      tests/op_test.py

+ 1 - 1
c++/lib/json

@@ -1 +1 @@
-Subproject commit dd7e25927fe7a49c81d07943c32444f0a9011665
+Subproject commit 27f5a6e82731b5538f1063bb8e955458ea7746e7

+ 24 - 17
c++/src/conv2D.cpp

@@ -36,15 +36,15 @@ namespace tf_lib {
       filter_shape, GetFilterDimIndex<num_spatial_dims>(filter_format, 'H'));
     DimensionHandle filter_cols_dim = c->Dim(
       filter_shape, GetFilterDimIndex<num_spatial_dims>(filter_format, 'W'));
-      */
+      
     DimensionHandle filter_input_depth_dim = c->Dim(
       filter_shape, GetFilterDimIndex<num_spatial_dims>(filter_format, 'I'));
-
+      */
     DimensionHandle output_rows, output_cols, output_channels;
     c->Subtract(input_spatial_dims[0], 4, &output_rows);
     c->Subtract(input_spatial_dims[1], 4, &output_cols);
 
-    c->Multiply(filter_input_depth_dim, output_depth_dim, &output_channels);
+    c->Subtract(output_depth_dim, 0, &output_channels);
 
     std::vector<DimensionHandle> out_dims(4);
     out_dims[0] = batch_size_dim;
@@ -80,19 +80,21 @@ namespace tf_lib {
     OP_REQUIRES_ASYNC(context, input_shape.dim_size(2) == 228, errors::InvalidArgument("Unsupported input width: ", input_shape.dim_size(2)), done);
     OP_REQUIRES_ASYNC(context, kernel_shape.dim_size(0) == 5, errors::InvalidArgument("Unsupported kernel height: ", kernel_shape.dim_size(0)), done);
     OP_REQUIRES_ASYNC(context, kernel_shape.dim_size(1) == 5, errors::InvalidArgument("Unsupported kernel width: ", kernel_shape.dim_size(1)), done);
+    OP_REQUIRES_ASYNC(context, kernel_shape.dim_size(2) == input_shape.dim_size(3), 
+      errors::InvalidArgument("kernel channels != input channels: ", kernel_shape.dim_size(2), " != ", input_shape.dim_size(3)), done);
 
     int batchSize = input_shape.dim_size(0);
     int channels = input_shape.dim_size(3);
-    int filters = kernel_shape.dim_size(3);
+    int outputChannels = kernel_shape.dim_size(3);
 
     TensorShape output_shape;
-    const int32 dims[] = {batchSize, outputSize, outputSize, channels * filters};
+    const int32 dims[] = {batchSize, outputSize, outputSize, outputChannels};
     TensorShapeUtils::MakeShape(dims, 4, &output_shape);
 
     output_shape.set_dim(0, batchSize);
     output_shape.set_dim(1, outputSize);
     output_shape.set_dim(2, outputSize);
-    output_shape.set_dim(3, channels * filters);
+    output_shape.set_dim(3, outputChannels);
 
     //printMu.lock();
     //std::cout << output_shape.DebugString() << std::endl;
@@ -108,20 +110,20 @@ namespace tf_lib {
     auto kernel_tensor = kernel.tensor<float, 4>();
     auto output_tensor = output->tensor<float, 4>();
 
-    auto worker = connectionManager.createWorker(Module::conv2D_5x5_Module, batchSize * channels * filters);
+    auto worker = connectionManager.createWorker(Module::conv2D_5x5_Module, batchSize * channels * outputChannels);
     {
       worker->setJobTimeout(milliseconds(300));
       worker->setRetryCount(10);
       auto jobs = worker->getJobList();
 
       for(int sample=0; sample<batchSize; sample++) {
-        for(int channel=0; channel<channels; channel++) {
-          for(int filter=0; filter<filters; filter++) {
-            auto job = jobs->getJob(sample * channels * filters + channel * filters + filter);
+        for(int outputChannel=0; outputChannel<outputChannels; outputChannel++) {
+          for(int channel=0; channel<channels; channel++) {
+            auto job = jobs->getJob(sample * outputChannels * channels + outputChannel * channels + channel);
             
             for(int x=0; x<kernelSize; x++) {
               for(int y=0; y<kernelSize; y++) {
-                job->setPayload(y*kernelSize + x, *((uint32_t*)&kernel_tensor(y, x, channel, filter)));
+                job->setPayload(y*kernelSize + x, *((uint32_t*)&kernel_tensor(y, x, channel, outputChannel)));
               }
             }
             for(int x=0; x<sizeWithBorder; x++) {
@@ -134,16 +136,21 @@ namespace tf_lib {
         }
       }
     }
-    worker->setDoneCallback([output_tensor, worker, done, batchSize, channels, filters, this]{
+    worker->setDoneCallback([output_tensor, worker, done, batchSize, channels, outputChannels, this]{
       auto jobs = worker->getJobList();
       for(int sample=0; sample<batchSize; sample++) {
-        for(int channel=0; channel<channels; channel++) {
-          for(int filter=0; filter<filters; filter++) {
-            auto job = jobs->getJob(sample * channels * filters + channel * filters + filter);
+        for(int outputChannel=0; outputChannel<outputChannels; outputChannel++) {
+          for(int x=0; x<outputSize; x++) {
+            for(int y=0; y<outputSize; y++) {
+              output_tensor(sample, y, x, outputChannel) = 0;
+            }
+          }
+          for(int channel=0; channel<channels; channel++) {
+            auto job = jobs->getJob(sample * outputChannels * channels + outputChannel * channels + channel);
             for(int x=0; x<outputSize; x++) {
               for(int y=0; y<outputSize; y++) {
-                uint32_t val = job->getResponsePayload((y+border*2)*sizeWithBorder + (x+border*2) + 1);
-                output_tensor(sample, y, x, channel) = *((float*)&val);
+                uint32_t val = job->getResponsePayload((y+border*2)*sizeWithBorder + (x+border*2));
+                output_tensor(sample, y, x, outputChannel) += *((float*)&val);
               }
             }
           }

+ 8 - 9
doku/layer/conv2d.md

@@ -7,25 +7,24 @@ Input:
 - an FPGA: `[imageY, imageX]`
 
 Kernel:
-- in TF: `[kernelY, kernelX, channels, filters]`
+- in TF: `[kernelY, kernelX, channels, outputChannels]`
 - an FPGA: `[kernelY, kernelX]`
 
 Output:
 - vom FPGA: `[imageY2, imageX2]`
-- an TF: `[batchSize, imageY2, imageX2, channels * filters]`
+- an TF: `[batchSize, imageY2, imageX2, outputChannels]`
 
 ## Parallelisierung
 
 1.  **Ohne FPGA-seitigem Speicher**
 
-    FPGA Recheneinheiten werden verteilt `(batchSize * channels * filters)` Mal verwendet.
+    FPGA Recheneinheiten werden verteilt `(batchSize * channels * outputChannels)` Mal verwendet.
 
     ```python
     for sample in range(batchSize):
-      for channnel in range(channels):
-        for filter in range(filters):
-          output[sample][channel + filter * channels] = f(
-            input[sample][channel], 
-            kernel[channel][filter]
-          )
+      for outputChannel in range(outputChannels):
+        output[sample][outputChannel] = sum([f(
+          input[sample][channel], 
+          kernel[channel][outputChannel]
+        ) for channel in range(channels)])
     ```

+ 33 - 0
tests/op_test.py

@@ -1,4 +1,5 @@
 import tensorflow as tf
+from tensorflow import nn
 import numpy as np
 from IPython import embed
 
@@ -35,5 +36,37 @@ class FPGALibTest(tf.test.TestCase):
         result = load_op.op_lib.MyDummyBig(input=input)
         self.assertAllEqual(result, input)
 
+  def testConv2DSingle(self):
+    img = np.ndarray((228,228), dtype=float)
+    img.fill(0)
+    for a in range(228):
+      for b in range(228):
+        img[a][b] = (a)*228+(b)
+    kernel = np.ndarray((5,5), dtype=float)
+    kernel.fill(0)
+    kernel[2][2] = 1
+    input = tf.constant(np.expand_dims(img, (0, 3)), dtype=float)
+    filter = tf.constant(np.expand_dims(kernel, (2, 3)), dtype=float)
+    with self.session():
+      ref = nn.convolution(input, filter)
+      result = load_op.op_lib.MyConv2D(input=input, filter=filter)
+      self.assertAllClose(result, ref)
+
+  def testConv2DRandom(self):
+    input = tf.random.uniform(shape=[1,228,228,1])
+    filter = tf.random.uniform(shape=[5,5,1,1])
+    with self.session():
+      ref = nn.convolution(input, filter)
+      result = load_op.op_lib.MyConv2D(input=input, filter=filter)
+      self.assertAllClose(result, ref)
+
+  def testConv2DMulti(self):
+    input = tf.random.uniform(shape=[2,228,228,3])
+    filter = tf.random.uniform(shape=[5,5,3,4])
+    with self.session():
+      ref = nn.convolution(input, filter)
+      result = load_op.op_lib.MyConv2D(input=input, filter=filter)
+      self.assertAllClose(result, ref)
+
 if __name__ == "__main__":
   tf.test.main()