|
- import os
- import pytest
- import numpy as np
- import mindspore as ms
-
- _TEST_ERROR = os.environ.get('TEST_ERROR')
-
- _skip_test_error = False if _TEST_ERROR and _TEST_ERROR.upper() in ('TRUE', '1') else True
-
- def is_test_under_cpu_context():
- return _skip_test_error and ms.context.get_context("device_target").upper() == 'CPU'
-
- def is_test_under_gpu_context():
- return _skip_test_error and ms.context.get_context("device_target").upper() == 'GPU'
-
- def is_test_under_ascend_context():
- return _skip_test_error and ms.context.get_context("device_target").upper() == 'ASCEND'
-
-
- def SKIP_ENV_CPU(reason):
- return pytest.mark.skipif(condition=is_test_under_cpu_context(), reason=reason)
-
- def SKIP_ENV_GPU(reason):
- return pytest.mark.skipif(condition=is_test_under_gpu_context(), reason=reason)
-
- def SKIP_ENV_ASCEND(reason):
- return pytest.mark.skipif(condition=is_test_under_ascend_context(), reason=reason)
-
-
- _MODE_ENV = os.environ.get('TEST_MODE')
-
- def set_mode_by_env_config():
- if _MODE_ENV is not None and _MODE_ENV == "1":
- ms.context.set_context(mode=ms.GRAPH_MODE)
- else:
- ms.context.set_context(mode=ms.PYNATIVE_MODE)
-
-
- def is_test_under_graph_context():
- return ms.context.get_context("mode") == ms.GRAPH_MODE
-
- def is_test_under_pynative_context():
- return ms.context.get_context("mode") == ms.PYNATIVE_MODE
-
- def SKIP_ENV_GRAPH_MODE(reason):
- return pytest.mark.skipif(condition=is_test_under_graph_context(), reason=reason)
-
- def SKIP_ENV_PYNATIVE_MODE(reason):
- return pytest.mark.skipif(condition=is_test_under_pynative_context(), reason=reason)
-
- def SKIP_ENV_ASCEND_GRAPH_MODE(reason):
- return pytest.mark.skipif(condition=is_test_under_graph_context() \
- and is_test_under_ascend_context(), reason=reason)
-
- def _param_compare(a1, b1, rtol=1e-5, atol=1e-8, equal_nan=False):
- a1 = a1.numpy()
- b1 = b1.numpy()
- assert np.allclose(a1, b1, rtol=rtol, atol=atol, equal_nan=equal_nan)
- assert a1.dtype == b1.dtype
- assert a1.shape == b1.shape
-
- def param_compare(a1, b1, rtol=1e-5, atol=1e-8, equal_nan=False):
- if isinstance(a1, (tuple, list)) or isinstance(b1, (tuple, list)):
- assert len(a1) == len(b1)
- for i in range(len(a1)):
- _param_compare(a1[i], b1[i], rtol=rtol, atol=atol, equal_nan=equal_nan)
- else:
- _param_compare(a1, b1, rtol=rtol, atol=atol, equal_nan=equal_nan)
-
- def type_shape_compare(a1, b1):
- a1 = a1.numpy()
- b1 = b1.numpy()
- assert a1.dtype == b1.dtype
- assert a1.shape == b1.shape
|