aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRobert Collins <rbtcollins@hp.com>2014-11-05 03:09:01 +1300
committerRobert Collins <rbtcollins@hp.com>2014-11-05 03:09:01 +1300
commitbf2bda3c9704181cebb6163f5eacd5ad4e1c15f4 (patch)
tree1f2e15056fc7e7bd965dd5b043183bc575d28f62 /Lib/unittest/loader.py
parentIssue #22773: fix failing test with old readline versions due to issue #19884. (diff)
downloadcpython-bf2bda3c9704181cebb6163f5eacd5ad4e1c15f4.tar.gz
cpython-bf2bda3c9704181cebb6163f5eacd5ad4e1c15f4.tar.bz2
cpython-bf2bda3c9704181cebb6163f5eacd5ad4e1c15f4.zip
Close #22457: Honour load_tests in the start_dir of discovery.
We were not honouring load_tests in a package/__init__.py when that was the start_dir parameter, though we do when it is a child package. The fix required a little care since it introduces the possibility of infinite recursion.
Diffstat (limited to 'Lib/unittest/loader.py')
-rw-r--r--Lib/unittest/loader.py162
1 files changed, 105 insertions, 57 deletions
diff --git a/Lib/unittest/loader.py b/Lib/unittest/loader.py
index 811bedf928a..8c10ad1f41e 100644
--- a/Lib/unittest/loader.py
+++ b/Lib/unittest/loader.py
@@ -65,6 +65,9 @@ class TestLoader(object):
def __init__(self):
super(TestLoader, self).__init__()
self.errors = []
+ # Tracks packages which we have called into via load_tests, to
+ # avoid infinite re-entrancy.
+ self._loading_packages = set()
def loadTestsFromTestCase(self, testCaseClass):
"""Return a suite of all tests cases contained in testCaseClass"""
@@ -229,9 +232,13 @@ class TestLoader(object):
If a test package name (directory with '__init__.py') matches the
pattern then the package will be checked for a 'load_tests' function. If
- this exists then it will be called with loader, tests, pattern.
+ this exists then it will be called with (loader, tests, pattern) unless
+ the package has already had load_tests called from the same discovery
+ invocation, in which case the package module object is not scanned for
+ tests - this ensures that when a package uses discover to further
+ discover child tests that infinite recursion does not happen.
- If load_tests exists then discovery does *not* recurse into the package,
+ If load_tests exists then discovery does *not* recurse into the package,
load_tests is responsible for loading all tests in the package.
The pattern is deliberately not stored as a loader attribute so that
@@ -355,69 +362,110 @@ class TestLoader(object):
def _find_tests(self, start_dir, pattern, namespace=False):
"""Used by discovery. Yields test suites it loads."""
+ # Handle the __init__ in this package
+ name = self._get_name_from_path(start_dir)
+ # name is '.' when start_dir == top_level_dir (and top_level_dir is by
+ # definition not a package).
+ if name != '.' and name not in self._loading_packages:
+ # name is in self._loading_packages while we have called into
+ # loadTestsFromModule with name.
+ tests, should_recurse = self._find_test_path(
+ start_dir, pattern, namespace)
+ if tests is not None:
+ yield tests
+ if not should_recurse:
+ # Either an error occured, or load_tests was used by the
+ # package.
+ return
+ # Handle the contents.
paths = sorted(os.listdir(start_dir))
-
for path in paths:
full_path = os.path.join(start_dir, path)
- if os.path.isfile(full_path):
- if not VALID_MODULE_NAME.match(path):
- # valid Python identifiers only
- continue
- if not self._match_path(path, full_path, pattern):
- continue
- # if the test file matches, load it
+ tests, should_recurse = self._find_test_path(
+ full_path, pattern, namespace)
+ if tests is not None:
+ yield tests
+ if should_recurse:
+ # we found a package that didn't use load_tests.
name = self._get_name_from_path(full_path)
+ self._loading_packages.add(name)
try:
- module = self._get_module_from_name(name)
- except case.SkipTest as e:
- yield _make_skipped_test(name, e, self.suiteClass)
- except:
- error_case, error_message = \
- _make_failed_import_test(name, self.suiteClass)
- self.errors.append(error_message)
- yield error_case
- else:
- mod_file = os.path.abspath(getattr(module, '__file__', full_path))
- realpath = _jython_aware_splitext(os.path.realpath(mod_file))
- fullpath_noext = _jython_aware_splitext(os.path.realpath(full_path))
- if realpath.lower() != fullpath_noext.lower():
- module_dir = os.path.dirname(realpath)
- mod_name = _jython_aware_splitext(os.path.basename(full_path))
- expected_dir = os.path.dirname(full_path)
- msg = ("%r module incorrectly imported from %r. Expected %r. "
- "Is this module globally installed?")
- raise ImportError(msg % (mod_name, module_dir, expected_dir))
- yield self.loadTestsFromModule(module, pattern=pattern)
- elif os.path.isdir(full_path):
- if (not namespace and
- not os.path.isfile(os.path.join(full_path, '__init__.py'))):
- continue
-
- load_tests = None
- tests = None
- name = self._get_name_from_path(full_path)
+ yield from self._find_tests(full_path, pattern, namespace)
+ finally:
+ self._loading_packages.discard(name)
+
+ def _find_test_path(self, full_path, pattern, namespace=False):
+ """Used by discovery.
+
+ Loads tests from a single file, or a directories' __init__.py when
+ passed the directory.
+
+ Returns a tuple (None_or_tests_from_file, should_recurse).
+ """
+ basename = os.path.basename(full_path)
+ if os.path.isfile(full_path):
+ if not VALID_MODULE_NAME.match(basename):
+ # valid Python identifiers only
+ return None, False
+ if not self._match_path(basename, full_path, pattern):
+ return None, False
+ # if the test file matches, load it
+ name = self._get_name_from_path(full_path)
+ try:
+ module = self._get_module_from_name(name)
+ except case.SkipTest as e:
+ return _make_skipped_test(name, e, self.suiteClass), False
+ except:
+ error_case, error_message = \
+ _make_failed_import_test(name, self.suiteClass)
+ self.errors.append(error_message)
+ return error_case, False
+ else:
+ mod_file = os.path.abspath(
+ getattr(module, '__file__', full_path))
+ realpath = _jython_aware_splitext(
+ os.path.realpath(mod_file))
+ fullpath_noext = _jython_aware_splitext(
+ os.path.realpath(full_path))
+ if realpath.lower() != fullpath_noext.lower():
+ module_dir = os.path.dirname(realpath)
+ mod_name = _jython_aware_splitext(
+ os.path.basename(full_path))
+ expected_dir = os.path.dirname(full_path)
+ msg = ("%r module incorrectly imported from %r. Expected "
+ "%r. Is this module globally installed?")
+ raise ImportError(
+ msg % (mod_name, module_dir, expected_dir))
+ return self.loadTestsFromModule(module, pattern=pattern), False
+ elif os.path.isdir(full_path):
+ if (not namespace and
+ not os.path.isfile(os.path.join(full_path, '__init__.py'))):
+ return None, False
+
+ load_tests = None
+ tests = None
+ name = self._get_name_from_path(full_path)
+ try:
+ package = self._get_module_from_name(name)
+ except case.SkipTest as e:
+ return _make_skipped_test(name, e, self.suiteClass), False
+ except:
+ error_case, error_message = \
+ _make_failed_import_test(name, self.suiteClass)
+ self.errors.append(error_message)
+ return error_case, False
+ else:
+ load_tests = getattr(package, 'load_tests', None)
+ # Mark this package as being in load_tests (possibly ;))
+ self._loading_packages.add(name)
try:
- package = self._get_module_from_name(name)
- except case.SkipTest as e:
- yield _make_skipped_test(name, e, self.suiteClass)
- except:
- error_case, error_message = \
- _make_failed_import_test(name, self.suiteClass)
- self.errors.append(error_message)
- yield error_case
- else:
- load_tests = getattr(package, 'load_tests', None)
tests = self.loadTestsFromModule(package, pattern=pattern)
- if tests is not None:
- # tests loaded from package file
- yield tests
-
if load_tests is not None:
- # loadTestsFromModule(package) has load_tests for us.
- continue
- # recurse into the package
- yield from self._find_tests(full_path, pattern,
- namespace=namespace)
+ # loadTestsFromModule(package) has loaded tests for us.
+ return tests, False
+ return tests, True
+ finally:
+ self._loading_packages.discard(name)
defaultTestLoader = TestLoader()