From 4c67ff566f6453fda3063d3315db1a2a71ed861f Mon Sep 17 00:00:00 2001 From: coeff-aij <175928954+coeff-aij@users.noreply.github.com> Date: Thu, 29 Jan 2026 16:17:36 +0900 Subject: [PATCH 1/6] add: tests for annotations in traits --- tests/ui/fail/annot_preds_trait.rs | 43 ++++++++++++++++++++++++++++++ tests/ui/pass/annot_preds_trait.rs | 43 ++++++++++++++++++++++++++++++ 2 files changed, 86 insertions(+) create mode 100644 tests/ui/fail/annot_preds_trait.rs create mode 100644 tests/ui/pass/annot_preds_trait.rs diff --git a/tests/ui/fail/annot_preds_trait.rs b/tests/ui/fail/annot_preds_trait.rs new file mode 100644 index 0000000..b8eb4a1 --- /dev/null +++ b/tests/ui/fail/annot_preds_trait.rs @@ -0,0 +1,43 @@ +//@error-in-other-file: Unsat +//@compile-flags: -Adead_code -C debug-assertions=off + +// A is represented as Tuple in SMT-LIB2 format. +struct A { + x: i64, +} + +trait Double { + // Support annotations in trait definitions + #[thrust::predicate] + fn is_double(self, doubled: Self) -> bool; + + // This annotations are applied to all implementors of the `Double` trait. + #[thrust::requires(true)] + #[thrust::ensures(is_double(*self, ^self))] + fn double(&mut self); +} + +impl Double for A { + // Write concrete definitions for predicates in `impl` blocks + #[thrust::predicate] + fn is_double(self, doubled: Self) -> bool { + // (tuple_proj.0 self) is equivalent to self.x + // self.x * 3 == doubled.x (this isn't actually doubled!) is written as following: + "(= + (* (tuple_proj.0 self) 3) + (tuple_proj.0 doubled) + )"; true // This definition does not comply with annotations in trait! + } + + // Check if this method complies with annotations in + // trait definition. + fn double(&mut self) { + self.x += self.x; + } +} + +fn main() { + let mut a = A { x: 3 }; + a.double(); + assert!(a.x == 6); +} diff --git a/tests/ui/pass/annot_preds_trait.rs b/tests/ui/pass/annot_preds_trait.rs new file mode 100644 index 0000000..5f38585 --- /dev/null +++ b/tests/ui/pass/annot_preds_trait.rs @@ -0,0 +1,43 @@ +//@check-pass +//@compile-flags: -Adead_code -C debug-assertions=off + +// A is represented as Tuple in SMT-LIB2 format. +struct A { + x: i64, +} + +trait Double { + // Support annotations in trait definitions + #[thrust::predicate] + fn is_double(self, doubled: Self) -> bool; + + // This annotations are applied to all implementors of the `Double` trait. + #[thrust::requires(true)] + #[thrust::ensures(is_double(*self, ^self))] + fn double(&mut self); +} + +impl Double for A { + // Write concrete definitions for predicates in `impl` blocks + #[thrust::predicate] + fn is_double(self, doubled: Self) -> bool { + // (tuple_proj.0 self) is equivalent to self.x + // self.x * 2 == doubled.x is written as following: + "(= + (* (tuple_proj.0 self) 2) + (tuple_proj.0 doubled) + )"; true + } + + // Check if this method complies with annotations in + // trait definition. + fn double(&mut self) { + self.x += self.x; + } +} + +fn main() { + let mut a = A { x: 3 }; + a.double(); + assert!(a.x == 6); +} From 87d818c12cbab7d1cfa7643dcbdfb6d83ad99146 Mon Sep 17 00:00:00 2001 From: coeff-aij <175928954+coeff-aij@users.noreply.github.com> Date: Sun, 1 Feb 2026 19:02:15 +0900 Subject: [PATCH 2/6] add: reference trait-side definitions for require/ensure annotations of functions in impl blocks Update src/analyze/local_def.rs Co-authored-by: Hiromi Ogawa --- src/analyze/local_def.rs | 70 +++++++++++++++++++++++++++++----------- 1 file changed, 52 insertions(+), 18 deletions(-) diff --git a/src/analyze/local_def.rs b/src/analyze/local_def.rs index 67033eb..85efdbb 100644 --- a/src/analyze/local_def.rs +++ b/src/analyze/local_def.rs @@ -53,15 +53,19 @@ pub struct Analyzer<'tcx, 'ctx> { } impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { - fn extract_require_annot(&self, resolver: T) -> Option> + fn extract_require_annot( + &self, + def_id: DefId, + resolver: T, + ) -> Option> where T: annot::Resolver, { let mut require_annot = None; - for attrs in self.tcx.get_attrs_by_path( - self.local_def_id.to_def_id(), - &analyze::annot::requires_path(), - ) { + for attrs in self + .tcx + .get_attrs_by_path(def_id, &analyze::annot::requires_path()) + { if require_annot.is_some() { unimplemented!(); } @@ -72,15 +76,15 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { require_annot } - fn extract_ensure_annot(&self, resolver: T) -> Option> + fn extract_ensure_annot(&self, def_id: DefId, resolver: T) -> Option> where T: annot::Resolver, { let mut ensure_annot = None; - for attrs in self.tcx.get_attrs_by_path( - self.local_def_id.to_def_id(), - &analyze::annot::ensures_path(), - ) { + for attrs in self + .tcx + .get_attrs_by_path(def_id, &analyze::annot::ensures_path()) + { if ensure_annot.is_some() { unimplemented!(); } @@ -252,6 +256,17 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { || (all_params_annotated && has_ret) } + pub fn trait_item_id(&mut self) -> Option { + let impl_item_assoc = self + .tcx + .opt_associated_item(self.local_def_id.to_def_id())?; + let trait_item_id = impl_item_assoc + .trait_item_def_id + .and_then(|id| id.as_local())?; + + Some(trait_item_id) + } + pub fn expected_ty(&mut self) -> rty::RefinedType { let sig = self .ctx @@ -268,14 +283,33 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { param_resolver.push_param(input_ident.name, input_ty.to_sort()); } - let mut require_annot = self.extract_require_annot(¶m_resolver); - let mut ensure_annot = { - let output_ty = self.type_builder.build(sig.output()); - let resolver = annot::StackedResolver::default() - .resolver(analyze::annot::ResultResolver::new(output_ty.to_sort())) - .resolver((¶m_resolver).map(rty::RefinedTypeVar::Free)); - self.extract_ensure_annot(resolver) - }; + let mut require_annot = + self.extract_require_annot(self.local_def_id.to_def_id(), ¶m_resolver); + + let output_ty = self.type_builder.build(sig.output()); + let result_param_resolver = annot::StackedResolver::default() + .resolver(analyze::annot::ResultResolver::new(output_ty.to_sort())) + .resolver((¶m_resolver).map(rty::RefinedTypeVar::Free)); + let mut ensure_annot = + self.extract_ensure_annot(self.local_def_id.to_def_id(), &result_param_resolver); + + if let Some(trait_item_id) = self.trait_item_id() { + tracing::info!( + "trait item fonud: {:?}", + trait_item_id, + ); + let trait_require_annot = + self.extract_require_annot(trait_item_id.into(), ¶m_resolver); + let trait_ensure_annot = + self.extract_ensure_annot(trait_item_id.into(), &result_param_resolver); + + assert!(require_annot.is_none() || trait_require_annot.is_none()); + require_annot = require_annot.or(trait_require_annot); + + assert!(ensure_annot.is_none() || trait_ensure_annot.is_none()); + ensure_annot = ensure_annot.or(trait_ensure_annot); + } + let param_annots = self.extract_param_annots(¶m_resolver); let ret_annot = self.extract_ret_annot(¶m_resolver); From 35fe6d4b7ee971dfdfb938dcf0b45caab3101ab6 Mon Sep 17 00:00:00 2001 From: coeff-aij <175928954+coeff-aij@users.noreply.github.com> Date: Tue, 3 Feb 2026 15:49:58 +0900 Subject: [PATCH 3/6] change: move extract_*_annot()s to analyze::Analyzer from analyze::local_def::Analyzer --- src/analyze.rs | 44 ++++++++++++++++++++++++++ src/analyze/local_def.rs | 67 ++++++++-------------------------------- 2 files changed, 57 insertions(+), 54 deletions(-) diff --git a/src/analyze.rs b/src/analyze.rs index 574a44b..db294a9 100644 --- a/src/analyze.rs +++ b/src/analyze.rs @@ -16,6 +16,8 @@ use rustc_middle::mir::{self, BasicBlock, Local}; use rustc_middle::ty::{self as mir_ty, TyCtxt}; use rustc_span::def_id::{DefId, LocalDefId}; +use crate::analyze; +use crate::annot::{AnnotFormula, AnnotParser, Resolver}; use crate::chc; use crate::pretty::PrettyDisplayExt as _; use crate::refine::{self, BasicBlockType, TypeBuilder}; @@ -435,4 +437,46 @@ impl<'tcx> Analyzer<'tcx> { let body = self.tcx.optimized_mir(local_def_id); self.local_fn_sig_with_body(local_def_id, body) } + + fn extract_require_annot( + &self, + def_id: DefId, + resolver: T, + ) -> Option> + where + T: Resolver, + { + let mut require_annot = None; + for attrs in self + .tcx + .get_attrs_by_path(def_id, &analyze::annot::requires_path()) + { + if require_annot.is_some() { + unimplemented!(); + } + let ts = analyze::annot::extract_annot_tokens(attrs.clone()); + let require = AnnotParser::new(&resolver).parse_formula(ts).unwrap(); + require_annot = Some(require); + } + require_annot + } + + fn extract_ensure_annot(&self, def_id: DefId, resolver: T) -> Option> + where + T: Resolver, + { + let mut ensure_annot = None; + for attrs in self + .tcx + .get_attrs_by_path(def_id, &analyze::annot::ensures_path()) + { + if ensure_annot.is_some() { + unimplemented!(); + } + let ts = analyze::annot::extract_annot_tokens(attrs.clone()); + let ensure = AnnotParser::new(&resolver).parse_formula(ts).unwrap(); + ensure_annot = Some(ensure); + } + ensure_annot + } } diff --git a/src/analyze/local_def.rs b/src/analyze/local_def.rs index 85efdbb..805063d 100644 --- a/src/analyze/local_def.rs +++ b/src/analyze/local_def.rs @@ -53,48 +53,6 @@ pub struct Analyzer<'tcx, 'ctx> { } impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { - fn extract_require_annot( - &self, - def_id: DefId, - resolver: T, - ) -> Option> - where - T: annot::Resolver, - { - let mut require_annot = None; - for attrs in self - .tcx - .get_attrs_by_path(def_id, &analyze::annot::requires_path()) - { - if require_annot.is_some() { - unimplemented!(); - } - let ts = analyze::annot::extract_annot_tokens(attrs.clone()); - let require = AnnotParser::new(&resolver).parse_formula(ts).unwrap(); - require_annot = Some(require); - } - require_annot - } - - fn extract_ensure_annot(&self, def_id: DefId, resolver: T) -> Option> - where - T: annot::Resolver, - { - let mut ensure_annot = None; - for attrs in self - .tcx - .get_attrs_by_path(def_id, &analyze::annot::ensures_path()) - { - if ensure_annot.is_some() { - unimplemented!(); - } - let ts = analyze::annot::extract_annot_tokens(attrs.clone()); - let ensure = AnnotParser::new(&resolver).parse_formula(ts).unwrap(); - ensure_annot = Some(ensure); - } - ensure_annot - } - fn extract_param_annots(&self, resolver: T) -> Vec<(Ident, rty::RefinedType)> where T: annot::Resolver, @@ -283,25 +241,26 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { param_resolver.push_param(input_ident.name, input_ty.to_sort()); } - let mut require_annot = - self.extract_require_annot(self.local_def_id.to_def_id(), ¶m_resolver); + let mut require_annot = self + .ctx + .extract_require_annot(self.local_def_id.to_def_id(), ¶m_resolver); let output_ty = self.type_builder.build(sig.output()); let result_param_resolver = annot::StackedResolver::default() .resolver(analyze::annot::ResultResolver::new(output_ty.to_sort())) .resolver((¶m_resolver).map(rty::RefinedTypeVar::Free)); - let mut ensure_annot = - self.extract_ensure_annot(self.local_def_id.to_def_id(), &result_param_resolver); + let mut ensure_annot = self + .ctx + .extract_ensure_annot(self.local_def_id.to_def_id(), &result_param_resolver); if let Some(trait_item_id) = self.trait_item_id() { - tracing::info!( - "trait item fonud: {:?}", - trait_item_id, - ); - let trait_require_annot = - self.extract_require_annot(trait_item_id.into(), ¶m_resolver); - let trait_ensure_annot = - self.extract_ensure_annot(trait_item_id.into(), &result_param_resolver); + tracing::info!("trait item fonud: {:?}", trait_item_id); + let trait_require_annot = self + .ctx + .extract_require_annot(trait_item_id.into(), ¶m_resolver); + let trait_ensure_annot = self + .ctx + .extract_ensure_annot(trait_item_id.into(), &result_param_resolver); assert!(require_annot.is_none() || trait_require_annot.is_none()); require_annot = require_annot.or(trait_require_annot); From 2e8d7b1ba3b646afc18327fe29b97347b33c95cb Mon Sep 17 00:00:00 2001 From: coeff-aij <175928954+coeff-aij@users.noreply.github.com> Date: Wed, 4 Feb 2026 23:45:56 +0900 Subject: [PATCH 4/6] change: insert type names as prefix of name for predicates in impl blocks --- src/analyze/local_def.rs | 32 ++++++++++++++++++++++++++++++-- 1 file changed, 30 insertions(+), 2 deletions(-) diff --git a/src/analyze/local_def.rs b/src/analyze/local_def.rs index 805063d..ee32c7b 100644 --- a/src/analyze/local_def.rs +++ b/src/analyze/local_def.rs @@ -89,8 +89,36 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { ret_annot } + fn impl_type(&self) -> Option> { + use rustc_hir::def::DefKind; + + let parent_def_id = self + .tcx + .parent(self.local_def_id.to_def_id()); + + if !matches!(self.tcx.def_kind(parent_def_id), DefKind::Impl { .. }) { + return None; + } + + let self_ty = self + .tcx + .type_of(parent_def_id) + .instantiate_identity(); + + Some(self_ty) + } + pub fn analyze_predicate_definition(&self, local_def_id: LocalDefId) { - let pred_name = self.tcx.item_name(local_def_id.to_def_id()).to_string(); + // predicate's name + let impl_type = self.impl_type(); + let pred_item_name = self + .tcx + .item_name(local_def_id.to_def_id()) + .to_string(); + let pred_name = match impl_type { + Some(t) => t.to_string() + "_" + &pred_item_name, + None => pred_item_name, + }; // function's body use rustc_hir::{Block, Expr, ExprKind}; @@ -214,7 +242,7 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { || (all_params_annotated && has_ret) } - pub fn trait_item_id(&mut self) -> Option { + pub fn trait_item_id(&self) -> Option { let impl_item_assoc = self .tcx .opt_associated_item(self.local_def_id.to_def_id())?; From 2afd7a4491f8ad7a218f86cdf98aaf0cb9f1d329 Mon Sep 17 00:00:00 2001 From: coeff-aij <175928954+coeff-aij@users.noreply.github.com> Date: Sat, 7 Feb 2026 17:16:10 +0900 Subject: [PATCH 5/6] =?UTF-8?q?add:=20tests=20for=20identifying=20struct?= =?UTF-8?q?=E2=80=91bound=20predicates=20using=20`Self::`?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/ui/fail/annot_preds_trait.rs | 2 +- tests/ui/fail/annot_preds_trait_multi.rs | 71 ++++++++++++++++++++++++ tests/ui/pass/annot_preds_trait.rs | 2 +- tests/ui/pass/annot_preds_trait_multi.rs | 71 ++++++++++++++++++++++++ 4 files changed, 144 insertions(+), 2 deletions(-) create mode 100644 tests/ui/fail/annot_preds_trait_multi.rs create mode 100644 tests/ui/pass/annot_preds_trait_multi.rs diff --git a/tests/ui/fail/annot_preds_trait.rs b/tests/ui/fail/annot_preds_trait.rs index b8eb4a1..bd0bdbc 100644 --- a/tests/ui/fail/annot_preds_trait.rs +++ b/tests/ui/fail/annot_preds_trait.rs @@ -13,7 +13,7 @@ trait Double { // This annotations are applied to all implementors of the `Double` trait. #[thrust::requires(true)] - #[thrust::ensures(is_double(*self, ^self))] + #[thrust::ensures(Self::is_double(*self, ^self))] fn double(&mut self); } diff --git a/tests/ui/fail/annot_preds_trait_multi.rs b/tests/ui/fail/annot_preds_trait_multi.rs new file mode 100644 index 0000000..a277358 --- /dev/null +++ b/tests/ui/fail/annot_preds_trait_multi.rs @@ -0,0 +1,71 @@ +//@error-in-other-file: Unsat +//@compile-flags: -Adead_code -C debug-assertions=off + +trait Double { + // Support annotations in trait definitions + #[thrust::predicate] + fn is_double(self, doubled: Self) -> bool; + + // This annotations are applied to all implementors of the `Double` trait. + #[thrust::requires(true)] + #[thrust::ensures(Self::is_double(*self, ^self))] + fn double(&mut self); +} + +// A is represented as Tuple in SMT-LIB2 format. +struct A { + x: i64, +} + +impl Double for A { + #[thrust::predicate] + fn is_double(self, doubled: Self) -> bool { + // self.x * 2 == doubled.x + "(= + (* (tuple_proj.0 self) 2) + (tuple_proj.0 doubled) + )"; true + } + + fn double(&mut self) { + self.x += self.x; + } +} + +// B is represented as Tuple in SMT-LIB2 format. +struct B { + x: i64, + y: i64, +} + +impl Double for B { + #[thrust::predicate] + fn is_double(self, doubled: Self) -> bool { + // self.x * 3 == doubled.x && self.y * 2 == doubled.y (this isn't actually doubled!) + "(and + (= + (* (tuple_proj.0 self) 3) + (tuple_proj.0 doubled) + ) + (= + (* (tuple_proj.1 self) 2) + (tuple_proj.1 doubled) + ) + )"; true // This definition does not comply with annotations in trait! + } + + fn double(&mut self) { + self.x += self.x; + self.y += self.y; + } +} + +fn main() { + let mut a = A { x: 3 }; + a.double(); + assert!(a.x == 6); + + let mut b = B { x: 2, y: 5 }; + b.double(); + assert!(b.x == 4 && b.y == 10); +} diff --git a/tests/ui/pass/annot_preds_trait.rs b/tests/ui/pass/annot_preds_trait.rs index 5f38585..bb2e37c 100644 --- a/tests/ui/pass/annot_preds_trait.rs +++ b/tests/ui/pass/annot_preds_trait.rs @@ -13,7 +13,7 @@ trait Double { // This annotations are applied to all implementors of the `Double` trait. #[thrust::requires(true)] - #[thrust::ensures(is_double(*self, ^self))] + #[thrust::ensures(Self::is_double(*self, ^self))] fn double(&mut self); } diff --git a/tests/ui/pass/annot_preds_trait_multi.rs b/tests/ui/pass/annot_preds_trait_multi.rs new file mode 100644 index 0000000..a51ff84 --- /dev/null +++ b/tests/ui/pass/annot_preds_trait_multi.rs @@ -0,0 +1,71 @@ +//@check-pass +//@compile-flags: -Adead_code -C debug-assertions=off + +trait Double { + // Support annotations in trait definitions + #[thrust::predicate] + fn is_double(self, doubled: Self) -> bool; + + // This annotations are applied to all implementors of the `Double` trait. + #[thrust::requires(true)] + #[thrust::ensures(Self::is_double(*self, ^self))] + fn double(&mut self); +} + +// A is represented as Tuple in SMT-LIB2 format. +struct A { + x: i64, +} + +impl Double for A { + #[thrust::predicate] + fn is_double(self, doubled: Self) -> bool { + // self.x * 2 == doubled.x + "(= + (* (tuple_proj.0 self) 2) + (tuple_proj.0 doubled) + )"; true + } + + fn double(&mut self) { + self.x += self.x; + } +} + +// B is represented as Tuple in SMT-LIB2 format. +struct B { + x: i64, + y: i64, +} + +impl Double for B { + #[thrust::predicate] + fn is_double(self, doubled: Self) -> bool { + // self.x * 2 == doubled.x && self.y * 2 == doubled.y + "(and + (= + (* (tuple_proj.0 self) 2) + (tuple_proj.0 doubled) + ) + (= + (* (tuple_proj.1 self) 2) + (tuple_proj.1 doubled) + ) + )"; true + } + + fn double(&mut self) { + self.x += self.x; + self.y += self.y; + } +} + +fn main() { + let mut a = A { x: 3 }; + a.double(); + assert!(a.x == 6); + + let mut b = B { x: 2, y: 5 }; + b.double(); + assert!(b.x == 4 && b.y == 10); +} From 23b3d9a26edd1a09c3c8b90bf1168a8f7b9c064f Mon Sep 17 00:00:00 2001 From: coeff-aij <175928954+coeff-aij@users.noreply.github.com> Date: Sat, 7 Feb 2026 17:39:25 +0900 Subject: [PATCH 6/6] add: Identify struct-bounded predicates using `Self::` prefix --- src/analyze.rs | 14 ++++++-- src/analyze/local_def.rs | 75 +++++++++++++++++++++++----------------- src/annot.rs | 46 ++++++++++++++++++++++-- 3 files changed, 99 insertions(+), 36 deletions(-) diff --git a/src/analyze.rs b/src/analyze.rs index db294a9..f8078d5 100644 --- a/src/analyze.rs +++ b/src/analyze.rs @@ -442,11 +442,13 @@ impl<'tcx> Analyzer<'tcx> { &self, def_id: DefId, resolver: T, + self_type_name: Option, ) -> Option> where T: Resolver, { let mut require_annot = None; + let parser = AnnotParser::new(&resolver, self_type_name); for attrs in self .tcx .get_attrs_by_path(def_id, &analyze::annot::requires_path()) @@ -455,17 +457,23 @@ impl<'tcx> Analyzer<'tcx> { unimplemented!(); } let ts = analyze::annot::extract_annot_tokens(attrs.clone()); - let require = AnnotParser::new(&resolver).parse_formula(ts).unwrap(); + let require = parser.parse_formula(ts).unwrap(); require_annot = Some(require); } require_annot } - fn extract_ensure_annot(&self, def_id: DefId, resolver: T) -> Option> + fn extract_ensure_annot( + &self, + def_id: DefId, + resolver: T, + self_type_name: Option, + ) -> Option> where T: Resolver, { let mut ensure_annot = None; + let parser = AnnotParser::new(&resolver, self_type_name); for attrs in self .tcx .get_attrs_by_path(def_id, &analyze::annot::ensures_path()) @@ -474,7 +482,7 @@ impl<'tcx> Analyzer<'tcx> { unimplemented!(); } let ts = analyze::annot::extract_annot_tokens(attrs.clone()); - let ensure = AnnotParser::new(&resolver).parse_formula(ts).unwrap(); + let ensure = parser.parse_formula(ts).unwrap(); ensure_annot = Some(ensure); } ensure_annot diff --git a/src/analyze/local_def.rs b/src/analyze/local_def.rs index ee32c7b..97a20fe 100644 --- a/src/analyze/local_def.rs +++ b/src/analyze/local_def.rs @@ -53,28 +53,38 @@ pub struct Analyzer<'tcx, 'ctx> { } impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { - fn extract_param_annots(&self, resolver: T) -> Vec<(Ident, rty::RefinedType)> + fn extract_param_annots( + &self, + resolver: T, + self_type_name: Option, + ) -> Vec<(Ident, rty::RefinedType)> where T: annot::Resolver, { let mut param_annots = Vec::new(); + let parser = AnnotParser::new(&resolver, self_type_name); for attrs in self .tcx .get_attrs_by_path(self.local_def_id.to_def_id(), &analyze::annot::param_path()) { let ts = analyze::annot::extract_annot_tokens(attrs.clone()); let (ident, ts) = analyze::annot::split_param(&ts); - let param = AnnotParser::new(&resolver).parse_rty(ts).unwrap(); + let param = parser.parse_rty(ts).unwrap(); param_annots.push((ident, param)); } param_annots } - fn extract_ret_annot(&self, resolver: T) -> Option> + fn extract_ret_annot( + &self, + resolver: T, + self_type_name: Option, + ) -> Option> where T: annot::Resolver, { let mut ret_annot = None; + let parser = AnnotParser::new(&resolver, self_type_name); for attrs in self .tcx .get_attrs_by_path(self.local_def_id.to_def_id(), &analyze::annot::ret_path()) @@ -83,7 +93,7 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { unimplemented!(); } let ts = analyze::annot::extract_annot_tokens(attrs.clone()); - let ret = AnnotParser::new(&resolver).parse_rty(ts).unwrap(); + let ret = parser.parse_rty(ts).unwrap(); ret_annot = Some(ret); } ret_annot @@ -92,29 +102,21 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { fn impl_type(&self) -> Option> { use rustc_hir::def::DefKind; - let parent_def_id = self - .tcx - .parent(self.local_def_id.to_def_id()); + let parent_def_id = self.tcx.parent(self.local_def_id.to_def_id()); if !matches!(self.tcx.def_kind(parent_def_id), DefKind::Impl { .. }) { return None; } - let self_ty = self - .tcx - .type_of(parent_def_id) - .instantiate_identity(); - + let self_ty = self.tcx.type_of(parent_def_id).instantiate_identity(); + Some(self_ty) } pub fn analyze_predicate_definition(&self, local_def_id: LocalDefId) { // predicate's name let impl_type = self.impl_type(); - let pred_item_name = self - .tcx - .item_name(local_def_id.to_def_id()) - .to_string(); + let pred_item_name = self.tcx.item_name(local_def_id.to_def_id()).to_string(); let pred_name = match impl_type { Some(t) => t.to_string() + "_" + &pred_item_name, None => pred_item_name, @@ -269,26 +271,37 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { param_resolver.push_param(input_ident.name, input_ty.to_sort()); } - let mut require_annot = self - .ctx - .extract_require_annot(self.local_def_id.to_def_id(), ¶m_resolver); - let output_ty = self.type_builder.build(sig.output()); let result_param_resolver = annot::StackedResolver::default() .resolver(analyze::annot::ResultResolver::new(output_ty.to_sort())) .resolver((¶m_resolver).map(rty::RefinedTypeVar::Free)); - let mut ensure_annot = self - .ctx - .extract_ensure_annot(self.local_def_id.to_def_id(), &result_param_resolver); + + let self_type_name = self.impl_type().map(|ty| ty.to_string()); + + let mut require_annot = self.ctx.extract_require_annot( + self.local_def_id.to_def_id(), + ¶m_resolver, + self_type_name.clone(), + ); + + let mut ensure_annot = self.ctx.extract_ensure_annot( + self.local_def_id.to_def_id(), + &result_param_resolver, + self_type_name.clone(), + ); if let Some(trait_item_id) = self.trait_item_id() { tracing::info!("trait item fonud: {:?}", trait_item_id); - let trait_require_annot = self - .ctx - .extract_require_annot(trait_item_id.into(), ¶m_resolver); - let trait_ensure_annot = self - .ctx - .extract_ensure_annot(trait_item_id.into(), &result_param_resolver); + let trait_require_annot = self.ctx.extract_require_annot( + trait_item_id.into(), + ¶m_resolver, + self_type_name.clone(), + ); + let trait_ensure_annot = self.ctx.extract_ensure_annot( + trait_item_id.into(), + &result_param_resolver, + self_type_name.clone(), + ); assert!(require_annot.is_none() || trait_require_annot.is_none()); require_annot = require_annot.or(trait_require_annot); @@ -297,8 +310,8 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { ensure_annot = ensure_annot.or(trait_ensure_annot); } - let param_annots = self.extract_param_annots(¶m_resolver); - let ret_annot = self.extract_ret_annot(¶m_resolver); + let param_annots = self.extract_param_annots(¶m_resolver, self_type_name.clone()); + let ret_annot = self.extract_ret_annot(¶m_resolver, self_type_name); if self.is_annotated_as_callable() { if require_annot.is_some() || ensure_annot.is_some() { diff --git a/src/annot.rs b/src/annot.rs index 30bf7d3..769e0e5 100644 --- a/src/annot.rs +++ b/src/annot.rs @@ -250,6 +250,7 @@ impl FormulaOrTerm { /// A parser for refinement type annotations and formula annotations. struct Parser<'a, T> { resolver: T, + self_type_name: Option, cursor: RefTokenTreeCursor<'a>, formula_existentials: HashMap, } @@ -453,6 +454,7 @@ where TokenTree::Delimited(_, _, Delimiter::Parenthesis, s) => { let mut parser = Parser { resolver: self.boxed_resolver(), + self_type_name: self.self_type_name.clone(), cursor: s.trees(), formula_existentials: self.formula_existentials.clone(), }; @@ -493,6 +495,7 @@ where let mut parser = Parser { resolver: self.boxed_resolver(), + self_type_name: self.self_type_name.clone(), cursor: args.trees(), formula_existentials: self.formula_existentials.clone(), }; @@ -518,11 +521,40 @@ where }; let mut parser = Parser { resolver: self.boxed_resolver(), + self_type_name: self.self_type_name.clone(), cursor: s.trees(), formula_existentials: self.formula_existentials.clone(), }; let args = parser.parse_arg_terms()?; parser.end_of_input()?; + + // Identify struct-bound predicates call such as `Self::pred()` + match path.segments.first() { + Some(AnnotPathSegment { + ident: Ident { name: symbol, .. }, + generic_args, + }) if symbol.as_str() == "Self" && generic_args.is_empty() => { + if path.segments.len() != 2 { + unimplemented!("long path beginning with `Self::`"); + } + + let func_name = path.segments.get(1).unwrap().ident.name.as_str(); + let pred_name = if let Some(self_type_name) = &self.self_type_name { + self_type_name.clone() + "_" + func_name + } else { + func_name.to_string() + }; + + let pred_symbol = chc::UserDefinedPred::new(pred_name); + let pred = chc::Pred::UserDefined(pred_symbol); + + let atom = chc::Atom::new(pred, args); + let formula = chc::Formula::Atom(atom); + return Ok(FormulaOrTerm::Formula(formula)); + } + _ => {} + } + let (term, sort) = path.to_datatype_ctor(args); FormulaOrTerm::Term(term, sort) } @@ -908,6 +940,7 @@ where TokenTree::Delimited(_, _, Delimiter::Parenthesis, ts) => { let mut parser = Parser { resolver: self.boxed_resolver(), + self_type_name: self.self_type_name.clone(), cursor: ts.trees(), formula_existentials: self.formula_existentials.clone(), }; @@ -1014,6 +1047,7 @@ where TokenTree::Delimited(_, _, Delimiter::Parenthesis, ts) => { let mut parser = Parser { resolver: self.boxed_resolver(), + self_type_name: self.self_type_name.clone(), cursor: ts.trees(), formula_existentials: self.formula_existentials.clone(), }; @@ -1050,6 +1084,7 @@ where let mut parser = Parser { resolver: self.boxed_resolver(), + self_type_name: self.self_type_name.clone(), cursor: ts.trees(), formula_existentials: self.formula_existentials.clone(), }; @@ -1074,6 +1109,7 @@ where let mut parser = Parser { resolver: RefinementResolver::new(self.boxed_resolver()), + self_type_name: self.self_type_name.clone(), cursor: parser.cursor, formula_existentials: self.formula_existentials.clone(), }; @@ -1199,11 +1235,15 @@ impl<'a, T> StackedResolver<'a, T> { #[derive(Debug, Clone)] pub struct AnnotParser { resolver: T, + self_type_name: Option, } impl AnnotParser { - pub fn new(resolver: T) -> Self { - Self { resolver } + pub fn new(resolver: T, self_type_name: Option) -> Self { + Self { + resolver, + self_type_name, + } } } @@ -1214,6 +1254,7 @@ where pub fn parse_rty(&self, ts: TokenStream) -> Result> { let mut parser = Parser { resolver: &self.resolver, + self_type_name: self.self_type_name.clone(), cursor: ts.trees(), formula_existentials: Default::default(), }; @@ -1225,6 +1266,7 @@ where pub fn parse_formula(&self, ts: TokenStream) -> Result> { let mut parser = Parser { resolver: &self.resolver, + self_type_name: self.self_type_name.clone(), cursor: ts.trees(), formula_existentials: Default::default(), };