diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/ValidatorUtil.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/ValidatorUtil.scala index 3d93c4e8742ab..a49de687a27dc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/ValidatorUtil.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/ValidatorUtil.scala @@ -42,7 +42,7 @@ object ValidatorUtil extends Logging { val in = openSchemaFile(new Path(key)) try { val schemaFactory = SchemaFactory.newInstance(XMLConstants.W3C_XML_SCHEMA_NS_URI) - schemaFactory.newSchema(new StreamSource(in)) + schemaFactory.newSchema(new StreamSource(in, key)) } finally { in.close() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/xml/XSDToSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/xml/XSDToSchema.scala index 87082299615c3..c03c0ba11de57 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/xml/XSDToSchema.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/xml/XSDToSchema.scala @@ -47,7 +47,7 @@ object XSDToSchema extends Logging{ def read(xsdPath: Path): StructType = { val in = ValidatorUtil.openSchemaFile(xsdPath) val xmlSchemaCollection = new XmlSchemaCollection() - xmlSchemaCollection.setBaseUri(xsdPath.getParent.toString) + xmlSchemaCollection.setBaseUri(xsdPath.toString) val xmlSchema = xmlSchemaCollection.read(new InputStreamReader(in)) getStructType(xmlSchema) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlSuite.scala index 7df7c0d49d191..51e8cfc7f1030 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlSuite.scala @@ -1206,14 +1206,16 @@ class XmlSuite } test("test XSD validation") { - val basketDF = spark.read - .option("rowTag", "basket") - .option("inferSchema", true) - .option("rowValidationXSDPath", getTestResourcePath(resDir + "basket.xsd") - .replace("file:/", "/")) - .xml(getTestResourcePath(resDir + "basket.xml")) - // Mostly checking it doesn't fail - assert(basketDF.selectExpr("entry[0].key").head().getLong(0) === 9027) + Seq("basket.xsd", "include-example/first.xsd").foreach { xsdFile => + val basketDF = spark.read + .option("rowTag", "basket") + .option("inferSchema", true) + .option("rowValidationXSDPath", getTestResourcePath(resDir + xsdFile) + .replace("file:/", "/")) + .xml(getTestResourcePath(resDir + "basket.xml")) + // Mostly checking it doesn't fail + assert(basketDF.selectExpr("entry[0].key").head().getLong(0) === 9027) + } } test("test XSD validation with validation error") {