|
- # Copyright 2022 The Orbit Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
-
- """Tests for orbit.controller."""
-
- import os
-
- from absl import logging
- from absl.testing import parameterized
-
- import numpy as np
-
- from orbit import controller
- from orbit import runner
- from orbit import standard_runner
- import orbit.utils
-
- import tensorflow as tf
-
-
- def create_model():
- x = tf.keras.layers.Input(shape=(3,), name="input")
- y = tf.keras.layers.Dense(4, name="dense")(x)
- model = tf.keras.Model(x, y)
- return model
-
-
- def summaries_with_matching_keyword(keyword, summary_dir):
- """Returns summary protos matching given keyword from event file."""
- matches = []
- event_paths = tf.io.gfile.glob(os.path.join(summary_dir, "events*"))
- for event in tf.compat.v1.train.summary_iterator(event_paths[-1]):
- if event.summary is not None:
- for value in event.summary.value:
- if keyword in value.tag:
- matches.append(event.summary)
- return matches
-
-
- def dataset_fn(ctx):
- del ctx
- inputs = np.zeros((10, 3), dtype=np.float32)
- targets = np.ones((10, 4), dtype=np.float32)
- dataset = tf.data.Dataset.from_tensor_slices((inputs, targets))
- dataset = dataset.repeat(100)
- dataset = dataset.batch(10, drop_remainder=True)
- return dataset
-
-
- class TestRunner(standard_runner.StandardTrainer,
- standard_runner.StandardEvaluator):
- """Implements the training and evaluation APIs for the test model."""
-
- def __init__(self, return_numpy=False):
- self.strategy = tf.distribute.get_strategy()
- self.model = create_model()
- self.optimizer = tf.keras.optimizers.RMSprop(learning_rate=0.1)
- self.global_step = self.optimizer.iterations
- self.train_loss = tf.keras.metrics.Mean("train_loss", dtype=tf.float32)
- self.eval_loss = tf.keras.metrics.Mean("eval_loss", dtype=tf.float32)
- self.return_numpy = return_numpy
- train_dataset = self.strategy.distribute_datasets_from_function(dataset_fn)
- eval_dataset = self.strategy.distribute_datasets_from_function(dataset_fn)
- standard_runner.StandardTrainer.__init__(self, train_dataset)
- standard_runner.StandardEvaluator.__init__(self, eval_dataset)
-
- def train_step(self, iterator):
-
- def _replicated_step(inputs):
- """Replicated training step."""
- inputs, targets = inputs
- with tf.GradientTape() as tape:
- outputs = self.model(inputs)
- loss = tf.reduce_mean(tf.keras.losses.MSE(targets, outputs))
- grads = tape.gradient(loss, self.model.variables)
- self.optimizer.apply_gradients(zip(grads, self.model.variables))
- self.train_loss.update_state(loss)
-
- self.strategy.run(_replicated_step, args=(next(iterator),))
-
- def train_loop_end(self):
- train_loss = self.train_loss.result()
- return {
- "loss": train_loss.numpy() if self.return_numpy else train_loss,
- }
-
- def build_eval_dataset(self):
- return self.strategy.distribute_datasets_from_function(dataset_fn)
-
- def eval_begin(self):
- self.eval_loss.reset_states()
-
- def eval_step(self, iterator):
-
- def _replicated_step(inputs):
- """Replicated evaluation step."""
- inputs, targets = inputs
- outputs = self.model(inputs)
- loss = tf.reduce_mean(tf.keras.losses.MSE(targets, outputs))
- self.eval_loss.update_state(loss)
-
- self.strategy.run(_replicated_step, args=(next(iterator),))
-
- def eval_end(self):
- eval_loss = self.eval_loss.result()
- return {
- "eval_loss": eval_loss.numpy() if self.return_numpy else eval_loss,
- }
-
-
- class TestEvaluator(standard_runner.StandardEvaluator):
- """Implements the training and evaluation APIs for the test model."""
-
- def __init__(self):
- self.strategy = tf.distribute.get_strategy()
- self.model = create_model()
- eval_dataset = self.strategy.distribute_datasets_from_function(dataset_fn)
- standard_runner.StandardEvaluator.__init__(self, eval_dataset)
-
- def eval_reduce(self, state, output):
- state.append(output)
- return state
-
- def eval_begin(self):
- return []
-
- def eval_step(self, iterator):
-
- def _replicated_step(inputs):
- """Replicated evaluation step."""
- inputs, targets = inputs
- outputs = self.model(inputs)
- loss = tf.reduce_mean(tf.keras.losses.MSE(targets, outputs))
- return loss
-
- per_replica_losses = self.strategy.run(
- _replicated_step, args=(next(iterator),))
- mean_loss = self.strategy.reduce(
- tf.distribute.ReduceOp.MEAN, per_replica_losses, axis=None)
- return mean_loss
-
- def eval_end(self, outputs):
- return {
- "eval_loss": tf.reduce_mean(outputs),
- }
-
-
- class TestEvaluatorNoOutput(runner.AbstractEvaluator):
-
- def evaluate(self, num_steps):
- pass
-
-
- class TestEvaluatorWithNestedSummary(standard_runner.StandardEvaluator):
- """Implements the training and evaluation APIs for the test model."""
-
- def __init__(self):
- self.strategy = tf.distribute.get_strategy()
- self.model = create_model()
- dataset = self.strategy.distribute_datasets_from_function(dataset_fn)
- dataset2 = self.strategy.distribute_datasets_from_function(dataset_fn)
- self.loss = tf.keras.metrics.Mean("loss", dtype=tf.float32)
- self.accuracy = tf.keras.metrics.CategoricalAccuracy(
- "accuracy", dtype=tf.float32)
- self.loss2 = tf.keras.metrics.Mean("loss", dtype=tf.float32)
- self.accuracy2 = tf.keras.metrics.CategoricalAccuracy(
- "accuracy", dtype=tf.float32)
- standard_runner.StandardEvaluator.__init__(
- self, eval_dataset={
- "dataset": dataset,
- "dataset2": dataset2
- })
-
- def eval_step(self, iterator):
-
- def _replicated_step(loss, accuracy, inputs):
- """Replicated evaluation step."""
- inputs, targets = inputs
- outputs = self.model(inputs)
- loss.update_state(tf.keras.losses.MSE(targets, outputs))
- accuracy.update_state(targets, outputs)
-
- self.strategy.run(
- lambda inputs: _replicated_step(self.loss, self.accuracy, inputs),
- args=(next(iterator["dataset"]),))
- self.strategy.run(
- lambda inputs: _replicated_step(self.loss2, self.accuracy2, inputs),
- args=(next(iterator["dataset2"]),))
-
- def eval_end(self):
- return {
- "dataset": {
- "loss": self.loss.result(),
- "accuracy": self.accuracy.result()
- },
- "dataset2": {
- "loss": self.loss2.result(),
- "accuracy": self.accuracy2.result()
- },
- }
-
-
- class TestTrainerWithSummaries(standard_runner.StandardTrainer):
- """A Trainer model with summaries for testing purposes."""
-
- def __init__(self):
- self.strategy = tf.distribute.get_strategy()
- self.model = create_model()
- self.optimizer = tf.keras.optimizers.RMSprop(learning_rate=0.1)
- self.global_step = self.optimizer.iterations
- self.train_loss = tf.keras.metrics.Mean("train_loss", dtype=tf.float32)
- train_dataset = self.strategy.distribute_datasets_from_function(dataset_fn)
- standard_runner.StandardTrainer.__init__(
- self,
- train_dataset,
- options=standard_runner.StandardTrainerOptions(
- use_tpu_summary_optimization=True))
-
- def build_train_dataset(self):
- return self.strategy.distribute_datasets_from_function(dataset_fn)
-
- def train_step(self, iterator):
-
- def _replicated_step(inputs):
- """Replicated training step."""
- inputs, targets = inputs
- with tf.GradientTape() as tape:
- outputs = self.model(inputs)
- loss = tf.reduce_mean(tf.keras.losses.MSE(targets, outputs))
- tf.summary.scalar("loss", loss)
- grads = tape.gradient(loss, self.model.variables)
- self.optimizer.apply_gradients(zip(grads, self.model.variables))
- self.train_loss.update_state(loss)
-
- self.strategy.run(_replicated_step, args=(next(iterator),))
-
-
- class ControllerTest(tf.test.TestCase, parameterized.TestCase):
-
- def setUp(self):
- super().setUp()
- self.model_dir = self.get_temp_dir()
-
- def test_no_checkpoint(self):
- test_runner = TestRunner()
- # No checkpoint manager and no strategy.
- test_controller = controller.Controller(
- trainer=test_runner,
- evaluator=test_runner,
- global_step=test_runner.global_step,
- steps_per_loop=2,
- summary_dir=os.path.join(self.model_dir, "summaries/train"),
- eval_summary_dir=os.path.join(self.model_dir, "summaries/eval"))
- test_controller.train_and_evaluate(
- train_steps=10, eval_steps=2, eval_interval=6)
- self.assertEqual(test_runner.global_step, 10)
- # Loss and accuracy values should be written into summaries.
- self.assertNotEmpty(
- tf.io.gfile.listdir(os.path.join(self.model_dir, "summaries/train")))
- self.assertNotEmpty(
- summaries_with_matching_keyword(
- "loss", os.path.join(self.model_dir, "summaries/train")))
- self.assertNotEmpty(
- tf.io.gfile.listdir(os.path.join(self.model_dir, "summaries/eval")))
- self.assertNotEmpty(
- summaries_with_matching_keyword(
- "eval_loss", os.path.join(self.model_dir, "summaries/eval")))
- # No checkpoint, so global step starts from 0.
- test_runner.global_step.assign(0)
- test_controller.train_and_evaluate(
- train_steps=10, eval_steps=2, eval_interval=6)
- self.assertEqual(test_runner.global_step, 10)
-
- def test_no_checkpoint_and_summaries(self):
- test_runner = TestRunner()
- # No checkpoint + summary directories.
- test_controller = controller.Controller(
- trainer=test_runner,
- evaluator=test_runner,
- global_step=test_runner.global_step,
- steps_per_loop=2)
- test_controller.train_and_evaluate(
- train_steps=10, eval_steps=2, eval_interval=6)
- self.assertEqual(test_runner.global_step, 10)
-
- def test_has_checkpoint_no_summaries(self):
- test_runner = TestRunner()
- # Has checkpoint, but no summary directories.
- checkpoint = tf.train.Checkpoint(model=test_runner.model)
- checkpoint_manager = tf.train.CheckpointManager(
- checkpoint,
- self.model_dir,
- max_to_keep=None,
- step_counter=test_runner.global_step)
- test_controller = controller.Controller(
- trainer=test_runner,
- evaluator=test_runner,
- global_step=test_runner.global_step,
- checkpoint_manager=checkpoint_manager,
- steps_per_loop=2)
- test_controller.train_and_evaluate(
- train_steps=10, eval_steps=2, eval_interval=6)
- self.assertEqual(test_runner.global_step, 10)
-
- # No summaries are saved.
- self.assertEmpty(tf.io.gfile.glob(
- os.path.join(checkpoint_manager.directory, "events.*")))
-
- def test_has_checkpoint_eval_summary_only(self):
- test_runner = TestRunner()
- # Has checkpoint, but no summary directories.
- checkpoint = tf.train.Checkpoint(model=test_runner.model)
- checkpoint_manager = tf.train.CheckpointManager(
- checkpoint,
- self.model_dir,
- max_to_keep=None,
- step_counter=test_runner.global_step)
- test_controller = controller.Controller(
- trainer=test_runner,
- evaluator=test_runner,
- global_step=test_runner.global_step,
- checkpoint_manager=checkpoint_manager,
- eval_summary_dir=os.path.join(self.model_dir, "summaries/eval"),
- steps_per_loop=2)
- test_controller.train_and_evaluate(
- train_steps=10, eval_steps=2, eval_interval=6)
- self.assertEqual(test_runner.global_step, 10)
-
- # Training summaries are not saved.
- self.assertEmpty(tf.io.gfile.glob(
- os.path.join(checkpoint_manager.directory, "events.*")))
- # Evaluation summaries are saved.
- self.assertNotEmpty(tf.io.gfile.glob(
- os.path.join(self.model_dir, "summaries/eval/events.*")))
-
- def test_restore_from_most_recent_checkpoint(self):
- test_runner = TestRunner()
- checkpoint = tf.train.Checkpoint(model=test_runner.model)
- checkpoint_manager = tf.train.CheckpointManager(
- checkpoint,
- self.model_dir,
- max_to_keep=None,
- step_counter=test_runner.global_step,
- checkpoint_interval=5)
- test_controller = controller.Controller(
- trainer=test_runner,
- global_step=test_runner.global_step,
- checkpoint_manager=checkpoint_manager,
- eval_summary_dir=os.path.join(self.model_dir, "summaries/eval"),
- steps_per_loop=5)
- test_controller.train(20)
- self.assertLen(checkpoint_manager.checkpoints, 4)
- restored_path = test_controller.restore_checkpoint()
- self.assertEqual(restored_path, checkpoint_manager.checkpoints[-1])
-
- @parameterized.named_parameters(("return_numpy", True),
- ("return_tensor", False))
- def test_train_and_evaluate(self, return_numpy):
- test_runner = TestRunner(return_numpy=return_numpy)
-
- checkpoint = tf.train.Checkpoint(
- model=test_runner.model, optimizer=test_runner.optimizer)
- checkpoint_manager = tf.train.CheckpointManager(
- checkpoint,
- self.model_dir,
- max_to_keep=None,
- step_counter=test_runner.global_step,
- checkpoint_interval=10)
- test_controller = controller.Controller(
- trainer=test_runner,
- evaluator=test_runner,
- global_step=test_runner.global_step,
- steps_per_loop=2,
- summary_dir=os.path.join(self.model_dir, "summaries/train"),
- checkpoint_manager=checkpoint_manager,
- eval_summary_dir=os.path.join(self.model_dir, "summaries/eval"))
- test_controller.train_and_evaluate(
- train_steps=10, eval_steps=2, eval_interval=6)
-
- # Checkpoints are saved.
- self.assertNotEmpty(tf.io.gfile.glob(os.path.join(self.model_dir, "ckpt*")))
-
- # Loss and accuracy values should be written into summaries.
- self.assertNotEmpty(
- tf.io.gfile.listdir(os.path.join(self.model_dir, "summaries/train")))
- self.assertNotEmpty(
- summaries_with_matching_keyword(
- "loss", os.path.join(self.model_dir, "summaries/train")))
- self.assertNotEmpty(
- tf.io.gfile.listdir(os.path.join(self.model_dir, "summaries/eval")))
- self.assertNotEmpty(
- summaries_with_matching_keyword(
- "eval_loss", os.path.join(self.model_dir, "summaries/eval")))
-
- def test_train_only(self):
- test_runner = TestRunner()
-
- checkpoint = tf.train.Checkpoint(
- model=test_runner.model, optimizer=test_runner.optimizer)
- checkpoint_manager = tf.train.CheckpointManager(
- checkpoint,
- self.model_dir,
- max_to_keep=None,
- step_counter=test_runner.global_step,
- checkpoint_interval=10)
- test_controller = controller.Controller(
- trainer=test_runner,
- global_step=test_runner.global_step,
- steps_per_loop=2,
- summary_dir=os.path.join(self.model_dir, "summaries/train"),
- checkpoint_manager=checkpoint_manager,
- eval_summary_dir=os.path.join(self.model_dir, "summaries/eval"),
- )
- test_controller.train(steps=10)
-
- # Checkpoints are saved.
- self.assertNotEmpty(tf.io.gfile.glob(os.path.join(self.model_dir, "ckpt*")))
-
- # Only train summaries are written.
- self.assertNotEmpty(
- tf.io.gfile.listdir(os.path.join(self.model_dir, "summaries/train")))
- self.assertNotEmpty(
- summaries_with_matching_keyword(
- "loss", os.path.join(self.model_dir, "summaries/train")))
- self.assertFalse(
- tf.io.gfile.exists(os.path.join(self.model_dir, "summaries/eval")))
-
- def test_evaluate_only(self):
- test_runner = TestRunner()
-
- checkpoint = tf.train.Checkpoint(model=test_runner.model)
- checkpoint.save(os.path.join(self.model_dir, "ckpt"))
- checkpoint_manager = tf.train.CheckpointManager(
- checkpoint,
- self.model_dir,
- max_to_keep=None,
- step_counter=test_runner.global_step)
- test_controller = controller.Controller(
- evaluator=test_runner,
- global_step=test_runner.global_step,
- checkpoint_manager=checkpoint_manager,
- summary_dir=os.path.join(self.model_dir, "summaries/train"),
- eval_summary_dir=os.path.join(self.model_dir, "summaries/eval"))
- eval_results = test_controller.evaluate(steps=2)
-
- # Only eval summaries are written
- self.assertFalse(
- tf.io.gfile.exists(os.path.join(self.model_dir, "summaries/train")))
- self.assertNotEmpty(
- tf.io.gfile.listdir(os.path.join(self.model_dir, "summaries/eval")))
- self.assertNotEmpty(
- summaries_with_matching_keyword(
- "eval_loss", os.path.join(self.model_dir, "summaries/eval")))
- self.assertIn("eval_loss", eval_results)
-
- # Tests continuous eval with timeout and timeout_fn.
- done_file = os.path.join(self.model_dir, "summaries/eval/Done")
-
- def timeout_fn():
- with tf.io.gfile.GFile(done_file, "w") as f:
- f.write("DONE")
- return True
-
- test_controller = controller.Controller(
- evaluator=test_runner,
- global_step=test_runner.global_step,
- checkpoint_manager=checkpoint_manager,
- eval_summary_dir=os.path.join(self.model_dir, "summaries/eval"))
- test_controller.evaluate_continuously(
- timeout=1, timeout_fn=timeout_fn, steps=2)
- self.assertNotEmpty(tf.io.gfile.glob(done_file))
-
- def test_no_eval_steps(self):
- test_runner = TestRunner()
-
- checkpoint = tf.train.Checkpoint(model=test_runner.model)
- checkpoint.save(os.path.join(self.model_dir, "ckpt"))
- checkpoint_manager = tf.train.CheckpointManager(
- checkpoint,
- self.model_dir,
- max_to_keep=None,
- step_counter=test_runner.global_step)
- test_controller = controller.Controller(
- evaluator=test_runner,
- global_step=test_runner.global_step,
- checkpoint_manager=checkpoint_manager)
- test_controller.evaluate()
-
- def test_already_trained_model(self):
- test_runner = TestRunner()
- test_runner.global_step.assign(10)
-
- checkpoint = tf.train.Checkpoint(
- model=test_runner.model, optimizer=test_runner.optimizer)
- checkpoint_manager = tf.train.CheckpointManager(
- checkpoint,
- self.model_dir,
- max_to_keep=None,
- step_counter=test_runner.global_step,
- checkpoint_interval=10)
- test_controller = controller.Controller(
- trainer=test_runner,
- global_step=test_runner.global_step,
- steps_per_loop=2,
- checkpoint_manager=checkpoint_manager)
- # `global_step` is already `train_steps`.
- test_controller.train(steps=10)
-
- def test_summaries_inside_train_fn(self):
- test_runner = TestTrainerWithSummaries()
-
- checkpoint = tf.train.Checkpoint(
- model=test_runner.model, optimizer=test_runner.optimizer)
- checkpoint_manager = tf.train.CheckpointManager(
- checkpoint,
- self.model_dir,
- max_to_keep=None,
- step_counter=test_runner.global_step)
- test_controller = controller.Controller(
- trainer=test_runner,
- global_step=test_runner.global_step,
- steps_per_loop=2,
- summary_dir=os.path.join(self.model_dir, "summaries/train"),
- summary_interval=2,
- checkpoint_manager=checkpoint_manager,
- )
- test_controller.train(steps=10)
-
- # Checkpoints are saved.
- self.assertEmpty(tf.io.gfile.glob(os.path.join(self.model_dir, "ckpt*")))
-
- # Only train summaries are written.
- self.assertNotEmpty(
- tf.io.gfile.listdir(os.path.join(self.model_dir, "summaries/train")))
- self.assertNotEmpty(
- summaries_with_matching_keyword(
- "loss", os.path.join(self.model_dir, "summaries/train")))
- self.assertFalse(
- tf.io.gfile.exists(os.path.join(self.model_dir, "summaries/eval")))
-
- def test_train_and_evaluate_with_same_summary_dir(self):
- test_runner = TestRunner()
-
- checkpoint = tf.train.Checkpoint(
- model=test_runner.model, optimizer=test_runner.optimizer)
- checkpoint_manager = tf.train.CheckpointManager(
- checkpoint,
- self.model_dir,
- max_to_keep=None,
- step_counter=test_runner.global_step)
- test_controller = controller.Controller(
- trainer=test_runner,
- evaluator=test_runner,
- global_step=test_runner.global_step,
- steps_per_loop=2,
- summary_dir=os.path.join(self.model_dir, "summaries"),
- checkpoint_manager=checkpoint_manager,
- eval_summary_dir=os.path.join(self.model_dir, "summaries"))
- test_controller.train_and_evaluate(
- train_steps=10, eval_steps=2, eval_interval=6)
-
- # Loss and accuracy values should be written into summaries.
- self.assertNotEmpty(
- tf.io.gfile.listdir(os.path.join(self.model_dir, "summaries")))
- self.assertNotEmpty(
- summaries_with_matching_keyword(
- "loss", os.path.join(self.model_dir, "summaries")))
- self.assertNotEmpty(
- summaries_with_matching_keyword(
- "eval_loss", os.path.join(self.model_dir, "summaries")))
-
- def test_early_stop_on_eval_loss(self):
- test_runner = TestRunner()
-
- class EarlyStopController(controller.Controller):
- """A subclass of Controller that supports early stopping."""
-
- def train_and_evaluate(self,
- train_steps: int = None,
- eval_steps: int = None,
- eval_interval: int = None):
- while self.global_step.numpy() < train_steps:
- interval = min(train_steps - self.global_step.numpy(), eval_interval)
- num_steps = self.global_step.numpy() + interval
- self.train(steps=num_steps, checkpoint_at_completion=False)
- self.evaluate(steps=eval_steps)
- # Early stop condition.
- if test_runner.eval_loss.result() < 0.1:
- logging.info(
- "Training early stopped as eval_loss %s is less than 0.1",
- test_runner.eval_loss.result())
- return
-
- checkpoint = tf.train.Checkpoint(
- model=test_runner.model, optimizer=test_runner.optimizer)
- checkpoint_manager = tf.train.CheckpointManager(
- checkpoint,
- self.model_dir,
- max_to_keep=None,
- step_counter=test_runner.global_step,
- checkpoint_interval=10)
- test_controller = EarlyStopController(
- trainer=test_runner,
- evaluator=test_runner,
- global_step=test_runner.global_step,
- steps_per_loop=2,
- checkpoint_manager=checkpoint_manager)
- test_controller.train_and_evaluate(
- train_steps=10, eval_steps=6, eval_interval=2)
-
- self.assertLess(test_runner.global_step, 10)
-
- def test_evaluate_with_loss_output(self):
- test_evaluator = TestEvaluator()
-
- checkpoint = tf.train.Checkpoint(model=test_evaluator.model)
- checkpoint.save(os.path.join(self.model_dir, "ckpt"))
- checkpoint_manager = tf.train.CheckpointManager(
- checkpoint, self.model_dir, max_to_keep=None)
- test_controller = controller.Controller(
- evaluator=test_evaluator,
- global_step=tf.Variable(0, dtype=tf.int64),
- checkpoint_manager=checkpoint_manager,
- eval_summary_dir=os.path.join(self.model_dir, "summaries/eval"))
- test_controller.evaluate(steps=5)
-
- # Only eval summaries are written
- self.assertNotEmpty(
- tf.io.gfile.listdir(os.path.join(self.model_dir, "summaries/eval")))
- self.assertNotEmpty(
- summaries_with_matching_keyword(
- "eval_loss", os.path.join(self.model_dir, "summaries/eval")))
-
- def test_evaluate_with_no_output(self):
- test_controller = controller.Controller(
- evaluator=TestEvaluatorNoOutput(),
- global_step=tf.Variable(0, dtype=tf.int64),
- eval_summary_dir=os.path.join(self.model_dir, "summaries/eval"))
- self.assertEqual(test_controller.evaluate(steps=5), {})
-
- def test_train_and_evaluate_reset_datasets(self):
- test_runner = TestRunner()
-
- test_controller = controller.Controller(
- trainer=test_runner,
- evaluator=test_runner,
- global_step=test_runner.global_step,
- steps_per_loop=2)
-
- test_controller.train_and_evaluate(
- train_steps=10, eval_steps=2, eval_interval=6)
-
- train_dataset = (
- test_runner.strategy.distribute_datasets_from_function(dataset_fn))
- eval_dataset = (
- test_runner.strategy.distribute_datasets_from_function(dataset_fn))
- test_runner.train_dataset = train_dataset
- test_runner.eval_dataset = eval_dataset
-
- test_controller.train_and_evaluate(
- train_steps=10, eval_steps=2, eval_interval=6)
-
- def test_eval_and_checkpoint_interval(self):
- test_runner = TestRunner()
-
- checkpoint = tf.train.Checkpoint(
- model=test_runner.model, optimizer=test_runner.optimizer)
- checkpoint_manager = tf.train.CheckpointManager(
- checkpoint,
- self.model_dir,
- max_to_keep=None,
- step_counter=test_runner.global_step,
- checkpoint_interval=5)
- test_controller = controller.Controller(
- trainer=test_runner,
- evaluator=test_runner,
- global_step=test_runner.global_step,
- steps_per_loop=10,
- checkpoint_manager=checkpoint_manager,
- summary_dir=self.model_dir)
- test_controller.train_and_evaluate(
- train_steps=10, eval_steps=2, eval_interval=5)
-
- # Expect 3 checkpoints to be saved at step: 5, 10.
- self.assertLen(
- tf.io.gfile.glob(os.path.join(self.model_dir, "ckpt-*.data*")), 2)
- # Expect evaluation is performed 2 times at step: 5, 10.
- self.assertLen(
- summaries_with_matching_keyword("eval_loss", self.model_dir), 2)
-
- @parameterized.named_parameters(("DefaultSummary", False),
- ("InjectSummary", True))
- def test_evaluate_with_nested_summaries(self, inject_summary_manager):
- test_evaluator = TestEvaluatorWithNestedSummary()
- if inject_summary_manager:
- summary_manager = orbit.utils.SummaryManager(
- self.model_dir,
- tf.summary.scalar,
- global_step=tf.Variable(0, dtype=tf.int64))
- else:
- summary_manager = None
- test_controller = controller.Controller(
- evaluator=test_evaluator,
- global_step=tf.Variable(0, dtype=tf.int64),
- eval_summary_dir=self.model_dir,
- summary_manager=summary_manager)
- test_controller.evaluate(steps=5)
-
- self.assertNotEmpty(
- tf.io.gfile.listdir(os.path.join(self.model_dir, "dataset")))
- self.assertNotEmpty(
- summaries_with_matching_keyword(
- "loss", os.path.join(self.model_dir, "dataset")))
- self.assertNotEmpty(
- summaries_with_matching_keyword(
- "accuracy", os.path.join(self.model_dir, "dataset")))
-
- self.assertNotEmpty(
- tf.io.gfile.listdir(os.path.join(self.model_dir, "dataset2")))
- self.assertNotEmpty(
- summaries_with_matching_keyword(
- "loss", os.path.join(self.model_dir, "dataset2")))
- self.assertNotEmpty(
- summaries_with_matching_keyword(
- "accuracy", os.path.join(self.model_dir, "dataset2")))
-
- def test_actions(self):
- test_runner = TestRunner()
- checkpoint = tf.train.Checkpoint(
- model=test_runner.model, optimizer=test_runner.optimizer)
- checkpoint_manager = tf.train.CheckpointManager(
- checkpoint,
- self.model_dir,
- max_to_keep=None,
- step_counter=test_runner.global_step,
- checkpoint_interval=10)
-
- class OutputRecorderAction:
- """Simple `Action` that just saves the outputs passed to `__call__`."""
-
- def __init__(self):
- self.outputs = []
-
- def __call__(self, output):
- self.outputs.append(output)
-
- train_output_recorder = OutputRecorderAction()
- eval_output_recorder = OutputRecorderAction()
-
- test_controller = controller.Controller(
- trainer=test_runner,
- evaluator=test_runner,
- train_actions=[train_output_recorder],
- eval_actions=[eval_output_recorder],
- global_step=test_runner.global_step,
- steps_per_loop=2,
- summary_dir=os.path.join(self.model_dir, "summaries/train"),
- checkpoint_manager=checkpoint_manager,
- eval_summary_dir=os.path.join(self.model_dir, "summaries/eval"))
- test_controller.train_and_evaluate(
- train_steps=10, eval_steps=2, eval_interval=6)
-
- self.assertLen(train_output_recorder.outputs, 5)
- for output in train_output_recorder.outputs:
- self.assertIn("loss", output)
- self.assertGreaterEqual(output["loss"], 0)
-
- self.assertLen(eval_output_recorder.outputs, 2)
- for output in eval_output_recorder.outputs:
- self.assertIn("eval_loss", output)
- self.assertGreaterEqual(output["eval_loss"], 0)
-
- def test_step_per_loop_callable(self):
- test_runner = TestRunner()
-
- checkpoint = tf.train.Checkpoint(
- model=test_runner.model, optimizer=test_runner.optimizer)
- checkpoint_manager = tf.train.CheckpointManager(
- checkpoint,
- self.model_dir,
- max_to_keep=None,
- step_counter=test_runner.global_step,
- checkpoint_interval=10)
-
- def steps_per_loop_fn(global_step):
- if global_step > 4:
- return 4
- return 2
-
- test_controller = controller.Controller(
- trainer=test_runner,
- global_step=test_runner.global_step,
- steps_per_loop=steps_per_loop_fn,
- checkpoint_manager=checkpoint_manager,
- )
- test_controller.train(steps=10)
- self.assertEqual(test_runner.global_step, 10)
-
-
- if __name__ == "__main__":
- tf.test.main()
|