diff --git a/sdsl/synthcl/examples/fastWalshTransform/synth/kernel.rkt b/sdsl/synthcl/examples/fastWalshTransform/synth/kernel.rkt index e90708b8..a1da95a0 100644 --- a/sdsl/synthcl/examples/fastWalshTransform/synth/kernel.rkt +++ b/sdsl/synthcl/examples/fastWalshTransform/synth/kernel.rkt @@ -5,8 +5,8 @@ (: int tid group pair match) (: float t1 t2) (= tid (get_global_id 0)) - (= group (idx tid step 1)) ; (% tid step) - (= pair (+ (* (<< step 1) (idx tid step 1)) group)) ; (/ tid step) + (= group (idx tid step)) ; (% tid step) + (= pair (+ (* (<< step 1) (idx tid step)) group)) ; (/ tid step) (= match (+ pair step)) (= t1 [tArray pair]) (= t2 [tArray match]) @@ -33,10 +33,9 @@ (* left right) (% left right)]) -(grammar* int (idx [int tid] [int step] [int depth]) - [choose tid step (?? int) - (op (idx tid step (- depth 1)) - (idx tid step (- depth 1)))]) +(grammar int (idx [int tid] [int step]) + (op (choose tid step (?? int)) + (choose tid step (?? int)))) (kernel void (fwtKernel [float* tArray] [int step]) (: int tid group pair match) diff --git a/sdsl/synthcl/lang/forms.rkt b/sdsl/synthcl/lang/forms.rkt index 0d3766ac..ebef69a6 100644 --- a/sdsl/synthcl/lang/forms.rkt +++ b/sdsl/synthcl/lang/forms.rkt @@ -4,12 +4,12 @@ (for-syntax "types.rkt" "errors.rkt" (only-in racket make-list) (only-in syntax/stx stx-null?)) "types.rkt" "util.rkt" (prefix-in rosette/ (only-in rosette if assert void)) - (only-in rosette/lib/synthax define-synthax [?? @??] choose) + (only-in rosette/lib/synthax define-simple-grammar [?? @??] choose) (only-in "../model/runtime.rkt" address-of malloc) (only-in "builtins.rkt" NULL clCreateProgramWithSource)) (provide assert - print procedure kernel grammar grammar* ?? choose + print procedure kernel grammar ?? choose sizeof @ : = app-or-ref locally-scoped if-statement for-statement range) @@ -156,35 +156,14 @@ (set! param ((type) param)) ... expr ...))])) -; Grammar syntax. We assume (but don't enforce) that -; grammar bodies are free of side effects, and that -; grammars are invoked on side-effect free expressions. -; This assumption is needed only for simplified code generation, -; where we treat each grammar application as a substition -; in the codegen phase. Synthesis works fine without this -; assumption, and a better code generator could get rid of it -; by lifting all the lets/lambda to the top level. -(define-syntax (grammar* stx) - (syntax-case stx () - [(_ out (id [type param] ... [int depth]) expr) - (quasisyntax/loc stx - (define-synthax id - [(param ... depth) - (assert (>= depth 0)) - expr] - (lambda (e sol) - (define vars (syntax->list #'(param ... depth))) - (define vals (cdr (syntax->list e))) - #`(let (#,@(map list vars vals)) expr))))])) - (define-syntax (grammar stx) (syntax-case stx () [(grammar out (id [type param] ...) expr) (quasisyntax/loc stx - (define-synthax (id param ...) expr))])) + (define-simple-grammar (id param ...) expr))])) ; Constant syntax. -(define-synthax (?? t) (@?? (type-base t))) +(define-simple-grammar (?? t) (@?? (type-base t))) ; Syntax for creating a local scope for a sequence of statements. (define-syntax (locally-scoped stx) diff --git a/sdsl/synthcl/lang/main.rkt b/sdsl/synthcl/lang/main.rkt index 0b30b23d..98c97d6a 100644 --- a/sdsl/synthcl/lang/main.rkt +++ b/sdsl/synthcl/lang/main.rkt @@ -30,7 +30,7 @@ locally-scoped range @ sizeof ; Solver-aided statements and forms - assert verify synth choose ?? grammar grammar* + assert verify synth choose ?? grammar ; Real operators = += -= *= /= %= &= $= ^= <<= >>= diff --git a/sdsl/synthcl/lang/queries.rkt b/sdsl/synthcl/lang/queries.rkt index 0f1434a2..f30a11b2 100644 --- a/sdsl/synthcl/lang/queries.rkt +++ b/sdsl/synthcl/lang/queries.rkt @@ -4,7 +4,7 @@ (only-in rackunit test-pred) (for-syntax (only-in racket/syntax with-syntax*)) (only-in "forms.rkt" range :) - (only-in rosette/lib/synthax generate-forms) + (only-in rosette/lib/synthax current-grammar-depth print-forms) (prefix-in @ (only-in rosette verify synthesize))) (provide verify synth expected? query-output-port) @@ -39,25 +39,10 @@ (printf "No counterexample found.\n") (unsat))))))))))])) -(define (inline-let f [env (hash)]) - (syntax-case f (let) - [(let ([x e] ...) body) - (let ([vars (syntax->datum #'(x ...))] - [exps (map (curryr inline-let env) (syntax->list #'(e ...)))]) - (inline-let #'body (apply hash-set* env (flatten (map cons vars exps)))))] - [(_ ...) - #`(#,@(map (curryr inline-let env) (syntax->list f)))] - [_ (hash-ref env (syntax->datum f) f)])) - -(define (print-forms sol) - (for ([f (generate-forms sol)]) - (printf "~a:~a:~a\n" (syntax-source f) (syntax-line f) (syntax-column f)) - (printf "~a\n" (pretty-format (syntax->datum (inline-let f)))))) - ; The synthesize form. (define-syntax (synth stx) (syntax-case stx () - [(synthesize #:forall [decl ...] #:bitwidth bw #:ensure expr) + [(synthesize #:forall [decl ...] #:bitwidth bw #:grammar-depth depth #:ensure expr) (with-syntax* ([([id seq] ...) (map id&range (syntax->list #'(decl ...)))] [(tmp ...) (generate-temporaries #'(id ...))]) (quasisyntax/loc stx @@ -68,6 +53,7 @@ (expected?) (with-terms (parameterize ([current-bitwidth bw] + [current-grammar-depth depth] [current-output-port (query-output-port)]) (printf "Synthesizing ~a\n" (source-of #'synthesize)) (define-values (id ...) @@ -79,9 +65,11 @@ (if (sat? m) (print-forms m) (printf "No solution found.\n")) - m))))))))] + m))))))))] + [(synthesize #:forall ds #:bitwidth bw #:ensure e) + (syntax/loc stx (synthesize #:forall ds #:bitwidth bw #:grammar-depth 3 #:ensure e))] [(synthesize #:forall ds #:ensure e) - (syntax/loc stx (synthesize #:forall ds #:bitwidth 8 #:ensure e))])) + (syntax/loc stx (synthesize #:forall ds #:bitwidth 8 #:grammar-depth 3 #:ensure e))])) ; Returns the declared id and the Racket sequence diff --git a/sdsl/synthcl/lang/typecheck.rkt b/sdsl/synthcl/lang/typecheck.rkt index a44f83e5..c6b9fdf9 100644 --- a/sdsl/synthcl/lang/typecheck.rkt +++ b/sdsl/synthcl/lang/typecheck.rkt @@ -35,8 +35,7 @@ (and (identifier? #'proc) (or (free-label-identifier=? #'proc #'procedure) (free-label-identifier=? #'proc #'kernel) - (free-label-identifier=? #'proc #'grammar) - (free-label-identifier=? #'proc #'grammar*))) + (free-label-identifier=? #'proc #'grammar))) (let ([out-type (identifier->type #'out stx)] [arg-types (map (curryr identifier->type stx) (syntax->list #'(type ...)))]) (when (free-label-identifier=? #'proc #'kernel) @@ -202,13 +201,29 @@ ; Typechecks the verify or synthesis query statement. (define (typecheck-query stx) (syntax-case stx () + [(query #:forall [decl ...] #:bitwidth bw #:grammar-depth depth #:ensure form) + (parameterize ([current-env (env)]) + (with-syntax ([(typed-decl ...) (map typecheck-query-declaration (syntax->list #'(decl ...)))] + [typed-form (typecheck #'form)] + [typed-bw (typecheck #'bw)] + [typed-depth (typecheck #'depth)]) + (check-no-conversion (type-ref #'typed-bw) int #'typed-bw stx) + (check-no-conversion (type-ref #'typed-depth) int #'typed-depth stx) + (type-set (syntax/loc stx (query #:forall [typed-decl ...] + #:bitwidth typed-bw + #:grammar-depth typed-depth + #:ensure typed-form)) + void)))] [(query #:forall [decl ...] #:bitwidth bw #:ensure form) (parameterize ([current-env (env)]) (with-syntax ([(typed-decl ...) (map typecheck-query-declaration (syntax->list #'(decl ...)))] [typed-form (typecheck #'form)] [typed-bw (typecheck #'bw)]) (check-no-conversion (type-ref #'typed-bw) int #'typed-bw stx) - (type-set (syntax/loc stx (query #:forall [typed-decl ...] #:bitwidth typed-bw #:ensure typed-form)) void)))] + (type-set (syntax/loc stx (query #:forall [typed-decl ...] + #:bitwidth typed-bw + #:ensure typed-form)) + void)))] [(query #:forall [decl ...] #:ensure form) (parameterize ([current-env (env)]) (with-syntax ([(typed-decl ...) (map typecheck-query-declaration (syntax->list #'(decl ...)))] @@ -435,7 +450,6 @@ (dict-set! procs #'choose typecheck-choose) (dict-set! procs #'?? typecheck-??) (dict-set! procs #'grammar typecheck-grammar) - (dict-set! procs #'grammar* typecheck-grammar) (dict-set! procs #'procedure typecheck-procedure) (dict-set! procs #'kernel typecheck-procedure)