260 lines
10 KiB
Java

package eu.jonahbauer.raytracing.math;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;
import java.util.*;
import java.util.function.Predicate;
public final class Octree<T> {
private final @NotNull NodeStorage<T> storage;
public Octree(@NotNull Vec3 center, double dimension) {
this.storage = new NodeStorage<>(center, dimension);
}
public void add(@NotNull BoundingBox bbox, T object) {
storage.add(new Entry<>(bbox, object));
}
/**
* Use HERO algorithms to find all elements that could possibly be hit by the given ray.
* @see <a href="https://doi.org/10.1007/978-3-642-76298-7_3">
* Agate, M., Grimsdale, R.L., Lister, P.F. (1991).
* The HERO Algorithm for Ray-Tracing Octrees.
* In: Grimsdale, R.L., Straßer, W. (eds) Advances in Computer Graphics Hardware IV. Eurographic Seminars. Springer, Berlin, Heidelberg.</a>
*/
public void hit(@NotNull Ray ray, @NotNull Predicate<T> action) {
storage.hit(ray, action);
}
public static int getOctantIndex(@NotNull Vec3 center, @NotNull Vec3 pos) {
return (pos.x() < center.x() ? 0 : 1)
| (pos.y() < center.y() ? 0 : 2)
| (pos.z() < center.z() ? 0 : 4);
}
private static sealed abstract class Storage<T> {
protected static final int LIST_SIZE_LIMIT = 32;
protected final @NotNull Vec3 center;
protected final double dimension;
public Storage(@NotNull Vec3 center, double dimension) {
this.center = Objects.requireNonNull(center);
this.dimension = dimension;
}
public abstract @NotNull Storage<T> add(@NotNull Entry<T> entry);
protected abstract boolean hit(@NotNull Ray ray, @NotNull Predicate<T> action);
protected boolean hit0(@NotNull Ray ray, int vmask, double tmin, double tmax, @NotNull Predicate<T> action) {
return hit(ray, action);
}
}
private static final class ListStorage<T> extends Storage<T> {
private final @NotNull List<Entry<T>> list = new ArrayList<>();
public ListStorage(@NotNull Vec3 center, double dimension) {
super(center, dimension);
}
@Override
public @NotNull Storage<T> add(@NotNull Entry<T> entry) {
if (list.size() >= LIST_SIZE_LIMIT) {
var node = new NodeStorage<T>(center, dimension);
list.forEach(node::add);
node.add(entry);
return node;
} else {
list.add(entry);
return this;
}
}
@Override
protected boolean hit(@NotNull Ray ray, @NotNull Predicate<T> action) {
var hit = false;
for (Entry<T> entry : list) {
hit |= action.test(entry.object());
}
return hit;
}
}
private static final class NodeStorage<T> extends Storage<T> {
@SuppressWarnings("unchecked")
private final @Nullable Storage<T> @NotNull[] octants = new Storage[8];
private final @NotNull List<Entry<T>> list = new ArrayList<>(); // track elements spanning multiple octants separately
public NodeStorage(@NotNull Vec3 center, double dimension) {
super(center, dimension);
}
@Override
public @NotNull Storage<T> add(@NotNull Entry<T> entry) {
var index = getOctantIndex(center, entry.bbox().min());
if (index != getOctantIndex(center, entry.bbox().max())) {
list.add(entry);
} else {
var subnode = octants[index];
if (subnode == null) {
subnode = newOctant(index);
}
octants[index] = subnode.add(entry);
}
return this;
}
private @NotNull Storage<T> newOctant(int index) {
var newSize = 0.5 * dimension;
var newCenter = this.center
.plus(new Vec3(
(index & 1) == 0 ? -newSize : newSize,
(index & 2) == 0 ? -newSize : newSize,
(index & 4) == 0 ? -newSize : newSize
));
return new ListStorage<>(newCenter, newSize);
}
@Override
protected boolean hit(@NotNull Ray ray, @NotNull Predicate<T> action) {
int vmask = (ray.direction().x() < 0 ? 1 : 0)
| (ray.direction().y() < 0 ? 2 : 0)
| (ray.direction().z() < 0 ? 4 : 0);
var min = center.minus(dimension, dimension, dimension);
var max = center.plus(dimension, dimension, dimension);
// calculate t values for intersection points of ray with planes through min
var tmin = calculatePlaneIntersections(min, ray);
// calculate t values for intersection points of ray with planes through max
var tmax = calculatePlaneIntersections(max, ray);
// determine range of t for which the ray is inside this voxel
double tlmax = Double.NEGATIVE_INFINITY; // lower limit maximum
double tumin = Double.POSITIVE_INFINITY; // upper limit minimum
for (int i = 0; i < 3; i++) {
// classify t values as lower or upper limit based on vmask
if ((vmask & (1 << i)) == 0) {
// min is lower limit and max is upper limit
tlmax = Math.max(tlmax, tmin[i]);
tumin = Math.min(tumin, tmax[i]);
} else {
// max is lower limit and min is upper limit
tlmax = Math.max(tlmax, tmax[i]);
tumin = Math.min(tumin, tmin[i]);
}
}
var hit = tlmax < tumin;
if (!hit) return false;
return hit0(ray, vmask, tlmax, tumin, action);
}
@Override
protected boolean hit0(@NotNull Ray ray, int vmask, double tmin, double tmax, @NotNull Predicate<T> action) {
if (tmax < 0) return false;
// check for hit
var hit = false;
// process entries spanning multiple children
for (Entry<T> entry : list) {
hit |= action.test(entry.object());
}
// t values for intersection points of ray with planes through center
var tmid = calculatePlaneIntersections(center, ray);
// masks of planes in the order of intersection, e.g. [2, 1, 4] for a ray intersection y = center.y() then x = center.x() then z = center.z()
var masklist = calculateMasklist(tmid);
// the first child to be hit by the ray assuming a ray with positive x, y and z coordinates
var childmask = (tmid[0] < tmin ? 1 : 0)
| (tmid[1] < tmin ? 2 : 0)
| (tmid[2] < tmin ? 4 : 0);
// the last child to be hit by the ray assuming a ray with positive x, y and z coordinates
var lastmask = (tmid[0] < tmax ? 1 : 0)
| (tmid[1] < tmax ? 2 : 0)
| (tmid[2] < tmax ? 4 : 0);
var childTmin = tmin;
int i = 0;
while (true) {
// use vmask to nullify the assumption of a positive ray made for childmask
var child = octants[childmask ^ vmask];
// calculate t value for exit of child
double childTmax;
if (childmask == lastmask) {
// last child shares tmax
childTmax = tmax;
} else {
// determine next child
while ((masklist[i] & childmask) != 0) {
i++;
}
childmask = childmask | masklist[i];
// tmax of current child is the t value for the intersection with the plane dividing the current and next child
childTmax = tmid[Integer.numberOfTrailingZeros(masklist[i])];
}
// process child
var childHit = child != null && child.hit0(ray, vmask, childTmin, childTmax, action);
hit |= childHit;
// break after last child has been processed or a hit has been found
if (childTmax == tmax || childHit) break;
// tmin of next child is tmax of current child
childTmin = childTmax;
}
return hit;
}
private double @NotNull [] calculatePlaneIntersections(@NotNull Vec3 position, @NotNull Ray ray) {
return new double[] {
(position.x() - ray.origin().x()) / ray.direction().x(),
(position.y() - ray.origin().y()) / ray.direction().y(),
(position.z() - ray.origin().z()) / ray.direction().z(),
};
}
private static final int[][] MASKLISTS = new int[][] {
{1, 2, 4},
{1, 4, 2},
{4, 1, 2},
{2, 1, 4},
{2, 4, 1},
{4, 2, 1}
};
private static int @NotNull [] calculateMasklist(double @NotNull[] tmid) {
if (tmid[0] < tmid[1]) {
if (tmid[1] < tmid[2]) {
return MASKLISTS[0]; // {1, 2, 4}
} else if (tmid[0] < tmid[2]) {
return MASKLISTS[1]; // {1, 4, 2}
} else {
return MASKLISTS[2]; // {4, 1, 2}
}
} else {
if (tmid[0] < tmid[2]) {
return MASKLISTS[3]; // {2, 1, 4}
} else if (tmid[1] < tmid[2]) {
return MASKLISTS[4]; // {2, 4, 1}
} else {
return MASKLISTS[5]; // {4, 2, 1}
}
}
}
}
private record Entry<T>(@NotNull BoundingBox bbox, T object) { }
}