Source code for src.util.dataclass

"""Extensions to :py:mod:`dataclasses`, for streamlined class definition.
"""
import collections
import copy
import dataclasses
import enum
import functools
import re
import typing
from . import basic
from . import exceptions

import logging
_log = logging.getLogger(__name__)

[docs]class RegexPatternBase(): """Dummy parent class for :class:`RegexPattern` and :class:`ChainedRegexPattern`. """ pass
[docs]class RegexPattern(collections.UserDict, RegexPatternBase): """Wraps :py:class:`re.Pattern` with more convenience methods. Extracts values of named fields from a string by parsing it with a regex with named capture groups, and stores those values in a dict. """ def __init__(self, regex, defaults=None, input_field=None, match_error_filter=None): """Constructor. Args: regex: str or :py:class:`re.Pattern`: regex to use for string parsing. Should contain named match groups corresponding to the fields to parse. defaults: dict, optional. If supplied, any fields not matched by the regex will be set equal to their values here. input_field: str, optional. If supplied, add a field to the match with the supplied name which will be set equal to the contents of the input string on a successful match. match_error_filter: optional, bool or :class:`RegexPattern` or :class:`ChainedRegexPattern`. If supplied, suppresses raising ValueErrors when match() fails. If boolean or none, either always or never raise ValueError. If a RegexPattern, try matching the input string that caused the failed match against it. If it matches, do not raise an error. Attributes: data: dict, either empty when unmatched, or containing the contents of the match. From :py:class:`collections.UserDict`. fields: frozenset of fields matched by the pattern. Consists of the *union* of named match groups in regex, and *all* keys in defaults. input_string: Contains string that was input to last call of match(), whether successful or not. """ try: if isinstance(regex, re.Pattern): self.regex = regex else: self.regex = re.compile(regex, re.VERBOSE) except re.error as exc: raise ValueError('Malformed input regex.') from exc if self.regex.groups != len(self.regex.groupindex): # _log.warning("Unnamed match groups in regex") pass if self.regex.groups == 0: # _log.warning("No named match groups in regex") pass if not defaults: self._defaults = dict() else: self._defaults = defaults.copy() self.input_field = input_field self._match_error_filter = match_error_filter self._update_fields()
[docs] def clear(self): """Erase an existing match. """ self.data = dict() self.input_string = "" self.is_matched = False
[docs] def _update_fields(self): self.regex_fields = frozenset(self.regex.groupindex.keys()) self.fields = self.regex_fields.union(self._defaults.keys()) if self.input_field: self.fields = self.fields.union((self.input_field, )) self.clear()
[docs] def update_defaults(self, d): """Update the default values used for the match with the values in d. """ if d: self._defaults.update(d) self._update_fields()
[docs] def match(self, str_, *args): self.clear() # to be safe self.input_string = str_ m = self.regex.fullmatch(str_, *args) if not m: self.is_matched = False if hasattr(self._match_error_filter, 'match'): try: self._match_error_filter.match(str_, *args) except Exception as exc: raise exceptions.RegexParseError( f"Couldn't match {str_} against {self.regex}.") raise exceptions.RegexSuppressedError(str_) elif self._match_error_filter: raise exceptions.RegexSuppressedError(str_) else: raise exceptions.RegexParseError( f"Couldn't match {str_} against {self.regex}.") else: self.data = m.groupdict(default=NOTSET) for k,v in self._defaults.items(): if self.data.get(k, NOTSET) is NOTSET: self.data[k] = v if self.input_field: self.data[self.input_field] = m.string self._validate_match(m) if any(self.data[f] is NOTSET for f in self.fields): bad_names = [f for f in self.fields if self.data[f] is NOTSET] raise exceptions.RegexParseError((f"Couldn't match the " f"following fields in {str_}: " + ', '.join(bad_names) )) self.is_matched = True
[docs] def _validate_match(self, match_obj): """Hook for post-processing of match, running after all fields are assigned but before final check that all fields are set. """ pass
def __str__(self): if not self.is_matched: str_ = ', '.join(self.fields) else: str_ = ', '.join([f'{k}={v}' for k,v in self.data.items()]) return f"<{self.__class__.__name__}({str_})>" def __copy__(self): if hasattr(self._match_error_filter, 'copy'): match_error_filter_copy = self._match_error_filter.copy() else: # bool or None match_error_filter_copy = self._match_error_filter obj = self.__class__( self.regex.pattern, defaults=self._defaults.copy(), input_field=self.input_field, match_error_filter=match_error_filter_copy, ) obj.data = self.data.copy() return obj def __deepcopy__(self, memo): obj = self.__class__( copy.deepcopy(self.regex.pattern, memo), defaults=copy.deepcopy(self._defaults, memo), input_field=copy.deepcopy(self.input_field, memo), match_error_filter=copy.deepcopy(self._match_error_filter, memo) ) obj.data = copy.deepcopy(self.data, memo) return obj
[docs]class RegexPatternWithTemplate(RegexPattern): """Adds formatted output to RegexPattern. Args: template: str, optional. Template string to use for formatting contents of match in format() method. Contents of the matched fields will be subsituted using the {}-syntax of python string formatting. Other arguments the same """ def __init__(self, regex, defaults=None, input_field=None, match_error_filter=None, template=None, log=_log): super(RegexPatternWithTemplate, self).__init__(regex, defaults=defaults, input_field=input_field, match_error_filter=match_error_filter) self.template = template for f in self.fields: if f not in self.template: log.warning("Field %s not included in output.", f)
[docs] def format(self): if self.template is None: raise AssertionError('Template string needs to be defined.') if not self.is_matched: raise ValueError('No match') return self.template.format(**self.data)
def __copy__(self): if hasattr(self._match_error_filter, 'copy'): match_error_filter_copy = self._match_error_filter.copy() else: # bool or None match_error_filter_copy = self._match_error_filter obj = self.__class__( self.regex.pattern, defaults=self._defaults.copy(), input_field=self.input_field, match_error_filter=match_error_filter_copy, template=self.template ) obj.data = self.data.copy() return obj def __deepcopy__(self, memo): obj = self.__class__( copy.deepcopy(self.regex.pattern, memo), defaults=copy.deepcopy(self._defaults, memo), input_field=copy.deepcopy(self.input_field, memo), match_error_filter=copy.deepcopy(self._match_error_filter, memo), template=copy.deepcopy(self.template, memo) ) obj.data = copy.deepcopy(self.data, memo) return obj
[docs]class ChainedRegexPattern(RegexPatternBase): """Class which takes an 'or' of multiple RegexPatterns. Matches are attempted on the supplied RegexPatterns in order, with the first one that succeeds determining the returned answer. Public methods work the same as on RegexPattern. """ def __init__(self, *string_patterns, defaults=None, input_field=None, match_error_filter=None): # NB, changes attributes on patterns passed as arguments, so # once created they can't be used on their own new_pats = [] for pat in string_patterns: if isinstance(pat, RegexPattern): new_pats.append(pat) elif isinstance(pat, ChainedRegexPattern): new_pats.extend(pat._patterns) else: raise ValueError("Bad input") self._patterns = tuple(string_patterns) if input_field: self.input_field = input_field self._match_error_filter = match_error_filter for pat in self._patterns: if defaults: pat.update_defaults(defaults) if input_field: pat.input_field = input_field pat._match_error_filter = None pat._update_fields() self._update_fields() @property def is_matched(self): return (self._match >= 0) @property def data(self): if self.is_matched: return self._patterns[self._match].data else: return dict()
[docs] def clear(self): for pat in self._patterns: pat.clear() self._match = -1 self.input_string = ""
[docs] def _update_fields(self): self.fields = self._patterns[0].fields for pat in self._patterns: if pat.fields != self.fields: raise ValueError("Incompatible fields.") self.clear()
[docs] def update_defaults(self, d): if d: for pat in self._patterns: pat.update_defaults(d) self._update_fields()
[docs] def match(self, str_, *args): self.clear() self.input_string = str_ for i, pat in enumerate(self._patterns): try: pat.match(str_, *args) if not pat.is_matched: raise ValueError() self._match = i except ValueError: continue if not self.is_matched: if hasattr(self._match_error_filter, 'match'): try: self._match_error_filter.match(str_, *args) except Exception as exc: raise exceptions.RegexParseError((f"Couldn't match {str_} " f"against any pattern in {self.__class__.__name__}.")) raise exceptions.RegexSuppressedError(str_) elif self._match_error_filter: raise exceptions.RegexSuppressedError(str_) else: raise exceptions.RegexParseError((f"Couldn't match {str_} " f"against any pattern in {self.__class__.__name__}."))
def __str__(self): if not self.is_matched: str_ = ', '.join(self.fields) else: str_ = ', '.join([f'{k}={v}' for k,v in self.data.items()]) return f"<{self.__class__.__name__}({str_})>"
[docs] def format(self): if not self.is_matched: raise ValueError('No match') return self._patterns[self._match].format()
def __copy__(self): new_pats = (pat.copy() for pat in self._patterns) return self.__class__( *new_pats, match_error_filter=self._match_error_filter.copy() ) def __deepcopy__(self, memo): new_pats = (copy.deepcopy(pat, memo) for pat in self._patterns) return self.__class__( *new_pats, match_error_filter=copy.deepcopy(self._match_error_filter, memo) )
# --------------------------------------------------------- NOTSET = basic.sentinel_object_factory('NotSet') NOTSET.__doc__ = """ Sentinel object to detect uninitialized values, in cases where ``None`` is a valid value. """ MANDATORY = basic.sentinel_object_factory('Mandatory') MANDATORY.__doc__ = """ Sentinel object to mark :func:`mdtf_dataclass` fields that do not take a default value. This is a workaround to avoid errors with non-default fields coming after default fields in the dataclass-generated ``__init__`` method under `inheritance <https://docs.python.org/3/library/dataclasses.html#inheritance>`__: we use the second solution described in `<https://stackoverflow.com/a/53085935>`__. """
[docs]def _mdtf_dataclass_get_field_types(obj, f): """Common functionality for :func:`_mdtf_dataclass_type_coercion` and :func:`_mdtf_dataclass_type_check`. Given a :py:class:`datacalsses.Field` object *f*, return either a tuple of the type its value should be coerced to and a tuple of the valid types its value can have, or (None, None) to signal a case we don't handle. """ if not f.init: # ignore fields that aren't handled at init return (None, None) value = getattr(obj, f.name) # ignore unset field values, regardless of type if value is None or value is NOTSET: return (None, None) # guess what types are valid new_type = None if f.type is typing.Any or isinstance(f.type, typing.TypeVar): return (None, None) if dataclasses.is_dataclass(f.type): # ignore if type is a dataclass: use this type annotation to # implement dataclass inheritance if not isinstance(obj, f.type): raise exceptions.DataclassParseError((f"Field {f.name} specified " f"as dataclass {f.type.__name__}, which isn't a parent class " f"of {obj.__class__.__name__}.")) return (None, None) elif isinstance(f.type, typing._GenericAlias) \ or isinstance(f.type, typing._SpecialForm): # type is a generic from typing module, eg "typing.List" if f.type.__origin__ is typing.Union: new_type = None # can't do coercion, but can test type valid_types = list(f.type.__args__) elif issubclass(f.type.__origin__, typing.Generic): return (None, None) # can't do anything in this case else: new_type = f.type.__origin__ valid_types = [new_type] else: new_type = f.type valid_types = [new_type] # Get types of field's default value, if present. Dataclass doesn't # require defaults to be same type as what's given for field. if not isinstance(f.default, dataclasses._MISSING_TYPE): valid_types.append(type(f.default)) if not isinstance(f.default_factory, dataclasses._MISSING_TYPE): valid_types.append(type(f.default_factory())) return (new_type, valid_types)
[docs]def _mdtf_dataclass_type_coercion(self, log): """Do type checking on all dataclass fields after the auto-generated ``__init__`` method, but before any ``__post_init__`` method. .. warning:: Type checking logic used is specific to the ``typing`` module in python 3.7. It may or may not work on newer pythons, and definitely will not work with 3.5 or 3.6. See `<https://stackoverflow.com/a/52664522>`__. """ for f in dataclasses.fields(self): value = getattr(self, f.name, NOTSET) new_type, valid_types = _mdtf_dataclass_get_field_types(self, f) try: if valid_types is None or isinstance(value, tuple(valid_types)): continue # don't coerce if we're already a valid type if new_type is None or hasattr(new_type, '__abstract_methods__'): continue # can't do type coercion else: if hasattr(new_type, 'from_struct'): new_value = new_type.from_struct(value) elif isinstance(new_type, enum.Enum): # need to use item syntax to create enum from name new_value = new_type.__getitem__(value) else: new_value = new_type(value) # https://stackoverflow.com/a/54119384 for implementation object.__setattr__(self, f.name, new_value) except (TypeError, ValueError, dataclasses.FrozenInstanceError) as exc: raise exceptions.DataclassParseError((f"{self.__class__.__name__}: " f"Couldn't coerce value {repr(value)} for field {f.name} from " f"type {type(value)} to type {new_type}.")) from exc except Exception as exc: log.exception("%s: Caught exception: %r", self.__class__.__name__, exc) raise exc
[docs]def _mdtf_dataclass_type_check(self, log): """Do type checking on all dataclass fields after ``__init__`` and ``__post_init__`` methods. .. warning:: Type checking logic used is specific to the ``typing`` module in python 3.7. It may or may not work on newer pythons, and definitely will not work with 3.5 or 3.6. See `<https://stackoverflow.com/a/52664522>`__. """ for f in dataclasses.fields(self): value = getattr(self, f.name, NOTSET) if value is None or value is NOTSET: continue if value is MANDATORY: raise exceptions.DataclassParseError((f"{self.__class__.__name__}: " f"No value supplied for mandatory field {f.name}.")) _, valid_types = _mdtf_dataclass_get_field_types(self, f) if valid_types is not None and not isinstance(value, tuple(valid_types)): log.exception("%s: Failed type check for field '%s': %s != %s.", self.__class__.__name__, f.name, type(value), valid_types) raise exceptions.DataclassParseError((f"{self.__class__.__name__}: " f"Expected {f.name} to be {f.type}, got {type(value)} " f"({repr(value)})."))
DEFAULT_MDTF_DATACLASS_KWARGS = {'init': True, 'repr': True, 'eq': True, 'order': False, 'unsafe_hash': False, 'frozen': False} # declaration to allow calling with and without args: python cookbook 9.6 # https://github.com/dabeaz/python-cookbook/blob/master/src/9/defining_a_decorator_that_takes_an_optional_argument/example.py
[docs]def mdtf_dataclass(cls=None, **deco_kwargs): """Wrap :py:func:`~dataclasses.dataclass` class decorator to customize dataclasses to provide (very) rudimentary type checking and conversion. This is hacky, since dataclasses don't enforce type annontations for their fields. A better solution would be to use a deserialization library like pydantic. After the auto-generated ``__init__`` and the class' ``__post_init__``, the following tasks are performed: 1. Verify that mandatory fields have values specified. We have to work around the usual :py:func:`~dataclasses.dataclass` way of doing this, because it leads to errors in the signature of the dataclass-generated ``__init__`` method under inheritance (mandatory fields can't come after optional fields.) Mandatory fields must be designated by setting their default to ``MANDATORY``, and a DataclassParseError is raised here if mandatory fields are uninitialized. 2. Check each field's value to see if it's consistent with known type info. If not, attempt to coerce it to that type, using a ``from_struct`` method if it exists. Raise DataclassParseError if this fails. .. warning:: Unlike :py:func:`~dataclasses.dataclass`, all fields **must** have a *default* or *default_factory* defined. Fields which are mandatory must have their default value set to the sentinel object ``MANDATORY``. """ dc_kwargs = DEFAULT_MDTF_DATACLASS_KWARGS.copy() dc_kwargs.update(deco_kwargs) if cls is None: # called without arguments return functools.partial(mdtf_dataclass, **dc_kwargs) if not hasattr(cls, '__post_init__'): # create dummy __post_init__ if none deefined, so we can wrap it # contrast with what we do below in regex_dataclass() def _dummy_post_init(self, *args, **kwargs): pass type.__setattr__(cls, '__post_init__', _dummy_post_init) # apply dataclasses' decorator cls = dataclasses.dataclass(cls, **dc_kwargs) # Do type coercion after dataclass' __init__, but before user __post_init__ # Do type check after __init__ and __post_init__ _old_post_init = cls.__post_init__ @functools.wraps(_old_post_init) def _new_post_init(self, *args, **kwargs): if hasattr(self, 'log'): _post_init_log = self.log # for object hierarchy else: _post_init_log = _log # fallback: use module-level logger _mdtf_dataclass_type_coercion(self, _post_init_log) _old_post_init(self, *args, **kwargs) _mdtf_dataclass_type_check(self, _post_init_log) type.__setattr__(cls, '__post_init__', _new_post_init) return cls
[docs]def is_regex_dataclass(obj): return hasattr(obj, '_is_regex_dataclass') and obj._is_regex_dataclass == True
[docs]def _regex_dataclass_preprocess_kwargs(self, kwargs): """Edit kwargs going to the auto-generated __init__ method of this dataclass. If any fields are regex_dataclasses, construct and parse their values first. Raises a DataclassParseError if different regex_dataclasses (at any level of inheritance) try to assign different values to a field of the same name. We do this by assigning to a :class:`~src.util.basic.ConsistentDict`. """ new_kw = filter_dataclass(kwargs, self, init='all') new_kw = basic.ConsistentDict.from_struct(new_kw) for cls_ in self.__class__.__bases__: if not is_regex_dataclass(cls_): continue for f in dataclasses.fields(self): if not f.type == cls_: continue if f.name in kwargs: val = kwargs[f.name] elif not isinstance(f.default, dataclasses._MISSING_TYPE): val = f.default elif not isinstance(f.default_factory, dataclasses._MISSING_TYPE): val = f.default_factory() else: raise exceptions.DataclassParseError(f"Can't set value for {f.name}.") new_d = dataclasses.asdict(f.type.from_string(val)) new_d = filter_dataclass(new_d, self, init='all') try: new_kw.update(new_d) except exceptions.WormKeyError as exc: raise exceptions.DataclassParseError((f"{self.__class__.__name__}: " f"Tried to make inconsistent field assignment when parsing " f"{f.name} as an instance of {f.type.__name__}.")) from exc post_init = dict() for f in dataclasses.fields(self): if not f.init and f.name in new_kw: post_init[f.name] = new_kw.pop(f.name) return (new_kw, post_init)
[docs]def regex_dataclass(pattern, **deco_kwargs): """Decorator for a dataclass that adds a from_string classmethod which creates instances of that dataclass by parsing an input string with a :class:`RegexPattern` or :class:`ChainedRegexPattern`. The values of all fields returned by the match() method of the pattern are passed to the __init__ method of the dataclass as kwargs. Additionally, if the type of one or more fields is set to a class that's also been decorated with regex_dataclass, the parsing logic for that field's regex_dataclass will be invoked on that field's value (ie, a string obtained by regex matching in *this* regex_dataclass), and the parsed values of those fields will be supplied to this regex_dataclass constructor. This is our implementation of composition for regex_dataclasses. .. note:: Unlike :func:`mdtf_dataclass`, type coercion is done *after* ``__post_init__`` for these dataclasses. This is necessary due to composition: if a regex_dataclass is being instantiated as a field of another regex_dataclass, all values being passed to it will be strings (the regex fields), and type coercion is the job of ``__post_init__``. """ dc_kwargs = DEFAULT_MDTF_DATACLASS_KWARGS.copy() dc_kwargs.update(deco_kwargs) def _dataclass_decorator(cls): if '__post_init__' not in cls.__dict__: # Prevent class from inheriting __post_init__ from parents if it # doesn't overload it (which is why we use __dict__ and not # hasattr().) __post_init__ of all parents will have been called when # the parent classes are instantiated by _regex_dataclass_preprocess_kwargs. def _dummy_post_init(self, *args, **kwargs): pass type.__setattr__(cls, '__post_init__', _dummy_post_init) # apply dataclasses' decorator cls = dataclasses.dataclass(cls, **dc_kwargs) # check that all DCs specified as fields are also in class hierarchy # so that we inherit their fields; probably no way this could happen though for f in dataclasses.fields(cls): if is_regex_dataclass(f.type) and f.type not in cls.__mro__: raise TypeError((f"{cls.__name__}: Field {f.name} specified as " f"{f.type.__name__}, but we don't inherit from it.")) _old_init = cls.__init__ @functools.wraps(_old_init) def _new_init(self, first_arg=None, *args, **kwargs): if isinstance(first_arg, str) and not args and not kwargs: # instantiate from running regex on string, if a string is the # only argument to the constructor self._pattern.match(first_arg) first_arg = None kwargs = self._pattern.data new_kw, other_kw = _regex_dataclass_preprocess_kwargs(self, kwargs) for k,v in other_kw.items(): # set field values that aren't arguments to _old_init object.__setattr__(self, k, v) if first_arg is None: _old_init(self, *args, **new_kw) else: _old_init(self, first_arg, *args, **new_kw) _mdtf_dataclass_type_coercion(self, _log) _mdtf_dataclass_type_check(self, _log) type.__setattr__(cls, '__init__', _new_init) def _from_string(cls_, str_, *args): cls_._pattern.match(str_, *args) return cls_(**cls_._pattern.data) type.__setattr__(cls, 'from_string', classmethod(_from_string)) type.__setattr__(cls, '_is_regex_dataclass', True) type.__setattr__(cls, '_pattern', pattern) return cls return _dataclass_decorator
[docs]def dataclass_factory(dataclass_decorator, class_name, *parents, **kwargs): """Function that returns a dataclass (ie, a decorated class) whose fields are the union of the fields specified in its parent classes. Args: dataclass_decorator: decorator to apply to the new class. class_name: name of the new class. parents: collection of other mdtf_dataclasses to inherit from. Order in the collection determines the MRO. kwargs: optional; arguments to pass to dataclass_decorator when it's applied to produce the returned class. """ def _to_dataclass(self, cls_, **kwargs_): f"""Method to create an instance of one of the parent classes of {class_name} by copying over the relevant subset of fields. """ new_kwargs = filter_dataclass(self, cls_) new_kwargs.update(kwargs_) return cls_(**new_kwargs) def _from_dataclasses(cls_, *other_dcs, **kwargs_): f"""Classmethod to create a new instance of {class_name} from instances of its parents, along with any other field values passed in kwargs. """ new_kwargs = dict() for dc in other_dcs: new_kwargs.update(filter_dataclass(dc, cls_)) new_kwargs.update(kwargs_) return cls_(**new_kwargs) methods = { 'to_dataclass': _to_dataclass, 'from_dataclasses': classmethod(_from_dataclasses), } for dc in parents: method_nm = 'to_' + dc.__name__ methods[method_nm] = functools.partialmethod(_to_dataclass, cls_=dc) new_cls = type(class_name, tuple(parents), methods) return dataclass_decorator(new_cls, **kwargs)
# ----------------------------------------------------
[docs]def filter_dataclass(d, dc, init=False): """Return a dict of the subset of fields or entries in d that correspond to the fields in dataclass dc. Args: d: (dict, dataclass or dataclass instance): dc: (dataclass or dataclass instance): init: bool or 'all', default False: - If False: Include only the fields of dc (as returned by :py:func:`dataclasses.fields`.) - If True: Include only the arguments to dc's constructor (ie, include any `init-only fields <https://docs.python.org/3/library/dataclasses.html#init-only-variables>`__ and exclude any of dc's fields with init=False. - If 'all': Include the union of the above two options. Returns: dict containing the subset of key:value pairs from d such that the keys are included in the set of dc's fields specified by the value of init. """ assert dataclasses.is_dataclass(dc) if dataclasses.is_dataclass(d): if isinstance(d, type): d = d() # d is a class; instantiate with default field values d = dataclasses.asdict(d) if not init or (init == 'all'): ans = {f.name: d[f.name] for f in dataclasses.fields(dc) if f.name in d} else: ans = {f.name: d[f.name] for f in dataclasses.fields(dc) \ if (f.name in d and f.init)} if init or (init == 'all'): init_fields = filter( (lambda f: f.type == dataclasses.InitVar), dc.__dataclass_fields__.values() ) ans.update({f.name: d[f.name] for f in init_fields if f.name in d}) return ans
[docs]def coerce_to_dataclass(d, dc, **kwargs): """Given a dataclass dc (may be the class or an instance of it), and a dict, dataclass or dataclass instance d, return an instance of dc's class with field values initialized from those in d, along with any extra values passed in kwargs. """ new_kwargs = filter_dataclass(d, dc, init=True) if kwargs: new_kwargs.update(kwargs) new_kwargs = filter_dataclass(new_kwargs, dc, init=True) if not isinstance(dc, type): dc = dc.__class__ return dc(**new_kwargs)