Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Convert simple_copy to use the snitch runtime dialect from xdsl #116

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

jorendumoulin
Copy link
Contributor

@jorendumoulin jorendumoulin commented Apr 17, 2024

First generate the snrt op:

builtin.module {
  func.func public @simple_copy(%arg0 : memref<?xi32>, %arg1 : memref<?xi32>) {
    %0 = func.call @snax_is_dm_core() : () -> i1
    "scf.if"(%0) ({
      %1 = arith.constant 0 : index
      %2 = "memref.dim"(%arg0, %1) : (memref<?xi32>, index) -> index
      %3 = arith.constant 4 : index
      %4 = arith.muli %2, %3 : index
      %5 = "memref.extract_aligned_pointer_as_index"(%arg0) : (memref<?xi32>) -> index
      %6 = "memref.extract_aligned_pointer_as_index"(%arg1) : (memref<?xi32>) -> index
      %7 = builtin.unrealized_conversion_cast %5 : index to i32
      %8 = builtin.unrealized_conversion_cast %6 : index to i32
      %9 = builtin.unrealized_conversion_cast %4 : index to i32
      %10 = "snrt.dma_start_1d"(%8, %7, %9) : (i32, i32, i32) -> i32
      scf.yield
    }, {
    }) : (i1) -> ()
    func.return
  }
  func.func private @snax_is_dm_core() -> i1
}

convert-snrt-to-riscv:

builtin.module {
  func.func public @simple_copy(%arg0 : memref<?xi32>, %arg1 : memref<?xi32>) {
    %0 = func.call @snax_is_dm_core() : () -> i1
    "scf.if"(%0) ({
      %1 = arith.constant 0 : index
      %2 = "memref.dim"(%arg0, %1) : (memref<?xi32>, index) -> index
      %3 = arith.constant 4 : index
      %4 = arith.muli %2, %3 : index
      %5 = "memref.extract_aligned_pointer_as_index"(%arg0) : (memref<?xi32>) -> index
      %6 = "memref.extract_aligned_pointer_as_index"(%arg1) : (memref<?xi32>) -> index
      %7 = builtin.unrealized_conversion_cast %5 : index to i32
      %8 = builtin.unrealized_conversion_cast %6 : index to i32
      %9 = builtin.unrealized_conversion_cast %4 : index to i32
      %10 = riscv.get_register : () -> !riscv.reg<zero>
      %11 = builtin.unrealized_conversion_cast %8 : i32 to !riscv.reg<>
      %12 = builtin.unrealized_conversion_cast %7 : i32 to !riscv.reg<>
      %13 = builtin.unrealized_conversion_cast %9 : i32 to !riscv.reg<>
      riscv_snitch.dmsrc %12, %10 : (!riscv.reg<>, !riscv.reg<zero>) -> ()
      riscv_snitch.dmdst %11, %10 : (!riscv.reg<>, !riscv.reg<zero>) -> ()
      %14 = riscv_snitch.dmcpyi %13, 0 : (!riscv.reg<>) -> !riscv.reg<>
      %15 = builtin.unrealized_conversion_cast %14 : !riscv.reg<> to i32
      scf.yield
    }, {
    }) : (i1) -> ()
    func.return
  }
  func.func private @snax_is_dm_core() -> i1
}

convert-riscv-to-llvm:

builtin.module {
  func.func public @simple_copy(%arg0 : memref<?xi32>, %arg1 : memref<?xi32>) {
    %0 = func.call @snax_is_dm_core() : () -> i1
    "scf.if"(%0) ({
      %1 = arith.constant 0 : index
      %2 = "memref.dim"(%arg0, %1) : (memref<?xi32>, index) -> index
      %3 = arith.constant 4 : index
      %4 = arith.muli %2, %3 : index
      %5 = "memref.extract_aligned_pointer_as_index"(%arg0) : (memref<?xi32>) -> index
      %6 = "memref.extract_aligned_pointer_as_index"(%arg1) : (memref<?xi32>) -> index
      %7 = builtin.unrealized_conversion_cast %5 : index to i32
      %8 = builtin.unrealized_conversion_cast %6 : index to i32
      %9 = builtin.unrealized_conversion_cast %4 : index to i32
      %10 = riscv.get_register : () -> !riscv.reg<zero>
      %11 = builtin.unrealized_conversion_cast %8 : i32 to !riscv.reg<>
      %12 = builtin.unrealized_conversion_cast %7 : i32 to !riscv.reg<>
      %13 = builtin.unrealized_conversion_cast %9 : i32 to !riscv.reg<>
      %14 = builtin.unrealized_conversion_cast %12 : !riscv.reg<> to i32
      "llvm.inline_asm"(%14) <{"asm_string" = ".insn r 0x2b, 0, 0, x0, $0, x0", "constraints" = "r", "asm_dialect" = 0 : i64}> : (i32) -> ()
      %15 = builtin.unrealized_conversion_cast %11 : !riscv.reg<> to i32
      "llvm.inline_asm"(%15) <{"asm_string" = ".insn r 0x2b, 0, 1, x0, $0, x0", "constraints" = "r", "asm_dialect" = 0 : i64}> : (i32) -> ()
      %16 = builtin.unrealized_conversion_cast %13 : !riscv.reg<> to i32
      %17 = "llvm.inline_asm"(%16) <{"asm_string" = ".insn r 0x2b, 0, 2, $0, $1, 0", "constraints" = "=r,r", "asm_dialect" = 0 : i64}> : (i32) -> i32
      %18 = builtin.unrealized_conversion_cast %17 : i32 to !riscv.reg<>
      %19 = builtin.unrealized_conversion_cast %18 : !riscv.reg<> to i32
      scf.yield
    }, {
    }) : (i1) -> ()
    func.return
  }
  func.func private @snax_is_dm_core() -> i1
}

Final manual changes to make this work:

  • remove riscv.get_register op manually (dce does not remove this)
  • fix .insn r 0x2b, 0, 2, $0, $1, 0 -> dmcpyi and dmstati are not in a valid risc-v register format so insn complains about this :(. We can either hack this by using the an i-type instruction and combining the two constants into one, or by converting the imm5 to register name which index corresponds to the correct immediate
  • use has_side_attributes

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant