Skip to content

Commit

Permalink
[FIX] Single GPU training job update (GoogleCloudPlatform#794)
Browse files Browse the repository at this point in the history
single GPU job update
  • Loading branch information
xiangshen-dk authored Sep 3, 2024
1 parent 1d8ccc3 commit b22273b
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import tensorflow as tf
import numpy as np
import keras

strategy = tf.distribute.MirroredStrategy()
print('Number of devices: {}'.format(strategy.num_replicas_in_sync))
Expand All @@ -28,7 +29,7 @@ def scale(image):

images_dir = "/data/mnist_predict/"

img_dataset = tf.keras.utils.image_dataset_from_directory(
img_dataset = keras.utils.image_dataset_from_directory(
images_dir,
image_size=(28, 28),
color_mode="grayscale",
Expand All @@ -41,13 +42,13 @@ def scale(image):

img_prediction_dataset = img_dataset.map(scale)

model_path = '/data/mnist_saved_model/'
model_path = '/data/mnist_saved_model/mnist.keras'

with strategy.scope():
replicated_model = tf.keras.models.load_model(model_path)
replicated_model = keras.models.load_model(model_path)
replicated_model.compile(
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer=tf.keras.optimizers.Adam(),
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer=keras.optimizers.Adam(),
metrics=['accuracy'])

predictions = replicated_model.predict(img_prediction_dataset)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import tensorflow_datasets as tfds
import tensorflow as tf
import keras
import glob

datasets, info = tfds.load(name='mnist', with_info=True, as_supervised=True)

Expand Down Expand Up @@ -45,16 +47,17 @@ def scale(image, label):
eval_dataset = mnist_test.map(scale).batch(BATCH_SIZE)

with strategy.scope():
model = tf.keras.Sequential([
tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(28, 28, 1)),
tf.keras.layers.MaxPooling2D(),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(64, activation='relu'),
tf.keras.layers.Dense(10)
model = keras.Sequential([
keras.Input(shape=(28, 28, 1)),
keras.layers.Conv2D(32, kernel_size=(3, 3), activation="relu"),
keras.layers.MaxPooling2D(),
keras.layers.Flatten(),
keras.layers.Dense(64, activation='relu'),
keras.layers.Dense(10)
])

model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer=tf.keras.optimizers.Adam(),
model.compile(loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer=keras.optimizers.Adam(),
metrics=['accuracy'])

# Define the checkpoint directory to store the checkpoints.
Expand All @@ -71,7 +74,7 @@ def decay(epoch):
return 1e-5

# Define a callback for printing the learning rate at the end of each epoch.
class PrintLR(tf.keras.callbacks.Callback):
class PrintLR(keras.callbacks.Callback):
def on_epoch_end(self, epoch, logs=None):
print('\nLearning rate for epoch {} is {}'.format(epoch + 1,
model.optimizer.learning_rate.numpy()))
Expand All @@ -87,15 +90,25 @@ def on_epoch_end(self, epoch, logs=None):
EPOCHS = 12

model.fit(train_dataset, epochs=EPOCHS, callbacks=callbacks)
model.load_weights(tf.train.latest_checkpoint(checkpoint_dir))

# Function to find the latest .h5 file
def find_latest_h5_checkpoint(checkpoint_dir):
list_of_files = glob.glob(f'{checkpoint_dir}/*.h5')
if list_of_files:
latest_file = max(list_of_files, key=os.path.getctime)
return latest_file
else:
return None

model.load_weights(find_latest_h5_checkpoint(checkpoint_dir))

eval_loss, eval_acc = model.evaluate(eval_dataset)

print('Eval loss: {}, Eval accuracy: {}'.format(eval_loss, eval_acc))

path = '/data/mnist_saved_model/'
path = '/data/mnist_saved_model/mnist.keras'

model.save(path, save_format='tf')
model.save(path)

print('Training finished. Model saved')

0 comments on commit b22273b

Please sign in to comment.