Skip to main content

paiagram_core/graph/
arrange.rs

1use bevy::tasks::{AsyncComputeTaskPool, Task, block_on, futures_lite::future::poll_once};
2use bevy::{ecs::entity::EntityHashMap, prelude::*};
3use petgraph::graph::NodeIndex;
4use serde::Deserialize;
5use std::collections::{HashMap, HashSet, VecDeque};
6use std::sync::{
7    Arc,
8    atomic::{AtomicUsize, Ordering},
9};
10use visgraph::layout::force_directed::force_directed_layout;
11
12use super::{Graph, Node, NodeCoor};
13
14#[derive(Clone, Copy, Debug, PartialEq, Eq)]
15pub enum GraphLayoutKind {
16    ForceDirected,
17    OSM,
18}
19
20#[derive(Resource)]
21pub struct GraphLayoutTask {
22    pub task: Task<Vec<(Entity, NodeCoor)>>,
23    finished: Arc<AtomicUsize>,
24    queued_for_retry: Arc<AtomicUsize>,
25    pub total: usize,
26    pub kind: GraphLayoutKind,
27}
28
29impl GraphLayoutTask {
30    fn new(
31        task: Task<Vec<(Entity, NodeCoor)>>,
32        finished: Arc<AtomicUsize>,
33        queued_for_retry: Arc<AtomicUsize>,
34        total: usize,
35        kind: GraphLayoutKind,
36    ) -> Self {
37        Self {
38            task,
39            finished,
40            queued_for_retry,
41            total,
42            kind,
43        }
44    }
45
46    pub fn progress(&self) -> (usize, usize, usize) {
47        (
48            self.finished.load(Ordering::Relaxed),
49            self.total,
50            self.queued_for_retry.load(Ordering::Relaxed),
51        )
52    }
53}
54
55pub fn apply_graph_layout_task(
56    mut commands: Commands,
57    task: Option<ResMut<GraphLayoutTask>>,
58    mut nodes: Query<&mut Node>,
59) {
60    let Some(mut task) = task else {
61        return;
62    };
63    let Some(found) = block_on(poll_once(&mut task.task)) else {
64        return;
65    };
66    for (entity, coor) in found {
67        let Ok(mut node) = nodes.get_mut(entity) else {
68            continue;
69        };
70        node.coor = coor;
71    }
72    let (finished, total, queued_for_retry) = task.progress();
73    info!(
74        "Graph arrange completed: mode={:?}, mapped={finished}/{total}, retry_queued={queued_for_retry}",
75        task.kind
76    );
77    commands.remove_resource::<GraphLayoutTask>();
78}
79
80pub fn apply_force_directed_layout(
81    In(iterations): In<u32>,
82    graph_map: Res<Graph>,
83    mut nodes: Query<&mut Node>,
84) {
85    let graph: petgraph::Graph<_, _, _, usize> = graph_map.map.clone().into_graph();
86    let binding = &graph;
87    let entity_map: EntityHashMap<NodeIndex<usize>> = graph
88        .node_indices()
89        .map(|idx| (*graph.node_weight(idx).unwrap(), idx))
90        .collect();
91    let layout = force_directed_layout(&binding, iterations, 0.1);
92
93    for node_entity in graph_map.nodes() {
94        let Some(&idx) = entity_map.get(&node_entity) else {
95            continue;
96        };
97        let Ok(mut node) = nodes.get_mut(node_entity) else {
98            continue;
99        };
100        let (nx, ny) = layout(idx);
101        node.coor = NodeCoor::from_xy(nx as f64, ny as f64);
102    }
103}
104
105pub fn auto_arrange_graph(
106    (In(ctx), In(iterations)): (In<egui::Context>, In<u32>),
107    mut commands: Commands,
108    graph_map: Res<Graph>,
109) {
110    let graph: petgraph::Graph<_, _, _, usize> = graph_map.map.clone().into_graph();
111    let total = graph.node_count();
112    let finished = Arc::new(AtomicUsize::new(0));
113    let queued_for_retry = Arc::new(AtomicUsize::new(0));
114    let finished_in_task = Arc::clone(&finished);
115
116    info!(
117        "Starting force-directed arrange: nodes={}, iterations={}",
118        total, iterations
119    );
120
121    let task = AsyncComputeTaskPool::get().spawn(async move {
122        let binding = &graph;
123        let layout = force_directed_layout(&binding, iterations, 0.1);
124        let out: Vec<(Entity, NodeCoor)> = graph
125            .node_indices()
126            .map(|idx| {
127                let (x, y) = layout(idx);
128                (
129                    *graph.node_weight(idx).unwrap(),
130                    NodeCoor::from_xy(x as f64 * 10000.0, y as f64 * 10000.0),
131                )
132            })
133            .collect();
134        finished_in_task.store(total, Ordering::Relaxed);
135        ctx.request_repaint();
136        out
137    });
138
139    commands.insert_resource(GraphLayoutTask::new(
140        task,
141        finished,
142        queued_for_retry,
143        total,
144        GraphLayoutKind::ForceDirected,
145    ));
146}
147
148#[derive(Deserialize)]
149struct OSMResponse {
150    elements: Vec<OSMElement>,
151}
152
153#[derive(Deserialize)]
154struct OSMElement {
155    lat: Option<f64>,
156    lon: Option<f64>,
157    center: Option<OSMCenter>,
158    #[serde(default)]
159    tags: std::collections::HashMap<String, String>,
160}
161
162#[derive(Deserialize)]
163struct OSMCenter {
164    lat: f64,
165    lon: f64,
166}
167
168impl OSMElement {
169    fn coor(&self) -> Option<NodeCoor> {
170        match (self.lon, self.lat, self.center.as_ref()) {
171            (Some(lon), Some(lat), _) => Some(NodeCoor::new(lon, lat)),
172            (_, _, Some(center)) => Some(NodeCoor::new(center.lon, center.lat)),
173            _ => None,
174        }
175    }
176}
177
178fn escape_overpass_regex(input: &str) -> String {
179    let mut out = String::with_capacity(input.len());
180    for c in input.chars() {
181        match c {
182            '\\' | '.' | '+' | '*' | '?' | '^' | '$' | '(' | ')' | '[' | ']' | '{' | '}' | '|' => {
183                out.push('\\');
184                out.push(c);
185            }
186            _ => out.push(c),
187        }
188    }
189    out
190}
191
192fn name_tag_weight(key: &str) -> f64 {
193    match key {
194        "name" => 0.06,
195        _ if key.starts_with("name:") => 0.05,
196        "official_name" => 0.04,
197        _ if key.starts_with("official_name:") => 0.04,
198        "short_name" => 0.03,
199        _ if key.starts_with("short_name:") => 0.03,
200        "loc_name" => 0.02,
201        _ if key.starts_with("loc_name:") => 0.02,
202        "alt_name" => 0.01,
203        _ if key.starts_with("alt_name:") => 0.01,
204        "old_name" => 0.0,
205        _ if key.starts_with("old_name:") => 0.0,
206        _ => -1.0,
207    }
208}
209
210fn station_kind_weight(tags: &HashMap<String, String>) -> f64 {
211    let railway_weight: f64 = match tags.get("railway").map(String::as_str) {
212        Some("station") => 0.60,
213        Some("halt") => 0.55,
214        Some("tram_stop") => 0.45,
215        Some("stop") => 0.40,
216        Some("light_rail") | Some("subway") | Some("monorail_station") => 0.40,
217        Some("stop_position") => 0.20,
218        Some("platform") => 0.15,
219        Some("disused_station") | Some("preserved") => 0.10,
220        Some(_) | None => 0.0,
221    };
222    let public_transport_weight: f64 = match tags.get("public_transport").map(String::as_str) {
223        Some("station") => 0.50,
224        Some("stop_area") => 0.35,
225        Some("platform") => 0.20,
226        Some("stop_position") => 0.15,
227        Some(_) | None => 0.0,
228    };
229    let station_weight: f64 = match tags.get("station").map(String::as_str) {
230        Some("subway") | Some("light_rail") => 0.20,
231        Some(_) | None => 0.0,
232    };
233    railway_weight
234        .max(public_transport_weight)
235        .max(station_weight)
236}
237
238fn best_name_match<'a>(elements: &'a [OSMElement], station_name: &str) -> Option<&'a OSMElement> {
239    let mut best: Option<(&OSMElement, f64)> = None;
240    for element in elements {
241        if element.coor().is_none() {
242            continue;
243        }
244        let base_weight = station_kind_weight(&element.tags);
245        for (key, value) in &element.tags {
246            let name_weight = name_tag_weight(key);
247            if name_weight < 0.0 {
248                continue;
249            }
250
251            let score = if value == station_name {
252                2.0 + base_weight + name_weight
253            } else {
254                let similarity = strsim::jaro_winkler(station_name, value);
255                if similarity <= 0.9 {
256                    continue;
257                }
258                similarity + base_weight + name_weight
259            };
260
261            if best
262                .as_ref()
263                .is_none_or(|(_, best_score)| score > *best_score)
264            {
265                best = Some((element, score));
266            }
267        }
268    }
269    best.map(|(element, _)| element)
270}
271
272fn fill_unmatched_via_neighbors(
273    graph: &petgraph::Graph<Entity, Entity, petgraph::Directed, usize>,
274    known_positions: &mut HashMap<Entity, NodeCoor>,
275    all_stations: &[Entity],
276) -> usize {
277    let entity_to_index: HashMap<Entity, NodeIndex<usize>> = graph
278        .node_indices()
279        .map(|idx| (*graph.node_weight(idx).unwrap(), idx))
280        .collect();
281
282    let mut fallback_count = 0usize;
283    for &station in all_stations {
284        if known_positions.contains_key(&station) {
285            continue;
286        }
287        let Some(&start_idx) = entity_to_index.get(&station) else {
288            continue;
289        };
290
291        let mut queue = VecDeque::new();
292        let mut visited = HashSet::new();
293        let mut found_neighbor_positions = Vec::new();
294
295        queue.push_back(start_idx);
296        visited.insert(start_idx);
297
298        while let Some(current) = queue.pop_front() {
299            for neighbor in graph.neighbors_undirected(current) {
300                if !visited.insert(neighbor) {
301                    continue;
302                }
303                let neighbor_entity = *graph.node_weight(neighbor).unwrap();
304                if let Some(coor) = known_positions.get(&neighbor_entity) {
305                    found_neighbor_positions.push(*coor);
306                } else {
307                    queue.push_back(neighbor);
308                }
309            }
310        }
311
312        if found_neighbor_positions.is_empty() {
313            continue;
314        }
315
316        let count = found_neighbor_positions.len() as f64;
317        let avg_lon = found_neighbor_positions.iter().map(|p| p.lon).sum::<f64>() / count;
318        let avg_lat = found_neighbor_positions.iter().map(|p| p.lat).sum::<f64>() / count;
319        known_positions.insert(station, NodeCoor::new(avg_lon, avg_lat));
320        fallback_count += 1;
321    }
322
323    fallback_count
324}
325
326pub fn arrange_via_osm(
327    (In(ctx), In(area_name)): (In<egui::Context>, In<Option<String>>),
328    mut commands: Commands,
329    graph_map: Res<Graph>,
330    station_names: Query<(Entity, &Name), With<crate::station::Station>>,
331) {
332    const MAX_RETRY_COUNT: usize = 3;
333    const OVERPASS_ENDPOINTS: [&str; 2] = [
334        "https://maps.mail.ru/osm/tools/overpass/api/interpreter",
335        "https://overpass-api.de/api/interpreter",
336    ];
337    let stations: Vec<(Entity, String)> = station_names
338        .iter()
339        .map(|(entity, name)| (entity, name.to_string()))
340        .collect();
341    let total = stations.len();
342    let station_entities: Vec<Entity> = stations.iter().map(|(entity, _)| *entity).collect();
343    let graph: petgraph::Graph<_, _, _, usize> = graph_map.map.clone().into_graph();
344
345    info!(
346        "Starting OSM arrange: stations={}, area={}",
347        total,
348        area_name.as_deref().unwrap_or("<global>")
349    );
350
351    let finished = Arc::new(AtomicUsize::new(0));
352    let queued_for_retry = Arc::new(AtomicUsize::new(0));
353    let finished_in_task = Arc::clone(&finished);
354    let queued_in_task = Arc::clone(&queued_for_retry);
355
356    let mut task_queue: VecDeque<(Vec<(Entity, String)>, usize)> = stations
357        .chunks(100)
358        .map(|chunk| (chunk.to_vec(), 0))
359        .collect();
360
361    let (area_def, area_filter) = match area_name.as_ref() {
362        Some(area) => {
363            // Check if the input is a 2-letter ISO code (e.g., "CN", "US", "FR")
364            if area.len() == 2 && area.chars().all(|c| c.is_ascii_alphabetic()) {
365                let country_code = area.to_uppercase();
366                info!(?country_code);
367                (
368                    format!(r#"area["ISO3166-1"="{country_code}"]->.searchArea;"#),
369                    "(area.searchArea)",
370                )
371            } else {
372                info!(?area);
373                (
374                    format!(r#"area[name="{}"]->.searchArea;"#, area),
375                    "(area.searchArea)",
376                )
377            }
378        }
379        None => (String::new(), ""),
380    };
381
382    let task = AsyncComputeTaskPool::get().spawn(async move {
383        let mut known_positions: HashMap<Entity, NodeCoor> = HashMap::new();
384
385        while let Some((chunk, retry_count)) = task_queue.pop_front() {
386            if retry_count >= MAX_RETRY_COUNT {
387                finished_in_task.fetch_add(chunk.len(), Ordering::Relaxed);
388                continue;
389            }
390
391            let names_regex = chunk
392                .iter()
393                .map(|(_, name)| escape_overpass_regex(name))
394                .collect::<Vec<_>>()
395                .join("|");
396
397            let query = format!(
398                r#"[out:json];{area_def}(node[~"^(railway|public_transport|station|subway|light_rail)$"~"^(station|halt|stop|tram_stop|subway_entrance|monorail_station|light_rail_station|narrow_gauge_station|funicular_station|preserved|disused_station|stop_position|platform|stop_area|subway|railway|tram|yes)$"][~"name(:.*)?"~"^({names_regex})$"]{area_filter};);out;"#,
399            );
400
401            let mut osm_data: Option<OSMResponse> = None;
402            for endpoint in OVERPASS_ENDPOINTS {
403                let request = ehttp::Request::post(
404                    endpoint,
405                    format!("data={}", urlencoding::encode(&query)).into_bytes(),
406                );
407
408                let response = match ehttp::fetch_async(request).await {
409                    Ok(resp) => resp,
410                    Err(e) => {
411                        warn!(
412                            "OSM request failed: endpoint={}, chunk(size={}), retry={}/{} ({:?})",
413                            endpoint,
414                            chunk.len(),
415                            retry_count + 1,
416                            MAX_RETRY_COUNT,
417                            e
418                        );
419                        continue;
420                    }
421                };
422
423                if !response.ok {
424                    let body_preview = response
425                        .text()
426                        .map(|t| t.chars().take(200).collect::<String>())
427                        .unwrap_or_else(|| "<non-utf8>".to_string());
428                    warn!(
429                        "OSM bad response: endpoint={}, status={} {}, content_type={:?}, body_preview={:?}",
430                        endpoint,
431                        response.status,
432                        response.status_text,
433                        response.content_type(),
434                        body_preview
435                    );
436                    continue;
437                }
438
439                match response.json() {
440                    Ok(data) => {
441                        info!(
442                            "OSM chunk fetched: endpoint={}, chunk(size={}), retry={}/{}",
443                            endpoint,
444                            chunk.len(),
445                            retry_count,
446                            MAX_RETRY_COUNT
447                        );
448                        osm_data = Some(data);
449                        break;
450                    }
451                    Err(e) => {
452                        let body_preview = response
453                            .text()
454                            .map(|t| t.chars().take(200).collect::<String>())
455                            .unwrap_or_else(|| "<non-utf8>".to_string());
456                        warn!(
457                            "OSM response parse failed: endpoint={}, chunk(size={}), retry={}/{} ({:?}), content_type={:?}, body_preview={:?}",
458                            endpoint,
459                            chunk.len(),
460                            retry_count + 1,
461                            MAX_RETRY_COUNT,
462                            e,
463                            response.content_type(),
464                            body_preview
465                        );
466                    }
467                }
468            }
469
470            let Some(osm_data) = osm_data else {
471                queued_in_task.fetch_add(chunk.len(), Ordering::Relaxed);
472                task_queue.push_back((chunk, retry_count + 1));
473                continue;
474            };
475
476            let chunk_size = chunk.len();
477            let mut matched_count = 0usize;
478            for (entity, name) in chunk {
479                if let Some(element) = best_name_match(&osm_data.elements, &name) {
480                    if let Some(coor) = element.coor() {
481                        known_positions.insert(entity, coor);
482                        matched_count += 1;
483                    }
484                }
485                finished_in_task.fetch_add(1, Ordering::Relaxed);
486            }
487            info!(
488                "OSM chunk processed: matched={}/{}, progress={}/{}",
489                matched_count,
490                chunk_size,
491                finished_in_task.load(Ordering::Relaxed),
492                total
493            );
494            ctx.request_repaint();
495        }
496
497        let fallback_count = fill_unmatched_via_neighbors(&graph, &mut known_positions, &station_entities);
498        info!(
499            "OSM neighbour fallback applied: fallback_mapped={}, total_mapped={}/{}",
500            fallback_count,
501            known_positions.len(),
502            total
503        );
504
505        known_positions.into_iter().collect()
506    });
507
508    commands.insert_resource(GraphLayoutTask::new(
509        task,
510        finished,
511        queued_for_retry,
512        total,
513        GraphLayoutKind::OSM,
514    ));
515}