diff --git a/Raytrace/PlotRays.py b/Raytrace/PlotRays.py index 745c271..6fa0dd2 100644 --- a/Raytrace/PlotRays.py +++ b/Raytrace/PlotRays.py @@ -1,5 +1,5 @@ from Raytrace.TriangleMesh import TriangleMesh -from Raytrace.SideScan import ScanReading +from Raytrace.SideScan import ScanReading, RawScanReading import numpy as np try: import plotly.graph_objs as go # type: ignore @@ -12,9 +12,12 @@ def plot_mesh(mesh:TriangleMesh, **kwargs) -> go.Mesh3d: i, j, k = [list(range(i,x.shape[0],3)) for i in range(3)] return go.Mesh3d(x=x, y=y, z=z, i=i, j=j, k=k, **kwargs) -def plot_rays(reading:ScanReading, **kwargs) -> go.Scatter3d: - origin = reading.raw.origins[reading.raw.finite] - inters = reading.raw.intersections[reading.raw.finite] +def plot_rays(reading:ScanReading|RawScanReading, **kwargs) -> go.Scatter3d: + if isinstance(reading, ScanReading): + return plot_rays(reading.raw, **kwargs) + + origin = reading.origins[reading.finite] + inters = reading.intersections[reading.finite] x, y, z = np.stack((origin, inters, origin)).swapaxes(0,1).reshape((-1,3)).swapaxes(0,1) return go.Scatter3d(x=x, y=y, z=z, **kwargs) @@ -36,13 +39,16 @@ def plot_reading(reading:ScanReading) -> go.Figure: 'y': np.zeros(reading.raw.intersection_count), }) ]) -def plot_intersections(reading:ScanReading, **kwargs) -> go.Scatter3d: - inters = reading.raw.intersections[reading.raw.finite] +def plot_intersections(reading:ScanReading|RawScanReading, **kwargs) -> go.Scatter3d: + if isinstance(reading, ScanReading): + return plot_intersections(reading.raw, **kwargs) + + inters = reading.intersections[reading.finite] x, y, z = inters.reshape((-1,3)).swapaxes(0,1) return go.Scatter3d(x=x, y=y, z=z, **kwargs) -def plot_intersections_list(readings:list[ScanReading], **kwargs) -> list[go.Scatter3d]: +def plot_intersections_list(readings:list[ScanReading|RawScanReading], **kwargs) -> list[go.Scatter3d]: return [plot_intersections(reading, **kwargs) for reading in readings] def plot_readings_heatmap(readings) -> go.Figure: