Skip to content

Latest commit

 

History

History
67 lines (59 loc) · 3.74 KB

ROADMAP.md

File metadata and controls

67 lines (59 loc) · 3.74 KB

Roadmap

A kind of roadmap that gives a rough idea about how the project will be continued and which features will be implemented, in the sense of a todo list.

  • Being able to perform some translations PR#3.
  • Basic functionalities:
    • Annotation @jace.jit.
    • Composable with Jax, i.e. take the Jax derivative of a JaCe annotated function.
    • Implementing the stages model that is supported by Jax.
    • Handling Jax arrays as native input (only on single host).
    • Cache the compilation and lowering results for later reuse. In Jax these parts (together with the dispatch) are actually written in C++, thus in the beginning we will use a self made cache.
  • Implementing some basic PrimitiveTranslators, that allows us to run some early tests, such as:
    • Backporting the ones from the prototype.
    • Implement the scatter primitive (needed for pyhpc).
    • Implement the scan primitive (needed for pyhpc).
  • Initial optimization pipeline In order to do benchmarks, we need to perform optimizations first. However, the one offered by DaCe were not that well, so we should, for now, backport the ones from the prototype.
  • Support GPU code (relatively simple, but needs some detection logic).
  • Initial benchmark: In the beginning we will not have the same dispatching performance as Jax. But passing these benchmarks could give us some better hint of how to proceed in this matter.
  • Support of static arguments.
  • Stop relying on jax.make_jaxpr(). Look at the jax._src.pjit.make_jit() function for how to hijack the staging process.
  • Implementing more advanced primitives:
    • Handling pytrees as arguments.
    • Implement random numbers.
    • jax.numpy.
    • jax.scipy.
  • Passing the single host Jax unit tests.
  • Multi-Device capabilities, i.e. multiple GPUs but all on the same host.
    • Passing the associated Jax unit tests.
  • Multi-Host capabilities, i.e. MPI.
    • Passing the associated Jax unit tests.

General

These are more general topics that should be addressed at one point.

  • Integrating better with Jax
    • Support its array type (probably implement this in DaCe).
  • Increase the dispatching speed + Cache Jax does this in C++, which is impossible to beat in Python, thus we have to go that root as well.
  • Debugging information.
  • Dynamic shapes This could be done by making the inputs fully dynamic, and then use the primitives to simplify. For example in an addition the shape of the two inputs and the outputs are the same. That is knowledge that is inherent to the primitives itself. However, the compiled object must know how to extract the sizes itself.
  • Defining a Logo: It should be green with a nice curly font.

Optimization & Transformations

The SDFG generated by JaCe have a very particular structure, thus we could and probably should write some highly targeted optimization passes for them. Our experiments with the prototype showed that the most important transformation is Map fusion and the one in DaCe is essentially broken.

  • Modified state fusion; Because of the structure we have, this could make Simplify much more efficient.
  • Trivial Tasklet removal. Since we will work a lot with Maps that are trivial (probably the best structure for fusing) we will end up with some of trivial Tasklets, i.e. __out = __in. Thus, we should have a good way to get rid of them.
  • Modified Map fusion transformation. We should still support parallel and serial fusion as the prototype did, but focusing on serial.