From 3b80f5c8e5a95ba31e92e4825ecc0ba3148b555a Mon Sep 17 00:00:00 2001 From: Brian Demers Date: Thu, 28 Sep 2023 13:15:44 -0400 Subject: [PATCH] The InvalidRequestFilter is more flexible Allowing encoded periods and forward slashes can now be independently enabled --- .../web/filter/InvalidRequestFilter.java | 56 ++++++++++++-- .../filter/InvalidRequestFilterTest.groovy | 75 ++++++++++++++++--- 2 files changed, 117 insertions(+), 14 deletions(-) diff --git a/web/src/main/java/org/apache/shiro/web/filter/InvalidRequestFilter.java b/web/src/main/java/org/apache/shiro/web/filter/InvalidRequestFilter.java index 97133500fe..55b831cce1 100644 --- a/web/src/main/java/org/apache/shiro/web/filter/InvalidRequestFilter.java +++ b/web/src/main/java/org/apache/shiro/web/filter/InvalidRequestFilter.java @@ -28,6 +28,7 @@ import java.util.Arrays; import java.util.Collections; import java.util.List; +import java.util.stream.Stream; /** * A request filter that blocks malicious requests. Invalid request will respond with a 400 response code. @@ -61,6 +62,12 @@ public class InvalidRequestFilter extends AccessControlFilter { private boolean blockTraversal = true; + private boolean blockEncodedPeriod = true; + + private boolean blockEncodedForwardSlash = true; + + private boolean blockRewriteTraversal = true; + @Override protected boolean isAccessAllowed(ServletRequest req, ServletResponse response, Object mappedValue) throws Exception { HttpServletRequest request = WebUtils.toHttp(req); @@ -74,8 +81,10 @@ private boolean isValid(String uri) { return !StringUtils.hasText(uri) || ( !containsSemicolon(uri) && !containsBackslash(uri) - && !containsNonAsciiCharacters(uri)) - && !containsTraversal(uri); + && !containsNonAsciiCharacters(uri) + && !containsTraversal(uri) + && !containsEncodedPeriods(uri) + && !containsEncodedForwardSlash(uri)); } @Override @@ -118,9 +127,22 @@ private static boolean containsOnlyPrintableAsciiCharacters(String uri) { private boolean containsTraversal(String uri) { if (isBlockTraversal()) { - return !(isNormalized(uri) - && PERIOD.stream().noneMatch(uri::contains) - && FORWARDSLASH.stream().noneMatch(uri::contains)); + return !isNormalized(uri) + || (isBlockRewriteTraversal() && Stream.of("/..;", "/.;").anyMatch(uri::contains)); + } + return false; + } + + private boolean containsEncodedPeriods(String uri) { + if (isBlockEncodedPeriod()) { + return PERIOD.stream().anyMatch(uri::contains); + } + return false; + } + + private boolean containsEncodedForwardSlash(String uri) { + if (isBlockEncodedForwardSlash()) { + return FORWARDSLASH.stream().anyMatch(uri::contains); } return false; } @@ -180,4 +202,28 @@ public boolean isBlockTraversal() { public void setBlockTraversal(boolean blockTraversal) { this.blockTraversal = blockTraversal; } + + public boolean isBlockEncodedPeriod() { + return blockEncodedPeriod; + } + + public void setBlockEncodedPeriod(boolean blockEncodedPeriod) { + this.blockEncodedPeriod = blockEncodedPeriod; + } + + public boolean isBlockEncodedForwardSlash() { + return blockEncodedForwardSlash; + } + + public void setBlockEncodedForwardSlash(boolean blockEncodedForwardSlash) { + this.blockEncodedForwardSlash = blockEncodedForwardSlash; + } + + public boolean isBlockRewriteTraversal() { + return blockRewriteTraversal; + } + + public void setBlockRewriteTraversal(boolean blockRewriteTraversal) { + this.blockRewriteTraversal = blockRewriteTraversal; + } } diff --git a/web/src/test/groovy/org/apache/shiro/web/filter/InvalidRequestFilterTest.groovy b/web/src/test/groovy/org/apache/shiro/web/filter/InvalidRequestFilterTest.groovy index 9e37b3fef2..8777974c4f 100644 --- a/web/src/test/groovy/org/apache/shiro/web/filter/InvalidRequestFilterTest.groovy +++ b/web/src/test/groovy/org/apache/shiro/web/filter/InvalidRequestFilterTest.groovy @@ -38,6 +38,9 @@ class InvalidRequestFilterTest { assertThat "filter.blockNonAscii expected to be true", filter.isBlockNonAscii() assertThat "filter.blockSemicolon expected to be true", filter.isBlockSemicolon() assertThat "filter.blockTraversal expected to be true", filter.isBlockTraversal() + assertThat "filter.blockRewriteTraversal expected to be true", filter.isBlockRewriteTraversal() + assertThat "filter.blockEncodedPeriod expected to be true", filter.isBlockEncodedPeriod() + assertThat "filter.blockEncodedForwardSlash expected to be true", filter.isBlockEncodedForwardSlash() } @Test @@ -58,7 +61,6 @@ class InvalidRequestFilterTest { } } - @Test void testFilterBlocks() { InvalidRequestFilter filter = new InvalidRequestFilter() @@ -72,6 +74,7 @@ class InvalidRequestFilterTest { assertPathBlocked(filter, "/something", "/;something") assertPathBlocked(filter, "/something", "/something", "/;") + assertPathBlocked(filter, "/something", "/something", "/.;") } @Test @@ -80,23 +83,81 @@ class InvalidRequestFilterTest { assertPathBlocked(filter, "/something/../") assertPathBlocked(filter, "/something/../bar") assertPathBlocked(filter, "/something/../bar/") - assertPathBlocked(filter, "/something/%2e%2E/bar/") assertPathBlocked(filter, "/something/..") assertPathBlocked(filter, "/..") assertPathBlocked(filter, "..") assertPathBlocked(filter, "../") - assertPathBlocked(filter, "%2E./") assertPathBlocked(filter, "%2F./") assertPathBlocked(filter, "/something/./") assertPathBlocked(filter, "/something/./bar") assertPathBlocked(filter, "/something/\u002e/bar") assertPathBlocked(filter, "/something/./bar/") - assertPathBlocked(filter, "/something/%2e/bar/") - assertPathBlocked(filter, "/something/%2f/bar/") assertPathBlocked(filter, "/something/.") assertPathBlocked(filter, "/.") assertPathBlocked(filter, "/something/../something/.") assertPathBlocked(filter, "/something/../something/.") + assertPathBlocked(filter, "/something/.;") + assertPathBlocked(filter, "/something/%2e%3b") + + assertPathAllowed(filter, "/something/.bar") + assertPathAllowed(filter, "/.something") + assertPathAllowed(filter, ".something") + } + + @Test + void testBlocksEncodedPeriod() { + InvalidRequestFilter filter = new InvalidRequestFilter() + assertPathBlocked(filter, "/%2esomething") + assertPathBlocked(filter, "%2esomething") + assertPathBlocked(filter, "%2E./") + assertPathBlocked(filter, "%2F./") + assertPathBlocked(filter, "/something/%2e;") + assertPathBlocked(filter, "/something/%2e%3b") + assertPathBlocked(filter, "/something/%2e%2E/bar/") + assertPathBlocked(filter, "/something/%2e/bar/") + } + + @Test + void testAllowsEncodedPeriod() { + InvalidRequestFilter filter = new InvalidRequestFilter() + filter.setBlockEncodedPeriod(false) + assertPathAllowed(filter, "/%2esomething") + assertPathAllowed(filter, "%2esomething") + assertPathAllowed(filter, "%2E./") + assertPathAllowed(filter, "/something/%2e%2E/bar/") + assertPathAllowed(filter, "/something/%2e/bar/") + } + + @Test + void testBlocksEncodedForwardSlash() { + InvalidRequestFilter filter = new InvalidRequestFilter() + assertPathBlocked(filter, "%2F./") + assertPathBlocked(filter, "/something/%2f/bar/") + } + + @Test + void testAllowsEncodedForwardSlash() { + InvalidRequestFilter filter = new InvalidRequestFilter() + filter.setBlockEncodedForwardSlash(false) + assertPathAllowed(filter, "%2F./") + assertPathAllowed(filter, "/something/%2f/bar/") + } + + @Test + void testBlocksRewriteTraversal() { + InvalidRequestFilter filter = new InvalidRequestFilter() + filter.setBlockSemicolon(false) + assertPathBlocked(filter, "/something/..;jsessionid=foobar") + assertPathBlocked(filter, "/something/.;jsessionid=foobar") + } + + @Test + void testAllowRewriteTraversal() { + InvalidRequestFilter filter = new InvalidRequestFilter() + filter.setBlockSemicolon(false) + filter.setBlockRewriteTraversal(false) + assertPathAllowed(filter, "/something/..;jsessionid=foobar") + assertPathAllowed(filter, "/something/.;jsessionid=foobar") } @Test @@ -158,15 +219,11 @@ class InvalidRequestFilterTest { assertPathAllowed(filter, "/..") assertPathAllowed(filter, "..") assertPathAllowed(filter, "../") - assertPathAllowed(filter, "%2E./") - assertPathAllowed(filter, "%2F./") assertPathAllowed(filter, "/something/./") assertPathAllowed(filter, "/something/./bar") assertPathAllowed(filter, "/something/\u002e/bar") assertPathAllowed(filter, "/something\u002fbar") assertPathAllowed(filter, "/something/./bar/") - assertPathAllowed(filter, "/something/%2e/bar/") - assertPathAllowed(filter, "/something/%2f/bar/") assertPathAllowed(filter, "/something/.") assertPathAllowed(filter, "/.") assertPathAllowed(filter, "/something/../something/.")