Skip to content

Commit 5fe9dbc

Browse files
committed
[GR-62315] Fix compilation bailout in WasmFunctionInstance.execute.
PullRequest: graal/20069
2 parents e4f91a8 + a257aeb commit 5fe9dbc

File tree

7 files changed

+343
-109
lines changed

7 files changed

+343
-109
lines changed

wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/Linker.java

+2-1
Original file line numberDiff line numberDiff line change
@@ -469,7 +469,8 @@ void resolveFunctionImport(WasmContext context, WasmInstance instance, WasmFunct
469469
}
470470

471471
void resolveFunctionExport(WasmModule module, int functionIndex, String exportedFunctionName) {
472-
final ImportDescriptor importDescriptor = module.symbolTable().function(functionIndex).importDescriptor();
472+
final WasmFunction function = module.symbolTable().function(functionIndex);
473+
final ImportDescriptor importDescriptor = function.importDescriptor();
473474
final Sym[] dependencies = (importDescriptor != null) ? new Sym[]{new ImportFunctionSym(module.name(), importDescriptor, functionIndex)} : ResolutionDag.NO_DEPENDENCIES;
474475
resolutionDag.resolveLater(new ExportFunctionSym(module.name(), exportedFunctionName), dependencies, NO_RESOLVE_ACTION);
475476
}

wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/SymbolTable.java

+5-6
Original file line numberDiff line numberDiff line change
@@ -86,9 +86,9 @@ public abstract class SymbolTable {
8686
private static final int NO_EQUIVALENCE_CLASS = 0;
8787
static final int FIRST_EQUIVALENCE_CLASS = NO_EQUIVALENCE_CLASS + 1;
8888

89-
public static class FunctionType {
90-
private final byte[] paramTypes;
91-
private final byte[] resultTypes;
89+
public static final class FunctionType {
90+
@CompilationFinal(dimensions = 1) private final byte[] paramTypes;
91+
@CompilationFinal(dimensions = 1) private final byte[] resultTypes;
9292
private final int hashCode;
9393

9494
FunctionType(byte[] paramTypes, byte[] resultTypes) {
@@ -112,10 +112,9 @@ public int hashCode() {
112112

113113
@Override
114114
public boolean equals(Object object) {
115-
if (!(object instanceof FunctionType)) {
115+
if (!(object instanceof FunctionType that)) {
116116
return false;
117117
}
118-
FunctionType that = (FunctionType) object;
119118
if (this.paramTypes.length != that.paramTypes.length) {
120119
return false;
121120
}
@@ -146,7 +145,7 @@ public String toString() {
146145
for (int i = 0; i < resultTypes.length; i++) {
147146
resultNames[i] = WasmType.toString(resultTypes[i]);
148147
}
149-
return Arrays.toString(paramNames) + " -> " + Arrays.toString(resultNames);
148+
return "(" + String.join(" ", paramNames) + ")->(" + String.join(" ", resultNames) + ")";
150149
}
151150
}
152151

wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/WasmFunction.java

+26-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright (c) 2019, 2024, Oracle and/or its affiliates. All rights reserved.
2+
* Copyright (c) 2019, 2025, Oracle and/or its affiliates. All rights reserved.
33
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
44
*
55
* The Universal Permissive License (UPL), Version 1.0
@@ -52,6 +52,8 @@ public final class WasmFunction {
5252
@CompilationFinal private int typeEquivalenceClass;
5353
@CompilationFinal private String debugName;
5454
@CompilationFinal private CallTarget callTarget;
55+
/** Interop call adapter for argument and return value validation and conversion. */
56+
@CompilationFinal private volatile CallTarget interopCallAdapter;
5557

5658
/**
5759
* Represents a WebAssembly function.
@@ -128,6 +130,14 @@ public String importedFunctionName() {
128130
return isImported() ? importDescriptor.memberName() : null;
129131
}
130132

133+
public String exportedFunctionName() {
134+
return symbolTable.exportedFunctionName(index);
135+
}
136+
137+
public boolean isExported() {
138+
return exportedFunctionName() != null;
139+
}
140+
131141
public int typeIndex() {
132142
return typeIndex;
133143
}
@@ -158,4 +168,19 @@ void setImportedFunctionCallTarget(CallTarget callTarget) {
158168
assert isImported() : this;
159169
this.callTarget = callTarget;
160170
}
171+
172+
public CallTarget getInteropCallAdapter() {
173+
return interopCallAdapter;
174+
}
175+
176+
@TruffleBoundary
177+
public CallTarget getOrCreateInteropCallAdapter(WasmLanguage language) {
178+
CallTarget callAdapter = this.interopCallAdapter;
179+
if (callAdapter == null) {
180+
// Benign initialization race: The call target will be the same each time.
181+
callAdapter = language.interopCallAdapterFor(type());
182+
this.interopCallAdapter = callAdapter;
183+
}
184+
return callAdapter;
185+
}
161186
}

wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/WasmFunctionInstance.java

+48-100
Original file line numberDiff line numberDiff line change
@@ -42,24 +42,20 @@
4242

4343
import java.util.Objects;
4444

45-
import org.graalvm.wasm.api.InteropArray;
46-
import org.graalvm.wasm.api.Vector128;
47-
import org.graalvm.wasm.exception.Failure;
48-
import org.graalvm.wasm.exception.WasmException;
49-
import org.graalvm.wasm.nodes.WasmIndirectCallNode;
50-
5145
import com.oracle.truffle.api.CallTarget;
46+
import com.oracle.truffle.api.CompilerDirectives;
5247
import com.oracle.truffle.api.CompilerDirectives.TruffleBoundary;
5348
import com.oracle.truffle.api.RootCallTarget;
5449
import com.oracle.truffle.api.TruffleContext;
50+
import com.oracle.truffle.api.dsl.Bind;
5551
import com.oracle.truffle.api.dsl.Cached;
56-
import com.oracle.truffle.api.interop.ArityException;
52+
import com.oracle.truffle.api.dsl.Specialization;
5753
import com.oracle.truffle.api.interop.InteropLibrary;
5854
import com.oracle.truffle.api.interop.TruffleObject;
59-
import com.oracle.truffle.api.interop.UnsupportedTypeException;
60-
import com.oracle.truffle.api.library.CachedLibrary;
6155
import com.oracle.truffle.api.library.ExportLibrary;
6256
import com.oracle.truffle.api.library.ExportMessage;
57+
import com.oracle.truffle.api.nodes.EncapsulatingNodeReference;
58+
import com.oracle.truffle.api.nodes.Node;
6359

6460
@ExportLibrary(InteropLibrary.class)
6561
public final class WasmFunctionInstance extends EmbedderDataHolder implements TruffleObject {
@@ -69,6 +65,10 @@ public final class WasmFunctionInstance extends EmbedderDataHolder implements Tr
6965
private final WasmFunction function;
7066
private final CallTarget target;
7167
private final TruffleContext truffleContext;
68+
/**
69+
* Stores the imported function object for {@link org.graalvm.wasm.api.ExecuteHostFunctionNode}.
70+
* Initialized during linking.
71+
*/
7272
private Object importedFunction;
7373

7474
/**
@@ -135,104 +135,52 @@ boolean isExecutable() {
135135
}
136136

137137
@ExportMessage
138-
Object execute(Object[] arguments,
139-
@CachedLibrary("this") InteropLibrary self,
140-
@Cached WasmIndirectCallNode callNode) throws ArityException, UnsupportedTypeException {
141-
TruffleContext c = getTruffleContext();
142-
Object prev = c.enter(self);
143-
try {
144-
Object result = callNode.execute(target, WasmArguments.create(moduleInstance, validateArguments(arguments)));
145-
146-
// For external calls of a WebAssembly function we have to materialize the multi-value
147-
// stack.
148-
// At this point the multi-value stack has already been populated, therefore, we don't
149-
// have to check the size of the multi-value stack.
150-
if (result == WasmConstant.MULTI_VALUE) {
151-
WasmLanguage language = context.language();
152-
assert language == WasmLanguage.get(null);
153-
return multiValueStackAsArray(language);
138+
static class Execute {
139+
private static Object execute(WasmFunctionInstance functionInstance, Object[] arguments, CallTarget callAdapter, Node node, Node callNode) {
140+
TruffleContext c = functionInstance.getTruffleContext();
141+
Object prev = c.enter(node);
142+
try {
143+
return callAdapter.call(callNode, WasmArguments.create(functionInstance, arguments));
144+
// throws ArityException, UnsupportedTypeException
145+
} finally {
146+
c.leave(node, prev);
154147
}
155-
return result;
156-
} finally {
157-
c.leave(self, prev);
158148
}
159-
}
160149

161-
private Object[] validateArguments(Object[] arguments) throws ArityException, UnsupportedTypeException {
162-
if (function == null) {
163-
return arguments;
150+
@SuppressWarnings("unused")
151+
@Specialization(guards = {"actualFunction == cachedFunction"}, limit = "2")
152+
static Object direct(WasmFunctionInstance functionInstance, Object[] arguments,
153+
@Bind("functionInstance.function()") WasmFunction actualFunction,
154+
@Cached("actualFunction") WasmFunction cachedFunction,
155+
@Cached("getOrCreateInteropCallAdapter(functionInstance)") CallTarget cachedCallAdapter,
156+
@Bind Node node) {
157+
return execute(functionInstance, arguments, cachedCallAdapter, node, node);
164158
}
165-
final int paramCount = function.paramCount();
166-
if (arguments.length != paramCount) {
167-
throw ArityException.create(paramCount, paramCount, arguments.length);
159+
160+
@SuppressWarnings("unused")
161+
@Specialization(guards = {"actualCallAdapter == cachedCallAdapter"}, limit = "3", replaces = "direct")
162+
static Object directAdapter(WasmFunctionInstance functionInstance, Object[] arguments,
163+
@Bind("getOrCreateInteropCallAdapter(functionInstance)") CallTarget actualCallAdapter,
164+
@Cached("actualCallAdapter") CallTarget cachedCallAdapter,
165+
@Bind Node node) {
166+
return execute(functionInstance, arguments, cachedCallAdapter, node, node);
168167
}
169-
for (int i = 0; i < paramCount; i++) {
170-
byte paramType = function.paramTypeAt(i);
171-
Object value = arguments[i];
172-
switch (paramType) {
173-
case WasmType.I32_TYPE -> {
174-
if (value instanceof Integer) {
175-
continue;
176-
}
177-
}
178-
case WasmType.I64_TYPE -> {
179-
if (value instanceof Long) {
180-
continue;
181-
}
182-
}
183-
case WasmType.F32_TYPE -> {
184-
if (value instanceof Float) {
185-
continue;
186-
}
187-
}
188-
case WasmType.F64_TYPE -> {
189-
if (value instanceof Double) {
190-
continue;
191-
}
192-
}
193-
case WasmType.V128_TYPE -> {
194-
if (value instanceof Vector128) {
195-
continue;
196-
}
197-
}
198-
case WasmType.FUNCREF_TYPE -> {
199-
if (value instanceof WasmFunctionInstance || value == WasmConstant.NULL) {
200-
continue;
201-
}
202-
}
203-
case WasmType.EXTERNREF_TYPE -> {
204-
continue;
205-
}
206-
default -> throw WasmException.create(Failure.UNKNOWN_TYPE);
207-
}
208-
throw UnsupportedTypeException.create(arguments);
168+
169+
@Specialization(replaces = "directAdapter")
170+
static Object indirect(WasmFunctionInstance functionInstance, Object[] arguments,
171+
@Bind Node node) {
172+
CallTarget callAdapter = getOrCreateInteropCallAdapter(functionInstance);
173+
Node callNode = node.isAdoptable() ? node : EncapsulatingNodeReference.getCurrent().get();
174+
return execute(functionInstance, arguments, callAdapter, node, callNode);
209175
}
210-
return arguments;
211-
}
212176

213-
private Object multiValueStackAsArray(WasmLanguage language) {
214-
final var multiValueStack = language.multiValueStack();
215-
final long[] primitiveMultiValueStack = multiValueStack.primitiveStack();
216-
final Object[] objectMultiValueStack = multiValueStack.objectStack();
217-
final int resultCount = function.resultCount();
218-
assert primitiveMultiValueStack.length >= resultCount;
219-
assert objectMultiValueStack.length >= resultCount;
220-
final Object[] values = new Object[resultCount];
221-
for (int i = 0; i < resultCount; i++) {
222-
byte resultType = function.resultTypeAt(i);
223-
values[i] = switch (resultType) {
224-
case WasmType.I32_TYPE -> (int) primitiveMultiValueStack[i];
225-
case WasmType.I64_TYPE -> primitiveMultiValueStack[i];
226-
case WasmType.F32_TYPE -> Float.intBitsToFloat((int) primitiveMultiValueStack[i]);
227-
case WasmType.F64_TYPE -> Double.longBitsToDouble(primitiveMultiValueStack[i]);
228-
case WasmType.V128_TYPE, WasmType.FUNCREF_TYPE, WasmType.EXTERNREF_TYPE -> {
229-
Object obj = objectMultiValueStack[i];
230-
objectMultiValueStack[i] = null;
231-
yield obj;
232-
}
233-
default -> throw WasmException.create(Failure.UNSPECIFIED_INTERNAL);
234-
};
177+
static CallTarget getOrCreateInteropCallAdapter(WasmFunctionInstance functionInstance) {
178+
WasmFunction function = functionInstance.function();
179+
CallTarget callAdapter = function.getInteropCallAdapter();
180+
if (CompilerDirectives.injectBranchProbability(CompilerDirectives.SLOWPATH_PROBABILITY, callAdapter == null)) {
181+
return function.getOrCreateInteropCallAdapter(functionInstance.context().language());
182+
}
183+
return callAdapter;
235184
}
236-
return InteropArray.create(values);
237185
}
238186
}

wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/WasmLanguage.java

+23
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
import org.graalvm.options.OptionDescriptors;
4949
import org.graalvm.options.OptionValues;
5050
import org.graalvm.polyglot.SandboxPolicy;
51+
import org.graalvm.wasm.api.InteropCallAdapterNode;
5152
import org.graalvm.wasm.api.JsConstants;
5253
import org.graalvm.wasm.api.WebAssembly;
5354
import org.graalvm.wasm.exception.WasmJsApiException;
@@ -56,6 +57,7 @@
5657
import org.graalvm.wasm.predefined.BuiltinModule;
5758

5859
import com.oracle.truffle.api.CallTarget;
60+
import com.oracle.truffle.api.CompilerAsserts;
5961
import com.oracle.truffle.api.CompilerDirectives.CompilationFinal;
6062
import com.oracle.truffle.api.ContextThreadLocal;
6163
import com.oracle.truffle.api.RootCallTarget;
@@ -96,8 +98,10 @@ public final class WasmLanguage extends TruffleLanguage<WasmContext> {
9698

9799
private final Map<SymbolTable.FunctionType, Integer> equivalenceClasses = new ConcurrentHashMap<>();
98100
private int nextEquivalenceClass = SymbolTable.FIRST_EQUIVALENCE_CLASS;
101+
private final Map<SymbolTable.FunctionType, CallTarget> interopCallAdapters = new ConcurrentHashMap<>();
99102

100103
public int equivalenceClassFor(SymbolTable.FunctionType type) {
104+
CompilerAsserts.neverPartOfCompilation();
101105
Integer equivalenceClass = equivalenceClasses.get(type);
102106
if (equivalenceClass == null) {
103107
synchronized (this) {
@@ -112,6 +116,20 @@ public int equivalenceClassFor(SymbolTable.FunctionType type) {
112116
return equivalenceClass;
113117
}
114118

119+
/**
120+
* Gets or creates the interop call adapter for a function type. Always returns the same call
121+
* target for any particular type.
122+
*/
123+
public CallTarget interopCallAdapterFor(SymbolTable.FunctionType type) {
124+
CompilerAsserts.neverPartOfCompilation();
125+
CallTarget callAdapter = interopCallAdapters.get(type);
126+
if (callAdapter == null) {
127+
callAdapter = interopCallAdapters.computeIfAbsent(type,
128+
k -> new InteropCallAdapterNode(this, k).getCallTarget());
129+
}
130+
return callAdapter;
131+
}
132+
115133
@Override
116134
protected WasmContext createContext(Env env) {
117135
WasmContext context = new WasmContext(env, this);
@@ -249,6 +267,11 @@ protected boolean areOptionsCompatible(OptionValues firstOptions, OptionValues n
249267
}
250268
}
251269

270+
@SuppressWarnings("unchecked")
271+
public static <E extends Throwable> RuntimeException rethrow(Throwable ex) throws E {
272+
throw (E) ex;
273+
}
274+
252275
public MultiValueStack multiValueStack() {
253276
return multiValueStackThreadLocal.get();
254277
}

0 commit comments

Comments
 (0)