-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
merge branch "faster-xml-writer" into trunk
- Loading branch information
Showing
1 changed file
with
158 additions
and
192 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,227 +1,193 @@ | ||
use quick_xml::se::to_writer_with_root; | ||
use serde::{Serialize, Serializer}; | ||
use std::collections::HashMap; | ||
use std::fmt::Write; | ||
use std::sync::mpsc::Receiver; | ||
|
||
use crate::elements::{Element, ElementType, Member, Metadata, SimpleElementType}; | ||
|
||
fn serialize_simple_element_type<S>( | ||
value: &Option<SimpleElementType>, | ||
serializer: S, | ||
) -> Result<S::Ok, S::Error> | ||
use quick_xml::escape::escape; | ||
use rayon::prelude::*; | ||
use std::fmt::{Error, Write}; | ||
use std::sync::mpsc::{channel, Receiver}; | ||
|
||
use crate::elements::{Element, ElementType, Metadata, SimpleElementType}; | ||
use crate::threadpools::WRITER_THREAD_POOL; | ||
|
||
// wrapper struct that implements std::fmt::Write for any type | ||
// that implements std::io::Write | ||
struct ToFmtWrite<T>(pub T); | ||
|
||
impl<T> Write for ToFmtWrite<T> | ||
where | ||
S: Serializer, | ||
T: std::io::Write, | ||
{ | ||
match value { | ||
Some(SimpleElementType::Node) => serializer.serialize_str("node"), | ||
Some(SimpleElementType::Way) => serializer.serialize_str("way"), | ||
Some(SimpleElementType::Relation) => serializer.serialize_str("relation"), | ||
None => serializer.serialize_none(), | ||
fn write_str(&mut self, s: &str) -> std::fmt::Result { | ||
self.0.write_all(s.as_bytes()).map_err(|_| std::fmt::Error) | ||
} | ||
} | ||
|
||
#[derive(Serialize)] | ||
#[serde(remote = "Member")] | ||
struct MemberDef { | ||
#[serde(rename = "@type", serialize_with = "serialize_simple_element_type")] | ||
t: Option<SimpleElementType>, | ||
#[serde(rename = "@ref")] | ||
id: i64, | ||
#[serde(rename = "@role")] | ||
role: Option<String>, | ||
} | ||
fn create_header(metadata: Metadata) -> String { | ||
let mut header = String::new(); | ||
header.push_str("<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n<osm version=\"0.6\""); | ||
|
||
#[derive(Serialize)] | ||
struct XmlTags { | ||
#[serde(rename = "@k")] | ||
k: String, | ||
#[serde(rename = "@v")] | ||
v: String, | ||
} | ||
macro_rules! append_attribute { | ||
($attr:ident) => { | ||
if let Some($attr) = &metadata.$attr { | ||
header.push_str(concat!(" ", stringify!($attr), "=\"")); | ||
header.push_str($attr); | ||
header.push('\"'); | ||
} | ||
}; | ||
} | ||
|
||
#[derive(Serialize)] | ||
pub struct XmlElementMeta { | ||
#[serde(rename = "@id")] | ||
id: i64, | ||
#[serde(rename = "@user")] | ||
user: Option<String>, | ||
#[serde(rename = "@uid")] | ||
uid: Option<i32>, | ||
#[serde(rename = "@visible")] | ||
visible: bool, | ||
#[serde(rename = "@version")] | ||
version: Option<i32>, | ||
#[serde(rename = "@changeset")] | ||
changeset: Option<i64>, | ||
#[serde(rename = "@timestamp")] | ||
timestamp: Option<String>, | ||
} | ||
append_attribute!(copyright); | ||
append_attribute!(generator); | ||
append_attribute!(license); | ||
append_attribute!(timestamp); | ||
append_attribute!(version); | ||
|
||
#[derive(Serialize)] | ||
struct XmlNode { | ||
#[serde(rename = "@lat")] | ||
lat: f64, | ||
#[serde(rename = "@lon")] | ||
lon: f64, | ||
#[serde(flatten)] | ||
meta: XmlElementMeta, | ||
#[serde(default, rename = "tag")] | ||
tags: Vec<XmlTags>, | ||
header.push_str(">\n"); | ||
header | ||
} | ||
|
||
#[derive(Serialize)] | ||
#[serde(rename = "nd")] | ||
struct XmlWayNode { | ||
#[serde(rename = "@ref")] | ||
nd_ref: i64, | ||
} | ||
fn append_serialized_metadata(base: &mut String, element: &Element) { | ||
base.push_str(" id=\""); | ||
base.push_str(&lexical::to_string(element.id)); | ||
base.push('\"'); | ||
|
||
#[derive(Serialize)] | ||
struct XmlWay { | ||
#[serde(flatten)] | ||
meta: XmlElementMeta, | ||
nd: Vec<XmlWayNode>, | ||
#[serde(default, rename = "tag")] | ||
tags: Vec<XmlTags>, | ||
} | ||
if let Some(c) = element.changeset { | ||
base.push_str(" changeset=\""); | ||
base.push_str(&lexical::to_string(c)); | ||
base.push('\"'); | ||
} | ||
|
||
fn serialize_member_vec<S: Serializer>(v: &[Member], serializer: S) -> Result<S::Ok, S::Error> { | ||
#[derive(Serialize)] | ||
struct Wrapper<'a>(#[serde(with = "MemberDef")] &'a Member); | ||
if let Some(t) = &element.timestamp { | ||
base.push_str(" timestamp=\""); | ||
base.push_str(t); | ||
base.push('\"'); | ||
} | ||
|
||
v.iter() | ||
.map(Wrapper) | ||
.collect::<Vec<_>>() | ||
.serialize(serializer) | ||
} | ||
if let Some(u) = element.uid { | ||
base.push_str(" uid=\""); | ||
base.push_str(&lexical::to_string(u)); | ||
base.push('\"'); | ||
} | ||
|
||
#[derive(Serialize)] | ||
struct XmlRelation { | ||
#[serde(flatten)] | ||
meta: XmlElementMeta, | ||
#[serde(serialize_with = "serialize_member_vec")] | ||
member: Vec<Member>, | ||
#[serde(default, rename = "tag")] | ||
tags: Vec<XmlTags>, | ||
} | ||
if let Some(u) = &element.user { | ||
base.push_str(" user=\""); | ||
base.push_str(u); | ||
base.push('\"'); | ||
} | ||
|
||
#[derive(Serialize)] | ||
pub struct XmlMetadata { | ||
#[serde(rename = "@version", skip_serializing_if = "Option::is_none")] | ||
pub version: Option<String>, | ||
#[serde(rename = "@generator", skip_serializing_if = "Option::is_none")] | ||
pub generator: Option<String>, | ||
#[serde(rename = "@copyright", skip_serializing_if = "Option::is_none")] | ||
pub copyright: Option<String>, | ||
#[serde(rename = "@license", skip_serializing_if = "Option::is_none")] | ||
pub license: Option<String>, | ||
if element.visible == Some(true) { | ||
base.push_str(" visible=\"true\""); | ||
} else if element.visible == Some(false) { | ||
base.push_str(" visible=\"false\""); | ||
} | ||
} | ||
|
||
#[derive(Serialize)] | ||
struct OsmXmlDocument { | ||
#[serde(flatten)] | ||
metadata: XmlMetadata, | ||
#[serde(default)] | ||
node: Vec<XmlNode>, | ||
#[serde(default)] | ||
way: Vec<XmlWay>, | ||
#[serde(default)] | ||
relation: Vec<XmlRelation>, | ||
fn append_serialized_tags(base: &mut String, element: &Element) { | ||
for (k, v) in &element.tags { | ||
base.push_str(" <tag k=\""); | ||
base.push_str(k); | ||
base.push_str("\" v=\""); | ||
base.push_str(v); | ||
base.push_str("\"/>\n"); | ||
} | ||
} | ||
|
||
struct ToFmtWrite<T>(pub T); | ||
fn append_serialized_element(base: &mut String, element: Element) { | ||
match &element.element_type { | ||
ElementType::Node { lat, lon } => { | ||
base.push_str(" <node lat=\""); | ||
base.push_str(&lexical::to_string(*lat)); | ||
base.push_str("\" lon=\""); | ||
base.push_str(&lexical::to_string(*lon)); | ||
base.push_str("\" "); | ||
|
||
append_serialized_metadata(base, &element); | ||
|
||
if element.tags.is_empty() { | ||
base.push('>'); | ||
append_serialized_tags(base, &element); | ||
base.push_str(" </node>\n"); | ||
} else { | ||
base.push_str("/>\n") | ||
} | ||
} | ||
ElementType::Way { nodes } => { | ||
// finish "type": "way", then start nodes dict | ||
base.push_str(" <way"); | ||
append_serialized_metadata(base, &element); | ||
base.push_str(">\n"); | ||
|
||
impl<T> Write for ToFmtWrite<T> | ||
where | ||
T: std::io::Write, | ||
{ | ||
fn write_str(&mut self, s: &str) -> std::fmt::Result { | ||
self.0.write_all(s.as_bytes()).map_err(|_| std::fmt::Error) | ||
} | ||
} | ||
for n in nodes { | ||
base.push_str(" <nd ref=\""); | ||
base.push_str(&lexical::to_string(*n)); | ||
base.push_str("\"/>\n") | ||
} | ||
|
||
fn convert_tags(element_tags: HashMap<String, String>) -> Vec<XmlTags> { | ||
element_tags | ||
.into_iter() | ||
.map(|(k, v)| XmlTags { k, v }) | ||
.collect() | ||
} | ||
append_serialized_tags(base, &element); | ||
|
||
fn convert_nodes(way_nodes: Vec<i64>) -> Vec<XmlWayNode> { | ||
way_nodes | ||
.into_iter() | ||
.map(|nd_ref| XmlWayNode { nd_ref }) | ||
.collect() | ||
base.push_str(" </way>\n"); | ||
} | ||
ElementType::Relation { members } => { | ||
base.push_str(" <relation"); | ||
append_serialized_metadata(base, &element); | ||
base.push_str(">\n"); | ||
|
||
for m in members { | ||
base.push_str(" <member "); | ||
match m.t { | ||
Some(SimpleElementType::Node) => base.push_str("type=\"node\""), | ||
Some(SimpleElementType::Way) => base.push_str("type=\"way\""), | ||
Some(SimpleElementType::Relation) => base.push_str("type=\"relation\""), | ||
None => (), | ||
} | ||
|
||
base.push_str(" ref=\""); | ||
base.push_str(&lexical::to_string(m.id)); | ||
base.push_str("\" role=\""); | ||
if let Some(ref r) = m.role { | ||
base.push_str(&escape(r.as_str())); | ||
} | ||
base.push_str("\"/>\n"); | ||
} | ||
base.push_str(" </relation>\n"); | ||
} | ||
} | ||
} | ||
|
||
fn split_and_convert_elements<I>( | ||
received_elements: I, | ||
) -> (Vec<XmlNode>, Vec<XmlWay>, Vec<XmlRelation>) | ||
where | ||
I: Iterator<Item = Element>, | ||
{ | ||
let mut nodes = Vec::new(); | ||
let mut ways = Vec::new(); | ||
let mut relations = Vec::new(); | ||
for e in received_elements { | ||
let meta = XmlElementMeta { | ||
id: e.id, | ||
user: e.user, | ||
uid: e.uid, | ||
visible: e.visible.unwrap_or(true), // TODO: better default behavior? | ||
version: e.version, | ||
changeset: e.changeset, | ||
timestamp: e.timestamp, | ||
}; | ||
let tags = convert_tags(e.tags); | ||
match e.element_type { | ||
ElementType::Node { lat, lon } => nodes.push(XmlNode { | ||
lat, | ||
lon, | ||
meta, | ||
tags, | ||
}), | ||
ElementType::Way { nodes } => ways.push(XmlWay { | ||
meta, | ||
nd: convert_nodes(nodes), | ||
tags, | ||
}), | ||
ElementType::Relation { members } => relations.push(XmlRelation { | ||
meta, | ||
member: members, | ||
tags, | ||
}), | ||
} | ||
fn serialize_chunk(chunk: Vec<Element>) -> Result<String, Error> { | ||
let mut output = String::new(); | ||
for element in chunk { | ||
append_serialized_element(&mut output, element); | ||
} | ||
(nodes, ways, relations) | ||
Ok(output) | ||
} | ||
|
||
pub fn write_xml<D: std::io::Write>(receiver: Receiver<Vec<Element>>, metadata: Metadata, dest: D) { | ||
let (node, way, relation) = split_and_convert_elements(receiver.iter().flatten()); | ||
|
||
let xml_osm_document = OsmXmlDocument { | ||
metadata: XmlMetadata { | ||
version: metadata.version, | ||
generator: metadata.generator, | ||
copyright: metadata.copyright, | ||
license: metadata.license, | ||
}, | ||
node, | ||
way, | ||
relation, | ||
}; | ||
|
||
let mut writer = ToFmtWrite(dest); | ||
|
||
let (output_sender, output_receiver) = channel(); | ||
WRITER_THREAD_POOL.install(move || { | ||
receiver | ||
.into_iter() | ||
.par_bridge() | ||
.map(serialize_chunk) | ||
.map(|result| result.expect("Failed to serialize chunk")) | ||
.for_each(|s| match output_sender.clone().send(s) { | ||
Ok(_) => (), | ||
Err(e) => panic!("Error passing output chunk between threads: {e:?}"), | ||
}); | ||
}); | ||
|
||
let header = create_header(metadata); | ||
|
||
writer | ||
.write_str("<?xml version=\"1.0\" encoding=\"UTF-8\"?>") | ||
.write_str(&header) | ||
.expect("Unable to write header to XML file!"); | ||
|
||
match to_writer_with_root(writer, "osm", &xml_osm_document) { | ||
Ok(_) => (), | ||
Err(e) => { | ||
panic!("XML serialization error: {e:?}"); | ||
} | ||
for output_string in output_receiver { | ||
writer | ||
.write_str(&output_string) | ||
.expect("Failed to write to output"); | ||
} | ||
|
||
writer | ||
.write_str("</osm>\n") | ||
.expect("Couldn't write final closing curly brace to output."); | ||
} |