Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RFC] OpenXLA PJRT plugin #33

Merged
merged 3 commits into from
Feb 11, 2023
Merged

[RFC] OpenXLA PJRT plugin #33

merged 3 commits into from
Feb 11, 2023

Conversation

jpienaar
Copy link
Member

@jpienaar jpienaar commented Jan 24, 2023

Request for comment OpenXLA PJRT plugin along with creation of new repo for its development.

Present proposal for OpenXLA PJRT plugin along with creation of new repo for
its development.
@jpienaar jpienaar marked this pull request as ready for review January 24, 2023 03:33
@stellaraccident
Copy link
Contributor

Thanks for the writeup, Jacques!

@vinodgro
Copy link

nice to have windows support. Some products need Windows implementation. I had mentioned a few times that Windows availability would make OpenXLA more attractive.

@stellaraccident
Copy link
Contributor

nice to have windows support. Some products need Windows implementation. I had mentioned a few times that Windows availability would make OpenXLA more attractive.

I looked at this, and the plugin infra itself looks like it needs a bit of work for Windows compatibility still. But the plugin implementation is tested and deployed on Windows currently. So this doesn't seem too far to expect from were I see.

@vinodgro
Copy link

What is the process here? Create an issue?

@stellaraccident
Copy link
Contributor

stellaraccident commented Jan 24, 2023

What is the process here? Create an issue?

Process for this RFC to progress or requesting windows support for the implementation generally? For the latter, I would create a dedicated issue for platform policy overall (vs asking component by component).

@stellaraccident
Copy link
Contributor

Related: jax-ml/jax#438

I don't see any reason why the plugin infra and the openxla plugin can't support Windows.

@vinodgro
Copy link

vinodgro commented Jan 24, 2023

For the latter, I would create a dedicated issue for platform policy overall (vs asking component by component).

Which repository? JAX or ???

@vinodgro
Copy link

For the latter, I would create a dedicated issue for platform policy overall (vs asking component by component).

Which repository? JAX or ???

#34

@jakeh-gc
Copy link

jakeh-gc commented Jan 24, 2023

I would like to see something about how a plugin is tested.

Having a reference test suite that checks a PJRT implementation is "correct" would be very useful.

@stellaraccident
Copy link
Contributor

I would like to see is something about how a plugin is tested.

Having a reference test suite that checks a PJRT implementation is "correct" would be very useful.

Agreed. Ime, the Jax test suite is a useful component of this, but it doesn't have sufficient coverage for more integration-heavy tests (ie. Multi device, exotic compilation modes, out of the ordinary data transfer scenarios). Probably makes sense to have an additional layer of testing, and then we call some union of test suites an interim CTS and work on unifying it in a next phase.

@bhack
Copy link
Contributor

bhack commented Jan 24, 2023

What is exactly a PJRT_Program?
Does every registered PJRT plugin need to be able to full compile any PJRT_Program?

@jyingl3
Copy link

jyingl3 commented Jan 24, 2023

PJRT_Program is the input for PJRT compile, provide by the ML framework. Currently it be either serialized HloModuleProto/HloModuleProtoWithConfig or MLIR module.

It is likely not required for the plugin to compile any PJRT_Program. It is still evolving in terms of how we want to design the compile API. @skye

@bhack
Copy link
Contributor

bhack commented Jan 24, 2023

It is likely not required for the plugin to compile any PJRT_Program

Would it be possible to comunicate with the plugin to understand if an input program could be compiled (or is it supported) without failing (fast) at "runtime" on compilation?

I am asking about this with the scope of having a complete overview, or at least a preliminary check, on the PJRT plugin coverage over a specific program.

I don't know if it make sense at this level but it was something that I've expected at the current framework bridge level and we still don't have it. So I don't know if the topic is still valid or not at this lower level with PJRT programs.

ManfeiBai
ManfeiBai approved these changes Jan 24, 2023
@jpienaar
Copy link
Member Author

It is likely not required for the plugin to compile any PJRT_Program

Would it be possible to comunicate with the plugin to understand if an input program could be compiled (or is it supported) without failing (fast) at "runtime" on compilation?

That is a good suggestion. Ability to flag compatibility for a given plugin instance [there are potentially conservative and aggressive options too - e.g., an unsupported op that could be optimized away after a few rounds of optimizations]. We should keep that in mind for the compile API discussion.

@joker-eph
Copy link
Contributor

PJRT_Program is the input for PJRT compile, provide by the ML framework. Currently it be either serialized HloModuleProto/HloModuleProtoWithConfig or MLIR module.

Ultimately is seems that it should be a StableHLO module in the OpenXLA architecture?

It is likely not required for the plugin to compile any PJRT_Program. It is still evolving in terms of how we want to design the compile API. @skye

This ties pretty well with a discussion this morning in the open meeting: someone asked about dynamic shape support and how some platforms won't be able to support it. So as @jpienaar mentions above, we need to have this in mind when designing the APIs: "program capabilities" or "features" that aren't uniformly supported. There is a range of possibilities in the API design to express this...

@bhack
Copy link
Contributor

bhack commented Jan 24, 2023

@jpienaar Yes exactly these arguments are all related to what I meant.
I hope we could have a specific ticket/discussion in the new repo.

@jpienaar
Copy link
Member Author

PJRT_Program is the input for PJRT compile, provide by the ML framework. Currently it be either serialized HloModuleProto/HloModuleProtoWithConfig or MLIR module.

Ultimately is seems that it should be a StableHLO module in the OpenXLA architecture?

+1, PJRT is not OpenXLA/XLA specific but within the context of OpenXLA this is the supported input format.

@jakeh-gc
Copy link

Agreed. Ime, the Jax test suite is a useful component of this, but it doesn't have sufficient coverage for more integration-heavy tests (ie. Multi device, exotic compilation modes, out of the ordinary data transfer scenarios).

Yes, I think having some of those tests would be good. A problem I've had with using the ML Frameworks tests for the IPU PJRT implementation is it's unclear what the "canonical" set of tests is. There are a lot of exceptions and exclusions for CPU, TPU, and GPU.

@ManfeiBai
Copy link

ManfeiBai commented Jan 24, 2023

It is likely not required for the plugin to compile any PJRT_Program

Would it be possible to comunicate with the plugin to understand if an input program could be compiled (or is it supported) without failing (fast) at "runtime" on compilation?

That is a good suggestion. Ability to flag compatibility for a given plugin instance [there are potentially conservative and aggressive options too - e.g., an unsupported op that could be optimized away after a few rounds of optimizations]. We should keep that in mind for the compile API discussion.

If the op means the op supported by framework, would an unsupported op that could be optimized away might means the PJRT plugin will return error more early with the op info?

@jpienaar
Copy link
Member Author

It is likely not required for the plugin to compile any PJRT_Program

Would it be possible to comunicate with the plugin to understand if an input program could be compiled (or is it supported) without failing (fast) at "runtime" on compilation?

That is a good suggestion. Ability to flag compatibility for a given plugin instance [there are potentially conservative and aggressive options too - e.g., an unsupported op that could be optimized away after a few rounds of optimizations]. We should keep that in mind for the compile API discussion.

If the op means the op supported by framework, would an unsupported op that could be optimized away might means the PJRT plugin will return error more early with the op info?

We are going into the compiler API design and query functionality a bit, which is a separate discussion that needs to happen still. We can have multiple layers and expensiveness of tests. So the most simple one could be purely on op names, more advanced/expensive checks the attributes, more check the types too, then checking usage within a context (e.g., can I elide asserts?), then trying to run some initial optimizations to see. As a straw man one could add yes, no, maybe results for query "is supported model" and the cost of the query or amount of effort to try can be configured additionally - and trying to do a couple of optimizations would be on the expensive path. The more information the backend could provide the earlier and easier. For backends that are complete (which is a much more tractable target for HLOs!) one might not need to dig in too deep to get a result (exception is ops like scatter which are not a simpler ops). But there may be value in being able to give a very quick response even if conservative for some use cases, while in others conservative isn't useful. Larger discussion :) (I've seen multiple different attempts here and this will be design question in the compiler API work).

@bhack
Copy link
Contributor

bhack commented Jan 25, 2023

It will be super useful also for CI jobs or other development activities if for an early check you will not need to retrieve/run the verification on the device specific resources.
I still hope that this layer could not fail fast as It would be useful to open feature request tickets/issues on the specific compiler or eventually have a full overview to find workarounds on your program.

Then honestly we need something similar on Frameworks bridges for failures om generating StableHLO programs.. but this is another story.

Also I don't know if we could have some complications with custom calls.

@mjsML
Copy link
Member

mjsML commented Jan 25, 2023

*   Registration of multiple PJRT plugins.

This a braindump of things to consider will keep adding :

  1. Versioning/variants ( e.g. x86, x86-64 and ARM CPUs, different CUDA compute capability GPUs, DPUs..etc),
  2. How does a device broadcast its capability/supported API calls? Not sure the balance is here between going full VM contract vs. "DEVICE_VARIANT_VERSION" naming.
  3. "Hot-swap/plug" PjRT plugins (i.e. in the runtime being able to unload and reload.)
  4. Single vs Multi-process access/ ownership (maybe covered by IRFT)
  5. Command queues/ buffer management APIs access, simple commands vs. transactional.

cc: @nouiz @stellaraccident

if (function_prt != nullptr) {
PJRT_Api* api = function_prt();
plugin_name = parse(library_path);
PluginInfo plugin_info(api, config_values);
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to check if plugin_name is already loaded and if so, do an early return.
I wouldn't return an error. Allowing to load the same pluging many times isn't wrong.

This will also allows function_prt() to do all the initialization that it needs. (related to some open questions bellow)

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the suggestion! Edited.

PJRT TPU client is created. Shall we add another method InitializePlugin and
only run it once? Alternatively, the plugin can implement it in
`PJRT_Client_Create` and run it once in the first time a client was created.
* Do we want to create PJRT clients for every plugin that is found? Will that
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wouldn't do that automatically as this will increase some resource utilization even if the end user doesn't want to use it.
I would let frameworks decide which behavior they want. I wouldn't impose that decision at the PJRT level.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree. Updated the text accordingly.

pjrt_client = create_pjrt_client(plugin_name)
```

For TensorFlow, discovery will be added to
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if 2 frameworks create to clients for the same device. Is this supported? If so, it would be great to specify it.
If this isn't supported, it will be harder to have in the same python script different frameworks. So supporting multiple clients would be great.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is up to the specific hardware and plugin implementation. If the hardware only allows exclusive access, then the software will abide by that constraint. Otherwise, some plugins may use the hardware in an exclusive way (ie. Allocate all memory). The openxla plugin that we envision now will default to allowing multiple clients and supporting on demand memory allocation (with a possible option for more greedy access as an optimization for cases that need it).

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That sound good. Thanks.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there is more to this than greedy access as an optimisation. A Graphcore IPU can only be owned by a single context on the host. So two processes, or indeed two clients in a single process, can't share an IPU.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that becomes a restriction to use of a Graphcore IPU then -- the PJRT API layer isn't going to do any kind of virtualization or remoting. If a software mechanism is needed to arbitrate multi-party access to a device, then that would be up to the device implementation.

Side note: this is currently an issue when using Jax by default with the XLA GPU backend as it allocates all memory, effectively making it impossible to share. There are environment variable workarounds to cause it to dynamically allocate.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that becomes a platform specific restriction, and I don't want a virtualisation layer. My concern is this becomes an implicit assumption in all users of this API.

The openxla plugin that we envision now will default to allowing multiple clients

My only point being that greedy or exclusive access isn't necessarily an optimisation, it can be a requirement (there's a reason in silicon for it).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see us doing anything in the software stack that makes multi-tenancy either harder or easier, but we will probably seek to make the default openxla implementation more user friendly on this front by default as it is a frequent pain point.

There really should be one Client per process, and if there can only be one Client per system then that would limit to only launching one process for that category of devices.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a paragraph to summarize this discussion thread.

@mjsML
Copy link
Member

mjsML commented Feb 1, 2023

  1. Offloading on the device API level, especially on Grace-Hopper+ (Page 19 - Figure 9) systems, which has a core design issue (CPU and GPU viewed as a single device?, hierarchical caching?)

@stellaraccident
Copy link
Contributor

  1. Offloading on the device API level, especially on Grace-Hopper+ (Page 19 - Figure 9) systems, which has a core design issue (CPU and GPU viewed as a single device?, hierarchical caching?)

I'm having trouble tracking the exact use case that is promoting the question (and can imagine different directions). Could you elaborate?

@janpfeifer
Copy link

janpfeifer commented Feb 1, 2023 via email

@bhack
Copy link
Contributor

bhack commented Feb 1, 2023

For those watching from the sides, could someone point to a public documentation on what is PjRT ? Thanks!

Also what is the extended version of this acronym?

jax-ml/jax#11439 (comment)

@joker-eph
Copy link
Contributor

For those watching from the sides, could someone point to a public documentation on what is PjRT ? Thanks!

I don't have a pointer to a documentation immediately, but think of PJRT right now as the public API for setting up XLA (other than building the input graph).

See this basic example: https://github.com/openxla/xla/blob/main/xla/examples/axpy/stablehlo_compile_test.cc#L60-L79

The PjRtStreamExecutorDevice and PjRtStreamExecutorClient classes are defined in https://github.com/openxla/xla/blob/main/xla/pjrt/pjrt_stream_executor_client.h and are just some specific implementations of PjRtDevice and PjRtClient.
These abstract base class are really good references I think to figure out what is the PJRT api right now. See this long description for PjRtClient: https://github.com/openxla/xla/blob/main/xla/pjrt/pjrt_client.h#L339-L386

Also what is the extended version of this acronym?

d8c85bc19e29fdff0aa3d03065cdf79cef6c0fb9: PjRt stands for "pretty much just another runtime".

@jpienaar
Copy link
Member Author

jpienaar commented Feb 1, 2023

For those watching from the sides, could someone point to a public documentation on what is PjRT ? Thanks!

I think the best on these are those listed at the bottom of the RFC and the headers here, let us know if too low level. As mentioned in intro PJRT is a device API that will provide an easy interface with which frameworks can integrate a packaged compiler and runtime solution.

Also what is the extended version of this acronym?

google/jax#11439 (comment)

From the header: 'PjRt stands for "Pretty much Just another RunTime"' (this makes a lot more sense if one considers the originally JAX expansion).

@Jianhui-Li
Copy link

PJRT_Program is the input for PJRT compile, provide by the ML framework. Currently it be either serialized HloModuleProto/HloModuleProtoWithConfig or MLIR module.

It is important to support versioning of the program representation so that the plugin can check before processing the program. Do you consider versioning and what is the runtime behavior if the version doesn't exactly match?

@stellaraccident
Copy link
Contributor

PJRT_Program is the input for PJRT compile, provide by the ML framework. Currently it be either serialized HloModuleProto/HloModuleProtoWithConfig or MLIR module.

It is important to support versioning of the program representation so that the plugin can check before processing the program. Do you consider versioning and what is the runtime behavior if the version doesn't exactly match?

We are driving this towards StableHLO which is developing version constraints as part of its design. New implementations should use that.

@burmako

@mjsML
Copy link
Member

mjsML commented Feb 7, 2023

I'm having trouble tracking the exact use case promoting the question (and can imagine different directions). Could you elaborate?

I'm sorry for the vague statement; I'm trying to jot down the key points so I can remember to bring them up when there is a higher bandwidth format.

A small write-up for 7 is:

For Grace Hopper + based systems, the NVL bandwidth is much higher than PCIe, opening the door for profitable offloading of tensors to the HBM. The next version of the runtime API would need to consider the programming model for such hybrid systems while tackling the traditional Host <-> GPU typical design. One example is if one is doing a "one layer at a time" style of computation for a massive model with a limited set of GPUs (model_size> All GPU memories)

Current frameworks (e.g. JAX) don't have an ergonomic path for this hybrid memory model; assuming they do, it will likely need an abstraction at the runtime level too.

@jyingl3
Copy link

jyingl3 commented Feb 8, 2023

*   Registration of multiple PJRT plugins.

This a braindump of things to consider will keep adding :

  1. Versioning/variants ( e.g. x86, x86-64 and ARM CPUs, different CUDA compute capability GPUs, DPUs..etc),
  2. How does a device broadcast its capability/supported API calls? Not sure the balance is here between going full VM contract vs. "DEVICE_VARIANT_VERSION" naming.
  3. "Hot-swap/plug" PjRT plugins (i.e. in the runtime being able to unload and reload.)
  4. Single vs Multi-process access/ ownership (maybe covered by IRFT)
  5. Command queues/ buffer management APIs access, simple commands vs. transactional.

cc: @nouiz @stellaraccident

Added a section to capture it.

@stellaraccident
Copy link
Contributor

@jyingl3 Can you catch me up on comm channels for this work? I just got a basic CI going and found that there was some API drift around Executables/LoadedExecutables that I am adapting to. As you know, it is basically impossible to keep up with the noise in the TF repo, and without seeing the dev process, this isn't a great experience for collaborators.

@theadactyl Any objection to using the IREE #pjrt-plugin channel for now to coordinate on API updates and such? Open to other options.

@stellaraccident
Copy link
Contributor

stellaraccident commented Feb 11, 2023

Can one of the PJRT owners please advise on where to file bugs and how to reach the engineers. I have filed this one: openxla/xla#1237 but the engineers who work on this do not appear to be members of the repo.

Thanks.
(ftr - I work at Google and could have pinged them privately but am trying to flush out how to actually interact with our OSS projects without such privileged access)

@bhack
Copy link
Contributor

bhack commented Feb 11, 2023

(ftr - I work at Google and could have pinged them privately but am trying to flush out how to actually interact with our OSS projects without such privileged access)

openxla/xla#448

@theadactyl theadactyl self-requested a review February 11, 2023 18:09
Copy link
Member

@theadactyl theadactyl left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This RFC is approved. To be clear, the scope of the RFC is creating the repo, not approving any particular design decision. I created this issue in the new repo so that the design feedback can be appropriately logged: https://github.com/openxla/openxla-pjrt-plugin/issues/3

@theadactyl theadactyl merged commit 89d40f4 into openxla:main Feb 11, 2023
@jyingl3
Copy link

jyingl3 commented Feb 22, 2023

@jyingl3 Can you catch me up on comm channels for this work? I just got a basic CI going and found that there was some API drift around Executables/LoadedExecutables that I am adapting to. As you know, it is basically impossible to keep up with the noise in the TF repo, and without seeing the dev process, this isn't a great experience for collaborators.

Thanks for checking! We have finalized the communication channel:

We also just added a README to https://github.com/openxla/xla/tree/main/xla/pjrt/c about the communication channel and some resources.

@skye

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.