Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ base ] Change how folds are defined for Data.Vect #2707

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

0xd34df00d
Copy link
Contributor

  • Adds dependent folds, where the type of the accumulator depends on the current length of the Vect.
  • Changes the Foldable instance for Vect to be expressed via the former.
  • Moves the tail-recursive version of foldr to a separate function, foldrTR.
  • As a sanity check, proves that foldrTR is equal to foldr.
  • Changes the definition of sumR in Data.Vect.Properties.Foldr to accept just Vects, since other foldables might define foldr differently, and this function lives in a module with Vect in name anyway.

* Adds dependent folds, where the type of the accumulator depends on the
  current length of the `Vect`.
* Changes the `Foldable` instance for `Vect` to be expressed via the
  former.
* Moves the tail-recursive version of `foldr` to a separate function,
  `foldrTR`.
* As a sanity check, proves that `foldrTR` is equal to `foldr`.
* Changes the definition of `sumR` in `Data.Vect.Properties.Foldr` to
  accept just `Vect`s, since other foldables might define `foldr`
  differently, and this function lives in a module with `Vect` in name
  anyway.
@stefan-hoeck
Copy link
Contributor

We have to be very careful about this, as it looks like it shows quadratic runtime complexity due to #2166. For instance, if I run the following program on my machine with an argument of "100000", it takes 5 seconds (!) to finish:

module Foldl

import Data.Vect
import System

foldlD : (0 accTy : Nat -> Type) ->
         (f : forall k. accTy k -> a -> accTy (S k)) ->
         (acc : accTy Z) ->
         (xs : Vect n a) ->
         accTy n
foldlD _     _ acc []        = acc
foldlD accTy f acc (x :: xs) = foldlD (accTy . S) f (acc `f` x) xs


fold : (acc -> e -> acc) -> acc -> Vect n e -> acc
fold f acc xs = foldlD (const _) f acc xs

main : IO ()
main = do
  [_,s] <- getArgs | _ => die "Invalid number of args"

  printLn $ fold (+) Z (replicate (cast s) 1)

If I replace fold on the last line with the current foldl, it is instantaneous even for n = 1000000 (less than 100 ms). So, at the very least, we must not implement foldl in terms of foldlD at the moment.

@stefan-hoeck
Copy link
Contributor

Two more data points: The following, which uses a trick described in #2166 by using a function with an explicit erased argument, does not work here:

foldlDImpl : (0 accTy : Nat -> Type) ->
           (f : (0 k : Nat) -> accTy k -> a -> accTy (S k)) ->
           (acc : accTy Z) ->
           (xs : Vect n a) ->
           accTy n
foldlDImpl _     _ acc []        = acc
foldlDImpl accTy f acc (x :: xs) = foldlDImpl (accTy . S) (\k => f (S k)) (f _ acc x) xs

The only way I got this to run in linear time is by means of believe_me, which is clearly unsatisfactory:

foldlDImpl : (0 accTy : Nat -> Type) ->
             (f : (0 k : Nat) -> accTy k -> a -> accTy (S k)) ->
             (acc : accTy Z) ->
             (xs : Vect n a) ->
             accTy n
foldlDImpl _     _ acc []        = acc
foldlDImpl accTy f acc (x :: xs) = foldlDImpl (accTy . S) (believe_me f) (f _ acc x) xs

foldlD : (0 accTy : Nat -> Type) ->
         (f : forall k. accTy k -> a -> accTy (S k)) ->
         (acc : accTy Z) ->
         (xs : Vect n a) ->
         accTy n
foldlD at f = foldlDImpl at $ \_ => f

@stefan-hoeck
Copy link
Contributor

OK, I got a O(n) version without believe_me by using a helper function for the recursion:

foldlD : (0 accTy : Nat -> Type) ->
         (f : forall k. accTy k -> a -> accTy (S k)) ->
         (acc : accTy Z) ->
         (xs : Vect n a) ->
         accTy n
foldlD at f acc xs = go acc xs
  where go : at k -> Vect m a -> at (k + m)
        go           x []        =
          rewrite plusZeroRightNeutral k in x

        go {m = S l} x (y :: xs) =
          rewrite sym (plusSuccRightSucc k l) in go (f x y) xs

@0xd34df00d
Copy link
Contributor Author

0xd34df00d commented Oct 12, 2022

Interesting, thanks for looking into this!

Since I'll eventually need to prove things about functions expressed with foldlD/foldrD, having a local where-bound helper would be unfortunate (how does one refer to it, after all?). Right now I have two options:

  1. Move the go out to a top-level foldlDhelper or something like that, hoping that the optimizer will be happy about that. I haven't tested yet whether the optimizer is actually happy.
  2. Leave fold{l,r} expressed directly instead of using fold{l,r}D, and just not care about the performance of the dependent folds until the bug is fixed (and still replacing the foldr implementation, having the tail-recursive version as a separate function).

Personally, I'd vote for (2), as it seems like a fair share of proofs about fold{l,r} implemented directly will still hold when the implementation is replaced with a call to fold{l,r}D, and the structure of foldlD proposed in the PR seems to be more obvious and more proof-friendly than whatever's needed to work around the optimizer bug.

What do you think?

@gallais
Copy link
Member

gallais commented Nov 21, 2022

The whole point of this refactoring is that the foldl-based presentation does
not require the use of rewrite. I wonder whether I could dig up my eta-contraction
pass and see whether that's enough to get this more elegant version the perf it deserves.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants