diff --git a/src/etc/htmldocck.py b/src/etc/htmldocck.py
index 6bb235b2c83..df215f31823 100644
--- a/src/etc/htmldocck.py
+++ b/src/etc/htmldocck.py
@@ -285,6 +285,11 @@ def flatten(node):
return ''.join(acc)
+def make_xml(text):
+ xml = ET.XML('%s' % text)
+ return xml
+
+
def normalize_xpath(path):
path = path.replace("{{channel}}", channel)
if path.startswith('//'):
@@ -401,7 +406,7 @@ def get_tree_count(tree, path):
return len(tree.findall(path))
-def check_snapshot(snapshot_name, tree, normalize_to_text):
+def check_snapshot(snapshot_name, actual_tree, normalize_to_text):
assert rust_test_path.endswith('.rs')
snapshot_path = '{}.{}.{}'.format(rust_test_path[:-3], snapshot_name, 'html')
try:
@@ -414,11 +419,15 @@ def check_snapshot(snapshot_name, tree, normalize_to_text):
raise FailedCheck('No saved snapshot value')
if not normalize_to_text:
- actual_str = ET.tostring(tree).decode('utf-8')
+ actual_str = ET.tostring(actual_tree).decode('utf-8')
else:
- actual_str = flatten(tree)
+ actual_str = flatten(actual_tree)
+
+ if not expected_str \
+ or (not normalize_to_text and
+ not compare_tree(make_xml(actual_str), make_xml(expected_str), stderr)) \
+ or (normalize_to_text and actual_str != expected_str):
- if expected_str != actual_str:
if bless:
with open(snapshot_path, 'w') as snapshot_file:
snapshot_file.write(actual_str)
@@ -430,6 +439,59 @@ def check_snapshot(snapshot_name, tree, normalize_to_text):
print()
raise FailedCheck('Actual snapshot value is different than expected')
+
+# Adapted from https://github.com/formencode/formencode/blob/3a1ba9de2fdd494dd945510a4568a3afeddb0b2e/formencode/doctest_xml_compare.py#L72-L120
+def compare_tree(x1, x2, reporter=None):
+ if x1.tag != x2.tag:
+ if reporter:
+ reporter('Tags do not match: %s and %s' % (x1.tag, x2.tag))
+ return False
+ for name, value in x1.attrib.items():
+ if x2.attrib.get(name) != value:
+ if reporter:
+ reporter('Attributes do not match: %s=%r, %s=%r'
+ % (name, value, name, x2.attrib.get(name)))
+ return False
+ for name in x2.attrib:
+ if name not in x1.attrib:
+ if reporter:
+ reporter('x2 has an attribute x1 is missing: %s'
+ % name)
+ return False
+ if not text_compare(x1.text, x2.text):
+ if reporter:
+ reporter('text: %r != %r' % (x1.text, x2.text))
+ return False
+ if not text_compare(x1.tail, x2.tail):
+ if reporter:
+ reporter('tail: %r != %r' % (x1.tail, x2.tail))
+ return False
+ cl1 = list(x1)
+ cl2 = list(x2)
+ if len(cl1) != len(cl2):
+ if reporter:
+ reporter('children length differs, %i != %i'
+ % (len(cl1), len(cl2)))
+ return False
+ i = 0
+ for c1, c2 in zip(cl1, cl2):
+ i += 1
+ if not compare_tree(c1, c2, reporter=reporter):
+ if reporter:
+ reporter('children %i do not match: %s'
+ % (i, c1.tag))
+ return False
+ return True
+
+
+def text_compare(t1, t2):
+ if not t1 and not t2:
+ return True
+ if t1 == '*' or t2 == '*':
+ return True
+ return (t1 or '').strip() == (t2 or '').strip()
+
+
def stderr(*args):
if sys.version_info.major < 3:
file = codecs.getwriter('utf-8')(sys.stderr)