Skip to main content

paiagram_core/
graph.rs

1use std::collections::HashMap;
2
3pub mod arrange;
4
5use crate::entry::EntryStop;
6use crate::interval::Interval;
7use crate::interval::IntervalQuery;
8use crate::route::Route;
9use crate::station::Platforms;
10use crate::station::Station;
11use crate::units::distance::Distance;
12use bevy::ecs::entity::EntityHashMap;
13use bevy::ecs::entity::EntityHashSet;
14use bevy::tasks::{AsyncComputeTaskPool, Task, block_on, futures_lite::future::poll_once};
15use bevy::{ecs::entity::EntityHash, prelude::*};
16use moonshine_core::kind::Instance;
17use moonshine_core::kind::SpawnInstance;
18use moonshine_core::prelude::{MapEntities, ReflectMapEntities};
19use petgraph::prelude::DiGraphMap;
20use petgraph::{algo::astar, visit::EdgeRef};
21use rstar::{AABB, PointDistance, RTree, RTreeObject};
22use serde::{Deserialize, Serialize};
23use smallvec::SmallVec;
24
25pub struct GraphPlugin;
26impl Plugin for GraphPlugin {
27    fn build(&self, app: &mut App) {
28        app.init_resource::<Graph>()
29            .init_resource::<GraphSpatialIndex>()
30            .init_resource::<GraphSpatialIndexState>()
31            .init_resource::<GraphIntervalSpatialIndex>()
32            .init_resource::<GraphIntervalSpatialIndexState>()
33            .add_systems(Update, arrange::apply_graph_layout_task)
34            .add_systems(
35                Update,
36                (
37                    mark_graph_spatial_index_dirty,
38                    start_graph_spatial_index_rebuild,
39                    apply_graph_spatial_index_task,
40                )
41                    .chain(),
42            )
43            .add_systems(
44                Update,
45                (
46                    mark_graph_interval_spatial_index_dirty,
47                    start_graph_interval_spatial_index_rebuild,
48                    apply_graph_interval_spatial_index_task,
49                )
50                    .chain(),
51            )
52            .add_observer(update_graph_on_station_removal)
53            .add_observer(update_graph_on_interval_removal)
54            .add_observer(add_interval_pair);
55        #[cfg(debug_assertions)]
56        {
57            use bevy::time::common_conditions::on_real_timer;
58            app.add_systems(
59                PostUpdate,
60                check_stations_in_graph.run_if(on_real_timer(std::time::Duration::from_secs(10))),
61            );
62        }
63    }
64}
65
66#[derive(Reflect, Clone, Resource, Serialize, Deserialize, Default, Deref, DerefMut)]
67#[reflect(Resource, opaque, Serialize, Deserialize, MapEntities)]
68pub struct Graph {
69    pub map: DiGraphMap<Entity, Entity, EntityHash>,
70}
71
72impl MapEntities for Graph {
73    fn map_entities<E: EntityMapper>(&mut self, entity_mapper: &mut E) {
74        // construct a new graph instead
75        let (nodes, edges) = self.capacity();
76        let mut new_graph = DiGraphMap::with_capacity(nodes, edges);
77        for mut node in self.nodes() {
78            node.map_entities(entity_mapper);
79            new_graph.add_node(node);
80        }
81        for (mut source, mut target, weight) in self.all_edges() {
82            let mut weight = *weight;
83            source.map_entities(entity_mapper);
84            target.map_entities(entity_mapper);
85            weight.map_entities(entity_mapper);
86            new_graph.add_edge(source, target, weight);
87        }
88        self.map = new_graph;
89    }
90}
91
92impl Graph {
93    pub fn route_between(
94        &self,
95        source: Entity,
96        target: Entity,
97        interval_q: &Query<IntervalQuery>,
98    ) -> Option<(i32, Vec<Entity>)> {
99        astar(
100            &self.map,
101            source,
102            |f| f == target,
103            |e| {
104                let Ok(i) = interval_q.get(*e.weight()) else {
105                    return i32::MAX;
106                };
107                i.distance().0
108            },
109            |_| 0,
110        )
111    }
112    pub fn route_between_source_waypoint_target(
113        &self,
114        mut points: impl Iterator<Item = Entity>,
115        interval_q: &Query<IntervalQuery>,
116    ) -> Option<(i32, Vec<Entity>)> {
117        let mut prev = points.next()?;
118        let mut total_length = 0;
119        let mut passes = vec![prev];
120        for curr in points {
121            let (leg_length, leg_points) = astar(
122                &self.map,
123                prev,
124                |f| f == curr,
125                |e| {
126                    let Ok(i) = interval_q.get(*e.weight()) else {
127                        return i32::MAX;
128                    };
129                    i.distance().0
130                },
131                |_| 0,
132            )?;
133            total_length += leg_length;
134            passes.extend_from_slice(&leg_points[1..]);
135            prev = curr;
136        }
137        Some((total_length, passes))
138    }
139    pub fn into_graph(self) -> petgraph::Graph<Entity, Entity> {
140        self.map.into_graph()
141    }
142}
143
144#[derive(Clone, Copy, Debug)]
145struct SpatialIndexedEntity {
146    entity: Entity,
147    point: [f64; 2],
148}
149
150#[derive(Clone, Copy, Debug)]
151struct IntervalSpatialIndexedEntity {
152    interval: Entity,
153    p0: [f64; 2],
154    p1: [f64; 2],
155}
156
157impl RTreeObject for SpatialIndexedEntity {
158    type Envelope = AABB<[f64; 2]>;
159
160    fn envelope(&self) -> Self::Envelope {
161        AABB::from_point(self.point)
162    }
163}
164
165impl PointDistance for SpatialIndexedEntity {
166    fn distance_2(&self, point: &[f64; 2]) -> f64 {
167        let dx = self.point[0] - point[0];
168        let dy = self.point[1] - point[1];
169        dx * dx + dy * dy
170    }
171}
172
173impl RTreeObject for IntervalSpatialIndexedEntity {
174    type Envelope = AABB<[f64; 2]>;
175
176    fn envelope(&self) -> Self::Envelope {
177        AABB::from_corners(
178            [self.p0[0].min(self.p1[0]), self.p0[1].min(self.p1[1])],
179            [self.p0[0].max(self.p1[0]), self.p0[1].max(self.p1[1])],
180        )
181    }
182}
183
184#[derive(Resource, Default)]
185pub struct GraphSpatialIndex {
186    tree: RTree<SpatialIndexedEntity>,
187}
188
189#[derive(Clone, Copy, Debug)]
190pub struct GraphIntervalSpatialSample {
191    pub interval: Entity,
192    pub p0: [f64; 2],
193    pub p1: [f64; 2],
194}
195
196#[derive(Resource, Default)]
197pub struct GraphIntervalSpatialIndex {
198    tree: RTree<IntervalSpatialIndexedEntity>,
199}
200
201impl GraphSpatialIndex {
202    pub fn is_empty(&self) -> bool {
203        self.tree.size() == 0
204    }
205
206    pub fn clear(&mut self) {
207        self.tree = RTree::new();
208    }
209
210    pub fn insert_xy(&mut self, entity: Entity, x: f64, y: f64) {
211        self.tree.insert(SpatialIndexedEntity {
212            entity,
213            point: [x, y],
214        });
215    }
216
217    pub fn insert_lon_lat(&mut self, entity: Entity, lon: f64, lat: f64) {
218        let (x, y) = lon_lat_to_xy(lon, lat);
219        self.insert_xy(entity, x, y);
220    }
221
222    pub fn entities_in_xy_aabb(
223        &self,
224        min_x: f64,
225        min_y: f64,
226        max_x: f64,
227        max_y: f64,
228    ) -> Vec<Entity> {
229        let envelope = AABB::from_corners(
230            [min_x.min(max_x), min_y.min(max_y)],
231            [min_x.max(max_x), min_y.max(max_y)],
232        );
233        self.tree
234            .locate_in_envelope_intersecting(&envelope)
235            .map(|entry| entry.entity)
236            .collect()
237    }
238
239    pub fn entities_in_lon_lat_aabb(
240        &self,
241        min_lon: f64,
242        min_lat: f64,
243        max_lon: f64,
244        max_lat: f64,
245    ) -> Vec<Entity> {
246        let (x0, y0) = lon_lat_to_xy(min_lon, min_lat);
247        let (x1, y1) = lon_lat_to_xy(max_lon, max_lat);
248        self.entities_in_xy_aabb(x0, y0, x1, y1)
249    }
250
251    pub fn nearest_in_xy(&self, x: f64, y: f64) -> Option<Entity> {
252        self.tree
253            .nearest_neighbor(&[x, y])
254            .map(|entry| entry.entity)
255    }
256
257    pub fn nearest_in_lon_lat(&self, lon: f64, lat: f64) -> Option<Entity> {
258        let (x, y) = lon_lat_to_xy(lon, lat);
259        self.nearest_in_xy(x, y)
260    }
261
262    fn replace_tree(&mut self, tree: RTree<SpatialIndexedEntity>) {
263        self.tree = tree;
264    }
265}
266
267impl GraphIntervalSpatialIndex {
268    pub fn is_empty(&self) -> bool {
269        self.tree.size() == 0
270    }
271
272    pub fn query_xy_aabb(
273        &self,
274        min_x: f64,
275        min_y: f64,
276        max_x: f64,
277        max_y: f64,
278    ) -> Vec<GraphIntervalSpatialSample> {
279        if self.is_empty() {
280            return Vec::new();
281        }
282
283        let envelope = AABB::from_corners(
284            [min_x.min(max_x), min_y.min(max_y)],
285            [min_x.max(max_x), min_y.max(max_y)],
286        );
287
288        self.tree
289            .locate_in_envelope_intersecting(&envelope)
290            .map(|item| GraphIntervalSpatialSample {
291                interval: item.interval,
292                p0: item.p0,
293                p1: item.p1,
294            })
295            .collect()
296    }
297
298    fn replace_tree(&mut self, tree: RTree<IntervalSpatialIndexedEntity>) {
299        self.tree = tree;
300    }
301}
302
303#[derive(Resource)]
304struct GraphSpatialIndexState {
305    dirty: bool,
306    task: Option<Task<RTree<SpatialIndexedEntity>>>,
307}
308
309#[derive(Resource)]
310struct GraphIntervalSpatialIndexState {
311    dirty: bool,
312    task: Option<Task<RTree<IntervalSpatialIndexedEntity>>>,
313}
314
315impl Default for GraphSpatialIndexState {
316    fn default() -> Self {
317        Self {
318            dirty: true,
319            task: None,
320        }
321    }
322}
323
324impl Default for GraphIntervalSpatialIndexState {
325    fn default() -> Self {
326        Self {
327            dirty: true,
328            task: None,
329        }
330    }
331}
332
333// EPSG:3857
334const EARTH_RADIUS_METERS: f64 = 6_378_137.0;
335const WEB_MERCATOR_MAX_LAT: f64 = 85.051_128_78;
336
337pub fn lon_lat_to_xy(lon: f64, lat: f64) -> (f64, f64) {
338    let x = EARTH_RADIUS_METERS * lon.to_radians();
339    let lat = lat.clamp(-WEB_MERCATOR_MAX_LAT, WEB_MERCATOR_MAX_LAT);
340    let lat_rad = lat.to_radians();
341    let y = -EARTH_RADIUS_METERS * (std::f64::consts::FRAC_PI_4 + lat_rad / 2.0).tan().ln();
342    (x, y)
343}
344
345pub fn xy_to_lon_lat(x: f64, y: f64) -> (f64, f64) {
346    let lon = (x / EARTH_RADIUS_METERS).to_degrees();
347    let lat =
348        (2.0 * (-y / EARTH_RADIUS_METERS).exp().atan() - std::f64::consts::FRAC_PI_2).to_degrees();
349    (lon, lat)
350}
351
352// TODO: partial update
353fn mark_graph_spatial_index_dirty(
354    mut state: ResMut<GraphSpatialIndexState>,
355    changed_nodes: Query<(), Or<(Added<Node>, Changed<Node>)>>,
356    mut removed_nodes: RemovedComponents<Node>,
357) {
358    if !changed_nodes.is_empty() || removed_nodes.read().next().is_some() {
359        state.dirty = true;
360    }
361}
362
363fn mark_graph_interval_spatial_index_dirty(
364    mut state: ResMut<GraphIntervalSpatialIndexState>,
365    graph: Res<Graph>,
366    changed_nodes: Query<(), Or<(Added<Node>, Changed<Node>)>>,
367    changed_intervals: Query<(), Or<(Added<Interval>, Changed<Interval>)>>,
368    mut removed_nodes: RemovedComponents<Node>,
369    mut removed_intervals: RemovedComponents<Interval>,
370) {
371    if graph.is_added()
372        || graph.is_changed()
373        || !changed_nodes.is_empty()
374        || !changed_intervals.is_empty()
375        || removed_nodes.read().next().is_some()
376        || removed_intervals.read().next().is_some()
377    {
378        state.dirty = true;
379    }
380}
381
382fn start_graph_spatial_index_rebuild(
383    mut state: ResMut<GraphSpatialIndexState>,
384    nodes: Query<(Entity, &Node)>,
385) {
386    if !state.dirty || state.task.is_some() {
387        return;
388    }
389    state.dirty = false;
390
391    let snapshot: Vec<(Entity, [f64; 2])> = nodes
392        .iter()
393        .map(|(entity, node)| (entity, node.coor.to_xy_arr()))
394        .collect();
395    state.task = Some(AsyncComputeTaskPool::get().spawn(async move {
396        let entries: Vec<SpatialIndexedEntity> = snapshot
397            .into_iter()
398            .map(|(entity, point)| SpatialIndexedEntity { entity, point })
399            .collect();
400        RTree::bulk_load(entries)
401    }));
402}
403
404fn start_graph_interval_spatial_index_rebuild(
405    mut state: ResMut<GraphIntervalSpatialIndexState>,
406    graph: Res<Graph>,
407    nodes: Query<&Node>,
408) {
409    if !state.dirty || state.task.is_some() {
410        return;
411    }
412    state.dirty = false;
413
414    let mut snapshot = Vec::<IntervalSpatialIndexedEntity>::new();
415    for (source, target, interval) in graph.all_edges() {
416        let Ok(source_node) = nodes.get(source) else {
417            continue;
418        };
419        let Ok(target_node) = nodes.get(target) else {
420            continue;
421        };
422        snapshot.push(IntervalSpatialIndexedEntity {
423            interval: *interval,
424            p0: source_node.coor.to_xy_arr(),
425            p1: target_node.coor.to_xy_arr(),
426        });
427    }
428
429    state.task = Some(AsyncComputeTaskPool::get().spawn(async move { RTree::bulk_load(snapshot) }));
430}
431
432fn apply_graph_spatial_index_task(
433    mut state: ResMut<GraphSpatialIndexState>,
434    mut index: ResMut<GraphSpatialIndex>,
435) {
436    let Some(task) = state.task.as_mut() else {
437        return;
438    };
439    let Some(tree) = block_on(poll_once(task)) else {
440        return;
441    };
442    index.replace_tree(tree);
443    state.task = None;
444}
445
446fn apply_graph_interval_spatial_index_task(
447    mut state: ResMut<GraphIntervalSpatialIndexState>,
448    mut index: ResMut<GraphIntervalSpatialIndex>,
449) {
450    let Some(task) = state.task.as_mut() else {
451        return;
452    };
453    let Some(tree) = block_on(poll_once(task)) else {
454        return;
455    };
456    index.replace_tree(tree);
457    state.task = None;
458}
459
460/// The position of the node.
461///
462/// This stores longitude and latitude values only.
463#[derive(Reflect, Clone, Copy, Debug)]
464pub struct NodeCoor {
465    pub lon: f64,
466    pub lat: f64,
467}
468
469impl std::fmt::Display for NodeCoor {
470    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
471        let lat_dir = if self.lat < 0.0 { 'S' } else { 'N' };
472        let lon_dir = if self.lon < 0.0 { 'W' } else { 'E' };
473        write!(
474            f,
475            "{:.4}°{}, {:.4}°{}",
476            self.lat.abs(),
477            lat_dir,
478            self.lon.abs(),
479            lon_dir
480        )
481    }
482}
483
484impl Default for NodeCoor {
485    fn default() -> Self {
486        Self::new(0.0, 0.0)
487    }
488}
489
490impl NodeCoor {
491    pub fn new(lon: f64, lat: f64) -> Self {
492        Self { lon, lat }
493    }
494    pub fn from_xy(x: f64, y: f64) -> Self {
495        let (lon, lat) = xy_to_lon_lat(x, y);
496        Self::new(lon, lat)
497    }
498    pub fn to_xy(&self) -> (f64, f64) {
499        lon_lat_to_xy(self.lon, self.lat)
500    }
501    pub fn to_xy_arr(&self) -> [f64; 2] {
502        let (x, y) = self.to_xy();
503        [x, y]
504    }
505    /// Shift the node on the canvas by x and y
506    pub fn shift(&mut self, dx: f64, dy: f64) {
507        self.lon += dx;
508        self.lat += dy;
509    }
510    /// Linearly interpolates between `self` and `other` by fraction `t`.
511    /// `t` is typically between 0.0 and 1.0.
512    pub fn lerp(&self, other: &Self, t: f64) -> Self {
513        let end_lon = other.lon;
514        let end_lat = other.lat;
515
516        Self {
517            lon: self.lon + (end_lon - self.lon) * t,
518            lat: self.lat + (end_lat - self.lat) * t,
519        }
520    }
521}
522
523#[derive(Default, Reflect, Component, Debug)]
524#[reflect(Component)]
525pub struct Node {
526    pub coor: NodeCoor,
527}
528
529fn update_graph_on_station_removal(
530    removed_station: On<Remove, Station>,
531    mut commands: Commands,
532    mut graph: ResMut<Graph>,
533) {
534    let s = removed_station.entity;
535    for e in graph
536        .neighbors_directed(s, petgraph::Direction::Incoming)
537        .chain(graph.neighbors_directed(s, petgraph::Direction::Outgoing))
538    {
539        commands.entity(e).despawn();
540    }
541    graph.remove_node(s);
542}
543
544fn update_graph_on_interval_removal(
545    removed_interval: On<Remove, Interval>,
546    mut graph: ResMut<Graph>,
547) {
548    let i = removed_interval.entity;
549    let mut source = None;
550    let mut target = None;
551    for (s, t, weight) in graph.all_edges() {
552        if i != *weight {
553            continue;
554        }
555        source = Some(s);
556        target = Some(t);
557        break;
558    }
559    let (Some(s), Some(t)) = (source, target) else {
560        return;
561    };
562    graph.remove_edge(s, t);
563}
564
565#[cfg(debug_assertions)]
566fn check_stations_in_graph(
567    graph: Res<Graph>,
568    stations: Populated<Entity, With<Station>>,
569    intervals: Populated<Entity, With<Interval>>,
570    names: Query<&Name>,
571) {
572    let queried_station_set: EntityHashSet = stations.iter().collect();
573    let queried_interval_set: EntityHashSet = intervals.iter().collect();
574    let mut graphed_station_set = EntityHashSet::new();
575    let mut graphed_interval_set = EntityHashSet::new();
576    for (_, _, w) in graph.all_edges() {
577        graphed_interval_set.insert(*w);
578    }
579    for node in graph.nodes() {
580        graphed_station_set.insert(node);
581    }
582    if queried_station_set != graphed_station_set {
583        debug_graph_set_diff(
584            "station",
585            &queried_station_set,
586            &graphed_station_set,
587            &names,
588        );
589    }
590    if queried_interval_set != graphed_interval_set {
591        debug_graph_set_diff(
592            "interval",
593            &queried_interval_set,
594            &graphed_interval_set,
595            &names,
596        );
597    }
598    debug_assert_eq!(queried_station_set, graphed_station_set);
599    debug_assert_eq!(queried_interval_set, graphed_interval_set);
600}
601
602#[cfg(debug_assertions)]
603fn debug_graph_set_diff(
604    label: &str,
605    queried: &EntityHashSet,
606    graphed: &EntityHashSet,
607    names: &Query<&Name>,
608) {
609    let intersection: EntityHashSet = queried.intersection(graphed).copied().collect();
610    let only_queried: EntityHashSet = queried.difference(graphed).copied().collect();
611    let only_graphed: EntityHashSet = graphed.difference(queried).copied().collect();
612
613    let list_with_names = |set: &EntityHashSet| -> Vec<String> {
614        let mut out: Vec<String> = set
615            .iter()
616            .map(|e| match names.get(*e) {
617                Ok(name) => format!("{} ({})", name.as_str(), e.index()),
618                Err(_) => format!("<??> ({})", e.index()),
619            })
620            .collect();
621        out.sort_unstable();
622        out
623    };
624
625    warn!(
626        "Graph {label} set mismatch: intersection={:#?} | only_queried={:#?} | only_graphed={:#?}",
627        list_with_names(&intersection),
628        list_with_names(&only_queried),
629        list_with_names(&only_graphed)
630    );
631}
632
633// TODO: instead of merging them, make stations platforms instead
634pub fn merge_station_by_name(
635    mut commands: Commands,
636    mut graph: ResMut<Graph>,
637    stations: Query<(Entity, &Name, &Platforms), With<Station>>,
638    entry_stops: Query<(Entity, &EntryStop)>,
639    mut routes: Query<&mut Route>,
640) {
641    let mut name_map: HashMap<&str, SmallVec<[Entity; 1]>> = HashMap::new();
642    for (entity, name, _) in &stations {
643        let v = name_map.entry(name.as_str()).or_default();
644        v.push(entity);
645    }
646
647    let mut remap: EntityHashMap<Entity> = EntityHashMap::default();
648
649    for (_name, mut entities) in name_map.into_iter().filter(|(_, v)| v.len() > 1) {
650        entities.sort_unstable_by_key(|entity| entity.index());
651        let keep = entities[0];
652
653        for duplicate in entities.into_iter().skip(1) {
654            if let Ok((_, _, platforms)) = stations.get(duplicate) {
655                let to_move: SmallVec<[Entity; 8]> = platforms.iter().collect();
656                if !to_move.is_empty() {
657                    commands.entity(keep).add_children(&to_move);
658                }
659            }
660            remap.insert(duplicate, keep);
661        }
662    }
663
664    if remap.is_empty() {
665        return;
666    }
667
668    let (nodes, edges) = graph.capacity();
669    let mut new_graph = DiGraphMap::with_capacity(nodes, edges);
670    let mut removed_intervals = EntityHashSet::default();
671    for (source, target, weight) in graph.all_edges() {
672        let source = remap.get(&source).copied().unwrap_or(source);
673        let target = remap.get(&target).copied().unwrap_or(target);
674        if let Some(existing_weight) = new_graph.edge_weight(source, target) {
675            if *existing_weight != *weight {
676                removed_intervals.insert(*weight);
677            }
678            continue;
679        }
680        new_graph.add_edge(source, target, *weight);
681    }
682    for node in graph.nodes() {
683        let node = remap.get(&node).copied().unwrap_or(node);
684        new_graph.add_node(node);
685    }
686    graph.map = new_graph;
687
688    for (entry, stop) in &entry_stops {
689        if let Some(new_stop) = remap.get(&stop.0) {
690            commands.entity(entry).insert(EntryStop(*new_stop));
691        }
692    }
693
694    for mut route in &mut routes {
695        for stop in &mut route.stops {
696            if let Some(new_stop) = remap.get(stop) {
697                *stop = *new_stop;
698            }
699        }
700    }
701
702    for interval in removed_intervals {
703        commands.entity(interval).despawn();
704    }
705
706    for duplicate in remap.into_keys() {
707        commands.entity(duplicate).despawn();
708    }
709}
710
711#[derive(Event, Clone, Copy)]
712pub struct AddIntervalPair {
713    pub source: Entity,
714    pub target: Entity,
715    pub length: Distance,
716}
717
718fn add_interval_pair(msg: On<AddIntervalPair>, mut graph: ResMut<Graph>, mut commands: Commands) {
719    if !graph.contains_edge(msg.source, msg.target) {
720        let e1: Instance<Interval> = commands
721            .spawn_instance(Interval { length: msg.length })
722            .into();
723        graph.add_edge(msg.source, msg.target, e1.entity());
724    }
725    if !graph.contains_edge(msg.target, msg.source) {
726        let e2: Instance<Interval> = commands
727            .spawn_instance(Interval { length: msg.length })
728            .into();
729        graph.add_edge(msg.target, msg.source, e2.entity());
730    }
731}