Source code for src.util.tests.test_dataclass

import unittest
import unittest.mock as mock
import dataclasses
import typing
from src.util import basic, exceptions
from src.util import dataclass as util
from src.util import datelabel as dt # only used to construct one test instance


[docs]class TestRegexPattern(unittest.TestCase): # TODO: need many more tests for RegexPattern et al.
[docs] def test_regex_dataclass(self): regex = r"/(?P<foo>\d+)/(?P<bar>\d+)/other_text" ppat = util.RegexPattern(regex) @util.regex_dataclass(ppat) class A(): foo: int bar: int ppat.match('/123/456/other_text') self.assertDictEqual(ppat.data, {'foo': '123', 'bar': '456'}) a = A.from_string('/1/2/other_text') self.assertEqual(a.foo, 1) self.assertEqual(a.bar, 2) b = A.from_string('/3/4/other_text') self.assertEqual(a.foo, 1) self.assertEqual(a.bar, 2) self.assertEqual(b.foo, 3) self.assertEqual(b.bar, 4)
[docs]class TestRegexDataclassInheritance(unittest.TestCase):
[docs] def test_initvar(self): grid_label_regex = util.RegexPattern(r""" g(?P<global_mean>m?)(?P<grid_number>\d?) """, input_field="grid_label" ) @util.regex_dataclass(grid_label_regex) class CMIP6_GridLabel(): grid_label: str = util.MANDATORY global_mean: dataclasses.InitVar = "" grid_number: int = 0 spatial_avg: str = dataclasses.field(init=False) def __post_init__(self, global_mean=None): if global_mean == 'm': self.spatial_avg = 'global_mean' else: self.spatial_avg = None drs_directory_regex = util.RegexPattern(r""" /?(CMIP6/)?(?P<activity_id>\w+)/(?P<grid_label>\w+)/ """, input_field="directory" ) @util.regex_dataclass(drs_directory_regex) class CMIP6_DRSDirectory(CMIP6_GridLabel): directory: str = "" activity_id: str = "" grid_label: CMIP6_GridLabel = "" foo = CMIP6_GridLabel('gm6') self.assertDictEqual( dataclasses.asdict(foo), {'grid_label': 'gm6', 'grid_number': 6, 'spatial_avg': 'global_mean'} ) bar = CMIP6_DRSDirectory('/CMIP6/bazinga/gm6/') self.assertDictEqual( dataclasses.asdict(bar), {'grid_label': 'gm6', 'grid_number': 6, 'spatial_avg': 'global_mean', 'directory': '/CMIP6/bazinga/gm6/', 'activity_id': 'bazinga'} )
[docs] def test_conflicts(self): parent1_regex = util.RegexPattern(r""" g(?P<global_mean>m?)(?P<grid_number>\d?) """, input_field="parent1" ) @util.regex_dataclass(parent1_regex) class Parent1(): parent1: str = util.MANDATORY global_mean: dataclasses.InitVar = "" grid_number: int = 0 spatial_avg: str = dataclasses.field(init=False) def __post_init__(self, global_mean=None): if global_mean: self.spatial_avg = 'global_mean' else: self.spatial_avg = None parent2_regex = util.RegexPattern(r""" x(?P<grid_number>\d?)x(?P<spatial_avg>\w+)x """, input_field="parent2" ) @util.regex_dataclass(parent2_regex) class Parent2(): parent2: str = util.MANDATORY grid_number: int = 0 spatial_avg: str = "" def __post_init__(self): if self.spatial_avg: self.spatial_avg += '_mean' child_regex = util.RegexPattern(r""" (?P<activity_id>\w+)/(?P<grid_label>\w+)/(?P<redundant_label>\w+)/ """, input_field="directory" ) @util.regex_dataclass(child_regex) class Child(Parent1, Parent2): directory: str = "" activity_id: str = "" grid_label: Parent1 = "" redundant_label: Parent2 = "" # consistent assignment to fields of same name in parent dataclasses foo = Child('bazinga/gm6/x6xglobalx/') self.assertDictEqual( dataclasses.asdict(foo), {'parent2': 'x6xglobalx', 'grid_number': 6, 'spatial_avg': 'global_mean', 'parent1': 'gm6', 'directory': 'bazinga/gm6/x6xglobalx/', 'activity_id': 'bazinga', 'grid_label': 'gm6', 'redundant_label': 'x6xglobalx'} ) # conflict in assignment to fields of same name in parent dataclasses with self.assertRaises(exceptions.DataclassParseError): _ = Child('bazinga/gm6/x5xglobalx/') with self.assertRaises(exceptions.DataclassParseError): _ = Child('bazinga/gm6/x6xNOT_THE_SAMEx/')
[docs]class TestMDTFDataclass(unittest.TestCase):
[docs] def test_builtin_coerce(self): @util.mdtf_dataclass class Dummy(object): a: str = None b: int = None c: list = None dummy = Dummy(a="foo", b="5", c=(1,2,3)) self.assertEqual(dummy.a, "foo") self.assertEqual(dummy.b, 5) self.assertEqual(dummy.c, [1,2,3])
[docs] def test_builtin_coerce_pre_postinit(self): @util.mdtf_dataclass class Dummy(object): b: int = None def __post_init__(self): self.b += 5 dummy = Dummy(b="3") self.assertEqual(dummy.b, 8) with self.assertRaises(exceptions.DataclassParseError): _ = Dummy(b=Exception)
[docs] def test_builtin_check_post_postinit_1(self): @util.mdtf_dataclass class Dummy(object): a: str = None def __post_init__(self): self.a = 5 with self.assertRaises(exceptions.DataclassParseError): _ = Dummy(a="a string")
[docs] def test_builtin_check_post_postinit_2(self): @util.mdtf_dataclass class Dummy(object): a: str = None def __post_init__(self): self.a = util.MANDATORY with self.assertRaises(exceptions.DataclassParseError): _ = Dummy(a="a string")
[docs] def test_decorator_args(self): @util.mdtf_dataclass(frozen=True) class Dummy(object): a: str = None b: int = None dummy = Dummy(a="foo", b=5) self.assertTrue(hasattr(dummy, '__hash__')) self.assertEqual(dummy.a, "foo") self.assertEqual(dummy.b, 5) with self.assertRaises(dataclasses.FrozenInstanceError): dummy.b = 7
[docs] def test_mandatory_args(self): @util.mdtf_dataclass class Dummy(object): a: str = util.MANDATORY b: int = util.NOTSET c: list = dataclasses.field(default_factory=list) dummy = Dummy(a="foo") self.assertEqual(dummy.a, "foo") self.assertEqual(dummy.b, util.NOTSET) self.assertEqual(dummy.c, []) with self.assertRaises(exceptions.DataclassParseError): dummy = Dummy(b=5)
[docs] def test_mandatory_arg_inheritance(self): @util.mdtf_dataclass class Dummy1(object): a: str = util.MANDATORY @util.mdtf_dataclass class Dummy2(object): b: int = util.NOTSET @util.mdtf_dataclass class Dummy12(Dummy1, Dummy2): pass @util.mdtf_dataclass class Dummy21(Dummy2, Dummy1): pass dummy = Dummy12(a="foo") self.assertEqual(dummy.a, "foo") self.assertEqual(dummy.b, util.NOTSET) with self.assertRaises(exceptions.DataclassParseError): dummy = Dummy12(b=5) dummy = Dummy21(a="foo") self.assertEqual(dummy.a, "foo") self.assertEqual(dummy.b, util.NOTSET) with self.assertRaises(exceptions.DataclassParseError): dummy = Dummy21(b=5)
[docs] def test_defaults_coerce(self): @util.mdtf_dataclass() class Dummy(object): a: int = 5 b: int = None c: int = util.NOTSET d: int = "not_an_int_but_python_don't_care" e: int = dataclasses.field(default_factory=list) dummy = Dummy() self.assertEqual(dummy.a, 5) self.assertEqual(dummy.b, None) self.assertEqual(dummy.c, util.NOTSET) self.assertEqual(dummy.d, "not_an_int_but_python_don't_care") self.assertEqual(dummy.e, [])
[docs] def test_ignore_noninit_values(self): @util.mdtf_dataclass class Dummy(object): a: int = 5 b: int = 6 c: int = dataclasses.field(init=False) d: dataclasses.InitVar[int] = None def __post_init__(self, d): self.c = "foo" self.d = d dummy = Dummy(a=None, b=util.NOTSET, d="bar") self.assertEqual(dummy.a, None) self.assertEqual(dummy.b, util.NOTSET) self.assertEqual(dummy.c, "foo") self.assertEqual(dummy.d, "bar")
[docs] def test_from_struct(self): FooEnum = basic.MDTFEnum('FooEnum', 'X Y Z') @util.mdtf_dataclass class Dummy(object): a: FooEnum = None b: dt.Date = None c: dt.DateFrequency = None dummy = Dummy(a="X", b="2010", c="6hr") self.assertEqual(dummy.a, FooEnum.X) self.assertEqual(dummy.b, dt.Date(2010)) self.assertEqual(dummy.c, dt.DateFrequency(6, 'hr'))
[docs] def test_typing_generics(self): @util.mdtf_dataclass class Dummy(object): a: typing.List = None b: typing.List[int] = None c: typing.Union[int, list] = 6 d: typing.MutableSequence = dataclasses.field(default_factory=list) e: typing.Text = "foo" dummy = Dummy(a=(1,2), b=(1,2)) self.assertEqual(dummy.a, [1,2]) self.assertEqual(dummy.b, [1,2]) self.assertEqual(dummy.c, 6) dummy = Dummy(a=(1,2), b=(1,2), c=[1,2]) self.assertEqual(dummy.c, [1,2]) dummy = Dummy(a=(1,2), b=(1,2), c=5) self.assertEqual(dummy.c, 5) dummy = Dummy(a=(1,2), b=(1,2), d=[1,2]) self.assertEqual(dummy.d, [1,2]) with self.assertRaises(exceptions.DataclassParseError): _ = Dummy(a=(1,2), b=(1,2), d=(1,2))
[docs] def test_typing_generics_2(self): def dummy_f(x: str) -> int: return int(x) @util.mdtf_dataclass class Dummy(object): a: typing.Any = None b: typing.TypeVar('foo') = None c: typing.Callable[[int], str] = util.NOTSET d: typing.Generic[typing.TypeVar('X'), typing.TypeVar('X')] = None e: typing.Tuple[int, int] = (5,6) dummy = Dummy(a="a") self.assertEqual(dummy.a, "a") self.assertEqual(dummy.b, None) self.assertEqual(dummy.c, util.NOTSET) self.assertEqual(dummy.d, None) self.assertEqual(dummy.e, (5,6)) dummy = Dummy(a="a", b="bar", c=dummy_f, d="also_ignored", e=[1,2]) self.assertEqual(dummy.a, "a") self.assertEqual(dummy.b, "bar") self.assertEqual(dummy.c, dummy_f) self.assertEqual(dummy.d, "also_ignored") self.assertEqual(dummy.e, (1,2))
if __name__ == '__main__': unittest.main()