Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make RuleTransformer fully recursive [#257] #421

Merged
merged 1 commit into from
Dec 23, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions shared/src/main/scala/scala/xml/transform/NestingTransformer.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
/* __ *\
** ________ ___ / / ___ Scala API **
** / __/ __// _ | / / / _ | (c) 2002-2020, LAMP/EPFL **
** __\ \/ /__/ __ |/ /__/ __ | (c) 2011-2020, Lightbend, Inc. **
** /____/\___/_/ |_/____/_/ | | http://scala-lang.org/ **
** |/ **
\* */

package scala
package xml
package transform

import scala.collection.Seq

class NestingTransformer(rule: RewriteRule) extends BasicTransformer {
override def transform(n: Node): Seq[Node] = {
rule.transform(super.transform(n))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ package transform
import scala.collection.Seq

class RuleTransformer(rules: RewriteRule*) extends BasicTransformer {
override def transform(n: Node): Seq[Node] =
rules.foldLeft(super.transform(n)) { (res, rule) => rule transform res }
private val transformers = rules.map(new NestingTransformer(_))
override def transform(n: Node): Seq[Node] = {
if (transformers.isEmpty) n
else transformers.tail.foldLeft(transformers.head.transform(n)) { (res, transformer) => transformer.transform(res) }
}
}
17 changes: 16 additions & 1 deletion shared/src/test/scala-2.x/scala/xml/TransformersTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ class TransformersTest {
@Test
def preserveReferentialComplexityInLinearComplexity = { // SI-4528
var i = 0

val xmlNode = <a><b><c><h1>Hello Example</h1></c></b></a>

new RuleTransformer(new RewriteRule {
Expand All @@ -77,4 +77,19 @@ class TransformersTest {

assertEquals(1, i)
}

@Test
def appliesRulesRecursivelyOnPreviousChanges = { // #257
def add(outer: Elem, inner: Node) = new RewriteRule {
override def transform(n: Node): Seq[Node] = n match {
case e: Elem if e.label == outer.label => e.copy(child = e.child ++ inner)
case other => other
}
}

def transformer = new RuleTransformer(add(<element/>, <new/>), add(<new/>, <thing/>))

assertEquals(<element><new><thing/></new></element>, transformer(<element/>))
}
}