From badb73b921c95d51e91fc034c408008e54e46cf7 Mon Sep 17 00:00:00 2001 From: DianQK Date: Tue, 20 Feb 2024 21:47:27 +0800 Subject: [PATCH 1/7] Update matches_reduce_branches.rs --- ...h_i128_u128.MatchBranchSimplification.diff | 42 ++++ ...atch_i16_i8.MatchBranchSimplification.diff | 37 +++ ...atch_i8_i16.MatchBranchSimplification.diff | 37 +++ ..._i16_failed.MatchBranchSimplification.diff | 37 +++ ...atch_u8_i16.MatchBranchSimplification.diff | 32 +++ ...ch_u8_i16_2.MatchBranchSimplification.diff | 26 ++ ..._i16_failed.MatchBranchSimplification.diff | 32 +++ ...16_fallback.MatchBranchSimplification.diff | 31 +++ ...atch_u8_u16.MatchBranchSimplification.diff | 37 +++ ...ch_u8_u16_2.MatchBranchSimplification.diff | 37 +++ tests/mir-opt/matches_reduce_branches.rs | 223 +++++++++++++++++- 11 files changed, 567 insertions(+), 4 deletions(-) create mode 100644 tests/mir-opt/matches_reduce_branches.match_i128_u128.MatchBranchSimplification.diff create mode 100644 tests/mir-opt/matches_reduce_branches.match_i16_i8.MatchBranchSimplification.diff create mode 100644 tests/mir-opt/matches_reduce_branches.match_i8_i16.MatchBranchSimplification.diff create mode 100644 tests/mir-opt/matches_reduce_branches.match_i8_i16_failed.MatchBranchSimplification.diff create mode 100644 tests/mir-opt/matches_reduce_branches.match_u8_i16.MatchBranchSimplification.diff create mode 100644 tests/mir-opt/matches_reduce_branches.match_u8_i16_2.MatchBranchSimplification.diff create mode 100644 tests/mir-opt/matches_reduce_branches.match_u8_i16_failed.MatchBranchSimplification.diff create mode 100644 tests/mir-opt/matches_reduce_branches.match_u8_i16_fallback.MatchBranchSimplification.diff create mode 100644 tests/mir-opt/matches_reduce_branches.match_u8_u16.MatchBranchSimplification.diff create mode 100644 tests/mir-opt/matches_reduce_branches.match_u8_u16_2.MatchBranchSimplification.diff diff --git a/tests/mir-opt/matches_reduce_branches.match_i128_u128.MatchBranchSimplification.diff b/tests/mir-opt/matches_reduce_branches.match_i128_u128.MatchBranchSimplification.diff new file mode 100644 index 00000000000..1f20349fdec --- /dev/null +++ b/tests/mir-opt/matches_reduce_branches.match_i128_u128.MatchBranchSimplification.diff @@ -0,0 +1,42 @@ +- // MIR for `match_i128_u128` before MatchBranchSimplification ++ // MIR for `match_i128_u128` after MatchBranchSimplification + + fn match_i128_u128(_1: EnumAi128) -> u128 { + debug i => _1; + let mut _0: u128; + let mut _2: i128; + + bb0: { + _2 = discriminant(_1); + switchInt(move _2) -> [1: bb3, 2: bb4, 3: bb5, 340282366920938463463374607431768211455: bb2, otherwise: bb1]; + } + + bb1: { + unreachable; + } + + bb2: { + _0 = const core::num::::MAX; + goto -> bb6; + } + + bb3: { + _0 = const 1_u128; + goto -> bb6; + } + + bb4: { + _0 = const 2_u128; + goto -> bb6; + } + + bb5: { + _0 = const 3_u128; + goto -> bb6; + } + + bb6: { + return; + } + } + diff --git a/tests/mir-opt/matches_reduce_branches.match_i16_i8.MatchBranchSimplification.diff b/tests/mir-opt/matches_reduce_branches.match_i16_i8.MatchBranchSimplification.diff new file mode 100644 index 00000000000..4b435310916 --- /dev/null +++ b/tests/mir-opt/matches_reduce_branches.match_i16_i8.MatchBranchSimplification.diff @@ -0,0 +1,37 @@ +- // MIR for `match_i16_i8` before MatchBranchSimplification ++ // MIR for `match_i16_i8` after MatchBranchSimplification + + fn match_i16_i8(_1: EnumAi16) -> i8 { + debug i => _1; + let mut _0: i8; + let mut _2: i16; + + bb0: { + _2 = discriminant(_1); + switchInt(move _2) -> [65535: bb3, 2: bb4, 65533: bb2, otherwise: bb1]; + } + + bb1: { + unreachable; + } + + bb2: { + _0 = const -3_i8; + goto -> bb5; + } + + bb3: { + _0 = const -1_i8; + goto -> bb5; + } + + bb4: { + _0 = const 2_i8; + goto -> bb5; + } + + bb5: { + return; + } + } + diff --git a/tests/mir-opt/matches_reduce_branches.match_i8_i16.MatchBranchSimplification.diff b/tests/mir-opt/matches_reduce_branches.match_i8_i16.MatchBranchSimplification.diff new file mode 100644 index 00000000000..8a390736add --- /dev/null +++ b/tests/mir-opt/matches_reduce_branches.match_i8_i16.MatchBranchSimplification.diff @@ -0,0 +1,37 @@ +- // MIR for `match_i8_i16` before MatchBranchSimplification ++ // MIR for `match_i8_i16` after MatchBranchSimplification + + fn match_i8_i16(_1: EnumAi8) -> i16 { + debug i => _1; + let mut _0: i16; + let mut _2: i8; + + bb0: { + _2 = discriminant(_1); + switchInt(move _2) -> [255: bb3, 2: bb4, 253: bb2, otherwise: bb1]; + } + + bb1: { + unreachable; + } + + bb2: { + _0 = const -3_i16; + goto -> bb5; + } + + bb3: { + _0 = const -1_i16; + goto -> bb5; + } + + bb4: { + _0 = const 2_i16; + goto -> bb5; + } + + bb5: { + return; + } + } + diff --git a/tests/mir-opt/matches_reduce_branches.match_i8_i16_failed.MatchBranchSimplification.diff b/tests/mir-opt/matches_reduce_branches.match_i8_i16_failed.MatchBranchSimplification.diff new file mode 100644 index 00000000000..b0217792294 --- /dev/null +++ b/tests/mir-opt/matches_reduce_branches.match_i8_i16_failed.MatchBranchSimplification.diff @@ -0,0 +1,37 @@ +- // MIR for `match_i8_i16_failed` before MatchBranchSimplification ++ // MIR for `match_i8_i16_failed` after MatchBranchSimplification + + fn match_i8_i16_failed(_1: EnumAi8) -> i16 { + debug i => _1; + let mut _0: i16; + let mut _2: i8; + + bb0: { + _2 = discriminant(_1); + switchInt(move _2) -> [255: bb3, 2: bb4, 253: bb2, otherwise: bb1]; + } + + bb1: { + unreachable; + } + + bb2: { + _0 = const 3_i16; + goto -> bb5; + } + + bb3: { + _0 = const -1_i16; + goto -> bb5; + } + + bb4: { + _0 = const 2_i16; + goto -> bb5; + } + + bb5: { + return; + } + } + diff --git a/tests/mir-opt/matches_reduce_branches.match_u8_i16.MatchBranchSimplification.diff b/tests/mir-opt/matches_reduce_branches.match_u8_i16.MatchBranchSimplification.diff new file mode 100644 index 00000000000..72ad60956ab --- /dev/null +++ b/tests/mir-opt/matches_reduce_branches.match_u8_i16.MatchBranchSimplification.diff @@ -0,0 +1,32 @@ +- // MIR for `match_u8_i16` before MatchBranchSimplification ++ // MIR for `match_u8_i16` after MatchBranchSimplification + + fn match_u8_i16(_1: EnumAu8) -> i16 { + debug i => _1; + let mut _0: i16; + let mut _2: u8; + + bb0: { + _2 = discriminant(_1); + switchInt(move _2) -> [1: bb3, 2: bb2, otherwise: bb1]; + } + + bb1: { + unreachable; + } + + bb2: { + _0 = const 2_i16; + goto -> bb4; + } + + bb3: { + _0 = const 1_i16; + goto -> bb4; + } + + bb4: { + return; + } + } + diff --git a/tests/mir-opt/matches_reduce_branches.match_u8_i16_2.MatchBranchSimplification.diff b/tests/mir-opt/matches_reduce_branches.match_u8_i16_2.MatchBranchSimplification.diff new file mode 100644 index 00000000000..3333cd765a8 --- /dev/null +++ b/tests/mir-opt/matches_reduce_branches.match_u8_i16_2.MatchBranchSimplification.diff @@ -0,0 +1,26 @@ +- // MIR for `match_u8_i16_2` before MatchBranchSimplification ++ // MIR for `match_u8_i16_2` after MatchBranchSimplification + + fn match_u8_i16_2(_1: EnumAu8) -> i16 { + let mut _0: i16; + let mut _2: u8; + + bb0: { + _2 = discriminant(_1); + switchInt(_2) -> [1: bb3, 2: bb1, otherwise: bb2]; + } + + bb1: { + _0 = const 2_i16; + goto -> bb3; + } + + bb2: { + unreachable; + } + + bb3: { + return; + } + } + diff --git a/tests/mir-opt/matches_reduce_branches.match_u8_i16_failed.MatchBranchSimplification.diff b/tests/mir-opt/matches_reduce_branches.match_u8_i16_failed.MatchBranchSimplification.diff new file mode 100644 index 00000000000..6da19e46dab --- /dev/null +++ b/tests/mir-opt/matches_reduce_branches.match_u8_i16_failed.MatchBranchSimplification.diff @@ -0,0 +1,32 @@ +- // MIR for `match_u8_i16_failed` before MatchBranchSimplification ++ // MIR for `match_u8_i16_failed` after MatchBranchSimplification + + fn match_u8_i16_failed(_1: EnumAu8) -> i16 { + debug i => _1; + let mut _0: i16; + let mut _2: u8; + + bb0: { + _2 = discriminant(_1); + switchInt(move _2) -> [1: bb3, 2: bb2, otherwise: bb1]; + } + + bb1: { + unreachable; + } + + bb2: { + _0 = const 3_i16; + goto -> bb4; + } + + bb3: { + _0 = const 1_i16; + goto -> bb4; + } + + bb4: { + return; + } + } + diff --git a/tests/mir-opt/matches_reduce_branches.match_u8_i16_fallback.MatchBranchSimplification.diff b/tests/mir-opt/matches_reduce_branches.match_u8_i16_fallback.MatchBranchSimplification.diff new file mode 100644 index 00000000000..8fa497fe890 --- /dev/null +++ b/tests/mir-opt/matches_reduce_branches.match_u8_i16_fallback.MatchBranchSimplification.diff @@ -0,0 +1,31 @@ +- // MIR for `match_u8_i16_fallback` before MatchBranchSimplification ++ // MIR for `match_u8_i16_fallback` after MatchBranchSimplification + + fn match_u8_i16_fallback(_1: u8) -> i16 { + debug i => _1; + let mut _0: i16; + + bb0: { + switchInt(_1) -> [1: bb2, 2: bb3, otherwise: bb1]; + } + + bb1: { + _0 = const 3_i16; + goto -> bb4; + } + + bb2: { + _0 = const 1_i16; + goto -> bb4; + } + + bb3: { + _0 = const 2_i16; + goto -> bb4; + } + + bb4: { + return; + } + } + diff --git a/tests/mir-opt/matches_reduce_branches.match_u8_u16.MatchBranchSimplification.diff b/tests/mir-opt/matches_reduce_branches.match_u8_u16.MatchBranchSimplification.diff new file mode 100644 index 00000000000..043fdb197a3 --- /dev/null +++ b/tests/mir-opt/matches_reduce_branches.match_u8_u16.MatchBranchSimplification.diff @@ -0,0 +1,37 @@ +- // MIR for `match_u8_u16` before MatchBranchSimplification ++ // MIR for `match_u8_u16` after MatchBranchSimplification + + fn match_u8_u16(_1: EnumBu8) -> u16 { + debug i => _1; + let mut _0: u16; + let mut _2: u8; + + bb0: { + _2 = discriminant(_1); + switchInt(move _2) -> [1: bb3, 2: bb4, 5: bb2, otherwise: bb1]; + } + + bb1: { + unreachable; + } + + bb2: { + _0 = const 5_u16; + goto -> bb5; + } + + bb3: { + _0 = const 1_u16; + goto -> bb5; + } + + bb4: { + _0 = const 2_u16; + goto -> bb5; + } + + bb5: { + return; + } + } + diff --git a/tests/mir-opt/matches_reduce_branches.match_u8_u16_2.MatchBranchSimplification.diff b/tests/mir-opt/matches_reduce_branches.match_u8_u16_2.MatchBranchSimplification.diff new file mode 100644 index 00000000000..b47de6a52b7 --- /dev/null +++ b/tests/mir-opt/matches_reduce_branches.match_u8_u16_2.MatchBranchSimplification.diff @@ -0,0 +1,37 @@ +- // MIR for `match_u8_u16_2` before MatchBranchSimplification ++ // MIR for `match_u8_u16_2` after MatchBranchSimplification + + fn match_u8_u16_2(_1: EnumBu8) -> i16 { + let mut _0: i16; + let mut _2: u8; + + bb0: { + _2 = discriminant(_1); + switchInt(_2) -> [1: bb1, 2: bb2, 5: bb3, otherwise: bb4]; + } + + bb1: { + _0 = const 1_i16; + goto -> bb5; + } + + bb2: { + _0 = const 2_i16; + goto -> bb5; + } + + bb3: { + _0 = const 5_i16; + _0 = const 5_i16; + goto -> bb5; + } + + bb4: { + unreachable; + } + + bb5: { + return; + } + } + diff --git a/tests/mir-opt/matches_reduce_branches.rs b/tests/mir-opt/matches_reduce_branches.rs index 4bf14e5a7bd..2e7b7d4e600 100644 --- a/tests/mir-opt/matches_reduce_branches.rs +++ b/tests/mir-opt/matches_reduce_branches.rs @@ -1,18 +1,28 @@ -// skip-filecheck //@ unit-test: MatchBranchSimplification +#![feature(repr128)] +#![feature(core_intrinsics)] +#![feature(custom_mir)] + +use std::intrinsics::mir::*; // EMIT_MIR matches_reduce_branches.foo.MatchBranchSimplification.diff -// EMIT_MIR matches_reduce_branches.bar.MatchBranchSimplification.diff -// EMIT_MIR matches_reduce_branches.match_nested_if.MatchBranchSimplification.diff - fn foo(bar: Option<()>) { + // CHECK-LABEL: fn foo( + // CHECK: = Eq( + // CHECK: switchInt + // CHECK-NOT: switchInt if matches!(bar, None) { () } } +// EMIT_MIR matches_reduce_branches.bar.MatchBranchSimplification.diff fn bar(i: i32) -> (bool, bool, bool, bool) { + // CHECK-LABEL: fn bar( + // CHECK: = Ne( + // CHECK: = Eq( + // CHECK-NOT: switchInt let a; let b; let c; @@ -38,7 +48,10 @@ fn bar(i: i32) -> (bool, bool, bool, bool) { (a, b, c, d) } +// EMIT_MIR matches_reduce_branches.match_nested_if.MatchBranchSimplification.diff fn match_nested_if() -> bool { + // CHECK-LABEL: fn match_nested_if( + // CHECK-NOT: switchInt let val = match () { () if if if if true { true } else { false } { true } else { false } { true @@ -53,9 +66,211 @@ fn match_nested_if() -> bool { val } +#[repr(u8)] +enum EnumAu8 { + A = 1, + B = 2, +} + +// EMIT_MIR matches_reduce_branches.match_u8_i16.MatchBranchSimplification.diff +fn match_u8_i16(i: EnumAu8) -> i16 { + // CHECK-LABEL: fn match_u8_i16( + // CHECK: switchInt + match i { + EnumAu8::A => 1, + EnumAu8::B => 2, + } +} + +// EMIT_MIR matches_reduce_branches.match_u8_i16_2.MatchBranchSimplification.diff +// Check for different instruction lengths +#[custom_mir(dialect = "built")] +fn match_u8_i16_2(i: EnumAu8) -> i16 { + // CHECK-LABEL: fn match_u8_i16_2( + // CHECK: switchInt + mir!( + { + let a = Discriminant(i); + match a { + 1 => bb1, + 2 => bb2, + _ => unreachable_bb, + } + } + bb1 = { + Goto(ret) + } + bb2 = { + RET = 2; + Goto(ret) + } + unreachable_bb = { + Unreachable() + } + ret = { + Return() + } + ) +} + +// EMIT_MIR matches_reduce_branches.match_u8_i16_failed.MatchBranchSimplification.diff +fn match_u8_i16_failed(i: EnumAu8) -> i16 { + // CHECK-LABEL: fn match_u8_i16_failed( + // CHECK: switchInt + match i { + EnumAu8::A => 1, + EnumAu8::B => 3, + } +} + +// EMIT_MIR matches_reduce_branches.match_u8_i16_fallback.MatchBranchSimplification.diff +fn match_u8_i16_fallback(i: u8) -> i16 { + // CHECK-LABEL: fn match_u8_i16_fallback( + // CHECK: switchInt + match i { + 1 => 1, + 2 => 2, + _ => 3, + } +} + +#[repr(u8)] +enum EnumBu8 { + A = 1, + B = 2, + C = 5, +} + +// EMIT_MIR matches_reduce_branches.match_u8_u16.MatchBranchSimplification.diff +fn match_u8_u16(i: EnumBu8) -> u16 { + // CHECK-LABEL: fn match_u8_u16( + // CHECK: switchInt + match i { + EnumBu8::A => 1, + EnumBu8::B => 2, + EnumBu8::C => 5, + } +} + +// EMIT_MIR matches_reduce_branches.match_u8_u16_2.MatchBranchSimplification.diff +// Check for different instruction lengths +#[custom_mir(dialect = "built")] +fn match_u8_u16_2(i: EnumBu8) -> i16 { + // CHECK-LABEL: fn match_u8_u16_2( + // CHECK: switchInt + mir!( + { + let a = Discriminant(i); + match a { + 1 => bb1, + 2 => bb2, + 5 => bb5, + _ => unreachable_bb, + } + } + bb1 = { + RET = 1; + Goto(ret) + } + bb2 = { + RET = 2; + Goto(ret) + } + bb5 = { + RET = 5; + RET = 5; + Goto(ret) + } + unreachable_bb = { + Unreachable() + } + ret = { + Return() + } + ) +} + +#[repr(i8)] +enum EnumAi8 { + A = -1, + B = 2, + C = -3, +} + +// EMIT_MIR matches_reduce_branches.match_i8_i16.MatchBranchSimplification.diff +fn match_i8_i16(i: EnumAi8) -> i16 { + // CHECK-LABEL: fn match_i8_i16( + // CHECK: switchInt + match i { + EnumAi8::A => -1, + EnumAi8::B => 2, + EnumAi8::C => -3, + } +} + +// EMIT_MIR matches_reduce_branches.match_i8_i16_failed.MatchBranchSimplification.diff +fn match_i8_i16_failed(i: EnumAi8) -> i16 { + // CHECK-LABEL: fn match_i8_i16_failed( + // CHECK: switchInt + match i { + EnumAi8::A => -1, + EnumAi8::B => 2, + EnumAi8::C => 3, + } +} + +#[repr(i16)] +enum EnumAi16 { + A = -1, + B = 2, + C = -3, +} + +// EMIT_MIR matches_reduce_branches.match_i16_i8.MatchBranchSimplification.diff +fn match_i16_i8(i: EnumAi16) -> i8 { + // CHECK-LABEL: fn match_i16_i8( + // CHECK: switchInt + match i { + EnumAi16::A => -1, + EnumAi16::B => 2, + EnumAi16::C => -3, + } +} + +#[repr(i128)] +enum EnumAi128 { + A = 1, + B = 2, + C = 3, + D = -1, +} + +// EMIT_MIR matches_reduce_branches.match_i128_u128.MatchBranchSimplification.diff +fn match_i128_u128(i: EnumAi128) -> u128 { + // CHECK-LABEL: fn match_i128_u128( + // CHECK: switchInt + match i { + EnumAi128::A => 1, + EnumAi128::B => 2, + EnumAi128::C => 3, + EnumAi128::D => u128::MAX, + } +} + fn main() { let _ = foo(None); let _ = foo(Some(())); let _ = bar(0); let _ = match_nested_if(); + let _ = match_u8_i16(EnumAu8::A); + let _ = match_u8_i16_2(EnumAu8::A); + let _ = match_u8_i16_failed(EnumAu8::A); + let _ = match_u8_i16_fallback(1); + let _ = match_u8_u16(EnumBu8::A); + let _ = match_u8_u16_2(EnumBu8::A); + let _ = match_i8_i16(EnumAi8::A); + let _ = match_i8_i16_failed(EnumAi8::A); + let _ = match_i8_i16(EnumAi8::A); + let _ = match_i16_i8(EnumAi16::A); + let _ = match_i128_u128(EnumAi128::A); } From 7af74584533fddfc80b6c29394e5b8a088be68bc Mon Sep 17 00:00:00 2001 From: DianQK Date: Tue, 20 Feb 2024 21:47:42 +0800 Subject: [PATCH 2/7] Refactor `MatchBranchSimplification` --- .../rustc_mir_transform/src/match_branches.rs | 334 +++++++++++------- 1 file changed, 203 insertions(+), 131 deletions(-) diff --git a/compiler/rustc_mir_transform/src/match_branches.rs b/compiler/rustc_mir_transform/src/match_branches.rs index 6d4332793af..be1158683ac 100644 --- a/compiler/rustc_mir_transform/src/match_branches.rs +++ b/compiler/rustc_mir_transform/src/match_branches.rs @@ -1,11 +1,116 @@ +use rustc_index::IndexVec; use rustc_middle::mir::*; -use rustc_middle::ty::TyCtxt; +use rustc_middle::ty::{ParamEnv, Ty, TyCtxt}; use std::iter; use super::simplify::simplify_cfg; pub struct MatchBranchSimplification; +impl<'tcx> MirPass<'tcx> for MatchBranchSimplification { + fn is_enabled(&self, sess: &rustc_session::Session) -> bool { + sess.mir_opt_level() >= 1 + } + + fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + let def_id = body.source.def_id(); + let param_env = tcx.param_env_reveal_all_normalized(def_id); + + let bbs = body.basic_blocks.as_mut(); + let mut should_cleanup = false; + for bb_idx in bbs.indices() { + if !tcx.consider_optimizing(|| format!("MatchBranchSimplification {def_id:?} ")) { + continue; + } + + match bbs[bb_idx].terminator().kind { + TerminatorKind::SwitchInt { + discr: ref _discr @ (Operand::Copy(_) | Operand::Move(_)), + ref targets, + .. + // We require that the possible target blocks don't contain this block. + } if !targets.all_targets().contains(&bb_idx) => {} + // Only optimize switch int statements + _ => continue, + }; + + if SimplifyToIf.simplify(tcx, &mut body.local_decls, bbs, bb_idx, param_env) { + should_cleanup = true; + continue; + } + } + + if should_cleanup { + simplify_cfg(body); + } + } +} + +trait SimplifyMatch<'tcx> { + fn simplify( + &self, + tcx: TyCtxt<'tcx>, + local_decls: &mut IndexVec>, + bbs: &mut IndexVec>, + switch_bb_idx: BasicBlock, + param_env: ParamEnv<'tcx>, + ) -> bool { + let (discr, targets) = match bbs[switch_bb_idx].terminator().kind { + TerminatorKind::SwitchInt { ref discr, ref targets, .. } => (discr, targets), + _ => unreachable!(), + }; + + if !self.can_simplify(tcx, targets, param_env, bbs) { + return false; + } + + // Take ownership of items now that we know we can optimize. + let discr = discr.clone(); + let discr_ty = discr.ty(local_decls, tcx); + + // Introduce a temporary for the discriminant value. + let source_info = bbs[switch_bb_idx].terminator().source_info; + let discr_local = local_decls.push(LocalDecl::new(discr_ty, source_info.span)); + + // We already checked that first and second are different blocks, + // and bb_idx has a different terminator from both of them. + let new_stmts = self.new_stmts(tcx, targets, param_env, bbs, discr_local.clone(), discr_ty); + let (_, first) = targets.iter().next().unwrap(); + let (from, first) = bbs.pick2_mut(switch_bb_idx, first); + from.statements + .push(Statement { source_info, kind: StatementKind::StorageLive(discr_local) }); + from.statements.push(Statement { + source_info, + kind: StatementKind::Assign(Box::new((Place::from(discr_local), Rvalue::Use(discr)))), + }); + from.statements.extend(new_stmts); + from.statements + .push(Statement { source_info, kind: StatementKind::StorageDead(discr_local) }); + from.terminator_mut().kind = first.terminator().kind.clone(); + true + } + + fn can_simplify( + &self, + tcx: TyCtxt<'tcx>, + targets: &SwitchTargets, + param_env: ParamEnv<'tcx>, + bbs: &IndexVec>, + ) -> bool; + + fn new_stmts( + &self, + tcx: TyCtxt<'tcx>, + targets: &SwitchTargets, + param_env: ParamEnv<'tcx>, + bbs: &IndexVec>, + discr_local: Local, + discr_ty: Ty<'tcx>, + ) -> Vec>; +} + +struct SimplifyToIf; + /// If a source block is found that switches between two blocks that are exactly /// the same modulo const bool assignments (e.g., one assigns true another false /// to the same place), merge a target block statements into the source block, @@ -37,144 +142,111 @@ pub struct MatchBranchSimplification; /// goto -> bb3; /// } /// ``` +impl<'tcx> SimplifyMatch<'tcx> for SimplifyToIf { + fn can_simplify( + &self, + tcx: TyCtxt<'tcx>, + targets: &SwitchTargets, + param_env: ParamEnv<'tcx>, + bbs: &IndexVec>, + ) -> bool { + if targets.iter().len() != 1 { + return false; + } + // We require that the possible target blocks all be distinct. + let (_, first) = targets.iter().next().unwrap(); + let second = targets.otherwise(); + if first == second { + return false; + } + // Check that destinations are identical, and if not, then don't optimize this block + if bbs[first].terminator().kind != bbs[second].terminator().kind { + return false; + } -impl<'tcx> MirPass<'tcx> for MatchBranchSimplification { - fn is_enabled(&self, sess: &rustc_session::Session) -> bool { - sess.mir_opt_level() >= 1 + // Check that blocks are assignments of consts to the same place or same statement, + // and match up 1-1, if not don't optimize this block. + let first_stmts = &bbs[first].statements; + let second_stmts = &bbs[second].statements; + if first_stmts.len() != second_stmts.len() { + return false; + } + for (f, s) in iter::zip(first_stmts, second_stmts) { + match (&f.kind, &s.kind) { + // If two statements are exactly the same, we can optimize. + (f_s, s_s) if f_s == s_s => {} + + // If two statements are const bool assignments to the same place, we can optimize. + ( + StatementKind::Assign(box (lhs_f, Rvalue::Use(Operand::Constant(f_c)))), + StatementKind::Assign(box (lhs_s, Rvalue::Use(Operand::Constant(s_c)))), + ) if lhs_f == lhs_s + && f_c.const_.ty().is_bool() + && s_c.const_.ty().is_bool() + && f_c.const_.try_eval_bool(tcx, param_env).is_some() + && s_c.const_.try_eval_bool(tcx, param_env).is_some() => {} + + // Otherwise we cannot optimize. Try another block. + _ => return false, + } + } + true } - fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { - let def_id = body.source.def_id(); - let param_env = tcx.param_env_reveal_all_normalized(def_id); + fn new_stmts( + &self, + tcx: TyCtxt<'tcx>, + targets: &SwitchTargets, + param_env: ParamEnv<'tcx>, + bbs: &IndexVec>, + discr_local: Local, + discr_ty: Ty<'tcx>, + ) -> Vec> { + let (val, first) = targets.iter().next().unwrap(); + let second = targets.otherwise(); + // We already checked that first and second are different blocks, + // and bb_idx has a different terminator from both of them. + let first = &bbs[first]; + let second = &bbs[second]; - let bbs = body.basic_blocks.as_mut(); - let mut should_cleanup = false; - 'outer: for bb_idx in bbs.indices() { - if !tcx.consider_optimizing(|| format!("MatchBranchSimplification {def_id:?} ")) { - continue; - } + let new_stmts = iter::zip(&first.statements, &second.statements).map(|(f, s)| { + match (&f.kind, &s.kind) { + (f_s, s_s) if f_s == s_s => (*f).clone(), - let (discr, val, first, second) = match bbs[bb_idx].terminator().kind { - TerminatorKind::SwitchInt { - discr: ref discr @ (Operand::Copy(_) | Operand::Move(_)), - ref targets, - .. - } if targets.iter().len() == 1 => { - let (value, target) = targets.iter().next().unwrap(); - // We require that this block and the two possible target blocks all be - // distinct. - if target == targets.otherwise() - || bb_idx == target - || bb_idx == targets.otherwise() - { - continue; - } - (discr, value, target, targets.otherwise()) - } - // Only optimize switch int statements - _ => continue, - }; - - // Check that destinations are identical, and if not, then don't optimize this block - if bbs[first].terminator().kind != bbs[second].terminator().kind { - continue; - } - - // Check that blocks are assignments of consts to the same place or same statement, - // and match up 1-1, if not don't optimize this block. - let first_stmts = &bbs[first].statements; - let scnd_stmts = &bbs[second].statements; - if first_stmts.len() != scnd_stmts.len() { - continue; - } - for (f, s) in iter::zip(first_stmts, scnd_stmts) { - match (&f.kind, &s.kind) { - // If two statements are exactly the same, we can optimize. - (f_s, s_s) if f_s == s_s => {} - - // If two statements are const bool assignments to the same place, we can optimize. - ( - StatementKind::Assign(box (lhs_f, Rvalue::Use(Operand::Constant(f_c)))), - StatementKind::Assign(box (lhs_s, Rvalue::Use(Operand::Constant(s_c)))), - ) if lhs_f == lhs_s - && f_c.const_.ty().is_bool() - && s_c.const_.ty().is_bool() - && f_c.const_.try_eval_bool(tcx, param_env).is_some() - && s_c.const_.try_eval_bool(tcx, param_env).is_some() => {} - - // Otherwise we cannot optimize. Try another block. - _ => continue 'outer, - } - } - // Take ownership of items now that we know we can optimize. - let discr = discr.clone(); - let discr_ty = discr.ty(&body.local_decls, tcx); - - // Introduce a temporary for the discriminant value. - let source_info = bbs[bb_idx].terminator().source_info; - let discr_local = body.local_decls.push(LocalDecl::new(discr_ty, source_info.span)); - - // We already checked that first and second are different blocks, - // and bb_idx has a different terminator from both of them. - let (from, first, second) = bbs.pick3_mut(bb_idx, first, second); - - let new_stmts = iter::zip(&first.statements, &second.statements).map(|(f, s)| { - match (&f.kind, &s.kind) { - (f_s, s_s) if f_s == s_s => (*f).clone(), - - ( - StatementKind::Assign(box (lhs, Rvalue::Use(Operand::Constant(f_c)))), - StatementKind::Assign(box (_, Rvalue::Use(Operand::Constant(s_c)))), - ) => { - // From earlier loop we know that we are dealing with bool constants only: - let f_b = f_c.const_.try_eval_bool(tcx, param_env).unwrap(); - let s_b = s_c.const_.try_eval_bool(tcx, param_env).unwrap(); - if f_b == s_b { - // Same value in both blocks. Use statement as is. - (*f).clone() - } else { - // Different value between blocks. Make value conditional on switch condition. - let size = tcx.layout_of(param_env.and(discr_ty)).unwrap().size; - let const_cmp = Operand::const_from_scalar( - tcx, - discr_ty, - rustc_const_eval::interpret::Scalar::from_uint(val, size), - rustc_span::DUMMY_SP, - ); - let op = if f_b { BinOp::Eq } else { BinOp::Ne }; - let rhs = Rvalue::BinaryOp( - op, - Box::new((Operand::Copy(Place::from(discr_local)), const_cmp)), - ); - Statement { - source_info: f.source_info, - kind: StatementKind::Assign(Box::new((*lhs, rhs))), - } + ( + StatementKind::Assign(box (lhs, Rvalue::Use(Operand::Constant(f_c)))), + StatementKind::Assign(box (_, Rvalue::Use(Operand::Constant(s_c)))), + ) => { + // From earlier loop we know that we are dealing with bool constants only: + let f_b = f_c.const_.try_eval_bool(tcx, param_env).unwrap(); + let s_b = s_c.const_.try_eval_bool(tcx, param_env).unwrap(); + if f_b == s_b { + // Same value in both blocks. Use statement as is. + (*f).clone() + } else { + // Different value between blocks. Make value conditional on switch condition. + let size = tcx.layout_of(param_env.and(discr_ty)).unwrap().size; + let const_cmp = Operand::const_from_scalar( + tcx, + discr_ty, + rustc_const_eval::interpret::Scalar::from_uint(val, size), + rustc_span::DUMMY_SP, + ); + let op = if f_b { BinOp::Eq } else { BinOp::Ne }; + let rhs = Rvalue::BinaryOp( + op, + Box::new((Operand::Copy(Place::from(discr_local)), const_cmp)), + ); + Statement { + source_info: f.source_info, + kind: StatementKind::Assign(Box::new((*lhs, rhs))), } } - - _ => unreachable!(), } - }); - from.statements - .push(Statement { source_info, kind: StatementKind::StorageLive(discr_local) }); - from.statements.push(Statement { - source_info, - kind: StatementKind::Assign(Box::new(( - Place::from(discr_local), - Rvalue::Use(discr), - ))), - }); - from.statements.extend(new_stmts); - from.statements - .push(Statement { source_info, kind: StatementKind::StorageDead(discr_local) }); - from.terminator_mut().kind = first.terminator().kind.clone(); - should_cleanup = true; - } - - if should_cleanup { - simplify_cfg(body); - } + _ => unreachable!(), + } + }); + new_stmts.collect() } } From 1f061f47e2903e90651f63368e3ff0aebac8e3e6 Mon Sep 17 00:00:00 2001 From: DianQK Date: Tue, 20 Feb 2024 22:07:09 +0800 Subject: [PATCH 3/7] Transforms match into an assignment statement --- compiler/rustc_middle/src/mir/terminator.rs | 6 + .../rustc_mir_transform/src/match_branches.rs | 230 +++++++++++++++++- tests/codegen/match-optimized.rs | 4 +- ...h_i128_u128.MatchBranchSimplification.diff | 61 ++--- ...atch_u8_i16.MatchBranchSimplification.diff | 41 ++-- ...atch_u8_u16.MatchBranchSimplification.diff | 51 ++-- tests/mir-opt/matches_reduce_branches.rs | 12 +- ...stive_match.MatchBranchSimplification.diff | 41 ++-- ...ve_match_i8.MatchBranchSimplification.diff | 41 ++-- 9 files changed, 370 insertions(+), 117 deletions(-) diff --git a/compiler/rustc_middle/src/mir/terminator.rs b/compiler/rustc_middle/src/mir/terminator.rs index f116347cc2b..58a27c1f9ef 100644 --- a/compiler/rustc_middle/src/mir/terminator.rs +++ b/compiler/rustc_middle/src/mir/terminator.rs @@ -85,6 +85,12 @@ impl SwitchTargets { self.values.push(value); self.targets.insert(self.targets.len() - 1, bb); } + + /// Returns true if all targets (including the fallback target) are distinct. + #[inline] + pub fn is_distinct(&self) -> bool { + self.targets.iter().collect::>().len() == self.targets.len() + } } pub struct SwitchTargetsIter<'a> { diff --git a/compiler/rustc_mir_transform/src/match_branches.rs b/compiler/rustc_mir_transform/src/match_branches.rs index be1158683ac..e766c1ae0f6 100644 --- a/compiler/rustc_mir_transform/src/match_branches.rs +++ b/compiler/rustc_mir_transform/src/match_branches.rs @@ -1,6 +1,6 @@ use rustc_index::IndexVec; use rustc_middle::mir::*; -use rustc_middle::ty::{ParamEnv, Ty, TyCtxt}; +use rustc_middle::ty::{ParamEnv, ScalarInt, Ty, TyCtxt}; use std::iter; use super::simplify::simplify_cfg; @@ -38,6 +38,11 @@ impl<'tcx> MirPass<'tcx> for MatchBranchSimplification { should_cleanup = true; continue; } + if SimplifyToExp::default().simplify(tcx, &mut body.local_decls, bbs, bb_idx, param_env) + { + should_cleanup = true; + continue; + } } if should_cleanup { @@ -47,8 +52,10 @@ impl<'tcx> MirPass<'tcx> for MatchBranchSimplification { } trait SimplifyMatch<'tcx> { + /// Simplifies a match statement, returning true if the simplification succeeds, false otherwise. + /// Generic code is written here, and we generally don't need a custom implementation. fn simplify( - &self, + &mut self, tcx: TyCtxt<'tcx>, local_decls: &mut IndexVec>, bbs: &mut IndexVec>, @@ -72,9 +79,7 @@ trait SimplifyMatch<'tcx> { let source_info = bbs[switch_bb_idx].terminator().source_info; let discr_local = local_decls.push(LocalDecl::new(discr_ty, source_info.span)); - // We already checked that first and second are different blocks, - // and bb_idx has a different terminator from both of them. - let new_stmts = self.new_stmts(tcx, targets, param_env, bbs, discr_local.clone(), discr_ty); + let new_stmts = self.new_stmts(tcx, targets, param_env, bbs, discr_local, discr_ty); let (_, first) = targets.iter().next().unwrap(); let (from, first) = bbs.pick2_mut(switch_bb_idx, first); from.statements @@ -90,8 +95,11 @@ trait SimplifyMatch<'tcx> { true } + /// Check that the BBs to be simplified satisfies all distinct and + /// that the terminator are the same. + /// There are also conditions for different ways of simplification. fn can_simplify( - &self, + &mut self, tcx: TyCtxt<'tcx>, targets: &SwitchTargets, param_env: ParamEnv<'tcx>, @@ -144,7 +152,7 @@ struct SimplifyToIf; /// ``` impl<'tcx> SimplifyMatch<'tcx> for SimplifyToIf { fn can_simplify( - &self, + &mut self, tcx: TyCtxt<'tcx>, targets: &SwitchTargets, param_env: ParamEnv<'tcx>, @@ -250,3 +258,211 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToIf { new_stmts.collect() } } + +#[derive(Default)] +struct SimplifyToExp { + transfrom_types: Vec, +} + +#[derive(Clone, Copy)] +enum CompareType<'tcx, 'a> { + Same(&'a StatementKind<'tcx>), + Eq(&'a Place<'tcx>, Ty<'tcx>, ScalarInt), + Discr(&'a Place<'tcx>, Ty<'tcx>), +} + +enum TransfromType { + Same, + Eq, + Discr, +} + +impl From> for TransfromType { + fn from(compare_type: CompareType<'_, '_>) -> Self { + match compare_type { + CompareType::Same(_) => TransfromType::Same, + CompareType::Eq(_, _, _) => TransfromType::Eq, + CompareType::Discr(_, _) => TransfromType::Discr, + } + } +} + +/// If we find that the value of match is the same as the assignment, +/// merge a target block statements into the source block, +/// using cast to transform different integer types. +/// +/// For example: +/// +/// ```ignore (MIR) +/// bb0: { +/// switchInt(_1) -> [1: bb2, 2: bb3, 3: bb4, otherwise: bb1]; +/// } +/// +/// bb1: { +/// unreachable; +/// } +/// +/// bb2: { +/// _0 = const 1_i16; +/// goto -> bb5; +/// } +/// +/// bb3: { +/// _0 = const 2_i16; +/// goto -> bb5; +/// } +/// +/// bb4: { +/// _0 = const 3_i16; +/// goto -> bb5; +/// } +/// ``` +/// +/// into: +/// +/// ```ignore (MIR) +/// bb0: { +/// _0 = _3 as i16 (IntToInt); +/// goto -> bb5; +/// } +/// ``` +impl<'tcx> SimplifyMatch<'tcx> for SimplifyToExp { + fn can_simplify( + &mut self, + tcx: TyCtxt<'tcx>, + targets: &SwitchTargets, + param_env: ParamEnv<'tcx>, + bbs: &IndexVec>, + ) -> bool { + if targets.iter().len() < 2 || targets.iter().len() > 64 { + return false; + } + // We require that the possible target blocks all be distinct. + if !targets.is_distinct() { + return false; + } + if !bbs[targets.otherwise()].is_empty_unreachable() { + return false; + } + let mut target_iter = targets.iter(); + let (first_val, first_target) = target_iter.next().unwrap(); + let first_terminator_kind = &bbs[first_target].terminator().kind; + // Check that destinations are identical, and if not, then don't optimize this block + if !targets + .iter() + .all(|(_, other_target)| first_terminator_kind == &bbs[other_target].terminator().kind) + { + return false; + } + + let first_stmts = &bbs[first_target].statements; + let (second_val, second_target) = target_iter.next().unwrap(); + let second_stmts = &bbs[second_target].statements; + if first_stmts.len() != second_stmts.len() { + return false; + } + + let mut compare_types = Vec::new(); + for (f, s) in iter::zip(first_stmts, second_stmts) { + let compare_type = match (&f.kind, &s.kind) { + // If two statements are exactly the same, we can optimize. + (f_s, s_s) if f_s == s_s => CompareType::Same(f_s), + + // If two statements are assignments with the match values to the same place, we can optimize. + ( + StatementKind::Assign(box (lhs_f, Rvalue::Use(Operand::Constant(f_c)))), + StatementKind::Assign(box (lhs_s, Rvalue::Use(Operand::Constant(s_c)))), + ) if lhs_f == lhs_s + && f_c.const_.ty() == s_c.const_.ty() + && f_c.const_.ty().is_integral() => + { + match ( + f_c.const_.try_eval_scalar_int(tcx, param_env), + s_c.const_.try_eval_scalar_int(tcx, param_env), + ) { + (Some(f), Some(s)) if f == s => CompareType::Eq(lhs_f, f_c.const_.ty(), f), + (Some(f), Some(s)) + if Some(f) == ScalarInt::try_from_uint(first_val, f.size()) + && Some(s) == ScalarInt::try_from_uint(second_val, s.size()) => + { + CompareType::Discr(lhs_f, f_c.const_.ty()) + } + _ => return false, + } + } + + // Otherwise we cannot optimize. Try another block. + _ => return false, + }; + compare_types.push(compare_type); + } + + // All remaining BBs need to fulfill the same pattern as the two BBs from the previous step. + for (other_val, other_target) in target_iter { + let other_stmts = &bbs[other_target].statements; + if compare_types.len() != other_stmts.len() { + return false; + } + for (f, s) in iter::zip(&compare_types, other_stmts) { + match (*f, &s.kind) { + (CompareType::Same(f_s), s_s) if f_s == s_s => {} + ( + CompareType::Eq(lhs_f, f_ty, val), + StatementKind::Assign(box (lhs_s, Rvalue::Use(Operand::Constant(s_c)))), + ) if lhs_f == lhs_s + && s_c.const_.ty() == f_ty + && s_c.const_.try_eval_scalar_int(tcx, param_env) == Some(val) => {} + ( + CompareType::Discr(lhs_f, f_ty), + StatementKind::Assign(box (lhs_s, Rvalue::Use(Operand::Constant(s_c)))), + ) if lhs_f == lhs_s && s_c.const_.ty() == f_ty => { + let Some(f) = s_c.const_.try_eval_scalar_int(tcx, param_env) else { + return false; + }; + if Some(f) != ScalarInt::try_from_uint(other_val, f.size()) { + return false; + } + } + _ => return false, + } + } + } + self.transfrom_types = compare_types.into_iter().map(|c| c.into()).collect(); + true + } + + fn new_stmts( + &self, + _tcx: TyCtxt<'tcx>, + targets: &SwitchTargets, + _param_env: ParamEnv<'tcx>, + bbs: &IndexVec>, + discr_local: Local, + discr_ty: Ty<'tcx>, + ) -> Vec> { + let (_, first) = targets.iter().next().unwrap(); + let first = &bbs[first]; + + let new_stmts = + iter::zip(&self.transfrom_types, &first.statements).map(|(t, s)| match (t, &s.kind) { + (TransfromType::Same, _) | (TransfromType::Eq, _) => (*s).clone(), + ( + TransfromType::Discr, + StatementKind::Assign(box (lhs, Rvalue::Use(Operand::Constant(f_c)))), + ) => { + let operand = Operand::Copy(Place::from(discr_local)); + let r_val = if f_c.const_.ty() == discr_ty { + Rvalue::Use(operand) + } else { + Rvalue::Cast(CastKind::IntToInt, operand, f_c.const_.ty()) + }; + Statement { + source_info: s.source_info, + kind: StatementKind::Assign(Box::new((*lhs, r_val))), + } + } + _ => unreachable!(), + }); + new_stmts.collect() + } +} diff --git a/tests/codegen/match-optimized.rs b/tests/codegen/match-optimized.rs index 09907edf8f2..5cecafb9f29 100644 --- a/tests/codegen/match-optimized.rs +++ b/tests/codegen/match-optimized.rs @@ -26,12 +26,12 @@ pub fn exhaustive_match(e: E) -> u8 { // CHECK-NEXT: store i8 1, ptr %_0, align 1 // CHECK-NEXT: br label %[[EXIT]] // CHECK: [[C]]: -// CHECK-NEXT: store i8 2, ptr %_0, align 1 +// CHECK-NEXT: store i8 3, ptr %_0, align 1 // CHECK-NEXT: br label %[[EXIT]] match e { E::A => 0, E::B => 1, - E::C => 2, + E::C => 3, } } diff --git a/tests/mir-opt/matches_reduce_branches.match_i128_u128.MatchBranchSimplification.diff b/tests/mir-opt/matches_reduce_branches.match_i128_u128.MatchBranchSimplification.diff index 1f20349fdec..31ce51dc6de 100644 --- a/tests/mir-opt/matches_reduce_branches.match_i128_u128.MatchBranchSimplification.diff +++ b/tests/mir-opt/matches_reduce_branches.match_i128_u128.MatchBranchSimplification.diff @@ -5,37 +5,42 @@ debug i => _1; let mut _0: u128; let mut _2: i128; ++ let mut _3: i128; bb0: { _2 = discriminant(_1); - switchInt(move _2) -> [1: bb3, 2: bb4, 3: bb5, 340282366920938463463374607431768211455: bb2, otherwise: bb1]; - } - - bb1: { - unreachable; - } - - bb2: { - _0 = const core::num::::MAX; - goto -> bb6; - } - - bb3: { - _0 = const 1_u128; - goto -> bb6; - } - - bb4: { - _0 = const 2_u128; - goto -> bb6; - } - - bb5: { - _0 = const 3_u128; - goto -> bb6; - } - - bb6: { +- switchInt(move _2) -> [1: bb3, 2: bb4, 3: bb5, 340282366920938463463374607431768211455: bb2, otherwise: bb1]; +- } +- +- bb1: { +- unreachable; +- } +- +- bb2: { +- _0 = const core::num::::MAX; +- goto -> bb6; +- } +- +- bb3: { +- _0 = const 1_u128; +- goto -> bb6; +- } +- +- bb4: { +- _0 = const 2_u128; +- goto -> bb6; +- } +- +- bb5: { +- _0 = const 3_u128; +- goto -> bb6; +- } +- +- bb6: { ++ StorageLive(_3); ++ _3 = move _2; ++ _0 = _3 as u128 (IntToInt); ++ StorageDead(_3); return; } } diff --git a/tests/mir-opt/matches_reduce_branches.match_u8_i16.MatchBranchSimplification.diff b/tests/mir-opt/matches_reduce_branches.match_u8_i16.MatchBranchSimplification.diff index 72ad60956ab..9ee01a87a91 100644 --- a/tests/mir-opt/matches_reduce_branches.match_u8_i16.MatchBranchSimplification.diff +++ b/tests/mir-opt/matches_reduce_branches.match_u8_i16.MatchBranchSimplification.diff @@ -5,27 +5,32 @@ debug i => _1; let mut _0: i16; let mut _2: u8; ++ let mut _3: u8; bb0: { _2 = discriminant(_1); - switchInt(move _2) -> [1: bb3, 2: bb2, otherwise: bb1]; - } - - bb1: { - unreachable; - } - - bb2: { - _0 = const 2_i16; - goto -> bb4; - } - - bb3: { - _0 = const 1_i16; - goto -> bb4; - } - - bb4: { +- switchInt(move _2) -> [1: bb3, 2: bb2, otherwise: bb1]; +- } +- +- bb1: { +- unreachable; +- } +- +- bb2: { +- _0 = const 2_i16; +- goto -> bb4; +- } +- +- bb3: { +- _0 = const 1_i16; +- goto -> bb4; +- } +- +- bb4: { ++ StorageLive(_3); ++ _3 = move _2; ++ _0 = _3 as i16 (IntToInt); ++ StorageDead(_3); return; } } diff --git a/tests/mir-opt/matches_reduce_branches.match_u8_u16.MatchBranchSimplification.diff b/tests/mir-opt/matches_reduce_branches.match_u8_u16.MatchBranchSimplification.diff index 043fdb197a3..aa9fcc60a3e 100644 --- a/tests/mir-opt/matches_reduce_branches.match_u8_u16.MatchBranchSimplification.diff +++ b/tests/mir-opt/matches_reduce_branches.match_u8_u16.MatchBranchSimplification.diff @@ -5,32 +5,37 @@ debug i => _1; let mut _0: u16; let mut _2: u8; ++ let mut _3: u8; bb0: { _2 = discriminant(_1); - switchInt(move _2) -> [1: bb3, 2: bb4, 5: bb2, otherwise: bb1]; - } - - bb1: { - unreachable; - } - - bb2: { - _0 = const 5_u16; - goto -> bb5; - } - - bb3: { - _0 = const 1_u16; - goto -> bb5; - } - - bb4: { - _0 = const 2_u16; - goto -> bb5; - } - - bb5: { +- switchInt(move _2) -> [1: bb3, 2: bb4, 5: bb2, otherwise: bb1]; +- } +- +- bb1: { +- unreachable; +- } +- +- bb2: { +- _0 = const 5_u16; +- goto -> bb5; +- } +- +- bb3: { +- _0 = const 1_u16; +- goto -> bb5; +- } +- +- bb4: { +- _0 = const 2_u16; +- goto -> bb5; +- } +- +- bb5: { ++ StorageLive(_3); ++ _3 = move _2; ++ _0 = _3 as u16 (IntToInt); ++ StorageDead(_3); return; } } diff --git a/tests/mir-opt/matches_reduce_branches.rs b/tests/mir-opt/matches_reduce_branches.rs index 2e7b7d4e600..d51dd7c5873 100644 --- a/tests/mir-opt/matches_reduce_branches.rs +++ b/tests/mir-opt/matches_reduce_branches.rs @@ -75,7 +75,9 @@ enum EnumAu8 { // EMIT_MIR matches_reduce_branches.match_u8_i16.MatchBranchSimplification.diff fn match_u8_i16(i: EnumAu8) -> i16 { // CHECK-LABEL: fn match_u8_i16( - // CHECK: switchInt + // CHECK-NOT: switchInt + // CHECK: _0 = _3 as i16 (IntToInt); + // CHECH: return match i { EnumAu8::A => 1, EnumAu8::B => 2, @@ -144,7 +146,9 @@ enum EnumBu8 { // EMIT_MIR matches_reduce_branches.match_u8_u16.MatchBranchSimplification.diff fn match_u8_u16(i: EnumBu8) -> u16 { // CHECK-LABEL: fn match_u8_u16( - // CHECK: switchInt + // CHECK-NOT: switchInt + // CHECK: _0 = _3 as u16 (IntToInt); + // CHECH: return match i { EnumBu8::A => 1, EnumBu8::B => 2, @@ -248,7 +252,9 @@ enum EnumAi128 { // EMIT_MIR matches_reduce_branches.match_i128_u128.MatchBranchSimplification.diff fn match_i128_u128(i: EnumAi128) -> u128 { // CHECK-LABEL: fn match_i128_u128( - // CHECK: switchInt + // CHECK-NOT: switchInt + // CHECK: _0 = _3 as u128 (IntToInt); + // CHECH: return match i { EnumAi128::A => 1, EnumAi128::B => 2, diff --git a/tests/mir-opt/matches_u8.exhaustive_match.MatchBranchSimplification.diff b/tests/mir-opt/matches_u8.exhaustive_match.MatchBranchSimplification.diff index 157f9c98353..11a18f58e3a 100644 --- a/tests/mir-opt/matches_u8.exhaustive_match.MatchBranchSimplification.diff +++ b/tests/mir-opt/matches_u8.exhaustive_match.MatchBranchSimplification.diff @@ -5,27 +5,32 @@ debug e => _1; let mut _0: u8; let mut _2: isize; ++ let mut _3: isize; bb0: { _2 = discriminant(_1); - switchInt(move _2) -> [0: bb3, 1: bb2, otherwise: bb1]; - } - - bb1: { - unreachable; - } - - bb2: { - _0 = const 1_u8; - goto -> bb4; - } - - bb3: { - _0 = const 0_u8; - goto -> bb4; - } - - bb4: { +- switchInt(move _2) -> [0: bb3, 1: bb2, otherwise: bb1]; +- } +- +- bb1: { +- unreachable; +- } +- +- bb2: { +- _0 = const 1_u8; +- goto -> bb4; +- } +- +- bb3: { +- _0 = const 0_u8; +- goto -> bb4; +- } +- +- bb4: { ++ StorageLive(_3); ++ _3 = move _2; ++ _0 = _3 as u8 (IntToInt); ++ StorageDead(_3); return; } } diff --git a/tests/mir-opt/matches_u8.exhaustive_match_i8.MatchBranchSimplification.diff b/tests/mir-opt/matches_u8.exhaustive_match_i8.MatchBranchSimplification.diff index 19083771fd9..809badc41ba 100644 --- a/tests/mir-opt/matches_u8.exhaustive_match_i8.MatchBranchSimplification.diff +++ b/tests/mir-opt/matches_u8.exhaustive_match_i8.MatchBranchSimplification.diff @@ -5,27 +5,32 @@ debug e => _1; let mut _0: i8; let mut _2: isize; ++ let mut _3: isize; bb0: { _2 = discriminant(_1); - switchInt(move _2) -> [0: bb3, 1: bb2, otherwise: bb1]; - } - - bb1: { - unreachable; - } - - bb2: { - _0 = const 1_i8; - goto -> bb4; - } - - bb3: { - _0 = const 0_i8; - goto -> bb4; - } - - bb4: { +- switchInt(move _2) -> [0: bb3, 1: bb2, otherwise: bb1]; +- } +- +- bb1: { +- unreachable; +- } +- +- bb2: { +- _0 = const 1_i8; +- goto -> bb4; +- } +- +- bb3: { +- _0 = const 0_i8; +- goto -> bb4; +- } +- +- bb4: { ++ StorageLive(_3); ++ _3 = move _2; ++ _0 = _3 as i8 (IntToInt); ++ StorageDead(_3); return; } } From e752af765ea04ba663d82524cfdcc2b7b6cb58aa Mon Sep 17 00:00:00 2001 From: DianQK Date: Tue, 20 Feb 2024 21:55:46 +0800 Subject: [PATCH 4/7] Transforms a match containing negative numbers into an assignment statement as well --- .../rustc_mir_transform/src/match_branches.rs | 49 ++++++++++++++---- ...atch_i16_i8.MatchBranchSimplification.diff | 51 ++++++++++--------- ...atch_i8_i16.MatchBranchSimplification.diff | 51 ++++++++++--------- tests/mir-opt/matches_reduce_branches.rs | 8 ++- 4 files changed, 100 insertions(+), 59 deletions(-) diff --git a/compiler/rustc_mir_transform/src/match_branches.rs b/compiler/rustc_mir_transform/src/match_branches.rs index e766c1ae0f6..a444df34048 100644 --- a/compiler/rustc_mir_transform/src/match_branches.rs +++ b/compiler/rustc_mir_transform/src/match_branches.rs @@ -1,6 +1,7 @@ use rustc_index::IndexVec; use rustc_middle::mir::*; use rustc_middle::ty::{ParamEnv, ScalarInt, Ty, TyCtxt}; +use rustc_target::abi::Size; use std::iter; use super::simplify::simplify_cfg; @@ -67,13 +68,13 @@ trait SimplifyMatch<'tcx> { _ => unreachable!(), }; - if !self.can_simplify(tcx, targets, param_env, bbs) { + let discr_ty = discr.ty(local_decls, tcx); + if !self.can_simplify(tcx, targets, param_env, bbs, discr_ty) { return false; } // Take ownership of items now that we know we can optimize. let discr = discr.clone(); - let discr_ty = discr.ty(local_decls, tcx); // Introduce a temporary for the discriminant value. let source_info = bbs[switch_bb_idx].terminator().source_info; @@ -104,6 +105,7 @@ trait SimplifyMatch<'tcx> { targets: &SwitchTargets, param_env: ParamEnv<'tcx>, bbs: &IndexVec>, + discr_ty: Ty<'tcx>, ) -> bool; fn new_stmts( @@ -157,6 +159,7 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToIf { targets: &SwitchTargets, param_env: ParamEnv<'tcx>, bbs: &IndexVec>, + _discr_ty: Ty<'tcx>, ) -> bool { if targets.iter().len() != 1 { return false; @@ -268,7 +271,7 @@ struct SimplifyToExp { enum CompareType<'tcx, 'a> { Same(&'a StatementKind<'tcx>), Eq(&'a Place<'tcx>, Ty<'tcx>, ScalarInt), - Discr(&'a Place<'tcx>, Ty<'tcx>), + Discr(&'a Place<'tcx>, Ty<'tcx>, bool), } enum TransfromType { @@ -282,7 +285,7 @@ impl From> for TransfromType { match compare_type { CompareType::Same(_) => TransfromType::Same, CompareType::Eq(_, _, _) => TransfromType::Eq, - CompareType::Discr(_, _) => TransfromType::Discr, + CompareType::Discr(_, _, _) => TransfromType::Discr, } } } @@ -333,6 +336,7 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToExp { targets: &SwitchTargets, param_env: ParamEnv<'tcx>, bbs: &IndexVec>, + discr_ty: Ty<'tcx>, ) -> bool { if targets.iter().len() < 2 || targets.iter().len() > 64 { return false; @@ -355,6 +359,7 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToExp { return false; } + let discr_size = tcx.layout_of(param_env.and(discr_ty)).unwrap().size; let first_stmts = &bbs[first_target].statements; let (second_val, second_target) = target_iter.next().unwrap(); let second_stmts = &bbs[second_target].statements; @@ -362,6 +367,11 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToExp { return false; } + fn int_equal(l: ScalarInt, r: impl Into, size: Size) -> bool { + l.try_to_int(l.size()).unwrap() + == ScalarInt::try_from_uint(r, size).unwrap().try_to_int(size).unwrap() + } + let mut compare_types = Vec::new(); for (f, s) in iter::zip(first_stmts, second_stmts) { let compare_type = match (&f.kind, &s.kind) { @@ -382,12 +392,22 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToExp { ) { (Some(f), Some(s)) if f == s => CompareType::Eq(lhs_f, f_c.const_.ty(), f), (Some(f), Some(s)) - if Some(f) == ScalarInt::try_from_uint(first_val, f.size()) - && Some(s) == ScalarInt::try_from_uint(second_val, s.size()) => + if ((f_c.const_.ty().is_signed() || discr_ty.is_signed()) + && int_equal(f, first_val, discr_size) + && int_equal(s, second_val, discr_size)) + || (Some(f) == ScalarInt::try_from_uint(first_val, f.size()) + && Some(s) + == ScalarInt::try_from_uint(second_val, s.size())) => { - CompareType::Discr(lhs_f, f_c.const_.ty()) + CompareType::Discr( + lhs_f, + f_c.const_.ty(), + f_c.const_.ty().is_signed() || discr_ty.is_signed(), + ) + } + _ => { + return false; } - _ => return false, } } @@ -413,15 +433,22 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToExp { && s_c.const_.ty() == f_ty && s_c.const_.try_eval_scalar_int(tcx, param_env) == Some(val) => {} ( - CompareType::Discr(lhs_f, f_ty), + CompareType::Discr(lhs_f, f_ty, is_signed), StatementKind::Assign(box (lhs_s, Rvalue::Use(Operand::Constant(s_c)))), ) if lhs_f == lhs_s && s_c.const_.ty() == f_ty => { let Some(f) = s_c.const_.try_eval_scalar_int(tcx, param_env) else { return false; }; - if Some(f) != ScalarInt::try_from_uint(other_val, f.size()) { - return false; + if is_signed + && s_c.const_.ty().is_signed() + && int_equal(f, other_val, discr_size) + { + continue; } + if Some(f) == ScalarInt::try_from_uint(other_val, f.size()) { + continue; + } + return false; } _ => return false, } diff --git a/tests/mir-opt/matches_reduce_branches.match_i16_i8.MatchBranchSimplification.diff b/tests/mir-opt/matches_reduce_branches.match_i16_i8.MatchBranchSimplification.diff index 4b435310916..e1b537b1b71 100644 --- a/tests/mir-opt/matches_reduce_branches.match_i16_i8.MatchBranchSimplification.diff +++ b/tests/mir-opt/matches_reduce_branches.match_i16_i8.MatchBranchSimplification.diff @@ -5,32 +5,37 @@ debug i => _1; let mut _0: i8; let mut _2: i16; ++ let mut _3: i16; bb0: { _2 = discriminant(_1); - switchInt(move _2) -> [65535: bb3, 2: bb4, 65533: bb2, otherwise: bb1]; - } - - bb1: { - unreachable; - } - - bb2: { - _0 = const -3_i8; - goto -> bb5; - } - - bb3: { - _0 = const -1_i8; - goto -> bb5; - } - - bb4: { - _0 = const 2_i8; - goto -> bb5; - } - - bb5: { +- switchInt(move _2) -> [65535: bb3, 2: bb4, 65533: bb2, otherwise: bb1]; +- } +- +- bb1: { +- unreachable; +- } +- +- bb2: { +- _0 = const -3_i8; +- goto -> bb5; +- } +- +- bb3: { +- _0 = const -1_i8; +- goto -> bb5; +- } +- +- bb4: { +- _0 = const 2_i8; +- goto -> bb5; +- } +- +- bb5: { ++ StorageLive(_3); ++ _3 = move _2; ++ _0 = _3 as i8 (IntToInt); ++ StorageDead(_3); return; } } diff --git a/tests/mir-opt/matches_reduce_branches.match_i8_i16.MatchBranchSimplification.diff b/tests/mir-opt/matches_reduce_branches.match_i8_i16.MatchBranchSimplification.diff index 8a390736add..cabc5a44cd8 100644 --- a/tests/mir-opt/matches_reduce_branches.match_i8_i16.MatchBranchSimplification.diff +++ b/tests/mir-opt/matches_reduce_branches.match_i8_i16.MatchBranchSimplification.diff @@ -5,32 +5,37 @@ debug i => _1; let mut _0: i16; let mut _2: i8; ++ let mut _3: i8; bb0: { _2 = discriminant(_1); - switchInt(move _2) -> [255: bb3, 2: bb4, 253: bb2, otherwise: bb1]; - } - - bb1: { - unreachable; - } - - bb2: { - _0 = const -3_i16; - goto -> bb5; - } - - bb3: { - _0 = const -1_i16; - goto -> bb5; - } - - bb4: { - _0 = const 2_i16; - goto -> bb5; - } - - bb5: { +- switchInt(move _2) -> [255: bb3, 2: bb4, 253: bb2, otherwise: bb1]; +- } +- +- bb1: { +- unreachable; +- } +- +- bb2: { +- _0 = const -3_i16; +- goto -> bb5; +- } +- +- bb3: { +- _0 = const -1_i16; +- goto -> bb5; +- } +- +- bb4: { +- _0 = const 2_i16; +- goto -> bb5; +- } +- +- bb5: { ++ StorageLive(_3); ++ _3 = move _2; ++ _0 = _3 as i16 (IntToInt); ++ StorageDead(_3); return; } } diff --git a/tests/mir-opt/matches_reduce_branches.rs b/tests/mir-opt/matches_reduce_branches.rs index d51dd7c5873..ca3e5f747d1 100644 --- a/tests/mir-opt/matches_reduce_branches.rs +++ b/tests/mir-opt/matches_reduce_branches.rs @@ -204,7 +204,9 @@ enum EnumAi8 { // EMIT_MIR matches_reduce_branches.match_i8_i16.MatchBranchSimplification.diff fn match_i8_i16(i: EnumAi8) -> i16 { // CHECK-LABEL: fn match_i8_i16( - // CHECK: switchInt + // CHECK-NOT: switchInt + // CHECK: _0 = _3 as i16 (IntToInt); + // CHECH: return match i { EnumAi8::A => -1, EnumAi8::B => 2, @@ -233,7 +235,9 @@ enum EnumAi16 { // EMIT_MIR matches_reduce_branches.match_i16_i8.MatchBranchSimplification.diff fn match_i16_i8(i: EnumAi16) -> i8 { // CHECK-LABEL: fn match_i16_i8( - // CHECK: switchInt + // CHECK-NOT: switchInt + // CHECK: _0 = _3 as i8 (IntToInt); + // CHECH: return match i { EnumAi16::A => -1, EnumAi16::B => 2, From 254289a16e318bbbbecadb05c14abbc07f16d2b4 Mon Sep 17 00:00:00 2001 From: DianQK Date: Wed, 21 Feb 2024 18:51:51 +0800 Subject: [PATCH 5/7] Updating the MIR with MirPatch --- .../rustc_mir_transform/src/match_branches.rs | 104 +++++++++--------- 1 file changed, 54 insertions(+), 50 deletions(-) diff --git a/compiler/rustc_mir_transform/src/match_branches.rs b/compiler/rustc_mir_transform/src/match_branches.rs index a444df34048..e9203043769 100644 --- a/compiler/rustc_mir_transform/src/match_branches.rs +++ b/compiler/rustc_mir_transform/src/match_branches.rs @@ -1,4 +1,5 @@ -use rustc_index::IndexVec; +use rustc_index::IndexSlice; +use rustc_middle::mir::patch::MirPatch; use rustc_middle::mir::*; use rustc_middle::ty::{ParamEnv, ScalarInt, Ty, TyCtxt}; use rustc_target::abi::Size; @@ -17,9 +18,10 @@ impl<'tcx> MirPass<'tcx> for MatchBranchSimplification { let def_id = body.source.def_id(); let param_env = tcx.param_env_reveal_all_normalized(def_id); - let bbs = body.basic_blocks.as_mut(); let mut should_cleanup = false; - for bb_idx in bbs.indices() { + for i in 0..body.basic_blocks.len() { + let bbs = &*body.basic_blocks; + let bb_idx = BasicBlock::from_usize(i); if !tcx.consider_optimizing(|| format!("MatchBranchSimplification {def_id:?} ")) { continue; } @@ -35,12 +37,11 @@ impl<'tcx> MirPass<'tcx> for MatchBranchSimplification { _ => continue, }; - if SimplifyToIf.simplify(tcx, &mut body.local_decls, bbs, bb_idx, param_env) { + if SimplifyToIf.simplify(tcx, body, bb_idx, param_env) { should_cleanup = true; continue; } - if SimplifyToExp::default().simplify(tcx, &mut body.local_decls, bbs, bb_idx, param_env) - { + if SimplifyToExp::default().simplify(tcx, body, bb_idx, param_env) { should_cleanup = true; continue; } @@ -58,41 +59,39 @@ trait SimplifyMatch<'tcx> { fn simplify( &mut self, tcx: TyCtxt<'tcx>, - local_decls: &mut IndexVec>, - bbs: &mut IndexVec>, + body: &mut Body<'tcx>, switch_bb_idx: BasicBlock, param_env: ParamEnv<'tcx>, ) -> bool { + let bbs = &body.basic_blocks; let (discr, targets) = match bbs[switch_bb_idx].terminator().kind { TerminatorKind::SwitchInt { ref discr, ref targets, .. } => (discr, targets), _ => unreachable!(), }; - let discr_ty = discr.ty(local_decls, tcx); + let discr_ty = discr.ty(body.local_decls(), tcx); if !self.can_simplify(tcx, targets, param_env, bbs, discr_ty) { return false; } + let mut patch = MirPatch::new(body); + // Take ownership of items now that we know we can optimize. let discr = discr.clone(); // Introduce a temporary for the discriminant value. let source_info = bbs[switch_bb_idx].terminator().source_info; - let discr_local = local_decls.push(LocalDecl::new(discr_ty, source_info.span)); + let discr_local = patch.new_temp(discr_ty, source_info.span); - let new_stmts = self.new_stmts(tcx, targets, param_env, bbs, discr_local, discr_ty); let (_, first) = targets.iter().next().unwrap(); - let (from, first) = bbs.pick2_mut(switch_bb_idx, first); - from.statements - .push(Statement { source_info, kind: StatementKind::StorageLive(discr_local) }); - from.statements.push(Statement { - source_info, - kind: StatementKind::Assign(Box::new((Place::from(discr_local), Rvalue::Use(discr)))), - }); - from.statements.extend(new_stmts); - from.statements - .push(Statement { source_info, kind: StatementKind::StorageDead(discr_local) }); - from.terminator_mut().kind = first.terminator().kind.clone(); + let statement_index = bbs[switch_bb_idx].statements.len(); + let parent_end = Location { block: switch_bb_idx, statement_index }; + patch.add_statement(parent_end, StatementKind::StorageLive(discr_local)); + patch.add_assign(parent_end, Place::from(discr_local), Rvalue::Use(discr)); + self.new_stmts(tcx, targets, param_env, &mut patch, parent_end, bbs, discr_local, discr_ty); + patch.add_statement(parent_end, StatementKind::StorageDead(discr_local)); + patch.patch_terminator(switch_bb_idx, bbs[first].terminator().kind.clone()); + patch.apply(body); true } @@ -104,7 +103,7 @@ trait SimplifyMatch<'tcx> { tcx: TyCtxt<'tcx>, targets: &SwitchTargets, param_env: ParamEnv<'tcx>, - bbs: &IndexVec>, + bbs: &IndexSlice>, discr_ty: Ty<'tcx>, ) -> bool; @@ -113,10 +112,12 @@ trait SimplifyMatch<'tcx> { tcx: TyCtxt<'tcx>, targets: &SwitchTargets, param_env: ParamEnv<'tcx>, - bbs: &IndexVec>, + patch: &mut MirPatch<'tcx>, + parent_end: Location, + bbs: &IndexSlice>, discr_local: Local, discr_ty: Ty<'tcx>, - ) -> Vec>; + ); } struct SimplifyToIf; @@ -158,7 +159,7 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToIf { tcx: TyCtxt<'tcx>, targets: &SwitchTargets, param_env: ParamEnv<'tcx>, - bbs: &IndexVec>, + bbs: &IndexSlice>, _discr_ty: Ty<'tcx>, ) -> bool { if targets.iter().len() != 1 { @@ -209,20 +210,23 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToIf { tcx: TyCtxt<'tcx>, targets: &SwitchTargets, param_env: ParamEnv<'tcx>, - bbs: &IndexVec>, + patch: &mut MirPatch<'tcx>, + parent_end: Location, + bbs: &IndexSlice>, discr_local: Local, discr_ty: Ty<'tcx>, - ) -> Vec> { + ) { let (val, first) = targets.iter().next().unwrap(); let second = targets.otherwise(); // We already checked that first and second are different blocks, // and bb_idx has a different terminator from both of them. let first = &bbs[first]; let second = &bbs[second]; - - let new_stmts = iter::zip(&first.statements, &second.statements).map(|(f, s)| { + for (f, s) in iter::zip(&first.statements, &second.statements) { match (&f.kind, &s.kind) { - (f_s, s_s) if f_s == s_s => (*f).clone(), + (f_s, s_s) if f_s == s_s => { + patch.add_statement(parent_end, f.kind.clone()); + } ( StatementKind::Assign(box (lhs, Rvalue::Use(Operand::Constant(f_c)))), @@ -233,7 +237,7 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToIf { let s_b = s_c.const_.try_eval_bool(tcx, param_env).unwrap(); if f_b == s_b { // Same value in both blocks. Use statement as is. - (*f).clone() + patch.add_statement(parent_end, f.kind.clone()); } else { // Different value between blocks. Make value conditional on switch condition. let size = tcx.layout_of(param_env.and(discr_ty)).unwrap().size; @@ -248,17 +252,13 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToIf { op, Box::new((Operand::Copy(Place::from(discr_local)), const_cmp)), ); - Statement { - source_info: f.source_info, - kind: StatementKind::Assign(Box::new((*lhs, rhs))), - } + patch.add_assign(parent_end, *lhs, rhs); } } _ => unreachable!(), } - }); - new_stmts.collect() + } } } @@ -335,7 +335,7 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToExp { tcx: TyCtxt<'tcx>, targets: &SwitchTargets, param_env: ParamEnv<'tcx>, - bbs: &IndexVec>, + bbs: &IndexSlice>, discr_ty: Ty<'tcx>, ) -> bool { if targets.iter().len() < 2 || targets.iter().len() > 64 { @@ -372,6 +372,7 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToExp { == ScalarInt::try_from_uint(r, size).unwrap().try_to_int(size).unwrap() } + // We first compare the two branches, and then the other branches need to fulfill the same conditions. let mut compare_types = Vec::new(); for (f, s) in iter::zip(first_stmts, second_stmts) { let compare_type = match (&f.kind, &s.kind) { @@ -391,6 +392,8 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToExp { s_c.const_.try_eval_scalar_int(tcx, param_env), ) { (Some(f), Some(s)) if f == s => CompareType::Eq(lhs_f, f_c.const_.ty(), f), + // Enum variants can also be simplified to an assignment statement if their values are equal. + // We need to consider both unsigned and signed scenarios here. (Some(f), Some(s)) if ((f_c.const_.ty().is_signed() || discr_ty.is_signed()) && int_equal(f, first_val, discr_size) @@ -463,16 +466,20 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToExp { _tcx: TyCtxt<'tcx>, targets: &SwitchTargets, _param_env: ParamEnv<'tcx>, - bbs: &IndexVec>, + patch: &mut MirPatch<'tcx>, + parent_end: Location, + bbs: &IndexSlice>, discr_local: Local, discr_ty: Ty<'tcx>, - ) -> Vec> { + ) { let (_, first) = targets.iter().next().unwrap(); let first = &bbs[first]; - let new_stmts = - iter::zip(&self.transfrom_types, &first.statements).map(|(t, s)| match (t, &s.kind) { - (TransfromType::Same, _) | (TransfromType::Eq, _) => (*s).clone(), + for (t, s) in iter::zip(&self.transfrom_types, &first.statements) { + match (t, &s.kind) { + (TransfromType::Same, _) | (TransfromType::Eq, _) => { + patch.add_statement(parent_end, s.kind.clone()); + } ( TransfromType::Discr, StatementKind::Assign(box (lhs, Rvalue::Use(Operand::Constant(f_c)))), @@ -483,13 +490,10 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToExp { } else { Rvalue::Cast(CastKind::IntToInt, operand, f_c.const_.ty()) }; - Statement { - source_info: s.source_info, - kind: StatementKind::Assign(Box::new((*lhs, r_val))), - } + patch.add_assign(parent_end, *lhs, r_val); } _ => unreachable!(), - }); - new_stmts.collect() + } + } } } From 032bb742ab537ac41ce5428b9c344a8c348bd2c9 Mon Sep 17 00:00:00 2001 From: DianQK Date: Sun, 10 Mar 2024 22:07:41 +0800 Subject: [PATCH 6/7] Add comments for `CompareType` --- .../rustc_mir_transform/src/match_branches.rs | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/compiler/rustc_mir_transform/src/match_branches.rs b/compiler/rustc_mir_transform/src/match_branches.rs index e9203043769..2b6589100c8 100644 --- a/compiler/rustc_mir_transform/src/match_branches.rs +++ b/compiler/rustc_mir_transform/src/match_branches.rs @@ -269,9 +269,12 @@ struct SimplifyToExp { #[derive(Clone, Copy)] enum CompareType<'tcx, 'a> { + /// Identical statements. Same(&'a StatementKind<'tcx>), + /// Assignment statements have the same value. Eq(&'a Place<'tcx>, Ty<'tcx>, ScalarInt), - Discr(&'a Place<'tcx>, Ty<'tcx>, bool), + /// Enum variant comparison type. + Discr { place: &'a Place<'tcx>, ty: Ty<'tcx>, is_signed: bool }, } enum TransfromType { @@ -285,7 +288,7 @@ impl From> for TransfromType { match compare_type { CompareType::Same(_) => TransfromType::Same, CompareType::Eq(_, _, _) => TransfromType::Eq, - CompareType::Discr(_, _, _) => TransfromType::Discr, + CompareType::Discr { .. } => TransfromType::Discr, } } } @@ -402,11 +405,11 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToExp { && Some(s) == ScalarInt::try_from_uint(second_val, s.size())) => { - CompareType::Discr( - lhs_f, - f_c.const_.ty(), - f_c.const_.ty().is_signed() || discr_ty.is_signed(), - ) + CompareType::Discr { + place: lhs_f, + ty: f_c.const_.ty(), + is_signed: f_c.const_.ty().is_signed() || discr_ty.is_signed(), + } } _ => { return false; @@ -436,7 +439,7 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToExp { && s_c.const_.ty() == f_ty && s_c.const_.try_eval_scalar_int(tcx, param_env) == Some(val) => {} ( - CompareType::Discr(lhs_f, f_ty, is_signed), + CompareType::Discr { place: lhs_f, ty: f_ty, is_signed }, StatementKind::Assign(box (lhs_s, Rvalue::Use(Operand::Constant(s_c)))), ) if lhs_f == lhs_s && s_c.const_.ty() == f_ty => { let Some(f) = s_c.const_.try_eval_scalar_int(tcx, param_env) else { From f4407370dbb67115bc4acc97dbbdceba0f6d17f3 Mon Sep 17 00:00:00 2001 From: DianQK Date: Sun, 10 Mar 2024 22:23:55 +0800 Subject: [PATCH 7/7] Change the return type of `can_simplify` to `Option<()>` --- .../rustc_mir_transform/src/match_branches.rs | 54 +++++++++---------- 1 file changed, 26 insertions(+), 28 deletions(-) diff --git a/compiler/rustc_mir_transform/src/match_branches.rs b/compiler/rustc_mir_transform/src/match_branches.rs index 2b6589100c8..4d9a198eeb2 100644 --- a/compiler/rustc_mir_transform/src/match_branches.rs +++ b/compiler/rustc_mir_transform/src/match_branches.rs @@ -37,11 +37,11 @@ impl<'tcx> MirPass<'tcx> for MatchBranchSimplification { _ => continue, }; - if SimplifyToIf.simplify(tcx, body, bb_idx, param_env) { + if SimplifyToIf.simplify(tcx, body, bb_idx, param_env).is_some() { should_cleanup = true; continue; } - if SimplifyToExp::default().simplify(tcx, body, bb_idx, param_env) { + if SimplifyToExp::default().simplify(tcx, body, bb_idx, param_env).is_some() { should_cleanup = true; continue; } @@ -62,7 +62,7 @@ trait SimplifyMatch<'tcx> { body: &mut Body<'tcx>, switch_bb_idx: BasicBlock, param_env: ParamEnv<'tcx>, - ) -> bool { + ) -> Option<()> { let bbs = &body.basic_blocks; let (discr, targets) = match bbs[switch_bb_idx].terminator().kind { TerminatorKind::SwitchInt { ref discr, ref targets, .. } => (discr, targets), @@ -70,9 +70,7 @@ trait SimplifyMatch<'tcx> { }; let discr_ty = discr.ty(body.local_decls(), tcx); - if !self.can_simplify(tcx, targets, param_env, bbs, discr_ty) { - return false; - } + self.can_simplify(tcx, targets, param_env, bbs, discr_ty)?; let mut patch = MirPatch::new(body); @@ -92,7 +90,7 @@ trait SimplifyMatch<'tcx> { patch.add_statement(parent_end, StatementKind::StorageDead(discr_local)); patch.patch_terminator(switch_bb_idx, bbs[first].terminator().kind.clone()); patch.apply(body); - true + Some(()) } /// Check that the BBs to be simplified satisfies all distinct and @@ -105,7 +103,7 @@ trait SimplifyMatch<'tcx> { param_env: ParamEnv<'tcx>, bbs: &IndexSlice>, discr_ty: Ty<'tcx>, - ) -> bool; + ) -> Option<()>; fn new_stmts( &self, @@ -161,19 +159,19 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToIf { param_env: ParamEnv<'tcx>, bbs: &IndexSlice>, _discr_ty: Ty<'tcx>, - ) -> bool { + ) -> Option<()> { if targets.iter().len() != 1 { - return false; + return None; } // We require that the possible target blocks all be distinct. let (_, first) = targets.iter().next().unwrap(); let second = targets.otherwise(); if first == second { - return false; + return None; } // Check that destinations are identical, and if not, then don't optimize this block if bbs[first].terminator().kind != bbs[second].terminator().kind { - return false; + return None; } // Check that blocks are assignments of consts to the same place or same statement, @@ -181,7 +179,7 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToIf { let first_stmts = &bbs[first].statements; let second_stmts = &bbs[second].statements; if first_stmts.len() != second_stmts.len() { - return false; + return None; } for (f, s) in iter::zip(first_stmts, second_stmts) { match (&f.kind, &s.kind) { @@ -199,10 +197,10 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToIf { && s_c.const_.try_eval_bool(tcx, param_env).is_some() => {} // Otherwise we cannot optimize. Try another block. - _ => return false, + _ => return None, } } - true + Some(()) } fn new_stmts( @@ -340,16 +338,16 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToExp { param_env: ParamEnv<'tcx>, bbs: &IndexSlice>, discr_ty: Ty<'tcx>, - ) -> bool { + ) -> Option<()> { if targets.iter().len() < 2 || targets.iter().len() > 64 { - return false; + return None; } // We require that the possible target blocks all be distinct. if !targets.is_distinct() { - return false; + return None; } if !bbs[targets.otherwise()].is_empty_unreachable() { - return false; + return None; } let mut target_iter = targets.iter(); let (first_val, first_target) = target_iter.next().unwrap(); @@ -359,7 +357,7 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToExp { .iter() .all(|(_, other_target)| first_terminator_kind == &bbs[other_target].terminator().kind) { - return false; + return None; } let discr_size = tcx.layout_of(param_env.and(discr_ty)).unwrap().size; @@ -367,7 +365,7 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToExp { let (second_val, second_target) = target_iter.next().unwrap(); let second_stmts = &bbs[second_target].statements; if first_stmts.len() != second_stmts.len() { - return false; + return None; } fn int_equal(l: ScalarInt, r: impl Into, size: Size) -> bool { @@ -412,13 +410,13 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToExp { } } _ => { - return false; + return None; } } } // Otherwise we cannot optimize. Try another block. - _ => return false, + _ => return None, }; compare_types.push(compare_type); } @@ -427,7 +425,7 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToExp { for (other_val, other_target) in target_iter { let other_stmts = &bbs[other_target].statements; if compare_types.len() != other_stmts.len() { - return false; + return None; } for (f, s) in iter::zip(&compare_types, other_stmts) { match (*f, &s.kind) { @@ -443,7 +441,7 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToExp { StatementKind::Assign(box (lhs_s, Rvalue::Use(Operand::Constant(s_c)))), ) if lhs_f == lhs_s && s_c.const_.ty() == f_ty => { let Some(f) = s_c.const_.try_eval_scalar_int(tcx, param_env) else { - return false; + return None; }; if is_signed && s_c.const_.ty().is_signed() @@ -454,14 +452,14 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToExp { if Some(f) == ScalarInt::try_from_uint(other_val, f.size()) { continue; } - return false; + return None; } - _ => return false, + _ => return None, } } } self.transfrom_types = compare_types.into_iter().map(|c| c.into()).collect(); - true + Some(()) } fn new_stmts(