-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathtensorflow_training.py
More file actions
58 lines (48 loc) · 1.73 KB
/
tensorflow_training.py
File metadata and controls
58 lines (48 loc) · 1.73 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import argparse
import pathlib
from pyrtlnet.constants import quantized_model_prefix, test_data_file
from pyrtlnet.mnist_util import load_mnist_images
from pyrtlnet.tensorflow_training import (
evaluate_model,
quantize_model,
train_unquantized_model,
)
from pyrtlnet.training_util import save_mnist_data
def main() -> None:
parser = argparse.ArgumentParser(prog="tensorflow_training.py")
parser.add_argument("--tensor_path", type=str, default=".")
args = parser.parse_args()
# Load MNIST dataset.
(train_images, train_labels), (test_images, test_labels) = load_mnist_images()
learning_rate = 0.001
epochs = 10
print("Training unquantized model.")
model = train_unquantized_model(
learning_rate=learning_rate,
epochs=epochs,
train_images=train_images,
train_labels=train_labels,
)
print("Evaluating unquantized model.")
evaluate_model(model=model, test_images=test_images, test_labels=test_labels)
model_prefix = pathlib.Path(args.tensor_path) / quantized_model_prefix
print(
f"Training quantized model and writing {model_prefix}.tflite and "
f"{model_prefix}.npz."
)
model = quantize_model(
model=model,
learning_rate=learning_rate / 10000,
epochs=int(epochs / 5),
train_images=train_images,
train_labels=train_labels,
quantized_model_prefix=model_prefix,
)
print("Evaluating quantized model.")
evaluate_model(model=model, test_images=test_images, test_labels=test_labels)
print(f"Saving MNIST test data to {test_data_file}")
save_mnist_data(
tensor_path=args.tensor_path, test_images=test_images, test_labels=test_labels
)
if __name__ == "__main__":
main()