From da99b81cbc227a32a8b0aaf8ff8381f8a0e38f59 Mon Sep 17 00:00:00 2001 From: Matt Davis Date: Mon, 27 Nov 2023 13:28:19 -0800 Subject: [PATCH] WIP nested class support * could be done with existing "all dots" tag syntax instead of the `@` separator, but discarding the explicit distinction between the module/package and type name makes the import/construction logic a lot less precise when dealing with nested types. --- lib/yaml/constructor.py | 15 +++++-- lib/yaml/representer.py | 42 +++++++++++-------- .../data/construct-python-object.code | 2 + .../data/construct-python-object.data | 1 + tests/legacy_tests/test_constructor.py | 12 +++++- 5 files changed, 50 insertions(+), 22 deletions(-) diff --git a/lib/yaml/constructor.py b/lib/yaml/constructor.py index 619acd30..2caa10c1 100644 --- a/lib/yaml/constructor.py +++ b/lib/yaml/constructor.py @@ -541,7 +541,9 @@ def find_python_name(self, name, mark, unsafe=False): if not name: raise ConstructorError("while constructing a Python object", mark, "expected non-empty name appended to the tag", mark) - if '.' in name: + if '@' in name: # handles nested objects via __qualname__ + module_name, object_name = name.rsplit('@', 1) + elif '.' in name: # handle old-style references module_name, object_name = name.rsplit('.', 1) else: module_name = 'builtins' @@ -556,11 +558,16 @@ def find_python_name(self, name, mark, unsafe=False): raise ConstructorError("while constructing a Python object", mark, "module %r is not imported" % module_name, mark) module = sys.modules[module_name] - if not hasattr(module, object_name): + + # descend multi-part object_name to support nested classes + cur_obj = module + for attr in object_name.split('.'): + cur_obj = getattr(cur_obj, attr, None) + if not cur_obj: raise ConstructorError("while constructing a Python object", mark, "cannot find %r in the module %r" - % (object_name, module.__name__), mark) - return getattr(module, object_name) + % (object_name, module_name), mark) + return cur_obj def construct_python_name(self, suffix, node): value = self.construct_scalar(node) diff --git a/lib/yaml/representer.py b/lib/yaml/representer.py index 808ca06d..6173ec38 100644 --- a/lib/yaml/representer.py +++ b/lib/yaml/representer.py @@ -336,24 +336,32 @@ def represent_object(self, data): else: tag = 'tag:yaml.org,2002:python/object/apply:' newobj = False - function_name = '%s.%s' % (function.__module__, function.__name__) - if not args and not listitems and not dictitems \ - and isinstance(state, dict) and newobj: - return self.represent_mapping( - 'tag:yaml.org,2002:python/object:'+function_name, state) - if not listitems and not dictitems \ - and isinstance(state, dict) and not state: - return self.represent_sequence(tag+function_name, args) + value = {} - if args: - value['args'] = args - if state or not isinstance(state, dict): - value['state'] = state - if listitems: - value['listitems'] = listitems - if dictitems: - value['dictitems'] = dictitems - return self.represent_mapping(tag+function_name, value) + + represent_impl = self.represent_mapping + + if not args and not listitems and not dictitems and isinstance(state, dict) and newobj: + # object supports simple object state w/ __newobj__ + tag = 'tag:yaml.org,2002:python/object:' + value = state + elif not listitems and not dictitems and isinstance(state, dict) and not state: + value = args + represent_impl = self.represent_sequence + else: + if args: + value['args'] = args + if state or not isinstance(state, dict): + value['state'] = state + if listitems: + value['listitems'] = listitems + if dictitems: + value['dictitems'] = dictitems + + type_qualname = getattr(function, '__qualname__', getattr(function, '__name__', None)) + type_separator = '@' if '.' in type_qualname else '.' # if nested class, use @ in tag to disambiguate module name and object qualname + tag = f'{tag}{function.__module__}{type_separator}{type_qualname}' + return represent_impl(tag, value) def represent_ordered_dict(self, data): # Provide uniform representation across different Python versions. diff --git a/tests/legacy_tests/data/construct-python-object.code b/tests/legacy_tests/data/construct-python-object.code index 9e611e43..c8f3c784 100644 --- a/tests/legacy_tests/data/construct-python-object.code +++ b/tests/legacy_tests/data/construct-python-object.code @@ -2,6 +2,8 @@ AnObject(1, 'two', [3,3,3]), AnInstance(1, 'two', [3,3,3]), +NestedOuterObject.NestedInnerObject1.NestedInnerObject2.NestedInnerObject3('hi mom'), + AnObject(1, 'two', [3,3,3]), AnInstance(1, 'two', [3,3,3]), diff --git a/tests/legacy_tests/data/construct-python-object.data b/tests/legacy_tests/data/construct-python-object.data index 66797e4c..0a5bbb72 100644 --- a/tests/legacy_tests/data/construct-python-object.data +++ b/tests/legacy_tests/data/construct-python-object.data @@ -1,5 +1,6 @@ - !!python/object:test_constructor.AnObject { foo: 1, bar: two, baz: [3,3,3] } - !!python/object:test_constructor.AnInstance { foo: 1, bar: two, baz: [3,3,3] } +- !!python/object:test_constructor@NestedOuterObject.NestedInnerObject1.NestedInnerObject2.NestedInnerObject3 { data: hi mom } - !!python/object/new:test_constructor.AnObject { args: [1, two], kwds: {baz: [3,3,3]} } - !!python/object/apply:test_constructor.AnInstance { args: [1, two], kwds: {baz: [3,3,3]} } diff --git a/tests/legacy_tests/test_constructor.py b/tests/legacy_tests/test_constructor.py index 0783a21b..376203e6 100644 --- a/tests/legacy_tests/test_constructor.py +++ b/tests/legacy_tests/test_constructor.py @@ -15,7 +15,7 @@ def execute(code): def _make_objects(): global MyLoader, MyDumper, MyTestClass1, MyTestClass2, MyTestClass3, YAMLObject1, YAMLObject2, \ - AnObject, AnInstance, AState, ACustomState, InitArgs, InitArgsWithState, \ + AnObject, AnInstance, NestedOuterObject, AState, ACustomState, InitArgs, InitArgsWithState, \ NewArgs, NewArgsWithState, Reduce, ReduceWithState, Slots, MyInt, MyList, MyDict, \ FixedOffset, today, execute, MyFullLoader @@ -128,6 +128,16 @@ def __eq__(self, other): return type(self) is type(other) and \ (self.foo, self.bar, self.baz) == (other.foo, other.bar, other.baz) + class NestedOuterObject: + class NestedInnerObject1: + class NestedInnerObject2: + class NestedInnerObject3: + def __init__(self, data): + self.data = data + def __eq__(self, other): + return type(self) is type(other) and self.data == other.data + + class AnInstance: def __init__(self, foo=None, bar=None, baz=None): self.foo = foo