diff --git a/compiler/rustc_middle/src/ty/fold.rs b/compiler/rustc_middle/src/ty/fold.rs index 5469aeb4c2c..794c0ddf97d 100644 --- a/compiler/rustc_middle/src/ty/fold.rs +++ b/compiler/rustc_middle/src/ty/fold.rs @@ -656,17 +656,17 @@ struct BoundVarReplacer<'a, 'tcx> { /// the ones we have visited. current_index: ty::DebruijnIndex, - fld_r: Option<&'a mut (dyn FnMut(ty::BoundRegion) -> ty::Region<'tcx> + 'a)>, - fld_t: Option<&'a mut (dyn FnMut(ty::BoundTy) -> Ty<'tcx> + 'a)>, - fld_c: Option<&'a mut (dyn FnMut(ty::BoundVar, Ty<'tcx>) -> ty::Const<'tcx> + 'a)>, + fld_r: &'a mut (dyn FnMut(ty::BoundRegion) -> ty::Region<'tcx> + 'a), + fld_t: &'a mut (dyn FnMut(ty::BoundTy) -> Ty<'tcx> + 'a), + fld_c: &'a mut (dyn FnMut(ty::BoundVar, Ty<'tcx>) -> ty::Const<'tcx> + 'a), } impl<'a, 'tcx> BoundVarReplacer<'a, 'tcx> { fn new( tcx: TyCtxt<'tcx>, - fld_r: Option<&'a mut (dyn FnMut(ty::BoundRegion) -> ty::Region<'tcx> + 'a)>, - fld_t: Option<&'a mut (dyn FnMut(ty::BoundTy) -> Ty<'tcx> + 'a)>, - fld_c: Option<&'a mut (dyn FnMut(ty::BoundVar, Ty<'tcx>) -> ty::Const<'tcx> + 'a)>, + fld_r: &'a mut (dyn FnMut(ty::BoundRegion) -> ty::Region<'tcx> + 'a), + fld_t: &'a mut (dyn FnMut(ty::BoundTy) -> Ty<'tcx> + 'a), + fld_c: &'a mut (dyn FnMut(ty::BoundVar, Ty<'tcx>) -> ty::Const<'tcx> + 'a), ) -> Self { BoundVarReplacer { tcx, current_index: ty::INNERMOST, fld_r, fld_t, fld_c } } @@ -690,55 +690,42 @@ impl<'a, 'tcx> TypeFolder<'tcx> for BoundVarReplacer<'a, 'tcx> { fn fold_ty(&mut self, t: Ty<'tcx>) -> Ty<'tcx> { match *t.kind() { ty::Bound(debruijn, bound_ty) if debruijn == self.current_index => { - if let Some(fld_t) = self.fld_t.as_mut() { - let ty = fld_t(bound_ty); - return ty::fold::shift_vars(self.tcx, ty, self.current_index.as_u32()); - } + let ty = (self.fld_t)(bound_ty); + ty::fold::shift_vars(self.tcx, ty, self.current_index.as_u32()) } - _ if t.has_vars_bound_at_or_above(self.current_index) => { - return t.super_fold_with(self); - } - _ => {} + _ if t.has_vars_bound_at_or_above(self.current_index) => t.super_fold_with(self), + _ => t, } - t } fn fold_region(&mut self, r: ty::Region<'tcx>) -> ty::Region<'tcx> { match *r { ty::ReLateBound(debruijn, br) if debruijn == self.current_index => { - if let Some(fld_r) = self.fld_r.as_mut() { - let region = fld_r(br); - return if let ty::ReLateBound(debruijn1, br) = *region { - // If the callback returns a late-bound region, - // that region should always use the INNERMOST - // debruijn index. Then we adjust it to the - // correct depth. - assert_eq!(debruijn1, ty::INNERMOST); - self.tcx.mk_region(ty::ReLateBound(debruijn, br)) - } else { - region - }; + let region = (self.fld_r)(br); + if let ty::ReLateBound(debruijn1, br) = *region { + // If the callback returns a late-bound region, + // that region should always use the INNERMOST + // debruijn index. Then we adjust it to the + // correct depth. + assert_eq!(debruijn1, ty::INNERMOST); + self.tcx.mk_region(ty::ReLateBound(debruijn, br)) + } else { + region } } - _ => {} + _ => r, } - r } fn fold_const(&mut self, ct: ty::Const<'tcx>) -> ty::Const<'tcx> { match ct.val() { ty::ConstKind::Bound(debruijn, bound_const) if debruijn == self.current_index => { - if let Some(fld_c) = self.fld_c.as_mut() { - let ct = fld_c(bound_const, ct.ty()); - return ty::fold::shift_vars(self.tcx, ct, self.current_index.as_u32()); - } + let ct = (self.fld_c)(bound_const, ct.ty()); + ty::fold::shift_vars(self.tcx, ct, self.current_index.as_u32()) } - _ if ct.has_vars_bound_at_or_above(self.current_index) => { - return ct.super_fold_with(self); - } - _ => {} + _ if ct.has_vars_bound_at_or_above(self.current_index) => ct.super_fold_with(self), + _ => ct, } - ct } } @@ -752,8 +739,10 @@ impl<'tcx> TyCtxt<'tcx> { /// returned at the end with each bound region and the free region /// that replaced it. /// - /// This method only replaces late bound regions and the result may still - /// contain escaping bound types. + /// # Panics + /// + /// This method only replaces late bound regions. Any types or + /// constants bound by `value` will cause an ICE. pub fn replace_late_bound_regions( self, value: Binder<'tcx, T>, @@ -766,11 +755,14 @@ impl<'tcx> TyCtxt<'tcx> { let mut region_map = BTreeMap::new(); let mut real_fld_r = |br: ty::BoundRegion| *region_map.entry(br).or_insert_with(|| fld_r(br)); + let mut fld_t = |b| bug!("unexpected bound ty in binder: {b:?}"); + let mut fld_c = |b, ty| bug!("unexpected bound ct in binder: {b:?} {ty}"); + let value = value.skip_binder(); let value = if !value.has_escaping_bound_vars() { value } else { - let mut replacer = BoundVarReplacer::new(self, Some(&mut real_fld_r), None, None); + let mut replacer = BoundVarReplacer::new(self, &mut real_fld_r, &mut fld_t, &mut fld_c); value.fold_with(&mut replacer) }; (value, region_map) @@ -795,15 +787,14 @@ impl<'tcx> TyCtxt<'tcx> { if !value.has_escaping_bound_vars() { value } else { - let mut replacer = - BoundVarReplacer::new(self, Some(&mut fld_r), Some(&mut fld_t), Some(&mut fld_c)); + let mut replacer = BoundVarReplacer::new(self, &mut fld_r, &mut fld_t, &mut fld_c); value.fold_with(&mut replacer) } } /// Replaces all types or regions bound by the given `Binder`. The `fld_r` - /// closure replaces bound regions while the `fld_t` closure replaces bound - /// types. + /// closure replaces bound regions, the `fld_t` closure replaces bound + /// types, and `fld_c` replaces bound constants. pub fn replace_bound_vars( self, value: Binder<'tcx, T>,