diff --git a/src/analyze.rs b/src/analyze.rs index 45f47558..e7420df9 100644 --- a/src/analyze.rs +++ b/src/analyze.rs @@ -790,6 +790,88 @@ impl<'tcx> Analyzer<'tcx> { ensure_annot } + /// Collects every `#[thrust::refinement_path(..)]` path statement in the + /// function body, returning each `(type position, formula_fn DefId)`. + fn extract_refinement_paths( + &self, + local_def_id: LocalDefId, + ) -> Vec<(rty::TypePosition, DefId)> { + let mut out = Vec::new(); + let Some(body) = self.tcx.hir_maybe_body_owned_by(local_def_id) else { + return out; + }; + let rustc_hir::ExprKind::Block(block, _) = body.value.kind else { + return out; + }; + let attr_path = analyze::annot::refinement_path_path(); + let typeck = self.tcx.typeck(local_def_id); + for stmt in block.stmts { + let Some(attr) = self + .tcx + .hir_attrs(stmt.hir_id) + .iter() + .find(|attr| attr.path_matches(&attr_path)) + else { + continue; + }; + let ts = analyze::annot::extract_annot_tokens(attr.clone()); + let position = analyze::annot::parse_type_position(&ts); + + let rustc_hir::StmtKind::Semi(expr) = stmt.kind else { + self.tcx.dcx().span_err( + stmt.span, + "annotated path is expected to be a semi statement", + ); + continue; + }; + let rustc_hir::ExprKind::Path(qpath) = expr.kind else { + self.tcx.dcx().span_err( + expr.span, + "annotated path is expected to be a path expression", + ); + continue; + }; + let rustc_hir::def::Res::Def(_, def_id) = typeck.qpath_res(&qpath, expr.hir_id) else { + self.tcx.dcx().span_err( + expr.span, + "annotated path is expected to refer to a definition", + ); + continue; + }; + out.push((position, def_id)); + } + out + } + + /// Resolves every `#[thrust::refinement_path(..)]` annotation into a + /// positioned refinement, by translating the referenced formula function. + pub fn extract_refinement_annots( + &self, + local_def_id: LocalDefId, + generic_args: mir_ty::GenericArgsRef<'tcx>, + ) -> Vec<(rty::TypePosition, rty::Refinement)> { + let mut out = Vec::new(); + for (position, def_id) in self.extract_refinement_paths(local_def_id) { + let Some(formula_def_id) = def_id.as_local() else { + panic!( + "refinement_path annotation is expected to refer to a local def, but found: {:?}", + def_id + ); + }; + let Some(formula_fn) = self.formula_fn_with_args(formula_def_id, generic_args) else { + panic!( + "refinement_path annotation {:?} is not a formula function", + formula_def_id + ); + }; + let AnnotFormula::Formula(formula) = formula_fn.to_ensure_annot() else { + panic!("refinement_path annotation must lower to a plain formula"); + }; + out.push((position, formula.into())); + } + out + } + /// Whether the given `def_id` corresponds to a method of one of the `Fn` traits. fn is_fn_trait_method(&self, def_id: DefId) -> bool { self.tcx diff --git a/src/analyze/annot.rs b/src/analyze/annot.rs index ec8465c9..5b409f35 100644 --- a/src/analyze/annot.rs +++ b/src/analyze/annot.rs @@ -61,6 +61,10 @@ pub fn ensures_path_path() -> [Symbol; 2] { [Symbol::intern("thrust"), Symbol::intern("ensures_path")] } +pub fn refinement_path_path() -> [Symbol; 2] { + [Symbol::intern("thrust"), Symbol::intern("refinement_path")] +} + pub fn model_ty_path() -> [Symbol; 3] { [ Symbol::intern("thrust"), @@ -207,6 +211,67 @@ pub fn extract_annot_tokens(attr: Attribute) -> TokenStream { d.tokens } +/// Parses a [`rty::TypePosition`] from the tokens of a +/// `#[thrust::refinement_path(..)]` attribute. +/// +/// Tokens are comma-separated steps. Each step is one of: +/// - The keyword `result` → [`rty::TypePositionStep::Return`] (navigate to a +/// function type's return slot). +/// - `$i` (a `$` followed by an integer) → [`rty::TypePositionStep::Param`]`(i)` +/// (navigate to the `i`-th parameter of a function type). +/// - A bare integer `i` → [`rty::TypePositionStep::TypeArg`]`(i)` (navigate to +/// the `i`-th type argument of a generic type such as an enum or `Box`). +/// +/// Examples: `result` is the return; `$0` is the first parameter; `$0, 0` is +/// the first type-argument of the first parameter; `$0, result` is the return +/// of a function-typed first parameter. +pub fn parse_type_position(ts: &TokenStream) -> rty::TypePosition { + use rustc_ast::token::{LitKind, TokenKind}; + use rustc_ast::tokenstream::TokenTree; + + let parse_int = |lit: &rustc_ast::token::Lit| -> usize { + assert_eq!( + lit.kind, + LitKind::Integer, + "expected an integer in type position" + ); + lit.symbol + .as_str() + .parse() + .expect("invalid integer in type position") + }; + + let mut steps = Vec::new(); + let mut iter = ts.iter(); + while let Some(tt) = iter.next() { + let TokenTree::Token(t, _) = tt else { + panic!("unexpected token tree in type position"); + }; + match &t.kind { + TokenKind::Comma => {} + TokenKind::Ident(sym, _) if sym.as_str() == "result" => { + steps.push(rty::TypePositionStep::Return); + } + TokenKind::Dollar => { + let i = match iter.next() { + Some(TokenTree::Token(t, _)) => match &t.kind { + TokenKind::Literal(lit) => parse_int(lit), + _ => panic!("expected integer after `$` in type position: {:?}", t), + }, + _ => panic!("expected integer after `$` in type position"), + }; + steps.push(rty::TypePositionStep::Param(rty::FunctionParamIdx::from(i))); + } + TokenKind::Literal(lit) => { + steps.push(rty::TypePositionStep::TypeArg(parse_int(lit))); + } + _ => panic!("unexpected token in type position: {:?}", t), + } + } + + rty::TypePosition::new(steps) +} + pub fn split_param(ts: &TokenStream) -> (Ident, TokenStream) { use rustc_ast::token::TokenKind; use rustc_ast::tokenstream::TokenTree; diff --git a/src/analyze/local_def.rs b/src/analyze/local_def.rs index 9f7f57ee..1cb273f3 100644 --- a/src/analyze/local_def.rs +++ b/src/analyze/local_def.rs @@ -406,6 +406,10 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { assert!(require_annot.is_none() || param_annots.is_empty()); assert!(ensure_annot.is_none() || ret_annot.is_none()); + let refinement_annots = self + .ctx + .extract_refinement_annots(self.local_def_id, self.generic_args); + let trait_item_ty = self.trait_item_ty(); let is_fully_annotated = self.is_fully_annotated(); @@ -431,6 +435,9 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { if let Some(ret_rty) = ret_annot { builder.ret_rty(ret_rty); } + for (position, refinement) in refinement_annots { + builder.install_refinement_at(&position, refinement); + } if is_fully_annotated { rty::RefinedType::unrefined(builder.build().into()) diff --git a/src/refine/template.rs b/src/refine/template.rs index ed0762ed..37956784 100644 --- a/src/refine/template.rs +++ b/src/refine/template.rs @@ -565,6 +565,49 @@ impl<'tcx, 'a, R> FunctionTemplateTypeBuilder<'tcx, 'a, R> { self.ret_rty = Some(rty); self } + + /// Installs a refinement at a [`rty::TypePosition`]. + /// + /// The first step must be [`rty::TypePositionStep::Param`] or + /// [`rty::TypePositionStep::Return`]; the remaining steps are forwarded to + /// [`rty::RefinedType::install_refinement_at`]. + pub fn install_refinement_at( + &mut self, + position: &rty::TypePosition, + refinement: rty::Refinement, + ) -> &mut Self { + let (first, rest) = match position.steps().split_first() { + Some(pair) => pair, + None => panic!("type position applied to a function type must not be empty"), + }; + match first { + rty::TypePositionStep::Param(idx) => { + if !self.param_rtys.contains_key(idx) { + let ty = self.inner.build(self.param_tys[idx.index()].ty).vacuous(); + self.param_rtys + .insert(*idx, rty::RefinedType::unrefined(ty)); + } + self.param_rtys + .get_mut(idx) + .unwrap() + .install_refinement_at(rest, refinement); + } + rty::TypePositionStep::Return => { + if self.ret_rty.is_none() { + let ty = self.inner.build(self.ret_ty).vacuous(); + self.ret_rty = Some(rty::RefinedType::unrefined(ty)); + } + self.ret_rty + .as_mut() + .unwrap() + .install_refinement_at(rest, refinement); + } + rty::TypePositionStep::TypeArg(_) => { + panic!("type position applied to a function type must start with a param or result step, not a type argument"); + } + } + self + } } impl<'tcx, 'a, R> FunctionTemplateTypeBuilder<'tcx, 'a, R> diff --git a/src/rty.rs b/src/rty.rs index c9a3249a..abfff7bf 100644 --- a/src/rty.rs +++ b/src/rty.rs @@ -83,6 +83,83 @@ where } } +/// One step in a [`TypePosition`] path. +/// +/// A path is a sequence of steps that addresses a sub-type within a +/// (potentially nested) function signature: +/// - [`Param`](Self::Param) / [`Return`](Self::Return) navigate into a +/// function type's parameter or return slot respectively. +/// - [`TypeArg`](Self::TypeArg) navigates into the `i`-th type argument of a +/// generic type (enum, `Box`, etc.). +/// +/// Using distinct variants for function navigation ([`Param`](Self::Param), +/// [`Return`](Self::Return)) and generic-arg navigation +/// ([`TypeArg`](Self::TypeArg)) allows the same path representation to address +/// positions inside higher-order function types. For example, `$0, result` +/// addresses the return type of a function-typed first parameter. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum TypePositionStep { + /// Navigate to the `i`-th parameter of a function type. + Param(FunctionParamIdx), + /// Navigate to the return type of a function type. + Return, + /// Navigate to the `i`-th type argument of a generic type (enum, `Box`, …). + TypeArg(usize), +} + +impl std::fmt::Display for TypePositionStep { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match self { + TypePositionStep::Param(idx) => write!(f, "{}", idx), + TypePositionStep::Return => f.write_str("result"), + TypePositionStep::TypeArg(i) => write!(f, "{}", i), + } + } +} + +/// A path addressing a sub-type within a type, used to attach a refinement. +/// +/// An empty path addresses the type itself. Each step descends one level: +/// [`TypePositionStep::Param`] / [`TypePositionStep::Return`] enter a function +/// type's parameter or return slot, and [`TypePositionStep::TypeArg`] enters a +/// generic type argument. Steps combine freely, so positions inside +/// higher-order function types are expressible. A path applied to a function +/// type is therefore non-empty, beginning with a `Param`/`Return` step. +/// +/// Examples (function `fn f(x: List) -> Box`): +/// - `$0` — parameter `x`. +/// - `result` — the return type. +/// - `$0, 0` — the first type arg of `x`. +/// - `result, 0` — the pointee of the `Box` return. +/// - `$0, result` — the return of a function-typed param `x`. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct TypePosition { + steps: Vec, +} + +impl TypePosition { + pub fn new(steps: Vec) -> Self { + TypePosition { steps } + } + + pub fn steps(&self) -> &[TypePositionStep] { + &self.steps + } +} + +impl std::fmt::Display for TypePosition { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + let mut iter = self.steps.iter(); + if let Some(first) = iter.next() { + write!(f, "{}", first)?; + } + for s in iter { + write!(f, ", {}", s)?; + } + Ok(()) + } +} + #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)] pub enum FunctionAbi { #[default] @@ -1475,6 +1552,54 @@ where } } +impl RefinedType { + /// Installs `refinement` at the sub-type addressed by `steps`. + /// + /// An empty `steps` slice replaces the refinement at this node. Each step + /// in the slice navigates one level deeper: + /// - [`TypePositionStep::TypeArg`] descends into enum type arguments or the + /// `Box` pointee. + /// - [`TypePositionStep::Param`] / [`TypePositionStep::Return`] descend + /// into a function-typed position's parameter or return slot. + pub fn install_refinement_at( + &mut self, + steps: &[TypePositionStep], + refinement: Refinement, + ) { + let Some((step, rest)) = steps.split_first() else { + self.refinement = refinement; + return; + }; + match step { + TypePositionStep::TypeArg(i) => match &mut self.ty { + Type::Enum(e) => { + let arg = e.args.get_mut(TypeParamIdx::from(*i)).unwrap_or_else(|| { + panic!("refine step [{}] out of range for enum type", i) + }); + arg.install_refinement_at(rest, refinement); + } + Type::Pointer(p) => { + assert_eq!(*i, 0, "Box type position must be [0]"); + p.elem.install_refinement_at(rest, refinement); + } + ty => panic!("TypeArg step on unsupported type: {:?}", ty), + }, + TypePositionStep::Param(idx) => match &mut self.ty { + Type::Function(func) => { + func.params[*idx].install_refinement_at(rest, refinement); + } + ty => panic!("Param step on non-function type: {:?}", ty), + }, + TypePositionStep::Return => match &mut self.ty { + Type::Function(func) => { + func.ret.install_refinement_at(rest, refinement); + } + ty => panic!("Return step on non-function type: {:?}", ty), + }, + } + } +} + impl RefinedType { fn pretty_atom<'a, 'b, D>( &'b self, diff --git a/tests/ui/fail/annot_box_term.rs b/tests/ui/fail/annot_box_term.rs index 27184bc6..4391c35e 100644 --- a/tests/ui/fail/annot_box_term.rs +++ b/tests/ui/fail/annot_box_term.rs @@ -1,7 +1,7 @@ //@error-in-other-file: Unsat //@compile-flags: -C debug-assertions=off -#[thrust::sig(fn(x: int) -> {r: Box | r == })] +#[thrust_macros::sig(fn(x: i64) -> { r: Box | r == thrust_models::model::Box::new(x) })] fn box_create(x: i64) -> Box { Box::new(x) } diff --git a/tests/ui/fail/refine_param_simple.rs b/tests/ui/fail/refine_param_simple.rs new file mode 100644 index 00000000..11697be7 --- /dev/null +++ b/tests/ui/fail/refine_param_simple.rs @@ -0,0 +1,9 @@ +//@error-in-other-file: Unsat + +#[thrust_macros::param(x: { v: i32 | v > 0 })] +#[thrust_macros::ret({ r: i32 | r > x })] +fn f(x: i32) -> i32 { + x +} + +fn main() {} diff --git a/tests/ui/fail/refine_sig.rs b/tests/ui/fail/refine_sig.rs new file mode 100644 index 00000000..17b36888 --- /dev/null +++ b/tests/ui/fail/refine_sig.rs @@ -0,0 +1,8 @@ +//@error-in-other-file: Unsat + +#[thrust_macros::sig(fn(x: { v: i32 | v > 0 }) -> { r: i32 | r > x })] +fn g(x: i32) -> i32 { + x +} + +fn main() {} diff --git a/tests/ui/pass/annot_box_term.rs b/tests/ui/pass/annot_box_term.rs index 67c71e16..b96d49e8 100644 --- a/tests/ui/pass/annot_box_term.rs +++ b/tests/ui/pass/annot_box_term.rs @@ -1,7 +1,7 @@ //@check-pass //@compile-flags: -C debug-assertions=off -#[thrust::sig(fn(x: int) -> {r: Box | r == })] +#[thrust_macros::sig(fn(x: i64) -> { r: Box | r == thrust_models::model::Box::new(x) })] fn box_create(x: i64) -> Box { Box::new(x) } diff --git a/tests/ui/pass/refine_param_simple.rs b/tests/ui/pass/refine_param_simple.rs new file mode 100644 index 00000000..bd732583 --- /dev/null +++ b/tests/ui/pass/refine_param_simple.rs @@ -0,0 +1,9 @@ +//@check-pass + +#[thrust_macros::param(x: { v: i32 | v > 0 })] +#[thrust_macros::ret({ r: i32 | r >= x })] +fn f(x: i32) -> i32 { + x +} + +fn main() {} diff --git a/tests/ui/pass/refine_sig.rs b/tests/ui/pass/refine_sig.rs new file mode 100644 index 00000000..db233970 --- /dev/null +++ b/tests/ui/pass/refine_sig.rs @@ -0,0 +1,8 @@ +//@check-pass + +#[thrust_macros::sig(fn(x: { v: i32 | v > 0 }) -> { r: i32 | r >= x })] +fn g(x: i32) -> i32 { + x +} + +fn main() {} diff --git a/thrust-macros/src/lib.rs b/thrust-macros/src/lib.rs index 7bc0c042..3399f9f8 100644 --- a/thrust-macros/src/lib.rs +++ b/thrust-macros/src/lib.rs @@ -6,6 +6,8 @@ use syn::{ WherePredicate, }; +mod rty; + #[derive(Debug, Clone)] enum FnOuterItem { ItemImpl(syn::ItemImpl), @@ -550,6 +552,21 @@ impl ExpandedTokens { } } +#[proc_macro_attribute] +pub fn param(attr: TokenStream, item: TokenStream) -> TokenStream { + rty::expand(rty::AnnotationKind::Param, attr, item) +} + +#[proc_macro_attribute] +pub fn ret(attr: TokenStream, item: TokenStream) -> TokenStream { + rty::expand(rty::AnnotationKind::Ret, attr, item) +} + +#[proc_macro_attribute] +pub fn sig(attr: TokenStream, item: TokenStream) -> TokenStream { + rty::expand(rty::AnnotationKind::Sig, attr, item) +} + fn mentions_self(sig: &syn::Signature) -> bool { struct Visitor { mentions_self: bool, diff --git a/thrust-macros/src/rty.rs b/thrust-macros/src/rty.rs new file mode 100644 index 00000000..e32031e7 --- /dev/null +++ b/thrust-macros/src/rty.rs @@ -0,0 +1,457 @@ +//! Refinement-type annotations: `param`, `ret`, `sig`. +//! +//! These lower refinement types (e.g. `List<{ v: i32 | v > 0 }>`) into +//! `#[thrust::formula_fn]`s plus positioned `#[thrust::refinement_path(..)]` +//! path statements injected into the function body. The "type position" +//! addresses into the function type: a parameter (`$i`) or the return (the +//! `result` keyword) selects a function slot, and bare integer steps (`i`) +//! descend into generic arguments (enum args / `Box` pointee). For example, +//! `#[thrust::refinement_path(result, 0)]` is the first type-argument of the +//! return. + +use proc_macro::TokenStream; +use proc_macro2::{TokenStream as TokenStream2, TokenTree as TokenTree2}; +use quote::{format_ident, quote, ToTokens}; +use syn::{parse_macro_input, FnArg}; + +use super::{ + extended_where_clause, extract_outer_context, fn_params_with_model_ty, generic_params_tokens, + generic_turbofish, model_where_predicates, FnItemWithSignature, FnOuterItem, +}; + +/// One step in a refinement's type-position path. +/// +/// Mirrors the plugin's `rty::TypePositionStep` and uses the same attribute +/// encoding: +/// - [`Param`](Self::Param) / [`Return`](Self::Return) navigate into a function +/// type; encoded as `$i` / the `result` keyword. +/// - [`TypeArg`](Self::TypeArg) navigates into a generic type argument; encoded +/// as a bare integer `i`. +#[derive(Clone, Copy)] +enum TypePositionStep { + Param(usize), + Return, + TypeArg(usize), +} + +#[derive(Clone)] +struct Refinement { + /// Full type-position path from the function root to the refined type. + steps: Vec, + binder: syn::Ident, + binder_ty: TokenStream2, + formula: TokenStream2, +} + +/// Which refinement-type annotation is being expanded. +pub(crate) enum AnnotationKind { + Param, + Ret, + Sig, +} + +/// A type expression from the annotation, paired with the position of its root +/// within the function signature. [`scan_type`] walks each one to extract the +/// refinements it contains. +struct PositionedTypeExpr { + /// Steps locating the root of `tokens` (e.g. `[Param(0)]` for the first + /// parameter); [`scan_type`] appends `TypeArg` steps as it descends. + root: Vec, + tokens: Vec, +} + +/// A `name : type` binding, e.g. a parameter in a `sig` annotation or the +/// binder of a refinement `{ name: type | .. }`. +struct NamedType { + name: syn::Ident, + tokens: Vec, +} + +pub(crate) fn expand(kind: AnnotationKind, attr: TokenStream, item: TokenStream) -> TokenStream { + let mut func = parse_macro_input!(item as FnItemWithSignature); + + let outer_context = match extract_outer_context(&func) { + Ok(ctx) => ctx, + Err(e) => { + let err = e.to_compile_error(); + return quote! { #err #func }.into(); + } + }; + + let attr_tokens: Vec = TokenStream2::from(attr).into_iter().collect(); + let type_exprs = match annotated_type_exprs(kind, &func, &attr_tokens) { + Ok(exprs) => exprs, + Err(e) => { + let err = e.to_compile_error(); + return quote! { #err #func }.into(); + } + }; + + let mut refinements = Vec::new(); + for expr in type_exprs { + if let Err(e) = scan_type(&expr.tokens, expr.root, &mut refinements) { + let err = e.to_compile_error(); + return quote! { #err #func }.into(); + } + } + + if refinements.is_empty() { + return func.into_token_stream().into(); + } + + let has_receiver = func.sig().receiver().is_some(); + let mut formula_fns = Vec::new(); + let mut path_stmts = Vec::new(); + for mut r in refinements { + if has_receiver { + r.formula = rewrite_self_in_tokens(r.formula); + } + formula_fns.push(refine_formula_fn(&func, outer_context.as_ref(), &r)); + path_stmts.push(refine_path_stmt(&func, &r)); + } + + let Some(block) = func.block_mut() else { + let err = syn::Error::new_spanned( + func.sig().ident.clone(), + "refinement-type annotations require a function body", + ) + .into_compile_error(); + return quote! { #err #func }.into(); + }; + let orig_stmts = block.stmts.drain(..).collect::>(); + *block = syn::parse_quote!({ + #(#path_stmts)* + #(#orig_stmts)* + }); + func.attrs_mut() + .push(syn::parse_quote!(#[allow(path_statements)])); + + quote! { + #(#formula_fns)* + #func + } + .into() +} + +/// Turns an annotation into the type expressions to scan, each anchored at its +/// root position within the function signature. +fn annotated_type_exprs( + kind: AnnotationKind, + func: &FnItemWithSignature, + attr_tokens: &[TokenTree2], +) -> syn::Result> { + let at_param = |func: &FnItemWithSignature, nt: NamedType| -> syn::Result { + let idx = param_index(func, &nt.name)?; + Ok(PositionedTypeExpr { + root: vec![TypePositionStep::Param(idx)], + tokens: nt.tokens, + }) + }; + match kind { + AnnotationKind::Param => Ok(vec![at_param(func, split_name_type(attr_tokens)?)?]), + AnnotationKind::Ret => Ok(vec![PositionedTypeExpr { + root: vec![TypePositionStep::Return], + tokens: attr_tokens.to_vec(), + }]), + AnnotationKind::Sig => { + let sig = parse_sig_attr(attr_tokens)?; + let mut exprs = Vec::new(); + for param in sig.params { + exprs.push(at_param(func, param)?); + } + exprs.push(PositionedTypeExpr { + root: vec![TypePositionStep::Return], + tokens: sig.ret, + }); + Ok(exprs) + } + } +} + +fn param_index(func: &FnItemWithSignature, name: &syn::Ident) -> syn::Result { + let pos = func.sig().inputs.iter().position(|arg| match arg { + FnArg::Receiver(_) => name == "self", + FnArg::Typed(pt) => matches!(&*pt.pat, syn::Pat::Ident(pi) if &pi.ident == name), + }); + pos.ok_or_else(|| { + syn::Error::new_spanned(name, format!("no parameter named `{}` in signature", name)) + }) +} + +/// Parses `name : ` from a flat token slice. +fn split_name_type(tokens: &[TokenTree2]) -> syn::Result { + let name = match tokens.first() { + Some(TokenTree2::Ident(id)) => id.clone(), + _ => return Err(err_tokens(tokens, "expected a parameter name")), + }; + match tokens.get(1) { + Some(TokenTree2::Punct(p)) if p.as_char() == ':' => {} + _ => return Err(err_tokens(tokens, "expected `:` after parameter name")), + } + Ok(NamedType { + name, + tokens: tokens[2..].to_vec(), + }) +} + +/// The parsed parts of a `fn ( n0: t0 , ... ) -> ret` signature annotation. +struct SigAnnotation { + params: Vec, + ret: Vec, +} + +/// Parses `fn ( n0: t0 , ... ) -> ret`. +fn parse_sig_attr(tokens: &[TokenTree2]) -> syn::Result { + match tokens.first() { + Some(TokenTree2::Ident(id)) if id == "fn" => {} + _ => return Err(err_tokens(tokens, "expected `fn` in sig annotation")), + } + let arg_group = match tokens.get(1) { + Some(TokenTree2::Group(g)) if g.delimiter() == proc_macro2::Delimiter::Parenthesis => g, + _ => return Err(err_tokens(tokens, "expected `(..)` after `fn`")), + }; + + let mut params = Vec::new(); + let arg_tokens: Vec = arg_group.stream().into_iter().collect(); + for arg in split_top_level_commas(&arg_tokens) { + if arg.is_empty() { + continue; + } + params.push(split_name_type(&arg)?); + } + + // expect `->` then the return type + let mut rest = &tokens[2..]; + match (rest.first(), rest.get(1)) { + (Some(TokenTree2::Punct(a)), Some(TokenTree2::Punct(b))) + if a.as_char() == '-' && b.as_char() == '>' => + { + rest = &rest[2..]; + } + _ => { + return Err(err_tokens( + tokens, + "expected `->` and a return type in sig annotation", + )) + } + } + Ok(SigAnnotation { + params, + ret: rest.to_vec(), + }) +} + +/// Scans a type expression and records every refinement node with its full +/// type-position path (`steps`). +/// +/// `steps` holds the path from the function root to the current type node. +/// When a refinement `{binder: ty | formula}` is found the current `steps` are +/// recorded; when descending into generic type arguments a +/// [`TypePositionStep::TypeArg`]`(i)` step is appended to `steps`. +fn scan_type( + tokens: &[TokenTree2], + steps: Vec, + out: &mut Vec, +) -> syn::Result<()> { + if tokens.is_empty() { + return Ok(()); + } + + // A refinement node is exactly a brace-delimited group. + if tokens.len() == 1 { + if let TokenTree2::Group(g) = &tokens[0] { + if g.delimiter() == proc_macro2::Delimiter::Brace { + let (binder, formula) = split_refinement(g.stream())?; + out.push(Refinement { + steps: steps.clone(), + binder: binder.name, + binder_ty: binder.tokens.iter().cloned().collect(), + formula, + }); + // Descend into the binder type for nested refinements. + scan_type(&binder.tokens, steps, out)?; + return Ok(()); + } + } + } + + // A nominal type `Name` (`Box` included). + if let TokenTree2::Ident(_) = &tokens[0] { + if let Some(TokenTree2::Punct(p)) = tokens.get(1) { + if p.as_char() == '<' { + let mut type_idx = 0; + for arg in split_angle_args(&tokens[2..]) { + if is_lifetime(&arg) { + continue; + } + let mut child = steps.clone(); + child.push(TypePositionStep::TypeArg(type_idx)); + scan_type(&arg, child, out)?; + type_idx += 1; + } + } + } + } + + Ok(()) +} + +/// Splits `{ binder : ty | formula }` into its binder and formula expression. +fn split_refinement(stream: TokenStream2) -> syn::Result<(NamedType, TokenStream2)> { + let toks: Vec = stream.into_iter().collect(); + let bar = toks + .iter() + .position(|tt| matches!(tt, TokenTree2::Punct(p) if p.as_char() == '|')) + .ok_or_else(|| err_tokens(&toks, "refinement type must contain `|`"))?; + let binder = split_name_type(&toks[..bar])?; + let formula: TokenStream2 = toks[bar + 1..].iter().cloned().collect(); + Ok((binder, formula)) +} + +/// Splits the tokens following an opening `<` at top level by commas, stopping +/// at the matching `>`. +fn split_angle_args(tokens: &[TokenTree2]) -> Vec> { + let mut args = Vec::new(); + let mut cur = Vec::new(); + let mut depth = 1usize; + for tt in tokens { + if let TokenTree2::Punct(p) = tt { + match p.as_char() { + '<' => { + depth += 1; + cur.push(tt.clone()); + continue; + } + '>' => { + depth -= 1; + if depth == 0 { + break; + } + cur.push(tt.clone()); + continue; + } + ',' if depth == 1 => { + args.push(std::mem::take(&mut cur)); + continue; + } + _ => {} + } + } + cur.push(tt.clone()); + } + if !cur.is_empty() { + args.push(cur); + } + args +} + +fn split_top_level_commas(tokens: &[TokenTree2]) -> Vec> { + let mut out = Vec::new(); + let mut cur = Vec::new(); + let mut depth = 0i32; + for tt in tokens { + if let TokenTree2::Punct(p) = tt { + match p.as_char() { + '<' => depth += 1, + '>' => depth -= 1, + ',' if depth == 0 => { + out.push(std::mem::take(&mut cur)); + continue; + } + _ => {} + } + } + cur.push(tt.clone()); + } + out.push(cur); + out +} + +fn is_lifetime(tokens: &[TokenTree2]) -> bool { + matches!(tokens.first(), Some(TokenTree2::Punct(p)) if p.as_char() == '\'') +} + +fn err_tokens(tokens: &[TokenTree2], msg: &str) -> syn::Error { + let stream: TokenStream2 = tokens.iter().cloned().collect(); + syn::Error::new_spanned(stream, msg) +} + +fn rewrite_self_in_tokens(tokens: TokenStream2) -> TokenStream2 { + tokens + .into_iter() + .map(|tt| match tt { + TokenTree2::Ident(id) if id == "self" => TokenTree2::Ident(format_ident!("self_")), + TokenTree2::Group(g) => { + let inner = rewrite_self_in_tokens(g.stream()); + TokenTree2::Group(proc_macro2::Group::new(g.delimiter(), inner)) + } + other => other, + }) + .collect() +} + +fn refine_fn_name(func: &FnItemWithSignature, r: &Refinement) -> syn::Ident { + let pos = r + .steps + .iter() + .map(|s| match s { + TypePositionStep::Param(i) => format!("p{}", i), + TypePositionStep::Return => "ret".to_string(), + TypePositionStep::TypeArg(i) => format!("t{}", i), + }) + .collect::>() + .join("_"); + format_ident!("_thrust_refine_{}_{}", func.sig().ident, pos) +} + +fn refine_formula_fn( + func: &FnItemWithSignature, + outer_context: Option<&FnOuterItem>, + r: &Refinement, +) -> TokenStream2 { + let name = refine_fn_name(func, r); + let def_generics = generic_params_tokens(&func.sig().generics); + let model_params = fn_params_with_model_ty(&func.sig().inputs); + let model_preds = model_where_predicates(func, outer_context); + let extended_where = extended_where_clause(func, &model_preds); + let binder = &r.binder; + let binder_ty = &r.binder_ty; + let formula = &r.formula; + + quote! { + #[allow(unused_variables)] + #[allow(non_snake_case)] + #[thrust::formula_fn] + fn #name #def_generics( + #binder: <#binder_ty as thrust_models::Model>::Ty, + #model_params + ) -> bool #extended_where { + #formula + } + } +} + +fn refine_path_stmt(func: &FnItemWithSignature, r: &Refinement) -> TokenStream2 { + let name = refine_fn_name(func, r); + let turbofish = generic_turbofish(&func.sig().generics); + let path_prefix = if func.sig().receiver().is_some() { + quote!(Self::) + } else { + quote!() + }; + let encoded_steps = r.steps.iter().map(|s| match s { + TypePositionStep::Param(i) => { + let lit = proc_macro2::Literal::usize_unsuffixed(*i); + quote!($#lit) + } + TypePositionStep::Return => quote!(result), + TypePositionStep::TypeArg(i) => { + let lit = proc_macro2::Literal::usize_unsuffixed(*i); + quote!(#lit) + } + }); + quote! { + #[thrust::refinement_path(#(#encoded_steps),*)] + #path_prefix #name #turbofish; + } +}