diff --git a/CSSLayout/CSSLayout.c b/CSSLayout/CSSLayout.c index dd63fb25..07a0c695 100644 --- a/CSSLayout/CSSLayout.c +++ b/CSSLayout/CSSLayout.c @@ -255,8 +255,18 @@ static void _CSSNodeMarkDirty(const CSSNodeRef node) { } } +void CSSNodeSetMeasureFunc(const CSSNodeRef node, CSSMeasureFunc measureFunc) { + CSS_ASSERT(CSSNodeChildCount(node) == 0, "Cannot set measure function: Nodes with measure functions cannot have children."); + node->measure = measureFunc; +} + +CSSMeasureFunc CSSNodeGetMeasureFunc(const CSSNodeRef node) { + return node->measure; +} + void CSSNodeInsertChild(const CSSNodeRef node, const CSSNodeRef child, const uint32_t index) { CSS_ASSERT(child->parent == NULL, "Child already has a parent, it must be removed first."); + CSS_ASSERT(node->measure == NULL, "Cannot add child: Nodes with measure functions cannot have children."); CSSNodeListInsert(&node->children, child, index); child->parent = node; _CSSNodeMarkDirty(node); @@ -278,7 +288,7 @@ inline uint32_t CSSNodeChildCount(const CSSNodeRef node) { } void CSSNodeMarkDirty(const CSSNodeRef node) { - CSS_ASSERT(node->measure != NULL || CSSNodeChildCount(node) > 0, + CSS_ASSERT(node->measure != NULL, "Only leaf nodes with custom measure functions" "should manually mark themselves as dirty"); _CSSNodeMarkDirty(node); @@ -367,7 +377,6 @@ void CSSNodeStyleSetFlex(const CSSNodeRef node, const float flex) { } CSS_NODE_PROPERTY_IMPL(void *, Context, context, context); -CSS_NODE_PROPERTY_IMPL(CSSMeasureFunc, MeasureFunc, measureFunc, measure); CSS_NODE_PROPERTY_IMPL(CSSPrintFunc, PrintFunc, printFunc, print); CSS_NODE_PROPERTY_IMPL(bool, IsTextnode, isTextNode, isTextNode); CSS_NODE_PROPERTY_IMPL(bool, HasNewLayout, hasNewLayout, hasNewLayout); @@ -1211,7 +1220,7 @@ static void layoutNodeImpl(const CSSNodeRef node, // For content (text) nodes, determine the dimensions based on the text // contents. - if (node->measure && CSSNodeChildCount(node) == 0) { + if (node->measure) { const float innerWidth = availableWidth - marginAxisRow - paddingAndBorderAxisRow; const float innerHeight = availableHeight - marginAxisColumn - paddingAndBorderAxisColumn; @@ -2185,7 +2194,7 @@ bool layoutNodeInternal(const CSSNodeRef node, // most // expensive to measure, so it's worth avoiding redundant measurements if at // all possible. - if (node->measure && CSSNodeChildCount(node) == 0) { + if (node->measure) { const float marginAxisRow = getMarginAxis(node, CSSFlexDirectionRow); const float marginAxisColumn = getMarginAxis(node, CSSFlexDirectionColumn); diff --git a/tests/CSSLayoutMeasureTest.cpp b/tests/CSSLayoutMeasureTest.cpp index e8362d0d..624ca125 100644 --- a/tests/CSSLayoutMeasureTest.cpp +++ b/tests/CSSLayoutMeasureTest.cpp @@ -15,30 +15,24 @@ static CSSSize _measure(CSSNodeRef node, CSSMeasureMode widthMode, float height, CSSMeasureMode heightMode) { - int *measureCount = (int *)CSSNodeGetContext(node); - *measureCount = *measureCount + 1; return CSSSize { - .width = widthMode == CSSMeasureModeUndefined ? 10 : width, - .height = heightMode == CSSMeasureModeUndefined ? 10 : width, + .width = 0, + .height = 0, }; } -TEST(CSSLayoutTest, ignore_measure_on_non_leaf_node) { +TEST(CSSLayoutTest, cannot_add_child_to_node_with_measure_func) { const CSSNodeRef root = CSSNodeNew(); - int measureCount = 0; - CSSNodeSetContext(root, &measureCount); CSSNodeSetMeasureFunc(root, _measure); const CSSNodeRef root_child0 = CSSNodeNew(); - int childMeasureCount = 0; - CSSNodeSetContext(root_child0, &childMeasureCount); - CSSNodeSetMeasureFunc(root_child0, _measure); + ASSERT_DEATH(CSSNodeInsertChild(root, root_child0, 0), "Cannot add child.*"); +} + +TEST(CSSLayoutTest, cannot_add_measure_func_to_non_leaf_node) { + const CSSNodeRef root = CSSNodeNew(); + const CSSNodeRef root_child0 = CSSNodeNew(); CSSNodeInsertChild(root, root_child0, 0); - CSSNodeCalculateLayout(root, CSSUndefined, CSSUndefined, CSSDirectionLTR); - - ASSERT_EQ(0, measureCount); - ASSERT_EQ(1, childMeasureCount); - - CSSNodeFreeRecursive(root); + ASSERT_DEATH(CSSNodeSetMeasureFunc(root, _measure), "Cannot set measure function.*"); }