cifar.py 2.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667
  1. # from https://www.tensorflow.org/tutorials/images/cnn
  2. import tensorflow as tf
  3. from tensorflow.keras import datasets, layers, models
  4. import matplotlib.pyplot as plt
  5. import sys
  6. sys.path.append('../hostLib/')
  7. from hostLib.layers.conv2d import Conv2D as Conv2DFPGA
  8. (train_images, train_labels), (test_images, test_labels) = datasets.cifar10.load_data()
  9. # Normalize pixel values to be between 0 and 1
  10. train_images, test_images = train_images / 255.0, test_images / 255.0
  11. class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer',
  12. 'dog', 'frog', 'horse', 'ship', 'truck']
  13. plt.figure(figsize=(10,10))
  14. for i in range(25):
  15. plt.subplot(5,5,i+1)
  16. plt.xticks([])
  17. plt.yticks([])
  18. plt.grid(False)
  19. plt.imshow(train_images[i], cmap=plt.cm.binary)
  20. # The CIFAR labels happen to be arrays,
  21. # which is why you need the extra index
  22. plt.xlabel(class_names[train_labels[i][0]])
  23. plt.show()
  24. model = models.Sequential()
  25. model.add(layers.Lambda(lambda x: tf.image.resize(x, (228,228)))) #to-do: implement 3 stage 32x32_3x3 conv2d with relu
  26. model.add(Conv2DFPGA(1))#32
  27. model.add(layers.Activation('relu'))
  28. model.add(layers.Lambda(lambda x: tf.image.resize(x, (228,228))))
  29. model.add(Conv2DFPGA(1))#64
  30. model.add(layers.Activation('relu'))
  31. model.add(layers.Lambda(lambda x: tf.image.resize(x, (228,228))))
  32. model.add(Conv2DFPGA(1))#64
  33. model.add(layers.Activation('relu'))
  34. model.add(layers.MaxPooling2D(pool_size=(16, 16)))
  35. model.add(layers.Flatten())
  36. model.add(layers.Dense(64, activation='relu'))
  37. model.add(layers.Dense(10))
  38. model.build(input_shape=(None, 32, 32, 3))
  39. model.summary()
  40. model.compile(optimizer='adam',
  41. loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
  42. metrics=['accuracy'])
  43. history = model.fit(train_images, train_labels, epochs=10,
  44. validation_data=(test_images, test_labels))
  45. plt.plot(history.history['accuracy'], label='accuracy')
  46. plt.plot(history.history['val_accuracy'], label = 'val_accuracy')
  47. plt.xlabel('Epoch')
  48. plt.ylabel('Accuracy')
  49. plt.ylim([0.5, 1])
  50. plt.legend(loc='lower right')
  51. test_loss, test_acc = model.evaluate(test_images, test_labels, verbose=2)
  52. print(test_acc)
  53. plt.show()