Skip to content

Commit

Permalink
Add compile
Browse files Browse the repository at this point in the history
  • Loading branch information
zcbenz committed Apr 28, 2024
1 parent 6f4634f commit ca2c0bc
Show file tree
Hide file tree
Showing 5 changed files with 442 additions and 10 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ There are a few exceptions due to limitations of JavaScript:

Some features are not supported yet and will be implemented in future:

* The function passed to `mx.grad`/`mx.valueAndGrad`/`mx.vmap` must have all its
parameters taking `mx.array`.
* The function passed to `mx.grad`/`mx.valueAndGrad`/`mx.vmap`/`mx.compile` must
have all its parameters taking `mx.array`.
* When creating a `mx.array` from JavaScript Array, the Array must only include
primitive values.

Expand Down
2 changes: 1 addition & 1 deletion deps/kizunapi
14 changes: 8 additions & 6 deletions lib/index.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -319,12 +319,14 @@ type ValueAndGradFunctionScalar = (...args: array[]) => [array, array]
type ValueAndGradFunctionGeneric = (...args: array[]) => [array[], array[]]
export function valueAndGrad(func: (...args: array[]) => array, argnums?: number | number[]): ValueAndGradFunctionScalar;
export function valueAndGrad(func: (...args: array[]) => array[], argnums?: number | number[]): ValueAndGradFunctionGeneric;
type GradFunctionScalar = (...args: array[]) => array
type GradFunctionGeneric = (...args: array[]) => array[]
export function grad(func: (...args: array[]) => array, argnums?: number | number[]): GradFunctionScalar;
export function grad(func: (...args: array[]) => array[], argnums?: number | number[]): GradFunctionGeneric;
export function vmap(func: (...args: array[]) => array, inAxes?: number | number[], outAxis?: number): GradFunctionScalar;
export function vmap(func: (...args: array[]) => array[], inAxes?: number | number[], outAxes?: number[]): GradFunctionGeneric;
type ComputeFunctionScalar = (...args: array[]) => array
type ComputeFunctionGeneric = (...args: array[]) => array[]
export function grad(func: ComputeFunctionScalar, argnums?: number | number[]): ComputeFunctionScalar;
export function grad(func: ComputeFunctionGeneric, argnums?: number | number[]): ComputeFunctionGeneric;
export function vmap(func: ComputeFunctionScalar, inAxes?: number | number[], outAxis?: number): ComputeFunctionScalar;
export function vmap(func: ComputeFunctionGeneric, inAxes?: number | number[], outAxes?: number[]): ComputeFunctionGeneric;
export function compile(func: ComputeFunctionScalar, shapeless?: boolean): ComputeFunctionScalar;
export function compile(func: ComputeFunctionGeneric, shapeless?: boolean): ComputeFunctionGeneric;

// Metal.
export namespace metal {
Expand Down
34 changes: 33 additions & 1 deletion src/transforms.cc
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
#include "src/array.h"
#include "src/ops.h"

// Needed for detail::compile.
#include "mlx/transforms_impl.h"

namespace {

// Unflatten the function call result.
Expand Down Expand Up @@ -138,6 +141,32 @@ VMap(napi_env env,
};
}

std::function<napi_value(ki::Arguments*)>
Compile(napi_env env,
napi_value value,
std::optional<bool> shapeless) {
// Reference the JS function as napi_value only lives at current tick.
ki::Persistent js_func(env, value);
std::uintptr_t func_id = reinterpret_cast<std::uintptr_t>(js_func.Id());
// Call compile with the JS function.
auto func = mx::detail::compile(
[js_func = std::move(js_func)](const std::vector<mx::array>& primals) {
return ExecuteWithPrimals(js_func.Env(), js_func.Value(), primals);
},
func_id,
shapeless.value_or(false));
// Return a JS function that converts JS args into primals.
return [env, func = std::move(func)](ki::Arguments* args) -> napi_value {
std::vector<mx::array> arrays;
if (!ReadArgs(args, &arrays))
return nullptr;
auto results = func(std::move(arrays));
if (ki::IsExceptionPending(env))
return nullptr;
return UnflattenResults(env, results);
};
}

} // namespace transforms_ops

void InitTransforms(napi_env env, napi_value exports) {
Expand All @@ -148,5 +177,8 @@ void InitTransforms(napi_env env, napi_value exports) {
"vjp", JVPOpWrapper(&mx::vjp),
"valueAndGrad", &transforms_ops::ValueAndGrad,
"grad", &transforms_ops::Grad,
"vmap", &transforms_ops::VMap);
"vmap", &transforms_ops::VMap,
"compile", &transforms_ops::Compile,
"disableCompile", &mx::disable_compile,
"enableCompile", &mx::enable_compile);
}
Loading

0 comments on commit ca2c0bc

Please sign in to comment.