add octree to improve rendering performance by reducing the number of ray-sphere-intersection calculations
This commit is contained in:
parent
a84ed5c050
commit
d67f877428
253
src/main/java/eu/jonahbauer/raytracing/math/Octree.java
Normal file
253
src/main/java/eu/jonahbauer/raytracing/math/Octree.java
Normal file
@ -0,0 +1,253 @@
|
||||
package eu.jonahbauer.raytracing.math;
|
||||
|
||||
import org.jetbrains.annotations.NotNull;
|
||||
import org.jetbrains.annotations.Nullable;
|
||||
|
||||
import java.util.*;
|
||||
import java.util.function.Consumer;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
public final class Octree<T> {
|
||||
private final @NotNull NodeStorage<T> storage;
|
||||
|
||||
public Octree(@NotNull Vec3 center, double dimension) {
|
||||
this.storage = new NodeStorage<>(center, dimension);
|
||||
}
|
||||
|
||||
public void add(@NotNull BoundingBox bbox, T object) {
|
||||
storage.add(new Entry<>(bbox, object));
|
||||
}
|
||||
|
||||
/**
|
||||
* Use HERO algorithms to find all elements that could possibly be hit by the given ray.
|
||||
* @see <a href="https://diglib.eg.org/server/api/core/bitstreams/33fe8d58-1101-40ff-878a-79d689a4607d/content">The HERO Algorithm for Ray-Tracing Octrees</a>
|
||||
*/
|
||||
public void hit(@NotNull Ray ray, @NotNull Consumer<T> action) {
|
||||
storage.hit(ray, action);
|
||||
}
|
||||
|
||||
@Override
|
||||
public @NotNull String toString() {
|
||||
return storage.toString();
|
||||
}
|
||||
|
||||
public static int getOctantIndex(@NotNull Vec3 center, @NotNull Vec3 pos) {
|
||||
return (pos.x() < center.x() ? 0 : 1)
|
||||
| (pos.y() < center.y() ? 0 : 2)
|
||||
| (pos.z() < center.z() ? 0 : 4);
|
||||
|
||||
}
|
||||
|
||||
private sealed interface Storage<T> {
|
||||
int LIST_SIZE_LIMIT = 32;
|
||||
|
||||
@NotNull Storage<T> add(@NotNull Entry<T> entry);
|
||||
}
|
||||
|
||||
private static final class ListStorage<T> implements Storage<T> {
|
||||
private final @NotNull Vec3 center;
|
||||
private final double dimension;
|
||||
|
||||
private final @NotNull List<Entry<T>> list = new ArrayList<>();
|
||||
|
||||
public ListStorage(@NotNull Vec3 center, double dimension) {
|
||||
this.center = Objects.requireNonNull(center);
|
||||
this.dimension = dimension;
|
||||
}
|
||||
|
||||
@Override
|
||||
public @NotNull Storage<T> add(@NotNull Entry<T> entry) {
|
||||
if (list.size() >= LIST_SIZE_LIMIT) {
|
||||
var node = new NodeStorage<T>(center, dimension);
|
||||
list.forEach(node::add);
|
||||
node.add(entry);
|
||||
return node;
|
||||
} else {
|
||||
list.add(entry);
|
||||
return this;
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return list.toString();
|
||||
}
|
||||
}
|
||||
|
||||
private static final class NodeStorage<T> implements Storage<T> {
|
||||
private final @NotNull Vec3 center;
|
||||
private final double dimension;
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
private final @Nullable Storage<T> @NotNull[] octants = new Storage[8];
|
||||
private final @NotNull List<Entry<T>> list = new ArrayList<>(); // track elements spanning multiple octants separately
|
||||
|
||||
public NodeStorage(@NotNull Vec3 center, double dimension) {
|
||||
this.center = Objects.requireNonNull(center);
|
||||
this.dimension = dimension;
|
||||
}
|
||||
|
||||
@Override
|
||||
public @NotNull Storage<T> add(@NotNull Entry<T> entry) {
|
||||
var index = getOctantIndex(center, entry.bbox().min());
|
||||
if (index != getOctantIndex(center, entry.bbox().max())) {
|
||||
list.add(entry);
|
||||
} else {
|
||||
var subnode = octants[index];
|
||||
if (subnode == null) {
|
||||
subnode = newOctant(index);
|
||||
}
|
||||
octants[index] = subnode.add(entry);
|
||||
}
|
||||
return this;
|
||||
}
|
||||
|
||||
private @NotNull Storage<T> newOctant(int index) {
|
||||
var newSize = 0.5 * dimension;
|
||||
var newCenter = this.center
|
||||
.plus(new Vec3(
|
||||
(index & 1) == 0 ? -newSize : newSize,
|
||||
(index & 2) == 0 ? -newSize : newSize,
|
||||
(index & 4) == 0 ? -newSize : newSize
|
||||
));
|
||||
return new ListStorage<>(newCenter, newSize);
|
||||
}
|
||||
|
||||
public void hit(@NotNull Ray ray, @NotNull Consumer<T> action) {
|
||||
int vmask = (ray.direction().x() < 0 ? 1 : 0)
|
||||
| (ray.direction().y() < 0 ? 2 : 0)
|
||||
| (ray.direction().z() < 0 ? 4 : 0);
|
||||
|
||||
var min = center.minus(dimension, dimension, dimension);
|
||||
var max = center.plus(dimension, dimension, dimension);
|
||||
|
||||
// calculate t values for intersection points of ray with planes through min
|
||||
var tmin = calculatePlaneIntersections(min, ray);
|
||||
// calculate t values for intersection points of ray with planes through max
|
||||
var tmax = calculatePlaneIntersections(max, ray);
|
||||
|
||||
// determine range of t for which the ray is inside this voxel
|
||||
double tlmax = Double.NEGATIVE_INFINITY; // lower limit maximum
|
||||
double tumin = Double.POSITIVE_INFINITY; // upper limit minimum
|
||||
for (int i = 0; i < 3; i++) {
|
||||
// classify t values as lower or upper limit based on vmask
|
||||
if ((vmask & (1 << i)) == 0) {
|
||||
// min is lower limit and max is upper limit
|
||||
tlmax = Math.max(tlmax, tmin[i]);
|
||||
tumin = Math.min(tumin, tmax[i]);
|
||||
} else {
|
||||
// max is lower limit and min is upper limit
|
||||
tlmax = Math.max(tlmax, tmax[i]);
|
||||
tumin = Math.min(tumin, tmin[i]);
|
||||
}
|
||||
}
|
||||
|
||||
var hit = tlmax < tumin;
|
||||
if (!hit) return;
|
||||
|
||||
hit0(ray, vmask, tlmax, tumin, action);
|
||||
}
|
||||
|
||||
private void hit0(@NotNull Ray ray, int vmask, double tmin, double tmax, @NotNull Consumer<T> action) {
|
||||
if (tmax < 0) return;
|
||||
for (Entry<T> entry : list) {
|
||||
action.accept(entry.object());
|
||||
}
|
||||
|
||||
// t values for intersection points of ray with planes through center
|
||||
var tmid = calculatePlaneIntersections(center, ray);
|
||||
// masks of planes in the order of intersection, e.g. [2, 1, 4] for a ray intersection y = center.y() then x = center.x() then z = center.z()
|
||||
var masklist = calculateMastlist(tmid);
|
||||
var childmask = (tmid[0] < tmin ? 1 : 0)
|
||||
| (tmid[1] < tmin ? 2 : 0)
|
||||
| (tmid[2] < tmin ? 4 : 0);
|
||||
var lastmask = (tmid[0] < tmax ? 1 : 0)
|
||||
| (tmid[1] < tmax ? 2 : 0)
|
||||
| (tmid[2] < tmax ? 4 : 0);
|
||||
|
||||
var childTmin = tmin;
|
||||
|
||||
int i = 0;
|
||||
while (true) {
|
||||
var child = octants[childmask ^ vmask];
|
||||
double childTmax;
|
||||
if (childmask == lastmask) {
|
||||
childTmax = tmax;
|
||||
} else {
|
||||
while ((masklist[i] & childmask) != 0) {
|
||||
i++;
|
||||
}
|
||||
childmask = childmask | masklist[i];
|
||||
childTmax = tmid[Integer.numberOfTrailingZeros(masklist[i])];
|
||||
}
|
||||
|
||||
if (child instanceof ListStorage<T> list) {
|
||||
for (Entry<T> entry : list.list) {
|
||||
action.accept(entry.object);
|
||||
}
|
||||
} else if (child instanceof NodeStorage<T> node) {
|
||||
node.hit0(ray, vmask, childTmin, childTmax, action);
|
||||
}
|
||||
|
||||
if (childTmax == tmax) break;
|
||||
childTmin = childTmax;
|
||||
}
|
||||
}
|
||||
|
||||
private double @NotNull [] calculatePlaneIntersections(@NotNull Vec3 position, @NotNull Ray ray) {
|
||||
return new double[] {
|
||||
(position.x() - ray.origin().x()) / ray.direction().x(),
|
||||
(position.y() - ray.origin().y()) / ray.direction().y(),
|
||||
(position.z() - ray.origin().z()) / ray.direction().z(),
|
||||
};
|
||||
}
|
||||
|
||||
private int @NotNull [] calculateMastlist(double @NotNull[] tmid) {
|
||||
var masklist = new int[3];
|
||||
if (tmid[0] < tmid[1] && tmid[0] < tmid[2]) {
|
||||
masklist[0] = 1;
|
||||
if (tmid[1] < tmid[2]) {
|
||||
masklist[1] = 2;
|
||||
masklist[2] = 4;
|
||||
} else {
|
||||
masklist[1] = 4;
|
||||
masklist[2] = 2;
|
||||
}
|
||||
} else if (tmid[1] < tmid[2]) {
|
||||
masklist[0] = 2;
|
||||
if (tmid[0] < tmid[2]) {
|
||||
masklist[1] = 1;
|
||||
masklist[2] = 4;
|
||||
} else {
|
||||
masklist[1] = 4;
|
||||
masklist[2] = 1;
|
||||
}
|
||||
} else {
|
||||
masklist[0] = 4;
|
||||
if (tmid[0] < tmid[1]) {
|
||||
masklist[1] = 1;
|
||||
masklist[2] = 2;
|
||||
} else {
|
||||
masklist[1] = 2;
|
||||
masklist[2] = 1;
|
||||
}
|
||||
}
|
||||
return masklist;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
var out = new StringBuilder("Octree centered on " + center + " with dimension " + dimension + "\n");
|
||||
for (int i = 0; i < 8; i++) {
|
||||
out.append(i == 7 ? "\\- [" : "|- [").append(i).append("]: ");
|
||||
|
||||
var prefix = i == 7 ? " " : "| ";
|
||||
out.append(Objects.toString(octants[i]).lines().map(str -> prefix + str).collect(Collectors.joining("\n")).substring(8));
|
||||
out.append("\n");
|
||||
}
|
||||
return out.toString();
|
||||
}
|
||||
}
|
||||
|
||||
private record Entry<T>(@NotNull BoundingBox bbox, T object) { }
|
||||
}
|
@ -6,6 +6,8 @@ import java.util.Optional;
|
||||
|
||||
public record Vec3(double x, double y, double z) {
|
||||
public static final Vec3 ZERO = new Vec3(0, 0, 0);
|
||||
public static final Vec3 MAX = new Vec3(Double.MAX_VALUE, Double.MAX_VALUE, Double.MAX_VALUE);
|
||||
public static final Vec3 MIN = new Vec3(-Double.MAX_VALUE, -Double.MAX_VALUE, -Double.MAX_VALUE);
|
||||
public static final Vec3 UNIT_X = new Vec3(1, 0, 0);
|
||||
public static final Vec3 UNIT_Y = new Vec3(0, 1, 0);
|
||||
public static final Vec3 UNIT_Z = new Vec3(0, 0, 1);
|
||||
@ -52,6 +54,38 @@ public record Vec3(double x, double y, double z) {
|
||||
return a.minus(b).length();
|
||||
}
|
||||
|
||||
public static @NotNull Vec3 average(@NotNull Vec3 current, @NotNull Vec3 next, int index) {
|
||||
return new Vec3(
|
||||
current.x() + (next.x() - current.x()) / index,
|
||||
current.y() + (next.y() - current.y()) / index,
|
||||
current.z() + (next.z() - current.z()) / index
|
||||
);
|
||||
}
|
||||
|
||||
public static @NotNull Vec3 max(@NotNull Vec3 a, @NotNull Vec3 b) {
|
||||
return new Vec3(
|
||||
Math.max(a.x(), b.x()),
|
||||
Math.max(a.y(), b.y()),
|
||||
Math.max(a.z(), b.z())
|
||||
);
|
||||
}
|
||||
|
||||
public static @NotNull Vec3 min(@NotNull Vec3 a, @NotNull Vec3 b) {
|
||||
return new Vec3(
|
||||
Math.min(a.x(), b.x()),
|
||||
Math.min(a.y(), b.y()),
|
||||
Math.min(a.z(), b.z())
|
||||
);
|
||||
}
|
||||
|
||||
public @NotNull Vec3 plus(double x, double y, double z) {
|
||||
return new Vec3(this.x + x, this.y + y, this.z + z);
|
||||
}
|
||||
|
||||
public @NotNull Vec3 minus(double x, double y, double z) {
|
||||
return new Vec3(this.x - x, this.y - y, this.z - z);
|
||||
}
|
||||
|
||||
public @NotNull Vec3 plus(@NotNull Vec3 b) {
|
||||
return new Vec3(this.x + b.x, this.y + b.y, this.z + b.z);
|
||||
}
|
||||
|
@ -1,31 +1,80 @@
|
||||
package eu.jonahbauer.raytracing.scene;
|
||||
|
||||
import eu.jonahbauer.raytracing.math.Octree;
|
||||
import eu.jonahbauer.raytracing.math.Range;
|
||||
import eu.jonahbauer.raytracing.math.Ray;
|
||||
import eu.jonahbauer.raytracing.math.Vec3;
|
||||
import org.jetbrains.annotations.NotNull;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import java.util.Optional;
|
||||
|
||||
public record Scene(@NotNull List<@NotNull Hittable> objects) implements Hittable {
|
||||
public final class Scene implements Hittable {
|
||||
private final @NotNull Octree<@NotNull Hittable> octree;
|
||||
private final @NotNull List<@NotNull Hittable> list;
|
||||
|
||||
public Scene {
|
||||
objects = List.copyOf(objects);
|
||||
}
|
||||
public Scene(@NotNull List<? extends @NotNull Hittable> objects) {
|
||||
this.octree = newOctree(objects);
|
||||
this.list = new ArrayList<>();
|
||||
|
||||
public Scene(@NotNull Hittable @NotNull ... objects) {
|
||||
this(List.of(objects));
|
||||
}
|
||||
|
||||
public @NotNull Optional<HitResult> hit(@NotNull Ray ray, @NotNull Range range) {
|
||||
var result = (HitResult) null;
|
||||
for (var object : objects) {
|
||||
var r = object.hit(ray, range);
|
||||
if (r.isPresent() && range.surrounds(r.get().t())) {
|
||||
result = r.get();
|
||||
range = new Range(range.min(), result.t());
|
||||
for (Hittable object : objects) {
|
||||
var bbox = object.getBoundingBox();
|
||||
if (bbox.isPresent()) {
|
||||
octree.add(bbox.get(), object);
|
||||
} else {
|
||||
list.add(object);
|
||||
}
|
||||
}
|
||||
return Optional.ofNullable(result);
|
||||
}
|
||||
|
||||
@Override
|
||||
public @NotNull Optional<HitResult> hit(@NotNull Ray ray, @NotNull Range range) {
|
||||
var state = new State();
|
||||
state.range = range;
|
||||
|
||||
octree.hit(ray, object -> hit(state, ray, object));
|
||||
list.forEach(object -> hit(state, ray, object));
|
||||
|
||||
return Optional.ofNullable(state.result);
|
||||
}
|
||||
|
||||
private void hit(@NotNull State state, @NotNull Ray ray, @NotNull Hittable object) {
|
||||
var r = object.hit(ray, state.range);
|
||||
if (r.isPresent() && state.range.surrounds(r.get().t())) {
|
||||
state.result = r.get();
|
||||
state.range = new Range(state.range.min(), state.result.t());
|
||||
}
|
||||
}
|
||||
|
||||
private static @NotNull Octree<Hittable> newOctree(@NotNull List<? extends Hittable> objects) {
|
||||
Vec3 center = Vec3.ZERO, max = Vec3.MIN, min = Vec3.MAX;
|
||||
|
||||
int i = 1;
|
||||
for (Hittable object : objects) {
|
||||
var bbox = object.getBoundingBox();
|
||||
if (bbox.isPresent()) {
|
||||
center = Vec3.average(center, bbox.get().center(), i++);
|
||||
max = Vec3.max(max, bbox.get().max());
|
||||
min = Vec3.min(min, bbox.get().min());
|
||||
}
|
||||
}
|
||||
|
||||
var dimension = Arrays.stream(new double[] {
|
||||
Math.abs(max.x() - center.x()),
|
||||
Math.abs(max.y() - center.y()),
|
||||
Math.abs(max.z() - center.z()),
|
||||
Math.abs(min.x() - center.x()),
|
||||
Math.abs(min.y() - center.y()),
|
||||
Math.abs(min.z() - center.z())
|
||||
}).max().orElse(10);
|
||||
|
||||
return new Octree<Hittable>(center, dimension);
|
||||
}
|
||||
|
||||
private static class State {
|
||||
HitResult result;
|
||||
Range range;
|
||||
}
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user