From 09ed93896da9f76422f66720c6eea94798417ba1 Mon Sep 17 00:00:00 2001 From: Claude Date: Tue, 26 May 2026 05:07:35 +0000 Subject: [PATCH 1/8] Add param/ret/sig refinement-type annotations via thrust_macros MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Introduce `thrust_macros::param`, `thrust_macros::ret`, and `thrust_macros::sig` attribute macros that lower refinement types (e.g. `{ v: i32 | v > 0 }`) into `#[thrust::formula_fn]`s, giving refinement formulas the same rustc-typechecked treatment as requires/ensures. Each refinement is placed via a new "type position" — a path addressing into the function type (parameter index or return slot, then generic-argument indices for enum args and Box pointees) — emitted as a `#[thrust::refine(..)]` path statement and installed into the parameter or return RefinedType template. Migrate the existing `thrust::sig` tests to `thrust_macros::sig`. --- src/analyze.rs | 79 ++++++ src/analyze/annot.rs | 31 +++ src/analyze/local_def.rs | 7 + src/refine/template.rs | 33 +++ src/rty.rs | 40 +++ tests/ui/fail/annot_box_term.rs | 2 +- tests/ui/fail/refine_param_simple.rs | 9 + tests/ui/fail/refine_sig.rs | 8 + tests/ui/pass/annot_box_term.rs | 2 +- tests/ui/pass/refine_param_simple.rs | 9 + tests/ui/pass/refine_sig.rs | 8 + thrust-macros/src/lib.rs | 397 ++++++++++++++++++++++++++- 12 files changed, 622 insertions(+), 3 deletions(-) create mode 100644 tests/ui/fail/refine_param_simple.rs create mode 100644 tests/ui/fail/refine_sig.rs create mode 100644 tests/ui/pass/refine_param_simple.rs create mode 100644 tests/ui/pass/refine_sig.rs diff --git a/src/analyze.rs b/src/analyze.rs index 45f47558..65df3712 100644 --- a/src/analyze.rs +++ b/src/analyze.rs @@ -790,6 +790,85 @@ impl<'tcx> Analyzer<'tcx> { ensure_annot } + /// Collects every `#[thrust::refine(..)]` path statement in the function + /// body, returning each `(type position, formula_fn DefId)`. + fn extract_refine_paths(&self, local_def_id: LocalDefId) -> Vec<(Vec, DefId)> { + let mut out = Vec::new(); + let Some(body) = self.tcx.hir_maybe_body_owned_by(local_def_id) else { + return out; + }; + let rustc_hir::ExprKind::Block(block, _) = body.value.kind else { + return out; + }; + let attr_path = analyze::annot::refine_path(); + let typeck = self.tcx.typeck(local_def_id); + for stmt in block.stmts { + let Some(attr) = self + .tcx + .hir_attrs(stmt.hir_id) + .iter() + .find(|attr| attr.path_matches(&attr_path)) + else { + continue; + }; + let ts = analyze::annot::extract_annot_tokens(attr.clone()); + let position = analyze::annot::parse_position(&ts); + + let rustc_hir::StmtKind::Semi(expr) = stmt.kind else { + self.tcx.dcx().span_err( + stmt.span, + "annotated path is expected to be a semi statement", + ); + continue; + }; + let rustc_hir::ExprKind::Path(qpath) = expr.kind else { + self.tcx.dcx().span_err( + expr.span, + "annotated path is expected to be a path expression", + ); + continue; + }; + let rustc_hir::def::Res::Def(_, def_id) = typeck.qpath_res(&qpath, expr.hir_id) else { + self.tcx.dcx().span_err( + expr.span, + "annotated path is expected to refer to a definition", + ); + continue; + }; + out.push((position, def_id)); + } + out + } + + /// Resolves every `#[thrust::refine(..)]` annotation into a positioned + /// refinement, by translating the referenced formula function. + pub fn extract_refine_annots( + &self, + local_def_id: LocalDefId, + generic_args: mir_ty::GenericArgsRef<'tcx>, + ) -> Vec<(Vec, rty::Refinement)> { + let mut out = Vec::new(); + for (position, def_id) in self.extract_refine_paths(local_def_id) { + let Some(formula_def_id) = def_id.as_local() else { + panic!( + "refine annotation with path is expected to refer to a local def, but found: {:?}", + def_id + ); + }; + let Some(formula_fn) = self.formula_fn_with_args(formula_def_id, generic_args) else { + panic!( + "refine annotation {:?} is not a formula function", + formula_def_id + ); + }; + let AnnotFormula::Formula(formula) = formula_fn.to_ensure_annot() else { + panic!("refine annotation must lower to a plain formula"); + }; + out.push((position, formula.into())); + } + out + } + /// Whether the given `def_id` corresponds to a method of one of the `Fn` traits. fn is_fn_trait_method(&self, def_id: DefId) -> bool { self.tcx diff --git a/src/analyze/annot.rs b/src/analyze/annot.rs index ec8465c9..0d539a8f 100644 --- a/src/analyze/annot.rs +++ b/src/analyze/annot.rs @@ -61,6 +61,10 @@ pub fn ensures_path_path() -> [Symbol; 2] { [Symbol::intern("thrust"), Symbol::intern("ensures_path")] } +pub fn refine_path() -> [Symbol; 2] { + [Symbol::intern("thrust"), Symbol::intern("refine")] +} + pub fn model_ty_path() -> [Symbol; 3] { [ Symbol::intern("thrust"), @@ -207,6 +211,33 @@ pub fn extract_annot_tokens(attr: Attribute) -> TokenStream { d.tokens } +/// Parses a comma-separated list of integer literals (a "type position") from +/// the tokens of a `#[thrust::refine(..)]` attribute. +pub fn parse_position(ts: &TokenStream) -> Vec { + use rustc_ast::token::{LitKind, TokenKind}; + use rustc_ast::tokenstream::TokenTree; + + let mut out = Vec::new(); + for tt in ts.iter() { + match tt { + TokenTree::Token(t, _) => match &t.kind { + TokenKind::Comma => {} + TokenKind::Literal(lit) if lit.kind == LitKind::Integer => { + out.push( + lit.symbol + .as_str() + .parse() + .expect("invalid integer in refine position"), + ); + } + _ => panic!("unexpected token in refine position: {:?}", t), + }, + _ => panic!("unexpected token tree in refine position"), + } + } + out +} + pub fn split_param(ts: &TokenStream) -> (Ident, TokenStream) { use rustc_ast::token::TokenKind; use rustc_ast::tokenstream::TokenTree; diff --git a/src/analyze/local_def.rs b/src/analyze/local_def.rs index 9f7f57ee..6ccff823 100644 --- a/src/analyze/local_def.rs +++ b/src/analyze/local_def.rs @@ -406,6 +406,10 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { assert!(require_annot.is_none() || param_annots.is_empty()); assert!(ensure_annot.is_none() || ret_annot.is_none()); + let refine_annots = self + .ctx + .extract_refine_annots(self.local_def_id, self.generic_args); + let trait_item_ty = self.trait_item_ty(); let is_fully_annotated = self.is_fully_annotated(); @@ -431,6 +435,9 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { if let Some(ret_rty) = ret_annot { builder.ret_rty(ret_rty); } + for (position, refinement) in refine_annots { + builder.refine(&position, refinement); + } if is_fully_annotated { rty::RefinedType::unrefined(builder.build().into()) diff --git a/src/refine/template.rs b/src/refine/template.rs index ed0762ed..074b8bb7 100644 --- a/src/refine/template.rs +++ b/src/refine/template.rs @@ -565,6 +565,39 @@ impl<'tcx, 'a, R> FunctionTemplateTypeBuilder<'tcx, 'a, R> { self.ret_rty = Some(rty); self } + + /// Installs a refinement at a function-type position. The first index of + /// `path` selects a parameter (by index) or the return slot (when it equals + /// the parameter count); the remaining indices descend into the slot's type. + pub fn refine( + &mut self, + path: &[usize], + refinement: rty::Refinement, + ) -> &mut Self { + let (&slot, sub) = path.split_first().expect("refine path must be non-empty"); + let n = self.param_tys.len(); + if slot < n { + let idx = rty::FunctionParamIdx::from(slot); + if !self.param_rtys.contains_key(&idx) { + let ty = self.inner.build(self.param_tys[slot].ty).vacuous(); + self.param_rtys.insert(idx, rty::RefinedType::unrefined(ty)); + } + self.param_rtys + .get_mut(&idx) + .unwrap() + .install_refinement_at(sub, refinement); + } else { + if self.ret_rty.is_none() { + let ty = self.inner.build(self.ret_ty).vacuous(); + self.ret_rty = Some(rty::RefinedType::unrefined(ty)); + } + self.ret_rty + .as_mut() + .unwrap() + .install_refinement_at(sub, refinement); + } + self + } } impl<'tcx, 'a, R> FunctionTemplateTypeBuilder<'tcx, 'a, R> diff --git a/src/rty.rs b/src/rty.rs index c9a3249a..a7c543c0 100644 --- a/src/rty.rs +++ b/src/rty.rs @@ -1475,6 +1475,46 @@ where } } +impl RefinedType { + /// Installs `refinement` at the given type position, descending through the + /// function type (parameters then return), enum type arguments, and `Box` + /// pointees. An empty path replaces the refinement at the current node. + pub fn install_refinement_at( + &mut self, + path: &[usize], + refinement: Refinement, + ) { + let Some((&step, rest)) = path.split_first() else { + self.refinement = refinement; + return; + }; + match &mut self.ty { + Type::Enum(e) => { + let arg = e.args.get_mut(TypeParamIdx::from(step)).unwrap_or_else(|| { + panic!("refine position {} out of range for enum type", step) + }); + arg.install_refinement_at(rest, refinement); + } + Type::Pointer(p) => { + assert_eq!(step, 0, "Box type position must be 0"); + p.elem.install_refinement_at(rest, refinement); + } + Type::Function(f) => { + let n = f.params.len(); + if step < n { + f.params[FunctionParamIdx::from(step)].install_refinement_at(rest, refinement); + } else { + f.ret.install_refinement_at(rest, refinement); + } + } + ty => panic!( + "unsupported type at refine position step {}: {:?}", + step, ty + ), + } + } +} + impl RefinedType { fn pretty_atom<'a, 'b, D>( &'b self, diff --git a/tests/ui/fail/annot_box_term.rs b/tests/ui/fail/annot_box_term.rs index 27184bc6..4391c35e 100644 --- a/tests/ui/fail/annot_box_term.rs +++ b/tests/ui/fail/annot_box_term.rs @@ -1,7 +1,7 @@ //@error-in-other-file: Unsat //@compile-flags: -C debug-assertions=off -#[thrust::sig(fn(x: int) -> {r: Box | r == })] +#[thrust_macros::sig(fn(x: i64) -> { r: Box | r == thrust_models::model::Box::new(x) })] fn box_create(x: i64) -> Box { Box::new(x) } diff --git a/tests/ui/fail/refine_param_simple.rs b/tests/ui/fail/refine_param_simple.rs new file mode 100644 index 00000000..11697be7 --- /dev/null +++ b/tests/ui/fail/refine_param_simple.rs @@ -0,0 +1,9 @@ +//@error-in-other-file: Unsat + +#[thrust_macros::param(x: { v: i32 | v > 0 })] +#[thrust_macros::ret({ r: i32 | r > x })] +fn f(x: i32) -> i32 { + x +} + +fn main() {} diff --git a/tests/ui/fail/refine_sig.rs b/tests/ui/fail/refine_sig.rs new file mode 100644 index 00000000..17b36888 --- /dev/null +++ b/tests/ui/fail/refine_sig.rs @@ -0,0 +1,8 @@ +//@error-in-other-file: Unsat + +#[thrust_macros::sig(fn(x: { v: i32 | v > 0 }) -> { r: i32 | r > x })] +fn g(x: i32) -> i32 { + x +} + +fn main() {} diff --git a/tests/ui/pass/annot_box_term.rs b/tests/ui/pass/annot_box_term.rs index 67c71e16..b96d49e8 100644 --- a/tests/ui/pass/annot_box_term.rs +++ b/tests/ui/pass/annot_box_term.rs @@ -1,7 +1,7 @@ //@check-pass //@compile-flags: -C debug-assertions=off -#[thrust::sig(fn(x: int) -> {r: Box | r == })] +#[thrust_macros::sig(fn(x: i64) -> { r: Box | r == thrust_models::model::Box::new(x) })] fn box_create(x: i64) -> Box { Box::new(x) } diff --git a/tests/ui/pass/refine_param_simple.rs b/tests/ui/pass/refine_param_simple.rs new file mode 100644 index 00000000..bd732583 --- /dev/null +++ b/tests/ui/pass/refine_param_simple.rs @@ -0,0 +1,9 @@ +//@check-pass + +#[thrust_macros::param(x: { v: i32 | v > 0 })] +#[thrust_macros::ret({ r: i32 | r >= x })] +fn f(x: i32) -> i32 { + x +} + +fn main() {} diff --git a/tests/ui/pass/refine_sig.rs b/tests/ui/pass/refine_sig.rs new file mode 100644 index 00000000..db233970 --- /dev/null +++ b/tests/ui/pass/refine_sig.rs @@ -0,0 +1,8 @@ +//@check-pass + +#[thrust_macros::sig(fn(x: { v: i32 | v > 0 }) -> { r: i32 | r >= x })] +fn g(x: i32) -> i32 { + x +} + +fn main() {} diff --git a/thrust-macros/src/lib.rs b/thrust-macros/src/lib.rs index 7bc0c042..d17c914a 100644 --- a/thrust-macros/src/lib.rs +++ b/thrust-macros/src/lib.rs @@ -1,5 +1,5 @@ use proc_macro::TokenStream; -use proc_macro2::TokenStream as TokenStream2; +use proc_macro2::{TokenStream as TokenStream2, TokenTree as TokenTree2}; use quote::{format_ident, quote, ToTokens}; use syn::{ parse_macro_input, punctuated::Punctuated, FnArg, GenericParam, Generics, TypeParamBound, @@ -550,6 +550,401 @@ impl ExpandedTokens { } } +// --------------------------------------------------------------------------- +// Refinement-type annotations: `param`, `ret`, `sig`. +// +// These lower refinement types (e.g. `List<{ v: i32 | v > 0 }>`) into +// `#[thrust::formula_fn]`s plus positioned `#[thrust::refine(..)]` path +// statements injected into the function body. The "type position" addresses +// into the function type: the first index selects a parameter (by index) or +// the return slot (== parameter count), and subsequent indices descend into +// generic arguments (enum args / `Box` pointee). +// --------------------------------------------------------------------------- + +#[derive(Clone)] +struct Refinement { + path: Vec, + binder: syn::Ident, + binder_ty: TokenStream2, + formula: TokenStream2, +} + +#[proc_macro_attribute] +pub fn param(attr: TokenStream, item: TokenStream) -> TokenStream { + expand_refine(RefineKind::Param, attr, item) +} + +#[proc_macro_attribute] +pub fn ret(attr: TokenStream, item: TokenStream) -> TokenStream { + expand_refine(RefineKind::Ret, attr, item) +} + +#[proc_macro_attribute] +pub fn sig(attr: TokenStream, item: TokenStream) -> TokenStream { + expand_refine(RefineKind::Sig, attr, item) +} + +enum RefineKind { + Param, + Ret, + Sig, +} + +fn expand_refine(kind: RefineKind, attr: TokenStream, item: TokenStream) -> TokenStream { + let mut func = parse_macro_input!(item as FnItemWithSignature); + + let outer_context = match extract_outer_context(&func) { + Ok(ctx) => ctx, + Err(e) => { + let err = e.to_compile_error(); + return quote! { #err #func }.into(); + } + }; + + let attr_tokens: Vec = TokenStream2::from(attr).into_iter().collect(); + let jobs = match build_refine_jobs(kind, &func, &attr_tokens) { + Ok(jobs) => jobs, + Err(e) => { + let err = e.to_compile_error(); + return quote! { #err #func }.into(); + } + }; + + let mut refinements = Vec::new(); + for (root, ty_tokens) in jobs { + if let Err(e) = scan_type(&ty_tokens, &root, &mut refinements) { + let err = e.to_compile_error(); + return quote! { #err #func }.into(); + } + } + + if refinements.is_empty() { + return func.into_token_stream().into(); + } + + let has_receiver = func.sig().receiver().is_some(); + let mut formula_fns = Vec::new(); + let mut path_stmts = Vec::new(); + for mut r in refinements { + if has_receiver { + r.formula = rewrite_self_in_tokens(r.formula); + } + formula_fns.push(refine_formula_fn(&func, outer_context.as_ref(), &r)); + path_stmts.push(refine_path_stmt(&func, &r)); + } + + let Some(block) = func.block_mut() else { + let err = syn::Error::new_spanned( + func.sig().ident.clone(), + "refinement-type annotations require a function body", + ) + .into_compile_error(); + return quote! { #err #func }.into(); + }; + let orig_stmts = block.stmts.drain(..).collect::>(); + *block = syn::parse_quote!({ + #(#path_stmts)* + #(#orig_stmts)* + }); + func.attrs_mut() + .push(syn::parse_quote!(#[allow(path_statements)])); + + quote! { + #(#formula_fns)* + #func + } + .into() +} + +/// Builds `(root_path, type_tokens)` jobs to scan from the attribute tokens. +fn build_refine_jobs( + kind: RefineKind, + func: &FnItemWithSignature, + attr_tokens: &[TokenTree2], +) -> syn::Result, Vec)>> { + let param_count = func.sig().inputs.len(); + match kind { + RefineKind::Param => { + let (name, ty_tokens) = split_name_type(attr_tokens)?; + let idx = param_index(func, &name)?; + Ok(vec![(vec![idx], ty_tokens)]) + } + RefineKind::Ret => Ok(vec![(vec![param_count], attr_tokens.to_vec())]), + RefineKind::Sig => { + let (args, ret_tokens) = parse_sig_attr(attr_tokens)?; + let mut jobs = Vec::new(); + for (name, ty_tokens) in args { + let idx = param_index(func, &name)?; + jobs.push((vec![idx], ty_tokens)); + } + jobs.push((vec![param_count], ret_tokens)); + Ok(jobs) + } + } +} + +fn param_index(func: &FnItemWithSignature, name: &syn::Ident) -> syn::Result { + let pos = func.sig().inputs.iter().position(|arg| match arg { + FnArg::Receiver(_) => name == "self", + FnArg::Typed(pt) => matches!(&*pt.pat, syn::Pat::Ident(pi) if &pi.ident == name), + }); + pos.ok_or_else(|| { + syn::Error::new_spanned(name, format!("no parameter named `{}` in signature", name)) + }) +} + +/// Parses `name : ` from a flat token slice. +fn split_name_type(tokens: &[TokenTree2]) -> syn::Result<(syn::Ident, Vec)> { + let name = match tokens.first() { + Some(TokenTree2::Ident(id)) => id.clone(), + _ => return Err(err_tokens(tokens, "expected a parameter name")), + }; + match tokens.get(1) { + Some(TokenTree2::Punct(p)) if p.as_char() == ':' => {} + _ => return Err(err_tokens(tokens, "expected `:` after parameter name")), + } + Ok((name, tokens[2..].to_vec())) +} + +/// Parses `fn ( n0: t0 , ... ) -> ret` into `((name, ty_tokens)*, ret_tokens)`. +#[allow(clippy::type_complexity)] +fn parse_sig_attr( + tokens: &[TokenTree2], +) -> syn::Result<(Vec<(syn::Ident, Vec)>, Vec)> { + match tokens.first() { + Some(TokenTree2::Ident(id)) if id == "fn" => {} + _ => return Err(err_tokens(tokens, "expected `fn` in sig annotation")), + } + let arg_group = match tokens.get(1) { + Some(TokenTree2::Group(g)) if g.delimiter() == proc_macro2::Delimiter::Parenthesis => g, + _ => return Err(err_tokens(tokens, "expected `(..)` after `fn`")), + }; + + let mut args = Vec::new(); + let arg_tokens: Vec = arg_group.stream().into_iter().collect(); + for arg in split_top_level_commas(&arg_tokens) { + if arg.is_empty() { + continue; + } + args.push(split_name_type(&arg)?); + } + + // expect `->` then the return type + let mut rest = &tokens[2..]; + match (rest.first(), rest.get(1)) { + (Some(TokenTree2::Punct(a)), Some(TokenTree2::Punct(b))) + if a.as_char() == '-' && b.as_char() == '>' => + { + rest = &rest[2..]; + } + _ => { + return Err(err_tokens( + tokens, + "expected `->` and a return type in sig annotation", + )) + } + } + Ok((args, rest.to_vec())) +} + +/// Scans a single type expression, recording every refinement node together +/// with its type position. +fn scan_type(tokens: &[TokenTree2], path: &[usize], out: &mut Vec) -> syn::Result<()> { + if tokens.is_empty() { + return Ok(()); + } + + // A refinement node is exactly a brace-delimited group. + if tokens.len() == 1 { + if let TokenTree2::Group(g) = &tokens[0] { + if g.delimiter() == proc_macro2::Delimiter::Brace { + let (binder, binder_ty, formula) = split_refinement(g.stream())?; + out.push(Refinement { + path: path.to_vec(), + binder, + binder_ty: binder_ty.iter().cloned().collect(), + formula, + }); + // The refinement's own type sits at the same position; descend + // into it to find further nested refinements. + scan_type(&binder_ty, path, out)?; + return Ok(()); + } + } + } + + // A nominal type `Name` (`Box` included). + if let TokenTree2::Ident(_) = &tokens[0] { + if let Some(TokenTree2::Punct(p)) = tokens.get(1) { + if p.as_char() == '<' { + let mut type_idx = 0; + for arg in split_angle_args(&tokens[2..]) { + if is_lifetime(&arg) { + continue; + } + let mut child = path.to_vec(); + child.push(type_idx); + scan_type(&arg, &child, out)?; + type_idx += 1; + } + } + } + } + + Ok(()) +} + +/// Splits `{ binder : ty | formula }` contents into its parts. +fn split_refinement( + stream: TokenStream2, +) -> syn::Result<(syn::Ident, Vec, TokenStream2)> { + let toks: Vec = stream.into_iter().collect(); + let bar = toks + .iter() + .position(|tt| matches!(tt, TokenTree2::Punct(p) if p.as_char() == '|')) + .ok_or_else(|| err_tokens(&toks, "refinement type must contain `|`"))?; + let (binder, binder_ty) = split_name_type(&toks[..bar])?; + let formula: TokenStream2 = toks[bar + 1..].iter().cloned().collect(); + Ok((binder, binder_ty, formula)) +} + +/// Splits the tokens following an opening `<` at top level by commas, stopping +/// at the matching `>`. +fn split_angle_args(tokens: &[TokenTree2]) -> Vec> { + let mut args = Vec::new(); + let mut cur = Vec::new(); + let mut depth = 1usize; + for tt in tokens { + if let TokenTree2::Punct(p) = tt { + match p.as_char() { + '<' => { + depth += 1; + cur.push(tt.clone()); + continue; + } + '>' => { + depth -= 1; + if depth == 0 { + break; + } + cur.push(tt.clone()); + continue; + } + ',' if depth == 1 => { + args.push(std::mem::take(&mut cur)); + continue; + } + _ => {} + } + } + cur.push(tt.clone()); + } + if !cur.is_empty() { + args.push(cur); + } + args +} + +fn split_top_level_commas(tokens: &[TokenTree2]) -> Vec> { + let mut out = Vec::new(); + let mut cur = Vec::new(); + let mut depth = 0i32; + for tt in tokens { + if let TokenTree2::Punct(p) = tt { + match p.as_char() { + '<' => depth += 1, + '>' => depth -= 1, + ',' if depth == 0 => { + out.push(std::mem::take(&mut cur)); + continue; + } + _ => {} + } + } + cur.push(tt.clone()); + } + out.push(cur); + out +} + +fn is_lifetime(tokens: &[TokenTree2]) -> bool { + matches!(tokens.first(), Some(TokenTree2::Punct(p)) if p.as_char() == '\'') +} + +fn err_tokens(tokens: &[TokenTree2], msg: &str) -> syn::Error { + let stream: TokenStream2 = tokens.iter().cloned().collect(); + syn::Error::new_spanned(stream, msg) +} + +fn rewrite_self_in_tokens(tokens: TokenStream2) -> TokenStream2 { + tokens + .into_iter() + .map(|tt| match tt { + TokenTree2::Ident(id) if id == "self" => TokenTree2::Ident(format_ident!("self_")), + TokenTree2::Group(g) => { + let inner = rewrite_self_in_tokens(g.stream()); + TokenTree2::Group(proc_macro2::Group::new(g.delimiter(), inner)) + } + other => other, + }) + .collect() +} + +fn refine_fn_name(func: &FnItemWithSignature, path: &[usize]) -> syn::Ident { + let pos = path + .iter() + .map(|i| i.to_string()) + .collect::>() + .join("_"); + format_ident!("_thrust_refine_{}_{}", func.sig().ident, pos) +} + +fn refine_formula_fn( + func: &FnItemWithSignature, + outer_context: Option<&FnOuterItem>, + r: &Refinement, +) -> TokenStream2 { + let name = refine_fn_name(func, &r.path); + let def_generics = generic_params_tokens(&func.sig().generics); + let model_params = fn_params_with_model_ty(&func.sig().inputs); + let model_preds = model_where_predicates(func, outer_context); + let extended_where = extended_where_clause(func, &model_preds); + let binder = &r.binder; + let binder_ty = &r.binder_ty; + let formula = &r.formula; + + quote! { + #[allow(unused_variables)] + #[allow(non_snake_case)] + #[thrust::formula_fn] + fn #name #def_generics( + #binder: <#binder_ty as thrust_models::Model>::Ty, + #model_params + ) -> bool #extended_where { + #formula + } + } +} + +fn refine_path_stmt(func: &FnItemWithSignature, r: &Refinement) -> TokenStream2 { + let name = refine_fn_name(func, &r.path); + let turbofish = generic_turbofish(&func.sig().generics); + let path_prefix = if func.sig().receiver().is_some() { + quote!(Self::) + } else { + quote!() + }; + let pos = r + .path + .iter() + .map(|i| proc_macro2::Literal::usize_unsuffixed(*i)) + .collect::>(); + quote! { + #[thrust::refine(#(#pos),*)] + #path_prefix #name #turbofish; + } +} + fn mentions_self(sig: &syn::Signature) -> bool { struct Visitor { mentions_self: bool, From 302f1b8bd04181210fab6f4b175e693f9eb22a12 Mon Sep 17 00:00:00 2001 From: Claude Date: Tue, 26 May 2026 16:09:32 +0000 Subject: [PATCH 2/8] Model refinement type positions with a structured TypePosition Replace the bare `Vec` type position with `rty::TypePosition` (a `TypePositionRoot` of `Param(idx)` / `Return`, plus a projection of nested type-argument indices). The `#[thrust::refine(..)]` attribute now uses the `result` keyword to select the return instead of the parameter count, which was unintuitive. Adds `Display` (`$1`, `result.0`). --- src/analyze.rs | 6 +-- src/analyze/annot.rs | 50 ++++++++++++++++------- src/refine/template.rs | 46 +++++++++++----------- src/rty.rs | 69 +++++++++++++++++++++++++------- thrust-macros/src/lib.rs | 85 +++++++++++++++++++++++++--------------- 5 files changed, 170 insertions(+), 86 deletions(-) diff --git a/src/analyze.rs b/src/analyze.rs index 65df3712..2111c39f 100644 --- a/src/analyze.rs +++ b/src/analyze.rs @@ -792,7 +792,7 @@ impl<'tcx> Analyzer<'tcx> { /// Collects every `#[thrust::refine(..)]` path statement in the function /// body, returning each `(type position, formula_fn DefId)`. - fn extract_refine_paths(&self, local_def_id: LocalDefId) -> Vec<(Vec, DefId)> { + fn extract_refine_paths(&self, local_def_id: LocalDefId) -> Vec<(rty::TypePosition, DefId)> { let mut out = Vec::new(); let Some(body) = self.tcx.hir_maybe_body_owned_by(local_def_id) else { return out; @@ -812,7 +812,7 @@ impl<'tcx> Analyzer<'tcx> { continue; }; let ts = analyze::annot::extract_annot_tokens(attr.clone()); - let position = analyze::annot::parse_position(&ts); + let position = analyze::annot::parse_type_position(&ts); let rustc_hir::StmtKind::Semi(expr) = stmt.kind else { self.tcx.dcx().span_err( @@ -846,7 +846,7 @@ impl<'tcx> Analyzer<'tcx> { &self, local_def_id: LocalDefId, generic_args: mir_ty::GenericArgsRef<'tcx>, - ) -> Vec<(Vec, rty::Refinement)> { + ) -> Vec<(rty::TypePosition, rty::Refinement)> { let mut out = Vec::new(); for (position, def_id) in self.extract_refine_paths(local_def_id) { let Some(formula_def_id) = def_id.as_local() else { diff --git a/src/analyze/annot.rs b/src/analyze/annot.rs index 0d539a8f..6572da5c 100644 --- a/src/analyze/annot.rs +++ b/src/analyze/annot.rs @@ -211,31 +211,53 @@ pub fn extract_annot_tokens(attr: Attribute) -> TokenStream { d.tokens } -/// Parses a comma-separated list of integer literals (a "type position") from -/// the tokens of a `#[thrust::refine(..)]` attribute. -pub fn parse_position(ts: &TokenStream) -> Vec { +/// Parses a [`rty::TypePosition`] from the tokens of a `#[thrust::refine(..)]` +/// attribute. +/// +/// The first token is the root: the keyword `result` for the return, or an +/// integer for a parameter index. The remaining comma-separated integers form +/// the projection into nested type arguments. For example `result, 0` is the +/// first type-argument of the return, and `1` is the second parameter. +pub fn parse_type_position(ts: &TokenStream) -> rty::TypePosition { use rustc_ast::token::{LitKind, TokenKind}; use rustc_ast::tokenstream::TokenTree; - let mut out = Vec::new(); - for tt in ts.iter() { + let parse_int = |lit: &rustc_ast::token::Lit| -> usize { + assert_eq!( + lit.kind, + LitKind::Integer, + "expected an integer in refine position" + ); + lit.symbol + .as_str() + .parse() + .expect("invalid integer in refine position") + }; + + let mut iter = ts.iter(); + let root = match iter.next() { + Some(TokenTree::Token(t, _)) => match &t.kind { + TokenKind::Ident(sym, _) if sym.as_str() == "result" => rty::TypePositionRoot::Return, + TokenKind::Literal(lit) => { + rty::TypePositionRoot::Param(rty::FunctionParamIdx::from(parse_int(lit))) + } + _ => panic!("unexpected refine position root: {:?}", t), + }, + _ => panic!("empty refine position"), + }; + + let mut projection = Vec::new(); + for tt in iter { match tt { TokenTree::Token(t, _) => match &t.kind { TokenKind::Comma => {} - TokenKind::Literal(lit) if lit.kind == LitKind::Integer => { - out.push( - lit.symbol - .as_str() - .parse() - .expect("invalid integer in refine position"), - ); - } + TokenKind::Literal(lit) => projection.push(parse_int(lit)), _ => panic!("unexpected token in refine position: {:?}", t), }, _ => panic!("unexpected token tree in refine position"), } } - out + rty::TypePosition::new(root, projection) } pub fn split_param(ts: &TokenStream) -> (Ident, TokenStream) { diff --git a/src/refine/template.rs b/src/refine/template.rs index 074b8bb7..04fb8eae 100644 --- a/src/refine/template.rs +++ b/src/refine/template.rs @@ -566,35 +566,35 @@ impl<'tcx, 'a, R> FunctionTemplateTypeBuilder<'tcx, 'a, R> { self } - /// Installs a refinement at a function-type position. The first index of - /// `path` selects a parameter (by index) or the return slot (when it equals - /// the parameter count); the remaining indices descend into the slot's type. + /// Installs a refinement at a [`rty::TypePosition`]. The root selects a + /// parameter or the return slot; the projection then descends into the + /// slot's nested type arguments. pub fn refine( &mut self, - path: &[usize], + position: &rty::TypePosition, refinement: rty::Refinement, ) -> &mut Self { - let (&slot, sub) = path.split_first().expect("refine path must be non-empty"); - let n = self.param_tys.len(); - if slot < n { - let idx = rty::FunctionParamIdx::from(slot); - if !self.param_rtys.contains_key(&idx) { - let ty = self.inner.build(self.param_tys[slot].ty).vacuous(); - self.param_rtys.insert(idx, rty::RefinedType::unrefined(ty)); + match position.root { + rty::TypePositionRoot::Param(idx) => { + if !self.param_rtys.contains_key(&idx) { + let ty = self.inner.build(self.param_tys[idx.index()].ty).vacuous(); + self.param_rtys.insert(idx, rty::RefinedType::unrefined(ty)); + } + self.param_rtys + .get_mut(&idx) + .unwrap() + .install_refinement_at(&position.projection, refinement); } - self.param_rtys - .get_mut(&idx) - .unwrap() - .install_refinement_at(sub, refinement); - } else { - if self.ret_rty.is_none() { - let ty = self.inner.build(self.ret_ty).vacuous(); - self.ret_rty = Some(rty::RefinedType::unrefined(ty)); + rty::TypePositionRoot::Return => { + if self.ret_rty.is_none() { + let ty = self.inner.build(self.ret_ty).vacuous(); + self.ret_rty = Some(rty::RefinedType::unrefined(ty)); + } + self.ret_rty + .as_mut() + .unwrap() + .install_refinement_at(&position.projection, refinement); } - self.ret_rty - .as_mut() - .unwrap() - .install_refinement_at(sub, refinement); } self } diff --git a/src/rty.rs b/src/rty.rs index a7c543c0..e80ef866 100644 --- a/src/rty.rs +++ b/src/rty.rs @@ -83,6 +83,53 @@ where } } +/// Selects a parameter or the return of a function type — the root of a +/// [`TypePosition`]. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum TypePositionRoot { + Param(FunctionParamIdx), + Return, +} + +impl std::fmt::Display for TypePositionRoot { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match self { + TypePositionRoot::Param(idx) => write!(f, "{}", idx), + TypePositionRoot::Return => f.write_str("result"), + } + } +} + +/// A position addressing a sub-type within a function type, used to attach a +/// refinement. +/// +/// The [`root`](Self::root) selects a parameter or the return; the +/// [`projection`](Self::projection) then descends into nested type arguments +/// (enum type-arguments, `Box` pointee). For example, `result.0` addresses the +/// first type-argument of the return type, and `$1` addresses the second +/// parameter. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct TypePosition { + pub root: TypePositionRoot, + pub projection: Vec, +} + +impl TypePosition { + pub fn new(root: TypePositionRoot, projection: Vec) -> Self { + TypePosition { root, projection } + } +} + +impl std::fmt::Display for TypePosition { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "{}", self.root)?; + for p in &self.projection { + write!(f, ".{}", p)?; + } + Ok(()) + } +} + #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)] pub enum FunctionAbi { #[default] @@ -1476,22 +1523,22 @@ where } impl RefinedType { - /// Installs `refinement` at the given type position, descending through the - /// function type (parameters then return), enum type arguments, and `Box` - /// pointees. An empty path replaces the refinement at the current node. + /// Installs `refinement` at the given projection — a path of nested + /// type-argument indices descending through enum type arguments and `Box` + /// pointees. An empty projection replaces the refinement at this node. pub fn install_refinement_at( &mut self, - path: &[usize], + projection: &[usize], refinement: Refinement, ) { - let Some((&step, rest)) = path.split_first() else { + let Some((&step, rest)) = projection.split_first() else { self.refinement = refinement; return; }; match &mut self.ty { Type::Enum(e) => { let arg = e.args.get_mut(TypeParamIdx::from(step)).unwrap_or_else(|| { - panic!("refine position {} out of range for enum type", step) + panic!("refine projection {} out of range for enum type", step) }); arg.install_refinement_at(rest, refinement); } @@ -1499,16 +1546,8 @@ impl RefinedType { assert_eq!(step, 0, "Box type position must be 0"); p.elem.install_refinement_at(rest, refinement); } - Type::Function(f) => { - let n = f.params.len(); - if step < n { - f.params[FunctionParamIdx::from(step)].install_refinement_at(rest, refinement); - } else { - f.ret.install_refinement_at(rest, refinement); - } - } ty => panic!( - "unsupported type at refine position step {}: {:?}", + "unsupported type at refine projection step {}: {:?}", step, ty ), } diff --git a/thrust-macros/src/lib.rs b/thrust-macros/src/lib.rs index d17c914a..50052d5f 100644 --- a/thrust-macros/src/lib.rs +++ b/thrust-macros/src/lib.rs @@ -556,14 +556,23 @@ impl ExpandedTokens { // These lower refinement types (e.g. `List<{ v: i32 | v > 0 }>`) into // `#[thrust::formula_fn]`s plus positioned `#[thrust::refine(..)]` path // statements injected into the function body. The "type position" addresses -// into the function type: the first index selects a parameter (by index) or -// the return slot (== parameter count), and subsequent indices descend into -// generic arguments (enum args / `Box` pointee). +// into the function type: its root selects a parameter (by index) or the +// return (the `result` keyword), and the projection (the remaining indices) +// descends into generic arguments (enum args / `Box` pointee). For example, +// `#[thrust::refine(result, 0)]` is the first type-argument of the return. // --------------------------------------------------------------------------- +/// Root of a refinement's type position: a parameter (by index) or the return. +#[derive(Clone, Copy)] +enum RefineRoot { + Param(usize), + Return, +} + #[derive(Clone)] struct Refinement { - path: Vec, + root: RefineRoot, + projection: Vec, binder: syn::Ident, binder_ty: TokenStream2, formula: TokenStream2, @@ -612,7 +621,7 @@ fn expand_refine(kind: RefineKind, attr: TokenStream, item: TokenStream) -> Toke let mut refinements = Vec::new(); for (root, ty_tokens) in jobs { - if let Err(e) = scan_type(&ty_tokens, &root, &mut refinements) { + if let Err(e) = scan_type(&ty_tokens, root, &[], &mut refinements) { let err = e.to_compile_error(); return quote! { #err #func }.into(); } @@ -656,28 +665,27 @@ fn expand_refine(kind: RefineKind, attr: TokenStream, item: TokenStream) -> Toke .into() } -/// Builds `(root_path, type_tokens)` jobs to scan from the attribute tokens. +/// Builds `(root, type_tokens)` jobs to scan from the attribute tokens. fn build_refine_jobs( kind: RefineKind, func: &FnItemWithSignature, attr_tokens: &[TokenTree2], -) -> syn::Result, Vec)>> { - let param_count = func.sig().inputs.len(); +) -> syn::Result)>> { match kind { RefineKind::Param => { let (name, ty_tokens) = split_name_type(attr_tokens)?; let idx = param_index(func, &name)?; - Ok(vec![(vec![idx], ty_tokens)]) + Ok(vec![(RefineRoot::Param(idx), ty_tokens)]) } - RefineKind::Ret => Ok(vec![(vec![param_count], attr_tokens.to_vec())]), + RefineKind::Ret => Ok(vec![(RefineRoot::Return, attr_tokens.to_vec())]), RefineKind::Sig => { let (args, ret_tokens) = parse_sig_attr(attr_tokens)?; let mut jobs = Vec::new(); for (name, ty_tokens) in args { let idx = param_index(func, &name)?; - jobs.push((vec![idx], ty_tokens)); + jobs.push((RefineRoot::Param(idx), ty_tokens)); } - jobs.push((vec![param_count], ret_tokens)); + jobs.push((RefineRoot::Return, ret_tokens)); Ok(jobs) } } @@ -748,8 +756,14 @@ fn parse_sig_attr( } /// Scans a single type expression, recording every refinement node together -/// with its type position. -fn scan_type(tokens: &[TokenTree2], path: &[usize], out: &mut Vec) -> syn::Result<()> { +/// with its type position (a fixed `root` plus the `projection` accumulated +/// while descending into nested type arguments). +fn scan_type( + tokens: &[TokenTree2], + root: RefineRoot, + projection: &[usize], + out: &mut Vec, +) -> syn::Result<()> { if tokens.is_empty() { return Ok(()); } @@ -760,14 +774,15 @@ fn scan_type(tokens: &[TokenTree2], path: &[usize], out: &mut Vec) - if g.delimiter() == proc_macro2::Delimiter::Brace { let (binder, binder_ty, formula) = split_refinement(g.stream())?; out.push(Refinement { - path: path.to_vec(), + root, + projection: projection.to_vec(), binder, binder_ty: binder_ty.iter().cloned().collect(), formula, }); // The refinement's own type sits at the same position; descend // into it to find further nested refinements. - scan_type(&binder_ty, path, out)?; + scan_type(&binder_ty, root, projection, out)?; return Ok(()); } } @@ -782,9 +797,9 @@ fn scan_type(tokens: &[TokenTree2], path: &[usize], out: &mut Vec) - if is_lifetime(&arg) { continue; } - let mut child = path.to_vec(); + let mut child = projection.to_vec(); child.push(type_idx); - scan_type(&arg, &child, out)?; + scan_type(&arg, root, &child, out)?; type_idx += 1; } } @@ -890,12 +905,14 @@ fn rewrite_self_in_tokens(tokens: TokenStream2) -> TokenStream2 { .collect() } -fn refine_fn_name(func: &FnItemWithSignature, path: &[usize]) -> syn::Ident { - let pos = path - .iter() - .map(|i| i.to_string()) - .collect::>() - .join("_"); +fn refine_fn_name(func: &FnItemWithSignature, r: &Refinement) -> syn::Ident { + let mut pos = match r.root { + RefineRoot::Param(i) => format!("p{}", i), + RefineRoot::Return => "ret".to_string(), + }; + for p in &r.projection { + pos.push_str(&format!("_{}", p)); + } format_ident!("_thrust_refine_{}_{}", func.sig().ident, pos) } @@ -904,7 +921,7 @@ fn refine_formula_fn( outer_context: Option<&FnOuterItem>, r: &Refinement, ) -> TokenStream2 { - let name = refine_fn_name(func, &r.path); + let name = refine_fn_name(func, r); let def_generics = generic_params_tokens(&func.sig().generics); let model_params = fn_params_with_model_ty(&func.sig().inputs); let model_preds = model_where_predicates(func, outer_context); @@ -927,20 +944,26 @@ fn refine_formula_fn( } fn refine_path_stmt(func: &FnItemWithSignature, r: &Refinement) -> TokenStream2 { - let name = refine_fn_name(func, &r.path); + let name = refine_fn_name(func, r); let turbofish = generic_turbofish(&func.sig().generics); let path_prefix = if func.sig().receiver().is_some() { quote!(Self::) } else { quote!() }; - let pos = r - .path + let root = match r.root { + RefineRoot::Param(i) => { + let lit = proc_macro2::Literal::usize_unsuffixed(i); + quote!(#lit) + } + RefineRoot::Return => quote!(result), + }; + let projection = r + .projection .iter() - .map(|i| proc_macro2::Literal::usize_unsuffixed(*i)) - .collect::>(); + .map(|i| proc_macro2::Literal::usize_unsuffixed(*i)); quote! { - #[thrust::refine(#(#pos),*)] + #[thrust::refine(#root #(, #projection)*)] #path_prefix #name #turbofish; } } From b889d56f407252657ae1e1e032e1841980f2ee3b Mon Sep 17 00:00:00 2001 From: Claude Date: Wed, 27 May 2026 01:56:35 +0000 Subject: [PATCH 3/8] Redesign TypePosition as a flat sequence of TypePositionStep Replace the previous TypePositionRoot + projection: Vec split with a flat Vec, where each step is one of: - Param(FunctionParamIdx): navigate into a function type's parameter - Return: navigate into a function type's return slot - TypeArg(usize): navigate into a generic type's type argument This makes the path representation uniform across all type levels, enabling future support for refinements on positions inside higher-order function types (e.g. [$0, result] for the return type of a function-typed first parameter). The attribute encoding changes accordingly: - result (ident) => Return - integer i => Param(i) - bracket group [i] => TypeArg(i) The macro-side RefineRoot+projection split is similarly replaced with a flat Vec in the Refinement struct, and scan_type now threads the full steps Vec through recursive calls. https://claude.ai/code/session_01Km9xwaVaGQjnAy1iHsPewa --- src/analyze/annot.rs | 62 ++++++++++++------- src/refine/template.rs | 32 ++++++---- src/rty.rs | 129 +++++++++++++++++++++++++++------------ thrust-macros/src/lib.rs | 101 +++++++++++++++++------------- 4 files changed, 210 insertions(+), 114 deletions(-) diff --git a/src/analyze/annot.rs b/src/analyze/annot.rs index 6572da5c..91ce4072 100644 --- a/src/analyze/annot.rs +++ b/src/analyze/annot.rs @@ -214,12 +214,19 @@ pub fn extract_annot_tokens(attr: Attribute) -> TokenStream { /// Parses a [`rty::TypePosition`] from the tokens of a `#[thrust::refine(..)]` /// attribute. /// -/// The first token is the root: the keyword `result` for the return, or an -/// integer for a parameter index. The remaining comma-separated integers form -/// the projection into nested type arguments. For example `result, 0` is the -/// first type-argument of the return, and `1` is the second parameter. +/// Tokens are comma-separated steps. Each step is one of: +/// - The keyword `result` → [`rty::TypePositionStep::Return`] (navigate to a +/// function type's return slot). +/// - An integer literal `i` → [`rty::TypePositionStep::Param`]`(i)` (navigate +/// to the `i`-th parameter of a function type). +/// - A bracket group `[i]` → [`rty::TypePositionStep::TypeArg`]`(i)` (navigate +/// to the `i`-th type argument of a generic type such as an enum or `Box`). +/// +/// Examples: `result` is the return; `0` is the first parameter; `0, [0]` is +/// the first type-argument of the first parameter; `0, result` is the return of +/// a function-typed first parameter. pub fn parse_type_position(ts: &TokenStream) -> rty::TypePosition { - use rustc_ast::token::{LitKind, TokenKind}; + use rustc_ast::token::{Delimiter, LitKind, TokenKind}; use rustc_ast::tokenstream::TokenTree; let parse_int = |lit: &rustc_ast::token::Lit| -> usize { @@ -234,30 +241,43 @@ pub fn parse_type_position(ts: &TokenStream) -> rty::TypePosition { .expect("invalid integer in refine position") }; - let mut iter = ts.iter(); - let root = match iter.next() { - Some(TokenTree::Token(t, _)) => match &t.kind { - TokenKind::Ident(sym, _) if sym.as_str() == "result" => rty::TypePositionRoot::Return, - TokenKind::Literal(lit) => { - rty::TypePositionRoot::Param(rty::FunctionParamIdx::from(parse_int(lit))) - } - _ => panic!("unexpected refine position root: {:?}", t), - }, - _ => panic!("empty refine position"), - }; - - let mut projection = Vec::new(); - for tt in iter { + let mut steps = Vec::new(); + for tt in ts.iter() { match tt { TokenTree::Token(t, _) => match &t.kind { TokenKind::Comma => {} - TokenKind::Literal(lit) => projection.push(parse_int(lit)), + TokenKind::Ident(sym, _) if sym.as_str() == "result" => { + steps.push(rty::TypePositionStep::Return); + } + TokenKind::Literal(lit) => { + steps.push(rty::TypePositionStep::Param(rty::FunctionParamIdx::from( + parse_int(lit), + ))); + } _ => panic!("unexpected token in refine position: {:?}", t), }, + TokenTree::Delimited(_, _, Delimiter::Bracket, inner) => { + let mut inner_iter = inner.iter(); + let i = match inner_iter.next() { + Some(TokenTree::Token(t, _)) => match &t.kind { + TokenKind::Literal(lit) => parse_int(lit), + _ => panic!("expected integer inside [..] refine step: {:?}", t), + }, + _ => panic!("expected integer inside [..] refine step"), + }; + assert!( + inner_iter.next().is_none(), + "expected exactly one integer inside [..] refine step" + ); + steps.push(rty::TypePositionStep::TypeArg(i)); + } _ => panic!("unexpected token tree in refine position"), } } - rty::TypePosition::new(root, projection) + + assert!(!steps.is_empty(), "empty refine position"); + let first = steps.remove(0); + rty::TypePosition::new(first, steps) } pub fn split_param(ts: &TokenStream) -> (Ident, TokenStream) { diff --git a/src/refine/template.rs b/src/refine/template.rs index 04fb8eae..50408612 100644 --- a/src/refine/template.rs +++ b/src/refine/template.rs @@ -566,26 +566,33 @@ impl<'tcx, 'a, R> FunctionTemplateTypeBuilder<'tcx, 'a, R> { self } - /// Installs a refinement at a [`rty::TypePosition`]. The root selects a - /// parameter or the return slot; the projection then descends into the - /// slot's nested type arguments. + /// Installs a refinement at a [`rty::TypePosition`]. + /// + /// The first step must be [`rty::TypePositionStep::Param`] or + /// [`rty::TypePositionStep::Return`]; the remaining steps are forwarded to + /// [`rty::RefinedType::install_refinement_at`]. pub fn refine( &mut self, position: &rty::TypePosition, refinement: rty::Refinement, ) -> &mut Self { - match position.root { - rty::TypePositionRoot::Param(idx) => { - if !self.param_rtys.contains_key(&idx) { + let (first, rest) = match position.steps().split_first() { + Some(pair) => pair, + None => panic!("empty TypePosition"), + }; + match first { + rty::TypePositionStep::Param(idx) => { + if !self.param_rtys.contains_key(idx) { let ty = self.inner.build(self.param_tys[idx.index()].ty).vacuous(); - self.param_rtys.insert(idx, rty::RefinedType::unrefined(ty)); + self.param_rtys + .insert(*idx, rty::RefinedType::unrefined(ty)); } self.param_rtys - .get_mut(&idx) + .get_mut(idx) .unwrap() - .install_refinement_at(&position.projection, refinement); + .install_refinement_at(rest, refinement); } - rty::TypePositionRoot::Return => { + rty::TypePositionStep::Return => { if self.ret_rty.is_none() { let ty = self.inner.build(self.ret_ty).vacuous(); self.ret_rty = Some(rty::RefinedType::unrefined(ty)); @@ -593,7 +600,10 @@ impl<'tcx, 'a, R> FunctionTemplateTypeBuilder<'tcx, 'a, R> { self.ret_rty .as_mut() .unwrap() - .install_refinement_at(&position.projection, refinement); + .install_refinement_at(rest, refinement); + } + rty::TypePositionStep::TypeArg(_) => { + panic!("TypePosition must start with Param or Return, not TypeArg"); } } self diff --git a/src/rty.rs b/src/rty.rs index e80ef866..b0f10f3d 100644 --- a/src/rty.rs +++ b/src/rty.rs @@ -83,48 +83,81 @@ where } } -/// Selects a parameter or the return of a function type — the root of a -/// [`TypePosition`]. +/// One step in a [`TypePosition`] path. +/// +/// A path is a sequence of steps that addresses a sub-type within a +/// (potentially nested) function signature: +/// - [`Param`](Self::Param) / [`Return`](Self::Return) navigate into a +/// function type's parameter or return slot respectively. +/// - [`TypeArg`](Self::TypeArg) navigates into the `i`-th type argument of a +/// generic type (enum, `Box`, etc.). +/// +/// Using distinct variants for function navigation ([`Param`](Self::Param), +/// [`Return`](Self::Return)) and generic-arg navigation +/// ([`TypeArg`](Self::TypeArg)) allows the same path representation to address +/// positions inside higher-order function types. For example, `[$0, result]` +/// addresses the return type of a function-typed first parameter. #[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum TypePositionRoot { +pub enum TypePositionStep { + /// Navigate to the `i`-th parameter of a function type. Param(FunctionParamIdx), + /// Navigate to the return type of a function type. Return, + /// Navigate to the `i`-th type argument of a generic type (enum, `Box`, …). + TypeArg(usize), } -impl std::fmt::Display for TypePositionRoot { +impl std::fmt::Display for TypePositionStep { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { match self { - TypePositionRoot::Param(idx) => write!(f, "{}", idx), - TypePositionRoot::Return => f.write_str("result"), + TypePositionStep::Param(idx) => write!(f, "{}", idx), + TypePositionStep::Return => f.write_str("result"), + TypePositionStep::TypeArg(i) => write!(f, "[{}]", i), } } } -/// A position addressing a sub-type within a function type, used to attach a -/// refinement. +/// A path addressing a sub-type in a function's type signature, used to attach +/// a refinement. /// -/// The [`root`](Self::root) selects a parameter or the return; the -/// [`projection`](Self::projection) then descends into nested type arguments -/// (enum type-arguments, `Box` pointee). For example, `result.0` addresses the -/// first type-argument of the return type, and `$1` addresses the second -/// parameter. +/// The first step must be [`TypePositionStep::Param`] or +/// [`TypePositionStep::Return`] (selecting which slot of the top-level function +/// type to enter). Subsequent steps can freely combine +/// [`TypePositionStep::Param`] / [`TypePositionStep::Return`] (for +/// function-typed positions) and [`TypePositionStep::TypeArg`] (for generic +/// types), enabling positions inside higher-order function types. +/// +/// Examples (function `fn f(x: List) -> Box`): +/// - `[$0]` — parameter `x`. +/// - `[result]` — the return type. +/// - `[$0, [0]]` — the first type arg of `x`. +/// - `[result, [0]]` — the pointee of the `Box` return. +/// - `[$0, result]` — the return of a function-typed param `x`. #[derive(Debug, Clone, PartialEq, Eq)] pub struct TypePosition { - pub root: TypePositionRoot, - pub projection: Vec, + steps: Vec, } impl TypePosition { - pub fn new(root: TypePositionRoot, projection: Vec) -> Self { - TypePosition { root, projection } + pub fn new(first: TypePositionStep, rest: Vec) -> Self { + let mut steps = vec![first]; + steps.extend(rest); + TypePosition { steps } + } + + pub fn steps(&self) -> &[TypePositionStep] { + &self.steps } } impl std::fmt::Display for TypePosition { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - write!(f, "{}", self.root)?; - for p in &self.projection { - write!(f, ".{}", p)?; + let mut iter = self.steps.iter(); + if let Some(first) = iter.next() { + write!(f, "{}", first)?; + } + for s in iter { + write!(f, ".{}", s)?; } Ok(()) } @@ -1523,33 +1556,49 @@ where } impl RefinedType { - /// Installs `refinement` at the given projection — a path of nested - /// type-argument indices descending through enum type arguments and `Box` - /// pointees. An empty projection replaces the refinement at this node. + /// Installs `refinement` at the sub-type addressed by `steps`. + /// + /// An empty `steps` slice replaces the refinement at this node. Each step + /// in the slice navigates one level deeper: + /// - [`TypePositionStep::TypeArg`] descends into enum type arguments or the + /// `Box` pointee. + /// - [`TypePositionStep::Param`] / [`TypePositionStep::Return`] descend + /// into a function-typed position's parameter or return slot. pub fn install_refinement_at( &mut self, - projection: &[usize], + steps: &[TypePositionStep], refinement: Refinement, ) { - let Some((&step, rest)) = projection.split_first() else { + let Some((step, rest)) = steps.split_first() else { self.refinement = refinement; return; }; - match &mut self.ty { - Type::Enum(e) => { - let arg = e.args.get_mut(TypeParamIdx::from(step)).unwrap_or_else(|| { - panic!("refine projection {} out of range for enum type", step) - }); - arg.install_refinement_at(rest, refinement); - } - Type::Pointer(p) => { - assert_eq!(step, 0, "Box type position must be 0"); - p.elem.install_refinement_at(rest, refinement); - } - ty => panic!( - "unsupported type at refine projection step {}: {:?}", - step, ty - ), + match step { + TypePositionStep::TypeArg(i) => match &mut self.ty { + Type::Enum(e) => { + let arg = e.args.get_mut(TypeParamIdx::from(*i)).unwrap_or_else(|| { + panic!("refine step [{}] out of range for enum type", i) + }); + arg.install_refinement_at(rest, refinement); + } + Type::Pointer(p) => { + assert_eq!(*i, 0, "Box type position must be [0]"); + p.elem.install_refinement_at(rest, refinement); + } + ty => panic!("TypeArg step on unsupported type: {:?}", ty), + }, + TypePositionStep::Param(idx) => match &mut self.ty { + Type::Function(func) => { + func.params[*idx].install_refinement_at(rest, refinement); + } + ty => panic!("Param step on non-function type: {:?}", ty), + }, + TypePositionStep::Return => match &mut self.ty { + Type::Function(func) => { + func.ret.install_refinement_at(rest, refinement); + } + ty => panic!("Return step on non-function type: {:?}", ty), + }, } } } diff --git a/thrust-macros/src/lib.rs b/thrust-macros/src/lib.rs index 50052d5f..f4619d13 100644 --- a/thrust-macros/src/lib.rs +++ b/thrust-macros/src/lib.rs @@ -562,17 +562,25 @@ impl ExpandedTokens { // `#[thrust::refine(result, 0)]` is the first type-argument of the return. // --------------------------------------------------------------------------- -/// Root of a refinement's type position: a parameter (by index) or the return. +/// One step in a refinement's type-position path. +/// +/// Mirrors [`rty::TypePositionStep`] on the plugin side and uses the same +/// attribute encoding: +/// - [`Param`](Self::Param) / [`Return`](Self::Return) navigate into a function +/// type; encoded as an integer literal / the `result` keyword. +/// - [`TypeArg`](Self::TypeArg) navigates into a generic type argument; encoded +/// as a bracket group `[i]`. #[derive(Clone, Copy)] -enum RefineRoot { +enum RefineStep { Param(usize), Return, + TypeArg(usize), } #[derive(Clone)] struct Refinement { - root: RefineRoot, - projection: Vec, + /// Full type-position path from the function root to the refined type. + steps: Vec, binder: syn::Ident, binder_ty: TokenStream2, formula: TokenStream2, @@ -620,8 +628,8 @@ fn expand_refine(kind: RefineKind, attr: TokenStream, item: TokenStream) -> Toke }; let mut refinements = Vec::new(); - for (root, ty_tokens) in jobs { - if let Err(e) = scan_type(&ty_tokens, root, &[], &mut refinements) { + for (root_steps, ty_tokens) in jobs { + if let Err(e) = scan_type(&ty_tokens, root_steps, &mut refinements) { let err = e.to_compile_error(); return quote! { #err #func }.into(); } @@ -665,27 +673,32 @@ fn expand_refine(kind: RefineKind, attr: TokenStream, item: TokenStream) -> Toke .into() } -/// Builds `(root, type_tokens)` jobs to scan from the attribute tokens. +/// Builds `(root_steps, type_tokens)` jobs to scan from the attribute tokens. +/// +/// Each job's `root_steps` contains the initial [`RefineStep`]s that fix the +/// position of the type expression within the function signature (e.g. +/// `[Param(0)]` for the first parameter). [`scan_type`] will append further +/// [`RefineStep::TypeArg`] steps as it descends into generic type arguments. fn build_refine_jobs( kind: RefineKind, func: &FnItemWithSignature, attr_tokens: &[TokenTree2], -) -> syn::Result)>> { +) -> syn::Result, Vec)>> { match kind { RefineKind::Param => { let (name, ty_tokens) = split_name_type(attr_tokens)?; let idx = param_index(func, &name)?; - Ok(vec![(RefineRoot::Param(idx), ty_tokens)]) + Ok(vec![(vec![RefineStep::Param(idx)], ty_tokens)]) } - RefineKind::Ret => Ok(vec![(RefineRoot::Return, attr_tokens.to_vec())]), + RefineKind::Ret => Ok(vec![(vec![RefineStep::Return], attr_tokens.to_vec())]), RefineKind::Sig => { let (args, ret_tokens) = parse_sig_attr(attr_tokens)?; let mut jobs = Vec::new(); for (name, ty_tokens) in args { let idx = param_index(func, &name)?; - jobs.push((RefineRoot::Param(idx), ty_tokens)); + jobs.push((vec![RefineStep::Param(idx)], ty_tokens)); } - jobs.push((RefineRoot::Return, ret_tokens)); + jobs.push((vec![RefineStep::Return], ret_tokens)); Ok(jobs) } } @@ -755,13 +768,16 @@ fn parse_sig_attr( Ok((args, rest.to_vec())) } -/// Scans a single type expression, recording every refinement node together -/// with its type position (a fixed `root` plus the `projection` accumulated -/// while descending into nested type arguments). +/// Scans a type expression and records every refinement node with its full +/// type-position path (`steps`). +/// +/// `steps` holds the path from the function root to the current type node. +/// When a refinement `{binder: ty | formula}` is found the current `steps` are +/// recorded; when descending into generic type arguments a +/// [`RefineStep::TypeArg`]`(i)` step is appended to `steps`. fn scan_type( tokens: &[TokenTree2], - root: RefineRoot, - projection: &[usize], + steps: Vec, out: &mut Vec, ) -> syn::Result<()> { if tokens.is_empty() { @@ -774,15 +790,13 @@ fn scan_type( if g.delimiter() == proc_macro2::Delimiter::Brace { let (binder, binder_ty, formula) = split_refinement(g.stream())?; out.push(Refinement { - root, - projection: projection.to_vec(), + steps: steps.clone(), binder, binder_ty: binder_ty.iter().cloned().collect(), formula, }); - // The refinement's own type sits at the same position; descend - // into it to find further nested refinements. - scan_type(&binder_ty, root, projection, out)?; + // Descend into the binder type for nested refinements. + scan_type(&binder_ty, steps, out)?; return Ok(()); } } @@ -797,9 +811,9 @@ fn scan_type( if is_lifetime(&arg) { continue; } - let mut child = projection.to_vec(); - child.push(type_idx); - scan_type(&arg, root, &child, out)?; + let mut child = steps.clone(); + child.push(RefineStep::TypeArg(type_idx)); + scan_type(&arg, child, out)?; type_idx += 1; } } @@ -906,13 +920,16 @@ fn rewrite_self_in_tokens(tokens: TokenStream2) -> TokenStream2 { } fn refine_fn_name(func: &FnItemWithSignature, r: &Refinement) -> syn::Ident { - let mut pos = match r.root { - RefineRoot::Param(i) => format!("p{}", i), - RefineRoot::Return => "ret".to_string(), - }; - for p in &r.projection { - pos.push_str(&format!("_{}", p)); - } + let pos = r + .steps + .iter() + .map(|s| match s { + RefineStep::Param(i) => format!("p{}", i), + RefineStep::Return => "ret".to_string(), + RefineStep::TypeArg(i) => format!("t{}", i), + }) + .collect::>() + .join("_"); format_ident!("_thrust_refine_{}_{}", func.sig().ident, pos) } @@ -951,19 +968,19 @@ fn refine_path_stmt(func: &FnItemWithSignature, r: &Refinement) -> TokenStream2 } else { quote!() }; - let root = match r.root { - RefineRoot::Param(i) => { - let lit = proc_macro2::Literal::usize_unsuffixed(i); + let encoded_steps = r.steps.iter().map(|s| match s { + RefineStep::Param(i) => { + let lit = proc_macro2::Literal::usize_unsuffixed(*i); quote!(#lit) } - RefineRoot::Return => quote!(result), - }; - let projection = r - .projection - .iter() - .map(|i| proc_macro2::Literal::usize_unsuffixed(*i)); + RefineStep::Return => quote!(result), + RefineStep::TypeArg(i) => { + let lit = proc_macro2::Literal::usize_unsuffixed(*i); + quote!([#lit]) + } + }); quote! { - #[thrust::refine(#root #(, #projection)*)] + #[thrust::refine(#(#encoded_steps),*)] #path_prefix #name #turbofish; } } From c2c0381450d3109934e898ee94ef6b4ad549fe00 Mon Sep 17 00:00:00 2001 From: Claude Date: Wed, 27 May 2026 08:49:56 +0000 Subject: [PATCH 4/8] Rename thrust::refine attribute to thrust::refinement_path Align the macro-emitted attribute with the existing requires_path / ensures_path convention: the _path suffix conveys that the attribute's target is a path to a formula_fn. The symbol-path helper follows suit (refine_path -> refinement_path_path). https://claude.ai/code/session_01Km9xwaVaGQjnAy1iHsPewa --- src/analyze.rs | 10 +++++----- src/analyze/annot.rs | 8 ++++---- thrust-macros/src/lib.rs | 15 ++++++++------- 3 files changed, 17 insertions(+), 16 deletions(-) diff --git a/src/analyze.rs b/src/analyze.rs index 2111c39f..e16fb08b 100644 --- a/src/analyze.rs +++ b/src/analyze.rs @@ -790,8 +790,8 @@ impl<'tcx> Analyzer<'tcx> { ensure_annot } - /// Collects every `#[thrust::refine(..)]` path statement in the function - /// body, returning each `(type position, formula_fn DefId)`. + /// Collects every `#[thrust::refinement_path(..)]` path statement in the + /// function body, returning each `(type position, formula_fn DefId)`. fn extract_refine_paths(&self, local_def_id: LocalDefId) -> Vec<(rty::TypePosition, DefId)> { let mut out = Vec::new(); let Some(body) = self.tcx.hir_maybe_body_owned_by(local_def_id) else { @@ -800,7 +800,7 @@ impl<'tcx> Analyzer<'tcx> { let rustc_hir::ExprKind::Block(block, _) = body.value.kind else { return out; }; - let attr_path = analyze::annot::refine_path(); + let attr_path = analyze::annot::refinement_path_path(); let typeck = self.tcx.typeck(local_def_id); for stmt in block.stmts { let Some(attr) = self @@ -840,8 +840,8 @@ impl<'tcx> Analyzer<'tcx> { out } - /// Resolves every `#[thrust::refine(..)]` annotation into a positioned - /// refinement, by translating the referenced formula function. + /// Resolves every `#[thrust::refinement_path(..)]` annotation into a + /// positioned refinement, by translating the referenced formula function. pub fn extract_refine_annots( &self, local_def_id: LocalDefId, diff --git a/src/analyze/annot.rs b/src/analyze/annot.rs index 91ce4072..6a5da38c 100644 --- a/src/analyze/annot.rs +++ b/src/analyze/annot.rs @@ -61,8 +61,8 @@ pub fn ensures_path_path() -> [Symbol; 2] { [Symbol::intern("thrust"), Symbol::intern("ensures_path")] } -pub fn refine_path() -> [Symbol; 2] { - [Symbol::intern("thrust"), Symbol::intern("refine")] +pub fn refinement_path_path() -> [Symbol; 2] { + [Symbol::intern("thrust"), Symbol::intern("refinement_path")] } pub fn model_ty_path() -> [Symbol; 3] { @@ -211,8 +211,8 @@ pub fn extract_annot_tokens(attr: Attribute) -> TokenStream { d.tokens } -/// Parses a [`rty::TypePosition`] from the tokens of a `#[thrust::refine(..)]` -/// attribute. +/// Parses a [`rty::TypePosition`] from the tokens of a +/// `#[thrust::refinement_path(..)]` attribute. /// /// Tokens are comma-separated steps. Each step is one of: /// - The keyword `result` → [`rty::TypePositionStep::Return`] (navigate to a diff --git a/thrust-macros/src/lib.rs b/thrust-macros/src/lib.rs index f4619d13..03e351e8 100644 --- a/thrust-macros/src/lib.rs +++ b/thrust-macros/src/lib.rs @@ -554,12 +554,13 @@ impl ExpandedTokens { // Refinement-type annotations: `param`, `ret`, `sig`. // // These lower refinement types (e.g. `List<{ v: i32 | v > 0 }>`) into -// `#[thrust::formula_fn]`s plus positioned `#[thrust::refine(..)]` path -// statements injected into the function body. The "type position" addresses -// into the function type: its root selects a parameter (by index) or the -// return (the `result` keyword), and the projection (the remaining indices) -// descends into generic arguments (enum args / `Box` pointee). For example, -// `#[thrust::refine(result, 0)]` is the first type-argument of the return. +// `#[thrust::formula_fn]`s plus positioned `#[thrust::refinement_path(..)]` +// path statements injected into the function body. The "type position" +// addresses into the function type: a parameter (by index) or the return (the +// `result` keyword) selects a function slot, and bracket steps (`[i]`) descend +// into generic arguments (enum args / `Box` pointee). For example, +// `#[thrust::refinement_path(result, [0])]` is the first type-argument of the +// return. // --------------------------------------------------------------------------- /// One step in a refinement's type-position path. @@ -980,7 +981,7 @@ fn refine_path_stmt(func: &FnItemWithSignature, r: &Refinement) -> TokenStream2 } }); quote! { - #[thrust::refine(#(#encoded_steps),*)] + #[thrust::refinement_path(#(#encoded_steps),*)] #path_prefix #name #turbofish; } } From fad37e24b044b82877c23aa2c8079bb32f78bd15 Mon Sep 17 00:00:00 2001 From: Claude Date: Wed, 27 May 2026 12:34:27 +0000 Subject: [PATCH 5/8] Extract refinement-type annotation macros into refine module Move the param/ret/sig expansion logic and its token-scanning helpers out of the crate root into a dedicated refine module, leaving only the thin proc-macro entry points in lib.rs. Shrinks lib.rs from ~1232 to ~813 lines. https://claude.ai/code/session_01Km9xwaVaGQjnAy1iHsPewa --- thrust-macros/src/lib.rs | 431 +----------------------------------- thrust-macros/src/refine.rs | 428 +++++++++++++++++++++++++++++++++++ 2 files changed, 434 insertions(+), 425 deletions(-) create mode 100644 thrust-macros/src/refine.rs diff --git a/thrust-macros/src/lib.rs b/thrust-macros/src/lib.rs index 03e351e8..70969bb9 100644 --- a/thrust-macros/src/lib.rs +++ b/thrust-macros/src/lib.rs @@ -1,11 +1,13 @@ use proc_macro::TokenStream; -use proc_macro2::{TokenStream as TokenStream2, TokenTree as TokenTree2}; +use proc_macro2::TokenStream as TokenStream2; use quote::{format_ident, quote, ToTokens}; use syn::{ parse_macro_input, punctuated::Punctuated, FnArg, GenericParam, Generics, TypeParamBound, WherePredicate, }; +mod refine; + #[derive(Debug, Clone)] enum FnOuterItem { ItemImpl(syn::ItemImpl), @@ -550,440 +552,19 @@ impl ExpandedTokens { } } -// --------------------------------------------------------------------------- -// Refinement-type annotations: `param`, `ret`, `sig`. -// -// These lower refinement types (e.g. `List<{ v: i32 | v > 0 }>`) into -// `#[thrust::formula_fn]`s plus positioned `#[thrust::refinement_path(..)]` -// path statements injected into the function body. The "type position" -// addresses into the function type: a parameter (by index) or the return (the -// `result` keyword) selects a function slot, and bracket steps (`[i]`) descend -// into generic arguments (enum args / `Box` pointee). For example, -// `#[thrust::refinement_path(result, [0])]` is the first type-argument of the -// return. -// --------------------------------------------------------------------------- - -/// One step in a refinement's type-position path. -/// -/// Mirrors [`rty::TypePositionStep`] on the plugin side and uses the same -/// attribute encoding: -/// - [`Param`](Self::Param) / [`Return`](Self::Return) navigate into a function -/// type; encoded as an integer literal / the `result` keyword. -/// - [`TypeArg`](Self::TypeArg) navigates into a generic type argument; encoded -/// as a bracket group `[i]`. -#[derive(Clone, Copy)] -enum RefineStep { - Param(usize), - Return, - TypeArg(usize), -} - -#[derive(Clone)] -struct Refinement { - /// Full type-position path from the function root to the refined type. - steps: Vec, - binder: syn::Ident, - binder_ty: TokenStream2, - formula: TokenStream2, -} - #[proc_macro_attribute] pub fn param(attr: TokenStream, item: TokenStream) -> TokenStream { - expand_refine(RefineKind::Param, attr, item) + refine::expand_refine(refine::RefineKind::Param, attr, item) } #[proc_macro_attribute] pub fn ret(attr: TokenStream, item: TokenStream) -> TokenStream { - expand_refine(RefineKind::Ret, attr, item) + refine::expand_refine(refine::RefineKind::Ret, attr, item) } #[proc_macro_attribute] pub fn sig(attr: TokenStream, item: TokenStream) -> TokenStream { - expand_refine(RefineKind::Sig, attr, item) -} - -enum RefineKind { - Param, - Ret, - Sig, -} - -fn expand_refine(kind: RefineKind, attr: TokenStream, item: TokenStream) -> TokenStream { - let mut func = parse_macro_input!(item as FnItemWithSignature); - - let outer_context = match extract_outer_context(&func) { - Ok(ctx) => ctx, - Err(e) => { - let err = e.to_compile_error(); - return quote! { #err #func }.into(); - } - }; - - let attr_tokens: Vec = TokenStream2::from(attr).into_iter().collect(); - let jobs = match build_refine_jobs(kind, &func, &attr_tokens) { - Ok(jobs) => jobs, - Err(e) => { - let err = e.to_compile_error(); - return quote! { #err #func }.into(); - } - }; - - let mut refinements = Vec::new(); - for (root_steps, ty_tokens) in jobs { - if let Err(e) = scan_type(&ty_tokens, root_steps, &mut refinements) { - let err = e.to_compile_error(); - return quote! { #err #func }.into(); - } - } - - if refinements.is_empty() { - return func.into_token_stream().into(); - } - - let has_receiver = func.sig().receiver().is_some(); - let mut formula_fns = Vec::new(); - let mut path_stmts = Vec::new(); - for mut r in refinements { - if has_receiver { - r.formula = rewrite_self_in_tokens(r.formula); - } - formula_fns.push(refine_formula_fn(&func, outer_context.as_ref(), &r)); - path_stmts.push(refine_path_stmt(&func, &r)); - } - - let Some(block) = func.block_mut() else { - let err = syn::Error::new_spanned( - func.sig().ident.clone(), - "refinement-type annotations require a function body", - ) - .into_compile_error(); - return quote! { #err #func }.into(); - }; - let orig_stmts = block.stmts.drain(..).collect::>(); - *block = syn::parse_quote!({ - #(#path_stmts)* - #(#orig_stmts)* - }); - func.attrs_mut() - .push(syn::parse_quote!(#[allow(path_statements)])); - - quote! { - #(#formula_fns)* - #func - } - .into() -} - -/// Builds `(root_steps, type_tokens)` jobs to scan from the attribute tokens. -/// -/// Each job's `root_steps` contains the initial [`RefineStep`]s that fix the -/// position of the type expression within the function signature (e.g. -/// `[Param(0)]` for the first parameter). [`scan_type`] will append further -/// [`RefineStep::TypeArg`] steps as it descends into generic type arguments. -fn build_refine_jobs( - kind: RefineKind, - func: &FnItemWithSignature, - attr_tokens: &[TokenTree2], -) -> syn::Result, Vec)>> { - match kind { - RefineKind::Param => { - let (name, ty_tokens) = split_name_type(attr_tokens)?; - let idx = param_index(func, &name)?; - Ok(vec![(vec![RefineStep::Param(idx)], ty_tokens)]) - } - RefineKind::Ret => Ok(vec![(vec![RefineStep::Return], attr_tokens.to_vec())]), - RefineKind::Sig => { - let (args, ret_tokens) = parse_sig_attr(attr_tokens)?; - let mut jobs = Vec::new(); - for (name, ty_tokens) in args { - let idx = param_index(func, &name)?; - jobs.push((vec![RefineStep::Param(idx)], ty_tokens)); - } - jobs.push((vec![RefineStep::Return], ret_tokens)); - Ok(jobs) - } - } -} - -fn param_index(func: &FnItemWithSignature, name: &syn::Ident) -> syn::Result { - let pos = func.sig().inputs.iter().position(|arg| match arg { - FnArg::Receiver(_) => name == "self", - FnArg::Typed(pt) => matches!(&*pt.pat, syn::Pat::Ident(pi) if &pi.ident == name), - }); - pos.ok_or_else(|| { - syn::Error::new_spanned(name, format!("no parameter named `{}` in signature", name)) - }) -} - -/// Parses `name : ` from a flat token slice. -fn split_name_type(tokens: &[TokenTree2]) -> syn::Result<(syn::Ident, Vec)> { - let name = match tokens.first() { - Some(TokenTree2::Ident(id)) => id.clone(), - _ => return Err(err_tokens(tokens, "expected a parameter name")), - }; - match tokens.get(1) { - Some(TokenTree2::Punct(p)) if p.as_char() == ':' => {} - _ => return Err(err_tokens(tokens, "expected `:` after parameter name")), - } - Ok((name, tokens[2..].to_vec())) -} - -/// Parses `fn ( n0: t0 , ... ) -> ret` into `((name, ty_tokens)*, ret_tokens)`. -#[allow(clippy::type_complexity)] -fn parse_sig_attr( - tokens: &[TokenTree2], -) -> syn::Result<(Vec<(syn::Ident, Vec)>, Vec)> { - match tokens.first() { - Some(TokenTree2::Ident(id)) if id == "fn" => {} - _ => return Err(err_tokens(tokens, "expected `fn` in sig annotation")), - } - let arg_group = match tokens.get(1) { - Some(TokenTree2::Group(g)) if g.delimiter() == proc_macro2::Delimiter::Parenthesis => g, - _ => return Err(err_tokens(tokens, "expected `(..)` after `fn`")), - }; - - let mut args = Vec::new(); - let arg_tokens: Vec = arg_group.stream().into_iter().collect(); - for arg in split_top_level_commas(&arg_tokens) { - if arg.is_empty() { - continue; - } - args.push(split_name_type(&arg)?); - } - - // expect `->` then the return type - let mut rest = &tokens[2..]; - match (rest.first(), rest.get(1)) { - (Some(TokenTree2::Punct(a)), Some(TokenTree2::Punct(b))) - if a.as_char() == '-' && b.as_char() == '>' => - { - rest = &rest[2..]; - } - _ => { - return Err(err_tokens( - tokens, - "expected `->` and a return type in sig annotation", - )) - } - } - Ok((args, rest.to_vec())) -} - -/// Scans a type expression and records every refinement node with its full -/// type-position path (`steps`). -/// -/// `steps` holds the path from the function root to the current type node. -/// When a refinement `{binder: ty | formula}` is found the current `steps` are -/// recorded; when descending into generic type arguments a -/// [`RefineStep::TypeArg`]`(i)` step is appended to `steps`. -fn scan_type( - tokens: &[TokenTree2], - steps: Vec, - out: &mut Vec, -) -> syn::Result<()> { - if tokens.is_empty() { - return Ok(()); - } - - // A refinement node is exactly a brace-delimited group. - if tokens.len() == 1 { - if let TokenTree2::Group(g) = &tokens[0] { - if g.delimiter() == proc_macro2::Delimiter::Brace { - let (binder, binder_ty, formula) = split_refinement(g.stream())?; - out.push(Refinement { - steps: steps.clone(), - binder, - binder_ty: binder_ty.iter().cloned().collect(), - formula, - }); - // Descend into the binder type for nested refinements. - scan_type(&binder_ty, steps, out)?; - return Ok(()); - } - } - } - - // A nominal type `Name` (`Box` included). - if let TokenTree2::Ident(_) = &tokens[0] { - if let Some(TokenTree2::Punct(p)) = tokens.get(1) { - if p.as_char() == '<' { - let mut type_idx = 0; - for arg in split_angle_args(&tokens[2..]) { - if is_lifetime(&arg) { - continue; - } - let mut child = steps.clone(); - child.push(RefineStep::TypeArg(type_idx)); - scan_type(&arg, child, out)?; - type_idx += 1; - } - } - } - } - - Ok(()) -} - -/// Splits `{ binder : ty | formula }` contents into its parts. -fn split_refinement( - stream: TokenStream2, -) -> syn::Result<(syn::Ident, Vec, TokenStream2)> { - let toks: Vec = stream.into_iter().collect(); - let bar = toks - .iter() - .position(|tt| matches!(tt, TokenTree2::Punct(p) if p.as_char() == '|')) - .ok_or_else(|| err_tokens(&toks, "refinement type must contain `|`"))?; - let (binder, binder_ty) = split_name_type(&toks[..bar])?; - let formula: TokenStream2 = toks[bar + 1..].iter().cloned().collect(); - Ok((binder, binder_ty, formula)) -} - -/// Splits the tokens following an opening `<` at top level by commas, stopping -/// at the matching `>`. -fn split_angle_args(tokens: &[TokenTree2]) -> Vec> { - let mut args = Vec::new(); - let mut cur = Vec::new(); - let mut depth = 1usize; - for tt in tokens { - if let TokenTree2::Punct(p) = tt { - match p.as_char() { - '<' => { - depth += 1; - cur.push(tt.clone()); - continue; - } - '>' => { - depth -= 1; - if depth == 0 { - break; - } - cur.push(tt.clone()); - continue; - } - ',' if depth == 1 => { - args.push(std::mem::take(&mut cur)); - continue; - } - _ => {} - } - } - cur.push(tt.clone()); - } - if !cur.is_empty() { - args.push(cur); - } - args -} - -fn split_top_level_commas(tokens: &[TokenTree2]) -> Vec> { - let mut out = Vec::new(); - let mut cur = Vec::new(); - let mut depth = 0i32; - for tt in tokens { - if let TokenTree2::Punct(p) = tt { - match p.as_char() { - '<' => depth += 1, - '>' => depth -= 1, - ',' if depth == 0 => { - out.push(std::mem::take(&mut cur)); - continue; - } - _ => {} - } - } - cur.push(tt.clone()); - } - out.push(cur); - out -} - -fn is_lifetime(tokens: &[TokenTree2]) -> bool { - matches!(tokens.first(), Some(TokenTree2::Punct(p)) if p.as_char() == '\'') -} - -fn err_tokens(tokens: &[TokenTree2], msg: &str) -> syn::Error { - let stream: TokenStream2 = tokens.iter().cloned().collect(); - syn::Error::new_spanned(stream, msg) -} - -fn rewrite_self_in_tokens(tokens: TokenStream2) -> TokenStream2 { - tokens - .into_iter() - .map(|tt| match tt { - TokenTree2::Ident(id) if id == "self" => TokenTree2::Ident(format_ident!("self_")), - TokenTree2::Group(g) => { - let inner = rewrite_self_in_tokens(g.stream()); - TokenTree2::Group(proc_macro2::Group::new(g.delimiter(), inner)) - } - other => other, - }) - .collect() -} - -fn refine_fn_name(func: &FnItemWithSignature, r: &Refinement) -> syn::Ident { - let pos = r - .steps - .iter() - .map(|s| match s { - RefineStep::Param(i) => format!("p{}", i), - RefineStep::Return => "ret".to_string(), - RefineStep::TypeArg(i) => format!("t{}", i), - }) - .collect::>() - .join("_"); - format_ident!("_thrust_refine_{}_{}", func.sig().ident, pos) -} - -fn refine_formula_fn( - func: &FnItemWithSignature, - outer_context: Option<&FnOuterItem>, - r: &Refinement, -) -> TokenStream2 { - let name = refine_fn_name(func, r); - let def_generics = generic_params_tokens(&func.sig().generics); - let model_params = fn_params_with_model_ty(&func.sig().inputs); - let model_preds = model_where_predicates(func, outer_context); - let extended_where = extended_where_clause(func, &model_preds); - let binder = &r.binder; - let binder_ty = &r.binder_ty; - let formula = &r.formula; - - quote! { - #[allow(unused_variables)] - #[allow(non_snake_case)] - #[thrust::formula_fn] - fn #name #def_generics( - #binder: <#binder_ty as thrust_models::Model>::Ty, - #model_params - ) -> bool #extended_where { - #formula - } - } -} - -fn refine_path_stmt(func: &FnItemWithSignature, r: &Refinement) -> TokenStream2 { - let name = refine_fn_name(func, r); - let turbofish = generic_turbofish(&func.sig().generics); - let path_prefix = if func.sig().receiver().is_some() { - quote!(Self::) - } else { - quote!() - }; - let encoded_steps = r.steps.iter().map(|s| match s { - RefineStep::Param(i) => { - let lit = proc_macro2::Literal::usize_unsuffixed(*i); - quote!(#lit) - } - RefineStep::Return => quote!(result), - RefineStep::TypeArg(i) => { - let lit = proc_macro2::Literal::usize_unsuffixed(*i); - quote!([#lit]) - } - }); - quote! { - #[thrust::refinement_path(#(#encoded_steps),*)] - #path_prefix #name #turbofish; - } + refine::expand_refine(refine::RefineKind::Sig, attr, item) } fn mentions_self(sig: &syn::Signature) -> bool { diff --git a/thrust-macros/src/refine.rs b/thrust-macros/src/refine.rs new file mode 100644 index 00000000..9199e660 --- /dev/null +++ b/thrust-macros/src/refine.rs @@ -0,0 +1,428 @@ +//! Refinement-type annotations: `param`, `ret`, `sig`. +//! +//! These lower refinement types (e.g. `List<{ v: i32 | v > 0 }>`) into +//! `#[thrust::formula_fn]`s plus positioned `#[thrust::refinement_path(..)]` +//! path statements injected into the function body. The "type position" +//! addresses into the function type: a parameter (by index) or the return (the +//! `result` keyword) selects a function slot, and bracket steps (`[i]`) descend +//! into generic arguments (enum args / `Box` pointee). For example, +//! `#[thrust::refinement_path(result, [0])]` is the first type-argument of the +//! return. + +use proc_macro::TokenStream; +use proc_macro2::{TokenStream as TokenStream2, TokenTree as TokenTree2}; +use quote::{format_ident, quote, ToTokens}; +use syn::{parse_macro_input, FnArg}; + +use super::{ + extended_where_clause, extract_outer_context, fn_params_with_model_ty, generic_params_tokens, + generic_turbofish, model_where_predicates, FnItemWithSignature, FnOuterItem, +}; + +/// One step in a refinement's type-position path. +/// +/// Mirrors [`rty::TypePositionStep`] on the plugin side and uses the same +/// attribute encoding: +/// - [`Param`](Self::Param) / [`Return`](Self::Return) navigate into a function +/// type; encoded as an integer literal / the `result` keyword. +/// - [`TypeArg`](Self::TypeArg) navigates into a generic type argument; encoded +/// as a bracket group `[i]`. +#[derive(Clone, Copy)] +enum RefineStep { + Param(usize), + Return, + TypeArg(usize), +} + +#[derive(Clone)] +struct Refinement { + /// Full type-position path from the function root to the refined type. + steps: Vec, + binder: syn::Ident, + binder_ty: TokenStream2, + formula: TokenStream2, +} + +pub(crate) enum RefineKind { + Param, + Ret, + Sig, +} + +pub(crate) fn expand_refine(kind: RefineKind, attr: TokenStream, item: TokenStream) -> TokenStream { + let mut func = parse_macro_input!(item as FnItemWithSignature); + + let outer_context = match extract_outer_context(&func) { + Ok(ctx) => ctx, + Err(e) => { + let err = e.to_compile_error(); + return quote! { #err #func }.into(); + } + }; + + let attr_tokens: Vec = TokenStream2::from(attr).into_iter().collect(); + let jobs = match build_refine_jobs(kind, &func, &attr_tokens) { + Ok(jobs) => jobs, + Err(e) => { + let err = e.to_compile_error(); + return quote! { #err #func }.into(); + } + }; + + let mut refinements = Vec::new(); + for (root_steps, ty_tokens) in jobs { + if let Err(e) = scan_type(&ty_tokens, root_steps, &mut refinements) { + let err = e.to_compile_error(); + return quote! { #err #func }.into(); + } + } + + if refinements.is_empty() { + return func.into_token_stream().into(); + } + + let has_receiver = func.sig().receiver().is_some(); + let mut formula_fns = Vec::new(); + let mut path_stmts = Vec::new(); + for mut r in refinements { + if has_receiver { + r.formula = rewrite_self_in_tokens(r.formula); + } + formula_fns.push(refine_formula_fn(&func, outer_context.as_ref(), &r)); + path_stmts.push(refine_path_stmt(&func, &r)); + } + + let Some(block) = func.block_mut() else { + let err = syn::Error::new_spanned( + func.sig().ident.clone(), + "refinement-type annotations require a function body", + ) + .into_compile_error(); + return quote! { #err #func }.into(); + }; + let orig_stmts = block.stmts.drain(..).collect::>(); + *block = syn::parse_quote!({ + #(#path_stmts)* + #(#orig_stmts)* + }); + func.attrs_mut() + .push(syn::parse_quote!(#[allow(path_statements)])); + + quote! { + #(#formula_fns)* + #func + } + .into() +} + +/// Builds `(root_steps, type_tokens)` jobs to scan from the attribute tokens. +/// +/// Each job's `root_steps` contains the initial [`RefineStep`]s that fix the +/// position of the type expression within the function signature (e.g. +/// `[Param(0)]` for the first parameter). [`scan_type`] will append further +/// [`RefineStep::TypeArg`] steps as it descends into generic type arguments. +fn build_refine_jobs( + kind: RefineKind, + func: &FnItemWithSignature, + attr_tokens: &[TokenTree2], +) -> syn::Result, Vec)>> { + match kind { + RefineKind::Param => { + let (name, ty_tokens) = split_name_type(attr_tokens)?; + let idx = param_index(func, &name)?; + Ok(vec![(vec![RefineStep::Param(idx)], ty_tokens)]) + } + RefineKind::Ret => Ok(vec![(vec![RefineStep::Return], attr_tokens.to_vec())]), + RefineKind::Sig => { + let (args, ret_tokens) = parse_sig_attr(attr_tokens)?; + let mut jobs = Vec::new(); + for (name, ty_tokens) in args { + let idx = param_index(func, &name)?; + jobs.push((vec![RefineStep::Param(idx)], ty_tokens)); + } + jobs.push((vec![RefineStep::Return], ret_tokens)); + Ok(jobs) + } + } +} + +fn param_index(func: &FnItemWithSignature, name: &syn::Ident) -> syn::Result { + let pos = func.sig().inputs.iter().position(|arg| match arg { + FnArg::Receiver(_) => name == "self", + FnArg::Typed(pt) => matches!(&*pt.pat, syn::Pat::Ident(pi) if &pi.ident == name), + }); + pos.ok_or_else(|| { + syn::Error::new_spanned(name, format!("no parameter named `{}` in signature", name)) + }) +} + +/// Parses `name : ` from a flat token slice. +fn split_name_type(tokens: &[TokenTree2]) -> syn::Result<(syn::Ident, Vec)> { + let name = match tokens.first() { + Some(TokenTree2::Ident(id)) => id.clone(), + _ => return Err(err_tokens(tokens, "expected a parameter name")), + }; + match tokens.get(1) { + Some(TokenTree2::Punct(p)) if p.as_char() == ':' => {} + _ => return Err(err_tokens(tokens, "expected `:` after parameter name")), + } + Ok((name, tokens[2..].to_vec())) +} + +/// Parses `fn ( n0: t0 , ... ) -> ret` into `((name, ty_tokens)*, ret_tokens)`. +#[allow(clippy::type_complexity)] +fn parse_sig_attr( + tokens: &[TokenTree2], +) -> syn::Result<(Vec<(syn::Ident, Vec)>, Vec)> { + match tokens.first() { + Some(TokenTree2::Ident(id)) if id == "fn" => {} + _ => return Err(err_tokens(tokens, "expected `fn` in sig annotation")), + } + let arg_group = match tokens.get(1) { + Some(TokenTree2::Group(g)) if g.delimiter() == proc_macro2::Delimiter::Parenthesis => g, + _ => return Err(err_tokens(tokens, "expected `(..)` after `fn`")), + }; + + let mut args = Vec::new(); + let arg_tokens: Vec = arg_group.stream().into_iter().collect(); + for arg in split_top_level_commas(&arg_tokens) { + if arg.is_empty() { + continue; + } + args.push(split_name_type(&arg)?); + } + + // expect `->` then the return type + let mut rest = &tokens[2..]; + match (rest.first(), rest.get(1)) { + (Some(TokenTree2::Punct(a)), Some(TokenTree2::Punct(b))) + if a.as_char() == '-' && b.as_char() == '>' => + { + rest = &rest[2..]; + } + _ => { + return Err(err_tokens( + tokens, + "expected `->` and a return type in sig annotation", + )) + } + } + Ok((args, rest.to_vec())) +} + +/// Scans a type expression and records every refinement node with its full +/// type-position path (`steps`). +/// +/// `steps` holds the path from the function root to the current type node. +/// When a refinement `{binder: ty | formula}` is found the current `steps` are +/// recorded; when descending into generic type arguments a +/// [`RefineStep::TypeArg`]`(i)` step is appended to `steps`. +fn scan_type( + tokens: &[TokenTree2], + steps: Vec, + out: &mut Vec, +) -> syn::Result<()> { + if tokens.is_empty() { + return Ok(()); + } + + // A refinement node is exactly a brace-delimited group. + if tokens.len() == 1 { + if let TokenTree2::Group(g) = &tokens[0] { + if g.delimiter() == proc_macro2::Delimiter::Brace { + let (binder, binder_ty, formula) = split_refinement(g.stream())?; + out.push(Refinement { + steps: steps.clone(), + binder, + binder_ty: binder_ty.iter().cloned().collect(), + formula, + }); + // Descend into the binder type for nested refinements. + scan_type(&binder_ty, steps, out)?; + return Ok(()); + } + } + } + + // A nominal type `Name` (`Box` included). + if let TokenTree2::Ident(_) = &tokens[0] { + if let Some(TokenTree2::Punct(p)) = tokens.get(1) { + if p.as_char() == '<' { + let mut type_idx = 0; + for arg in split_angle_args(&tokens[2..]) { + if is_lifetime(&arg) { + continue; + } + let mut child = steps.clone(); + child.push(RefineStep::TypeArg(type_idx)); + scan_type(&arg, child, out)?; + type_idx += 1; + } + } + } + } + + Ok(()) +} + +/// Splits `{ binder : ty | formula }` contents into its parts. +fn split_refinement( + stream: TokenStream2, +) -> syn::Result<(syn::Ident, Vec, TokenStream2)> { + let toks: Vec = stream.into_iter().collect(); + let bar = toks + .iter() + .position(|tt| matches!(tt, TokenTree2::Punct(p) if p.as_char() == '|')) + .ok_or_else(|| err_tokens(&toks, "refinement type must contain `|`"))?; + let (binder, binder_ty) = split_name_type(&toks[..bar])?; + let formula: TokenStream2 = toks[bar + 1..].iter().cloned().collect(); + Ok((binder, binder_ty, formula)) +} + +/// Splits the tokens following an opening `<` at top level by commas, stopping +/// at the matching `>`. +fn split_angle_args(tokens: &[TokenTree2]) -> Vec> { + let mut args = Vec::new(); + let mut cur = Vec::new(); + let mut depth = 1usize; + for tt in tokens { + if let TokenTree2::Punct(p) = tt { + match p.as_char() { + '<' => { + depth += 1; + cur.push(tt.clone()); + continue; + } + '>' => { + depth -= 1; + if depth == 0 { + break; + } + cur.push(tt.clone()); + continue; + } + ',' if depth == 1 => { + args.push(std::mem::take(&mut cur)); + continue; + } + _ => {} + } + } + cur.push(tt.clone()); + } + if !cur.is_empty() { + args.push(cur); + } + args +} + +fn split_top_level_commas(tokens: &[TokenTree2]) -> Vec> { + let mut out = Vec::new(); + let mut cur = Vec::new(); + let mut depth = 0i32; + for tt in tokens { + if let TokenTree2::Punct(p) = tt { + match p.as_char() { + '<' => depth += 1, + '>' => depth -= 1, + ',' if depth == 0 => { + out.push(std::mem::take(&mut cur)); + continue; + } + _ => {} + } + } + cur.push(tt.clone()); + } + out.push(cur); + out +} + +fn is_lifetime(tokens: &[TokenTree2]) -> bool { + matches!(tokens.first(), Some(TokenTree2::Punct(p)) if p.as_char() == '\'') +} + +fn err_tokens(tokens: &[TokenTree2], msg: &str) -> syn::Error { + let stream: TokenStream2 = tokens.iter().cloned().collect(); + syn::Error::new_spanned(stream, msg) +} + +fn rewrite_self_in_tokens(tokens: TokenStream2) -> TokenStream2 { + tokens + .into_iter() + .map(|tt| match tt { + TokenTree2::Ident(id) if id == "self" => TokenTree2::Ident(format_ident!("self_")), + TokenTree2::Group(g) => { + let inner = rewrite_self_in_tokens(g.stream()); + TokenTree2::Group(proc_macro2::Group::new(g.delimiter(), inner)) + } + other => other, + }) + .collect() +} + +fn refine_fn_name(func: &FnItemWithSignature, r: &Refinement) -> syn::Ident { + let pos = r + .steps + .iter() + .map(|s| match s { + RefineStep::Param(i) => format!("p{}", i), + RefineStep::Return => "ret".to_string(), + RefineStep::TypeArg(i) => format!("t{}", i), + }) + .collect::>() + .join("_"); + format_ident!("_thrust_refine_{}_{}", func.sig().ident, pos) +} + +fn refine_formula_fn( + func: &FnItemWithSignature, + outer_context: Option<&FnOuterItem>, + r: &Refinement, +) -> TokenStream2 { + let name = refine_fn_name(func, r); + let def_generics = generic_params_tokens(&func.sig().generics); + let model_params = fn_params_with_model_ty(&func.sig().inputs); + let model_preds = model_where_predicates(func, outer_context); + let extended_where = extended_where_clause(func, &model_preds); + let binder = &r.binder; + let binder_ty = &r.binder_ty; + let formula = &r.formula; + + quote! { + #[allow(unused_variables)] + #[allow(non_snake_case)] + #[thrust::formula_fn] + fn #name #def_generics( + #binder: <#binder_ty as thrust_models::Model>::Ty, + #model_params + ) -> bool #extended_where { + #formula + } + } +} + +fn refine_path_stmt(func: &FnItemWithSignature, r: &Refinement) -> TokenStream2 { + let name = refine_fn_name(func, r); + let turbofish = generic_turbofish(&func.sig().generics); + let path_prefix = if func.sig().receiver().is_some() { + quote!(Self::) + } else { + quote!() + }; + let encoded_steps = r.steps.iter().map(|s| match s { + RefineStep::Param(i) => { + let lit = proc_macro2::Literal::usize_unsuffixed(*i); + quote!(#lit) + } + RefineStep::Return => quote!(result), + RefineStep::TypeArg(i) => { + let lit = proc_macro2::Literal::usize_unsuffixed(*i); + quote!([#lit]) + } + }); + quote! { + #[thrust::refinement_path(#(#encoded_steps),*)] + #path_prefix #name #turbofish; + } +} From dc425fb9f82abbfbce59b763a758b09962c97192 Mon Sep 17 00:00:00 2001 From: Claude Date: Wed, 27 May 2026 15:30:34 +0000 Subject: [PATCH 6/8] Use $i / bare integer syntax for refinement_path type positions Encode function parameters as $i (matching FunctionParamIdx's Display) and type arguments as bare integers, instead of bare integers for params and bracketed [i] for type args. Reads more naturally and keeps the attribute syntax consistent with how parameters are displayed elsewhere. https://claude.ai/code/session_01Km9xwaVaGQjnAy1iHsPewa --- src/analyze/annot.rs | 60 +++++++++++++++++-------------------- src/rty.rs | 14 ++++----- thrust-macros/src/refine.rs | 16 +++++----- 3 files changed, 42 insertions(+), 48 deletions(-) diff --git a/src/analyze/annot.rs b/src/analyze/annot.rs index 6a5da38c..80196ba4 100644 --- a/src/analyze/annot.rs +++ b/src/analyze/annot.rs @@ -217,16 +217,16 @@ pub fn extract_annot_tokens(attr: Attribute) -> TokenStream { /// Tokens are comma-separated steps. Each step is one of: /// - The keyword `result` → [`rty::TypePositionStep::Return`] (navigate to a /// function type's return slot). -/// - An integer literal `i` → [`rty::TypePositionStep::Param`]`(i)` (navigate -/// to the `i`-th parameter of a function type). -/// - A bracket group `[i]` → [`rty::TypePositionStep::TypeArg`]`(i)` (navigate -/// to the `i`-th type argument of a generic type such as an enum or `Box`). +/// - `$i` (a `$` followed by an integer) → [`rty::TypePositionStep::Param`]`(i)` +/// (navigate to the `i`-th parameter of a function type). +/// - A bare integer `i` → [`rty::TypePositionStep::TypeArg`]`(i)` (navigate to +/// the `i`-th type argument of a generic type such as an enum or `Box`). /// -/// Examples: `result` is the return; `0` is the first parameter; `0, [0]` is -/// the first type-argument of the first parameter; `0, result` is the return of -/// a function-typed first parameter. +/// Examples: `result` is the return; `$0` is the first parameter; `$0, 0` is +/// the first type-argument of the first parameter; `$0, result` is the return +/// of a function-typed first parameter. pub fn parse_type_position(ts: &TokenStream) -> rty::TypePosition { - use rustc_ast::token::{Delimiter, LitKind, TokenKind}; + use rustc_ast::token::{LitKind, TokenKind}; use rustc_ast::tokenstream::TokenTree; let parse_int = |lit: &rustc_ast::token::Lit| -> usize { @@ -242,36 +242,30 @@ pub fn parse_type_position(ts: &TokenStream) -> rty::TypePosition { }; let mut steps = Vec::new(); - for tt in ts.iter() { - match tt { - TokenTree::Token(t, _) => match &t.kind { - TokenKind::Comma => {} - TokenKind::Ident(sym, _) if sym.as_str() == "result" => { - steps.push(rty::TypePositionStep::Return); - } - TokenKind::Literal(lit) => { - steps.push(rty::TypePositionStep::Param(rty::FunctionParamIdx::from( - parse_int(lit), - ))); - } - _ => panic!("unexpected token in refine position: {:?}", t), - }, - TokenTree::Delimited(_, _, Delimiter::Bracket, inner) => { - let mut inner_iter = inner.iter(); - let i = match inner_iter.next() { + let mut iter = ts.iter(); + while let Some(tt) = iter.next() { + let TokenTree::Token(t, _) = tt else { + panic!("unexpected token tree in refine position"); + }; + match &t.kind { + TokenKind::Comma => {} + TokenKind::Ident(sym, _) if sym.as_str() == "result" => { + steps.push(rty::TypePositionStep::Return); + } + TokenKind::Dollar => { + let i = match iter.next() { Some(TokenTree::Token(t, _)) => match &t.kind { TokenKind::Literal(lit) => parse_int(lit), - _ => panic!("expected integer inside [..] refine step: {:?}", t), + _ => panic!("expected integer after `$` in refine position: {:?}", t), }, - _ => panic!("expected integer inside [..] refine step"), + _ => panic!("expected integer after `$` in refine position"), }; - assert!( - inner_iter.next().is_none(), - "expected exactly one integer inside [..] refine step" - ); - steps.push(rty::TypePositionStep::TypeArg(i)); + steps.push(rty::TypePositionStep::Param(rty::FunctionParamIdx::from(i))); + } + TokenKind::Literal(lit) => { + steps.push(rty::TypePositionStep::TypeArg(parse_int(lit))); } - _ => panic!("unexpected token tree in refine position"), + _ => panic!("unexpected token in refine position: {:?}", t), } } diff --git a/src/rty.rs b/src/rty.rs index b0f10f3d..9466eaac 100644 --- a/src/rty.rs +++ b/src/rty.rs @@ -95,7 +95,7 @@ where /// Using distinct variants for function navigation ([`Param`](Self::Param), /// [`Return`](Self::Return)) and generic-arg navigation /// ([`TypeArg`](Self::TypeArg)) allows the same path representation to address -/// positions inside higher-order function types. For example, `[$0, result]` +/// positions inside higher-order function types. For example, `$0.result` /// addresses the return type of a function-typed first parameter. #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum TypePositionStep { @@ -112,7 +112,7 @@ impl std::fmt::Display for TypePositionStep { match self { TypePositionStep::Param(idx) => write!(f, "{}", idx), TypePositionStep::Return => f.write_str("result"), - TypePositionStep::TypeArg(i) => write!(f, "[{}]", i), + TypePositionStep::TypeArg(i) => write!(f, "{}", i), } } } @@ -128,11 +128,11 @@ impl std::fmt::Display for TypePositionStep { /// types), enabling positions inside higher-order function types. /// /// Examples (function `fn f(x: List) -> Box`): -/// - `[$0]` — parameter `x`. -/// - `[result]` — the return type. -/// - `[$0, [0]]` — the first type arg of `x`. -/// - `[result, [0]]` — the pointee of the `Box` return. -/// - `[$0, result]` — the return of a function-typed param `x`. +/// - `$0` — parameter `x`. +/// - `result` — the return type. +/// - `$0.0` — the first type arg of `x`. +/// - `result.0` — the pointee of the `Box` return. +/// - `$0.result` — the return of a function-typed param `x`. #[derive(Debug, Clone, PartialEq, Eq)] pub struct TypePosition { steps: Vec, diff --git a/thrust-macros/src/refine.rs b/thrust-macros/src/refine.rs index 9199e660..06353817 100644 --- a/thrust-macros/src/refine.rs +++ b/thrust-macros/src/refine.rs @@ -3,10 +3,10 @@ //! These lower refinement types (e.g. `List<{ v: i32 | v > 0 }>`) into //! `#[thrust::formula_fn]`s plus positioned `#[thrust::refinement_path(..)]` //! path statements injected into the function body. The "type position" -//! addresses into the function type: a parameter (by index) or the return (the -//! `result` keyword) selects a function slot, and bracket steps (`[i]`) descend -//! into generic arguments (enum args / `Box` pointee). For example, -//! `#[thrust::refinement_path(result, [0])]` is the first type-argument of the +//! addresses into the function type: a parameter (`$i`) or the return (the +//! `result` keyword) selects a function slot, and bare integer steps (`i`) +//! descend into generic arguments (enum args / `Box` pointee). For example, +//! `#[thrust::refinement_path(result, 0)]` is the first type-argument of the //! return. use proc_macro::TokenStream; @@ -24,9 +24,9 @@ use super::{ /// Mirrors [`rty::TypePositionStep`] on the plugin side and uses the same /// attribute encoding: /// - [`Param`](Self::Param) / [`Return`](Self::Return) navigate into a function -/// type; encoded as an integer literal / the `result` keyword. +/// type; encoded as `$i` / the `result` keyword. /// - [`TypeArg`](Self::TypeArg) navigates into a generic type argument; encoded -/// as a bracket group `[i]`. +/// as a bare integer `i`. #[derive(Clone, Copy)] enum RefineStep { Param(usize), @@ -413,12 +413,12 @@ fn refine_path_stmt(func: &FnItemWithSignature, r: &Refinement) -> TokenStream2 let encoded_steps = r.steps.iter().map(|s| match s { RefineStep::Param(i) => { let lit = proc_macro2::Literal::usize_unsuffixed(*i); - quote!(#lit) + quote!($#lit) } RefineStep::Return => quote!(result), RefineStep::TypeArg(i) => { let lit = proc_macro2::Literal::usize_unsuffixed(*i); - quote!([#lit]) + quote!(#lit) } }); quote! { From d5a15760abf6d236f7b3210a0762e89fd03e313a Mon Sep 17 00:00:00 2001 From: Claude Date: Wed, 27 May 2026 15:47:38 +0000 Subject: [PATCH 7/8] Clarify refinement type-position naming and align Display with syntax - Rename FunctionTemplateTypeBuilder::refine to install_refinement_at, matching the RefinedType method it delegates to. - Rename extract_refine_paths / extract_refine_annots to extract_refinement_paths / extract_refinement_annots, and clarify the related panic messages, to reduce the overloaded use of "refine". - Make TypePosition's Display match the refinement_path(..) surface syntax (comma-separated steps) instead of dot-separated. - Replace the macro's build_refine_jobs (and its complex tuple return type) with annotated_type_exprs returning PositionedTypeExpr, and introduce NamedType / SigAnnotation structs so parse_sig_attr no longer needs an allow(clippy::type_complexity). Rename RefineStep to PositionStep, RefineKind to AnnotationKind, and expand_refine to expand for clarity. https://claude.ai/code/session_01Km9xwaVaGQjnAy1iHsPewa --- src/analyze.rs | 15 ++-- src/analyze/local_def.rs | 8 +- src/refine/template.rs | 2 +- src/rty.rs | 10 +-- thrust-macros/src/lib.rs | 6 +- thrust-macros/src/refine.rs | 149 +++++++++++++++++++++--------------- 6 files changed, 111 insertions(+), 79 deletions(-) diff --git a/src/analyze.rs b/src/analyze.rs index e16fb08b..e7420df9 100644 --- a/src/analyze.rs +++ b/src/analyze.rs @@ -792,7 +792,10 @@ impl<'tcx> Analyzer<'tcx> { /// Collects every `#[thrust::refinement_path(..)]` path statement in the /// function body, returning each `(type position, formula_fn DefId)`. - fn extract_refine_paths(&self, local_def_id: LocalDefId) -> Vec<(rty::TypePosition, DefId)> { + fn extract_refinement_paths( + &self, + local_def_id: LocalDefId, + ) -> Vec<(rty::TypePosition, DefId)> { let mut out = Vec::new(); let Some(body) = self.tcx.hir_maybe_body_owned_by(local_def_id) else { return out; @@ -842,27 +845,27 @@ impl<'tcx> Analyzer<'tcx> { /// Resolves every `#[thrust::refinement_path(..)]` annotation into a /// positioned refinement, by translating the referenced formula function. - pub fn extract_refine_annots( + pub fn extract_refinement_annots( &self, local_def_id: LocalDefId, generic_args: mir_ty::GenericArgsRef<'tcx>, ) -> Vec<(rty::TypePosition, rty::Refinement)> { let mut out = Vec::new(); - for (position, def_id) in self.extract_refine_paths(local_def_id) { + for (position, def_id) in self.extract_refinement_paths(local_def_id) { let Some(formula_def_id) = def_id.as_local() else { panic!( - "refine annotation with path is expected to refer to a local def, but found: {:?}", + "refinement_path annotation is expected to refer to a local def, but found: {:?}", def_id ); }; let Some(formula_fn) = self.formula_fn_with_args(formula_def_id, generic_args) else { panic!( - "refine annotation {:?} is not a formula function", + "refinement_path annotation {:?} is not a formula function", formula_def_id ); }; let AnnotFormula::Formula(formula) = formula_fn.to_ensure_annot() else { - panic!("refine annotation must lower to a plain formula"); + panic!("refinement_path annotation must lower to a plain formula"); }; out.push((position, formula.into())); } diff --git a/src/analyze/local_def.rs b/src/analyze/local_def.rs index 6ccff823..1cb273f3 100644 --- a/src/analyze/local_def.rs +++ b/src/analyze/local_def.rs @@ -406,9 +406,9 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { assert!(require_annot.is_none() || param_annots.is_empty()); assert!(ensure_annot.is_none() || ret_annot.is_none()); - let refine_annots = self + let refinement_annots = self .ctx - .extract_refine_annots(self.local_def_id, self.generic_args); + .extract_refinement_annots(self.local_def_id, self.generic_args); let trait_item_ty = self.trait_item_ty(); let is_fully_annotated = self.is_fully_annotated(); @@ -435,8 +435,8 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { if let Some(ret_rty) = ret_annot { builder.ret_rty(ret_rty); } - for (position, refinement) in refine_annots { - builder.refine(&position, refinement); + for (position, refinement) in refinement_annots { + builder.install_refinement_at(&position, refinement); } if is_fully_annotated { diff --git a/src/refine/template.rs b/src/refine/template.rs index 50408612..e88008a6 100644 --- a/src/refine/template.rs +++ b/src/refine/template.rs @@ -571,7 +571,7 @@ impl<'tcx, 'a, R> FunctionTemplateTypeBuilder<'tcx, 'a, R> { /// The first step must be [`rty::TypePositionStep::Param`] or /// [`rty::TypePositionStep::Return`]; the remaining steps are forwarded to /// [`rty::RefinedType::install_refinement_at`]. - pub fn refine( + pub fn install_refinement_at( &mut self, position: &rty::TypePosition, refinement: rty::Refinement, diff --git a/src/rty.rs b/src/rty.rs index 9466eaac..c0c4efd7 100644 --- a/src/rty.rs +++ b/src/rty.rs @@ -95,7 +95,7 @@ where /// Using distinct variants for function navigation ([`Param`](Self::Param), /// [`Return`](Self::Return)) and generic-arg navigation /// ([`TypeArg`](Self::TypeArg)) allows the same path representation to address -/// positions inside higher-order function types. For example, `$0.result` +/// positions inside higher-order function types. For example, `$0, result` /// addresses the return type of a function-typed first parameter. #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum TypePositionStep { @@ -130,9 +130,9 @@ impl std::fmt::Display for TypePositionStep { /// Examples (function `fn f(x: List) -> Box`): /// - `$0` — parameter `x`. /// - `result` — the return type. -/// - `$0.0` — the first type arg of `x`. -/// - `result.0` — the pointee of the `Box` return. -/// - `$0.result` — the return of a function-typed param `x`. +/// - `$0, 0` — the first type arg of `x`. +/// - `result, 0` — the pointee of the `Box` return. +/// - `$0, result` — the return of a function-typed param `x`. #[derive(Debug, Clone, PartialEq, Eq)] pub struct TypePosition { steps: Vec, @@ -157,7 +157,7 @@ impl std::fmt::Display for TypePosition { write!(f, "{}", first)?; } for s in iter { - write!(f, ".{}", s)?; + write!(f, ", {}", s)?; } Ok(()) } diff --git a/thrust-macros/src/lib.rs b/thrust-macros/src/lib.rs index 70969bb9..ab1edfec 100644 --- a/thrust-macros/src/lib.rs +++ b/thrust-macros/src/lib.rs @@ -554,17 +554,17 @@ impl ExpandedTokens { #[proc_macro_attribute] pub fn param(attr: TokenStream, item: TokenStream) -> TokenStream { - refine::expand_refine(refine::RefineKind::Param, attr, item) + refine::expand(refine::AnnotationKind::Param, attr, item) } #[proc_macro_attribute] pub fn ret(attr: TokenStream, item: TokenStream) -> TokenStream { - refine::expand_refine(refine::RefineKind::Ret, attr, item) + refine::expand(refine::AnnotationKind::Ret, attr, item) } #[proc_macro_attribute] pub fn sig(attr: TokenStream, item: TokenStream) -> TokenStream { - refine::expand_refine(refine::RefineKind::Sig, attr, item) + refine::expand(refine::AnnotationKind::Sig, attr, item) } fn mentions_self(sig: &syn::Signature) -> bool { diff --git a/thrust-macros/src/refine.rs b/thrust-macros/src/refine.rs index 06353817..102f53e4 100644 --- a/thrust-macros/src/refine.rs +++ b/thrust-macros/src/refine.rs @@ -28,7 +28,7 @@ use super::{ /// - [`TypeArg`](Self::TypeArg) navigates into a generic type argument; encoded /// as a bare integer `i`. #[derive(Clone, Copy)] -enum RefineStep { +enum PositionStep { Param(usize), Return, TypeArg(usize), @@ -37,19 +37,37 @@ enum RefineStep { #[derive(Clone)] struct Refinement { /// Full type-position path from the function root to the refined type. - steps: Vec, + steps: Vec, binder: syn::Ident, binder_ty: TokenStream2, formula: TokenStream2, } -pub(crate) enum RefineKind { +/// Which refinement-type annotation is being expanded. +pub(crate) enum AnnotationKind { Param, Ret, Sig, } -pub(crate) fn expand_refine(kind: RefineKind, attr: TokenStream, item: TokenStream) -> TokenStream { +/// A type expression from the annotation, paired with the position of its root +/// within the function signature. [`scan_type`] walks each one to extract the +/// refinements it contains. +struct PositionedTypeExpr { + /// Steps locating the root of `tokens` (e.g. `[Param(0)]` for the first + /// parameter); [`scan_type`] appends `TypeArg` steps as it descends. + root: Vec, + tokens: Vec, +} + +/// A `name : type` binding, e.g. a parameter in a `sig` annotation or the +/// binder of a refinement `{ name: type | .. }`. +struct NamedType { + name: syn::Ident, + tokens: Vec, +} + +pub(crate) fn expand(kind: AnnotationKind, attr: TokenStream, item: TokenStream) -> TokenStream { let mut func = parse_macro_input!(item as FnItemWithSignature); let outer_context = match extract_outer_context(&func) { @@ -61,8 +79,8 @@ pub(crate) fn expand_refine(kind: RefineKind, attr: TokenStream, item: TokenStre }; let attr_tokens: Vec = TokenStream2::from(attr).into_iter().collect(); - let jobs = match build_refine_jobs(kind, &func, &attr_tokens) { - Ok(jobs) => jobs, + let type_exprs = match annotated_type_exprs(kind, &func, &attr_tokens) { + Ok(exprs) => exprs, Err(e) => { let err = e.to_compile_error(); return quote! { #err #func }.into(); @@ -70,8 +88,8 @@ pub(crate) fn expand_refine(kind: RefineKind, attr: TokenStream, item: TokenStre }; let mut refinements = Vec::new(); - for (root_steps, ty_tokens) in jobs { - if let Err(e) = scan_type(&ty_tokens, root_steps, &mut refinements) { + for expr in type_exprs { + if let Err(e) = scan_type(&expr.tokens, expr.root, &mut refinements) { let err = e.to_compile_error(); return quote! { #err #func }.into(); } @@ -115,33 +133,37 @@ pub(crate) fn expand_refine(kind: RefineKind, attr: TokenStream, item: TokenStre .into() } -/// Builds `(root_steps, type_tokens)` jobs to scan from the attribute tokens. -/// -/// Each job's `root_steps` contains the initial [`RefineStep`]s that fix the -/// position of the type expression within the function signature (e.g. -/// `[Param(0)]` for the first parameter). [`scan_type`] will append further -/// [`RefineStep::TypeArg`] steps as it descends into generic type arguments. -fn build_refine_jobs( - kind: RefineKind, +/// Turns an annotation into the type expressions to scan, each anchored at its +/// root position within the function signature. +fn annotated_type_exprs( + kind: AnnotationKind, func: &FnItemWithSignature, attr_tokens: &[TokenTree2], -) -> syn::Result, Vec)>> { +) -> syn::Result> { + let at_param = |func: &FnItemWithSignature, nt: NamedType| -> syn::Result { + let idx = param_index(func, &nt.name)?; + Ok(PositionedTypeExpr { + root: vec![PositionStep::Param(idx)], + tokens: nt.tokens, + }) + }; match kind { - RefineKind::Param => { - let (name, ty_tokens) = split_name_type(attr_tokens)?; - let idx = param_index(func, &name)?; - Ok(vec![(vec![RefineStep::Param(idx)], ty_tokens)]) - } - RefineKind::Ret => Ok(vec![(vec![RefineStep::Return], attr_tokens.to_vec())]), - RefineKind::Sig => { - let (args, ret_tokens) = parse_sig_attr(attr_tokens)?; - let mut jobs = Vec::new(); - for (name, ty_tokens) in args { - let idx = param_index(func, &name)?; - jobs.push((vec![RefineStep::Param(idx)], ty_tokens)); + AnnotationKind::Param => Ok(vec![at_param(func, split_name_type(attr_tokens)?)?]), + AnnotationKind::Ret => Ok(vec![PositionedTypeExpr { + root: vec![PositionStep::Return], + tokens: attr_tokens.to_vec(), + }]), + AnnotationKind::Sig => { + let sig = parse_sig_attr(attr_tokens)?; + let mut exprs = Vec::new(); + for param in sig.params { + exprs.push(at_param(func, param)?); } - jobs.push((vec![RefineStep::Return], ret_tokens)); - Ok(jobs) + exprs.push(PositionedTypeExpr { + root: vec![PositionStep::Return], + tokens: sig.ret, + }); + Ok(exprs) } } } @@ -157,7 +179,7 @@ fn param_index(func: &FnItemWithSignature, name: &syn::Ident) -> syn::Result` from a flat token slice. -fn split_name_type(tokens: &[TokenTree2]) -> syn::Result<(syn::Ident, Vec)> { +fn split_name_type(tokens: &[TokenTree2]) -> syn::Result { let name = match tokens.first() { Some(TokenTree2::Ident(id)) => id.clone(), _ => return Err(err_tokens(tokens, "expected a parameter name")), @@ -166,14 +188,20 @@ fn split_name_type(tokens: &[TokenTree2]) -> syn::Result<(syn::Ident, Vec {} _ => return Err(err_tokens(tokens, "expected `:` after parameter name")), } - Ok((name, tokens[2..].to_vec())) + Ok(NamedType { + name, + tokens: tokens[2..].to_vec(), + }) } -/// Parses `fn ( n0: t0 , ... ) -> ret` into `((name, ty_tokens)*, ret_tokens)`. -#[allow(clippy::type_complexity)] -fn parse_sig_attr( - tokens: &[TokenTree2], -) -> syn::Result<(Vec<(syn::Ident, Vec)>, Vec)> { +/// The parsed parts of a `fn ( n0: t0 , ... ) -> ret` signature annotation. +struct SigAnnotation { + params: Vec, + ret: Vec, +} + +/// Parses `fn ( n0: t0 , ... ) -> ret`. +fn parse_sig_attr(tokens: &[TokenTree2]) -> syn::Result { match tokens.first() { Some(TokenTree2::Ident(id)) if id == "fn" => {} _ => return Err(err_tokens(tokens, "expected `fn` in sig annotation")), @@ -183,13 +211,13 @@ fn parse_sig_attr( _ => return Err(err_tokens(tokens, "expected `(..)` after `fn`")), }; - let mut args = Vec::new(); + let mut params = Vec::new(); let arg_tokens: Vec = arg_group.stream().into_iter().collect(); for arg in split_top_level_commas(&arg_tokens) { if arg.is_empty() { continue; } - args.push(split_name_type(&arg)?); + params.push(split_name_type(&arg)?); } // expect `->` then the return type @@ -207,7 +235,10 @@ fn parse_sig_attr( )) } } - Ok((args, rest.to_vec())) + Ok(SigAnnotation { + params, + ret: rest.to_vec(), + }) } /// Scans a type expression and records every refinement node with its full @@ -216,10 +247,10 @@ fn parse_sig_attr( /// `steps` holds the path from the function root to the current type node. /// When a refinement `{binder: ty | formula}` is found the current `steps` are /// recorded; when descending into generic type arguments a -/// [`RefineStep::TypeArg`]`(i)` step is appended to `steps`. +/// [`PositionStep::TypeArg`]`(i)` step is appended to `steps`. fn scan_type( tokens: &[TokenTree2], - steps: Vec, + steps: Vec, out: &mut Vec, ) -> syn::Result<()> { if tokens.is_empty() { @@ -230,15 +261,15 @@ fn scan_type( if tokens.len() == 1 { if let TokenTree2::Group(g) = &tokens[0] { if g.delimiter() == proc_macro2::Delimiter::Brace { - let (binder, binder_ty, formula) = split_refinement(g.stream())?; + let (binder, formula) = split_refinement(g.stream())?; out.push(Refinement { steps: steps.clone(), - binder, - binder_ty: binder_ty.iter().cloned().collect(), + binder: binder.name, + binder_ty: binder.tokens.iter().cloned().collect(), formula, }); // Descend into the binder type for nested refinements. - scan_type(&binder_ty, steps, out)?; + scan_type(&binder.tokens, steps, out)?; return Ok(()); } } @@ -254,7 +285,7 @@ fn scan_type( continue; } let mut child = steps.clone(); - child.push(RefineStep::TypeArg(type_idx)); + child.push(PositionStep::TypeArg(type_idx)); scan_type(&arg, child, out)?; type_idx += 1; } @@ -265,18 +296,16 @@ fn scan_type( Ok(()) } -/// Splits `{ binder : ty | formula }` contents into its parts. -fn split_refinement( - stream: TokenStream2, -) -> syn::Result<(syn::Ident, Vec, TokenStream2)> { +/// Splits `{ binder : ty | formula }` into its binder and formula expression. +fn split_refinement(stream: TokenStream2) -> syn::Result<(NamedType, TokenStream2)> { let toks: Vec = stream.into_iter().collect(); let bar = toks .iter() .position(|tt| matches!(tt, TokenTree2::Punct(p) if p.as_char() == '|')) .ok_or_else(|| err_tokens(&toks, "refinement type must contain `|`"))?; - let (binder, binder_ty) = split_name_type(&toks[..bar])?; + let binder = split_name_type(&toks[..bar])?; let formula: TokenStream2 = toks[bar + 1..].iter().cloned().collect(); - Ok((binder, binder_ty, formula)) + Ok((binder, formula)) } /// Splits the tokens following an opening `<` at top level by commas, stopping @@ -366,9 +395,9 @@ fn refine_fn_name(func: &FnItemWithSignature, r: &Refinement) -> syn::Ident { .steps .iter() .map(|s| match s { - RefineStep::Param(i) => format!("p{}", i), - RefineStep::Return => "ret".to_string(), - RefineStep::TypeArg(i) => format!("t{}", i), + PositionStep::Param(i) => format!("p{}", i), + PositionStep::Return => "ret".to_string(), + PositionStep::TypeArg(i) => format!("t{}", i), }) .collect::>() .join("_"); @@ -411,12 +440,12 @@ fn refine_path_stmt(func: &FnItemWithSignature, r: &Refinement) -> TokenStream2 quote!() }; let encoded_steps = r.steps.iter().map(|s| match s { - RefineStep::Param(i) => { + PositionStep::Param(i) => { let lit = proc_macro2::Literal::usize_unsuffixed(*i); quote!($#lit) } - RefineStep::Return => quote!(result), - RefineStep::TypeArg(i) => { + PositionStep::Return => quote!(result), + PositionStep::TypeArg(i) => { let lit = proc_macro2::Literal::usize_unsuffixed(*i); quote!(#lit) } From e39d33c119594607e8e3d48628426d3a8a645026 Mon Sep 17 00:00:00 2001 From: Claude Date: Wed, 27 May 2026 16:23:55 +0000 Subject: [PATCH 8/8] Refine type-position terminology and drop non-empty invariant - Replace the undefined phrase "refine position" in parser diagnostics with "type position". - Drop the non-empty invariant from TypePosition: an empty path is a valid notion (it addresses the type itself); a path is only non-empty when applied to a function type. TypePosition::new now takes the full step vector, and that non-emptiness is checked where it matters (the function-type builder). - Rename the macro module file thrust-macros/src/refine.rs to rty.rs and use TypePositionStep there, mirroring the plugin's rty module instead of naming the same concept differently. https://claude.ai/code/session_01Km9xwaVaGQjnAy1iHsPewa --- src/analyze/annot.rs | 16 +++++------- src/refine/template.rs | 4 +-- src/rty.rs | 19 ++++++-------- thrust-macros/src/lib.rs | 8 +++--- thrust-macros/src/{refine.rs => rty.rs} | 34 ++++++++++++------------- 5 files changed, 38 insertions(+), 43 deletions(-) rename thrust-macros/src/{refine.rs => rty.rs} (94%) diff --git a/src/analyze/annot.rs b/src/analyze/annot.rs index 80196ba4..5b409f35 100644 --- a/src/analyze/annot.rs +++ b/src/analyze/annot.rs @@ -233,19 +233,19 @@ pub fn parse_type_position(ts: &TokenStream) -> rty::TypePosition { assert_eq!( lit.kind, LitKind::Integer, - "expected an integer in refine position" + "expected an integer in type position" ); lit.symbol .as_str() .parse() - .expect("invalid integer in refine position") + .expect("invalid integer in type position") }; let mut steps = Vec::new(); let mut iter = ts.iter(); while let Some(tt) = iter.next() { let TokenTree::Token(t, _) = tt else { - panic!("unexpected token tree in refine position"); + panic!("unexpected token tree in type position"); }; match &t.kind { TokenKind::Comma => {} @@ -256,22 +256,20 @@ pub fn parse_type_position(ts: &TokenStream) -> rty::TypePosition { let i = match iter.next() { Some(TokenTree::Token(t, _)) => match &t.kind { TokenKind::Literal(lit) => parse_int(lit), - _ => panic!("expected integer after `$` in refine position: {:?}", t), + _ => panic!("expected integer after `$` in type position: {:?}", t), }, - _ => panic!("expected integer after `$` in refine position"), + _ => panic!("expected integer after `$` in type position"), }; steps.push(rty::TypePositionStep::Param(rty::FunctionParamIdx::from(i))); } TokenKind::Literal(lit) => { steps.push(rty::TypePositionStep::TypeArg(parse_int(lit))); } - _ => panic!("unexpected token in refine position: {:?}", t), + _ => panic!("unexpected token in type position: {:?}", t), } } - assert!(!steps.is_empty(), "empty refine position"); - let first = steps.remove(0); - rty::TypePosition::new(first, steps) + rty::TypePosition::new(steps) } pub fn split_param(ts: &TokenStream) -> (Ident, TokenStream) { diff --git a/src/refine/template.rs b/src/refine/template.rs index e88008a6..37956784 100644 --- a/src/refine/template.rs +++ b/src/refine/template.rs @@ -578,7 +578,7 @@ impl<'tcx, 'a, R> FunctionTemplateTypeBuilder<'tcx, 'a, R> { ) -> &mut Self { let (first, rest) = match position.steps().split_first() { Some(pair) => pair, - None => panic!("empty TypePosition"), + None => panic!("type position applied to a function type must not be empty"), }; match first { rty::TypePositionStep::Param(idx) => { @@ -603,7 +603,7 @@ impl<'tcx, 'a, R> FunctionTemplateTypeBuilder<'tcx, 'a, R> { .install_refinement_at(rest, refinement); } rty::TypePositionStep::TypeArg(_) => { - panic!("TypePosition must start with Param or Return, not TypeArg"); + panic!("type position applied to a function type must start with a param or result step, not a type argument"); } } self diff --git a/src/rty.rs b/src/rty.rs index c0c4efd7..abfff7bf 100644 --- a/src/rty.rs +++ b/src/rty.rs @@ -117,15 +117,14 @@ impl std::fmt::Display for TypePositionStep { } } -/// A path addressing a sub-type in a function's type signature, used to attach -/// a refinement. +/// A path addressing a sub-type within a type, used to attach a refinement. /// -/// The first step must be [`TypePositionStep::Param`] or -/// [`TypePositionStep::Return`] (selecting which slot of the top-level function -/// type to enter). Subsequent steps can freely combine -/// [`TypePositionStep::Param`] / [`TypePositionStep::Return`] (for -/// function-typed positions) and [`TypePositionStep::TypeArg`] (for generic -/// types), enabling positions inside higher-order function types. +/// An empty path addresses the type itself. Each step descends one level: +/// [`TypePositionStep::Param`] / [`TypePositionStep::Return`] enter a function +/// type's parameter or return slot, and [`TypePositionStep::TypeArg`] enters a +/// generic type argument. Steps combine freely, so positions inside +/// higher-order function types are expressible. A path applied to a function +/// type is therefore non-empty, beginning with a `Param`/`Return` step. /// /// Examples (function `fn f(x: List) -> Box`): /// - `$0` — parameter `x`. @@ -139,9 +138,7 @@ pub struct TypePosition { } impl TypePosition { - pub fn new(first: TypePositionStep, rest: Vec) -> Self { - let mut steps = vec![first]; - steps.extend(rest); + pub fn new(steps: Vec) -> Self { TypePosition { steps } } diff --git a/thrust-macros/src/lib.rs b/thrust-macros/src/lib.rs index ab1edfec..3399f9f8 100644 --- a/thrust-macros/src/lib.rs +++ b/thrust-macros/src/lib.rs @@ -6,7 +6,7 @@ use syn::{ WherePredicate, }; -mod refine; +mod rty; #[derive(Debug, Clone)] enum FnOuterItem { @@ -554,17 +554,17 @@ impl ExpandedTokens { #[proc_macro_attribute] pub fn param(attr: TokenStream, item: TokenStream) -> TokenStream { - refine::expand(refine::AnnotationKind::Param, attr, item) + rty::expand(rty::AnnotationKind::Param, attr, item) } #[proc_macro_attribute] pub fn ret(attr: TokenStream, item: TokenStream) -> TokenStream { - refine::expand(refine::AnnotationKind::Ret, attr, item) + rty::expand(rty::AnnotationKind::Ret, attr, item) } #[proc_macro_attribute] pub fn sig(attr: TokenStream, item: TokenStream) -> TokenStream { - refine::expand(refine::AnnotationKind::Sig, attr, item) + rty::expand(rty::AnnotationKind::Sig, attr, item) } fn mentions_self(sig: &syn::Signature) -> bool { diff --git a/thrust-macros/src/refine.rs b/thrust-macros/src/rty.rs similarity index 94% rename from thrust-macros/src/refine.rs rename to thrust-macros/src/rty.rs index 102f53e4..e32031e7 100644 --- a/thrust-macros/src/refine.rs +++ b/thrust-macros/src/rty.rs @@ -21,14 +21,14 @@ use super::{ /// One step in a refinement's type-position path. /// -/// Mirrors [`rty::TypePositionStep`] on the plugin side and uses the same -/// attribute encoding: +/// Mirrors the plugin's `rty::TypePositionStep` and uses the same attribute +/// encoding: /// - [`Param`](Self::Param) / [`Return`](Self::Return) navigate into a function /// type; encoded as `$i` / the `result` keyword. /// - [`TypeArg`](Self::TypeArg) navigates into a generic type argument; encoded /// as a bare integer `i`. #[derive(Clone, Copy)] -enum PositionStep { +enum TypePositionStep { Param(usize), Return, TypeArg(usize), @@ -37,7 +37,7 @@ enum PositionStep { #[derive(Clone)] struct Refinement { /// Full type-position path from the function root to the refined type. - steps: Vec, + steps: Vec, binder: syn::Ident, binder_ty: TokenStream2, formula: TokenStream2, @@ -56,7 +56,7 @@ pub(crate) enum AnnotationKind { struct PositionedTypeExpr { /// Steps locating the root of `tokens` (e.g. `[Param(0)]` for the first /// parameter); [`scan_type`] appends `TypeArg` steps as it descends. - root: Vec, + root: Vec, tokens: Vec, } @@ -143,14 +143,14 @@ fn annotated_type_exprs( let at_param = |func: &FnItemWithSignature, nt: NamedType| -> syn::Result { let idx = param_index(func, &nt.name)?; Ok(PositionedTypeExpr { - root: vec![PositionStep::Param(idx)], + root: vec![TypePositionStep::Param(idx)], tokens: nt.tokens, }) }; match kind { AnnotationKind::Param => Ok(vec![at_param(func, split_name_type(attr_tokens)?)?]), AnnotationKind::Ret => Ok(vec![PositionedTypeExpr { - root: vec![PositionStep::Return], + root: vec![TypePositionStep::Return], tokens: attr_tokens.to_vec(), }]), AnnotationKind::Sig => { @@ -160,7 +160,7 @@ fn annotated_type_exprs( exprs.push(at_param(func, param)?); } exprs.push(PositionedTypeExpr { - root: vec![PositionStep::Return], + root: vec![TypePositionStep::Return], tokens: sig.ret, }); Ok(exprs) @@ -247,10 +247,10 @@ fn parse_sig_attr(tokens: &[TokenTree2]) -> syn::Result { /// `steps` holds the path from the function root to the current type node. /// When a refinement `{binder: ty | formula}` is found the current `steps` are /// recorded; when descending into generic type arguments a -/// [`PositionStep::TypeArg`]`(i)` step is appended to `steps`. +/// [`TypePositionStep::TypeArg`]`(i)` step is appended to `steps`. fn scan_type( tokens: &[TokenTree2], - steps: Vec, + steps: Vec, out: &mut Vec, ) -> syn::Result<()> { if tokens.is_empty() { @@ -285,7 +285,7 @@ fn scan_type( continue; } let mut child = steps.clone(); - child.push(PositionStep::TypeArg(type_idx)); + child.push(TypePositionStep::TypeArg(type_idx)); scan_type(&arg, child, out)?; type_idx += 1; } @@ -395,9 +395,9 @@ fn refine_fn_name(func: &FnItemWithSignature, r: &Refinement) -> syn::Ident { .steps .iter() .map(|s| match s { - PositionStep::Param(i) => format!("p{}", i), - PositionStep::Return => "ret".to_string(), - PositionStep::TypeArg(i) => format!("t{}", i), + TypePositionStep::Param(i) => format!("p{}", i), + TypePositionStep::Return => "ret".to_string(), + TypePositionStep::TypeArg(i) => format!("t{}", i), }) .collect::>() .join("_"); @@ -440,12 +440,12 @@ fn refine_path_stmt(func: &FnItemWithSignature, r: &Refinement) -> TokenStream2 quote!() }; let encoded_steps = r.steps.iter().map(|s| match s { - PositionStep::Param(i) => { + TypePositionStep::Param(i) => { let lit = proc_macro2::Literal::usize_unsuffixed(*i); quote!($#lit) } - PositionStep::Return => quote!(result), - PositionStep::TypeArg(i) => { + TypePositionStep::Return => quote!(result), + TypePositionStep::TypeArg(i) => { let lit = proc_macro2::Literal::usize_unsuffixed(*i); quote!(#lit) }