From 32b27e22256b2e6eef9d59c9715b98a0acb5b053 Mon Sep 17 00:00:00 2001 From: jbb01 <32650546+jbb01@users.noreply.github.com> Date: Sun, 18 Aug 2024 14:21:28 +0200 Subject: [PATCH] improve performance by precomputing parts of the AABB intersection algorithm --- .../eu/jonahbauer/raytracing/math/AABB.java | 54 +++++++------ .../eu/jonahbauer/raytracing/math/Ray.java | 76 +++++++++++++++++-- .../eu/jonahbauer/raytracing/math/Vec3.java | 8 ++ .../raytracing/scene/HitResult.java | 8 +- 4 files changed, 108 insertions(+), 38 deletions(-) diff --git a/src/main/java/eu/jonahbauer/raytracing/math/AABB.java b/src/main/java/eu/jonahbauer/raytracing/math/AABB.java index 76cc556..559e9c7 100644 --- a/src/main/java/eu/jonahbauer/raytracing/math/AABB.java +++ b/src/main/java/eu/jonahbauer/raytracing/math/AABB.java @@ -79,35 +79,39 @@ public record AABB(@NotNull Vec3 min, @NotNull Vec3 max) { * @return {@code true} iff the ray intersects this bounding box, {@code false} otherwise */ public boolean hit(@NotNull Ray ray, @NotNull Range range) { - var origin = ray.origin(); - var direction = ray.direction(); - var invDirection = direction.inv(); - - // calculate t values for intersection points of ray with planes through min - var tmin = intersect(min(), origin, invDirection); - // calculate t values for intersection points of ray with planes through max - var tmax = intersect(max(), origin, invDirection); - - // determine range of t for which the ray is inside this voxel - double tlmax = Double.NEGATIVE_INFINITY; // lower limit maximum - double tumin = Double.POSITIVE_INFINITY; // upper limit minimum - - for (int i = 0; i < 3; i++) { - // classify t values as lower or upper limit based on ray direction - if (direction.get(i) >= 0) { - // min is lower limit and max is upper limit - if (tmin[i] > tlmax) tlmax = tmin[i]; - if (tmax[i] < tumin) tumin = tmax[i]; - } else { - // max is lower limit and min is upper limit - if (tmax[i] > tlmax) tlmax = tmax[i]; - if (tmin[i] < tumin) tumin = tmin[i]; - } - } + var invDirection = ray.getInvDirection(); + var negInvOrigin = ray.getNegInvOrigin(); + + var tminX = Math.fma(min.x(), invDirection.x(), negInvOrigin.x()); + var tminY = Math.fma(min.y(), invDirection.y(), negInvOrigin.y()); + var tminZ = Math.fma(min.z(), invDirection.z(), negInvOrigin.z()); + + var tmaxX = Math.fma(max.x(), invDirection.x(), negInvOrigin.x()); + var tmaxY = Math.fma(max.y(), invDirection.y(), negInvOrigin.y()); + var tmaxZ = Math.fma(max.z(), invDirection.z(), negInvOrigin.z()); + + var tlmax = max( + Math.min(tminX, tmaxX), + Math.min(tminY, tmaxY), + Math.min(tminZ, tmaxZ) + ); + var tumin = min( + Math.max(tminX, tmaxX), + Math.max(tminY, tmaxY), + Math.max(tminZ, tmaxZ) + ); return tlmax < tumin && tumin >= range.min() && tlmax <= range.max(); } + private static double max(double a, double b, double c) { + return Math.max(a, Math.max(b, c)); + } + + private static double min(double a, double b, double c) { + return Math.min(a, Math.min(b, c)); + } + /** * Computes the {@code t} values of the intersections of a ray with the axis-aligned planes through a point. * @param corner the point diff --git a/src/main/java/eu/jonahbauer/raytracing/math/Ray.java b/src/main/java/eu/jonahbauer/raytracing/math/Ray.java index acbd811..4a023f2 100644 --- a/src/main/java/eu/jonahbauer/raytracing/math/Ray.java +++ b/src/main/java/eu/jonahbauer/raytracing/math/Ray.java @@ -6,17 +6,35 @@ import org.jetbrains.annotations.NotNull; import java.util.Objects; -public record Ray(@NotNull Vec3 origin, @NotNull Vec3 direction, @NotNull SampledWavelengths lambda) { - public Ray { - Objects.requireNonNull(origin, "origin"); - Objects.requireNonNull(direction, "direction"); - Objects.requireNonNull(lambda, "lambda"); - } +public final class Ray { + private final @NotNull Vec3 origin; + private final @NotNull Vec3 direction; + private final @NotNull SampledWavelengths lambda; + + private final @NotNull Vec3 inv; + private final @NotNull Vec3 negInvOrigin; public Ray(@NotNull Vec3 origin, @NotNull Vec3 direction) { this(origin, direction, SampledWavelengths.EMPTY); } + public Ray(@NotNull Vec3 origin, @NotNull Vec3 direction, @NotNull SampledWavelengths lambda) { + this.origin = Objects.requireNonNull(origin, "origin"); + this.direction = Objects.requireNonNull(direction, "direction"); + this.lambda = Objects.requireNonNull(lambda, "lambda"); + + this.inv = direction.inv(); + this.negInvOrigin = inv.neg().times(origin); + } + + private Ray(@NotNull Vec3 origin, @NotNull Vec3 direction, @NotNull SampledWavelengths lambda, @NotNull Vec3 inv, @NotNull Vec3 negInvOrigin) { + this.origin = origin; + this.direction = direction; + this.lambda = lambda; + this.inv = inv; + this.negInvOrigin = negInvOrigin; + } + public @NotNull Vec3 at(double t) { return Vec3.fma(t, direction, origin); } @@ -30,6 +48,50 @@ public record Ray(@NotNull Vec3 origin, @NotNull Vec3 direction, @NotNull Sample } public @NotNull Ray with(@NotNull SampledWavelengths lambda) { - return new Ray(origin, direction, lambda); + return new Ray(origin, direction, lambda, inv, negInvOrigin); + } + + public @NotNull Vec3 origin() { + return origin; } + + public @NotNull Vec3 direction() { + return direction; + } + + public @NotNull SampledWavelengths lambda() { + return lambda; + } + + public @NotNull Vec3 getInvDirection() { + return inv; + } + + public @NotNull Vec3 getNegInvOrigin() { + return negInvOrigin; + } + + @Override + public boolean equals(Object obj) { + if (obj == this) return true; + if (obj == null || obj.getClass() != this.getClass()) return false; + var that = (Ray) obj; + return Objects.equals(this.origin, that.origin) && + Objects.equals(this.direction, that.direction) && + Objects.equals(this.lambda, that.lambda); + } + + @Override + public int hashCode() { + return Objects.hash(origin, direction, lambda); + } + + @Override + public @NotNull String toString() { + return "Ray[" + + "origin=" + origin + ", " + + "direction=" + direction + ", " + + "lambda=" + lambda + ']'; + } + } diff --git a/src/main/java/eu/jonahbauer/raytracing/math/Vec3.java b/src/main/java/eu/jonahbauer/raytracing/math/Vec3.java index 0ae4512..eaeaa60 100644 --- a/src/main/java/eu/jonahbauer/raytracing/math/Vec3.java +++ b/src/main/java/eu/jonahbauer/raytracing/math/Vec3.java @@ -160,6 +160,14 @@ public record Vec3(double x, double y, double z) implements IVec3 { ); } + public static @NotNull Vec3 fma(@NotNull Vec3 a, @NotNull Vec3 b, @NotNull Vec3 c) { + return new Vec3( + Math.fma(a.x(), b.x(), c.x()), + Math.fma(a.y(), b.y(), c.y()), + Math.fma(a.z(), b.z(), c.z()) + ); + } + public static double tripleProduct(@NotNull Vec3 a, @NotNull Vec3 b, @NotNull Vec3 c) { return a.x * b.y * c.z + a.y * b.z * c.x + a.z * b.x * c.y - c.x * b.y * a.z - c.y * b.z * a.x - c.z * b.x * a.y; } diff --git a/src/main/java/eu/jonahbauer/raytracing/scene/HitResult.java b/src/main/java/eu/jonahbauer/raytracing/scene/HitResult.java index 83f4a00..e83921b 100644 --- a/src/main/java/eu/jonahbauer/raytracing/scene/HitResult.java +++ b/src/main/java/eu/jonahbauer/raytracing/scene/HitResult.java @@ -7,10 +7,10 @@ import eu.jonahbauer.raytracing.render.material.Material; import eu.jonahbauer.raytracing.render.texture.Texture; import org.jetbrains.annotations.NotNull; -import java.util.Objects; +import java.util.random.RandomGenerator; /** - * The result of a {@linkplain Hittable#hit(Ray, Range) hit}. + * The result of a {@linkplain Hittable#hit(Ray, Range, RandomGenerator) hit}. * @param t the {@code t} value at which the hit occurs * @param position the position of the hit * @param normal the surface normal at the hit position @@ -24,10 +24,6 @@ public record HitResult( double t, @NotNull Vec3 position, @NotNull Vec3 normal, @NotNull Hittable target, @NotNull Material material, double u, double v, boolean isFrontFace ) implements Comparable { - public HitResult { - Objects.requireNonNull(position, "position"); - normal = normal.unit(); - } public @NotNull HitResult withPositionAndNormal(@NotNull Vec3 position, @NotNull Vec3 normal) { return new HitResult(t, position, normal, target, material, u, v, isFrontFace);