Skip to content

Commit

Permalink
Additional getters and setters for operation and builder (#263)
Browse files Browse the repository at this point in the history
Added additional getters and setters for `Operation` such that you can
get operands, results and attributes, and modify attributes. Most of
these new functions are not strictly necessary, but they can be more
convenient than the existing functions.

I'm using the `AttributeLike` trait such that any attribute-like type
can be used. I wonder if we should also use this as generic return type
for `Operation::attribute`?

I also created a `OperationOperand` type, but currently it is simply and
alias for `Value`. Should `OperationOperand` be a more complex type like
`OperationResult`, or should I remove `OperationOperand` and replace all
its occurences with `Value`?

If I'm not mistaken, operands, results, successors and attributes can be
borrowed by the builder, while ownership of regions is moved to MLIR
(hence the call to `forget`).

And one last thing: should I use `&str` instead of `impl AsRef<str>`, or
is it fine like this?

Please let me know if there's anything else I should change.
  • Loading branch information
Danacus authored Jul 25, 2023
1 parent 9354e9a commit efe7930
Show file tree
Hide file tree
Showing 3 changed files with 234 additions and 29 deletions.
4 changes: 4 additions & 0 deletions melior/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ pub enum Error {
},
InvokeFunction,
OperationResultExpected(String),
OperationAttributeExpected(String),
PositionOutOfBounds {
name: &'static str,
value: String,
Expand Down Expand Up @@ -43,6 +44,9 @@ impl Display for Error {
Self::OperationResultExpected(value) => {
write!(formatter, "operation result expected: {value}")
}
Self::OperationAttributeExpected(value) => {
write!(formatter, "attribute {value} expected")
}
Self::ParsePassPipeline(message) => {
write!(formatter, "failed to parse pass pipeline:\n{}", message)
}
Expand Down
244 changes: 215 additions & 29 deletions melior/src/ir/operation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,22 +7,25 @@ mod result;
pub use self::{
builder::OperationBuilder, printing_flags::OperationPrintingFlags, result::OperationResult,
};
use super::{BlockRef, Identifier, RegionRef};
use super::{Attribute, AttributeLike, BlockRef, Identifier, RegionRef, Value};
use crate::{
context::{Context, ContextRef},
utility::{print_callback, print_string_callback},
Error,
Error, StringRef,
};
use core::{
fmt,
mem::{forget, transmute},
};
use mlir_sys::{
mlirOperationClone, mlirOperationDestroy, mlirOperationDump, mlirOperationEqual,
mlirOperationGetBlock, mlirOperationGetContext, mlirOperationGetName,
mlirOperationGetNextInBlock, mlirOperationGetNumRegions, mlirOperationGetNumResults,
mlirOperationGetRegion, mlirOperationGetResult, mlirOperationPrint,
mlirOperationPrintWithFlags, mlirOperationVerify, MlirOperation,
mlirOperationGetAttribute, mlirOperationGetAttributeByName, mlirOperationGetBlock,
mlirOperationGetContext, mlirOperationGetName, mlirOperationGetNextInBlock,
mlirOperationGetNumAttributes, mlirOperationGetNumOperands, mlirOperationGetNumRegions,
mlirOperationGetNumResults, mlirOperationGetNumSuccessors, mlirOperationGetOperand,
mlirOperationGetRegion, mlirOperationGetResult, mlirOperationGetSuccessor, mlirOperationPrint,
mlirOperationPrintWithFlags, mlirOperationRemoveAttributeByName,
mlirOperationSetAttributeByName, mlirOperationVerify, MlirOperation,
};
use std::{
ffi::c_void,
Expand Down Expand Up @@ -55,50 +58,180 @@ impl<'c> Operation<'c> {
unsafe { BlockRef::from_option_raw(mlirOperationGetBlock(self.raw)) }
}

/// Gets the number of operands.
pub fn operand_count(&self) -> usize {
unsafe { mlirOperationGetNumOperands(self.raw) as usize }
}

/// Gets the operand at a position.
pub fn operand(&self, index: usize) -> Result<Value<'c, '_>, Error> {
if index < self.operand_count() {
unsafe {
Ok(Value::from_raw(mlirOperationGetOperand(
self.raw,
index as isize,
)))
}
} else {
Err(Error::PositionOutOfBounds {
name: "operation operand",
value: self.to_string(),
index,
})
}
}

/// Gets all operands.
pub fn operands(&self) -> impl Iterator<Item = Value<'c, '_>> {
(0..self.operand_count()).map(|index| self.operand(index).expect("valid operand index"))
}

/// Gets the number of results.
pub fn result_count(&self) -> usize {
unsafe { mlirOperationGetNumResults(self.raw) as usize }
}

/// Gets a result at a position.
pub fn result(&self, index: usize) -> Result<OperationResult<'c, '_>, Error> {
unsafe {
if index < self.result_count() {
if index < self.result_count() {
unsafe {
Ok(OperationResult::from_raw(mlirOperationGetResult(
self.raw,
index as isize,
)))
} else {
Err(Error::PositionOutOfBounds {
name: "operation result",
value: self.to_string(),
index,
})
}
} else {
Err(Error::PositionOutOfBounds {
name: "operation result",
value: self.to_string(),
index,
})
}
}

/// Gets a number of results.
pub fn result_count(&self) -> usize {
unsafe { mlirOperationGetNumResults(self.raw) as usize }
/// Gets all results.
pub fn results(&self) -> impl Iterator<Item = OperationResult<'c, '_>> {
(0..self.result_count()).map(|index| self.result(index).expect("valid result index"))
}

/// Gets the number of regions.
pub fn region_count(&self) -> usize {
unsafe { mlirOperationGetNumRegions(self.raw) as usize }
}

/// Gets a region at a position.
pub fn region(&self, index: usize) -> Result<RegionRef<'c, '_>, Error> {
unsafe {
if index < self.region_count() {
if index < self.region_count() {
unsafe {
Ok(RegionRef::from_raw(mlirOperationGetRegion(
self.raw,
index as isize,
)))
} else {
Err(Error::PositionOutOfBounds {
name: "region",
value: self.to_string(),
index,
})
}
} else {
Err(Error::PositionOutOfBounds {
name: "region",
value: self.to_string(),
index,
})
}
}

/// Gets a number of regions.
pub fn region_count(&self) -> usize {
unsafe { mlirOperationGetNumRegions(self.raw) as usize }
/// Gets all regions.
pub fn regions(&self) -> impl Iterator<Item = RegionRef<'c, '_>> {
(0..self.result_count()).map(|index| self.region(index).expect("valid result index"))
}

/// Gets the number of successors.
pub fn successor_count(&self) -> usize {
unsafe { mlirOperationGetNumSuccessors(self.raw) as usize }
}

/// Gets a successor at a position.
pub fn successor(&self, index: usize) -> Result<BlockRef<'c, '_>, Error> {
if index < self.successor_count() {
unsafe {
Ok(BlockRef::from_raw(mlirOperationGetSuccessor(
self.raw,
index as isize,
)))
}
} else {
Err(Error::PositionOutOfBounds {
name: "successor",
value: self.to_string(),
index,
})
}
}

/// Gets all successors.
pub fn successors(&self) -> impl Iterator<Item = BlockRef<'c, '_>> {
(0..self.successor_count())
.map(|index| self.successor(index).expect("valid successor index"))
}

/// Gets the number of attributes.
pub fn attribute_count(&self) -> usize {
unsafe { mlirOperationGetNumAttributes(self.raw) as usize }
}

/// Gets a attribute at a position.
pub fn attribute_at(&self, index: usize) -> Result<(Identifier<'c>, Attribute<'c>), Error> {
if index < self.attribute_count() {
unsafe {
let named_attribute = mlirOperationGetAttribute(self.raw, index as isize);
Ok((
Identifier::from_raw(named_attribute.name),
Attribute::from_raw(named_attribute.attribute),
))
}
} else {
Err(Error::PositionOutOfBounds {
name: "attribute",
value: self.to_string(),
index,
})
}
}

/// Gets all attributes.
pub fn attributes(&self) -> impl Iterator<Item = (Identifier<'c>, Attribute<'c>)> + '_ {
(0..self.attribute_count())
.map(|index| self.attribute_at(index).expect("valid attribute index"))
}

/// Gets a attribute with the given name.
pub fn attribute(&self, name: &str) -> Option<Attribute<'c>> {
unsafe {
Attribute::from_option_raw(mlirOperationGetAttributeByName(
self.raw,
StringRef::from(name).to_raw(),
))
}
}

/// Checks if the operation has a attribute with the given name.
pub fn has_attribute(&self, name: &str) -> bool {
self.attribute(name).is_some()
}

/// Sets the attribute with the given name to the given attribute.
pub fn set_attribute(&mut self, name: &str, attribute: &Attribute<'c>) {
unsafe {
mlirOperationSetAttributeByName(
self.raw,
StringRef::from(name).to_raw(),
attribute.to_raw(),
)
}
}

/// Removes the attribute with the given name.
pub fn remove_attribute(&mut self, name: &str) -> Result<(), Error> {
unsafe { mlirOperationRemoveAttributeByName(self.raw, StringRef::from(name).to_raw()) }
.then_some(())
.ok_or(Error::OperationAttributeExpected(name.into()))
}

/// Gets the next operation in the same block.
Expand Down Expand Up @@ -301,7 +434,7 @@ mod tests {
use super::*;
use crate::{
context::Context,
ir::{Block, Location},
ir::{attribute::StringAttribute, Block, Location, Type},
test::create_test_context,
};
use pretty_assertions::assert_eq;
Expand Down Expand Up @@ -382,6 +515,59 @@ mod tests {
);
}

#[test]
fn operands() {
let context = create_test_context();
context.set_allow_unregistered_dialects(true);

let location = Location::unknown(&context);
let r#type = Type::index(&context);
let block = Block::new(&[(r#type, location)]);
let argument: Value = block.argument(0).unwrap().into();

let operands = vec![argument.clone(), argument.clone(), argument.clone()];
let operation = OperationBuilder::new("foo", Location::unknown(&context))
.add_operands(&operands)
.build();

assert_eq!(
operation.operands().skip(1).collect::<Vec<_>>(),
vec![argument.clone(), argument.clone()]
);
}

#[test]
fn attribute() {
let context = create_test_context();
context.set_allow_unregistered_dialects(true);

let mut operation = OperationBuilder::new("foo", Location::unknown(&context))
.add_attributes(&[(
Identifier::new(&context, "foo"),
StringAttribute::new(&context, "bar").into(),
)])
.build();
assert!(operation.has_attribute("foo"));
assert_eq!(
operation.attribute("foo").map(|a| a.to_string()),
Some("\"bar\"".into())
);
assert!(operation.remove_attribute("foo").is_ok());
assert!(operation.remove_attribute("foo").is_err());
operation.set_attribute("foo", &StringAttribute::new(&context, "foo").into());
assert_eq!(
operation.attribute("foo").map(|a| a.to_string()),
Some("\"foo\"".into())
);
assert_eq!(
operation.attributes().next(),
Some((
Identifier::new(&context, "foo"),
StringAttribute::new(&context, "foo").into()
))
)
}

#[test]
fn clone() {
let context = create_test_context();
Expand Down
15 changes: 15 additions & 0 deletions melior/src/ir/operation/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,21 @@ mod tests {
OperationBuilder::new("foo", Location::unknown(&context)).build();
}

#[test]
fn add_operands() {
let context = create_test_context();
context.set_allow_unregistered_dialects(true);

let location = Location::unknown(&context);
let r#type = Type::index(&context);
let block = Block::new(&[(r#type, location)]);
let argument = block.argument(0).unwrap().into();

OperationBuilder::new("foo", Location::unknown(&context))
.add_operands(&[argument])
.build();
}

#[test]
fn add_results() {
let context = create_test_context();
Expand Down

0 comments on commit efe7930

Please sign in to comment.