train.py 2.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172
  1. import tensorflow as tf
  2. import tensorflow.keras as keras
  3. from tensorflow.keras import layers
  4. from tensorflow.keras.layers import Input, Embedding, LSTM, Dense, Dropout, Flatten, MaxPooling2D, Conv2D
  5. from tensorflow.keras.models import Model, Sequential
  6. from tensorflow.keras.datasets import mnist
  7. from tensorflow.keras.utils import plot_model, to_categorical
  8. import numpy as np
  9. from IPython import embed
  10. zero_out_module = tf.load_op_library('./zero_out.so')
  11. batch_size = 128
  12. num_classes = 10
  13. epochs = 1 # 12
  14. # input image dimensions
  15. img_rows, img_cols = 28, 28
  16. # the data, split between train and test sets
  17. (x_train, y_train), (x_test, y_test) = mnist.load_data()
  18. x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1)
  19. x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1)
  20. input_shape = (img_rows, img_cols, 1)
  21. x_train = x_train.astype('float32')
  22. x_test = x_test.astype('float32')
  23. x_train /= 255
  24. x_test /= 255
  25. print('x_train shape:', x_train.shape)
  26. print(x_train.shape[0], 'train samples')
  27. print(x_test.shape[0], 'test samples')
  28. # convert class vectors to binary class matrices
  29. y_train = to_categorical(y_train, num_classes)
  30. y_test = to_categorical(y_test, num_classes)
  31. class Linear(layers.Layer):
  32. def __init__(self, units=32, input_dim=32):
  33. super(Linear, self).__init__()
  34. def call(self, inputs):
  35. ints = tf.dtypes.cast(inputs, dtype=tf.int32)
  36. print(ints)
  37. outs = zero_out_module.zero_out(ints)
  38. return tf.dtypes.cast(outs, dtype=tf.float32)
  39. model = Sequential()
  40. model.add(Flatten())
  41. model.add(Dense(128, activation='relu'))
  42. model.add(Dropout(0.5))
  43. model.add(Dense(num_classes, activation='softmax'))
  44. model.add(Linear())
  45. model.compile(loss=keras.losses.categorical_crossentropy,
  46. optimizer=keras.optimizers.Adadelta(),
  47. metrics=['accuracy'])
  48. model.fit(x_train, y_train,
  49. batch_size=batch_size,
  50. epochs=epochs,
  51. verbose=1,
  52. validation_data=(x_test, y_test))
  53. score = model.evaluate(x_test, y_test, verbose=0)
  54. print('Test loss:', score[0])
  55. print('Test accuracy:', score[1])
  56. plot_model(model, to_file='model.png', expand_nested=True, show_shapes=True)