Skip to content

Commit

Permalink
wasm: use non-trapping conversion instructions when casting floats an…
Browse files Browse the repository at this point in the history
…d doubles to ints and longs

Fix #976
  • Loading branch information
konsoletyper committed Nov 30, 2024
1 parent 1460835 commit f11a547
Show file tree
Hide file tree
Showing 7 changed files with 168 additions and 44 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -641,26 +641,29 @@ public void binary(WasmFloatBinaryOperation opcode, WasmFloatType type) {
}

@Override
public void convert(WasmNumType sourceType, WasmNumType targetType, boolean signed, boolean reinterpret) {
public void convert(WasmNumType sourceType, WasmNumType targetType, boolean signed, boolean reinterpret,
boolean nonTrapping) {
switch (targetType) {
case INT32:
writer.write("i32.");
switch (sourceType) {
case FLOAT32:
if (reinterpret) {
writer.write("reinterpret_f32");
} else if (signed) {
writer.write("trunc_f32_s");
} else {
writer.write("trunc_f32_u");
writer.write("trunc_");
if (nonTrapping) {
writer.write("sat_");
}
writer.write("f32_").write(signed ? "s" : "u");
}
break;
case FLOAT64:
if (signed) {
writer.write("trunc_f64_s");
} else {
writer.write("trunc_f64_u");
writer.write("trunc_");
if (nonTrapping) {
writer.write("sat_");
}
writer.write("f64_").write(signed ? "s" : "u");
break;
case INT64:
writer.write("wrap_i64");
Expand All @@ -674,19 +677,21 @@ public void convert(WasmNumType sourceType, WasmNumType targetType, boolean sign
writer.write("i64.");
switch (sourceType) {
case FLOAT32:
if (signed) {
writer.write("trunc_f32_s");
} else {
writer.write("trunc_f32_u");
writer.write("trunc_");
if (nonTrapping) {
writer.write("sat_");
}
writer.write("f32_").write(signed ? "s" : "u");
break;
case FLOAT64:
if (reinterpret) {
writer.write("reinterpret_f64");
} else if (signed) {
writer.write("trunc_f64_s");
} else {
writer.write("trunc_f64_u");
writer.write("trunc_");
if (nonTrapping) {
writer.write("sat_");
}
writer.write("f64_").write(signed ? "s" : "u");
}
break;
case INT32:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1172,9 +1172,11 @@ public void visit(InitClassStatement statement) {
@Override
public void visit(PrimitiveCastExpr expr) {
accept(expr.getValue());
result = new WasmConversion(WasmGeneratorUtil.mapType(expr.getSource()),
var conversion = new WasmConversion(WasmGeneratorUtil.mapType(expr.getSource()),
WasmGeneratorUtil.mapType(expr.getTarget()), true, result);
result.setLocation(expr.getLocation());
conversion.setNonTrapping(true);
conversion.setLocation(expr.getLocation());
result = conversion;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ public class WasmConversion extends WasmExpression {
private boolean signed;
private WasmExpression operand;
private boolean reinterpret;
private boolean nonTrapping;

public WasmConversion(WasmNumType sourceType, WasmNumType targetType, boolean signed, WasmExpression operand) {
Objects.requireNonNull(sourceType);
Expand Down Expand Up @@ -78,6 +79,14 @@ public void setOperand(WasmExpression operand) {
this.operand = operand;
}

public boolean isNonTrapping() {
return nonTrapping;
}

public void setNonTrapping(boolean nonTrapping) {
this.nonTrapping = nonTrapping;
}

@Override
public void acceptVisitor(WasmExpressionVisitor visitor) {
visitor.visit(this);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,8 @@ default void loadFloat64(int align, int offset) {
default void storeFloat64(int align, int offset) {
}

default void convert(WasmNumType sourceType, WasmNumType targetType, boolean signed, boolean reinterpret) {
default void convert(WasmNumType sourceType, WasmNumType targetType, boolean signed, boolean reinterpret,
boolean nonTrapping) {
}

default void memoryGrow() {
Expand Down
73 changes: 49 additions & 24 deletions core/src/main/java/org/teavm/backend/wasm/parser/CodeParser.java
Original file line number Diff line number Diff line change
Expand Up @@ -516,76 +516,76 @@ private boolean parseExpr() {
break;

case 0xA7:
codeListener.convert(WasmNumType.INT64, WasmNumType.INT32, false, false);
codeListener.convert(WasmNumType.INT64, WasmNumType.INT32, false, false, false);
break;
case 0xA8:
codeListener.convert(WasmNumType.FLOAT32, WasmNumType.INT32, false, false);
codeListener.convert(WasmNumType.FLOAT32, WasmNumType.INT32, false, false, false);
break;
case 0xA9:
codeListener.convert(WasmNumType.FLOAT32, WasmNumType.INT32, true, false);
codeListener.convert(WasmNumType.FLOAT32, WasmNumType.INT32, true, false, false);
break;
case 0xAA:
codeListener.convert(WasmNumType.FLOAT64, WasmNumType.INT32, false, false);
codeListener.convert(WasmNumType.FLOAT64, WasmNumType.INT32, false, false, false);
break;
case 0xAB:
codeListener.convert(WasmNumType.FLOAT64, WasmNumType.INT32, true, false);
codeListener.convert(WasmNumType.FLOAT64, WasmNumType.INT32, true, false, false);
break;
case 0xAC:
codeListener.convert(WasmNumType.INT32, WasmNumType.INT64, false, false);
codeListener.convert(WasmNumType.INT32, WasmNumType.INT64, false, false, false);
break;
case 0xAD:
codeListener.convert(WasmNumType.INT32, WasmNumType.INT64, true, false);
codeListener.convert(WasmNumType.INT32, WasmNumType.INT64, true, false, false);
break;
case 0xAE:
codeListener.convert(WasmNumType.FLOAT32, WasmNumType.INT64, false, false);
codeListener.convert(WasmNumType.FLOAT32, WasmNumType.INT64, false, false, false);
break;
case 0xAF:
codeListener.convert(WasmNumType.FLOAT32, WasmNumType.INT64, true, false);
codeListener.convert(WasmNumType.FLOAT32, WasmNumType.INT64, true, false, false);
break;
case 0xB0:
codeListener.convert(WasmNumType.FLOAT64, WasmNumType.INT64, false, false);
codeListener.convert(WasmNumType.FLOAT64, WasmNumType.INT64, false, false, false);
break;
case 0xB1:
codeListener.convert(WasmNumType.FLOAT64, WasmNumType.INT64, true, false);
codeListener.convert(WasmNumType.FLOAT64, WasmNumType.INT64, true, false, false);
break;
case 0xB2:
codeListener.convert(WasmNumType.INT32, WasmNumType.FLOAT32, false, false);
codeListener.convert(WasmNumType.INT32, WasmNumType.FLOAT32, false, false, false);
break;
case 0xB3:
codeListener.convert(WasmNumType.INT32, WasmNumType.FLOAT32, true, false);
codeListener.convert(WasmNumType.INT32, WasmNumType.FLOAT32, true, false, false);
break;
case 0xB4:
codeListener.convert(WasmNumType.INT64, WasmNumType.FLOAT32, false, false);
codeListener.convert(WasmNumType.INT64, WasmNumType.FLOAT32, false, false, false);
break;
case 0xB5:
codeListener.convert(WasmNumType.INT64, WasmNumType.FLOAT32, true, false);
codeListener.convert(WasmNumType.INT64, WasmNumType.FLOAT32, true, false, false);
break;
case 0xB6:
codeListener.convert(WasmNumType.FLOAT64, WasmNumType.FLOAT32, true, false);
codeListener.convert(WasmNumType.FLOAT64, WasmNumType.FLOAT32, true, false, false);
break;
case 0xB7:
codeListener.convert(WasmNumType.INT32, WasmNumType.FLOAT64, false, false);
codeListener.convert(WasmNumType.INT32, WasmNumType.FLOAT64, false, false, false);
break;
case 0xB8:
codeListener.convert(WasmNumType.INT32, WasmNumType.FLOAT64, true, false);
codeListener.convert(WasmNumType.INT32, WasmNumType.FLOAT64, true, false, false);
break;
case 0xB9:
codeListener.convert(WasmNumType.INT64, WasmNumType.FLOAT64, false, false);
codeListener.convert(WasmNumType.INT64, WasmNumType.FLOAT64, false, false, false);
break;
case 0xBA:
codeListener.convert(WasmNumType.INT64, WasmNumType.FLOAT64, true, false);
codeListener.convert(WasmNumType.INT64, WasmNumType.FLOAT64, true, false, false);
break;
case 0xBC:
codeListener.convert(WasmNumType.FLOAT32, WasmNumType.INT32, false, true);
codeListener.convert(WasmNumType.FLOAT32, WasmNumType.INT32, false, true, false);
break;
case 0xBD:
codeListener.convert(WasmNumType.FLOAT64, WasmNumType.INT64, false, true);
codeListener.convert(WasmNumType.FLOAT64, WasmNumType.INT64, false, true, false);
break;
case 0xBE:
codeListener.convert(WasmNumType.INT32, WasmNumType.FLOAT32, false, true);
codeListener.convert(WasmNumType.INT32, WasmNumType.FLOAT32, false, true, false);
break;
case 0xBF:
codeListener.convert(WasmNumType.INT64, WasmNumType.FLOAT64, false, true);
codeListener.convert(WasmNumType.INT64, WasmNumType.FLOAT64, false, true, false);
break;

case 0xD0:
Expand Down Expand Up @@ -623,6 +623,31 @@ private boolean parseExpr() {

private boolean parseExtExpr() {
switch (readLEB()) {
case 0:
codeListener.convert(WasmNumType.INT32, WasmNumType.FLOAT32, true, false, true);
return true;
case 1:
codeListener.convert(WasmNumType.INT32, WasmNumType.FLOAT32, false, false, true);
return true;
case 2:
codeListener.convert(WasmNumType.INT32, WasmNumType.FLOAT64, true, false, true);
return true;
case 3:
codeListener.convert(WasmNumType.INT32, WasmNumType.FLOAT64, false, false, true);
return true;
case 4:
codeListener.convert(WasmNumType.INT64, WasmNumType.FLOAT32, true, false, true);
return true;
case 5:
codeListener.convert(WasmNumType.INT64, WasmNumType.FLOAT32, false, false, true);
return true;
case 6:
codeListener.convert(WasmNumType.INT64, WasmNumType.FLOAT64, true, false, true);
return true;
case 7:
codeListener.convert(WasmNumType.INT64, WasmNumType.FLOAT64, false, false, true);
return true;

case 10: {
if (reader.data[reader.ptr++] != 0 || reader.data[reader.ptr++] != 0) {
return false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -839,12 +839,20 @@ public void visit(WasmConversion expression) {
case INT32:
if (expression.isReinterpret()) {
writer.writeByte(0xBC);
} else if (expression.isNonTrapping()) {
writer.writeByte(0xFC);
writer.writeByte(expression.isSigned() ? 0 : 1);
} else {
writer.writeByte(expression.isSigned() ? 0xA8 : 0xA9);
}
break;
case INT64:
writer.writeByte(expression.isSigned() ? 0xAE : 0xAF);
if (expression.isNonTrapping()) {
writer.writeByte(0xFC);
writer.writeByte(expression.isSigned() ? 4 : 5);
} else {
writer.writeByte(expression.isSigned() ? 0xAE : 0xAF);
}
break;
case FLOAT32:
break;
Expand All @@ -856,11 +864,19 @@ public void visit(WasmConversion expression) {
case FLOAT64:
switch (expression.getTargetType()) {
case INT32:
writer.writeByte(expression.isSigned() ? 0xAA : 0xAB);
if (expression.isNonTrapping()) {
writer.writeByte(0xFC);
writer.writeByte(expression.isSigned() ? 2 : 3);
} else {
writer.writeByte(expression.isSigned() ? 0xAA : 0xAB);
}
break;
case INT64:
if (expression.isReinterpret()) {
writer.writeByte(0xBD);
} else if (expression.isNonTrapping()) {
writer.writeByte(0xFC);
writer.writeByte(expression.isSigned() ? 6 : 7);
} else {
writer.writeByte(expression.isSigned() ? 0xB0 : 0xB1);
}
Expand Down
66 changes: 66 additions & 0 deletions tests/src/test/java/org/teavm/vm/NumericConversionTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
/*
* Copyright 2024 Alexey Andreev.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.teavm.vm;

import static org.junit.Assert.assertEquals;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.teavm.junit.SkipPlatform;
import org.teavm.junit.TeaVMTestRunner;
import org.teavm.junit.TestPlatform;

@RunWith(TeaVMTestRunner.class)
public class NumericConversionTest {
@Test
@SkipPlatform({TestPlatform.JAVASCRIPT, TestPlatform.C})
public void floatOverflow() {
assertEquals(2147483647, (int) (floatOne() * (1 << 30) * (1 << 3)));
assertEquals(2147483647, (int) (floatOne() * Float.POSITIVE_INFINITY));
assertEquals(-2147483648, (int) (-floatOne() * (1 << 30) * (1 << 3)));
assertEquals(-2147483648, (int) (-floatOne() * Float.POSITIVE_INFINITY));
assertEquals(0, (int) (floatOne() * Float.NaN));

assertEquals((1L << 63) - 1, (long) (floatOne() * (1L << 60) * (1 << 5)));
assertEquals((1L << 63) - 1, (long) (floatOne() * Float.POSITIVE_INFINITY));
assertEquals(1L << 63, (long) (-floatOne() * (1L << 60) * (1 << 5)));
assertEquals(1L << 63, (long) (-floatOne() * Float.POSITIVE_INFINITY));
assertEquals(0, (long) (floatOne() * Float.NaN));
}

@Test
@SkipPlatform({TestPlatform.JAVASCRIPT, TestPlatform.C})
public void doubleOverflow() {
assertEquals(2147483647, (int) (doubleOne() * (1 << 30) * (1 << 3)));
assertEquals(2147483647, (int) (doubleOne() * Float.POSITIVE_INFINITY));
assertEquals(-2147483648, (int) (-doubleOne() * (1 << 30) * (1 << 3)));
assertEquals(-2147483648, (int) (-doubleOne() * Float.POSITIVE_INFINITY));
assertEquals(0, (int) (doubleOne() * Double.NaN));

assertEquals((1L << 63) - 1, (long) (doubleOne() * (1L << 60) * (1 << 5)));
assertEquals((1L << 63) - 1, (long) (doubleOne() * Double.POSITIVE_INFINITY));
assertEquals(1L << 63, (long) (-doubleOne() * (1L << 60) * (1 << 5)));
assertEquals(1L << 63, (long) (-doubleOne() * Double.POSITIVE_INFINITY));
assertEquals(0, (long) (doubleOne() * Double.NaN));
}

private float floatOne() {
return 1;
}

private double doubleOne() {
return 1;
}
}

0 comments on commit f11a547

Please sign in to comment.