conv2d.py 1.2 KB

1234567891011121314151617181920212223242526272829303132333435363738
  1. import tensorflow as tf
  2. from tensorflow.python.framework import tensor_shape
  3. from tensorflow.keras import layers, initializers, regularizers, constraints
  4. from .. import load_op
  5. class Conv2D(layers.Layer):
  6. def __init__(self,
  7. filters = 1,
  8. kernel_initializer = 'glorot_uniform',
  9. kernel_regularizer=None,
  10. kernel_constraint=None,
  11. ):
  12. super(Conv2D, self).__init__()
  13. #int, dim of output space
  14. self.filters = filters
  15. self.kernel_initializer = initializers.get(kernel_initializer)
  16. self.kernel_regularizer = regularizers.get(kernel_regularizer)
  17. self.kernel_constraint = constraints.get(kernel_constraint)
  18. def build(self, input_shape):
  19. input_shape = tf.TensorShape(input_shape)
  20. self.input_channel = input_shape[3]
  21. kernel_shape = (5,)*2 + (self.input_channel, self.filters)
  22. self.kernel = self.add_weight(
  23. name='kernel',
  24. shape=kernel_shape,
  25. initializer=self.kernel_initializer,
  26. regularizer=self.kernel_regularizer,
  27. constraint=self.kernel_constraint,
  28. trainable=True,
  29. dtype=self.dtype)
  30. def call(self, inputs):
  31. return load_op.op_lib.MyConv2D(input=inputs, filter=self.kernel)