From d6a7c2e44b4f497357f8569d423104232a58f384 Mon Sep 17 00:00:00 2001 From: Brian Harring Date: Sat, 24 Dec 2022 13:14:53 -0800 Subject: Reformat w/ black 22.12.0 for consistency. Signed-off-by: Brian Harring Signed-off-by: Arthur Zamarin --- doc/conf.py | 157 +++++----- src/snakeoil/__init__.py | 4 +- src/snakeoil/_fileutils.py | 25 +- src/snakeoil/bash.py | 115 ++++---- src/snakeoil/caching.py | 10 +- src/snakeoil/chksum/__init__.py | 19 +- src/snakeoil/chksum/defaults.py | 53 ++-- src/snakeoil/cli/arghparse.py | 481 +++++++++++++++++++------------ src/snakeoil/cli/exceptions.py | 6 +- src/snakeoil/cli/input.py | 37 ++- src/snakeoil/cli/tool.py | 63 ++-- src/snakeoil/compatibility.py | 15 +- src/snakeoil/compression/__init__.py | 107 +++---- src/snakeoil/compression/_bzip2.py | 30 +- src/snakeoil/compression/_util.py | 59 ++-- src/snakeoil/compression/_xz.py | 28 +- src/snakeoil/constraints.py | 30 +- src/snakeoil/containers.py | 26 +- src/snakeoil/contexts.py | 78 +++-- src/snakeoil/currying.py | 60 +++- src/snakeoil/data_source.py | 81 ++++-- src/snakeoil/decorators.py | 7 + src/snakeoil/demandimport.py | 30 +- src/snakeoil/demandload.py | 95 +++--- src/snakeoil/dependant_methods.py | 20 +- src/snakeoil/errors.py | 14 +- src/snakeoil/fileutils.py | 43 +-- src/snakeoil/formatters.py | 115 ++++---- src/snakeoil/iterables.py | 12 +- src/snakeoil/klass.py | 164 ++++++++--- src/snakeoil/mappings.py | 67 +++-- src/snakeoil/modules.py | 1 + src/snakeoil/obj.py | 169 ++++++++--- src/snakeoil/osutils/__init__.py | 40 ++- src/snakeoil/osutils/mount.py | 10 +- src/snakeoil/osutils/native_readdir.py | 55 ++-- src/snakeoil/pickling.py | 3 +- src/snakeoil/process/__init__.py | 8 +- src/snakeoil/process/namespaces.py | 47 +-- src/snakeoil/process/spawn.py | 95 ++++-- src/snakeoil/sequences.py | 35 ++- src/snakeoil/stringio.py | 2 +- src/snakeoil/strings.py | 14 +- src/snakeoil/tar.py | 27 +- src/snakeoil/test/__init__.py | 29 +- src/snakeoil/test/argparse_helpers.py | 57 ++-- src/snakeoil/test/eq_hash_inheritance.py | 11 +- src/snakeoil/test/mixins.py | 46 +-- src/snakeoil/test/modules.py | 4 +- src/snakeoil/test/slot_shadowing.py | 22 +- src/snakeoil/version.py | 50 ++-- src/snakeoil/weakrefs.py | 1 - tests/cli/test_arghparse.py | 336 ++++++++++----------- tests/cli/test_input.py | 99 ++++--- tests/compression/__init__.py | 68 +++-- tests/compression/test_bzip2.py | 40 +-- tests/compression/test_init.py | 77 +++-- tests/compression/test_xz.py | 30 +- tests/test_bash.py | 283 ++++++++++-------- tests/test_caching.py | 22 +- tests/test_chksum.py | 4 +- tests/test_chksum_defaults.py | 37 ++- tests/test_constraints.py | 64 ++-- tests/test_containers.py | 19 +- tests/test_contexts.py | 51 ++-- tests/test_currying.py | 119 ++++---- tests/test_data_source.py | 37 ++- tests/test_decorators.py | 37 ++- tests/test_demandload.py | 86 +++--- tests/test_demandload_usage.py | 6 +- tests/test_dependant_methods.py | 21 +- tests/test_fileutils.py | 105 +++---- tests/test_formatters.py | 243 +++++++++------- tests/test_iterables.py | 11 +- tests/test_klass.py | 154 ++++++---- tests/test_mappings.py | 208 ++++++------- tests/test_modules.py | 67 +++-- tests/test_obj.py | 85 ++++-- tests/test_osutils.py | 221 +++++++------- tests/test_process.py | 6 +- tests/test_process_spawn.py | 50 ++-- tests/test_sequences.py | 110 +++---- tests/test_stringio.py | 5 +- tests/test_strings.py | 34 +-- tests/test_version.py | 139 +++++---- 85 files changed, 3268 insertions(+), 2383 deletions(-) diff --git a/doc/conf.py b/doc/conf.py index d2cbc09..0a83d80 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -18,231 +18,248 @@ import sys # If extensions (or modules to document with autodoc) are in another directory, # add these directories to sys.path here. If the directory is relative to the # documentation root, use os.path.abspath to make it absolute, like shown here. -sys.path.insert(0, os.path.abspath('../src/')) +sys.path.insert(0, os.path.abspath("../src/")) # generate API docs -subprocess.call([ - 'sphinx-apidoc', '-ef', '-o', 'api', '../src/snakeoil', - '../src/snakeoil/dist', '../src/snakeoil/test', # excludes -]) +subprocess.call( + [ + "sphinx-apidoc", + "-ef", + "-o", + "api", + "../src/snakeoil", + "../src/snakeoil/dist", + "../src/snakeoil/test", # excludes + ] +) # -- General configuration ----------------------------------------------------- # If your documentation needs a minimal Sphinx version, state it here. -needs_sphinx = '1.0' +needs_sphinx = "1.0" # Add any Sphinx extension module names here, as strings. They can be extensions # coming with Sphinx (named 'sphinx.ext.*') or your custom ones. extensions = [ - 'sphinx.ext.autodoc', 'sphinx.ext.autosummary', 'sphinx.ext.doctest', - 'sphinx.ext.extlinks', - 'sphinx.ext.intersphinx', 'sphinx.ext.todo', 'sphinx.ext.coverage', - 'sphinx.ext.ifconfig', 'sphinx.ext.graphviz', - 'sphinx.ext.viewcode', + "sphinx.ext.autodoc", + "sphinx.ext.autosummary", + "sphinx.ext.doctest", + "sphinx.ext.extlinks", + "sphinx.ext.intersphinx", + "sphinx.ext.todo", + "sphinx.ext.coverage", + "sphinx.ext.ifconfig", + "sphinx.ext.graphviz", + "sphinx.ext.viewcode", ] # Add any paths that contain templates here, relative to this directory. -templates_path = ['_templates'] +templates_path = ["_templates"] # The suffix of source filenames. -source_suffix = '.rst' +source_suffix = ".rst" # The encoding of source files. -#source_encoding = 'utf-8-sig' +# source_encoding = 'utf-8-sig' # The master toctree document. -master_doc = 'index' +master_doc = "index" # General information about the project. -project = 'snakeoil' -authors = '' -copyright = '2007-2022, snakeoil contributors' +project = "snakeoil" +authors = "" +copyright = "2007-2022, snakeoil contributors" # The version info for the project you're documenting, acts as replacement for # |version| and |release|, also used in various other places throughout the # built documents. # # The short X.Y version. -version = '-trunk' +version = "-trunk" # The full version, including alpha/beta/rc tags. -release = '-trunk' +release = "-trunk" # The language for content autogenerated by Sphinx. Refer to documentation # for a list of supported languages. -#language = None +# language = None # There are two options for replacing |today|: either, you set today to some # non-false value, then it is used: -#today = '' +# today = '' # Else, today_fmt is used as the format for a strftime call. -#today_fmt = '%B %d, %Y' +# today_fmt = '%B %d, %Y' # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. -exclude_patterns = ['_build'] +exclude_patterns = ["_build"] # The reST default role (used for this markup: `text`) to use for all documents. -#default_role = None +# default_role = None # If true, '()' will be appended to :func: etc. cross-reference text. -#add_function_parentheses = True +# add_function_parentheses = True # If true, the current module name will be prepended to all description # unit titles (such as .. function::). -#add_module_names = True +# add_module_names = True # If true, sectionauthor and moduleauthor directives will be shown in the # output. They are ignored by default. show_authors = False # The name of the Pygments (syntax highlighting) style to use. -pygments_style = 'sphinx' +pygments_style = "sphinx" # A list of ignored prefixes for module index sorting. -#modindex_common_prefix = [] +# modindex_common_prefix = [] # -- Options for HTML output --------------------------------------------------- # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. -#html_theme = 'default' -html_theme = 'default' +# html_theme = 'default' +html_theme = "default" # Theme options are theme-specific and customize the look and feel of a theme # further. For a list of options available for each theme, see the # documentation. -#html_theme_options = {} +# html_theme_options = {} # Add any paths that contain custom themes here, relative to this directory. -#html_theme_path = [] +# html_theme_path = [] # The name for this set of Sphinx documents. If None, it defaults to # " v documentation". -#html_title = None +# html_title = None # A shorter title for the navigation bar. Default is the same as html_title. -#html_short_title = None +# html_short_title = None # The name of an image file (relative to this directory) to place at the top # of the sidebar. -#html_logo = None +# html_logo = None # The name of an image file (within the static path) to use as favicon of the # docs. This file should be a Windows icon file (.ico) being 16x16 or 32x32 # pixels large. -#html_favicon = None +# html_favicon = None # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". -#html_static_path = ['_static'] +# html_static_path = ['_static'] # If not '', a 'Last updated on:' timestamp is inserted at every page bottom, # using the given strftime format. -#html_last_updated_fmt = '%b %d, %Y' +# html_last_updated_fmt = '%b %d, %Y' # If true, SmartyPants will be used to convert quotes and dashes to # typographically correct entities. -#html_use_smartypants = True +# html_use_smartypants = True # Custom sidebar templates, maps document names to template names. -#html_sidebars = {} +# html_sidebars = {} # Additional templates that should be rendered to pages, maps page names to # template names. -#html_additional_pages = {} +# html_additional_pages = {} # If false, no module index is generated. -#html_domain_indices = True +# html_domain_indices = True # If false, no index is generated. -#html_use_index = True +# html_use_index = True # If true, the index is split into individual pages for each letter. -#html_split_index = False +# html_split_index = False # If true, links to the reST sources are added to the pages. html_show_sourcelink = False # If true, "Created using Sphinx" is shown in the HTML footer. Default is True. -#html_show_sphinx = True +# html_show_sphinx = True # If true, "(C) Copyright ..." is shown in the HTML footer. Default is True. -#html_show_copyright = True +# html_show_copyright = True # If true, an OpenSearch description file will be output, and all pages will # contain a tag referring to it. The value of this option must be the # base URL from which the finished HTML is served. -#html_use_opensearch = '' +# html_use_opensearch = '' # This is the file name suffix for HTML files (e.g. ".xhtml"). -#html_file_suffix = None +# html_file_suffix = None # Output file base name for HTML help builder. -htmlhelp_basename = 'snakeoildoc' +htmlhelp_basename = "snakeoildoc" # -- Options for LaTeX output -------------------------------------------------- # The paper size ('letter' or 'a4'). -#latex_paper_size = 'letter' +# latex_paper_size = 'letter' # The font size ('10pt', '11pt' or '12pt'). -#latex_font_size = '10pt' +# latex_font_size = '10pt' # Grouping the document tree into LaTeX files. List of tuples # (source start file, target name, title, author, documentclass [howto/manual]). latex_documents = [ - ('index', 'snakeoil.tex', 'snakeoil Documentation', - authors, 'manual'), + ("index", "snakeoil.tex", "snakeoil Documentation", authors, "manual"), ] # The name of an image file (relative to this directory) to place at the top of # the title page. -#latex_logo = None +# latex_logo = None # For "manual" documents, if this is true, then toplevel headings are parts, # not chapters. -#latex_use_parts = False +# latex_use_parts = False # If true, show page references after internal links. -#latex_show_pagerefs = False +# latex_show_pagerefs = False # If true, show URL addresses after external links. -#latex_show_urls = False +# latex_show_urls = False # Additional stuff for the LaTeX preamble. -#latex_preamble = '' +# latex_preamble = '' # Documents to append as an appendix to all manuals. -#latex_appendices = [] +# latex_appendices = [] # If false, no module index is generated. -#latex_domain_indices = True +# latex_domain_indices = True # -- Options for manual page output -------------------------------------------- # One entry per manual page. List of tuples # (source start file, name, description, authors, manual section). -man_pages = [ - ('index', 'snakeoil', 'snakeoil Documentation', [], 1) -] +man_pages = [("index", "snakeoil", "snakeoil Documentation", [], 1)] # Example configuration for intersphinx: refer to the Python standard library. -intersphinx_mapping = {'http://docs.python.org/': None} -autodoc_default_flags = ['members', 'show-inheritance', 'inherited-members'] # + ['undoc-members'] +intersphinx_mapping = {"http://docs.python.org/": None} +autodoc_default_flags = [ + "members", + "show-inheritance", + "inherited-members", +] # + ['undoc-members'] autosummary_generate = False rst_epilog = """ .. |homepage| replace:: https://github.com/pkgcore/snakeoil .. |release_url| replace:: https://github.com/pkgcore/snakeoil/releases -""" % {"release": release} +""" % { + "release": release +} extlinks = { - 'git_tag': ('https://github.com/pkgcore/snakeoil/releases/tag/%s', 'git log '), - 'git_release': ('https://github.com/pkgcore/snakeoil/archive/%s.tar.gz', - 'release download '), + "git_tag": ("https://github.com/pkgcore/snakeoil/releases/tag/%s", "git log "), + "git_release": ( + "https://github.com/pkgcore/snakeoil/archive/%s.tar.gz", + "release download ", + ), } diff --git a/src/snakeoil/__init__.py b/src/snakeoil/__init__.py index 636d663..a1c1f20 100644 --- a/src/snakeoil/__init__.py +++ b/src/snakeoil/__init__.py @@ -10,5 +10,5 @@ This library is a bit of a grabbag of the following: * optimized implementations of common patterns """ -__title__ = 'snakeoil' -__version__ = '0.10.4' +__title__ = "snakeoil" +__version__ = "0.10.4" diff --git a/src/snakeoil/_fileutils.py b/src/snakeoil/_fileutils.py index 51f82c7..4a22648 100644 --- a/src/snakeoil/_fileutils.py +++ b/src/snakeoil/_fileutils.py @@ -5,7 +5,9 @@ Access this functionality from :py:module:`snakeoil.osutils` instead """ __all__ = ( - "mmap_and_close", "readlines_iter", "native_readlines", + "mmap_and_close", + "readlines_iter", + "native_readlines", "native_readfile", ) @@ -31,6 +33,7 @@ def mmap_and_close(fd, *args, **kwargs): class readlines_iter: __slots__ = ("iterable", "mtime", "source") + def __init__(self, iterable, mtime, close=True, source=None): if source is None: source = iterable @@ -54,17 +57,25 @@ class readlines_iter: source.close() def close(self): - if hasattr(self.source, 'close'): + if hasattr(self.source, "close"): self.source.close() def __iter__(self): return self.iterable + def _native_readlines_shim(*args, **kwds): - return native_readlines('r', *args, **kwds) + return native_readlines("r", *args, **kwds) + -def native_readlines(mode, mypath, strip_whitespace=True, swallow_missing=False, - none_on_missing=False, encoding=None): +def native_readlines( + mode, + mypath, + strip_whitespace=True, + swallow_missing=False, + none_on_missing=False, + encoding=None, +): """Read a file, yielding each line. :param mypath: fs path for the file to read @@ -102,8 +113,10 @@ def _py2k_ascii_strict_filter(source): raise ValueError("character ordinal over 127") yield line + def _native_readfile_shim(*args, **kwds): - return native_readfile('r', *args, **kwds) + return native_readfile("r", *args, **kwds) + def native_readfile(mode, mypath, none_on_missing=False, encoding=None): """Read a file, returning the contents. diff --git a/src/snakeoil/bash.py b/src/snakeoil/bash.py index 5eaaca2..3ca737c 100644 --- a/src/snakeoil/bash.py +++ b/src/snakeoil/bash.py @@ -15,19 +15,25 @@ from .fileutils import readlines from .log import logger from .mappings import ProtectedDict -demand_compile_regexp('line_cont_regexp', r'^(.*[^\\]|)\\$') -demand_compile_regexp('inline_comment_regexp', r'^.*\s#.*$') -demand_compile_regexp('var_find', r'\\?(\${\w+}|\$\w+)') -demand_compile_regexp('backslash_find', r'\\.') -demand_compile_regexp('ansi_escape_re', r'(\x9B|\x1B\[)[0-?]*[ -/]*[@-~]') +demand_compile_regexp("line_cont_regexp", r"^(.*[^\\]|)\\$") +demand_compile_regexp("inline_comment_regexp", r"^.*\s#.*$") +demand_compile_regexp("var_find", r"\\?(\${\w+}|\$\w+)") +demand_compile_regexp("backslash_find", r"\\.") +demand_compile_regexp("ansi_escape_re", r"(\x9B|\x1B\[)[0-?]*[ -/]*[@-~]") __all__ = ( - "iter_read_bash", "read_bash", "read_dict", "read_bash_dict", - "bash_parser", "BashParseError") - - -def iter_read_bash(bash_source, allow_inline_comments=True, - allow_line_cont=False, enum_line=False): + "iter_read_bash", + "read_bash", + "read_dict", + "read_bash_dict", + "bash_parser", + "BashParseError", +) + + +def iter_read_bash( + bash_source, allow_inline_comments=True, allow_line_cont=False, enum_line=False +): """Iterate over a file honoring bash commenting rules and line continuations. Note that it's considered good behaviour to close filehandles, as @@ -45,7 +51,7 @@ def iter_read_bash(bash_source, allow_inline_comments=True, """ if isinstance(bash_source, str): bash_source = readlines(bash_source, True) - s = '' + s = "" for lineno, line in enumerate(bash_source, 1): if allow_line_cont and s: s += line @@ -53,19 +59,20 @@ def iter_read_bash(bash_source, allow_inline_comments=True, s = line.lstrip() if s: - if s[0] != '#': + if s[0] != "#": if allow_inline_comments: - if (not allow_line_cont or - (allow_line_cont and inline_comment_regexp.match(line))): + if not allow_line_cont or ( + allow_line_cont and inline_comment_regexp.match(line) + ): s = s.split("#", 1)[0].rstrip() if allow_line_cont and line_cont_regexp.match(line): - s = s.rstrip('\\\n') + s = s.rstrip("\\\n") continue if enum_line: yield lineno, s.rstrip() else: yield s.rstrip() - s = '' + s = "" if s: if enum_line: yield lineno, s @@ -122,7 +129,7 @@ def read_bash_dict(bash_source, vars_dict=None, sourcing_command=None): try: while tok is not None: key = s.get_token() - if key == 'export': + if key == "export": # discard 'export' token from "export VAR=VALUE" lines key = s.get_token() if key is None: @@ -133,23 +140,23 @@ def read_bash_dict(bash_source, vars_dict=None, sourcing_command=None): # detect empty assigns continue eq = s.get_token() - if eq != '=': + if eq != "=": raise BashParseError( - bash_source, s.lineno, - "got token %r, was expecting '='" % eq) + bash_source, s.lineno, "got token %r, was expecting '='" % eq + ) val = s.get_token() if val is None: - val = '' - elif val == 'export': + val = "" + elif val == "export": val = s.get_token() # look ahead to see if we just got an empty assign. next_tok = s.get_token() - if next_tok == '=': + if next_tok == "=": # ... we did. # leftmost insertions, thus reversed ordering s.push_token(next_tok) s.push_token(val) - val = '' + val = "" else: s.push_token(next_tok) d[key] = val @@ -163,9 +170,15 @@ def read_bash_dict(bash_source, vars_dict=None, sourcing_command=None): return d -def read_dict(bash_source, splitter="=", source_isiter=False, - allow_inline_comments=True, strip=False, filename=None, - ignore_errors=False): +def read_dict( + bash_source, + splitter="=", + source_isiter=False, + allow_inline_comments=True, + strip=False, + filename=None, + ignore_errors=False, +): """Read key value pairs from a file, ignoring bash-style comments. :param splitter: the string to split on. Can be None to @@ -180,12 +193,11 @@ def read_dict(bash_source, splitter="=", source_isiter=False, d = {} if not source_isiter: filename = bash_source - i = iter_read_bash( - bash_source, allow_inline_comments=allow_inline_comments) + i = iter_read_bash(bash_source, allow_inline_comments=allow_inline_comments) else: if filename is None: # XXX what to do? - filename = '' + filename = "" i = bash_source line_count = 0 try: @@ -195,10 +207,11 @@ def read_dict(bash_source, splitter="=", source_isiter=False, k, v = k.split(splitter, 1) except ValueError as e: if filename == "": - filename = getattr(bash_source, 'name', bash_source) + filename = getattr(bash_source, "name", bash_source) if ignore_errors: logger.error( - 'bash parse error in %r, line %s', filename, line_count) + "bash parse error in %r, line %s", filename, line_count + ) continue else: raise BashParseError(filename, line_count) from e @@ -239,7 +252,7 @@ class bash_parser(shlex): :param env: initial environment to use for variable interpolation :type env: must be a mapping; if None, an empty dict is used """ - self.__dict__['state'] = ' ' + self.__dict__["state"] = " " super().__init__(source, posix=True, infile=infile) self.wordchars += "@${}/.-+/:~^*" self.wordchars = frozenset(self.wordchars) @@ -252,12 +265,10 @@ class bash_parser(shlex): def __setattr__(self, attr, val): if attr == "state": - if (self.state, val) in ( - ('"', 'a'), ('a', '"'), ('a', ' '), ("'", 'a')): + if (self.state, val) in (('"', "a"), ("a", '"'), ("a", " "), ("'", "a")): strl = len(self.token) if self.__pos != strl: - self.changed_state.append( - (self.state, self.token[self.__pos:])) + self.changed_state.append((self.state, self.token[self.__pos :])) self.__pos = strl self.__dict__[attr] = val @@ -275,13 +286,13 @@ class bash_parser(shlex): return token if self.state is None: # eof reached. - self.changed_state.append((self.state, token[self.__pos:])) + self.changed_state.append((self.state, token[self.__pos :])) else: - self.changed_state.append((self.state, self.token[self.__pos:])) - tok = '' + self.changed_state.append((self.state, self.token[self.__pos :])) + tok = "" for s, t in self.changed_state: if s in ('"', "a"): - tok += self.var_expand(t).replace("\\\n", '') + tok += self.var_expand(t).replace("\\\n", "") else: tok += t return tok @@ -291,26 +302,27 @@ class bash_parser(shlex): l = [] while match := var_find.search(val, pos): pos = match.start() - if val[pos] == '\\': + if val[pos] == "\\": # it's escaped. either it's \\$ or \\${ , either way, # skipping two ahead handles it. pos += 2 else: - var = val[match.start():match.end()].strip("${}") + var = val[match.start() : match.end()].strip("${}") if prev != pos: l.append(val[prev:pos]) if var in self.env: if not isinstance(self.env[var], str): raise ValueError( - "env key %r must be a string, not %s: %r" % ( - var, type(self.env[var]), self.env[var])) + "env key %r must be a string, not %s: %r" + % (var, type(self.env[var]), self.env[var]) + ) l.append(self.env[var]) else: l.append("") prev = pos = match.end() # do \\ cleansing, collapsing val down also. - val = backslash_find.sub(_nuke_backslash, ''.join(l) + val[prev:]) + val = backslash_find.sub(_nuke_backslash, "".join(l) + val[prev:]) return val @@ -320,10 +332,11 @@ class BashParseError(Exception): def __init__(self, filename, line, errmsg=None): if errmsg is not None: super().__init__( - "error parsing '%s' on or before line %i: err %s" % - (filename, line, errmsg)) + "error parsing '%s' on or before line %i: err %s" + % (filename, line, errmsg) + ) else: super().__init__( - "error parsing '%s' on or before line %i" % - (filename, line)) + "error parsing '%s' on or before line %i" % (filename, line) + ) self.file, self.line, self.errmsg = filename, line, errmsg diff --git a/src/snakeoil/caching.py b/src/snakeoil/caching.py index 3ed75d2..b20f074 100644 --- a/src/snakeoil/caching.py +++ b/src/snakeoil/caching.py @@ -79,21 +79,22 @@ class WeakInstMeta(type): Examples of usage is the restrictions subsystem for U{pkgcore project} """ + def __new__(cls, name, bases, d): if d.get("__inst_caching__", False): d["__inst_caching__"] = True d["__inst_dict__"] = WeakValueDictionary() else: d["__inst_caching__"] = False - slots = d.get('__slots__') + slots = d.get("__slots__") # get ourselves a singleton to be safe... o = object() if slots is not None: for base in bases: - if getattr(base, '__weakref__', o) is not o: + if getattr(base, "__weakref__", o) is not o: break else: - d['__slots__'] = tuple(slots) + ('__weakref__',) + d["__slots__"] = tuple(slots) + ("__weakref__",) return type.__new__(cls, name, bases, d) def __call__(cls, *a, **kw): @@ -105,8 +106,7 @@ class WeakInstMeta(type): try: instance = cls.__inst_dict__.get(key) except (NotImplementedError, TypeError) as t: - warnings.warn( - f"caching keys for {cls}, got {t} for a={a}, kw={kw}") + warnings.warn(f"caching keys for {cls}, got {t} for a={a}, kw={kw}") del t key = instance = None diff --git a/src/snakeoil/chksum/__init__.py b/src/snakeoil/chksum/__init__.py index 8f2bc73..5bd6a7c 100644 --- a/src/snakeoil/chksum/__init__.py +++ b/src/snakeoil/chksum/__init__.py @@ -64,10 +64,9 @@ def init(additional_handlers=None): :param additional_handlers: None, or pass in a dict of type:func """ - global __inited__ # pylint: disable=global-statement + global __inited__ # pylint: disable=global-statement - if additional_handlers is not None and not isinstance( - additional_handlers, dict): + if additional_handlers is not None and not isinstance(additional_handlers, dict): raise TypeError("additional handlers must be a dict!") chksum_types.clear() @@ -119,15 +118,19 @@ def get_chksums(location, *chksums, **kwds): # try to hand off to the per file handler, may be faster. if len(chksums) == 1: return [handlers[chksums[0]](location)] - if len(chksums) == 2 and 'size' in chksums: + if len(chksums) == 2 and "size" in chksums: parallelize = False else: parallelize = kwds.get("parallelize", True) can_mmap = True for k in chksums: can_mmap &= handlers[k].can_mmap - return chksum_loop_over_file(location, [handlers[k].new() for k in chksums], - parallelize=parallelize, can_mmap=can_mmap) + return chksum_loop_over_file( + location, + [handlers[k].new() for k in chksums], + parallelize=parallelize, + can_mmap=can_mmap, + ) class LazilyHashedPath(metaclass=klass.immutable_instance): @@ -135,7 +138,7 @@ class LazilyHashedPath(metaclass=klass.immutable_instance): def __init__(self, path, **initial_values): f = object.__setattr__ - f(self, 'path', path) + f(self, "path", path) for attr, val in initial_values.items(): f(self, attr, val) @@ -143,7 +146,7 @@ class LazilyHashedPath(metaclass=klass.immutable_instance): if not attr.islower(): # Disallow sHa1. raise AttributeError(attr) - elif attr == 'mtime': + elif attr == "mtime": val = osutils.stat_mtime_long(self.path) else: try: diff --git a/src/snakeoil/chksum/defaults.py b/src/snakeoil/chksum/defaults.py index bf2be41..fe01377 100644 --- a/src/snakeoil/chksum/defaults.py +++ b/src/snakeoil/chksum/defaults.py @@ -11,7 +11,7 @@ from sys import intern from ..data_source import base as base_data_source from ..fileutils import mmap_or_open_for_read -blocksize = 2 ** 17 +blocksize = 2**17 blake2b_size = 128 blake2s_size = 64 @@ -36,8 +36,11 @@ def chf_thread(queue, callback): def chksum_loop_over_file(filename, chfs, parallelize=True, can_mmap=True): chfs = [chf() for chf in chfs] loop_over_file( - filename, [chf.update for chf in chfs], - parallelize=parallelize, can_mmap=can_mmap) + filename, + [chf.update for chf in chfs], + parallelize=parallelize, + can_mmap=can_mmap, + ) return [int(chf.hexdigest(), 16) for chf in chfs] @@ -54,7 +57,7 @@ def loop_over_file(handle, callbacks, parallelize=True, can_mmap=True): else: f = handle close_f = False - if getattr(handle, 'encoding', None): + if getattr(handle, "encoding", None): # wanker. bypass the encoding, go straight to the raw source. f = f.buffer # reset; we do it for compat, but it also avoids unpleasant issues from @@ -68,8 +71,10 @@ def loop_over_file(handle, callbacks, parallelize=True, can_mmap=True): if parallelize: queues = [queue.Queue(8) for _ in callbacks] - threads = [threading.Thread(target=chf_thread, args=(queue, functor)) - for queue, functor in zip(queues, callbacks)] + threads = [ + threading.Thread(target=chf_thread, args=(queue, functor)) + for queue, functor in zip(queues, callbacks) + ] for thread in threads: thread.start() @@ -79,7 +84,7 @@ def loop_over_file(handle, callbacks, parallelize=True, can_mmap=True): if m is not None: for callback in callbacks: callback(m) - elif hasattr(f, 'getvalue'): + elif hasattr(f, "getvalue"): data = f.getvalue() if not isinstance(data, bytes): data = data.encode() @@ -107,7 +112,6 @@ def loop_over_file(handle, callbacks, parallelize=True, can_mmap=True): class Chksummer: - def __init__(self, chf_type, obj, str_size, can_mmap=True): self.obj = obj self.chf_type = chf_type @@ -118,15 +122,14 @@ class Chksummer: return self.obj def long2str(self, val): - return ("%x" % val).rjust(self.str_size, '0') + return ("%x" % val).rjust(self.str_size, "0") @staticmethod def str2long(val): return int(val, 16) def __call__(self, filename): - return chksum_loop_over_file( - filename, [self.obj], can_mmap=self.can_mmap)[0] + return chksum_loop_over_file(filename, [self.obj], can_mmap=self.can_mmap)[0] def __str__(self): return "%s chksummer" % self.chf_type @@ -134,31 +137,26 @@ class Chksummer: chksum_types = { chksumname: Chksummer(chksumname, partial(hashlib.new, hashlibname), size) - for hashlibname, chksumname, size in [ # conditional upon FIPS, but available in >3.8. - ('md5', 'md5', md5_size), - + ("md5", "md5", md5_size), # Guaranteed as of python 3.8 - ('blake2b', 'blake2b', blake2b_size), - ('blake2s', 'blake2s', blake2s_size), - ('sha1', 'sha1', sha1_size), - ('sha256', 'sha256', sha256_size), - ('sha3_256', 'sha3_256', sha3_256_size), - ('sha3_512', 'sha3_512', sha3_512_size), - ('sha512', 'sha512', sha512_size), - + ("blake2b", "blake2b", blake2b_size), + ("blake2s", "blake2s", blake2s_size), + ("sha1", "sha1", sha1_size), + ("sha256", "sha256", sha256_size), + ("sha3_256", "sha3_256", sha3_256_size), + ("sha3_512", "sha3_512", sha3_512_size), + ("sha512", "sha512", sha512_size), # not guaranteed, but may be available. - ('whirlpool', 'whirlpool', whirlpool_size), - ('ripemd160', 'rmd160', rmd160_size), - + ("whirlpool", "whirlpool", whirlpool_size), + ("ripemd160", "rmd160", rmd160_size), ] if hashlibname in hashlib.algorithms_available } class SizeUpdater: - def __init__(self): self.count = 0 @@ -176,8 +174,7 @@ class SizeChksummer(Chksummer): """ def __init__(self): - super().__init__( - chf_type='size', obj=SizeUpdater, str_size=1000000000) + super().__init__(chf_type="size", obj=SizeUpdater, str_size=1000000000) @staticmethod def long2str(val): diff --git a/src/snakeoil/cli/arghparse.py b/src/snakeoil/cli/arghparse.py index 1c832c3..774699e 100644 --- a/src/snakeoil/cli/arghparse.py +++ b/src/snakeoil/cli/arghparse.py @@ -9,8 +9,18 @@ import pkgutil import subprocess import sys import traceback -from argparse import (_UNRECOGNIZED_ARGS_ATTR, OPTIONAL, PARSER, REMAINDER, SUPPRESS, ZERO_OR_MORE, - ArgumentError, _, _get_action_name, _SubParsersAction) +from argparse import ( + _UNRECOGNIZED_ARGS_ATTR, + OPTIONAL, + PARSER, + REMAINDER, + SUPPRESS, + ZERO_OR_MORE, + ArgumentError, + _, + _get_action_name, + _SubParsersAction, +) from collections import Counter from functools import partial from itertools import chain @@ -31,11 +41,11 @@ from ..version import get_version _generate_docs = False -@klass.patch('argparse.ArgumentParser.add_subparsers') -@klass.patch('argparse._SubParsersAction.add_parser') -@klass.patch('argparse._ActionsContainer.add_mutually_exclusive_group') -@klass.patch('argparse._ActionsContainer.add_argument_group') -@klass.patch('argparse._ActionsContainer.add_argument') +@klass.patch("argparse.ArgumentParser.add_subparsers") +@klass.patch("argparse._SubParsersAction.add_parser") +@klass.patch("argparse._ActionsContainer.add_mutually_exclusive_group") +@klass.patch("argparse._ActionsContainer.add_argument_group") +@klass.patch("argparse._ActionsContainer.add_argument") def _add_argument_docs(orig_func, self, *args, **kwargs): """Enable docs keyword argument support for argparse arguments. @@ -48,16 +58,16 @@ def _add_argument_docs(orig_func, self, *args, **kwargs): enable the global _generate_docs variable in order to replace the summarized help strings with the extended doc strings. """ - docs = kwargs.pop('docs', None) + docs = kwargs.pop("docs", None) obj = orig_func(self, *args, **kwargs) if _generate_docs and docs is not None: if isinstance(docs, (list, tuple)): # list args are often used if originator wanted to strip # off first description summary line - docs = '\n'.join(docs) - docs = '\n'.join(dedent(docs).strip().split('\n')) + docs = "\n".join(docs) + docs = "\n".join(dedent(docs).strip().split("\n")) - if orig_func.__name__ == 'add_subparsers': + if orig_func.__name__ == "add_subparsers": # store original description before overriding it with extended # docs for general subparsers argument groups self._subparsers._description = self._subparsers.description @@ -93,7 +103,7 @@ class ParseNonblockingStdin(argparse.Action): """Accept arguments from standard input in a non-blocking fashion.""" def __init__(self, *args, **kwargs): - self.filter_func = kwargs.pop('filter_func', lambda x: x.strip()) + self.filter_func = kwargs.pop("filter_func", lambda x: x.strip()) super().__init__(*args, **kwargs) def _stdin(self): @@ -106,9 +116,11 @@ class ParseNonblockingStdin(argparse.Action): break def __call__(self, parser, namespace, values, option_string=None): - if values is not None and len(values) == 1 and values[0] == '-': + if values is not None and len(values) == 1 and values[0] == "-": if sys.stdin.isatty(): - raise argparse.ArgumentError(self, "'-' is only valid when piping data in") + raise argparse.ArgumentError( + self, "'-' is only valid when piping data in" + ) values = self._stdin() setattr(namespace, self.dest, values) @@ -117,16 +129,18 @@ class ParseStdin(ExtendAction): """Accept arguments from standard input in a blocking fashion.""" def __init__(self, *args, **kwargs): - self.filter_func = kwargs.pop('filter_func', lambda x: x.strip()) + self.filter_func = kwargs.pop("filter_func", lambda x: x.strip()) super().__init__(*args, **kwargs) def __call__(self, parser, namespace, values, option_string=None): - if values is not None and len(values) == 1 and values[0] == '-': + if values is not None and len(values) == 1 and values[0] == "-": if sys.stdin.isatty(): - raise argparse.ArgumentError(self, "'-' is only valid when piping data in") + raise argparse.ArgumentError( + self, "'-' is only valid when piping data in" + ) values = [x.rstrip() for x in sys.stdin.readlines() if self.filter_func(x)] # reassign stdin to allow interactivity (currently only works for unix) - sys.stdin = open('/dev/tty') + sys.stdin = open("/dev/tty") super().__call__(parser, namespace, values, option_string) @@ -136,10 +150,10 @@ class CommaSeparatedValues(argparse._AppendAction): def parse_values(self, values): items = [] if isinstance(values, str): - items.extend(x for x in values.split(',') if x) + items.extend(x for x in values.split(",") if x) else: for value in values: - items.extend(x for x in value.split(',') if x) + items.extend(x for x in value.split(",") if x) return items def __call__(self, parser, namespace, values, option_string=None): @@ -174,16 +188,16 @@ class CommaSeparatedNegations(argparse._AppendAction): values = [values] for value in values: try: - neg, pos = split_negations(x for x in value.split(',') if x) + neg, pos = split_negations(x for x in value.split(",") if x) except ValueError as e: raise argparse.ArgumentTypeError(e) disabled.extend(neg) enabled.extend(pos) if colliding := set(disabled).intersection(enabled): - collisions = ', '.join(map(repr, sorted(colliding))) + collisions = ", ".join(map(repr, sorted(colliding))) s = pluralism(colliding) - msg = f'colliding value{s}: {collisions}' + msg = f"colliding value{s}: {collisions}" raise argparse.ArgumentError(self, msg) return disabled, enabled @@ -222,7 +236,7 @@ class CommaSeparatedElements(argparse._AppendAction): values = [values] for value in values: try: - neg, neu, pos = split_elements(x for x in value.split(',') if x) + neg, neu, pos = split_elements(x for x in value.split(",") if x) except ValueError as e: raise argparse.ArgumentTypeError(e) disabled.extend(neg) @@ -231,9 +245,9 @@ class CommaSeparatedElements(argparse._AppendAction): elements = [set(x) for x in (disabled, neutral, enabled) if x] if len(elements) > 1 and (colliding := set.intersection(*elements)): - collisions = ', '.join(map(repr, sorted(colliding))) + collisions = ", ".join(map(repr, sorted(colliding))) s = pluralism(colliding) - msg = f'colliding value{s}: {collisions}' + msg = f"colliding value{s}: {collisions}" raise argparse.ArgumentError(self, msg) return disabled, neutral, enabled @@ -260,14 +274,14 @@ class ManHelpAction(argparse._HelpAction): """Display man pages for long --help option and abbreviated output for -h.""" def __call__(self, parser, namespace, values, option_string=None): - if option_string == '--help': + if option_string == "--help": # Try spawning man page -- assumes one level deep for subcommand # specific man pages with commands separated by hyphen. For example # running `pinspect profile --help` tries to open pinspect-profile # man page, but `pinspect profile masks --help` also tries to open # pinspect-profile. - man_page = '-'.join(parser.prog.split()[:2]) - p = subprocess.Popen(['man', man_page], stderr=subprocess.DEVNULL) + man_page = "-".join(parser.prog.split()[:2]) + p = subprocess.Popen(["man", man_page], stderr=subprocess.DEVNULL) p.communicate() if p.returncode == 0: parser.exit() @@ -279,16 +293,17 @@ class ManHelpAction(argparse._HelpAction): class StoreBool(argparse._StoreAction): - - def __init__(self, - option_strings, - dest, - nargs=None, - const=None, - default=None, - required=False, - help=None, - metavar='BOOLEAN'): + def __init__( + self, + option_strings, + dest, + nargs=None, + const=None, + default=None, + required=False, + help=None, + metavar="BOOLEAN", + ): super().__init__( option_strings=option_strings, dest=dest, @@ -298,38 +313,42 @@ class StoreBool(argparse._StoreAction): type=self.boolean, required=required, help=help, - metavar=metavar) + metavar=metavar, + ) @staticmethod def boolean(value): value = value.lower() - if value in ('y', 'yes', 'true', '1'): + if value in ("y", "yes", "true", "1"): return True - elif value in ('n', 'no', 'false', '0'): + elif value in ("n", "no", "false", "0"): return False raise ValueError("value %r must be [y|yes|true|1|n|no|false|0]" % (value,)) class EnableDebug(argparse._StoreTrueAction): - def __call__(self, parser, namespace, values, option_string=None): super().__call__(parser, namespace, values, option_string=option_string) logging.root.setLevel(logging.DEBUG) class Verbosity(argparse.Action): - def __init__(self, option_strings, dest, default=None, required=False, help=None): super().__init__( - option_strings=option_strings, dest=dest, nargs=0, - default=default, required=required, help=help) + option_strings=option_strings, + dest=dest, + nargs=0, + default=default, + required=required, + help=help, + ) # map verbose/quiet args to increment/decrement the underlying verbosity value self.value_map = { - '-q': -1, - '--quiet': -1, - '-v': 1, - '--verbose': 1, + "-q": -1, + "--quiet": -1, + "-v": 1, + "--verbose": 1, } def __call__(self, parser, namespace, values, option_string=None): @@ -343,7 +362,6 @@ class Verbosity(argparse.Action): class DelayedValue: - def __init__(self, invokable, priority=0): self.priority = priority if not callable(invokable): @@ -355,7 +373,6 @@ class DelayedValue: class DelayedDefault(DelayedValue): - @classmethod def wipe(cls, attrs, priority): if isinstance(attrs, str): @@ -376,20 +393,17 @@ class DelayedDefault(DelayedValue): class DelayedParse(DelayedValue): - def __call__(self, namespace, attr): self.invokable() class OrderedParse(DelayedValue): - def __call__(self, namespace, attr): self.invokable(namespace) delattr(namespace, attr) class Delayed(argparse.Action): - def __init__(self, option_strings, dest, target=None, priority=0, **kwargs): if target is None: raise ValueError("target must be non None for Delayed") @@ -397,21 +411,30 @@ class Delayed(argparse.Action): self.priority = int(priority) self.target = target(option_strings=option_strings, dest=dest, **kwargs.copy()) super().__init__( - option_strings=option_strings[:], dest=dest, - nargs=kwargs.get("nargs", None), required=kwargs.get("required", None), - help=kwargs.get("help", None), metavar=kwargs.get("metavar", None), - default=kwargs.get("default", None)) + option_strings=option_strings[:], + dest=dest, + nargs=kwargs.get("nargs", None), + required=kwargs.get("required", None), + help=kwargs.get("help", None), + metavar=kwargs.get("metavar", None), + default=kwargs.get("default", None), + ) def __call__(self, parser, namespace, values, option_string=None): - setattr(namespace, self.dest, DelayedParse( - partial(self.target, parser, namespace, values, option_string), - self.priority)) + setattr( + namespace, + self.dest, + DelayedParse( + partial(self.target, parser, namespace, values, option_string), + self.priority, + ), + ) class Expansion(argparse.Action): - - def __init__(self, option_strings, dest, nargs=None, help=None, - required=None, subst=None): + def __init__( + self, option_strings, dest, nargs=None, help=None, required=None, subst=None + ): if subst is None: raise TypeError("substitution string must be set") # simple aliases with no required arguments shouldn't need to specify nargs @@ -424,7 +447,8 @@ class Expansion(argparse.Action): help=help, required=required, default=False, - nargs=nargs) + nargs=nargs, + ) self.subst = tuple(subst) def __call__(self, parser, namespace, values, option_string=None): @@ -434,7 +458,7 @@ class Expansion(argparse.Action): if isinstance(values, str): vals = [vals] dvals = {str(idx): val for idx, val in enumerate(vals)} - dvals['*'] = ' '.join(vals) + dvals["*"] = " ".join(vals) for action in actions: action_map.update((option, action) for option in action.option_strings) @@ -445,8 +469,8 @@ class Expansion(argparse.Action): args = [x % dvals for x in args] if not action: raise ValueError( - "unable to find option %r for %r" % - (option, self.option_strings)) + "unable to find option %r for %r" % (option, self.option_strings) + ) if action.type is not None: args = list(map(action.type, args)) if action.nargs in (1, None): @@ -456,7 +480,6 @@ class Expansion(argparse.Action): class _SubParser(argparse._SubParsersAction): - def add_parser(self, name, cls=None, **kwargs): """Subparser that links description/help if one is specified.""" description = kwargs.get("description") @@ -465,7 +488,7 @@ class _SubParser(argparse._SubParsersAction): if help_txt is not None: kwargs["description"] = help_txt elif help_txt is None: - kwargs["help"] = description.split('\n', 1)[0] + kwargs["help"] = description.split("\n", 1)[0] # support using a custom parser class for the subparser orig_class = self._parser_class @@ -486,7 +509,7 @@ class _SubParser(argparse._SubParsersAction): Note that this assumes a specific module naming and layout scheme for commands. """ prog = self._prog_prefix - module = f'{prog}.scripts.{prog}_{subcmd}' + module = f"{prog}.scripts.{prog}_{subcmd}" func = partial(self._lazy_parser, module, subcmd) self._name_parser_map[subcmd] = lazy_object_proxy.Proxy(func) @@ -507,8 +530,8 @@ class _SubParser(argparse._SubParsersAction): try: parser = self._name_parser_map[parser_name] except KeyError: - tup = parser_name, ', '.join(self._name_parser_map) - msg = _('unknown parser %r (choices: %s)') % tup + tup = parser_name, ", ".join(self._name_parser_map) + msg = _("unknown parser %r (choices: %s)") % tup raise argparse.ArgumentError(self, msg) # parse all the remaining options into the namespace @@ -526,11 +549,13 @@ class CsvHelpFormatter(argparse.HelpFormatter): def _format_args(self, action, default_metavar): get_metavar = self._metavar_formatter(action, default_metavar) if isinstance(action, (CommaSeparatedValues, CommaSeparatedValuesAppend)): - result = '%s[,%s,...]' % get_metavar(2) - elif isinstance(action, (CommaSeparatedNegations, CommaSeparatedNegationsAppend)): - result = '%s[,-%s,...]' % get_metavar(2) + result = "%s[,%s,...]" % get_metavar(2) + elif isinstance( + action, (CommaSeparatedNegations, CommaSeparatedNegationsAppend) + ): + result = "%s[,-%s,...]" % get_metavar(2) elif isinstance(action, (CommaSeparatedElements, CommaSeparatedElementsAppend)): - result = '%s[,-%s,+%s...]' % get_metavar(3) + result = "%s[,-%s,+%s...]" % get_metavar(3) else: result = super()._format_args(action, default_metavar) return result @@ -540,7 +565,7 @@ class SortedHelpFormatter(CsvHelpFormatter): """Help formatter that sorts arguments by option strings.""" def add_arguments(self, actions): - actions = sorted(actions, key=attrgetter('option_strings')) + actions = sorted(actions, key=attrgetter("option_strings")) super().add_arguments(actions) @@ -576,7 +601,7 @@ class SubcmdAbbrevArgumentParser(argparse.ArgumentParser): # for everything but PARSER, REMAINDER args, strip out first '--' if action.nargs not in [PARSER, REMAINDER]: try: - arg_strings.remove('--') + arg_strings.remove("--") except ValueError: pass @@ -592,8 +617,11 @@ class SubcmdAbbrevArgumentParser(argparse.ArgumentParser): # when nargs='*' on a positional, if there were no command-line # args, use the default if it is anything other than None - elif (not arg_strings and action.nargs == ZERO_OR_MORE and - not action.option_strings): + elif ( + not arg_strings + and action.nargs == ZERO_OR_MORE + and not action.option_strings + ): if action.default is not None: value = action.default else: @@ -602,7 +630,7 @@ class SubcmdAbbrevArgumentParser(argparse.ArgumentParser): # single argument or optional argument produces a single value elif len(arg_strings) == 1 and action.nargs in [None, OPTIONAL]: - arg_string, = arg_strings + (arg_string,) = arg_strings value = self._get_value(action, arg_string) self._check_value(action, value) @@ -688,7 +716,7 @@ class OptionalsParser(argparse.ArgumentParser): for i, mutex_action in enumerate(mutex_group._group_actions): conflicts = action_conflicts.setdefault(mutex_action, []) conflicts.extend(group_actions[:i]) - conflicts.extend(group_actions[i + 1:]) + conflicts.extend(group_actions[i + 1 :]) # find all option indices, and determine the arg_string_pattern # which has an 'O' if there is an option at an index, @@ -699,24 +727,24 @@ class OptionalsParser(argparse.ArgumentParser): for i, arg_string in enumerate(arg_strings_iter): # all args after -- are non-options - if arg_string == '--': - arg_string_pattern_parts.append('-') + if arg_string == "--": + arg_string_pattern_parts.append("-") for arg_string in arg_strings_iter: - arg_string_pattern_parts.append('A') + arg_string_pattern_parts.append("A") # otherwise, add the arg to the arg strings # and note the index if it was an option else: option_tuple = self._parse_optional(arg_string) if option_tuple is None: - pattern = 'A' + pattern = "A" else: option_string_indices[i] = option_tuple - pattern = 'O' + pattern = "O" arg_string_pattern_parts.append(pattern) # join the pieces together to form the pattern - arg_strings_pattern = ''.join(arg_string_pattern_parts) + arg_strings_pattern = "".join(arg_string_pattern_parts) # converts arg strings to the appropriate and then takes the action seen_actions = set() @@ -733,7 +761,7 @@ class OptionalsParser(argparse.ArgumentParser): seen_non_default_actions.add(action) for conflict_action in action_conflicts.get(action, []): if conflict_action in seen_non_default_actions: - msg = _('not allowed with argument %s') + msg = _("not allowed with argument %s") action_name = _get_action_name(conflict_action) raise ArgumentError(action, msg % action_name) @@ -762,14 +790,14 @@ class OptionalsParser(argparse.ArgumentParser): # if we match help options, skip them for now so subparsers # show up in the help output - if arg_strings[start_index] in ('-h', '--help'): + if arg_strings[start_index] in ("-h", "--help"): extras.append(arg_strings[start_index]) return start_index + 1 # if there is an explicit argument, try to match the # optional's string arguments to only this if explicit_arg is not None: - arg_count = match_argument(action, 'A') + arg_count = match_argument(action, "A") # if the action is a single-dash option and takes no # arguments, try to parse more single-dash options out @@ -785,7 +813,7 @@ class OptionalsParser(argparse.ArgumentParser): action = optionals_map[option_string] explicit_arg = new_explicit_arg else: - msg = _('ignored explicit argument %r') + msg = _("ignored explicit argument %r") raise ArgumentError(action, msg % explicit_arg) # if the action expect exactly one argument, we've @@ -799,7 +827,7 @@ class OptionalsParser(argparse.ArgumentParser): # error if a double-dash option did not use the # explicit argument else: - msg = _('ignored explicit argument %r') + msg = _("ignored explicit argument %r") raise ArgumentError(action, msg % explicit_arg) # if there is no explicit argument, try to match the @@ -835,13 +863,13 @@ class OptionalsParser(argparse.ArgumentParser): # slice off the appropriate arg strings for each Positional # and add the Positional and its args to the list for action, arg_count in zip(positionals, arg_counts): - args = arg_strings[start_index: start_index + arg_count] + args = arg_strings[start_index : start_index + arg_count] start_index += arg_count take_action(action, args) # slice off the Positionals that we just parsed and return the # index at which the Positionals' string args stopped - positionals[:] = positionals[len(arg_counts):] + positionals[:] = positionals[len(arg_counts) :] return start_index # consume Positionals and Optionals alternately, until we have @@ -855,10 +883,9 @@ class OptionalsParser(argparse.ArgumentParser): while start_index <= max_option_string_index: # consume any Positionals preceding the next option - next_option_string_index = min([ - index - for index in option_string_indices - if index >= start_index]) + next_option_string_index = min( + [index for index in option_string_indices if index >= start_index] + ) if start_index != next_option_string_index: # positionals_end_index = consume_positionals(start_index) positionals_end_index = start_index @@ -894,7 +921,9 @@ class OptionalsParser(argparse.ArgumentParser): for action in self._actions: if action not in seen_actions: # ignore required subcommands and positionals as they'll be handled later - skip = not action.option_strings or isinstance(action, _SubParsersAction) + skip = not action.option_strings or isinstance( + action, _SubParsersAction + ) if action.required and not skip: required_actions.append(_get_action_name(action)) else: @@ -902,16 +931,23 @@ class OptionalsParser(argparse.ArgumentParser): # parsing arguments to avoid calling convert functions # twice (which may fail) if the argument was given, but # only if it was defined already in the namespace - if (action.default is not None and - isinstance(action.default, str) and - hasattr(namespace, action.dest) and - action.default is getattr(namespace, action.dest)): - setattr(namespace, action.dest, - self._get_value(action, action.default)) + if ( + action.default is not None + and isinstance(action.default, str) + and hasattr(namespace, action.dest) + and action.default is getattr(namespace, action.dest) + ): + setattr( + namespace, + action.dest, + self._get_value(action, action.default), + ) if required_actions: - self.error(_('the following arguments are required: %s') % - ', '.join(required_actions)) + self.error( + _("the following arguments are required: %s") + % ", ".join(required_actions) + ) # make sure all required groups had one option present for group in self._mutually_exclusive_groups: @@ -922,11 +958,13 @@ class OptionalsParser(argparse.ArgumentParser): # if no actions were used, report the error else: - names = [_get_action_name(action) - for action in group._group_actions - if action.help is not SUPPRESS] - msg = _('one of the arguments %s is required') - self.error(msg % ' '.join(names)) + names = [ + _get_action_name(action) + for action in group._group_actions + if action.help is not SUPPRESS + ] + msg = _("one of the arguments %s is required") + self.error(msg % " ".join(names)) # return the updated namespace and the extra arguments return namespace, extras @@ -937,31 +975,49 @@ class CsvActionsParser(argparse.ArgumentParser): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.register('action', 'csv', CommaSeparatedValues) - self.register('action', 'csv_append', CommaSeparatedValuesAppend) - self.register('action', 'csv_negations', CommaSeparatedNegations) - self.register('action', 'csv_negations_append', CommaSeparatedNegationsAppend) - self.register('action', 'csv_elements', CommaSeparatedElements) - self.register('action', 'csv_elements_append', CommaSeparatedElementsAppend) + self.register("action", "csv", CommaSeparatedValues) + self.register("action", "csv_append", CommaSeparatedValuesAppend) + self.register("action", "csv_negations", CommaSeparatedNegations) + self.register("action", "csv_negations_append", CommaSeparatedNegationsAppend) + self.register("action", "csv_elements", CommaSeparatedElements) + self.register("action", "csv_elements_append", CommaSeparatedElementsAppend) class ArgumentParser(OptionalsParser, CsvActionsParser): """Extended, argparse-compatible argument parser.""" - def __init__(self, suppress=False, subcmds=False, color=True, debug=True, quiet=True, - verbose=True, version=True, add_help=True, sorted_help=False, - description=None, docs=None, script=None, prog=None, **kwargs): - self.debug = debug and '--debug' in sys.argv[1:] + def __init__( + self, + suppress=False, + subcmds=False, + color=True, + debug=True, + quiet=True, + verbose=True, + version=True, + add_help=True, + sorted_help=False, + description=None, + docs=None, + script=None, + prog=None, + **kwargs, + ): + self.debug = debug and "--debug" in sys.argv[1:] self.verbosity = int(verbose) if self.verbosity: argv = Counter(sys.argv[1:]) # Only supports single, short opts (i.e. -vv isn't recognized), # post argparsing the proper value supporting those kind of args is # in the options namespace. - self.verbosity = sum(chain.from_iterable(( - (-1 for x in range(argv['-q'] + argv['--quiet'])), - (1 for x in range(argv['-v'] + argv['--verbose'])), - ))) + self.verbosity = sum( + chain.from_iterable( + ( + (-1 for x in range(argv["-q"] + argv["--quiet"])), + (1 for x in range(argv["-v"] + argv["--verbose"])), + ) + ) + ) # subparsers action object from calling add_subparsers() self.__subparsers = None @@ -979,7 +1035,7 @@ class ArgumentParser(OptionalsParser, CsvActionsParser): # usage such as adding conflicting options to both the root command and # subcommands without causing issues in addition to helping support # default subparsers. - self._parents = tuple(kwargs.get('parents', ())) + self._parents = tuple(kwargs.get("parents", ())) # extract the description to use and set docs for doc generation description = self._update_desc(description, docs) @@ -993,11 +1049,12 @@ class ArgumentParser(OptionalsParser, CsvActionsParser): raise TypeError except TypeError: raise ValueError( - "invalid script parameter, should be (__file__, __name__)") + "invalid script parameter, should be (__file__, __name__)" + ) - project = script_module.split('.')[0] + project = script_module.split(".")[0] if prog is None: - prog = script_module.split('.')[-1] + prog = script_module.split(".")[-1] if sorted_help: formatter = SortedHelpFormatter @@ -1005,27 +1062,36 @@ class ArgumentParser(OptionalsParser, CsvActionsParser): formatter = CsvHelpFormatter super().__init__( - description=description, formatter_class=formatter, - prog=prog, add_help=False, **kwargs) + description=description, + formatter_class=formatter, + prog=prog, + add_help=False, + **kwargs, + ) # register custom actions - self.register('action', 'parsers', _SubParser) + self.register("action", "parsers", _SubParser) if not suppress: - base_opts = self.add_argument_group('base options') + base_opts = self.add_argument_group("base options") if add_help: base_opts.add_argument( - '-h', '--help', action=ManHelpAction, default=argparse.SUPPRESS, - help='show this help message and exit', + "-h", + "--help", + action=ManHelpAction, + default=argparse.SUPPRESS, + help="show this help message and exit", docs=""" Show this help message and exit. To get more information see the related man page. - """) + """, + ) if version and script is not None: # Note that this option will currently only be available on the # base command, not on subcommands. base_opts.add_argument( - '--version', action='version', + "--version", + action="version", version=get_version(project, script_path), help="show this program's version info and exit", docs=""" @@ -1034,39 +1100,58 @@ class ArgumentParser(OptionalsParser, CsvActionsParser): When running from within a git repo or a version installed from git the latest commit hash and date will be shown. - """) + """, + ) if debug: base_opts.add_argument( - '--debug', action=EnableDebug, help='enable debugging checks', - docs='Enable debug checks and show verbose debug output.') + "--debug", + action=EnableDebug, + help="enable debugging checks", + docs="Enable debug checks and show verbose debug output.", + ) if quiet: base_opts.add_argument( - '-q', '--quiet', action=Verbosity, dest='verbosity', default=0, - help='suppress non-error messages', - docs="Suppress non-error, informational messages.") + "-q", + "--quiet", + action=Verbosity, + dest="verbosity", + default=0, + help="suppress non-error messages", + docs="Suppress non-error, informational messages.", + ) if verbose: base_opts.add_argument( - '-v', '--verbose', action=Verbosity, dest='verbosity', default=0, - help='show verbose output', - docs="Increase the verbosity of various output.") + "-v", + "--verbose", + action=Verbosity, + dest="verbosity", + default=0, + help="show verbose output", + docs="Increase the verbosity of various output.", + ) if color: base_opts.add_argument( - '--color', action=StoreBool, + "--color", + action=StoreBool, default=sys.stdout.isatty(), - help='enable/disable color support', + help="enable/disable color support", docs=""" Toggle colored output support. This can be used to forcibly enable color support when piping output or other sitations where stdout is not a tty. - """) + """, + ) # register existing subcommands if subcmds: - prefix = f'{prog}.scripts.{prog}_' + prefix = f"{prog}.scripts.{prog}_" if subcmd_modules := [ - name[len(prefix):] for _, name, _ in - pkgutil.walk_packages([os.path.dirname(script_path)], f'{prog}.scripts.') - if name.startswith(prefix)]: + name[len(prefix) :] + for _, name, _ in pkgutil.walk_packages( + [os.path.dirname(script_path)], f"{prog}.scripts." + ) + if name.startswith(prefix) + ]: subparsers = self.add_subparsers() for subcmd in subcmd_modules: subparsers.add_command(subcmd) @@ -1080,7 +1165,7 @@ class ArgumentParser(OptionalsParser, CsvActionsParser): """ description_lines = [] if description is not None: - description_lines = description.strip().split('\n', 1) + description_lines = description.strip().split("\n", 1) description = description_lines[0] if _generate_docs: if docs is None and len(description_lines) == 2: @@ -1156,7 +1241,9 @@ class ArgumentParser(OptionalsParser, CsvActionsParser): try: # run registered early parse functions from all parsers - for functor, parser in chain.from_iterable(x.__early_parse for x in self.parsers): + for functor, parser in chain.from_iterable( + x.__early_parse for x in self.parsers + ): namespace, args = functor(parser, namespace, args) # parse the arguments and exit if there are any errors @@ -1176,7 +1263,7 @@ class ArgumentParser(OptionalsParser, CsvActionsParser): args, unknown_args = self.parse_known_args(args, namespace) # make sure the correct function and prog are set if running a subcommand - subcmd_parser = self.subparsers.get(getattr(args, 'subcommand', None), None) + subcmd_parser = self.subparsers.get(getattr(args, "subcommand", None), None) if subcmd_parser is not None: # override the running program with full subcommand self.prog = subcmd_parser.prog @@ -1186,7 +1273,7 @@ class ArgumentParser(OptionalsParser, CsvActionsParser): namespace.main_func = subcmd_parser.__main_func if unknown_args: - self.error('unrecognized arguments: %s' % ' '.join(unknown_args)) + self.error("unrecognized arguments: %s" % " ".join(unknown_args)) # Two runs are required; first, handle any suppression defaults # introduced. Subparsers defaults cannot override the parent parser, as @@ -1198,14 +1285,20 @@ class ArgumentParser(OptionalsParser, CsvActionsParser): # intentionally no protection of suppression code; this should # just work. - i = ((attr, val) for attr, val in args.__dict__.items() - if isinstance(val, DelayedDefault)) + i = ( + (attr, val) + for attr, val in args.__dict__.items() + if isinstance(val, DelayedDefault) + ) for attr, functor in sorted(i, key=lambda val: val[1].priority): functor(args, attr) # now run the delays - i = ((attr, val) for attr, val in args.__dict__.items() - if isinstance(val, DelayedValue)) + i = ( + (attr, val) + for attr, val in args.__dict__.items() + if isinstance(val, DelayedValue) + ) try: for attr, delayed in sorted(i, key=lambda val: val[1].priority): delayed(args, attr) @@ -1216,7 +1309,9 @@ class ArgumentParser(OptionalsParser, CsvActionsParser): self.error(str(e)) # run final arg validation - final_checks = [k for k in args.__dict__.keys() if k.startswith('__final_check__')] + final_checks = [ + k for k in args.__dict__.keys() if k.startswith("__final_check__") + ] for check in final_checks: functor = args.pop(check) functor(self, args) @@ -1232,7 +1327,7 @@ class ArgumentParser(OptionalsParser, CsvActionsParser): if self.debug and sys.exc_info() != (None, None, None): # output traceback if any exception is on the stack traceback.print_exc() - self.exit(status, '%s: error: %s\n' % (self.prog, message)) + self.exit(status, "%s: error: %s\n" % (self.prog, message)) def bind_main_func(self, functor): """Decorator to set a main function for the parser.""" @@ -1245,8 +1340,8 @@ class ArgumentParser(OptionalsParser, CsvActionsParser): def bind_class(self, obj): if not isinstance(obj, ArgparseCommand): raise ValueError( - "expected obj to be an instance of " - "ArgparseCommand; got %r" % (obj,)) + "expected obj to be an instance of " "ArgparseCommand; got %r" % (obj,) + ) obj.bind_to_parser(self) return self @@ -1261,10 +1356,12 @@ class ArgumentParser(OptionalsParser, CsvActionsParser): """Only run delayed default functor if the attribute isn't set.""" if isinstance(object.__getattribute__(namespace, attr), DelayedValue): functor(namespace, attr) + if name is None: name = functor.__name__ self.set_defaults(**{name: DelayedValue(default, priority)}) return functor + return f def bind_parse_priority(self, priority): @@ -1272,6 +1369,7 @@ class ArgumentParser(OptionalsParser, CsvActionsParser): name = functor.__name__ self.set_defaults(**{name: OrderedParse(functor, priority)}) return functor + return f def add_subparsers(self, **kwargs): @@ -1280,9 +1378,9 @@ class ArgumentParser(OptionalsParser, CsvActionsParser): if self.__subparsers is not None: return self.__subparsers - kwargs.setdefault('title', 'subcommands') - kwargs.setdefault('dest', 'subcommand') - kwargs.setdefault('prog', self.prog) + kwargs.setdefault("title", "subcommands") + kwargs.setdefault("dest", "subcommand") + kwargs.setdefault("prog", self.prog) subparsers = argparse.ArgumentParser.add_subparsers(self, **kwargs) subparsers.required = True self.__subparsers = subparsers @@ -1300,18 +1398,17 @@ class ArgumentParser(OptionalsParser, CsvActionsParser): def bind_final_check(self, functor): """Decorator to bind a function for argument validation.""" - name = f'__final_check__{functor.__name__}' + name = f"__final_check__{functor.__name__}" self.set_defaults(**{name: functor}) return functor class ArgparseCommand: - def bind_to_parser(self, parser): parser.bind_main_func(self) def __call__(self, namespace, out, err): - raise NotImplementedError(self, '__call__') + raise NotImplementedError(self, "__call__") class FileType(argparse.FileType): @@ -1322,11 +1419,11 @@ class FileType(argparse.FileType): def __call__(self, string): # the special argument "-" means sys.std{in,out} - if string == '-': - if 'r' in self._mode: - return sys.stdin.buffer if 'b' in self._mode else sys.stdin - elif any(c in self._mode for c in 'wax'): - return sys.stdout.buffer if 'b' in self._mode else sys.stdout + if string == "-": + if "r" in self._mode: + return sys.stdin.buffer if "b" in self._mode else sys.stdin + elif any(c in self._mode for c in "wax"): + return sys.stdout.buffer if "b" in self._mode else sys.stdout else: msg = _('argument "-" with mode %r') % self._mode raise ValueError(msg) @@ -1342,23 +1439,27 @@ class FileType(argparse.FileType): def existent_path(value): """Check if file argument path exists.""" if not os.path.exists(value): - raise argparse.ArgumentTypeError(f'nonexistent path: {value!r}') + raise argparse.ArgumentTypeError(f"nonexistent path: {value!r}") try: return os.path.realpath(value) except EnvironmentError as e: - raise ValueError(f'while resolving path {value!r}, encountered error: {e}') from e + raise ValueError( + f"while resolving path {value!r}, encountered error: {e}" + ) from e def existent_dir(value): """Check if argument path exists and is a directory.""" if not os.path.exists(value): - raise argparse.ArgumentTypeError(f'nonexistent dir: {value!r}') + raise argparse.ArgumentTypeError(f"nonexistent dir: {value!r}") elif not os.path.isdir(value): - raise argparse.ArgumentTypeError(f'file already exists: {value!r}') + raise argparse.ArgumentTypeError(f"file already exists: {value!r}") try: return os.path.realpath(value) except EnvironmentError as e: - raise ValueError(f'while resolving path {value!r}, encountered error: {e}') from e + raise ValueError( + f"while resolving path {value!r}, encountered error: {e}" + ) from e def create_dir(value): @@ -1367,9 +1468,9 @@ def create_dir(value): try: os.makedirs(path, exist_ok=True) except FileExistsError: - raise argparse.ArgumentTypeError(f'file already exists: {value!r}') + raise argparse.ArgumentTypeError(f"file already exists: {value!r}") except IOError as e: - raise argparse.ArgumentTypeError(f'failed creating dir: {e}') + raise argparse.ArgumentTypeError(f"failed creating dir: {e}") return path @@ -1378,12 +1479,12 @@ def bounded_int(func, desc, x): try: n = int(x) except ValueError: - raise argparse.ArgumentTypeError('invalid integer value') + raise argparse.ArgumentTypeError("invalid integer value") if not func(n): - raise argparse.ArgumentTypeError(f'must be {desc}') + raise argparse.ArgumentTypeError(f"must be {desc}") return n def positive_int(x): """Check if argument is a positive integer.""" - return bounded_int(lambda n: n >= 1, '>= 1', x) + return bounded_int(lambda n: n >= 1, ">= 1", x) diff --git a/src/snakeoil/cli/exceptions.py b/src/snakeoil/cli/exceptions.py index fb0149f..42b8d67 100644 --- a/src/snakeoil/cli/exceptions.py +++ b/src/snakeoil/cli/exceptions.py @@ -11,7 +11,7 @@ class UserException(Exception): self._verbosity = verbosity def msg(self, verbosity=0): - return '' + return "" class ExitException(Exception): @@ -30,6 +30,8 @@ class ExitException(Exception): def find_user_exception(exc): """Find the UserException related to a given exception if one exists.""" try: - return next(e for e in walk_exception_chain(exc) if isinstance(e, UserException)) + return next( + e for e in walk_exception_chain(exc) if isinstance(e, UserException) + ) except StopIteration: return None diff --git a/src/snakeoil/cli/input.py b/src/snakeoil/cli/input.py index a89d61e..9db9ea8 100644 --- a/src/snakeoil/cli/input.py +++ b/src/snakeoil/cli/input.py @@ -41,9 +41,9 @@ def userquery(prompt, out, err, responses=None, default_answer=None, limit=3): """ if responses is None: responses = { - 'yes': (True, out.fg('green'), 'Yes'), - 'no': (False, out.fg('red'), 'No'), - } + "yes": (True, out.fg("green"), "Yes"), + "no": (False, out.fg("red"), "No"), + } if default_answer is None: default_answer = True if default_answer is not None: @@ -52,25 +52,25 @@ def userquery(prompt, out, err, responses=None, default_answer=None, limit=3): default_answer_name = val[1:] break else: - raise ValueError('default answer matches no responses') + raise ValueError("default answer matches no responses") for i in range(limit): # XXX see docstring about crummyness if isinstance(prompt, tuple): out.write(autoline=False, *prompt) else: out.write(prompt, autoline=False) - out.write(' [', autoline=False) + out.write(" [", autoline=False) prompts = list(responses.values()) for choice in prompts[:-1]: out.write(autoline=False, *choice[1:]) - out.write(out.reset, '/', autoline=False) + out.write(out.reset, "/", autoline=False) out.write(autoline=False, *prompts[-1][1:]) - out.write(out.reset, ']', autoline=False) + out.write(out.reset, "]", autoline=False) if default_answer is not None: - out.write(' (default: ', autoline=False) + out.write(" (default: ", autoline=False) out.write(autoline=False, *default_answer_name) - out.write(')', autoline=False) - out.write(': ', autoline=False) + out.write(")", autoline=False) + out.write(": ", autoline=False) try: response = input() except EOFError as e: @@ -83,15 +83,20 @@ def userquery(prompt, out, err, responses=None, default_answer=None, limit=3): raise if not response: return default_answer - results = sorted(set( - (key, value) for key, value in responses.items() - if key[:len(response)].lower() == response.lower())) + results = sorted( + set( + (key, value) + for key, value in responses.items() + if key[: len(response)].lower() == response.lower() + ) + ) if not results: - err.write('Sorry, response %r not understood.' % (response,)) + err.write("Sorry, response %r not understood." % (response,)) elif len(results) > 1: err.write( - 'Response %r is ambiguous (%s)' % - (response, ', '.join(key for key, val in results))) + "Response %r is ambiguous (%s)" + % (response, ", ".join(key for key, val in results)) + ) else: return list(results)[0][1][0] diff --git a/src/snakeoil/cli/tool.py b/src/snakeoil/cli/tool.py index 2a142b7..d00b3cf 100644 --- a/src/snakeoil/cli/tool.py +++ b/src/snakeoil/cli/tool.py @@ -36,14 +36,14 @@ class Tool: if not sys.stdout.isatty() and sys.stdout == sys.__stdout__: # if redirecting/piping stdout use line buffering, skip if # stdout has been set to some non-standard object - outfile = os.fdopen(sys.stdout.fileno(), 'w', 1) + outfile = os.fdopen(sys.stdout.fileno(), "w", 1) else: outfile = sys.stdout if errfile is None: errfile = sys.stderr out_fd = err_fd = None - if hasattr(outfile, 'fileno') and hasattr(errfile, 'fileno'): + if hasattr(outfile, "fileno") and hasattr(errfile, "fileno"): # annoyingly, fileno can exist but through unsupport try: out_fd, err_fd = outfile.fileno(), errfile.fileno() @@ -52,9 +52,11 @@ class Tool: if out_fd is not None and err_fd is not None: out_stat, err_stat = os.fstat(out_fd), os.fstat(err_fd) - if out_stat.st_dev == err_stat.st_dev \ - and out_stat.st_ino == err_stat.st_ino and \ - not errfile.isatty(): + if ( + out_stat.st_dev == err_stat.st_dev + and out_stat.st_ino == err_stat.st_ino + and not errfile.isatty() + ): # they're the same underlying fd. thus # point the handles at the same so we don't # get intermixed buffering issues. @@ -64,7 +66,7 @@ class Tool: self._errfile = errfile self.out = self.parser.out = formatters.PlainTextFormatter(outfile) self.err = self.parser.err = formatters.PlainTextFormatter(errfile) - self.out.verbosity = self.err.verbosity = getattr(self.parser, 'verbosity', 0) + self.out.verbosity = self.err.verbosity = getattr(self.parser, "verbosity", 0) def __call__(self, args=None): """Run the utility. @@ -98,19 +100,21 @@ class Tool: try: self.pre_parse(args, namespace) options = self.parser.parse_args(args=args, namespace=namespace) - main_func = options.pop('main_func', None) + main_func = options.pop("main_func", None) if main_func is None: raise RuntimeError("argparser missing main method") # reconfigure formatters for colored output if enabled - if getattr(options, 'color', True): + if getattr(options, "color", True): formatter_factory = partial( - formatters.get_formatter, force_color=getattr(options, 'color', False)) + formatters.get_formatter, + force_color=getattr(options, "color", False), + ) self.out = formatter_factory(self._outfile) self.err = formatter_factory(self._errfile) # reconfigure formatters with properly parsed output verbosity - self.out.verbosity = self.err.verbosity = getattr(options, 'verbosity', 0) + self.out.verbosity = self.err.verbosity = getattr(options, "verbosity", 0) if logging.root.handlers: # Remove the default handler. @@ -138,13 +142,13 @@ class Tool: exc = find_user_exception(e) if exc is not None: # allow exception attribute to override user verbosity level - if getattr(exc, '_verbosity', None) is not None: + if getattr(exc, "_verbosity", None) is not None: verbosity = exc._verbosity else: - verbosity = getattr(self.parser, 'verbosity', 0) + verbosity = getattr(self.parser, "verbosity", 0) # output verbose error message if it exists if verbosity > 0: - msg = exc.msg(verbosity).strip('\n') + msg = exc.msg(verbosity).strip("\n") if msg: self.err.write(msg) raise SystemExit @@ -166,15 +170,17 @@ class Tool: try: with suppress_warnings: - self.options, func = self.parse_args(args=self.args, namespace=self.options) + self.options, func = self.parse_args( + args=self.args, namespace=self.options + ) exitstatus = func(self.options, self.out, self.err) except SystemExit as e: # handle argparse or other third party modules using sys.exit internally exitstatus = e.code except KeyboardInterrupt: - self._errfile.write('keyboard interrupted- exiting') + self._errfile.write("keyboard interrupted- exiting") if self.parser.debug: - self._errfile.write('\n') + self._errfile.write("\n") traceback.print_exc() signal(SIGINT, SIG_DFL) os.killpg(os.getpgid(0), SIGINT) @@ -187,9 +193,9 @@ class Tool: if self.options is not None: # set terminal title on exit if exitstatus: - self.out.title(f'{self.options.prog} failed') + self.out.title(f"{self.options.prog} failed") else: - self.out.title(f'{self.options.prog} succeeded') + self.out.title(f"{self.options.prog} succeeded") return exitstatus @@ -204,18 +210,25 @@ class FormattingHandler(logging.Handler): def emit(self, record): if record.levelno >= logging.ERROR: - color = 'red' + color = "red" elif record.levelno >= logging.WARNING: - color = 'yellow' + color = "yellow" else: - color = 'cyan' - first_prefix = (self.out.fg(color), self.out.bold, record.levelname, - self.out.reset, ' ', record.name, ': ') - later_prefix = (len(record.levelname) + len(record.name)) * ' ' + ' : ' + color = "cyan" + first_prefix = ( + self.out.fg(color), + self.out.bold, + record.levelname, + self.out.reset, + " ", + record.name, + ": ", + ) + later_prefix = (len(record.levelname) + len(record.name)) * " " + " : " self.out.first_prefix.extend(first_prefix) self.out.later_prefix.append(later_prefix) try: - for line in self.format(record).split('\n'): + for line in self.format(record).split("\n"): self.out.write(line, wrap=True) except Exception: self.handleError(record) diff --git a/src/snakeoil/compatibility.py b/src/snakeoil/compatibility.py index c9a8a54..4fee417 100644 --- a/src/snakeoil/compatibility.py +++ b/src/snakeoil/compatibility.py @@ -8,12 +8,15 @@ __all__ = ("cmp", "sorted_cmp", "sort_cmp") def sorted_key_from_cmp(cmp_func, key_func=None): class _key_proxy: - __slots__ = ('_obj',) + __slots__ = ("_obj",) + + if key_func: # done this way for speed reasons. - if key_func: # done this way for speed reasons. def __init__(self, obj, key_convert=key_func): self._obj = key_convert(obj) + else: + def __init__(self, obj): self._obj = obj @@ -40,13 +43,13 @@ def cmp(obj1, obj2, raw_cmp=_raw_cmp): def sorted_cmp(sequence, func, key=None, reverse=False): - return sorted(sequence, reverse=reverse, - key=sorted_key_from_cmp(func, key_func=key)) + return sorted( + sequence, reverse=reverse, key=sorted_key_from_cmp(func, key_func=key) + ) def sort_cmp(list_inst, func, key=None, reverse=False): - list_inst.sort(reverse=reverse, - key=sorted_key_from_cmp(func, key_func=key)) + list_inst.sort(reverse=reverse, key=sorted_key_from_cmp(func, key_func=key)) IGNORED_EXCEPTIONS = (RuntimeError, MemoryError, SystemExit, KeyboardInterrupt) diff --git a/src/snakeoil/compression/__init__.py b/src/snakeoil/compression/__init__.py index 77c0631..9eee103 100644 --- a/src/snakeoil/compression/__init__.py +++ b/src/snakeoil/compression/__init__.py @@ -9,13 +9,12 @@ from ..process.spawn import spawn_get_output class _transform_source: - def __init__(self, name): self.name = name @cached_property def module(self): - return import_module(f'snakeoil.compression._{self.name}') + return import_module(f"snakeoil.compression._{self.name}") def compress_data(self, data, level, parallelize=False): parallelize = parallelize and self.module.parallelizable @@ -34,7 +33,7 @@ class _transform_source: return self.module.decompress_handle(handle, parallelize=parallelize) -_transforms = {name: _transform_source(name) for name in ('bzip2', 'xz')} +_transforms = {name: _transform_source(name) for name in ("bzip2", "xz")} def compress_data(compressor_type, data, level=9, **kwds): @@ -73,13 +72,13 @@ class ArComp: cls = cls.known_exts[ext] return super(ArComp, cls).__new__(cls) except KeyError: - raise ArCompError(f'unknown compression file extension: {ext!r}') + raise ArCompError(f"unknown compression file extension: {ext!r}") def __init_subclass__(cls, **kwargs): """Initialize result subclasses and register archive extensions.""" super().__init_subclass__(**kwargs) - if not all((cls.binary, cls.default_unpack_cmd, cls.exts)): # pragma: no cover - raise ValueError(f'class missing required attrs: {cls!r}') + if not all((cls.binary, cls.default_unpack_cmd, cls.exts)): # pragma: no cover + raise ValueError(f"class missing required attrs: {cls!r}") for ext in cls.exts: cls.known_exts[ext] = cls @@ -95,9 +94,10 @@ class ArComp: except process.CommandNotFound: continue else: - choices = ', '.join(self.binary) + choices = ", ".join(self.binary) raise ArCompError( - f'required binary not found from the following choices: {choices}') + f"required binary not found from the following choices: {choices}" + ) cmd = self.default_unpack_cmd.format(binary=binary, path=self.path) return cmd @@ -112,7 +112,7 @@ class _Archive: cmd = shlex.split(self._unpack_cmd.format(path=self.path)) ret, output = spawn_get_output(cmd, collect_fds=(2,), **kwargs) if ret: - msg = '\n'.join(output) if output else f'unpacking failed: {self.path!r}' + msg = "\n".join(output) if output else f"unpacking failed: {self.path!r}" raise ArCompError(msg, code=ret) @@ -121,11 +121,12 @@ class _CompressedFile: def unpack(self, dest=None, **kwargs): cmd = shlex.split(self._unpack_cmd.format(path=self.path)) - with open(dest, 'wb') as f: + with open(dest, "wb") as f: ret, output = spawn_get_output( - cmd, collect_fds=(2,), fd_pipes={1: f.fileno()}, **kwargs) + cmd, collect_fds=(2,), fd_pipes={1: f.fileno()}, **kwargs + ) if ret: - msg = '\n'.join(output) if output else f'unpacking failed: {self.path!r}' + msg = "\n".join(output) if output else f"unpacking failed: {self.path!r}" raise ArCompError(msg, code=ret) @@ -134,18 +135,25 @@ class _CompressedStdin: def unpack(self, dest=None, **kwargs): cmd = shlex.split(self._unpack_cmd) - with open(self.path, 'rb') as src, open(dest, 'wb') as f: + with open(self.path, "rb") as src, open(dest, "wb") as f: ret, output = spawn_get_output( - cmd, collect_fds=(2,), fd_pipes={0: src.fileno(), 1: f.fileno()}, **kwargs) + cmd, + collect_fds=(2,), + fd_pipes={0: src.fileno(), 1: f.fileno()}, + **kwargs, + ) if ret: - msg = '\n'.join(output) if output else f'unpacking failed: {self.path!r}' + msg = "\n".join(output) if output else f"unpacking failed: {self.path!r}" raise ArCompError(msg, code=ret) class _Tar(_Archive, ArComp): - exts = frozenset(['.tar']) - binary = ('gtar', 'tar',) + exts = frozenset([".tar"]) + binary = ( + "gtar", + "tar", + ) compress_binary = None default_unpack_cmd = '{binary} xf "{path}"' @@ -162,95 +170,96 @@ class _Tar(_Archive, ArComp): except process.CommandNotFound: pass else: - choices = ', '.join(next(zip(*self.compress_binary))) + choices = ", ".join(next(zip(*self.compress_binary))) raise ArCompError( - 'no compression binary found from the ' - f'following choices: {choices}') + "no compression binary found from the " + f"following choices: {choices}" + ) return cmd class _TarGZ(_Tar): - exts = frozenset(['.tar.gz', '.tgz', '.tar.Z', '.tar.z']) - compress_binary = (('pigz',), ('gzip',)) + exts = frozenset([".tar.gz", ".tgz", ".tar.Z", ".tar.z"]) + compress_binary = (("pigz",), ("gzip",)) class _TarBZ2(_Tar): - exts = frozenset(['.tar.bz2', '.tbz2', '.tbz']) - compress_binary = (('lbzip2',), ('pbzip2',), ('bzip2',)) + exts = frozenset([".tar.bz2", ".tbz2", ".tbz"]) + compress_binary = (("lbzip2",), ("pbzip2",), ("bzip2",)) class _TarLZMA(_Tar): - exts = frozenset(['.tar.lzma']) - compress_binary = (('lzma',)) + exts = frozenset([".tar.lzma"]) + compress_binary = ("lzma",) class _TarXZ(_Tar): - exts = frozenset(['.tar.xz', '.txz']) - compress_binary = (('pixz',), ('xz', f'-T{multiprocessing.cpu_count()}')) + exts = frozenset([".tar.xz", ".txz"]) + compress_binary = (("pixz",), ("xz", f"-T{multiprocessing.cpu_count()}")) class _Zip(_Archive, ArComp): - exts = frozenset(['.ZIP', '.zip', '.jar']) - binary = ('unzip',) + exts = frozenset([".ZIP", ".zip", ".jar"]) + binary = ("unzip",) default_unpack_cmd = '{binary} -qo "{path}"' class _GZ(_CompressedStdin, ArComp): - exts = frozenset(['.gz', '.Z', '.z']) - binary = ('pigz', 'gzip') - default_unpack_cmd = '{binary} -d -c' + exts = frozenset([".gz", ".Z", ".z"]) + binary = ("pigz", "gzip") + default_unpack_cmd = "{binary} -d -c" class _BZ2(_CompressedStdin, ArComp): - exts = frozenset(['.bz2', '.bz']) - binary = ('lbzip2', 'pbzip2', 'bzip2') - default_unpack_cmd = '{binary} -d -c' + exts = frozenset([".bz2", ".bz"]) + binary = ("lbzip2", "pbzip2", "bzip2") + default_unpack_cmd = "{binary} -d -c" class _XZ(_CompressedStdin, ArComp): - exts = frozenset(['.xz']) - binary = ('pixz', 'xz') - default_unpack_cmd = '{binary} -d -c' + exts = frozenset([".xz"]) + binary = ("pixz", "xz") + default_unpack_cmd = "{binary} -d -c" class _7Z(_Archive, ArComp): - exts = frozenset(['.7Z', '.7z']) - binary = ('7z',) + exts = frozenset([".7Z", ".7z"]) + binary = ("7z",) default_unpack_cmd = '{binary} x -y "{path}"' class _Rar(_Archive, ArComp): - exts = frozenset(['.RAR', '.rar']) - binary = ('unrar',) + exts = frozenset([".RAR", ".rar"]) + binary = ("unrar",) default_unpack_cmd = '{binary} x -idq -o+ "{path}"' class _LHA(_Archive, ArComp): - exts = frozenset(['.LHa', '.LHA', '.lha', '.lzh']) - binary = ('lha',) + exts = frozenset([".LHa", ".LHA", ".lha", ".lzh"]) + binary = ("lha",) default_unpack_cmd = '{binary} xfq "{path}"' class _Ar(_Archive, ArComp): - exts = frozenset(['.a', '.deb']) - binary = ('ar',) + exts = frozenset([".a", ".deb"]) + binary = ("ar",) default_unpack_cmd = '{binary} x "{path}"' class _LZMA(_CompressedFile, ArComp): - exts = frozenset(['.lzma']) - binary = ('lzma',) + exts = frozenset([".lzma"]) + binary = ("lzma",) default_unpack_cmd = '{binary} -dc "{path}"' diff --git a/src/snakeoil/compression/_bzip2.py b/src/snakeoil/compression/_bzip2.py index 1a38922..122debd 100644 --- a/src/snakeoil/compression/_bzip2.py +++ b/src/snakeoil/compression/_bzip2.py @@ -25,6 +25,7 @@ bz2_path = process.find_binary("bzip2") try: from bz2 import BZ2File, compress as _compress_data, decompress as _decompress_data + native = True except ImportError: @@ -40,7 +41,7 @@ _decompress_handle = partial(_util.decompress_handle, bz2_path) try: lbzip2_path = process.find_binary("lbzip2") - lbzip2_compress_args = (f'-n{multiprocessing.cpu_count()}', ) + lbzip2_compress_args = (f"-n{multiprocessing.cpu_count()}",) lbzip2_decompress_args = lbzip2_compress_args parallelizable = True except process.CommandNotFound: @@ -51,28 +52,35 @@ except process.CommandNotFound: def compress_data(data, level=9, parallelize=False): if parallelize and parallelizable: - return _util.compress_data(lbzip2_path, data, compresslevel=level, - extra_args=lbzip2_compress_args) + return _util.compress_data( + lbzip2_path, data, compresslevel=level, extra_args=lbzip2_compress_args + ) return _compress_data(data, compresslevel=level) + def decompress_data(data, parallelize=False): if parallelize and parallelizable: - return _util.decompress_data(lbzip2_path, data, - extra_args=lbzip2_decompress_args) + return _util.decompress_data( + lbzip2_path, data, extra_args=lbzip2_decompress_args + ) return _decompress_data(data) + def compress_handle(handle, level=9, parallelize=False): if parallelize and parallelizable: - return _util.compress_handle(lbzip2_path, handle, compresslevel=level, - extra_args=lbzip2_compress_args) + return _util.compress_handle( + lbzip2_path, handle, compresslevel=level, extra_args=lbzip2_compress_args + ) elif native and isinstance(handle, str): - return BZ2File(handle, mode='w', compresslevel=level) + return BZ2File(handle, mode="w", compresslevel=level) return _compress_handle(handle, compresslevel=level) + def decompress_handle(handle, parallelize=False): if parallelize and parallelizable: - return _util.decompress_handle(lbzip2_path, handle, - extra_args=lbzip2_decompress_args) + return _util.decompress_handle( + lbzip2_path, handle, extra_args=lbzip2_decompress_args + ) elif native and isinstance(handle, str): - return BZ2File(handle, mode='r') + return BZ2File(handle, mode="r") return _decompress_handle(handle) diff --git a/src/snakeoil/compression/_util.py b/src/snakeoil/compression/_util.py index e1af5ae..b95d80c 100644 --- a/src/snakeoil/compression/_util.py +++ b/src/snakeoil/compression/_util.py @@ -6,15 +6,20 @@ import subprocess def _drive_process(args, mode, data): - p = subprocess.Popen(args, - stdin=subprocess.PIPE, stdout=subprocess.PIPE, - stderr=subprocess.PIPE, close_fds=True) + p = subprocess.Popen( + args, + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + close_fds=True, + ) try: stdout, stderr = p.communicate(data) if p.returncode != 0: - args = ' '.join(args) + args = " ".join(args) raise ValueError( - f"{mode} returned {p.returncode} exitcode from '{args}', stderr={stderr.decode()}") + f"{mode} returned {p.returncode} exitcode from '{args}', stderr={stderr.decode()}" + ) return stdout finally: if p is not None and p.returncode is None: @@ -22,21 +27,20 @@ def _drive_process(args, mode, data): def compress_data(binary, data, compresslevel=9, extra_args=()): - args = [binary, f'-{compresslevel}c'] + args = [binary, f"-{compresslevel}c"] args.extend(extra_args) - return _drive_process(args, 'compression', data) + return _drive_process(args, "compression", data) def decompress_data(binary, data, extra_args=()): - args = [binary, '-dc'] + args = [binary, "-dc"] args.extend(extra_args) - return _drive_process(args, 'decompression', data) + return _drive_process(args, "decompression", data) class _process_handle: - def __init__(self, handle, args, is_read=False): - self.mode = 'rb' if is_read else 'wb' + self.mode = "rb" if is_read else "wb" self.args = tuple(args) self.is_read = is_read @@ -51,9 +55,10 @@ class _process_handle: handle = open(handle, mode=self.mode) close = True elif not isinstance(handle, int): - if not hasattr(handle, 'fileno'): + if not hasattr(handle, "fileno"): raise TypeError( - f"handle {handle!r} isn't a string, integer, and lacks a fileno method") + f"handle {handle!r} isn't a string, integer, and lacks a fileno method" + ) handle = handle.fileno() try: @@ -64,18 +69,17 @@ class _process_handle: def _setup_process(self, handle): self.position = 0 - stderr = open(os.devnull, 'wb') + stderr = open(os.devnull, "wb") kwds = dict(stderr=stderr) if self.is_read: - kwds['stdin'] = handle - kwds['stdout'] = subprocess.PIPE + kwds["stdin"] = handle + kwds["stdout"] = subprocess.PIPE else: - kwds['stdout'] = handle - kwds['stdin'] = subprocess.PIPE + kwds["stdout"] = handle + kwds["stdin"] = subprocess.PIPE try: - self._process = subprocess.Popen( - self.args, close_fds=True, **kwds) + self._process = subprocess.Popen(self.args, close_fds=True, **kwds) finally: stderr.close() @@ -106,7 +110,8 @@ class _process_handle: if self._allow_reopen is None: raise TypeError( f"instance {self} can't do negative seeks: " - f"asked for {position}, was at {self.position}") + f"asked for {position}, was at {self.position}" + ) self._terminate() self._open_handle(self._allow_reopen) return self.seek(position) @@ -130,7 +135,7 @@ class _process_handle: # reallocating it continually; via this usage, we # only slice once the val is less than seek_size; # iow, two allocations worst case. - null_block = '\0' * seek_size + null_block = "\0" * seek_size while val: self.write(null_block[:val]) offset -= val @@ -145,11 +150,13 @@ class _process_handle: raise def close(self): - if not hasattr(self, '_process'): + if not hasattr(self, "_process"): return if self._process.returncode is not None: if self._process.returncode != 0: - raise Exception(f"{self.args} invocation had non zero exit: {self._process.returncode}") + raise Exception( + f"{self.args} invocation had non zero exit: {self._process.returncode}" + ) return self.handle.close() @@ -163,12 +170,12 @@ class _process_handle: def compress_handle(binary_path, handle, compresslevel=9, extra_args=()): - args = [binary_path, f'-{compresslevel}c'] + args = [binary_path, f"-{compresslevel}c"] args.extend(extra_args) return _process_handle(handle, args, False) def decompress_handle(binary_path, handle, extra_args=()): - args = [binary_path, '-dc'] + args = [binary_path, "-dc"] args.extend(extra_args) return _process_handle(handle, args, True) diff --git a/src/snakeoil/compression/_xz.py b/src/snakeoil/compression/_xz.py index 4707737..9a91a4c 100644 --- a/src/snakeoil/compression/_xz.py +++ b/src/snakeoil/compression/_xz.py @@ -21,7 +21,7 @@ from ..compression import _util # if xz can't be found, throw an error. xz_path = process.find_binary("xz") -xz_compress_args = (f'-T{multiprocessing.cpu_count()}',) +xz_compress_args = (f"-T{multiprocessing.cpu_count()}",) xz_decompress_args = xz_compress_args parallelizable = True @@ -29,6 +29,7 @@ try: from lzma import LZMAFile from lzma import compress as _compress_data from lzma import decompress as _decompress_data + native = True except ImportError: @@ -45,30 +46,33 @@ _decompress_handle = partial(_util.decompress_handle, xz_path) def compress_data(data, level=9, parallelize=False): if parallelize and parallelizable: - return _util.compress_data(xz_path, data, compresslevel=level, - extra_args=xz_compress_args) + return _util.compress_data( + xz_path, data, compresslevel=level, extra_args=xz_compress_args + ) if native: return _compress_data(data, preset=level) return _compress_data(data, compresslevel=level) + def decompress_data(data, parallelize=False): if parallelize and parallelizable: - return _util.decompress_data(xz_path, data, - extra_args=xz_decompress_args) + return _util.decompress_data(xz_path, data, extra_args=xz_decompress_args) return _decompress_data(data) + def compress_handle(handle, level=9, parallelize=False): if parallelize and parallelizable: - return _util.compress_handle(xz_path, handle, compresslevel=level, - extra_args=xz_compress_args) + return _util.compress_handle( + xz_path, handle, compresslevel=level, extra_args=xz_compress_args + ) elif native and isinstance(handle, str): - return LZMAFile(handle, mode='w', preset=level) + return LZMAFile(handle, mode="w", preset=level) return _compress_handle(handle, compresslevel=level) + def decompress_handle(handle, parallelize=False): if parallelize and parallelizable: - return _util.decompress_handle(xz_path, handle, - extra_args=xz_decompress_args) - elif (native and isinstance(handle, str)): - return LZMAFile(handle, mode='r') + return _util.decompress_handle(xz_path, handle, extra_args=xz_decompress_args) + elif native and isinstance(handle, str): + return LZMAFile(handle, mode="r") return _decompress_handle(handle) diff --git a/src/snakeoil/constraints.py b/src/snakeoil/constraints.py index 63e4671..c239727 100644 --- a/src/snakeoil/constraints.py +++ b/src/snakeoil/constraints.py @@ -31,8 +31,9 @@ class Constraint(Protocol): domain. :return: ``True`` if the assignment is satisfied. """ + def __call__(self, **kwargs: Any) -> bool: - raise NotImplementedError('Constraint', '__call__') + raise NotImplementedError("Constraint", "__call__") class _Domain(list): @@ -75,10 +76,13 @@ class Problem: of a :py:class:`dict` assigning to each variable in the problem a single value from it's domain. """ + def __init__(self): self.variables: dict[str, _Domain] = {} self.constraints: list[tuple[Constraint, frozenset[str]]] = [] - self.vconstraints: dict[str, list[tuple[Constraint, frozenset[str]]]] = defaultdict(list) + self.vconstraints: dict[ + str, list[tuple[Constraint, frozenset[str]]] + ] = defaultdict(list) def add_variable(self, domain: Iterable[Any], *variables: str): """Add variables to the problem, which use the specified domain. @@ -94,7 +98,9 @@ class Problem: from each domain. """ for variable in variables: - assert variable not in self.variables, f'variable {variable!r} was already added' + assert ( + variable not in self.variables + ), f"variable {variable!r} was already added" self.variables[variable] = _Domain(domain) def add_constraint(self, constraint: Constraint, variables: frozenset[str]): @@ -110,10 +116,15 @@ class Problem: """ self.constraints.append((constraint, variables)) for variable in variables: - assert variable in self.variables, f'unknown variable {variable!r}' + assert variable in self.variables, f"unknown variable {variable!r}" self.vconstraints[variable].append((constraint, variables)) - def __check(self, constraint: Constraint, variables: frozenset[str], assignments: dict[str, Any]) -> bool: + def __check( + self, + constraint: Constraint, + variables: frozenset[str], + assignments: dict[str, Any], + ) -> bool: assignments = {k: v for k, v in assignments.items() if k in variables} unassigned = variables - assignments.keys() if not unassigned: @@ -147,14 +158,17 @@ class Problem: # mix the Degree and Minimum Remaining Values (MRV) heuristics lst = sorted( (-len(self.vconstraints[name]), len(domain), name) - for name, domain in self.variables.items()) + for name, domain in self.variables.items() + ) for _, _, variable in lst: if variable not in assignments: values = self.variables[variable][:] push_domains = tuple( - domain for name, domain in self.variables.items() - if name != variable and name not in assignments) + domain + for name, domain in self.variables.items() + if name != variable and name not in assignments + ) break else: # no unassigned variables, we've got a solution. diff --git a/src/snakeoil/containers.py b/src/snakeoil/containers.py index ba21155..ebfc1b3 100644 --- a/src/snakeoil/containers.py +++ b/src/snakeoil/containers.py @@ -4,8 +4,12 @@ Container classes and functionality for implementing them """ __all__ = ( - "InvertedContains", "SetMixin", "LimitedChangeSet", "Unchangable", - "ProtectedSet", "RefCountingSet" + "InvertedContains", + "SetMixin", + "LimitedChangeSet", + "Unchangable", + "ProtectedSet", + "RefCountingSet", ) from itertools import chain, filterfalse @@ -70,9 +74,11 @@ class SetMixin: @steal_docs(set) def __xor__(self, other, kls=None): - return (kls or self.__class__)(chain( - (x for x in self if x not in other), - (x for x in other if x not in self))) + return (kls or self.__class__)( + chain( + (x for x in self if x not in other), (x for x in other if x not in self) + ) + ) @steal_docs(set) def __rxor__(self, other): @@ -120,8 +126,7 @@ class LimitedChangeSet(SetMixin): def _default_key_validator(val): return val - def __init__(self, initial_keys, unchangable_keys=None, - key_validator=None): + def __init__(self, initial_keys, unchangable_keys=None, key_validator=None): """ :param initial_keys: iterable holding the initial values to set :param unchangable_keys: container holding keys that cannot be changed @@ -185,8 +190,7 @@ class LimitedChangeSet(SetMixin): def rollback(self, point=0): l = self.changes_count() if point < 0 or point > l: - raise TypeError( - "%s point must be >=0 and <= changes_count()" % point) + raise TypeError("%s point must be >=0 and <= changes_count()" % point) while l > point: change, key = self._change_order.pop(-1) self._changed.remove(key) @@ -221,9 +225,8 @@ class LimitedChangeSet(SetMixin): class Unchangable(Exception): - def __init__(self, key): - super().__init__(f'key {key!r} is unchangable') + super().__init__(f"key {key!r} is unchangable") self.key = key @@ -240,6 +243,7 @@ class ProtectedSet(SetMixin): >>> myset.remove(2) >>> assert 2 not in protected """ + def __init__(self, orig_set): self._orig = orig_set self._new = set() diff --git a/src/snakeoil/contexts.py b/src/snakeoil/contexts.py index 57092b4..394574b 100644 --- a/src/snakeoil/contexts.py +++ b/src/snakeoil/contexts.py @@ -41,6 +41,7 @@ from .sequences import predicate_split # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN # THE SOFTWARE. + class SplitExec: """Context manager separating code execution across parent/child processes. @@ -48,6 +49,7 @@ class SplitExec: of the context are executed only on the forked child. Exceptions are pickled and passed back to the parent. """ + def __init__(self): self.__trace_lock = threading.Lock() self.__orig_sys_trace = None @@ -184,7 +186,7 @@ class SplitExec: @staticmethod def __excepthook(_exc_type, exc_value, exc_traceback): """Output the proper traceback information from the chroot context.""" - if hasattr(exc_value, '__traceback_list__'): + if hasattr(exc_value, "__traceback_list__"): sys.stderr.write(exc_value.__traceback_list__) else: traceback.print_tb(exc_traceback) @@ -253,7 +255,7 @@ class SplitExec: except AttributeError: # an offset of two accounts for this method and its caller frame = inspect.stack(0)[2][0] - while frame.f_locals.get('self') is self: + while frame.f_locals.get("self") is self: frame = frame.f_back self.__frame = frame # pylint: disable=W0201 return frame @@ -262,11 +264,24 @@ class SplitExec: class Namespace(SplitExec): """Context manager that provides Linux namespace support.""" - def __init__(self, mount=False, uts=True, ipc=False, net=False, pid=False, - user=False, hostname=None): + def __init__( + self, + mount=False, + uts=True, + ipc=False, + net=False, + pid=False, + user=False, + hostname=None, + ): self._hostname = hostname self._namespaces = { - 'mount': mount, 'uts': uts, 'ipc': ipc, 'net': net, 'pid': pid, 'user': user, + "mount": mount, + "uts": uts, + "ipc": ipc, + "net": net, + "pid": pid, + "user": user, } super().__init__() @@ -279,8 +294,8 @@ class GitStash(AbstractContextManager): def __init__(self, path, pathspecs=None, staged=False): self.path = path - self.pathspecs = ['--'] + pathspecs if pathspecs else [] - self._staged = ['--keep-index'] if staged else [] + self.pathspecs = ["--"] + pathspecs if pathspecs else [] + self._staged = ["--keep-index"] if staged else [] self._stashed = False def __enter__(self): @@ -288,14 +303,18 @@ class GitStash(AbstractContextManager): # check for untracked or modified/uncommitted files try: p = subprocess.run( - ['git', 'status', '--porcelain=1', '-u'] + self.pathspecs, - stdout=subprocess.PIPE, stderr=subprocess.DEVNULL, - cwd=self.path, encoding='utf8', check=True) + ["git", "status", "--porcelain=1", "-u"] + self.pathspecs, + stdout=subprocess.PIPE, + stderr=subprocess.DEVNULL, + cwd=self.path, + encoding="utf8", + check=True, + ) except subprocess.CalledProcessError: - raise ValueError(f'not a git repo: {self.path}') + raise ValueError(f"not a git repo: {self.path}") # split file changes into unstaged vs staged - unstaged, staged = predicate_split(lambda x: x[1] == ' ', p.stdout.splitlines()) + unstaged, staged = predicate_split(lambda x: x[1] == " ", p.stdout.splitlines()) # don't stash when no relevant changes exist if self._staged: @@ -306,14 +325,18 @@ class GitStash(AbstractContextManager): # stash all existing untracked or modified/uncommitted files try: - stash_cmd = ['git', 'stash', 'push', '-u', '-m', 'pkgcheck scan --commits'] + stash_cmd = ["git", "stash", "push", "-u", "-m", "pkgcheck scan --commits"] subprocess.run( stash_cmd + self._staged + self.pathspecs, - stdout=subprocess.DEVNULL, stderr=subprocess.PIPE, - cwd=self.path, check=True, encoding='utf8') + stdout=subprocess.DEVNULL, + stderr=subprocess.PIPE, + cwd=self.path, + check=True, + encoding="utf8", + ) except subprocess.CalledProcessError as e: error = e.stderr.splitlines()[0] - raise UserException(f'git failed stashing files: {error}') + raise UserException(f"git failed stashing files: {error}") self._stashed = True def __exit__(self, _exc_type, _exc_value, _traceback): @@ -321,12 +344,16 @@ class GitStash(AbstractContextManager): if self._stashed: try: subprocess.run( - ['git', 'stash', 'pop'], - stdout=subprocess.DEVNULL, stderr=subprocess.PIPE, - cwd=self.path, check=True, encoding='utf8') + ["git", "stash", "pop"], + stdout=subprocess.DEVNULL, + stderr=subprocess.PIPE, + cwd=self.path, + check=True, + encoding="utf8", + ) except subprocess.CalledProcessError as e: error = e.stderr.splitlines()[0] - raise UserException(f'git failed applying stash: {error}') + raise UserException(f"git failed applying stash: {error}") @contextmanager @@ -347,7 +374,7 @@ def chdir(path): @contextmanager -def syspath(path: str, condition: bool=True, position: int=0): +def syspath(path: str, condition: bool = True, position: int = 0): """Context manager that mangles ``sys.path`` and then reverts on exit. :param path: The directory path to add to ``sys.path``. @@ -425,6 +452,7 @@ def os_environ(*remove, **update): # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + @contextmanager def patch(target, new): """Simplified module monkey patching via context manager. @@ -434,7 +462,7 @@ def patch(target, new): """ def _import_module(target): - components = target.split('.') + components = target.split(".") import_path = components.pop(0) module = import_module(import_path) for comp in components: @@ -448,16 +476,16 @@ def patch(target, new): def _get_target(target): if isinstance(target, str): try: - module, attr = target.rsplit('.', 1) + module, attr = target.rsplit(".", 1) except (TypeError, ValueError): - raise TypeError(f'invalid target: {target!r}') + raise TypeError(f"invalid target: {target!r}") module = _import_module(module) return module, attr else: try: obj, attr = target except (TypeError, ValueError): - raise TypeError(f'invalid target: {target!r}') + raise TypeError(f"invalid target: {target!r}") return obj, attr obj, attr = _get_target(target) diff --git a/src/snakeoil/currying.py b/src/snakeoil/currying.py index f8d8d97..bc16dc0 100644 --- a/src/snakeoil/currying.py +++ b/src/snakeoil/currying.py @@ -46,14 +46,19 @@ def pre_curry(func, *args, **kwargs): """ if not kwargs: + def callit(*moreargs, **morekwargs): return func(*(args + moreargs), **morekwargs) + elif not args: + def callit(*moreargs, **morekwargs): kw = kwargs.copy() kw.update(morekwargs) return func(*moreargs, **kw) + else: + def callit(*moreargs, **morekwargs): kw = kwargs.copy() kw.update(morekwargs) @@ -67,14 +72,19 @@ def post_curry(func, *args, **kwargs): """passed in args are appended to any further args supplied""" if not kwargs: + def callit(*moreargs, **morekwargs): return func(*(moreargs + args), **morekwargs) + elif not args: + def callit(*moreargs, **morekwargs): kw = morekwargs.copy() kw.update(kwargs) return func(*moreargs, **kw) + else: + def callit(*moreargs, **morekwargs): kw = morekwargs.copy() kw.update(kwargs) @@ -112,18 +122,32 @@ def wrap_exception(recast_exception, *args, **kwds): # set this here so that 2to3 will rewrite it. try: if not issubclass(recast_exception, Exception): - raise ValueError("recast_exception must be an %s derivative: got %r" % - (Exception, recast_exception)) + raise ValueError( + "recast_exception must be an %s derivative: got %r" + % (Exception, recast_exception) + ) except TypeError as e: - raise TypeError("recast_exception must be an %s derivative; got %r, failed %r", - (Exception.__name__, recast_exception, e)) + raise TypeError( + "recast_exception must be an %s derivative; got %r, failed %r", + (Exception.__name__, recast_exception, e), + ) ignores = kwds.pop("ignores", (recast_exception,)) pass_error = kwds.pop("pass_error", None) - return wrap_exception_complex(partial(_simple_throw, recast_exception, args, kwds, pass_error), ignores) - - -def _simple_throw(recast_exception, recast_args, recast_kwds, pass_error, - exception, functor, args, kwds): + return wrap_exception_complex( + partial(_simple_throw, recast_exception, args, kwds, pass_error), ignores + ) + + +def _simple_throw( + recast_exception, + recast_args, + recast_kwds, + pass_error, + exception, + functor, + args, + kwds, +): if pass_error: recast_kwds[pass_error] = exception return recast_exception(*recast_args, **recast_kwds) @@ -131,15 +155,22 @@ def _simple_throw(recast_exception, recast_args, recast_kwds, pass_error, def wrap_exception_complex(creation_func, ignores): try: - if not hasattr(ignores, '__iter__') and issubclass(ignores, Exception) or ignores is Exception: + if ( + not hasattr(ignores, "__iter__") + and issubclass(ignores, Exception) + or ignores is Exception + ): ignores = (ignores,) ignores = tuple(ignores) except TypeError as e: - raise TypeError("ignores must be either a tuple of %s, or a %s: got %r, error %r" - % (Exception.__name__, Exception.__name__, ignores, e)) + raise TypeError( + "ignores must be either a tuple of %s, or a %s: got %r, error %r" + % (Exception.__name__, Exception.__name__, ignores, e) + ) if not all(issubclass(x, Exception) for x in ignores): - raise TypeError("ignores has a non %s derivative in it: %r" % - (Exception.__name__, ignores)) + raise TypeError( + "ignores has a non %s derivative in it: %r" % (Exception.__name__, ignores) + ) return partial(_inner_wrap_exception, creation_func, ignores) @@ -153,5 +184,6 @@ def _inner_wrap_exception(exception_maker, ignores, functor): raise except Exception as e: raise exception_maker(e, functor, args, kwargs) from e + _wrap_exception.func = functor return pretty_docs(_wrap_exception, name=functor.__name__) diff --git a/src/snakeoil/data_source.py b/src/snakeoil/data_source.py index 1faa540..3b0ccb1 100644 --- a/src/snakeoil/data_source.py +++ b/src/snakeoil/data_source.py @@ -33,8 +33,13 @@ we caught the exception. """ __all__ = ( - "base", "bz2_source", "data_source", "local_source", "text_data_source", - "bytes_data_source", "invokable_data_source", + "base", + "bz2_source", + "data_source", + "local_source", + "text_data_source", + "bytes_data_source", + "invokable_data_source", ) import errno @@ -62,10 +67,9 @@ def _mk_writable_cls(base, name): exceptions attribute """ - base_cls = base exceptions = (MemoryError,) - __slots__ = ('_callback',) + __slots__ = ("_callback",) def __init__(self, callback, data): """ @@ -85,6 +89,7 @@ def _mk_writable_cls(base, name): self._callback(self.read()) self._callback = None self.base_cls.close(self) + kls.__name__ = name return kls @@ -100,6 +105,7 @@ class text_ro_StringIO(stringio.text_readonly): Specifically this adds the necessary `exceptions` attribute; see :py:class:`snakeoil.stringio.text_readonly` for methods details. """ + __slots__ = () exceptions = (MemoryError, TypeError) @@ -111,6 +117,7 @@ class bytes_ro_StringIO(stringio.bytes_readonly): Specifically this adds the necessary `exceptions` attribute; see :py:class:`snakeoil.stringio.bytes_readonly` for methods details. """ + __slots__ = () exceptions = (MemoryError, TypeError) @@ -131,6 +138,7 @@ class base: :ivar path: If None, no local path is available- else it's the ondisk path to the data """ + __slots__ = ("weakref",) path = None @@ -155,7 +163,8 @@ class base: def transfer_to_path(self, path): return self.transfer_to_data_source( - local_source(path, mutable=True, encoding=None)) + local_source(path, mutable=True, encoding=None) + ) def transfer_to_data_source(self, write_source): read_f, m, write_f = None, None, None @@ -208,31 +217,32 @@ class local_source(base): raise TypeError("data source %s is immutable" % (self,)) if self.encoding: opener = open_file - opener = post_curry(opener, buffering=self.buffering_window, - encoding=self.encoding) + opener = post_curry( + opener, buffering=self.buffering_window, encoding=self.encoding + ) else: opener = post_curry(open_file, self.buffering_window) if not writable: - return opener(self.path, 'r') + return opener(self.path, "r") try: return opener(self.path, "r+") except IOError as ie: if ie.errno != errno.ENOENT: raise - return opener(self.path, 'w+') + return opener(self.path, "w+") @klass.steal_docs(base) def bytes_fileobj(self, writable=False): if not writable: - return open_file(self.path, 'rb', self.buffering_window) + return open_file(self.path, "rb", self.buffering_window) if not self.mutable: raise TypeError("data source %s is immutable" % (self,)) try: - return open_file(self.path, 'rb+', self.buffering_window) + return open_file(self.path, "rb+", self.buffering_window) except IOError as ie: if ie.errno != errno.ENOENT: raise - return open_file(self.path, 'wb+', self.buffering_window) + return open_file(self.path, "wb+", self.buffering_window) class bz2_source(base): @@ -255,7 +265,8 @@ class bz2_source(base): def text_fileobj(self, writable=False): data = compression.decompress_data( - 'bzip2', fileutils.readfile_bytes(self.path)).decode() + "bzip2", fileutils.readfile_bytes(self.path) + ).decode() if writable: if not self.mutable: raise TypeError(f"data source {self} is not mutable") @@ -263,8 +274,7 @@ class bz2_source(base): return text_ro_StringIO(data) def bytes_fileobj(self, writable=False): - data = compression.decompress_data( - 'bzip2', fileutils.readfile_bytes(self.path)) + data = compression.decompress_data("bzip2", fileutils.readfile_bytes(self.path)) if writable: if not self.mutable: raise TypeError(f"data source {self} is not mutable") @@ -275,7 +285,7 @@ class bz2_source(base): if isinstance(data, str): data = data.encode() with open(self.path, "wb") as f: - f.write(compression.compress_data('bzip2', data)) + f.write(compression.compress_data("bzip2", data)) class data_source(base): @@ -293,7 +303,7 @@ class data_source(base): :ivar path: note that path is None for this class- no on disk location available. """ - __slots__ = ('data', 'mutable') + __slots__ = ("data", "mutable") def __init__(self, data, mutable=False): """ @@ -305,7 +315,7 @@ class data_source(base): self.mutable = mutable def _convert_data(self, mode): - if mode == 'bytes': + if mode == "bytes": if isinstance(self.data, bytes): return self.data return self.data.encode() @@ -318,9 +328,8 @@ class data_source(base): if writable: if not self.mutable: raise TypeError(f"data source {self} is not mutable") - return text_wr_StringIO(self._reset_data, - self._convert_data('text')) - return text_ro_StringIO(self._convert_data('text')) + return text_wr_StringIO(self._reset_data, self._convert_data("text")) + return text_ro_StringIO(self._convert_data("text")) def _reset_data(self, data): if isinstance(self.data, bytes): @@ -335,9 +344,8 @@ class data_source(base): if writable: if not self.mutable: raise TypeError(f"data source {self} is not mutable") - return bytes_wr_StringIO(self._reset_data, - self._convert_data('bytes')) - return bytes_ro_StringIO(self._convert_data('bytes')) + return bytes_wr_StringIO(self._reset_data, self._convert_data("bytes")) + return bytes_ro_StringIO(self._convert_data("bytes")) class text_data_source(data_source): @@ -355,7 +363,7 @@ class text_data_source(data_source): data_source.__init__(self, data, mutable=mutable) def _convert_data(self, mode): - if mode != 'bytes': + if mode != "bytes": return self.data return self.data.encode() @@ -375,7 +383,7 @@ class bytes_data_source(data_source): data_source.__init__(self, data, mutable=mutable) def _convert_data(self, mode): - if mode == 'bytes': + if mode == "bytes": return self.data return self.data.decode() @@ -390,6 +398,7 @@ class invokable_data_source(data_source): Note that this instance is explicitly readonly. """ + __slots__ = () def __init__(self, data): @@ -412,7 +421,9 @@ class invokable_data_source(data_source): return self.data(False) @classmethod - def wrap_function(cls, invokable, returns_text=True, returns_handle=False, encoding_hint=None): + def wrap_function( + cls, invokable, returns_text=True, returns_handle=False, encoding_hint=None + ): """ Helper function to automatically convert a function that returns text or bytes into appropriate callable @@ -425,10 +436,20 @@ class invokable_data_source(data_source): :param encoding_hint: the preferred encoding to use for encoding :return: invokable_data_source instance """ - return cls(partial(cls._simple_wrapper, invokable, encoding_hint, returns_text, returns_handle)) + return cls( + partial( + cls._simple_wrapper, + invokable, + encoding_hint, + returns_text, + returns_handle, + ) + ) @staticmethod - def _simple_wrapper(invokable, encoding_hint, returns_text, returns_handle, text_wanted): + def _simple_wrapper( + invokable, encoding_hint, returns_text, returns_handle, text_wanted + ): data = invokable() if returns_text != text_wanted: if text_wanted: @@ -446,7 +467,7 @@ class invokable_data_source(data_source): data = data.read() if encoding_hint is None: # fallback to utf8 - encoding_hint = 'utf8' + encoding_hint = "utf8" data = data.encode(encoding_hint) elif returns_handle: return data diff --git a/src/snakeoil/decorators.py b/src/snakeoil/decorators.py index 550d19c..c0e0429 100644 --- a/src/snakeoil/decorators.py +++ b/src/snakeoil/decorators.py @@ -7,29 +7,36 @@ from .contexts import Namespace, SplitExec def splitexec(func): """Run the decorated function in another process.""" + @wraps(func) def wrapper(*args, **kwargs): with SplitExec(): return func(*args, **kwargs) + return wrapper def namespace(**namespaces): """Run the decorated function in a specified namespace.""" + def decorator(func): @wraps(func) def wrapper(*args, **kwargs): with Namespace(**namespaces): return func(*args, **kwargs) + return wrapper + return decorator def coroutine(func): """Prime a coroutine for input.""" + @wraps(func) def prime(*args, **kwargs): cr = func(*args, **kwargs) next(cr) return cr + return prime diff --git a/src/snakeoil/demandimport.py b/src/snakeoil/demandimport.py index c87e114..0a24ccc 100644 --- a/src/snakeoil/demandimport.py +++ b/src/snakeoil/demandimport.py @@ -13,14 +13,16 @@ from importlib.util import LazyLoader _disabled = False # modules that have issues when lazily imported -_skip = frozenset([ - '__builtin__', - '__future__', - 'builtins', - 'grp', - 'pwd', - 'OpenSSL.SSL', # pyopenssl -]) +_skip = frozenset( + [ + "__builtin__", + "__future__", + "builtins", + "grp", + "pwd", + "OpenSSL.SSL", # pyopenssl + ] +) class _LazyLoader(LazyLoader): @@ -35,10 +37,8 @@ class _LazyLoader(LazyLoader): # custom loaders using our extended LazyLoader -_extensions_loader = _LazyLoader.factory( - importlib.machinery.ExtensionFileLoader) -_bytecode_loader = _LazyLoader.factory( - importlib.machinery.SourcelessFileLoader) +_extensions_loader = _LazyLoader.factory(importlib.machinery.ExtensionFileLoader) +_bytecode_loader = _LazyLoader.factory(importlib.machinery.SourcelessFileLoader) _source_loader = _LazyLoader.factory(importlib.machinery.SourceFileLoader) @@ -54,7 +54,11 @@ def _filefinder(path): def enable(): """Enable lazy loading for all future module imports.""" - if os.environ.get('SNAKEOIL_DEMANDIMPORT', 'y').lower() not in ('n', 'no' '0', 'false'): + if os.environ.get("SNAKEOIL_DEMANDIMPORT", "y").lower() not in ( + "n", + "no" "0", + "false", + ): sys.path_hooks.insert(0, _filefinder) diff --git a/src/snakeoil/demandload.py b/src/snakeoil/demandload.py index 0c84e8c..3800622 100644 --- a/src/snakeoil/demandload.py +++ b/src/snakeoil/demandload.py @@ -49,8 +49,9 @@ from .modules import load_any # There are some demandloaded imports below the definition of demandload. -_allowed_chars = "".join((x.isalnum() or x in "_.") and " " or "a" - for x in map(chr, range(256))) +_allowed_chars = "".join( + (x.isalnum() or x in "_.") and " " or "a" for x in map(chr, range(256)) +) def parse_imports(imports): @@ -72,15 +73,16 @@ def parse_imports(imports): :rtype: iterable of tuples of two C{str} objects. """ for s in imports: - fromlist = s.split(':', 1) + fromlist = s.split(":", 1) if len(fromlist) == 1: # Not a "from" import. - if '.' in s: + if "." in s: raise ValueError( "dotted imports are disallowed; see " "snakeoil.demandload docstring for " - f"details; {s!r}") - split = s.split('@', 1) + f"details; {s!r}" + ) + split = s.split("@", 1) for s in split: if not s.translate(_allowed_chars).isspace(): raise ValueError(f"bad target: {s}") @@ -94,28 +96,33 @@ def parse_imports(imports): base, targets = fromlist if not base.translate(_allowed_chars).isspace(): raise ValueError(f"bad target: {base}") - for target in targets.split(','): - split = target.split('@', 1) + for target in targets.split(","): + split = target.split("@", 1) for s in split: if not s.translate(_allowed_chars).isspace(): raise ValueError(f"bad target: {s}") - yield base + '.' + split[0], split[-1] + yield base + "." + split[0], split[-1] + def _protection_enabled_disabled(): return False + def _noisy_protection_disabled(): return False + def _protection_enabled_enabled(): val = os.environ.get("SNAKEOIL_DEMANDLOAD_PROTECTION", "n").lower() return val in ("yes", "true", "1", "y") + def _noisy_protection_enabled(): val = os.environ.get("SNAKEOIL_DEMANDLOAD_WARN", "y").lower() return val in ("yes", "true", "1", "y") -if 'pydoc' in sys.modules or 'epydoc' in sys.modules: + +if "pydoc" in sys.modules or "epydoc" in sys.modules: _protection_enabled = _protection_enabled_disabled _noisy_protection = _noisy_protection_disabled else: @@ -164,15 +171,15 @@ class Placeholder: """ if not callable(load_func): raise TypeError(f"load_func must be callable; got {load_func!r}") - object.__setattr__(self, '_scope', scope) - object.__setattr__(self, '_name', name) - object.__setattr__(self, '_replacing_tids', []) - object.__setattr__(self, '_load_func', load_func) - object.__setattr__(self, '_loading_lock', threading.Lock()) + object.__setattr__(self, "_scope", scope) + object.__setattr__(self, "_name", name) + object.__setattr__(self, "_replacing_tids", []) + object.__setattr__(self, "_load_func", load_func) + object.__setattr__(self, "_loading_lock", threading.Lock()) def _target_already_loaded(self, complain=True): - name = object.__getattribute__(self, '_name') - scope = object.__getattribute__(self, '_scope') + name = object.__getattribute__(self, "_name") + scope = object.__getattribute__(self, "_scope") # in a threaded environment, it's possible for tid1 to get the # placeholder from globals, python switches to tid2, which triggers @@ -188,13 +195,16 @@ class Placeholder: # it's impossible for this pathway to accidentally be triggered twice- # meaning it is a misuse by the consuming client code. if complain: - tids_to_complain_about = object.__getattribute__(self, '_replacing_tids') + tids_to_complain_about = object.__getattribute__(self, "_replacing_tids") if threading.current_thread().ident in tids_to_complain_about: if _protection_enabled(): - raise ValueError(f'Placeholder for {name!r} was triggered twice') + raise ValueError(f"Placeholder for {name!r} was triggered twice") elif _noisy_protection(): - logging.warning('Placeholder for %r was triggered multiple times ' - 'in file %r', name, scope.get("__file__", "unknown")) + logging.warning( + "Placeholder for %r was triggered multiple times " "in file %r", + name, + scope.get("__file__", "unknown"), + ) return scope[name] def _get_target(self): @@ -202,9 +212,9 @@ class Placeholder: :return: the result of calling C{_load_func}. """ - preloaded_func = object.__getattribute__(self, '_target_already_loaded') - with object.__getattribute__(self, '_loading_lock'): - load_func = object.__getattribute__(self, '_load_func') + preloaded_func = object.__getattribute__(self, "_target_already_loaded") + with object.__getattribute__(self, "_loading_lock"): + load_func = object.__getattribute__(self, "_load_func") if load_func is None: # This means that there was contention; two threads made it into # _get_target. That's fine; suppress complaints, and return the @@ -215,18 +225,17 @@ class Placeholder: # fix the scope, and replace this method with one that shortcircuits # (and appropriately complains) the lookup. result = load_func() - scope = object.__getattribute__(self, '_scope') - name = object.__getattribute__(self, '_name') + scope = object.__getattribute__(self, "_scope") + name = object.__getattribute__(self, "_name") scope[name] = result # Replace this method with the fast path/preloaded one; this # is to ensure complaints get leveled if needed. - object.__setattr__(self, '_get_target', preloaded_func) - object.__setattr__(self, '_load_func', None) - + object.__setattr__(self, "_get_target", preloaded_func) + object.__setattr__(self, "_load_func", None) # note this step *has* to follow scope modification; else it # will go maximum depth recursion. - tids = object.__getattribute__(self, '_replacing_tids') + tids = object.__getattribute__(self, "_replacing_tids") tids.append(threading.current_thread().ident) return result @@ -237,18 +246,18 @@ class Placeholder: # Various methods proxied to our replacement. def __str__(self): - return self.__getattribute__('__str__')() + return self.__getattribute__("__str__")() def __getattribute__(self, attr): - result = object.__getattribute__(self, '_get_target')() + result = object.__getattribute__(self, "_get_target")() return getattr(result, attr) def __setattr__(self, attr, value): - result = object.__getattribute__(self, '_get_target')() + result = object.__getattribute__(self, "_get_target")() setattr(result, attr, value) def __call__(self, *args, **kwargs): - result = object.__getattribute__(self, '_get_target')() + result = object.__getattribute__(self, "_get_target")() return result(*args, **kwargs) @@ -267,7 +276,7 @@ def demandload(*imports, **kwargs): """ # pull the caller's global namespace if undefined - scope = kwargs.pop('scope', sys._getframe(1).f_globals) + scope = kwargs.pop("scope", sys._getframe(1).f_globals) for source, target in parse_imports(imports): scope[target] = Placeholder.load_namespace(scope, target, source) @@ -280,7 +289,7 @@ enabled_demandload = demandload def disabled_demandload(*imports, **kwargs): """Exactly like :py:func:`demandload` but does all imports immediately.""" - scope = kwargs.pop('scope', sys._getframe(1).f_globals) + scope = kwargs.pop("scope", sys._getframe(1).f_globals) for source, target in parse_imports(imports): scope[target] = load_any(source) @@ -292,21 +301,25 @@ def demand_compile_regexp(name, *args, **kwargs): :param name: the name of the compiled re object in that scope. """ - scope = kwargs.pop('scope', sys._getframe(1).f_globals) + scope = kwargs.pop("scope", sys._getframe(1).f_globals) scope[name] = Placeholder.load_regex(scope, name, *args, **kwargs) def disabled_demand_compile_regexp(name, *args, **kwargs): """Exactly like :py:func:`demand_compile_regexp` but does all imports immediately.""" - scope = kwargs.pop('scope', sys._getframe(1).f_globals) + scope = kwargs.pop("scope", sys._getframe(1).f_globals) scope[name] = re.compile(*args, **kwargs) -if os.environ.get("SNAKEOIL_DEMANDLOAD_DISABLED", 'n').lower() in ('y', 'yes' '1', 'true'): +if os.environ.get("SNAKEOIL_DEMANDLOAD_DISABLED", "n").lower() in ( + "y", + "yes" "1", + "true", +): demandload = disabled_demandload demand_compile_regexp = disabled_demand_compile_regexp demandload( - 'logging', - 're', + "logging", + "re", ) diff --git a/src/snakeoil/dependant_methods.py b/src/snakeoil/dependant_methods.py index 551af1e..031ad63 100644 --- a/src/snakeoil/dependant_methods.py +++ b/src/snakeoil/dependant_methods.py @@ -63,7 +63,7 @@ def _ensure_deps(cls_id, name, func, self, *a, **kw): s = _yield_deps(self, self.stage_depends, name) r = True - if not hasattr(self, '_stage_state'): + if not hasattr(self, "_stage_state"): self._stage_state = set() for dep in s: if dep not in self._stage_state: @@ -108,8 +108,8 @@ def __wrap_stage_dependencies__(cls): f = getattr(cls, x) except AttributeError: raise TypeError( - "class %r stage_depends specifies %r, which doesn't exist" % - (cls, x)) + "class %r stage_depends specifies %r, which doesn't exist" % (cls, x) + ) f2 = pre_curry(_ensure_deps, cls_id, x, f) f2.sd_raw_func = f setattr(cls, x, f2) @@ -122,9 +122,9 @@ def __unwrap_stage_dependencies__(cls): f = getattr(cls, x) except AttributeError: raise TypeError( - "class %r stage_depends specifies %r, which doesn't exist" % - (cls, x)) - setattr(cls, x, getattr(f, 'sd_raw_func', f)) + "class %r stage_depends specifies %r, which doesn't exist" % (cls, x) + ) + setattr(cls, x, getattr(f, "sd_raw_func", f)) def __set_stage_state__(self, state): @@ -165,17 +165,17 @@ class ForcedDepends(type): def __new__(cls, name, bases, d): obj = super(ForcedDepends, cls).__new__(cls, name, bases, d) - if not hasattr(obj, 'stage_depends'): + if not hasattr(obj, "stage_depends"): obj.stage_depends = {} for x in ("wrap", "unwrap"): - s = '__%s_stage_dependencies__' % x + s = "__%s_stage_dependencies__" % x if not hasattr(obj, s): setattr(obj, s, classmethod(globals()[s])) obj.__unwrap_stage_dependencies__() obj.__wrap_stage_dependencies__() - if not hasattr(obj, '__force_stage_state__'): + if not hasattr(obj, "__force_stage_state__"): obj.__set_stage_state__ = __set_stage_state__ - if not hasattr(obj, '__stage_step_callback__'): + if not hasattr(obj, "__stage_step_callback__"): obj.__stage_step_callback__ = __stage_step_callback__ return obj diff --git a/src/snakeoil/errors.py b/src/snakeoil/errors.py index 9b7b541..ab37348 100644 --- a/src/snakeoil/errors.py +++ b/src/snakeoil/errors.py @@ -12,20 +12,20 @@ def walk_exception_chain(exc, ignore_first=False, reverse=False): def _inner_walk_exception_chain(exc, ignore_first): if not ignore_first: yield exc - exc = getattr(exc, '__cause__', None) + exc = getattr(exc, "__cause__", None) while exc is not None: yield exc - exc = getattr(exc, '__cause__', None) + exc = getattr(exc, "__cause__", None) def dump_error(raw_exc, msg=None, handle=sys.stderr, tb=None): # force default output for exceptions - if getattr(handle, 'reset', False): + if getattr(handle, "reset", False): handle.write(handle.reset) - prefix = '' + prefix = "" if msg: - prefix = ' ' + prefix = " " handle.write(msg.rstrip("\n") + ":\n") if tb: handle.write("Traceback follows:\n") @@ -34,8 +34,8 @@ def dump_error(raw_exc, msg=None, handle=sys.stderr, tb=None): if raw_exc is not None: for exc in walk_exception_chain(raw_exc): exc_strings.extend( - prefix + x.strip() - for x in (x for x in str(exc).split("\n") if x)) + prefix + x.strip() for x in (x for x in str(exc).split("\n") if x) + ) if exc_strings: if msg and tb: handle.write(f"\n{raw_exc.__class__.__name__}:\n") diff --git a/src/snakeoil/fileutils.py b/src/snakeoil/fileutils.py index ed3a034..b0aa724 100644 --- a/src/snakeoil/fileutils.py +++ b/src/snakeoil/fileutils.py @@ -23,23 +23,28 @@ def touch(fname, mode=0o644, **kwargs): See os.utime for other supported arguments. """ flags = os.O_CREAT | os.O_APPEND - dir_fd = kwargs.get('dir_fd', None) + dir_fd = kwargs.get("dir_fd", None) os_open = partial(os.open, dir_fd=dir_fd) with os.fdopen(os_open(fname, flags, mode)) as f: os.utime( f.fileno() if os.utime in os.supports_fd else fname, - dir_fd=None if os.supports_fd else dir_fd, **kwargs) + dir_fd=None if os.supports_fd else dir_fd, + **kwargs + ) + def mmap_or_open_for_read(path): size = os.stat(path).st_size if size == 0: - return (None, data_source.bytes_ro_StringIO(b'')) + return (None, data_source.bytes_ro_StringIO(b"")) fd = None try: fd = os.open(path, os.O_RDONLY) - return (_fileutils.mmap_and_close( - fd, size, mmap.MAP_SHARED, mmap.PROT_READ), None) + return ( + _fileutils.mmap_and_close(fd, size, mmap.MAP_SHARED, mmap.PROT_READ), + None, + ) except IGNORED_EXCEPTIONS: raise except: @@ -85,7 +90,8 @@ class AtomicWriteFile_mixin: fp = os.path.realpath(fp) self._original_fp = fp self._temp_fp = os.path.join( - os.path.dirname(fp), ".update." + os.path.basename(fp)) + os.path.dirname(fp), ".update." + os.path.basename(fp) + ) old_umask = None if perms: # give it just write perms @@ -140,7 +146,7 @@ class AtomicWriteFile(AtomicWriteFile_mixin): self.raw = open(self._temp_fp, mode=self._computed_mode) def _real_close(self): - if hasattr(self, 'raw'): + if hasattr(self, "raw"): return self.raw.close() return None @@ -149,24 +155,23 @@ class AtomicWriteFile(AtomicWriteFile_mixin): def _mk_pretty_derived_func(func, name_base, name, *args, **kwds): if name: - name = '_' + name - return pretty_docs(partial(func, *args, **kwds), - name='%s%s' % (name_base, name)) + name = "_" + name + return pretty_docs(partial(func, *args, **kwds), name="%s%s" % (name_base, name)) -_mk_readfile = partial( - _mk_pretty_derived_func, _fileutils.native_readfile, 'readfile') +_mk_readfile = partial(_mk_pretty_derived_func, _fileutils.native_readfile, "readfile") -readfile_ascii = _mk_readfile('ascii', 'rt') -readfile_bytes = _mk_readfile('bytes', 'rb') -readfile_utf8 = _mk_readfile('utf8', 'r', encoding='utf8') +readfile_ascii = _mk_readfile("ascii", "rt") +readfile_bytes = _mk_readfile("bytes", "rb") +readfile_utf8 = _mk_readfile("utf8", "r", encoding="utf8") readfile = readfile_utf8 _mk_readlines = partial( - _mk_pretty_derived_func, _fileutils.native_readlines, 'readlines') + _mk_pretty_derived_func, _fileutils.native_readlines, "readlines" +) -readlines_ascii = _mk_readlines('ascii', 'r', encoding='ascii') -readlines_bytes = _mk_readlines('bytes', 'rb') -readlines_utf8 = _mk_readlines('utf8', 'r', encoding='utf8') +readlines_ascii = _mk_readlines("ascii", "r", encoding="ascii") +readlines_bytes = _mk_readlines("bytes", "rb") +readlines_utf8 = _mk_readlines("utf8", "r", encoding="utf8") readlines = readlines_utf8 diff --git a/src/snakeoil/formatters.py b/src/snakeoil/formatters.py index 397667f..3faf661 100644 --- a/src/snakeoil/formatters.py +++ b/src/snakeoil/formatters.py @@ -10,7 +10,9 @@ from .klass import GetAttrProxy, steal_docs from .mappings import defaultdictkey __all__ = ( - "Formatter", "PlainTextFormatter", "get_formatter", + "Formatter", + "PlainTextFormatter", + "get_formatter", "decorate_forced_wrapping", ) @@ -98,13 +100,11 @@ class Formatter: def error(self, message): """Format a string as an error message.""" - self.write(message, prefixes=( - self.fg('red'), self.bold, '!!! ', self.reset)) + self.write(message, prefixes=(self.fg("red"), self.bold, "!!! ", self.reset)) def warn(self, message): """Format a string as a warning message.""" - self.write(message, prefixes=( - self.fg('yellow'), self.bold, '*** ', self.reset)) + self.write(message, prefixes=(self.fg("yellow"), self.bold, "*** ", self.reset)) def title(self, string): """Set the title to string""" @@ -123,7 +123,7 @@ class PlainTextFormatter(Formatter): every write. """ - bold = underline = reset = '' + bold = underline = reset = "" def __init__(self, stream, width=79, encoding=None): """Initialize. @@ -144,12 +144,12 @@ class PlainTextFormatter(Formatter): else: self.stream = stream if encoding is None: - encoding = getattr(self.stream, 'encoding', None) + encoding = getattr(self.stream, "encoding", None) if encoding is None: try: encoding = locale.getpreferredencoding() except locale.Error: - encoding = 'ascii' + encoding = "ascii" self.encoding = encoding self.width = width self._pos = 0 @@ -162,7 +162,7 @@ class PlainTextFormatter(Formatter): return True def _force_encoding(self, val): - return val.encode(self.encoding, 'replace') + return val.encode(self.encoding, "replace") def _write_prefix(self, wrap): if self._in_first_line: @@ -190,34 +190,32 @@ class PlainTextFormatter(Formatter): @steal_docs(Formatter) def write(self, *args, **kwargs): - wrap = kwargs.get('wrap', self.wrap) - autoline = kwargs.get('autoline', self.autoline) - prefixes = kwargs.get('prefixes') - first_prefixes = kwargs.get('first_prefixes') - later_prefixes = kwargs.get('later_prefixes') + wrap = kwargs.get("wrap", self.wrap) + autoline = kwargs.get("autoline", self.autoline) + prefixes = kwargs.get("prefixes") + first_prefixes = kwargs.get("first_prefixes") + later_prefixes = kwargs.get("later_prefixes") if prefixes is not None: if first_prefixes is not None or later_prefixes is not None: raise TypeError( - 'do not pass first_prefixes or later_prefixes ' - 'if prefixes is passed') + "do not pass first_prefixes or later_prefixes " + "if prefixes is passed" + ) first_prefixes = later_prefixes = prefixes - prefix = kwargs.get('prefix') - first_prefix = kwargs.get('first_prefix') - later_prefix = kwargs.get('later_prefix') + prefix = kwargs.get("prefix") + first_prefix = kwargs.get("first_prefix") + later_prefix = kwargs.get("later_prefix") if prefix is not None: if first_prefix is not None or later_prefix is not None: - raise TypeError( - 'do not pass first_prefix or later_prefix with prefix') + raise TypeError("do not pass first_prefix or later_prefix with prefix") first_prefix = later_prefix = prefix if first_prefix is not None: if first_prefixes is not None: - raise TypeError( - 'do not pass both first_prefix and first_prefixes') + raise TypeError("do not pass both first_prefix and first_prefixes") first_prefixes = (first_prefix,) if later_prefix is not None: if later_prefixes is not None: - raise TypeError( - 'do not pass both later_prefix and later_prefixes') + raise TypeError("do not pass both later_prefix and later_prefixes") later_prefixes = (later_prefix,) if first_prefixes is not None: self.first_prefix.extend(first_prefixes) @@ -242,7 +240,7 @@ class PlainTextFormatter(Formatter): while wrap and self._pos + len(arg) > self.width: # We have to split. maxlen = self.width - self._pos - space = arg.rfind(' ', 0, maxlen) + space = arg.rfind(" ", 0, maxlen) if space == -1: # No space to split on. @@ -254,7 +252,7 @@ class PlainTextFormatter(Formatter): # written something we can also go to the next # line. if self._in_first_line or self._wrote_something: - bit = '' + bit = "" else: # Forcibly split this as far to the right as # possible. @@ -263,11 +261,11 @@ class PlainTextFormatter(Formatter): else: bit = arg[:space] # Omit the space we split on. - arg = arg[space + 1:] + arg = arg[space + 1 :] if conversion_needed: bit = self._force_encoding(bit) self.stream.write(bit) - self.stream.write(self._force_encoding('\n')) + self.stream.write(self._force_encoding("\n")) self._pos = 0 self._in_first_line = False self._wrote_something = False @@ -280,7 +278,7 @@ class PlainTextFormatter(Formatter): arg = self._force_encoding(arg) self.stream.write(arg) if autoline: - self.stream.write(self._force_encoding('\n')) + self.stream.write(self._force_encoding("\n")) self._wrote_something = False self._pos = 0 self._in_first_line = True @@ -290,32 +288,28 @@ class PlainTextFormatter(Formatter): raise finally: if first_prefixes is not None: - self.first_prefix = self.first_prefix[:-len(first_prefixes)] + self.first_prefix = self.first_prefix[: -len(first_prefixes)] if later_prefixes is not None: - self.later_prefix = self.later_prefix[:-len(later_prefixes)] + self.later_prefix = self.later_prefix[: -len(later_prefixes)] def fg(self, color=None): """change fg color Compatibility method- no coloring escapes are returned from it. """ - return '' + return "" def bg(self, color=None): """change bg color Compatibility method- no coloring escapes are returned from it. """ - return '' + return "" def flush(self): self.stream.flush() - - - - class TerminfoDisabled(Exception): """Raised if Terminfo is disabled.""" @@ -331,7 +325,7 @@ class TerminfoUnsupported(Exception): self.term = term def __str__(self): - return f'unsupported terminal type: {self.term!r}' + return f"unsupported terminal type: {self.term!r}" # This is necessary because the curses module is optional (and we @@ -341,6 +335,7 @@ try: except ImportError: TerminfoColor = None else: + class TerminfoColor: """Class encapsulating a specific terminfo entry for a color. @@ -351,8 +346,8 @@ else: __slots__ = ("mode", "color", "__weakref__") def __init__(self, mode, color): - object.__setattr__(self, 'mode', mode) - object.__setattr__(self, 'color', color) + object.__setattr__(self, "mode", mode) + object.__setattr__(self, "color", color) def __call__(self, formatter): if self.color is None: @@ -374,7 +369,7 @@ else: if template: res = curses.tparm(template, color) else: - res = b'' + res = b"" formatter._current_colors[self.mode] = res formatter.stream.write(res) @@ -393,7 +388,7 @@ else: def __init__(self, value): if value is None: raise _BogusTerminfo() - object.__setattr__(self, 'value', value) + object.__setattr__(self, "value", value) def __setattr__(self, key, value): raise AttributeError(f"{self.__class__.__name__} instances are immutable") @@ -441,33 +436,32 @@ else: super().__init__(stream, encoding=encoding) fd = stream.fileno() if term is None: - if term := os.environ.get('TERM'): + if term := os.environ.get("TERM"): try: curses.setupterm(fd=fd, term=term) except curses.error: pass else: - raise TerminfoDisabled('no terminfo entries') + raise TerminfoDisabled("no terminfo entries") else: # TODO maybe do something more useful than raising curses.error # if term is not in the terminfo db here? curses.setupterm(fd=fd, term=term) self._term = term - self.width = curses.tigetnum('cols') + self.width = curses.tigetnum("cols") try: - self.reset = TerminfoReset(curses.tigetstr('sgr0')) - self.bold = TerminfoMode(curses.tigetstr('bold')) - self.underline = TerminfoMode(curses.tigetstr('smul')) - self._color_reset = curses.tigetstr('op') - self._set_color = ( - curses.tigetstr('setaf'), - curses.tigetstr('setab')) + self.reset = TerminfoReset(curses.tigetstr("sgr0")) + self.bold = TerminfoMode(curses.tigetstr("bold")) + self.underline = TerminfoMode(curses.tigetstr("smul")) + self._color_reset = curses.tigetstr("op") + self._set_color = (curses.tigetstr("setaf"), curses.tigetstr("setab")) except (_BogusTerminfo, curses.error) as e: raise TerminfoUnsupported(self._term) from e if not all(self._set_color): raise TerminfoDisabled( - 'setting background/foreground colors is not supported') + "setting background/foreground colors is not supported" + ) curses.tparm(self._set_color[0], curses.COLOR_WHITE) @@ -507,16 +501,14 @@ else: # not set the hs flag. So just check for the ability to # jump to and out of the status line, without checking if # the status line we're using exists. - tsl = curses.tigetstr('tsl') - fsl = curses.tigetstr('fsl') + tsl = curses.tigetstr("tsl") + fsl = curses.tigetstr("fsl") if tsl and fsl: - self.stream.write( - tsl + string.encode(self.encoding, 'replace') + fsl) + self.stream.write(tsl + string.encode(self.encoding, "replace") + fsl) self.stream.flush() class ObserverFormatter: - def __init__(self, real_formatter): self._formatter = real_formatter @@ -542,7 +534,7 @@ def get_formatter(stream, force_color=False): # needs an fd to pass to curses, not just a filelike talking to a tty. if os.isatty(fd) or force_color: try: - term = 'ansi' if force_color else None + term = "ansi" if force_color else None return TerminfoFormatter(stream, term=term) except (curses.error, TerminfoDisabled, TerminfoUnsupported): # This happens if TERM is unset and possibly in more cases. @@ -553,6 +545,7 @@ def get_formatter(stream, force_color=False): def decorate_forced_wrapping(setting=True): """Decorator to force a specific line wrapping state for the duration of invocation.""" + def wrapped_func(func): def f(out, *args, **kwds): oldwrap = out.wrap @@ -561,5 +554,7 @@ def decorate_forced_wrapping(setting=True): return func(out, *args, **kwds) finally: out.wrap = oldwrap + return f + return wrapped_func diff --git a/src/snakeoil/iterables.py b/src/snakeoil/iterables.py index 787af67..7176a2b 100644 --- a/src/snakeoil/iterables.py +++ b/src/snakeoil/iterables.py @@ -19,8 +19,7 @@ def partition(iterable, predicate=bool): filter and the second the matched items. """ a, b = itertools.tee((predicate(x), x) for x in iterable) - return ((x for pred, x in a if not pred), - (x for pred, x in b if pred)) + return ((x for pred, x in a if not pred), (x for pred, x in b if pred)) class expandable_chain: @@ -107,6 +106,7 @@ class caching_iter: 3 """ + __slots__ = ("iterable", "__weakref__", "cached_list", "sorter") def __init__(self, iterable, sorter=None): @@ -139,7 +139,7 @@ class caching_iter: if self.iterable is not None: i = itertools.islice(self.iterable, 0, index - (existing_len - 1)) self.cached_list.extend(i) - if len(self.cached_list) -1 != index: + if len(self.cached_list) - 1 != index: # consumed, baby. self.iterable = None self.cached_list = tuple(self.cached_list) @@ -209,8 +209,7 @@ class caching_iter: return len(self.cached_list) def __iter__(self): - if (self.sorter is not None and - self.iterable is not None): + if self.sorter is not None and self.iterable is not None: if self.cached_list: self.cached_list.extend(self.iterable) self.cached_list = tuple(self.sorter(self.cached_list)) @@ -237,8 +236,7 @@ class caching_iter: return hash(self.cached_list) def __str__(self): - return "iterable(%s), cached: %s" % ( - self.iterable, str(self.cached_list)) + return "iterable(%s), cached: %s" % (self.iterable, str(self.cached_list)) def iter_sort(sorter, *iterables): diff --git a/src/snakeoil/klass.py b/src/snakeoil/klass.py index 0e59258..23d6d3c 100644 --- a/src/snakeoil/klass.py +++ b/src/snakeoil/klass.py @@ -7,12 +7,27 @@ involved in writing classes. """ __all__ = ( - "generic_equality", "reflective_hash", "inject_richcmp_methods_from_cmp", - "static_attrgetter", "instance_attrgetter", "jit_attr", "jit_attr_none", - "jit_attr_named", "jit_attr_ext_method", "alias_attr", "cached_hash", - "cached_property", "cached_property_named", - "steal_docs", "immutable_instance", "inject_immutable_instance", - "alias_method", "aliased", "alias", "patch", "SlotsPicklingMixin", + "generic_equality", + "reflective_hash", + "inject_richcmp_methods_from_cmp", + "static_attrgetter", + "instance_attrgetter", + "jit_attr", + "jit_attr_none", + "jit_attr_named", + "jit_attr_ext_method", + "alias_attr", + "cached_hash", + "cached_property", + "cached_property_named", + "steal_docs", + "immutable_instance", + "inject_immutable_instance", + "alias_method", + "aliased", + "alias", + "patch", + "SlotsPicklingMixin", ) import inspect @@ -32,6 +47,7 @@ sentinel = object() def GetAttrProxy(target): def reflected_getattr(self, attr): return getattr(object.__getattribute__(self, target), attr) + return reflected_getattr @@ -43,6 +59,7 @@ def DirProxy(target): except AttributeError: attrs.extend(obj.__slots__) return sorted(set(attrs)) + return combined_dir @@ -69,6 +86,8 @@ def get(self, key, default=None): _attrlist_getter = attrgetter("__attr_comparison__") + + def generic_attr_eq(inst1, inst2): """ compare inst1 to inst2, returning True if equal, False if not. @@ -78,8 +97,7 @@ def generic_attr_eq(inst1, inst2): if inst1 is inst2: return True for attr in _attrlist_getter(inst1): - if getattr(inst1, attr, sentinel) != \ - getattr(inst2, attr, sentinel): + if getattr(inst1, attr, sentinel) != getattr(inst2, attr, sentinel): return False return True @@ -105,28 +123,36 @@ def reflective_hash(attr): :param attr: attribute name to pull the hash from on the instance :return: hash value for instance this func is used in. """ + def __hash__(self): return getattr(self, attr) + return __hash__ + def _internal_jit_attr( - func, attr_name, singleton=None, - use_cls_setattr=False, use_singleton=True, doc=None): + func, attr_name, singleton=None, use_cls_setattr=False, use_singleton=True, doc=None +): """Object implementing the descriptor protocol for use in Just In Time access to attributes. Consumers should likely be using the :py:func:`jit_func` line of helper functions instead of directly consuming this. """ - doc = getattr(func, '__doc__', None) if doc is None else doc + doc = getattr(func, "__doc__", None) if doc is None else doc class _internal_jit_attr(_raw_internal_jit_attr): __doc__ = doc __slots__ = () + kls = _internal_jit_attr return kls( - func, attr_name, singleton=singleton, use_cls_setattr=use_cls_setattr, - use_singleton=use_singleton) + func, + attr_name, + singleton=singleton, + use_cls_setattr=use_cls_setattr, + use_singleton=use_singleton, + ) class _raw_internal_jit_attr: @@ -134,8 +160,9 @@ class _raw_internal_jit_attr: __slots__ = ("storage_attr", "function", "_setter", "singleton", "use_singleton") - def __init__(self, func, attr_name, singleton=None, - use_cls_setattr=False, use_singleton=True): + def __init__( + self, func, attr_name, singleton=None, use_cls_setattr=False, use_singleton=True + ): """ :param func: function to invoke upon first request for this content :param attr_name: attribute name to store the generated value in @@ -178,8 +205,9 @@ class _raw_internal_jit_attr: return obj -def generic_equality(name, bases, scope, real_type=type, - eq=generic_attr_eq, ne=generic_attr_ne): +def generic_equality( + name, bases, scope, real_type=type, eq=generic_attr_eq, ne=generic_attr_ne +): """ metaclass generating __eq__/__ne__ methods from an attribute list @@ -208,7 +236,9 @@ def generic_equality(name, bases, scope, real_type=type, attrlist = scope[attrlist] for x in attrlist: if not isinstance(x, str): - raise TypeError(f"all members of attrlist must be strings- got {type(x)!r} {x!r}") + raise TypeError( + f"all members of attrlist must be strings- got {type(x)!r} {x!r}" + ) scope["__attr_comparison__"] = tuple(attrlist) scope.setdefault("__eq__", eq) @@ -285,9 +315,14 @@ def inject_richcmp_methods_from_cmp(scope): :param scope: the modifiable scope of a class namespace to work on """ - for key, func in (("__lt__", generic_lt), ("__le__", generic_le), - ("__eq__", generic_eq), ("__ne__", generic_ne), - ("__ge__", generic_ge), ("__gt__", generic_gt)): + for key, func in ( + ("__lt__", generic_lt), + ("__le__", generic_le), + ("__eq__", generic_eq), + ("__ne__", generic_ne), + ("__ge__", generic_ge), + ("__gt__", generic_gt), + ): scope.setdefault(key, func) @@ -329,7 +364,8 @@ class chained_getter(metaclass=partial(generic_equality, real_type=WeakInstMeta) >>> print(o.recursive == foo.seq.__hash__) True """ - __slots__ = ('namespace', 'getter') + + __slots__ = ("namespace", "getter") __fifo_cache__ = deque() __inst_caching__ = True __attr_comparison__ = ("namespace",) @@ -361,16 +397,20 @@ instance_attrgetter = chained_getter # this annoyingly means our docs have to be recommitted every change, # even if no real code changed (since the id() continually moves)... class _singleton_kls: - def __str__(self): return "uncached singleton instance" _uncached_singleton = _singleton_kls -T = typing.TypeVar('T') +T = typing.TypeVar("T") + -def jit_attr(func: typing.Callable[[typing.Any], T], kls=_internal_jit_attr, uncached_val: typing.Any=_uncached_singleton) -> T: +def jit_attr( + func: typing.Callable[[typing.Any], T], + kls=_internal_jit_attr, + uncached_val: typing.Any = _uncached_singleton, +) -> T: """ decorator to JIT generate, and cache the wrapped functions result in '_' + func.__name__ on the instance. @@ -399,8 +439,13 @@ def jit_attr_none(func: typing.Callable[[typing.Any], T], kls=_internal_jit_attr return jit_attr(func, kls=kls, uncached_val=None) -def jit_attr_named(stored_attr_name: str, use_cls_setattr=False, kls=_internal_jit_attr, - uncached_val: typing.Any=_uncached_singleton, doc=None): +def jit_attr_named( + stored_attr_name: str, + use_cls_setattr=False, + kls=_internal_jit_attr, + uncached_val: typing.Any = _uncached_singleton, + doc=None, +): """ Version of :py:func:`jit_attr` decorator that allows for explicit control over the attribute name used to store the cache value. @@ -410,9 +455,14 @@ def jit_attr_named(stored_attr_name: str, use_cls_setattr=False, kls=_internal_j return post_curry(kls, stored_attr_name, uncached_val, use_cls_setattr, doc=doc) -def jit_attr_ext_method(func_name: str, stored_attr_name: str, - use_cls_setattr=False, kls=_internal_jit_attr, - uncached_val: typing.Any=_uncached_singleton, doc=None): +def jit_attr_ext_method( + func_name: str, + stored_attr_name: str, + use_cls_setattr=False, + kls=_internal_jit_attr, + uncached_val: typing.Any = _uncached_singleton, + doc=None, +): """ Decorator handing maximal control of attribute JIT'ing to the invoker. @@ -421,11 +471,20 @@ def jit_attr_ext_method(func_name: str, stored_attr_name: str, Generally speaking, you only need this when you are doing something rather *special*. """ - return kls(alias_method(func_name), stored_attr_name, - uncached_val, use_cls_setattr, doc=doc) + return kls( + alias_method(func_name), + stored_attr_name, + uncached_val, + use_cls_setattr, + doc=doc, + ) -def cached_property(func: typing.Callable[[typing.Any], T], kls=_internal_jit_attr, use_cls_setattr=False) -> T: +def cached_property( + func: typing.Callable[[typing.Any], T], + kls=_internal_jit_attr, + use_cls_setattr=False, +) -> T: """ like `property`, just with caching @@ -454,8 +513,9 @@ def cached_property(func: typing.Callable[[typing.Any], T], kls=_internal_jit_at >>> print(obj.attr) 1 """ - return kls(func, func.__name__, None, use_singleton=False, - use_cls_setattr=use_cls_setattr) + return kls( + func, func.__name__, None, use_singleton=False, use_cls_setattr=use_cls_setattr + ) def cached_property_named(name: str, kls=_internal_jit_attr, use_cls_setattr=False): @@ -538,11 +598,13 @@ def cached_hash(func): >>> assert hash(f) == 12345 # note we still get the same value >>> assert f.hash_invocations == 1 # and that the function was invoked only once. """ + def __hash__(self): - val = getattr(self, '_hash', None) + val = getattr(self, "_hash", None) if val is None: - object.__setattr__(self, '_hash', val := func(self)) + object.__setattr__(self, "_hash", val := func(self)) return val + return __hash__ @@ -574,6 +636,7 @@ def steal_docs(target, ignore_missing=False, name=None): >>> f = foo([1,2,3]) >>> assert f.extend.__doc__ == list.extend.__doc__ """ + def inner(functor): if inspect.isclass(target): if name is not None: @@ -590,6 +653,7 @@ def steal_docs(target, ignore_missing=False, name=None): obj = target functor.__doc__ = obj.__doc__ return functor + return inner @@ -611,7 +675,7 @@ def patch(target, external_decorator=None): """ def _import_module(target): - components = target.split('.') + components = target.split(".") import_path = components.pop(0) module = import_module(import_path) for comp in components: @@ -624,7 +688,7 @@ def patch(target, external_decorator=None): def _get_target(target): try: - module, attr = target.rsplit('.', 1) + module, attr = target.rsplit(".", 1) except (TypeError, ValueError): raise TypeError(f"invalid target: {target!r}") module = _import_module(module) @@ -632,7 +696,7 @@ def patch(target, external_decorator=None): def decorator(func): # use the original function wrapper - func = getattr(func, '_func', func) + func = getattr(func, "_func", func) module, attr = _get_target(target) orig_func = getattr(module, attr) @@ -749,6 +813,7 @@ class alias: >>> speak = Speak() >>> assert speak.shout('foo') == speak.yell('foo') == speak.scream('foo') """ + def __init__(self, *aliases): self.aliases = set(aliases) @@ -762,11 +827,14 @@ def aliased(cls): orig_methods = cls.__dict__.copy() seen_aliases = set() for _name, method in orig_methods.items(): - if hasattr(method, '_aliases'): - collisions = method._aliases.intersection(orig_methods.keys() | seen_aliases) + if hasattr(method, "_aliases"): + collisions = method._aliases.intersection( + orig_methods.keys() | seen_aliases + ) if collisions: raise ValueError( - f"aliases collide with existing attributes: {', '.join(collisions)}") + f"aliases collide with existing attributes: {', '.join(collisions)}" + ) seen_aliases |= method._aliases for alias in method._aliases: setattr(cls, alias, method) @@ -780,9 +848,13 @@ class SlotsPicklingMixin: def __getstate__(self): all_slots = itertools.chain.from_iterable( - getattr(t, '__slots__', ()) for t in type(self).__mro__) - state = {attr: getattr(self, attr) for attr in all_slots - if hasattr(self, attr) and attr != '__weakref__'} + getattr(t, "__slots__", ()) for t in type(self).__mro__ + ) + state = { + attr: getattr(self, attr) + for attr in all_slots + if hasattr(self, attr) and attr != "__weakref__" + } return state def __setstate__(self, state): diff --git a/src/snakeoil/mappings.py b/src/snakeoil/mappings.py index c349897..d4ac221 100644 --- a/src/snakeoil/mappings.py +++ b/src/snakeoil/mappings.py @@ -3,10 +3,17 @@ Miscellaneous mapping related classes and functionality """ __all__ = ( - "DictMixin", "LazyValDict", "LazyFullValLoadDict", - "ProtectedDict", "ImmutableDict", "IndeterminantDict", - "defaultdictkey", "AttrAccessible", "StackedDict", - "make_SlottedDict_kls", "ProxiedAttrs", + "DictMixin", + "LazyValDict", + "LazyFullValLoadDict", + "ProtectedDict", + "ImmutableDict", + "IndeterminantDict", + "defaultdictkey", + "AttrAccessible", + "StackedDict", + "make_SlottedDict_kls", + "ProxiedAttrs", ) import operator @@ -168,6 +175,7 @@ class LazyValDict(DictMixin): given a function to get keys, and to look up the val for those keys, it'll lazily load key definitions and values as requested """ + __slots__ = ("_keys", "_keys_func", "_vals", "_val_func") __externally_mutable__ = False @@ -184,8 +192,7 @@ class LazyValDict(DictMixin): self._keys_func = None else: if not callable(get_keys_func): - raise TypeError( - "get_keys_func isn't iterable or callable") + raise TypeError("get_keys_func isn't iterable or callable") self._keys_func = get_keys_func self._val_func = get_val_func self._vals = {} @@ -234,6 +241,7 @@ class LazyFullValLoadDict(LazyValDict): The val function must still return values one by one per key. """ + __slots__ = () def __getitem__(self, key): @@ -297,8 +305,7 @@ class ProtectedDict(DictMixin): yield k def __contains__(self, key): - return key in self.new or (key not in self.blacklist and - key in self.orig) + return key in self.new or (key not in self.blacklist and key in self.orig) class ImmutableDict(Mapping): @@ -320,14 +327,14 @@ class ImmutableDict(Mapping): try: mapping = {k: v for k, v in data} except TypeError as e: - raise TypeError(f'unsupported data format: {e}') - object.__setattr__(self, '_dict', mapping) + raise TypeError(f"unsupported data format: {e}") + object.__setattr__(self, "_dict", mapping) def __getitem__(self, key): # hack to avoid recursion exceptions for subclasses that use # inject_getitem_as_getattr() - if key == '_dict': - return object.__getattribute__(self, '_dict') + if key == "_dict": + return object.__getattribute__(self, "_dict") return self._dict[key] def __iter__(self): @@ -356,7 +363,7 @@ class OrderedFrozenSet(Set): try: self._dict = ImmutableDict({x: None for x in iterable}) except TypeError as e: - raise TypeError('not iterable') from e + raise TypeError("not iterable") from e def __contains__(self, key): return key in self._dict @@ -369,7 +376,7 @@ class OrderedFrozenSet(Set): try: return next(islice(self._dict, key, None)) except StopIteration: - raise IndexError('index out of range') + raise IndexError("index out of range") # handle keys using slice notation return self.__class__(list(self._dict)[key]) @@ -384,8 +391,8 @@ class OrderedFrozenSet(Set): return set(self._dict) == other def __str__(self): - elements_str = ', '.join(map(repr, self._dict)) - return f'{{{elements_str}}}' + elements_str = ", ".join(map(repr, self._dict)) + return f"{{{elements_str}}}" def __repr__(self): return self.__str__() @@ -413,7 +420,7 @@ class OrderedSet(OrderedFrozenSet, MutableSet): try: self._dict = {x: None for x in iterable} except TypeError as e: - raise TypeError('not iterable') from e + raise TypeError("not iterable") from e def add(self, value): self._dict[value] = None @@ -434,7 +441,7 @@ class OrderedSet(OrderedFrozenSet, MutableSet): self._dict.update((x, None) for x in iterable) def __hash__(self): - raise TypeError(f'unhashable type: {self.__class__.__name__!r}') + raise TypeError(f"unhashable type: {self.__class__.__name__!r}") class IndeterminantDict: @@ -473,12 +480,21 @@ class IndeterminantDict: def __unmodifiable(func, *args): raise TypeError(f"indeterminate dict: '{func}()' can't modify {args!r}") - for func in ('__delitem__', '__setitem__', 'setdefault', 'popitem', 'update', 'clear'): + + for func in ( + "__delitem__", + "__setitem__", + "setdefault", + "popitem", + "update", + "clear", + ): locals()[func] = partial(__unmodifiable, func) def __indeterminate(func, *args): raise TypeError(f"indeterminate dict: '{func}()' is inaccessible") - for func in ('__iter__', '__len__', 'keys', 'values', 'items'): + + for func in ("__iter__", "__len__", "keys", "values", "items"): locals()[func] = partial(__indeterminate, func) @@ -650,6 +666,7 @@ def _KeyError_to_Attr(functor): return functor(self, *args) except KeyError: raise AttributeError(args[0]) + inner.__name__ = functor.__name__ inner.__doc__ = functor.__doc__ return inner @@ -681,9 +698,9 @@ def inject_getitem_as_getattr(scope): :param scope: the scope of a class to modify, adding methods as needed """ - scope.setdefault('__getattr__', _KeyError_to_Attr(operator.__getitem__)) - scope.setdefault('__delattr__', _KeyError_to_Attr(operator.__delitem__)) - scope.setdefault('__setattr__', _KeyError_to_Attr(operator.__setitem__)) + scope.setdefault("__getattr__", _KeyError_to_Attr(operator.__getitem__)) + scope.setdefault("__delattr__", _KeyError_to_Attr(operator.__delitem__)) + scope.setdefault("__setattr__", _KeyError_to_Attr(operator.__setitem__)) class AttrAccessible(dict): @@ -713,7 +730,7 @@ class ProxiedAttrs(DictMixin): :param target: The object to wrap. """ - __slots__ = ('__target__',) + __slots__ = ("__target__",) def __init__(self, target): self.__target__ = target @@ -860,7 +877,7 @@ class _SlottedDict(DictMixin): def make_SlottedDict_kls(keys): """Create a space efficient mapping class with a limited set of keys.""" new_keys = tuple(sorted(keys)) - cls_name = f'SlottedDict_{hash(new_keys)}' + cls_name = f"SlottedDict_{hash(new_keys)}" o = globals().get(cls_name, None) if o is None: o = type(cls_name, (_SlottedDict,), {}) diff --git a/src/snakeoil/modules.py b/src/snakeoil/modules.py index 740ea24..ec69701 100644 --- a/src/snakeoil/modules.py +++ b/src/snakeoil/modules.py @@ -14,6 +14,7 @@ class FailedImport(ImportError): """ Raised when a requested target cannot be imported """ + def __init__(self, trg, e): super().__init__(self, f"Failed importing target '{trg}': '{e}'") self.trg, self.e = trg, e diff --git a/src/snakeoil/obj.py b/src/snakeoil/obj.py index 73cede9..a8598bf 100644 --- a/src/snakeoil/obj.py +++ b/src/snakeoil/obj.py @@ -74,7 +74,6 @@ try to proxy builtin objects like tuples, lists, dicts, sets, etc. """ - __all__ = ("DelayedInstantiation", "DelayedInstantiation_kls", "make_kls", "popattr") from . import klass @@ -87,14 +86,25 @@ from . import klass # pointless class creation- thus having two separate lists. base_kls_descriptors = [ - '__delattr__', '__hash__', '__reduce__', - '__reduce_ex__', '__repr__', '__setattr__', '__str__', - '__format__', '__subclasshook__', # >=py2.6 - '__le__', '__lt__', '__ge__', '__gt__', '__eq__', '__ne__', # py3 - '__dir__', # >=py3.3 + "__delattr__", + "__hash__", + "__reduce__", + "__reduce_ex__", + "__repr__", + "__setattr__", + "__str__", + "__format__", + "__subclasshook__", # >=py2.6 + "__le__", + "__lt__", + "__ge__", + "__gt__", + "__eq__", + "__ne__", # py3 + "__dir__", # >=py3.3 ] -if hasattr(object, '__sizeof__'): - base_kls_descriptors.append('__sizeof__') +if hasattr(object, "__sizeof__"): + base_kls_descriptors.append("__sizeof__") base_kls_descriptors = frozenset(base_kls_descriptors) @@ -134,13 +144,13 @@ class BaseDelayedObject: def __getattribute__(self, attr): obj = object.__getattribute__(self, "__obj__") if obj is None: - if attr == '__class__': + if attr == "__class__": return object.__getattribute__(self, "__delayed__")[0] - elif attr == '__doc__': + elif attr == "__doc__": kls = object.__getattribute__(self, "__delayed__")[0] - return getattr(kls, '__doc__', None) + return getattr(kls, "__doc__", None) - obj = object.__getattribute__(self, '__instantiate_proxy_instance__')() + obj = object.__getattribute__(self, "__instantiate_proxy_instance__")() if attr == "__obj__": # special casing for klass.alias_method @@ -157,61 +167,122 @@ class BaseDelayedObject: # special case the normal descriptors for x in base_kls_descriptors: locals()[x] = klass.alias_method( - "__obj__.%s" % (x,), - doc=getattr(getattr(object, x), '__doc__', None)) + "__obj__.%s" % (x,), doc=getattr(getattr(object, x), "__doc__", None) + ) # pylint: disable=undefined-loop-variable del x # note that we ignore __getattribute__; we already handle it. -kls_descriptors = frozenset([ - # rich comparison protocol... - '__le__', '__lt__', '__eq__', '__ne__', '__gt__', '__ge__', - # unicode conversion - '__unicode__', - # truth... - '__bool__', - # container protocol... - '__len__', '__getitem__', '__setitem__', '__delitem__', - '__iter__', '__contains__', '__index__', '__reversed__', - # deprecated sequence protocol bits... - '__getslice__', '__setslice__', '__delslice__', - # numeric... - '__add__', '__sub__', '__mul__', '__floordiv__', '__mod__', - '__divmod__', '__pow__', '__lshift__', '__rshift__', - '__and__', '__xor__', '__or__', '__div__', '__truediv__', - '__rad__', '__rsub__', '__rmul__', '__rdiv__', '__rtruediv__', - '__rfloordiv__', '__rmod__', '__rdivmod__', '__rpow__', - '__rlshift__', '__rrshift__', '__rand__', '__rxor__', '__ror__', - '__iadd__', '__isub__', '__imul__', '__idiv__', '__itruediv__', - '__ifloordiv__', '__imod__', '__ipow__', '__ilshift__', - '__irshift__', '__iand__', '__ixor__', '__ior__', - '__neg__', '__pos__', '__abs__', '__invert__', '__complex__', - '__int__', '__long__', '__float__', '__oct__', '__hex__', - '__coerce__', '__trunc__', '__radd__', '__floor__', '__ceil__', - '__round__', - # remaining... - '__call__', '__sizeof__', -]) +kls_descriptors = frozenset( + [ + # rich comparison protocol... + "__le__", + "__lt__", + "__eq__", + "__ne__", + "__gt__", + "__ge__", + # unicode conversion + "__unicode__", + # truth... + "__bool__", + # container protocol... + "__len__", + "__getitem__", + "__setitem__", + "__delitem__", + "__iter__", + "__contains__", + "__index__", + "__reversed__", + # deprecated sequence protocol bits... + "__getslice__", + "__setslice__", + "__delslice__", + # numeric... + "__add__", + "__sub__", + "__mul__", + "__floordiv__", + "__mod__", + "__divmod__", + "__pow__", + "__lshift__", + "__rshift__", + "__and__", + "__xor__", + "__or__", + "__div__", + "__truediv__", + "__rad__", + "__rsub__", + "__rmul__", + "__rdiv__", + "__rtruediv__", + "__rfloordiv__", + "__rmod__", + "__rdivmod__", + "__rpow__", + "__rlshift__", + "__rrshift__", + "__rand__", + "__rxor__", + "__ror__", + "__iadd__", + "__isub__", + "__imul__", + "__idiv__", + "__itruediv__", + "__ifloordiv__", + "__imod__", + "__ipow__", + "__ilshift__", + "__irshift__", + "__iand__", + "__ixor__", + "__ior__", + "__neg__", + "__pos__", + "__abs__", + "__invert__", + "__complex__", + "__int__", + "__long__", + "__float__", + "__oct__", + "__hex__", + "__coerce__", + "__trunc__", + "__radd__", + "__floor__", + "__ceil__", + "__round__", + # remaining... + "__call__", + "__sizeof__", + ] +) kls_descriptors = kls_descriptors.difference(base_kls_descriptors) -descriptor_overrides = {k: klass.alias_method(f"__obj__.{k}") - for k in kls_descriptors} +descriptor_overrides = {k: klass.alias_method(f"__obj__.{k}") for k in kls_descriptors} _method_cache = {} + + def make_kls(kls, proxy_base=BaseDelayedObject): special_descriptors = kls_descriptors.intersection(dir(kls)) - doc = getattr(kls, '__doc__', None) + doc = getattr(kls, "__doc__", None) if not special_descriptors and doc is None: return proxy_base key = (tuple(sorted(special_descriptors)), doc) o = _method_cache.get(key, None) if o is None: + class CustomDelayedObject(proxy_base): - locals().update((k, descriptor_overrides[k]) - for k in special_descriptors) + locals().update((k, descriptor_overrides[k]) for k in special_descriptors) __doc__ = doc o = CustomDelayedObject @@ -230,6 +301,8 @@ def DelayedInstantiation_kls(kls, *a, **kwd): _class_cache = {} + + def DelayedInstantiation(resultant_kls, func, *a, **kwd): """Generate an objects that does not get initialized before it is used. diff --git a/src/snakeoil/osutils/__init__.py b/src/snakeoil/osutils/__init__.py index f0ea498..036a78d 100644 --- a/src/snakeoil/osutils/__init__.py +++ b/src/snakeoil/osutils/__init__.py @@ -34,9 +34,18 @@ pretty quickly. """ __all__ = ( - 'abspath', 'abssymlink', 'ensure_dirs', 'join', 'pjoin', 'listdir_files', - 'listdir_dirs', 'listdir', 'readdir', 'normpath', 'unlink_if_exists', - 'supported_systems', + "abspath", + "abssymlink", + "ensure_dirs", + "join", + "pjoin", + "listdir_files", + "listdir_dirs", + "listdir", + "readdir", + "normpath", + "unlink_if_exists", + "supported_systems", ) import errno @@ -86,14 +95,18 @@ def supported_systems(*systems): ... NotImplementedError: func2 not supported on nonexistent """ + def _decorator(f): def _wrapper(*args, **kwargs): if sys.platform.startswith(systems): return f(*args, **kwargs) else: - raise NotImplementedError('%s not supported on %s' - % (f.__name__, sys.platform)) + raise NotImplementedError( + "%s not supported on %s" % (f.__name__, sys.platform) + ) + return _wrapper + return _decorator @@ -134,7 +147,7 @@ def ensure_dirs(path, gid=-1, uid=-1, mode=0o777, minimal=True): try: um = os.umask(0) # if the dir perms would lack +wx, we have to force it - force_temp_perms = ((mode & 0o300) != 0o300) + force_temp_perms = (mode & 0o300) != 0o300 resets = [] apath = normpath(os.path.abspath(path)) sticky_parent = False @@ -149,7 +162,7 @@ def ensure_dirs(path, gid=-1, uid=-1, mode=0o777, minimal=True): # if it's a subdir, we need +wx at least if apath != base: - sticky_parent = (st.st_mode & stat.S_ISGID) + sticky_parent = st.st_mode & stat.S_ISGID except OSError: # nothing exists. @@ -185,8 +198,7 @@ def ensure_dirs(path, gid=-1, uid=-1, mode=0o777, minimal=True): return False try: - if ((gid != -1 and gid != st.st_gid) or - (uid != -1 and uid != st.st_uid)): + if (gid != -1 and gid != st.st_gid) or (uid != -1 and uid != st.st_uid): os.chown(path, uid, gid) if minimal: if mode != (st.st_mode & mode): @@ -207,9 +219,9 @@ def abssymlink(path): a symlink """ mylink = os.readlink(path) - if mylink[0] != '/': + if mylink[0] != "/": mydir = os.path.dirname(path) - mylink = mydir + '/' + mylink + mylink = mydir + "/" + mylink return normpath(mylink) @@ -256,7 +268,7 @@ def normpath(mypath: str) -> str: `os.path.normpath` only in that it'll convert leading '//' into '/' """ newpath = os.path.normpath(mypath) - double_sep = b'//' if isinstance(newpath, bytes) else '//' + double_sep = b"//" if isinstance(newpath, bytes) else "//" if newpath.startswith(double_sep): return newpath[1:] return newpath @@ -306,9 +318,9 @@ def fallback_access(path, mode, root=0): return mode == (mode & (st.st_mode & 0x7)) -if os.uname()[0].lower() == 'sunos': +if os.uname()[0].lower() == "sunos": access = fallback_access - access.__name__ = 'access' + access.__name__ = "access" else: access = os.access diff --git a/src/snakeoil/osutils/mount.py b/src/snakeoil/osutils/mount.py index 05eb47d..b28853d 100644 --- a/src/snakeoil/osutils/mount.py +++ b/src/snakeoil/osutils/mount.py @@ -1,4 +1,4 @@ -__all__ = ('mount', 'umount') +__all__ = ("mount", "umount") import ctypes import os @@ -40,10 +40,10 @@ MNT_EXPIRE = 4 UMOUNT_NOFOLLOW = 8 -@supported_systems('linux') +@supported_systems("linux") def mount(source, target, fstype, flags, data=None): """Call mount(2); see the man page for details.""" - libc = ctypes.CDLL(find_library('c'), use_errno=True) + libc = ctypes.CDLL(find_library("c"), use_errno=True) source = source.encode() if isinstance(source, str) else source target = target.encode() if isinstance(target, str) else target fstype = fstype.encode() if isinstance(fstype, str) else fstype @@ -52,10 +52,10 @@ def mount(source, target, fstype, flags, data=None): raise OSError(e, os.strerror(e)) -@supported_systems('linux') +@supported_systems("linux") def umount(target, flags=None): """Call umount or umount2; see the umount(2) man page for details.""" - libc = ctypes.CDLL(find_library('c'), use_errno=True) + libc = ctypes.CDLL(find_library("c"), use_errno=True) target = target.encode() if isinstance(target, str) else target args = [] func = libc.umount diff --git a/src/snakeoil/osutils/native_readdir.py b/src/snakeoil/osutils/native_readdir.py index 0efb9c9..b129a9b 100644 --- a/src/snakeoil/osutils/native_readdir.py +++ b/src/snakeoil/osutils/native_readdir.py @@ -3,8 +3,18 @@ import errno import os -from stat import (S_IFBLK, S_IFCHR, S_IFDIR, S_IFIFO, S_IFLNK, S_IFMT, S_IFREG, S_IFSOCK, S_ISDIR, - S_ISREG) +from stat import ( + S_IFBLK, + S_IFCHR, + S_IFDIR, + S_IFIFO, + S_IFLNK, + S_IFMT, + S_IFREG, + S_IFSOCK, + S_ISDIR, + S_ISREG, +) from ..mappings import ProtectedDict @@ -14,6 +24,7 @@ listdir = os.listdir # import cycle. pjoin = os.path.join + def stat_swallow_enoent(path, check, default=False, stat=os.stat): try: return check(stat(path).st_mode) @@ -22,6 +33,7 @@ def stat_swallow_enoent(path, check, default=False, stat=os.stat): return default raise + def listdir_dirs(path, followSymlinks=True): """ Return a list of all subdirectories within a directory @@ -36,11 +48,12 @@ def listdir_dirs(path, followSymlinks=True): pjf = pjoin lstat = os.lstat if followSymlinks: - return [x for x in os.listdir(path) if - stat_swallow_enoent(pjf(path, x), scheck)] + return [ + x for x in os.listdir(path) if stat_swallow_enoent(pjf(path, x), scheck) + ] lstat = os.lstat - return [x for x in os.listdir(path) if - scheck(lstat(pjf(path, x)).st_mode)] + return [x for x in os.listdir(path) if scheck(lstat(pjf(path, x)).st_mode)] + def listdir_files(path, followSymlinks=True): """ @@ -56,24 +69,28 @@ def listdir_files(path, followSymlinks=True): scheck = S_ISREG pjf = pjoin if followSymlinks: - return [x for x in os.listdir(path) if - stat_swallow_enoent(pjf(path, x), scheck)] + return [ + x for x in os.listdir(path) if stat_swallow_enoent(pjf(path, x), scheck) + ] lstat = os.lstat - return [x for x in os.listdir(path) if - scheck(lstat(pjf(path, x)).st_mode)] + return [x for x in os.listdir(path) if scheck(lstat(pjf(path, x)).st_mode)] + # we store this outside the function to ensure that # the strings used are reused, thus avoiding unneeded # allocations -d_type_mapping = ProtectedDict({ - S_IFREG: "file", - S_IFDIR: "directory", - S_IFLNK: "symlink", - S_IFCHR: "chardev", - S_IFBLK: "block", - S_IFSOCK: "socket", - S_IFIFO: "fifo", -}) +d_type_mapping = ProtectedDict( + { + S_IFREG: "file", + S_IFDIR: "directory", + S_IFLNK: "symlink", + S_IFCHR: "chardev", + S_IFBLK: "block", + S_IFSOCK: "socket", + S_IFIFO: "fifo", + } +) + def readdir(path): """ diff --git a/src/snakeoil/pickling.py b/src/snakeoil/pickling.py index e212e6d..9707ecd 100644 --- a/src/snakeoil/pickling.py +++ b/src/snakeoil/pickling.py @@ -3,7 +3,8 @@ pickling convenience module """ __all__ = ( - "iter_stream", "dump_stream", + "iter_stream", + "dump_stream", ) from pickle import dump, load diff --git a/src/snakeoil/process/__init__.py b/src/snakeoil/process/__init__.py index f242269..8966fd6 100644 --- a/src/snakeoil/process/__init__.py +++ b/src/snakeoil/process/__init__.py @@ -44,7 +44,7 @@ def get_exit_status(status: int): if os.WIFSIGNALED(status): return 128 + os.WTERMSIG(status) else: - assert os.WIFEXITED(status), 'Unexpected exit status %r' % status + assert os.WIFEXITED(status), "Unexpected exit status %r" % status return os.WEXITSTATUS(status) @@ -84,16 +84,14 @@ def exit_as_status(status: int): class CommandNotFound(Exception): - def __init__(self, command): - super().__init__(f'failed to find binary: {command!r}') + super().__init__(f"failed to find binary: {command!r}") self.command = command class ProcessNotFound(Exception): - def __init__(self, pid): - super().__init__(f'nonexistent process: {pid}') + super().__init__(f"nonexistent process: {pid}") closerange = os.closerange diff --git a/src/snakeoil/process/namespaces.py b/src/snakeoil/process/namespaces.py index fd6bd74..6a823c6 100644 --- a/src/snakeoil/process/namespaces.py +++ b/src/snakeoil/process/namespaces.py @@ -11,8 +11,15 @@ import socket import subprocess import sys -from ..osutils.mount import (MS_NODEV, MS_NOEXEC, MS_NOSUID, MS_PRIVATE, MS_REC, MS_RELATIME, - MS_SLAVE) +from ..osutils.mount import ( + MS_NODEV, + MS_NOEXEC, + MS_NOSUID, + MS_PRIVATE, + MS_REC, + MS_RELATIME, + MS_SLAVE, +) from ..osutils.mount import mount as _mount from . import exit_as_status @@ -39,7 +46,7 @@ def setns(fd, nstype): fp = open(fd) fd = fp.fileno() - libc = ctypes.CDLL(ctypes.util.find_library('c'), use_errno=True) + libc = ctypes.CDLL(ctypes.util.find_library("c"), use_errno=True) if libc.setns(ctypes.c_int(fd), ctypes.c_int(nstype)) != 0: e = ctypes.get_errno() raise OSError(e, os.strerror(e)) @@ -54,7 +61,7 @@ def unshare(flags): :param flags: Namespaces to unshare; bitwise OR of CLONE_* flags. :raises OSError: if unshare failed. """ - libc = ctypes.CDLL(ctypes.util.find_library('c'), use_errno=True) + libc = ctypes.CDLL(ctypes.util.find_library("c"), use_errno=True) if libc.unshare(ctypes.c_int(flags)) != 0: e = ctypes.get_errno() raise OSError(e, os.strerror(e)) @@ -140,15 +147,13 @@ def create_pidns(): # Make sure to unshare the existing mount point if needed. Some distros # create shared mount points everywhere by default. try: - _mount(None, '/proc', 'proc', MS_PRIVATE | MS_REC) + _mount(None, "/proc", "proc", MS_PRIVATE | MS_REC) except OSError as e: if e.errno != errno.EINVAL: raise # The child needs its own proc mount as it'll be different. - _mount( - 'proc', '/proc', 'proc', - MS_NOSUID | MS_NODEV | MS_NOEXEC | MS_RELATIME) + _mount("proc", "/proc", "proc", MS_NOSUID | MS_NODEV | MS_NOEXEC | MS_RELATIME) if pid := os.fork(): # Mask SIGINT with the assumption that the child will catch & process it. @@ -195,12 +200,13 @@ def create_netns(): # Since we've unshared the net namespace, we need to bring up loopback. # The kernel automatically adds the various ip addresses, so skip that. try: - subprocess.call(['ip', 'link', 'set', 'up', 'lo']) + subprocess.call(["ip", "link", "set", "up", "lo"]) except OSError as e: if e.errno == errno.ENOENT: sys.stderr.write( - 'warning: could not bring up loopback for network; ' - 'install the iproute2 package\n') + "warning: could not bring up loopback for network; " + "install the iproute2 package\n" + ) else: raise @@ -243,16 +249,17 @@ def create_userns(): # For all other errors, abort. They shouldn't happen. raise - with open('/proc/self/setgroups', 'w') as f: - f.write('deny') - with open('/proc/self/uid_map', 'w') as f: - f.write('0 %s 1\n' % uid) - with open('/proc/self/gid_map', 'w') as f: - f.write('0 %s 1\n' % gid) + with open("/proc/self/setgroups", "w") as f: + f.write("deny") + with open("/proc/self/uid_map", "w") as f: + f.write("0 %s 1\n" % uid) + with open("/proc/self/gid_map", "w") as f: + f.write("0 %s 1\n" % gid) -def simple_unshare(mount=True, uts=True, ipc=True, net=False, pid=False, - user=False, hostname=None): +def simple_unshare( + mount=True, uts=True, ipc=True, net=False, pid=False, user=False, hostname=None +): """Simpler helper for setting up namespaces quickly. If support for any namespace type is not available, we'll silently skip it. @@ -278,7 +285,7 @@ def simple_unshare(mount=True, uts=True, ipc=True, net=False, pid=False, # on systems that share the rootfs by default, but allow events in the # parent to propagate down. try: - _mount(None, '/', None, MS_REC | MS_SLAVE) + _mount(None, "/", None, MS_REC | MS_SLAVE) except OSError as e: if e.errno != errno.EINVAL: raise diff --git a/src/snakeoil/process/spawn.py b/src/snakeoil/process/spawn.py index 3413e3f..48b60b1 100644 --- a/src/snakeoil/process/spawn.py +++ b/src/snakeoil/process/spawn.py @@ -3,8 +3,12 @@ subprocess related functionality """ __all__ = [ - "cleanup_pids", "spawn", "spawn_sandbox", "spawn_bash", - "spawn_get_output", "bash_version", + "cleanup_pids", + "spawn", + "spawn_sandbox", + "spawn_bash", + "spawn_get_output", + "bash_version", ] import atexit @@ -17,11 +21,12 @@ from ..mappings import ProtectedDict from ..osutils import access from . import CommandNotFound, closerange, find_binary -BASH_BINARY = find_binary('bash', fallback='/bin/bash') -SANDBOX_BINARY = find_binary('sandbox', fallback='/usr/bin/sandbox') +BASH_BINARY = find_binary("bash", fallback="/bin/bash") +SANDBOX_BINARY = find_binary("sandbox", fallback="/usr/bin/sandbox") try: import resource + max_fd_limit = resource.getrlimit(resource.RLIMIT_NOFILE)[0] except ImportError: max_fd_limit = 256 @@ -36,8 +41,14 @@ def bash_version(force=False): pass try: ret, ver = spawn_get_output( - [BASH_BINARY, '--norc', '--noprofile', '-c', - 'printf ${BASH_VERSINFO[0]}.${BASH_VERSINFO[1]}.${BASH_VERSINFO[2]}']) + [ + BASH_BINARY, + "--norc", + "--noprofile", + "-c", + "printf ${BASH_VERSINFO[0]}.${BASH_VERSINFO[1]}.${BASH_VERSINFO[2]}", + ] + ) if ret == 0: try: ver = ver[0] @@ -54,7 +65,7 @@ def bash_version(force=False): def spawn_bash(mycommand, debug=False, name=None, **kwds): """spawn the command via bash -c""" - args = [BASH_BINARY, '--norc', '--noprofile'] + args = [BASH_BINARY, "--norc", "--noprofile"] if debug: # Print commands and their arguments as they are executed. args.append("-x") @@ -84,6 +95,8 @@ def spawn_sandbox(mycommand, name=None, **kwds): _exithandlers = [] + + def atexit_register(func, *args, **kargs): """Wrapper around atexit.register that is needed in order to track what is registered. For example, when portage restarts itself via @@ -119,6 +132,8 @@ atexit.register(run_exitfuncs) # we exit. spawn() takes care of adding and removing pids to this list # as it creates and cleans up processes. spawned_pids = [] + + def cleanup_pids(pids=None): """reap list of pids if specified, else all children""" @@ -146,8 +161,19 @@ def cleanup_pids(pids=None): pass -def spawn(mycommand, env=None, name=None, fd_pipes=None, returnpid=False, - uid=None, gid=None, groups=None, umask=None, cwd=None, pgid=None): +def spawn( + mycommand, + env=None, + name=None, + fd_pipes=None, + returnpid=False, + uid=None, + gid=None, + groups=None, + umask=None, + cwd=None, + pgid=None, +): """wrapper around execve @@ -177,8 +203,19 @@ def spawn(mycommand, env=None, name=None, fd_pipes=None, returnpid=False, # 'Catch "Exception"' # pylint: disable-msg=W0703 try: - _exec(binary, mycommand, name, fd_pipes, env, gid, groups, - uid, umask, cwd, pgid) + _exec( + binary, + mycommand, + name, + fd_pipes, + env, + gid, + groups, + uid, + umask, + cwd, + pgid, + ) except Exception as e: # We need to catch _any_ exception so that it doesn't # propogate out of this function and cause exiting @@ -228,8 +265,19 @@ def spawn(mycommand, env=None, name=None, fd_pipes=None, returnpid=False, return 0 -def _exec(binary, mycommand, name=None, fd_pipes=None, env=None, gid=None, - groups=None, uid=None, umask=None, cwd=None, pgid=None): +def _exec( + binary, + mycommand, + name=None, + fd_pipes=None, + env=None, + gid=None, + groups=None, + uid=None, + umask=None, + cwd=None, + pgid=None, +): """internal function to handle exec'ing the child process. If it succeeds this function does not return. It might raise an @@ -321,8 +369,15 @@ def _exec(binary, mycommand, name=None, fd_pipes=None, env=None, gid=None, os.execve(binary, myargs, env) -def spawn_get_output(mycommand, spawn_type=None, raw_exit_code=False, collect_fds=(1,), - fd_pipes=None, split_lines=True, **kwds): +def spawn_get_output( + mycommand, + spawn_type=None, + raw_exit_code=False, + collect_fds=(1,), + fd_pipes=None, + split_lines=True, + **kwds, +): """Call spawn, collecting the output to fd's specified in collect_fds list. @@ -386,8 +441,8 @@ def process_exit_code(retval): :return: The exit code if it exit'd, the signal if it died from signalling. """ # If it got a signal, return the signal that was sent. - if retval & 0xff: - return (retval & 0xff) << 8 + if retval & 0xFF: + return (retval & 0xFF) << 8 # Otherwise, return its exit code. return retval >> 8 @@ -399,7 +454,7 @@ class ExecutionFailure(Exception): self.msg = msg def __str__(self): - return f'Execution Failure: {self.msg}' + return f"Execution Failure: {self.msg}" # cached capabilities @@ -411,7 +466,7 @@ def is_sandbox_capable(force=False): return is_sandbox_capable.cached_result except AttributeError: pass - if 'SANDBOX_ACTIVE' in os.environ: + if "SANDBOX_ACTIVE" in os.environ: # we can not spawn a sandbox inside another one res = False elif not (os.path.isfile(SANDBOX_BINARY) and access(SANDBOX_BINARY, os.X_OK)): @@ -432,5 +487,5 @@ def is_userpriv_capable(force=False): return is_userpriv_capable.cached_result except AttributeError: pass - res = is_userpriv_capable.cached_result = (os.getuid() == 0) + res = is_userpriv_capable.cached_result = os.getuid() == 0 return res diff --git a/src/snakeoil/sequences.py b/src/snakeoil/sequences.py index c041136..b80b101 100644 --- a/src/snakeoil/sequences.py +++ b/src/snakeoil/sequences.py @@ -1,9 +1,14 @@ """sequence related operations and classes""" __all__ = ( - 'unstable_unique', 'stable_unique', 'iter_stable_unique', - 'iflatten_instance', 'iflatten_func', 'ChainedLists', 'predicate_split', - 'split_negations', + "unstable_unique", + "stable_unique", + "iter_stable_unique", + "iflatten_instance", + "iflatten_func", + "ChainedLists", + "predicate_split", + "split_negations", ) from typing import Any, Callable, Iterable, Type @@ -89,7 +94,9 @@ def iter_stable_unique(iterable): break -def iflatten_instance(l: Iterable, skip_flattening: Iterable[Type] = (str, bytes)) -> Iterable: +def iflatten_instance( + l: Iterable, skip_flattening: Iterable[Type] = (str, bytes) +) -> Iterable: """collapse [[1],2] into [1,2] :param skip_flattening: list of classes to not descend through @@ -103,9 +110,10 @@ def iflatten_instance(l: Iterable, skip_flattening: Iterable[Type] = (str, bytes try: while True: x = next(iters) - if (hasattr(x, '__iter__') and not ( - isinstance(x, skip_flattening) or ( - isinstance(x, (str, bytes)) and len(x) == 1))): + if hasattr(x, "__iter__") and not ( + isinstance(x, skip_flattening) + or (isinstance(x, (str, bytes)) and len(x) == 1) + ): iters.appendleft(x) else: yield x @@ -128,7 +136,7 @@ def iflatten_func(l: Iterable, skip_func: Callable[[Any], bool]) -> Iterable: try: while True: x = next(iters) - if hasattr(x, '__iter__') and not skip_func(x): + if hasattr(x, "__iter__") and not skip_func(x): iters.appendleft(x) else: yield x @@ -164,6 +172,7 @@ class ChainedLists: ... TypeError: not mutable """ + __slots__ = ("_lists", "__weakref__") def __init__(self, *lists): @@ -258,7 +267,7 @@ def predicate_split(func, stream, key=None): def split_negations(iterable, func=str): - """"Split an iterable into negative and positive elements. + """ "Split an iterable into negative and positive elements. :param iterable: iterable targeted for splitting :param func: wrapper method to modify tokens @@ -267,7 +276,7 @@ def split_negations(iterable, func=str): """ neg, pos = [], [] for token in iterable: - if token[0] == '-': + if token[0] == "-": if len(token) == 1: raise ValueError("'-' negation without a token") token = token[1:] @@ -281,7 +290,7 @@ def split_negations(iterable, func=str): def split_elements(iterable, func=str): - """"Split an iterable into negative, neutral, and positive elements. + """ "Split an iterable into negative, neutral, and positive elements. :param iterable: iterable targeted for splitting :param func: wrapper method to modify tokens @@ -289,11 +298,11 @@ def split_elements(iterable, func=str): :return: Tuple containing negative, neutral, and positive element tuples, respectively. """ neg, neu, pos = [], [], [] - token_map = {'-': neg, '+': pos} + token_map = {"-": neg, "+": pos} for token in iterable: if token[0] in token_map: if len(token) == 1: - raise ValueError('%r without a token' % (token[0],)) + raise ValueError("%r without a token" % (token[0],)) l = token_map[token[0]] token = token[1:] else: diff --git a/src/snakeoil/stringio.py b/src/snakeoil/stringio.py index c17db92..1a392eb 100644 --- a/src/snakeoil/stringio.py +++ b/src/snakeoil/stringio.py @@ -27,7 +27,7 @@ is usable under both py2k and py3k. """ # TODO: deprecated, remove in 0.9.0 -__all__ = ('text_readonly', 'bytes_readonly') +__all__ = ("text_readonly", "bytes_readonly") import io diff --git a/src/snakeoil/strings.py b/src/snakeoil/strings.py index 7d7b2a8..d0dc58a 100644 --- a/src/snakeoil/strings.py +++ b/src/snakeoil/strings.py @@ -2,10 +2,10 @@ from .demandload import demand_compile_regexp -demand_compile_regexp('_whitespace_regex', r'^(?P\s+)') +demand_compile_regexp("_whitespace_regex", r"^(?P\s+)") -def pluralism(obj, none=None, singular='', plural='s'): +def pluralism(obj, none=None, singular="", plural="s"): """Return singular or plural suffix depending on object's length or value.""" # default to plural for empty objects, e.g. there are 0 repos if none is None: @@ -27,16 +27,16 @@ def pluralism(obj, none=None, singular='', plural='s'): def doc_dedent(s): """Support dedenting docstrings with initial line having no indentation.""" try: - lines = s.split('\n') + lines = s.split("\n") except AttributeError: - raise TypeError(f'{s!r} is not a string') + raise TypeError(f"{s!r} is not a string") if lines: # find first line with an indent if one exists for line in lines: if mo := _whitespace_regex.match(line): - indent = mo.group('indent') + indent = mo.group("indent") break else: - indent = '' + indent = "" len_i = len(indent) - return '\n'.join(x[len_i:] if x.startswith(indent) else x for x in lines) + return "\n".join(x[len_i:] if x.startswith(indent) else x for x in lines) diff --git a/src/snakeoil/tar.py b/src/snakeoil/tar.py index 7daa9eb..74061a5 100644 --- a/src/snakeoil/tar.py +++ b/src/snakeoil/tar.py @@ -43,13 +43,30 @@ class TarInfo(tarfile.TarInfo): :ivar uname: same as TarInfo.uname, just interned via a property. """ - if not hasattr(tarfile.TarInfo, '__slots__'): + if not hasattr(tarfile.TarInfo, "__slots__"): __slots__ = ( - "name", "mode", "uid", "gid", "size", "mtime", "chksum", "type", - "linkname", "_uname", "_gname", "devmajor", "devminor", "prefix", - "offset", "offset_data", "_buf", "sparse", "_link_target") + "name", + "mode", + "uid", + "gid", + "size", + "mtime", + "chksum", + "type", + "linkname", + "_uname", + "_gname", + "devmajor", + "devminor", + "prefix", + "offset", + "offset_data", + "_buf", + "sparse", + "_link_target", + ) else: - __slots__ = ('_buf', '_uname', '_gname') + __slots__ = ("_buf", "_uname", "_gname") def get_buf(self): return self.tobuf() diff --git a/src/snakeoil/test/__init__.py b/src/snakeoil/test/__init__.py index bb9381f..b93094a 100644 --- a/src/snakeoil/test/__init__.py +++ b/src/snakeoil/test/__init__.py @@ -13,18 +13,20 @@ from snakeoil import klass def random_str(length): """Return a random string of specified length.""" - return ''.join(random.choices(string.ascii_letters + string.digits, k=length)) + return "".join(random.choices(string.ascii_letters + string.digits, k=length)) def coverage(): """Extract coverage instance (if it exists) from the current running context.""" cov = None import inspect + try: import coverage + frame = inspect.currentframe() while frame is not None: - cov = getattr(frame.f_locals.get('self'), 'coverage', None) + cov = getattr(frame.f_locals.get("self"), "coverage", None) if isinstance(cov, coverage.coverage): break frame = frame.f_back @@ -33,7 +35,7 @@ def coverage(): return cov -@klass.patch('os._exit') +@klass.patch("os._exit") def _os_exit(orig_exit, val): """Monkeypatch os._exit() to save coverage data before exit.""" cov = coverage() @@ -51,7 +53,9 @@ def protect_process(functor, name=None): if os.environ.get(_PROTECT_ENV_VAR, False): return functor(self) if name is None: - name = f"{self.__class__.__module__}.{self.__class__.__name__}.{method_name}" + name = ( + f"{self.__class__.__module__}.{self.__class__.__name__}.{method_name}" + ) runner_path = __file__ if runner_path.endswith(".pyc") or runner_path.endswith(".pyo"): runner_path = runner_path.rsplit(".", maxsplit=1)[0] + ".py" @@ -59,11 +63,18 @@ def protect_process(functor, name=None): try: os.environ[_PROTECT_ENV_VAR] = "yes" args = [sys.executable, __file__, name] - p = subprocess.Popen(args, shell=False, env=os.environ.copy(), - stdout=subprocess.PIPE, stderr=subprocess.STDOUT) + p = subprocess.Popen( + args, + shell=False, + env=os.environ.copy(), + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + ) stdout, _ = p.communicate() ret = p.wait() - assert ret == 0, f"subprocess run: {args!r}\nnon zero exit: {ret}\nstdout:\n{stdout}" + assert ( + ret == 0 + ), f"subprocess run: {args!r}\nnon zero exit: {ret}\nstdout:\n{stdout}" finally: if wipe: os.environ.pop(_PROTECT_ENV_VAR, None) @@ -71,7 +82,7 @@ def protect_process(functor, name=None): for x in ("__doc__", "__name__"): if hasattr(functor, x): setattr(_inner_run, x, getattr(functor, x)) - method_name = getattr(functor, '__name__', None) + method_name = getattr(functor, "__name__", None) return _inner_run @@ -84,4 +95,4 @@ def hide_imports(*import_names: str): raise ImportError() return orig_import(name, *args, **kwargs) - return patch('builtins.__import__', side_effect=mock_import) + return patch("builtins.__import__", side_effect=mock_import) diff --git a/src/snakeoil/test/argparse_helpers.py b/src/snakeoil/test/argparse_helpers.py index 152e335..bfcdc3b 100644 --- a/src/snakeoil/test/argparse_helpers.py +++ b/src/snakeoil/test/argparse_helpers.py @@ -56,25 +56,24 @@ class Color(FormatterObject): self.color = color def __repr__(self): - return f'' + return f"" class Reset(FormatterObject): __inst_caching__ = True def __repr__(self): - return '' + return "" class Bold(FormatterObject): __inst_caching__ = True def __repr__(self): - return '' + return "" class ListStream(list): - def write(self, *args): stringlist = [] objectlist = [] @@ -82,13 +81,16 @@ class ListStream(list): if isinstance(arg, bytes): stringlist.append(arg) else: - objectlist.append(b''.join(stringlist)) + objectlist.append(b"".join(stringlist)) stringlist = [] objectlist.append(arg) - objectlist.append(b''.join(stringlist)) + objectlist.append(b"".join(stringlist)) # We use len because boolean ops shortcircuit - if (len(self) and isinstance(self[-1], bytes) and - isinstance(objectlist[0], bytes)): + if ( + len(self) + and isinstance(self[-1], bytes) + and isinstance(objectlist[0], bytes) + ): self[-1] = self[-1] + objectlist.pop(0) self.extend(objectlist) @@ -97,7 +99,6 @@ class ListStream(list): class FakeStreamFormatter(PlainTextFormatter): - def __init__(self): super().__init__(ListStream([])) self.reset = Reset() @@ -108,15 +109,15 @@ class FakeStreamFormatter(PlainTextFormatter): self.stream = ListStream([]) def fg(self, color=None): - return Color('fg', color) + return Color("fg", color) def bg(self, color=None): - return Color('bg', color) + return Color("bg", color) def get_text_stream(self): - return b''.join( - (x for x in self.stream - if not isinstance(x, FormatterObject))).decode('ascii') + return b"".join( + (x for x in self.stream if not isinstance(x, FormatterObject)) + ).decode("ascii") class ArgParseMixin: @@ -148,7 +149,7 @@ class ArgParseMixin: except Error as e: assert message == e.message else: - raise AssertionError('no error triggered') + raise AssertionError("no error triggered") def assertExit(self, status, message, *args, **kwargs): """Pass args, assert they trigger the right exit condition.""" @@ -158,7 +159,7 @@ class ArgParseMixin: assert message == e.message.strip() assert status == e.status else: - raise AssertionError('no exit triggered') + raise AssertionError("no exit triggered") def assertOut(self, out, *args, **kwargs): """Like :obj:`assertOutAndErr` but without err.""" @@ -182,17 +183,25 @@ class ArgParseMixin: main = self.get_main(options) main(options, outformatter, errformatter) diffs = [] - for name, strings, formatter in [('out', out, outformatter), - ('err', err, errformatter)]: + for name, strings, formatter in [ + ("out", out, outformatter), + ("err", err, errformatter), + ]: actual = formatter.get_text_stream() if strings: - expected = '\n'.join(strings) + expected = "\n".join(strings) else: - expected = '' + expected = "" if expected != actual: - diffs.extend(difflib.unified_diff( - strings, actual.split('\n')[:-1], - 'expected %s' % (name,), 'actual', lineterm='')) + diffs.extend( + difflib.unified_diff( + strings, + actual.split("\n")[:-1], + "expected %s" % (name,), + "actual", + lineterm="", + ) + ) if diffs: - raise AssertionError('\n' + '\n'.join(diffs)) + raise AssertionError("\n" + "\n".join(diffs)) return options diff --git a/src/snakeoil/test/eq_hash_inheritance.py b/src/snakeoil/test/eq_hash_inheritance.py index 5012f9d..eaa42a3 100644 --- a/src/snakeoil/test/eq_hash_inheritance.py +++ b/src/snakeoil/test/eq_hash_inheritance.py @@ -3,7 +3,7 @@ from . import mixins class Test(mixins.TargetedNamespaceWalker, mixins.KlassWalker): - target_namespace = 'snakeoil' + target_namespace = "snakeoil" singleton = object() @@ -26,8 +26,8 @@ class Test(mixins.TargetedNamespaceWalker, mixins.KlassWalker): # object sets __hash__/__eq__, which isn't usually # intended to be inherited/reused continue - eq = getattr(parent, '__eq__', self.singleton) - h = getattr(parent, '__hash__', self.singleton) + eq = getattr(parent, "__eq__", self.singleton) + h = getattr(parent, "__hash__", self.singleton) if eq == object.__eq__ and h == object.__hash__: continue if eq and h: @@ -37,10 +37,11 @@ class Test(mixins.TargetedNamespaceWalker, mixins.KlassWalker): # pylint: disable=undefined-loop-variable # 'parent' is guaranteed to be defined due to the 'else' clause above - assert getattr(cls, '__hash__') is not None, ( + assert getattr(cls, "__hash__") is not None, ( f"class '{cls.__module__}.{cls.__name__}' had its __hash__ reset, " "while it would've inherited __hash__ from parent " f"'{parent.__module__}.{parent.__name__}'; this occurs in py3k when " "__eq__ is defined alone. If this is desired behaviour, set " "__hash__intentionally_disabled__ to True to explicitly ignore this" - " class") + " class" + ) diff --git a/src/snakeoil/test/mixins.py b/src/snakeoil/test/mixins.py index aa66839..0648de1 100644 --- a/src/snakeoil/test/mixins.py +++ b/src/snakeoil/test/mixins.py @@ -16,23 +16,27 @@ class PythonNamespaceWalker: # This is for py3.2/PEP3149; dso's now have the interp + major/minor embedded # in the name. # TODO: update this for pypy's naming - abi_target = 'cpython-%i%i' % tuple(sys.version_info[:2]) + abi_target = "cpython-%i%i" % tuple(sys.version_info[:2]) - module_blacklist = frozenset({ - 'snakeoil.cli.arghparse', 'snakeoil.pickling', - }) + module_blacklist = frozenset( + { + "snakeoil.cli.arghparse", + "snakeoil.pickling", + } + ) def _default_module_blacklister(self, target): - return target in self.module_blacklist or target.startswith('snakeoil.dist') + return target in self.module_blacklist or target.startswith("snakeoil.dist") def walk_namespace(self, namespace, **kwds): - location = os.path.abspath(os.path.dirname( - self.poor_mans_load(namespace).__file__)) - return self.get_modules(self.recurse(location), namespace=namespace, - **kwds) - - def get_modules(self, feed, namespace=None, blacklist_func=None, - ignore_failed_imports=None): + location = os.path.abspath( + os.path.dirname(self.poor_mans_load(namespace).__file__) + ) + return self.get_modules(self.recurse(location), namespace=namespace, **kwds) + + def get_modules( + self, feed, namespace=None, blacklist_func=None, ignore_failed_imports=None + ): if ignore_failed_imports is None: ignore_failed_imports = self.ignore_all_import_failures if namespace is None: @@ -57,7 +61,7 @@ class PythonNamespaceWalker: raise def recurse(self, location, valid_namespace=True): - if os.path.dirname(location) == '__pycache__': + if os.path.dirname(location) == "__pycache__": # Shouldn't be possible, but make sure we avoid this if it manages # to occur. return @@ -78,10 +82,13 @@ class PythonNamespaceWalker: # file disappeared under our feet... lock file from # trial can cause this. ignore. import logging - logging.debug("file %r disappeared under our feet, ignoring", - os.path.join(location, x)) - seen = set(['__init__']) + logging.debug( + "file %r disappeared under our feet, ignoring", + os.path.join(location, x), + ) + + seen = set(["__init__"]) for x, st in stats: if not (x.startswith(".") or x.endswith("~")) and stat.S_ISREG(st): if x.endswith((".py", ".pyc", ".pyo", ".so")): @@ -89,8 +96,8 @@ class PythonNamespaceWalker: # Ensure we're not looking at a >=py3k .so which injects # the version name in... if y not in seen: - if '.' in y and x.endswith('.so'): - y, abi = x.rsplit('.', 1) + if "." in y and x.endswith(".so"): + y, abi = x.rsplit(".", 1) if abi != self.abi_target: continue seen.add(y) @@ -135,6 +142,7 @@ class TargetedNamespaceWalker(PythonNamespaceWalker): for _mod in self.walk_namespace(namespace): pass + class _classWalker: cls_blacklist = frozenset() @@ -173,7 +181,6 @@ class _classWalker: class SubclassWalker(_classWalker): - def walk_derivatives(self, cls, seen=None): if len(inspect.signature(cls.__subclasses__).parameters) != 0: return @@ -193,7 +200,6 @@ class SubclassWalker(_classWalker): class KlassWalker(_classWalker): - def walk_derivatives(self, cls, seen=None): if len(inspect.signature(cls.__subclasses__).parameters) != 0: return diff --git a/src/snakeoil/test/modules.py b/src/snakeoil/test/modules.py index 6b3a6a1..0ae116b 100644 --- a/src/snakeoil/test/modules.py +++ b/src/snakeoil/test/modules.py @@ -3,12 +3,12 @@ from . import mixins class ExportedModules(mixins.PythonNamespaceWalker): - target_namespace = 'snakeoil' + target_namespace = "snakeoil" def test__all__accuracy(self): failures = [] for module in self.walk_namespace(self.target_namespace): - for target in getattr(module, '__all__', ()): + for target in getattr(module, "__all__", ()): if not hasattr(module, target): failures.append((module, target)) assert not failures, f"nonexistent __all__ targets spotted: {failures}" diff --git a/src/snakeoil/test/slot_shadowing.py b/src/snakeoil/test/slot_shadowing.py index 3e260c2..fac6619 100644 --- a/src/snakeoil/test/slot_shadowing.py +++ b/src/snakeoil/test/slot_shadowing.py @@ -5,7 +5,7 @@ from . import mixins class SlotShadowing(mixins.TargetedNamespaceWalker, mixins.SubclassWalker): - target_namespace = 'snakeoil' + target_namespace = "snakeoil" err_if_slots_is_str = True err_if_slots_is_mutable = True @@ -22,20 +22,20 @@ class SlotShadowing(mixins.TargetedNamespaceWalker, mixins.SubclassWalker): @staticmethod def mk_name(kls): - return f'{kls.__module__}.{kls.__name__}' + return f"{kls.__module__}.{kls.__name__}" def _should_ignore(self, kls): return self.mk_name(kls).split(".")[0] != self.target_namespace def run_check(self, kls): - if getattr(kls, '__slotting_intentionally_disabled__', False): + if getattr(kls, "__slotting_intentionally_disabled__", False): return slotting = {} raw_slottings = {} for parent in self.recurse_parents(kls): - slots = getattr(parent, '__slots__', None) + slots = getattr(parent, "__slots__", None) if slots is None: continue @@ -49,14 +49,15 @@ class SlotShadowing(mixins.TargetedNamespaceWalker, mixins.SubclassWalker): for slot in slots: slotting.setdefault(slot, parent) - slots = getattr(kls, '__slots__', None) + slots = getattr(kls, "__slots__", None) if slots is None and not slotting: return if isinstance(slots, str): if self.err_if_slots_is_str: pytest.fail( - f"cls {kls!r}; slots is {slots!r} (should be a tuple or list)") + f"cls {kls!r}; slots is {slots!r} (should be a tuple or list)" + ) slots = (slots,) if slots is None: @@ -64,8 +65,7 @@ class SlotShadowing(mixins.TargetedNamespaceWalker, mixins.SubclassWalker): if not isinstance(slots, tuple): if self.err_if_slots_is_mutable: - pytest.fail( - f"cls {kls!r}; slots is {slots!r}- - should be a tuple") + pytest.fail(f"cls {kls!r}; slots is {slots!r}- - should be a tuple") slots = tuple(slots) if slots is None or (slots and slots in raw_slottings): @@ -74,9 +74,11 @@ class SlotShadowing(mixins.TargetedNamespaceWalker, mixins.SubclassWalker): # daftly copied the parents... thus defeating the purpose. pytest.fail( f"cls {kls!r}; slots is {slots!r}, seemingly inherited from " - f"{raw_slottings[slots]!r}; the derivative class should be __slots__ = ()") + f"{raw_slottings[slots]!r}; the derivative class should be __slots__ = ()" + ) for slot in slots: if slot in slotting: pytest.fail( - f"cls {kls!r}; slot {slot!r} was already defined at {slotting[slot]!r}") + f"cls {kls!r}; slot {slot!r} was already defined at {slotting[slot]!r}" + ) diff --git a/src/snakeoil/version.py b/src/snakeoil/version.py index fa3fc96..2fd65d2 100644 --- a/src/snakeoil/version.py +++ b/src/snakeoil/version.py @@ -28,43 +28,46 @@ def get_version(project, repo_file, api_version=None): version_info = None if api_version is None: try: - api_version = getattr(import_module(project), '__version__') + api_version = getattr(import_module(project), "__version__") except ImportError: - raise ValueError(f'no {project} module in the syspath') + raise ValueError(f"no {project} module in the syspath") try: - version_info = getattr( - import_module(f'{project}._verinfo'), 'version_info') + version_info = getattr(import_module(f"{project}._verinfo"), "version_info") except ImportError: # we're probably in a git repo path = os.path.dirname(os.path.abspath(repo_file)) version_info = get_git_version(path) if version_info is None: - s = '' - elif version_info['tag'] == api_version: + s = "" + elif version_info["tag"] == api_version: s = f" -- released {version_info['date']}" else: - rev = version_info['rev'][:7] - date = version_info['date'] - commits = version_info.get('commits', None) - commits = f'-{commits}' if commits is not None else '' - s = f'{commits}-g{rev} -- {date}' + rev = version_info["rev"][:7] + date = version_info["date"] + commits = version_info.get("commits", None) + commits = f"-{commits}" if commits is not None else "" + s = f"{commits}-g{rev} -- {date}" - _ver = f'{project} {api_version}{s}' + _ver = f"{project} {api_version}{s}" return _ver def _run_git(path, cmd): env = dict(os.environ) - for key in env.copy(): # pragma: no cover + for key in env.copy(): # pragma: no cover if key.startswith("LC_"): del env[key] env["LC_CTYPE"] = "C" env["LC_ALL"] = "C" r = subprocess.Popen( - ['git'] + list(cmd), stdout=subprocess.PIPE, env=env, - stderr=subprocess.DEVNULL, cwd=path) + ["git"] + list(cmd), + stdout=subprocess.PIPE, + env=env, + stderr=subprocess.DEVNULL, + cwd=path, + ) stdout = r.communicate()[0] return stdout, r.returncode @@ -83,21 +86,20 @@ def get_git_version(path): tag = _get_git_tag(path, data[0]) # get number of commits since most recent tag - stdout, ret = _run_git(path, ['describe', '--tags', '--abbrev=0']) + stdout, ret = _run_git(path, ["describe", "--tags", "--abbrev=0"]) prev_tag = None commits = None if ret == 0: prev_tag = stdout.decode().strip() - stdout, ret = _run_git( - path, ['log', '--oneline', f'{prev_tag}..HEAD']) + stdout, ret = _run_git(path, ["log", "--oneline", f"{prev_tag}..HEAD"]) if ret == 0: commits = len(stdout.decode().splitlines()) return { - 'rev': data[0], - 'date': data[1], - 'tag': tag, - 'commits': commits, + "rev": data[0], + "date": data[1], + "tag": tag, + "commits": commits, } except EnvironmentError as exc: # ENOENT is thrown when the git binary can't be found. @@ -107,14 +109,14 @@ def get_git_version(path): def _get_git_tag(path, rev): - stdout, _ = _run_git(path, ['name-rev', '--tag', rev]) + stdout, _ = _run_git(path, ["name-rev", "--tag", rev]) tag = stdout.decode().split() if len(tag) != 2: return None tag = tag[1] if not tag.startswith("tags/"): return None - tag = tag[len("tags/"):] + tag = tag[len("tags/") :] if tag.endswith("^0"): tag = tag[:-2] if tag.startswith("v"): diff --git a/src/snakeoil/weakrefs.py b/src/snakeoil/weakrefs.py index 1fb7e0b..83a79c0 100644 --- a/src/snakeoil/weakrefs.py +++ b/src/snakeoil/weakrefs.py @@ -13,7 +13,6 @@ def finalize_instance(obj, weakref_inst): class WeakRefProxy(BaseDelayedObject): - def __instantiate_proxy_instance__(self): obj = BaseDelayedObject.__instantiate_proxy_instance__(self) weakref = ref(self, partial(finalize_instance, obj)) diff --git a/tests/cli/test_arghparse.py b/tests/cli/test_arghparse.py index ccaa65e..4741d86 100644 --- a/tests/cli/test_arghparse.py +++ b/tests/cli/test_arghparse.py @@ -11,56 +11,60 @@ from snakeoil.test import argparse_helpers class TestArgparseDocs: - def test_add_argument_docs(self): # force using an unpatched version of argparse reload(argparse) parser = argparse.ArgumentParser() - parser.add_argument('--foo', action='store_true') + parser.add_argument("--foo", action="store_true") # vanilla argparse doesn't support docs kwargs with pytest.raises(TypeError): parser.add_argument( - '-b', '--blah', action='store_true', docs='Blah blah blah') + "-b", "--blah", action="store_true", docs="Blah blah blah" + ) with pytest.raises(TypeError): - parser.add_argument_group('fa', description='fa la la', docs='fa la la la') + parser.add_argument_group("fa", description="fa la la", docs="fa la la la") with pytest.raises(TypeError): - parser.add_mutually_exclusive_group('fee', description='fi', docs='fo fum') + parser.add_mutually_exclusive_group("fee", description="fi", docs="fo fum") # forcibly monkey-patch argparse to allow docs kwargs reload(arghparse) - default = 'baz baz' - docs = 'blah blah' + default = "baz baz" + docs = "blah blah" for enable_docs, expected_txt in ((False, default), (True, docs)): arghparse._generate_docs = enable_docs parser = argparse.ArgumentParser() subparsers = parser.add_subparsers(description=default, docs=docs) - subparser = subparsers.add_parser('foo', description=default, docs=docs) + subparser = subparsers.add_parser("foo", description=default, docs=docs) action = parser.add_argument( - '-b', '--blah', action='store_true', help=default, docs=docs) - arg_group = parser.add_argument_group('fa', description=default, docs=docs) + "-b", "--blah", action="store_true", help=default, docs=docs + ) + arg_group = parser.add_argument_group("fa", description=default, docs=docs) mut_arg_group = parser.add_mutually_exclusive_group() mut_action = mut_arg_group.add_argument( - '-f', '--fee', action='store_true', help=default, docs=docs) + "-f", "--fee", action="store_true", help=default, docs=docs + ) - assert getattr(parser._subparsers, 'description', None) == expected_txt - assert getattr(subparser, 'description', None) == expected_txt - assert getattr(action, 'help', None) == expected_txt - assert getattr(arg_group, 'description', None) == expected_txt - assert getattr(mut_action, 'help', None) == expected_txt + assert getattr(parser._subparsers, "description", None) == expected_txt + assert getattr(subparser, "description", None) == expected_txt + assert getattr(action, "help", None) == expected_txt + assert getattr(arg_group, "description", None) == expected_txt + assert getattr(mut_action, "help", None) == expected_txt # list/tuple-based docs arghparse._generate_docs = True - docs = 'foo bar' + docs = "foo bar" parser = argparse.ArgumentParser() list_action = parser.add_argument( - '-b', '--blah', action='store_true', help=default, docs=list(docs.split())) + "-b", "--blah", action="store_true", help=default, docs=list(docs.split()) + ) tuple_action = parser.add_argument( - '-c', '--cat', action='store_true', help=default, docs=tuple(docs.split())) - assert getattr(list_action, 'help', None) == 'foo\nbar' - assert getattr(tuple_action, 'help', None) == 'foo\nbar' + "-c", "--cat", action="store_true", help=default, docs=tuple(docs.split()) + ) + assert getattr(list_action, "help", None) == "foo\nbar" + assert getattr(tuple_action, "help", None) == "foo\nbar" class TestOptionalsParser: @@ -68,7 +72,9 @@ class TestOptionalsParser: # TODO: move this to a generic argparse fixture @pytest.fixture(autouse=True) def __setup_optionals_parser(self): - self.optionals_parser = argparse_helpers.mangle_parser(arghparse.OptionalsParser()) + self.optionals_parser = argparse_helpers.mangle_parser( + arghparse.OptionalsParser() + ) def test_no_args(self): args, unknown = self.optionals_parser.parse_known_optionals([]) @@ -76,14 +82,14 @@ class TestOptionalsParser: assert unknown == [] def test_only_positionals(self): - self.optionals_parser.add_argument('args') + self.optionals_parser.add_argument("args") args, unknown = self.optionals_parser.parse_known_optionals([]) - assert vars(args) == {'args': None} + assert vars(args) == {"args": None} assert unknown == [] def test_optionals(self): - self.optionals_parser.add_argument('--opt1') - self.optionals_parser.add_argument('args') + self.optionals_parser.add_argument("--opt1") + self.optionals_parser.add_argument("args") parse = self.optionals_parser.parse_known_optionals # no args @@ -92,37 +98,37 @@ class TestOptionalsParser: assert unknown == [] # only known optional - args, unknown = parse(['--opt1', 'yes']) - assert args.opt1 == 'yes' + args, unknown = parse(["--opt1", "yes"]) + assert args.opt1 == "yes" assert unknown == [] # unknown optional - args, unknown = parse(['--foo']) + args, unknown = parse(["--foo"]) assert args.opt1 is None - assert unknown == ['--foo'] + assert unknown == ["--foo"] # unknown optional and positional - args, unknown = parse(['--foo', 'arg']) + args, unknown = parse(["--foo", "arg"]) assert args.opt1 is None - assert unknown == ['--foo', 'arg'] + assert unknown == ["--foo", "arg"] # known optional with unknown optional - args, unknown = parse(['--opt1', 'yes', '--foo']) - assert args.opt1 == 'yes' - assert unknown == ['--foo'] + args, unknown = parse(["--opt1", "yes", "--foo"]) + assert args.opt1 == "yes" + assert unknown == ["--foo"] # different order - args, unknown = parse(['--foo', '--opt1', 'yes']) - assert args.opt1 == 'yes' - assert unknown == ['--foo'] + args, unknown = parse(["--foo", "--opt1", "yes"]) + assert args.opt1 == "yes" + assert unknown == ["--foo"] # known optional with unknown positional - args, unknown = parse(['--opt1', 'yes', 'arg']) - assert args.opt1 == 'yes' - assert unknown == ['arg'] + args, unknown = parse(["--opt1", "yes", "arg"]) + assert args.opt1 == "yes" + assert unknown == ["arg"] # known optionals parsing stops at the first positional arg - args, unknown = parse(['arg', '--opt1', 'yes']) + args, unknown = parse(["arg", "--opt1", "yes"]) assert args.opt1 is None - assert unknown == ['arg', '--opt1', 'yes'] + assert unknown == ["arg", "--opt1", "yes"] class TestCsvActionsParser: @@ -134,20 +140,19 @@ class TestCsvActionsParser: def test_bad_action(self): with pytest.raises(ValueError) as excinfo: - self.csv_parser.add_argument('--arg1', action='unknown') + self.csv_parser.add_argument("--arg1", action="unknown") assert 'unknown action "unknown"' == str(excinfo.value) def test_csv_actions(self): - self.csv_parser.add_argument('--arg1', action='csv') - self.csv_parser.add_argument('--arg2', action='csv_append') - self.csv_parser.add_argument('--arg3', action='csv_negations') - self.csv_parser.add_argument('--arg4', action='csv_negations_append') - self.csv_parser.add_argument('--arg5', action='csv_elements') - self.csv_parser.add_argument('--arg6', action='csv_elements_append') + self.csv_parser.add_argument("--arg1", action="csv") + self.csv_parser.add_argument("--arg2", action="csv_append") + self.csv_parser.add_argument("--arg3", action="csv_negations") + self.csv_parser.add_argument("--arg4", action="csv_negations_append") + self.csv_parser.add_argument("--arg5", action="csv_elements") + self.csv_parser.add_argument("--arg6", action="csv_elements_append") class TestArgumentParser(TestCsvActionsParser, TestOptionalsParser): - def test_debug(self): # debug passed parser = argparse_helpers.mangle_parser(arghparse.ArgumentParser(debug=True)) @@ -161,8 +166,10 @@ class TestArgumentParser(TestCsvActionsParser, TestOptionalsParser): assert namespace.debug is False # debug passed in sys.argv -- early debug attr on the parser instance is set - with mock.patch('sys.argv', ['script', '--debug']): - parser = argparse_helpers.mangle_parser(arghparse.ArgumentParser(debug=True)) + with mock.patch("sys.argv", ["script", "--debug"]): + parser = argparse_helpers.mangle_parser( + arghparse.ArgumentParser(debug=True) + ) assert parser.debug is True def test_debug_disabled(self): @@ -176,34 +183,36 @@ class TestArgumentParser(TestCsvActionsParser, TestOptionalsParser): # parser attribute still exists assert parser.debug is False # but namespace attribute doesn't - assert not hasattr(namespace, 'debug') + assert not hasattr(namespace, "debug") def test_verbosity(self): values = ( ([], 0), - (['-q'], -1), - (['--quiet'], -1), - (['-v'], 1), - (['--verbose'], 1), - (['-q', '-v'], 0), - (['--quiet', '--verbose'], 0), - (['-q', '-q'], -2), - (['-v', '-v'], 2), + (["-q"], -1), + (["--quiet"], -1), + (["-v"], 1), + (["--verbose"], 1), + (["-q", "-v"], 0), + (["--quiet", "--verbose"], 0), + (["-q", "-q"], -2), + (["-v", "-v"], 2), ) for args, val in values: - with mock.patch('sys.argv', ['script'] + args): + with mock.patch("sys.argv", ["script"] + args): parser = argparse_helpers.mangle_parser( - arghparse.ArgumentParser(quiet=True, verbose=True)) + arghparse.ArgumentParser(quiet=True, verbose=True) + ) namespace = parser.parse_args(args) - assert parser.verbosity == val, '{} failed'.format(args) - assert namespace.verbosity == val, '{} failed'.format(args) + assert parser.verbosity == val, "{} failed".format(args) + assert namespace.verbosity == val, "{} failed".format(args) def test_verbosity_disabled(self): parser = argparse_helpers.mangle_parser( - arghparse.ArgumentParser(quiet=False, verbose=False)) + arghparse.ArgumentParser(quiet=False, verbose=False) + ) # ensure the options aren't there if disabled - for args in ('-q', '--quiet', '-v', '--verbose'): + for args in ("-q", "--quiet", "-v", "--verbose"): with pytest.raises(argparse_helpers.Error): namespace = parser.parse_args([args]) @@ -211,17 +220,15 @@ class TestArgumentParser(TestCsvActionsParser, TestOptionalsParser): # parser attribute still exists assert parser.verbosity == 0 # but namespace attribute doesn't - assert not hasattr(namespace, 'verbosity') + assert not hasattr(namespace, "verbosity") class BaseArgparseOptions: - def setup_method(self, method): self.parser = argparse_helpers.mangle_parser(arghparse.ArgumentParser()) class TestStoreBoolAction(BaseArgparseOptions): - def setup_method(self, method): super().setup_method(method) self.parser.add_argument("--testing", action=arghparse.StoreBool, default=None) @@ -229,13 +236,13 @@ class TestStoreBoolAction(BaseArgparseOptions): def test_bool_disabled(self): for raw_val in ("n", "no", "false"): for allowed in (raw_val.upper(), raw_val.lower()): - namespace = self.parser.parse_args(['--testing=' + allowed]) + namespace = self.parser.parse_args(["--testing=" + allowed]) assert namespace.testing is False def test_bool_enabled(self): for raw_val in ("y", "yes", "true"): for allowed in (raw_val.upper(), raw_val.lower()): - namespace = self.parser.parse_args(['--testing=' + allowed]) + namespace = self.parser.parse_args(["--testing=" + allowed]) assert namespace.testing is True def test_bool_invalid(self): @@ -244,249 +251,244 @@ class TestStoreBoolAction(BaseArgparseOptions): class ParseStdinTest(BaseArgparseOptions): - def setup_method(self, method): super().setup_method(method) - self.parser.add_argument( - "testing", nargs='+', action=arghparse.ParseStdin) + self.parser.add_argument("testing", nargs="+", action=arghparse.ParseStdin) def test_none_invalid(self): with pytest.raises(argparse_helpers.Error): self.parser.parse_args([]) def test_non_stdin(self): - namespace = self.parser.parse_args(['foo']) - assert namespace.testing == ['foo'] + namespace = self.parser.parse_args(["foo"]) + assert namespace.testing == ["foo"] def test_non_stdin_multiple(self): - namespace = self.parser.parse_args(['foo', 'bar']) - assert namespace.testing == ['foo', 'bar'] + namespace = self.parser.parse_args(["foo", "bar"]) + assert namespace.testing == ["foo", "bar"] def test_stdin(self): # stdin is an interactive tty - with mock.patch('sys.stdin.isatty', return_value=True): + with mock.patch("sys.stdin.isatty", return_value=True): with pytest.raises(argparse_helpers.Error) as excinfo: - namespace = self.parser.parse_args(['-']) - assert 'only valid when piping data in' in str(excinfo.value) + namespace = self.parser.parse_args(["-"]) + assert "only valid when piping data in" in str(excinfo.value) # fake piping data in for readlines, expected in ( - ([], []), - ([' '], []), - (['\n'], []), - (['\n', '\n'], []), - (['foo'], ['foo']), - (['foo '], ['foo']), - (['foo\n'], ['foo']), - (['foo', 'bar', 'baz'], ['foo', 'bar', 'baz']), - (['\nfoo\n', ' bar ', '\nbaz'], ['\nfoo', ' bar', '\nbaz']), + ([], []), + ([" "], []), + (["\n"], []), + (["\n", "\n"], []), + (["foo"], ["foo"]), + (["foo "], ["foo"]), + (["foo\n"], ["foo"]), + (["foo", "bar", "baz"], ["foo", "bar", "baz"]), + (["\nfoo\n", " bar ", "\nbaz"], ["\nfoo", " bar", "\nbaz"]), ): - with mock.patch('sys.stdin') as stdin, \ - mock.patch("builtins.open", mock.mock_open()) as mock_file: + with mock.patch("sys.stdin") as stdin, mock.patch( + "builtins.open", mock.mock_open() + ) as mock_file: stdin.readlines.return_value = readlines stdin.isatty.return_value = False - namespace = self.parser.parse_args(['-']) + namespace = self.parser.parse_args(["-"]) mock_file.assert_called_once_with("/dev/tty") assert namespace.testing == expected class TestCommaSeparatedValuesAction(BaseArgparseOptions): - def setup_method(self, method): super().setup_method(method) self.test_values = ( - ('', []), - (',', []), - (',,', []), - ('a', ['a']), - ('a,b,-c', ['a', 'b', '-c']), + ("", []), + (",", []), + (",,", []), + ("a", ["a"]), + ("a,b,-c", ["a", "b", "-c"]), ) - self.action = 'csv' + self.action = "csv" self.single_expected = lambda x: x self.multi_expected = lambda x: x def test_parse_args(self): - self.parser.add_argument('--testing', action=self.action) + self.parser.add_argument("--testing", action=self.action) for raw_val, expected in self.test_values: - namespace = self.parser.parse_args(['--testing=' + raw_val]) + namespace = self.parser.parse_args(["--testing=" + raw_val]) assert namespace.testing == self.single_expected(expected) def test_parse_multi_args(self): - self.parser.add_argument('--testing', action=self.action) + self.parser.add_argument("--testing", action=self.action) for raw_val, expected in self.test_values: - namespace = self.parser.parse_args([ - '--testing=' + raw_val, '--testing=' + raw_val, - ]) + namespace = self.parser.parse_args( + [ + "--testing=" + raw_val, + "--testing=" + raw_val, + ] + ) assert namespace.testing == self.multi_expected(expected) class TestCommaSeparatedValuesAppendAction(TestCommaSeparatedValuesAction): - def setup_method(self, method): super().setup_method(method) - self.action = 'csv_append' + self.action = "csv_append" self.multi_expected = lambda x: x + x class TestCommaSeparatedNegationsAction(TestCommaSeparatedValuesAction): - def setup_method(self, method): super().setup_method(method) self.test_values = ( - ('', ([], [])), - (',', ([], [])), - (',,', ([], [])), - ('a', ([], ['a'])), - ('-a', (['a'], [])), - ('a,-b,-c,d', (['b', 'c'], ['a', 'd'])), + ("", ([], [])), + (",", ([], [])), + (",,", ([], [])), + ("a", ([], ["a"])), + ("-a", (["a"], [])), + ("a,-b,-c,d", (["b", "c"], ["a", "d"])), ) - self.bad_args = ('-',) - self.action = 'csv_negations' + self.bad_args = ("-",) + self.action = "csv_negations" def test_parse_bad_args(self): - self.parser.add_argument('--testing', action=self.action) + self.parser.add_argument("--testing", action=self.action) for arg in self.bad_args: with pytest.raises(argparse.ArgumentTypeError) as excinfo: - namespace = self.parser.parse_args(['--testing=' + arg]) - assert 'without a token' in str(excinfo.value) + namespace = self.parser.parse_args(["--testing=" + arg]) + assert "without a token" in str(excinfo.value) class TestCommaSeparatedNegationsAppendAction(TestCommaSeparatedNegationsAction): - def setup_method(self, method): super().setup_method(method) - self.action = 'csv_negations_append' + self.action = "csv_negations_append" self.multi_expected = lambda x: tuple(x + y for x, y in zip(x, x)) class TestCommaSeparatedElementsAction(TestCommaSeparatedNegationsAction): - def setup_method(self, method): super().setup_method(method) self.test_values = ( - ('', ([], [], [])), - (',', ([], [], [])), - (',,', ([], [], [])), - ('-a', (['a'], [], [])), - ('a', ([], ['a'], [])), - ('+a', ([], [], ['a'])), - ('a,-b,-c,d', (['b', 'c'], ['a', 'd'], [])), - ('a,-b,+c,-d,+e,f', (['b', 'd'], ['a', 'f'], ['c', 'e'])), + ("", ([], [], [])), + (",", ([], [], [])), + (",,", ([], [], [])), + ("-a", (["a"], [], [])), + ("a", ([], ["a"], [])), + ("+a", ([], [], ["a"])), + ("a,-b,-c,d", (["b", "c"], ["a", "d"], [])), + ("a,-b,+c,-d,+e,f", (["b", "d"], ["a", "f"], ["c", "e"])), ) - self.bad_values = ('-', '+') - self.action = 'csv_elements' + self.bad_values = ("-", "+") + self.action = "csv_elements" class TestCommaSeparatedElementsAppendAction(TestCommaSeparatedElementsAction): - def setup_method(self, method): super().setup_method(method) - self.action = 'csv_elements_append' + self.action = "csv_elements_append" self.multi_expected = lambda x: tuple(x + y for x, y in zip(x, x)) class TestExistentPathType(BaseArgparseOptions): - def setup_method(self, method): super().setup_method(method) - self.parser.add_argument('--path', type=arghparse.existent_path) + self.parser.add_argument("--path", type=arghparse.existent_path) def test_nonexistent(self): # nonexistent path arg raises an error with pytest.raises(argparse_helpers.Error): - self.parser.parse_args(['--path=/path/to/nowhere']) + self.parser.parse_args(["--path=/path/to/nowhere"]) def test_os_errors(self, tmpdir): # random OS/FS issues raise errors - with mock.patch('os.path.realpath') as realpath: - realpath.side_effect = OSError(19, 'Random OS error') + with mock.patch("os.path.realpath") as realpath: + realpath.side_effect = OSError(19, "Random OS error") with pytest.raises(argparse_helpers.Error): - self.parser.parse_args(['--path=%s' % tmpdir]) + self.parser.parse_args(["--path=%s" % tmpdir]) def test_regular_usage(self, tmpdir): - namespace = self.parser.parse_args(['--path=%s' % tmpdir]) + namespace = self.parser.parse_args(["--path=%s" % tmpdir]) assert namespace.path == str(tmpdir) class TestExistentDirType(BaseArgparseOptions): - def setup_method(self, method): super().setup_method(method) - self.parser.add_argument('--path', type=arghparse.existent_dir) + self.parser.add_argument("--path", type=arghparse.existent_dir) def test_nonexistent(self): # nonexistent path arg raises an error with pytest.raises(argparse_helpers.Error): - self.parser.parse_args(['--path=/path/to/nowhere']) + self.parser.parse_args(["--path=/path/to/nowhere"]) def test_os_errors(self, tmp_path): # random OS/FS issues raise errors - with mock.patch('os.path.realpath') as realpath: - realpath.side_effect = OSError(19, 'Random OS error') + with mock.patch("os.path.realpath") as realpath: + realpath.side_effect = OSError(19, "Random OS error") with pytest.raises(argparse_helpers.Error): - self.parser.parse_args([f'--path={tmp_path}']) + self.parser.parse_args([f"--path={tmp_path}"]) def test_file_path(self, tmp_path): - f = tmp_path / 'file' + f = tmp_path / "file" f.touch() with pytest.raises(argparse_helpers.Error): - self.parser.parse_args([f'--path={f}']) + self.parser.parse_args([f"--path={f}"]) def test_regular_usage(self, tmp_path): - namespace = self.parser.parse_args([f'--path={tmp_path}']) + namespace = self.parser.parse_args([f"--path={tmp_path}"]) assert namespace.path == str(tmp_path) class TestNamespace: - def setup_method(self, method): self.parser = argparse_helpers.mangle_parser(arghparse.ArgumentParser()) def test_pop(self): - self.parser.set_defaults(test='test') + self.parser.set_defaults(test="test") namespace = self.parser.parse_args([]) - assert namespace.pop('test') == 'test' + assert namespace.pop("test") == "test" # re-popping raises an exception since the attr has been removed with pytest.raises(AttributeError): - namespace.pop('test') + namespace.pop("test") # popping a nonexistent attr with a fallback returns the fallback - assert namespace.pop('nonexistent', 'foo') == 'foo' + assert namespace.pop("nonexistent", "foo") == "foo" def test_collapse_delayed(self): def _delayed_val(namespace, attr, val): setattr(namespace, attr, val) - self.parser.set_defaults(delayed=arghparse.DelayedValue(partial(_delayed_val, val=42))) + + self.parser.set_defaults( + delayed=arghparse.DelayedValue(partial(_delayed_val, val=42)) + ) namespace = self.parser.parse_args([]) assert namespace.delayed == 42 def test_bool(self): namespace = arghparse.Namespace() assert not namespace - namespace.arg = 'foo' + namespace.arg = "foo" assert namespace class TestManHelpAction: - def test_help(self, capsys): parser = argparse_helpers.mangle_parser(arghparse.ArgumentParser()) - with mock.patch('subprocess.Popen') as popen: + with mock.patch("subprocess.Popen") as popen: # --help long option tries man page first before falling back to help output with pytest.raises(argparse_helpers.Exit): - namespace = parser.parse_args(['--help']) + namespace = parser.parse_args(["--help"]) popen.assert_called_once() - assert popen.call_args[0][0][0] == 'man' + assert popen.call_args[0][0][0] == "man" captured = capsys.readouterr() - assert captured.out.strip().startswith('usage: ') + assert captured.out.strip().startswith("usage: ") popen.reset_mock() # -h short option just displays the regular help output with pytest.raises(argparse_helpers.Exit): - namespace = parser.parse_args(['-h']) + namespace = parser.parse_args(["-h"]) popen.assert_not_called() captured = capsys.readouterr() - assert captured.out.strip().startswith('usage: ') + assert captured.out.strip().startswith("usage: ") popen.reset_mock() diff --git a/tests/cli/test_input.py b/tests/cli/test_input.py index 8efb1f5..2f15f25 100644 --- a/tests/cli/test_input.py +++ b/tests/cli/test_input.py @@ -9,12 +9,11 @@ from snakeoil.test.argparse_helpers import FakeStreamFormatter @pytest.fixture def mocked_input(): - with mock.patch('builtins.input') as mocked_input: + with mock.patch("builtins.input") as mocked_input: yield mocked_input class TestUserQuery: - @pytest.fixture(autouse=True) def __setup(self): self.out = FakeStreamFormatter() @@ -22,98 +21,104 @@ class TestUserQuery: self.query = partial(userquery, out=self.out, err=self.err) def test_default_answer(self, mocked_input): - mocked_input.return_value = '' - assert self.query('foo') == True + mocked_input.return_value = "" + assert self.query("foo") == True def test_tuple_prompt(self, mocked_input): - mocked_input.return_value = '' - prompt = 'perhaps a tuple' + mocked_input.return_value = "" + prompt = "perhaps a tuple" assert self.query(tuple(prompt.split())) == True - output = ''.join(prompt.split()) - assert self.out.get_text_stream().strip().split('\n')[0][:len(output)] == output + output = "".join(prompt.split()) + assert ( + self.out.get_text_stream().strip().split("\n")[0][: len(output)] == output + ) def test_no_default_answer(self, mocked_input): responses = { - 'a': ('z', 'Yes'), - 'b': ('y', 'No'), + "a": ("z", "Yes"), + "b": ("y", "No"), } # no default answer returns None for empty input - mocked_input.return_value = '' - assert self.query('foo', responses=responses) == None - mocked_input.return_value = 'a' - assert self.query('foo', responses=responses) == 'z' - mocked_input.return_value = 'b' - assert self.query('foo', responses=responses) == 'y' + mocked_input.return_value = "" + assert self.query("foo", responses=responses) == None + mocked_input.return_value = "a" + assert self.query("foo", responses=responses) == "z" + mocked_input.return_value = "b" + assert self.query("foo", responses=responses) == "y" def test_ambiguous_input(self, mocked_input): responses = { - 'a': ('z', 'Yes'), - 'A': ('y', 'No'), + "a": ("z", "Yes"), + "A": ("y", "No"), } - mocked_input.return_value = 'a' + mocked_input.return_value = "a" with pytest.raises(NoChoice): - self.query('foo', responses=responses) - error_output = self.err.get_text_stream().strip().split('\n')[1] - expected = 'Response %r is ambiguous (%s)' % ( - mocked_input.return_value, ', '.join(sorted(responses.keys()))) + self.query("foo", responses=responses) + error_output = self.err.get_text_stream().strip().split("\n")[1] + expected = "Response %r is ambiguous (%s)" % ( + mocked_input.return_value, + ", ".join(sorted(responses.keys())), + ) assert error_output == expected def test_default_correct_input(self, mocked_input): - for input, output in (('no', False), - ('No', False), - ('yes', True), - ('Yes', True)): + for input, output in ( + ("no", False), + ("No", False), + ("yes", True), + ("Yes", True), + ): mocked_input.return_value = input - assert self.query('foo') == output + assert self.query("foo") == output def test_default_answer_no_matches(self, mocked_input): - mocked_input.return_value = '' + mocked_input.return_value = "" with pytest.raises(ValueError): - self.query('foo', default_answer='foo') + self.query("foo", default_answer="foo") assert self.out.stream == [] def test_custom_default_answer(self, mocked_input): - mocked_input.return_value = '' - assert self.query('foo', default_answer=False) == False + mocked_input.return_value = "" + assert self.query("foo", default_answer=False) == False def test_eof_nochoice(self, mocked_input): # user hits ctrl-d mocked_input.side_effect = EOFError with pytest.raises(NoChoice): - self.query('foo') - output = self.out.get_text_stream().strip().split('\n')[1] - expected = 'Not answerable: EOF on STDIN' + self.query("foo") + output = self.out.get_text_stream().strip().split("\n")[1] + expected = "Not answerable: EOF on STDIN" assert output == expected def test_stdin_closed_nochoice(self, mocked_input): - mocked_input.side_effect = IOError(errno.EBADF, '') + mocked_input.side_effect = IOError(errno.EBADF, "") with pytest.raises(NoChoice): - self.query('foo') - output = self.out.get_text_stream().strip().split('\n')[1] - expected = 'Not answerable: STDIN is either closed, or not readable' + self.query("foo") + output = self.out.get_text_stream().strip().split("\n")[1] + expected = "Not answerable: STDIN is either closed, or not readable" assert output == expected def test_unhandled_ioerror(self, mocked_input): - mocked_input.side_effect = IOError(errno.ENODEV, '') + mocked_input.side_effect = IOError(errno.ENODEV, "") with pytest.raises(IOError): - self.query('foo') + self.query("foo") def test_bad_choice_limit(self, mocked_input): # user hits enters a bad choice 3 times in a row - mocked_input.return_value = 'bad' + mocked_input.return_value = "bad" with pytest.raises(NoChoice): - self.query('foo') + self.query("foo") assert mocked_input.call_count == 3 - output = self.err.get_text_stream().strip().split('\n')[1] + output = self.err.get_text_stream().strip().split("\n")[1] expected = "Sorry, response %r not understood." % (mocked_input.return_value,) assert output == expected def test_custom_choice_limit(self, mocked_input): # user hits enters a bad choice 5 times in a row - mocked_input.return_value = 'haha' + mocked_input.return_value = "haha" with pytest.raises(NoChoice): - self.query('foo', limit=5) + self.query("foo", limit=5) assert mocked_input.call_count == 5 - output = self.err.get_text_stream().strip().split('\n')[1] + output = self.err.get_text_stream().strip().split("\n")[1] expected = "Sorry, response %r not understood." % (mocked_input.return_value,) assert output == expected diff --git a/tests/compression/__init__.py b/tests/compression/__init__.py index 3b70dcb..0bf26d0 100644 --- a/tests/compression/__init__.py +++ b/tests/compression/__init__.py @@ -4,78 +4,100 @@ import pytest from snakeoil import compression from snakeoil.process import CommandNotFound, find_binary + def hide_binary(*binaries: str): def mock_find_binary(name): if name in binaries: raise CommandNotFound(name) return find_binary(name) - return patch('snakeoil.process.find_binary', side_effect=mock_find_binary) + return patch("snakeoil.process.find_binary", side_effect=mock_find_binary) class Base: - module: str = '' - decompressed_test_data: bytes = b'' - compressed_test_data: bytes = b'' + module: str = "" + decompressed_test_data: bytes = b"" + compressed_test_data: bytes = b"" def decompress(self, data: bytes) -> bytes: - raise NotImplementedError(self, 'decompress') + raise NotImplementedError(self, "decompress") - @pytest.mark.parametrize('parallelize', (True, False)) - @pytest.mark.parametrize('level', (1, 9)) + @pytest.mark.parametrize("parallelize", (True, False)) + @pytest.mark.parametrize("level", (1, 9)) def test_compress_data(self, level, parallelize): - compressed = compression.compress_data(self.module, self.decompressed_test_data, level=level, parallelize=parallelize) + compressed = compression.compress_data( + self.module, + self.decompressed_test_data, + level=level, + parallelize=parallelize, + ) assert compressed assert self.decompress(compressed) == self.decompressed_test_data - @pytest.mark.parametrize('parallelize', (True, False)) + @pytest.mark.parametrize("parallelize", (True, False)) def test_decompress_data(self, parallelize): - assert self.decompressed_test_data == compression.decompress_data(self.module, self.compressed_test_data, parallelize=parallelize) + assert self.decompressed_test_data == compression.decompress_data( + self.module, self.compressed_test_data, parallelize=parallelize + ) - @pytest.mark.parametrize('parallelize', (True, False)) - @pytest.mark.parametrize('level', (1, 9)) + @pytest.mark.parametrize("parallelize", (True, False)) + @pytest.mark.parametrize("level", (1, 9)) def test_compress_handle(self, tmp_path, level, parallelize): - path = tmp_path / f'test.{self.module}' + path = tmp_path / f"test.{self.module}" - stream = compression.compress_handle(self.module, str(path), level=level, parallelize=parallelize) + stream = compression.compress_handle( + self.module, str(path), level=level, parallelize=parallelize + ) stream.write(self.decompressed_test_data) stream.close() assert self.decompress(path.read_bytes()) == self.decompressed_test_data with path.open("wb") as file: - stream = compression.compress_handle(self.module, file, level=level, parallelize=parallelize) + stream = compression.compress_handle( + self.module, file, level=level, parallelize=parallelize + ) stream.write(self.decompressed_test_data) stream.close() assert self.decompress(path.read_bytes()) == self.decompressed_test_data with path.open("wb") as file: - stream = compression.compress_handle(self.module, file.fileno(), level=level, parallelize=parallelize) + stream = compression.compress_handle( + self.module, file.fileno(), level=level, parallelize=parallelize + ) stream.write(self.decompressed_test_data) stream.close() assert self.decompress(path.read_bytes()) == self.decompressed_test_data with pytest.raises(TypeError): - compression.compress_handle(self.module, b'', level=level, parallelize=parallelize) + compression.compress_handle( + self.module, b"", level=level, parallelize=parallelize + ) - @pytest.mark.parametrize('parallelize', (True, False)) + @pytest.mark.parametrize("parallelize", (True, False)) def test_decompress_handle(self, tmp_path, parallelize): - path = tmp_path / f'test.{self.module}' + path = tmp_path / f"test.{self.module}" path.write_bytes(self.compressed_test_data) - stream = compression.decompress_handle(self.module, str(path), parallelize=parallelize) + stream = compression.decompress_handle( + self.module, str(path), parallelize=parallelize + ) assert stream.read() == self.decompressed_test_data stream.close() with path.open("rb") as file: - stream = compression.decompress_handle(self.module, file, parallelize=parallelize) + stream = compression.decompress_handle( + self.module, file, parallelize=parallelize + ) assert stream.read() == self.decompressed_test_data stream.close() with path.open("rb") as file: - stream = compression.decompress_handle(self.module, file.fileno(), parallelize=parallelize) + stream = compression.decompress_handle( + self.module, file.fileno(), parallelize=parallelize + ) assert stream.read() == self.decompressed_test_data stream.close() with pytest.raises(TypeError): - compression.decompress_handle(self.module, b'', parallelize=parallelize) + compression.decompress_handle(self.module, b"", parallelize=parallelize) diff --git a/tests/compression/test_bzip2.py b/tests/compression/test_bzip2.py index f3093d0..9fdffd9 100644 --- a/tests/compression/test_bzip2.py +++ b/tests/compression/test_bzip2.py @@ -10,28 +10,29 @@ from . import Base, hide_binary def test_no_native(): - with hide_imports('bz2'): + with hide_imports("bz2"): importlib.reload(_bzip2) assert not _bzip2.native def test_missing_bzip2_binary(): - with hide_binary('bzip2'): - with pytest.raises(CommandNotFound, match='bzip2'): + with hide_binary("bzip2"): + with pytest.raises(CommandNotFound, match="bzip2"): importlib.reload(_bzip2) def test_missing_lbzip2_binary(): - with hide_binary('lbzip2'): + with hide_binary("lbzip2"): importlib.reload(_bzip2) assert not _bzip2.parallelizable + class Bzip2Base(Base): - module = 'bzip2' - decompressed_test_data = b'Some text here\n' + module = "bzip2" + decompressed_test_data = b"Some text here\n" compressed_test_data = ( - b'BZh91AY&SY\x1bM\x00\x02\x00\x00\x01\xd3\x80\x00\x10@\x00\x08\x00\x02' + b"BZh91AY&SY\x1bM\x00\x02\x00\x00\x01\xd3\x80\x00\x10@\x00\x08\x00\x02" b'B\x94@ \x00"\r\x03\xd4\x0c \t!\x1b\xb7\x80u/\x17rE8P\x90\x1bM\x00\x02' ) @@ -40,37 +41,36 @@ class Bzip2Base(Base): class TestStdlib(Bzip2Base): - - @pytest.fixture(autouse=True, scope='class') + @pytest.fixture(autouse=True, scope="class") def _setup(self): try: - find_binary('bzip2') + find_binary("bzip2") except CommandNotFound: - pytest.skip('bzip2 binary not found') - with hide_binary('lbzip2'): + pytest.skip("bzip2 binary not found") + with hide_binary("lbzip2"): importlib.reload(_bzip2) yield class TestBzip2(Bzip2Base): - - @pytest.fixture(autouse=True, scope='class') + @pytest.fixture(autouse=True, scope="class") def _setup(self): - with hide_binary('lbzip2'): + with hide_binary("lbzip2"): importlib.reload(_bzip2) yield class TestLbzip2(Bzip2Base): - - @pytest.fixture(autouse=True, scope='class') + @pytest.fixture(autouse=True, scope="class") def _setup(self): try: - find_binary('lbzip2') + find_binary("lbzip2") except CommandNotFound: - pytest.skip('lbzip2 binary not found') + pytest.skip("lbzip2 binary not found") importlib.reload(_bzip2) def test_bad_level(self): with pytest.raises(ValueError, match='unknown option "-0"'): - _bzip2.compress_data(self.decompressed_test_data, level=90, parallelize=True) + _bzip2.compress_data( + self.decompressed_test_data, level=90, parallelize=True + ) diff --git a/tests/compression/test_init.py b/tests/compression/test_init.py index f3a4027..f1fe5bd 100644 --- a/tests/compression/test_init.py +++ b/tests/compression/test_init.py @@ -11,78 +11,77 @@ from . import hide_binary @pytest.mark.skipif(sys.platform == "darwin", reason="darwin fails with bzip2") class TestArComp: - - @pytest.fixture(scope='class') + @pytest.fixture(scope="class") def tar_file(self, tmp_path_factory): data = tmp_path_factory.mktemp("data") - (data / 'file1').write_text('Hello world') - (data / 'file2').write_text('Larry the Cow') - path = data / 'test 1.tar' - subprocess.run(['tar', 'cf', str(path), 'file1', 'file2'], cwd=data, check=True) - (data / 'file1').unlink() - (data / 'file2').unlink() + (data / "file1").write_text("Hello world") + (data / "file2").write_text("Larry the Cow") + path = data / "test 1.tar" + subprocess.run(["tar", "cf", str(path), "file1", "file2"], cwd=data, check=True) + (data / "file1").unlink() + (data / "file2").unlink() return str(path) - @pytest.fixture(scope='class') + @pytest.fixture(scope="class") def tar_bz2_file(self, tar_file): - subprocess.run(['bzip2', '-z', '-k', tar_file], check=True) + subprocess.run(["bzip2", "-z", "-k", tar_file], check=True) return tar_file + ".bz2" - @pytest.fixture(scope='class') + @pytest.fixture(scope="class") def tbz2_file(self, tar_bz2_file): - new_path = tar_bz2_file.replace('.tar.bz2', '.tbz2') + new_path = tar_bz2_file.replace(".tar.bz2", ".tbz2") shutil.copyfile(tar_bz2_file, new_path) return new_path - @pytest.fixture(scope='class') + @pytest.fixture(scope="class") def lzma_file(self, tmp_path_factory): - data = (tmp_path_factory.mktemp("data") / 'test 2.lzma') - with data.open('wb') as f: - subprocess.run(['lzma'], check=True, input=b'Hello world', stdout=f) + data = tmp_path_factory.mktemp("data") / "test 2.lzma" + with data.open("wb") as f: + subprocess.run(["lzma"], check=True, input=b"Hello world", stdout=f) return str(data) def test_unknown_extenstion(self, tmp_path): - file = tmp_path / 'test.file' - with pytest.raises(ArCompError, match='unknown compression file extension'): - ArComp(file, ext='.foo') + file = tmp_path / "test.file" + with pytest.raises(ArCompError, match="unknown compression file extension"): + ArComp(file, ext=".foo") def test_missing_tar(self, tmp_path, tar_file): - with hide_binary('tar'), chdir(tmp_path): - with pytest.raises(ArCompError, match='required binary not found'): - ArComp(tar_file, ext='.tar').unpack(dest=tmp_path) + with hide_binary("tar"), chdir(tmp_path): + with pytest.raises(ArCompError, match="required binary not found"): + ArComp(tar_file, ext=".tar").unpack(dest=tmp_path) def test_tar(self, tmp_path, tar_file): with chdir(tmp_path): - ArComp(tar_file, ext='.tar').unpack(dest=tmp_path) - assert (tmp_path / 'file1').read_text() == 'Hello world' - assert (tmp_path / 'file2').read_text() == 'Larry the Cow' + ArComp(tar_file, ext=".tar").unpack(dest=tmp_path) + assert (tmp_path / "file1").read_text() == "Hello world" + assert (tmp_path / "file2").read_text() == "Larry the Cow" def test_tar_bz2(self, tmp_path, tar_bz2_file): with chdir(tmp_path): - ArComp(tar_bz2_file, ext='.tar.bz2').unpack(dest=tmp_path) - assert (tmp_path / 'file1').read_text() == 'Hello world' - assert (tmp_path / 'file2').read_text() == 'Larry the Cow' + ArComp(tar_bz2_file, ext=".tar.bz2").unpack(dest=tmp_path) + assert (tmp_path / "file1").read_text() == "Hello world" + assert (tmp_path / "file2").read_text() == "Larry the Cow" def test_tbz2(self, tmp_path, tbz2_file): with chdir(tmp_path): - ArComp(tbz2_file, ext='.tbz2').unpack(dest=tmp_path) - assert (tmp_path / 'file1').read_text() == 'Hello world' - assert (tmp_path / 'file2').read_text() == 'Larry the Cow' + ArComp(tbz2_file, ext=".tbz2").unpack(dest=tmp_path) + assert (tmp_path / "file1").read_text() == "Hello world" + assert (tmp_path / "file2").read_text() == "Larry the Cow" def test_fallback_tbz2(self, tmp_path, tbz2_file): with hide_binary(*next(zip(*_TarBZ2.compress_binary[:-1]))): with chdir(tmp_path): - ArComp(tbz2_file, ext='.tbz2').unpack(dest=tmp_path) - assert (tmp_path / 'file1').read_text() == 'Hello world' - assert (tmp_path / 'file2').read_text() == 'Larry the Cow' + ArComp(tbz2_file, ext=".tbz2").unpack(dest=tmp_path) + assert (tmp_path / "file1").read_text() == "Hello world" + assert (tmp_path / "file2").read_text() == "Larry the Cow" def test_no_fallback_tbz2(self, tmp_path, tbz2_file): with hide_binary(*next(zip(*_TarBZ2.compress_binary))), chdir(tmp_path): - with pytest.raises(ArCompError, match='no compression binary'): - ArComp(tbz2_file, ext='.tbz2').unpack(dest=tmp_path) + with pytest.raises(ArCompError, match="no compression binary"): + ArComp(tbz2_file, ext=".tbz2").unpack(dest=tmp_path) def test_lzma(self, tmp_path, lzma_file): - dest = tmp_path / 'file' + dest = tmp_path / "file" with chdir(tmp_path): - ArComp(lzma_file, ext='.lzma').unpack(dest=dest) - assert (dest).read_bytes() == b'Hello world' + ArComp(lzma_file, ext=".lzma").unpack(dest=dest) + assert (dest).read_bytes() == b"Hello world" diff --git a/tests/compression/test_xz.py b/tests/compression/test_xz.py index f8417b3..0af7c64 100644 --- a/tests/compression/test_xz.py +++ b/tests/compression/test_xz.py @@ -10,26 +10,26 @@ from . import Base, hide_binary def test_no_native(): - with hide_imports('lzma'): + with hide_imports("lzma"): importlib.reload(_xz) assert not _xz.native def test_missing_xz_binary(): - with hide_binary('xz'): - with pytest.raises(CommandNotFound, match='xz'): + with hide_binary("xz"): + with pytest.raises(CommandNotFound, match="xz"): importlib.reload(_xz) class XzBase(Base): - module = 'xz' - decompressed_test_data = b'Some text here\n' * 2 + module = "xz" + decompressed_test_data = b"Some text here\n" * 2 compressed_test_data = ( - b'\xfd7zXZ\x00\x00\x04\xe6\xd6\xb4F\x04\xc0\x1e\x1e!\x01\x16\x00\x00\x00' - b'\x00\x00\x00\x00\x00\x00j\xf6\x947\xe0\x00\x1d\x00\x16]\x00)\x9b\xc9\xa6g' - b'Bw\x8c\xb3\x9eA\x9a\xbeT\xc9\xfa\xe3\x19\x8f(\x00\x00\x00\x00\x00\x96N' - b'\xa8\x8ed\xa2WH\x00\x01:\x1e1V \xff\x1f\xb6\xf3}\x01\x00\x00\x00\x00\x04YZ' + b"\xfd7zXZ\x00\x00\x04\xe6\xd6\xb4F\x04\xc0\x1e\x1e!\x01\x16\x00\x00\x00" + b"\x00\x00\x00\x00\x00\x00j\xf6\x947\xe0\x00\x1d\x00\x16]\x00)\x9b\xc9\xa6g" + b"Bw\x8c\xb3\x9eA\x9a\xbeT\xc9\xfa\xe3\x19\x8f(\x00\x00\x00\x00\x00\x96N" + b"\xa8\x8ed\xa2WH\x00\x01:\x1e1V \xff\x1f\xb6\xf3}\x01\x00\x00\x00\x00\x04YZ" ) def decompress(self, data: bytes) -> bytes: @@ -37,20 +37,18 @@ class XzBase(Base): class TestStdlib(XzBase): - - @pytest.fixture(autouse=True, scope='class') + @pytest.fixture(autouse=True, scope="class") def _setup(self): try: - find_binary('xz') + find_binary("xz") except CommandNotFound: - pytest.skip('xz binary not found') + pytest.skip("xz binary not found") importlib.reload(_xz) class TestXz(XzBase): - - @pytest.fixture(autouse=True, scope='class') + @pytest.fixture(autouse=True, scope="class") def _setup(self): - with hide_imports('lzma'): + with hide_imports("lzma"): importlib.reload(_xz) yield diff --git a/tests/test_bash.py b/tests/test_bash.py index ec9df53..3d2157b 100644 --- a/tests/test_bash.py +++ b/tests/test_bash.py @@ -1,128 +1,138 @@ from io import StringIO import pytest -from snakeoil.bash import (BashParseError, iter_read_bash, read_bash, - read_bash_dict, read_dict) +from snakeoil.bash import ( + BashParseError, + iter_read_bash, + read_bash, + read_bash_dict, + read_dict, +) class TestBashCommentStripping: - def test_iter_read_bash(self): - output = iter_read_bash(StringIO( - '\n' - '# hi I am a comment\n' - 'I am not \n' - ' asdf # inline comment\n')) - assert list(output) == ['I am not', 'asdf'] + output = iter_read_bash( + StringIO( + "\n" "# hi I am a comment\n" "I am not \n" " asdf # inline comment\n" + ) + ) + assert list(output) == ["I am not", "asdf"] - output = iter_read_bash(StringIO( - 'inline # comment '), allow_inline_comments=False) - assert list(output) == ['inline # comment'] + output = iter_read_bash( + StringIO("inline # comment "), allow_inline_comments=False + ) + assert list(output) == ["inline # comment"] def test_iter_read_bash_line_cont(self): - output = iter_read_bash(StringIO( - '\n' - '# hi I am a comment\\\n' - 'I am not \\\n' - 'a comment \n' - ' asdf # inline comment\\\n'), - allow_line_cont=True) - assert list(output) == ['I am not a comment', 'asdf'] + output = iter_read_bash( + StringIO( + "\n" + "# hi I am a comment\\\n" + "I am not \\\n" + "a comment \n" + " asdf # inline comment\\\n" + ), + allow_line_cont=True, + ) + assert list(output) == ["I am not a comment", "asdf"] # continuation into inline comment - output = iter_read_bash(StringIO( - '\n' - '# hi I am a comment\n' - 'I am \\\n' - 'not a \\\n' - 'comment # inline comment\n'), - allow_line_cont=True) - assert list(output) == ['I am not a comment'] + output = iter_read_bash( + StringIO( + "\n" + "# hi I am a comment\n" + "I am \\\n" + "not a \\\n" + "comment # inline comment\n" + ), + allow_line_cont=True, + ) + assert list(output) == ["I am not a comment"] # ends with continuation - output = iter_read_bash(StringIO( - '\n' - '# hi I am a comment\n' - 'I am \\\n' - '\\\n' - 'not a \\\n' - 'comment\\\n' - '\\\n'), - allow_line_cont=True) - assert list(output) == ['I am not a comment'] + output = iter_read_bash( + StringIO( + "\n" + "# hi I am a comment\n" + "I am \\\n" + "\\\n" + "not a \\\n" + "comment\\\n" + "\\\n" + ), + allow_line_cont=True, + ) + assert list(output) == ["I am not a comment"] # embedded comment prefix via continued lines - output = iter_read_bash(StringIO( - '\\\n' - '# comment\\\n' - ' not a comment\n' - '\\\n' - ' # inner comment\n' - 'also not\\\n' - '#\\\n' - 'a comment\n'), - allow_line_cont=True) - assert list(output) == ['not a comment', 'also not#a comment'] + output = iter_read_bash( + StringIO( + "\\\n" + "# comment\\\n" + " not a comment\n" + "\\\n" + " # inner comment\n" + "also not\\\n" + "#\\\n" + "a comment\n" + ), + allow_line_cont=True, + ) + assert list(output) == ["not a comment", "also not#a comment"] # Line continuations have to end with \ without any backslash # before the pattern. - output = iter_read_bash(StringIO( - 'I am \\ \n' - 'not a comment'), - allow_line_cont=True) - assert list(output) == ['I am \\', 'not a comment'] - output = iter_read_bash(StringIO( - '\\\n' - 'I am \\\\\n' - 'not a comment'), - allow_line_cont=True) - assert list(output) == ['I am \\\\', 'not a comment'] + output = iter_read_bash( + StringIO("I am \\ \n" "not a comment"), allow_line_cont=True + ) + assert list(output) == ["I am \\", "not a comment"] + output = iter_read_bash( + StringIO("\\\n" "I am \\\\\n" "not a comment"), allow_line_cont=True + ) + assert list(output) == ["I am \\\\", "not a comment"] def test_read_bash(self): - output = read_bash(StringIO( - '\n' - '# hi I am a comment\n' - 'I am not\n')) - assert output == ['I am not'] + output = read_bash(StringIO("\n" "# hi I am a comment\n" "I am not\n")) + assert output == ["I am not"] class TestReadDictConfig: - def test_read_dict(self): - bash_dict = read_dict(StringIO( - '\n' - '# hi I am a comment\n' - 'foo1=bar\n' - 'foo2="bar"\n' - 'foo3=\'bar"\n')) + bash_dict = read_dict( + StringIO( + "\n" "# hi I am a comment\n" "foo1=bar\n" 'foo2="bar"\n' "foo3='bar\"\n" + ) + ) assert bash_dict == { - 'foo1': 'bar', - 'foo2': 'bar', - 'foo3': '\'bar"', - } - assert read_dict(['foo=bar'], source_isiter=True) == {'foo': 'bar'} + "foo1": "bar", + "foo2": "bar", + "foo3": "'bar\"", + } + assert read_dict(["foo=bar"], source_isiter=True) == {"foo": "bar"} with pytest.raises(BashParseError): - read_dict(['invalid'], source_isiter=True) + read_dict(["invalid"], source_isiter=True) - bash_dict = read_dict(StringIO("foo bar\nfoo2 bar\nfoo3\tbar\n"), splitter=None) - assert bash_dict == dict.fromkeys(('foo', 'foo2', 'foo3'), 'bar') - bash_dict = read_dict(['foo = blah', 'foo2= blah ', 'foo3=blah'], strip=True) - assert bash_dict == dict.fromkeys(('foo', 'foo2', 'foo3'), 'blah') + bash_dict = read_dict( + StringIO("foo bar\nfoo2 bar\nfoo3\tbar\n"), splitter=None + ) + assert bash_dict == dict.fromkeys(("foo", "foo2", "foo3"), "bar") + bash_dict = read_dict(["foo = blah", "foo2= blah ", "foo3=blah"], strip=True) + assert bash_dict == dict.fromkeys(("foo", "foo2", "foo3"), "blah") class TestReadBashDict: - @pytest.fixture(autouse=True) def _setup(self, tmp_path): self.valid_file = tmp_path / "valid" self.valid_file.write_text( - '# hi I am a comment\n' - 'foo1=bar\n' + "# hi I am a comment\n" + "foo1=bar\n" "foo2='bar'\n" 'foo3="bar"\n' - 'foo4=-/:j4\n' - 'foo5=\n' + "foo4=-/:j4\n" + "foo5=\n" 'export foo6="bar"\n' ) self.sourcing_file = tmp_path / "sourcing" @@ -131,18 +141,13 @@ class TestReadBashDict: self.sourcing_file2.write_text(f'source "{self.valid_file}"\n') self.advanced_file = tmp_path / "advanced" self.advanced_file.write_text( - 'one1=1\n' - 'one_=$one1\n' - 'two1=2\n' - 'two_=${two1}\n' + "one1=1\n" "one_=$one1\n" "two1=2\n" "two_=${two1}\n" ) self.env_file = tmp_path / "env" - self.env_file.write_text('imported=${external}\n') + self.env_file.write_text("imported=${external}\n") self.escaped_file = tmp_path / "escaped" self.escaped_file.write_text( - 'end=bye\n' - 'quoteddollar="\\${dollar}"\n' - 'quotedexpansion="\\${${end}}"\n' + "end=bye\n" 'quoteddollar="\\${dollar}"\n' 'quotedexpansion="\\${${end}}"\n' ) self.unclosed_file = tmp_path / "unclosed" self.unclosed_file.write_text('foo="bar') @@ -151,19 +156,19 @@ class TestReadBashDict: try: return read_bash_dict(handle, *args, **kwds) finally: - if hasattr(handle, 'close'): + if hasattr(handle, "close"): handle.close() def test_read_bash_dict(self): # TODO this is not even close to complete bash_dict = self.invoke_and_close(str(self.valid_file)) d = { - 'foo1': 'bar', - 'foo2': 'bar', - 'foo3': 'bar', - 'foo4': '-/:j4', - 'foo5': '', - 'foo6': 'bar', + "foo1": "bar", + "foo2": "bar", + "foo3": "bar", + "foo4": "-/:j4", + "foo5": "", + "foo6": "bar", } assert bash_dict == d @@ -171,59 +176,81 @@ class TestReadBashDict: self.invoke_and_close(StringIO("a=b\ny='")) def test_var_read(self): - assert self.invoke_and_close(StringIO("x=y@a\n")) == {'x': 'y@a'} - assert self.invoke_and_close(StringIO("x=y~a\n")) == {'x': 'y~a'} - assert self.invoke_and_close(StringIO("x=y^a\n")) == {'x': 'y^a'} - assert self.invoke_and_close(StringIO('x="\nasdf\nfdsa"')) == {'x': '\nasdf\nfdsa'} + assert self.invoke_and_close(StringIO("x=y@a\n")) == {"x": "y@a"} + assert self.invoke_and_close(StringIO("x=y~a\n")) == {"x": "y~a"} + assert self.invoke_and_close(StringIO("x=y^a\n")) == {"x": "y^a"} + assert self.invoke_and_close(StringIO('x="\nasdf\nfdsa"')) == { + "x": "\nasdf\nfdsa" + } def test_empty_assign(self): self.valid_file.write_text("foo=\ndar=blah\n") - assert self.invoke_and_close(str(self.valid_file)) == {'foo': '', 'dar': 'blah'} + assert self.invoke_and_close(str(self.valid_file)) == {"foo": "", "dar": "blah"} self.valid_file.write_text("foo=\ndar=\n") - assert self.invoke_and_close(str(self.valid_file)) == {'foo': '', 'dar': ''} + assert self.invoke_and_close(str(self.valid_file)) == {"foo": "", "dar": ""} self.valid_file.write_text("foo=blah\ndar=\n") - assert self.invoke_and_close(str(self.valid_file)) == {'foo': 'blah', 'dar': ''} + assert self.invoke_and_close(str(self.valid_file)) == {"foo": "blah", "dar": ""} def test_quoting(self): - assert self.invoke_and_close(StringIO("x='y \\\na'")) == {'x': 'y \\\na'} - assert self.invoke_and_close(StringIO("x='y'a\n")) == {'x': "ya"} - assert self.invoke_and_close(StringIO('x="y \\\nasdf"')) == {'x': 'y asdf'} + assert self.invoke_and_close(StringIO("x='y \\\na'")) == {"x": "y \\\na"} + assert self.invoke_and_close(StringIO("x='y'a\n")) == {"x": "ya"} + assert self.invoke_and_close(StringIO('x="y \\\nasdf"')) == {"x": "y asdf"} def test_eof_without_newline(self): - assert self.invoke_and_close(StringIO("x=y")) == {'x': 'y'} - assert self.invoke_and_close(StringIO("x='y'a")) == {'x': 'ya'} + assert self.invoke_and_close(StringIO("x=y")) == {"x": "y"} + assert self.invoke_and_close(StringIO("x='y'a")) == {"x": "ya"} def test_sourcing(self): - output = self.invoke_and_close(str(self.sourcing_file), sourcing_command='source') - expected = {'foo1': 'bar', 'foo2': 'bar', 'foo3': 'bar', 'foo4': '-/:j4', 'foo5': '', 'foo6': 'bar'} + output = self.invoke_and_close( + str(self.sourcing_file), sourcing_command="source" + ) + expected = { + "foo1": "bar", + "foo2": "bar", + "foo3": "bar", + "foo4": "-/:j4", + "foo5": "", + "foo6": "bar", + } assert output == expected - output = self.invoke_and_close(str(self.sourcing_file2), sourcing_command='source') - expected = {'foo1': 'bar', 'foo2': 'bar', 'foo3': 'bar', 'foo4': '-/:j4', 'foo5': '', 'foo6': 'bar'} + output = self.invoke_and_close( + str(self.sourcing_file2), sourcing_command="source" + ) + expected = { + "foo1": "bar", + "foo2": "bar", + "foo3": "bar", + "foo4": "-/:j4", + "foo5": "", + "foo6": "bar", + } assert output == expected def test_read_advanced(self): output = self.invoke_and_close(str(self.advanced_file)) expected = { - 'one1': '1', - 'one_': '1', - 'two1': '2', - 'two_': '2', + "one1": "1", + "one_": "1", + "two1": "2", + "two_": "2", } assert output == expected def test_env(self): - assert self.invoke_and_close(str(self.env_file)) == {'imported': ''} - env = {'external': 'imported foo'} + assert self.invoke_and_close(str(self.env_file)) == {"imported": ""} + env = {"external": "imported foo"} env_backup = env.copy() - assert self.invoke_and_close(str(self.env_file), env) == {'imported': 'imported foo'} + assert self.invoke_and_close(str(self.env_file), env) == { + "imported": "imported foo" + } assert env_backup == env def test_escaping(self): output = self.invoke_and_close(str(self.escaped_file)) expected = { - 'end': 'bye', - 'quoteddollar': '${dollar}', - 'quotedexpansion': '${bye}', + "end": "bye", + "quoteddollar": "${dollar}", + "quotedexpansion": "${bye}", } assert output == expected diff --git a/tests/test_caching.py b/tests/test_caching.py index eaa5014..06615d3 100644 --- a/tests/test_caching.py +++ b/tests/test_caching.py @@ -5,17 +5,20 @@ from snakeoil.caching import WeakInstMeta class weak_slotted(metaclass=WeakInstMeta): __inst_caching__ = True - __slots__ = ('one',) + __slots__ = ("one",) class weak_inst(metaclass=WeakInstMeta): __inst_caching__ = True counter = 0 + def __new__(cls, *args, **kwargs): cls.counter += 1 return object.__new__(cls) + def __init__(self, *args, **kwargs): pass + @classmethod def reset(cls): cls.counter = 0 @@ -34,7 +37,6 @@ class reenabled_weak_inst(automatic_disabled_weak_inst): class TestWeakInstMeta: - def test_reuse(self, kls=weak_inst): kls.reset() o = kls() @@ -99,8 +101,8 @@ class TestWeakInstMeta: # (RaisingHashFor...). # UserWarning is ignored and everything other warning is an error. - @pytest.mark.filterwarnings('ignore::UserWarning') - @pytest.mark.filterwarnings('error') + @pytest.mark.filterwarnings("ignore::UserWarning") + @pytest.mark.filterwarnings("error") def test_uncachable(self): weak_inst.reset() @@ -108,21 +110,24 @@ class TestWeakInstMeta: class RaisingHashForTestUncachable: def __init__(self, error): self.error = error + def __hash__(self): raise self.error assert weak_inst([]) is not weak_inst([]) assert weak_inst.counter == 2 for x in (TypeError, NotImplementedError): - assert weak_inst(RaisingHashForTestUncachable(x)) is not \ - weak_inst(RaisingHashForTestUncachable(x)) + assert weak_inst(RaisingHashForTestUncachable(x)) is not weak_inst( + RaisingHashForTestUncachable(x) + ) - @pytest.mark.filterwarnings('error::UserWarning') + @pytest.mark.filterwarnings("error::UserWarning") def test_uncachable_warning_msg(self): # This name is *important*, see above. class RaisingHashForTestUncachableWarnings: def __init__(self, error): self.error = error + def __hash__(self): raise self.error @@ -134,6 +139,7 @@ class TestWeakInstMeta: class BrokenHash: def __hash__(self): return 1 + assert weak_inst(BrokenHash()) is not weak_inst(BrokenHash()) def test_weak_slot(self): @@ -148,7 +154,7 @@ class TestWeakInstMeta: # The actual test is that the class definition works. class ExistingWeakrefSlot: __inst_caching__ = True - __slots__ = ('one', '__weakref__') + __slots__ = ("one", "__weakref__") assert ExistingWeakrefSlot() diff --git a/tests/test_chksum.py b/tests/test_chksum.py index 016e3f7..b4c1ab2 100644 --- a/tests/test_chksum.py +++ b/tests/test_chksum.py @@ -3,15 +3,16 @@ from snakeoil import chksum class Test_funcs: - def setup_method(self, method): chksum.__inited__ = False chksum.chksum_types.clear() self._saved_init = chksum.init self._inited_count = 0 + def f(): self._inited_count += 1 chksum.__inited__ = True + chksum.init = f # ensure we aren't mangling chksum state for other tests. @@ -41,4 +42,3 @@ class Test_funcs: assert chksum.get_handler("x") == 1 assert chksum.get_handler("y") == 2 assert self._inited_count == 1 - diff --git a/tests/test_chksum_defaults.py b/tests/test_chksum_defaults.py index 7f867d8..a22d339 100644 --- a/tests/test_chksum_defaults.py +++ b/tests/test_chksum_defaults.py @@ -14,14 +14,15 @@ def require_chf(func): def subfunc(self): if self.chf is None: pytest.skip( - 'no handler for %s, do you need to install PyCrypto or mhash?' - % (self.chf_type,)) + "no handler for %s, do you need to install PyCrypto or mhash?" + % (self.chf_type,) + ) func(self) + return subfunc class base: - def get_chf(self): try: self.chf = chksum.get_handler(self.chf_type) @@ -53,14 +54,17 @@ class base: @require_chf def test_data_source_check(self): assert self.chf(local_source(self.fn)) == self.expected_long - assert self.chf(data_source(fileutils.readfile_ascii(self.fn))) == self.expected_long + assert ( + self.chf(data_source(fileutils.readfile_ascii(self.fn))) + == self.expected_long + ) -class ChksumTest(base): +class ChksumTest(base): @require_chf def test_str2long(self): assert self.chf.str2long(self.expected_str) == self.expected_long - if self.chf_type == 'size': + if self.chf_type == "size": return for x in extra_chksums.get(self.chf_type, ()): assert self.chf.str2long(x) == int(x, 16) @@ -68,11 +72,12 @@ class ChksumTest(base): @require_chf def test_long2str(self): assert self.chf.long2str(self.expected_long) == self.expected_str - if self.chf_type == 'size': + if self.chf_type == "size": return for x in extra_chksums.get(self.chf_type, ()): assert self.chf.long2str(int(x == 16)), x + checksums = { "rmd160": "b83ad488d624e7911f886420ab230f78f6368b9f", "sha1": "63cd8cce8a1773dffb400ee184be3ec7d89791f5", @@ -87,22 +92,22 @@ checksums = { checksums.update((k, (int(v, 16), v)) for k, v in checksums.items()) checksums["size"] = (int(len(data) * multi), str(int(len(data) * multi))) -extra_chksums = { - "md5": - ["2dfd84279314a178d0fa842af3a40e25577e1bc"] -} +extra_chksums = {"md5": ["2dfd84279314a178d0fa842af3a40e25577e1bc"]} for k, v in checksums.items(): - extra_chksums.setdefault(k, []).extend((''.rjust(len(v[1]), '0'), '01'.rjust(len(v[1]), '0'))) + extra_chksums.setdefault(k, []).extend( + ("".rjust(len(v[1]), "0"), "01".rjust(len(v[1]), "0")) + ) # trick: create subclasses for each checksum with a useful class name. for chf_type, expected in checksums.items(): expectedsum = expected[0] expectedstr = expected[1] - globals()['TestChksum' + chf_type.capitalize()] = type( - 'TestChksum' + chf_type.capitalize(), + globals()["TestChksum" + chf_type.capitalize()] = type( + "TestChksum" + chf_type.capitalize(), (ChksumTest,), - dict(chf_type=chf_type, expected_long=expectedsum, expected_str=expectedstr)) + dict(chf_type=chf_type, expected_long=expectedsum, expected_str=expectedstr), + ) # pylint: disable=undefined-loop-variable del chf_type, expected @@ -110,7 +115,7 @@ del chf_type, expected class TestGetChksums(base): - chfs = [k for k in sorted(checksums) if k in ('md5', 'sha1')] + chfs = [k for k in sorted(checksums) if k in ("md5", "sha1")] expected_long = [checksums[k][0] for k in chfs] def get_chf(self): diff --git a/tests/test_constraints.py b/tests/test_constraints.py index 5e938a9..15d360c 100644 --- a/tests/test_constraints.py +++ b/tests/test_constraints.py @@ -2,61 +2,75 @@ import pytest from snakeoil.constraints import Problem + def any_of(**kwargs): return any(kwargs.values()) + def all_of(**kwargs): return all(kwargs.values()) + def test_readd_variables(): p = Problem() - p.add_variable((True, False), 'x', 'y') + p.add_variable((True, False), "x", "y") with pytest.raises(AssertionError, match="variable 'y' was already added"): - p.add_variable((True, False), 'y', 'z') + p.add_variable((True, False), "y", "z") + def test_constraint_unknown_variable(): p = Problem() - p.add_variable((True, False), 'x', 'y') + p.add_variable((True, False), "x", "y") with pytest.raises(AssertionError, match="unknown variable 'z'"): - p.add_constraint(any_of, ('y', 'z')) + p.add_constraint(any_of, ("y", "z")) + def test_empty_problem(): p = Problem() - assert tuple(p) == ({}, ) + assert tuple(p) == ({},) + def test_empty_constraints(): p = Problem() - p.add_variable((True, False), 'x', 'y') - p.add_variable((True, ), 'z') + p.add_variable((True, False), "x", "y") + p.add_variable((True,), "z") assert len(tuple(p)) == 4 + def test_domain_prefer_later(): p = Problem() - p.add_variable((False, True), 'x', 'y') - p.add_constraint(any_of, ('x', 'y')) - assert next(iter(p)) == {'x': True, 'y': True} + p.add_variable((False, True), "x", "y") + p.add_constraint(any_of, ("x", "y")) + assert next(iter(p)) == {"x": True, "y": True} + def test_constraint_single_variable(): p = Problem() - p.add_variable((True, False), 'x', 'y') - p.add_constraint(lambda x: x, ('x', )) - p.add_constraint(lambda y: not y, ('y', )) - assert tuple(p) == ({'x': True, 'y': False}, ) + p.add_variable((True, False), "x", "y") + p.add_constraint(lambda x: x, ("x",)) + p.add_constraint(lambda y: not y, ("y",)) + assert tuple(p) == ({"x": True, "y": False},) + def test_no_solution(): p = Problem() - p.add_variable((True, ), 'x') - p.add_variable((True, False), 'y', 'z') - p.add_constraint(lambda x, y: not x or y, ('x', 'y')) - p.add_constraint(lambda y, z: not y or not z, ('y', 'z')) - p.add_constraint(lambda x, z: not x or z, ('x', 'z')) + p.add_variable((True,), "x") + p.add_variable((True, False), "y", "z") + p.add_constraint(lambda x, y: not x or y, ("x", "y")) + p.add_constraint(lambda y, z: not y or not z, ("y", "z")) + p.add_constraint(lambda x, z: not x or z, ("x", "z")) assert not tuple(p) + def test_forward_check(): p = Problem() - p.add_variable(range(2, 10), 'x', 'y', 'z') - p.add_constraint(lambda x, y: (x + y) % 2 == 0, ('x', 'y')) - p.add_constraint(lambda x, y, z: (x * y * z) % 2 != 0, ('x', 'y', 'z')) - p.add_constraint(lambda y, z: y < z, ('y', 'z')) - p.add_constraint(lambda z, x: x ** 2 <= z, ('x', 'z')) - assert tuple(p) == ({'x': 3, 'y': 7, 'z': 9}, {'x': 3, 'y': 5, 'z': 9}, {'x': 3, 'y': 3, 'z': 9}) + p.add_variable(range(2, 10), "x", "y", "z") + p.add_constraint(lambda x, y: (x + y) % 2 == 0, ("x", "y")) + p.add_constraint(lambda x, y, z: (x * y * z) % 2 != 0, ("x", "y", "z")) + p.add_constraint(lambda y, z: y < z, ("y", "z")) + p.add_constraint(lambda z, x: x**2 <= z, ("x", "z")) + assert tuple(p) == ( + {"x": 3, "y": 7, "z": 9}, + {"x": 3, "y": 5, "z": 9}, + {"x": 3, "y": 3, "z": 9}, + ) diff --git a/tests/test_containers.py b/tests/test_containers.py index f6940c9..9df3258 100644 --- a/tests/test_containers.py +++ b/tests/test_containers.py @@ -5,7 +5,6 @@ from snakeoil import containers class TestInvertedContains: - def setup_method(self, method): self.set = containers.InvertedContains(range(12)) @@ -17,7 +16,7 @@ class TestInvertedContains: class BasicSet(containers.SetMixin): - __slots__ = ('_data',) + __slots__ = ("_data",) def __init__(self, data): self._data = set(data) @@ -28,7 +27,7 @@ class BasicSet(containers.SetMixin): def __contains__(self, other): return other in self._data - #def __str__(self): + # def __str__(self): # return 'BasicSet([%s])' % ', '.join((str(x) for x in self._data)) def __eq__(self, other): @@ -43,7 +42,6 @@ class BasicSet(containers.SetMixin): class TestSetMethods: - def test_and(self): c = BasicSet(range(100)) s = set(range(25, 75)) @@ -80,8 +78,8 @@ class TestSetMethods: assert c - s == r1 assert s - c == r2 -class TestLimitedChangeSet: +class TestLimitedChangeSet: def setup_method(self, method): self.set = containers.LimitedChangeSet(range(12)) @@ -89,17 +87,18 @@ class TestLimitedChangeSet: def f(val): assert isinstance(val, int) return val + self.set = containers.LimitedChangeSet(range(12), key_validator=f) self.set.add(13) self.set.add(14) self.set.remove(11) assert 5 in self.set with pytest.raises(AssertionError): - self.set.add('2') + self.set.add("2") with pytest.raises(AssertionError): - self.set.remove('2') + self.set.remove("2") with pytest.raises(AssertionError): - self.set.__contains__('2') + self.set.__contains__("2") def test_basic(self, changes=0): # this should be a no-op @@ -188,7 +187,7 @@ class TestLimitedChangeSet: assert sorted(list(self.set)) == list(range(-1, 13)) def test_str(self): - assert str(containers.LimitedChangeSet([7])) == 'LimitedChangeSet([7])' + assert str(containers.LimitedChangeSet([7])) == "LimitedChangeSet([7])" def test__eq__(self): c = containers.LimitedChangeSet(range(99)) @@ -199,7 +198,6 @@ class TestLimitedChangeSet: class TestLimitedChangeSetWithBlacklist: - def setup_method(self, method): self.set = containers.LimitedChangeSet(range(12), [3, 13]) @@ -222,7 +220,6 @@ class TestLimitedChangeSetWithBlacklist: class TestProtectedSet: - def setup_method(self, method): self.set = containers.ProtectedSet(set(range(12))) diff --git a/tests/test_contexts.py b/tests/test_contexts.py index 219d5ee..be212a2 100644 --- a/tests/test_contexts.py +++ b/tests/test_contexts.py @@ -44,9 +44,10 @@ def test_syspath(tmpdir): assert mangled_syspath == tuple(sys.path) -@pytest.mark.skip(reason='this currently breaks on github ci, https://github.com/pkgcore/snakeoil/issues/68') +@pytest.mark.skip( + reason="this currently breaks on github ci, https://github.com/pkgcore/snakeoil/issues/68" +) class TestSplitExec: - def test_context_process(self): # code inside the with statement is run in a separate process pid = os.getpid() @@ -77,9 +78,9 @@ class TestSplitExec: b = 3 # changes to locals aren't propagated back assert a == 1 - assert 'b' not in locals() + assert "b" not in locals() # but they're accessible via the 'locals' attr - expected = {'a': 2, 'b': 3} + expected = {"a": 2, "b": 3} for k, v in expected.items(): assert c.locals[k] == v @@ -87,20 +88,21 @@ class TestSplitExec: with SplitExec() as c: func = lambda x: x from sys import implementation + a = 4 - assert c.locals == {'a': 4} + assert c.locals == {"a": 4} def test_context_exceptions(self): # exceptions in the child process are sent back to the parent and re-raised with pytest.raises(IOError) as e: with SplitExec() as c: - raise IOError(errno.EBUSY, 'random error') + raise IOError(errno.EBUSY, "random error") assert e.value.errno == errno.EBUSY def test_child_setup_raises_exception(self): class ChildSetupException(SplitExec): def _child_setup(self): - raise IOError(errno.EBUSY, 'random error') + raise IOError(errno.EBUSY, "random error") with pytest.raises(IOError) as e: with ChildSetupException() as c: @@ -108,26 +110,33 @@ class TestSplitExec: assert e.value.errno == errno.EBUSY -@pytest.mark.skipif(not sys.platform.startswith('linux'), reason='supported on Linux only') -@pytest.mark.xfail(platform.python_implementation() == "PyPy", reason='Fails on PyPy') +@pytest.mark.skipif( + not sys.platform.startswith("linux"), reason="supported on Linux only" +) +@pytest.mark.xfail(platform.python_implementation() == "PyPy", reason="Fails on PyPy") class TestNamespace: - - @pytest.mark.skipif(not os.path.exists('/proc/self/ns/user'), - reason='user namespace support required') + @pytest.mark.skipif( + not os.path.exists("/proc/self/ns/user"), + reason="user namespace support required", + ) def test_user_namespace(self): try: with Namespace(user=True) as ns: assert os.getuid() == 0 except PermissionError: - pytest.skip('No permission to use user namespace') - - @pytest.mark.skipif(not (os.path.exists('/proc/self/ns/user') and os.path.exists('/proc/self/ns/uts')), - reason='user and uts namespace support required') + pytest.skip("No permission to use user namespace") + + @pytest.mark.skipif( + not ( + os.path.exists("/proc/self/ns/user") and os.path.exists("/proc/self/ns/uts") + ), + reason="user and uts namespace support required", + ) def test_uts_namespace(self): try: - with Namespace(user=True, uts=True, hostname='host') as ns: - ns_hostname, _, ns_domainname = socket.getfqdn().partition('.') - assert ns_hostname == 'host' - assert ns_domainname == '' + with Namespace(user=True, uts=True, hostname="host") as ns: + ns_hostname, _, ns_domainname = socket.getfqdn().partition(".") + assert ns_hostname == "host" + assert ns_domainname == "" except PermissionError: - pytest.skip('No permission to use user and uts namespace') + pytest.skip("No permission to use user and uts namespace") diff --git a/tests/test_currying.py b/tests/test_currying.py index d14f0fc..7f8618f 100644 --- a/tests/test_currying.py +++ b/tests/test_currying.py @@ -5,8 +5,10 @@ from snakeoil import currying def passthrough(*args, **kwargs): return args, kwargs + # docstring is part of the test + def documented(): """original docstring""" @@ -18,36 +20,37 @@ class TestPreCurry: def test_pre_curry(self): noop = self.pre_curry(passthrough) assert noop() == ((), {}) - assert noop('foo', 'bar') == (('foo', 'bar'), {}) - assert noop(foo='bar') == ((), {'foo': 'bar'}) - assert noop('foo', bar='baz') == (('foo',), {'bar': 'baz'}) + assert noop("foo", "bar") == (("foo", "bar"), {}) + assert noop(foo="bar") == ((), {"foo": "bar"}) + assert noop("foo", bar="baz") == (("foo",), {"bar": "baz"}) one_arg = self.pre_curry(passthrough, 42) assert one_arg() == ((42,), {}) - assert one_arg('foo', 'bar') == ((42, 'foo', 'bar'), {}) - assert one_arg(foo='bar') == ((42,), {'foo': 'bar'}) - assert one_arg('foo', bar='baz') == ((42, 'foo'), {'bar': 'baz'}) + assert one_arg("foo", "bar") == ((42, "foo", "bar"), {}) + assert one_arg(foo="bar") == ((42,), {"foo": "bar"}) + assert one_arg("foo", bar="baz") == ((42, "foo"), {"bar": "baz"}) keyword_arg = self.pre_curry(passthrough, foo=42) - assert keyword_arg() == ((), {'foo': 42}) - assert keyword_arg('foo', 'bar') == (('foo', 'bar'), {'foo': 42}) - assert keyword_arg(foo='bar') == ((), {'foo': 'bar'}) - assert keyword_arg('foo', bar='baz') == (('foo',), {'bar': 'baz', 'foo': 42}) + assert keyword_arg() == ((), {"foo": 42}) + assert keyword_arg("foo", "bar") == (("foo", "bar"), {"foo": 42}) + assert keyword_arg(foo="bar") == ((), {"foo": "bar"}) + assert keyword_arg("foo", bar="baz") == (("foo",), {"bar": "baz", "foo": 42}) both = self.pre_curry(passthrough, 42, foo=42) - assert both() == ((42,), {'foo': 42}) - assert both('foo', 'bar') == ((42, 'foo', 'bar'), {'foo': 42}) - assert both(foo='bar') == ((42,), {'foo': 'bar'}) - assert both('foo', bar='baz') == ((42, 'foo'), {'bar': 'baz', 'foo': 42}) + assert both() == ((42,), {"foo": 42}) + assert both("foo", "bar") == ((42, "foo", "bar"), {"foo": 42}) + assert both(foo="bar") == ((42,), {"foo": "bar"}) + assert both("foo", bar="baz") == ((42, "foo"), {"bar": "baz", "foo": 42}) def test_curry_original(self): assert self.pre_curry(passthrough).func is passthrough def test_instancemethod(self): class Test: - method = self.pre_curry(passthrough, 'test') + method = self.pre_curry(passthrough, "test") + test = Test() - assert (('test', test), {}) == test.method() + assert (("test", test), {}) == test.method() class Test_pretty_docs: @@ -56,58 +59,63 @@ class Test_pretty_docs: def test_module_magic(self): for target in self.currying_targets: - assert currying.pretty_docs(target(passthrough)).__module__ is \ - passthrough.__module__ + assert ( + currying.pretty_docs(target(passthrough)).__module__ + is passthrough.__module__ + ) # test is kinda useless if they are identical without pretty_docs - assert getattr(target(passthrough), '__module__', None) is not \ - passthrough.__module__ + assert ( + getattr(target(passthrough), "__module__", None) + is not passthrough.__module__ + ) def test_pretty_docs(self): for target in self.currying_targets: for func in (passthrough, documented): - assert currying.pretty_docs(target(func), 'new doc').__doc__ == 'new doc' + assert ( + currying.pretty_docs(target(func), "new doc").__doc__ == "new doc" + ) assert currying.pretty_docs(target(func)).__doc__ is func.__doc__ class TestPostCurry: - def test_post_curry(self): noop = currying.post_curry(passthrough) assert noop() == ((), {}) - assert noop('foo', 'bar') == (('foo', 'bar'), {}) - assert noop(foo='bar') == ((), {'foo': 'bar'}) - assert noop('foo', bar='baz') == (('foo',), {'bar': 'baz'}) + assert noop("foo", "bar") == (("foo", "bar"), {}) + assert noop(foo="bar") == ((), {"foo": "bar"}) + assert noop("foo", bar="baz") == (("foo",), {"bar": "baz"}) one_arg = currying.post_curry(passthrough, 42) assert one_arg() == ((42,), {}) - assert one_arg('foo', 'bar') == (('foo', 'bar', 42), {}) - assert one_arg(foo='bar') == ((42,), {'foo': 'bar'}) - assert one_arg('foo', bar='baz') == (('foo', 42), {'bar': 'baz'}) + assert one_arg("foo", "bar") == (("foo", "bar", 42), {}) + assert one_arg(foo="bar") == ((42,), {"foo": "bar"}) + assert one_arg("foo", bar="baz") == (("foo", 42), {"bar": "baz"}) keyword_arg = currying.post_curry(passthrough, foo=42) - assert keyword_arg() == ((), {'foo': 42}) - assert keyword_arg('foo', 'bar') == (('foo', 'bar'), {'foo': 42}) - assert keyword_arg(foo='bar') == ((), {'foo': 42}) - assert keyword_arg('foo', bar='baz') == (('foo',), {'bar': 'baz', 'foo': 42}) + assert keyword_arg() == ((), {"foo": 42}) + assert keyword_arg("foo", "bar") == (("foo", "bar"), {"foo": 42}) + assert keyword_arg(foo="bar") == ((), {"foo": 42}) + assert keyword_arg("foo", bar="baz") == (("foo",), {"bar": "baz", "foo": 42}) both = currying.post_curry(passthrough, 42, foo=42) - assert both() == ((42,), {'foo': 42}) - assert both('foo', 'bar') == (('foo', 'bar', 42), {'foo': 42}) - assert both(foo='bar') == ((42,), {'foo': 42}) - assert both('foo', bar='baz') == (('foo', 42), {'bar': 'baz', 'foo': 42}) + assert both() == ((42,), {"foo": 42}) + assert both("foo", "bar") == (("foo", "bar", 42), {"foo": 42}) + assert both(foo="bar") == ((42,), {"foo": 42}) + assert both("foo", bar="baz") == (("foo", 42), {"bar": "baz", "foo": 42}) def test_curry_original(self): assert currying.post_curry(passthrough).func is passthrough def test_instancemethod(self): class Test: - method = currying.post_curry(passthrough, 'test') + method = currying.post_curry(passthrough, "test") + test = Test() - assert ((test, 'test'), {}) == test.method() + assert ((test, "test"), {}) == test.method() class Test_wrap_exception: - def test_wrap_exception_complex(self): inner, outer = [], [] @@ -118,33 +126,33 @@ class Test_wrap_exception: assert isinstance(exception, inner_exception) assert functor is throwing_func assert fargs == (False,) - assert fkwds == {'monkey': 'bone'} + assert fkwds == {"monkey": "bone"} outer.append(True) raise wrapping_exception() def throwing_func(*args, **kwds): assert args == (False,) - assert kwds == {'monkey': 'bone'} + assert kwds == {"monkey": "bone"} inner.append(True) raise inner_exception() func = currying.wrap_exception_complex(f, IndexError)(throwing_func) # basic behaviour - pytest.raises(IndexError, func, False, monkey='bone') + pytest.raises(IndexError, func, False, monkey="bone") assert len(inner) == 1 assert len(outer) == 1 # ensure pass thru if it's an allowed exception inner_exception = IndexError - pytest.raises(IndexError, func, False, monkey='bone') + pytest.raises(IndexError, func, False, monkey="bone") assert len(inner) == 2 assert len(outer) == 1 # finally, ensure it doesn't intercept, and passes thru for # exceptions it shouldn't handle inner_exception = MemoryError - pytest.raises(MemoryError, func, False, monkey='bone') + pytest.raises(MemoryError, func, False, monkey="bone") assert len(inner) == 3 assert len(outer) == 1 @@ -159,9 +167,10 @@ class Test_wrap_exception: self.args = args self.kwds = kwds - func = currying.wrap_exception(my_exception, 1, 3, 2, monkey='bone', - ignores=ValueError)(throwing_func) - assert func.__name__ == 'throwing_func' + func = currying.wrap_exception( + my_exception, 1, 3, 2, monkey="bone", ignores=ValueError + )(throwing_func) + assert func.__name__ == "throwing_func" pytest.raises(ValueError, func) throw_kls = IndexError pytest.raises(my_exception, func) @@ -170,17 +179,23 @@ class Test_wrap_exception: raise AssertionError("shouldn't have been able to reach here") except my_exception as e: assert e.args == (1, 3, 2) - assert e.kwds == {'monkey': 'bone'} + assert e.kwds == {"monkey": "bone"} # finally, verify that the exception can be pased in. func = currying.wrap_exception( - my_exception, 1, 3, 2, monkey='bone', - ignores=ValueError, pass_error="the_exception")(throwing_func) - assert func.__name__ == 'throwing_func' + my_exception, + 1, + 3, + 2, + monkey="bone", + ignores=ValueError, + pass_error="the_exception", + )(throwing_func) + assert func.__name__ == "throwing_func" pytest.raises(my_exception, func) try: func() raise AssertionError("shouldn't have been able to reach here") except my_exception as e: assert e.args == (1, 3, 2) - assert e.kwds == {'monkey': 'bone', 'the_exception': e.__cause__} + assert e.kwds == {"monkey": "bone", "the_exception": e.__cause__} diff --git a/tests/test_data_source.py b/tests/test_data_source.py index ddd3eee..1ede9aa 100644 --- a/tests/test_data_source.py +++ b/tests/test_data_source.py @@ -53,15 +53,15 @@ class TestDataSource: assert reader_data == writer_data def _mk_data(self, size=(100000)): - return ''.join(str(x % 10) for x in range(size)) + return "".join(str(x % 10) for x in range(size)) def test_transfer_to_data_source(self): data = self._mk_data() reader = self.get_obj(data=data) if self.supports_mutable: - writer = self.get_obj(data='', mutable=True) + writer = self.get_obj(data="", mutable=True) else: - writer = data_source.data_source('', mutable=True) + writer = data_source.data_source("", mutable=True) reader.transfer_to_data_source(writer) self.assertContents(reader, writer) @@ -70,9 +70,11 @@ class TestDataSource: data = self._mk_data() reader = self.get_obj(data=data) if isinstance(reader, data_source.bz2_source): - writer = data_source.bz2_source(tmp_path / 'transfer_to_path', mutable=True) + writer = data_source.bz2_source(tmp_path / "transfer_to_path", mutable=True) else: - writer = data_source.local_source(tmp_path / 'transfer_to_path', mutable=True) + writer = data_source.local_source( + tmp_path / "transfer_to_path", mutable=True + ) reader.transfer_to_path(writer.path) @@ -82,9 +84,9 @@ class TestDataSource: data = self._mk_data() reader = self.get_obj(data=data) if self.supports_mutable: - writer = self.get_obj(data='', mutable=True) + writer = self.get_obj(data="", mutable=True) else: - writer = data_source.data_source('', mutable=True) + writer = data_source.data_source("", mutable=True) with reader.bytes_fileobj() as reader_f, writer.bytes_fileobj(True) as writer_f: data_source.transfer_between_files(reader_f, writer_f) @@ -93,15 +95,14 @@ class TestDataSource: class TestLocalSource(TestDataSource): - def get_obj(self, data="foonani", mutable=False, test_creation=False): self.fp = self.dir / "localsource.test" if not test_creation: mode = None if isinstance(data, bytes): - mode = 'wb' + mode = "wb" elif mode is None: - mode = 'w' + mode = "w" with open(self.fp, mode) as f: f.write(data) return data_source.local_source(self.fp, mutable=mutable) @@ -118,21 +119,20 @@ class TestLocalSource(TestDataSource): obj = self.get_obj(test_creation=True, mutable=True) # this will blow up if tries to ascii decode it. with obj.bytes_fileobj(True) as f: - assert f.read() == b'' + assert f.read() == b"" f.write(data) with obj.bytes_fileobj() as f: assert f.read() == data class TestBz2Source(TestDataSource): - def get_obj(self, data="foonani", mutable=False, test_creation=False): self.fp = self.dir / "bz2source.test.bz2" if not test_creation: if isinstance(data, str): data = data.encode() - with open(self.fp, 'wb') as f: - f.write(compression.compress_data('bzip2', data)) + with open(self.fp, "wb") as f: + f.write(compression.compress_data("bzip2", data)) return data_source.bz2_source(self.fp, mutable=mutable) def test_bytes_fileobj(self): @@ -150,8 +150,7 @@ class Test_invokable_data_source(TestDataSource): def get_obj(self, data="foonani", mutable=False): if isinstance(data, str): data = data.encode("utf8") - return data_source.invokable_data_source( - partial(self._get_data, data)) + return data_source.invokable_data_source(partial(self._get_data, data)) @staticmethod def _get_data(data, is_text=False): @@ -168,10 +167,10 @@ class Test_invokable_data_source_wrapper_text(Test_invokable_data_source): def get_obj(self, mutable=False, data="foonani"): return data_source.invokable_data_source.wrap_function( - partial(self._get_data, data), - self.text_mode) + partial(self._get_data, data), self.text_mode + ) - def _get_data(self, data='foonani'): + def _get_data(self, data="foonani"): if isinstance(data, str): if not self.text_mode: return data.encode("utf8") diff --git a/tests/test_decorators.py b/tests/test_decorators.py index ce00ac4..92c0fb8 100644 --- a/tests/test_decorators.py +++ b/tests/test_decorators.py @@ -8,7 +8,6 @@ from snakeoil.decorators import coroutine, namespace, splitexec class TestSplitExecDecorator: - def setup_method(self, method): self.pid = os.getpid() @@ -18,11 +17,14 @@ class TestSplitExecDecorator: assert self.pid != os.getpid() -@pytest.mark.skipif(not sys.platform.startswith('linux'), reason='supported on Linux only') +@pytest.mark.skipif( + not sys.platform.startswith("linux"), reason="supported on Linux only" +) class TestNamespaceDecorator: - - @pytest.mark.skipif(not os.path.exists('/proc/self/ns/user'), - reason='user namespace support required') + @pytest.mark.skipif( + not os.path.exists("/proc/self/ns/user"), + reason="user namespace support required", + ) def test_user_namespace(self): @namespace(user=True) def do_test(): @@ -31,31 +33,34 @@ class TestNamespaceDecorator: try: do_test() except PermissionError: - pytest.skip('No permission to use user namespace') - - @pytest.mark.skipif(not (os.path.exists('/proc/self/ns/user') and os.path.exists('/proc/self/ns/uts')), - reason='user and uts namespace support required') + pytest.skip("No permission to use user namespace") + + @pytest.mark.skipif( + not ( + os.path.exists("/proc/self/ns/user") and os.path.exists("/proc/self/ns/uts") + ), + reason="user and uts namespace support required", + ) def test_uts_namespace(self): - @namespace(user=True, uts=True, hostname='host') + @namespace(user=True, uts=True, hostname="host") def do_test(): - ns_hostname, _, ns_domainname = socket.getfqdn().partition('.') - assert ns_hostname == 'host' - assert ns_domainname == '' + ns_hostname, _, ns_domainname = socket.getfqdn().partition(".") + assert ns_hostname == "host" + assert ns_domainname == "" try: do_test() except PermissionError: - pytest.skip('No permission to use user and uts namespace') + pytest.skip("No permission to use user and uts namespace") class TestCoroutineDecorator: - def test_coroutine(self): @coroutine def count(): i = 0 while True: - val = (yield i) + val = yield i i = val if val is not None else i + 1 cr = count() diff --git a/tests/test_demandload.py b/tests/test_demandload.py index ec2661c..843cf80 100644 --- a/tests/test_demandload.py +++ b/tests/test_demandload.py @@ -9,6 +9,7 @@ from snakeoil import demandload # setup is what the test expects. # it also explicitly resets the state on the way out. + def reset_globals(functor): def f(*args, **kwds): orig_demandload = demandload.demandload @@ -22,60 +23,61 @@ def reset_globals(functor): demandload.demand_compile_regexp = orig_demand_compile demandload._protection_enabled = orig_protection demandload._noisy_protection = orig_noisy + return f class TestParser: - @reset_globals def test_parse(self): for input, output in [ - ('foo', [('foo', 'foo')]), - ('foo:bar', [('foo.bar', 'bar')]), - ('foo:bar,baz@spork', [('foo.bar', 'bar'), ('foo.baz', 'spork')]), - ('foo@bar', [('foo', 'bar')]), - ('foo_bar', [('foo_bar', 'foo_bar')]), - ]: + ("foo", [("foo", "foo")]), + ("foo:bar", [("foo.bar", "bar")]), + ("foo:bar,baz@spork", [("foo.bar", "bar"), ("foo.baz", "spork")]), + ("foo@bar", [("foo", "bar")]), + ("foo_bar", [("foo_bar", "foo_bar")]), + ]: assert output == list(demandload.parse_imports([input])) - pytest.raises(ValueError, list, demandload.parse_imports(['a.b'])) - pytest.raises(ValueError, list, demandload.parse_imports(['a:,'])) - pytest.raises(ValueError, list, demandload.parse_imports(['a:b,x@'])) - pytest.raises(ValueError, list, demandload.parse_imports(['b-x'])) - pytest.raises(ValueError, list, demandload.parse_imports([' b_x'])) + pytest.raises(ValueError, list, demandload.parse_imports(["a.b"])) + pytest.raises(ValueError, list, demandload.parse_imports(["a:,"])) + pytest.raises(ValueError, list, demandload.parse_imports(["a:b,x@"])) + pytest.raises(ValueError, list, demandload.parse_imports(["b-x"])) + pytest.raises(ValueError, list, demandload.parse_imports([" b_x"])) class TestPlaceholder: - @reset_globals def test_getattr(self): scope = {} - placeholder = demandload.Placeholder(scope, 'foo', list) - assert scope == object.__getattribute__(placeholder, '_scope') + placeholder = demandload.Placeholder(scope, "foo", list) + assert scope == object.__getattribute__(placeholder, "_scope") assert placeholder.__doc__ == [].__doc__ - assert scope['foo'] == [] + assert scope["foo"] == [] demandload._protection_enabled = lambda: True with pytest.raises(ValueError): - getattr(placeholder, '__doc__') + getattr(placeholder, "__doc__") @reset_globals def test__str__(self): scope = {} - placeholder = demandload.Placeholder(scope, 'foo', list) - assert scope == object.__getattribute__(placeholder, '_scope') + placeholder = demandload.Placeholder(scope, "foo", list) + assert scope == object.__getattribute__(placeholder, "_scope") assert str(placeholder) == str([]) - assert scope['foo'] == [] + assert scope["foo"] == [] @reset_globals def test_call(self): def passthrough(*args, **kwargs): return args, kwargs + def get_func(): return passthrough + scope = {} - placeholder = demandload.Placeholder(scope, 'foo', get_func) - assert scope == object.__getattribute__(placeholder, '_scope') - assert (('arg',), {'kwarg': 42}) == placeholder('arg', kwarg=42) - assert passthrough is scope['foo'] + placeholder = demandload.Placeholder(scope, "foo", get_func) + assert scope == object.__getattribute__(placeholder, "_scope") + assert (("arg",), {"kwarg": 42}) == placeholder("arg", kwarg=42) + assert passthrough is scope["foo"] @reset_globals def test_setattr(self): @@ -83,45 +85,43 @@ class TestPlaceholder: pass scope = {} - placeholder = demandload.Placeholder(scope, 'foo', Struct) + placeholder = demandload.Placeholder(scope, "foo", Struct) placeholder.val = 7 demandload._protection_enabled = lambda: True with pytest.raises(ValueError): - getattr(placeholder, 'val') - assert 7 == scope['foo'].val + getattr(placeholder, "val") + assert 7 == scope["foo"].val class TestImport: - @reset_globals def test_demandload(self): scope = {} - demandload.demandload('snakeoil:demandload', scope=scope) - assert demandload is not scope['demandload'] - assert demandload.demandload is scope['demandload'].demandload - assert demandload is scope['demandload'] + demandload.demandload("snakeoil:demandload", scope=scope) + assert demandload is not scope["demandload"] + assert demandload.demandload is scope["demandload"].demandload + assert demandload is scope["demandload"] @reset_globals def test_disabled_demandload(self): scope = {} - demandload.disabled_demandload('snakeoil:demandload', scope=scope) - assert demandload is scope['demandload'] + demandload.disabled_demandload("snakeoil:demandload", scope=scope) + assert demandload is scope["demandload"] class TestDemandCompileRegexp: - @reset_globals def test_demand_compile_regexp(self): scope = {} - demandload.demand_compile_regexp('foo', 'frob', scope=scope) - assert list(scope.keys()) == ['foo'] - assert 'frob' == scope['foo'].pattern - assert 'frob' == scope['foo'].pattern + demandload.demand_compile_regexp("foo", "frob", scope=scope) + assert list(scope.keys()) == ["foo"] + assert "frob" == scope["foo"].pattern + assert "frob" == scope["foo"].pattern # verify it's delayed via a bad regex. - demandload.demand_compile_regexp('foo', 'f(', scope=scope) - assert list(scope.keys()) == ['foo'] + demandload.demand_compile_regexp("foo", "f(", scope=scope) + assert list(scope.keys()) == ["foo"] # should blow up on accessing an attribute. - obj = scope['foo'] + obj = scope["foo"] with pytest.raises(sre_constants.error): - getattr(obj, 'pattern') + getattr(obj, "pattern") diff --git a/tests/test_demandload_usage.py b/tests/test_demandload_usage.py index 79ccfe0..ddde056 100644 --- a/tests/test_demandload_usage.py +++ b/tests/test_demandload_usage.py @@ -4,7 +4,7 @@ from snakeoil.test import mixins class TestDemandLoadTargets(mixins.PythonNamespaceWalker): - target_namespace = 'snakeoil' + target_namespace = "snakeoil" ignore_all_import_failures = False @pytest.fixture(autouse=True) @@ -16,8 +16,8 @@ class TestDemandLoadTargets(mixins.PythonNamespaceWalker): def test_demandload_targets(self): for x in self.walk_namespace( - self.target_namespace, - ignore_failed_imports=self.ignore_all_import_failures): + self.target_namespace, ignore_failed_imports=self.ignore_all_import_failures + ): self.check_space(x) def check_space(self, mod): diff --git a/tests/test_dependant_methods.py b/tests/test_dependant_methods.py index ffd6d36..186e175 100644 --- a/tests/test_dependant_methods.py +++ b/tests/test_dependant_methods.py @@ -8,7 +8,6 @@ def func(self, seq, data, val=True): class TestDependantMethods: - @staticmethod def generate_instance(methods, dependencies): class Class(metaclass=dm.ForcedDepends): @@ -25,13 +24,15 @@ class TestDependantMethods: results = [] o = self.generate_instance( {str(x): currying.post_curry(func, results, x) for x in range(10)}, - {str(x): str(x - 1) for x in range(1, 10)}) + {str(x): str(x - 1) for x in range(1, 10)}, + ) getattr(o, "9")() assert results == list(range(10)) results = [] o = self.generate_instance( {str(x): currying.post_curry(func, results, x, False) for x in range(10)}, - {str(x): str(x - 1) for x in range(1, 10)}) + {str(x): str(x - 1) for x in range(1, 10)}, + ) getattr(o, "9")() assert results == [0] getattr(o, "9")() @@ -41,7 +42,8 @@ class TestDependantMethods: results = [] o = self.generate_instance( {str(x): currying.post_curry(func, results, x) for x in range(10)}, - {str(x): str(x - 1) for x in range(1, 10)}) + {str(x): str(x - 1) for x in range(1, 10)}, + ) getattr(o, "1")() assert results == [0, 1] getattr(o, "2")() @@ -71,14 +73,15 @@ class TestDependantMethods: results = [] o = self.generate_instance( {str(x): currying.post_curry(func, results, x) for x in range(10)}, - {str(x): str(x - 1) for x in range(1, 10)}) - getattr(o, '2')(ignore_deps=True) + {str(x): str(x - 1) for x in range(1, 10)}, + ) + getattr(o, "2")(ignore_deps=True) assert [2] == results def test_no_deps(self): results = [] o = self.generate_instance( - {str(x): currying.post_curry(func, results, x) for x in range(10)}, - {}) - getattr(o, '2')() + {str(x): currying.post_curry(func, results, x) for x in range(10)}, {} + ) + getattr(o, "2")() assert [2] == results diff --git a/tests/test_fileutils.py b/tests/test_fileutils.py index a4555f8..356eb74 100644 --- a/tests/test_fileutils.py +++ b/tests/test_fileutils.py @@ -13,7 +13,6 @@ from snakeoil.test import random_str class TestTouch: - @pytest.fixture def random_path(self, tmp_path): return tmp_path / random_str(10) @@ -124,19 +123,19 @@ class TestAtomicWriteFile: def cpy_setup_class(scope, func_name): - if getattr(fileutils, 'native_%s' % func_name) \ - is getattr(fileutils, func_name): - scope['skip'] = 'extensions disabled' + if getattr(fileutils, "native_%s" % func_name) is getattr(fileutils, func_name): + scope["skip"] = "extensions disabled" else: - scope['func'] = staticmethod(getattr(fileutils, func_name)) + scope["func"] = staticmethod(getattr(fileutils, func_name)) + class Test_readfile: func = staticmethod(fileutils.readfile) - test_cases = ['asdf\nfdasswer\1923', '', '987234'] + test_cases = ["asdf\nfdasswer\1923", "", "987234"] - default_encoding = 'ascii' - none_on_missing_ret_data = 'dar' + default_encoding = "ascii" + none_on_missing_ret_data = "dar" @staticmethod def convert_data(data, encoding): @@ -147,7 +146,7 @@ class Test_readfile: return data def test_it(self, tmp_path): - fp = tmp_path / 'testfile' + fp = tmp_path / "testfile" for expected in self.test_cases: raised = None encoding = self.default_encoding @@ -168,16 +167,16 @@ class Test_readfile: assert self.func(path) == expected def test_none_on_missing(self, tmp_path): - fp = tmp_path / 'nonexistent' + fp = tmp_path / "nonexistent" with pytest.raises(FileNotFoundError): self.func(fp) assert self.func(fp, True) is None - fp.write_bytes(self.convert_data('dar', 'ascii')) + fp.write_bytes(self.convert_data("dar", "ascii")) assert self.func(fp, True) == self.none_on_missing_ret_data # ensure it handles paths that go through files- # still should be suppress - assert self.func(fp / 'extra', True) is None + assert self.func(fp / "extra", True) is None class Test_readfile_ascii(Test_readfile): @@ -186,85 +185,86 @@ class Test_readfile_ascii(Test_readfile): class Test_readfile_utf8(Test_readfile): func = staticmethod(fileutils.readfile_utf8) - default_encoding = 'utf8' + default_encoding = "utf8" class Test_readfile_bytes(Test_readfile): func = staticmethod(fileutils.readfile_bytes) default_encoding = None - test_cases = list(map( - currying.post_curry(Test_readfile.convert_data, 'ascii'), - Test_readfile.test_cases)) - test_cases.append('\ua000fa'.encode("utf8")) + test_cases = list( + map( + currying.post_curry(Test_readfile.convert_data, "ascii"), + Test_readfile.test_cases, + ) + ) + test_cases.append("\ua000fa".encode("utf8")) none_on_missing_ret_data = Test_readfile.convert_data( - Test_readfile.none_on_missing_ret_data, 'ascii') + Test_readfile.none_on_missing_ret_data, "ascii" + ) class readlines_mixin: - def assertFunc(self, path, expected): expected = tuple(expected.split()) - if expected == ('',): + if expected == ("",): expected = () - if 'utf8' not in self.encoding_mode: + if "utf8" not in self.encoding_mode: assert tuple(self.func(path)) == expected return assert tuple(self.func(path)) == expected def test_none_on_missing(self, tmp_path): - fp = tmp_path / 'nonexistent' + fp = tmp_path / "nonexistent" with pytest.raises(FileNotFoundError): self.func(fp) assert not tuple(self.func(fp, False, True)) - fp.write_bytes(self.convert_data('dar', 'ascii')) + fp.write_bytes(self.convert_data("dar", "ascii")) assert tuple(self.func(fp, True)) == (self.none_on_missing_ret_data,) - assert not tuple(self.func(fp / 'missing', False, True)) + assert not tuple(self.func(fp / "missing", False, True)) def test_strip_whitespace(self, tmp_path): - fp = tmp_path / 'data' + fp = tmp_path / "data" - fp.write_bytes(self.convert_data(' dar1 \ndar2 \n dar3\n', - 'ascii')) + fp.write_bytes(self.convert_data(" dar1 \ndar2 \n dar3\n", "ascii")) results = tuple(self.func(fp, True)) - expected = ('dar1', 'dar2', 'dar3') - if self.encoding_mode == 'bytes': + expected = ("dar1", "dar2", "dar3") + if self.encoding_mode == "bytes": expected = tuple(x.encode("ascii") for x in expected) assert results == expected # this time without the trailing newline... - fp.write_bytes(self.convert_data(' dar1 \ndar2 \n dar3', - 'ascii')) + fp.write_bytes(self.convert_data(" dar1 \ndar2 \n dar3", "ascii")) results = tuple(self.func(fp, True)) assert results == expected # test a couple of edgecases; underly c extension has gotten these # wrong before. - fp.write_bytes(self.convert_data('0', 'ascii')) + fp.write_bytes(self.convert_data("0", "ascii")) results = tuple(self.func(fp, True)) - expected = ('0',) - if self.encoding_mode == 'bytes': + expected = ("0",) + if self.encoding_mode == "bytes": expected = tuple(x.encode("ascii") for x in expected) assert results == expected - fp.write_bytes(self.convert_data('0\n', 'ascii')) + fp.write_bytes(self.convert_data("0\n", "ascii")) results = tuple(self.func(fp, True)) - expected = ('0',) - if self.encoding_mode == 'bytes': + expected = ("0",) + if self.encoding_mode == "bytes": expected = tuple(x.encode("ascii") for x in expected) assert results == expected - fp.write_bytes(self.convert_data('0 ', 'ascii')) + fp.write_bytes(self.convert_data("0 ", "ascii")) results = tuple(self.func(fp, True)) - expected = ('0',) - if self.encoding_mode == 'bytes': + expected = ("0",) + if self.encoding_mode == "bytes": expected = tuple(x.encode("ascii") for x in expected) assert results == expected def mk_readlines_test(scope, mode): - func_name = 'readlines_%s' % mode - base = globals()['Test_readfile_%s' % mode] + func_name = "readlines_%s" % mode + base = globals()["Test_readfile_%s" % mode] class kls(readlines_mixin, base): func = staticmethod(getattr(fileutils, func_name)) @@ -273,14 +273,15 @@ def mk_readlines_test(scope, mode): kls.__name__ = "Test_%s" % func_name scope["Test_%s" % func_name] = kls + for case in ("ascii", "bytes", "utf8"): - name = 'readlines_%s' % case + name = "readlines_%s" % case mk_readlines_test(locals(), case) class TestBrokenStats: - test_cases = ['/proc/crypto', '/sys/devices/system/cpu/present'] + test_cases = ["/proc/crypto", "/sys/devices/system/cpu/present"] def test_readfile(self): for path in self.test_cases: @@ -292,7 +293,7 @@ class TestBrokenStats: def _check_path(self, path, func, split_it=False): try: - with open(path, 'r') as handle: + with open(path, "r") as handle: data = handle.read() except EnvironmentError as e: if e.errno not in (errno.ENOENT, errno.EPERM): @@ -302,7 +303,7 @@ class TestBrokenStats: func_data = func(path) if split_it: func_data = list(func_data) - data = [x for x in data.split('\n') if x] + data = [x for x in data.split("\n") if x] func_data = [x for x in func_data if x] assert func_data == data @@ -313,13 +314,13 @@ class Test_mmap_or_open_for_read: func = staticmethod(fileutils.mmap_or_open_for_read) def test_zero_length(self, tmp_path): - (path := tmp_path / "target").write_text('') + (path := tmp_path / "target").write_text("") m, f = self.func(path) assert m is None - assert f.read() == b'' + assert f.read() == b"" f.close() - def test_mmap(self, tmp_path, data=b'foonani'): + def test_mmap(self, tmp_path, data=b"foonani"): (path := tmp_path / "target").write_bytes(data) m, f = self.func(path) assert len(m) == len(data) @@ -329,14 +330,14 @@ class Test_mmap_or_open_for_read: class Test_mmap_and_close: - def test_it(self, tmp_path): - (path := tmp_path / "target").write_bytes(data := b'asdfasdf') + (path := tmp_path / "target").write_bytes(data := b"asdfasdf") fd, m = None, None try: fd = os.open(path, os.O_RDONLY) m = _fileutils.mmap_and_close( - fd, len(data), mmap.MAP_PRIVATE, mmap.PROT_READ) + fd, len(data), mmap.MAP_PRIVATE, mmap.PROT_READ + ) # and ensure it closed the fd... with pytest.raises(EnvironmentError): os.read(fd, 1) diff --git a/tests/test_formatters.py b/tests/test_formatters.py index 266ef1e..549f2ad 100644 --- a/tests/test_formatters.py +++ b/tests/test_formatters.py @@ -18,16 +18,16 @@ class TestPlainTextFormatter: def test_basics(self): # As many sporks as fit in 20 chars. - sporks = ' '.join(3 * ('spork',)) + sporks = " ".join(3 * ("spork",)) for inputs, output in [ - (('\N{SNOWMAN}',), '?'), - ((7 * 'spork ',), '%s\n%s\n%s' % (sporks, sporks, 'spork ')), - (7 * ('spork ',), '%s \n%s \n%s' % (sporks, sporks, 'spork ')), - ((30 * 'a'), 20 * 'a' + '\n' + 10 * 'a'), - (30 * ('a',), 20 * 'a' + '\n' + 10 * 'a'), - ]: + (("\N{SNOWMAN}",), "?"), + ((7 * "spork ",), "%s\n%s\n%s" % (sporks, sporks, "spork ")), + (7 * ("spork ",), "%s \n%s \n%s" % (sporks, sporks, "spork ")), + ((30 * "a"), 20 * "a" + "\n" + 10 * "a"), + (30 * ("a",), 20 * "a" + "\n" + 10 * "a"), + ]: stream = BytesIO() - formatter = self.kls(stream, encoding='ascii') + formatter = self.kls(stream, encoding="ascii") formatter.width = 20 formatter.write(autoline=False, wrap=True, *inputs) assert output.encode() == stream.getvalue() @@ -35,69 +35,70 @@ class TestPlainTextFormatter: def test_first_prefix(self): # As many sporks as fit in 20 chars. for inputs, output in [ - (('\N{SNOWMAN}',), 'foon:?'), - ((7 * 'spork ',), - 'foon:spork spork\n' - 'spork spork spork\n' - 'spork spork '), - (7 * ('spork ',), - 'foon:spork spork \n' - 'spork spork spork \n' - 'spork spork '), - ((30 * 'a'), 'foon:' + 15 * 'a' + '\n' + 15 * 'a'), - (30 * ('a',), 'foon:' + 15 * 'a' + '\n' + 15 * 'a'), - ]: + (("\N{SNOWMAN}",), "foon:?"), + ( + (7 * "spork ",), + "foon:spork spork\n" "spork spork spork\n" "spork spork ", + ), + ( + 7 * ("spork ",), + "foon:spork spork \n" "spork spork spork \n" "spork spork ", + ), + ((30 * "a"), "foon:" + 15 * "a" + "\n" + 15 * "a"), + (30 * ("a",), "foon:" + 15 * "a" + "\n" + 15 * "a"), + ]: stream = BytesIO() - formatter = self.kls(stream, encoding='ascii') + formatter = self.kls(stream, encoding="ascii") formatter.width = 20 - formatter.write(autoline=False, wrap=True, first_prefix='foon:', *inputs) + formatter.write(autoline=False, wrap=True, first_prefix="foon:", *inputs) assert output.encode() == stream.getvalue() def test_later_prefix(self): for inputs, output in [ - (('\N{SNOWMAN}',), '?'), - ((7 * 'spork ',), - 'spork spork spork\n' - 'foon:spork spork\n' - 'foon:spork spork '), - (7 * ('spork ',), - 'spork spork spork \n' - 'foon:spork spork \n' - 'foon:spork spork '), - ((30 * 'a'), 20 * 'a' + '\n' + 'foon:' + 10 * 'a'), - (30 * ('a',), 20 * 'a' + '\n' + 'foon:' + 10 * 'a'), - ]: + (("\N{SNOWMAN}",), "?"), + ( + (7 * "spork ",), + "spork spork spork\n" "foon:spork spork\n" "foon:spork spork ", + ), + ( + 7 * ("spork ",), + "spork spork spork \n" "foon:spork spork \n" "foon:spork spork ", + ), + ((30 * "a"), 20 * "a" + "\n" + "foon:" + 10 * "a"), + (30 * ("a",), 20 * "a" + "\n" + "foon:" + 10 * "a"), + ]: stream = BytesIO() - formatter = self.kls(stream, encoding='ascii') + formatter = self.kls(stream, encoding="ascii") formatter.width = 20 - formatter.later_prefix = ['foon:'] + formatter.later_prefix = ["foon:"] formatter.write(wrap=True, autoline=False, *inputs) assert output.encode() == stream.getvalue() def test_complex(self): stream = BytesIO() - formatter = self.kls(stream, encoding='ascii') + formatter = self.kls(stream, encoding="ascii") formatter.width = 9 - formatter.first_prefix = ['foo', None, ' d'] - formatter.later_prefix = ['dorkey'] + formatter.first_prefix = ["foo", None, " d"] + formatter.later_prefix = ["dorkey"] formatter.write("dar bl", wrap=True, autoline=False) assert "foo ddar\ndorkeybl".encode() == stream.getvalue() - formatter.write(" "*formatter.width, wrap=True, autoline=True) + formatter.write(" " * formatter.width, wrap=True, autoline=True) formatter.stream = stream = BytesIO() formatter.write("dar", " b", wrap=True, autoline=False) assert "foo ddar\ndorkeyb".encode() == stream.getvalue() - output = \ -""" rdepends: >=dev-lang/python-2.3 >=sys-apps/sed-4.0.5 + output = """ rdepends: >=dev-lang/python-2.3 >=sys-apps/sed-4.0.5 dev-python/python-fchksum """ stream = BytesIO() - formatter = self.kls(stream, encoding='ascii', width=80) + formatter = self.kls(stream, encoding="ascii", width=80) formatter.wrap = True assert formatter.autoline assert formatter.width == 80 - formatter.later_prefix = [' '] - formatter.write(" rdepends: >=dev-lang/python-2.3 " - ">=sys-apps/sed-4.0.5 dev-python/python-fchksum") + formatter.later_prefix = [" "] + formatter.write( + " rdepends: >=dev-lang/python-2.3 " + ">=sys-apps/sed-4.0.5 dev-python/python-fchksum" + ) assert len(formatter.first_prefix) == 0 assert len(formatter.later_prefix) == 1 assert output.encode() == stream.getvalue() @@ -105,148 +106,176 @@ class TestPlainTextFormatter: formatter.stream = stream = BytesIO() # push it right up to the limit. formatter.width = 82 - formatter.write(" rdepends: >=dev-lang/python-2.3 " - ">=sys-apps/sed-4.0.5 dev-python/python-fchksum") + formatter.write( + " rdepends: >=dev-lang/python-2.3 " + ">=sys-apps/sed-4.0.5 dev-python/python-fchksum" + ) assert output.encode() == stream.getvalue() formatter.first_prefix = [] - formatter.later_prefix = [' '] + formatter.later_prefix = [" "] formatter.width = 28 formatter.autoline = False formatter.wrap = True formatter.stream = stream = BytesIO() input = (" description: ", "The Portage") formatter.write(*input) - output = ''.join(input).rsplit(" ", 1) - output[1] = ' %s' % output[1] - assert '\n'.join(output).encode() == stream.getvalue() - + output = "".join(input).rsplit(" ", 1) + output[1] = " %s" % output[1] + assert "\n".join(output).encode() == stream.getvalue() def test_wrap_autoline(self): for inputs, output in [ - ((3 * ('spork',)), 'spork\nspork\nspork\n'), - (3 * (('spork',),), 'spork\nspork\nspork\n'), - (((3 * 'spork',),), - '\n' - 'foonsporks\n' - 'foonporksp\n' - 'foonork\n'), - ((('fo',), (2 * 'spork',),), 'fo\nsporkspork\n'), - ((('fo',), (3 * 'spork',),), - 'fo\n' - '\n' - 'foonsporks\n' - 'foonporksp\n' - 'foonork\n'), - ]: + ((3 * ("spork",)), "spork\nspork\nspork\n"), + (3 * (("spork",),), "spork\nspork\nspork\n"), + (((3 * "spork",),), "\n" "foonsporks\n" "foonporksp\n" "foonork\n"), + ( + ( + ("fo",), + (2 * "spork",), + ), + "fo\nsporkspork\n", + ), + ( + ( + ("fo",), + (3 * "spork",), + ), + "fo\n" "\n" "foonsporks\n" "foonporksp\n" "foonork\n", + ), + ]: stream = BytesIO() - formatter = self.kls(stream, encoding='ascii') + formatter = self.kls(stream, encoding="ascii") formatter.width = 10 for input in inputs: - formatter.write(wrap=True, later_prefix='foon', *input) + formatter.write(wrap=True, later_prefix="foon", *input) assert output.encode() == stream.getvalue() class TerminfoFormatterTest: - def _test_stream(self, stream, formatter, inputs, output): stream.seek(0) stream.truncate() formatter.write(*inputs) stream.seek(0) result = stream.read() - output = ''.join(output) - assert output.encode() == result, \ - "given(%r), expected(%r), got(%r)" % (inputs, output, result) + output = "".join(output) + assert output.encode() == result, "given(%r), expected(%r), got(%r)" % ( + inputs, + output, + result, + ) @issue7567 def test_terminfo(self): - esc = '\x1b[' + esc = "\x1b[" stream = TemporaryFile() - f = formatters.TerminfoFormatter(stream, 'ansi', True, 'ascii') + f = formatters.TerminfoFormatter(stream, "ansi", True, "ascii") f.autoline = False for inputs, output in ( - ((f.bold, 'bold'), (esc, '1m', 'bold', esc, '0;10m')), - ((f.underline, 'underline'), - (esc, '4m', 'underline', esc, '0;10m')), - ((f.fg('red'), 'red'), (esc, '31m', 'red', esc, '39;49m')), - ((f.fg('red'), 'red', f.bold, 'boldred', f.fg(), 'bold', - f.reset, 'done'), - (esc, '31m', 'red', esc, '1m', 'boldred', esc, '39;49m', 'bold', - esc, '0;10m', 'done')), - ((42,), ('42',)), - (('\N{SNOWMAN}',), ('?',)) - ): + ((f.bold, "bold"), (esc, "1m", "bold", esc, "0;10m")), + ((f.underline, "underline"), (esc, "4m", "underline", esc, "0;10m")), + ((f.fg("red"), "red"), (esc, "31m", "red", esc, "39;49m")), + ( + ( + f.fg("red"), + "red", + f.bold, + "boldred", + f.fg(), + "bold", + f.reset, + "done", + ), + ( + esc, + "31m", + "red", + esc, + "1m", + "boldred", + esc, + "39;49m", + "bold", + esc, + "0;10m", + "done", + ), + ), + ((42,), ("42",)), + (("\N{SNOWMAN}",), ("?",)), + ): self._test_stream(stream, f, inputs, output) f.autoline = True - self._test_stream( - stream, f, ('lala',), ('lala', '\n')) + self._test_stream(stream, f, ("lala",), ("lala", "\n")) def test_unsupported_term(self): stream = TemporaryFile() with pytest.raises(formatters.TerminfoUnsupported): - formatters.TerminfoFormatter(stream, term='dumb') + formatters.TerminfoFormatter(stream, term="dumb") @issue7567 def test_title(self): stream = TemporaryFile() try: - f = formatters.TerminfoFormatter(stream, 'xterm+sl', True, 'ascii') + f = formatters.TerminfoFormatter(stream, "xterm+sl", True, "ascii") except curses.error: pytest.skip("xterm+sl not in terminfo db") - f.title('TITLE') + f.title("TITLE") stream.seek(0) - assert b'\x1b]0;TITLE\x07' == stream.read() + assert b"\x1b]0;TITLE\x07" == stream.read() def _with_term(term, func, *args, **kwargs): - orig_term = os.environ.get('TERM') + orig_term = os.environ.get("TERM") try: - os.environ['TERM'] = term + os.environ["TERM"] = term return func(*args, **kwargs) finally: if orig_term is None: - del os.environ['TERM'] + del os.environ["TERM"] else: - os.environ['TERM'] = orig_term + os.environ["TERM"] = orig_term + # XXX ripped from pkgcore's test_commandline -def _get_pty_pair(encoding='ascii'): +def _get_pty_pair(encoding="ascii"): master_fd, slave_fd = pty.openpty() - master = os.fdopen(master_fd, 'rb', 0) - out = os.fdopen(slave_fd, 'wb', 0) + master = os.fdopen(master_fd, "rb", 0) + out = os.fdopen(slave_fd, "wb", 0) return master, out -@pytest.mark.skip(reason='this currently breaks on github ci due to the issue7567 workaround') +@pytest.mark.skip( + reason="this currently breaks on github ci due to the issue7567 workaround" +) class TestGetFormatter: - @issue7567 def test_dumb_terminal(self): master, _out = _get_pty_pair() - formatter = _with_term('dumb', formatters.get_formatter, master) + formatter = _with_term("dumb", formatters.get_formatter, master) assert isinstance(formatter, formatters.PlainTextFormatter) @issue7567 def test_vt100_terminal(self): master, _out = _get_pty_pair() - formatter = _with_term('vt100', formatters.get_formatter, master) + formatter = _with_term("vt100", formatters.get_formatter, master) assert isinstance(formatter, formatters.PlainTextFormatter) @issue7567 def test_smart_terminal(self): master, _out = _get_pty_pair() - formatter = _with_term('xterm', formatters.get_formatter, master) + formatter = _with_term("xterm", formatters.get_formatter, master) assert isinstance(formatter, formatters.TerminfoFormatter) @issue7567 def test_not_a_tty(self): stream = TemporaryFile() - formatter = _with_term('xterm', formatters.get_formatter, stream) + formatter = _with_term("xterm", formatters.get_formatter, stream) assert isinstance(formatter, formatters.PlainTextFormatter) @issue7567 def test_no_fd(self): stream = BytesIO() - formatter = _with_term('xterm', formatters.get_formatter, stream) + formatter = _with_term("xterm", formatters.get_formatter, stream) assert isinstance(formatter, formatters.PlainTextFormatter) diff --git a/tests/test_iterables.py b/tests/test_iterables.py index 3345c5e..d0d5768 100644 --- a/tests/test_iterables.py +++ b/tests/test_iterables.py @@ -1,12 +1,10 @@ import operator import pytest -from snakeoil.iterables import (caching_iter, expandable_chain, iter_sort, - partition) +from snakeoil.iterables import caching_iter, expandable_chain, iter_sort, partition class TestPartition: - def test_empty(self): a, b = partition(()) assert list(a) == [] @@ -23,19 +21,18 @@ class TestPartition: class TestExpandableChain: - def test_normal_function(self): i = [iter(range(100)) for x in range(3)] e = expandable_chain() e.extend(i) - assert list(e) == list(range(100))*3 + assert list(e) == list(range(100)) * 3 for x in i + [e]: pytest.raises(StopIteration, x.__next__) def test_extend(self): e = expandable_chain() e.extend(range(100) for i in (1, 2)) - assert list(e) == list(range(100))*2 + assert list(e) == list(range(100)) * 2 with pytest.raises(StopIteration): e.extend([[]]) @@ -62,7 +59,6 @@ class TestExpandableChain: class TestCachingIter: - def test_iter_consumption(self): i = iter(range(100)) c = caching_iter(i) @@ -147,6 +143,7 @@ class Test_iter_sort: def test_ordering(self): def f(l): return sorted(l, key=operator.itemgetter(0)) + result = list(iter_sort(f, *[iter(range(x, x + 10)) for x in (30, 20, 0, 10)])) expected = list(range(40)) assert result == expected diff --git a/tests/test_klass.py b/tests/test_klass.py index 773925d..25728fa 100644 --- a/tests/test_klass.py +++ b/tests/test_klass.py @@ -14,7 +14,8 @@ class Test_GetAttrProxy: class foo1: def __init__(self, obj): self.obj = obj - __getattr__ = self.kls('obj') + + __getattr__ = self.kls("obj") class foo2: pass @@ -27,18 +28,18 @@ class Test_GetAttrProxy: o2.foon = "dar" assert o.foon == "dar" o.foon = "foo" - assert o.foon == 'foo' + assert o.foon == "foo" def test_attrlist(self): def make_class(attr_list=None): class foo(metaclass=self.kls): if attr_list is not None: - locals()['__attr_comparison__'] = attr_list + locals()["__attr_comparison__"] = attr_list with pytest.raises(TypeError): make_class() with pytest.raises(TypeError): - make_class(['foon']) + make_class(["foon"]) with pytest.raises(TypeError): make_class([None]) @@ -47,38 +48,39 @@ class Test_GetAttrProxy: bar = "baz" class Test: - method = self.kls('test') + method = self.kls("test") test = foo() test = Test() - assert test.method('bar') == foo.bar + assert test.method("bar") == foo.bar class TestDirProxy: - @staticmethod def noninternal_attrs(obj): - return sorted(x for x in dir(obj) if not re.match(r'__\w+__', x)) + return sorted(x for x in dir(obj) if not re.match(r"__\w+__", x)) def test_combined(self): class foo1: def __init__(self, obj): self.obj = obj - __dir__ = klass.DirProxy('obj') + + __dir__ = klass.DirProxy("obj") class foo2: def __init__(self): - self.attr = 'foo' + self.attr = "foo" o2 = foo2() o = foo1(o2) - assert self.noninternal_attrs(o) == ['attr', 'obj'] + assert self.noninternal_attrs(o) == ["attr", "obj"] def test_empty(self): class foo1: def __init__(self, obj): self.obj = obj - __dir__ = klass.DirProxy('obj') + + __dir__ = klass.DirProxy("obj") class foo2: pass @@ -86,23 +88,26 @@ class TestDirProxy: o2 = foo2() o = foo1(o2) assert self.noninternal_attrs(o2) == [] - assert self.noninternal_attrs(o) == ['obj'] + assert self.noninternal_attrs(o) == ["obj"] def test_slots(self): class foo1: - __slots__ = ('obj',) + __slots__ = ("obj",) + def __init__(self, obj): self.obj = obj - __dir__ = klass.DirProxy('obj') + + __dir__ = klass.DirProxy("obj") class foo2: - __slots__ = ('attr',) + __slots__ = ("attr",) + def __init__(self): - self.attr = 'foo' + self.attr = "foo" o2 = foo2() o = foo1(o2) - assert self.noninternal_attrs(o) == ['attr', 'obj'] + assert self.noninternal_attrs(o) == ["attr", "obj"] class Test_contains: @@ -111,6 +116,7 @@ class Test_contains: def test_it(self): class c(dict): __contains__ = self.func + d = c({"1": 2}) assert "1" in d assert 1 not in d @@ -122,6 +128,7 @@ class Test_get: def test_it(self): class c(dict): get = self.func + d = c({"1": 2}) assert d.get("1") == 2 assert d.get("1", 3) == 2 @@ -142,11 +149,13 @@ class Test_chained_getter: assert id(self.kls("fa2341fa")) == l[0] def test_eq(self): - assert self.kls("asdf", disable_inst_caching=True) == \ - self.kls("asdf", disable_inst_caching=True) + assert self.kls("asdf", disable_inst_caching=True) == self.kls( + "asdf", disable_inst_caching=True + ) - assert self.kls("asdf2", disable_inst_caching=True) != \ - self.kls("asdf", disable_inst_caching=True) + assert self.kls("asdf2", disable_inst_caching=True) != self.kls( + "asdf", disable_inst_caching=True + ) def test_it(self): class maze: @@ -159,13 +168,13 @@ class Test_chained_getter: d = {} m = maze(d) f = self.kls - assert f('foon')(m) == m + assert f("foon")(m) == m d["foon"] = 1 - assert f('foon')(m) == 1 - assert f('dar.foon')(m) == 1 - assert f('.'.join(['blah']*10))(m) == m + assert f("foon")(m) == 1 + assert f("dar.foon")(m) == 1 + assert f(".".join(["blah"] * 10))(m) == m with pytest.raises(AttributeError): - f('foon.dar')(m) + f("foon.dar")(m) class Test_jit_attr: @@ -184,23 +193,28 @@ class Test_jit_attr: def jit_attr_ext_method(self): return partial(klass.jit_attr_ext_method, kls=self.kls) - def mk_inst(self, attrname='_attr', method_lookup=False, - use_cls_setattr=False, func=None, - singleton=klass._uncached_singleton): + def mk_inst( + self, + attrname="_attr", + method_lookup=False, + use_cls_setattr=False, + func=None, + singleton=klass._uncached_singleton, + ): f = func if not func: + def f(self): self._invokes.append(self) return 54321 class cls: - def __init__(self): sf = partial(object.__setattr__, self) - sf('_sets', []) - sf('_reflects', []) - sf('_invokes', []) + sf("_sets", []) + sf("_reflects", []) + sf("_invokes", []) attr = self.kls(f, attrname, singleton, use_cls_setattr) @@ -219,13 +233,22 @@ class Test_jit_attr: sets = [instance] * sets reflects = [instance] * reflects invokes = [instance] * invokes - msg = ("checking %s: got(%r), expected(%r); state was sets=%r, " - "reflects=%r, invokes=%r" % ( - "%s", "%s", "%s", instance._sets, instance._reflects, - instance._invokes)) + msg = ( + "checking %s: got(%r), expected(%r); state was sets=%r, " + "reflects=%r, invokes=%r" + % ("%s", "%s", "%s", instance._sets, instance._reflects, instance._invokes) + ) assert instance._sets == sets, msg % ("sets", instance._sets, sets) - assert instance._reflects == reflects, msg % ("reflects", instance._reflects, reflects) - assert instance._invokes == invokes, msg % ("invokes", instance._invokes, invokes) + assert instance._reflects == reflects, msg % ( + "reflects", + instance._reflects, + reflects, + ) + assert instance._invokes == invokes, msg % ( + "invokes", + instance._invokes, + invokes, + ) def test_implementation(self): obj = self.mk_inst() @@ -298,7 +321,7 @@ class Test_jit_attr: object.__setattr__(self, attr, value) o = cls() - assert not hasattr(o, 'invoked') + assert not hasattr(o, "invoked") assert o.my_attr == now assert o._blah2 == now assert o.invoked @@ -315,34 +338,34 @@ class Test_jit_attr: return now2 def __setattr__(self, attr, value): - if not getattr(self, '_setattr_allowed', False): + if not getattr(self, "_setattr_allowed", False): raise TypeError("setattr isn't allowed for %s" % attr) object.__setattr__(self, attr, value) - base.attr = self.jit_attr_ext_method('f1', '_attr') + base.attr = self.jit_attr_ext_method("f1", "_attr") o = base() assert o.attr == now assert o._attr == now assert o.attr == now - base.attr = self.jit_attr_ext_method('f1', '_attr', use_cls_setattr=True) + base.attr = self.jit_attr_ext_method("f1", "_attr", use_cls_setattr=True) o = base() with pytest.raises(TypeError): - getattr(o, 'attr') + getattr(o, "attr") base._setattr_allowed = True assert o.attr == now - base.attr = self.jit_attr_ext_method('f2', '_attr2') + base.attr = self.jit_attr_ext_method("f2", "_attr2") o = base() assert o.attr == now2 assert o._attr2 == now2 # finally, check that it's doing lookups rather then storing the func. - base.attr = self.jit_attr_ext_method('func', '_attr2') + base.attr = self.jit_attr_ext_method("func", "_attr2") o = base() # no func... with pytest.raises(AttributeError): - getattr(o, 'attr') + getattr(o, "attr") base.func = base.f1 assert o.attr == now assert o._attr2 == now @@ -354,7 +377,13 @@ class Test_jit_attr: def test_check_singleton_is_compare(self): def throw_assert(*args, **kwds): - raise AssertionError("I shouldn't be invoked: %s, %s" % (args, kwds,)) + raise AssertionError( + "I shouldn't be invoked: %s, %s" + % ( + args, + kwds, + ) + ) class puker: __eq__ = throw_assert @@ -369,11 +398,13 @@ class Test_jit_attr: def test_cached_property(self): l = [] + class foo: @klass.cached_property def blah(self, l=l, i=iter(range(5))): l.append(None) return next(i) + f = foo() assert f.blah == 0 assert len(l) == 1 @@ -413,15 +444,15 @@ class Test_aliased_attr: o = cls() with pytest.raises(AttributeError): - getattr(o, 'attr') + getattr(o, "attr") o.dar = "foon" with pytest.raises(AttributeError): - getattr(o, 'attr') + getattr(o, "attr") o.dar = o o.blah = "monkey" - assert o.attr == 'monkey' + assert o.attr == "monkey" # verify it'll cross properties... class blah: @@ -431,6 +462,7 @@ class Test_aliased_attr: @property def foon(self): return blah() + alias = self.func("foon.target") o = cls() @@ -442,12 +474,15 @@ class Test_cached_hash: def test_it(self): now = int(time()) + class cls: invoked = [] + @self.func def __hash__(self): self.invoked.append(self) return now + o = cls() assert hash(o) == now assert o.invoked == [o] @@ -462,7 +497,7 @@ class Test_reflective_hash: def test_it(self): class cls: - __hash__ = self.func('_hash') + __hash__ = self.func("_hash") obj = cls() with pytest.raises(AttributeError): @@ -477,7 +512,8 @@ class Test_reflective_hash: hash(obj) class cls2: - __hash__ = self.func('_dar') + __hash__ = self.func("_dar") + obj = cls2() with pytest.raises(AttributeError): hash(obj) @@ -486,7 +522,6 @@ class Test_reflective_hash: class TestImmutableInstance: - def test_metaclass(self): self.common_test(lambda x: x, metaclass=klass.immutable_instance) @@ -506,7 +541,7 @@ class TestImmutableInstance: with pytest.raises(AttributeError): delattr(o, "dar") - object.__setattr__(o, 'dar', 'foon') + object.__setattr__(o, "dar", "foon") with pytest.raises(AttributeError): delattr(o, "dar") @@ -541,7 +576,6 @@ class TestAliasMethod: class TestPatch: - def setup_method(self, method): # cache original methods self._math_ceil = math.ceil @@ -556,7 +590,7 @@ class TestPatch: n = 0.1 assert math.ceil(n) == 1 - @klass.patch('math.ceil') + @klass.patch("math.ceil") def ceil(orig_ceil, n): return math.floor(n) @@ -567,8 +601,8 @@ class TestPatch: assert math.ceil(n) == 2 assert math.floor(n) == 1 - @klass.patch('math.ceil') - @klass.patch('math.floor') + @klass.patch("math.ceil") + @klass.patch("math.floor") def zero(orig_func, n): return 0 diff --git a/tests/test_mappings.py b/tests/test_mappings.py index b1aef25..1ffe780 100644 --- a/tests/test_mappings.py +++ b/tests/test_mappings.py @@ -10,7 +10,6 @@ def a_dozen(): class BasicDict(mappings.DictMixin): - def __init__(self, i=None, **kwargs): self._d = {} mappings.DictMixin.__init__(self, i, **kwargs) @@ -20,7 +19,6 @@ class BasicDict(mappings.DictMixin): class MutableDict(BasicDict): - def __setitem__(self, key, val): self._d[key] = val @@ -36,7 +34,6 @@ class ImmutableDict(BasicDict): class TestDictMixin: - def test_immutability(self): d = ImmutableDict() pytest.raises(AttributeError, d.__setitem__, "spork", "foon") @@ -59,12 +56,12 @@ class TestDictMixin: pytest.raises(KeyError, d.pop, "spork") assert d.pop("spork", "bat") == "bat" assert d.pop("foo") == "bar" - assert d.popitem(), ("baz" == "cat") + assert d.popitem(), "baz" == "cat" pytest.raises(KeyError, d.popitem) assert d.pop("nonexistent", None) == None def test_init(self): - d = MutableDict((('foo', 'bar'), ('spork', 'foon')), baz="cat") + d = MutableDict((("foo", "bar"), ("spork", "foon")), baz="cat") assert d["foo"] == "bar" assert d["baz"] == "cat" d.clear() @@ -73,19 +70,20 @@ class TestDictMixin: def test_bool(self): d = MutableDict() assert not d - d['x'] = 1 + d["x"] = 1 assert d - del d['x'] + del d["x"] assert not d class RememberingNegateMixin: - def setup_method(self, method): self.negate_calls = [] + def negate(i): self.negate_calls.append(i) return -i + self.negate = negate def teardown_method(self, method): @@ -94,7 +92,6 @@ class RememberingNegateMixin: class LazyValDictTestMixin: - def test_invalid_operations(self): pytest.raises(AttributeError, operator.setitem, self.dict, 7, 7) pytest.raises(AttributeError, operator.delitem, self.dict, 7) @@ -118,6 +115,7 @@ class LazyValDictTestMixin: # missing key def get(): return self.dict[42] + pytest.raises(KeyError, get) def test_caching(self): @@ -129,7 +127,6 @@ class LazyValDictTestMixin: class TestLazyValDictWithList(LazyValDictTestMixin, RememberingNegateMixin): - def setup_method(self, method): super().setup_method(method) self.dict = mappings.LazyValDict(list(range(12)), self.negate) @@ -148,14 +145,12 @@ class TestLazyValDictWithList(LazyValDictTestMixin, RememberingNegateMixin): class TestLazyValDictWithFunc(LazyValDictTestMixin, RememberingNegateMixin): - def setup_method(self, method): super().setup_method(method) self.dict = mappings.LazyValDict(a_dozen, self.negate) class TestLazyValDict: - def test_invalid_init_args(self): pytest.raises(TypeError, mappings.LazyValDict, [1], 42) pytest.raises(TypeError, mappings.LazyValDict, 42, a_dozen) @@ -164,36 +159,43 @@ class TestLazyValDict: # TODO check for valid values for dict.new, since that seems to be # part of the interface? class TestProtectedDict: - def setup_method(self, method): self.orig = {1: -1, 2: -2} self.dict = mappings.ProtectedDict(self.orig) def test_basic_operations(self): assert self.dict[1] == -1 + def get(i): return self.dict[i] + pytest.raises(KeyError, get, 3) assert sorted(self.dict.keys()) == [1, 2] assert -1 not in self.dict assert 2 in self.dict + def remove(i): del self.dict[i] + pytest.raises(KeyError, remove, 50) def test_basic_mutating(self): # add something self.dict[7] = -7 + def check_after_adding(): assert self.dict[7] == -7 assert 7 in self.dict assert sorted(self.dict.keys()) == [1, 2, 7] + check_after_adding() # remove it again del self.dict[7] assert 7 not in self.dict + def get(i): return self.dict[i] + pytest.raises(KeyError, get, 7) assert sorted(self.dict.keys()) == [1, 2] # add it back @@ -214,7 +216,6 @@ class TestProtectedDict: class TestImmutableDict: - def test_init_iterator(self): d = mappings.ImmutableDict((x, x) for x in range(3)) assert dict(d) == {0: 0, 1: 1, 2: 2} @@ -239,7 +240,7 @@ class TestImmutableDict: def test_init_dictmixin(self): d = MutableDict(baz="cat") e = mappings.ImmutableDict(d) - assert dict(d) == {'baz': 'cat'} + assert dict(d) == {"baz": "cat"} def test_init_bad_data(self): for data in (range(10), list(range(10)), [([], 1)]): @@ -288,7 +289,6 @@ class TestImmutableDict: class TestOrderedFrozenSet: - def test_magic_methods(self): s = mappings.OrderedFrozenSet(range(9)) for x in range(9): @@ -299,7 +299,7 @@ class TestOrderedFrozenSet: for i in range(9): assert s[i] == i assert list(s[1:]) == list(range(1, 9)) - with pytest.raises(IndexError, match='index out of range'): + with pytest.raises(IndexError, match="index out of range"): s[9] assert s == set(range(9)) @@ -308,12 +308,12 @@ class TestOrderedFrozenSet: assert hash(s) def test_ordering(self): - s = mappings.OrderedFrozenSet('set') - assert 'set' == ''.join(s) - assert 'tes' == ''.join(reversed(s)) - s = mappings.OrderedFrozenSet('setordered') - assert 'setord' == ''.join(s) - assert 'drotes' == ''.join(reversed(s)) + s = mappings.OrderedFrozenSet("set") + assert "set" == "".join(s) + assert "tes" == "".join(reversed(s)) + s = mappings.OrderedFrozenSet("setordered") + assert "setord" == "".join(s) + assert "drotes" == "".join(reversed(s)) def test_immmutability(self): s = mappings.OrderedFrozenSet(range(9)) @@ -355,41 +355,40 @@ class TestOrderedFrozenSet: class TestOrderedSet(TestOrderedFrozenSet): - def test_hash(self): with pytest.raises(TypeError): - assert hash(mappings.OrderedSet('set')) + assert hash(mappings.OrderedSet("set")) def test_add(self): s = mappings.OrderedSet() - s.add('a') - assert 'a' in s + s.add("a") + assert "a" in s s.add(1) assert 1 in s - assert list(s) == ['a', 1] + assert list(s) == ["a", 1] def test_discard(self): s = mappings.OrderedSet() - s.discard('a') - s.add('a') + s.discard("a") + s.add("a") assert s - s.discard('a') + s.discard("a") assert not s def test_remove(self): s = mappings.OrderedSet() with pytest.raises(KeyError): - s.remove('a') - s.add('a') - assert 'a' in s - s.remove('a') - assert 'a' not in s + s.remove("a") + s.add("a") + assert "a" in s + s.remove("a") + assert "a" not in s def test_clear(self): s = mappings.OrderedSet() s.clear() assert len(s) == 0 - s.add('a') + s.add("a") assert len(s) == 1 s.clear() assert len(s) == 0 @@ -425,8 +424,9 @@ class TestStackedDict: assert x in std def test_len(self): - assert sum(map(len, (self.orig_dict, self.new_dict))) == \ - len(mappings.StackedDict(self.orig_dict, self.new_dict)) + assert sum(map(len, (self.orig_dict, self.new_dict))) == len( + mappings.StackedDict(self.orig_dict, self.new_dict) + ) def test_setattr(self): pytest.raises(TypeError, mappings.StackedDict().__setitem__, (1, 2)) @@ -447,24 +447,28 @@ class TestStackedDict: assert len(s) == 0 def test_keys(self): - assert sorted(mappings.StackedDict(self.orig_dict, self.new_dict)) == \ - sorted(list(self.orig_dict.keys()) + list(self.new_dict.keys())) + assert sorted(mappings.StackedDict(self.orig_dict, self.new_dict)) == sorted( + list(self.orig_dict.keys()) + list(self.new_dict.keys()) + ) class TestIndeterminantDict: - def test_disabled_methods(self): d = mappings.IndeterminantDict(lambda *a: None) for x in ( - "clear", - ("update", {}), - ("setdefault", 1), - "__iter__", "__len__", "__hash__", - ("__delitem__", 1), - ("__setitem__", 2), - ("popitem", 2), - "keys", "items", "values", - ): + "clear", + ("update", {}), + ("setdefault", 1), + "__iter__", + "__len__", + "__hash__", + ("__delitem__", 1), + ("__setitem__", 2), + ("popitem", 2), + "keys", + "items", + "values", + ): if isinstance(x, tuple): pytest.raises(TypeError, getattr(d, x[0]), x[1]) else: @@ -472,7 +476,8 @@ class TestIndeterminantDict: def test_starter_dict(self): d = mappings.IndeterminantDict( - lambda key: False, starter_dict={}.fromkeys(range(100), True)) + lambda key: False, starter_dict={}.fromkeys(range(100), True) + ) for x in range(100): assert d[x] == True for x in range(100, 110): @@ -481,21 +486,24 @@ class TestIndeterminantDict: def test_behaviour(self): val = [] d = mappings.IndeterminantDict( - lambda key: val.append(key), {}.fromkeys(range(10), True)) + lambda key: val.append(key), {}.fromkeys(range(10), True) + ) assert d[0] == True assert d[11] == None assert val == [11] + def func(*a): raise KeyError + with pytest.raises(KeyError): mappings.IndeterminantDict(func).__getitem__(1) - def test_get(self): def func(key): if key == 2: raise KeyError return True + d = mappings.IndeterminantDict(func, {1: 1}) assert d.get(1, 1) == 1 assert d.get(1, 2) == 1 @@ -505,41 +513,42 @@ class TestIndeterminantDict: class TestFoldingDict: - def test_preserve(self): dct = mappings.PreservingFoldingDict( - str.lower, list({'Foo': 'bar', 'fnz': 'donkey'}.items())) - assert dct['fnz'] == 'donkey' - assert dct['foo'] == 'bar' - assert sorted(['bar' == 'donkey']), sorted(dct.values()) + str.lower, list({"Foo": "bar", "fnz": "donkey"}.items()) + ) + assert dct["fnz"] == "donkey" + assert dct["foo"] == "bar" + assert sorted(["bar" == "donkey"]), sorted(dct.values()) assert dct.copy() == dct - assert dct['foo'] == dct.get('Foo') - assert 'foo' in dct - keys = ['Foo', 'fnz'] + assert dct["foo"] == dct.get("Foo") + assert "foo" in dct + keys = ["Foo", "fnz"] keysList = list(dct) for key in keys: assert key in list(dct.keys()) assert key in keysList assert (key, dct[key]) in list(dct.items()) assert len(keys) == len(dct) - assert dct.pop('foo') == 'bar' - assert 'foo' not in dct - del dct['fnz'] - assert 'fnz' not in dct - dct['Foo'] = 'bar' + assert dct.pop("foo") == "bar" + assert "foo" not in dct + del dct["fnz"] + assert "fnz" not in dct + dct["Foo"] = "bar" dct.refold(lambda _: _) - assert 'foo' not in dct - assert 'Foo' in dct - assert list(dct.items()) == [('Foo', 'bar')] + assert "foo" not in dct + assert "Foo" in dct + assert list(dct.items()) == [("Foo", "bar")] dct.clear() assert {} == dict(dct) def test_no_preserve(self): dct = mappings.NonPreservingFoldingDict( - str.lower, list({'Foo': 'bar', 'fnz': 'monkey'}.items())) - assert sorted(['bar', 'monkey']) == sorted(dct.values()) + str.lower, list({"Foo": "bar", "fnz": "monkey"}.items()) + ) + assert sorted(["bar", "monkey"]) == sorted(dct.values()) assert dct.copy() == dct - keys = ['foo', 'fnz'] + keys = ["foo", "fnz"] keysList = [key for key in dct] for key in keys: assert key in list(dct.keys()) @@ -547,8 +556,8 @@ class TestFoldingDict: assert key in keysList assert (key, dct[key]) in list(dct.items()) assert len(keys) == len(dct) - assert dct.pop('foo') == 'bar' - del dct['fnz'] + assert dct.pop("foo") == "bar" + del dct["fnz"] assert list(dct.keys()) == [] dct.clear() assert {} == dict(dct) @@ -580,20 +589,20 @@ class Test_attr_to_item_mapping: if kls is None: kls = self.kls o = kls(f=2, g=3) - assert ['f', 'g'] == sorted(o) - self.assertBoth(o, 'g', 3) + assert ["f", "g"] == sorted(o) + self.assertBoth(o, "g", 3) o.g = 4 - self.assertBoth(o, 'g', 4) + self.assertBoth(o, "g", 4) del o.g with pytest.raises(KeyError): - operator.__getitem__(o, 'g') + operator.__getitem__(o, "g") with pytest.raises(AttributeError): - getattr(o, 'g') - del o['f'] + getattr(o, "g") + del o["f"] with pytest.raises(KeyError): - operator.__getitem__(o, 'f') + operator.__getitem__(o, "f") with pytest.raises(AttributeError): - getattr(o, 'f') + getattr(o, "f") def test_inject(self): class foon(dict): @@ -611,30 +620,31 @@ class Test_ProxiedAttrs: def __init__(self, **kwargs): for attr, val in kwargs.items(): setattr(self, attr, val) + obj = foo() d = self.kls(obj) with pytest.raises(KeyError): - operator.__getitem__(d, 'x') + operator.__getitem__(d, "x") with pytest.raises(KeyError): - operator.__delitem__(d, 'x') - assert 'x' not in d - d['x'] = 1 - assert d['x'] == 1 - assert 'x' in d - assert ['x'] == list(x for x in d if not x.startswith("__")) - del d['x'] - assert 'x' not in d + operator.__delitem__(d, "x") + assert "x" not in d + d["x"] = 1 + assert d["x"] == 1 + assert "x" in d + assert ["x"] == list(x for x in d if not x.startswith("__")) + del d["x"] + assert "x" not in d with pytest.raises(KeyError): - operator.__delitem__(d, 'x') + operator.__delitem__(d, "x") with pytest.raises(KeyError): - operator.__getitem__(d, 'x') + operator.__getitem__(d, "x") # Finally, verify that immutable attribute errors are handled correctly. d = self.kls(object()) with pytest.raises(KeyError): - operator.__setitem__(d, 'x', 1) + operator.__setitem__(d, "x", 1) with pytest.raises(KeyError): - operator.__delitem__(d, 'x') + operator.__delitem__(d, "x") class TestSlottedDict: @@ -642,9 +652,9 @@ class TestSlottedDict: kls = staticmethod(mappings.make_SlottedDict_kls) def test_exceptions(self): - d = self.kls(['spork'])() + d = self.kls(["spork"])() for op in (operator.getitem, operator.delitem): with pytest.raises(KeyError): - op(d, 'spork') + op(d, "spork") with pytest.raises(KeyError): - op(d, 'foon') + op(d, "foon") diff --git a/tests/test_modules.py b/tests/test_modules.py index f417497..da20f4f 100644 --- a/tests/test_modules.py +++ b/tests/test_modules.py @@ -5,19 +5,18 @@ from snakeoil import modules class TestModules: - @pytest.fixture(autouse=True) def _setup(self, tmp_path): # set up some test modules for our use - packdir = tmp_path / 'mod_testpack' + packdir = tmp_path / "mod_testpack" packdir.mkdir() # create an empty file - (packdir / '__init__.py').touch() + (packdir / "__init__.py").touch() for directory in (tmp_path, packdir): for i in range(3): - (directory / f'mod_test{i}.py').write_text('def foo(): pass\n') - (directory / 'mod_horked.py').write_text('1/0\n') + (directory / f"mod_test{i}.py").write_text("def foo(): pass\n") + (directory / "mod_horked.py").write_text("1/0\n") # append them to path sys.path.insert(0, str(tmp_path)) @@ -27,89 +26,93 @@ class TestModules: sys.path.pop(0) # make sure we don't keep the sys.modules entries around for i in range(3): - sys.modules.pop('mod_test%s' % i, None) - sys.modules.pop('mod_testpack.mod_test%s' % i, None) - sys.modules.pop('mod_testpack', None) - sys.modules.pop('mod_horked', None) - sys.modules.pop('mod_testpack.mod_horked', None) + sys.modules.pop("mod_test%s" % i, None) + sys.modules.pop("mod_testpack.mod_test%s" % i, None) + sys.modules.pop("mod_testpack", None) + sys.modules.pop("mod_horked", None) + sys.modules.pop("mod_testpack.mod_horked", None) def test_load_module(self): # import an already-imported module - assert modules.load_module('snakeoil.modules') is modules + assert modules.load_module("snakeoil.modules") is modules # and a system one, just for kicks - assert modules.load_module('sys') is sys + assert modules.load_module("sys") is sys # non-existing module from an existing package with pytest.raises(modules.FailedImport): - modules.load_module('snakeoil.__not_there') + modules.load_module("snakeoil.__not_there") # (hopefully :) non-existing top-level module/package with pytest.raises(modules.FailedImport): - modules.load_module('__not_there') + modules.load_module("__not_there") # "Unable to import" # pylint: disable=F0401 # unimported toplevel module - modtest1 = modules.load_module('mod_test1') + modtest1 = modules.load_module("mod_test1") import mod_test1 + assert mod_test1 is modtest1 # unimported in-package module - packtest2 = modules.load_module('mod_testpack.mod_test2') + packtest2 = modules.load_module("mod_testpack.mod_test2") from mod_testpack import mod_test2 + assert mod_test2 is packtest2 def test_load_attribute(self): # already imported - assert modules.load_attribute('sys.path') is sys.path + assert modules.load_attribute("sys.path") is sys.path # unimported - myfoo = modules.load_attribute('mod_testpack.mod_test2.foo') + myfoo = modules.load_attribute("mod_testpack.mod_test2.foo") # "Unable to import" # pylint: disable=F0401 from mod_testpack.mod_test2 import foo + assert foo is myfoo # nonexisting attribute with pytest.raises(modules.FailedImport): - modules.load_attribute('snakeoil.froznicator') + modules.load_attribute("snakeoil.froznicator") # nonexisting top-level with pytest.raises(modules.FailedImport): - modules.load_attribute('spork_does_not_exist.foo') + modules.load_attribute("spork_does_not_exist.foo") # not an attr with pytest.raises(modules.FailedImport): - modules.load_attribute('sys') + modules.load_attribute("sys") # not imported yet with pytest.raises(modules.FailedImport): - modules.load_attribute('mod_testpack.mod_test3') + modules.load_attribute("mod_testpack.mod_test3") def test_load_any(self): # import an already-imported module - assert modules.load_any('snakeoil.modules') is modules + assert modules.load_any("snakeoil.modules") is modules # attribute of an already imported module - assert modules.load_any('sys.path') is sys.path + assert modules.load_any("sys.path") is sys.path # already imported toplevel. - assert sys is modules.load_any('sys') + assert sys is modules.load_any("sys") # unimported - myfoo = modules.load_any('mod_testpack.mod_test2.foo') + myfoo = modules.load_any("mod_testpack.mod_test2.foo") # "Unable to import" # pylint: disable=F0401 from mod_testpack.mod_test2 import foo + assert foo is myfoo # nonexisting attribute with pytest.raises(modules.FailedImport): - modules.load_any('snakeoil.froznicator') + modules.load_any("snakeoil.froznicator") # nonexisting top-level with pytest.raises(modules.FailedImport): - modules.load_any('spork_does_not_exist.foo') + modules.load_any("spork_does_not_exist.foo") with pytest.raises(modules.FailedImport): - modules.load_any('spork_does_not_exist') + modules.load_any("spork_does_not_exist") # not imported yet with pytest.raises(modules.FailedImport): - modules.load_any('mod_testpack.mod_test3') + modules.load_any("mod_testpack.mod_test3") def test_broken_module(self): for func in [modules.load_module, modules.load_any]: with pytest.raises(modules.FailedImport): - func('mod_testpack.mod_horked') - assert 'mod_testpack.mod_horked' not in sys.modules + func("mod_testpack.mod_horked") + assert "mod_testpack.mod_horked" not in sys.modules diff --git a/tests/test_obj.py b/tests/test_obj.py index 83f9f77..78083d8 100644 --- a/tests/test_obj.py +++ b/tests/test_obj.py @@ -7,7 +7,6 @@ make_DIkls = obj.DelayedInstantiation_kls class TestDelayedInstantiation: - def test_simple(self): t = tuple([1, 2, 3]) o = make_DI(tuple, lambda: t) @@ -19,19 +18,34 @@ class TestDelayedInstantiation: assert t >= o def test_descriptor_awareness(self): - def assertKls(cls, ignores=(), - default_ignores=("__new__", "__init__", "__init_subclass__", - "__getattribute__", "__class__", - "__getnewargs__", "__getstate__", - "__doc__", "__class_getitem__")): - required = set(x for x in dir(cls) - if x.startswith("__") and x.endswith("__")) + def assertKls( + cls, + ignores=(), + default_ignores=( + "__new__", + "__init__", + "__init_subclass__", + "__getattribute__", + "__class__", + "__getnewargs__", + "__getstate__", + "__doc__", + "__class_getitem__", + ), + ): + required = set( + x for x in dir(cls) if x.startswith("__") and x.endswith("__") + ) missing = required.difference(obj.kls_descriptors) missing.difference_update(obj.base_kls_descriptors) missing.difference_update(default_ignores) missing.difference_update(ignores) - assert not missing, ("object %r potentially has unsupported special " - "attributes: %s" % (cls, ', '.join(missing))) + assert ( + not missing + ), "object %r potentially has unsupported special " "attributes: %s" % ( + cls, + ", ".join(missing), + ) assertKls(object) assertKls(1) @@ -43,25 +57,38 @@ class TestDelayedInstantiation: def test_BaseDelayedObject(self): # assert that all methods/descriptors of object # are covered via the base. - o = set(dir(object)).difference(f"__{x}__" for x in ( - "class", "getattribute", "new", "init", "init_subclass", "getstate", "doc")) + o = set(dir(object)).difference( + f"__{x}__" + for x in ( + "class", + "getattribute", + "new", + "init", + "init_subclass", + "getstate", + "doc", + ) + ) diff = o.difference(obj.base_kls_descriptors) - assert not diff, ("base delayed instantiation class should cover all of object, but " - "%r was spotted" % (",".join(sorted(diff)),)) + assert not diff, ( + "base delayed instantiation class should cover all of object, but " + "%r was spotted" % (",".join(sorted(diff)),) + ) assert obj.DelayedInstantiation_kls(int, "1") + 2 == 3 - def test_klass_choice_optimization(self): """ensure that BaseDelayedObject is used whenever possible""" # note object is an odd one- it actually has a __doc__, thus # it must always be a custom o = make_DI(object, object) - assert object.__getattribute__(o, '__class__') is not obj.BaseDelayedObject + assert object.__getattribute__(o, "__class__") is not obj.BaseDelayedObject + class foon: pass + o = make_DI(foon, foon) - cls = object.__getattribute__(o, '__class__') + cls = object.__getattribute__(o, "__class__") assert cls is obj.BaseDelayedObject # now ensure we always get the same kls back for derivatives @@ -70,39 +97,43 @@ class TestDelayedInstantiation: return True o = make_DI(foon, foon) - cls = object.__getattribute__(o, '__class__') + cls = object.__getattribute__(o, "__class__") assert cls is not obj.BaseDelayedObject o = make_DI(foon, foon) - cls2 = object.__getattribute__(o, '__class__') + cls2 = object.__getattribute__(o, "__class__") assert cls is cls2 def test__class__(self): l = [] + def f(): l.append(False) return True + o = make_DI(bool, f) assert isinstance(o, bool) assert not l, "accessing __class__ shouldn't trigger instantiation" def test__doc__(self): l = [] + def f(): l.append(True) return foon() + class foon: __doc__ = "monkey" o = make_DI(foon, f) - assert o.__doc__ == 'monkey' + assert o.__doc__ == "monkey" assert not l, ( "in accessing __doc__, the instance was generated- " "this is a class level attribute, thus shouldn't " - "trigger instantiation") + "trigger instantiation" + ) class TestPopattr: - class Object: pass @@ -113,21 +144,21 @@ class TestPopattr: def test_no_attrs(self): # object without any attrs with pytest.raises(AttributeError): - obj.popattr(object(), 'nonexistent') + obj.popattr(object(), "nonexistent") def test_nonexistent_attr(self): # object with attr trying to get nonexistent attr with pytest.raises(AttributeError): - obj.popattr(self.o, 'nonexistent') + obj.popattr(self.o, "nonexistent") def test_fallback(self): # object with attr trying to get nonexistent attr using fallback - value = obj.popattr(self.o, 'nonexistent', 2) + value = obj.popattr(self.o, "nonexistent", 2) assert value == 2 def test_removed_attr(self): - value = obj.popattr(self.o, 'test') + value = obj.popattr(self.o, "test") assert value == 1 # verify that attr was removed from the object with pytest.raises(AttributeError): - obj.popattr(self.o, 'test') + obj.popattr(self.o, "test") diff --git a/tests/test_osutils.py b/tests/test_osutils.py index 1809282..264d670 100644 --- a/tests/test_osutils.py +++ b/tests/test_osutils.py @@ -16,44 +16,45 @@ from snakeoil.osutils.mount import MNT_DETACH, MS_BIND, mount, umount class ReaddirCommon: - @pytest.fixture def subdir(self, tmp_path): - subdir = tmp_path / 'dir' + subdir = tmp_path / "dir" subdir.mkdir() - (tmp_path / 'file').touch() - os.mkfifo((tmp_path / 'fifo')) + (tmp_path / "file").touch() + os.mkfifo((tmp_path / "fifo")) return subdir def _test_missing(self, tmp_path, funcs): for func in funcs: - pytest.raises(OSError, func, tmp_path / 'spork') + pytest.raises(OSError, func, tmp_path / "spork") class TestNativeListDir(ReaddirCommon): - def test_listdir(self, tmp_path, subdir): - assert set(native_readdir.listdir(tmp_path)) == {'dir', 'fifo', 'file'} + assert set(native_readdir.listdir(tmp_path)) == {"dir", "fifo", "file"} assert native_readdir.listdir(subdir) == [] def test_listdir_dirs(self, tmp_path, subdir): - assert native_readdir.listdir_dirs(tmp_path) == ['dir'] + assert native_readdir.listdir_dirs(tmp_path) == ["dir"] assert native_readdir.listdir_dirs(subdir) == [] def test_listdir_files(self, tmp_path, subdir): - assert native_readdir.listdir_files(tmp_path) == ['file'] + assert native_readdir.listdir_files(tmp_path) == ["file"] assert native_readdir.listdir_dirs(subdir) == [] def test_missing(self, tmp_path, subdir): - return self._test_missing(tmp_path, ( - native_readdir.listdir, - native_readdir.listdir_dirs, - native_readdir.listdir_files, - )) + return self._test_missing( + tmp_path, + ( + native_readdir.listdir, + native_readdir.listdir_dirs, + native_readdir.listdir_files, + ), + ) def test_dangling_sym(self, tmp_path, subdir): (tmp_path / "monkeys").symlink_to("foon") - assert native_readdir.listdir_files(tmp_path) == ['file'] + assert native_readdir.listdir_files(tmp_path) == ["file"] class TestNativeReaddir(ReaddirCommon): @@ -78,36 +79,37 @@ class TestNativeReaddir(ReaddirCommon): class TestEnsureDirs: - def check_dir(self, path, uid, gid, mode): assert path.is_dir() st = os.stat(path) - assert stat.S_IMODE(st.st_mode) == mode, \ - '0%o != 0%o' % (stat.S_IMODE(st.st_mode), mode) + assert stat.S_IMODE(st.st_mode) == mode, "0%o != 0%o" % ( + stat.S_IMODE(st.st_mode), + mode, + ) assert st.st_uid == uid assert st.st_gid == gid def test_ensure_dirs(self, tmp_path): # default settings - path = tmp_path / 'foo' / 'bar' + path = tmp_path / "foo" / "bar" assert osutils.ensure_dirs(path) self.check_dir(path, os.geteuid(), os.getegid(), 0o777) def test_minimal_nonmodifying(self, tmp_path): - path = tmp_path / 'foo' / 'bar' + path = tmp_path / "foo" / "bar" assert osutils.ensure_dirs(path, mode=0o755) os.chmod(path, 0o777) assert osutils.ensure_dirs(path, mode=0o755, minimal=True) self.check_dir(path, os.geteuid(), os.getegid(), 0o777) def test_minimal_modifying(self, tmp_path): - path = tmp_path / 'foo' / 'bar' + path = tmp_path / "foo" / "bar" assert osutils.ensure_dirs(path, mode=0o750) assert osutils.ensure_dirs(path, mode=0o005, minimal=True) self.check_dir(path, os.geteuid(), os.getegid(), 0o755) def test_create_unwritable_subdir(self, tmp_path): - path = tmp_path / 'restricted' / 'restricted' + path = tmp_path / "restricted" / "restricted" # create the subdirs without 020 first assert osutils.ensure_dirs(path.parent) assert osutils.ensure_dirs(path, mode=0o020) @@ -118,38 +120,39 @@ class TestEnsureDirs: def test_path_is_a_file(self, tmp_path): # fail if passed a path to an existing file - path = tmp_path / 'file' + path = tmp_path / "file" touch(path) assert path.is_file() assert not osutils.ensure_dirs(path, mode=0o700) def test_non_dir_in_path(self, tmp_path): # fail if one of the parts of the path isn't a dir - path = tmp_path / 'file' / 'dir' - (tmp_path / 'file').touch() + path = tmp_path / "file" / "dir" + (tmp_path / "file").touch() assert not osutils.ensure_dirs(path, mode=0o700) def test_mkdir_failing(self, tmp_path): # fail if os.mkdir fails - with mock.patch('snakeoil.osutils.os.mkdir') as mkdir: - mkdir.side_effect = OSError(30, 'Read-only file system') - path = tmp_path / 'dir' + with mock.patch("snakeoil.osutils.os.mkdir") as mkdir: + mkdir.side_effect = OSError(30, "Read-only file system") + path = tmp_path / "dir" assert not osutils.ensure_dirs(path, mode=0o700) # force temp perms assert not osutils.ensure_dirs(path, mode=0o400) - mkdir.side_effect = OSError(17, 'File exists') + mkdir.side_effect = OSError(17, "File exists") assert not osutils.ensure_dirs(path, mode=0o700) def test_chmod_or_chown_failing(self, tmp_path): # fail if chmod or chown fails - path = tmp_path / 'dir' + path = tmp_path / "dir" path.mkdir() path.chmod(0o750) - with mock.patch('snakeoil.osutils.os.chmod') as chmod, \ - mock.patch('snakeoil.osutils.os.chown') as chown: - chmod.side_effect = OSError(5, 'Input/output error') + with mock.patch("snakeoil.osutils.os.chmod") as chmod, mock.patch( + "snakeoil.osutils.os.chown" + ) as chown: + chmod.side_effect = OSError(5, "Input/output error") # chmod failure when file exists and trying to reset perms to match # the specified mode @@ -163,13 +166,13 @@ class TestEnsureDirs: # chown failure when resetting perms on parents chmod.side_effect = None - chown.side_effect = OSError(5, 'Input/output error') + chown.side_effect = OSError(5, "Input/output error") assert not osutils.ensure_dirs(path, uid=1000, gid=1000, mode=0o400) def test_reset_sticky_parent_perms(self, tmp_path): # make sure perms are reset after traversing over sticky parents - sticky_parent = tmp_path / 'dir' - path = sticky_parent / 'dir' + sticky_parent = tmp_path / "dir" + path = sticky_parent / "dir" sticky_parent.mkdir() sticky_parent.chmod(0o2755) pre_sticky_parent = os.stat(sticky_parent) @@ -178,7 +181,7 @@ class TestEnsureDirs: assert pre_sticky_parent.st_mode == post_sticky_parent.st_mode def test_mode(self, tmp_path): - path = tmp_path / 'mode' / 'mode' + path = tmp_path / "mode" / "mode" assert osutils.ensure_dirs(path, mode=0o700) self.check_dir(path, os.geteuid(), os.getegid(), 0o700) # unrestrict it @@ -188,12 +191,12 @@ class TestEnsureDirs: def test_gid(self, tmp_path): # abuse the portage group as secondary group try: - portage_gid = grp.getgrnam('portage').gr_gid + portage_gid = grp.getgrnam("portage").gr_gid except KeyError: - pytest.skip('the portage group does not exist') + pytest.skip("the portage group does not exist") if portage_gid not in os.getgroups(): - pytest.skip('you are not in the portage group') - path = tmp_path / 'group' / 'group' + pytest.skip("you are not in the portage group") + path = tmp_path / "group" / "group" assert osutils.ensure_dirs(path, gid=portage_gid) self.check_dir(path, os.geteuid(), portage_gid, 0o777) assert osutils.ensure_dirs(path) @@ -203,12 +206,11 @@ class TestEnsureDirs: class TestAbsSymlink: - def test_abssymlink(self, tmp_path): - target = tmp_path / 'target' - linkname = tmp_path / 'link' + target = tmp_path / "target" + linkname = tmp_path / "link" target.mkdir() - linkname.symlink_to('target') + linkname.symlink_to("target") assert osutils.abssymlink(linkname) == str(target) @@ -223,28 +225,30 @@ class Test_Native_NormPath: got = f(src) assert got == val, f"{src!r}: expected {val!r}, got {got!r}" - check('/foo/', '/foo') - check('//foo/', '/foo') - check('//foo/.', '/foo') - check('//..', '/') - check('//..//foo', '/foo') - check('/foo/..', '/') - check('..//foo', '../foo') - check('../foo/../', '..') - check('../', '..') - check('../foo/..', '..') - check('../foo/../dar', '../dar') - check('.//foo', 'foo') - check('/foo/../../', '/') - check('/foo/../../..', '/') - check('/tmp/foo/../dar/', '/tmp/dar') - check('/tmp/foo/../dar', '/tmp/dar') + check("/foo/", "/foo") + check("//foo/", "/foo") + check("//foo/.", "/foo") + check("//..", "/") + check("//..//foo", "/foo") + check("/foo/..", "/") + check("..//foo", "../foo") + check("../foo/../", "..") + check("../", "..") + check("../foo/..", "..") + check("../foo/../dar", "../dar") + check(".//foo", "foo") + check("/foo/../../", "/") + check("/foo/../../..", "/") + check("/tmp/foo/../dar/", "/tmp/dar") + check("/tmp/foo/../dar", "/tmp/dar") # explicit unicode and bytes - check('/tmṕ/föo//../dár', '/tmṕ/dár') - check(b'/tm\xe1\xb9\x95/f\xc3\xb6o//../d\xc3\xa1r', b'/tm\xe1\xb9\x95/d\xc3\xa1r') - check('/föó/..', '/') - check(b'/f\xc3\xb6\xc3\xb3/..', b'/') + check("/tmṕ/föo//../dár", "/tmṕ/dár") + check( + b"/tm\xe1\xb9\x95/f\xc3\xb6o//../d\xc3\xa1r", b"/tm\xe1\xb9\x95/d\xc3\xa1r" + ) + check("/föó/..", "/") + check(b"/f\xc3\xb6\xc3\xb3/..", b"/") @pytest.mark.skipif(os.getuid() != 0, reason="these tests must be ran as root") @@ -253,7 +257,7 @@ class TestAccess: func = staticmethod(osutils.fallback_access) def test_fallback(self, tmp_path): - fp = tmp_path / 'file' + fp = tmp_path / "file" # create the file fp.touch() fp.chmod(0o000) @@ -270,9 +274,9 @@ class Test_unlink_if_exists: def test_it(self, tmp_path): f = self.func - path = tmp_path / 'target' + path = tmp_path / "target" f(path) - path.write_text('') + path.write_text("") f(path) assert not path.exists() # and once more for good measure... @@ -280,18 +284,17 @@ class Test_unlink_if_exists: class TestSupportedSystems: - def test_supported_system(self): - @supported_systems('supported') + @supported_systems("supported") def func(): return True - with mock.patch('snakeoil.osutils.sys') as _sys: - _sys.configure_mock(platform='supported') + with mock.patch("snakeoil.osutils.sys") as _sys: + _sys.configure_mock(platform="supported") assert func() def test_unsupported_system(self): - @supported_systems('unsupported') + @supported_systems("unsupported") def func(): return True @@ -299,39 +302,39 @@ class TestSupportedSystems: func() # make sure we're iterating through the system params correctly - with mock.patch('snakeoil.osutils.sys') as _sys: - _sys.configure_mock(platform='u') + with mock.patch("snakeoil.osutils.sys") as _sys: + _sys.configure_mock(platform="u") with pytest.raises(NotImplementedError): func() def test_multiple_systems(self): - @supported_systems('darwin', 'linux') + @supported_systems("darwin", "linux") def func(): return True - with mock.patch('snakeoil.osutils.sys') as _sys: - _sys.configure_mock(platform='nonexistent') + with mock.patch("snakeoil.osutils.sys") as _sys: + _sys.configure_mock(platform="nonexistent") with pytest.raises(NotImplementedError): func() - for platform in ('linux2', 'darwin'): + for platform in ("linux2", "darwin"): _sys.configure_mock(platform=platform) assert func() -@pytest.mark.skipif(not sys.platform.startswith('linux'), - reason='supported on Linux only') +@pytest.mark.skipif( + not sys.platform.startswith("linux"), reason="supported on Linux only" +) class TestMount: - @pytest.fixture def source(self, tmp_path): - source = tmp_path / 'source' + source = tmp_path / "source" source.mkdir() return source @pytest.fixture def target(self, tmp_path): - target = tmp_path / 'target' + target = tmp_path / "target" target.mkdir() return target @@ -340,21 +343,25 @@ class TestMount: # byte strings; if they are unicode strings the arguments get mangled # leading to errors when the syscall is run. This confirms mount() from # snakeoil.osutils always converts the arguments into byte strings. - for source, target, fstype in ((b'source', b'target', b'fstype'), - ('source', 'target', 'fstype')): - with mock.patch('snakeoil.osutils.mount.ctypes') as mock_ctypes: + for source, target, fstype in ( + (b"source", b"target", b"fstype"), + ("source", "target", "fstype"), + ): + with mock.patch("snakeoil.osutils.mount.ctypes") as mock_ctypes: with pytest.raises(OSError): mount(str(source), str(target), fstype, MS_BIND) - mount_call = next(x for x in mock_ctypes.mock_calls if x[0] == 'CDLL().mount') + mount_call = next( + x for x in mock_ctypes.mock_calls if x[0] == "CDLL().mount" + ) for arg in mount_call[1][0:3]: assert isinstance(arg, bytes) def test_missing_dirs(self): with pytest.raises(OSError) as cm: - mount('source', 'target', None, MS_BIND) + mount("source", "target", None, MS_BIND) assert cm.value.errno in (errno.EPERM, errno.ENOENT) - @pytest.mark.skipif(os.getuid() == 0, reason='this test must be run as non-root') + @pytest.mark.skipif(os.getuid() == 0, reason="this test must be run as non-root") def test_no_perms(self, source, target): with pytest.raises(OSError) as cm: mount(str(source), str(target), None, MS_BIND) @@ -363,11 +370,15 @@ class TestMount: umount(str(target)) assert cm.value.errno in (errno.EPERM, errno.EINVAL) - @pytest.mark.skipif(not (os.path.exists('/proc/self/ns/mnt') and os.path.exists('/proc/self/ns/user')), - reason='user and mount namespace support required') + @pytest.mark.skipif( + not ( + os.path.exists("/proc/self/ns/mnt") and os.path.exists("/proc/self/ns/user") + ), + reason="user and mount namespace support required", + ) def test_bind_mount(self, source, target): - src_file = source / 'file' - bind_file = target / 'file' + src_file = source / "file" + bind_file = target / "file" src_file.touch() try: @@ -378,15 +389,19 @@ class TestMount: umount(str(target)) assert not bind_file.exists() except PermissionError: - pytest.skip('No permission to use user and mount namespace') - - @pytest.mark.skipif(not (os.path.exists('/proc/self/ns/mnt') and os.path.exists('/proc/self/ns/user')), - reason='user and mount namespace support required') + pytest.skip("No permission to use user and mount namespace") + + @pytest.mark.skipif( + not ( + os.path.exists("/proc/self/ns/mnt") and os.path.exists("/proc/self/ns/user") + ), + reason="user and mount namespace support required", + ) def test_lazy_unmount(self, source, target): - src_file = source / 'file' - bind_file = target / 'file' + src_file = source / "file" + bind_file = target / "file" src_file.touch() - src_file.write_text('foo') + src_file.write_text("foo") try: with Namespace(user=True, mount=True): @@ -403,14 +418,14 @@ class TestMount: # confirm the file doesn't exist in the bind mount anymore assert not bind_file.exists() # but the file is still accessible to the process - assert f.read() == 'foo' + assert f.read() == "foo" # trying to reopen causes IOError with pytest.raises(IOError) as cm: f = bind_file.open() assert cm.value.errno == errno.ENOENT except PermissionError: - pytest.skip('No permission to use user and mount namespace') + pytest.skip("No permission to use user and mount namespace") class TestSizeofFmt: diff --git a/tests/test_process.py b/tests/test_process.py index bb45712..488b7b0 100644 --- a/tests/test_process.py +++ b/tests/test_process.py @@ -30,8 +30,10 @@ class TestFindBinary: process.find_binary(self.script) def test_fallback(self): - fallback = process.find_binary(self.script, fallback=os.path.join('bin', self.script)) - assert fallback == os.path.join('bin', self.script) + fallback = process.find_binary( + self.script, fallback=os.path.join("bin", self.script) + ) + assert fallback == os.path.join("bin", self.script) def test_not_executable(self, tmp_path): fp = tmp_path / self.script diff --git a/tests/test_process_spawn.py b/tests/test_process_spawn.py index 8981c6e..556b34c 100644 --- a/tests/test_process_spawn.py +++ b/tests/test_process_spawn.py @@ -6,11 +6,11 @@ from snakeoil import process from snakeoil.contexts import chdir from snakeoil.process import spawn -BASH_BINARY = process.find_binary("bash", fallback='') +BASH_BINARY = process.find_binary("bash", fallback="") -@pytest.mark.skipif(not BASH_BINARY, reason='missing bash binary') -class TestSpawn: +@pytest.mark.skipif(not BASH_BINARY, reason="missing bash binary") +class TestSpawn: @pytest.fixture(autouse=True) def _setup(self, tmp_path): orig_path = os.environ["PATH"] @@ -37,21 +37,25 @@ class TestSpawn: def test_get_output(self, tmp_path, dev_null): filename = "spawn-getoutput.sh" for r, s, text, args in ( - [0, ["dar\n"], "echo dar\n", {}], - [0, ["dar"], "echo -n dar", {}], - [1, ["blah\n", "dar\n"], "echo blah\necho dar\nexit 1", {}], - [0, [], "echo dar 1>&2", {"fd_pipes": {1: 1, 2: dev_null}}]): + [0, ["dar\n"], "echo dar\n", {}], + [0, ["dar"], "echo -n dar", {}], + [1, ["blah\n", "dar\n"], "echo blah\necho dar\nexit 1", {}], + [0, [], "echo dar 1>&2", {"fd_pipes": {1: 1, 2: dev_null}}], + ): fp = self.generate_script(tmp_path, filename, text) - assert (r, s) == spawn.spawn_get_output(str(fp), spawn_type=spawn.spawn_bash, **args) + assert (r, s) == spawn.spawn_get_output( + str(fp), spawn_type=spawn.spawn_bash, **args + ) os.unlink(fp) @pytest.mark.skipif(not spawn.is_sandbox_capable(), reason="missing sandbox binary") def test_sandbox(self, tmp_path): - fp = self.generate_script( - tmp_path, "spawn-sandbox.sh", "echo $LD_PRELOAD") + fp = self.generate_script(tmp_path, "spawn-sandbox.sh", "echo $LD_PRELOAD") ret = spawn.spawn_get_output(str(fp), spawn_type=spawn.spawn_sandbox) assert ret[1], "no output; exit code was %s; script location %s" % (ret[0], fp) - assert "libsandbox.so" in [os.path.basename(x.strip()) for x in ret[1][0].split()] + assert "libsandbox.so" in [ + os.path.basename(x.strip()) for x in ret[1][0].split() + ] os.unlink(fp) @pytest.mark.skipif(not spawn.is_sandbox_capable(), reason="missing sandbox binary") @@ -60,15 +64,17 @@ class TestSpawn: this verifies our fix works. """ - fp = self.generate_script( - tmp_path, "spawn-sandbox.sh", "echo $LD_PRELOAD") + fp = self.generate_script(tmp_path, "spawn-sandbox.sh", "echo $LD_PRELOAD") dpath = tmp_path / "dar" dpath.mkdir() with chdir(dpath): dpath.rmdir() - assert "libsandbox.so" in \ - [os.path.basename(x.strip()) for x in spawn.spawn_get_output( - str(fp), spawn_type=spawn.spawn_sandbox, cwd='/')[1][0].split()] + assert "libsandbox.so" in [ + os.path.basename(x.strip()) + for x in spawn.spawn_get_output( + str(fp), spawn_type=spawn.spawn_sandbox, cwd="/" + )[1][0].split() + ] fp.unlink() def test_process_exit_code(self): @@ -98,13 +104,12 @@ class TestSpawn: def test_spawn_bash(self, capfd): # bash builtin for true without exec'ing true (eg, no path lookup) - assert 0 == spawn.spawn_bash('echo bash') + assert 0 == spawn.spawn_bash("echo bash") out, _err = capfd.readouterr() - assert out.strip() == 'bash' + assert out.strip() == "bash" def test_umask(self, tmp_path): - fp = self.generate_script( - tmp_path, "spawn-umask.sh", f"#!{BASH_BINARY}\numask") + fp = self.generate_script(tmp_path, "spawn-umask.sh", f"#!{BASH_BINARY}\numask") try: old_umask = os.umask(0) if old_umask == 0: @@ -113,7 +118,8 @@ class TestSpawn: os.umask(desired) else: desired = 0 - assert str(desired).lstrip("0") == \ - spawn.spawn_get_output(str(fp))[1][0].strip().lstrip("0") + assert str(desired).lstrip("0") == spawn.spawn_get_output(str(fp))[1][ + 0 + ].strip().lstrip("0") finally: os.umask(old_umask) diff --git a/tests/test_sequences.py b/tests/test_sequences.py index edbaa5a..0d8c5a6 100644 --- a/tests/test_sequences.py +++ b/tests/test_sequences.py @@ -8,13 +8,11 @@ from snakeoil.sequences import split_elements, split_negations class UnhashableComplex(complex): - def __hash__(self): raise TypeError class TestStableUnique: - def common_check(self, func): # silly assert func(()) == [] @@ -23,9 +21,10 @@ class TestStableUnique: # neither def test_stable_unique(self, func=sequences.stable_unique): - assert list(set([1, 2, 3])) == [1, 2, 3], \ - "this test is reliant on the interpreter hasing 1,2,3 into a specific ordering- " \ + assert list(set([1, 2, 3])) == [1, 2, 3], ( + "this test is reliant on the interpreter hasing 1,2,3 into a specific ordering- " "for whatever reason, ordering differs, thus this test can't verify it" + ) assert func([3, 2, 1]) == [3, 2, 1] def test_iter_stable_unique(self): @@ -43,20 +42,19 @@ class TestStableUnique: uc = UnhashableComplex res = sequences.unstable_unique([uc(1, 0), uc(0, 1), uc(1, 0)]) # sortable - assert sorted(sequences.unstable_unique( - [[1, 2], [1, 3], [1, 2], [1, 3]])) == [[1, 2], [1, 3]] + assert sorted(sequences.unstable_unique([[1, 2], [1, 3], [1, 2], [1, 3]])) == [ + [1, 2], + [1, 3], + ] assert res == [uc(1, 0), uc(0, 1)] or res == [uc(0, 1), uc(1, 0)] assert sorted(sequences.unstable_unique(self._generator())) == sorted(range(6)) class TestChainedLists: - @staticmethod def gen_cl(): return sequences.ChainedLists( - list(range(3)), - list(range(3, 6)), - list(range(6, 100)) + list(range(3)), list(range(3, 6)), list(range(6, 100)) ) def test_contains(self): @@ -72,7 +70,7 @@ class TestChainedLists: def test_str(self): l = sequences.ChainedLists(list(range(3)), list(range(3, 5))) - assert str(l) == '[ [0, 1, 2], [3, 4] ]' + assert str(l) == "[ [0, 1, 2], [3, 4] ]" def test_getitem(self): cl = self.gen_cl() @@ -108,15 +106,18 @@ class Test_iflatten_instance: def test_it(self): o = OrderedDict((k, None) for k in range(10)) for l, correct, skip in ( - (["asdf", ["asdf", "asdf"], 1, None], - ["asdf", "asdf", "asdf", 1, None], str), - ([o, 1, "fds"], [o, 1, "fds"], (str, OrderedDict)), - ([o, 1, "fds"], list(range(10)) + [1, "fds"], str), - ("fds", ["fds"], str), - ("fds", ["f", "d", "s"], int), - ('', [''], str), - (1, [1], int), - ): + ( + ["asdf", ["asdf", "asdf"], 1, None], + ["asdf", "asdf", "asdf", 1, None], + str, + ), + ([o, 1, "fds"], [o, 1, "fds"], (str, OrderedDict)), + ([o, 1, "fds"], list(range(10)) + [1, "fds"], str), + ("fds", ["fds"], str), + ("fds", ["f", "d", "s"], int), + ("", [""], str), + (1, [1], int), + ): iterator = self.func(l, skip) assert list(iterator) == correct assert list(iterator) == [] @@ -126,6 +127,7 @@ class Test_iflatten_instance: # have to iterate. def fail(): return list(self.func(None)) + with pytest.raises(TypeError): fail() @@ -148,13 +150,16 @@ class Test_iflatten_func: def test_it(self): o = OrderedDict((k, None) for k in range(10)) for l, correct, skip in ( - (["asdf", ["asdf", "asdf"], 1, None], - ["asdf", "asdf", "asdf", 1, None], str), - ([o, 1, "fds"], [o, 1, "fds"], (str, OrderedDict)), - ([o, 1, "fds"], list(range(10)) + [1, "fds"], str), - ("fds", ["fds"], str), - (1, [1], int), - ): + ( + ["asdf", ["asdf", "asdf"], 1, None], + ["asdf", "asdf", "asdf", 1, None], + str, + ), + ([o, 1, "fds"], [o, 1, "fds"], (str, OrderedDict)), + ([o, 1, "fds"], list(range(10)) + [1, "fds"], str), + ("fds", ["fds"], str), + (1, [1], int), + ): iterator = self.func(l, lambda x: isinstance(x, skip)) assert list(iterator) == correct assert list(iterator) == [] @@ -164,6 +169,7 @@ class Test_iflatten_func: # have to iterate. def fail(): return list(self.func(None, lambda x: False)) + with pytest.raises(TypeError): fail() @@ -189,25 +195,24 @@ class Test_predicate_split: assert true_l == list(range(0, 100, 2)) def test_key(self): - false_l, true_l = self.kls(lambda x: x % 2 == 0, - ([0, x] for x in range(100)), - key=itemgetter(1)) + false_l, true_l = self.kls( + lambda x: x % 2 == 0, ([0, x] for x in range(100)), key=itemgetter(1) + ) assert false_l == [[0, x] for x in range(1, 100, 2)] assert true_l == [[0, x] for x in range(0, 100, 2)] class TestSplitNegations: - def test_empty(self): # empty input - seq = '' + seq = "" assert split_negations(seq) == ((), ()) def test_bad_value(self): # no-value negation should raise a ValueError bad_values = ( - '-', - 'a b c - d f e', + "-", + "a b c - d f e", ) for s in bad_values: @@ -216,7 +221,7 @@ class TestSplitNegations: def test_negs(self): # all negs - seq = ('-' + str(x) for x in range(100)) + seq = ("-" + str(x) for x in range(100)) assert split_negations(seq) == (tuple(map(str, range(100))), ()) def test_pos(self): @@ -226,31 +231,33 @@ class TestSplitNegations: def test_neg_pos(self): # both - seq = (('-' + str(x), str(x)) for x in range(100)) + seq = (("-" + str(x), str(x)) for x in range(100)) seq = chain.from_iterable(seq) - assert split_negations(seq) == (tuple(map(str, range(100))), tuple(map(str, range(100)))) + assert split_negations(seq) == ( + tuple(map(str, range(100))), + tuple(map(str, range(100))), + ) def test_converter(self): # converter method - seq = (('-' + str(x), str(x)) for x in range(100)) + seq = (("-" + str(x), str(x)) for x in range(100)) seq = chain.from_iterable(seq) assert split_negations(seq, int) == (tuple(range(100)), tuple(range(100))) class TestSplitElements: - def test_empty(self): # empty input - seq = '' + seq = "" assert split_elements(seq) == ((), (), ()) def test_bad_value(self): # no-value neg/pos should raise ValueErrors bad_values = ( - '-', - '+', - 'a b c - d f e', - 'a b c + d f e', + "-", + "+", + "a b c - d f e", + "a b c + d f e", ) for s in bad_values: @@ -259,7 +266,7 @@ class TestSplitElements: def test_negs(self): # all negs - seq = ('-' + str(x) for x in range(100)) + seq = ("-" + str(x) for x in range(100)) assert split_elements(seq) == (tuple(map(str, range(100))), (), ()) def test_neutral(self): @@ -269,12 +276,12 @@ class TestSplitElements: def test_pos(self): # all pos - seq = ('+' + str(x) for x in range(100)) + seq = ("+" + str(x) for x in range(100)) assert split_elements(seq) == ((), (), tuple(map(str, range(100)))) def test_neg_pos(self): # both negative and positive values - seq = (('-' + str(x), '+' + str(x)) for x in range(100)) + seq = (("-" + str(x), "+" + str(x)) for x in range(100)) seq = chain.from_iterable(seq) assert split_elements(seq) == ( tuple(map(str, range(100))), @@ -284,7 +291,7 @@ class TestSplitElements: def test_neg_neu_pos(self): # all three value types - seq = (('-' + str(x), str(x), '+' + str(x)) for x in range(100)) + seq = (("-" + str(x), str(x), "+" + str(x)) for x in range(100)) seq = chain.from_iterable(seq) assert split_elements(seq) == ( tuple(map(str, range(100))), @@ -294,7 +301,10 @@ class TestSplitElements: def test_converter(self): # converter method - seq = (('-' + str(x), str(x), '+' + str(x)) for x in range(100)) + seq = (("-" + str(x), str(x), "+" + str(x)) for x in range(100)) seq = chain.from_iterable(seq) assert split_elements(seq, int) == ( - tuple(range(100)), tuple(range(100)), tuple(range(100))) + tuple(range(100)), + tuple(range(100)), + tuple(range(100)), + ) diff --git a/tests/test_stringio.py b/tests/test_stringio.py index 4fb7c78..1e6d1e5 100644 --- a/tests/test_stringio.py +++ b/tests/test_stringio.py @@ -34,6 +34,7 @@ class readonly_mixin: class Test_text_readonly(readonly_mixin): kls = stringio.text_readonly -class Test_bytes_readonly(readonly_mixin ): + +class Test_bytes_readonly(readonly_mixin): kls = stringio.bytes_readonly - encoding = 'utf8' + encoding = "utf8" diff --git a/tests/test_strings.py b/tests/test_strings.py index b55c230..1930570 100644 --- a/tests/test_strings.py +++ b/tests/test_strings.py @@ -3,38 +3,36 @@ from snakeoil.strings import doc_dedent, pluralism class TestPluralism: - def test_none(self): # default - assert pluralism([]) == 's' + assert pluralism([]) == "s" # different suffix for nonexistence - assert pluralism([], none='') == '' + assert pluralism([], none="") == "" def test_singular(self): # default - assert pluralism([1]) == '' + assert pluralism([1]) == "" # different suffix for singular existence - assert pluralism([1], singular='o') == 'o' + assert pluralism([1], singular="o") == "o" def test_plural(self): # default - assert pluralism([1, 2]) == 's' + assert pluralism([1, 2]) == "s" # different suffix for plural existence - assert pluralism([1, 2], plural='ies') == 'ies' + assert pluralism([1, 2], plural="ies") == "ies" def test_int(self): - assert pluralism(0) == 's' - assert pluralism(1) == '' - assert pluralism(2) == 's' + assert pluralism(0) == "s" + assert pluralism(1) == "" + assert pluralism(2) == "s" class TestDocDedent: - def test_empty(self): - s = '' + s = "" assert s == doc_dedent(s) def test_non_string(self): @@ -42,20 +40,20 @@ class TestDocDedent: doc_dedent(None) def test_line(self): - s = 'line' + s = "line" assert s == doc_dedent(s) def test_indented_line(self): - for indent in ('\t', ' '): - s = f'{indent}line' - assert 'line' == doc_dedent(s) + for indent in ("\t", " "): + s = f"{indent}line" + assert "line" == doc_dedent(s) def test_docstring(self): s = """Docstring to test. foo bar """ - assert 'Docstring to test.\n\nfoo bar\n' == doc_dedent(s) + assert "Docstring to test.\n\nfoo bar\n" == doc_dedent(s) def test_all_indented(self): s = """\ @@ -63,4 +61,4 @@ class TestDocDedent: foo bar """ - assert 'Docstring to test.\n\nfoo bar\n' == doc_dedent(s) + assert "Docstring to test.\n\nfoo bar\n" == doc_dedent(s) diff --git a/tests/test_version.py b/tests/test_version.py index 7dad73e..0992754 100644 --- a/tests/test_version.py +++ b/tests/test_version.py @@ -7,7 +7,6 @@ from snakeoil import __version__, version class TestVersion: - def setup_method(self, method): # reset the cached version in the module reload(version) @@ -17,124 +16,142 @@ class TestVersion: def test_get_version_unknown(self): with pytest.raises(ValueError): - version.get_version('snakeoilfoo', __file__) + version.get_version("snakeoilfoo", __file__) def test_get_version_api(self): - v = version.get_version('snakeoil', __file__, '9.9.9') - assert v.startswith('snakeoil 9.9.9') + v = version.get_version("snakeoil", __file__, "9.9.9") + assert v.startswith("snakeoil 9.9.9") def test_get_version_git_dev(self): - with mock.patch('snakeoil.version.import_module') as import_module, \ - mock.patch('snakeoil.version.get_git_version') as get_git_version: + with mock.patch("snakeoil.version.import_module") as import_module, mock.patch( + "snakeoil.version.get_git_version" + ) as get_git_version: import_module.side_effect = ImportError verinfo = { - 'rev': '1ff76b021d208f7df38ac524537b6419404f1c64', - 'date': 'Mon Sep 25 13:50:24 2017 -0400', - 'tag': None + "rev": "1ff76b021d208f7df38ac524537b6419404f1c64", + "date": "Mon Sep 25 13:50:24 2017 -0400", + "tag": None, } get_git_version.return_value = verinfo - result = version.get_version('snakeoil', __file__, __version__) - assert result == f"snakeoil {__version__}-g{verinfo['rev'][:7]} -- {verinfo['date']}" + result = version.get_version("snakeoil", __file__, __version__) + assert ( + result + == f"snakeoil {__version__}-g{verinfo['rev'][:7]} -- {verinfo['date']}" + ) def test_get_version_git_release(self): verinfo = { - 'rev': 'ab38751890efa8be96b7f95938d6b868b769bab6', - 'date': 'Thu Sep 21 15:57:38 2017 -0400', - 'tag': '2.3.4', + "rev": "ab38751890efa8be96b7f95938d6b868b769bab6", + "date": "Thu Sep 21 15:57:38 2017 -0400", + "tag": "2.3.4", } # fake snakeoil._verinfo module object class Verinfo: version_info = verinfo - with mock.patch('snakeoil.version.import_module') as import_module: + with mock.patch("snakeoil.version.import_module") as import_module: import_module.return_value = Verinfo() - result = version.get_version('snakeoil', __file__, verinfo['tag']) + result = version.get_version("snakeoil", __file__, verinfo["tag"]) assert result == f"snakeoil {verinfo['tag']} -- released {verinfo['date']}" def test_get_version_no_git_version(self): - with mock.patch('snakeoil.version.import_module') as import_module, \ - mock.patch('snakeoil.version.get_git_version') as get_git_version: + with mock.patch("snakeoil.version.import_module") as import_module, mock.patch( + "snakeoil.version.get_git_version" + ) as get_git_version: import_module.side_effect = ImportError get_git_version.return_value = None - result = version.get_version('snakeoil', 'nonexistent', __version__) - assert result == f'snakeoil {__version__}' + result = version.get_version("snakeoil", "nonexistent", __version__) + assert result == f"snakeoil {__version__}" def test_get_version_caching(self): # retrieved version info is cached in a module attr - v = version.get_version('snakeoil', __file__) - assert v.startswith(f'snakeoil {__version__}') + v = version.get_version("snakeoil", __file__) + assert v.startswith(f"snakeoil {__version__}") # re-running get_version returns the cached attr instead of reprocessing - with mock.patch('snakeoil.version.import_module') as import_module: - v = version.get_version('snakeoil', __file__) + with mock.patch("snakeoil.version.import_module") as import_module: + v = version.get_version("snakeoil", __file__) assert not import_module.called class TestGitVersion: - def test_get_git_version_not_available(self): - with mock.patch('snakeoil.version._run_git') as run_git: - run_git.side_effect = EnvironmentError(errno.ENOENT, 'git not found') - assert version.get_git_version('nonexistent') is None + with mock.patch("snakeoil.version._run_git") as run_git: + run_git.side_effect = EnvironmentError(errno.ENOENT, "git not found") + assert version.get_git_version("nonexistent") is None def test_get_git_version_error(self): - with mock.patch('snakeoil.version._run_git') as run_git: - run_git.return_value = (b'foo', 1) - assert version.get_git_version('nonexistent') is None + with mock.patch("snakeoil.version._run_git") as run_git: + run_git.return_value = (b"foo", 1) + assert version.get_git_version("nonexistent") is None def test_get_git_version_non_repo(self, tmpdir): assert version.get_git_version(str(tmpdir)) is None def test_get_git_version_exc(self): with pytest.raises(OSError): - with mock.patch('snakeoil.version._run_git') as run_git: - run_git.side_effect = OSError(errno.EIO, 'Input/output error') - version.get_git_version('nonexistent') + with mock.patch("snakeoil.version._run_git") as run_git: + run_git.side_effect = OSError(errno.EIO, "Input/output error") + version.get_git_version("nonexistent") def test_get_git_version_good_dev(self): - with mock.patch('snakeoil.version._run_git') as run_git: + with mock.patch("snakeoil.version._run_git") as run_git: # dev version run_git.return_value = ( - b'1ff76b021d208f7df38ac524537b6419404f1c64\nMon Sep 25 13:50:24 2017 -0400', 0) - result = version.get_git_version('nonexistent') + b"1ff76b021d208f7df38ac524537b6419404f1c64\nMon Sep 25 13:50:24 2017 -0400", + 0, + ) + result = version.get_git_version("nonexistent") expected = { - 'rev': '1ff76b021d208f7df38ac524537b6419404f1c64', - 'date': 'Mon Sep 25 13:50:24 2017 -0400', - 'tag': None, - 'commits': 2, + "rev": "1ff76b021d208f7df38ac524537b6419404f1c64", + "date": "Mon Sep 25 13:50:24 2017 -0400", + "tag": None, + "commits": 2, } assert result == expected def test_get_git_version_good_tag(self): - with mock.patch('snakeoil.version._run_git') as run_git, \ - mock.patch('snakeoil.version._get_git_tag') as get_git_tag: + with mock.patch("snakeoil.version._run_git") as run_git, mock.patch( + "snakeoil.version._get_git_tag" + ) as get_git_tag: # tagged, release version run_git.return_value = ( - b'1ff76b021d208f7df38ac524537b6419404f1c64\nMon Sep 25 13:50:24 2017 -0400', 0) - get_git_tag.return_value = '1.1.1' - result = version.get_git_version('nonexistent') + b"1ff76b021d208f7df38ac524537b6419404f1c64\nMon Sep 25 13:50:24 2017 -0400", + 0, + ) + get_git_tag.return_value = "1.1.1" + result = version.get_git_version("nonexistent") expected = { - 'rev': '1ff76b021d208f7df38ac524537b6419404f1c64', - 'date': 'Mon Sep 25 13:50:24 2017 -0400', - 'tag': '1.1.1', - 'commits': 2, + "rev": "1ff76b021d208f7df38ac524537b6419404f1c64", + "date": "Mon Sep 25 13:50:24 2017 -0400", + "tag": "1.1.1", + "commits": 2, } assert result == expected def test_get_git_tag_bad_output(self): - with mock.patch('snakeoil.version._run_git') as run_git: + with mock.patch("snakeoil.version._run_git") as run_git: # unknown git tag rev output - run_git.return_value = (b'a', 1) - assert version._get_git_tag('foo', 'bar') is None - run_git.return_value = (b'a foo/v0.7.2', 0) - assert version._get_git_tag('foo', 'bar') is None + run_git.return_value = (b"a", 1) + assert version._get_git_tag("foo", "bar") is None + run_git.return_value = (b"a foo/v0.7.2", 0) + assert version._get_git_tag("foo", "bar") is None # expected output formats - run_git.return_value = (b'ab38751890efa8be96b7f95938d6b868b769bab6 tags/v1.1.1^0', 0) - assert version._get_git_tag('foo', 'bar') == '1.1.1' - run_git.return_value = (b'ab38751890efa8be96b7f95938d6b868b769bab6 tags/v1.1.1', 0) - assert version._get_git_tag('foo', 'bar') == '1.1.1' - run_git.return_value = (b'ab38751890efa8be96b7f95938d6b868b769bab6 tags/1.1.1', 0) - assert version._get_git_tag('foo', 'bar') == '1.1.1' + run_git.return_value = ( + b"ab38751890efa8be96b7f95938d6b868b769bab6 tags/v1.1.1^0", + 0, + ) + assert version._get_git_tag("foo", "bar") == "1.1.1" + run_git.return_value = ( + b"ab38751890efa8be96b7f95938d6b868b769bab6 tags/v1.1.1", + 0, + ) + assert version._get_git_tag("foo", "bar") == "1.1.1" + run_git.return_value = ( + b"ab38751890efa8be96b7f95938d6b868b769bab6 tags/1.1.1", + 0, + ) + assert version._get_git_tag("foo", "bar") == "1.1.1" -- cgit v1.2.3-65-gdbad