Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[proposal] Support of functional interfaces #516

Open
andrii0lomakin opened this issue Jul 30, 2024 · 22 comments
Open

[proposal] Support of functional interfaces #516

andrii0lomakin opened this issue Jul 30, 2024 · 22 comments
Assignees
Labels
discussion feature New feature proposal proposal New proposals

Comments

@andrii0lomakin
Copy link
Contributor

Many tasks of heterogenous computing are variations of the same well-known patterns that differ only in functions called per element.
Like, reduce, for example, when SoftMax and RMS normalization, in a nutshell, differ only in functions processed.

It would be beneficial to support functional interfaces as parameters of kernels that unwind to real calls (with appropriate restrictions, of course).

That will minimize development time and increase the maintainability of kernels developed in TornadoVM.

@jjfumero
Copy link
Member

Hi @andrii0lomakin , is this a proposal or an issue?

You can build libraries on top of TornadoVM that solve specific functionality for domains of applications such as LLMs, Linear Algebra, Graphics, Physics, etc.

@andrii0lomakin
Copy link
Contributor Author

This is a feature request. At least, I have chosen it like this :-) . At the moment, I need to repeat the same boilerplate code repeatedly, and it would benefit me to pass functional interfaces.

@jjfumero jjfumero added proposal New proposals feature New feature proposal discussion labels Jul 30, 2024
@jjfumero
Copy link
Member

This is a feature request.

Ok, I open this for discussion for all community members and TornadoVM maintainers.

At the moment, I need to repeat the same boilerplate code repeatedly, and it would benefit me to pass functional interfaces.

What do you mean by passing functional interfaces? At which level? Tasks within TaskGraph receive functional interfaces already. Do you have any example in mind that you can share?

As I mentioned, one can build libraries on top. For example, LLM and transformer library that contain softmax, normalization, reductions and the matmul. Is this what you are referring to?

@jjfumero jjfumero changed the title Support of functional interfaces [proposal] Support of functional interfaces Jul 30, 2024
@andrii0lomakin
Copy link
Contributor Author

Simplest example.

I have an RMS norm layer that, in a nutshell, consists of reduce kernels that are called in layers as described here: https://developer.download.nvidia.com/assets/cuda/files/reduction.pdf.

Initially, I reduce the squired values and then just call the plain sum kernel in subsequent layers.
I can create a single kernel that accepts Function as a parameter and inline it during compilation, but instead, I need to repeat the same code.

The same is true for SoftMax, which uses a reduce function as its denominator.

I suppose there are plenty of other such use cases.

P.S. I am ready to implement this feature myself, but I need your feedback, of course.

@jjfumero
Copy link
Member

Reductions are supported in TornadoVM in two ways:

  1. Using the combination of @Parallel and @Reduce for loop-parallelism. The TornadoVM JIT Compiler is able to generate efficient reductions for simple test-cases (min, max, sum, etc).
    Details: https://www.researchgate.net/publication/327871451_Using_Compiler_Snippets_to_Exploit_Parallelism_on_Heterogeneous_Hardware_A_Java_Reduction_Case_Study

Note that the final reduction happens on the device. The TornadoVM JIT compiler generates two kernels from a single reduction kernel (one to run in parallel, and the second to perform the final reductions from the remaining work-groups).

  1. Using the Kernel API and the KernelContext to access local memory and barriers. In this case, the developer needs to implement the whole logic. This might be also the path to follow for more complex reductions.

@andrii0lomakin
Copy link
Contributor Author

@jjfumero seems like I was not clear enough, I know how to implement reduce and have knowledge of barriers and etc.

I mean that if I had the ability to pass the function that I need to call when I perform reduce, I would not need to repeat the same boilerplate code again and again.

For example in my code I have two kernels, one does:

    localSnippet[context.localIdx] = inputTensor.get(currentInputOffset);

and another one

 float value = inputTensor.get(currentInputOffset);
 localSnippet[context.localIdx] = value * value;

in general I can pass just a function in both kernels, but instead I need to repeat the same code again and again.

P.S. In general when I raise the issue I have impression that I always receive information how to implement basics of functionality instead of discussion of concrete issue in depth, probably I need change something in my conversational style to avoid this :-)

@andrii0lomakin
Copy link
Contributor Author

As for TorandoVM annotations, I find them a good entry point for developers who want to learn about heterogenous computing, but because all those "complications" are done for nothing but performance, at least at this concrete stage they are not quite suitable for production usage, at least in commercial applications or in librariries that are created to support commercial tools.

@jjfumero
Copy link
Member

at least at this concrete stage they are not quite suitable for production usage

TornadoVM is an academic project fully developed and maintained by Master and PhD students, researchers and staff at The University of Manchester. It is not a product, at least yet. Feedback and contributions are very welcome.

In general when I raise the issue I have impression that I always receive information how to implement basics of functionality instead of discussion of concrete issue in depth, probably I need change something in my conversational style to avoid this :-)

More concrete questions with test cases will be useful.

If I sketch what you want (pseudocode):

public static void sampleReduction(KernelContext context, 
                   FloatArray a,
                   FloatArray b, 
                  FunctionalInterface f) {

        int globalIdx = context.globalIdx;
        int localIdx = context.localIdx;
        int localGroupSize = context.localGroupSizeX;
        int groupID = context.groupIdx; // Expose Group ID

        float[] localA = context.allocateFloatLocalArray(256);
        localA[localIdx] = a.get(globalIdx);
        for (int stride = (localGroupSize / 2); stride > 0; stride /= 2) {
            context.localBarrier();
            if (localIdx < stride) {
                    localA[localIdx] = f.apply(localA[localIdx], localA[localIdx + stride]);     // Use of a functional interface
            }
        }
        if (localIdx == 0) {
            b.set(groupID, localA[0]);
        }
    }

The functional interface could potentially be used at any level and in different scenarios, not just for the compute. Is this a better approximation?

@andrii0lomakin
Copy link
Contributor Author

andrii0lomakin commented Jul 30, 2024

@jjfumero

First of all I truly believe that TornadoVM has great potential and is doing a lot of advertising for this project. Hopefully with some positive outcome. If all goes well I hope there will be a lot of contributions from my side.

Do not understand me wrong, my last observation was intended only to improve the quality of observation and nothing more than that.

The functional interface could potentially be used at any level and in different scenarios, not just for the compute. Is this a better approximation?

Absolutely, thank you for your summarization.

@andrii0lomakin
Copy link
Contributor Author

As I have mentioned above I am ready to implement this feature myself, I am not sure that it fits project design.

@jjfumero
Copy link
Member

Comments and feedback are very valuable for us so we really appreciate your feedback. Hopefully with the help of community members like you, TornadoVM can improve in many aspects.

This feature looks a great addition. If you want to implement these cases, feel free to open a PR.

@andrii0lomakin
Copy link
Contributor Author

andrii0lomakin commented Jul 30, 2024

@jjfumero Cool, I am on it then. I will provide PR in a weeks.

@andrii0lomakin
Copy link
Contributor Author

andrii0lomakin commented Jul 31, 2024

@jjfumero, here is a sketch of the steps that I will follow to implement the given issue.

First of all, only TaskX functional interfaces will be accepted as arguments.
I also will rewrite uk.ac.manchester.tornado.runtime.analyzer.TaskUtils#resolveMethodHandle to use ASM. It looks more robust and maintainable to me. I will likely add more checks to ensure that only lambda's of the correct form will be passed.

The algorithm will be similar to already implemented in uk.ac.manchester.tornado.runtime.analyzer.TaskUtils#resolveMethodHandle:

  1. Find a static method in the task code passed to the resolveMethodHandle.
  2. Find functional interface arguments in the list of arguments to the static method. Only TaskX instances will be allowed.
  3. If there are no instances of TaskX passed, just return the original method.
  4. Otherwise, use the ASM method visitor and class writer to navigate over passed in static method and replace all calls to functional interfaces by the call to a static method.

As the result of the last step, the byte code of a new class will be generated, and the class will be defined using the jdk.internal.misc.Unsafe#defineClass method.

P.S. With such an approach, I am thinking about loosening the requirements of the lambda code passed in the task.
As for me, that is enough of TaskX to be stateless and not use this pointer of the callee. In such cases, we can always generate a class with a static method that will represent a given task and use it in TornadoVM.

P.S.2 In later versions, I am going to validate the passed-in tasks using ASM to throw more meaningful exceptions than it is now when, in many cases, some kind of cryptic errors are thrown, leaving the users puzzled about what is really going on.

@andrii0lomakin
Copy link
Contributor Author

IMHO, GrallVM JIT can be successfully replaced by the upcoming HAT project but it seems like it will take a while before the first version will be provided by the OpenJDK team.

@jjfumero
Copy link
Member

Follow up questions:

only TaskX functional interfaces will be accepted as arguments.

  • Does this proposal change the TornadoVM Task-Graph API?
  • Can you share test-cases about how the proposal looks like?
  • Does this proposal plan to change the TornadoVM Runtime component?
  • Are you planing to adapt/extend the codgen (backends)?

With such an approach, I am thinking about loosening the requirements of the lambda code passed in the task.

Which requirements are you referring to?

In such cases, we can always generate a class with a static method that will represent a given task and use it in TornadoVM.

If we have the reference to the method already, why do we need to generate host code to be able to compile with TornadoVM? I did not get this part.

@andrii0lomakin
Copy link
Contributor Author

HI @jjfumero

Does this proposal change the TornadoVM Task-Graph API?

No. API stays the same.

Can you share test-cases about how the proposal looks like?

I can not provide test-case now. But I will provide conceptual usage:

Kernel:

public static void sampleReduction(KernelContext context, 
                   FloatArray a,
                   FloatArray b, 
                  BiFunction<Float, Float> f) {

        int globalIdx = context.globalIdx;
        int localIdx = context.localIdx;
        int localGroupSize = context.localGroupSizeX;
        int groupID = context.groupIdx; // Expose Group ID

        float[] localA = context.allocateFloatLocalArray(256);
        localA[localIdx] = a.get(globalIdx);
        for (int stride = (localGroupSize / 2); stride > 0; stride /= 2) {
            context.localBarrier();
            if (localIdx < stride) {
                    localA[localIdx] = f.apply(localA[localIdx], localA[localIdx + stride]);     // Use of a functional interface
            }
        }
        if (localIdx == 0) {
            b.set(groupID, localA[0]);
        }
    }


   Usage:
    taskGraph.task(C::sampleReduction, context, a, b,  (BiFunction) (a1, a2) -> a1 + a2)    

Does this proposal plan to change the TornadoVM Runtime component?

As I can see now, only TaskUtils#resolveMethodHandle will be changed.

Are you planing to adapt/extend the codgen (backends)?

No. But I will probably need to add support for handling primitive wrappers, we will see.

@andrii0lomakin
Copy link
Contributor Author

If we already have the reference to the method, why do we need to generate host code to be able to compile it with TornadoVM? I did not get this part.

AFAIK most backends do not support polymorphic calls, so passing of lambda essentially means implicit generation of new kernel with passed function. Kotlin works exactly the same way by inlining passed in lambdas and generation of artificial functions to decrease object allocations.

@andrii0lomakin
Copy link
Contributor Author

Only TaskX instances will be allowed.

Do not think this requirement is actually needed. Will check during concrete implementaiton.

@andrii0lomakin
Copy link
Contributor Author

andrii0lomakin commented Aug 1, 2024

Which requirements are you referring to?

Let us just skip it for a moment, from my experience the only robust way to pass kernels as of now is to pass static methods only. For example this code fails

  taskGraph.task("copyVector",
        (source, sourceOffset, destination, destinationOffset, length) -> {
          for (@Parallel int i = 0; i < length; i++) {
            destination.set(destinationOffset + i, source[sourceOffset + i]);
          }
        }, arrayToCopy, 0, resultArray, 0, arrayToCopy.length);

Though the same code in static method works like a charm, but I need to perform deeper investigation about reasons. Seems like it all steams down to the handling of Unbox(ing) that I suppose I will need to deal anyway during implementation.

In general I want to do the following with primitive wrappers:

Allow only the following operations on wrappers in passed in code:

  1. Pass them as an parameters to the functions.
  2. Assign to another variable.
  3. Perform arithmetic operations.
  4. Unboxing (xValue that corresponds to the passed type, according to JLS).

Those checks will be performed during resolving of method handles and will allow essentially to replace wrappers by primitives. But I will be ready to discuss it in more details when I will go further in implementation.

@andrii0lomakin
Copy link
Contributor Author

@jjfumero Did I answer your questions?

@jjfumero
Copy link
Member

jjfumero commented Aug 2, 2024

Yes. I think we have different views on implementation, which is normal. If you want to work on this, I suggest reviewing the proposal when you have a PoC, and we can iterate on this.

I think what you will see is a new parameter in the Graal IR that corresponds to your new lambda function. Then, from my view, you can use Graal to get access to that lambda. To me, the bulk of the work is in the code gen. But again, it could be just another way to implement the same functionality.

@andrii0lomakin
Copy link
Contributor Author

Yes. I think we have different views on implementation, which is normal. If you want to work on this, I suggest reviewing the >proposal when you have a PoC, and we can iterate on this.

Sounds like a plan, thank you.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
discussion feature New feature proposal proposal New proposals
Projects
None yet
Development

No branches or pull requests

5 participants