conv2d.py 1.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162
  1. import tensorflow as tf
  2. from tensorflow.python.ops import nn_ops
  3. from tensorflow.python.framework import ops
  4. from tensorflow.python.ops import array_ops
  5. from tensorflow.python.framework import tensor_shape
  6. from tensorflow.keras import layers, initializers, regularizers, constraints
  7. from .. import load_op
  8. class Conv2D(layers.Layer):
  9. def __init__(self,
  10. filters = 1,
  11. kernel_initializer = 'glorot_uniform',
  12. kernel_regularizer=None,
  13. kernel_constraint=None,
  14. ):
  15. super(Conv2D, self).__init__()
  16. #int, dim of output space
  17. self.filters = filters
  18. self.kernel_initializer = initializers.get(kernel_initializer)
  19. self.kernel_regularizer = regularizers.get(kernel_regularizer)
  20. self.kernel_constraint = constraints.get(kernel_constraint)
  21. def build(self, input_shape):
  22. input_shape = tf.TensorShape(input_shape)
  23. self.input_channel = input_shape[3]
  24. kernel_shape = (5,)*2 + (self.input_channel, self.filters)
  25. self.kernel = self.add_weight(
  26. name='kernel',
  27. shape=kernel_shape,
  28. initializer=self.kernel_initializer,
  29. regularizer=self.kernel_regularizer,
  30. constraint=self.kernel_constraint,
  31. trainable=True,
  32. dtype=self.dtype)
  33. def call(self, inputs):
  34. return load_op.op_lib.MyConv2D(input=inputs, filter=self.kernel)
  35. @ops.RegisterGradient("MyConv2D")
  36. def _my_conv_2d_grad(op, grad):
  37. shape_0, shape_1 = array_ops.shape_n([op.inputs[0], op.inputs[1]])
  38. return [
  39. nn_ops.conv2d_backprop_input(
  40. shape_0,
  41. op.inputs[1],
  42. grad,
  43. strides=[1,1,1,1],
  44. padding="VALID"
  45. ),
  46. nn_ops.conv2d_backprop_filter(
  47. op.inputs[0],
  48. shape_1,
  49. grad,
  50. strides=[1,1,1,1],
  51. padding="VALID"
  52. )
  53. ]