add octree to improve rendering performance by reducing the number of ray-sphere-intersection calculations

This commit is contained in:
jbb01 2024-08-04 23:57:54 +02:00
parent a84ed5c050
commit d67f877428
3 changed files with 352 additions and 16 deletions

View File

@ -0,0 +1,253 @@
package eu.jonahbauer.raytracing.math;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;
import java.util.*;
import java.util.function.Consumer;
import java.util.stream.Collectors;
public final class Octree<T> {
private final @NotNull NodeStorage<T> storage;
public Octree(@NotNull Vec3 center, double dimension) {
this.storage = new NodeStorage<>(center, dimension);
}
public void add(@NotNull BoundingBox bbox, T object) {
storage.add(new Entry<>(bbox, object));
}
/**
* Use HERO algorithms to find all elements that could possibly be hit by the given ray.
* @see <a href="https://diglib.eg.org/server/api/core/bitstreams/33fe8d58-1101-40ff-878a-79d689a4607d/content">The HERO Algorithm for Ray-Tracing Octrees</a>
*/
public void hit(@NotNull Ray ray, @NotNull Consumer<T> action) {
storage.hit(ray, action);
}
@Override
public @NotNull String toString() {
return storage.toString();
}
public static int getOctantIndex(@NotNull Vec3 center, @NotNull Vec3 pos) {
return (pos.x() < center.x() ? 0 : 1)
| (pos.y() < center.y() ? 0 : 2)
| (pos.z() < center.z() ? 0 : 4);
}
private sealed interface Storage<T> {
int LIST_SIZE_LIMIT = 32;
@NotNull Storage<T> add(@NotNull Entry<T> entry);
}
private static final class ListStorage<T> implements Storage<T> {
private final @NotNull Vec3 center;
private final double dimension;
private final @NotNull List<Entry<T>> list = new ArrayList<>();
public ListStorage(@NotNull Vec3 center, double dimension) {
this.center = Objects.requireNonNull(center);
this.dimension = dimension;
}
@Override
public @NotNull Storage<T> add(@NotNull Entry<T> entry) {
if (list.size() >= LIST_SIZE_LIMIT) {
var node = new NodeStorage<T>(center, dimension);
list.forEach(node::add);
node.add(entry);
return node;
} else {
list.add(entry);
return this;
}
}
@Override
public String toString() {
return list.toString();
}
}
private static final class NodeStorage<T> implements Storage<T> {
private final @NotNull Vec3 center;
private final double dimension;
@SuppressWarnings("unchecked")
private final @Nullable Storage<T> @NotNull[] octants = new Storage[8];
private final @NotNull List<Entry<T>> list = new ArrayList<>(); // track elements spanning multiple octants separately
public NodeStorage(@NotNull Vec3 center, double dimension) {
this.center = Objects.requireNonNull(center);
this.dimension = dimension;
}
@Override
public @NotNull Storage<T> add(@NotNull Entry<T> entry) {
var index = getOctantIndex(center, entry.bbox().min());
if (index != getOctantIndex(center, entry.bbox().max())) {
list.add(entry);
} else {
var subnode = octants[index];
if (subnode == null) {
subnode = newOctant(index);
}
octants[index] = subnode.add(entry);
}
return this;
}
private @NotNull Storage<T> newOctant(int index) {
var newSize = 0.5 * dimension;
var newCenter = this.center
.plus(new Vec3(
(index & 1) == 0 ? -newSize : newSize,
(index & 2) == 0 ? -newSize : newSize,
(index & 4) == 0 ? -newSize : newSize
));
return new ListStorage<>(newCenter, newSize);
}
public void hit(@NotNull Ray ray, @NotNull Consumer<T> action) {
int vmask = (ray.direction().x() < 0 ? 1 : 0)
| (ray.direction().y() < 0 ? 2 : 0)
| (ray.direction().z() < 0 ? 4 : 0);
var min = center.minus(dimension, dimension, dimension);
var max = center.plus(dimension, dimension, dimension);
// calculate t values for intersection points of ray with planes through min
var tmin = calculatePlaneIntersections(min, ray);
// calculate t values for intersection points of ray with planes through max
var tmax = calculatePlaneIntersections(max, ray);
// determine range of t for which the ray is inside this voxel
double tlmax = Double.NEGATIVE_INFINITY; // lower limit maximum
double tumin = Double.POSITIVE_INFINITY; // upper limit minimum
for (int i = 0; i < 3; i++) {
// classify t values as lower or upper limit based on vmask
if ((vmask & (1 << i)) == 0) {
// min is lower limit and max is upper limit
tlmax = Math.max(tlmax, tmin[i]);
tumin = Math.min(tumin, tmax[i]);
} else {
// max is lower limit and min is upper limit
tlmax = Math.max(tlmax, tmax[i]);
tumin = Math.min(tumin, tmin[i]);
}
}
var hit = tlmax < tumin;
if (!hit) return;
hit0(ray, vmask, tlmax, tumin, action);
}
private void hit0(@NotNull Ray ray, int vmask, double tmin, double tmax, @NotNull Consumer<T> action) {
if (tmax < 0) return;
for (Entry<T> entry : list) {
action.accept(entry.object());
}
// t values for intersection points of ray with planes through center
var tmid = calculatePlaneIntersections(center, ray);
// masks of planes in the order of intersection, e.g. [2, 1, 4] for a ray intersection y = center.y() then x = center.x() then z = center.z()
var masklist = calculateMastlist(tmid);
var childmask = (tmid[0] < tmin ? 1 : 0)
| (tmid[1] < tmin ? 2 : 0)
| (tmid[2] < tmin ? 4 : 0);
var lastmask = (tmid[0] < tmax ? 1 : 0)
| (tmid[1] < tmax ? 2 : 0)
| (tmid[2] < tmax ? 4 : 0);
var childTmin = tmin;
int i = 0;
while (true) {
var child = octants[childmask ^ vmask];
double childTmax;
if (childmask == lastmask) {
childTmax = tmax;
} else {
while ((masklist[i] & childmask) != 0) {
i++;
}
childmask = childmask | masklist[i];
childTmax = tmid[Integer.numberOfTrailingZeros(masklist[i])];
}
if (child instanceof ListStorage<T> list) {
for (Entry<T> entry : list.list) {
action.accept(entry.object);
}
} else if (child instanceof NodeStorage<T> node) {
node.hit0(ray, vmask, childTmin, childTmax, action);
}
if (childTmax == tmax) break;
childTmin = childTmax;
}
}
private double @NotNull [] calculatePlaneIntersections(@NotNull Vec3 position, @NotNull Ray ray) {
return new double[] {
(position.x() - ray.origin().x()) / ray.direction().x(),
(position.y() - ray.origin().y()) / ray.direction().y(),
(position.z() - ray.origin().z()) / ray.direction().z(),
};
}
private int @NotNull [] calculateMastlist(double @NotNull[] tmid) {
var masklist = new int[3];
if (tmid[0] < tmid[1] && tmid[0] < tmid[2]) {
masklist[0] = 1;
if (tmid[1] < tmid[2]) {
masklist[1] = 2;
masklist[2] = 4;
} else {
masklist[1] = 4;
masklist[2] = 2;
}
} else if (tmid[1] < tmid[2]) {
masklist[0] = 2;
if (tmid[0] < tmid[2]) {
masklist[1] = 1;
masklist[2] = 4;
} else {
masklist[1] = 4;
masklist[2] = 1;
}
} else {
masklist[0] = 4;
if (tmid[0] < tmid[1]) {
masklist[1] = 1;
masklist[2] = 2;
} else {
masklist[1] = 2;
masklist[2] = 1;
}
}
return masklist;
}
@Override
public String toString() {
var out = new StringBuilder("Octree centered on " + center + " with dimension " + dimension + "\n");
for (int i = 0; i < 8; i++) {
out.append(i == 7 ? "\\- [" : "|- [").append(i).append("]: ");
var prefix = i == 7 ? " " : "| ";
out.append(Objects.toString(octants[i]).lines().map(str -> prefix + str).collect(Collectors.joining("\n")).substring(8));
out.append("\n");
}
return out.toString();
}
}
private record Entry<T>(@NotNull BoundingBox bbox, T object) { }
}

View File

@ -6,6 +6,8 @@ import java.util.Optional;
public record Vec3(double x, double y, double z) {
public static final Vec3 ZERO = new Vec3(0, 0, 0);
public static final Vec3 MAX = new Vec3(Double.MAX_VALUE, Double.MAX_VALUE, Double.MAX_VALUE);
public static final Vec3 MIN = new Vec3(-Double.MAX_VALUE, -Double.MAX_VALUE, -Double.MAX_VALUE);
public static final Vec3 UNIT_X = new Vec3(1, 0, 0);
public static final Vec3 UNIT_Y = new Vec3(0, 1, 0);
public static final Vec3 UNIT_Z = new Vec3(0, 0, 1);
@ -52,6 +54,38 @@ public record Vec3(double x, double y, double z) {
return a.minus(b).length();
}
public static @NotNull Vec3 average(@NotNull Vec3 current, @NotNull Vec3 next, int index) {
return new Vec3(
current.x() + (next.x() - current.x()) / index,
current.y() + (next.y() - current.y()) / index,
current.z() + (next.z() - current.z()) / index
);
}
public static @NotNull Vec3 max(@NotNull Vec3 a, @NotNull Vec3 b) {
return new Vec3(
Math.max(a.x(), b.x()),
Math.max(a.y(), b.y()),
Math.max(a.z(), b.z())
);
}
public static @NotNull Vec3 min(@NotNull Vec3 a, @NotNull Vec3 b) {
return new Vec3(
Math.min(a.x(), b.x()),
Math.min(a.y(), b.y()),
Math.min(a.z(), b.z())
);
}
public @NotNull Vec3 plus(double x, double y, double z) {
return new Vec3(this.x + x, this.y + y, this.z + z);
}
public @NotNull Vec3 minus(double x, double y, double z) {
return new Vec3(this.x - x, this.y - y, this.z - z);
}
public @NotNull Vec3 plus(@NotNull Vec3 b) {
return new Vec3(this.x + b.x, this.y + b.y, this.z + b.z);
}

View File

@ -1,31 +1,80 @@
package eu.jonahbauer.raytracing.scene;
import eu.jonahbauer.raytracing.math.Octree;
import eu.jonahbauer.raytracing.math.Range;
import eu.jonahbauer.raytracing.math.Ray;
import eu.jonahbauer.raytracing.math.Vec3;
import org.jetbrains.annotations.NotNull;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Optional;
public record Scene(@NotNull List<@NotNull Hittable> objects) implements Hittable {
public final class Scene implements Hittable {
private final @NotNull Octree<@NotNull Hittable> octree;
private final @NotNull List<@NotNull Hittable> list;
public Scene {
objects = List.copyOf(objects);
}
public Scene(@NotNull Hittable @NotNull ... objects) {
this(List.of(objects));
public Scene(@NotNull List<? extends @NotNull Hittable> objects) {
this.octree = newOctree(objects);
this.list = new ArrayList<>();
for (Hittable object : objects) {
var bbox = object.getBoundingBox();
if (bbox.isPresent()) {
octree.add(bbox.get(), object);
} else {
list.add(object);
}
}
}
@Override
public @NotNull Optional<HitResult> hit(@NotNull Ray ray, @NotNull Range range) {
var result = (HitResult) null;
for (var object : objects) {
var r = object.hit(ray, range);
if (r.isPresent() && range.surrounds(r.get().t())) {
result = r.get();
range = new Range(range.min(), result.t());
var state = new State();
state.range = range;
octree.hit(ray, object -> hit(state, ray, object));
list.forEach(object -> hit(state, ray, object));
return Optional.ofNullable(state.result);
}
private void hit(@NotNull State state, @NotNull Ray ray, @NotNull Hittable object) {
var r = object.hit(ray, state.range);
if (r.isPresent() && state.range.surrounds(r.get().t())) {
state.result = r.get();
state.range = new Range(state.range.min(), state.result.t());
}
}
return Optional.ofNullable(result);
private static @NotNull Octree<Hittable> newOctree(@NotNull List<? extends Hittable> objects) {
Vec3 center = Vec3.ZERO, max = Vec3.MIN, min = Vec3.MAX;
int i = 1;
for (Hittable object : objects) {
var bbox = object.getBoundingBox();
if (bbox.isPresent()) {
center = Vec3.average(center, bbox.get().center(), i++);
max = Vec3.max(max, bbox.get().max());
min = Vec3.min(min, bbox.get().min());
}
}
var dimension = Arrays.stream(new double[] {
Math.abs(max.x() - center.x()),
Math.abs(max.y() - center.y()),
Math.abs(max.z() - center.z()),
Math.abs(min.x() - center.x()),
Math.abs(min.y() - center.y()),
Math.abs(min.z() - center.z())
}).max().orElse(10);
return new Octree<Hittable>(center, dimension);
}
private static class State {
HitResult result;
Range range;
}
}