jpayne@69
|
1 # Pytest customization
|
jpayne@69
|
2 import os
|
jpayne@69
|
3 import pytest
|
jpayne@69
|
4 import warnings
|
jpayne@69
|
5
|
jpayne@69
|
6 import numpy as np
|
jpayne@69
|
7 import numpy.testing as npt
|
jpayne@69
|
8 from scipy._lib._fpumode import get_fpu_mode
|
jpayne@69
|
9 from scipy._lib._testutils import FPUModeChangeWarning
|
jpayne@69
|
10 from scipy._lib import _pep440
|
jpayne@69
|
11
|
jpayne@69
|
12
|
jpayne@69
|
13 def pytest_configure(config):
|
jpayne@69
|
14 config.addinivalue_line("markers",
|
jpayne@69
|
15 "slow: Tests that are very slow.")
|
jpayne@69
|
16 config.addinivalue_line("markers",
|
jpayne@69
|
17 "xslow: mark test as extremely slow (not run unless explicitly requested)")
|
jpayne@69
|
18 config.addinivalue_line("markers",
|
jpayne@69
|
19 "xfail_on_32bit: mark test as failing on 32-bit platforms")
|
jpayne@69
|
20 try:
|
jpayne@69
|
21 import pytest_timeout # noqa:F401
|
jpayne@69
|
22 except Exception:
|
jpayne@69
|
23 config.addinivalue_line(
|
jpayne@69
|
24 "markers", 'timeout: mark a test for a non-default timeout')
|
jpayne@69
|
25
|
jpayne@69
|
26
|
jpayne@69
|
27 def _get_mark(item, name):
|
jpayne@69
|
28 if _pep440.parse(pytest.__version__) >= _pep440.Version("3.6.0"):
|
jpayne@69
|
29 mark = item.get_closest_marker(name)
|
jpayne@69
|
30 else:
|
jpayne@69
|
31 mark = item.get_marker(name)
|
jpayne@69
|
32 return mark
|
jpayne@69
|
33
|
jpayne@69
|
34
|
jpayne@69
|
35 def pytest_runtest_setup(item):
|
jpayne@69
|
36 mark = _get_mark(item, "xslow")
|
jpayne@69
|
37 if mark is not None:
|
jpayne@69
|
38 try:
|
jpayne@69
|
39 v = int(os.environ.get('SCIPY_XSLOW', '0'))
|
jpayne@69
|
40 except ValueError:
|
jpayne@69
|
41 v = False
|
jpayne@69
|
42 if not v:
|
jpayne@69
|
43 pytest.skip("very slow test; set environment variable SCIPY_XSLOW=1 to run it")
|
jpayne@69
|
44 mark = _get_mark(item, 'xfail_on_32bit')
|
jpayne@69
|
45 if mark is not None and np.intp(0).itemsize < 8:
|
jpayne@69
|
46 pytest.xfail('Fails on our 32-bit test platform(s): %s' % (mark.args[0],))
|
jpayne@69
|
47
|
jpayne@69
|
48 # Older versions of threadpoolctl have an issue that may lead to this
|
jpayne@69
|
49 # warning being emitted, see gh-14441
|
jpayne@69
|
50 with npt.suppress_warnings() as sup:
|
jpayne@69
|
51 sup.filter(pytest.PytestUnraisableExceptionWarning)
|
jpayne@69
|
52
|
jpayne@69
|
53 try:
|
jpayne@69
|
54 from threadpoolctl import threadpool_limits
|
jpayne@69
|
55
|
jpayne@69
|
56 HAS_THREADPOOLCTL = True
|
jpayne@69
|
57 except Exception: # observed in gh-14441: (ImportError, AttributeError)
|
jpayne@69
|
58 # Optional dependency only. All exceptions are caught, for robustness
|
jpayne@69
|
59 HAS_THREADPOOLCTL = False
|
jpayne@69
|
60
|
jpayne@69
|
61 if HAS_THREADPOOLCTL:
|
jpayne@69
|
62 # Set the number of openmp threads based on the number of workers
|
jpayne@69
|
63 # xdist is using to prevent oversubscription. Simplified version of what
|
jpayne@69
|
64 # sklearn does (it can rely on threadpoolctl and its builtin OpenMP helper
|
jpayne@69
|
65 # functions)
|
jpayne@69
|
66 try:
|
jpayne@69
|
67 xdist_worker_count = int(os.environ['PYTEST_XDIST_WORKER_COUNT'])
|
jpayne@69
|
68 except KeyError:
|
jpayne@69
|
69 # raises when pytest-xdist is not installed
|
jpayne@69
|
70 return
|
jpayne@69
|
71
|
jpayne@69
|
72 if not os.getenv('OMP_NUM_THREADS'):
|
jpayne@69
|
73 max_openmp_threads = os.cpu_count() // 2 # use nr of physical cores
|
jpayne@69
|
74 threads_per_worker = max(max_openmp_threads // xdist_worker_count, 1)
|
jpayne@69
|
75 try:
|
jpayne@69
|
76 threadpool_limits(threads_per_worker, user_api='blas')
|
jpayne@69
|
77 except Exception:
|
jpayne@69
|
78 # May raise AttributeError for older versions of OpenBLAS.
|
jpayne@69
|
79 # Catch any error for robustness.
|
jpayne@69
|
80 return
|
jpayne@69
|
81
|
jpayne@69
|
82
|
jpayne@69
|
83 @pytest.fixture(scope="function", autouse=True)
|
jpayne@69
|
84 def check_fpu_mode(request):
|
jpayne@69
|
85 """
|
jpayne@69
|
86 Check FPU mode was not changed during the test.
|
jpayne@69
|
87 """
|
jpayne@69
|
88 old_mode = get_fpu_mode()
|
jpayne@69
|
89 yield
|
jpayne@69
|
90 new_mode = get_fpu_mode()
|
jpayne@69
|
91
|
jpayne@69
|
92 if old_mode != new_mode:
|
jpayne@69
|
93 warnings.warn("FPU mode changed from {0:#x} to {1:#x} during "
|
jpayne@69
|
94 "the test".format(old_mode, new_mode),
|
jpayne@69
|
95 category=FPUModeChangeWarning, stacklevel=0)
|