diff --git a/src/jit.rs b/src/jit.rs index 1bf8d7cf..d84a7a4c 100644 --- a/src/jit.rs +++ b/src/jit.rs @@ -1169,7 +1169,7 @@ impl<'a, V: Verifier, C: ContextObject> JitCompiler<'a, V, C> { // dst-in RAX RAX RAX RAX RAX RAX RAX // dst-out RAX RDX RDX RAX RAX RDX RDX - let signed = (opc & ebpf::BPF_ALU_OP_MASK) == ebpf::BPF_SDIV; + let signed = (opc & ebpf::BPF_ALU_OP_MASK) == ebpf::BPF_MUL || (opc & ebpf::BPF_ALU_OP_MASK) == ebpf::BPF_SDIV; let division = (opc & ebpf::BPF_ALU_OP_MASK) != ebpf::BPF_MUL; let alt_dst = (opc & ebpf::BPF_ALU_OP_MASK) == ebpf::BPF_MOD; let size = if (opc & ebpf::BPF_CLS_MASK) == ebpf::BPF_ALU64 { OperandSize::S64 } else { OperandSize::S32 }; @@ -1221,10 +1221,12 @@ impl<'a, V: Verifier, C: ContextObject> JitCompiler<'a, V, C> { if dst != RDX { self.emit_ins(X86Instruction::push(RDX, None)); } - if signed { - self.emit_ins(X86Instruction::sign_extend_rax_rdx(size)); - } else if division { - self.emit_ins(X86Instruction::alu(size, 0x31, RDX, RDX, 0, None)); // RDX = 0 + if division { + if signed { + self.emit_ins(X86Instruction::sign_extend_rax_rdx(size)); + } else { + self.emit_ins(X86Instruction::alu(size, 0x31, RDX, RDX, 0, None)); // RDX = 0 + } } self.emit_ins(X86Instruction::alu(size, 0xf7, 0x4 | (division as u8) << 1 | signed as u8, R11, 0, None));