From d8f4b6073342b2688cc3dff5be94ecc03138c15e Mon Sep 17 00:00:00 2001 From: "Steven G. Johnson" Date: Sun, 10 Jul 2016 16:17:46 -0400 Subject: [PATCH] compress fused broadcast args by eliminating literals and some (pure) duplicates --- src/julia-syntax.scm | 49 ++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 45 insertions(+), 4 deletions(-) diff --git a/src/julia-syntax.scm b/src/julia-syntax.scm index 9edda07871fd1..9574af45482bb 100644 --- a/src/julia-syntax.scm +++ b/src/julia-syntax.scm @@ -1554,13 +1554,12 @@ (define (fuse? e) (and (pair? e) (eq? (car e) 'fuse))) (define (anyfuse? exprs) (if (null? exprs) #f (if (fuse? (car exprs)) #t (anyfuse? (cdr exprs))))) - (define (...? e) (and (pair? e) (eq? (car e) '...))) (define (to-lambda f args) ; convert f to anonymous function with hygienic tuple args - (define (genarg arg) (if (...? arg) (list '... (gensy)) (gensy))) + (define (genarg arg) (if (vararg? arg) (list '... (gensy)) (gensy))) (define (hygienic f) ; rename args of f == (-> (tuple args...) body) (let* ((oldargs (cdadr f)) (newargs (map genarg oldargs)) - (renames (map (lambda (oldarg newarg) (if (...? oldarg) + (renames (map (lambda (oldarg newarg) (if (vararg? oldarg) (cons (cadr oldarg) (cadr newarg)) (cons oldarg newarg))) oldargs newargs))) @@ -1610,7 +1609,49 @@ (if (anyfuse? args_) `(fuse ,(fuse-funcs (to-lambda f args) args_) ,(fuse-args args_)) `(fuse ,(to-lambda f args) ,args_)))) - (let ((e (make-fuse f args))) ; an expression '(fuse func args) + ; given e == (fuse lambda args), compress the argument list by removing (pure) + ; duplicates in args, inlining literals, and moving any varargs to the end: + (define (compress-fuse e) + (define (findfarg arg args fargs) ; for arg in args, return corresponding farg + (if (eq? arg (car args)) + (car fargs) + (findfarg arg (cdr args) (cdr fargs)))) + (let ((f (cadr e)) + (args (caddr e))) + (define (cf old-fargs old-args new-fargs new-args renames varfarg vararg) + (if (null? old-args) + (let ((nfargs (if (null? varfarg) new-fargs (cons varfarg new-fargs))) + (nargs (if (null? vararg) new-args (cons vararg new-args)))) + `(fuse (-> (tuple ,@(reverse nfargs)) ,(replace-vars (caddr f) renames)) + ,(reverse nargs))) + (let ((farg (car old-fargs)) (arg (car old-args))) + (cond + ((and (vararg? farg) (vararg? arg)) ; arg... must be the last argument + (if (null? varfarg) + (cf (cdr old-fargs) (cdr old-args) + new-fargs new-args renames farg arg) + (if (eq? (cadr vararg) (cadr arg)) + (cf (cdr old-fargs) (cdr old-args) + new-fargs new-args (cons (cons (cadr farg) (cadr varfarg)) renames) + varfarg vararg) + (error "multiple splatted args cannot be fused into a single broadcast")))) + ((number? arg) ; inline numeric literals + (cf (cdr old-fargs) (cdr old-args) + new-fargs new-args + (cons (cons farg arg) renames) + varfarg vararg)) + ((and (symbol? arg) (memq arg new-args)) ; combine duplicate args + ; (note: calling memq for every arg is O(length(args)^2) ... + ; ... would be better to replace with a hash table if args is long) + (cf (cdr old-fargs) (cdr old-args) + new-fargs new-args + (cons (cons farg (findfarg arg new-args new-fargs)) renames) + varfarg vararg)) + (else + (cf (cdr old-fargs) (cdr old-args) + (cons farg new-fargs) (cons arg new-args) renames varfarg vararg)))))) + (cf (cdadr f) args '() '() '() '() '()))) + (let ((e (compress-fuse (make-fuse f args)))) ; an expression '(fuse func args) (expand-forms `(call broadcast ,(from-lambda (cadr e)) ,@(caddr e))))) ;; table mapping expression head to a function expanding that form