train.py 2.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677
  1. import tensorflow as tf
  2. import tensorflow.keras as keras
  3. from tensorflow.keras import layers
  4. from tensorflow.keras.layers import Input, Embedding, LSTM, Dense, Dropout, Flatten, MaxPooling2D, Conv2D
  5. from tensorflow.keras.models import Model, Sequential
  6. from tensorflow.keras.datasets import mnist
  7. from tensorflow.keras.utils import plot_model, to_categorical
  8. import numpy as np
  9. from IPython import embed
  10. import sys
  11. sys.path.append('../hostLib/')
  12. from hostLib.layers.conv2d import Conv2D as Conv2DFPGA
  13. batch_size = 128
  14. num_classes = 10
  15. epochs = 1200
  16. # input image dimensions
  17. img_rows, img_cols = 28, 28
  18. # the data, split between train and test sets
  19. (x_train, y_train), (x_test, y_test) = mnist.load_data()
  20. x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1)
  21. x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1)
  22. x_train = x_train.astype('float')
  23. x_test = x_test.astype('float')
  24. #x_train /= 255
  25. #x_test /= 255
  26. print('x_train shape:', x_train.shape)
  27. print(x_train.shape[0], 'train samples')
  28. print(x_test.shape[0], 'test samples')
  29. # convert class vectors to binary class matrices
  30. y_train = to_categorical(y_train, num_classes)
  31. y_test = to_categorical(y_test, num_classes)
  32. i = layers.Input(shape=(28, 28, 1))
  33. a = layers.Lambda(lambda x: tf.image.resize(x, (228,228)))(i)
  34. b = Conv2DFPGA(2)(a)
  35. c = Conv2DFPGA(1)(a)
  36. d = Conv2DFPGA(1)(layers.Lambda(lambda x: tf.image.resize(x, (228,228)))(b))
  37. e = Conv2DFPGA(2)(layers.Lambda(lambda x: tf.image.resize(x, (228,228)))(c))
  38. print(a)
  39. print(b)
  40. print(c)
  41. print(d)
  42. print(e)
  43. x = layers.Add()([d,e])
  44. y = layers.Flatten()(x)
  45. z = layers.Dense(num_classes, activation='softmax')(y)
  46. model = Model(inputs=i, outputs=z)
  47. print(model.output_shape)
  48. model.compile(loss=keras.losses.categorical_crossentropy,
  49. optimizer=keras.optimizers.Adadelta(),
  50. metrics=['accuracy'])
  51. plot_model(model, to_file='model.png', expand_nested=True, show_shapes=True)
  52. model.fit(x_train, y_train,
  53. batch_size=batch_size,
  54. epochs=epochs,
  55. verbose=1,
  56. validation_data=(x_test, y_test))
  57. score = model.evaluate(x_test, y_test, verbose=0)
  58. print('Test loss:', score[0])
  59. print('Test accuracy:', score[1])