From 507d734cb9c76a4850112181533208b5911dfb46 Mon Sep 17 00:00:00 2001 From: Anton Medvedev Date: Thu, 16 Nov 2023 16:22:34 +0100 Subject: [PATCH] Fix ast printing for ?? operator Fixes #442 --- ast/print.go | 21 +++++++++++++-------- ast/print_test.go | 6 ++++-- 2 files changed, 17 insertions(+), 10 deletions(-) diff --git a/ast/print.go b/ast/print.go index dd9e0db0f..b4fb4f367 100644 --- a/ast/print.go +++ b/ast/print.go @@ -58,18 +58,23 @@ func (n *UnaryNode) String() string { } func (n *BinaryNode) String() string { - var left, right string - if b, ok := n.Left.(*BinaryNode); ok && operator.Less(b.Operator, n.Operator) { - left = fmt.Sprintf("(%s)", n.Left.String()) + var lhs, rhs string + + lb, ok := n.Left.(*BinaryNode) + if ok && (operator.Less(lb.Operator, n.Operator) || lb.Operator == "??") { + lhs = fmt.Sprintf("(%s)", n.Left.String()) } else { - left = n.Left.String() + lhs = n.Left.String() } - if b, ok := n.Right.(*BinaryNode); ok && operator.Less(b.Operator, n.Operator) { - right = fmt.Sprintf("(%s)", n.Right.String()) + + rb, ok := n.Right.(*BinaryNode) + if ok && operator.Less(rb.Operator, n.Operator) { + rhs = fmt.Sprintf("(%s)", n.Right.String()) } else { - right = n.Right.String() + rhs = n.Right.String() } - return fmt.Sprintf("%s %s %s", left, n.Operator, right) + + return fmt.Sprintf("%s %s %s", lhs, n.Operator, rhs) } func (n *ChainNode) String() string { diff --git a/ast/print_test.go b/ast/print_test.go index 7bb5c85b2..077e85582 100644 --- a/ast/print_test.go +++ b/ast/print_test.go @@ -3,10 +3,11 @@ package ast_test import ( "testing" - "github.com/antonmedv/expr/ast" - "github.com/antonmedv/expr/parser" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + + "github.com/antonmedv/expr/ast" + "github.com/antonmedv/expr/parser" ) func TestPrint(t *testing.T) { @@ -67,6 +68,7 @@ func TestPrint(t *testing.T) { {`a[1:]`, `a[1:]`}, {`a[1:]`, `a[1:]`}, {`a[:]`, `a[:]`}, + {`(nil ?? 1) > 0`, `(nil ?? 1) > 0`}, } for _, tt := range tests {