|
- # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
-
- import unittest
-
- import config
- import numpy as np
- import utils
-
- import paddle
-
-
- @utils.place(config.DEVICES)
- @utils.parameterize(
- (utils.TEST_CASE_NAME, 'fun', 'args', 'dtype'),
- (
- ('unary_float32', paddle.tanh, (np.random.rand(2, 3),), 'float32'),
- (
- 'binary_float32',
- paddle.matmul,
- (np.random.rand(2, 3), np.random.rand(3, 2)),
- 'float32',
- ),
- ('unary_float64', paddle.tanh, (np.random.rand(2, 3),), 'float64'),
- (
- 'binary_float64',
- paddle.matmul,
- (np.random.rand(2, 3), np.random.rand(3, 2)),
- 'float64',
- ),
- ),
- )
- class TestJacobianPrim(unittest.TestCase):
- @classmethod
- def setUpClass(cls):
- cls.args = [arg.astype(cls.dtype) for arg in cls.args]
- cls._rtol = (
- config.TOLERANCE.get(cls.dtype).get('first_order_grad').get('rtol')
- )
- cls._atol = (
- config.TOLERANCE.get(cls.dtype).get('first_order_grad').get('atol')
- )
-
- def setUp(self):
- paddle.enable_static()
- paddle.incubate.autograd.enable_prim()
-
- def tearDown(self):
- paddle.incubate.autograd.disable_prim()
- paddle.disable_static()
-
- def test_jacobian_prim(self):
- def wrapper(fun, args):
- mp = paddle.static.Program()
- sp = paddle.static.Program()
- with paddle.static.program_guard(mp, sp):
- static_args = [
- paddle.static.data(f'arg{i}', arg.shape, self.dtype)
- for i, arg in enumerate(args)
- ]
- for arg in static_args:
- arg.stop_gradient = False
- jac = paddle.incubate.autograd.Jacobian(fun, static_args)[:]
- if paddle.incubate.autograd.prim_enabled():
- paddle.incubate.autograd.prim2orig()
- exe = paddle.static.Executor()
- exe.run(sp)
- [jac] = exe.run(
- mp,
- feed={f'arg{i}': arg for i, arg in enumerate(args)},
- fetch_list=[jac],
- )
- return jac
-
- paddle.incubate.autograd.enable_prim()
- prim_jac = wrapper(self.fun, self.args)
- paddle.incubate.autograd.disable_prim()
- orig_jac = wrapper(self.fun, self.args)
-
- np.testing.assert_allclose(
- orig_jac, prim_jac, rtol=self._rtol, atol=self._atol
- )
-
-
- @utils.place(config.DEVICES)
- @utils.parameterize(
- (utils.TEST_CASE_NAME, 'fun', 'args', 'dtype'),
- (
- ('unary_float32', paddle.tanh, (np.random.rand(1),), 'float32'),
- (
- 'binary_float32',
- paddle.multiply,
- (np.random.rand(1), np.random.rand(1)),
- 'float32',
- ),
- ('unary_float64', paddle.tanh, (np.random.rand(1),), 'float64'),
- (
- 'binary_float64',
- paddle.multiply,
- (np.random.rand(1), np.random.rand(1)),
- 'float64',
- ),
- ),
- )
- class TestHessianPrim(unittest.TestCase):
- @classmethod
- def setUpClass(cls):
- cls.args = [arg.astype(cls.dtype) for arg in cls.args]
- cls._rtol = (
- config.TOLERANCE.get(cls.dtype).get('second_order_grad').get('rtol')
- )
- cls._atol = (
- config.TOLERANCE.get(cls.dtype).get('second_order_grad').get('atol')
- )
-
- def setUp(self):
- paddle.enable_static()
- paddle.incubate.autograd.enable_prim()
-
- def tearDown(self):
- paddle.incubate.autograd.disable_prim()
- paddle.disable_static()
-
- def test_jacobian_prim(self):
- def wrapper(fun, args):
- mp = paddle.static.Program()
- sp = paddle.static.Program()
- with paddle.static.program_guard(mp, sp):
- static_args = [
- paddle.static.data(f'arg{i}', arg.shape, self.dtype)
- for i, arg in enumerate(args)
- ]
- for arg in static_args:
- arg.stop_gradient = False
- hessian = paddle.incubate.autograd.Hessian(fun, static_args)[:]
- if paddle.incubate.autograd.prim_enabled():
- paddle.incubate.autograd.prim2orig()
- exe = paddle.static.Executor()
- exe.run(sp)
- [hessian] = exe.run(
- mp,
- feed={f'arg{i}': arg for i, arg in enumerate(args)},
- fetch_list=[hessian],
- )
- return hessian
-
- paddle.incubate.autograd.enable_prim()
- prim_jac = wrapper(self.fun, self.args)
- paddle.incubate.autograd.disable_prim()
- orig_jac = wrapper(self.fun, self.args)
-
- np.testing.assert_allclose(
- orig_jac, prim_jac, rtol=self._rtol, atol=self._atol
- )
-
-
- @utils.place(config.DEVICES)
- @utils.parameterize(
- (utils.TEST_CASE_NAME, 'fun', 'args', 'dtype'),
- (
- ('unary_float32', paddle.tanh, (np.random.rand(2, 3),), 'float32'),
- (
- 'binary_float32',
- paddle.matmul,
- (np.random.rand(2, 3), np.random.rand(3, 2)),
- 'float32',
- ),
- ('unary_float64', paddle.tanh, (np.random.rand(2, 3),), 'float64'),
- (
- 'binary_float64',
- paddle.matmul,
- (np.random.rand(2, 3), np.random.rand(3, 2)),
- 'float64',
- ),
- ),
- )
- class TestJvpPrim(unittest.TestCase):
- @classmethod
- def setUpClass(cls):
- cls.args = [arg.astype(cls.dtype) for arg in cls.args]
- cls._rtol = (
- config.TOLERANCE.get(cls.dtype).get('first_order_grad').get('rtol')
- )
- cls._atol = (
- config.TOLERANCE.get(cls.dtype).get('first_order_grad').get('atol')
- )
-
- def setUp(self):
- paddle.enable_static()
- paddle.incubate.autograd.enable_prim()
-
- def tearDown(self):
- paddle.incubate.autograd.disable_prim()
- paddle.disable_static()
-
- def test_jacobian_prim(self):
- def wrapper(fun, args):
- mp = paddle.static.Program()
- sp = paddle.static.Program()
- with paddle.static.program_guard(mp, sp):
- static_args = [
- paddle.static.data(f'arg{i}', arg.shape, self.dtype)
- for i, arg in enumerate(args)
- ]
- for arg in static_args:
- arg.stop_gradient = False
- _, jvp_res = paddle.incubate.autograd.jvp(fun, static_args)
- if paddle.incubate.autograd.prim_enabled():
- paddle.incubate.autograd.prim2orig()
- exe = paddle.static.Executor()
- exe.run(sp)
- jvp_res = exe.run(
- mp,
- feed={f'arg{i}': arg for i, arg in enumerate(args)},
- fetch_list=[jvp_res],
- )
- return jvp_res
-
- paddle.incubate.autograd.enable_prim()
- prim_jvp = wrapper(self.fun, self.args)
- paddle.incubate.autograd.disable_prim()
- orig_jvp = wrapper(self.fun, self.args)
-
- np.testing.assert_allclose(
- orig_jvp, prim_jvp, rtol=self._rtol, atol=self._atol
- )
-
-
- @utils.place(config.DEVICES)
- @utils.parameterize(
- (utils.TEST_CASE_NAME, 'fun', 'args', 'dtype'),
- (
- ('unary_float32', paddle.tanh, (np.random.rand(2, 3),), 'float32'),
- (
- 'binary_float32',
- paddle.matmul,
- (np.random.rand(2, 3), np.random.rand(3, 2)),
- 'float32',
- ),
- ('unary_float64', paddle.tanh, (np.random.rand(2, 3),), 'float64'),
- (
- 'binary_float64',
- paddle.matmul,
- (np.random.rand(2, 3), np.random.rand(3, 2)),
- 'float64',
- ),
- ),
- )
- class TestVjpPrim(unittest.TestCase):
- @classmethod
- def setUpClass(cls):
- cls.args = [arg.astype(cls.dtype) for arg in cls.args]
- cls._rtol = (
- config.TOLERANCE.get(cls.dtype).get('first_order_grad').get('rtol')
- )
- cls._atol = (
- config.TOLERANCE.get(cls.dtype).get('first_order_grad').get('atol')
- )
-
- def setUp(self):
- paddle.enable_static()
- paddle.incubate.autograd.enable_prim()
-
- def tearDown(self):
- paddle.incubate.autograd.disable_prim()
- paddle.disable_static()
-
- def test_jacobian_prim(self):
- def wrapper(fun, args):
- mp = paddle.static.Program()
- sp = paddle.static.Program()
- with paddle.static.program_guard(mp, sp):
- static_args = [
- paddle.static.data(f'arg{i}', arg.shape, self.dtype)
- for i, arg in enumerate(args)
- ]
- for arg in static_args:
- arg.stop_gradient = False
- _, vjp_res = paddle.incubate.autograd.vjp(fun, static_args)
- if paddle.incubate.autograd.prim_enabled():
- paddle.incubate.autograd.prim2orig()
- exe = paddle.static.Executor()
- exe.run(sp)
- vjp_res = exe.run(
- mp,
- feed={f'arg{i}': arg for i, arg in enumerate(args)},
- fetch_list=[vjp_res],
- )
- return vjp_res
-
- paddle.incubate.autograd.enable_prim()
- prim_vjp = wrapper(self.fun, self.args)
- paddle.incubate.autograd.disable_prim()
- orig_vjp = wrapper(self.fun, self.args)
-
- for orig, prim in zip(orig_vjp, prim_vjp):
- np.testing.assert_allclose(
- orig, prim, rtol=self._rtol, atol=self._atol
- )
-
-
- if __name__ == "__main__":
- unittest.main()
|