diff options
author | Brian Harring <ferringb@gmail.com> | 2022-12-24 13:14:53 -0800 |
---|---|---|
committer | Arthur Zamarin <arthurzam@gentoo.org> | 2022-12-25 19:49:11 +0200 |
commit | d6a7c2e44b4f497357f8569d423104232a58f384 (patch) | |
tree | 625ac52169356714a9f5e69e11f2b6cc2d72355a /src/snakeoil | |
parent | compression: prefer gtar over tar if available (diff) | |
download | snakeoil-d6a7c2e44b4f497357f8569d423104232a58f384.tar.gz snakeoil-d6a7c2e44b4f497357f8569d423104232a58f384.tar.bz2 snakeoil-d6a7c2e44b4f497357f8569d423104232a58f384.zip |
Reformat w/ black 22.12.0 for consistency.
Signed-off-by: Brian Harring <ferringb@gmail.com>
Signed-off-by: Arthur Zamarin <arthurzam@gentoo.org>
Diffstat (limited to 'src/snakeoil')
51 files changed, 1614 insertions, 1006 deletions
diff --git a/src/snakeoil/__init__.py b/src/snakeoil/__init__.py index 636d663..a1c1f20 100644 --- a/src/snakeoil/__init__.py +++ b/src/snakeoil/__init__.py @@ -10,5 +10,5 @@ This library is a bit of a grabbag of the following: * optimized implementations of common patterns """ -__title__ = 'snakeoil' -__version__ = '0.10.4' +__title__ = "snakeoil" +__version__ = "0.10.4" diff --git a/src/snakeoil/_fileutils.py b/src/snakeoil/_fileutils.py index 51f82c7..4a22648 100644 --- a/src/snakeoil/_fileutils.py +++ b/src/snakeoil/_fileutils.py @@ -5,7 +5,9 @@ Access this functionality from :py:module:`snakeoil.osutils` instead """ __all__ = ( - "mmap_and_close", "readlines_iter", "native_readlines", + "mmap_and_close", + "readlines_iter", + "native_readlines", "native_readfile", ) @@ -31,6 +33,7 @@ def mmap_and_close(fd, *args, **kwargs): class readlines_iter: __slots__ = ("iterable", "mtime", "source") + def __init__(self, iterable, mtime, close=True, source=None): if source is None: source = iterable @@ -54,17 +57,25 @@ class readlines_iter: source.close() def close(self): - if hasattr(self.source, 'close'): + if hasattr(self.source, "close"): self.source.close() def __iter__(self): return self.iterable + def _native_readlines_shim(*args, **kwds): - return native_readlines('r', *args, **kwds) + return native_readlines("r", *args, **kwds) + -def native_readlines(mode, mypath, strip_whitespace=True, swallow_missing=False, - none_on_missing=False, encoding=None): +def native_readlines( + mode, + mypath, + strip_whitespace=True, + swallow_missing=False, + none_on_missing=False, + encoding=None, +): """Read a file, yielding each line. :param mypath: fs path for the file to read @@ -102,8 +113,10 @@ def _py2k_ascii_strict_filter(source): raise ValueError("character ordinal over 127") yield line + def _native_readfile_shim(*args, **kwds): - return native_readfile('r', *args, **kwds) + return native_readfile("r", *args, **kwds) + def native_readfile(mode, mypath, none_on_missing=False, encoding=None): """Read a file, returning the contents. diff --git a/src/snakeoil/bash.py b/src/snakeoil/bash.py index 5eaaca2..3ca737c 100644 --- a/src/snakeoil/bash.py +++ b/src/snakeoil/bash.py @@ -15,19 +15,25 @@ from .fileutils import readlines from .log import logger from .mappings import ProtectedDict -demand_compile_regexp('line_cont_regexp', r'^(.*[^\\]|)\\$') -demand_compile_regexp('inline_comment_regexp', r'^.*\s#.*$') -demand_compile_regexp('var_find', r'\\?(\${\w+}|\$\w+)') -demand_compile_regexp('backslash_find', r'\\.') -demand_compile_regexp('ansi_escape_re', r'(\x9B|\x1B\[)[0-?]*[ -/]*[@-~]') +demand_compile_regexp("line_cont_regexp", r"^(.*[^\\]|)\\$") +demand_compile_regexp("inline_comment_regexp", r"^.*\s#.*$") +demand_compile_regexp("var_find", r"\\?(\${\w+}|\$\w+)") +demand_compile_regexp("backslash_find", r"\\.") +demand_compile_regexp("ansi_escape_re", r"(\x9B|\x1B\[)[0-?]*[ -/]*[@-~]") __all__ = ( - "iter_read_bash", "read_bash", "read_dict", "read_bash_dict", - "bash_parser", "BashParseError") - - -def iter_read_bash(bash_source, allow_inline_comments=True, - allow_line_cont=False, enum_line=False): + "iter_read_bash", + "read_bash", + "read_dict", + "read_bash_dict", + "bash_parser", + "BashParseError", +) + + +def iter_read_bash( + bash_source, allow_inline_comments=True, allow_line_cont=False, enum_line=False +): """Iterate over a file honoring bash commenting rules and line continuations. Note that it's considered good behaviour to close filehandles, as @@ -45,7 +51,7 @@ def iter_read_bash(bash_source, allow_inline_comments=True, """ if isinstance(bash_source, str): bash_source = readlines(bash_source, True) - s = '' + s = "" for lineno, line in enumerate(bash_source, 1): if allow_line_cont and s: s += line @@ -53,19 +59,20 @@ def iter_read_bash(bash_source, allow_inline_comments=True, s = line.lstrip() if s: - if s[0] != '#': + if s[0] != "#": if allow_inline_comments: - if (not allow_line_cont or - (allow_line_cont and inline_comment_regexp.match(line))): + if not allow_line_cont or ( + allow_line_cont and inline_comment_regexp.match(line) + ): s = s.split("#", 1)[0].rstrip() if allow_line_cont and line_cont_regexp.match(line): - s = s.rstrip('\\\n') + s = s.rstrip("\\\n") continue if enum_line: yield lineno, s.rstrip() else: yield s.rstrip() - s = '' + s = "" if s: if enum_line: yield lineno, s @@ -122,7 +129,7 @@ def read_bash_dict(bash_source, vars_dict=None, sourcing_command=None): try: while tok is not None: key = s.get_token() - if key == 'export': + if key == "export": # discard 'export' token from "export VAR=VALUE" lines key = s.get_token() if key is None: @@ -133,23 +140,23 @@ def read_bash_dict(bash_source, vars_dict=None, sourcing_command=None): # detect empty assigns continue eq = s.get_token() - if eq != '=': + if eq != "=": raise BashParseError( - bash_source, s.lineno, - "got token %r, was expecting '='" % eq) + bash_source, s.lineno, "got token %r, was expecting '='" % eq + ) val = s.get_token() if val is None: - val = '' - elif val == 'export': + val = "" + elif val == "export": val = s.get_token() # look ahead to see if we just got an empty assign. next_tok = s.get_token() - if next_tok == '=': + if next_tok == "=": # ... we did. # leftmost insertions, thus reversed ordering s.push_token(next_tok) s.push_token(val) - val = '' + val = "" else: s.push_token(next_tok) d[key] = val @@ -163,9 +170,15 @@ def read_bash_dict(bash_source, vars_dict=None, sourcing_command=None): return d -def read_dict(bash_source, splitter="=", source_isiter=False, - allow_inline_comments=True, strip=False, filename=None, - ignore_errors=False): +def read_dict( + bash_source, + splitter="=", + source_isiter=False, + allow_inline_comments=True, + strip=False, + filename=None, + ignore_errors=False, +): """Read key value pairs from a file, ignoring bash-style comments. :param splitter: the string to split on. Can be None to @@ -180,12 +193,11 @@ def read_dict(bash_source, splitter="=", source_isiter=False, d = {} if not source_isiter: filename = bash_source - i = iter_read_bash( - bash_source, allow_inline_comments=allow_inline_comments) + i = iter_read_bash(bash_source, allow_inline_comments=allow_inline_comments) else: if filename is None: # XXX what to do? - filename = '<unknown>' + filename = "<unknown>" i = bash_source line_count = 0 try: @@ -195,10 +207,11 @@ def read_dict(bash_source, splitter="=", source_isiter=False, k, v = k.split(splitter, 1) except ValueError as e: if filename == "<unknown>": - filename = getattr(bash_source, 'name', bash_source) + filename = getattr(bash_source, "name", bash_source) if ignore_errors: logger.error( - 'bash parse error in %r, line %s', filename, line_count) + "bash parse error in %r, line %s", filename, line_count + ) continue else: raise BashParseError(filename, line_count) from e @@ -239,7 +252,7 @@ class bash_parser(shlex): :param env: initial environment to use for variable interpolation :type env: must be a mapping; if None, an empty dict is used """ - self.__dict__['state'] = ' ' + self.__dict__["state"] = " " super().__init__(source, posix=True, infile=infile) self.wordchars += "@${}/.-+/:~^*" self.wordchars = frozenset(self.wordchars) @@ -252,12 +265,10 @@ class bash_parser(shlex): def __setattr__(self, attr, val): if attr == "state": - if (self.state, val) in ( - ('"', 'a'), ('a', '"'), ('a', ' '), ("'", 'a')): + if (self.state, val) in (('"', "a"), ("a", '"'), ("a", " "), ("'", "a")): strl = len(self.token) if self.__pos != strl: - self.changed_state.append( - (self.state, self.token[self.__pos:])) + self.changed_state.append((self.state, self.token[self.__pos :])) self.__pos = strl self.__dict__[attr] = val @@ -275,13 +286,13 @@ class bash_parser(shlex): return token if self.state is None: # eof reached. - self.changed_state.append((self.state, token[self.__pos:])) + self.changed_state.append((self.state, token[self.__pos :])) else: - self.changed_state.append((self.state, self.token[self.__pos:])) - tok = '' + self.changed_state.append((self.state, self.token[self.__pos :])) + tok = "" for s, t in self.changed_state: if s in ('"', "a"): - tok += self.var_expand(t).replace("\\\n", '') + tok += self.var_expand(t).replace("\\\n", "") else: tok += t return tok @@ -291,26 +302,27 @@ class bash_parser(shlex): l = [] while match := var_find.search(val, pos): pos = match.start() - if val[pos] == '\\': + if val[pos] == "\\": # it's escaped. either it's \\$ or \\${ , either way, # skipping two ahead handles it. pos += 2 else: - var = val[match.start():match.end()].strip("${}") + var = val[match.start() : match.end()].strip("${}") if prev != pos: l.append(val[prev:pos]) if var in self.env: if not isinstance(self.env[var], str): raise ValueError( - "env key %r must be a string, not %s: %r" % ( - var, type(self.env[var]), self.env[var])) + "env key %r must be a string, not %s: %r" + % (var, type(self.env[var]), self.env[var]) + ) l.append(self.env[var]) else: l.append("") prev = pos = match.end() # do \\ cleansing, collapsing val down also. - val = backslash_find.sub(_nuke_backslash, ''.join(l) + val[prev:]) + val = backslash_find.sub(_nuke_backslash, "".join(l) + val[prev:]) return val @@ -320,10 +332,11 @@ class BashParseError(Exception): def __init__(self, filename, line, errmsg=None): if errmsg is not None: super().__init__( - "error parsing '%s' on or before line %i: err %s" % - (filename, line, errmsg)) + "error parsing '%s' on or before line %i: err %s" + % (filename, line, errmsg) + ) else: super().__init__( - "error parsing '%s' on or before line %i" % - (filename, line)) + "error parsing '%s' on or before line %i" % (filename, line) + ) self.file, self.line, self.errmsg = filename, line, errmsg diff --git a/src/snakeoil/caching.py b/src/snakeoil/caching.py index 3ed75d2..b20f074 100644 --- a/src/snakeoil/caching.py +++ b/src/snakeoil/caching.py @@ -79,21 +79,22 @@ class WeakInstMeta(type): Examples of usage is the restrictions subsystem for U{pkgcore project<http://pkgcore.org>} """ + def __new__(cls, name, bases, d): if d.get("__inst_caching__", False): d["__inst_caching__"] = True d["__inst_dict__"] = WeakValueDictionary() else: d["__inst_caching__"] = False - slots = d.get('__slots__') + slots = d.get("__slots__") # get ourselves a singleton to be safe... o = object() if slots is not None: for base in bases: - if getattr(base, '__weakref__', o) is not o: + if getattr(base, "__weakref__", o) is not o: break else: - d['__slots__'] = tuple(slots) + ('__weakref__',) + d["__slots__"] = tuple(slots) + ("__weakref__",) return type.__new__(cls, name, bases, d) def __call__(cls, *a, **kw): @@ -105,8 +106,7 @@ class WeakInstMeta(type): try: instance = cls.__inst_dict__.get(key) except (NotImplementedError, TypeError) as t: - warnings.warn( - f"caching keys for {cls}, got {t} for a={a}, kw={kw}") + warnings.warn(f"caching keys for {cls}, got {t} for a={a}, kw={kw}") del t key = instance = None diff --git a/src/snakeoil/chksum/__init__.py b/src/snakeoil/chksum/__init__.py index 8f2bc73..5bd6a7c 100644 --- a/src/snakeoil/chksum/__init__.py +++ b/src/snakeoil/chksum/__init__.py @@ -64,10 +64,9 @@ def init(additional_handlers=None): :param additional_handlers: None, or pass in a dict of type:func """ - global __inited__ # pylint: disable=global-statement + global __inited__ # pylint: disable=global-statement - if additional_handlers is not None and not isinstance( - additional_handlers, dict): + if additional_handlers is not None and not isinstance(additional_handlers, dict): raise TypeError("additional handlers must be a dict!") chksum_types.clear() @@ -119,15 +118,19 @@ def get_chksums(location, *chksums, **kwds): # try to hand off to the per file handler, may be faster. if len(chksums) == 1: return [handlers[chksums[0]](location)] - if len(chksums) == 2 and 'size' in chksums: + if len(chksums) == 2 and "size" in chksums: parallelize = False else: parallelize = kwds.get("parallelize", True) can_mmap = True for k in chksums: can_mmap &= handlers[k].can_mmap - return chksum_loop_over_file(location, [handlers[k].new() for k in chksums], - parallelize=parallelize, can_mmap=can_mmap) + return chksum_loop_over_file( + location, + [handlers[k].new() for k in chksums], + parallelize=parallelize, + can_mmap=can_mmap, + ) class LazilyHashedPath(metaclass=klass.immutable_instance): @@ -135,7 +138,7 @@ class LazilyHashedPath(metaclass=klass.immutable_instance): def __init__(self, path, **initial_values): f = object.__setattr__ - f(self, 'path', path) + f(self, "path", path) for attr, val in initial_values.items(): f(self, attr, val) @@ -143,7 +146,7 @@ class LazilyHashedPath(metaclass=klass.immutable_instance): if not attr.islower(): # Disallow sHa1. raise AttributeError(attr) - elif attr == 'mtime': + elif attr == "mtime": val = osutils.stat_mtime_long(self.path) else: try: diff --git a/src/snakeoil/chksum/defaults.py b/src/snakeoil/chksum/defaults.py index bf2be41..fe01377 100644 --- a/src/snakeoil/chksum/defaults.py +++ b/src/snakeoil/chksum/defaults.py @@ -11,7 +11,7 @@ from sys import intern from ..data_source import base as base_data_source from ..fileutils import mmap_or_open_for_read -blocksize = 2 ** 17 +blocksize = 2**17 blake2b_size = 128 blake2s_size = 64 @@ -36,8 +36,11 @@ def chf_thread(queue, callback): def chksum_loop_over_file(filename, chfs, parallelize=True, can_mmap=True): chfs = [chf() for chf in chfs] loop_over_file( - filename, [chf.update for chf in chfs], - parallelize=parallelize, can_mmap=can_mmap) + filename, + [chf.update for chf in chfs], + parallelize=parallelize, + can_mmap=can_mmap, + ) return [int(chf.hexdigest(), 16) for chf in chfs] @@ -54,7 +57,7 @@ def loop_over_file(handle, callbacks, parallelize=True, can_mmap=True): else: f = handle close_f = False - if getattr(handle, 'encoding', None): + if getattr(handle, "encoding", None): # wanker. bypass the encoding, go straight to the raw source. f = f.buffer # reset; we do it for compat, but it also avoids unpleasant issues from @@ -68,8 +71,10 @@ def loop_over_file(handle, callbacks, parallelize=True, can_mmap=True): if parallelize: queues = [queue.Queue(8) for _ in callbacks] - threads = [threading.Thread(target=chf_thread, args=(queue, functor)) - for queue, functor in zip(queues, callbacks)] + threads = [ + threading.Thread(target=chf_thread, args=(queue, functor)) + for queue, functor in zip(queues, callbacks) + ] for thread in threads: thread.start() @@ -79,7 +84,7 @@ def loop_over_file(handle, callbacks, parallelize=True, can_mmap=True): if m is not None: for callback in callbacks: callback(m) - elif hasattr(f, 'getvalue'): + elif hasattr(f, "getvalue"): data = f.getvalue() if not isinstance(data, bytes): data = data.encode() @@ -107,7 +112,6 @@ def loop_over_file(handle, callbacks, parallelize=True, can_mmap=True): class Chksummer: - def __init__(self, chf_type, obj, str_size, can_mmap=True): self.obj = obj self.chf_type = chf_type @@ -118,15 +122,14 @@ class Chksummer: return self.obj def long2str(self, val): - return ("%x" % val).rjust(self.str_size, '0') + return ("%x" % val).rjust(self.str_size, "0") @staticmethod def str2long(val): return int(val, 16) def __call__(self, filename): - return chksum_loop_over_file( - filename, [self.obj], can_mmap=self.can_mmap)[0] + return chksum_loop_over_file(filename, [self.obj], can_mmap=self.can_mmap)[0] def __str__(self): return "%s chksummer" % self.chf_type @@ -134,31 +137,26 @@ class Chksummer: chksum_types = { chksumname: Chksummer(chksumname, partial(hashlib.new, hashlibname), size) - for hashlibname, chksumname, size in [ # conditional upon FIPS, but available in >3.8. - ('md5', 'md5', md5_size), - + ("md5", "md5", md5_size), # Guaranteed as of python 3.8 - ('blake2b', 'blake2b', blake2b_size), - ('blake2s', 'blake2s', blake2s_size), - ('sha1', 'sha1', sha1_size), - ('sha256', 'sha256', sha256_size), - ('sha3_256', 'sha3_256', sha3_256_size), - ('sha3_512', 'sha3_512', sha3_512_size), - ('sha512', 'sha512', sha512_size), - + ("blake2b", "blake2b", blake2b_size), + ("blake2s", "blake2s", blake2s_size), + ("sha1", "sha1", sha1_size), + ("sha256", "sha256", sha256_size), + ("sha3_256", "sha3_256", sha3_256_size), + ("sha3_512", "sha3_512", sha3_512_size), + ("sha512", "sha512", sha512_size), # not guaranteed, but may be available. - ('whirlpool', 'whirlpool', whirlpool_size), - ('ripemd160', 'rmd160', rmd160_size), - + ("whirlpool", "whirlpool", whirlpool_size), + ("ripemd160", "rmd160", rmd160_size), ] if hashlibname in hashlib.algorithms_available } class SizeUpdater: - def __init__(self): self.count = 0 @@ -176,8 +174,7 @@ class SizeChksummer(Chksummer): """ def __init__(self): - super().__init__( - chf_type='size', obj=SizeUpdater, str_size=1000000000) + super().__init__(chf_type="size", obj=SizeUpdater, str_size=1000000000) @staticmethod def long2str(val): diff --git a/src/snakeoil/cli/arghparse.py b/src/snakeoil/cli/arghparse.py index 1c832c3..774699e 100644 --- a/src/snakeoil/cli/arghparse.py +++ b/src/snakeoil/cli/arghparse.py @@ -9,8 +9,18 @@ import pkgutil import subprocess import sys import traceback -from argparse import (_UNRECOGNIZED_ARGS_ATTR, OPTIONAL, PARSER, REMAINDER, SUPPRESS, ZERO_OR_MORE, - ArgumentError, _, _get_action_name, _SubParsersAction) +from argparse import ( + _UNRECOGNIZED_ARGS_ATTR, + OPTIONAL, + PARSER, + REMAINDER, + SUPPRESS, + ZERO_OR_MORE, + ArgumentError, + _, + _get_action_name, + _SubParsersAction, +) from collections import Counter from functools import partial from itertools import chain @@ -31,11 +41,11 @@ from ..version import get_version _generate_docs = False -@klass.patch('argparse.ArgumentParser.add_subparsers') -@klass.patch('argparse._SubParsersAction.add_parser') -@klass.patch('argparse._ActionsContainer.add_mutually_exclusive_group') -@klass.patch('argparse._ActionsContainer.add_argument_group') -@klass.patch('argparse._ActionsContainer.add_argument') +@klass.patch("argparse.ArgumentParser.add_subparsers") +@klass.patch("argparse._SubParsersAction.add_parser") +@klass.patch("argparse._ActionsContainer.add_mutually_exclusive_group") +@klass.patch("argparse._ActionsContainer.add_argument_group") +@klass.patch("argparse._ActionsContainer.add_argument") def _add_argument_docs(orig_func, self, *args, **kwargs): """Enable docs keyword argument support for argparse arguments. @@ -48,16 +58,16 @@ def _add_argument_docs(orig_func, self, *args, **kwargs): enable the global _generate_docs variable in order to replace the summarized help strings with the extended doc strings. """ - docs = kwargs.pop('docs', None) + docs = kwargs.pop("docs", None) obj = orig_func(self, *args, **kwargs) if _generate_docs and docs is not None: if isinstance(docs, (list, tuple)): # list args are often used if originator wanted to strip # off first description summary line - docs = '\n'.join(docs) - docs = '\n'.join(dedent(docs).strip().split('\n')) + docs = "\n".join(docs) + docs = "\n".join(dedent(docs).strip().split("\n")) - if orig_func.__name__ == 'add_subparsers': + if orig_func.__name__ == "add_subparsers": # store original description before overriding it with extended # docs for general subparsers argument groups self._subparsers._description = self._subparsers.description @@ -93,7 +103,7 @@ class ParseNonblockingStdin(argparse.Action): """Accept arguments from standard input in a non-blocking fashion.""" def __init__(self, *args, **kwargs): - self.filter_func = kwargs.pop('filter_func', lambda x: x.strip()) + self.filter_func = kwargs.pop("filter_func", lambda x: x.strip()) super().__init__(*args, **kwargs) def _stdin(self): @@ -106,9 +116,11 @@ class ParseNonblockingStdin(argparse.Action): break def __call__(self, parser, namespace, values, option_string=None): - if values is not None and len(values) == 1 and values[0] == '-': + if values is not None and len(values) == 1 and values[0] == "-": if sys.stdin.isatty(): - raise argparse.ArgumentError(self, "'-' is only valid when piping data in") + raise argparse.ArgumentError( + self, "'-' is only valid when piping data in" + ) values = self._stdin() setattr(namespace, self.dest, values) @@ -117,16 +129,18 @@ class ParseStdin(ExtendAction): """Accept arguments from standard input in a blocking fashion.""" def __init__(self, *args, **kwargs): - self.filter_func = kwargs.pop('filter_func', lambda x: x.strip()) + self.filter_func = kwargs.pop("filter_func", lambda x: x.strip()) super().__init__(*args, **kwargs) def __call__(self, parser, namespace, values, option_string=None): - if values is not None and len(values) == 1 and values[0] == '-': + if values is not None and len(values) == 1 and values[0] == "-": if sys.stdin.isatty(): - raise argparse.ArgumentError(self, "'-' is only valid when piping data in") + raise argparse.ArgumentError( + self, "'-' is only valid when piping data in" + ) values = [x.rstrip() for x in sys.stdin.readlines() if self.filter_func(x)] # reassign stdin to allow interactivity (currently only works for unix) - sys.stdin = open('/dev/tty') + sys.stdin = open("/dev/tty") super().__call__(parser, namespace, values, option_string) @@ -136,10 +150,10 @@ class CommaSeparatedValues(argparse._AppendAction): def parse_values(self, values): items = [] if isinstance(values, str): - items.extend(x for x in values.split(',') if x) + items.extend(x for x in values.split(",") if x) else: for value in values: - items.extend(x for x in value.split(',') if x) + items.extend(x for x in value.split(",") if x) return items def __call__(self, parser, namespace, values, option_string=None): @@ -174,16 +188,16 @@ class CommaSeparatedNegations(argparse._AppendAction): values = [values] for value in values: try: - neg, pos = split_negations(x for x in value.split(',') if x) + neg, pos = split_negations(x for x in value.split(",") if x) except ValueError as e: raise argparse.ArgumentTypeError(e) disabled.extend(neg) enabled.extend(pos) if colliding := set(disabled).intersection(enabled): - collisions = ', '.join(map(repr, sorted(colliding))) + collisions = ", ".join(map(repr, sorted(colliding))) s = pluralism(colliding) - msg = f'colliding value{s}: {collisions}' + msg = f"colliding value{s}: {collisions}" raise argparse.ArgumentError(self, msg) return disabled, enabled @@ -222,7 +236,7 @@ class CommaSeparatedElements(argparse._AppendAction): values = [values] for value in values: try: - neg, neu, pos = split_elements(x for x in value.split(',') if x) + neg, neu, pos = split_elements(x for x in value.split(",") if x) except ValueError as e: raise argparse.ArgumentTypeError(e) disabled.extend(neg) @@ -231,9 +245,9 @@ class CommaSeparatedElements(argparse._AppendAction): elements = [set(x) for x in (disabled, neutral, enabled) if x] if len(elements) > 1 and (colliding := set.intersection(*elements)): - collisions = ', '.join(map(repr, sorted(colliding))) + collisions = ", ".join(map(repr, sorted(colliding))) s = pluralism(colliding) - msg = f'colliding value{s}: {collisions}' + msg = f"colliding value{s}: {collisions}" raise argparse.ArgumentError(self, msg) return disabled, neutral, enabled @@ -260,14 +274,14 @@ class ManHelpAction(argparse._HelpAction): """Display man pages for long --help option and abbreviated output for -h.""" def __call__(self, parser, namespace, values, option_string=None): - if option_string == '--help': + if option_string == "--help": # Try spawning man page -- assumes one level deep for subcommand # specific man pages with commands separated by hyphen. For example # running `pinspect profile --help` tries to open pinspect-profile # man page, but `pinspect profile masks --help` also tries to open # pinspect-profile. - man_page = '-'.join(parser.prog.split()[:2]) - p = subprocess.Popen(['man', man_page], stderr=subprocess.DEVNULL) + man_page = "-".join(parser.prog.split()[:2]) + p = subprocess.Popen(["man", man_page], stderr=subprocess.DEVNULL) p.communicate() if p.returncode == 0: parser.exit() @@ -279,16 +293,17 @@ class ManHelpAction(argparse._HelpAction): class StoreBool(argparse._StoreAction): - - def __init__(self, - option_strings, - dest, - nargs=None, - const=None, - default=None, - required=False, - help=None, - metavar='BOOLEAN'): + def __init__( + self, + option_strings, + dest, + nargs=None, + const=None, + default=None, + required=False, + help=None, + metavar="BOOLEAN", + ): super().__init__( option_strings=option_strings, dest=dest, @@ -298,38 +313,42 @@ class StoreBool(argparse._StoreAction): type=self.boolean, required=required, help=help, - metavar=metavar) + metavar=metavar, + ) @staticmethod def boolean(value): value = value.lower() - if value in ('y', 'yes', 'true', '1'): + if value in ("y", "yes", "true", "1"): return True - elif value in ('n', 'no', 'false', '0'): + elif value in ("n", "no", "false", "0"): return False raise ValueError("value %r must be [y|yes|true|1|n|no|false|0]" % (value,)) class EnableDebug(argparse._StoreTrueAction): - def __call__(self, parser, namespace, values, option_string=None): super().__call__(parser, namespace, values, option_string=option_string) logging.root.setLevel(logging.DEBUG) class Verbosity(argparse.Action): - def __init__(self, option_strings, dest, default=None, required=False, help=None): super().__init__( - option_strings=option_strings, dest=dest, nargs=0, - default=default, required=required, help=help) + option_strings=option_strings, + dest=dest, + nargs=0, + default=default, + required=required, + help=help, + ) # map verbose/quiet args to increment/decrement the underlying verbosity value self.value_map = { - '-q': -1, - '--quiet': -1, - '-v': 1, - '--verbose': 1, + "-q": -1, + "--quiet": -1, + "-v": 1, + "--verbose": 1, } def __call__(self, parser, namespace, values, option_string=None): @@ -343,7 +362,6 @@ class Verbosity(argparse.Action): class DelayedValue: - def __init__(self, invokable, priority=0): self.priority = priority if not callable(invokable): @@ -355,7 +373,6 @@ class DelayedValue: class DelayedDefault(DelayedValue): - @classmethod def wipe(cls, attrs, priority): if isinstance(attrs, str): @@ -376,20 +393,17 @@ class DelayedDefault(DelayedValue): class DelayedParse(DelayedValue): - def __call__(self, namespace, attr): self.invokable() class OrderedParse(DelayedValue): - def __call__(self, namespace, attr): self.invokable(namespace) delattr(namespace, attr) class Delayed(argparse.Action): - def __init__(self, option_strings, dest, target=None, priority=0, **kwargs): if target is None: raise ValueError("target must be non None for Delayed") @@ -397,21 +411,30 @@ class Delayed(argparse.Action): self.priority = int(priority) self.target = target(option_strings=option_strings, dest=dest, **kwargs.copy()) super().__init__( - option_strings=option_strings[:], dest=dest, - nargs=kwargs.get("nargs", None), required=kwargs.get("required", None), - help=kwargs.get("help", None), metavar=kwargs.get("metavar", None), - default=kwargs.get("default", None)) + option_strings=option_strings[:], + dest=dest, + nargs=kwargs.get("nargs", None), + required=kwargs.get("required", None), + help=kwargs.get("help", None), + metavar=kwargs.get("metavar", None), + default=kwargs.get("default", None), + ) def __call__(self, parser, namespace, values, option_string=None): - setattr(namespace, self.dest, DelayedParse( - partial(self.target, parser, namespace, values, option_string), - self.priority)) + setattr( + namespace, + self.dest, + DelayedParse( + partial(self.target, parser, namespace, values, option_string), + self.priority, + ), + ) class Expansion(argparse.Action): - - def __init__(self, option_strings, dest, nargs=None, help=None, - required=None, subst=None): + def __init__( + self, option_strings, dest, nargs=None, help=None, required=None, subst=None + ): if subst is None: raise TypeError("substitution string must be set") # simple aliases with no required arguments shouldn't need to specify nargs @@ -424,7 +447,8 @@ class Expansion(argparse.Action): help=help, required=required, default=False, - nargs=nargs) + nargs=nargs, + ) self.subst = tuple(subst) def __call__(self, parser, namespace, values, option_string=None): @@ -434,7 +458,7 @@ class Expansion(argparse.Action): if isinstance(values, str): vals = [vals] dvals = {str(idx): val for idx, val in enumerate(vals)} - dvals['*'] = ' '.join(vals) + dvals["*"] = " ".join(vals) for action in actions: action_map.update((option, action) for option in action.option_strings) @@ -445,8 +469,8 @@ class Expansion(argparse.Action): args = [x % dvals for x in args] if not action: raise ValueError( - "unable to find option %r for %r" % - (option, self.option_strings)) + "unable to find option %r for %r" % (option, self.option_strings) + ) if action.type is not None: args = list(map(action.type, args)) if action.nargs in (1, None): @@ -456,7 +480,6 @@ class Expansion(argparse.Action): class _SubParser(argparse._SubParsersAction): - def add_parser(self, name, cls=None, **kwargs): """Subparser that links description/help if one is specified.""" description = kwargs.get("description") @@ -465,7 +488,7 @@ class _SubParser(argparse._SubParsersAction): if help_txt is not None: kwargs["description"] = help_txt elif help_txt is None: - kwargs["help"] = description.split('\n', 1)[0] + kwargs["help"] = description.split("\n", 1)[0] # support using a custom parser class for the subparser orig_class = self._parser_class @@ -486,7 +509,7 @@ class _SubParser(argparse._SubParsersAction): Note that this assumes a specific module naming and layout scheme for commands. """ prog = self._prog_prefix - module = f'{prog}.scripts.{prog}_{subcmd}' + module = f"{prog}.scripts.{prog}_{subcmd}" func = partial(self._lazy_parser, module, subcmd) self._name_parser_map[subcmd] = lazy_object_proxy.Proxy(func) @@ -507,8 +530,8 @@ class _SubParser(argparse._SubParsersAction): try: parser = self._name_parser_map[parser_name] except KeyError: - tup = parser_name, ', '.join(self._name_parser_map) - msg = _('unknown parser %r (choices: %s)') % tup + tup = parser_name, ", ".join(self._name_parser_map) + msg = _("unknown parser %r (choices: %s)") % tup raise argparse.ArgumentError(self, msg) # parse all the remaining options into the namespace @@ -526,11 +549,13 @@ class CsvHelpFormatter(argparse.HelpFormatter): def _format_args(self, action, default_metavar): get_metavar = self._metavar_formatter(action, default_metavar) if isinstance(action, (CommaSeparatedValues, CommaSeparatedValuesAppend)): - result = '%s[,%s,...]' % get_metavar(2) - elif isinstance(action, (CommaSeparatedNegations, CommaSeparatedNegationsAppend)): - result = '%s[,-%s,...]' % get_metavar(2) + result = "%s[,%s,...]" % get_metavar(2) + elif isinstance( + action, (CommaSeparatedNegations, CommaSeparatedNegationsAppend) + ): + result = "%s[,-%s,...]" % get_metavar(2) elif isinstance(action, (CommaSeparatedElements, CommaSeparatedElementsAppend)): - result = '%s[,-%s,+%s...]' % get_metavar(3) + result = "%s[,-%s,+%s...]" % get_metavar(3) else: result = super()._format_args(action, default_metavar) return result @@ -540,7 +565,7 @@ class SortedHelpFormatter(CsvHelpFormatter): """Help formatter that sorts arguments by option strings.""" def add_arguments(self, actions): - actions = sorted(actions, key=attrgetter('option_strings')) + actions = sorted(actions, key=attrgetter("option_strings")) super().add_arguments(actions) @@ -576,7 +601,7 @@ class SubcmdAbbrevArgumentParser(argparse.ArgumentParser): # for everything but PARSER, REMAINDER args, strip out first '--' if action.nargs not in [PARSER, REMAINDER]: try: - arg_strings.remove('--') + arg_strings.remove("--") except ValueError: pass @@ -592,8 +617,11 @@ class SubcmdAbbrevArgumentParser(argparse.ArgumentParser): # when nargs='*' on a positional, if there were no command-line # args, use the default if it is anything other than None - elif (not arg_strings and action.nargs == ZERO_OR_MORE and - not action.option_strings): + elif ( + not arg_strings + and action.nargs == ZERO_OR_MORE + and not action.option_strings + ): if action.default is not None: value = action.default else: @@ -602,7 +630,7 @@ class SubcmdAbbrevArgumentParser(argparse.ArgumentParser): # single argument or optional argument produces a single value elif len(arg_strings) == 1 and action.nargs in [None, OPTIONAL]: - arg_string, = arg_strings + (arg_string,) = arg_strings value = self._get_value(action, arg_string) self._check_value(action, value) @@ -688,7 +716,7 @@ class OptionalsParser(argparse.ArgumentParser): for i, mutex_action in enumerate(mutex_group._group_actions): conflicts = action_conflicts.setdefault(mutex_action, []) conflicts.extend(group_actions[:i]) - conflicts.extend(group_actions[i + 1:]) + conflicts.extend(group_actions[i + 1 :]) # find all option indices, and determine the arg_string_pattern # which has an 'O' if there is an option at an index, @@ -699,24 +727,24 @@ class OptionalsParser(argparse.ArgumentParser): for i, arg_string in enumerate(arg_strings_iter): # all args after -- are non-options - if arg_string == '--': - arg_string_pattern_parts.append('-') + if arg_string == "--": + arg_string_pattern_parts.append("-") for arg_string in arg_strings_iter: - arg_string_pattern_parts.append('A') + arg_string_pattern_parts.append("A") # otherwise, add the arg to the arg strings # and note the index if it was an option else: option_tuple = self._parse_optional(arg_string) if option_tuple is None: - pattern = 'A' + pattern = "A" else: option_string_indices[i] = option_tuple - pattern = 'O' + pattern = "O" arg_string_pattern_parts.append(pattern) # join the pieces together to form the pattern - arg_strings_pattern = ''.join(arg_string_pattern_parts) + arg_strings_pattern = "".join(arg_string_pattern_parts) # converts arg strings to the appropriate and then takes the action seen_actions = set() @@ -733,7 +761,7 @@ class OptionalsParser(argparse.ArgumentParser): seen_non_default_actions.add(action) for conflict_action in action_conflicts.get(action, []): if conflict_action in seen_non_default_actions: - msg = _('not allowed with argument %s') + msg = _("not allowed with argument %s") action_name = _get_action_name(conflict_action) raise ArgumentError(action, msg % action_name) @@ -762,14 +790,14 @@ class OptionalsParser(argparse.ArgumentParser): # if we match help options, skip them for now so subparsers # show up in the help output - if arg_strings[start_index] in ('-h', '--help'): + if arg_strings[start_index] in ("-h", "--help"): extras.append(arg_strings[start_index]) return start_index + 1 # if there is an explicit argument, try to match the # optional's string arguments to only this if explicit_arg is not None: - arg_count = match_argument(action, 'A') + arg_count = match_argument(action, "A") # if the action is a single-dash option and takes no # arguments, try to parse more single-dash options out @@ -785,7 +813,7 @@ class OptionalsParser(argparse.ArgumentParser): action = optionals_map[option_string] explicit_arg = new_explicit_arg else: - msg = _('ignored explicit argument %r') + msg = _("ignored explicit argument %r") raise ArgumentError(action, msg % explicit_arg) # if the action expect exactly one argument, we've @@ -799,7 +827,7 @@ class OptionalsParser(argparse.ArgumentParser): # error if a double-dash option did not use the # explicit argument else: - msg = _('ignored explicit argument %r') + msg = _("ignored explicit argument %r") raise ArgumentError(action, msg % explicit_arg) # if there is no explicit argument, try to match the @@ -835,13 +863,13 @@ class OptionalsParser(argparse.ArgumentParser): # slice off the appropriate arg strings for each Positional # and add the Positional and its args to the list for action, arg_count in zip(positionals, arg_counts): - args = arg_strings[start_index: start_index + arg_count] + args = arg_strings[start_index : start_index + arg_count] start_index += arg_count take_action(action, args) # slice off the Positionals that we just parsed and return the # index at which the Positionals' string args stopped - positionals[:] = positionals[len(arg_counts):] + positionals[:] = positionals[len(arg_counts) :] return start_index # consume Positionals and Optionals alternately, until we have @@ -855,10 +883,9 @@ class OptionalsParser(argparse.ArgumentParser): while start_index <= max_option_string_index: # consume any Positionals preceding the next option - next_option_string_index = min([ - index - for index in option_string_indices - if index >= start_index]) + next_option_string_index = min( + [index for index in option_string_indices if index >= start_index] + ) if start_index != next_option_string_index: # positionals_end_index = consume_positionals(start_index) positionals_end_index = start_index @@ -894,7 +921,9 @@ class OptionalsParser(argparse.ArgumentParser): for action in self._actions: if action not in seen_actions: # ignore required subcommands and positionals as they'll be handled later - skip = not action.option_strings or isinstance(action, _SubParsersAction) + skip = not action.option_strings or isinstance( + action, _SubParsersAction + ) if action.required and not skip: required_actions.append(_get_action_name(action)) else: @@ -902,16 +931,23 @@ class OptionalsParser(argparse.ArgumentParser): # parsing arguments to avoid calling convert functions # twice (which may fail) if the argument was given, but # only if it was defined already in the namespace - if (action.default is not None and - isinstance(action.default, str) and - hasattr(namespace, action.dest) and - action.default is getattr(namespace, action.dest)): - setattr(namespace, action.dest, - self._get_value(action, action.default)) + if ( + action.default is not None + and isinstance(action.default, str) + and hasattr(namespace, action.dest) + and action.default is getattr(namespace, action.dest) + ): + setattr( + namespace, + action.dest, + self._get_value(action, action.default), + ) if required_actions: - self.error(_('the following arguments are required: %s') % - ', '.join(required_actions)) + self.error( + _("the following arguments are required: %s") + % ", ".join(required_actions) + ) # make sure all required groups had one option present for group in self._mutually_exclusive_groups: @@ -922,11 +958,13 @@ class OptionalsParser(argparse.ArgumentParser): # if no actions were used, report the error else: - names = [_get_action_name(action) - for action in group._group_actions - if action.help is not SUPPRESS] - msg = _('one of the arguments %s is required') - self.error(msg % ' '.join(names)) + names = [ + _get_action_name(action) + for action in group._group_actions + if action.help is not SUPPRESS + ] + msg = _("one of the arguments %s is required") + self.error(msg % " ".join(names)) # return the updated namespace and the extra arguments return namespace, extras @@ -937,31 +975,49 @@ class CsvActionsParser(argparse.ArgumentParser): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.register('action', 'csv', CommaSeparatedValues) - self.register('action', 'csv_append', CommaSeparatedValuesAppend) - self.register('action', 'csv_negations', CommaSeparatedNegations) - self.register('action', 'csv_negations_append', CommaSeparatedNegationsAppend) - self.register('action', 'csv_elements', CommaSeparatedElements) - self.register('action', 'csv_elements_append', CommaSeparatedElementsAppend) + self.register("action", "csv", CommaSeparatedValues) + self.register("action", "csv_append", CommaSeparatedValuesAppend) + self.register("action", "csv_negations", CommaSeparatedNegations) + self.register("action", "csv_negations_append", CommaSeparatedNegationsAppend) + self.register("action", "csv_elements", CommaSeparatedElements) + self.register("action", "csv_elements_append", CommaSeparatedElementsAppend) class ArgumentParser(OptionalsParser, CsvActionsParser): """Extended, argparse-compatible argument parser.""" - def __init__(self, suppress=False, subcmds=False, color=True, debug=True, quiet=True, - verbose=True, version=True, add_help=True, sorted_help=False, - description=None, docs=None, script=None, prog=None, **kwargs): - self.debug = debug and '--debug' in sys.argv[1:] + def __init__( + self, + suppress=False, + subcmds=False, + color=True, + debug=True, + quiet=True, + verbose=True, + version=True, + add_help=True, + sorted_help=False, + description=None, + docs=None, + script=None, + prog=None, + **kwargs, + ): + self.debug = debug and "--debug" in sys.argv[1:] self.verbosity = int(verbose) if self.verbosity: argv = Counter(sys.argv[1:]) # Only supports single, short opts (i.e. -vv isn't recognized), # post argparsing the proper value supporting those kind of args is # in the options namespace. - self.verbosity = sum(chain.from_iterable(( - (-1 for x in range(argv['-q'] + argv['--quiet'])), - (1 for x in range(argv['-v'] + argv['--verbose'])), - ))) + self.verbosity = sum( + chain.from_iterable( + ( + (-1 for x in range(argv["-q"] + argv["--quiet"])), + (1 for x in range(argv["-v"] + argv["--verbose"])), + ) + ) + ) # subparsers action object from calling add_subparsers() self.__subparsers = None @@ -979,7 +1035,7 @@ class ArgumentParser(OptionalsParser, CsvActionsParser): # usage such as adding conflicting options to both the root command and # subcommands without causing issues in addition to helping support # default subparsers. - self._parents = tuple(kwargs.get('parents', ())) + self._parents = tuple(kwargs.get("parents", ())) # extract the description to use and set docs for doc generation description = self._update_desc(description, docs) @@ -993,11 +1049,12 @@ class ArgumentParser(OptionalsParser, CsvActionsParser): raise TypeError except TypeError: raise ValueError( - "invalid script parameter, should be (__file__, __name__)") + "invalid script parameter, should be (__file__, __name__)" + ) - project = script_module.split('.')[0] + project = script_module.split(".")[0] if prog is None: - prog = script_module.split('.')[-1] + prog = script_module.split(".")[-1] if sorted_help: formatter = SortedHelpFormatter @@ -1005,27 +1062,36 @@ class ArgumentParser(OptionalsParser, CsvActionsParser): formatter = CsvHelpFormatter super().__init__( - description=description, formatter_class=formatter, - prog=prog, add_help=False, **kwargs) + description=description, + formatter_class=formatter, + prog=prog, + add_help=False, + **kwargs, + ) # register custom actions - self.register('action', 'parsers', _SubParser) + self.register("action", "parsers", _SubParser) if not suppress: - base_opts = self.add_argument_group('base options') + base_opts = self.add_argument_group("base options") if add_help: base_opts.add_argument( - '-h', '--help', action=ManHelpAction, default=argparse.SUPPRESS, - help='show this help message and exit', + "-h", + "--help", + action=ManHelpAction, + default=argparse.SUPPRESS, + help="show this help message and exit", docs=""" Show this help message and exit. To get more information see the related man page. - """) + """, + ) if version and script is not None: # Note that this option will currently only be available on the # base command, not on subcommands. base_opts.add_argument( - '--version', action='version', + "--version", + action="version", version=get_version(project, script_path), help="show this program's version info and exit", docs=""" @@ -1034,39 +1100,58 @@ class ArgumentParser(OptionalsParser, CsvActionsParser): When running from within a git repo or a version installed from git the latest commit hash and date will be shown. - """) + """, + ) if debug: base_opts.add_argument( - '--debug', action=EnableDebug, help='enable debugging checks', - docs='Enable debug checks and show verbose debug output.') + "--debug", + action=EnableDebug, + help="enable debugging checks", + docs="Enable debug checks and show verbose debug output.", + ) if quiet: base_opts.add_argument( - '-q', '--quiet', action=Verbosity, dest='verbosity', default=0, - help='suppress non-error messages', - docs="Suppress non-error, informational messages.") + "-q", + "--quiet", + action=Verbosity, + dest="verbosity", + default=0, + help="suppress non-error messages", + docs="Suppress non-error, informational messages.", + ) if verbose: base_opts.add_argument( - '-v', '--verbose', action=Verbosity, dest='verbosity', default=0, - help='show verbose output', - docs="Increase the verbosity of various output.") + "-v", + "--verbose", + action=Verbosity, + dest="verbosity", + default=0, + help="show verbose output", + docs="Increase the verbosity of various output.", + ) if color: base_opts.add_argument( - '--color', action=StoreBool, + "--color", + action=StoreBool, default=sys.stdout.isatty(), - help='enable/disable color support', + help="enable/disable color support", docs=""" Toggle colored output support. This can be used to forcibly enable color support when piping output or other sitations where stdout is not a tty. - """) + """, + ) # register existing subcommands if subcmds: - prefix = f'{prog}.scripts.{prog}_' + prefix = f"{prog}.scripts.{prog}_" if subcmd_modules := [ - name[len(prefix):] for _, name, _ in - pkgutil.walk_packages([os.path.dirname(script_path)], f'{prog}.scripts.') - if name.startswith(prefix)]: + name[len(prefix) :] + for _, name, _ in pkgutil.walk_packages( + [os.path.dirname(script_path)], f"{prog}.scripts." + ) + if name.startswith(prefix) + ]: subparsers = self.add_subparsers() for subcmd in subcmd_modules: subparsers.add_command(subcmd) @@ -1080,7 +1165,7 @@ class ArgumentParser(OptionalsParser, CsvActionsParser): """ description_lines = [] if description is not None: - description_lines = description.strip().split('\n', 1) + description_lines = description.strip().split("\n", 1) description = description_lines[0] if _generate_docs: if docs is None and len(description_lines) == 2: @@ -1156,7 +1241,9 @@ class ArgumentParser(OptionalsParser, CsvActionsParser): try: # run registered early parse functions from all parsers - for functor, parser in chain.from_iterable(x.__early_parse for x in self.parsers): + for functor, parser in chain.from_iterable( + x.__early_parse for x in self.parsers + ): namespace, args = functor(parser, namespace, args) # parse the arguments and exit if there are any errors @@ -1176,7 +1263,7 @@ class ArgumentParser(OptionalsParser, CsvActionsParser): args, unknown_args = self.parse_known_args(args, namespace) # make sure the correct function and prog are set if running a subcommand - subcmd_parser = self.subparsers.get(getattr(args, 'subcommand', None), None) + subcmd_parser = self.subparsers.get(getattr(args, "subcommand", None), None) if subcmd_parser is not None: # override the running program with full subcommand self.prog = subcmd_parser.prog @@ -1186,7 +1273,7 @@ class ArgumentParser(OptionalsParser, CsvActionsParser): namespace.main_func = subcmd_parser.__main_func if unknown_args: - self.error('unrecognized arguments: %s' % ' '.join(unknown_args)) + self.error("unrecognized arguments: %s" % " ".join(unknown_args)) # Two runs are required; first, handle any suppression defaults # introduced. Subparsers defaults cannot override the parent parser, as @@ -1198,14 +1285,20 @@ class ArgumentParser(OptionalsParser, CsvActionsParser): # intentionally no protection of suppression code; this should # just work. - i = ((attr, val) for attr, val in args.__dict__.items() - if isinstance(val, DelayedDefault)) + i = ( + (attr, val) + for attr, val in args.__dict__.items() + if isinstance(val, DelayedDefault) + ) for attr, functor in sorted(i, key=lambda val: val[1].priority): functor(args, attr) # now run the delays - i = ((attr, val) for attr, val in args.__dict__.items() - if isinstance(val, DelayedValue)) + i = ( + (attr, val) + for attr, val in args.__dict__.items() + if isinstance(val, DelayedValue) + ) try: for attr, delayed in sorted(i, key=lambda val: val[1].priority): delayed(args, attr) @@ -1216,7 +1309,9 @@ class ArgumentParser(OptionalsParser, CsvActionsParser): self.error(str(e)) # run final arg validation - final_checks = [k for k in args.__dict__.keys() if k.startswith('__final_check__')] + final_checks = [ + k for k in args.__dict__.keys() if k.startswith("__final_check__") + ] for check in final_checks: functor = args.pop(check) functor(self, args) @@ -1232,7 +1327,7 @@ class ArgumentParser(OptionalsParser, CsvActionsParser): if self.debug and sys.exc_info() != (None, None, None): # output traceback if any exception is on the stack traceback.print_exc() - self.exit(status, '%s: error: %s\n' % (self.prog, message)) + self.exit(status, "%s: error: %s\n" % (self.prog, message)) def bind_main_func(self, functor): """Decorator to set a main function for the parser.""" @@ -1245,8 +1340,8 @@ class ArgumentParser(OptionalsParser, CsvActionsParser): def bind_class(self, obj): if not isinstance(obj, ArgparseCommand): raise ValueError( - "expected obj to be an instance of " - "ArgparseCommand; got %r" % (obj,)) + "expected obj to be an instance of " "ArgparseCommand; got %r" % (obj,) + ) obj.bind_to_parser(self) return self @@ -1261,10 +1356,12 @@ class ArgumentParser(OptionalsParser, CsvActionsParser): """Only run delayed default functor if the attribute isn't set.""" if isinstance(object.__getattribute__(namespace, attr), DelayedValue): functor(namespace, attr) + if name is None: name = functor.__name__ self.set_defaults(**{name: DelayedValue(default, priority)}) return functor + return f def bind_parse_priority(self, priority): @@ -1272,6 +1369,7 @@ class ArgumentParser(OptionalsParser, CsvActionsParser): name = functor.__name__ self.set_defaults(**{name: OrderedParse(functor, priority)}) return functor + return f def add_subparsers(self, **kwargs): @@ -1280,9 +1378,9 @@ class ArgumentParser(OptionalsParser, CsvActionsParser): if self.__subparsers is not None: return self.__subparsers - kwargs.setdefault('title', 'subcommands') - kwargs.setdefault('dest', 'subcommand') - kwargs.setdefault('prog', self.prog) + kwargs.setdefault("title", "subcommands") + kwargs.setdefault("dest", "subcommand") + kwargs.setdefault("prog", self.prog) subparsers = argparse.ArgumentParser.add_subparsers(self, **kwargs) subparsers.required = True self.__subparsers = subparsers @@ -1300,18 +1398,17 @@ class ArgumentParser(OptionalsParser, CsvActionsParser): def bind_final_check(self, functor): """Decorator to bind a function for argument validation.""" - name = f'__final_check__{functor.__name__}' + name = f"__final_check__{functor.__name__}" self.set_defaults(**{name: functor}) return functor class ArgparseCommand: - def bind_to_parser(self, parser): parser.bind_main_func(self) def __call__(self, namespace, out, err): - raise NotImplementedError(self, '__call__') + raise NotImplementedError(self, "__call__") class FileType(argparse.FileType): @@ -1322,11 +1419,11 @@ class FileType(argparse.FileType): def __call__(self, string): # the special argument "-" means sys.std{in,out} - if string == '-': - if 'r' in self._mode: - return sys.stdin.buffer if 'b' in self._mode else sys.stdin - elif any(c in self._mode for c in 'wax'): - return sys.stdout.buffer if 'b' in self._mode else sys.stdout + if string == "-": + if "r" in self._mode: + return sys.stdin.buffer if "b" in self._mode else sys.stdin + elif any(c in self._mode for c in "wax"): + return sys.stdout.buffer if "b" in self._mode else sys.stdout else: msg = _('argument "-" with mode %r') % self._mode raise ValueError(msg) @@ -1342,23 +1439,27 @@ class FileType(argparse.FileType): def existent_path(value): """Check if file argument path exists.""" if not os.path.exists(value): - raise argparse.ArgumentTypeError(f'nonexistent path: {value!r}') + raise argparse.ArgumentTypeError(f"nonexistent path: {value!r}") try: return os.path.realpath(value) except EnvironmentError as e: - raise ValueError(f'while resolving path {value!r}, encountered error: {e}') from e + raise ValueError( + f"while resolving path {value!r}, encountered error: {e}" + ) from e def existent_dir(value): """Check if argument path exists and is a directory.""" if not os.path.exists(value): - raise argparse.ArgumentTypeError(f'nonexistent dir: {value!r}') + raise argparse.ArgumentTypeError(f"nonexistent dir: {value!r}") elif not os.path.isdir(value): - raise argparse.ArgumentTypeError(f'file already exists: {value!r}') + raise argparse.ArgumentTypeError(f"file already exists: {value!r}") try: return os.path.realpath(value) except EnvironmentError as e: - raise ValueError(f'while resolving path {value!r}, encountered error: {e}') from e + raise ValueError( + f"while resolving path {value!r}, encountered error: {e}" + ) from e def create_dir(value): @@ -1367,9 +1468,9 @@ def create_dir(value): try: os.makedirs(path, exist_ok=True) except FileExistsError: - raise argparse.ArgumentTypeError(f'file already exists: {value!r}') + raise argparse.ArgumentTypeError(f"file already exists: {value!r}") except IOError as e: - raise argparse.ArgumentTypeError(f'failed creating dir: {e}') + raise argparse.ArgumentTypeError(f"failed creating dir: {e}") return path @@ -1378,12 +1479,12 @@ def bounded_int(func, desc, x): try: n = int(x) except ValueError: - raise argparse.ArgumentTypeError('invalid integer value') + raise argparse.ArgumentTypeError("invalid integer value") if not func(n): - raise argparse.ArgumentTypeError(f'must be {desc}') + raise argparse.ArgumentTypeError(f"must be {desc}") return n def positive_int(x): """Check if argument is a positive integer.""" - return bounded_int(lambda n: n >= 1, '>= 1', x) + return bounded_int(lambda n: n >= 1, ">= 1", x) diff --git a/src/snakeoil/cli/exceptions.py b/src/snakeoil/cli/exceptions.py index fb0149f..42b8d67 100644 --- a/src/snakeoil/cli/exceptions.py +++ b/src/snakeoil/cli/exceptions.py @@ -11,7 +11,7 @@ class UserException(Exception): self._verbosity = verbosity def msg(self, verbosity=0): - return '' + return "" class ExitException(Exception): @@ -30,6 +30,8 @@ class ExitException(Exception): def find_user_exception(exc): """Find the UserException related to a given exception if one exists.""" try: - return next(e for e in walk_exception_chain(exc) if isinstance(e, UserException)) + return next( + e for e in walk_exception_chain(exc) if isinstance(e, UserException) + ) except StopIteration: return None diff --git a/src/snakeoil/cli/input.py b/src/snakeoil/cli/input.py index a89d61e..9db9ea8 100644 --- a/src/snakeoil/cli/input.py +++ b/src/snakeoil/cli/input.py @@ -41,9 +41,9 @@ def userquery(prompt, out, err, responses=None, default_answer=None, limit=3): """ if responses is None: responses = { - 'yes': (True, out.fg('green'), 'Yes'), - 'no': (False, out.fg('red'), 'No'), - } + "yes": (True, out.fg("green"), "Yes"), + "no": (False, out.fg("red"), "No"), + } if default_answer is None: default_answer = True if default_answer is not None: @@ -52,25 +52,25 @@ def userquery(prompt, out, err, responses=None, default_answer=None, limit=3): default_answer_name = val[1:] break else: - raise ValueError('default answer matches no responses') + raise ValueError("default answer matches no responses") for i in range(limit): # XXX see docstring about crummyness if isinstance(prompt, tuple): out.write(autoline=False, *prompt) else: out.write(prompt, autoline=False) - out.write(' [', autoline=False) + out.write(" [", autoline=False) prompts = list(responses.values()) for choice in prompts[:-1]: out.write(autoline=False, *choice[1:]) - out.write(out.reset, '/', autoline=False) + out.write(out.reset, "/", autoline=False) out.write(autoline=False, *prompts[-1][1:]) - out.write(out.reset, ']', autoline=False) + out.write(out.reset, "]", autoline=False) if default_answer is not None: - out.write(' (default: ', autoline=False) + out.write(" (default: ", autoline=False) out.write(autoline=False, *default_answer_name) - out.write(')', autoline=False) - out.write(': ', autoline=False) + out.write(")", autoline=False) + out.write(": ", autoline=False) try: response = input() except EOFError as e: @@ -83,15 +83,20 @@ def userquery(prompt, out, err, responses=None, default_answer=None, limit=3): raise if not response: return default_answer - results = sorted(set( - (key, value) for key, value in responses.items() - if key[:len(response)].lower() == response.lower())) + results = sorted( + set( + (key, value) + for key, value in responses.items() + if key[: len(response)].lower() == response.lower() + ) + ) if not results: - err.write('Sorry, response %r not understood.' % (response,)) + err.write("Sorry, response %r not understood." % (response,)) elif len(results) > 1: err.write( - 'Response %r is ambiguous (%s)' % - (response, ', '.join(key for key, val in results))) + "Response %r is ambiguous (%s)" + % (response, ", ".join(key for key, val in results)) + ) else: return list(results)[0][1][0] diff --git a/src/snakeoil/cli/tool.py b/src/snakeoil/cli/tool.py index 2a142b7..d00b3cf 100644 --- a/src/snakeoil/cli/tool.py +++ b/src/snakeoil/cli/tool.py @@ -36,14 +36,14 @@ class Tool: if not sys.stdout.isatty() and sys.stdout == sys.__stdout__: # if redirecting/piping stdout use line buffering, skip if # stdout has been set to some non-standard object - outfile = os.fdopen(sys.stdout.fileno(), 'w', 1) + outfile = os.fdopen(sys.stdout.fileno(), "w", 1) else: outfile = sys.stdout if errfile is None: errfile = sys.stderr out_fd = err_fd = None - if hasattr(outfile, 'fileno') and hasattr(errfile, 'fileno'): + if hasattr(outfile, "fileno") and hasattr(errfile, "fileno"): # annoyingly, fileno can exist but through unsupport try: out_fd, err_fd = outfile.fileno(), errfile.fileno() @@ -52,9 +52,11 @@ class Tool: if out_fd is not None and err_fd is not None: out_stat, err_stat = os.fstat(out_fd), os.fstat(err_fd) - if out_stat.st_dev == err_stat.st_dev \ - and out_stat.st_ino == err_stat.st_ino and \ - not errfile.isatty(): + if ( + out_stat.st_dev == err_stat.st_dev + and out_stat.st_ino == err_stat.st_ino + and not errfile.isatty() + ): # they're the same underlying fd. thus # point the handles at the same so we don't # get intermixed buffering issues. @@ -64,7 +66,7 @@ class Tool: self._errfile = errfile self.out = self.parser.out = formatters.PlainTextFormatter(outfile) self.err = self.parser.err = formatters.PlainTextFormatter(errfile) - self.out.verbosity = self.err.verbosity = getattr(self.parser, 'verbosity', 0) + self.out.verbosity = self.err.verbosity = getattr(self.parser, "verbosity", 0) def __call__(self, args=None): """Run the utility. @@ -98,19 +100,21 @@ class Tool: try: self.pre_parse(args, namespace) options = self.parser.parse_args(args=args, namespace=namespace) - main_func = options.pop('main_func', None) + main_func = options.pop("main_func", None) if main_func is None: raise RuntimeError("argparser missing main method") # reconfigure formatters for colored output if enabled - if getattr(options, 'color', True): + if getattr(options, "color", True): formatter_factory = partial( - formatters.get_formatter, force_color=getattr(options, 'color', False)) + formatters.get_formatter, + force_color=getattr(options, "color", False), + ) self.out = formatter_factory(self._outfile) self.err = formatter_factory(self._errfile) # reconfigure formatters with properly parsed output verbosity - self.out.verbosity = self.err.verbosity = getattr(options, 'verbosity', 0) + self.out.verbosity = self.err.verbosity = getattr(options, "verbosity", 0) if logging.root.handlers: # Remove the default handler. @@ -138,13 +142,13 @@ class Tool: exc = find_user_exception(e) if exc is not None: # allow exception attribute to override user verbosity level - if getattr(exc, '_verbosity', None) is not None: + if getattr(exc, "_verbosity", None) is not None: verbosity = exc._verbosity else: - verbosity = getattr(self.parser, 'verbosity', 0) + verbosity = getattr(self.parser, "verbosity", 0) # output verbose error message if it exists if verbosity > 0: - msg = exc.msg(verbosity).strip('\n') + msg = exc.msg(verbosity).strip("\n") if msg: self.err.write(msg) raise SystemExit @@ -166,15 +170,17 @@ class Tool: try: with suppress_warnings: - self.options, func = self.parse_args(args=self.args, namespace=self.options) + self.options, func = self.parse_args( + args=self.args, namespace=self.options + ) exitstatus = func(self.options, self.out, self.err) except SystemExit as e: # handle argparse or other third party modules using sys.exit internally exitstatus = e.code except KeyboardInterrupt: - self._errfile.write('keyboard interrupted- exiting') + self._errfile.write("keyboard interrupted- exiting") if self.parser.debug: - self._errfile.write('\n') + self._errfile.write("\n") traceback.print_exc() signal(SIGINT, SIG_DFL) os.killpg(os.getpgid(0), SIGINT) @@ -187,9 +193,9 @@ class Tool: if self.options is not None: # set terminal title on exit if exitstatus: - self.out.title(f'{self.options.prog} failed') + self.out.title(f"{self.options.prog} failed") else: - self.out.title(f'{self.options.prog} succeeded') + self.out.title(f"{self.options.prog} succeeded") return exitstatus @@ -204,18 +210,25 @@ class FormattingHandler(logging.Handler): def emit(self, record): if record.levelno >= logging.ERROR: - color = 'red' + color = "red" elif record.levelno >= logging.WARNING: - color = 'yellow' + color = "yellow" else: - color = 'cyan' - first_prefix = (self.out.fg(color), self.out.bold, record.levelname, - self.out.reset, ' ', record.name, ': ') - later_prefix = (len(record.levelname) + len(record.name)) * ' ' + ' : ' + color = "cyan" + first_prefix = ( + self.out.fg(color), + self.out.bold, + record.levelname, + self.out.reset, + " ", + record.name, + ": ", + ) + later_prefix = (len(record.levelname) + len(record.name)) * " " + " : " self.out.first_prefix.extend(first_prefix) self.out.later_prefix.append(later_prefix) try: - for line in self.format(record).split('\n'): + for line in self.format(record).split("\n"): self.out.write(line, wrap=True) except Exception: self.handleError(record) diff --git a/src/snakeoil/compatibility.py b/src/snakeoil/compatibility.py index c9a8a54..4fee417 100644 --- a/src/snakeoil/compatibility.py +++ b/src/snakeoil/compatibility.py @@ -8,12 +8,15 @@ __all__ = ("cmp", "sorted_cmp", "sort_cmp") def sorted_key_from_cmp(cmp_func, key_func=None): class _key_proxy: - __slots__ = ('_obj',) + __slots__ = ("_obj",) + + if key_func: # done this way for speed reasons. - if key_func: # done this way for speed reasons. def __init__(self, obj, key_convert=key_func): self._obj = key_convert(obj) + else: + def __init__(self, obj): self._obj = obj @@ -40,13 +43,13 @@ def cmp(obj1, obj2, raw_cmp=_raw_cmp): def sorted_cmp(sequence, func, key=None, reverse=False): - return sorted(sequence, reverse=reverse, - key=sorted_key_from_cmp(func, key_func=key)) + return sorted( + sequence, reverse=reverse, key=sorted_key_from_cmp(func, key_func=key) + ) def sort_cmp(list_inst, func, key=None, reverse=False): - list_inst.sort(reverse=reverse, - key=sorted_key_from_cmp(func, key_func=key)) + list_inst.sort(reverse=reverse, key=sorted_key_from_cmp(func, key_func=key)) IGNORED_EXCEPTIONS = (RuntimeError, MemoryError, SystemExit, KeyboardInterrupt) diff --git a/src/snakeoil/compression/__init__.py b/src/snakeoil/compression/__init__.py index 77c0631..9eee103 100644 --- a/src/snakeoil/compression/__init__.py +++ b/src/snakeoil/compression/__init__.py @@ -9,13 +9,12 @@ from ..process.spawn import spawn_get_output class _transform_source: - def __init__(self, name): self.name = name @cached_property def module(self): - return import_module(f'snakeoil.compression._{self.name}') + return import_module(f"snakeoil.compression._{self.name}") def compress_data(self, data, level, parallelize=False): parallelize = parallelize and self.module.parallelizable @@ -34,7 +33,7 @@ class _transform_source: return self.module.decompress_handle(handle, parallelize=parallelize) -_transforms = {name: _transform_source(name) for name in ('bzip2', 'xz')} +_transforms = {name: _transform_source(name) for name in ("bzip2", "xz")} def compress_data(compressor_type, data, level=9, **kwds): @@ -73,13 +72,13 @@ class ArComp: cls = cls.known_exts[ext] return super(ArComp, cls).__new__(cls) except KeyError: - raise ArCompError(f'unknown compression file extension: {ext!r}') + raise ArCompError(f"unknown compression file extension: {ext!r}") def __init_subclass__(cls, **kwargs): """Initialize result subclasses and register archive extensions.""" super().__init_subclass__(**kwargs) - if not all((cls.binary, cls.default_unpack_cmd, cls.exts)): # pragma: no cover - raise ValueError(f'class missing required attrs: {cls!r}') + if not all((cls.binary, cls.default_unpack_cmd, cls.exts)): # pragma: no cover + raise ValueError(f"class missing required attrs: {cls!r}") for ext in cls.exts: cls.known_exts[ext] = cls @@ -95,9 +94,10 @@ class ArComp: except process.CommandNotFound: continue else: - choices = ', '.join(self.binary) + choices = ", ".join(self.binary) raise ArCompError( - f'required binary not found from the following choices: {choices}') + f"required binary not found from the following choices: {choices}" + ) cmd = self.default_unpack_cmd.format(binary=binary, path=self.path) return cmd @@ -112,7 +112,7 @@ class _Archive: cmd = shlex.split(self._unpack_cmd.format(path=self.path)) ret, output = spawn_get_output(cmd, collect_fds=(2,), **kwargs) if ret: - msg = '\n'.join(output) if output else f'unpacking failed: {self.path!r}' + msg = "\n".join(output) if output else f"unpacking failed: {self.path!r}" raise ArCompError(msg, code=ret) @@ -121,11 +121,12 @@ class _CompressedFile: def unpack(self, dest=None, **kwargs): cmd = shlex.split(self._unpack_cmd.format(path=self.path)) - with open(dest, 'wb') as f: + with open(dest, "wb") as f: ret, output = spawn_get_output( - cmd, collect_fds=(2,), fd_pipes={1: f.fileno()}, **kwargs) + cmd, collect_fds=(2,), fd_pipes={1: f.fileno()}, **kwargs + ) if ret: - msg = '\n'.join(output) if output else f'unpacking failed: {self.path!r}' + msg = "\n".join(output) if output else f"unpacking failed: {self.path!r}" raise ArCompError(msg, code=ret) @@ -134,18 +135,25 @@ class _CompressedStdin: def unpack(self, dest=None, **kwargs): cmd = shlex.split(self._unpack_cmd) - with open(self.path, 'rb') as src, open(dest, 'wb') as f: + with open(self.path, "rb") as src, open(dest, "wb") as f: ret, output = spawn_get_output( - cmd, collect_fds=(2,), fd_pipes={0: src.fileno(), 1: f.fileno()}, **kwargs) + cmd, + collect_fds=(2,), + fd_pipes={0: src.fileno(), 1: f.fileno()}, + **kwargs, + ) if ret: - msg = '\n'.join(output) if output else f'unpacking failed: {self.path!r}' + msg = "\n".join(output) if output else f"unpacking failed: {self.path!r}" raise ArCompError(msg, code=ret) class _Tar(_Archive, ArComp): - exts = frozenset(['.tar']) - binary = ('gtar', 'tar',) + exts = frozenset([".tar"]) + binary = ( + "gtar", + "tar", + ) compress_binary = None default_unpack_cmd = '{binary} xf "{path}"' @@ -162,95 +170,96 @@ class _Tar(_Archive, ArComp): except process.CommandNotFound: pass else: - choices = ', '.join(next(zip(*self.compress_binary))) + choices = ", ".join(next(zip(*self.compress_binary))) raise ArCompError( - 'no compression binary found from the ' - f'following choices: {choices}') + "no compression binary found from the " + f"following choices: {choices}" + ) return cmd class _TarGZ(_Tar): - exts = frozenset(['.tar.gz', '.tgz', '.tar.Z', '.tar.z']) - compress_binary = (('pigz',), ('gzip',)) + exts = frozenset([".tar.gz", ".tgz", ".tar.Z", ".tar.z"]) + compress_binary = (("pigz",), ("gzip",)) class _TarBZ2(_Tar): - exts = frozenset(['.tar.bz2', '.tbz2', '.tbz']) - compress_binary = (('lbzip2',), ('pbzip2',), ('bzip2',)) + exts = frozenset([".tar.bz2", ".tbz2", ".tbz"]) + compress_binary = (("lbzip2",), ("pbzip2",), ("bzip2",)) class _TarLZMA(_Tar): - exts = frozenset(['.tar.lzma']) - compress_binary = (('lzma',)) + exts = frozenset([".tar.lzma"]) + compress_binary = ("lzma",) class _TarXZ(_Tar): - exts = frozenset(['.tar.xz', '.txz']) - compress_binary = (('pixz',), ('xz', f'-T{multiprocessing.cpu_count()}')) + exts = frozenset([".tar.xz", ".txz"]) + compress_binary = (("pixz",), ("xz", f"-T{multiprocessing.cpu_count()}")) class _Zip(_Archive, ArComp): - exts = frozenset(['.ZIP', '.zip', '.jar']) - binary = ('unzip',) + exts = frozenset([".ZIP", ".zip", ".jar"]) + binary = ("unzip",) default_unpack_cmd = '{binary} -qo "{path}"' class _GZ(_CompressedStdin, ArComp): - exts = frozenset(['.gz', '.Z', '.z']) - binary = ('pigz', 'gzip') - default_unpack_cmd = '{binary} -d -c' + exts = frozenset([".gz", ".Z", ".z"]) + binary = ("pigz", "gzip") + default_unpack_cmd = "{binary} -d -c" class _BZ2(_CompressedStdin, ArComp): - exts = frozenset(['.bz2', '.bz']) - binary = ('lbzip2', 'pbzip2', 'bzip2') - default_unpack_cmd = '{binary} -d -c' + exts = frozenset([".bz2", ".bz"]) + binary = ("lbzip2", "pbzip2", "bzip2") + default_unpack_cmd = "{binary} -d -c" class _XZ(_CompressedStdin, ArComp): - exts = frozenset(['.xz']) - binary = ('pixz', 'xz') - default_unpack_cmd = '{binary} -d -c' + exts = frozenset([".xz"]) + binary = ("pixz", "xz") + default_unpack_cmd = "{binary} -d -c" class _7Z(_Archive, ArComp): - exts = frozenset(['.7Z', '.7z']) - binary = ('7z',) + exts = frozenset([".7Z", ".7z"]) + binary = ("7z",) default_unpack_cmd = '{binary} x -y "{path}"' class _Rar(_Archive, ArComp): - exts = frozenset(['.RAR', '.rar']) - binary = ('unrar',) + exts = frozenset([".RAR", ".rar"]) + binary = ("unrar",) default_unpack_cmd = '{binary} x -idq -o+ "{path}"' class _LHA(_Archive, ArComp): - exts = frozenset(['.LHa', '.LHA', '.lha', '.lzh']) - binary = ('lha',) + exts = frozenset([".LHa", ".LHA", ".lha", ".lzh"]) + binary = ("lha",) default_unpack_cmd = '{binary} xfq "{path}"' class _Ar(_Archive, ArComp): - exts = frozenset(['.a', '.deb']) - binary = ('ar',) + exts = frozenset([".a", ".deb"]) + binary = ("ar",) default_unpack_cmd = '{binary} x "{path}"' class _LZMA(_CompressedFile, ArComp): - exts = frozenset(['.lzma']) - binary = ('lzma',) + exts = frozenset([".lzma"]) + binary = ("lzma",) default_unpack_cmd = '{binary} -dc "{path}"' diff --git a/src/snakeoil/compression/_bzip2.py b/src/snakeoil/compression/_bzip2.py index 1a38922..122debd 100644 --- a/src/snakeoil/compression/_bzip2.py +++ b/src/snakeoil/compression/_bzip2.py @@ -25,6 +25,7 @@ bz2_path = process.find_binary("bzip2") try: from bz2 import BZ2File, compress as _compress_data, decompress as _decompress_data + native = True except ImportError: @@ -40,7 +41,7 @@ _decompress_handle = partial(_util.decompress_handle, bz2_path) try: lbzip2_path = process.find_binary("lbzip2") - lbzip2_compress_args = (f'-n{multiprocessing.cpu_count()}', ) + lbzip2_compress_args = (f"-n{multiprocessing.cpu_count()}",) lbzip2_decompress_args = lbzip2_compress_args parallelizable = True except process.CommandNotFound: @@ -51,28 +52,35 @@ except process.CommandNotFound: def compress_data(data, level=9, parallelize=False): if parallelize and parallelizable: - return _util.compress_data(lbzip2_path, data, compresslevel=level, - extra_args=lbzip2_compress_args) + return _util.compress_data( + lbzip2_path, data, compresslevel=level, extra_args=lbzip2_compress_args + ) return _compress_data(data, compresslevel=level) + def decompress_data(data, parallelize=False): if parallelize and parallelizable: - return _util.decompress_data(lbzip2_path, data, - extra_args=lbzip2_decompress_args) + return _util.decompress_data( + lbzip2_path, data, extra_args=lbzip2_decompress_args + ) return _decompress_data(data) + def compress_handle(handle, level=9, parallelize=False): if parallelize and parallelizable: - return _util.compress_handle(lbzip2_path, handle, compresslevel=level, - extra_args=lbzip2_compress_args) + return _util.compress_handle( + lbzip2_path, handle, compresslevel=level, extra_args=lbzip2_compress_args + ) elif native and isinstance(handle, str): - return BZ2File(handle, mode='w', compresslevel=level) + return BZ2File(handle, mode="w", compresslevel=level) return _compress_handle(handle, compresslevel=level) + def decompress_handle(handle, parallelize=False): if parallelize and parallelizable: - return _util.decompress_handle(lbzip2_path, handle, - extra_args=lbzip2_decompress_args) + return _util.decompress_handle( + lbzip2_path, handle, extra_args=lbzip2_decompress_args + ) elif native and isinstance(handle, str): - return BZ2File(handle, mode='r') + return BZ2File(handle, mode="r") return _decompress_handle(handle) diff --git a/src/snakeoil/compression/_util.py b/src/snakeoil/compression/_util.py index e1af5ae..b95d80c 100644 --- a/src/snakeoil/compression/_util.py +++ b/src/snakeoil/compression/_util.py @@ -6,15 +6,20 @@ import subprocess def _drive_process(args, mode, data): - p = subprocess.Popen(args, - stdin=subprocess.PIPE, stdout=subprocess.PIPE, - stderr=subprocess.PIPE, close_fds=True) + p = subprocess.Popen( + args, + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + close_fds=True, + ) try: stdout, stderr = p.communicate(data) if p.returncode != 0: - args = ' '.join(args) + args = " ".join(args) raise ValueError( - f"{mode} returned {p.returncode} exitcode from '{args}', stderr={stderr.decode()}") + f"{mode} returned {p.returncode} exitcode from '{args}', stderr={stderr.decode()}" + ) return stdout finally: if p is not None and p.returncode is None: @@ -22,21 +27,20 @@ def _drive_process(args, mode, data): def compress_data(binary, data, compresslevel=9, extra_args=()): - args = [binary, f'-{compresslevel}c'] + args = [binary, f"-{compresslevel}c"] args.extend(extra_args) - return _drive_process(args, 'compression', data) + return _drive_process(args, "compression", data) def decompress_data(binary, data, extra_args=()): - args = [binary, '-dc'] + args = [binary, "-dc"] args.extend(extra_args) - return _drive_process(args, 'decompression', data) + return _drive_process(args, "decompression", data) class _process_handle: - def __init__(self, handle, args, is_read=False): - self.mode = 'rb' if is_read else 'wb' + self.mode = "rb" if is_read else "wb" self.args = tuple(args) self.is_read = is_read @@ -51,9 +55,10 @@ class _process_handle: handle = open(handle, mode=self.mode) close = True elif not isinstance(handle, int): - if not hasattr(handle, 'fileno'): + if not hasattr(handle, "fileno"): raise TypeError( - f"handle {handle!r} isn't a string, integer, and lacks a fileno method") + f"handle {handle!r} isn't a string, integer, and lacks a fileno method" + ) handle = handle.fileno() try: @@ -64,18 +69,17 @@ class _process_handle: def _setup_process(self, handle): self.position = 0 - stderr = open(os.devnull, 'wb') + stderr = open(os.devnull, "wb") kwds = dict(stderr=stderr) if self.is_read: - kwds['stdin'] = handle - kwds['stdout'] = subprocess.PIPE + kwds["stdin"] = handle + kwds["stdout"] = subprocess.PIPE else: - kwds['stdout'] = handle - kwds['stdin'] = subprocess.PIPE + kwds["stdout"] = handle + kwds["stdin"] = subprocess.PIPE try: - self._process = subprocess.Popen( - self.args, close_fds=True, **kwds) + self._process = subprocess.Popen(self.args, close_fds=True, **kwds) finally: stderr.close() @@ -106,7 +110,8 @@ class _process_handle: if self._allow_reopen is None: raise TypeError( f"instance {self} can't do negative seeks: " - f"asked for {position}, was at {self.position}") + f"asked for {position}, was at {self.position}" + ) self._terminate() self._open_handle(self._allow_reopen) return self.seek(position) @@ -130,7 +135,7 @@ class _process_handle: # reallocating it continually; via this usage, we # only slice once the val is less than seek_size; # iow, two allocations worst case. - null_block = '\0' * seek_size + null_block = "\0" * seek_size while val: self.write(null_block[:val]) offset -= val @@ -145,11 +150,13 @@ class _process_handle: raise def close(self): - if not hasattr(self, '_process'): + if not hasattr(self, "_process"): return if self._process.returncode is not None: if self._process.returncode != 0: - raise Exception(f"{self.args} invocation had non zero exit: {self._process.returncode}") + raise Exception( + f"{self.args} invocation had non zero exit: {self._process.returncode}" + ) return self.handle.close() @@ -163,12 +170,12 @@ class _process_handle: def compress_handle(binary_path, handle, compresslevel=9, extra_args=()): - args = [binary_path, f'-{compresslevel}c'] + args = [binary_path, f"-{compresslevel}c"] args.extend(extra_args) return _process_handle(handle, args, False) def decompress_handle(binary_path, handle, extra_args=()): - args = [binary_path, '-dc'] + args = [binary_path, "-dc"] args.extend(extra_args) return _process_handle(handle, args, True) diff --git a/src/snakeoil/compression/_xz.py b/src/snakeoil/compression/_xz.py index 4707737..9a91a4c 100644 --- a/src/snakeoil/compression/_xz.py +++ b/src/snakeoil/compression/_xz.py @@ -21,7 +21,7 @@ from ..compression import _util # if xz can't be found, throw an error. xz_path = process.find_binary("xz") -xz_compress_args = (f'-T{multiprocessing.cpu_count()}',) +xz_compress_args = (f"-T{multiprocessing.cpu_count()}",) xz_decompress_args = xz_compress_args parallelizable = True @@ -29,6 +29,7 @@ try: from lzma import LZMAFile from lzma import compress as _compress_data from lzma import decompress as _decompress_data + native = True except ImportError: @@ -45,30 +46,33 @@ _decompress_handle = partial(_util.decompress_handle, xz_path) def compress_data(data, level=9, parallelize=False): if parallelize and parallelizable: - return _util.compress_data(xz_path, data, compresslevel=level, - extra_args=xz_compress_args) + return _util.compress_data( + xz_path, data, compresslevel=level, extra_args=xz_compress_args + ) if native: return _compress_data(data, preset=level) return _compress_data(data, compresslevel=level) + def decompress_data(data, parallelize=False): if parallelize and parallelizable: - return _util.decompress_data(xz_path, data, - extra_args=xz_decompress_args) + return _util.decompress_data(xz_path, data, extra_args=xz_decompress_args) return _decompress_data(data) + def compress_handle(handle, level=9, parallelize=False): if parallelize and parallelizable: - return _util.compress_handle(xz_path, handle, compresslevel=level, - extra_args=xz_compress_args) + return _util.compress_handle( + xz_path, handle, compresslevel=level, extra_args=xz_compress_args + ) elif native and isinstance(handle, str): - return LZMAFile(handle, mode='w', preset=level) + return LZMAFile(handle, mode="w", preset=level) return _compress_handle(handle, compresslevel=level) + def decompress_handle(handle, parallelize=False): if parallelize and parallelizable: - return _util.decompress_handle(xz_path, handle, - extra_args=xz_decompress_args) - elif (native and isinstance(handle, str)): - return LZMAFile(handle, mode='r') + return _util.decompress_handle(xz_path, handle, extra_args=xz_decompress_args) + elif native and isinstance(handle, str): + return LZMAFile(handle, mode="r") return _decompress_handle(handle) diff --git a/src/snakeoil/constraints.py b/src/snakeoil/constraints.py index 63e4671..c239727 100644 --- a/src/snakeoil/constraints.py +++ b/src/snakeoil/constraints.py @@ -31,8 +31,9 @@ class Constraint(Protocol): domain. :return: ``True`` if the assignment is satisfied. """ + def __call__(self, **kwargs: Any) -> bool: - raise NotImplementedError('Constraint', '__call__') + raise NotImplementedError("Constraint", "__call__") class _Domain(list): @@ -75,10 +76,13 @@ class Problem: of a :py:class:`dict` assigning to each variable in the problem a single value from it's domain. """ + def __init__(self): self.variables: dict[str, _Domain] = {} self.constraints: list[tuple[Constraint, frozenset[str]]] = [] - self.vconstraints: dict[str, list[tuple[Constraint, frozenset[str]]]] = defaultdict(list) + self.vconstraints: dict[ + str, list[tuple[Constraint, frozenset[str]]] + ] = defaultdict(list) def add_variable(self, domain: Iterable[Any], *variables: str): """Add variables to the problem, which use the specified domain. @@ -94,7 +98,9 @@ class Problem: from each domain. """ for variable in variables: - assert variable not in self.variables, f'variable {variable!r} was already added' + assert ( + variable not in self.variables + ), f"variable {variable!r} was already added" self.variables[variable] = _Domain(domain) def add_constraint(self, constraint: Constraint, variables: frozenset[str]): @@ -110,10 +116,15 @@ class Problem: """ self.constraints.append((constraint, variables)) for variable in variables: - assert variable in self.variables, f'unknown variable {variable!r}' + assert variable in self.variables, f"unknown variable {variable!r}" self.vconstraints[variable].append((constraint, variables)) - def __check(self, constraint: Constraint, variables: frozenset[str], assignments: dict[str, Any]) -> bool: + def __check( + self, + constraint: Constraint, + variables: frozenset[str], + assignments: dict[str, Any], + ) -> bool: assignments = {k: v for k, v in assignments.items() if k in variables} unassigned = variables - assignments.keys() if not unassigned: @@ -147,14 +158,17 @@ class Problem: # mix the Degree and Minimum Remaining Values (MRV) heuristics lst = sorted( (-len(self.vconstraints[name]), len(domain), name) - for name, domain in self.variables.items()) + for name, domain in self.variables.items() + ) for _, _, variable in lst: if variable not in assignments: values = self.variables[variable][:] push_domains = tuple( - domain for name, domain in self.variables.items() - if name != variable and name not in assignments) + domain + for name, domain in self.variables.items() + if name != variable and name not in assignments + ) break else: # no unassigned variables, we've got a solution. diff --git a/src/snakeoil/containers.py b/src/snakeoil/containers.py index ba21155..ebfc1b3 100644 --- a/src/snakeoil/containers.py +++ b/src/snakeoil/containers.py @@ -4,8 +4,12 @@ Container classes and functionality for implementing them """ __all__ = ( - "InvertedContains", "SetMixin", "LimitedChangeSet", "Unchangable", - "ProtectedSet", "RefCountingSet" + "InvertedContains", + "SetMixin", + "LimitedChangeSet", + "Unchangable", + "ProtectedSet", + "RefCountingSet", ) from itertools import chain, filterfalse @@ -70,9 +74,11 @@ class SetMixin: @steal_docs(set) def __xor__(self, other, kls=None): - return (kls or self.__class__)(chain( - (x for x in self if x not in other), - (x for x in other if x not in self))) + return (kls or self.__class__)( + chain( + (x for x in self if x not in other), (x for x in other if x not in self) + ) + ) @steal_docs(set) def __rxor__(self, other): @@ -120,8 +126,7 @@ class LimitedChangeSet(SetMixin): def _default_key_validator(val): return val - def __init__(self, initial_keys, unchangable_keys=None, - key_validator=None): + def __init__(self, initial_keys, unchangable_keys=None, key_validator=None): """ :param initial_keys: iterable holding the initial values to set :param unchangable_keys: container holding keys that cannot be changed @@ -185,8 +190,7 @@ class LimitedChangeSet(SetMixin): def rollback(self, point=0): l = self.changes_count() if point < 0 or point > l: - raise TypeError( - "%s point must be >=0 and <= changes_count()" % point) + raise TypeError("%s point must be >=0 and <= changes_count()" % point) while l > point: change, key = self._change_order.pop(-1) self._changed.remove(key) @@ -221,9 +225,8 @@ class LimitedChangeSet(SetMixin): class Unchangable(Exception): - def __init__(self, key): - super().__init__(f'key {key!r} is unchangable') + super().__init__(f"key {key!r} is unchangable") self.key = key @@ -240,6 +243,7 @@ class ProtectedSet(SetMixin): >>> myset.remove(2) >>> assert 2 not in protected """ + def __init__(self, orig_set): self._orig = orig_set self._new = set() diff --git a/src/snakeoil/contexts.py b/src/snakeoil/contexts.py index 57092b4..394574b 100644 --- a/src/snakeoil/contexts.py +++ b/src/snakeoil/contexts.py @@ -41,6 +41,7 @@ from .sequences import predicate_split # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN # THE SOFTWARE. + class SplitExec: """Context manager separating code execution across parent/child processes. @@ -48,6 +49,7 @@ class SplitExec: of the context are executed only on the forked child. Exceptions are pickled and passed back to the parent. """ + def __init__(self): self.__trace_lock = threading.Lock() self.__orig_sys_trace = None @@ -184,7 +186,7 @@ class SplitExec: @staticmethod def __excepthook(_exc_type, exc_value, exc_traceback): """Output the proper traceback information from the chroot context.""" - if hasattr(exc_value, '__traceback_list__'): + if hasattr(exc_value, "__traceback_list__"): sys.stderr.write(exc_value.__traceback_list__) else: traceback.print_tb(exc_traceback) @@ -253,7 +255,7 @@ class SplitExec: except AttributeError: # an offset of two accounts for this method and its caller frame = inspect.stack(0)[2][0] - while frame.f_locals.get('self') is self: + while frame.f_locals.get("self") is self: frame = frame.f_back self.__frame = frame # pylint: disable=W0201 return frame @@ -262,11 +264,24 @@ class SplitExec: class Namespace(SplitExec): """Context manager that provides Linux namespace support.""" - def __init__(self, mount=False, uts=True, ipc=False, net=False, pid=False, - user=False, hostname=None): + def __init__( + self, + mount=False, + uts=True, + ipc=False, + net=False, + pid=False, + user=False, + hostname=None, + ): self._hostname = hostname self._namespaces = { - 'mount': mount, 'uts': uts, 'ipc': ipc, 'net': net, 'pid': pid, 'user': user, + "mount": mount, + "uts": uts, + "ipc": ipc, + "net": net, + "pid": pid, + "user": user, } super().__init__() @@ -279,8 +294,8 @@ class GitStash(AbstractContextManager): def __init__(self, path, pathspecs=None, staged=False): self.path = path - self.pathspecs = ['--'] + pathspecs if pathspecs else [] - self._staged = ['--keep-index'] if staged else [] + self.pathspecs = ["--"] + pathspecs if pathspecs else [] + self._staged = ["--keep-index"] if staged else [] self._stashed = False def __enter__(self): @@ -288,14 +303,18 @@ class GitStash(AbstractContextManager): # check for untracked or modified/uncommitted files try: p = subprocess.run( - ['git', 'status', '--porcelain=1', '-u'] + self.pathspecs, - stdout=subprocess.PIPE, stderr=subprocess.DEVNULL, - cwd=self.path, encoding='utf8', check=True) + ["git", "status", "--porcelain=1", "-u"] + self.pathspecs, + stdout=subprocess.PIPE, + stderr=subprocess.DEVNULL, + cwd=self.path, + encoding="utf8", + check=True, + ) except subprocess.CalledProcessError: - raise ValueError(f'not a git repo: {self.path}') + raise ValueError(f"not a git repo: {self.path}") # split file changes into unstaged vs staged - unstaged, staged = predicate_split(lambda x: x[1] == ' ', p.stdout.splitlines()) + unstaged, staged = predicate_split(lambda x: x[1] == " ", p.stdout.splitlines()) # don't stash when no relevant changes exist if self._staged: @@ -306,14 +325,18 @@ class GitStash(AbstractContextManager): # stash all existing untracked or modified/uncommitted files try: - stash_cmd = ['git', 'stash', 'push', '-u', '-m', 'pkgcheck scan --commits'] + stash_cmd = ["git", "stash", "push", "-u", "-m", "pkgcheck scan --commits"] subprocess.run( stash_cmd + self._staged + self.pathspecs, - stdout=subprocess.DEVNULL, stderr=subprocess.PIPE, - cwd=self.path, check=True, encoding='utf8') + stdout=subprocess.DEVNULL, + stderr=subprocess.PIPE, + cwd=self.path, + check=True, + encoding="utf8", + ) except subprocess.CalledProcessError as e: error = e.stderr.splitlines()[0] - raise UserException(f'git failed stashing files: {error}') + raise UserException(f"git failed stashing files: {error}") self._stashed = True def __exit__(self, _exc_type, _exc_value, _traceback): @@ -321,12 +344,16 @@ class GitStash(AbstractContextManager): if self._stashed: try: subprocess.run( - ['git', 'stash', 'pop'], - stdout=subprocess.DEVNULL, stderr=subprocess.PIPE, - cwd=self.path, check=True, encoding='utf8') + ["git", "stash", "pop"], + stdout=subprocess.DEVNULL, + stderr=subprocess.PIPE, + cwd=self.path, + check=True, + encoding="utf8", + ) except subprocess.CalledProcessError as e: error = e.stderr.splitlines()[0] - raise UserException(f'git failed applying stash: {error}') + raise UserException(f"git failed applying stash: {error}") @contextmanager @@ -347,7 +374,7 @@ def chdir(path): @contextmanager -def syspath(path: str, condition: bool=True, position: int=0): +def syspath(path: str, condition: bool = True, position: int = 0): """Context manager that mangles ``sys.path`` and then reverts on exit. :param path: The directory path to add to ``sys.path``. @@ -425,6 +452,7 @@ def os_environ(*remove, **update): # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + @contextmanager def patch(target, new): """Simplified module monkey patching via context manager. @@ -434,7 +462,7 @@ def patch(target, new): """ def _import_module(target): - components = target.split('.') + components = target.split(".") import_path = components.pop(0) module = import_module(import_path) for comp in components: @@ -448,16 +476,16 @@ def patch(target, new): def _get_target(target): if isinstance(target, str): try: - module, attr = target.rsplit('.', 1) + module, attr = target.rsplit(".", 1) except (TypeError, ValueError): - raise TypeError(f'invalid target: {target!r}') + raise TypeError(f"invalid target: {target!r}") module = _import_module(module) return module, attr else: try: obj, attr = target except (TypeError, ValueError): - raise TypeError(f'invalid target: {target!r}') + raise TypeError(f"invalid target: {target!r}") return obj, attr obj, attr = _get_target(target) diff --git a/src/snakeoil/currying.py b/src/snakeoil/currying.py index f8d8d97..bc16dc0 100644 --- a/src/snakeoil/currying.py +++ b/src/snakeoil/currying.py @@ -46,14 +46,19 @@ def pre_curry(func, *args, **kwargs): """ if not kwargs: + def callit(*moreargs, **morekwargs): return func(*(args + moreargs), **morekwargs) + elif not args: + def callit(*moreargs, **morekwargs): kw = kwargs.copy() kw.update(morekwargs) return func(*moreargs, **kw) + else: + def callit(*moreargs, **morekwargs): kw = kwargs.copy() kw.update(morekwargs) @@ -67,14 +72,19 @@ def post_curry(func, *args, **kwargs): """passed in args are appended to any further args supplied""" if not kwargs: + def callit(*moreargs, **morekwargs): return func(*(moreargs + args), **morekwargs) + elif not args: + def callit(*moreargs, **morekwargs): kw = morekwargs.copy() kw.update(kwargs) return func(*moreargs, **kw) + else: + def callit(*moreargs, **morekwargs): kw = morekwargs.copy() kw.update(kwargs) @@ -112,18 +122,32 @@ def wrap_exception(recast_exception, *args, **kwds): # set this here so that 2to3 will rewrite it. try: if not issubclass(recast_exception, Exception): - raise ValueError("recast_exception must be an %s derivative: got %r" % - (Exception, recast_exception)) + raise ValueError( + "recast_exception must be an %s derivative: got %r" + % (Exception, recast_exception) + ) except TypeError as e: - raise TypeError("recast_exception must be an %s derivative; got %r, failed %r", - (Exception.__name__, recast_exception, e)) + raise TypeError( + "recast_exception must be an %s derivative; got %r, failed %r", + (Exception.__name__, recast_exception, e), + ) ignores = kwds.pop("ignores", (recast_exception,)) pass_error = kwds.pop("pass_error", None) - return wrap_exception_complex(partial(_simple_throw, recast_exception, args, kwds, pass_error), ignores) - - -def _simple_throw(recast_exception, recast_args, recast_kwds, pass_error, - exception, functor, args, kwds): + return wrap_exception_complex( + partial(_simple_throw, recast_exception, args, kwds, pass_error), ignores + ) + + +def _simple_throw( + recast_exception, + recast_args, + recast_kwds, + pass_error, + exception, + functor, + args, + kwds, +): if pass_error: recast_kwds[pass_error] = exception return recast_exception(*recast_args, **recast_kwds) @@ -131,15 +155,22 @@ def _simple_throw(recast_exception, recast_args, recast_kwds, pass_error, def wrap_exception_complex(creation_func, ignores): try: - if not hasattr(ignores, '__iter__') and issubclass(ignores, Exception) or ignores is Exception: + if ( + not hasattr(ignores, "__iter__") + and issubclass(ignores, Exception) + or ignores is Exception + ): ignores = (ignores,) ignores = tuple(ignores) except TypeError as e: - raise TypeError("ignores must be either a tuple of %s, or a %s: got %r, error %r" - % (Exception.__name__, Exception.__name__, ignores, e)) + raise TypeError( + "ignores must be either a tuple of %s, or a %s: got %r, error %r" + % (Exception.__name__, Exception.__name__, ignores, e) + ) if not all(issubclass(x, Exception) for x in ignores): - raise TypeError("ignores has a non %s derivative in it: %r" % - (Exception.__name__, ignores)) + raise TypeError( + "ignores has a non %s derivative in it: %r" % (Exception.__name__, ignores) + ) return partial(_inner_wrap_exception, creation_func, ignores) @@ -153,5 +184,6 @@ def _inner_wrap_exception(exception_maker, ignores, functor): raise except Exception as e: raise exception_maker(e, functor, args, kwargs) from e + _wrap_exception.func = functor return pretty_docs(_wrap_exception, name=functor.__name__) diff --git a/src/snakeoil/data_source.py b/src/snakeoil/data_source.py index 1faa540..3b0ccb1 100644 --- a/src/snakeoil/data_source.py +++ b/src/snakeoil/data_source.py @@ -33,8 +33,13 @@ we caught the exception. """ __all__ = ( - "base", "bz2_source", "data_source", "local_source", "text_data_source", - "bytes_data_source", "invokable_data_source", + "base", + "bz2_source", + "data_source", + "local_source", + "text_data_source", + "bytes_data_source", + "invokable_data_source", ) import errno @@ -62,10 +67,9 @@ def _mk_writable_cls(base, name): exceptions attribute """ - base_cls = base exceptions = (MemoryError,) - __slots__ = ('_callback',) + __slots__ = ("_callback",) def __init__(self, callback, data): """ @@ -85,6 +89,7 @@ def _mk_writable_cls(base, name): self._callback(self.read()) self._callback = None self.base_cls.close(self) + kls.__name__ = name return kls @@ -100,6 +105,7 @@ class text_ro_StringIO(stringio.text_readonly): Specifically this adds the necessary `exceptions` attribute; see :py:class:`snakeoil.stringio.text_readonly` for methods details. """ + __slots__ = () exceptions = (MemoryError, TypeError) @@ -111,6 +117,7 @@ class bytes_ro_StringIO(stringio.bytes_readonly): Specifically this adds the necessary `exceptions` attribute; see :py:class:`snakeoil.stringio.bytes_readonly` for methods details. """ + __slots__ = () exceptions = (MemoryError, TypeError) @@ -131,6 +138,7 @@ class base: :ivar path: If None, no local path is available- else it's the ondisk path to the data """ + __slots__ = ("weakref",) path = None @@ -155,7 +163,8 @@ class base: def transfer_to_path(self, path): return self.transfer_to_data_source( - local_source(path, mutable=True, encoding=None)) + local_source(path, mutable=True, encoding=None) + ) def transfer_to_data_source(self, write_source): read_f, m, write_f = None, None, None @@ -208,31 +217,32 @@ class local_source(base): raise TypeError("data source %s is immutable" % (self,)) if self.encoding: opener = open_file - opener = post_curry(opener, buffering=self.buffering_window, - encoding=self.encoding) + opener = post_curry( + opener, buffering=self.buffering_window, encoding=self.encoding + ) else: opener = post_curry(open_file, self.buffering_window) if not writable: - return opener(self.path, 'r') + return opener(self.path, "r") try: return opener(self.path, "r+") except IOError as ie: if ie.errno != errno.ENOENT: raise - return opener(self.path, 'w+') + return opener(self.path, "w+") @klass.steal_docs(base) def bytes_fileobj(self, writable=False): if not writable: - return open_file(self.path, 'rb', self.buffering_window) + return open_file(self.path, "rb", self.buffering_window) if not self.mutable: raise TypeError("data source %s is immutable" % (self,)) try: - return open_file(self.path, 'rb+', self.buffering_window) + return open_file(self.path, "rb+", self.buffering_window) except IOError as ie: if ie.errno != errno.ENOENT: raise - return open_file(self.path, 'wb+', self.buffering_window) + return open_file(self.path, "wb+", self.buffering_window) class bz2_source(base): @@ -255,7 +265,8 @@ class bz2_source(base): def text_fileobj(self, writable=False): data = compression.decompress_data( - 'bzip2', fileutils.readfile_bytes(self.path)).decode() + "bzip2", fileutils.readfile_bytes(self.path) + ).decode() if writable: if not self.mutable: raise TypeError(f"data source {self} is not mutable") @@ -263,8 +274,7 @@ class bz2_source(base): return text_ro_StringIO(data) def bytes_fileobj(self, writable=False): - data = compression.decompress_data( - 'bzip2', fileutils.readfile_bytes(self.path)) + data = compression.decompress_data("bzip2", fileutils.readfile_bytes(self.path)) if writable: if not self.mutable: raise TypeError(f"data source {self} is not mutable") @@ -275,7 +285,7 @@ class bz2_source(base): if isinstance(data, str): data = data.encode() with open(self.path, "wb") as f: - f.write(compression.compress_data('bzip2', data)) + f.write(compression.compress_data("bzip2", data)) class data_source(base): @@ -293,7 +303,7 @@ class data_source(base): :ivar path: note that path is None for this class- no on disk location available. """ - __slots__ = ('data', 'mutable') + __slots__ = ("data", "mutable") def __init__(self, data, mutable=False): """ @@ -305,7 +315,7 @@ class data_source(base): self.mutable = mutable def _convert_data(self, mode): - if mode == 'bytes': + if mode == "bytes": if isinstance(self.data, bytes): return self.data return self.data.encode() @@ -318,9 +328,8 @@ class data_source(base): if writable: if not self.mutable: raise TypeError(f"data source {self} is not mutable") - return text_wr_StringIO(self._reset_data, - self._convert_data('text')) - return text_ro_StringIO(self._convert_data('text')) + return text_wr_StringIO(self._reset_data, self._convert_data("text")) + return text_ro_StringIO(self._convert_data("text")) def _reset_data(self, data): if isinstance(self.data, bytes): @@ -335,9 +344,8 @@ class data_source(base): if writable: if not self.mutable: raise TypeError(f"data source {self} is not mutable") - return bytes_wr_StringIO(self._reset_data, - self._convert_data('bytes')) - return bytes_ro_StringIO(self._convert_data('bytes')) + return bytes_wr_StringIO(self._reset_data, self._convert_data("bytes")) + return bytes_ro_StringIO(self._convert_data("bytes")) class text_data_source(data_source): @@ -355,7 +363,7 @@ class text_data_source(data_source): data_source.__init__(self, data, mutable=mutable) def _convert_data(self, mode): - if mode != 'bytes': + if mode != "bytes": return self.data return self.data.encode() @@ -375,7 +383,7 @@ class bytes_data_source(data_source): data_source.__init__(self, data, mutable=mutable) def _convert_data(self, mode): - if mode == 'bytes': + if mode == "bytes": return self.data return self.data.decode() @@ -390,6 +398,7 @@ class invokable_data_source(data_source): Note that this instance is explicitly readonly. """ + __slots__ = () def __init__(self, data): @@ -412,7 +421,9 @@ class invokable_data_source(data_source): return self.data(False) @classmethod - def wrap_function(cls, invokable, returns_text=True, returns_handle=False, encoding_hint=None): + def wrap_function( + cls, invokable, returns_text=True, returns_handle=False, encoding_hint=None + ): """ Helper function to automatically convert a function that returns text or bytes into appropriate callable @@ -425,10 +436,20 @@ class invokable_data_source(data_source): :param encoding_hint: the preferred encoding to use for encoding :return: invokable_data_source instance """ - return cls(partial(cls._simple_wrapper, invokable, encoding_hint, returns_text, returns_handle)) + return cls( + partial( + cls._simple_wrapper, + invokable, + encoding_hint, + returns_text, + returns_handle, + ) + ) @staticmethod - def _simple_wrapper(invokable, encoding_hint, returns_text, returns_handle, text_wanted): + def _simple_wrapper( + invokable, encoding_hint, returns_text, returns_handle, text_wanted + ): data = invokable() if returns_text != text_wanted: if text_wanted: @@ -446,7 +467,7 @@ class invokable_data_source(data_source): data = data.read() if encoding_hint is None: # fallback to utf8 - encoding_hint = 'utf8' + encoding_hint = "utf8" data = data.encode(encoding_hint) elif returns_handle: return data diff --git a/src/snakeoil/decorators.py b/src/snakeoil/decorators.py index 550d19c..c0e0429 100644 --- a/src/snakeoil/decorators.py +++ b/src/snakeoil/decorators.py @@ -7,29 +7,36 @@ from .contexts import Namespace, SplitExec def splitexec(func): """Run the decorated function in another process.""" + @wraps(func) def wrapper(*args, **kwargs): with SplitExec(): return func(*args, **kwargs) + return wrapper def namespace(**namespaces): """Run the decorated function in a specified namespace.""" + def decorator(func): @wraps(func) def wrapper(*args, **kwargs): with Namespace(**namespaces): return func(*args, **kwargs) + return wrapper + return decorator def coroutine(func): """Prime a coroutine for input.""" + @wraps(func) def prime(*args, **kwargs): cr = func(*args, **kwargs) next(cr) return cr + return prime diff --git a/src/snakeoil/demandimport.py b/src/snakeoil/demandimport.py index c87e114..0a24ccc 100644 --- a/src/snakeoil/demandimport.py +++ b/src/snakeoil/demandimport.py @@ -13,14 +13,16 @@ from importlib.util import LazyLoader _disabled = False # modules that have issues when lazily imported -_skip = frozenset([ - '__builtin__', - '__future__', - 'builtins', - 'grp', - 'pwd', - 'OpenSSL.SSL', # pyopenssl -]) +_skip = frozenset( + [ + "__builtin__", + "__future__", + "builtins", + "grp", + "pwd", + "OpenSSL.SSL", # pyopenssl + ] +) class _LazyLoader(LazyLoader): @@ -35,10 +37,8 @@ class _LazyLoader(LazyLoader): # custom loaders using our extended LazyLoader -_extensions_loader = _LazyLoader.factory( - importlib.machinery.ExtensionFileLoader) -_bytecode_loader = _LazyLoader.factory( - importlib.machinery.SourcelessFileLoader) +_extensions_loader = _LazyLoader.factory(importlib.machinery.ExtensionFileLoader) +_bytecode_loader = _LazyLoader.factory(importlib.machinery.SourcelessFileLoader) _source_loader = _LazyLoader.factory(importlib.machinery.SourceFileLoader) @@ -54,7 +54,11 @@ def _filefinder(path): def enable(): """Enable lazy loading for all future module imports.""" - if os.environ.get('SNAKEOIL_DEMANDIMPORT', 'y').lower() not in ('n', 'no' '0', 'false'): + if os.environ.get("SNAKEOIL_DEMANDIMPORT", "y").lower() not in ( + "n", + "no" "0", + "false", + ): sys.path_hooks.insert(0, _filefinder) diff --git a/src/snakeoil/demandload.py b/src/snakeoil/demandload.py index 0c84e8c..3800622 100644 --- a/src/snakeoil/demandload.py +++ b/src/snakeoil/demandload.py @@ -49,8 +49,9 @@ from .modules import load_any # There are some demandloaded imports below the definition of demandload. -_allowed_chars = "".join((x.isalnum() or x in "_.") and " " or "a" - for x in map(chr, range(256))) +_allowed_chars = "".join( + (x.isalnum() or x in "_.") and " " or "a" for x in map(chr, range(256)) +) def parse_imports(imports): @@ -72,15 +73,16 @@ def parse_imports(imports): :rtype: iterable of tuples of two C{str} objects. """ for s in imports: - fromlist = s.split(':', 1) + fromlist = s.split(":", 1) if len(fromlist) == 1: # Not a "from" import. - if '.' in s: + if "." in s: raise ValueError( "dotted imports are disallowed; see " "snakeoil.demandload docstring for " - f"details; {s!r}") - split = s.split('@', 1) + f"details; {s!r}" + ) + split = s.split("@", 1) for s in split: if not s.translate(_allowed_chars).isspace(): raise ValueError(f"bad target: {s}") @@ -94,28 +96,33 @@ def parse_imports(imports): base, targets = fromlist if not base.translate(_allowed_chars).isspace(): raise ValueError(f"bad target: {base}") - for target in targets.split(','): - split = target.split('@', 1) + for target in targets.split(","): + split = target.split("@", 1) for s in split: if not s.translate(_allowed_chars).isspace(): raise ValueError(f"bad target: {s}") - yield base + '.' + split[0], split[-1] + yield base + "." + split[0], split[-1] + def _protection_enabled_disabled(): return False + def _noisy_protection_disabled(): return False + def _protection_enabled_enabled(): val = os.environ.get("SNAKEOIL_DEMANDLOAD_PROTECTION", "n").lower() return val in ("yes", "true", "1", "y") + def _noisy_protection_enabled(): val = os.environ.get("SNAKEOIL_DEMANDLOAD_WARN", "y").lower() return val in ("yes", "true", "1", "y") -if 'pydoc' in sys.modules or 'epydoc' in sys.modules: + +if "pydoc" in sys.modules or "epydoc" in sys.modules: _protection_enabled = _protection_enabled_disabled _noisy_protection = _noisy_protection_disabled else: @@ -164,15 +171,15 @@ class Placeholder: """ if not callable(load_func): raise TypeError(f"load_func must be callable; got {load_func!r}") - object.__setattr__(self, '_scope', scope) - object.__setattr__(self, '_name', name) - object.__setattr__(self, '_replacing_tids', []) - object.__setattr__(self, '_load_func', load_func) - object.__setattr__(self, '_loading_lock', threading.Lock()) + object.__setattr__(self, "_scope", scope) + object.__setattr__(self, "_name", name) + object.__setattr__(self, "_replacing_tids", []) + object.__setattr__(self, "_load_func", load_func) + object.__setattr__(self, "_loading_lock", threading.Lock()) def _target_already_loaded(self, complain=True): - name = object.__getattribute__(self, '_name') - scope = object.__getattribute__(self, '_scope') + name = object.__getattribute__(self, "_name") + scope = object.__getattribute__(self, "_scope") # in a threaded environment, it's possible for tid1 to get the # placeholder from globals, python switches to tid2, which triggers @@ -188,13 +195,16 @@ class Placeholder: # it's impossible for this pathway to accidentally be triggered twice- # meaning it is a misuse by the consuming client code. if complain: - tids_to_complain_about = object.__getattribute__(self, '_replacing_tids') + tids_to_complain_about = object.__getattribute__(self, "_replacing_tids") if threading.current_thread().ident in tids_to_complain_about: if _protection_enabled(): - raise ValueError(f'Placeholder for {name!r} was triggered twice') + raise ValueError(f"Placeholder for {name!r} was triggered twice") elif _noisy_protection(): - logging.warning('Placeholder for %r was triggered multiple times ' - 'in file %r', name, scope.get("__file__", "unknown")) + logging.warning( + "Placeholder for %r was triggered multiple times " "in file %r", + name, + scope.get("__file__", "unknown"), + ) return scope[name] def _get_target(self): @@ -202,9 +212,9 @@ class Placeholder: :return: the result of calling C{_load_func}. """ - preloaded_func = object.__getattribute__(self, '_target_already_loaded') - with object.__getattribute__(self, '_loading_lock'): - load_func = object.__getattribute__(self, '_load_func') + preloaded_func = object.__getattribute__(self, "_target_already_loaded") + with object.__getattribute__(self, "_loading_lock"): + load_func = object.__getattribute__(self, "_load_func") if load_func is None: # This means that there was contention; two threads made it into # _get_target. That's fine; suppress complaints, and return the @@ -215,18 +225,17 @@ class Placeholder: # fix the scope, and replace this method with one that shortcircuits # (and appropriately complains) the lookup. result = load_func() - scope = object.__getattribute__(self, '_scope') - name = object.__getattribute__(self, '_name') + scope = object.__getattribute__(self, "_scope") + name = object.__getattribute__(self, "_name") scope[name] = result # Replace this method with the fast path/preloaded one; this # is to ensure complaints get leveled if needed. - object.__setattr__(self, '_get_target', preloaded_func) - object.__setattr__(self, '_load_func', None) - + object.__setattr__(self, "_get_target", preloaded_func) + object.__setattr__(self, "_load_func", None) # note this step *has* to follow scope modification; else it # will go maximum depth recursion. - tids = object.__getattribute__(self, '_replacing_tids') + tids = object.__getattribute__(self, "_replacing_tids") tids.append(threading.current_thread().ident) return result @@ -237,18 +246,18 @@ class Placeholder: # Various methods proxied to our replacement. def __str__(self): - return self.__getattribute__('__str__')() + return self.__getattribute__("__str__")() def __getattribute__(self, attr): - result = object.__getattribute__(self, '_get_target')() + result = object.__getattribute__(self, "_get_target")() return getattr(result, attr) def __setattr__(self, attr, value): - result = object.__getattribute__(self, '_get_target')() + result = object.__getattribute__(self, "_get_target")() setattr(result, attr, value) def __call__(self, *args, **kwargs): - result = object.__getattribute__(self, '_get_target')() + result = object.__getattribute__(self, "_get_target")() return result(*args, **kwargs) @@ -267,7 +276,7 @@ def demandload(*imports, **kwargs): """ # pull the caller's global namespace if undefined - scope = kwargs.pop('scope', sys._getframe(1).f_globals) + scope = kwargs.pop("scope", sys._getframe(1).f_globals) for source, target in parse_imports(imports): scope[target] = Placeholder.load_namespace(scope, target, source) @@ -280,7 +289,7 @@ enabled_demandload = demandload def disabled_demandload(*imports, **kwargs): """Exactly like :py:func:`demandload` but does all imports immediately.""" - scope = kwargs.pop('scope', sys._getframe(1).f_globals) + scope = kwargs.pop("scope", sys._getframe(1).f_globals) for source, target in parse_imports(imports): scope[target] = load_any(source) @@ -292,21 +301,25 @@ def demand_compile_regexp(name, *args, **kwargs): :param name: the name of the compiled re object in that scope. """ - scope = kwargs.pop('scope', sys._getframe(1).f_globals) + scope = kwargs.pop("scope", sys._getframe(1).f_globals) scope[name] = Placeholder.load_regex(scope, name, *args, **kwargs) def disabled_demand_compile_regexp(name, *args, **kwargs): """Exactly like :py:func:`demand_compile_regexp` but does all imports immediately.""" - scope = kwargs.pop('scope', sys._getframe(1).f_globals) + scope = kwargs.pop("scope", sys._getframe(1).f_globals) scope[name] = re.compile(*args, **kwargs) -if os.environ.get("SNAKEOIL_DEMANDLOAD_DISABLED", 'n').lower() in ('y', 'yes' '1', 'true'): +if os.environ.get("SNAKEOIL_DEMANDLOAD_DISABLED", "n").lower() in ( + "y", + "yes" "1", + "true", +): demandload = disabled_demandload demand_compile_regexp = disabled_demand_compile_regexp demandload( - 'logging', - 're', + "logging", + "re", ) diff --git a/src/snakeoil/dependant_methods.py b/src/snakeoil/dependant_methods.py index 551af1e..031ad63 100644 --- a/src/snakeoil/dependant_methods.py +++ b/src/snakeoil/dependant_methods.py @@ -63,7 +63,7 @@ def _ensure_deps(cls_id, name, func, self, *a, **kw): s = _yield_deps(self, self.stage_depends, name) r = True - if not hasattr(self, '_stage_state'): + if not hasattr(self, "_stage_state"): self._stage_state = set() for dep in s: if dep not in self._stage_state: @@ -108,8 +108,8 @@ def __wrap_stage_dependencies__(cls): f = getattr(cls, x) except AttributeError: raise TypeError( - "class %r stage_depends specifies %r, which doesn't exist" % - (cls, x)) + "class %r stage_depends specifies %r, which doesn't exist" % (cls, x) + ) f2 = pre_curry(_ensure_deps, cls_id, x, f) f2.sd_raw_func = f setattr(cls, x, f2) @@ -122,9 +122,9 @@ def __unwrap_stage_dependencies__(cls): f = getattr(cls, x) except AttributeError: raise TypeError( - "class %r stage_depends specifies %r, which doesn't exist" % - (cls, x)) - setattr(cls, x, getattr(f, 'sd_raw_func', f)) + "class %r stage_depends specifies %r, which doesn't exist" % (cls, x) + ) + setattr(cls, x, getattr(f, "sd_raw_func", f)) def __set_stage_state__(self, state): @@ -165,17 +165,17 @@ class ForcedDepends(type): def __new__(cls, name, bases, d): obj = super(ForcedDepends, cls).__new__(cls, name, bases, d) - if not hasattr(obj, 'stage_depends'): + if not hasattr(obj, "stage_depends"): obj.stage_depends = {} for x in ("wrap", "unwrap"): - s = '__%s_stage_dependencies__' % x + s = "__%s_stage_dependencies__" % x if not hasattr(obj, s): setattr(obj, s, classmethod(globals()[s])) obj.__unwrap_stage_dependencies__() obj.__wrap_stage_dependencies__() - if not hasattr(obj, '__force_stage_state__'): + if not hasattr(obj, "__force_stage_state__"): obj.__set_stage_state__ = __set_stage_state__ - if not hasattr(obj, '__stage_step_callback__'): + if not hasattr(obj, "__stage_step_callback__"): obj.__stage_step_callback__ = __stage_step_callback__ return obj diff --git a/src/snakeoil/errors.py b/src/snakeoil/errors.py index 9b7b541..ab37348 100644 --- a/src/snakeoil/errors.py +++ b/src/snakeoil/errors.py @@ -12,20 +12,20 @@ def walk_exception_chain(exc, ignore_first=False, reverse=False): def _inner_walk_exception_chain(exc, ignore_first): if not ignore_first: yield exc - exc = getattr(exc, '__cause__', None) + exc = getattr(exc, "__cause__", None) while exc is not None: yield exc - exc = getattr(exc, '__cause__', None) + exc = getattr(exc, "__cause__", None) def dump_error(raw_exc, msg=None, handle=sys.stderr, tb=None): # force default output for exceptions - if getattr(handle, 'reset', False): + if getattr(handle, "reset", False): handle.write(handle.reset) - prefix = '' + prefix = "" if msg: - prefix = ' ' + prefix = " " handle.write(msg.rstrip("\n") + ":\n") if tb: handle.write("Traceback follows:\n") @@ -34,8 +34,8 @@ def dump_error(raw_exc, msg=None, handle=sys.stderr, tb=None): if raw_exc is not None: for exc in walk_exception_chain(raw_exc): exc_strings.extend( - prefix + x.strip() - for x in (x for x in str(exc).split("\n") if x)) + prefix + x.strip() for x in (x for x in str(exc).split("\n") if x) + ) if exc_strings: if msg and tb: handle.write(f"\n{raw_exc.__class__.__name__}:\n") diff --git a/src/snakeoil/fileutils.py b/src/snakeoil/fileutils.py index ed3a034..b0aa724 100644 --- a/src/snakeoil/fileutils.py +++ b/src/snakeoil/fileutils.py @@ -23,23 +23,28 @@ def touch(fname, mode=0o644, **kwargs): See os.utime for other supported arguments. """ flags = os.O_CREAT | os.O_APPEND - dir_fd = kwargs.get('dir_fd', None) + dir_fd = kwargs.get("dir_fd", None) os_open = partial(os.open, dir_fd=dir_fd) with os.fdopen(os_open(fname, flags, mode)) as f: os.utime( f.fileno() if os.utime in os.supports_fd else fname, - dir_fd=None if os.supports_fd else dir_fd, **kwargs) + dir_fd=None if os.supports_fd else dir_fd, + **kwargs + ) + def mmap_or_open_for_read(path): size = os.stat(path).st_size if size == 0: - return (None, data_source.bytes_ro_StringIO(b'')) + return (None, data_source.bytes_ro_StringIO(b"")) fd = None try: fd = os.open(path, os.O_RDONLY) - return (_fileutils.mmap_and_close( - fd, size, mmap.MAP_SHARED, mmap.PROT_READ), None) + return ( + _fileutils.mmap_and_close(fd, size, mmap.MAP_SHARED, mmap.PROT_READ), + None, + ) except IGNORED_EXCEPTIONS: raise except: @@ -85,7 +90,8 @@ class AtomicWriteFile_mixin: fp = os.path.realpath(fp) self._original_fp = fp self._temp_fp = os.path.join( - os.path.dirname(fp), ".update." + os.path.basename(fp)) + os.path.dirname(fp), ".update." + os.path.basename(fp) + ) old_umask = None if perms: # give it just write perms @@ -140,7 +146,7 @@ class AtomicWriteFile(AtomicWriteFile_mixin): self.raw = open(self._temp_fp, mode=self._computed_mode) def _real_close(self): - if hasattr(self, 'raw'): + if hasattr(self, "raw"): return self.raw.close() return None @@ -149,24 +155,23 @@ class AtomicWriteFile(AtomicWriteFile_mixin): def _mk_pretty_derived_func(func, name_base, name, *args, **kwds): if name: - name = '_' + name - return pretty_docs(partial(func, *args, **kwds), - name='%s%s' % (name_base, name)) + name = "_" + name + return pretty_docs(partial(func, *args, **kwds), name="%s%s" % (name_base, name)) -_mk_readfile = partial( - _mk_pretty_derived_func, _fileutils.native_readfile, 'readfile') +_mk_readfile = partial(_mk_pretty_derived_func, _fileutils.native_readfile, "readfile") -readfile_ascii = _mk_readfile('ascii', 'rt') -readfile_bytes = _mk_readfile('bytes', 'rb') -readfile_utf8 = _mk_readfile('utf8', 'r', encoding='utf8') +readfile_ascii = _mk_readfile("ascii", "rt") +readfile_bytes = _mk_readfile("bytes", "rb") +readfile_utf8 = _mk_readfile("utf8", "r", encoding="utf8") readfile = readfile_utf8 _mk_readlines = partial( - _mk_pretty_derived_func, _fileutils.native_readlines, 'readlines') + _mk_pretty_derived_func, _fileutils.native_readlines, "readlines" +) -readlines_ascii = _mk_readlines('ascii', 'r', encoding='ascii') -readlines_bytes = _mk_readlines('bytes', 'rb') -readlines_utf8 = _mk_readlines('utf8', 'r', encoding='utf8') +readlines_ascii = _mk_readlines("ascii", "r", encoding="ascii") +readlines_bytes = _mk_readlines("bytes", "rb") +readlines_utf8 = _mk_readlines("utf8", "r", encoding="utf8") readlines = readlines_utf8 diff --git a/src/snakeoil/formatters.py b/src/snakeoil/formatters.py index 397667f..3faf661 100644 --- a/src/snakeoil/formatters.py +++ b/src/snakeoil/formatters.py @@ -10,7 +10,9 @@ from .klass import GetAttrProxy, steal_docs from .mappings import defaultdictkey __all__ = ( - "Formatter", "PlainTextFormatter", "get_formatter", + "Formatter", + "PlainTextFormatter", + "get_formatter", "decorate_forced_wrapping", ) @@ -98,13 +100,11 @@ class Formatter: def error(self, message): """Format a string as an error message.""" - self.write(message, prefixes=( - self.fg('red'), self.bold, '!!! ', self.reset)) + self.write(message, prefixes=(self.fg("red"), self.bold, "!!! ", self.reset)) def warn(self, message): """Format a string as a warning message.""" - self.write(message, prefixes=( - self.fg('yellow'), self.bold, '*** ', self.reset)) + self.write(message, prefixes=(self.fg("yellow"), self.bold, "*** ", self.reset)) def title(self, string): """Set the title to string""" @@ -123,7 +123,7 @@ class PlainTextFormatter(Formatter): every write. """ - bold = underline = reset = '' + bold = underline = reset = "" def __init__(self, stream, width=79, encoding=None): """Initialize. @@ -144,12 +144,12 @@ class PlainTextFormatter(Formatter): else: self.stream = stream if encoding is None: - encoding = getattr(self.stream, 'encoding', None) + encoding = getattr(self.stream, "encoding", None) if encoding is None: try: encoding = locale.getpreferredencoding() except locale.Error: - encoding = 'ascii' + encoding = "ascii" self.encoding = encoding self.width = width self._pos = 0 @@ -162,7 +162,7 @@ class PlainTextFormatter(Formatter): return True def _force_encoding(self, val): - return val.encode(self.encoding, 'replace') + return val.encode(self.encoding, "replace") def _write_prefix(self, wrap): if self._in_first_line: @@ -190,34 +190,32 @@ class PlainTextFormatter(Formatter): @steal_docs(Formatter) def write(self, *args, **kwargs): - wrap = kwargs.get('wrap', self.wrap) - autoline = kwargs.get('autoline', self.autoline) - prefixes = kwargs.get('prefixes') - first_prefixes = kwargs.get('first_prefixes') - later_prefixes = kwargs.get('later_prefixes') + wrap = kwargs.get("wrap", self.wrap) + autoline = kwargs.get("autoline", self.autoline) + prefixes = kwargs.get("prefixes") + first_prefixes = kwargs.get("first_prefixes") + later_prefixes = kwargs.get("later_prefixes") if prefixes is not None: if first_prefixes is not None or later_prefixes is not None: raise TypeError( - 'do not pass first_prefixes or later_prefixes ' - 'if prefixes is passed') + "do not pass first_prefixes or later_prefixes " + "if prefixes is passed" + ) first_prefixes = later_prefixes = prefixes - prefix = kwargs.get('prefix') - first_prefix = kwargs.get('first_prefix') - later_prefix = kwargs.get('later_prefix') + prefix = kwargs.get("prefix") + first_prefix = kwargs.get("first_prefix") + later_prefix = kwargs.get("later_prefix") if prefix is not None: if first_prefix is not None or later_prefix is not None: - raise TypeError( - 'do not pass first_prefix or later_prefix with prefix') + raise TypeError("do not pass first_prefix or later_prefix with prefix") first_prefix = later_prefix = prefix if first_prefix is not None: if first_prefixes is not None: - raise TypeError( - 'do not pass both first_prefix and first_prefixes') + raise TypeError("do not pass both first_prefix and first_prefixes") first_prefixes = (first_prefix,) if later_prefix is not None: if later_prefixes is not None: - raise TypeError( - 'do not pass both later_prefix and later_prefixes') + raise TypeError("do not pass both later_prefix and later_prefixes") later_prefixes = (later_prefix,) if first_prefixes is not None: self.first_prefix.extend(first_prefixes) @@ -242,7 +240,7 @@ class PlainTextFormatter(Formatter): while wrap and self._pos + len(arg) > self.width: # We have to split. maxlen = self.width - self._pos - space = arg.rfind(' ', 0, maxlen) + space = arg.rfind(" ", 0, maxlen) if space == -1: # No space to split on. @@ -254,7 +252,7 @@ class PlainTextFormatter(Formatter): # written something we can also go to the next # line. if self._in_first_line or self._wrote_something: - bit = '' + bit = "" else: # Forcibly split this as far to the right as # possible. @@ -263,11 +261,11 @@ class PlainTextFormatter(Formatter): else: bit = arg[:space] # Omit the space we split on. - arg = arg[space + 1:] + arg = arg[space + 1 :] if conversion_needed: bit = self._force_encoding(bit) self.stream.write(bit) - self.stream.write(self._force_encoding('\n')) + self.stream.write(self._force_encoding("\n")) self._pos = 0 self._in_first_line = False self._wrote_something = False @@ -280,7 +278,7 @@ class PlainTextFormatter(Formatter): arg = self._force_encoding(arg) self.stream.write(arg) if autoline: - self.stream.write(self._force_encoding('\n')) + self.stream.write(self._force_encoding("\n")) self._wrote_something = False self._pos = 0 self._in_first_line = True @@ -290,32 +288,28 @@ class PlainTextFormatter(Formatter): raise finally: if first_prefixes is not None: - self.first_prefix = self.first_prefix[:-len(first_prefixes)] + self.first_prefix = self.first_prefix[: -len(first_prefixes)] if later_prefixes is not None: - self.later_prefix = self.later_prefix[:-len(later_prefixes)] + self.later_prefix = self.later_prefix[: -len(later_prefixes)] def fg(self, color=None): """change fg color Compatibility method- no coloring escapes are returned from it. """ - return '' + return "" def bg(self, color=None): """change bg color Compatibility method- no coloring escapes are returned from it. """ - return '' + return "" def flush(self): self.stream.flush() - - - - class TerminfoDisabled(Exception): """Raised if Terminfo is disabled.""" @@ -331,7 +325,7 @@ class TerminfoUnsupported(Exception): self.term = term def __str__(self): - return f'unsupported terminal type: {self.term!r}' + return f"unsupported terminal type: {self.term!r}" # This is necessary because the curses module is optional (and we @@ -341,6 +335,7 @@ try: except ImportError: TerminfoColor = None else: + class TerminfoColor: """Class encapsulating a specific terminfo entry for a color. @@ -351,8 +346,8 @@ else: __slots__ = ("mode", "color", "__weakref__") def __init__(self, mode, color): - object.__setattr__(self, 'mode', mode) - object.__setattr__(self, 'color', color) + object.__setattr__(self, "mode", mode) + object.__setattr__(self, "color", color) def __call__(self, formatter): if self.color is None: @@ -374,7 +369,7 @@ else: if template: res = curses.tparm(template, color) else: - res = b'' + res = b"" formatter._current_colors[self.mode] = res formatter.stream.write(res) @@ -393,7 +388,7 @@ else: def __init__(self, value): if value is None: raise _BogusTerminfo() - object.__setattr__(self, 'value', value) + object.__setattr__(self, "value", value) def __setattr__(self, key, value): raise AttributeError(f"{self.__class__.__name__} instances are immutable") @@ -441,33 +436,32 @@ else: super().__init__(stream, encoding=encoding) fd = stream.fileno() if term is None: - if term := os.environ.get('TERM'): + if term := os.environ.get("TERM"): try: curses.setupterm(fd=fd, term=term) except curses.error: pass else: - raise TerminfoDisabled('no terminfo entries') + raise TerminfoDisabled("no terminfo entries") else: # TODO maybe do something more useful than raising curses.error # if term is not in the terminfo db here? curses.setupterm(fd=fd, term=term) self._term = term - self.width = curses.tigetnum('cols') + self.width = curses.tigetnum("cols") try: - self.reset = TerminfoReset(curses.tigetstr('sgr0')) - self.bold = TerminfoMode(curses.tigetstr('bold')) - self.underline = TerminfoMode(curses.tigetstr('smul')) - self._color_reset = curses.tigetstr('op') - self._set_color = ( - curses.tigetstr('setaf'), - curses.tigetstr('setab')) + self.reset = TerminfoReset(curses.tigetstr("sgr0")) + self.bold = TerminfoMode(curses.tigetstr("bold")) + self.underline = TerminfoMode(curses.tigetstr("smul")) + self._color_reset = curses.tigetstr("op") + self._set_color = (curses.tigetstr("setaf"), curses.tigetstr("setab")) except (_BogusTerminfo, curses.error) as e: raise TerminfoUnsupported(self._term) from e if not all(self._set_color): raise TerminfoDisabled( - 'setting background/foreground colors is not supported') + "setting background/foreground colors is not supported" + ) curses.tparm(self._set_color[0], curses.COLOR_WHITE) @@ -507,16 +501,14 @@ else: # not set the hs flag. So just check for the ability to # jump to and out of the status line, without checking if # the status line we're using exists. - tsl = curses.tigetstr('tsl') - fsl = curses.tigetstr('fsl') + tsl = curses.tigetstr("tsl") + fsl = curses.tigetstr("fsl") if tsl and fsl: - self.stream.write( - tsl + string.encode(self.encoding, 'replace') + fsl) + self.stream.write(tsl + string.encode(self.encoding, "replace") + fsl) self.stream.flush() class ObserverFormatter: - def __init__(self, real_formatter): self._formatter = real_formatter @@ -542,7 +534,7 @@ def get_formatter(stream, force_color=False): # needs an fd to pass to curses, not just a filelike talking to a tty. if os.isatty(fd) or force_color: try: - term = 'ansi' if force_color else None + term = "ansi" if force_color else None return TerminfoFormatter(stream, term=term) except (curses.error, TerminfoDisabled, TerminfoUnsupported): # This happens if TERM is unset and possibly in more cases. @@ -553,6 +545,7 @@ def get_formatter(stream, force_color=False): def decorate_forced_wrapping(setting=True): """Decorator to force a specific line wrapping state for the duration of invocation.""" + def wrapped_func(func): def f(out, *args, **kwds): oldwrap = out.wrap @@ -561,5 +554,7 @@ def decorate_forced_wrapping(setting=True): return func(out, *args, **kwds) finally: out.wrap = oldwrap + return f + return wrapped_func diff --git a/src/snakeoil/iterables.py b/src/snakeoil/iterables.py index 787af67..7176a2b 100644 --- a/src/snakeoil/iterables.py +++ b/src/snakeoil/iterables.py @@ -19,8 +19,7 @@ def partition(iterable, predicate=bool): filter and the second the matched items. """ a, b = itertools.tee((predicate(x), x) for x in iterable) - return ((x for pred, x in a if not pred), - (x for pred, x in b if pred)) + return ((x for pred, x in a if not pred), (x for pred, x in b if pred)) class expandable_chain: @@ -107,6 +106,7 @@ class caching_iter: 3 """ + __slots__ = ("iterable", "__weakref__", "cached_list", "sorter") def __init__(self, iterable, sorter=None): @@ -139,7 +139,7 @@ class caching_iter: if self.iterable is not None: i = itertools.islice(self.iterable, 0, index - (existing_len - 1)) self.cached_list.extend(i) - if len(self.cached_list) -1 != index: + if len(self.cached_list) - 1 != index: # consumed, baby. self.iterable = None self.cached_list = tuple(self.cached_list) @@ -209,8 +209,7 @@ class caching_iter: return len(self.cached_list) def __iter__(self): - if (self.sorter is not None and - self.iterable is not None): + if self.sorter is not None and self.iterable is not None: if self.cached_list: self.cached_list.extend(self.iterable) self.cached_list = tuple(self.sorter(self.cached_list)) @@ -237,8 +236,7 @@ class caching_iter: return hash(self.cached_list) def __str__(self): - return "iterable(%s), cached: %s" % ( - self.iterable, str(self.cached_list)) + return "iterable(%s), cached: %s" % (self.iterable, str(self.cached_list)) def iter_sort(sorter, *iterables): diff --git a/src/snakeoil/klass.py b/src/snakeoil/klass.py index 0e59258..23d6d3c 100644 --- a/src/snakeoil/klass.py +++ b/src/snakeoil/klass.py @@ -7,12 +7,27 @@ involved in writing classes. """ __all__ = ( - "generic_equality", "reflective_hash", "inject_richcmp_methods_from_cmp", - "static_attrgetter", "instance_attrgetter", "jit_attr", "jit_attr_none", - "jit_attr_named", "jit_attr_ext_method", "alias_attr", "cached_hash", - "cached_property", "cached_property_named", - "steal_docs", "immutable_instance", "inject_immutable_instance", - "alias_method", "aliased", "alias", "patch", "SlotsPicklingMixin", + "generic_equality", + "reflective_hash", + "inject_richcmp_methods_from_cmp", + "static_attrgetter", + "instance_attrgetter", + "jit_attr", + "jit_attr_none", + "jit_attr_named", + "jit_attr_ext_method", + "alias_attr", + "cached_hash", + "cached_property", + "cached_property_named", + "steal_docs", + "immutable_instance", + "inject_immutable_instance", + "alias_method", + "aliased", + "alias", + "patch", + "SlotsPicklingMixin", ) import inspect @@ -32,6 +47,7 @@ sentinel = object() def GetAttrProxy(target): def reflected_getattr(self, attr): return getattr(object.__getattribute__(self, target), attr) + return reflected_getattr @@ -43,6 +59,7 @@ def DirProxy(target): except AttributeError: attrs.extend(obj.__slots__) return sorted(set(attrs)) + return combined_dir @@ -69,6 +86,8 @@ def get(self, key, default=None): _attrlist_getter = attrgetter("__attr_comparison__") + + def generic_attr_eq(inst1, inst2): """ compare inst1 to inst2, returning True if equal, False if not. @@ -78,8 +97,7 @@ def generic_attr_eq(inst1, inst2): if inst1 is inst2: return True for attr in _attrlist_getter(inst1): - if getattr(inst1, attr, sentinel) != \ - getattr(inst2, attr, sentinel): + if getattr(inst1, attr, sentinel) != getattr(inst2, attr, sentinel): return False return True @@ -105,28 +123,36 @@ def reflective_hash(attr): :param attr: attribute name to pull the hash from on the instance :return: hash value for instance this func is used in. """ + def __hash__(self): return getattr(self, attr) + return __hash__ + def _internal_jit_attr( - func, attr_name, singleton=None, - use_cls_setattr=False, use_singleton=True, doc=None): + func, attr_name, singleton=None, use_cls_setattr=False, use_singleton=True, doc=None +): """Object implementing the descriptor protocol for use in Just In Time access to attributes. Consumers should likely be using the :py:func:`jit_func` line of helper functions instead of directly consuming this. """ - doc = getattr(func, '__doc__', None) if doc is None else doc + doc = getattr(func, "__doc__", None) if doc is None else doc class _internal_jit_attr(_raw_internal_jit_attr): __doc__ = doc __slots__ = () + kls = _internal_jit_attr return kls( - func, attr_name, singleton=singleton, use_cls_setattr=use_cls_setattr, - use_singleton=use_singleton) + func, + attr_name, + singleton=singleton, + use_cls_setattr=use_cls_setattr, + use_singleton=use_singleton, + ) class _raw_internal_jit_attr: @@ -134,8 +160,9 @@ class _raw_internal_jit_attr: __slots__ = ("storage_attr", "function", "_setter", "singleton", "use_singleton") - def __init__(self, func, attr_name, singleton=None, - use_cls_setattr=False, use_singleton=True): + def __init__( + self, func, attr_name, singleton=None, use_cls_setattr=False, use_singleton=True + ): """ :param func: function to invoke upon first request for this content :param attr_name: attribute name to store the generated value in @@ -178,8 +205,9 @@ class _raw_internal_jit_attr: return obj -def generic_equality(name, bases, scope, real_type=type, - eq=generic_attr_eq, ne=generic_attr_ne): +def generic_equality( + name, bases, scope, real_type=type, eq=generic_attr_eq, ne=generic_attr_ne +): """ metaclass generating __eq__/__ne__ methods from an attribute list @@ -208,7 +236,9 @@ def generic_equality(name, bases, scope, real_type=type, attrlist = scope[attrlist] for x in attrlist: if not isinstance(x, str): - raise TypeError(f"all members of attrlist must be strings- got {type(x)!r} {x!r}") + raise TypeError( + f"all members of attrlist must be strings- got {type(x)!r} {x!r}" + ) scope["__attr_comparison__"] = tuple(attrlist) scope.setdefault("__eq__", eq) @@ -285,9 +315,14 @@ def inject_richcmp_methods_from_cmp(scope): :param scope: the modifiable scope of a class namespace to work on """ - for key, func in (("__lt__", generic_lt), ("__le__", generic_le), - ("__eq__", generic_eq), ("__ne__", generic_ne), - ("__ge__", generic_ge), ("__gt__", generic_gt)): + for key, func in ( + ("__lt__", generic_lt), + ("__le__", generic_le), + ("__eq__", generic_eq), + ("__ne__", generic_ne), + ("__ge__", generic_ge), + ("__gt__", generic_gt), + ): scope.setdefault(key, func) @@ -329,7 +364,8 @@ class chained_getter(metaclass=partial(generic_equality, real_type=WeakInstMeta) >>> print(o.recursive == foo.seq.__hash__) True """ - __slots__ = ('namespace', 'getter') + + __slots__ = ("namespace", "getter") __fifo_cache__ = deque() __inst_caching__ = True __attr_comparison__ = ("namespace",) @@ -361,16 +397,20 @@ instance_attrgetter = chained_getter # this annoyingly means our docs have to be recommitted every change, # even if no real code changed (since the id() continually moves)... class _singleton_kls: - def __str__(self): return "uncached singleton instance" _uncached_singleton = _singleton_kls -T = typing.TypeVar('T') +T = typing.TypeVar("T") + -def jit_attr(func: typing.Callable[[typing.Any], T], kls=_internal_jit_attr, uncached_val: typing.Any=_uncached_singleton) -> T: +def jit_attr( + func: typing.Callable[[typing.Any], T], + kls=_internal_jit_attr, + uncached_val: typing.Any = _uncached_singleton, +) -> T: """ decorator to JIT generate, and cache the wrapped functions result in '_' + func.__name__ on the instance. @@ -399,8 +439,13 @@ def jit_attr_none(func: typing.Callable[[typing.Any], T], kls=_internal_jit_attr return jit_attr(func, kls=kls, uncached_val=None) -def jit_attr_named(stored_attr_name: str, use_cls_setattr=False, kls=_internal_jit_attr, - uncached_val: typing.Any=_uncached_singleton, doc=None): +def jit_attr_named( + stored_attr_name: str, + use_cls_setattr=False, + kls=_internal_jit_attr, + uncached_val: typing.Any = _uncached_singleton, + doc=None, +): """ Version of :py:func:`jit_attr` decorator that allows for explicit control over the attribute name used to store the cache value. @@ -410,9 +455,14 @@ def jit_attr_named(stored_attr_name: str, use_cls_setattr=False, kls=_internal_j return post_curry(kls, stored_attr_name, uncached_val, use_cls_setattr, doc=doc) -def jit_attr_ext_method(func_name: str, stored_attr_name: str, - use_cls_setattr=False, kls=_internal_jit_attr, - uncached_val: typing.Any=_uncached_singleton, doc=None): +def jit_attr_ext_method( + func_name: str, + stored_attr_name: str, + use_cls_setattr=False, + kls=_internal_jit_attr, + uncached_val: typing.Any = _uncached_singleton, + doc=None, +): """ Decorator handing maximal control of attribute JIT'ing to the invoker. @@ -421,11 +471,20 @@ def jit_attr_ext_method(func_name: str, stored_attr_name: str, Generally speaking, you only need this when you are doing something rather *special*. """ - return kls(alias_method(func_name), stored_attr_name, - uncached_val, use_cls_setattr, doc=doc) + return kls( + alias_method(func_name), + stored_attr_name, + uncached_val, + use_cls_setattr, + doc=doc, + ) -def cached_property(func: typing.Callable[[typing.Any], T], kls=_internal_jit_attr, use_cls_setattr=False) -> T: +def cached_property( + func: typing.Callable[[typing.Any], T], + kls=_internal_jit_attr, + use_cls_setattr=False, +) -> T: """ like `property`, just with caching @@ -454,8 +513,9 @@ def cached_property(func: typing.Callable[[typing.Any], T], kls=_internal_jit_at >>> print(obj.attr) 1 """ - return kls(func, func.__name__, None, use_singleton=False, - use_cls_setattr=use_cls_setattr) + return kls( + func, func.__name__, None, use_singleton=False, use_cls_setattr=use_cls_setattr + ) def cached_property_named(name: str, kls=_internal_jit_attr, use_cls_setattr=False): @@ -538,11 +598,13 @@ def cached_hash(func): >>> assert hash(f) == 12345 # note we still get the same value >>> assert f.hash_invocations == 1 # and that the function was invoked only once. """ + def __hash__(self): - val = getattr(self, '_hash', None) + val = getattr(self, "_hash", None) if val is None: - object.__setattr__(self, '_hash', val := func(self)) + object.__setattr__(self, "_hash", val := func(self)) return val + return __hash__ @@ -574,6 +636,7 @@ def steal_docs(target, ignore_missing=False, name=None): >>> f = foo([1,2,3]) >>> assert f.extend.__doc__ == list.extend.__doc__ """ + def inner(functor): if inspect.isclass(target): if name is not None: @@ -590,6 +653,7 @@ def steal_docs(target, ignore_missing=False, name=None): obj = target functor.__doc__ = obj.__doc__ return functor + return inner @@ -611,7 +675,7 @@ def patch(target, external_decorator=None): """ def _import_module(target): - components = target.split('.') + components = target.split(".") import_path = components.pop(0) module = import_module(import_path) for comp in components: @@ -624,7 +688,7 @@ def patch(target, external_decorator=None): def _get_target(target): try: - module, attr = target.rsplit('.', 1) + module, attr = target.rsplit(".", 1) except (TypeError, ValueError): raise TypeError(f"invalid target: {target!r}") module = _import_module(module) @@ -632,7 +696,7 @@ def patch(target, external_decorator=None): def decorator(func): # use the original function wrapper - func = getattr(func, '_func', func) + func = getattr(func, "_func", func) module, attr = _get_target(target) orig_func = getattr(module, attr) @@ -749,6 +813,7 @@ class alias: >>> speak = Speak() >>> assert speak.shout('foo') == speak.yell('foo') == speak.scream('foo') """ + def __init__(self, *aliases): self.aliases = set(aliases) @@ -762,11 +827,14 @@ def aliased(cls): orig_methods = cls.__dict__.copy() seen_aliases = set() for _name, method in orig_methods.items(): - if hasattr(method, '_aliases'): - collisions = method._aliases.intersection(orig_methods.keys() | seen_aliases) + if hasattr(method, "_aliases"): + collisions = method._aliases.intersection( + orig_methods.keys() | seen_aliases + ) if collisions: raise ValueError( - f"aliases collide with existing attributes: {', '.join(collisions)}") + f"aliases collide with existing attributes: {', '.join(collisions)}" + ) seen_aliases |= method._aliases for alias in method._aliases: setattr(cls, alias, method) @@ -780,9 +848,13 @@ class SlotsPicklingMixin: def __getstate__(self): all_slots = itertools.chain.from_iterable( - getattr(t, '__slots__', ()) for t in type(self).__mro__) - state = {attr: getattr(self, attr) for attr in all_slots - if hasattr(self, attr) and attr != '__weakref__'} + getattr(t, "__slots__", ()) for t in type(self).__mro__ + ) + state = { + attr: getattr(self, attr) + for attr in all_slots + if hasattr(self, attr) and attr != "__weakref__" + } return state def __setstate__(self, state): diff --git a/src/snakeoil/mappings.py b/src/snakeoil/mappings.py index c349897..d4ac221 100644 --- a/src/snakeoil/mappings.py +++ b/src/snakeoil/mappings.py @@ -3,10 +3,17 @@ Miscellaneous mapping related classes and functionality """ __all__ = ( - "DictMixin", "LazyValDict", "LazyFullValLoadDict", - "ProtectedDict", "ImmutableDict", "IndeterminantDict", - "defaultdictkey", "AttrAccessible", "StackedDict", - "make_SlottedDict_kls", "ProxiedAttrs", + "DictMixin", + "LazyValDict", + "LazyFullValLoadDict", + "ProtectedDict", + "ImmutableDict", + "IndeterminantDict", + "defaultdictkey", + "AttrAccessible", + "StackedDict", + "make_SlottedDict_kls", + "ProxiedAttrs", ) import operator @@ -168,6 +175,7 @@ class LazyValDict(DictMixin): given a function to get keys, and to look up the val for those keys, it'll lazily load key definitions and values as requested """ + __slots__ = ("_keys", "_keys_func", "_vals", "_val_func") __externally_mutable__ = False @@ -184,8 +192,7 @@ class LazyValDict(DictMixin): self._keys_func = None else: if not callable(get_keys_func): - raise TypeError( - "get_keys_func isn't iterable or callable") + raise TypeError("get_keys_func isn't iterable or callable") self._keys_func = get_keys_func self._val_func = get_val_func self._vals = {} @@ -234,6 +241,7 @@ class LazyFullValLoadDict(LazyValDict): The val function must still return values one by one per key. """ + __slots__ = () def __getitem__(self, key): @@ -297,8 +305,7 @@ class ProtectedDict(DictMixin): yield k def __contains__(self, key): - return key in self.new or (key not in self.blacklist and - key in self.orig) + return key in self.new or (key not in self.blacklist and key in self.orig) class ImmutableDict(Mapping): @@ -320,14 +327,14 @@ class ImmutableDict(Mapping): try: mapping = {k: v for k, v in data} except TypeError as e: - raise TypeError(f'unsupported data format: {e}') - object.__setattr__(self, '_dict', mapping) + raise TypeError(f"unsupported data format: {e}") + object.__setattr__(self, "_dict", mapping) def __getitem__(self, key): # hack to avoid recursion exceptions for subclasses that use # inject_getitem_as_getattr() - if key == '_dict': - return object.__getattribute__(self, '_dict') + if key == "_dict": + return object.__getattribute__(self, "_dict") return self._dict[key] def __iter__(self): @@ -356,7 +363,7 @@ class OrderedFrozenSet(Set): try: self._dict = ImmutableDict({x: None for x in iterable}) except TypeError as e: - raise TypeError('not iterable') from e + raise TypeError("not iterable") from e def __contains__(self, key): return key in self._dict @@ -369,7 +376,7 @@ class OrderedFrozenSet(Set): try: return next(islice(self._dict, key, None)) except StopIteration: - raise IndexError('index out of range') + raise IndexError("index out of range") # handle keys using slice notation return self.__class__(list(self._dict)[key]) @@ -384,8 +391,8 @@ class OrderedFrozenSet(Set): return set(self._dict) == other def __str__(self): - elements_str = ', '.join(map(repr, self._dict)) - return f'{{{elements_str}}}' + elements_str = ", ".join(map(repr, self._dict)) + return f"{{{elements_str}}}" def __repr__(self): return self.__str__() @@ -413,7 +420,7 @@ class OrderedSet(OrderedFrozenSet, MutableSet): try: self._dict = {x: None for x in iterable} except TypeError as e: - raise TypeError('not iterable') from e + raise TypeError("not iterable") from e def add(self, value): self._dict[value] = None @@ -434,7 +441,7 @@ class OrderedSet(OrderedFrozenSet, MutableSet): self._dict.update((x, None) for x in iterable) def __hash__(self): - raise TypeError(f'unhashable type: {self.__class__.__name__!r}') + raise TypeError(f"unhashable type: {self.__class__.__name__!r}") class IndeterminantDict: @@ -473,12 +480,21 @@ class IndeterminantDict: def __unmodifiable(func, *args): raise TypeError(f"indeterminate dict: '{func}()' can't modify {args!r}") - for func in ('__delitem__', '__setitem__', 'setdefault', 'popitem', 'update', 'clear'): + + for func in ( + "__delitem__", + "__setitem__", + "setdefault", + "popitem", + "update", + "clear", + ): locals()[func] = partial(__unmodifiable, func) def __indeterminate(func, *args): raise TypeError(f"indeterminate dict: '{func}()' is inaccessible") - for func in ('__iter__', '__len__', 'keys', 'values', 'items'): + + for func in ("__iter__", "__len__", "keys", "values", "items"): locals()[func] = partial(__indeterminate, func) @@ -650,6 +666,7 @@ def _KeyError_to_Attr(functor): return functor(self, *args) except KeyError: raise AttributeError(args[0]) + inner.__name__ = functor.__name__ inner.__doc__ = functor.__doc__ return inner @@ -681,9 +698,9 @@ def inject_getitem_as_getattr(scope): :param scope: the scope of a class to modify, adding methods as needed """ - scope.setdefault('__getattr__', _KeyError_to_Attr(operator.__getitem__)) - scope.setdefault('__delattr__', _KeyError_to_Attr(operator.__delitem__)) - scope.setdefault('__setattr__', _KeyError_to_Attr(operator.__setitem__)) + scope.setdefault("__getattr__", _KeyError_to_Attr(operator.__getitem__)) + scope.setdefault("__delattr__", _KeyError_to_Attr(operator.__delitem__)) + scope.setdefault("__setattr__", _KeyError_to_Attr(operator.__setitem__)) class AttrAccessible(dict): @@ -713,7 +730,7 @@ class ProxiedAttrs(DictMixin): :param target: The object to wrap. """ - __slots__ = ('__target__',) + __slots__ = ("__target__",) def __init__(self, target): self.__target__ = target @@ -860,7 +877,7 @@ class _SlottedDict(DictMixin): def make_SlottedDict_kls(keys): """Create a space efficient mapping class with a limited set of keys.""" new_keys = tuple(sorted(keys)) - cls_name = f'SlottedDict_{hash(new_keys)}' + cls_name = f"SlottedDict_{hash(new_keys)}" o = globals().get(cls_name, None) if o is None: o = type(cls_name, (_SlottedDict,), {}) diff --git a/src/snakeoil/modules.py b/src/snakeoil/modules.py index 740ea24..ec69701 100644 --- a/src/snakeoil/modules.py +++ b/src/snakeoil/modules.py @@ -14,6 +14,7 @@ class FailedImport(ImportError): """ Raised when a requested target cannot be imported """ + def __init__(self, trg, e): super().__init__(self, f"Failed importing target '{trg}': '{e}'") self.trg, self.e = trg, e diff --git a/src/snakeoil/obj.py b/src/snakeoil/obj.py index 73cede9..a8598bf 100644 --- a/src/snakeoil/obj.py +++ b/src/snakeoil/obj.py @@ -74,7 +74,6 @@ try to proxy builtin objects like tuples, lists, dicts, sets, etc. """ - __all__ = ("DelayedInstantiation", "DelayedInstantiation_kls", "make_kls", "popattr") from . import klass @@ -87,14 +86,25 @@ from . import klass # pointless class creation- thus having two separate lists. base_kls_descriptors = [ - '__delattr__', '__hash__', '__reduce__', - '__reduce_ex__', '__repr__', '__setattr__', '__str__', - '__format__', '__subclasshook__', # >=py2.6 - '__le__', '__lt__', '__ge__', '__gt__', '__eq__', '__ne__', # py3 - '__dir__', # >=py3.3 + "__delattr__", + "__hash__", + "__reduce__", + "__reduce_ex__", + "__repr__", + "__setattr__", + "__str__", + "__format__", + "__subclasshook__", # >=py2.6 + "__le__", + "__lt__", + "__ge__", + "__gt__", + "__eq__", + "__ne__", # py3 + "__dir__", # >=py3.3 ] -if hasattr(object, '__sizeof__'): - base_kls_descriptors.append('__sizeof__') +if hasattr(object, "__sizeof__"): + base_kls_descriptors.append("__sizeof__") base_kls_descriptors = frozenset(base_kls_descriptors) @@ -134,13 +144,13 @@ class BaseDelayedObject: def __getattribute__(self, attr): obj = object.__getattribute__(self, "__obj__") if obj is None: - if attr == '__class__': + if attr == "__class__": return object.__getattribute__(self, "__delayed__")[0] - elif attr == '__doc__': + elif attr == "__doc__": kls = object.__getattribute__(self, "__delayed__")[0] - return getattr(kls, '__doc__', None) + return getattr(kls, "__doc__", None) - obj = object.__getattribute__(self, '__instantiate_proxy_instance__')() + obj = object.__getattribute__(self, "__instantiate_proxy_instance__")() if attr == "__obj__": # special casing for klass.alias_method @@ -157,61 +167,122 @@ class BaseDelayedObject: # special case the normal descriptors for x in base_kls_descriptors: locals()[x] = klass.alias_method( - "__obj__.%s" % (x,), - doc=getattr(getattr(object, x), '__doc__', None)) + "__obj__.%s" % (x,), doc=getattr(getattr(object, x), "__doc__", None) + ) # pylint: disable=undefined-loop-variable del x # note that we ignore __getattribute__; we already handle it. -kls_descriptors = frozenset([ - # rich comparison protocol... - '__le__', '__lt__', '__eq__', '__ne__', '__gt__', '__ge__', - # unicode conversion - '__unicode__', - # truth... - '__bool__', - # container protocol... - '__len__', '__getitem__', '__setitem__', '__delitem__', - '__iter__', '__contains__', '__index__', '__reversed__', - # deprecated sequence protocol bits... - '__getslice__', '__setslice__', '__delslice__', - # numeric... - '__add__', '__sub__', '__mul__', '__floordiv__', '__mod__', - '__divmod__', '__pow__', '__lshift__', '__rshift__', - '__and__', '__xor__', '__or__', '__div__', '__truediv__', - '__rad__', '__rsub__', '__rmul__', '__rdiv__', '__rtruediv__', - '__rfloordiv__', '__rmod__', '__rdivmod__', '__rpow__', - '__rlshift__', '__rrshift__', '__rand__', '__rxor__', '__ror__', - '__iadd__', '__isub__', '__imul__', '__idiv__', '__itruediv__', - '__ifloordiv__', '__imod__', '__ipow__', '__ilshift__', - '__irshift__', '__iand__', '__ixor__', '__ior__', - '__neg__', '__pos__', '__abs__', '__invert__', '__complex__', - '__int__', '__long__', '__float__', '__oct__', '__hex__', - '__coerce__', '__trunc__', '__radd__', '__floor__', '__ceil__', - '__round__', - # remaining... - '__call__', '__sizeof__', -]) +kls_descriptors = frozenset( + [ + # rich comparison protocol... + "__le__", + "__lt__", + "__eq__", + "__ne__", + "__gt__", + "__ge__", + # unicode conversion + "__unicode__", + # truth... + "__bool__", + # container protocol... + "__len__", + "__getitem__", + "__setitem__", + "__delitem__", + "__iter__", + "__contains__", + "__index__", + "__reversed__", + # deprecated sequence protocol bits... + "__getslice__", + "__setslice__", + "__delslice__", + # numeric... + "__add__", + "__sub__", + "__mul__", + "__floordiv__", + "__mod__", + "__divmod__", + "__pow__", + "__lshift__", + "__rshift__", + "__and__", + "__xor__", + "__or__", + "__div__", + "__truediv__", + "__rad__", + "__rsub__", + "__rmul__", + "__rdiv__", + "__rtruediv__", + "__rfloordiv__", + "__rmod__", + "__rdivmod__", + "__rpow__", + "__rlshift__", + "__rrshift__", + "__rand__", + "__rxor__", + "__ror__", + "__iadd__", + "__isub__", + "__imul__", + "__idiv__", + "__itruediv__", + "__ifloordiv__", + "__imod__", + "__ipow__", + "__ilshift__", + "__irshift__", + "__iand__", + "__ixor__", + "__ior__", + "__neg__", + "__pos__", + "__abs__", + "__invert__", + "__complex__", + "__int__", + "__long__", + "__float__", + "__oct__", + "__hex__", + "__coerce__", + "__trunc__", + "__radd__", + "__floor__", + "__ceil__", + "__round__", + # remaining... + "__call__", + "__sizeof__", + ] +) kls_descriptors = kls_descriptors.difference(base_kls_descriptors) -descriptor_overrides = {k: klass.alias_method(f"__obj__.{k}") - for k in kls_descriptors} +descriptor_overrides = {k: klass.alias_method(f"__obj__.{k}") for k in kls_descriptors} _method_cache = {} + + def make_kls(kls, proxy_base=BaseDelayedObject): special_descriptors = kls_descriptors.intersection(dir(kls)) - doc = getattr(kls, '__doc__', None) + doc = getattr(kls, "__doc__", None) if not special_descriptors and doc is None: return proxy_base key = (tuple(sorted(special_descriptors)), doc) o = _method_cache.get(key, None) if o is None: + class CustomDelayedObject(proxy_base): - locals().update((k, descriptor_overrides[k]) - for k in special_descriptors) + locals().update((k, descriptor_overrides[k]) for k in special_descriptors) __doc__ = doc o = CustomDelayedObject @@ -230,6 +301,8 @@ def DelayedInstantiation_kls(kls, *a, **kwd): _class_cache = {} + + def DelayedInstantiation(resultant_kls, func, *a, **kwd): """Generate an objects that does not get initialized before it is used. diff --git a/src/snakeoil/osutils/__init__.py b/src/snakeoil/osutils/__init__.py index f0ea498..036a78d 100644 --- a/src/snakeoil/osutils/__init__.py +++ b/src/snakeoil/osutils/__init__.py @@ -34,9 +34,18 @@ pretty quickly. """ __all__ = ( - 'abspath', 'abssymlink', 'ensure_dirs', 'join', 'pjoin', 'listdir_files', - 'listdir_dirs', 'listdir', 'readdir', 'normpath', 'unlink_if_exists', - 'supported_systems', + "abspath", + "abssymlink", + "ensure_dirs", + "join", + "pjoin", + "listdir_files", + "listdir_dirs", + "listdir", + "readdir", + "normpath", + "unlink_if_exists", + "supported_systems", ) import errno @@ -86,14 +95,18 @@ def supported_systems(*systems): ... NotImplementedError: func2 not supported on nonexistent """ + def _decorator(f): def _wrapper(*args, **kwargs): if sys.platform.startswith(systems): return f(*args, **kwargs) else: - raise NotImplementedError('%s not supported on %s' - % (f.__name__, sys.platform)) + raise NotImplementedError( + "%s not supported on %s" % (f.__name__, sys.platform) + ) + return _wrapper + return _decorator @@ -134,7 +147,7 @@ def ensure_dirs(path, gid=-1, uid=-1, mode=0o777, minimal=True): try: um = os.umask(0) # if the dir perms would lack +wx, we have to force it - force_temp_perms = ((mode & 0o300) != 0o300) + force_temp_perms = (mode & 0o300) != 0o300 resets = [] apath = normpath(os.path.abspath(path)) sticky_parent = False @@ -149,7 +162,7 @@ def ensure_dirs(path, gid=-1, uid=-1, mode=0o777, minimal=True): # if it's a subdir, we need +wx at least if apath != base: - sticky_parent = (st.st_mode & stat.S_ISGID) + sticky_parent = st.st_mode & stat.S_ISGID except OSError: # nothing exists. @@ -185,8 +198,7 @@ def ensure_dirs(path, gid=-1, uid=-1, mode=0o777, minimal=True): return False try: - if ((gid != -1 and gid != st.st_gid) or - (uid != -1 and uid != st.st_uid)): + if (gid != -1 and gid != st.st_gid) or (uid != -1 and uid != st.st_uid): os.chown(path, uid, gid) if minimal: if mode != (st.st_mode & mode): @@ -207,9 +219,9 @@ def abssymlink(path): a symlink """ mylink = os.readlink(path) - if mylink[0] != '/': + if mylink[0] != "/": mydir = os.path.dirname(path) - mylink = mydir + '/' + mylink + mylink = mydir + "/" + mylink return normpath(mylink) @@ -256,7 +268,7 @@ def normpath(mypath: str) -> str: `os.path.normpath` only in that it'll convert leading '//' into '/' """ newpath = os.path.normpath(mypath) - double_sep = b'//' if isinstance(newpath, bytes) else '//' + double_sep = b"//" if isinstance(newpath, bytes) else "//" if newpath.startswith(double_sep): return newpath[1:] return newpath @@ -306,9 +318,9 @@ def fallback_access(path, mode, root=0): return mode == (mode & (st.st_mode & 0x7)) -if os.uname()[0].lower() == 'sunos': +if os.uname()[0].lower() == "sunos": access = fallback_access - access.__name__ = 'access' + access.__name__ = "access" else: access = os.access diff --git a/src/snakeoil/osutils/mount.py b/src/snakeoil/osutils/mount.py index 05eb47d..b28853d 100644 --- a/src/snakeoil/osutils/mount.py +++ b/src/snakeoil/osutils/mount.py @@ -1,4 +1,4 @@ -__all__ = ('mount', 'umount') +__all__ = ("mount", "umount") import ctypes import os @@ -40,10 +40,10 @@ MNT_EXPIRE = 4 UMOUNT_NOFOLLOW = 8 -@supported_systems('linux') +@supported_systems("linux") def mount(source, target, fstype, flags, data=None): """Call mount(2); see the man page for details.""" - libc = ctypes.CDLL(find_library('c'), use_errno=True) + libc = ctypes.CDLL(find_library("c"), use_errno=True) source = source.encode() if isinstance(source, str) else source target = target.encode() if isinstance(target, str) else target fstype = fstype.encode() if isinstance(fstype, str) else fstype @@ -52,10 +52,10 @@ def mount(source, target, fstype, flags, data=None): raise OSError(e, os.strerror(e)) -@supported_systems('linux') +@supported_systems("linux") def umount(target, flags=None): """Call umount or umount2; see the umount(2) man page for details.""" - libc = ctypes.CDLL(find_library('c'), use_errno=True) + libc = ctypes.CDLL(find_library("c"), use_errno=True) target = target.encode() if isinstance(target, str) else target args = [] func = libc.umount diff --git a/src/snakeoil/osutils/native_readdir.py b/src/snakeoil/osutils/native_readdir.py index 0efb9c9..b129a9b 100644 --- a/src/snakeoil/osutils/native_readdir.py +++ b/src/snakeoil/osutils/native_readdir.py @@ -3,8 +3,18 @@ import errno import os -from stat import (S_IFBLK, S_IFCHR, S_IFDIR, S_IFIFO, S_IFLNK, S_IFMT, S_IFREG, S_IFSOCK, S_ISDIR, - S_ISREG) +from stat import ( + S_IFBLK, + S_IFCHR, + S_IFDIR, + S_IFIFO, + S_IFLNK, + S_IFMT, + S_IFREG, + S_IFSOCK, + S_ISDIR, + S_ISREG, +) from ..mappings import ProtectedDict @@ -14,6 +24,7 @@ listdir = os.listdir # import cycle. pjoin = os.path.join + def stat_swallow_enoent(path, check, default=False, stat=os.stat): try: return check(stat(path).st_mode) @@ -22,6 +33,7 @@ def stat_swallow_enoent(path, check, default=False, stat=os.stat): return default raise + def listdir_dirs(path, followSymlinks=True): """ Return a list of all subdirectories within a directory @@ -36,11 +48,12 @@ def listdir_dirs(path, followSymlinks=True): pjf = pjoin lstat = os.lstat if followSymlinks: - return [x for x in os.listdir(path) if - stat_swallow_enoent(pjf(path, x), scheck)] + return [ + x for x in os.listdir(path) if stat_swallow_enoent(pjf(path, x), scheck) + ] lstat = os.lstat - return [x for x in os.listdir(path) if - scheck(lstat(pjf(path, x)).st_mode)] + return [x for x in os.listdir(path) if scheck(lstat(pjf(path, x)).st_mode)] + def listdir_files(path, followSymlinks=True): """ @@ -56,24 +69,28 @@ def listdir_files(path, followSymlinks=True): scheck = S_ISREG pjf = pjoin if followSymlinks: - return [x for x in os.listdir(path) if - stat_swallow_enoent(pjf(path, x), scheck)] + return [ + x for x in os.listdir(path) if stat_swallow_enoent(pjf(path, x), scheck) + ] lstat = os.lstat - return [x for x in os.listdir(path) if - scheck(lstat(pjf(path, x)).st_mode)] + return [x for x in os.listdir(path) if scheck(lstat(pjf(path, x)).st_mode)] + # we store this outside the function to ensure that # the strings used are reused, thus avoiding unneeded # allocations -d_type_mapping = ProtectedDict({ - S_IFREG: "file", - S_IFDIR: "directory", - S_IFLNK: "symlink", - S_IFCHR: "chardev", - S_IFBLK: "block", - S_IFSOCK: "socket", - S_IFIFO: "fifo", -}) +d_type_mapping = ProtectedDict( + { + S_IFREG: "file", + S_IFDIR: "directory", + S_IFLNK: "symlink", + S_IFCHR: "chardev", + S_IFBLK: "block", + S_IFSOCK: "socket", + S_IFIFO: "fifo", + } +) + def readdir(path): """ diff --git a/src/snakeoil/pickling.py b/src/snakeoil/pickling.py index e212e6d..9707ecd 100644 --- a/src/snakeoil/pickling.py +++ b/src/snakeoil/pickling.py @@ -3,7 +3,8 @@ pickling convenience module """ __all__ = ( - "iter_stream", "dump_stream", + "iter_stream", + "dump_stream", ) from pickle import dump, load diff --git a/src/snakeoil/process/__init__.py b/src/snakeoil/process/__init__.py index f242269..8966fd6 100644 --- a/src/snakeoil/process/__init__.py +++ b/src/snakeoil/process/__init__.py @@ -44,7 +44,7 @@ def get_exit_status(status: int): if os.WIFSIGNALED(status): return 128 + os.WTERMSIG(status) else: - assert os.WIFEXITED(status), 'Unexpected exit status %r' % status + assert os.WIFEXITED(status), "Unexpected exit status %r" % status return os.WEXITSTATUS(status) @@ -84,16 +84,14 @@ def exit_as_status(status: int): class CommandNotFound(Exception): - def __init__(self, command): - super().__init__(f'failed to find binary: {command!r}') + super().__init__(f"failed to find binary: {command!r}") self.command = command class ProcessNotFound(Exception): - def __init__(self, pid): - super().__init__(f'nonexistent process: {pid}') + super().__init__(f"nonexistent process: {pid}") closerange = os.closerange diff --git a/src/snakeoil/process/namespaces.py b/src/snakeoil/process/namespaces.py index fd6bd74..6a823c6 100644 --- a/src/snakeoil/process/namespaces.py +++ b/src/snakeoil/process/namespaces.py @@ -11,8 +11,15 @@ import socket import subprocess import sys -from ..osutils.mount import (MS_NODEV, MS_NOEXEC, MS_NOSUID, MS_PRIVATE, MS_REC, MS_RELATIME, - MS_SLAVE) +from ..osutils.mount import ( + MS_NODEV, + MS_NOEXEC, + MS_NOSUID, + MS_PRIVATE, + MS_REC, + MS_RELATIME, + MS_SLAVE, +) from ..osutils.mount import mount as _mount from . import exit_as_status @@ -39,7 +46,7 @@ def setns(fd, nstype): fp = open(fd) fd = fp.fileno() - libc = ctypes.CDLL(ctypes.util.find_library('c'), use_errno=True) + libc = ctypes.CDLL(ctypes.util.find_library("c"), use_errno=True) if libc.setns(ctypes.c_int(fd), ctypes.c_int(nstype)) != 0: e = ctypes.get_errno() raise OSError(e, os.strerror(e)) @@ -54,7 +61,7 @@ def unshare(flags): :param flags: Namespaces to unshare; bitwise OR of CLONE_* flags. :raises OSError: if unshare failed. """ - libc = ctypes.CDLL(ctypes.util.find_library('c'), use_errno=True) + libc = ctypes.CDLL(ctypes.util.find_library("c"), use_errno=True) if libc.unshare(ctypes.c_int(flags)) != 0: e = ctypes.get_errno() raise OSError(e, os.strerror(e)) @@ -140,15 +147,13 @@ def create_pidns(): # Make sure to unshare the existing mount point if needed. Some distros # create shared mount points everywhere by default. try: - _mount(None, '/proc', 'proc', MS_PRIVATE | MS_REC) + _mount(None, "/proc", "proc", MS_PRIVATE | MS_REC) except OSError as e: if e.errno != errno.EINVAL: raise # The child needs its own proc mount as it'll be different. - _mount( - 'proc', '/proc', 'proc', - MS_NOSUID | MS_NODEV | MS_NOEXEC | MS_RELATIME) + _mount("proc", "/proc", "proc", MS_NOSUID | MS_NODEV | MS_NOEXEC | MS_RELATIME) if pid := os.fork(): # Mask SIGINT with the assumption that the child will catch & process it. @@ -195,12 +200,13 @@ def create_netns(): # Since we've unshared the net namespace, we need to bring up loopback. # The kernel automatically adds the various ip addresses, so skip that. try: - subprocess.call(['ip', 'link', 'set', 'up', 'lo']) + subprocess.call(["ip", "link", "set", "up", "lo"]) except OSError as e: if e.errno == errno.ENOENT: sys.stderr.write( - 'warning: could not bring up loopback for network; ' - 'install the iproute2 package\n') + "warning: could not bring up loopback for network; " + "install the iproute2 package\n" + ) else: raise @@ -243,16 +249,17 @@ def create_userns(): # For all other errors, abort. They shouldn't happen. raise - with open('/proc/self/setgroups', 'w') as f: - f.write('deny') - with open('/proc/self/uid_map', 'w') as f: - f.write('0 %s 1\n' % uid) - with open('/proc/self/gid_map', 'w') as f: - f.write('0 %s 1\n' % gid) + with open("/proc/self/setgroups", "w") as f: + f.write("deny") + with open("/proc/self/uid_map", "w") as f: + f.write("0 %s 1\n" % uid) + with open("/proc/self/gid_map", "w") as f: + f.write("0 %s 1\n" % gid) -def simple_unshare(mount=True, uts=True, ipc=True, net=False, pid=False, - user=False, hostname=None): +def simple_unshare( + mount=True, uts=True, ipc=True, net=False, pid=False, user=False, hostname=None +): """Simpler helper for setting up namespaces quickly. If support for any namespace type is not available, we'll silently skip it. @@ -278,7 +285,7 @@ def simple_unshare(mount=True, uts=True, ipc=True, net=False, pid=False, # on systems that share the rootfs by default, but allow events in the # parent to propagate down. try: - _mount(None, '/', None, MS_REC | MS_SLAVE) + _mount(None, "/", None, MS_REC | MS_SLAVE) except OSError as e: if e.errno != errno.EINVAL: raise diff --git a/src/snakeoil/process/spawn.py b/src/snakeoil/process/spawn.py index 3413e3f..48b60b1 100644 --- a/src/snakeoil/process/spawn.py +++ b/src/snakeoil/process/spawn.py @@ -3,8 +3,12 @@ subprocess related functionality """ __all__ = [ - "cleanup_pids", "spawn", "spawn_sandbox", "spawn_bash", - "spawn_get_output", "bash_version", + "cleanup_pids", + "spawn", + "spawn_sandbox", + "spawn_bash", + "spawn_get_output", + "bash_version", ] import atexit @@ -17,11 +21,12 @@ from ..mappings import ProtectedDict from ..osutils import access from . import CommandNotFound, closerange, find_binary -BASH_BINARY = find_binary('bash', fallback='/bin/bash') -SANDBOX_BINARY = find_binary('sandbox', fallback='/usr/bin/sandbox') +BASH_BINARY = find_binary("bash", fallback="/bin/bash") +SANDBOX_BINARY = find_binary("sandbox", fallback="/usr/bin/sandbox") try: import resource + max_fd_limit = resource.getrlimit(resource.RLIMIT_NOFILE)[0] except ImportError: max_fd_limit = 256 @@ -36,8 +41,14 @@ def bash_version(force=False): pass try: ret, ver = spawn_get_output( - [BASH_BINARY, '--norc', '--noprofile', '-c', - 'printf ${BASH_VERSINFO[0]}.${BASH_VERSINFO[1]}.${BASH_VERSINFO[2]}']) + [ + BASH_BINARY, + "--norc", + "--noprofile", + "-c", + "printf ${BASH_VERSINFO[0]}.${BASH_VERSINFO[1]}.${BASH_VERSINFO[2]}", + ] + ) if ret == 0: try: ver = ver[0] @@ -54,7 +65,7 @@ def bash_version(force=False): def spawn_bash(mycommand, debug=False, name=None, **kwds): """spawn the command via bash -c""" - args = [BASH_BINARY, '--norc', '--noprofile'] + args = [BASH_BINARY, "--norc", "--noprofile"] if debug: # Print commands and their arguments as they are executed. args.append("-x") @@ -84,6 +95,8 @@ def spawn_sandbox(mycommand, name=None, **kwds): _exithandlers = [] + + def atexit_register(func, *args, **kargs): """Wrapper around atexit.register that is needed in order to track what is registered. For example, when portage restarts itself via @@ -119,6 +132,8 @@ atexit.register(run_exitfuncs) # we exit. spawn() takes care of adding and removing pids to this list # as it creates and cleans up processes. spawned_pids = [] + + def cleanup_pids(pids=None): """reap list of pids if specified, else all children""" @@ -146,8 +161,19 @@ def cleanup_pids(pids=None): pass -def spawn(mycommand, env=None, name=None, fd_pipes=None, returnpid=False, - uid=None, gid=None, groups=None, umask=None, cwd=None, pgid=None): +def spawn( + mycommand, + env=None, + name=None, + fd_pipes=None, + returnpid=False, + uid=None, + gid=None, + groups=None, + umask=None, + cwd=None, + pgid=None, +): """wrapper around execve @@ -177,8 +203,19 @@ def spawn(mycommand, env=None, name=None, fd_pipes=None, returnpid=False, # 'Catch "Exception"' # pylint: disable-msg=W0703 try: - _exec(binary, mycommand, name, fd_pipes, env, gid, groups, - uid, umask, cwd, pgid) + _exec( + binary, + mycommand, + name, + fd_pipes, + env, + gid, + groups, + uid, + umask, + cwd, + pgid, + ) except Exception as e: # We need to catch _any_ exception so that it doesn't # propogate out of this function and cause exiting @@ -228,8 +265,19 @@ def spawn(mycommand, env=None, name=None, fd_pipes=None, returnpid=False, return 0 -def _exec(binary, mycommand, name=None, fd_pipes=None, env=None, gid=None, - groups=None, uid=None, umask=None, cwd=None, pgid=None): +def _exec( + binary, + mycommand, + name=None, + fd_pipes=None, + env=None, + gid=None, + groups=None, + uid=None, + umask=None, + cwd=None, + pgid=None, +): """internal function to handle exec'ing the child process. If it succeeds this function does not return. It might raise an @@ -321,8 +369,15 @@ def _exec(binary, mycommand, name=None, fd_pipes=None, env=None, gid=None, os.execve(binary, myargs, env) -def spawn_get_output(mycommand, spawn_type=None, raw_exit_code=False, collect_fds=(1,), - fd_pipes=None, split_lines=True, **kwds): +def spawn_get_output( + mycommand, + spawn_type=None, + raw_exit_code=False, + collect_fds=(1,), + fd_pipes=None, + split_lines=True, + **kwds, +): """Call spawn, collecting the output to fd's specified in collect_fds list. @@ -386,8 +441,8 @@ def process_exit_code(retval): :return: The exit code if it exit'd, the signal if it died from signalling. """ # If it got a signal, return the signal that was sent. - if retval & 0xff: - return (retval & 0xff) << 8 + if retval & 0xFF: + return (retval & 0xFF) << 8 # Otherwise, return its exit code. return retval >> 8 @@ -399,7 +454,7 @@ class ExecutionFailure(Exception): self.msg = msg def __str__(self): - return f'Execution Failure: {self.msg}' + return f"Execution Failure: {self.msg}" # cached capabilities @@ -411,7 +466,7 @@ def is_sandbox_capable(force=False): return is_sandbox_capable.cached_result except AttributeError: pass - if 'SANDBOX_ACTIVE' in os.environ: + if "SANDBOX_ACTIVE" in os.environ: # we can not spawn a sandbox inside another one res = False elif not (os.path.isfile(SANDBOX_BINARY) and access(SANDBOX_BINARY, os.X_OK)): @@ -432,5 +487,5 @@ def is_userpriv_capable(force=False): return is_userpriv_capable.cached_result except AttributeError: pass - res = is_userpriv_capable.cached_result = (os.getuid() == 0) + res = is_userpriv_capable.cached_result = os.getuid() == 0 return res diff --git a/src/snakeoil/sequences.py b/src/snakeoil/sequences.py index c041136..b80b101 100644 --- a/src/snakeoil/sequences.py +++ b/src/snakeoil/sequences.py @@ -1,9 +1,14 @@ """sequence related operations and classes""" __all__ = ( - 'unstable_unique', 'stable_unique', 'iter_stable_unique', - 'iflatten_instance', 'iflatten_func', 'ChainedLists', 'predicate_split', - 'split_negations', + "unstable_unique", + "stable_unique", + "iter_stable_unique", + "iflatten_instance", + "iflatten_func", + "ChainedLists", + "predicate_split", + "split_negations", ) from typing import Any, Callable, Iterable, Type @@ -89,7 +94,9 @@ def iter_stable_unique(iterable): break -def iflatten_instance(l: Iterable, skip_flattening: Iterable[Type] = (str, bytes)) -> Iterable: +def iflatten_instance( + l: Iterable, skip_flattening: Iterable[Type] = (str, bytes) +) -> Iterable: """collapse [[1],2] into [1,2] :param skip_flattening: list of classes to not descend through @@ -103,9 +110,10 @@ def iflatten_instance(l: Iterable, skip_flattening: Iterable[Type] = (str, bytes try: while True: x = next(iters) - if (hasattr(x, '__iter__') and not ( - isinstance(x, skip_flattening) or ( - isinstance(x, (str, bytes)) and len(x) == 1))): + if hasattr(x, "__iter__") and not ( + isinstance(x, skip_flattening) + or (isinstance(x, (str, bytes)) and len(x) == 1) + ): iters.appendleft(x) else: yield x @@ -128,7 +136,7 @@ def iflatten_func(l: Iterable, skip_func: Callable[[Any], bool]) -> Iterable: try: while True: x = next(iters) - if hasattr(x, '__iter__') and not skip_func(x): + if hasattr(x, "__iter__") and not skip_func(x): iters.appendleft(x) else: yield x @@ -164,6 +172,7 @@ class ChainedLists: ... TypeError: not mutable """ + __slots__ = ("_lists", "__weakref__") def __init__(self, *lists): @@ -258,7 +267,7 @@ def predicate_split(func, stream, key=None): def split_negations(iterable, func=str): - """"Split an iterable into negative and positive elements. + """ "Split an iterable into negative and positive elements. :param iterable: iterable targeted for splitting :param func: wrapper method to modify tokens @@ -267,7 +276,7 @@ def split_negations(iterable, func=str): """ neg, pos = [], [] for token in iterable: - if token[0] == '-': + if token[0] == "-": if len(token) == 1: raise ValueError("'-' negation without a token") token = token[1:] @@ -281,7 +290,7 @@ def split_negations(iterable, func=str): def split_elements(iterable, func=str): - """"Split an iterable into negative, neutral, and positive elements. + """ "Split an iterable into negative, neutral, and positive elements. :param iterable: iterable targeted for splitting :param func: wrapper method to modify tokens @@ -289,11 +298,11 @@ def split_elements(iterable, func=str): :return: Tuple containing negative, neutral, and positive element tuples, respectively. """ neg, neu, pos = [], [], [] - token_map = {'-': neg, '+': pos} + token_map = {"-": neg, "+": pos} for token in iterable: if token[0] in token_map: if len(token) == 1: - raise ValueError('%r without a token' % (token[0],)) + raise ValueError("%r without a token" % (token[0],)) l = token_map[token[0]] token = token[1:] else: diff --git a/src/snakeoil/stringio.py b/src/snakeoil/stringio.py index c17db92..1a392eb 100644 --- a/src/snakeoil/stringio.py +++ b/src/snakeoil/stringio.py @@ -27,7 +27,7 @@ is usable under both py2k and py3k. """ # TODO: deprecated, remove in 0.9.0 -__all__ = ('text_readonly', 'bytes_readonly') +__all__ = ("text_readonly", "bytes_readonly") import io diff --git a/src/snakeoil/strings.py b/src/snakeoil/strings.py index 7d7b2a8..d0dc58a 100644 --- a/src/snakeoil/strings.py +++ b/src/snakeoil/strings.py @@ -2,10 +2,10 @@ from .demandload import demand_compile_regexp -demand_compile_regexp('_whitespace_regex', r'^(?P<indent>\s+)') +demand_compile_regexp("_whitespace_regex", r"^(?P<indent>\s+)") -def pluralism(obj, none=None, singular='', plural='s'): +def pluralism(obj, none=None, singular="", plural="s"): """Return singular or plural suffix depending on object's length or value.""" # default to plural for empty objects, e.g. there are 0 repos if none is None: @@ -27,16 +27,16 @@ def pluralism(obj, none=None, singular='', plural='s'): def doc_dedent(s): """Support dedenting docstrings with initial line having no indentation.""" try: - lines = s.split('\n') + lines = s.split("\n") except AttributeError: - raise TypeError(f'{s!r} is not a string') + raise TypeError(f"{s!r} is not a string") if lines: # find first line with an indent if one exists for line in lines: if mo := _whitespace_regex.match(line): - indent = mo.group('indent') + indent = mo.group("indent") break else: - indent = '' + indent = "" len_i = len(indent) - return '\n'.join(x[len_i:] if x.startswith(indent) else x for x in lines) + return "\n".join(x[len_i:] if x.startswith(indent) else x for x in lines) diff --git a/src/snakeoil/tar.py b/src/snakeoil/tar.py index 7daa9eb..74061a5 100644 --- a/src/snakeoil/tar.py +++ b/src/snakeoil/tar.py @@ -43,13 +43,30 @@ class TarInfo(tarfile.TarInfo): :ivar uname: same as TarInfo.uname, just interned via a property. """ - if not hasattr(tarfile.TarInfo, '__slots__'): + if not hasattr(tarfile.TarInfo, "__slots__"): __slots__ = ( - "name", "mode", "uid", "gid", "size", "mtime", "chksum", "type", - "linkname", "_uname", "_gname", "devmajor", "devminor", "prefix", - "offset", "offset_data", "_buf", "sparse", "_link_target") + "name", + "mode", + "uid", + "gid", + "size", + "mtime", + "chksum", + "type", + "linkname", + "_uname", + "_gname", + "devmajor", + "devminor", + "prefix", + "offset", + "offset_data", + "_buf", + "sparse", + "_link_target", + ) else: - __slots__ = ('_buf', '_uname', '_gname') + __slots__ = ("_buf", "_uname", "_gname") def get_buf(self): return self.tobuf() diff --git a/src/snakeoil/test/__init__.py b/src/snakeoil/test/__init__.py index bb9381f..b93094a 100644 --- a/src/snakeoil/test/__init__.py +++ b/src/snakeoil/test/__init__.py @@ -13,18 +13,20 @@ from snakeoil import klass def random_str(length): """Return a random string of specified length.""" - return ''.join(random.choices(string.ascii_letters + string.digits, k=length)) + return "".join(random.choices(string.ascii_letters + string.digits, k=length)) def coverage(): """Extract coverage instance (if it exists) from the current running context.""" cov = None import inspect + try: import coverage + frame = inspect.currentframe() while frame is not None: - cov = getattr(frame.f_locals.get('self'), 'coverage', None) + cov = getattr(frame.f_locals.get("self"), "coverage", None) if isinstance(cov, coverage.coverage): break frame = frame.f_back @@ -33,7 +35,7 @@ def coverage(): return cov -@klass.patch('os._exit') +@klass.patch("os._exit") def _os_exit(orig_exit, val): """Monkeypatch os._exit() to save coverage data before exit.""" cov = coverage() @@ -51,7 +53,9 @@ def protect_process(functor, name=None): if os.environ.get(_PROTECT_ENV_VAR, False): return functor(self) if name is None: - name = f"{self.__class__.__module__}.{self.__class__.__name__}.{method_name}" + name = ( + f"{self.__class__.__module__}.{self.__class__.__name__}.{method_name}" + ) runner_path = __file__ if runner_path.endswith(".pyc") or runner_path.endswith(".pyo"): runner_path = runner_path.rsplit(".", maxsplit=1)[0] + ".py" @@ -59,11 +63,18 @@ def protect_process(functor, name=None): try: os.environ[_PROTECT_ENV_VAR] = "yes" args = [sys.executable, __file__, name] - p = subprocess.Popen(args, shell=False, env=os.environ.copy(), - stdout=subprocess.PIPE, stderr=subprocess.STDOUT) + p = subprocess.Popen( + args, + shell=False, + env=os.environ.copy(), + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + ) stdout, _ = p.communicate() ret = p.wait() - assert ret == 0, f"subprocess run: {args!r}\nnon zero exit: {ret}\nstdout:\n{stdout}" + assert ( + ret == 0 + ), f"subprocess run: {args!r}\nnon zero exit: {ret}\nstdout:\n{stdout}" finally: if wipe: os.environ.pop(_PROTECT_ENV_VAR, None) @@ -71,7 +82,7 @@ def protect_process(functor, name=None): for x in ("__doc__", "__name__"): if hasattr(functor, x): setattr(_inner_run, x, getattr(functor, x)) - method_name = getattr(functor, '__name__', None) + method_name = getattr(functor, "__name__", None) return _inner_run @@ -84,4 +95,4 @@ def hide_imports(*import_names: str): raise ImportError() return orig_import(name, *args, **kwargs) - return patch('builtins.__import__', side_effect=mock_import) + return patch("builtins.__import__", side_effect=mock_import) diff --git a/src/snakeoil/test/argparse_helpers.py b/src/snakeoil/test/argparse_helpers.py index 152e335..bfcdc3b 100644 --- a/src/snakeoil/test/argparse_helpers.py +++ b/src/snakeoil/test/argparse_helpers.py @@ -56,25 +56,24 @@ class Color(FormatterObject): self.color = color def __repr__(self): - return f'<Color: mode - {self.mode}; color - {self.color}>' + return f"<Color: mode - {self.mode}; color - {self.color}>" class Reset(FormatterObject): __inst_caching__ = True def __repr__(self): - return '<Reset>' + return "<Reset>" class Bold(FormatterObject): __inst_caching__ = True def __repr__(self): - return '<Bold>' + return "<Bold>" class ListStream(list): - def write(self, *args): stringlist = [] objectlist = [] @@ -82,13 +81,16 @@ class ListStream(list): if isinstance(arg, bytes): stringlist.append(arg) else: - objectlist.append(b''.join(stringlist)) + objectlist.append(b"".join(stringlist)) stringlist = [] objectlist.append(arg) - objectlist.append(b''.join(stringlist)) + objectlist.append(b"".join(stringlist)) # We use len because boolean ops shortcircuit - if (len(self) and isinstance(self[-1], bytes) and - isinstance(objectlist[0], bytes)): + if ( + len(self) + and isinstance(self[-1], bytes) + and isinstance(objectlist[0], bytes) + ): self[-1] = self[-1] + objectlist.pop(0) self.extend(objectlist) @@ -97,7 +99,6 @@ class ListStream(list): class FakeStreamFormatter(PlainTextFormatter): - def __init__(self): super().__init__(ListStream([])) self.reset = Reset() @@ -108,15 +109,15 @@ class FakeStreamFormatter(PlainTextFormatter): self.stream = ListStream([]) def fg(self, color=None): - return Color('fg', color) + return Color("fg", color) def bg(self, color=None): - return Color('bg', color) + return Color("bg", color) def get_text_stream(self): - return b''.join( - (x for x in self.stream - if not isinstance(x, FormatterObject))).decode('ascii') + return b"".join( + (x for x in self.stream if not isinstance(x, FormatterObject)) + ).decode("ascii") class ArgParseMixin: @@ -148,7 +149,7 @@ class ArgParseMixin: except Error as e: assert message == e.message else: - raise AssertionError('no error triggered') + raise AssertionError("no error triggered") def assertExit(self, status, message, *args, **kwargs): """Pass args, assert they trigger the right exit condition.""" @@ -158,7 +159,7 @@ class ArgParseMixin: assert message == e.message.strip() assert status == e.status else: - raise AssertionError('no exit triggered') + raise AssertionError("no exit triggered") def assertOut(self, out, *args, **kwargs): """Like :obj:`assertOutAndErr` but without err.""" @@ -182,17 +183,25 @@ class ArgParseMixin: main = self.get_main(options) main(options, outformatter, errformatter) diffs = [] - for name, strings, formatter in [('out', out, outformatter), - ('err', err, errformatter)]: + for name, strings, formatter in [ + ("out", out, outformatter), + ("err", err, errformatter), + ]: actual = formatter.get_text_stream() if strings: - expected = '\n'.join(strings) + expected = "\n".join(strings) else: - expected = '' + expected = "" if expected != actual: - diffs.extend(difflib.unified_diff( - strings, actual.split('\n')[:-1], - 'expected %s' % (name,), 'actual', lineterm='')) + diffs.extend( + difflib.unified_diff( + strings, + actual.split("\n")[:-1], + "expected %s" % (name,), + "actual", + lineterm="", + ) + ) if diffs: - raise AssertionError('\n' + '\n'.join(diffs)) + raise AssertionError("\n" + "\n".join(diffs)) return options diff --git a/src/snakeoil/test/eq_hash_inheritance.py b/src/snakeoil/test/eq_hash_inheritance.py index 5012f9d..eaa42a3 100644 --- a/src/snakeoil/test/eq_hash_inheritance.py +++ b/src/snakeoil/test/eq_hash_inheritance.py @@ -3,7 +3,7 @@ from . import mixins class Test(mixins.TargetedNamespaceWalker, mixins.KlassWalker): - target_namespace = 'snakeoil' + target_namespace = "snakeoil" singleton = object() @@ -26,8 +26,8 @@ class Test(mixins.TargetedNamespaceWalker, mixins.KlassWalker): # object sets __hash__/__eq__, which isn't usually # intended to be inherited/reused continue - eq = getattr(parent, '__eq__', self.singleton) - h = getattr(parent, '__hash__', self.singleton) + eq = getattr(parent, "__eq__", self.singleton) + h = getattr(parent, "__hash__", self.singleton) if eq == object.__eq__ and h == object.__hash__: continue if eq and h: @@ -37,10 +37,11 @@ class Test(mixins.TargetedNamespaceWalker, mixins.KlassWalker): # pylint: disable=undefined-loop-variable # 'parent' is guaranteed to be defined due to the 'else' clause above - assert getattr(cls, '__hash__') is not None, ( + assert getattr(cls, "__hash__") is not None, ( f"class '{cls.__module__}.{cls.__name__}' had its __hash__ reset, " "while it would've inherited __hash__ from parent " f"'{parent.__module__}.{parent.__name__}'; this occurs in py3k when " "__eq__ is defined alone. If this is desired behaviour, set " "__hash__intentionally_disabled__ to True to explicitly ignore this" - " class") + " class" + ) diff --git a/src/snakeoil/test/mixins.py b/src/snakeoil/test/mixins.py index aa66839..0648de1 100644 --- a/src/snakeoil/test/mixins.py +++ b/src/snakeoil/test/mixins.py @@ -16,23 +16,27 @@ class PythonNamespaceWalker: # This is for py3.2/PEP3149; dso's now have the interp + major/minor embedded # in the name. # TODO: update this for pypy's naming - abi_target = 'cpython-%i%i' % tuple(sys.version_info[:2]) + abi_target = "cpython-%i%i" % tuple(sys.version_info[:2]) - module_blacklist = frozenset({ - 'snakeoil.cli.arghparse', 'snakeoil.pickling', - }) + module_blacklist = frozenset( + { + "snakeoil.cli.arghparse", + "snakeoil.pickling", + } + ) def _default_module_blacklister(self, target): - return target in self.module_blacklist or target.startswith('snakeoil.dist') + return target in self.module_blacklist or target.startswith("snakeoil.dist") def walk_namespace(self, namespace, **kwds): - location = os.path.abspath(os.path.dirname( - self.poor_mans_load(namespace).__file__)) - return self.get_modules(self.recurse(location), namespace=namespace, - **kwds) - - def get_modules(self, feed, namespace=None, blacklist_func=None, - ignore_failed_imports=None): + location = os.path.abspath( + os.path.dirname(self.poor_mans_load(namespace).__file__) + ) + return self.get_modules(self.recurse(location), namespace=namespace, **kwds) + + def get_modules( + self, feed, namespace=None, blacklist_func=None, ignore_failed_imports=None + ): if ignore_failed_imports is None: ignore_failed_imports = self.ignore_all_import_failures if namespace is None: @@ -57,7 +61,7 @@ class PythonNamespaceWalker: raise def recurse(self, location, valid_namespace=True): - if os.path.dirname(location) == '__pycache__': + if os.path.dirname(location) == "__pycache__": # Shouldn't be possible, but make sure we avoid this if it manages # to occur. return @@ -78,10 +82,13 @@ class PythonNamespaceWalker: # file disappeared under our feet... lock file from # trial can cause this. ignore. import logging - logging.debug("file %r disappeared under our feet, ignoring", - os.path.join(location, x)) - seen = set(['__init__']) + logging.debug( + "file %r disappeared under our feet, ignoring", + os.path.join(location, x), + ) + + seen = set(["__init__"]) for x, st in stats: if not (x.startswith(".") or x.endswith("~")) and stat.S_ISREG(st): if x.endswith((".py", ".pyc", ".pyo", ".so")): @@ -89,8 +96,8 @@ class PythonNamespaceWalker: # Ensure we're not looking at a >=py3k .so which injects # the version name in... if y not in seen: - if '.' in y and x.endswith('.so'): - y, abi = x.rsplit('.', 1) + if "." in y and x.endswith(".so"): + y, abi = x.rsplit(".", 1) if abi != self.abi_target: continue seen.add(y) @@ -135,6 +142,7 @@ class TargetedNamespaceWalker(PythonNamespaceWalker): for _mod in self.walk_namespace(namespace): pass + class _classWalker: cls_blacklist = frozenset() @@ -173,7 +181,6 @@ class _classWalker: class SubclassWalker(_classWalker): - def walk_derivatives(self, cls, seen=None): if len(inspect.signature(cls.__subclasses__).parameters) != 0: return @@ -193,7 +200,6 @@ class SubclassWalker(_classWalker): class KlassWalker(_classWalker): - def walk_derivatives(self, cls, seen=None): if len(inspect.signature(cls.__subclasses__).parameters) != 0: return diff --git a/src/snakeoil/test/modules.py b/src/snakeoil/test/modules.py index 6b3a6a1..0ae116b 100644 --- a/src/snakeoil/test/modules.py +++ b/src/snakeoil/test/modules.py @@ -3,12 +3,12 @@ from . import mixins class ExportedModules(mixins.PythonNamespaceWalker): - target_namespace = 'snakeoil' + target_namespace = "snakeoil" def test__all__accuracy(self): failures = [] for module in self.walk_namespace(self.target_namespace): - for target in getattr(module, '__all__', ()): + for target in getattr(module, "__all__", ()): if not hasattr(module, target): failures.append((module, target)) assert not failures, f"nonexistent __all__ targets spotted: {failures}" diff --git a/src/snakeoil/test/slot_shadowing.py b/src/snakeoil/test/slot_shadowing.py index 3e260c2..fac6619 100644 --- a/src/snakeoil/test/slot_shadowing.py +++ b/src/snakeoil/test/slot_shadowing.py @@ -5,7 +5,7 @@ from . import mixins class SlotShadowing(mixins.TargetedNamespaceWalker, mixins.SubclassWalker): - target_namespace = 'snakeoil' + target_namespace = "snakeoil" err_if_slots_is_str = True err_if_slots_is_mutable = True @@ -22,20 +22,20 @@ class SlotShadowing(mixins.TargetedNamespaceWalker, mixins.SubclassWalker): @staticmethod def mk_name(kls): - return f'{kls.__module__}.{kls.__name__}' + return f"{kls.__module__}.{kls.__name__}" def _should_ignore(self, kls): return self.mk_name(kls).split(".")[0] != self.target_namespace def run_check(self, kls): - if getattr(kls, '__slotting_intentionally_disabled__', False): + if getattr(kls, "__slotting_intentionally_disabled__", False): return slotting = {} raw_slottings = {} for parent in self.recurse_parents(kls): - slots = getattr(parent, '__slots__', None) + slots = getattr(parent, "__slots__", None) if slots is None: continue @@ -49,14 +49,15 @@ class SlotShadowing(mixins.TargetedNamespaceWalker, mixins.SubclassWalker): for slot in slots: slotting.setdefault(slot, parent) - slots = getattr(kls, '__slots__', None) + slots = getattr(kls, "__slots__", None) if slots is None and not slotting: return if isinstance(slots, str): if self.err_if_slots_is_str: pytest.fail( - f"cls {kls!r}; slots is {slots!r} (should be a tuple or list)") + f"cls {kls!r}; slots is {slots!r} (should be a tuple or list)" + ) slots = (slots,) if slots is None: @@ -64,8 +65,7 @@ class SlotShadowing(mixins.TargetedNamespaceWalker, mixins.SubclassWalker): if not isinstance(slots, tuple): if self.err_if_slots_is_mutable: - pytest.fail( - f"cls {kls!r}; slots is {slots!r}- - should be a tuple") + pytest.fail(f"cls {kls!r}; slots is {slots!r}- - should be a tuple") slots = tuple(slots) if slots is None or (slots and slots in raw_slottings): @@ -74,9 +74,11 @@ class SlotShadowing(mixins.TargetedNamespaceWalker, mixins.SubclassWalker): # daftly copied the parents... thus defeating the purpose. pytest.fail( f"cls {kls!r}; slots is {slots!r}, seemingly inherited from " - f"{raw_slottings[slots]!r}; the derivative class should be __slots__ = ()") + f"{raw_slottings[slots]!r}; the derivative class should be __slots__ = ()" + ) for slot in slots: if slot in slotting: pytest.fail( - f"cls {kls!r}; slot {slot!r} was already defined at {slotting[slot]!r}") + f"cls {kls!r}; slot {slot!r} was already defined at {slotting[slot]!r}" + ) diff --git a/src/snakeoil/version.py b/src/snakeoil/version.py index fa3fc96..2fd65d2 100644 --- a/src/snakeoil/version.py +++ b/src/snakeoil/version.py @@ -28,43 +28,46 @@ def get_version(project, repo_file, api_version=None): version_info = None if api_version is None: try: - api_version = getattr(import_module(project), '__version__') + api_version = getattr(import_module(project), "__version__") except ImportError: - raise ValueError(f'no {project} module in the syspath') + raise ValueError(f"no {project} module in the syspath") try: - version_info = getattr( - import_module(f'{project}._verinfo'), 'version_info') + version_info = getattr(import_module(f"{project}._verinfo"), "version_info") except ImportError: # we're probably in a git repo path = os.path.dirname(os.path.abspath(repo_file)) version_info = get_git_version(path) if version_info is None: - s = '' - elif version_info['tag'] == api_version: + s = "" + elif version_info["tag"] == api_version: s = f" -- released {version_info['date']}" else: - rev = version_info['rev'][:7] - date = version_info['date'] - commits = version_info.get('commits', None) - commits = f'-{commits}' if commits is not None else '' - s = f'{commits}-g{rev} -- {date}' + rev = version_info["rev"][:7] + date = version_info["date"] + commits = version_info.get("commits", None) + commits = f"-{commits}" if commits is not None else "" + s = f"{commits}-g{rev} -- {date}" - _ver = f'{project} {api_version}{s}' + _ver = f"{project} {api_version}{s}" return _ver def _run_git(path, cmd): env = dict(os.environ) - for key in env.copy(): # pragma: no cover + for key in env.copy(): # pragma: no cover if key.startswith("LC_"): del env[key] env["LC_CTYPE"] = "C" env["LC_ALL"] = "C" r = subprocess.Popen( - ['git'] + list(cmd), stdout=subprocess.PIPE, env=env, - stderr=subprocess.DEVNULL, cwd=path) + ["git"] + list(cmd), + stdout=subprocess.PIPE, + env=env, + stderr=subprocess.DEVNULL, + cwd=path, + ) stdout = r.communicate()[0] return stdout, r.returncode @@ -83,21 +86,20 @@ def get_git_version(path): tag = _get_git_tag(path, data[0]) # get number of commits since most recent tag - stdout, ret = _run_git(path, ['describe', '--tags', '--abbrev=0']) + stdout, ret = _run_git(path, ["describe", "--tags", "--abbrev=0"]) prev_tag = None commits = None if ret == 0: prev_tag = stdout.decode().strip() - stdout, ret = _run_git( - path, ['log', '--oneline', f'{prev_tag}..HEAD']) + stdout, ret = _run_git(path, ["log", "--oneline", f"{prev_tag}..HEAD"]) if ret == 0: commits = len(stdout.decode().splitlines()) return { - 'rev': data[0], - 'date': data[1], - 'tag': tag, - 'commits': commits, + "rev": data[0], + "date": data[1], + "tag": tag, + "commits": commits, } except EnvironmentError as exc: # ENOENT is thrown when the git binary can't be found. @@ -107,14 +109,14 @@ def get_git_version(path): def _get_git_tag(path, rev): - stdout, _ = _run_git(path, ['name-rev', '--tag', rev]) + stdout, _ = _run_git(path, ["name-rev", "--tag", rev]) tag = stdout.decode().split() if len(tag) != 2: return None tag = tag[1] if not tag.startswith("tags/"): return None - tag = tag[len("tags/"):] + tag = tag[len("tags/") :] if tag.endswith("^0"): tag = tag[:-2] if tag.startswith("v"): diff --git a/src/snakeoil/weakrefs.py b/src/snakeoil/weakrefs.py index 1fb7e0b..83a79c0 100644 --- a/src/snakeoil/weakrefs.py +++ b/src/snakeoil/weakrefs.py @@ -13,7 +13,6 @@ def finalize_instance(obj, weakref_inst): class WeakRefProxy(BaseDelayedObject): - def __instantiate_proxy_instance__(self): obj = BaseDelayedObject.__instantiate_proxy_instance__(self) weakref = ref(self, partial(finalize_instance, obj)) |