conv2d.py 2.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768
  1. import tensorflow as tf
  2. from tensorflow.python.ops import nn_ops
  3. from tensorflow.python.framework import ops
  4. from tensorflow.python.ops import array_ops
  5. from tensorflow.python.framework import tensor_shape
  6. from tensorflow.keras import layers, initializers, regularizers, constraints
  7. from .. import load_op
  8. class Conv2D(layers.Layer):
  9. def __init__(self,
  10. filters = 1,
  11. kernel_initializer = 'glorot_uniform',
  12. kernel_regularizer=None,
  13. kernel_constraint=None,
  14. implementation = 1):
  15. super(Conv2D, self).__init__()
  16. #int, dim of output space
  17. self.filters = filters
  18. self.kernel_initializer = initializers.get(kernel_initializer)
  19. self.kernel_regularizer = regularizers.get(kernel_regularizer)
  20. self.kernel_constraint = constraints.get(kernel_constraint)
  21. self.implementation = implementation
  22. def build(self, input_shape):
  23. input_shape = tf.TensorShape(input_shape)
  24. self.input_channel = input_shape[3]
  25. kernel_shape = (5,)*2 + (self.input_channel, self.filters)
  26. self.kernel = self.add_weight(
  27. name='kernel',
  28. shape=kernel_shape,
  29. initializer=self.kernel_initializer,
  30. regularizer=self.kernel_regularizer,
  31. constraint=self.kernel_constraint,
  32. trainable=True,
  33. dtype=self.dtype)
  34. def call(self, inputs):
  35. if self.implementation == 1:
  36. return load_op.op_lib.MyConv2D_1(input=inputs, filter=self.kernel)
  37. if self.implementation == 2:
  38. return load_op.op_lib.MyConv2D_2(input=inputs, filter=self.kernel)
  39. if self.implementation == 3:
  40. return load_op.op_lib.MyConv2D_3(input=inputs, filter=self.kernel)
  41. @ops.RegisterGradient("MyConv2D_1")
  42. @ops.RegisterGradient("MyConv2D_2")
  43. @ops.RegisterGradient("MyConv2D_3")
  44. def _my_conv_2d_grad(op, grad):
  45. shape_0, shape_1 = array_ops.shape_n([op.inputs[0], op.inputs[1]])
  46. return [
  47. nn_ops.conv2d_backprop_input(
  48. shape_0,
  49. op.inputs[1],
  50. grad,
  51. strides=[1,1,1,1],
  52. padding="VALID"
  53. ),
  54. nn_ops.conv2d_backprop_filter(
  55. op.inputs[0],
  56. shape_1,
  57. grad,
  58. strides=[1,1,1,1],
  59. padding="VALID"
  60. )
  61. ]