From 137c0b21908c07e21a2a28c9727c1ee72f1d59fd Mon Sep 17 00:00:00 2001 From: jbb01 <32650546+jbb01@users.noreply.github.com> Date: Sun, 4 Aug 2024 23:57:54 +0200 Subject: [PATCH] add octree to improve rendering performance by reducing the number of ray-sphere-intersection calculations --- .../java/eu/jonahbauer/raytracing/Main.java | 3 + .../eu/jonahbauer/raytracing/math/Octree.java | 259 ++++++++++++++++++ .../eu/jonahbauer/raytracing/math/Vec3.java | 34 +++ .../eu/jonahbauer/raytracing/scene/Scene.java | 80 +++++- 4 files changed, 363 insertions(+), 13 deletions(-) create mode 100644 src/main/java/eu/jonahbauer/raytracing/math/Octree.java diff --git a/src/main/java/eu/jonahbauer/raytracing/Main.java b/src/main/java/eu/jonahbauer/raytracing/Main.java index 1a82296..75db7d0 100644 --- a/src/main/java/eu/jonahbauer/raytracing/Main.java +++ b/src/main/java/eu/jonahbauer/raytracing/Main.java @@ -44,7 +44,10 @@ public class Main { var image = new LiveCanvas(new Image(camera.getWidth(), camera.getHeight())); image.preview(); + long time = System.nanoTime(); renderer.render(camera, scene, image); + System.out.printf("rendering finished after %dms", (System.nanoTime() - time) / 1_000_000); + ImageFormat.PNG.write(image, Path.of("scene-" + System.currentTimeMillis() + ".png")); } diff --git a/src/main/java/eu/jonahbauer/raytracing/math/Octree.java b/src/main/java/eu/jonahbauer/raytracing/math/Octree.java new file mode 100644 index 0000000..cbfc38d --- /dev/null +++ b/src/main/java/eu/jonahbauer/raytracing/math/Octree.java @@ -0,0 +1,259 @@ +package eu.jonahbauer.raytracing.math; + +import org.jetbrains.annotations.NotNull; +import org.jetbrains.annotations.Nullable; + +import java.util.*; +import java.util.function.Predicate; + +public final class Octree { + private final @NotNull NodeStorage storage; + + public Octree(@NotNull Vec3 center, double dimension) { + this.storage = new NodeStorage<>(center, dimension); + } + + public void add(@NotNull BoundingBox bbox, T object) { + storage.add(new Entry<>(bbox, object)); + } + + /** + * Use HERO algorithms to find all elements that could possibly be hit by the given ray. + * @see + * Agate, M., Grimsdale, R.L., Lister, P.F. (1991). + * The HERO Algorithm for Ray-Tracing Octrees. + * In: Grimsdale, R.L., Straßer, W. (eds) Advances in Computer Graphics Hardware IV. Eurographic Seminars. Springer, Berlin, Heidelberg. + */ + public void hit(@NotNull Ray ray, @NotNull Predicate action) { + storage.hit(ray, action); + } + + public static int getOctantIndex(@NotNull Vec3 center, @NotNull Vec3 pos) { + return (pos.x() < center.x() ? 0 : 1) + | (pos.y() < center.y() ? 0 : 2) + | (pos.z() < center.z() ? 0 : 4); + + } + + private static sealed abstract class Storage { + protected static final int LIST_SIZE_LIMIT = 32; + + protected final @NotNull Vec3 center; + protected final double dimension; + + public Storage(@NotNull Vec3 center, double dimension) { + this.center = Objects.requireNonNull(center); + this.dimension = dimension; + } + + public abstract @NotNull Storage add(@NotNull Entry entry); + + protected abstract boolean hit(@NotNull Ray ray, @NotNull Predicate action); + + protected boolean hit0(@NotNull Ray ray, int vmask, double tmin, double tmax, @NotNull Predicate action) { + return hit(ray, action); + } + } + + private static final class ListStorage extends Storage { + private final @NotNull List> list = new ArrayList<>(); + + public ListStorage(@NotNull Vec3 center, double dimension) { + super(center, dimension); + } + + @Override + public @NotNull Storage add(@NotNull Entry entry) { + if (list.size() >= LIST_SIZE_LIMIT) { + var node = new NodeStorage(center, dimension); + list.forEach(node::add); + node.add(entry); + return node; + } else { + list.add(entry); + return this; + } + } + + @Override + protected boolean hit(@NotNull Ray ray, @NotNull Predicate action) { + var hit = false; + for (Entry entry : list) { + hit |= action.test(entry.object()); + } + return hit; + } + } + + private static final class NodeStorage extends Storage { + @SuppressWarnings("unchecked") + private final @Nullable Storage @NotNull[] octants = new Storage[8]; + private final @NotNull List> list = new ArrayList<>(); // track elements spanning multiple octants separately + + public NodeStorage(@NotNull Vec3 center, double dimension) { + super(center, dimension); + } + + @Override + public @NotNull Storage add(@NotNull Entry entry) { + var index = getOctantIndex(center, entry.bbox().min()); + if (index != getOctantIndex(center, entry.bbox().max())) { + list.add(entry); + } else { + var subnode = octants[index]; + if (subnode == null) { + subnode = newOctant(index); + } + octants[index] = subnode.add(entry); + } + return this; + } + + private @NotNull Storage newOctant(int index) { + var newSize = 0.5 * dimension; + var newCenter = this.center + .plus(new Vec3( + (index & 1) == 0 ? -newSize : newSize, + (index & 2) == 0 ? -newSize : newSize, + (index & 4) == 0 ? -newSize : newSize + )); + return new ListStorage<>(newCenter, newSize); + } + + @Override + protected boolean hit(@NotNull Ray ray, @NotNull Predicate action) { + int vmask = (ray.direction().x() < 0 ? 1 : 0) + | (ray.direction().y() < 0 ? 2 : 0) + | (ray.direction().z() < 0 ? 4 : 0); + + var min = center.minus(dimension, dimension, dimension); + var max = center.plus(dimension, dimension, dimension); + + // calculate t values for intersection points of ray with planes through min + var tmin = calculatePlaneIntersections(min, ray); + // calculate t values for intersection points of ray with planes through max + var tmax = calculatePlaneIntersections(max, ray); + + // determine range of t for which the ray is inside this voxel + double tlmax = Double.NEGATIVE_INFINITY; // lower limit maximum + double tumin = Double.POSITIVE_INFINITY; // upper limit minimum + for (int i = 0; i < 3; i++) { + // classify t values as lower or upper limit based on vmask + if ((vmask & (1 << i)) == 0) { + // min is lower limit and max is upper limit + tlmax = Math.max(tlmax, tmin[i]); + tumin = Math.min(tumin, tmax[i]); + } else { + // max is lower limit and min is upper limit + tlmax = Math.max(tlmax, tmax[i]); + tumin = Math.min(tumin, tmin[i]); + } + } + + var hit = tlmax < tumin; + if (!hit) return false; + + return hit0(ray, vmask, tlmax, tumin, action); + } + + @Override + protected boolean hit0(@NotNull Ray ray, int vmask, double tmin, double tmax, @NotNull Predicate action) { + if (tmax < 0) return false; + + // check for hit + var hit = false; + + // process entries spanning multiple children + for (Entry entry : list) { + hit |= action.test(entry.object()); + } + + // t values for intersection points of ray with planes through center + var tmid = calculatePlaneIntersections(center, ray); + // masks of planes in the order of intersection, e.g. [2, 1, 4] for a ray intersection y = center.y() then x = center.x() then z = center.z() + var masklist = calculateMasklist(tmid); + // the first child to be hit by the ray assuming a ray with positive x, y and z coordinates + var childmask = (tmid[0] < tmin ? 1 : 0) + | (tmid[1] < tmin ? 2 : 0) + | (tmid[2] < tmin ? 4 : 0); + // the last child to be hit by the ray assuming a ray with positive x, y and z coordinates + var lastmask = (tmid[0] < tmax ? 1 : 0) + | (tmid[1] < tmax ? 2 : 0) + | (tmid[2] < tmax ? 4 : 0); + + var childTmin = tmin; + + int i = 0; + while (true) { + // use vmask to nullify the assumption of a positive ray made for childmask + var child = octants[childmask ^ vmask]; + + // calculate t value for exit of child + double childTmax; + if (childmask == lastmask) { + // last child shares tmax + childTmax = tmax; + } else { + // determine next child + while ((masklist[i] & childmask) != 0) { + i++; + } + childmask = childmask | masklist[i]; + // tmax of current child is the t value for the intersection with the plane dividing the current and next child + childTmax = tmid[Integer.numberOfTrailingZeros(masklist[i])]; + } + + // process child + var childHit = child != null && child.hit0(ray, vmask, childTmin, childTmax, action); + hit |= childHit; + + // break after last child has been processed or a hit has been found + if (childTmax == tmax || childHit) break; + + // tmin of next child is tmax of current child + childTmin = childTmax; + } + + return hit; + } + + private double @NotNull [] calculatePlaneIntersections(@NotNull Vec3 position, @NotNull Ray ray) { + return new double[] { + (position.x() - ray.origin().x()) / ray.direction().x(), + (position.y() - ray.origin().y()) / ray.direction().y(), + (position.z() - ray.origin().z()) / ray.direction().z(), + }; + } + + private static final int[][] MASKLISTS = new int[][] { + {1, 2, 4}, + {1, 4, 2}, + {4, 1, 2}, + {2, 1, 4}, + {2, 4, 1}, + {4, 2, 1} + }; + + private static int @NotNull [] calculateMasklist(double @NotNull[] tmid) { + if (tmid[0] < tmid[1]) { + if (tmid[1] < tmid[2]) { + return MASKLISTS[0]; // {1, 2, 4} + } else if (tmid[0] < tmid[2]) { + return MASKLISTS[1]; // {1, 4, 2} + } else { + return MASKLISTS[2]; // {4, 1, 2} + } + } else { + if (tmid[0] < tmid[2]) { + return MASKLISTS[3]; // {2, 1, 4} + } else if (tmid[1] < tmid[2]) { + return MASKLISTS[4]; // {2, 4, 1} + } else { + return MASKLISTS[5]; // {4, 2, 1} + } + } + } + } + + private record Entry(@NotNull BoundingBox bbox, T object) { } +} diff --git a/src/main/java/eu/jonahbauer/raytracing/math/Vec3.java b/src/main/java/eu/jonahbauer/raytracing/math/Vec3.java index 76c8d41..35bc699 100644 --- a/src/main/java/eu/jonahbauer/raytracing/math/Vec3.java +++ b/src/main/java/eu/jonahbauer/raytracing/math/Vec3.java @@ -6,6 +6,8 @@ import java.util.Optional; public record Vec3(double x, double y, double z) { public static final Vec3 ZERO = new Vec3(0, 0, 0); + public static final Vec3 MAX = new Vec3(Double.MAX_VALUE, Double.MAX_VALUE, Double.MAX_VALUE); + public static final Vec3 MIN = new Vec3(-Double.MAX_VALUE, -Double.MAX_VALUE, -Double.MAX_VALUE); public static final Vec3 UNIT_X = new Vec3(1, 0, 0); public static final Vec3 UNIT_Y = new Vec3(0, 1, 0); public static final Vec3 UNIT_Z = new Vec3(0, 0, 1); @@ -52,6 +54,38 @@ public record Vec3(double x, double y, double z) { return a.minus(b).length(); } + public static @NotNull Vec3 average(@NotNull Vec3 current, @NotNull Vec3 next, int index) { + return new Vec3( + current.x() + (next.x() - current.x()) / index, + current.y() + (next.y() - current.y()) / index, + current.z() + (next.z() - current.z()) / index + ); + } + + public static @NotNull Vec3 max(@NotNull Vec3 a, @NotNull Vec3 b) { + return new Vec3( + Math.max(a.x(), b.x()), + Math.max(a.y(), b.y()), + Math.max(a.z(), b.z()) + ); + } + + public static @NotNull Vec3 min(@NotNull Vec3 a, @NotNull Vec3 b) { + return new Vec3( + Math.min(a.x(), b.x()), + Math.min(a.y(), b.y()), + Math.min(a.z(), b.z()) + ); + } + + public @NotNull Vec3 plus(double x, double y, double z) { + return new Vec3(this.x + x, this.y + y, this.z + z); + } + + public @NotNull Vec3 minus(double x, double y, double z) { + return new Vec3(this.x - x, this.y - y, this.z - z); + } + public @NotNull Vec3 plus(@NotNull Vec3 b) { return new Vec3(this.x + b.x, this.y + b.y, this.z + b.z); } diff --git a/src/main/java/eu/jonahbauer/raytracing/scene/Scene.java b/src/main/java/eu/jonahbauer/raytracing/scene/Scene.java index 7430895..bfba547 100644 --- a/src/main/java/eu/jonahbauer/raytracing/scene/Scene.java +++ b/src/main/java/eu/jonahbauer/raytracing/scene/Scene.java @@ -1,31 +1,85 @@ package eu.jonahbauer.raytracing.scene; +import eu.jonahbauer.raytracing.math.Octree; import eu.jonahbauer.raytracing.math.Range; import eu.jonahbauer.raytracing.math.Ray; +import eu.jonahbauer.raytracing.math.Vec3; import org.jetbrains.annotations.NotNull; +import java.util.ArrayList; +import java.util.Arrays; import java.util.List; import java.util.Optional; -public record Scene(@NotNull List<@NotNull Hittable> objects) implements Hittable { +public final class Scene implements Hittable { + private final @NotNull Octree<@NotNull Hittable> octree; + private final @NotNull List<@NotNull Hittable> list; - public Scene { - objects = List.copyOf(objects); - } + public Scene(@NotNull List objects) { + this.octree = newOctree(objects); + this.list = new ArrayList<>(); - public Scene(@NotNull Hittable @NotNull ... objects) { - this(List.of(objects)); + for (Hittable object : objects) { + var bbox = object.getBoundingBox(); + if (bbox.isPresent()) { + octree.add(bbox.get(), object); + } else { + list.add(object); + } + } } + @Override public @NotNull Optional hit(@NotNull Ray ray, @NotNull Range range) { - var result = (HitResult) null; - for (var object : objects) { - var r = object.hit(ray, range); - if (r.isPresent() && range.surrounds(r.get().t())) { - result = r.get(); - range = new Range(range.min(), result.t()); + var state = new State(); + state.range = range; + + octree.hit(ray, object -> hit(state, ray, object)); + list.forEach(object -> hit(state, ray, object)); + + return Optional.ofNullable(state.result); + } + + private boolean hit(@NotNull State state, @NotNull Ray ray, @NotNull Hittable object) { + var r = object.hit(ray, state.range); + if (r.isPresent()) { + if (state.range.surrounds(r.get().t())){ + state.result = r.get(); + state.range = new Range(state.range.min(), state.result.t()); + } + return true; + } else { + return false; + } + } + + private static @NotNull Octree newOctree(@NotNull List objects) { + Vec3 center = Vec3.ZERO, max = Vec3.MIN, min = Vec3.MAX; + + int i = 1; + for (Hittable object : objects) { + var bbox = object.getBoundingBox(); + if (bbox.isPresent()) { + center = Vec3.average(center, bbox.get().center(), i++); + max = Vec3.max(max, bbox.get().max()); + min = Vec3.min(min, bbox.get().min()); } } - return Optional.ofNullable(result); + + var dimension = Arrays.stream(new double[] { + Math.abs(max.x() - center.x()), + Math.abs(max.y() - center.y()), + Math.abs(max.z() - center.z()), + Math.abs(min.x() - center.x()), + Math.abs(min.y() - center.y()), + Math.abs(min.z() - center.z()) + }).max().orElse(10); + + return new Octree<>(center, dimension); + } + + private static class State { + HitResult result; + Range range; } }