diff --git a/Cargo.toml b/Cargo.toml index 8445a188a..040724e1c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -2,7 +2,8 @@ members = [ "folding-schemes", "solidity-verifiers", - "cli" + "cli", + "frontend-macro" ] resolver = "2" diff --git a/examples/multi_inputs.rs b/examples/multi_inputs.rs index bb1508291..170b75823 100644 --- a/examples/multi_inputs.rs +++ b/examples/multi_inputs.rs @@ -8,6 +8,7 @@ use ark_r1cs_std::alloc::AllocVar; use ark_r1cs_std::fields::fp::FpVar; use ark_relations::r1cs::{ConstraintSystemRef, SynthesisError}; use core::marker::PhantomData; +use frontend_macro::Flatten; use std::time::Instant; use ark_bn254::{constraints::GVar, Bn254, Fr, G1Projective as Projective}; @@ -25,6 +26,17 @@ use utils::init_nova_ivc_params; /// we get by applying the step. /// In this example we set z_i and z_{i+1} to have five elements, and at each step we do different /// operations on each of them. +/// + +#[derive(Flatten)] +pub struct State { + pub a: F, + pub b: F, + pub c: F, + pub d: F, + pub e: F, +} + #[derive(Clone, Copy, Debug)] pub struct MultiInputsFCircuit { _f: PhantomData, @@ -35,9 +47,11 @@ impl FCircuit for MultiInputsFCircuit { fn new(_params: Self::Params) -> Result { Ok(Self { _f: PhantomData }) } + fn state_len(&self) -> usize { - 5 + State::::state_number() } + fn external_inputs_len(&self) -> usize { 0 } @@ -50,13 +64,16 @@ impl FCircuit for MultiInputsFCircuit { z_i: Vec, _external_inputs: Vec, ) -> Result, Error> { - let a = z_i[0] + F::from(4_u32); - let b = z_i[1] + F::from(40_u32); - let c = z_i[2] * F::from(4_u32); - let d = z_i[3] * F::from(40_u32); - let e = z_i[4] + F::from(100_u32); - - Ok(vec![a, b, c, d, e]) + let state = State::from(z_i); + + let next_state = State { + a: state.a + F::from(4_u32), + b: state.b + F::from(40_u32), + c: state.c * F::from(4_u32), + d: state.d * F::from(40_u32), + e: state.e + F::from(100_u32), + }; + Ok(Vec::from(next_state)) } /// generates the constraints for the step of F for the given z_i @@ -67,16 +84,21 @@ impl FCircuit for MultiInputsFCircuit { z_i: Vec>, _external_inputs: Vec>, ) -> Result>, SynthesisError> { + let cs_state = State::cs_state(z_i.clone()); + let four = FpVar::::new_constant(cs.clone(), F::from(4u32))?; let forty = FpVar::::new_constant(cs.clone(), F::from(40u32))?; let onehundred = FpVar::::new_constant(cs.clone(), F::from(100u32))?; - let a = z_i[0].clone() + four.clone(); - let b = z_i[1].clone() + forty.clone(); - let c = z_i[2].clone() * four; - let d = z_i[3].clone() * forty; - let e = z_i[4].clone() + onehundred; - Ok(vec![a, b, c, d, e]) + let next_cs_state = StateConstraint { + a: cs_state.a.clone() + four.clone(), + b: cs_state.b.clone() + forty.clone(), + c: cs_state.c.clone() * four, + d: cs_state.d.clone() * forty, + e: cs_state.e.clone() + onehundred, + }; + + Ok(Vec::from(next_cs_state)) } } diff --git a/folding-schemes/Cargo.toml b/folding-schemes/Cargo.toml index 964a633d2..65766190a 100644 --- a/folding-schemes/Cargo.toml +++ b/folding-schemes/Cargo.toml @@ -34,6 +34,7 @@ ark-grumpkin = {version="0.4.0", features=["r1cs"]} rand = "0.8.5" tracing = { version = "0.1", default-features = false, features = [ "attributes" ] } tracing-subscriber = { version = "0.2" } +frontend-macro = { path = "../frontend-macro/"} [features] default = ["parallel"] diff --git a/folding-schemes/src/frontend/mod.rs b/folding-schemes/src/frontend/mod.rs index 59f18eb16..ab85e69a3 100644 --- a/folding-schemes/src/frontend/mod.rs +++ b/folding-schemes/src/frontend/mod.rs @@ -22,7 +22,10 @@ pub trait FCircuit: Clone + Debug { /// returns the number of elements in the external inputs used by the FCircuit. External inputs /// are optional, and in case no external inputs are used, this method should return 0. - fn external_inputs_len(&self) -> usize; + /// default is zero + fn external_inputs_len(&self) -> usize { + 0 + } /// computes the next state values in place, assigning z_{i+1} into z_i, and computing the new /// z_{i+1} diff --git a/frontend-macro/Cargo.toml b/frontend-macro/Cargo.toml new file mode 100644 index 000000000..17c0fbe7e --- /dev/null +++ b/frontend-macro/Cargo.toml @@ -0,0 +1,21 @@ +[package] +name = "frontend-macro" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +syn = { version = "0.15", features = ["extra-traits"] } +quote = "0.6" +proc-macro2 = "0.4" + +[dev-dependencies] +trybuild = "1.0" +ark-ff = "^0.4.0" +ark-bn254 = {version="0.4.0", features=["r1cs"]} +ark-r1cs-std = { version = "0.4.0", default-features = false } # this is patched at the workspace level + + +[lib] +proc-macro = true \ No newline at end of file diff --git a/frontend-macro/src/lib.rs b/frontend-macro/src/lib.rs new file mode 100644 index 000000000..36a17b08d --- /dev/null +++ b/frontend-macro/src/lib.rs @@ -0,0 +1,109 @@ +use proc_macro::TokenStream; +use quote::quote; +use syn::{parse_macro_input, DeriveInput}; + +#[proc_macro_derive(Flatten)] +pub fn derive(input: TokenStream) -> TokenStream { + let ast = parse_macro_input!(input as DeriveInput); + let iden = &ast.ident; + let cs_iden_name = format!("{}Constraint", iden); + let cs_iden = syn::Ident::new(&cs_iden_name, iden.span()); + + let (impl_generics, ty_generics, where_clause) = ast.generics.split_for_impl(); + + let fields = if let syn::Data::Struct(syn::DataStruct { + fields: syn::Fields::Named(syn::FieldsNamed { ref named, .. }), + .. + }) = ast.data + { + named + } else { + unimplemented!(); + }; + + let cs_fields_to_vec = fields.iter().map(|f| { + let name = &f.ident; + quote! { value.#name } + }); + + let cs_fields = fields.iter().map(|f| { + let name = &f.ident; + quote! {pub #name: ark_r1cs_std::fields::fp::FpVar } + }); + + let builder_fields = fields.iter().map(|f| { + let name = &f.ident; + quote! { value.#name } + }); + + let cs_vec_to_fields = fields.iter().enumerate().map(|(id, f)| { + let name = &f.ident; + quote! {#name: vec[#id].clone()} + }); + + let vec_to_fields = fields.iter().enumerate().map(|(id, f)| { + let name = &f.ident; + quote! {#name: vec[#id]} + }); + + let state_number = fields.len(); + + let state_macro = quote! { + impl #impl_generics #iden #ty_generics #where_clause { + pub fn state_number() -> usize { + #state_number + } + } + + impl #impl_generics From<#iden #ty_generics> for Vec #ty_generics #where_clause { + fn from(value: #iden #ty_generics) -> Vec #ty_generics { + vec![#(#builder_fields,)*] + } + } + + impl #impl_generics From for #iden #ty_generics { + fn from(vec: Vec #ty_generics) -> #iden #ty_generics { + assert!(vec.len() == #iden::#ty_generics::state_number()); + #iden { + #(#vec_to_fields,)* + } + } + } + }; + + let constraint_macro = quote! { + pub struct #cs_iden { + #(#cs_fields,)* + } + + + impl From>> for #cs_iden { + fn from(vec: Vec>) -> #cs_iden{ + #cs_iden { + #(#cs_vec_to_fields,)* + } + } + } + + impl From<#cs_iden> for Vec> { + fn from(value: #cs_iden) -> Vec> { + vec![#(#cs_fields_to_vec,)*] + } + } + + + impl #iden{ + pub fn cs_state(v: Vec< ark_r1cs_std::fields::fp::FpVar>) -> #cs_iden { + #cs_iden::from(v) + } + } + }; + + let expanded = quote! { + #state_macro + + #constraint_macro + }; + + expanded.into() +} diff --git a/frontend-macro/tests/parse.rs b/frontend-macro/tests/parse.rs new file mode 100644 index 000000000..665a4d7b1 --- /dev/null +++ b/frontend-macro/tests/parse.rs @@ -0,0 +1,21 @@ +use ark_bn254::Fr; +use ark_ff::PrimeField; +use frontend_macro::Flatten; +#[derive(Flatten, Debug)] +struct State { + a: F, + b: F, +} + +fn main() { + let s = State:: { + a: Fr::from(1u32), + b: Fr::from(1u32), + }; + + let v: Vec = Vec::from(s); + + println!("{:?}", State::::state_number()); + println!("{:?}", v); + println!("{:?}", State::from(v)); +} diff --git a/frontend-macro/tests/runner.rs b/frontend-macro/tests/runner.rs new file mode 100644 index 000000000..b12fe9608 --- /dev/null +++ b/frontend-macro/tests/runner.rs @@ -0,0 +1,32 @@ +#[cfg(test)] +mod test { + + #[test] + fn try_compile() { + let t = trybuild::TestCases::new(); + t.pass("tests/parse.rs"); + } + + #[test] + fn try_run_test() { + use ark_bn254::Fr; + use ark_ff::PrimeField; + use frontend_macro::Flatten; + #[derive(Flatten, Debug, PartialEq, Clone)] + struct State { + a: F, + b: F, + } + + let s = State:: { + a: Fr::from(1u32), + b: Fr::from(1u32), + }; + + let v: Vec = Vec::from(s.clone()); + + assert_eq!(2, State::::state_number()); + assert_eq!(vec![s.a, s.b], v); + assert_eq!(s, State::from(v)); + } +}