|
- # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License
-
- import unittest
-
- from paddle.distributed.auto_parallel.static.process_mesh_v2 import (
- ProcessMesh,
- compute_compatible_process_mesh,
- merge_process_mesh,
- )
-
-
- class TestProcessMesh(unittest.TestCase):
- def test_process_mesh(self):
- mesh = [[0, 1, 2], [3, 4, 5]]
- mesh2 = [[0, 1], [2, 3]]
- process_mesh = ProcessMesh(mesh, dim_names=["x", "y"])
- process_mesh2 = ProcessMesh(mesh2)
- self.assertEqual(process_mesh.shape, [2, 3])
- self.assertEqual(process_mesh.process_ids, [0, 1, 2, 3, 4, 5])
- self.assertEqual(process_mesh.dim_names, ["x", "y"])
- self.assertEqual(process_mesh.size, 6)
- self.assertEqual(process_mesh.ndim, 2)
- self.assertEqual(process_mesh.dim_size(0), 2)
- self.assertEqual(process_mesh.dim_size(-1), 3)
- self.assertEqual(process_mesh.dim_size("x"), 2)
- self.assertEqual(process_mesh.dim_size("y"), 3)
- self.assertEqual(process_mesh.empty(), False)
- self.assertEqual(process_mesh.contains(0), True)
- self.assertEqual(process_mesh.contains(6), False)
- self.assertEqual(process_mesh, process_mesh)
- self.assertNotEqual(process_mesh, process_mesh2)
- self.assertEqual(str(process_mesh), str(process_mesh))
-
- def test_compute_compatible_process_mesh(self):
- process_mesh1 = ProcessMesh(
- [[0, 1, 2], [3, 4, 5]], dim_names=["x", "y"]
- )
- compatible_process_mesh = compute_compatible_process_mesh(
- [process_mesh1, None]
- )
- self.assertEqual(compatible_process_mesh, process_mesh1)
- compatible_process_mesh = compute_compatible_process_mesh(
- [None, process_mesh1]
- )
- self.assertEqual(compatible_process_mesh, process_mesh1)
-
- process_mesh2 = ProcessMesh([[0, 1, 2], [3, 4, 5]])
- compatible_process_mesh = compute_compatible_process_mesh(
- [process_mesh1, process_mesh2]
- )
- self.assertEqual(compatible_process_mesh, process_mesh1)
- self.assertEqual(compatible_process_mesh, process_mesh2)
-
- process_mesh2 = ProcessMesh([[0, 1, 2, 3, 4, 5]])
- compatible_process_mesh = compute_compatible_process_mesh(
- [process_mesh1, process_mesh2]
- )
- self.assertEqual(compatible_process_mesh, process_mesh1)
-
- process_mesh2 = ProcessMesh([[0, 1, 2]])
- compatible_process_mesh = compute_compatible_process_mesh(
- [process_mesh1, process_mesh2]
- )
- self.assertEqual(compatible_process_mesh, process_mesh1)
-
- def test_merge_process_mesh(self):
- process_mesh1 = ProcessMesh(
- [[0, 1, 2], [3, 4, 5]], dim_names=["x", "y"]
- )
- merged_process_mesh = merge_process_mesh([process_mesh1, None])
- self.assertEqual(merged_process_mesh, ProcessMesh([0, 1, 2, 3, 4, 5]))
- merged_process_mesh = merge_process_mesh([None, process_mesh1])
- self.assertEqual(merged_process_mesh, ProcessMesh([0, 1, 2, 3, 4, 5]))
-
- process_mesh2 = ProcessMesh([[0, 1, 2], [3, 4, 5]])
- merged_process_mesh = merge_process_mesh([process_mesh1, process_mesh2])
- self.assertEqual(merged_process_mesh, ProcessMesh([0, 1, 2, 3, 4, 5]))
-
- process_mesh2 = ProcessMesh([[0, 1, 2]])
- merged_process_mesh = merge_process_mesh([process_mesh1, process_mesh2])
- self.assertEqual(merged_process_mesh, ProcessMesh([0, 1, 2, 3, 4, 5]))
-
- process_mesh2 = ProcessMesh([[6, 7]])
- merged_process_mesh = merge_process_mesh([process_mesh1, process_mesh2])
- self.assertEqual(
- merged_process_mesh, ProcessMesh([0, 1, 2, 3, 4, 5, 6, 7])
- )
-
-
- if __name__ == "__main__":
- unittest.main()
|