Skip to content

Commit

Permalink
Improve compatibility for dispatched servlet request in Spring Web ad…
Browse files Browse the repository at this point in the history
…apter (#1681)
  • Loading branch information
jasonjoo2010 authored Aug 20, 2020
1 parent 5905874 commit 656e3de
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 33 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,21 @@
import org.springframework.web.servlet.ModelAndView;

/**
* Since request may be reprocessed in flow if any forwarding or including or other action
* happened (see {@link javax.servlet.ServletRequest#getDispatcherType()}) we will only
* deal with the initial request. So we use <b>reference count</b> to track in
* dispathing "onion" though which we could figure out whether we are in initial type "REQUEST".
* That means the sub-requests which we rarely meet in practice will NOT be recorded in Sentinel.
* <p>
* How to implement a forward sub-request in your action:
* <pre>
* initalRequest() {
* ModelAndView mav = new ModelAndView();
* mav.setViewName("another");
* return mav;
* }
* </pre>
*
* @author kaizi2009
* @since 1.7.1
*/
Expand All @@ -49,20 +64,46 @@ public AbstractSentinelInterceptor(BaseWebMvcConfig config) {
AssertUtil.assertNotBlank(config.getRequestAttributeName(), "requestAttributeName should not be blank");
this.baseWebMvcConfig = config;
}


/**
* @param request
* @param rcKey
* @param step
* @return reference count after increasing (initial value as zero to be increased)
*/
private Integer increaseReferece(HttpServletRequest request, String rcKey, int step) {
Object obj = request.getAttribute(rcKey);

if (obj == null) {
// initial
obj = Integer.valueOf(0);
}

Integer newRc = (Integer)obj + step;
request.setAttribute(rcKey, newRc);
return newRc;
}

@Override
public boolean preHandle(HttpServletRequest request, HttpServletResponse response, Object handler)
throws Exception {
try {
String resourceName = getResourceName(request);

if (StringUtil.isNotEmpty(resourceName)) {
// Parse the request origin using registered origin parser.
String origin = parseOrigin(request);
String contextName = getContextName(request);
ContextUtil.enter(contextName, origin);
setEntryInRequest(request, baseWebMvcConfig.getRequestAttributeName(), resourceName);
if (StringUtil.isEmpty(resourceName)) {
return true;
}

if (increaseReferece(request, this.baseWebMvcConfig.getRequestRefName(), 1) != 1) {
return true;
}

// Parse the request origin using registered origin parser.
String origin = parseOrigin(request);
String contextName = getContextName(request);
ContextUtil.enter(contextName, origin);
Entry entry = SphU.entry(resourceName, ResourceTypeConstants.COMMON_WEB, EntryType.IN);
request.setAttribute(baseWebMvcConfig.getRequestAttributeName(), entry);
return true;
} catch (BlockException e) {
try {
Expand Down Expand Up @@ -95,11 +136,20 @@ protected String getContextName(HttpServletRequest request) {
@Override
public void afterCompletion(HttpServletRequest request, HttpServletResponse response,
Object handler, Exception ex) throws Exception {
if (increaseReferece(request, this.baseWebMvcConfig.getRequestRefName(), -1) != 0) {
return;
}

Entry entry = getEntryInRequest(request, baseWebMvcConfig.getRequestAttributeName());
if (entry != null) {
traceExceptionAndExit(entry, ex);
removeEntryInRequest(request);
if (entry == null) {
// should not happen
RecordLog.warn("[{}] No entry found in request, key: {}",
getClass().getSimpleName(), baseWebMvcConfig.getRequestAttributeName());
return;
}

traceExceptionAndExit(entry, ex);
removeEntryInRequest(request);
ContextUtil.exit();
}

Expand All @@ -108,26 +158,6 @@ public void postHandle(HttpServletRequest request, HttpServletResponse response,
ModelAndView modelAndView) throws Exception {
}

/**
* Note:
* If the attribute key already exists in request, don't create new {@link Entry},
* to guarantee the order of {@link Entry} in pair and avoid {@link com.alibaba.csp.sentinel.ErrorEntryFreeException}.
*
* Refer to:
* https://github.com/alibaba/Sentinel/issues/1531
* https://github.com/alibaba/Sentinel/issues/1482
*/
protected void setEntryInRequest(HttpServletRequest request, String name, String resourceName) throws BlockException {
Object attrVal = request.getAttribute(name);
if (attrVal != null) {
RecordLog.warn("[{}] The attribute key '{}' already exists in request, please set `requestAttributeName`",
getClass().getSimpleName(), name);
} else {
Entry entry = SphU.entry(resourceName, ResourceTypeConstants.COMMON_WEB, EntryType.IN);
request.setAttribute(name, entry);
}
}

protected Entry getEntryInRequest(HttpServletRequest request, String attrKey) {
Object entryObject = request.getAttribute(attrKey);
return entryObject == null ? null : (Entry)entryObject;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

import com.alibaba.csp.sentinel.adapter.spring.webmvc.callback.BlockExceptionHandler;
import com.alibaba.csp.sentinel.adapter.spring.webmvc.callback.RequestOriginParser;
import com.alibaba.csp.sentinel.util.AssertUtil;

/**
* Common base configuration for Spring Web MVC adapter.
Expand All @@ -28,6 +27,7 @@
public abstract class BaseWebMvcConfig {

protected String requestAttributeName;
protected String requestRefName;
protected BlockExceptionHandler blockExceptionHandler;
protected RequestOriginParser originParser;

Expand All @@ -37,6 +37,16 @@ public String getRequestAttributeName() {

public void setRequestAttributeName(String requestAttributeName) {
this.requestAttributeName = requestAttributeName;
this.requestRefName = this.requestAttributeName + "-rc";
}

/**
* Paired with attr name used to track reference count.
*
* @return
*/
public String getRequestRefName() {
return requestRefName;
}

public BlockExceptionHandler getBlockExceptionHandler() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,40 +17,53 @@

import java.util.Random;
import java.util.concurrent.TimeUnit;
import org.springframework.stereotype.Controller;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.PathVariable;
import org.springframework.web.bind.annotation.RestController;
import org.springframework.web.bind.annotation.ResponseBody;
import org.springframework.web.servlet.ModelAndView;

/**
* Test controller
* @author kaizi2009
*/
@RestController
@Controller
public class WebMvcTestController {

@GetMapping("/hello")
@ResponseBody
public String apiHello() {
doBusiness();
return "Hello!";
}

@GetMapping("/err")
@ResponseBody
public String apiError() {
doBusiness();
return "Oops...";
}

@GetMapping("/foo/{id}")
@ResponseBody
public String apiFoo(@PathVariable("id") Long id) {
doBusiness();
return "Hello " + id;
}

@GetMapping("/exclude/{id}")
@ResponseBody
public String apiExclude(@PathVariable("id") Long id) {
doBusiness();
return "Exclude " + id;
}

@GetMapping("/forward")
public ModelAndView apiForward() {
ModelAndView mav = new ModelAndView();
mav.setViewName("hello");
return mav;
}

private void doBusiness() {
Random random = new Random(1);
Expand Down

0 comments on commit 656e3de

Please sign in to comment.