Skip to content
82 changes: 82 additions & 0 deletions src/analyze.rs
Original file line number Diff line number Diff line change
Expand Up @@ -790,6 +790,88 @@ impl<'tcx> Analyzer<'tcx> {
ensure_annot
}

/// Collects every `#[thrust::refinement_path(..)]` path statement in the
/// function body, returning each `(type position, formula_fn DefId)`.
fn extract_refinement_paths(
&self,
local_def_id: LocalDefId,
) -> Vec<(rty::TypePosition, DefId)> {
let mut out = Vec::new();
let Some(body) = self.tcx.hir_maybe_body_owned_by(local_def_id) else {
return out;
};
let rustc_hir::ExprKind::Block(block, _) = body.value.kind else {
return out;
};
let attr_path = analyze::annot::refinement_path_path();
let typeck = self.tcx.typeck(local_def_id);
for stmt in block.stmts {
let Some(attr) = self
.tcx
.hir_attrs(stmt.hir_id)
.iter()
.find(|attr| attr.path_matches(&attr_path))
else {
continue;
};
let ts = analyze::annot::extract_annot_tokens(attr.clone());
let position = analyze::annot::parse_type_position(&ts);

let rustc_hir::StmtKind::Semi(expr) = stmt.kind else {
self.tcx.dcx().span_err(
stmt.span,
"annotated path is expected to be a semi statement",
);
continue;
};
let rustc_hir::ExprKind::Path(qpath) = expr.kind else {
self.tcx.dcx().span_err(
expr.span,
"annotated path is expected to be a path expression",
);
continue;
};
let rustc_hir::def::Res::Def(_, def_id) = typeck.qpath_res(&qpath, expr.hir_id) else {
self.tcx.dcx().span_err(
expr.span,
"annotated path is expected to refer to a definition",
);
continue;
};
out.push((position, def_id));
}
out
}

/// Resolves every `#[thrust::refinement_path(..)]` annotation into a
/// positioned refinement, by translating the referenced formula function.
pub fn extract_refinement_annots(
&self,
local_def_id: LocalDefId,
generic_args: mir_ty::GenericArgsRef<'tcx>,
) -> Vec<(rty::TypePosition, rty::Refinement<rty::FunctionParamIdx>)> {
let mut out = Vec::new();
for (position, def_id) in self.extract_refinement_paths(local_def_id) {
let Some(formula_def_id) = def_id.as_local() else {
panic!(
"refinement_path annotation is expected to refer to a local def, but found: {:?}",
def_id
);
};
let Some(formula_fn) = self.formula_fn_with_args(formula_def_id, generic_args) else {
panic!(
"refinement_path annotation {:?} is not a formula function",
formula_def_id
);
};
let AnnotFormula::Formula(formula) = formula_fn.to_ensure_annot() else {
panic!("refinement_path annotation must lower to a plain formula");
};
out.push((position, formula.into()));
}
out
}

/// Whether the given `def_id` corresponds to a method of one of the `Fn` traits.
fn is_fn_trait_method(&self, def_id: DefId) -> bool {
self.tcx
Expand Down
65 changes: 65 additions & 0 deletions src/analyze/annot.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,10 @@ pub fn ensures_path_path() -> [Symbol; 2] {
[Symbol::intern("thrust"), Symbol::intern("ensures_path")]
}

pub fn refinement_path_path() -> [Symbol; 2] {
[Symbol::intern("thrust"), Symbol::intern("refinement_path")]
}

pub fn model_ty_path() -> [Symbol; 3] {
[
Symbol::intern("thrust"),
Expand Down Expand Up @@ -207,6 +211,67 @@ pub fn extract_annot_tokens(attr: Attribute) -> TokenStream {
d.tokens
}

/// Parses a [`rty::TypePosition`] from the tokens of a
/// `#[thrust::refinement_path(..)]` attribute.
///
/// Tokens are comma-separated steps. Each step is one of:
/// - The keyword `result` → [`rty::TypePositionStep::Return`] (navigate to a
/// function type's return slot).
/// - `$i` (a `$` followed by an integer) → [`rty::TypePositionStep::Param`]`(i)`
/// (navigate to the `i`-th parameter of a function type).
/// - A bare integer `i` → [`rty::TypePositionStep::TypeArg`]`(i)` (navigate to
/// the `i`-th type argument of a generic type such as an enum or `Box`).
///
/// Examples: `result` is the return; `$0` is the first parameter; `$0, 0` is
/// the first type-argument of the first parameter; `$0, result` is the return
/// of a function-typed first parameter.
pub fn parse_type_position(ts: &TokenStream) -> rty::TypePosition {
use rustc_ast::token::{LitKind, TokenKind};
use rustc_ast::tokenstream::TokenTree;

let parse_int = |lit: &rustc_ast::token::Lit| -> usize {
assert_eq!(
lit.kind,
LitKind::Integer,
"expected an integer in type position"
);
lit.symbol
.as_str()
.parse()
.expect("invalid integer in type position")
};

let mut steps = Vec::new();
let mut iter = ts.iter();
while let Some(tt) = iter.next() {
let TokenTree::Token(t, _) = tt else {
panic!("unexpected token tree in type position");
};
match &t.kind {
TokenKind::Comma => {}
TokenKind::Ident(sym, _) if sym.as_str() == "result" => {
steps.push(rty::TypePositionStep::Return);
}
TokenKind::Dollar => {
let i = match iter.next() {
Some(TokenTree::Token(t, _)) => match &t.kind {
TokenKind::Literal(lit) => parse_int(lit),
_ => panic!("expected integer after `$` in type position: {:?}", t),
},
_ => panic!("expected integer after `$` in type position"),
};
steps.push(rty::TypePositionStep::Param(rty::FunctionParamIdx::from(i)));
}
TokenKind::Literal(lit) => {
steps.push(rty::TypePositionStep::TypeArg(parse_int(lit)));
}
_ => panic!("unexpected token in type position: {:?}", t),
}
}

rty::TypePosition::new(steps)
}

pub fn split_param(ts: &TokenStream) -> (Ident, TokenStream) {
use rustc_ast::token::TokenKind;
use rustc_ast::tokenstream::TokenTree;
Expand Down
7 changes: 7 additions & 0 deletions src/analyze/local_def.rs
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,10 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
assert!(require_annot.is_none() || param_annots.is_empty());
assert!(ensure_annot.is_none() || ret_annot.is_none());

let refinement_annots = self
.ctx
.extract_refinement_annots(self.local_def_id, self.generic_args);

let trait_item_ty = self.trait_item_ty();
let is_fully_annotated = self.is_fully_annotated();

Expand All @@ -431,6 +435,9 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
if let Some(ret_rty) = ret_annot {
builder.ret_rty(ret_rty);
}
for (position, refinement) in refinement_annots {
builder.install_refinement_at(&position, refinement);
}

if is_fully_annotated {
rty::RefinedType::unrefined(builder.build().into())
Expand Down
43 changes: 43 additions & 0 deletions src/refine/template.rs
Original file line number Diff line number Diff line change
Expand Up @@ -565,6 +565,49 @@ impl<'tcx, 'a, R> FunctionTemplateTypeBuilder<'tcx, 'a, R> {
self.ret_rty = Some(rty);
self
}

/// Installs a refinement at a [`rty::TypePosition`].
///
/// The first step must be [`rty::TypePositionStep::Param`] or
/// [`rty::TypePositionStep::Return`]; the remaining steps are forwarded to
/// [`rty::RefinedType::install_refinement_at`].
pub fn install_refinement_at(
&mut self,
position: &rty::TypePosition,
refinement: rty::Refinement<rty::FunctionParamIdx>,
) -> &mut Self {
let (first, rest) = match position.steps().split_first() {
Some(pair) => pair,
None => panic!("type position applied to a function type must not be empty"),
};
match first {
rty::TypePositionStep::Param(idx) => {
if !self.param_rtys.contains_key(idx) {
let ty = self.inner.build(self.param_tys[idx.index()].ty).vacuous();
self.param_rtys
.insert(*idx, rty::RefinedType::unrefined(ty));
}
self.param_rtys
.get_mut(idx)
.unwrap()
.install_refinement_at(rest, refinement);
}
rty::TypePositionStep::Return => {
if self.ret_rty.is_none() {
let ty = self.inner.build(self.ret_ty).vacuous();
self.ret_rty = Some(rty::RefinedType::unrefined(ty));
}
self.ret_rty
.as_mut()
.unwrap()
.install_refinement_at(rest, refinement);
}
rty::TypePositionStep::TypeArg(_) => {
panic!("type position applied to a function type must start with a param or result step, not a type argument");
}
}
self
}
}

impl<'tcx, 'a, R> FunctionTemplateTypeBuilder<'tcx, 'a, R>
Expand Down
125 changes: 125 additions & 0 deletions src/rty.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,83 @@ where
}
}

/// One step in a [`TypePosition`] path.
///
/// A path is a sequence of steps that addresses a sub-type within a
/// (potentially nested) function signature:
/// - [`Param`](Self::Param) / [`Return`](Self::Return) navigate into a
/// function type's parameter or return slot respectively.
/// - [`TypeArg`](Self::TypeArg) navigates into the `i`-th type argument of a
/// generic type (enum, `Box`, etc.).
///
/// Using distinct variants for function navigation ([`Param`](Self::Param),
/// [`Return`](Self::Return)) and generic-arg navigation
/// ([`TypeArg`](Self::TypeArg)) allows the same path representation to address
/// positions inside higher-order function types. For example, `$0, result`
/// addresses the return type of a function-typed first parameter.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TypePositionStep {
/// Navigate to the `i`-th parameter of a function type.
Param(FunctionParamIdx),
/// Navigate to the return type of a function type.
Return,
/// Navigate to the `i`-th type argument of a generic type (enum, `Box`, …).
TypeArg(usize),
}

impl std::fmt::Display for TypePositionStep {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
match self {
TypePositionStep::Param(idx) => write!(f, "{}", idx),
TypePositionStep::Return => f.write_str("result"),
TypePositionStep::TypeArg(i) => write!(f, "{}", i),
}
}
}

/// A path addressing a sub-type within a type, used to attach a refinement.
///
/// An empty path addresses the type itself. Each step descends one level:
/// [`TypePositionStep::Param`] / [`TypePositionStep::Return`] enter a function
/// type's parameter or return slot, and [`TypePositionStep::TypeArg`] enters a
/// generic type argument. Steps combine freely, so positions inside
/// higher-order function types are expressible. A path applied to a function
/// type is therefore non-empty, beginning with a `Param`/`Return` step.
///
/// Examples (function `fn f(x: List<T>) -> Box<T>`):
/// - `$0` — parameter `x`.
/// - `result` — the return type.
/// - `$0, 0` — the first type arg of `x`.
/// - `result, 0` — the pointee of the `Box` return.
/// - `$0, result` — the return of a function-typed param `x`.
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct TypePosition {
steps: Vec<TypePositionStep>,
}

impl TypePosition {
pub fn new(steps: Vec<TypePositionStep>) -> Self {
TypePosition { steps }
}

pub fn steps(&self) -> &[TypePositionStep] {
&self.steps
}
}

impl std::fmt::Display for TypePosition {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
let mut iter = self.steps.iter();
if let Some(first) = iter.next() {
write!(f, "{}", first)?;
}
for s in iter {
write!(f, ", {}", s)?;
}
Ok(())
}
}

#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
pub enum FunctionAbi {
#[default]
Expand Down Expand Up @@ -1475,6 +1552,54 @@ where
}
}

impl RefinedType<FunctionParamIdx> {
/// Installs `refinement` at the sub-type addressed by `steps`.
///
/// An empty `steps` slice replaces the refinement at this node. Each step
/// in the slice navigates one level deeper:
/// - [`TypePositionStep::TypeArg`] descends into enum type arguments or the
/// `Box` pointee.
/// - [`TypePositionStep::Param`] / [`TypePositionStep::Return`] descend
/// into a function-typed position's parameter or return slot.
pub fn install_refinement_at(
&mut self,
steps: &[TypePositionStep],
refinement: Refinement<FunctionParamIdx>,
) {
let Some((step, rest)) = steps.split_first() else {
self.refinement = refinement;
return;
};
match step {
TypePositionStep::TypeArg(i) => match &mut self.ty {
Type::Enum(e) => {
let arg = e.args.get_mut(TypeParamIdx::from(*i)).unwrap_or_else(|| {
panic!("refine step [{}] out of range for enum type", i)
});
arg.install_refinement_at(rest, refinement);
}
Type::Pointer(p) => {
assert_eq!(*i, 0, "Box type position must be [0]");
p.elem.install_refinement_at(rest, refinement);
}
ty => panic!("TypeArg step on unsupported type: {:?}", ty),
},
TypePositionStep::Param(idx) => match &mut self.ty {
Type::Function(func) => {
func.params[*idx].install_refinement_at(rest, refinement);
}
ty => panic!("Param step on non-function type: {:?}", ty),
},
TypePositionStep::Return => match &mut self.ty {
Type::Function(func) => {
func.ret.install_refinement_at(rest, refinement);
}
ty => panic!("Return step on non-function type: {:?}", ty),
},
}
}
}

impl<FV> RefinedType<FV> {
fn pretty_atom<'a, 'b, D>(
&'b self,
Expand Down
2 changes: 1 addition & 1 deletion tests/ui/fail/annot_box_term.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
//@error-in-other-file: Unsat
//@compile-flags: -C debug-assertions=off

#[thrust::sig(fn(x: int) -> {r: Box<int> | r == <x>})]
#[thrust_macros::sig(fn(x: i64) -> { r: Box<i64> | r == thrust_models::model::Box::new(x) })]
fn box_create(x: i64) -> Box<i64> {
Box::new(x)
}
Expand Down
9 changes: 9 additions & 0 deletions tests/ui/fail/refine_param_simple.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
//@error-in-other-file: Unsat

#[thrust_macros::param(x: { v: i32 | v > 0 })]
#[thrust_macros::ret({ r: i32 | r > x })]
fn f(x: i32) -> i32 {
x
}

fn main() {}
Loading