|
- # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
-
- import unittest
-
- import numpy as np
-
- import paddle
- from paddle import base
- from paddle.base import core
- from paddle.pir_utils import test_with_pir_api
-
-
- class TestTakeAPI(unittest.TestCase):
- def set_mode(self):
- self.mode = 'raise'
-
- def set_dtype(self):
- self.input_dtype = 'float64'
- self.index_dtype = 'int64'
-
- def set_input(self):
- self.input_shape = [3, 4]
- self.index_shape = [2, 3]
- self.input_np = (
- np.arange(0, 12).reshape(self.input_shape).astype(self.input_dtype)
- )
- self.index_np = (
- np.arange(-4, 2).reshape(self.index_shape).astype(self.index_dtype)
- )
-
- def setUp(self):
- self.set_mode()
- self.set_dtype()
- self.set_input()
- self.place = (
- base.CUDAPlace(0)
- if core.is_compiled_with_cuda()
- else base.CPUPlace()
- )
-
- @test_with_pir_api
- def test_static_graph(self):
- paddle.enable_static()
- startup_program = paddle.static.Program()
- train_program = paddle.static.Program()
- with paddle.static.program_guard(startup_program, train_program):
- x = paddle.static.data(
- name='input', dtype=self.input_dtype, shape=self.input_shape
- )
- index = paddle.static.data(
- name='index', dtype=self.index_dtype, shape=self.index_shape
- )
- out = paddle.take(x, index, mode=self.mode)
-
- exe = paddle.static.Executor(self.place)
- st_result = exe.run(
- paddle.static.default_main_program(),
- feed={'input': self.input_np, 'index': self.index_np},
- fetch_list=out,
- )
- np.testing.assert_allclose(
- st_result[0],
- np.take(self.input_np, self.index_np, mode=self.mode),
- )
-
- def test_dygraph(self):
- paddle.disable_static(self.place)
- x = paddle.to_tensor(self.input_np)
- index = paddle.to_tensor(self.index_np)
- dy_result = paddle.take(x, index, mode=self.mode)
- np.testing.assert_allclose(
- np.take(self.input_np, self.index_np, mode=self.mode),
- dy_result.numpy(),
- )
-
-
- class TestTakeInt32(TestTakeAPI):
- """Test take API with data type int32"""
-
- def set_dtype(self):
- self.input_dtype = 'int32'
- self.index_dtype = 'int64'
-
-
- class TestTakeInt64(TestTakeAPI):
- """Test take API with data type int64"""
-
- def set_dtype(self):
- self.input_dtype = 'int64'
- self.index_dtype = 'int64'
-
-
- class TestTakeFloat32(TestTakeAPI):
- """Test take API with data type float32"""
-
- def set_dtype(self):
- self.input_dtype = 'float32'
- self.index_dtype = 'int64'
-
-
- class TestTakeTypeError(TestTakeAPI):
- """Test take Type Error"""
-
- @test_with_pir_api
- def test_static_type_error(self):
- """Argument 'index' must be Tensor"""
- paddle.enable_static()
- with paddle.static.program_guard(paddle.static.Program()):
- x = paddle.static.data(
- name='input', dtype=self.input_dtype, shape=self.input_shape
- )
- self.assertRaises(
- TypeError, paddle.take, x, self.index_np, self.mode
- )
-
- def test_dygraph_type_error(self):
- paddle.disable_static(self.place)
- x = paddle.to_tensor(self.input_np)
- self.assertRaises(TypeError, paddle.take, x, self.index_np, self.mode)
-
- @test_with_pir_api
- def test_static_dtype_error(self):
- """Data type of argument 'index' must be in [paddle.int32, paddle.int64]"""
- paddle.enable_static()
- with paddle.static.program_guard(paddle.static.Program()):
- x = paddle.static.data(
- name='input', dtype='float64', shape=self.input_shape
- )
- index = paddle.static.data(
- name='index', dtype='float32', shape=self.index_shape
- )
- self.assertRaises(TypeError, paddle.take, x, index, self.mode)
-
- def test_dygraph_dtype_error(self):
- paddle.disable_static(self.place)
- x = paddle.to_tensor(self.input_np)
- index = paddle.to_tensor(self.index_np, dtype='float32')
- self.assertRaises(TypeError, paddle.take, x, index, self.mode)
-
-
- class TestTakeModeRaisePos(unittest.TestCase):
- """Test positive index out of range error"""
-
- def set_mode(self):
- self.mode = 'raise'
-
- def set_dtype(self):
- self.input_dtype = 'float64'
- self.index_dtype = 'int64'
-
- def set_input(self):
- self.input_shape = [3, 4]
- self.index_shape = [5, 6]
- self.input_np = (
- np.arange(0, 12).reshape(self.input_shape).astype(self.input_dtype)
- )
- self.index_np = (
- np.arange(-10, 20)
- .reshape(self.index_shape)
- .astype(self.index_dtype)
- ) # positive indices are out of range
-
- def setUp(self):
- self.set_mode()
- self.set_dtype()
- self.set_input()
- self.place = (
- base.CUDAPlace(0)
- if core.is_compiled_with_cuda()
- else base.CPUPlace()
- )
-
- @test_with_pir_api
- def test_static_index_error(self):
- """When the index is out of range,
- an error is reported directly through `paddle.index_select`"""
- paddle.enable_static()
- with paddle.static.program_guard(paddle.static.Program()):
- x = paddle.static.data(
- name='input', dtype=self.input_dtype, shape=self.input_shape
- )
- index = paddle.static.data(
- name='index', dtype=self.index_dtype, shape=self.index_shape
- )
- self.assertRaises(ValueError, paddle.index_select, x, index)
-
- def test_dygraph_index_error(self):
- paddle.disable_static(self.place)
- x = paddle.to_tensor(self.input_np)
- index = paddle.to_tensor(self.index_np, dtype=self.index_dtype)
- self.assertRaises(ValueError, paddle.index_select, x, index)
-
-
- class TestTakeModeRaiseNeg(TestTakeModeRaisePos):
- """Test negative index out of range error"""
-
- def set_mode(self):
- self.mode = 'raise'
-
- def set_dtype(self):
- self.input_dtype = 'float64'
- self.index_dtype = 'int64'
-
- def set_input(self):
- self.input_shape = [3, 4]
- self.index_shape = [5, 6]
- self.input_np = (
- np.arange(0, 12).reshape(self.input_shape).astype(self.input_dtype)
- )
- self.index_np = (
- np.arange(-20, 10)
- .reshape(self.index_shape)
- .astype(self.index_dtype)
- ) # negative indices are out of range
-
- def setUp(self):
- self.set_mode()
- self.set_dtype()
- self.set_input()
- self.place = (
- base.CUDAPlace(0)
- if core.is_compiled_with_cuda()
- else base.CPUPlace()
- )
-
-
- class TestTakeModeWrap(TestTakeAPI):
- """Test take index out of range mode"""
-
- def set_mode(self):
- self.mode = 'wrap'
-
- def set_input(self):
- self.input_shape = [3, 4]
- self.index_shape = [5, 8]
- self.input_np = (
- np.arange(0, 12).reshape(self.input_shape).astype(self.input_dtype)
- )
- self.index_np = (
- np.arange(-20, 20)
- .reshape(self.index_shape)
- .astype(self.index_dtype)
- ) # Both ends of the index are out of bounds
-
-
- class TestTakeModeClip(TestTakeAPI):
- """Test take index out of range mode"""
-
- def set_mode(self):
- self.mode = 'clip'
-
- def set_input(self):
- self.input_shape = [3, 4]
- self.index_shape = [5, 8]
- self.input_np = (
- np.arange(0, 12).reshape(self.input_shape).astype(self.input_dtype)
- )
- self.index_np = (
- np.arange(-20, 20)
- .reshape(self.index_shape)
- .astype(self.index_dtype)
- ) # Both ends of the index are out of bounds
-
-
- if __name__ == "__main__":
- unittest.main()
|