Single commit implementing the enzyme/autodiff frontend
Co-authored-by: Lorenz Schmidt <bytesnake@mailbox.org>
This commit is contained in:
parent
52fd998399
commit
624c071b99
17 changed files with 1384 additions and 1 deletions
283
compiler/rustc_ast/src/expand/autodiff_attrs.rs
Normal file
283
compiler/rustc_ast/src/expand/autodiff_attrs.rs
Normal file
|
@ -0,0 +1,283 @@
|
||||||
|
//! This crate handles the user facing autodiff macro. For each `#[autodiff(...)]` attribute,
|
||||||
|
//! we create an [`AutoDiffItem`] which contains the source and target function names. The source
|
||||||
|
//! is the function to which the autodiff attribute is applied, and the target is the function
|
||||||
|
//! getting generated by us (with a name given by the user as the first autodiff arg).
|
||||||
|
|
||||||
|
use std::fmt::{self, Display, Formatter};
|
||||||
|
use std::str::FromStr;
|
||||||
|
|
||||||
|
use crate::expand::typetree::TypeTree;
|
||||||
|
use crate::expand::{Decodable, Encodable, HashStable_Generic};
|
||||||
|
use crate::ptr::P;
|
||||||
|
use crate::{Ty, TyKind};
|
||||||
|
|
||||||
|
/// Forward and Reverse Mode are well known names for automatic differentiation implementations.
|
||||||
|
/// Enzyme does support both, but with different semantics, see DiffActivity. The First variants
|
||||||
|
/// are a hack to support higher order derivatives. We need to compute first order derivatives
|
||||||
|
/// before we compute second order derivatives, otherwise we would differentiate our placeholder
|
||||||
|
/// functions. The proper solution is to recognize and resolve this DAG of autodiff invocations,
|
||||||
|
/// as it's already done in the C++ and Julia frontend of Enzyme.
|
||||||
|
///
|
||||||
|
/// (FIXME) remove *First variants.
|
||||||
|
/// Documentation for using [reverse](https://enzyme.mit.edu/rust/rev.html) and
|
||||||
|
/// [forward](https://enzyme.mit.edu/rust/fwd.html) mode is available online.
|
||||||
|
#[derive(Clone, Copy, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
|
||||||
|
pub enum DiffMode {
|
||||||
|
/// No autodiff is applied (used during error handling).
|
||||||
|
Error,
|
||||||
|
/// The primal function which we will differentiate.
|
||||||
|
Source,
|
||||||
|
/// The target function, to be created using forward mode AD.
|
||||||
|
Forward,
|
||||||
|
/// The target function, to be created using reverse mode AD.
|
||||||
|
Reverse,
|
||||||
|
/// The target function, to be created using forward mode AD.
|
||||||
|
/// This target function will also be used as a source for higher order derivatives,
|
||||||
|
/// so compute it before all Forward/Reverse targets and optimize it through llvm.
|
||||||
|
ForwardFirst,
|
||||||
|
/// The target function, to be created using reverse mode AD.
|
||||||
|
/// This target function will also be used as a source for higher order derivatives,
|
||||||
|
/// so compute it before all Forward/Reverse targets and optimize it through llvm.
|
||||||
|
ReverseFirst,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Dual and Duplicated (and their Only variants) are getting lowered to the same Enzyme Activity.
|
||||||
|
/// However, under forward mode we overwrite the previous shadow value, while for reverse mode
|
||||||
|
/// we add to the previous shadow value. To not surprise users, we picked different names.
|
||||||
|
/// Dual numbers is also a quite well known name for forward mode AD types.
|
||||||
|
#[derive(Clone, Copy, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
|
||||||
|
pub enum DiffActivity {
|
||||||
|
/// Implicit or Explicit () return type, so a special case of Const.
|
||||||
|
None,
|
||||||
|
/// Don't compute derivatives with respect to this input/output.
|
||||||
|
Const,
|
||||||
|
/// Reverse Mode, Compute derivatives for this scalar input/output.
|
||||||
|
Active,
|
||||||
|
/// Reverse Mode, Compute derivatives for this scalar output, but don't compute
|
||||||
|
/// the original return value.
|
||||||
|
ActiveOnly,
|
||||||
|
/// Forward Mode, Compute derivatives for this input/output and *overwrite* the shadow argument
|
||||||
|
/// with it.
|
||||||
|
Dual,
|
||||||
|
/// Forward Mode, Compute derivatives for this input/output and *overwrite* the shadow argument
|
||||||
|
/// with it. Drop the code which updates the original input/output for maximum performance.
|
||||||
|
DualOnly,
|
||||||
|
/// Reverse Mode, Compute derivatives for this &T or *T input and *add* it to the shadow argument.
|
||||||
|
Duplicated,
|
||||||
|
/// Reverse Mode, Compute derivatives for this &T or *T input and *add* it to the shadow argument.
|
||||||
|
/// Drop the code which updates the original input for maximum performance.
|
||||||
|
DuplicatedOnly,
|
||||||
|
/// All Integers must be Const, but these are used to mark the integer which represents the
|
||||||
|
/// length of a slice/vec. This is used for safety checks on slices.
|
||||||
|
FakeActivitySize,
|
||||||
|
}
|
||||||
|
/// We generate one of these structs for each `#[autodiff(...)]` attribute.
|
||||||
|
#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
|
||||||
|
pub struct AutoDiffItem {
|
||||||
|
/// The name of the function getting differentiated
|
||||||
|
pub source: String,
|
||||||
|
/// The name of the function being generated
|
||||||
|
pub target: String,
|
||||||
|
pub attrs: AutoDiffAttrs,
|
||||||
|
/// Describe the memory layout of input types
|
||||||
|
pub inputs: Vec<TypeTree>,
|
||||||
|
/// Describe the memory layout of the output type
|
||||||
|
pub output: TypeTree,
|
||||||
|
}
|
||||||
|
#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
|
||||||
|
pub struct AutoDiffAttrs {
|
||||||
|
/// Conceptually either forward or reverse mode AD, as described in various autodiff papers and
|
||||||
|
/// e.g. in the [JAX
|
||||||
|
/// Documentation](https://jax.readthedocs.io/en/latest/_tutorials/advanced-autodiff.html#how-it-s-made-two-foundational-autodiff-functions).
|
||||||
|
pub mode: DiffMode,
|
||||||
|
pub ret_activity: DiffActivity,
|
||||||
|
pub input_activity: Vec<DiffActivity>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl DiffMode {
|
||||||
|
pub fn is_rev(&self) -> bool {
|
||||||
|
matches!(self, DiffMode::Reverse | DiffMode::ReverseFirst)
|
||||||
|
}
|
||||||
|
pub fn is_fwd(&self) -> bool {
|
||||||
|
matches!(self, DiffMode::Forward | DiffMode::ForwardFirst)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Display for DiffMode {
|
||||||
|
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
|
||||||
|
match self {
|
||||||
|
DiffMode::Error => write!(f, "Error"),
|
||||||
|
DiffMode::Source => write!(f, "Source"),
|
||||||
|
DiffMode::Forward => write!(f, "Forward"),
|
||||||
|
DiffMode::Reverse => write!(f, "Reverse"),
|
||||||
|
DiffMode::ForwardFirst => write!(f, "ForwardFirst"),
|
||||||
|
DiffMode::ReverseFirst => write!(f, "ReverseFirst"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Active(Only) is valid in reverse-mode AD for scalar float returns (f16/f32/...).
|
||||||
|
/// Dual(Only) is valid in forward-mode AD for scalar float returns (f16/f32/...).
|
||||||
|
/// Const is valid for all cases and means that we don't compute derivatives wrt. this output.
|
||||||
|
/// That usually means we have a &mut or *mut T output and compute derivatives wrt. that arg,
|
||||||
|
/// but this is too complex to verify here. Also it's just a logic error if users get this wrong.
|
||||||
|
pub fn valid_ret_activity(mode: DiffMode, activity: DiffActivity) -> bool {
|
||||||
|
if activity == DiffActivity::None {
|
||||||
|
// Only valid if primal returns (), but we can't check that here.
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
match mode {
|
||||||
|
DiffMode::Error => false,
|
||||||
|
DiffMode::Source => false,
|
||||||
|
DiffMode::Forward | DiffMode::ForwardFirst => {
|
||||||
|
activity == DiffActivity::Dual
|
||||||
|
|| activity == DiffActivity::DualOnly
|
||||||
|
|| activity == DiffActivity::Const
|
||||||
|
}
|
||||||
|
DiffMode::Reverse | DiffMode::ReverseFirst => {
|
||||||
|
activity == DiffActivity::Const
|
||||||
|
|| activity == DiffActivity::Active
|
||||||
|
|| activity == DiffActivity::ActiveOnly
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// For indirections (ptr/ref) we can't use Active, since Active allocates a shadow value
|
||||||
|
/// for the given argument, but we generally can't know the size of such a type.
|
||||||
|
/// For scalar types (f16/f32/f64/f128) we can use Active and we can't use Duplicated,
|
||||||
|
/// since Duplicated expects a mutable ref/ptr and we would thus end up with a shadow value
|
||||||
|
/// who is an indirect type, which doesn't match the primal scalar type. We can't prevent
|
||||||
|
/// users here from marking scalars as Duplicated, due to type aliases.
|
||||||
|
pub fn valid_ty_for_activity(ty: &P<Ty>, activity: DiffActivity) -> bool {
|
||||||
|
use DiffActivity::*;
|
||||||
|
// It's always allowed to mark something as Const, since we won't compute derivatives wrt. it.
|
||||||
|
if matches!(activity, Const) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
if matches!(activity, Dual | DualOnly) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
// FIXME(ZuseZ4) We should make this more robust to also
|
||||||
|
// handle type aliases. Once that is done, we can be more restrictive here.
|
||||||
|
if matches!(activity, Active | ActiveOnly) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
matches!(ty.kind, TyKind::Ptr(_) | TyKind::Ref(..))
|
||||||
|
&& matches!(activity, Duplicated | DuplicatedOnly)
|
||||||
|
}
|
||||||
|
pub fn valid_input_activity(mode: DiffMode, activity: DiffActivity) -> bool {
|
||||||
|
use DiffActivity::*;
|
||||||
|
return match mode {
|
||||||
|
DiffMode::Error => false,
|
||||||
|
DiffMode::Source => false,
|
||||||
|
DiffMode::Forward | DiffMode::ForwardFirst => {
|
||||||
|
matches!(activity, Dual | DualOnly | Const)
|
||||||
|
}
|
||||||
|
DiffMode::Reverse | DiffMode::ReverseFirst => {
|
||||||
|
matches!(activity, Active | ActiveOnly | Duplicated | DuplicatedOnly | Const)
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Display for DiffActivity {
|
||||||
|
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
|
||||||
|
match self {
|
||||||
|
DiffActivity::None => write!(f, "None"),
|
||||||
|
DiffActivity::Const => write!(f, "Const"),
|
||||||
|
DiffActivity::Active => write!(f, "Active"),
|
||||||
|
DiffActivity::ActiveOnly => write!(f, "ActiveOnly"),
|
||||||
|
DiffActivity::Dual => write!(f, "Dual"),
|
||||||
|
DiffActivity::DualOnly => write!(f, "DualOnly"),
|
||||||
|
DiffActivity::Duplicated => write!(f, "Duplicated"),
|
||||||
|
DiffActivity::DuplicatedOnly => write!(f, "DuplicatedOnly"),
|
||||||
|
DiffActivity::FakeActivitySize => write!(f, "FakeActivitySize"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl FromStr for DiffMode {
|
||||||
|
type Err = ();
|
||||||
|
|
||||||
|
fn from_str(s: &str) -> Result<DiffMode, ()> {
|
||||||
|
match s {
|
||||||
|
"Error" => Ok(DiffMode::Error),
|
||||||
|
"Source" => Ok(DiffMode::Source),
|
||||||
|
"Forward" => Ok(DiffMode::Forward),
|
||||||
|
"Reverse" => Ok(DiffMode::Reverse),
|
||||||
|
"ForwardFirst" => Ok(DiffMode::ForwardFirst),
|
||||||
|
"ReverseFirst" => Ok(DiffMode::ReverseFirst),
|
||||||
|
_ => Err(()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
impl FromStr for DiffActivity {
|
||||||
|
type Err = ();
|
||||||
|
|
||||||
|
fn from_str(s: &str) -> Result<DiffActivity, ()> {
|
||||||
|
match s {
|
||||||
|
"None" => Ok(DiffActivity::None),
|
||||||
|
"Active" => Ok(DiffActivity::Active),
|
||||||
|
"ActiveOnly" => Ok(DiffActivity::ActiveOnly),
|
||||||
|
"Const" => Ok(DiffActivity::Const),
|
||||||
|
"Dual" => Ok(DiffActivity::Dual),
|
||||||
|
"DualOnly" => Ok(DiffActivity::DualOnly),
|
||||||
|
"Duplicated" => Ok(DiffActivity::Duplicated),
|
||||||
|
"DuplicatedOnly" => Ok(DiffActivity::DuplicatedOnly),
|
||||||
|
_ => Err(()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl AutoDiffAttrs {
|
||||||
|
pub fn has_ret_activity(&self) -> bool {
|
||||||
|
self.ret_activity != DiffActivity::None
|
||||||
|
}
|
||||||
|
pub fn has_active_only_ret(&self) -> bool {
|
||||||
|
self.ret_activity == DiffActivity::ActiveOnly
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn error() -> Self {
|
||||||
|
AutoDiffAttrs {
|
||||||
|
mode: DiffMode::Error,
|
||||||
|
ret_activity: DiffActivity::None,
|
||||||
|
input_activity: Vec::new(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
pub fn source() -> Self {
|
||||||
|
AutoDiffAttrs {
|
||||||
|
mode: DiffMode::Source,
|
||||||
|
ret_activity: DiffActivity::None,
|
||||||
|
input_activity: Vec::new(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn is_active(&self) -> bool {
|
||||||
|
self.mode != DiffMode::Error
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn is_source(&self) -> bool {
|
||||||
|
self.mode == DiffMode::Source
|
||||||
|
}
|
||||||
|
pub fn apply_autodiff(&self) -> bool {
|
||||||
|
!matches!(self.mode, DiffMode::Error | DiffMode::Source)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn into_item(
|
||||||
|
self,
|
||||||
|
source: String,
|
||||||
|
target: String,
|
||||||
|
inputs: Vec<TypeTree>,
|
||||||
|
output: TypeTree,
|
||||||
|
) -> AutoDiffItem {
|
||||||
|
AutoDiffItem { source, target, inputs, output, attrs: self }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl fmt::Display for AutoDiffItem {
|
||||||
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||||
|
write!(f, "Differentiating {} -> {}", self.source, self.target)?;
|
||||||
|
write!(f, " with attributes: {:?}", self.attrs)?;
|
||||||
|
write!(f, " with inputs: {:?}", self.inputs)?;
|
||||||
|
write!(f, " with output: {:?}", self.output)
|
||||||
|
}
|
||||||
|
}
|
|
@ -7,6 +7,8 @@ use rustc_span::symbol::Ident;
|
||||||
use crate::MetaItem;
|
use crate::MetaItem;
|
||||||
|
|
||||||
pub mod allocator;
|
pub mod allocator;
|
||||||
|
pub mod autodiff_attrs;
|
||||||
|
pub mod typetree;
|
||||||
|
|
||||||
#[derive(Debug, Clone, Encodable, Decodable, HashStable_Generic)]
|
#[derive(Debug, Clone, Encodable, Decodable, HashStable_Generic)]
|
||||||
pub struct StrippedCfgItem<ModId = DefId> {
|
pub struct StrippedCfgItem<ModId = DefId> {
|
||||||
|
|
90
compiler/rustc_ast/src/expand/typetree.rs
Normal file
90
compiler/rustc_ast/src/expand/typetree.rs
Normal file
|
@ -0,0 +1,90 @@
|
||||||
|
//! This module contains the definition of the `TypeTree` and `Type` structs.
|
||||||
|
//! They are thin Rust wrappers around the TypeTrees used by Enzyme as the LLVM based autodiff
|
||||||
|
//! backend. The Enzyme TypeTrees currently have various limitations and should be rewritten, so the
|
||||||
|
//! Rust frontend obviously has the same limitations. The main motivation of TypeTrees is to
|
||||||
|
//! represent how a type looks like "in memory". Enzyme can deduce this based on usage patterns in
|
||||||
|
//! the user code, but this is extremely slow and not even always sufficient. As such we lower some
|
||||||
|
//! information from rustc to help Enzyme. For a full explanation of their design it is necessary to
|
||||||
|
//! analyze the implementation in Enzyme core itself. As a rough summary, `-1` in Enzyme speech means
|
||||||
|
//! everywhere. That is `{0:-1: Float}` means at index 0 you have a ptr, if you dereference it it
|
||||||
|
//! will be floats everywhere. Thus `* f32`. If you have `{-1:int}` it means int's everywhere,
|
||||||
|
//! e.g. [i32; N]. `{0:-1:-1 float}` then means one pointer at offset 0, if you dereference it there
|
||||||
|
//! will be only pointers, if you dereference these new pointers they will point to array of floats.
|
||||||
|
//! Generally, it allows byte-specific descriptions.
|
||||||
|
//! FIXME: This description might be partly inaccurate and should be extended, along with
|
||||||
|
//! adding documentation to the corresponding Enzyme core code.
|
||||||
|
//! FIXME: Rewrite the TypeTree logic in Enzyme core to reduce the need for the rustc frontend to
|
||||||
|
//! provide typetree information.
|
||||||
|
//! FIXME: We should also re-evaluate where we create TypeTrees from Rust types, since MIR
|
||||||
|
//! representations of some types might not be accurate. For example a vector of floats might be
|
||||||
|
//! represented as a vector of u8s in MIR in some cases.
|
||||||
|
|
||||||
|
use std::fmt;
|
||||||
|
|
||||||
|
use crate::expand::{Decodable, Encodable, HashStable_Generic};
|
||||||
|
|
||||||
|
#[derive(Clone, Copy, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
|
||||||
|
pub enum Kind {
|
||||||
|
Anything,
|
||||||
|
Integer,
|
||||||
|
Pointer,
|
||||||
|
Half,
|
||||||
|
Float,
|
||||||
|
Double,
|
||||||
|
Unknown,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
|
||||||
|
pub struct TypeTree(pub Vec<Type>);
|
||||||
|
|
||||||
|
impl TypeTree {
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self(Vec::new())
|
||||||
|
}
|
||||||
|
pub fn all_ints() -> Self {
|
||||||
|
Self(vec![Type { offset: -1, size: 1, kind: Kind::Integer, child: TypeTree::new() }])
|
||||||
|
}
|
||||||
|
pub fn int(size: usize) -> Self {
|
||||||
|
let mut ints = Vec::with_capacity(size);
|
||||||
|
for i in 0..size {
|
||||||
|
ints.push(Type {
|
||||||
|
offset: i as isize,
|
||||||
|
size: 1,
|
||||||
|
kind: Kind::Integer,
|
||||||
|
child: TypeTree::new(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
Self(ints)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
|
||||||
|
pub struct FncTree {
|
||||||
|
pub args: Vec<TypeTree>,
|
||||||
|
pub ret: TypeTree,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
|
||||||
|
pub struct Type {
|
||||||
|
pub offset: isize,
|
||||||
|
pub size: usize,
|
||||||
|
pub kind: Kind,
|
||||||
|
pub child: TypeTree,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Type {
|
||||||
|
pub fn add_offset(self, add: isize) -> Self {
|
||||||
|
let offset = match self.offset {
|
||||||
|
-1 => add,
|
||||||
|
x => add + x,
|
||||||
|
};
|
||||||
|
|
||||||
|
Self { size: self.size, kind: self.kind, child: self.child, offset }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl fmt::Display for Type {
|
||||||
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||||
|
<Self as fmt::Debug>::fmt(self, f)
|
||||||
|
}
|
||||||
|
}
|
|
@ -3,6 +3,10 @@ name = "rustc_builtin_macros"
|
||||||
version = "0.0.0"
|
version = "0.0.0"
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
|
|
||||||
|
|
||||||
|
[lints.rust]
|
||||||
|
unexpected_cfgs = { level = "warn", check-cfg = ['cfg(llvm_enzyme)'] }
|
||||||
|
|
||||||
[lib]
|
[lib]
|
||||||
doctest = false
|
doctest = false
|
||||||
|
|
||||||
|
|
|
@ -69,6 +69,15 @@ builtin_macros_assert_requires_boolean = macro requires a boolean expression as
|
||||||
builtin_macros_assert_requires_expression = macro requires an expression as an argument
|
builtin_macros_assert_requires_expression = macro requires an expression as an argument
|
||||||
.suggestion = try removing semicolon
|
.suggestion = try removing semicolon
|
||||||
|
|
||||||
|
builtin_macros_autodiff = autodiff must be applied to function
|
||||||
|
builtin_macros_autodiff_missing_config = autodiff requires at least a name and mode
|
||||||
|
builtin_macros_autodiff_mode = unknown Mode: `{$mode}`. Use `Forward` or `Reverse`
|
||||||
|
builtin_macros_autodiff_mode_activity = {$act} can not be used in {$mode} Mode
|
||||||
|
builtin_macros_autodiff_not_build = this rustc version does not support autodiff
|
||||||
|
builtin_macros_autodiff_number_activities = expected {$expected} activities, but found {$found}
|
||||||
|
builtin_macros_autodiff_ty_activity = {$act} can not be used for this type
|
||||||
|
|
||||||
|
builtin_macros_autodiff_unknown_activity = did not recognize Activity: `{$act}`
|
||||||
builtin_macros_bad_derive_target = `derive` may only be applied to `struct`s, `enum`s and `union`s
|
builtin_macros_bad_derive_target = `derive` may only be applied to `struct`s, `enum`s and `union`s
|
||||||
.label = not applicable here
|
.label = not applicable here
|
||||||
.label2 = not a `struct`, `enum` or `union`
|
.label2 = not a `struct`, `enum` or `union`
|
||||||
|
|
820
compiler/rustc_builtin_macros/src/autodiff.rs
Normal file
820
compiler/rustc_builtin_macros/src/autodiff.rs
Normal file
|
@ -0,0 +1,820 @@
|
||||||
|
//! This module contains the implementation of the `#[autodiff]` attribute.
|
||||||
|
//! Currently our linter isn't smart enough to see that each import is used in one of the two
|
||||||
|
//! configs (autodiff enabled or disabled), so we have to add cfg's to each import.
|
||||||
|
//! FIXME(ZuseZ4): Remove this once we have a smarter linter.
|
||||||
|
|
||||||
|
#[cfg(llvm_enzyme)]
|
||||||
|
mod llvm_enzyme {
|
||||||
|
use std::str::FromStr;
|
||||||
|
use std::string::String;
|
||||||
|
|
||||||
|
use rustc_ast::expand::autodiff_attrs::{
|
||||||
|
AutoDiffAttrs, DiffActivity, DiffMode, valid_input_activity, valid_ty_for_activity,
|
||||||
|
};
|
||||||
|
use rustc_ast::ptr::P;
|
||||||
|
use rustc_ast::token::{Token, TokenKind};
|
||||||
|
use rustc_ast::tokenstream::*;
|
||||||
|
use rustc_ast::visit::AssocCtxt::*;
|
||||||
|
use rustc_ast::{
|
||||||
|
self as ast, AssocItemKind, BindingMode, FnRetTy, FnSig, Generics, ItemKind, MetaItemInner,
|
||||||
|
PatKind, TyKind,
|
||||||
|
};
|
||||||
|
use rustc_expand::base::{Annotatable, ExtCtxt};
|
||||||
|
use rustc_span::symbol::{Ident, kw, sym};
|
||||||
|
use rustc_span::{Span, Symbol};
|
||||||
|
use thin_vec::{ThinVec, thin_vec};
|
||||||
|
use tracing::{debug, trace};
|
||||||
|
|
||||||
|
use crate::errors;
|
||||||
|
|
||||||
|
// If we have a default `()` return type or explicitley `()` return type,
|
||||||
|
// then we often can skip doing some work.
|
||||||
|
fn has_ret(ty: &FnRetTy) -> bool {
|
||||||
|
match ty {
|
||||||
|
FnRetTy::Ty(ty) => !ty.kind.is_unit(),
|
||||||
|
FnRetTy::Default(_) => false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
fn first_ident(x: &MetaItemInner) -> rustc_span::symbol::Ident {
|
||||||
|
let segments = &x.meta_item().unwrap().path.segments;
|
||||||
|
assert!(segments.len() == 1);
|
||||||
|
segments[0].ident
|
||||||
|
}
|
||||||
|
|
||||||
|
fn name(x: &MetaItemInner) -> String {
|
||||||
|
first_ident(x).name.to_string()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn from_ast(
|
||||||
|
ecx: &mut ExtCtxt<'_>,
|
||||||
|
meta_item: &ThinVec<MetaItemInner>,
|
||||||
|
has_ret: bool,
|
||||||
|
) -> AutoDiffAttrs {
|
||||||
|
let dcx = ecx.sess.dcx();
|
||||||
|
let mode = name(&meta_item[1]);
|
||||||
|
let Ok(mode) = DiffMode::from_str(&mode) else {
|
||||||
|
dcx.emit_err(errors::AutoDiffInvalidMode { span: meta_item[1].span(), mode });
|
||||||
|
return AutoDiffAttrs::error();
|
||||||
|
};
|
||||||
|
let mut activities: Vec<DiffActivity> = vec![];
|
||||||
|
let mut errors = false;
|
||||||
|
for x in &meta_item[2..] {
|
||||||
|
let activity_str = name(&x);
|
||||||
|
let res = DiffActivity::from_str(&activity_str);
|
||||||
|
match res {
|
||||||
|
Ok(x) => activities.push(x),
|
||||||
|
Err(_) => {
|
||||||
|
dcx.emit_err(errors::AutoDiffUnknownActivity {
|
||||||
|
span: x.span(),
|
||||||
|
act: activity_str,
|
||||||
|
});
|
||||||
|
errors = true;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
if errors {
|
||||||
|
return AutoDiffAttrs::error();
|
||||||
|
}
|
||||||
|
|
||||||
|
// If a return type exist, we need to split the last activity,
|
||||||
|
// otherwise we return None as placeholder.
|
||||||
|
let (ret_activity, input_activity) = if has_ret {
|
||||||
|
let Some((last, rest)) = activities.split_last() else {
|
||||||
|
unreachable!(
|
||||||
|
"should not be reachable because we counted the number of activities previously"
|
||||||
|
);
|
||||||
|
};
|
||||||
|
(last, rest)
|
||||||
|
} else {
|
||||||
|
(&DiffActivity::None, activities.as_slice())
|
||||||
|
};
|
||||||
|
|
||||||
|
AutoDiffAttrs { mode, ret_activity: *ret_activity, input_activity: input_activity.to_vec() }
|
||||||
|
}
|
||||||
|
|
||||||
|
/// We expand the autodiff macro to generate a new placeholder function which passes
|
||||||
|
/// type-checking and can be called by users. The function body of the placeholder function will
|
||||||
|
/// later be replaced on LLVM-IR level, so the design of the body is less important and for now
|
||||||
|
/// should just prevent early inlining and optimizations which alter the function signature.
|
||||||
|
/// The exact signature of the generated function depends on the configuration provided by the
|
||||||
|
/// user, but here is an example:
|
||||||
|
///
|
||||||
|
/// ```
|
||||||
|
/// #[autodiff(cos_box, Reverse, Duplicated, Active)]
|
||||||
|
/// fn sin(x: &Box<f32>) -> f32 {
|
||||||
|
/// f32::sin(**x)
|
||||||
|
/// }
|
||||||
|
/// ```
|
||||||
|
/// which becomes expanded to:
|
||||||
|
/// ```
|
||||||
|
/// #[rustc_autodiff]
|
||||||
|
/// #[inline(never)]
|
||||||
|
/// fn sin(x: &Box<f32>) -> f32 {
|
||||||
|
/// f32::sin(**x)
|
||||||
|
/// }
|
||||||
|
/// #[rustc_autodiff(Reverse, Duplicated, Active)]
|
||||||
|
/// #[inline(never)]
|
||||||
|
/// fn cos_box(x: &Box<f32>, dx: &mut Box<f32>, dret: f32) -> f32 {
|
||||||
|
/// unsafe {
|
||||||
|
/// asm!("NOP");
|
||||||
|
/// };
|
||||||
|
/// ::core::hint::black_box(sin(x));
|
||||||
|
/// ::core::hint::black_box((dx, dret));
|
||||||
|
/// ::core::hint::black_box(sin(x))
|
||||||
|
/// }
|
||||||
|
/// ```
|
||||||
|
/// FIXME(ZuseZ4): Once autodiff is enabled by default, make this a doc comment which is checked
|
||||||
|
/// in CI.
|
||||||
|
pub(crate) fn expand(
|
||||||
|
ecx: &mut ExtCtxt<'_>,
|
||||||
|
expand_span: Span,
|
||||||
|
meta_item: &ast::MetaItem,
|
||||||
|
mut item: Annotatable,
|
||||||
|
) -> Vec<Annotatable> {
|
||||||
|
let dcx = ecx.sess.dcx();
|
||||||
|
// first get the annotable item:
|
||||||
|
let (sig, is_impl): (FnSig, bool) = match &item {
|
||||||
|
Annotatable::Item(ref iitem) => {
|
||||||
|
let sig = match &iitem.kind {
|
||||||
|
ItemKind::Fn(box ast::Fn { sig, .. }) => sig,
|
||||||
|
_ => {
|
||||||
|
dcx.emit_err(errors::AutoDiffInvalidApplication { span: item.span() });
|
||||||
|
return vec![item];
|
||||||
|
}
|
||||||
|
};
|
||||||
|
(sig.clone(), false)
|
||||||
|
}
|
||||||
|
Annotatable::AssocItem(ref assoc_item, _) => {
|
||||||
|
let sig = match &assoc_item.kind {
|
||||||
|
ast::AssocItemKind::Fn(box ast::Fn { sig, .. }) => sig,
|
||||||
|
_ => {
|
||||||
|
dcx.emit_err(errors::AutoDiffInvalidApplication { span: item.span() });
|
||||||
|
return vec![item];
|
||||||
|
}
|
||||||
|
};
|
||||||
|
(sig.clone(), true)
|
||||||
|
}
|
||||||
|
_ => {
|
||||||
|
dcx.emit_err(errors::AutoDiffInvalidApplication { span: item.span() });
|
||||||
|
return vec![item];
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let meta_item_vec: ThinVec<MetaItemInner> = match meta_item.kind {
|
||||||
|
ast::MetaItemKind::List(ref vec) => vec.clone(),
|
||||||
|
_ => {
|
||||||
|
dcx.emit_err(errors::AutoDiffInvalidApplication { span: item.span() });
|
||||||
|
return vec![item];
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let has_ret = has_ret(&sig.decl.output);
|
||||||
|
let sig_span = ecx.with_call_site_ctxt(sig.span);
|
||||||
|
|
||||||
|
let (vis, primal) = match &item {
|
||||||
|
Annotatable::Item(ref iitem) => (iitem.vis.clone(), iitem.ident.clone()),
|
||||||
|
Annotatable::AssocItem(ref assoc_item, _) => {
|
||||||
|
(assoc_item.vis.clone(), assoc_item.ident.clone())
|
||||||
|
}
|
||||||
|
_ => {
|
||||||
|
dcx.emit_err(errors::AutoDiffInvalidApplication { span: item.span() });
|
||||||
|
return vec![item];
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// create TokenStream from vec elemtents:
|
||||||
|
// meta_item doesn't have a .tokens field
|
||||||
|
let comma: Token = Token::new(TokenKind::Comma, Span::default());
|
||||||
|
let mut ts: Vec<TokenTree> = vec![];
|
||||||
|
if meta_item_vec.len() < 2 {
|
||||||
|
// At the bare minimum, we need a fnc name and a mode, even for a dummy function with no
|
||||||
|
// input and output args.
|
||||||
|
dcx.emit_err(errors::AutoDiffMissingConfig { span: item.span() });
|
||||||
|
return vec![item];
|
||||||
|
} else {
|
||||||
|
for t in meta_item_vec.clone()[1..].iter() {
|
||||||
|
let val = first_ident(t);
|
||||||
|
let t = Token::from_ast_ident(val);
|
||||||
|
ts.push(TokenTree::Token(t, Spacing::Joint));
|
||||||
|
ts.push(TokenTree::Token(comma.clone(), Spacing::Alone));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !has_ret {
|
||||||
|
// We don't want users to provide a return activity if the function doesn't return anything.
|
||||||
|
// For simplicity, we just add a dummy token to the end of the list.
|
||||||
|
let t = Token::new(TokenKind::Ident(sym::None, false.into()), Span::default());
|
||||||
|
ts.push(TokenTree::Token(t, Spacing::Joint));
|
||||||
|
}
|
||||||
|
let ts: TokenStream = TokenStream::from_iter(ts);
|
||||||
|
|
||||||
|
let x: AutoDiffAttrs = from_ast(ecx, &meta_item_vec, has_ret);
|
||||||
|
if !x.is_active() {
|
||||||
|
// We encountered an error, so we return the original item.
|
||||||
|
// This allows us to potentially parse other attributes.
|
||||||
|
return vec![item];
|
||||||
|
}
|
||||||
|
let span = ecx.with_def_site_ctxt(expand_span);
|
||||||
|
|
||||||
|
let n_active: u32 = x
|
||||||
|
.input_activity
|
||||||
|
.iter()
|
||||||
|
.filter(|a| **a == DiffActivity::Active || **a == DiffActivity::ActiveOnly)
|
||||||
|
.count() as u32;
|
||||||
|
let (d_sig, new_args, idents, errored) = gen_enzyme_decl(ecx, &sig, &x, span);
|
||||||
|
let new_decl_span = d_sig.span;
|
||||||
|
let d_body = gen_enzyme_body(
|
||||||
|
ecx,
|
||||||
|
&x,
|
||||||
|
n_active,
|
||||||
|
&sig,
|
||||||
|
&d_sig,
|
||||||
|
primal,
|
||||||
|
&new_args,
|
||||||
|
span,
|
||||||
|
sig_span,
|
||||||
|
new_decl_span,
|
||||||
|
idents,
|
||||||
|
errored,
|
||||||
|
);
|
||||||
|
let d_ident = first_ident(&meta_item_vec[0]);
|
||||||
|
|
||||||
|
// The first element of it is the name of the function to be generated
|
||||||
|
let asdf = Box::new(ast::Fn {
|
||||||
|
defaultness: ast::Defaultness::Final,
|
||||||
|
sig: d_sig,
|
||||||
|
generics: Generics::default(),
|
||||||
|
body: Some(d_body),
|
||||||
|
});
|
||||||
|
let mut rustc_ad_attr =
|
||||||
|
P(ast::NormalAttr::from_ident(Ident::with_dummy_span(sym::rustc_autodiff)));
|
||||||
|
|
||||||
|
let ts2: Vec<TokenTree> = vec![TokenTree::Token(
|
||||||
|
Token::new(TokenKind::Ident(sym::never, false.into()), span),
|
||||||
|
Spacing::Joint,
|
||||||
|
)];
|
||||||
|
let never_arg = ast::DelimArgs {
|
||||||
|
dspan: ast::tokenstream::DelimSpan::from_single(span),
|
||||||
|
delim: ast::token::Delimiter::Parenthesis,
|
||||||
|
tokens: ast::tokenstream::TokenStream::from_iter(ts2),
|
||||||
|
};
|
||||||
|
let inline_item = ast::AttrItem {
|
||||||
|
unsafety: ast::Safety::Default,
|
||||||
|
path: ast::Path::from_ident(Ident::with_dummy_span(sym::inline)),
|
||||||
|
args: ast::AttrArgs::Delimited(never_arg),
|
||||||
|
tokens: None,
|
||||||
|
};
|
||||||
|
let inline_never_attr = P(ast::NormalAttr { item: inline_item, tokens: None });
|
||||||
|
let new_id = ecx.sess.psess.attr_id_generator.mk_attr_id();
|
||||||
|
let attr: ast::Attribute = ast::Attribute {
|
||||||
|
kind: ast::AttrKind::Normal(rustc_ad_attr.clone()),
|
||||||
|
id: new_id,
|
||||||
|
style: ast::AttrStyle::Outer,
|
||||||
|
span,
|
||||||
|
};
|
||||||
|
let new_id = ecx.sess.psess.attr_id_generator.mk_attr_id();
|
||||||
|
let inline_never: ast::Attribute = ast::Attribute {
|
||||||
|
kind: ast::AttrKind::Normal(inline_never_attr),
|
||||||
|
id: new_id,
|
||||||
|
style: ast::AttrStyle::Outer,
|
||||||
|
span,
|
||||||
|
};
|
||||||
|
|
||||||
|
// Don't add it multiple times:
|
||||||
|
let orig_annotatable: Annotatable = match item {
|
||||||
|
Annotatable::Item(ref mut iitem) => {
|
||||||
|
if !iitem.attrs.iter().any(|a| a.id == attr.id) {
|
||||||
|
iitem.attrs.push(attr.clone());
|
||||||
|
}
|
||||||
|
if !iitem.attrs.iter().any(|a| a.id == inline_never.id) {
|
||||||
|
iitem.attrs.push(inline_never.clone());
|
||||||
|
}
|
||||||
|
Annotatable::Item(iitem.clone())
|
||||||
|
}
|
||||||
|
Annotatable::AssocItem(ref mut assoc_item, i @ Impl) => {
|
||||||
|
if !assoc_item.attrs.iter().any(|a| a.id == attr.id) {
|
||||||
|
assoc_item.attrs.push(attr.clone());
|
||||||
|
}
|
||||||
|
if !assoc_item.attrs.iter().any(|a| a.id == inline_never.id) {
|
||||||
|
assoc_item.attrs.push(inline_never.clone());
|
||||||
|
}
|
||||||
|
Annotatable::AssocItem(assoc_item.clone(), i)
|
||||||
|
}
|
||||||
|
_ => {
|
||||||
|
unreachable!("annotatable kind checked previously")
|
||||||
|
}
|
||||||
|
};
|
||||||
|
// Now update for d_fn
|
||||||
|
rustc_ad_attr.item.args = rustc_ast::AttrArgs::Delimited(rustc_ast::DelimArgs {
|
||||||
|
dspan: DelimSpan::dummy(),
|
||||||
|
delim: rustc_ast::token::Delimiter::Parenthesis,
|
||||||
|
tokens: ts,
|
||||||
|
});
|
||||||
|
let d_attr: ast::Attribute = ast::Attribute {
|
||||||
|
kind: ast::AttrKind::Normal(rustc_ad_attr.clone()),
|
||||||
|
id: new_id,
|
||||||
|
style: ast::AttrStyle::Outer,
|
||||||
|
span,
|
||||||
|
};
|
||||||
|
|
||||||
|
let d_annotatable = if is_impl {
|
||||||
|
let assoc_item: AssocItemKind = ast::AssocItemKind::Fn(asdf);
|
||||||
|
let d_fn = P(ast::AssocItem {
|
||||||
|
attrs: thin_vec![d_attr.clone(), inline_never],
|
||||||
|
id: ast::DUMMY_NODE_ID,
|
||||||
|
span,
|
||||||
|
vis,
|
||||||
|
ident: d_ident,
|
||||||
|
kind: assoc_item,
|
||||||
|
tokens: None,
|
||||||
|
});
|
||||||
|
Annotatable::AssocItem(d_fn, Impl)
|
||||||
|
} else {
|
||||||
|
let mut d_fn = ecx.item(
|
||||||
|
span,
|
||||||
|
d_ident,
|
||||||
|
thin_vec![d_attr.clone(), inline_never],
|
||||||
|
ItemKind::Fn(asdf),
|
||||||
|
);
|
||||||
|
d_fn.vis = vis;
|
||||||
|
Annotatable::Item(d_fn)
|
||||||
|
};
|
||||||
|
|
||||||
|
return vec![orig_annotatable, d_annotatable];
|
||||||
|
}
|
||||||
|
|
||||||
|
// shadow arguments (the extra ones which were not in the original (primal) function), in reverse mode must be
|
||||||
|
// mutable references or ptrs, because Enzyme will write into them.
|
||||||
|
fn assure_mut_ref(ty: &ast::Ty) -> ast::Ty {
|
||||||
|
let mut ty = ty.clone();
|
||||||
|
match ty.kind {
|
||||||
|
TyKind::Ptr(ref mut mut_ty) => {
|
||||||
|
mut_ty.mutbl = ast::Mutability::Mut;
|
||||||
|
}
|
||||||
|
TyKind::Ref(_, ref mut mut_ty) => {
|
||||||
|
mut_ty.mutbl = ast::Mutability::Mut;
|
||||||
|
}
|
||||||
|
_ => {
|
||||||
|
panic!("unsupported type: {:?}", ty);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ty
|
||||||
|
}
|
||||||
|
|
||||||
|
/// We only want this function to type-check, since we will replace the body
|
||||||
|
/// later on llvm level. Using `loop {}` does not cover all return types anymore,
|
||||||
|
/// so instead we build something that should pass. We also add a inline_asm
|
||||||
|
/// line, as one more barrier for rustc to prevent inlining of this function.
|
||||||
|
/// FIXME(ZuseZ4): We still have cases of incorrect inlining across modules, see
|
||||||
|
/// <https://github.com/EnzymeAD/rust/issues/173>, so this isn't sufficient.
|
||||||
|
/// It also triggers an Enzyme crash if we due to a bug ever try to differentiate
|
||||||
|
/// this function (which should never happen, since it is only a placeholder).
|
||||||
|
/// Finally, we also add back_box usages of all input arguments, to prevent rustc
|
||||||
|
/// from optimizing any arguments away.
|
||||||
|
fn gen_enzyme_body(
|
||||||
|
ecx: &ExtCtxt<'_>,
|
||||||
|
x: &AutoDiffAttrs,
|
||||||
|
n_active: u32,
|
||||||
|
sig: &ast::FnSig,
|
||||||
|
d_sig: &ast::FnSig,
|
||||||
|
primal: Ident,
|
||||||
|
new_names: &[String],
|
||||||
|
span: Span,
|
||||||
|
sig_span: Span,
|
||||||
|
new_decl_span: Span,
|
||||||
|
idents: Vec<Ident>,
|
||||||
|
errored: bool,
|
||||||
|
) -> P<ast::Block> {
|
||||||
|
let blackbox_path = ecx.std_path(&[sym::hint, sym::black_box]);
|
||||||
|
let noop = ast::InlineAsm {
|
||||||
|
asm_macro: ast::AsmMacro::Asm,
|
||||||
|
template: vec![ast::InlineAsmTemplatePiece::String("NOP".into())],
|
||||||
|
template_strs: Box::new([]),
|
||||||
|
operands: vec![],
|
||||||
|
clobber_abis: vec![],
|
||||||
|
options: ast::InlineAsmOptions::PURE | ast::InlineAsmOptions::NOMEM,
|
||||||
|
line_spans: vec![],
|
||||||
|
};
|
||||||
|
let noop_expr = ecx.expr_asm(span, P(noop));
|
||||||
|
let unsf = ast::BlockCheckMode::Unsafe(ast::UnsafeSource::CompilerGenerated);
|
||||||
|
let unsf_block = ast::Block {
|
||||||
|
stmts: thin_vec![ecx.stmt_semi(noop_expr)],
|
||||||
|
id: ast::DUMMY_NODE_ID,
|
||||||
|
tokens: None,
|
||||||
|
rules: unsf,
|
||||||
|
span,
|
||||||
|
could_be_bare_literal: false,
|
||||||
|
};
|
||||||
|
let unsf_expr = ecx.expr_block(P(unsf_block));
|
||||||
|
let blackbox_call_expr = ecx.expr_path(ecx.path(span, blackbox_path));
|
||||||
|
let primal_call = gen_primal_call(ecx, span, primal, idents);
|
||||||
|
let black_box_primal_call =
|
||||||
|
ecx.expr_call(new_decl_span, blackbox_call_expr.clone(), thin_vec![
|
||||||
|
primal_call.clone()
|
||||||
|
]);
|
||||||
|
let tup_args = new_names
|
||||||
|
.iter()
|
||||||
|
.map(|arg| ecx.expr_path(ecx.path_ident(span, Ident::from_str(arg))))
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
let black_box_remaining_args =
|
||||||
|
ecx.expr_call(sig_span, blackbox_call_expr.clone(), thin_vec![
|
||||||
|
ecx.expr_tuple(sig_span, tup_args)
|
||||||
|
]);
|
||||||
|
|
||||||
|
let mut body = ecx.block(span, ThinVec::new());
|
||||||
|
body.stmts.push(ecx.stmt_semi(unsf_expr));
|
||||||
|
|
||||||
|
// This uses primal args which won't be available if we errored before
|
||||||
|
if !errored {
|
||||||
|
body.stmts.push(ecx.stmt_semi(black_box_primal_call.clone()));
|
||||||
|
}
|
||||||
|
body.stmts.push(ecx.stmt_semi(black_box_remaining_args));
|
||||||
|
|
||||||
|
if !has_ret(&d_sig.decl.output) {
|
||||||
|
// there is no return type that we have to match, () works fine.
|
||||||
|
return body;
|
||||||
|
}
|
||||||
|
|
||||||
|
// having an active-only return means we'll drop the original return type.
|
||||||
|
// So that can be treated identical to not having one in the first place.
|
||||||
|
let primal_ret = has_ret(&sig.decl.output) && !x.has_active_only_ret();
|
||||||
|
|
||||||
|
if primal_ret && n_active == 0 && x.mode.is_rev() {
|
||||||
|
// We only have the primal ret.
|
||||||
|
body.stmts.push(ecx.stmt_expr(black_box_primal_call.clone()));
|
||||||
|
return body;
|
||||||
|
}
|
||||||
|
|
||||||
|
if !primal_ret && n_active == 1 {
|
||||||
|
// Again no tuple return, so return default float val.
|
||||||
|
let ty = match d_sig.decl.output {
|
||||||
|
FnRetTy::Ty(ref ty) => ty.clone(),
|
||||||
|
FnRetTy::Default(span) => {
|
||||||
|
panic!("Did not expect Default ret ty: {:?}", span);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
let arg = ty.kind.is_simple_path().unwrap();
|
||||||
|
let sl: Vec<Symbol> = vec![arg, kw::Default];
|
||||||
|
let tmp = ecx.def_site_path(&sl);
|
||||||
|
let default_call_expr = ecx.expr_path(ecx.path(span, tmp));
|
||||||
|
let default_call_expr = ecx.expr_call(new_decl_span, default_call_expr, thin_vec![]);
|
||||||
|
body.stmts.push(ecx.stmt_expr(default_call_expr));
|
||||||
|
return body;
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut exprs = ThinVec::<P<ast::Expr>>::new();
|
||||||
|
if primal_ret {
|
||||||
|
// We have both primal ret and active floats.
|
||||||
|
// primal ret is first, by construction.
|
||||||
|
exprs.push(primal_call.clone());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Now construct default placeholder for each active float.
|
||||||
|
// Is there something nicer than f32::default() and f64::default()?
|
||||||
|
let d_ret_ty = match d_sig.decl.output {
|
||||||
|
FnRetTy::Ty(ref ty) => ty.clone(),
|
||||||
|
FnRetTy::Default(span) => {
|
||||||
|
panic!("Did not expect Default ret ty: {:?}", span);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
let mut d_ret_ty = match d_ret_ty.kind.clone() {
|
||||||
|
TyKind::Tup(ref tys) => tys.clone(),
|
||||||
|
TyKind::Path(_, rustc_ast::Path { segments, .. }) => {
|
||||||
|
if let [segment] = &segments[..]
|
||||||
|
&& segment.args.is_none()
|
||||||
|
{
|
||||||
|
let id = vec![segments[0].ident];
|
||||||
|
let kind = TyKind::Path(None, ecx.path(span, id));
|
||||||
|
let ty = P(rustc_ast::Ty { kind, id: ast::DUMMY_NODE_ID, span, tokens: None });
|
||||||
|
thin_vec![ty]
|
||||||
|
} else {
|
||||||
|
panic!("Expected tuple or simple path return type");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ => {
|
||||||
|
// We messed up construction of d_sig
|
||||||
|
panic!("Did not expect non-tuple ret ty: {:?}", d_ret_ty);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
if x.mode.is_fwd() && x.ret_activity == DiffActivity::Dual {
|
||||||
|
assert!(d_ret_ty.len() == 2);
|
||||||
|
// both should be identical, by construction
|
||||||
|
let arg = d_ret_ty[0].kind.is_simple_path().unwrap();
|
||||||
|
let arg2 = d_ret_ty[1].kind.is_simple_path().unwrap();
|
||||||
|
assert!(arg == arg2);
|
||||||
|
let sl: Vec<Symbol> = vec![arg, kw::Default];
|
||||||
|
let tmp = ecx.def_site_path(&sl);
|
||||||
|
let default_call_expr = ecx.expr_path(ecx.path(span, tmp));
|
||||||
|
let default_call_expr = ecx.expr_call(new_decl_span, default_call_expr, thin_vec![]);
|
||||||
|
exprs.push(default_call_expr);
|
||||||
|
} else if x.mode.is_rev() {
|
||||||
|
if primal_ret {
|
||||||
|
// We have extra handling above for the primal ret
|
||||||
|
d_ret_ty = d_ret_ty[1..].to_vec().into();
|
||||||
|
}
|
||||||
|
|
||||||
|
for arg in d_ret_ty.iter() {
|
||||||
|
let arg = arg.kind.is_simple_path().unwrap();
|
||||||
|
let sl: Vec<Symbol> = vec![arg, kw::Default];
|
||||||
|
let tmp = ecx.def_site_path(&sl);
|
||||||
|
let default_call_expr = ecx.expr_path(ecx.path(span, tmp));
|
||||||
|
let default_call_expr =
|
||||||
|
ecx.expr_call(new_decl_span, default_call_expr, thin_vec![]);
|
||||||
|
exprs.push(default_call_expr);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let ret: P<ast::Expr>;
|
||||||
|
match &exprs[..] {
|
||||||
|
[] => {
|
||||||
|
assert!(!has_ret(&d_sig.decl.output));
|
||||||
|
// We don't have to match the return type.
|
||||||
|
return body;
|
||||||
|
}
|
||||||
|
[arg] => {
|
||||||
|
ret = ecx
|
||||||
|
.expr_call(new_decl_span, blackbox_call_expr.clone(), thin_vec![arg.clone()]);
|
||||||
|
}
|
||||||
|
args => {
|
||||||
|
let ret_tuple: P<ast::Expr> = ecx.expr_tuple(span, args.into());
|
||||||
|
ret =
|
||||||
|
ecx.expr_call(new_decl_span, blackbox_call_expr.clone(), thin_vec![ret_tuple]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
assert!(has_ret(&d_sig.decl.output));
|
||||||
|
body.stmts.push(ecx.stmt_expr(ret));
|
||||||
|
|
||||||
|
body
|
||||||
|
}
|
||||||
|
|
||||||
|
fn gen_primal_call(
|
||||||
|
ecx: &ExtCtxt<'_>,
|
||||||
|
span: Span,
|
||||||
|
primal: Ident,
|
||||||
|
idents: Vec<Ident>,
|
||||||
|
) -> P<ast::Expr> {
|
||||||
|
let has_self = idents.len() > 0 && idents[0].name == kw::SelfLower;
|
||||||
|
if has_self {
|
||||||
|
let args: ThinVec<_> =
|
||||||
|
idents[1..].iter().map(|arg| ecx.expr_path(ecx.path_ident(span, *arg))).collect();
|
||||||
|
let self_expr = ecx.expr_self(span);
|
||||||
|
ecx.expr_method_call(span, self_expr, primal, args.clone())
|
||||||
|
} else {
|
||||||
|
let args: ThinVec<_> =
|
||||||
|
idents.iter().map(|arg| ecx.expr_path(ecx.path_ident(span, *arg))).collect();
|
||||||
|
let primal_call_expr = ecx.expr_path(ecx.path_ident(span, primal));
|
||||||
|
ecx.expr_call(span, primal_call_expr, args)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate the new function declaration. Const arguments are kept as is. Duplicated arguments must
|
||||||
|
// be pointers or references. Those receive a shadow argument, which is a mutable reference/pointer.
|
||||||
|
// Active arguments must be scalars. Their shadow argument is added to the return type (and will be
|
||||||
|
// zero-initialized by Enzyme).
|
||||||
|
// Each argument of the primal function (and the return type if existing) must be annotated with an
|
||||||
|
// activity.
|
||||||
|
//
|
||||||
|
// Error handling: If the user provides an invalid configuration (incorrect numbers, types, or
|
||||||
|
// both), we emit an error and return the original signature. This allows us to continue parsing.
|
||||||
|
fn gen_enzyme_decl(
|
||||||
|
ecx: &ExtCtxt<'_>,
|
||||||
|
sig: &ast::FnSig,
|
||||||
|
x: &AutoDiffAttrs,
|
||||||
|
span: Span,
|
||||||
|
) -> (ast::FnSig, Vec<String>, Vec<Ident>, bool) {
|
||||||
|
let dcx = ecx.sess.dcx();
|
||||||
|
let has_ret = has_ret(&sig.decl.output);
|
||||||
|
let sig_args = sig.decl.inputs.len() + if has_ret { 1 } else { 0 };
|
||||||
|
let num_activities = x.input_activity.len() + if x.has_ret_activity() { 1 } else { 0 };
|
||||||
|
if sig_args != num_activities {
|
||||||
|
dcx.emit_err(errors::AutoDiffInvalidNumberActivities {
|
||||||
|
span,
|
||||||
|
expected: sig_args,
|
||||||
|
found: num_activities,
|
||||||
|
});
|
||||||
|
// This is not the right signature, but we can continue parsing.
|
||||||
|
return (sig.clone(), vec![], vec![], true);
|
||||||
|
}
|
||||||
|
assert!(sig.decl.inputs.len() == x.input_activity.len());
|
||||||
|
assert!(has_ret == x.has_ret_activity());
|
||||||
|
let mut d_decl = sig.decl.clone();
|
||||||
|
let mut d_inputs = Vec::new();
|
||||||
|
let mut new_inputs = Vec::new();
|
||||||
|
let mut idents = Vec::new();
|
||||||
|
let mut act_ret = ThinVec::new();
|
||||||
|
|
||||||
|
// We have two loops, a first one just to check the activities and types and possibly report
|
||||||
|
// multiple errors in one compilation session.
|
||||||
|
let mut errors = false;
|
||||||
|
for (arg, activity) in sig.decl.inputs.iter().zip(x.input_activity.iter()) {
|
||||||
|
if !valid_input_activity(x.mode, *activity) {
|
||||||
|
dcx.emit_err(errors::AutoDiffInvalidApplicationModeAct {
|
||||||
|
span,
|
||||||
|
mode: x.mode.to_string(),
|
||||||
|
act: activity.to_string(),
|
||||||
|
});
|
||||||
|
errors = true;
|
||||||
|
}
|
||||||
|
if !valid_ty_for_activity(&arg.ty, *activity) {
|
||||||
|
dcx.emit_err(errors::AutoDiffInvalidTypeForActivity {
|
||||||
|
span: arg.ty.span,
|
||||||
|
act: activity.to_string(),
|
||||||
|
});
|
||||||
|
errors = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if errors {
|
||||||
|
// This is not the right signature, but we can continue parsing.
|
||||||
|
return (sig.clone(), new_inputs, idents, true);
|
||||||
|
}
|
||||||
|
let unsafe_activities = x
|
||||||
|
.input_activity
|
||||||
|
.iter()
|
||||||
|
.any(|&act| matches!(act, DiffActivity::DuplicatedOnly | DiffActivity::DualOnly));
|
||||||
|
for (arg, activity) in sig.decl.inputs.iter().zip(x.input_activity.iter()) {
|
||||||
|
d_inputs.push(arg.clone());
|
||||||
|
match activity {
|
||||||
|
DiffActivity::Active => {
|
||||||
|
act_ret.push(arg.ty.clone());
|
||||||
|
}
|
||||||
|
DiffActivity::ActiveOnly => {
|
||||||
|
// We will add the active scalar to the return type.
|
||||||
|
// This is handled later.
|
||||||
|
}
|
||||||
|
DiffActivity::Duplicated | DiffActivity::DuplicatedOnly => {
|
||||||
|
let mut shadow_arg = arg.clone();
|
||||||
|
// We += into the shadow in reverse mode.
|
||||||
|
shadow_arg.ty = P(assure_mut_ref(&arg.ty));
|
||||||
|
let old_name = if let PatKind::Ident(_, ident, _) = arg.pat.kind {
|
||||||
|
ident.name
|
||||||
|
} else {
|
||||||
|
debug!("{:#?}", &shadow_arg.pat);
|
||||||
|
panic!("not an ident?");
|
||||||
|
};
|
||||||
|
let name: String = format!("d{}", old_name);
|
||||||
|
new_inputs.push(name.clone());
|
||||||
|
let ident = Ident::from_str_and_span(&name, shadow_arg.pat.span);
|
||||||
|
shadow_arg.pat = P(ast::Pat {
|
||||||
|
id: ast::DUMMY_NODE_ID,
|
||||||
|
kind: PatKind::Ident(BindingMode::NONE, ident, None),
|
||||||
|
span: shadow_arg.pat.span,
|
||||||
|
tokens: shadow_arg.pat.tokens.clone(),
|
||||||
|
});
|
||||||
|
d_inputs.push(shadow_arg);
|
||||||
|
}
|
||||||
|
DiffActivity::Dual | DiffActivity::DualOnly => {
|
||||||
|
let mut shadow_arg = arg.clone();
|
||||||
|
let old_name = if let PatKind::Ident(_, ident, _) = arg.pat.kind {
|
||||||
|
ident.name
|
||||||
|
} else {
|
||||||
|
debug!("{:#?}", &shadow_arg.pat);
|
||||||
|
panic!("not an ident?");
|
||||||
|
};
|
||||||
|
let name: String = format!("b{}", old_name);
|
||||||
|
new_inputs.push(name.clone());
|
||||||
|
let ident = Ident::from_str_and_span(&name, shadow_arg.pat.span);
|
||||||
|
shadow_arg.pat = P(ast::Pat {
|
||||||
|
id: ast::DUMMY_NODE_ID,
|
||||||
|
kind: PatKind::Ident(BindingMode::NONE, ident, None),
|
||||||
|
span: shadow_arg.pat.span,
|
||||||
|
tokens: shadow_arg.pat.tokens.clone(),
|
||||||
|
});
|
||||||
|
d_inputs.push(shadow_arg);
|
||||||
|
}
|
||||||
|
DiffActivity::Const => {
|
||||||
|
// Nothing to do here.
|
||||||
|
}
|
||||||
|
DiffActivity::None | DiffActivity::FakeActivitySize => {
|
||||||
|
panic!("Should not happen");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if let PatKind::Ident(_, ident, _) = arg.pat.kind {
|
||||||
|
idents.push(ident.clone());
|
||||||
|
} else {
|
||||||
|
panic!("not an ident?");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let active_only_ret = x.ret_activity == DiffActivity::ActiveOnly;
|
||||||
|
if active_only_ret {
|
||||||
|
assert!(x.mode.is_rev());
|
||||||
|
}
|
||||||
|
|
||||||
|
// If we return a scalar in the primal and the scalar is active,
|
||||||
|
// then add it as last arg to the inputs.
|
||||||
|
if x.mode.is_rev() {
|
||||||
|
match x.ret_activity {
|
||||||
|
DiffActivity::Active | DiffActivity::ActiveOnly => {
|
||||||
|
let ty = match d_decl.output {
|
||||||
|
FnRetTy::Ty(ref ty) => ty.clone(),
|
||||||
|
FnRetTy::Default(span) => {
|
||||||
|
panic!("Did not expect Default ret ty: {:?}", span);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
let name = "dret".to_string();
|
||||||
|
let ident = Ident::from_str_and_span(&name, ty.span);
|
||||||
|
let shadow_arg = ast::Param {
|
||||||
|
attrs: ThinVec::new(),
|
||||||
|
ty: ty.clone(),
|
||||||
|
pat: P(ast::Pat {
|
||||||
|
id: ast::DUMMY_NODE_ID,
|
||||||
|
kind: PatKind::Ident(BindingMode::NONE, ident, None),
|
||||||
|
span: ty.span,
|
||||||
|
tokens: None,
|
||||||
|
}),
|
||||||
|
id: ast::DUMMY_NODE_ID,
|
||||||
|
span: ty.span,
|
||||||
|
is_placeholder: false,
|
||||||
|
};
|
||||||
|
d_inputs.push(shadow_arg);
|
||||||
|
new_inputs.push(name);
|
||||||
|
}
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
d_decl.inputs = d_inputs.into();
|
||||||
|
|
||||||
|
if x.mode.is_fwd() {
|
||||||
|
if let DiffActivity::Dual = x.ret_activity {
|
||||||
|
let ty = match d_decl.output {
|
||||||
|
FnRetTy::Ty(ref ty) => ty.clone(),
|
||||||
|
FnRetTy::Default(span) => {
|
||||||
|
panic!("Did not expect Default ret ty: {:?}", span);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
// Dual can only be used for f32/f64 ret.
|
||||||
|
// In that case we return now a tuple with two floats.
|
||||||
|
let kind = TyKind::Tup(thin_vec![ty.clone(), ty.clone()]);
|
||||||
|
let ty = P(rustc_ast::Ty { kind, id: ty.id, span: ty.span, tokens: None });
|
||||||
|
d_decl.output = FnRetTy::Ty(ty);
|
||||||
|
}
|
||||||
|
if let DiffActivity::DualOnly = x.ret_activity {
|
||||||
|
// No need to change the return type,
|
||||||
|
// we will just return the shadow in place
|
||||||
|
// of the primal return.
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If we use ActiveOnly, drop the original return value.
|
||||||
|
d_decl.output =
|
||||||
|
if active_only_ret { FnRetTy::Default(span) } else { d_decl.output.clone() };
|
||||||
|
|
||||||
|
trace!("act_ret: {:?}", act_ret);
|
||||||
|
|
||||||
|
// If we have an active input scalar, add it's gradient to the
|
||||||
|
// return type. This might require changing the return type to a
|
||||||
|
// tuple.
|
||||||
|
if act_ret.len() > 0 {
|
||||||
|
let ret_ty = match d_decl.output {
|
||||||
|
FnRetTy::Ty(ref ty) => {
|
||||||
|
if !active_only_ret {
|
||||||
|
act_ret.insert(0, ty.clone());
|
||||||
|
}
|
||||||
|
let kind = TyKind::Tup(act_ret);
|
||||||
|
P(rustc_ast::Ty { kind, id: ty.id, span: ty.span, tokens: None })
|
||||||
|
}
|
||||||
|
FnRetTy::Default(span) => {
|
||||||
|
if act_ret.len() == 1 {
|
||||||
|
act_ret[0].clone()
|
||||||
|
} else {
|
||||||
|
let kind = TyKind::Tup(act_ret.iter().map(|arg| arg.clone()).collect());
|
||||||
|
P(rustc_ast::Ty { kind, id: ast::DUMMY_NODE_ID, span, tokens: None })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
d_decl.output = FnRetTy::Ty(ret_ty);
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut d_header = sig.header.clone();
|
||||||
|
if unsafe_activities {
|
||||||
|
d_header.safety = rustc_ast::Safety::Unsafe(span);
|
||||||
|
}
|
||||||
|
let d_sig = FnSig { header: d_header, decl: d_decl, span };
|
||||||
|
trace!("Generated signature: {:?}", d_sig);
|
||||||
|
(d_sig, new_inputs, idents, false)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(not(llvm_enzyme))]
|
||||||
|
mod ad_fallback {
|
||||||
|
use rustc_ast::ast;
|
||||||
|
use rustc_expand::base::{Annotatable, ExtCtxt};
|
||||||
|
use rustc_span::Span;
|
||||||
|
|
||||||
|
use crate::errors;
|
||||||
|
pub(crate) fn expand(
|
||||||
|
ecx: &mut ExtCtxt<'_>,
|
||||||
|
_expand_span: Span,
|
||||||
|
meta_item: &ast::MetaItem,
|
||||||
|
item: Annotatable,
|
||||||
|
) -> Vec<Annotatable> {
|
||||||
|
ecx.sess.dcx().emit_err(errors::AutoDiffSupportNotBuild { span: meta_item.span });
|
||||||
|
return vec![item];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(not(llvm_enzyme))]
|
||||||
|
pub(crate) use ad_fallback::expand;
|
||||||
|
#[cfg(llvm_enzyme)]
|
||||||
|
pub(crate) use llvm_enzyme::expand;
|
|
@ -145,6 +145,78 @@ pub(crate) struct AllocMustStatics {
|
||||||
pub(crate) span: Span,
|
pub(crate) span: Span,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(llvm_enzyme)]
|
||||||
|
pub(crate) use autodiff::*;
|
||||||
|
|
||||||
|
#[cfg(llvm_enzyme)]
|
||||||
|
mod autodiff {
|
||||||
|
use super::*;
|
||||||
|
#[derive(Diagnostic)]
|
||||||
|
#[diag(builtin_macros_autodiff_missing_config)]
|
||||||
|
pub(crate) struct AutoDiffMissingConfig {
|
||||||
|
#[primary_span]
|
||||||
|
pub(crate) span: Span,
|
||||||
|
}
|
||||||
|
#[derive(Diagnostic)]
|
||||||
|
#[diag(builtin_macros_autodiff_unknown_activity)]
|
||||||
|
pub(crate) struct AutoDiffUnknownActivity {
|
||||||
|
#[primary_span]
|
||||||
|
pub(crate) span: Span,
|
||||||
|
pub(crate) act: String,
|
||||||
|
}
|
||||||
|
#[derive(Diagnostic)]
|
||||||
|
#[diag(builtin_macros_autodiff_ty_activity)]
|
||||||
|
pub(crate) struct AutoDiffInvalidTypeForActivity {
|
||||||
|
#[primary_span]
|
||||||
|
pub(crate) span: Span,
|
||||||
|
pub(crate) act: String,
|
||||||
|
}
|
||||||
|
#[derive(Diagnostic)]
|
||||||
|
#[diag(builtin_macros_autodiff_number_activities)]
|
||||||
|
pub(crate) struct AutoDiffInvalidNumberActivities {
|
||||||
|
#[primary_span]
|
||||||
|
pub(crate) span: Span,
|
||||||
|
pub(crate) expected: usize,
|
||||||
|
pub(crate) found: usize,
|
||||||
|
}
|
||||||
|
#[derive(Diagnostic)]
|
||||||
|
#[diag(builtin_macros_autodiff_mode_activity)]
|
||||||
|
pub(crate) struct AutoDiffInvalidApplicationModeAct {
|
||||||
|
#[primary_span]
|
||||||
|
pub(crate) span: Span,
|
||||||
|
pub(crate) mode: String,
|
||||||
|
pub(crate) act: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Diagnostic)]
|
||||||
|
#[diag(builtin_macros_autodiff_mode)]
|
||||||
|
pub(crate) struct AutoDiffInvalidMode {
|
||||||
|
#[primary_span]
|
||||||
|
pub(crate) span: Span,
|
||||||
|
pub(crate) mode: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Diagnostic)]
|
||||||
|
#[diag(builtin_macros_autodiff)]
|
||||||
|
pub(crate) struct AutoDiffInvalidApplication {
|
||||||
|
#[primary_span]
|
||||||
|
pub(crate) span: Span,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(not(llvm_enzyme))]
|
||||||
|
pub(crate) use ad_fallback::*;
|
||||||
|
#[cfg(not(llvm_enzyme))]
|
||||||
|
mod ad_fallback {
|
||||||
|
use super::*;
|
||||||
|
#[derive(Diagnostic)]
|
||||||
|
#[diag(builtin_macros_autodiff_not_build)]
|
||||||
|
pub(crate) struct AutoDiffSupportNotBuild {
|
||||||
|
#[primary_span]
|
||||||
|
pub(crate) span: Span,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Diagnostic)]
|
#[derive(Diagnostic)]
|
||||||
#[diag(builtin_macros_concat_bytes_invalid)]
|
#[diag(builtin_macros_concat_bytes_invalid)]
|
||||||
pub(crate) struct ConcatBytesInvalid {
|
pub(crate) struct ConcatBytesInvalid {
|
||||||
|
|
|
@ -5,6 +5,7 @@
|
||||||
#![allow(internal_features)]
|
#![allow(internal_features)]
|
||||||
#![allow(rustc::diagnostic_outside_of_impl)]
|
#![allow(rustc::diagnostic_outside_of_impl)]
|
||||||
#![allow(rustc::untranslatable_diagnostic)]
|
#![allow(rustc::untranslatable_diagnostic)]
|
||||||
|
#![cfg_attr(not(bootstrap), feature(autodiff))]
|
||||||
#![doc(html_root_url = "https://doc.rust-lang.org/nightly/nightly-rustc/")]
|
#![doc(html_root_url = "https://doc.rust-lang.org/nightly/nightly-rustc/")]
|
||||||
#![doc(rust_logo)]
|
#![doc(rust_logo)]
|
||||||
#![feature(assert_matches)]
|
#![feature(assert_matches)]
|
||||||
|
@ -29,6 +30,7 @@ use crate::deriving::*;
|
||||||
|
|
||||||
mod alloc_error_handler;
|
mod alloc_error_handler;
|
||||||
mod assert;
|
mod assert;
|
||||||
|
mod autodiff;
|
||||||
mod cfg;
|
mod cfg;
|
||||||
mod cfg_accessible;
|
mod cfg_accessible;
|
||||||
mod cfg_eval;
|
mod cfg_eval;
|
||||||
|
@ -106,6 +108,7 @@ pub fn register_builtin_macros(resolver: &mut dyn ResolverExpand) {
|
||||||
|
|
||||||
register_attr! {
|
register_attr! {
|
||||||
alloc_error_handler: alloc_error_handler::expand,
|
alloc_error_handler: alloc_error_handler::expand,
|
||||||
|
autodiff: autodiff::expand,
|
||||||
bench: test::expand_bench,
|
bench: test::expand_bench,
|
||||||
cfg_accessible: cfg_accessible::Expander,
|
cfg_accessible: cfg_accessible::Expander,
|
||||||
cfg_eval: cfg_eval::expand,
|
cfg_eval: cfg_eval::expand,
|
||||||
|
|
|
@ -220,6 +220,10 @@ impl<'a> ExtCtxt<'a> {
|
||||||
self.stmt_local(local, span)
|
self.stmt_local(local, span)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn stmt_semi(&self, expr: P<ast::Expr>) -> ast::Stmt {
|
||||||
|
ast::Stmt { id: ast::DUMMY_NODE_ID, span: expr.span, kind: ast::StmtKind::Semi(expr) }
|
||||||
|
}
|
||||||
|
|
||||||
pub fn stmt_local(&self, local: P<ast::Local>, span: Span) -> ast::Stmt {
|
pub fn stmt_local(&self, local: P<ast::Local>, span: Span) -> ast::Stmt {
|
||||||
ast::Stmt { id: ast::DUMMY_NODE_ID, kind: ast::StmtKind::Let(local), span }
|
ast::Stmt { id: ast::DUMMY_NODE_ID, kind: ast::StmtKind::Let(local), span }
|
||||||
}
|
}
|
||||||
|
@ -287,6 +291,25 @@ impl<'a> ExtCtxt<'a> {
|
||||||
self.expr(sp, ast::ExprKind::Paren(e))
|
self.expr(sp, ast::ExprKind::Paren(e))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn expr_method_call(
|
||||||
|
&self,
|
||||||
|
span: Span,
|
||||||
|
expr: P<ast::Expr>,
|
||||||
|
ident: Ident,
|
||||||
|
args: ThinVec<P<ast::Expr>>,
|
||||||
|
) -> P<ast::Expr> {
|
||||||
|
let seg = ast::PathSegment::from_ident(ident);
|
||||||
|
self.expr(
|
||||||
|
span,
|
||||||
|
ast::ExprKind::MethodCall(Box::new(ast::MethodCall {
|
||||||
|
seg,
|
||||||
|
receiver: expr,
|
||||||
|
args,
|
||||||
|
span,
|
||||||
|
})),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
pub fn expr_call(
|
pub fn expr_call(
|
||||||
&self,
|
&self,
|
||||||
span: Span,
|
span: Span,
|
||||||
|
@ -295,6 +318,12 @@ impl<'a> ExtCtxt<'a> {
|
||||||
) -> P<ast::Expr> {
|
) -> P<ast::Expr> {
|
||||||
self.expr(span, ast::ExprKind::Call(expr, args))
|
self.expr(span, ast::ExprKind::Call(expr, args))
|
||||||
}
|
}
|
||||||
|
pub fn expr_loop(&self, sp: Span, block: P<ast::Block>) -> P<ast::Expr> {
|
||||||
|
self.expr(sp, ast::ExprKind::Loop(block, None, sp))
|
||||||
|
}
|
||||||
|
pub fn expr_asm(&self, sp: Span, expr: P<ast::InlineAsm>) -> P<ast::Expr> {
|
||||||
|
self.expr(sp, ast::ExprKind::InlineAsm(expr))
|
||||||
|
}
|
||||||
pub fn expr_call_ident(
|
pub fn expr_call_ident(
|
||||||
&self,
|
&self,
|
||||||
span: Span,
|
span: Span,
|
||||||
|
|
|
@ -752,6 +752,11 @@ pub const BUILTIN_ATTRIBUTES: &[BuiltinAttribute] = &[
|
||||||
template!(NameValueStr: "transparent|semitransparent|opaque"), ErrorFollowing,
|
template!(NameValueStr: "transparent|semitransparent|opaque"), ErrorFollowing,
|
||||||
EncodeCrossCrate::Yes, "used internally for testing macro hygiene",
|
EncodeCrossCrate::Yes, "used internally for testing macro hygiene",
|
||||||
),
|
),
|
||||||
|
rustc_attr!(
|
||||||
|
rustc_autodiff, Normal,
|
||||||
|
template!(Word, List: r#""...""#), DuplicatesOk,
|
||||||
|
EncodeCrossCrate::No, INTERNAL_UNSTABLE
|
||||||
|
),
|
||||||
|
|
||||||
// ==========================================================================
|
// ==========================================================================
|
||||||
// Internal attributes, Diagnostics related:
|
// Internal attributes, Diagnostics related:
|
||||||
|
|
|
@ -49,6 +49,10 @@ passes_attr_crate_level =
|
||||||
passes_attr_only_in_functions =
|
passes_attr_only_in_functions =
|
||||||
`{$attr}` attribute can only be used on functions
|
`{$attr}` attribute can only be used on functions
|
||||||
|
|
||||||
|
passes_autodiff_attr =
|
||||||
|
`#[autodiff]` should be applied to a function
|
||||||
|
.label = not a function
|
||||||
|
|
||||||
passes_both_ffi_const_and_pure =
|
passes_both_ffi_const_and_pure =
|
||||||
`#[ffi_const]` function cannot be `#[ffi_pure]`
|
`#[ffi_const]` function cannot be `#[ffi_pure]`
|
||||||
|
|
||||||
|
|
|
@ -243,6 +243,9 @@ impl<'tcx> CheckAttrVisitor<'tcx> {
|
||||||
self.check_generic_attr(hir_id, attr, target, Target::Fn);
|
self.check_generic_attr(hir_id, attr, target, Target::Fn);
|
||||||
self.check_proc_macro(hir_id, target, ProcMacroKind::Derive)
|
self.check_proc_macro(hir_id, target, ProcMacroKind::Derive)
|
||||||
}
|
}
|
||||||
|
[sym::autodiff, ..] => {
|
||||||
|
self.check_autodiff(hir_id, attr, span, target)
|
||||||
|
}
|
||||||
[sym::coroutine, ..] => {
|
[sym::coroutine, ..] => {
|
||||||
self.check_coroutine(attr, target);
|
self.check_coroutine(attr, target);
|
||||||
}
|
}
|
||||||
|
@ -2345,6 +2348,18 @@ impl<'tcx> CheckAttrVisitor<'tcx> {
|
||||||
self.dcx().emit_err(errors::RustcPubTransparent { span, attr_span });
|
self.dcx().emit_err(errors::RustcPubTransparent { span, attr_span });
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Checks if `#[autodiff]` is applied to an item other than a function item.
|
||||||
|
fn check_autodiff(&self, _hir_id: HirId, _attr: &Attribute, span: Span, target: Target) {
|
||||||
|
debug!("check_autodiff");
|
||||||
|
match target {
|
||||||
|
Target::Fn => {}
|
||||||
|
_ => {
|
||||||
|
self.dcx().emit_err(errors::AutoDiffAttr { attr_span: span });
|
||||||
|
self.abort.set(true);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'tcx> Visitor<'tcx> for CheckAttrVisitor<'tcx> {
|
impl<'tcx> Visitor<'tcx> for CheckAttrVisitor<'tcx> {
|
||||||
|
|
|
@ -20,6 +20,14 @@ use crate::lang_items::Duplicate;
|
||||||
#[diag(passes_incorrect_do_not_recommend_location)]
|
#[diag(passes_incorrect_do_not_recommend_location)]
|
||||||
pub(crate) struct IncorrectDoNotRecommendLocation;
|
pub(crate) struct IncorrectDoNotRecommendLocation;
|
||||||
|
|
||||||
|
#[derive(Diagnostic)]
|
||||||
|
#[diag(passes_autodiff_attr)]
|
||||||
|
pub(crate) struct AutoDiffAttr {
|
||||||
|
#[primary_span]
|
||||||
|
#[label]
|
||||||
|
pub attr_span: Span,
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(LintDiagnostic)]
|
#[derive(LintDiagnostic)]
|
||||||
#[diag(passes_outer_crate_level_attr)]
|
#[diag(passes_outer_crate_level_attr)]
|
||||||
pub(crate) struct OuterCrateLevelAttr;
|
pub(crate) struct OuterCrateLevelAttr;
|
||||||
|
|
|
@ -481,6 +481,8 @@ symbols! {
|
||||||
audit_that,
|
audit_that,
|
||||||
augmented_assignments,
|
augmented_assignments,
|
||||||
auto_traits,
|
auto_traits,
|
||||||
|
autodiff,
|
||||||
|
autodiff_fallback,
|
||||||
automatically_derived,
|
automatically_derived,
|
||||||
avx,
|
avx,
|
||||||
avx512_target_feature,
|
avx512_target_feature,
|
||||||
|
@ -544,6 +546,7 @@ symbols! {
|
||||||
cfg_accessible,
|
cfg_accessible,
|
||||||
cfg_attr,
|
cfg_attr,
|
||||||
cfg_attr_multi,
|
cfg_attr_multi,
|
||||||
|
cfg_autodiff_fallback,
|
||||||
cfg_boolean_literals,
|
cfg_boolean_literals,
|
||||||
cfg_doctest,
|
cfg_doctest,
|
||||||
cfg_eval,
|
cfg_eval,
|
||||||
|
@ -998,6 +1001,7 @@ symbols! {
|
||||||
hashset_iter_ty,
|
hashset_iter_ty,
|
||||||
hexagon_target_feature,
|
hexagon_target_feature,
|
||||||
hidden,
|
hidden,
|
||||||
|
hint,
|
||||||
homogeneous_aggregate,
|
homogeneous_aggregate,
|
||||||
host,
|
host,
|
||||||
html_favicon_url,
|
html_favicon_url,
|
||||||
|
@ -1650,6 +1654,7 @@ symbols! {
|
||||||
rustc_allow_incoherent_impl,
|
rustc_allow_incoherent_impl,
|
||||||
rustc_allowed_through_unstable_modules,
|
rustc_allowed_through_unstable_modules,
|
||||||
rustc_attrs,
|
rustc_attrs,
|
||||||
|
rustc_autodiff,
|
||||||
rustc_box,
|
rustc_box,
|
||||||
rustc_builtin_macro,
|
rustc_builtin_macro,
|
||||||
rustc_capture_analysis,
|
rustc_capture_analysis,
|
||||||
|
|
|
@ -278,6 +278,15 @@ pub mod assert_matches {
|
||||||
pub use crate::macros::{assert_matches, debug_assert_matches};
|
pub use crate::macros::{assert_matches, debug_assert_matches};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// We don't export this through #[macro_export] for now, to avoid breakage.
|
||||||
|
#[cfg(not(bootstrap))]
|
||||||
|
#[unstable(feature = "autodiff", issue = "124509")]
|
||||||
|
/// Unstable module containing the unstable `autodiff` macro.
|
||||||
|
pub mod autodiff {
|
||||||
|
#[unstable(feature = "autodiff", issue = "124509")]
|
||||||
|
pub use crate::macros::builtin::autodiff;
|
||||||
|
}
|
||||||
|
|
||||||
#[unstable(feature = "cfg_match", issue = "115585")]
|
#[unstable(feature = "cfg_match", issue = "115585")]
|
||||||
pub use crate::macros::cfg_match;
|
pub use crate::macros::cfg_match;
|
||||||
|
|
||||||
|
|
|
@ -1539,6 +1539,24 @@ pub(crate) mod builtin {
|
||||||
($file:expr $(,)?) => {{ /* compiler built-in */ }};
|
($file:expr $(,)?) => {{ /* compiler built-in */ }};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Automatic Differentiation macro which allows generating a new function to compute
|
||||||
|
/// the derivative of a given function. It may only be applied to a function.
|
||||||
|
/// The expected usage syntax is
|
||||||
|
/// `#[autodiff(NAME, MODE, INPUT_ACTIVITIES, OUTPUT_ACTIVITY)]`
|
||||||
|
/// where:
|
||||||
|
/// NAME is a string that represents a valid function name.
|
||||||
|
/// MODE is any of Forward, Reverse, ForwardFirst, ReverseFirst.
|
||||||
|
/// INPUT_ACTIVITIES consists of one valid activity for each input parameter.
|
||||||
|
/// OUTPUT_ACTIVITY must not be set if we implicitely return nothing (or explicitely return
|
||||||
|
/// `-> ()`. Otherwise it must be set to one of the allowed activities.
|
||||||
|
#[unstable(feature = "autodiff", issue = "124509")]
|
||||||
|
#[allow_internal_unstable(rustc_attrs)]
|
||||||
|
#[rustc_builtin_macro]
|
||||||
|
#[cfg(not(bootstrap))]
|
||||||
|
pub macro autodiff($item:item) {
|
||||||
|
/* compiler built-in */
|
||||||
|
}
|
||||||
|
|
||||||
/// Asserts that a boolean expression is `true` at runtime.
|
/// Asserts that a boolean expression is `true` at runtime.
|
||||||
///
|
///
|
||||||
/// This will invoke the [`panic!`] macro if the provided expression cannot be
|
/// This will invoke the [`panic!`] macro if the provided expression cannot be
|
||||||
|
|
|
@ -267,6 +267,7 @@
|
||||||
#![allow(unused_features)]
|
#![allow(unused_features)]
|
||||||
//
|
//
|
||||||
// Features:
|
// Features:
|
||||||
|
#![cfg_attr(not(bootstrap), feature(autodiff))]
|
||||||
#![cfg_attr(test, feature(internal_output_capture, print_internals, update_panic_count, rt))]
|
#![cfg_attr(test, feature(internal_output_capture, print_internals, update_panic_count, rt))]
|
||||||
#![cfg_attr(
|
#![cfg_attr(
|
||||||
all(target_vendor = "fortanix", target_env = "sgx"),
|
all(target_vendor = "fortanix", target_env = "sgx"),
|
||||||
|
@ -627,7 +628,13 @@ pub mod simd {
|
||||||
#[doc(inline)]
|
#[doc(inline)]
|
||||||
pub use crate::std_float::StdFloat;
|
pub use crate::std_float::StdFloat;
|
||||||
}
|
}
|
||||||
|
#[cfg(not(bootstrap))]
|
||||||
|
#[unstable(feature = "autodiff", issue = "124509")]
|
||||||
|
/// This module provides support for automatic differentiation.
|
||||||
|
pub mod autodiff {
|
||||||
|
/// This macro handles automatic differentiation.
|
||||||
|
pub use core::autodiff::autodiff;
|
||||||
|
}
|
||||||
#[stable(feature = "futures_api", since = "1.36.0")]
|
#[stable(feature = "futures_api", since = "1.36.0")]
|
||||||
pub mod task {
|
pub mod task {
|
||||||
//! Types and Traits for working with asynchronous tasks.
|
//! Types and Traits for working with asynchronous tasks.
|
||||||
|
|
Loading…
Add table
Reference in a new issue