conv2D.py 1.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354
  1. import tensorflow as tf
  2. from tensorflow.python.framework import tensor_shape
  3. from tensorflow.keras import layers, initializers, regularizers, constraints
  4. from .. import load_op
  5. class Conv2D(layers.Layer):
  6. def __init__(self,
  7. filters = 1,
  8. kernel_initializer = 'glorot_uniform',
  9. kernel_regularizer=None,
  10. kernel_constraint=None,
  11. ):
  12. super(Conv2D, self).__init__()
  13. #int, dim of output space
  14. self.filters = filters
  15. self.kernel_initializer = initializers.get(kernel_initializer)
  16. self.kernel_regularizer = regularizers.get(kernel_regularizer)
  17. self.kernel_constraint = constraints.get(kernel_constraint)
  18. def build(self, input_shape):
  19. input_shape = tf.TensorShape(input_shape)
  20. self.input_channel = input_shape[3]
  21. kernel_shape = (5,)*2 + (self.input_channel, self.filters)
  22. self.kernel = self.add_weight(
  23. name='kernel',
  24. shape=kernel_shape,
  25. initializer=self.kernel_initializer,
  26. regularizer=self.kernel_regularizer,
  27. constraint=self.kernel_constraint,
  28. trainable=True,
  29. dtype=self.dtype)
  30. def call(self, inputs):
  31. #out = tf.Tensor(tf.int32, shape=inputs.shape)
  32. ch_inputs = tf.unstack(inputs, axis=3)#tf.dtypes.cast(inputs, dtype=tf.int32), axis=3)
  33. ch_kernel = tf.unstack(tf.dtypes.cast(self.kernel, dtype=tf.int32), axis=2)
  34. ch_outputs = [None] * len(ch_inputs)
  35. for ch in range(len(ch_inputs)):
  36. print(ch_inputs[ch], ch_kernel[ch])
  37. ch_outputs[ch] = [None] * self.filters
  38. kernel_2d = tf.unstack(ch_kernel[ch], axis=2)
  39. for f in range(len(kernel_2d)):
  40. ch_outputs[ch][f] = load_op.op_lib.MyConv2D(input=ch_inputs[ch], filter=kernel_2d[f], delay=(f+1)*100)
  41. ch_outputs[ch] = tf.stack(ch_outputs[ch], axis=2)
  42. outs = tf.stack(ch_outputs, axis=2)
  43. return outs #tf.dtypes.cast(outs, dtype=tf.float32)