Skip to content

Commit

Permalink
Zune: Minor reafactoring in mod.rs
Browse files Browse the repository at this point in the history
  • Loading branch information
saidinesh5 committed Jun 13, 2022
1 parent 88048cf commit 1b86369
Showing 1 changed file with 83 additions and 78 deletions.
161 changes: 83 additions & 78 deletions crates/runtime/src/zune/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -334,89 +334,94 @@ fn instantiate_nodes(

for (stage_name, stage) in pipeline {
// Collect each output tensor into tensors
match stage {
// Models are handled on the host side, so we treat them separately
Stage::Capability(stage) => {
let wasm =
read_zip_resource_by_path(&stage.capability.to_string())
.context("Unable to load the capability")?;

let pb = ProcBlockNode::load(
&stage_name,
&wasm,
&runtime,
&input_tensors,
&output_tensors,
)?;
nodes.insert(stage_name.to_string(), Box::new(pb));
},
Stage::Model(stage) => {
// Instantiating the model's inference context here because that
// way model_data gets deallocated once we are done with it
// This way memory usage is under control
let model_data =
read_zip_resource_by_path(&stage.model.to_string())
.with_context(|| {
format!(
"Unable to read model from zune {}",
stage.model
)
})?;

let model_format =
stage.args.get("model-format").map(|f| f.to_string());
let node = load_model(
&model_data,
model_format.as_deref().unwrap_or("tflite"),
stage_name,
stage,
shared_state,
&input_tensors,
&output_tensors,
)?;
nodes.insert(stage_name.to_string(), node);
},
Stage::ProcBlock(stage) => {
let wasm =
read_zip_resource_by_path(&stage.proc_block.to_string())
.context("Unable to load the proc_block")?;

let pb = ProcBlockNode::load(
&stage_name,
&wasm,
&runtime,
&input_tensors,
&output_tensors,
)?;
nodes.insert(stage_name.to_string(), Box::new(pb));
},
Stage::Out(stage) => {
shared_state
.lock()
.unwrap()
.graph_contexts
.get_mut(stage_name)
.and_then(|c| {
for input in stage.inputs.iter() {
let tensor_key = key(&input.name, input.index);
let tensor_id = output_tensors.get(&tensor_key).copied();
c.input_tensors.insert(tensor_key,
TensorConstraint {
tensor_id,
element_type: ElementType::U8,
dimensions: Dimensions::Dynamic
}
);
}
Some(())
});
}, // Do nothing for capabilities/outputs
}
instantiate_node(stage, &mut read_zip_resource_by_path, stage_name, &runtime, &input_tensors, &output_tensors, &mut nodes, shared_state)
.with_context(|| format!("Unable to load node \"{stage_name}\""))?;
}

Ok(nodes)
}

fn instantiate_node(stage: &Stage, read_zip_resource_by_path: &mut impl FnMut(&str) -> Result<Vec<u8>, Error>, stage_name: &String, runtime: &Runtime, input_tensors: &IndexMap<String, usize>, output_tensors: &IndexMap<String, usize>, nodes: &mut IndexMap<String, Box<dyn Node>>, shared_state: &Arc<Mutex<State>>) -> Result<(), Error> {
Ok(match stage {
// Models are handled on the host side, so we treat them separately
Stage::Capability(stage) => {
let wasm =
read_zip_resource_by_path(&stage.capability.to_string())
.context("Unable to load the capability")?;

let pb = ProcBlockNode::load(
&stage_name,
&wasm,
runtime,
input_tensors,
output_tensors,
)?;
nodes.insert(stage_name.to_string(), Box::new(pb));
},
Stage::Model(stage) => {
// Instantiating the model's inference context here because that
// way model_data gets deallocated once we are done with it
// This way memory usage is under control
let model_data =
read_zip_resource_by_path(&stage.model.to_string())
.with_context(|| {
format!(
"Unable to read model from zune {}",
stage.model
)
})?;

let model_format =
stage.args.get("model-format").map(|f| f.to_string());
let node = load_model(
&model_data,
model_format.as_deref().unwrap_or("tflite"),
stage_name,
stage,
shared_state,
input_tensors,
output_tensors,
)?;
nodes.insert(stage_name.to_string(), node);
},
Stage::ProcBlock(stage) => {
let wasm =
read_zip_resource_by_path(&stage.proc_block.to_string())
.context("Unable to load the proc_block")?;

let pb = ProcBlockNode::load(
&stage_name,
&wasm,
runtime,
input_tensors,
output_tensors,
)?;
nodes.insert(stage_name.to_string(), Box::new(pb));
},
Stage::Out(stage) => {
shared_state
.lock()
.unwrap()
.graph_contexts
.get_mut(stage_name)
.and_then(|c| {
for input in stage.inputs.iter() {
let tensor_key = key(&input.name, input.index);
let tensor_id = output_tensors.get(&tensor_key).copied();
c.input_tensors.insert(tensor_key,
TensorConstraint {
tensor_id,
element_type: ElementType::U8,
dimensions: Dimensions::Dynamic
}
);
}
Some(())
});
}, // Do nothing for capabilities/outputs
})
}

fn load_model(
model_data: &[u8],
model_format: &str,
Expand Down

0 comments on commit 1b86369

Please sign in to comment.