|
- import argparse
- import gzip
- import json
- import os
- import datetime
-
- import numpy as np
- import tensorflow as tf
-
- num_workers = 1
-
-
- def load_data(data_dir):
- files = [
- 'train-labels-idx1-ubyte.gz', 'train-images-idx3-ubyte.gz',
- 't10k-labels-idx1-ubyte.gz', 't10k-images-idx3-ubyte.gz'
- ]
- paths = []
- for f in files:
- paths.append(os.path.join(data_dir, f))
- with gzip.open(paths[0], 'rb') as f:
- y_train = np.frombuffer(f.read(), np.uint8, offset=8)
- with gzip.open(paths[1], 'rb') as f:
- x_train = np.frombuffer(f.read(), np.uint8, offset=16).reshape(len(y_train), 28, 28)
- with gzip.open(paths[2], 'rb') as f:
- y_test = np.frombuffer(f.read(), np.uint8, offset=8)
- with gzip.open(paths[3], 'rb') as f:
- x_test = np.frombuffer(f.read(), np.uint8, offset=16).reshape(len(y_test), 28, 28)
- return (x_train, y_train), (x_test, y_test)
-
-
- def mnist_dataset(batch_size=64, data_dir=None):
- # load dataset
- if data_dir:
- print(f'Loading mnist data from {data_dir}')
- (x_train, y_train), (x_test, y_test) = load_data(data_dir)
- else:
- print('Loading mnist data from tf.keras.datasets.mnist')
- (x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
- x_train, x_test = x_train / 255.0, x_test / 255.0
-
- train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train)).shuffle(60000).repeat().batch(batch_size)
- test_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(batch_size)
-
- return train_dataset, test_dataset
-
-
- def get_strategy(strategy='off'):
- strategy = strategy.lower()
- # multiple nodes, every nodes have multiple GPUs
- if strategy == "multi_worker_mirrored":
- return tf.distribute.experimental.MultiWorkerMirroredStrategy()
- # single node with multiple GPUs
- if strategy == "mirrored":
- return tf.distribute.MirroredStrategy()
- # single node with single GPU
- return tf.distribute.get_strategy()
-
-
- def setup_env(args):
- tf.config.set_soft_device_placement(True)
-
- # limit the gpu memory usage as much as it need.
- try:
- gpus = tf.config.experimental.list_physical_devices('GPU')
- for gpu in gpus:
- tf.config.experimental.set_memory_growth(gpu, True)
- logical_gpus = tf.config.list_logical_devices('GPU')
- print(f"Detected {len(gpus)} Physical GPUs, {len(logical_gpus)} Logical GPUs")
- except RuntimeError as e:
- print(e)
-
- if args.strategy == 'multi_worker_mirrored':
- index = int(os.environ['VK_TASK_INDEX'])
- task_name = os.environ["VC_TASK_NAME"].upper()
- ips = os.environ[f'VC_{task_name}_HOSTS']
- ips = ips.split(',')
- global num_workers
- num_workers = len(ips)
- ips = [f'{ip}:20000' for ip in ips]
- os.environ["TF_CONFIG"] = json.dumps({
- "cluster": {
- "worker": ips
- },
- "task": {"type": "worker", "index": index}
- })
- print('Setup env TF_CONFIG:', os.environ["TF_CONFIG"])
-
-
- def setup_config():
- parser = argparse.ArgumentParser(description='Train MNIST digits classification')
- parser.add_argument('--dataset', help='Directory to MNIST dataset. Download mnist dataset if this argument is None')
- parser.add_argument('--output', default='.', help='Directory to save model, log and TensorBoard')
- parser.add_argument('--batch-size', default=64, type=int, help='Batch size')
- parser.add_argument('--epochs', default=10, type=int, help='Number of epochs')
-
- parser.add_argument('--eval', action='store_true', help='Whether do evaluation after training finished')
- parser.add_argument(
- '--strategy',
- default='off',
- choices=['off', 'mirrored', 'multi_worker_mirrored'],
- help='TensorFlow distributed training strategies'
- )
- args = parser.parse_args()
- return args
-
-
- def main():
- args = setup_config()
- # tf2 limitation: Collective ops must be configured at program startup
- strategy = get_strategy(args.strategy)
- setup_env(args)
-
- # build dataset
- global_batch_size = args.batch_size * num_workers
- train_dataset, test_dataset = mnist_dataset(
- batch_size=global_batch_size,
- data_dir=args.dataset
- )
- print(f'global_batch_size={global_batch_size}, num_workers={num_workers}')
-
- with strategy.scope():
- # build model
- model = tf.keras.models.Sequential([
- tf.keras.layers.Flatten(input_shape=(28, 28)),
- tf.keras.layers.Dense(128, activation='relu'),
- tf.keras.layers.Dropout(0.2),
- tf.keras.layers.Dense(10)
- ])
- model.compile(
- optimizer='adam',
- loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
- metrics=['accuracy']
- )
- model.summary()
-
- # TensorBoard
- tensorboard_dir = os.path.join(args.output, "tensorboard-" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))
- callbacks = [tf.keras.callbacks.TensorBoard(log_dir=tensorboard_dir, histogram_freq=1)]
-
- # training
- model.fit(train_dataset, epochs=args.epochs, steps_per_epoch=70, validation_data=test_dataset, callbacks=callbacks)
-
- # evaluation
- if args.eval:
- model.evaluate(test_dataset, verbose=2)
-
- # save model
- os.makedirs(args.output, exist_ok=True)
- model.save(args.output)
- print(f'Saved model to {args.output}')
-
-
- if __name__ == '__main__':
- # python train.py --dataset=/path/to/MNIST/dataset --output=/path/to/output
- print("TensorFlow version:", tf.__version__)
- main()
|