신기하네.. 그냥 알아서 받네?
>>> mnist = tf.keras.datasets.mnist >>> (train_images, train_labels), (test_images, test_labels) = mnist.load_data() Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz 11490434/11490434 [==============================] - 1s 0us/step |
학습하고 tflite 파일로 저장하기
import tensorflow as tf import numpy as np mnist = tf.keras.datasets.mnist (train_images, train_labels), (test_images, test_labels) = mnist.load_data() train_images = train_images.astype(np.float32) / 255.0 test_images = test_images.astype(np.float32) / 255.0 model = tf.keras.Sequential([ tf.keras.layers.InputLayer(input_shape=(28, 28)), tf.keras.layers.Reshape(target_shape=(28, 28, 1)), tf.keras.layers.Conv2D(filters=12, kernel_size=(3, 3), activation='relu'), tf.keras.layers.MaxPooling2D(pool_size=(2, 2)), tf.keras.layers.Flatten(), tf.keras.layers.Dense(10) ]) model.compile(optimizer='adam', loss=tf.keras.losses.SparseCategoricalCrossentropy( from_logits=True), metrics=['accuracy']) model.fit( train_images, train_labels, epochs=5, validation_data=(test_images, test_labels) ) # 일반 모델로 변환 converter = tf.lite.TFLiteConverter.from_keras_model(model) tflite_model = converter.convert() # 요건 차이 없음 # converter = tf.lite.TFLiteConverter.from_keras_model(model) # converter.optimizations = [tf.lite.Optimize.DEFAULT] # tflite_model_quant = converter.convert() # quant 를 하려면 아래 코드 실행해야 함 def representative_data_gen(): for input_value in tf.data.Dataset.from_tensor_slices(train_images).batch(1).take(100): yield [input_value] converter = tf.lite.TFLiteConverter.from_keras_model(model) converter.optimizations = [tf.lite.Optimize.DEFAULT] converter.representative_dataset = representative_data_gen converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] converter.inference_input_type = tf.uint8 converter.inference_output_type = tf.uint8 tflite_model_quant = converter.convert() # 파일로 저장하기 import pathlib tflite_models_dir = pathlib.Path("/tmp/mnist_tflite_models/") tflite_models_dir.mkdir(exist_ok=True, parents=True) # Save the unquantized/float model: tflite_model_file = tflite_models_dir/"mnist_model.tflite" tflite_model_file.write_bytes(tflite_model) # Save the quantized model: tflite_model_quant_file = tflite_models_dir/"mnist_model_quant.tflite" tflite_model_quant_file.write_bytes(tflite_model_quant) |
[링크 : https://www.tensorflow.org/lite/performance/post_training_integer_quant?hl=ko]
netron을 통해 생성한걸 보는데 quant나 그냥이나 어째 차이가 없냐?
+
quantization 하면 uint8로 변경된다.
그나저나, MNIST에 대해서 오해가 있었다.
출력이 [1,10] 인데 0~9 까지의 숫자에 대한 필기 데이터베이스지 알파벳이 아니란 것 -_-!
그래서 출력이 딱 10개인 건 당연하다는 것..
[링크 : https://en.wikipedia.org/wiki/MNIST_database]
+
EMNIST 라고 알파벳 손글씨가 따로 있다.
[링크 : https://www.nist.gov/itl/products-and-services/emnist-dataset]
[링크 : https://www.tensorflow.org/datasets/catalog/emnist?hl=ko]
'프로그램 사용 > yolo_tensorflow' 카테고리의 다른 글
i.mx8mp gopoint 실행 경로 (0) | 2024.01.02 |
---|---|
tensorflow keras dataset (0) | 2024.01.02 |
yolo-label (0) | 2022.03.22 |
tflite bazel rpi3b+ (0) | 2022.01.27 |
bazel cross compile (0) | 2022.01.27 |