Lower return types for gen fn to impl Iterator

This commit is contained in:
Eric Holk 2023-11-29 12:07:43 -08:00
parent bc0d10d4b0
commit c104f3b629
No known key found for this signature in database
GPG key ID: 8EA6B43ED4CE0911
7 changed files with 167 additions and 80 deletions
compiler
rustc_ast_lowering/src
rustc_hir/src
rustc_hir_typeck/src
rustc_parse/src/parser
rustc_resolve/src

View file

@ -1,3 +1,5 @@
use crate::FnReturnTransformation;
use super::errors::{InvalidAbi, InvalidAbiReason, InvalidAbiSuggestion, MisplacedRelaxTraitBound};
use super::ResolverAstLoweringExt;
use super::{AstOwner, ImplTraitContext, ImplTraitPosition};
@ -207,13 +209,33 @@ impl<'hir> LoweringContext<'_, 'hir> {
// only cares about the input argument patterns in the function
// declaration (decl), not the return types.
let asyncness = header.asyncness;
let body_id =
this.lower_maybe_async_body(span, hir_id, decl, asyncness, body.as_deref());
let genness = header.genness;
let body_id = this.lower_maybe_coroutine_body(
span,
hir_id,
decl,
asyncness,
genness,
body.as_deref(),
);
let itctx = ImplTraitContext::Universal;
let (generics, decl) =
this.lower_generics(generics, header.constness, id, &itctx, |this| {
let ret_id = asyncness.opt_return_id();
let ret_id = asyncness
.opt_return_id()
.map(|(node_id, span)| {
crate::FnReturnTransformation::Async(node_id, span)
})
.or_else(|| match genness {
Gen::Yes { span, closure_id: _, return_impl_trait_id } => {
Some(crate::FnReturnTransformation::Iterator(
return_impl_trait_id,
span,
))
}
_ => None,
});
this.lower_fn_decl(decl, id, *fn_sig_span, FnDeclKind::Fn, ret_id)
});
let sig = hir::FnSig {
@ -732,20 +754,31 @@ impl<'hir> LoweringContext<'_, 'hir> {
sig,
i.id,
FnDeclKind::Trait,
asyncness.opt_return_id(),
asyncness
.opt_return_id()
.map(|(node_id, span)| crate::FnReturnTransformation::Async(node_id, span)),
);
(generics, hir::TraitItemKind::Fn(sig, hir::TraitFn::Required(names)), false)
}
AssocItemKind::Fn(box Fn { sig, generics, body: Some(body), .. }) => {
let asyncness = sig.header.asyncness;
let body_id =
self.lower_maybe_async_body(i.span, hir_id, &sig.decl, asyncness, Some(body));
let genness = sig.header.genness;
let body_id = self.lower_maybe_coroutine_body(
i.span,
hir_id,
&sig.decl,
asyncness,
genness,
Some(body),
);
let (generics, sig) = self.lower_method_sig(
generics,
sig,
i.id,
FnDeclKind::Trait,
asyncness.opt_return_id(),
asyncness
.opt_return_id()
.map(|(node_id, span)| crate::FnReturnTransformation::Async(node_id, span)),
);
(generics, hir::TraitItemKind::Fn(sig, hir::TraitFn::Provided(body_id)), true)
}
@ -835,11 +868,13 @@ impl<'hir> LoweringContext<'_, 'hir> {
),
AssocItemKind::Fn(box Fn { sig, generics, body, .. }) => {
let asyncness = sig.header.asyncness;
let body_id = self.lower_maybe_async_body(
let genness = sig.header.genness;
let body_id = self.lower_maybe_coroutine_body(
i.span,
hir_id,
&sig.decl,
asyncness,
genness,
body.as_deref(),
);
let (generics, sig) = self.lower_method_sig(
@ -847,7 +882,9 @@ impl<'hir> LoweringContext<'_, 'hir> {
sig,
i.id,
if self.is_in_trait_impl { FnDeclKind::Impl } else { FnDeclKind::Inherent },
asyncness.opt_return_id(),
asyncness
.opt_return_id()
.map(|(node_id, span)| crate::FnReturnTransformation::Async(node_id, span)),
);
(generics, hir::ImplItemKind::Fn(sig, body_id))
@ -1011,16 +1048,22 @@ impl<'hir> LoweringContext<'_, 'hir> {
})
}
fn lower_maybe_async_body(
/// Takes what may be the body of an `async fn` or a `gen fn` and wraps it in an `async {}` or
/// `gen {}` block as appropriate.
fn lower_maybe_coroutine_body(
&mut self,
span: Span,
fn_id: hir::HirId,
decl: &FnDecl,
asyncness: Async,
genness: Gen,
body: Option<&Block>,
) -> hir::BodyId {
let (closure_id, body) = match (asyncness, body) {
(Async::Yes { closure_id, .. }, Some(body)) => (closure_id, body),
let (closure_id, body) = match (asyncness, genness, body) {
// FIXME(eholk): do something reasonable for `async gen fn`. Probably that's an error
// for now since it's not supported.
(Async::Yes { closure_id, .. }, _, Some(body))
| (_, Gen::Yes { closure_id, .. }, Some(body)) => (closure_id, body),
_ => return self.lower_fn_body_block(span, decl, body),
};
@ -1163,44 +1206,55 @@ impl<'hir> LoweringContext<'_, 'hir> {
parameters.push(new_parameter);
}
let async_expr = this.make_async_expr(
CaptureBy::Value { move_kw: rustc_span::DUMMY_SP },
closure_id,
None,
body.span,
hir::CoroutineSource::Fn,
|this| {
// Create a block from the user's function body:
let user_body = this.lower_block_expr(body);
let mkbody = |this: &mut LoweringContext<'_, 'hir>| {
// Create a block from the user's function body:
let user_body = this.lower_block_expr(body);
// Transform into `drop-temps { <user-body> }`, an expression:
let desugared_span =
this.mark_span_with_reason(DesugaringKind::Async, user_body.span, None);
let user_body =
this.expr_drop_temps(desugared_span, this.arena.alloc(user_body));
// Transform into `drop-temps { <user-body> }`, an expression:
let desugared_span =
this.mark_span_with_reason(DesugaringKind::Async, user_body.span, None);
let user_body = this.expr_drop_temps(desugared_span, this.arena.alloc(user_body));
// As noted above, create the final block like
//
// ```
// {
// let $param_pattern = $raw_param;
// ...
// drop-temps { <user-body> }
// }
// ```
let body = this.block_all(
desugared_span,
this.arena.alloc_from_iter(statements),
Some(user_body),
);
// As noted above, create the final block like
//
// ```
// {
// let $param_pattern = $raw_param;
// ...
// drop-temps { <user-body> }
// }
// ```
let body = this.block_all(
desugared_span,
this.arena.alloc_from_iter(statements),
Some(user_body),
);
this.expr_block(body)
},
);
this.expr_block(body)
};
let coroutine_expr = match (asyncness, genness) {
(Async::Yes { .. }, _) => this.make_async_expr(
CaptureBy::Value { move_kw: rustc_span::DUMMY_SP },
closure_id,
None,
body.span,
hir::CoroutineSource::Fn,
mkbody,
),
(_, Gen::Yes { .. }) => this.make_gen_expr(
CaptureBy::Value { move_kw: rustc_span::DUMMY_SP },
closure_id,
None,
body.span,
hir::CoroutineSource::Fn,
mkbody,
),
_ => unreachable!("we must have either an async fn or a gen fn"),
};
let hir_id = this.lower_node_id(closure_id);
this.maybe_forward_track_caller(body.span, fn_id, hir_id);
let expr = hir::Expr { hir_id, kind: async_expr, span: this.lower_span(body.span) };
let expr = hir::Expr { hir_id, kind: coroutine_expr, span: this.lower_span(body.span) };
(this.arena.alloc_from_iter(parameters), expr)
})
@ -1212,13 +1266,13 @@ impl<'hir> LoweringContext<'_, 'hir> {
sig: &FnSig,
id: NodeId,
kind: FnDeclKind,
is_async: Option<(NodeId, Span)>,
transform_return_type: Option<FnReturnTransformation>,
) -> (&'hir hir::Generics<'hir>, hir::FnSig<'hir>) {
let header = self.lower_fn_header(sig.header);
let itctx = ImplTraitContext::Universal;
let (generics, decl) =
self.lower_generics(generics, sig.header.constness, id, &itctx, |this| {
this.lower_fn_decl(&sig.decl, id, sig.span, kind, is_async)
this.lower_fn_decl(&sig.decl, id, sig.span, kind, transform_return_type)
});
(generics, hir::FnSig { header, decl, span: self.lower_span(sig.span) })
}

View file

@ -493,6 +493,21 @@ enum ParenthesizedGenericArgs {
Err,
}
/// Describes a return type transformation that can be performed by `LoweringContext::lower_fn_decl`
#[derive(Debug)]
enum FnReturnTransformation {
/// Replaces a return type `T` with `impl Future<Output = T>`.
///
/// The `NodeId` is the ID of the return type `impl Trait` item, and the `Span` points to the
/// `async` keyword.
Async(NodeId, Span),
/// Replaces a return type `T` with `impl Iterator<Item = T>`.
///
/// The `NodeId` is the ID of the return type `impl Trait` item, and the `Span` points to the
/// `gen` keyword.
Iterator(NodeId, Span),
}
impl<'a, 'hir> LoweringContext<'a, 'hir> {
fn create_def(
&mut self,
@ -1778,13 +1793,15 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
}))
}
// Lowers a function declaration.
//
// `decl`: the unlowered (AST) function declaration.
// `fn_node_id`: `impl Trait` arguments are lowered into generic parameters on the given `NodeId`.
// `make_ret_async`: if `Some`, converts `-> T` into `-> impl Future<Output = T>` in the
// return type. This is used for `async fn` declarations. The `NodeId` is the ID of the
// return type `impl Trait` item, and the `Span` points to the `async` keyword.
/// Lowers a function declaration.
///
/// `decl`: the unlowered (AST) function declaration.
///
/// `fn_node_id`: `impl Trait` arguments are lowered into generic parameters on the given
/// `NodeId`.
///
/// `transform_return_type`: if `Some`, applies some conversion to the return type, such as is
/// needed for `async fn` and `gen fn`. See [`FnReturnTransformation`] for more details.
#[instrument(level = "debug", skip(self))]
fn lower_fn_decl(
&mut self,
@ -1792,7 +1809,7 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
fn_node_id: NodeId,
fn_span: Span,
kind: FnDeclKind,
make_ret_async: Option<(NodeId, Span)>,
transform_return_type: Option<FnReturnTransformation>,
) -> &'hir hir::FnDecl<'hir> {
let c_variadic = decl.c_variadic();
@ -1821,11 +1838,12 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
self.lower_ty_direct(&param.ty, &itctx)
}));
let output = if let Some((ret_id, _span)) = make_ret_async {
let fn_def_id = self.local_def_id(fn_node_id);
self.lower_async_fn_ret_ty(&decl.output, fn_def_id, ret_id, kind, fn_span)
} else {
match &decl.output {
let output = match transform_return_type {
Some(transform) => {
let fn_def_id = self.local_def_id(fn_node_id);
self.lower_coroutine_fn_ret_ty(&decl.output, fn_def_id, transform, kind, fn_span)
}
None => match &decl.output {
FnRetTy::Ty(ty) => {
let context = if kind.return_impl_trait_allowed() {
let fn_def_id = self.local_def_id(fn_node_id);
@ -1849,7 +1867,7 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
hir::FnRetTy::Return(self.lower_ty(ty, &context))
}
FnRetTy::Default(span) => hir::FnRetTy::DefaultReturn(self.lower_span(*span)),
}
},
};
self.arena.alloc(hir::FnDecl {
@ -1888,17 +1906,22 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
// `fn_node_id`: `NodeId` of the parent function (used to create child impl trait definition)
// `opaque_ty_node_id`: `NodeId` of the opaque `impl Trait` type that should be created
#[instrument(level = "debug", skip(self))]
fn lower_async_fn_ret_ty(
fn lower_coroutine_fn_ret_ty(
&mut self,
output: &FnRetTy,
fn_def_id: LocalDefId,
opaque_ty_node_id: NodeId,
transform: FnReturnTransformation,
fn_kind: FnDeclKind,
fn_span: Span,
) -> hir::FnRetTy<'hir> {
let span = self.lower_span(fn_span);
let opaque_ty_span = self.mark_span_with_reason(DesugaringKind::Async, span, None);
let opaque_ty_node_id = match transform {
FnReturnTransformation::Async(opaque_ty_node_id, _)
| FnReturnTransformation::Iterator(opaque_ty_node_id, _) => opaque_ty_node_id,
};
let captured_lifetimes: Vec<_> = self
.resolver
.take_extra_lifetime_params(opaque_ty_node_id)
@ -1914,8 +1937,9 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
span,
opaque_ty_span,
|this| {
let future_bound = this.lower_async_fn_output_type_to_future_bound(
let future_bound = this.lower_coroutine_fn_output_type_to_future_bound(
output,
transform,
span,
ImplTraitContext::ReturnPositionOpaqueTy {
origin: hir::OpaqueTyOrigin::FnReturn(fn_def_id),
@ -1931,9 +1955,10 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
}
/// Transforms `-> T` into `Future<Output = T>`.
fn lower_async_fn_output_type_to_future_bound(
fn lower_coroutine_fn_output_type_to_future_bound(
&mut self,
output: &FnRetTy,
transform: FnReturnTransformation,
span: Span,
nested_impl_trait_context: ImplTraitContext,
) -> hir::GenericBound<'hir> {
@ -1948,17 +1973,23 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
FnRetTy::Default(ret_ty_span) => self.arena.alloc(self.ty_tup(*ret_ty_span, &[])),
};
// "<Output = T>"
// "<Output|Item = T>"
let (symbol, lang_item) = match transform {
FnReturnTransformation::Async(..) => (hir::FN_OUTPUT_NAME, hir::LangItem::Future),
FnReturnTransformation::Iterator(..) => {
(hir::ITERATOR_ITEM_NAME, hir::LangItem::Iterator)
}
};
let future_args = self.arena.alloc(hir::GenericArgs {
args: &[],
bindings: arena_vec![self; self.output_ty_binding(span, output_ty)],
bindings: arena_vec![self; self.assoc_ty_binding(symbol, span, output_ty)],
parenthesized: hir::GenericArgsParentheses::No,
span_ext: DUMMY_SP,
});
hir::GenericBound::LangItemTrait(
// ::std::future::Future<future_params>
hir::LangItem::Future,
lang_item,
self.lower_span(span),
self.next_id(),
future_args,

View file

@ -389,7 +389,7 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
FnRetTy::Default(_) => self.arena.alloc(self.ty_tup(*span, &[])),
};
let args = smallvec![GenericArg::Type(self.arena.alloc(self.ty_tup(*inputs_span, inputs)))];
let binding = self.output_ty_binding(output_ty.span, output_ty);
let binding = self.assoc_ty_binding(hir::FN_OUTPUT_NAME, output_ty.span, output_ty);
(
GenericArgsCtor {
args,
@ -401,13 +401,14 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
)
}
/// An associated type binding `Output = $ty`.
pub(crate) fn output_ty_binding(
/// An associated type binding `$symbol = $ty`.
pub(crate) fn assoc_ty_binding(
&mut self,
symbol: rustc_span::Symbol,
span: Span,
ty: &'hir hir::Ty<'hir>,
) -> hir::TypeBinding<'hir> {
let ident = Ident::with_dummy_span(hir::FN_OUTPUT_NAME);
let ident = Ident::with_dummy_span(symbol);
let kind = hir::TypeBindingKind::Equality { term: ty.into() };
let args = arena_vec![self;];
let bindings = arena_vec![self;];

View file

@ -2255,6 +2255,8 @@ pub enum ImplItemKind<'hir> {
/// The name of the associated type for `Fn` return types.
pub const FN_OUTPUT_NAME: Symbol = sym::Output;
/// The name of the associated type for `Iterator` item types.
pub const ITERATOR_ITEM_NAME: Symbol = sym::Item;
/// Bind a type to an associated type (i.e., `A = Foo`).
///

View file

@ -651,9 +651,9 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
},
)
}
Some(hir::CoroutineKind::Gen(hir::CoroutineSource::Fn)) => {
todo!("gen closures do not exist yet")
}
// For a `gen {}` block created as a `gen fn` body, we need the return type to be
// ().
Some(hir::CoroutineKind::Gen(hir::CoroutineSource::Fn)) => self.tcx.types.unit,
_ => astconv.ty_infer(None, decl.output.span()),
},

View file

@ -2410,10 +2410,6 @@ impl<'a> Parser<'a> {
}
}
if let Gen::Yes { span, .. } = genness {
self.sess.emit_err(errors::GenFn { span });
}
if !self.eat_keyword_case(kw::Fn, case) {
// It is possible for `expect_one_of` to recover given the contents of
// `self.expected_tokens`, therefore, do not use `self.unexpected()` which doesn't

View file

@ -156,7 +156,10 @@ impl<'a, 'b, 'tcx> visit::Visitor<'a> for DefCollector<'a, 'b, 'tcx> {
fn visit_fn(&mut self, fn_kind: FnKind<'a>, span: Span, _: NodeId) {
if let FnKind::Fn(_, _, sig, _, generics, body) = fn_kind {
if let Async::Yes { closure_id, .. } = sig.header.asyncness {
// FIXME(eholk): handle `async gen fn`
if let (Async::Yes { closure_id, .. }, _) | (_, Gen::Yes { closure_id, .. }) =
(sig.header.asyncness, sig.header.genness)
{
self.visit_generics(generics);
// For async functions, we need to create their inner defs inside of a