Remove a_is_expected from combine relations

This commit is contained in:
Michael Goulet 2024-02-26 20:42:09 +00:00
parent 61daee66a8
commit 04e22627f5
9 changed files with 75 additions and 135 deletions

View file

@ -120,7 +120,6 @@ impl<'me, 'bccx, 'tcx> NllTypeRelating<'me, 'bccx, 'tcx> {
fn relate_opaques(&mut self, a: Ty<'tcx>, b: Ty<'tcx>) -> RelateResult<'tcx, ()> {
let infcx = self.type_checker.infcx;
debug_assert!(!infcx.next_trait_solver());
let (a, b) = if self.a_is_expected() { (a, b) } else { (b, a) };
// `handle_opaque_type` cannot handle subtyping, so to support subtyping
// we instead eagerly generalize here. This is a bit of a mess but will go
// away once we're using the new solver.

View file

@ -49,7 +49,6 @@ pub struct At<'a, 'tcx> {
pub struct Trace<'a, 'tcx> {
at: At<'a, 'tcx>,
a_is_expected: bool,
trace: TypeTrace<'tcx>,
}
@ -105,23 +104,6 @@ pub trait ToTrace<'tcx>: Relate<'tcx> + Copy {
}
impl<'a, 'tcx> At<'a, 'tcx> {
/// Makes `a <: b`, where `a` may or may not be expected.
///
/// See [`At::trace_exp`] and [`Trace::sub`] for a version of
/// this method that only requires `T: Relate<'tcx>`
pub fn sub_exp<T>(
self,
define_opaque_types: DefineOpaqueTypes,
a_is_expected: bool,
a: T,
b: T,
) -> InferResult<'tcx, ()>
where
T: ToTrace<'tcx>,
{
self.trace_exp(a_is_expected, a, b).sub(define_opaque_types, a, b)
}
/// Makes `actual <: expected`. For example, if type-checking a
/// call like `foo(x)`, where `foo: fn(i32)`, you might have
/// `sup(i32, x)`, since the "expected" type is the type that
@ -138,7 +120,7 @@ impl<'a, 'tcx> At<'a, 'tcx> {
where
T: ToTrace<'tcx>,
{
self.sub_exp(define_opaque_types, false, actual, expected)
self.trace(expected, actual).sup(define_opaque_types, expected, actual)
}
/// Makes `expected <: actual`.
@ -154,24 +136,7 @@ impl<'a, 'tcx> At<'a, 'tcx> {
where
T: ToTrace<'tcx>,
{
self.sub_exp(define_opaque_types, true, expected, actual)
}
/// Makes `expected <: actual`.
///
/// See [`At::trace_exp`] and [`Trace::eq`] for a version of
/// this method that only requires `T: Relate<'tcx>`
pub fn eq_exp<T>(
self,
define_opaque_types: DefineOpaqueTypes,
a_is_expected: bool,
a: T,
b: T,
) -> InferResult<'tcx, ()>
where
T: ToTrace<'tcx>,
{
self.trace_exp(a_is_expected, a, b).eq(define_opaque_types, a, b)
self.trace(expected, actual).sub(define_opaque_types, expected, actual)
}
/// Makes `expected <: actual`.
@ -260,48 +225,50 @@ impl<'a, 'tcx> At<'a, 'tcx> {
where
T: ToTrace<'tcx>,
{
self.trace_exp(true, expected, actual)
}
/// Like `trace`, but the expected value is determined by the
/// boolean argument (if true, then the first argument `a` is the
/// "expected" value).
pub fn trace_exp<T>(self, a_is_expected: bool, a: T, b: T) -> Trace<'a, 'tcx>
where
T: ToTrace<'tcx>,
{
let trace = ToTrace::to_trace(self.cause, a_is_expected, a, b);
Trace { at: self, trace, a_is_expected }
let trace = ToTrace::to_trace(self.cause, true, expected, actual);
Trace { at: self, trace }
}
}
impl<'a, 'tcx> Trace<'a, 'tcx> {
/// Makes `a <: b` where `a` may or may not be expected (if
/// `a_is_expected` is true, then `a` is expected).
/// Makes `a <: b`.
#[instrument(skip(self), level = "debug")]
pub fn sub<T>(self, define_opaque_types: DefineOpaqueTypes, a: T, b: T) -> InferResult<'tcx, ()>
where
T: Relate<'tcx>,
{
let Trace { at, trace, a_is_expected } = self;
let Trace { at, trace } = self;
let mut fields = at.infcx.combine_fields(trace, at.param_env, define_opaque_types);
fields
.sub(a_is_expected)
.sub()
.relate(a, b)
.map(move |_| InferOk { value: (), obligations: fields.obligations })
}
/// Makes `a == b`; the expectation is set by the call to
/// `trace()`.
/// Makes `a :> b`.
#[instrument(skip(self), level = "debug")]
pub fn sup<T>(self, define_opaque_types: DefineOpaqueTypes, a: T, b: T) -> InferResult<'tcx, ()>
where
T: Relate<'tcx>,
{
let Trace { at, trace } = self;
let mut fields = at.infcx.combine_fields(trace, at.param_env, define_opaque_types);
fields
.sup()
.relate(a, b)
.map(move |_| InferOk { value: (), obligations: fields.obligations })
}
/// Makes `a == b`.
#[instrument(skip(self), level = "debug")]
pub fn eq<T>(self, define_opaque_types: DefineOpaqueTypes, a: T, b: T) -> InferResult<'tcx, ()>
where
T: Relate<'tcx>,
{
let Trace { at, trace, a_is_expected } = self;
let Trace { at, trace } = self;
let mut fields = at.infcx.combine_fields(trace, at.param_env, define_opaque_types);
fields
.equate(StructurallyRelateAliases::No, a_is_expected)
.equate(StructurallyRelateAliases::No)
.relate(a, b)
.map(move |_| InferOk { value: (), obligations: fields.obligations })
}
@ -313,11 +280,11 @@ impl<'a, 'tcx> Trace<'a, 'tcx> {
where
T: Relate<'tcx>,
{
let Trace { at, trace, a_is_expected } = self;
let Trace { at, trace } = self;
debug_assert!(at.infcx.next_trait_solver());
let mut fields = at.infcx.combine_fields(trace, at.param_env, DefineOpaqueTypes::No);
fields
.equate(StructurallyRelateAliases::Yes, a_is_expected)
.equate(StructurallyRelateAliases::Yes)
.relate(a, b)
.map(move |_| InferOk { value: (), obligations: fields.obligations })
}
@ -327,10 +294,10 @@ impl<'a, 'tcx> Trace<'a, 'tcx> {
where
T: Relate<'tcx>,
{
let Trace { at, trace, a_is_expected } = self;
let Trace { at, trace } = self;
let mut fields = at.infcx.combine_fields(trace, at.param_env, define_opaque_types);
fields
.lub(a_is_expected)
.lub()
.relate(a, b)
.map(move |t| InferOk { value: t, obligations: fields.obligations })
}
@ -340,10 +307,10 @@ impl<'a, 'tcx> Trace<'a, 'tcx> {
where
T: Relate<'tcx>,
{
let Trace { at, trace, a_is_expected } = self;
let Trace { at, trace } = self;
let mut fields = at.infcx.combine_fields(trace, at.param_env, define_opaque_types);
fields
.glb(a_is_expected)
.glb()
.relate(a, b)
.map(move |t| InferOk { value: t, obligations: fields.obligations })
}

View file

@ -522,13 +522,7 @@ impl<'tcx> InferCtxt<'tcx> {
) -> InferResult<'tcx, ()> {
let mut obligations = Vec::new();
self.insert_hidden_type(
opaque_type_key,
&cause,
param_env,
hidden_ty,
&mut obligations,
)?;
self.insert_hidden_type(opaque_type_key, &cause, param_env, hidden_ty, &mut obligations)?;
self.add_item_bounds_for_hidden_type(
opaque_type_key.def_id.to_def_id(),

View file

@ -321,21 +321,24 @@ impl<'infcx, 'tcx> CombineFields<'infcx, 'tcx> {
pub fn equate<'a>(
&'a mut self,
structurally_relate_aliases: StructurallyRelateAliases,
a_is_expected: bool,
) -> TypeRelating<'a, 'infcx, 'tcx> {
TypeRelating::new(self, a_is_expected, structurally_relate_aliases, ty::Invariant)
TypeRelating::new(self, structurally_relate_aliases, ty::Invariant)
}
pub fn sub<'a>(&'a mut self, a_is_expected: bool) -> TypeRelating<'a, 'infcx, 'tcx> {
TypeRelating::new(self, a_is_expected, StructurallyRelateAliases::No, ty::Covariant)
pub fn sub<'a>(&'a mut self) -> TypeRelating<'a, 'infcx, 'tcx> {
TypeRelating::new(self, StructurallyRelateAliases::No, ty::Covariant)
}
pub fn lub<'a>(&'a mut self, a_is_expected: bool) -> Lub<'a, 'infcx, 'tcx> {
Lub::new(self, a_is_expected)
pub fn sup<'a>(&'a mut self) -> TypeRelating<'a, 'infcx, 'tcx> {
TypeRelating::new(self, StructurallyRelateAliases::No, ty::Contravariant)
}
pub fn glb<'a>(&'a mut self, a_is_expected: bool) -> Glb<'a, 'infcx, 'tcx> {
Glb::new(self, a_is_expected)
pub fn lub<'a>(&'a mut self) -> Lub<'a, 'infcx, 'tcx> {
Lub::new(self)
}
pub fn glb<'a>(&'a mut self) -> Glb<'a, 'infcx, 'tcx> {
Glb::new(self)
}
pub fn register_obligations(&mut self, obligations: PredicateObligations<'tcx>) {

View file

@ -13,15 +13,11 @@ use crate::traits::{ObligationCause, PredicateObligations};
/// "Greatest lower bound" (common subtype)
pub struct Glb<'combine, 'infcx, 'tcx> {
fields: &'combine mut CombineFields<'infcx, 'tcx>,
a_is_expected: bool,
}
impl<'combine, 'infcx, 'tcx> Glb<'combine, 'infcx, 'tcx> {
pub fn new(
fields: &'combine mut CombineFields<'infcx, 'tcx>,
a_is_expected: bool,
) -> Glb<'combine, 'infcx, 'tcx> {
Glb { fields, a_is_expected }
pub fn new(fields: &'combine mut CombineFields<'infcx, 'tcx>) -> Glb<'combine, 'infcx, 'tcx> {
Glb { fields }
}
}
@ -35,7 +31,7 @@ impl<'tcx> TypeRelation<'tcx> for Glb<'_, '_, 'tcx> {
}
fn a_is_expected(&self) -> bool {
self.a_is_expected
true
}
fn relate_with_variance<T: Relate<'tcx>>(
@ -46,13 +42,11 @@ impl<'tcx> TypeRelation<'tcx> for Glb<'_, '_, 'tcx> {
b: T,
) -> RelateResult<'tcx, T> {
match variance {
ty::Invariant => {
self.fields.equate(StructurallyRelateAliases::No, self.a_is_expected).relate(a, b)
}
ty::Invariant => self.fields.equate(StructurallyRelateAliases::No).relate(a, b),
ty::Covariant => self.relate(a, b),
// FIXME(#41044) -- not correct, need test
ty::Bivariant => Ok(a),
ty::Contravariant => self.fields.lub(self.a_is_expected).relate(a, b),
ty::Contravariant => self.fields.lub().relate(a, b),
}
}
@ -126,7 +120,7 @@ impl<'combine, 'infcx, 'tcx> LatticeDir<'infcx, 'tcx> for Glb<'combine, 'infcx,
}
fn relate_bound(&mut self, v: Ty<'tcx>, a: Ty<'tcx>, b: Ty<'tcx>) -> RelateResult<'tcx, ()> {
let mut sub = self.fields.sub(self.a_is_expected);
let mut sub = self.fields.sub();
sub.relate(v, a)?;
sub.relate(v, b)?;
Ok(())

View file

@ -49,7 +49,13 @@ impl<'a, 'tcx> CombineFields<'a, 'tcx> {
debug!("b_prime={:?}", sup_prime);
// Compare types now that bound regions have been replaced.
let result = self.sub(sub_is_expected).relate(sub_prime, sup_prime);
// Reorder the inputs so that the expected is passed first.
let result = if sub_is_expected {
self.sub().relate(sub_prime, sup_prime)
} else {
self.sup().relate(sup_prime, sub_prime)
};
if result.is_ok() {
debug!("OK result={result:?}");
}

View file

@ -13,15 +13,11 @@ use rustc_span::Span;
/// "Least upper bound" (common supertype)
pub struct Lub<'combine, 'infcx, 'tcx> {
fields: &'combine mut CombineFields<'infcx, 'tcx>,
a_is_expected: bool,
}
impl<'combine, 'infcx, 'tcx> Lub<'combine, 'infcx, 'tcx> {
pub fn new(
fields: &'combine mut CombineFields<'infcx, 'tcx>,
a_is_expected: bool,
) -> Lub<'combine, 'infcx, 'tcx> {
Lub { fields, a_is_expected }
pub fn new(fields: &'combine mut CombineFields<'infcx, 'tcx>) -> Lub<'combine, 'infcx, 'tcx> {
Lub { fields }
}
}
@ -35,7 +31,7 @@ impl<'tcx> TypeRelation<'tcx> for Lub<'_, '_, 'tcx> {
}
fn a_is_expected(&self) -> bool {
self.a_is_expected
true
}
fn relate_with_variance<T: Relate<'tcx>>(
@ -46,13 +42,11 @@ impl<'tcx> TypeRelation<'tcx> for Lub<'_, '_, 'tcx> {
b: T,
) -> RelateResult<'tcx, T> {
match variance {
ty::Invariant => {
self.fields.equate(StructurallyRelateAliases::No, self.a_is_expected).relate(a, b)
}
ty::Invariant => self.fields.equate(StructurallyRelateAliases::No).relate(a, b),
ty::Covariant => self.relate(a, b),
// FIXME(#41044) -- not correct, need test
ty::Bivariant => Ok(a),
ty::Contravariant => self.fields.glb(self.a_is_expected).relate(a, b),
ty::Contravariant => self.fields.glb().relate(a, b),
}
}
@ -126,7 +120,7 @@ impl<'combine, 'infcx, 'tcx> LatticeDir<'infcx, 'tcx> for Lub<'combine, 'infcx,
}
fn relate_bound(&mut self, v: Ty<'tcx>, a: Ty<'tcx>, b: Ty<'tcx>) -> RelateResult<'tcx, ()> {
let mut sub = self.fields.sub(self.a_is_expected);
let mut sub = self.fields.sub();
sub.relate(a, v)?;
sub.relate(b, v)?;
Ok(())

View file

@ -12,7 +12,6 @@ use rustc_span::Span;
/// Enforce that `a` is equal to or a subtype of `b`.
pub struct TypeRelating<'combine, 'a, 'tcx> {
fields: &'combine mut CombineFields<'a, 'tcx>,
a_is_expected: bool,
structurally_relate_aliases: StructurallyRelateAliases,
ambient_variance: ty::Variance,
}
@ -20,11 +19,10 @@ pub struct TypeRelating<'combine, 'a, 'tcx> {
impl<'combine, 'infcx, 'tcx> TypeRelating<'combine, 'infcx, 'tcx> {
pub fn new(
f: &'combine mut CombineFields<'infcx, 'tcx>,
a_is_expected: bool,
structurally_relate_aliases: StructurallyRelateAliases,
ambient_variance: ty::Variance,
) -> TypeRelating<'combine, 'infcx, 'tcx> {
TypeRelating { fields: f, a_is_expected, structurally_relate_aliases, ambient_variance }
TypeRelating { fields: f, structurally_relate_aliases, ambient_variance }
}
}
@ -38,7 +36,7 @@ impl<'tcx> TypeRelation<'tcx> for TypeRelating<'_, '_, 'tcx> {
}
fn a_is_expected(&self) -> bool {
self.a_is_expected
true
}
fn relate_with_variance<T: Relate<'tcx>>(
@ -79,7 +77,7 @@ impl<'tcx> TypeRelation<'tcx> for TypeRelating<'_, '_, 'tcx> {
self.fields.trace.cause.clone(),
self.fields.param_env,
ty::Binder::dummy(ty::PredicateKind::Subtype(ty::SubtypePredicate {
a_is_expected: self.a_is_expected,
a_is_expected: true,
a,
b,
})),
@ -93,7 +91,7 @@ impl<'tcx> TypeRelation<'tcx> for TypeRelating<'_, '_, 'tcx> {
self.fields.trace.cause.clone(),
self.fields.param_env,
ty::Binder::dummy(ty::PredicateKind::Subtype(ty::SubtypePredicate {
a_is_expected: !self.a_is_expected,
a_is_expected: false,
a: b,
b: a,
})),
@ -109,18 +107,12 @@ impl<'tcx> TypeRelation<'tcx> for TypeRelating<'_, '_, 'tcx> {
}
(&ty::Infer(TyVar(a_vid)), _) => {
infcx.instantiate_ty_var(
self,
self.a_is_expected,
a_vid,
self.ambient_variance,
b,
)?;
infcx.instantiate_ty_var(self, true, a_vid, self.ambient_variance, b)?;
}
(_, &ty::Infer(TyVar(b_vid))) => {
infcx.instantiate_ty_var(
self,
!self.a_is_expected,
false,
b_vid,
self.ambient_variance.xform(ty::Contravariant),
a,
@ -147,13 +139,7 @@ impl<'tcx> TypeRelation<'tcx> for TypeRelating<'_, '_, 'tcx> {
{
self.fields.obligations.extend(
infcx
.handle_opaque_type(
a,
b,
self.a_is_expected,
&self.fields.trace.cause,
self.param_env(),
)?
.handle_opaque_type(a, b, true, &self.fields.trace.cause, self.param_env())?
.obligations,
);
}
@ -239,14 +225,14 @@ impl<'tcx> TypeRelation<'tcx> for TypeRelating<'_, '_, 'tcx> {
} else {
match self.ambient_variance {
ty::Covariant => {
self.fields.higher_ranked_sub(a, b, self.a_is_expected)?;
self.fields.higher_ranked_sub(a, b, true)?;
}
ty::Contravariant => {
self.fields.higher_ranked_sub(b, a, !self.a_is_expected)?;
self.fields.higher_ranked_sub(b, a, false)?;
}
ty::Invariant => {
self.fields.higher_ranked_sub(a, b, self.a_is_expected)?;
self.fields.higher_ranked_sub(b, a, !self.a_is_expected)?;
self.fields.higher_ranked_sub(a, b, true)?;
self.fields.higher_ranked_sub(b, a, false)?;
}
ty::Bivariant => {
unreachable!("Expected bivariance to be handled in relate_with_variance")

View file

@ -1541,12 +1541,9 @@ impl<'tcx> TypeErrCtxt<'_, 'tcx> {
// since the normalization is just done to improve the error message.
let _ = ocx.select_where_possible();
if let Err(new_err) = ocx.eq(
&obligation.cause,
obligation.param_env,
expected,
actual,
) {
if let Err(new_err) =
ocx.eq(&obligation.cause, obligation.param_env, expected, actual)
{
(Some((data, is_normalized_term_expected, normalized_term, data.term)), new_err)
} else {
(None, error.err)