diff --git a/engine/src/main/java/io/seldon/engine/service/InternalPredictionService.java b/engine/src/main/java/io/seldon/engine/service/InternalPredictionService.java index 667c684301..f8efcd18d7 100644 --- a/engine/src/main/java/io/seldon/engine/service/InternalPredictionService.java +++ b/engine/src/main/java/io/seldon/engine/service/InternalPredictionService.java @@ -65,6 +65,7 @@ import org.springframework.stereotype.Service; import org.springframework.util.LinkedMultiValueMap; import org.springframework.util.MultiValueMap; +import org.springframework.web.client.HttpStatusCodeException; import org.springframework.web.client.ResourceAccessException; import org.springframework.web.client.RestTemplate; @@ -495,27 +496,7 @@ private SeldonMessage queryREST( logger.debug(response); SeldonMessage.Builder builder = SeldonMessage.newBuilder(); JsonFormat.parser().ignoringUnknownFields().merge(response, builder); - SeldonMessage seldonMessage = builder.build(); - if (httpResponse.getStatusCode().is2xxSuccessful()) { - return seldonMessage; - } else { - logger.error( - "Couldn't retrieve prediction from external prediction server -- bad http return code: " - + httpResponse.getStatusCode()); - Status seldonMessageStatus = seldonMessage.getStatus(); - if (seldonMessageStatus == null) { - throw new APIException( - APIException.ApiExceptionType.ENGINE_MICROSERVICE_ERROR, - String.format("Bad return code %d", httpResponse.getStatusCodeValue())); - } - else - { - throw new APIException(seldonMessageStatus.getCode(), - seldonMessageStatus.getReason(), - 200, - seldonMessageStatus.getInfo()); - } - } + return builder.build(); } finally { if (logger.isDebugEnabled()) { logger.debug( @@ -528,9 +509,12 @@ private SeldonMessage queryREST( logger.error("Invalid protocol buffer during Json Format merge - ", e); throw new APIException( APIException.ApiExceptionType.ENGINE_MICROSERVICE_ERROR, e.toString()); - } catch (APIException e) + } catch (HttpStatusCodeException e) { - throw e; + logger.error( + "Couldn't retrieve prediction from external prediction server -- bad http return code: " + + e.getRawStatusCode()); + handleHttpStatusCodeError(e); } catch (Exception e) { logger.error("Couldn't retrieve prediction from external prediction server - ", e); @@ -543,4 +527,31 @@ private SeldonMessage queryREST( APIException.ApiExceptionType.ENGINE_MICROSERVICE_ERROR, String.format("Failed to retrieve predictions after %d attempts", restRetries)); } + + private void handleHttpStatusCodeError(HttpStatusCodeException exception) { + String response = exception.getResponseBodyAsString(); + SeldonMessage.Builder builder = SeldonMessage.newBuilder(); + try { + JsonFormat.parser().ignoringUnknownFields().merge(response, builder); + SeldonMessage seldonMessage = builder.build(); + Status seldonMessageStatus = seldonMessage.getStatus(); + if (seldonMessageStatus == null) { + throw new APIException( + APIException.ApiExceptionType.ENGINE_MICROSERVICE_ERROR, + String.format("Bad return code %d", exception.getRawStatusCode())); + } + else + { + throw new APIException(seldonMessageStatus.getCode(), + seldonMessageStatus.getReason(), + 200, + seldonMessageStatus.getInfo()); + } + } catch (InvalidProtocolBufferException ex) + { + logger.error("Invalid protocol buffer during Json Format merge - ", ex); + throw new APIException( + APIException.ApiExceptionType.ENGINE_MICROSERVICE_ERROR, ex.toString()); + } + } } diff --git a/engine/src/test/java/io/seldon/engine/api/rest/TestRestClientControllerExternalGraphs.java b/engine/src/test/java/io/seldon/engine/api/rest/TestRestClientControllerExternalGraphs.java index 7cdf11cd3e..aaa41c5d37 100644 --- a/engine/src/test/java/io/seldon/engine/api/rest/TestRestClientControllerExternalGraphs.java +++ b/engine/src/test/java/io/seldon/engine/api/rest/TestRestClientControllerExternalGraphs.java @@ -10,6 +10,8 @@ import io.seldon.protos.PredictionProtos.SeldonMessage; import java.net.URI; import java.nio.charset.StandardCharsets; + +import org.junit.After; import org.junit.Assert; import org.junit.Before; import org.junit.Test; @@ -25,6 +27,7 @@ import org.springframework.boot.test.web.client.TestRestTemplate; import org.springframework.boot.web.server.LocalServerPort; import org.springframework.http.HttpEntity; +import org.springframework.http.HttpHeaders; import org.springframework.http.HttpStatus; import org.springframework.http.MediaType; import org.springframework.http.ResponseEntity; @@ -36,6 +39,8 @@ import org.springframework.test.web.servlet.request.MockMvcRequestBuilders; import org.springframework.test.web.servlet.setup.MockMvcBuilders; import org.springframework.util.MultiValueMap; +import org.springframework.web.client.HttpServerErrorException; +import org.springframework.web.client.HttpStatusCodeException; import org.springframework.web.context.WebApplicationContext; @RunWith(SpringRunner.class) @@ -59,6 +64,11 @@ public void setup() throws Exception { mvc = MockMvcBuilders.webAppContextSetup(context).addFilters(new XSSFilter()).build(); } + @After + public void resetMocks() { + Mockito.reset(testRestTemplate.getRestTemplate()); + } + @LocalServerPort private int port; @Autowired private TestRestTemplate testRestTemplate; @@ -751,4 +761,40 @@ public ResponseEntity answer(InvocationOnMock invocation) { > -1); System.out.println(response); } + + @Test + public void testModelPredictionNon200Response() throws Exception { + String jsonStr = readFile("src/test/resources/model_simple.json", StandardCharsets.UTF_8); + String responseStr = + readFile("src/test/resources/response_status.json", StandardCharsets.UTF_8); + io.seldon.protos.DeploymentProtos.PredictorSpec.Builder PredictorSpecBuilder = io.seldon.protos.DeploymentProtos.PredictorSpec.newBuilder(); + EnginePredictor.updateMessageBuilderFromJson(PredictorSpecBuilder, jsonStr); + io.seldon.protos.DeploymentProtos.PredictorSpec predictorSpec = PredictorSpecBuilder.build(); + final String predictJson = "{" + "\"binData\": \"MTIz\"" + "}"; + ReflectionTestUtils.setField(enginePredictor, "predictorSpec", predictorSpec); + + HttpStatusCodeException exception = HttpServerErrorException.InternalServerError + .create(HttpStatus.BAD_REQUEST, "status text", HttpHeaders.EMPTY, responseStr.getBytes(StandardCharsets.UTF_8), StandardCharsets.UTF_8); + + Mockito.when( + testRestTemplate + .getRestTemplate() + .postForEntity( + Matchers.any(), + Matchers.>>any(), + Matchers.>any())) + .thenThrow(exception); + + MvcResult res = + mvc.perform( + MockMvcRequestBuilders.post("/api/v0.1/predictions") + .accept(MediaType.APPLICATION_JSON_UTF8) + .content(predictJson) + .contentType(MediaType.APPLICATION_JSON_UTF8)) + .andReturn(); + + // Check for returned response that wraps the ApiException into SeldonMessage + Assert.assertEquals(200, res.getResponse().getStatus()); + Assert.assertEquals(responseStr, res.getResponse().getContentAsString()); + } } diff --git a/engine/src/test/resources/response_status.json b/engine/src/test/resources/response_status.json new file mode 100644 index 0000000000..ff1bfebb19 --- /dev/null +++ b/engine/src/test/resources/response_status.json @@ -0,0 +1,8 @@ +{ + "status": { + "code": 400, + "info": "test error message", + "reason": "exception in prediction", + "status": "FAILURE" + } +} \ No newline at end of file