Skip to content

Commit

Permalink
Merge pull request #2 from tjs20011/raytrace-update
Browse files Browse the repository at this point in the history
Raytrace update
  • Loading branch information
jrb20008 authored Mar 3, 2025
2 parents 3ffe3b9 + 446bc2d commit 4840eeb
Show file tree
Hide file tree
Showing 6 changed files with 110 additions and 35 deletions.
2 changes: 1 addition & 1 deletion ExampleMultipleSonarReadings.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
"readings:list[ScanReading] = []\n",
"for n,rays in enumerate(rays_list):\n",
" print(n + 1, len(rays_list), sep='/', end='\\r')\n",
" readings.append(SideScan(mesh).scan_rays(rays))\n",
" readings.append(ScanReading(SideScan(mesh).scan_rays(rays)))\n",
"\n",
"print()\n",
"\n",
Expand Down
9 changes: 5 additions & 4 deletions ExampleSingleSonarReading.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
"source": [
"from Raytrace.TriangleMesh import Ray\n",
"from Raytrace.BVHMesh import BVHMesh\n",
"from Raytrace.SideScan import SideScan\n",
"from Raytrace.SideScan import SideScan, ScanReading\n",
"import numpy as np\n",
"import time"
]
Expand Down Expand Up @@ -43,10 +43,11 @@
"\n",
"rays = SideScan.generate_rays(orientation, min_angle, max_angle, sample_ray_count)\n",
"\n",
"reading = SideScan(mesh).scan_rays(rays)\n",
"raw_reading = SideScan(mesh).scan_rays(rays)\n",
"reading = ScanReading(raw_reading)\n",
"\n",
"print('Triangles:', mesh.triangles.shape[0])\n",
"reading.print_summary()"
"raw_reading.print_summary()"
]
},
{
Expand All @@ -56,7 +57,7 @@
"outputs": [],
"source": [
"import plotly.graph_objs as go\n",
"from PlotRays import plot_mesh, plot_rays, plot_reading"
"from Raytrace.PlotRays import plot_mesh, plot_rays, plot_reading"
]
},
{
Expand Down
11 changes: 11 additions & 0 deletions Raytrace/CompositeMesh.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import numpy as np
from Raytrace.TriangleMesh import TriangleMesh, Ray

class CompositeMesh(TriangleMesh):
meshes: list[TriangleMesh]
def __init__(self, meshes:list[TriangleMesh]) -> None:
self.meshes = meshes
self.triangles = np.concatenate([mesh.triangles for mesh in self.meshes])
def raytrace(self, ray:Ray) -> float:
distances = [mesh.raytrace(ray) for mesh in self.meshes]
return min(distances, default = np.inf)
18 changes: 12 additions & 6 deletions Raytrace/PlotRays.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from Raytrace.TriangleMesh import TriangleMesh
from Raytrace.SideScan import ScanReading
from Raytrace.SideScan import ScanReading, RawScanReading
import numpy as np
try:
import plotly.graph_objs as go # type: ignore
Expand All @@ -12,7 +12,10 @@ def plot_mesh(mesh:TriangleMesh, **kwargs) -> go.Mesh3d:
i, j, k = [list(range(i,x.shape[0],3)) for i in range(3)]
return go.Mesh3d(x=x, y=y, z=z, i=i, j=j, k=k, **kwargs)

def plot_rays(reading:ScanReading, **kwargs) -> go.Scatter3d:
def plot_rays(reading:ScanReading|RawScanReading, **kwargs) -> go.Scatter3d:
if isinstance(reading, ScanReading):
return plot_rays(reading.raw, **kwargs)

origin = reading.origins[reading.finite]
inters = reading.intersections[reading.finite]

Expand All @@ -32,17 +35,20 @@ def plot_reading(reading:ScanReading) -> go.Figure:
'mode': 'markers',
'name': 'Raw',
'showlegend': True,
'x': reading.distances[reading.finite],
'y': np.zeros(reading.intersection_count),
'x': reading.raw.distances[reading.raw.finite],
'y': np.zeros(reading.raw.intersection_count),
})
])
def plot_intersections(reading:ScanReading, **kwargs) -> go.Scatter3d:
def plot_intersections(reading:ScanReading|RawScanReading, **kwargs) -> go.Scatter3d:
if isinstance(reading, ScanReading):
return plot_intersections(reading.raw, **kwargs)

inters = reading.intersections[reading.finite]

x, y, z = inters.reshape((-1,3)).swapaxes(0,1)
return go.Scatter3d(x=x, y=y, z=z, **kwargs)

def plot_intersections_list(readings:list[ScanReading], **kwargs) -> list[go.Scatter3d]:
def plot_intersections_list(readings:list[ScanReading|RawScanReading], **kwargs) -> list[go.Scatter3d]:
return [plot_intersections(reading, **kwargs) for reading in readings]

def plot_readings_heatmap(readings) -> go.Figure:
Expand Down
92 changes: 69 additions & 23 deletions Raytrace/SideScan.py
Original file line number Diff line number Diff line change
@@ -1,62 +1,108 @@
from Raytrace.TriangleMesh import TriangleMesh, Ray
from typing import Self
import numpy as np
import time
import pickle

class ScanReading:
class RawScanReading:
distances:np.ndarray
intersections:np.ndarray
origins:np.ndarray
directions:np.ndarray
finite:np.ndarray
intersection_count:int

start_time:float = -1
end_time:float = -1

def __init__(self, distances:np.ndarray, rays:list[Ray]):
old_error_state = np.seterr(all='ignore')
self.distances = distances
self.origins = np.array([r.origin for r in rays])
self.directions = np.array([r.direction for r in rays])
self.intersections = self.origins + self.directions * self.distances.reshape((-1,1))
self.finite = np.isfinite(self.distances)
self.intersection_count = np.count_nonzero(self.finite)
np.seterr(**old_error_state)
def combine_min(self, other:Self) -> Self:
if len(self.distances) != len(other.distances):
raise ValueError("Cannot combine two readings of different sizes")
if np.any(self.origins != other.origins) or np.any(self.directions != other.directions):
pass # TODO: add warning
new_distances = self.distances.copy()
other_smaller = new_distances > other.distances
new_distances[other_smaller] = other.distances[other_smaller]

# HACK: new rays should be generated based on which distance was closer
new_rays = [Ray(i, j) for i, j in zip(self.directions, self.origins)]
return type(self)(new_distances, new_rays)

def print_summary(self) -> None:
print('Intersections:', self.intersection_count , '/', len(self.distances))
if self.end_time == -1 or self.start_time == -1: return
print('Time:', self.end_time - self.start_time, 'seconds')
print('Speed:', len(self.distances)/(self.end_time - self.start_time), 'rays/seconds')

def save(self, filename, override = False) -> None:
with open(filename, 'wb' if override else 'xb') as file:
pickle.dump(self, file)
@staticmethod
def load(filename) -> 'RawScanReading':
with open(filename, 'rb') as file:
obj = pickle.load(file)
if isinstance(obj, RawScanReading):
return obj
raise TypeError(f'The object saved in {filename} is type {type(obj)} not RawScanReading')
class ScanReading:
raw:RawScanReading

smooth_dist:float
result_reselution:int
min_dist:float
max_dist:float

result:np.ndarray

start_time:float = -1
end_time:float = -1

def __init__(self,
distances:np.ndarray, rays:list[Ray],
raw_reading:RawScanReading,
smooth_dist:float = 0.05,
result_reselution:int = 1000,
min_dist:float = 0,
max_dist:float = 2
):
old_error_state = np.seterr(all='ignore')
self.distances = distances
self.origins = np.array([r.origin for r in rays])
self.directions = np.array([r.direction for r in rays])
self.intersections = self.origins + self.directions * self.distances.reshape((-1,1))
self.finite = np.isfinite(self.distances)
self.intersection_count = np.count_nonzero(self.finite)

self.raw = raw_reading
self.smooth_dist = smooth_dist
self.result_reselution = result_reselution
self.min_dist = min_dist
self.max_dist = max_dist
self.convert_distances()
self.process_raw()
np.seterr(**old_error_state)

def convert_distances(self) -> None:
def process_raw(self) -> None:
old_error_state = np.seterr(all='ignore')
norm = (self.distances[self.finite] - self.min_dist) / (self.max_dist - self.min_dist)
norm = (self.raw.distances[self.raw.finite] - self.min_dist) / (self.max_dist - self.min_dist)

ldist = norm - (np.arange(0,1,1/self.result_reselution) + 0.5/self.result_reselution).reshape((-1,1))
smooth_val = self.smooth_dist / (self.max_dist - self.min_dist)
lval = np.pow(np.maximum(0, np.square(smooth_val) - np.square(ldist)),3) / (32/35*smooth_val**7) / len(self.distances)
lval = np.pow(np.maximum(0, np.square(smooth_val) - np.square(ldist)),3) / (32/35*smooth_val**7) / len(self.raw.distances)

self.result = np.sum(lval, axis = 1)
np.seterr(**old_error_state)
def print_summary(self) -> None:
print('Intersections:', self.intersection_count , '/', len(self.distances))
if self.end_time == -1 or self.start_time == -1: return
print('Time:', self.end_time - self.start_time, 'seconds')
print('Speed:', len(self.distances)/(self.end_time - self.start_time), 'rays/seconds')

def save(self, filename, override = False) -> None:
with open(filename, 'wb' if override else 'xb') as file:
pickle.dump(self, file)
@staticmethod
def load(filename) -> 'ScanReading':
with open(filename, 'rb') as file:
obj = pickle.load(file)
if isinstance(obj, ScanReading):
return obj
raise TypeError(f'The object saved in {filename} is type {type(obj)} not ScanReading')



class SideScan:
mesh:TriangleMesh
smooth_dist:float
Expand All @@ -65,15 +111,15 @@ def __init__(self, mesh:TriangleMesh, smooth_dist:float = 0.05, result_reselutio
self.mesh = mesh
self.smooth_dist = smooth_dist
self.result_reselution = result_reselution
def scan_rays(self, rays:list[Ray]) -> ScanReading:
def scan_rays(self, rays:list[Ray]) -> RawScanReading:
distances = np.empty((len(rays),), np.float32)

start_time = time.time()
for n, ray in enumerate(rays):
distances[n] = self.mesh.raytrace(ray)
end_time = time.time()

out = ScanReading(distances, rays)
out = RawScanReading(distances, rays)

out.start_time = start_time
out.end_time = end_time
Expand Down
13 changes: 12 additions & 1 deletion Raytrace/TriangleMesh.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import numpy as np

import pickle

class Ray:
direction:np.ndarray
Expand Down Expand Up @@ -115,3 +115,14 @@ def batch_triangle_ray_intersection(triangle_array, ray:Ray, epsilon = 1e-10) ->

np.seterr(**old_error_state)
return t

def save(self, filename, override = False) -> None:
with open(filename, 'wb' if override else 'xb') as file:
pickle.dump(self, file)
@staticmethod
def load(filename) -> 'TriangleMesh':
with open(filename, 'rb') as file:
obj = pickle.load(file)
if isinstance(obj, TriangleMesh):
return obj
raise TypeError(f'The object saved in {filename} is type {type(obj)} not TriangleMesh')

0 comments on commit 4840eeb

Please sign in to comment.