From d67f877428ef2801f58beb1d860dd031e3889996 Mon Sep 17 00:00:00 2001 From: jbb01 <32650546+jbb01@users.noreply.github.com> Date: Sun, 4 Aug 2024 23:57:54 +0200 Subject: [PATCH] add octree to improve rendering performance by reducing the number of ray-sphere-intersection calculations --- .../eu/jonahbauer/raytracing/math/Octree.java | 253 ++++++++++++++++++ .../eu/jonahbauer/raytracing/math/Vec3.java | 34 +++ .../eu/jonahbauer/raytracing/scene/Scene.java | 81 ++++-- 3 files changed, 352 insertions(+), 16 deletions(-) create mode 100644 src/main/java/eu/jonahbauer/raytracing/math/Octree.java diff --git a/src/main/java/eu/jonahbauer/raytracing/math/Octree.java b/src/main/java/eu/jonahbauer/raytracing/math/Octree.java new file mode 100644 index 0000000..97e673e --- /dev/null +++ b/src/main/java/eu/jonahbauer/raytracing/math/Octree.java @@ -0,0 +1,253 @@ +package eu.jonahbauer.raytracing.math; + +import org.jetbrains.annotations.NotNull; +import org.jetbrains.annotations.Nullable; + +import java.util.*; +import java.util.function.Consumer; +import java.util.stream.Collectors; + +public final class Octree { + private final @NotNull NodeStorage storage; + + public Octree(@NotNull Vec3 center, double dimension) { + this.storage = new NodeStorage<>(center, dimension); + } + + public void add(@NotNull BoundingBox bbox, T object) { + storage.add(new Entry<>(bbox, object)); + } + + /** + * Use HERO algorithms to find all elements that could possibly be hit by the given ray. + * @see The HERO Algorithm for Ray-Tracing Octrees + */ + public void hit(@NotNull Ray ray, @NotNull Consumer action) { + storage.hit(ray, action); + } + + @Override + public @NotNull String toString() { + return storage.toString(); + } + + public static int getOctantIndex(@NotNull Vec3 center, @NotNull Vec3 pos) { + return (pos.x() < center.x() ? 0 : 1) + | (pos.y() < center.y() ? 0 : 2) + | (pos.z() < center.z() ? 0 : 4); + + } + + private sealed interface Storage { + int LIST_SIZE_LIMIT = 32; + + @NotNull Storage add(@NotNull Entry entry); + } + + private static final class ListStorage implements Storage { + private final @NotNull Vec3 center; + private final double dimension; + + private final @NotNull List> list = new ArrayList<>(); + + public ListStorage(@NotNull Vec3 center, double dimension) { + this.center = Objects.requireNonNull(center); + this.dimension = dimension; + } + + @Override + public @NotNull Storage add(@NotNull Entry entry) { + if (list.size() >= LIST_SIZE_LIMIT) { + var node = new NodeStorage(center, dimension); + list.forEach(node::add); + node.add(entry); + return node; + } else { + list.add(entry); + return this; + } + } + + @Override + public String toString() { + return list.toString(); + } + } + + private static final class NodeStorage implements Storage { + private final @NotNull Vec3 center; + private final double dimension; + + @SuppressWarnings("unchecked") + private final @Nullable Storage @NotNull[] octants = new Storage[8]; + private final @NotNull List> list = new ArrayList<>(); // track elements spanning multiple octants separately + + public NodeStorage(@NotNull Vec3 center, double dimension) { + this.center = Objects.requireNonNull(center); + this.dimension = dimension; + } + + @Override + public @NotNull Storage add(@NotNull Entry entry) { + var index = getOctantIndex(center, entry.bbox().min()); + if (index != getOctantIndex(center, entry.bbox().max())) { + list.add(entry); + } else { + var subnode = octants[index]; + if (subnode == null) { + subnode = newOctant(index); + } + octants[index] = subnode.add(entry); + } + return this; + } + + private @NotNull Storage newOctant(int index) { + var newSize = 0.5 * dimension; + var newCenter = this.center + .plus(new Vec3( + (index & 1) == 0 ? -newSize : newSize, + (index & 2) == 0 ? -newSize : newSize, + (index & 4) == 0 ? -newSize : newSize + )); + return new ListStorage<>(newCenter, newSize); + } + + public void hit(@NotNull Ray ray, @NotNull Consumer action) { + int vmask = (ray.direction().x() < 0 ? 1 : 0) + | (ray.direction().y() < 0 ? 2 : 0) + | (ray.direction().z() < 0 ? 4 : 0); + + var min = center.minus(dimension, dimension, dimension); + var max = center.plus(dimension, dimension, dimension); + + // calculate t values for intersection points of ray with planes through min + var tmin = calculatePlaneIntersections(min, ray); + // calculate t values for intersection points of ray with planes through max + var tmax = calculatePlaneIntersections(max, ray); + + // determine range of t for which the ray is inside this voxel + double tlmax = Double.NEGATIVE_INFINITY; // lower limit maximum + double tumin = Double.POSITIVE_INFINITY; // upper limit minimum + for (int i = 0; i < 3; i++) { + // classify t values as lower or upper limit based on vmask + if ((vmask & (1 << i)) == 0) { + // min is lower limit and max is upper limit + tlmax = Math.max(tlmax, tmin[i]); + tumin = Math.min(tumin, tmax[i]); + } else { + // max is lower limit and min is upper limit + tlmax = Math.max(tlmax, tmax[i]); + tumin = Math.min(tumin, tmin[i]); + } + } + + var hit = tlmax < tumin; + if (!hit) return; + + hit0(ray, vmask, tlmax, tumin, action); + } + + private void hit0(@NotNull Ray ray, int vmask, double tmin, double tmax, @NotNull Consumer action) { + if (tmax < 0) return; + for (Entry entry : list) { + action.accept(entry.object()); + } + + // t values for intersection points of ray with planes through center + var tmid = calculatePlaneIntersections(center, ray); + // masks of planes in the order of intersection, e.g. [2, 1, 4] for a ray intersection y = center.y() then x = center.x() then z = center.z() + var masklist = calculateMastlist(tmid); + var childmask = (tmid[0] < tmin ? 1 : 0) + | (tmid[1] < tmin ? 2 : 0) + | (tmid[2] < tmin ? 4 : 0); + var lastmask = (tmid[0] < tmax ? 1 : 0) + | (tmid[1] < tmax ? 2 : 0) + | (tmid[2] < tmax ? 4 : 0); + + var childTmin = tmin; + + int i = 0; + while (true) { + var child = octants[childmask ^ vmask]; + double childTmax; + if (childmask == lastmask) { + childTmax = tmax; + } else { + while ((masklist[i] & childmask) != 0) { + i++; + } + childmask = childmask | masklist[i]; + childTmax = tmid[Integer.numberOfTrailingZeros(masklist[i])]; + } + + if (child instanceof ListStorage list) { + for (Entry entry : list.list) { + action.accept(entry.object); + } + } else if (child instanceof NodeStorage node) { + node.hit0(ray, vmask, childTmin, childTmax, action); + } + + if (childTmax == tmax) break; + childTmin = childTmax; + } + } + + private double @NotNull [] calculatePlaneIntersections(@NotNull Vec3 position, @NotNull Ray ray) { + return new double[] { + (position.x() - ray.origin().x()) / ray.direction().x(), + (position.y() - ray.origin().y()) / ray.direction().y(), + (position.z() - ray.origin().z()) / ray.direction().z(), + }; + } + + private int @NotNull [] calculateMastlist(double @NotNull[] tmid) { + var masklist = new int[3]; + if (tmid[0] < tmid[1] && tmid[0] < tmid[2]) { + masklist[0] = 1; + if (tmid[1] < tmid[2]) { + masklist[1] = 2; + masklist[2] = 4; + } else { + masklist[1] = 4; + masklist[2] = 2; + } + } else if (tmid[1] < tmid[2]) { + masklist[0] = 2; + if (tmid[0] < tmid[2]) { + masklist[1] = 1; + masklist[2] = 4; + } else { + masklist[1] = 4; + masklist[2] = 1; + } + } else { + masklist[0] = 4; + if (tmid[0] < tmid[1]) { + masklist[1] = 1; + masklist[2] = 2; + } else { + masklist[1] = 2; + masklist[2] = 1; + } + } + return masklist; + } + + @Override + public String toString() { + var out = new StringBuilder("Octree centered on " + center + " with dimension " + dimension + "\n"); + for (int i = 0; i < 8; i++) { + out.append(i == 7 ? "\\- [" : "|- [").append(i).append("]: "); + + var prefix = i == 7 ? " " : "| "; + out.append(Objects.toString(octants[i]).lines().map(str -> prefix + str).collect(Collectors.joining("\n")).substring(8)); + out.append("\n"); + } + return out.toString(); + } + } + + private record Entry(@NotNull BoundingBox bbox, T object) { } +} diff --git a/src/main/java/eu/jonahbauer/raytracing/math/Vec3.java b/src/main/java/eu/jonahbauer/raytracing/math/Vec3.java index 76c8d41..35bc699 100644 --- a/src/main/java/eu/jonahbauer/raytracing/math/Vec3.java +++ b/src/main/java/eu/jonahbauer/raytracing/math/Vec3.java @@ -6,6 +6,8 @@ import java.util.Optional; public record Vec3(double x, double y, double z) { public static final Vec3 ZERO = new Vec3(0, 0, 0); + public static final Vec3 MAX = new Vec3(Double.MAX_VALUE, Double.MAX_VALUE, Double.MAX_VALUE); + public static final Vec3 MIN = new Vec3(-Double.MAX_VALUE, -Double.MAX_VALUE, -Double.MAX_VALUE); public static final Vec3 UNIT_X = new Vec3(1, 0, 0); public static final Vec3 UNIT_Y = new Vec3(0, 1, 0); public static final Vec3 UNIT_Z = new Vec3(0, 0, 1); @@ -52,6 +54,38 @@ public record Vec3(double x, double y, double z) { return a.minus(b).length(); } + public static @NotNull Vec3 average(@NotNull Vec3 current, @NotNull Vec3 next, int index) { + return new Vec3( + current.x() + (next.x() - current.x()) / index, + current.y() + (next.y() - current.y()) / index, + current.z() + (next.z() - current.z()) / index + ); + } + + public static @NotNull Vec3 max(@NotNull Vec3 a, @NotNull Vec3 b) { + return new Vec3( + Math.max(a.x(), b.x()), + Math.max(a.y(), b.y()), + Math.max(a.z(), b.z()) + ); + } + + public static @NotNull Vec3 min(@NotNull Vec3 a, @NotNull Vec3 b) { + return new Vec3( + Math.min(a.x(), b.x()), + Math.min(a.y(), b.y()), + Math.min(a.z(), b.z()) + ); + } + + public @NotNull Vec3 plus(double x, double y, double z) { + return new Vec3(this.x + x, this.y + y, this.z + z); + } + + public @NotNull Vec3 minus(double x, double y, double z) { + return new Vec3(this.x - x, this.y - y, this.z - z); + } + public @NotNull Vec3 plus(@NotNull Vec3 b) { return new Vec3(this.x + b.x, this.y + b.y, this.z + b.z); } diff --git a/src/main/java/eu/jonahbauer/raytracing/scene/Scene.java b/src/main/java/eu/jonahbauer/raytracing/scene/Scene.java index 7430895..9be9235 100644 --- a/src/main/java/eu/jonahbauer/raytracing/scene/Scene.java +++ b/src/main/java/eu/jonahbauer/raytracing/scene/Scene.java @@ -1,31 +1,80 @@ package eu.jonahbauer.raytracing.scene; +import eu.jonahbauer.raytracing.math.Octree; import eu.jonahbauer.raytracing.math.Range; import eu.jonahbauer.raytracing.math.Ray; +import eu.jonahbauer.raytracing.math.Vec3; import org.jetbrains.annotations.NotNull; +import java.util.ArrayList; +import java.util.Arrays; import java.util.List; import java.util.Optional; -public record Scene(@NotNull List<@NotNull Hittable> objects) implements Hittable { +public final class Scene implements Hittable { + private final @NotNull Octree<@NotNull Hittable> octree; + private final @NotNull List<@NotNull Hittable> list; - public Scene { - objects = List.copyOf(objects); - } + public Scene(@NotNull List objects) { + this.octree = newOctree(objects); + this.list = new ArrayList<>(); - public Scene(@NotNull Hittable @NotNull ... objects) { - this(List.of(objects)); - } - - public @NotNull Optional hit(@NotNull Ray ray, @NotNull Range range) { - var result = (HitResult) null; - for (var object : objects) { - var r = object.hit(ray, range); - if (r.isPresent() && range.surrounds(r.get().t())) { - result = r.get(); - range = new Range(range.min(), result.t()); + for (Hittable object : objects) { + var bbox = object.getBoundingBox(); + if (bbox.isPresent()) { + octree.add(bbox.get(), object); + } else { + list.add(object); } } - return Optional.ofNullable(result); + } + + @Override + public @NotNull Optional hit(@NotNull Ray ray, @NotNull Range range) { + var state = new State(); + state.range = range; + + octree.hit(ray, object -> hit(state, ray, object)); + list.forEach(object -> hit(state, ray, object)); + + return Optional.ofNullable(state.result); + } + + private void hit(@NotNull State state, @NotNull Ray ray, @NotNull Hittable object) { + var r = object.hit(ray, state.range); + if (r.isPresent() && state.range.surrounds(r.get().t())) { + state.result = r.get(); + state.range = new Range(state.range.min(), state.result.t()); + } + } + + private static @NotNull Octree newOctree(@NotNull List objects) { + Vec3 center = Vec3.ZERO, max = Vec3.MIN, min = Vec3.MAX; + + int i = 1; + for (Hittable object : objects) { + var bbox = object.getBoundingBox(); + if (bbox.isPresent()) { + center = Vec3.average(center, bbox.get().center(), i++); + max = Vec3.max(max, bbox.get().max()); + min = Vec3.min(min, bbox.get().min()); + } + } + + var dimension = Arrays.stream(new double[] { + Math.abs(max.x() - center.x()), + Math.abs(max.y() - center.y()), + Math.abs(max.z() - center.z()), + Math.abs(min.x() - center.x()), + Math.abs(min.y() - center.y()), + Math.abs(min.z() - center.z()) + }).max().orElse(10); + + return new Octree(center, dimension); + } + + private static class State { + HitResult result; + Range range; } }