1use bevy::tasks::{AsyncComputeTaskPool, Task, block_on, futures_lite::future::poll_once};
2use bevy::{ecs::entity::EntityHashMap, prelude::*};
3use petgraph::graph::NodeIndex;
4use serde::Deserialize;
5use std::collections::{HashMap, HashSet, VecDeque};
6use std::sync::{
7 Arc,
8 atomic::{AtomicUsize, Ordering},
9};
10use visgraph::layout::force_directed::force_directed_layout;
11
12use super::{Graph, Node, NodeCoor};
13
14#[derive(Clone, Copy, Debug, PartialEq, Eq)]
15pub enum GraphLayoutKind {
16 ForceDirected,
17 OSM,
18}
19
20#[derive(Resource)]
21pub struct GraphLayoutTask {
22 pub task: Task<Vec<(Entity, NodeCoor)>>,
23 finished: Arc<AtomicUsize>,
24 queued_for_retry: Arc<AtomicUsize>,
25 pub total: usize,
26 pub kind: GraphLayoutKind,
27}
28
29impl GraphLayoutTask {
30 fn new(
31 task: Task<Vec<(Entity, NodeCoor)>>,
32 finished: Arc<AtomicUsize>,
33 queued_for_retry: Arc<AtomicUsize>,
34 total: usize,
35 kind: GraphLayoutKind,
36 ) -> Self {
37 Self {
38 task,
39 finished,
40 queued_for_retry,
41 total,
42 kind,
43 }
44 }
45
46 pub fn progress(&self) -> (usize, usize, usize) {
47 (
48 self.finished.load(Ordering::Relaxed),
49 self.total,
50 self.queued_for_retry.load(Ordering::Relaxed),
51 )
52 }
53}
54
55pub fn apply_graph_layout_task(
56 mut commands: Commands,
57 task: Option<ResMut<GraphLayoutTask>>,
58 mut nodes: Query<&mut Node>,
59) {
60 let Some(mut task) = task else {
61 return;
62 };
63 let Some(found) = block_on(poll_once(&mut task.task)) else {
64 return;
65 };
66 for (entity, coor) in found {
67 let Ok(mut node) = nodes.get_mut(entity) else {
68 continue;
69 };
70 node.coor = coor;
71 }
72 let (finished, total, queued_for_retry) = task.progress();
73 info!(
74 "Graph arrange completed: mode={:?}, mapped={finished}/{total}, retry_queued={queued_for_retry}",
75 task.kind
76 );
77 commands.remove_resource::<GraphLayoutTask>();
78}
79
80pub fn apply_force_directed_layout(
81 In(iterations): In<u32>,
82 graph_map: Res<Graph>,
83 mut nodes: Query<&mut Node>,
84) {
85 let graph: petgraph::Graph<_, _, _, usize> = graph_map.map.clone().into_graph();
86 let binding = &graph;
87 let entity_map: EntityHashMap<NodeIndex<usize>> = graph
88 .node_indices()
89 .map(|idx| (*graph.node_weight(idx).unwrap(), idx))
90 .collect();
91 let layout = force_directed_layout(&binding, iterations, 0.1);
92
93 for node_entity in graph_map.nodes() {
94 let Some(&idx) = entity_map.get(&node_entity) else {
95 continue;
96 };
97 let Ok(mut node) = nodes.get_mut(node_entity) else {
98 continue;
99 };
100 let (nx, ny) = layout(idx);
101 node.coor = NodeCoor::from_xy(nx as f64, ny as f64);
102 }
103}
104
105pub fn auto_arrange_graph(
106 (In(ctx), In(iterations)): (In<egui::Context>, In<u32>),
107 mut commands: Commands,
108 graph_map: Res<Graph>,
109) {
110 let graph: petgraph::Graph<_, _, _, usize> = graph_map.map.clone().into_graph();
111 let total = graph.node_count();
112 let finished = Arc::new(AtomicUsize::new(0));
113 let queued_for_retry = Arc::new(AtomicUsize::new(0));
114 let finished_in_task = Arc::clone(&finished);
115
116 info!(
117 "Starting force-directed arrange: nodes={}, iterations={}",
118 total, iterations
119 );
120
121 let task = AsyncComputeTaskPool::get().spawn(async move {
122 let binding = &graph;
123 let layout = force_directed_layout(&binding, iterations, 0.1);
124 let out: Vec<(Entity, NodeCoor)> = graph
125 .node_indices()
126 .map(|idx| {
127 let (x, y) = layout(idx);
128 (
129 *graph.node_weight(idx).unwrap(),
130 NodeCoor::from_xy(x as f64 * 10000.0, y as f64 * 10000.0),
131 )
132 })
133 .collect();
134 finished_in_task.store(total, Ordering::Relaxed);
135 ctx.request_repaint();
136 out
137 });
138
139 commands.insert_resource(GraphLayoutTask::new(
140 task,
141 finished,
142 queued_for_retry,
143 total,
144 GraphLayoutKind::ForceDirected,
145 ));
146}
147
148#[derive(Deserialize)]
149struct OSMResponse {
150 elements: Vec<OSMElement>,
151}
152
153#[derive(Deserialize)]
154struct OSMElement {
155 lat: Option<f64>,
156 lon: Option<f64>,
157 center: Option<OSMCenter>,
158 #[serde(default)]
159 tags: std::collections::HashMap<String, String>,
160}
161
162#[derive(Deserialize)]
163struct OSMCenter {
164 lat: f64,
165 lon: f64,
166}
167
168impl OSMElement {
169 fn coor(&self) -> Option<NodeCoor> {
170 match (self.lon, self.lat, self.center.as_ref()) {
171 (Some(lon), Some(lat), _) => Some(NodeCoor::new(lon, lat)),
172 (_, _, Some(center)) => Some(NodeCoor::new(center.lon, center.lat)),
173 _ => None,
174 }
175 }
176}
177
178fn escape_overpass_regex(input: &str) -> String {
179 let mut out = String::with_capacity(input.len());
180 for c in input.chars() {
181 match c {
182 '\\' | '.' | '+' | '*' | '?' | '^' | '$' | '(' | ')' | '[' | ']' | '{' | '}' | '|' => {
183 out.push('\\');
184 out.push(c);
185 }
186 _ => out.push(c),
187 }
188 }
189 out
190}
191
192fn name_tag_weight(key: &str) -> f64 {
193 match key {
194 "name" => 0.06,
195 _ if key.starts_with("name:") => 0.05,
196 "official_name" => 0.04,
197 _ if key.starts_with("official_name:") => 0.04,
198 "short_name" => 0.03,
199 _ if key.starts_with("short_name:") => 0.03,
200 "loc_name" => 0.02,
201 _ if key.starts_with("loc_name:") => 0.02,
202 "alt_name" => 0.01,
203 _ if key.starts_with("alt_name:") => 0.01,
204 "old_name" => 0.0,
205 _ if key.starts_with("old_name:") => 0.0,
206 _ => -1.0,
207 }
208}
209
210fn station_kind_weight(tags: &HashMap<String, String>) -> f64 {
211 let railway_weight: f64 = match tags.get("railway").map(String::as_str) {
212 Some("station") => 0.60,
213 Some("halt") => 0.55,
214 Some("tram_stop") => 0.45,
215 Some("stop") => 0.40,
216 Some("light_rail") | Some("subway") | Some("monorail_station") => 0.40,
217 Some("stop_position") => 0.20,
218 Some("platform") => 0.15,
219 Some("disused_station") | Some("preserved") => 0.10,
220 Some(_) | None => 0.0,
221 };
222 let public_transport_weight: f64 = match tags.get("public_transport").map(String::as_str) {
223 Some("station") => 0.50,
224 Some("stop_area") => 0.35,
225 Some("platform") => 0.20,
226 Some("stop_position") => 0.15,
227 Some(_) | None => 0.0,
228 };
229 let station_weight: f64 = match tags.get("station").map(String::as_str) {
230 Some("subway") | Some("light_rail") => 0.20,
231 Some(_) | None => 0.0,
232 };
233 railway_weight
234 .max(public_transport_weight)
235 .max(station_weight)
236}
237
238fn best_name_match<'a>(elements: &'a [OSMElement], station_name: &str) -> Option<&'a OSMElement> {
239 let mut best: Option<(&OSMElement, f64)> = None;
240 for element in elements {
241 if element.coor().is_none() {
242 continue;
243 }
244 let base_weight = station_kind_weight(&element.tags);
245 for (key, value) in &element.tags {
246 let name_weight = name_tag_weight(key);
247 if name_weight < 0.0 {
248 continue;
249 }
250
251 let score = if value == station_name {
252 2.0 + base_weight + name_weight
253 } else {
254 let similarity = strsim::jaro_winkler(station_name, value);
255 if similarity <= 0.9 {
256 continue;
257 }
258 similarity + base_weight + name_weight
259 };
260
261 if best
262 .as_ref()
263 .is_none_or(|(_, best_score)| score > *best_score)
264 {
265 best = Some((element, score));
266 }
267 }
268 }
269 best.map(|(element, _)| element)
270}
271
272fn fill_unmatched_via_neighbors(
273 graph: &petgraph::Graph<Entity, Entity, petgraph::Directed, usize>,
274 known_positions: &mut HashMap<Entity, NodeCoor>,
275 all_stations: &[Entity],
276) -> usize {
277 let entity_to_index: HashMap<Entity, NodeIndex<usize>> = graph
278 .node_indices()
279 .map(|idx| (*graph.node_weight(idx).unwrap(), idx))
280 .collect();
281
282 let mut fallback_count = 0usize;
283 for &station in all_stations {
284 if known_positions.contains_key(&station) {
285 continue;
286 }
287 let Some(&start_idx) = entity_to_index.get(&station) else {
288 continue;
289 };
290
291 let mut queue = VecDeque::new();
292 let mut visited = HashSet::new();
293 let mut found_neighbor_positions = Vec::new();
294
295 queue.push_back(start_idx);
296 visited.insert(start_idx);
297
298 while let Some(current) = queue.pop_front() {
299 for neighbor in graph.neighbors_undirected(current) {
300 if !visited.insert(neighbor) {
301 continue;
302 }
303 let neighbor_entity = *graph.node_weight(neighbor).unwrap();
304 if let Some(coor) = known_positions.get(&neighbor_entity) {
305 found_neighbor_positions.push(*coor);
306 } else {
307 queue.push_back(neighbor);
308 }
309 }
310 }
311
312 if found_neighbor_positions.is_empty() {
313 continue;
314 }
315
316 let count = found_neighbor_positions.len() as f64;
317 let avg_lon = found_neighbor_positions.iter().map(|p| p.lon).sum::<f64>() / count;
318 let avg_lat = found_neighbor_positions.iter().map(|p| p.lat).sum::<f64>() / count;
319 known_positions.insert(station, NodeCoor::new(avg_lon, avg_lat));
320 fallback_count += 1;
321 }
322
323 fallback_count
324}
325
326pub fn arrange_via_osm(
327 (In(ctx), In(area_name)): (In<egui::Context>, In<Option<String>>),
328 mut commands: Commands,
329 graph_map: Res<Graph>,
330 station_names: Query<(Entity, &Name), With<crate::station::Station>>,
331) {
332 const MAX_RETRY_COUNT: usize = 3;
333 const OVERPASS_ENDPOINTS: [&str; 2] = [
334 "https://maps.mail.ru/osm/tools/overpass/api/interpreter",
335 "https://overpass-api.de/api/interpreter",
336 ];
337 let stations: Vec<(Entity, String)> = station_names
338 .iter()
339 .map(|(entity, name)| (entity, name.to_string()))
340 .collect();
341 let total = stations.len();
342 let station_entities: Vec<Entity> = stations.iter().map(|(entity, _)| *entity).collect();
343 let graph: petgraph::Graph<_, _, _, usize> = graph_map.map.clone().into_graph();
344
345 info!(
346 "Starting OSM arrange: stations={}, area={}",
347 total,
348 area_name.as_deref().unwrap_or("<global>")
349 );
350
351 let finished = Arc::new(AtomicUsize::new(0));
352 let queued_for_retry = Arc::new(AtomicUsize::new(0));
353 let finished_in_task = Arc::clone(&finished);
354 let queued_in_task = Arc::clone(&queued_for_retry);
355
356 let mut task_queue: VecDeque<(Vec<(Entity, String)>, usize)> = stations
357 .chunks(100)
358 .map(|chunk| (chunk.to_vec(), 0))
359 .collect();
360
361 let (area_def, area_filter) = match area_name.as_ref() {
362 Some(area) => {
363 if area.len() == 2 && area.chars().all(|c| c.is_ascii_alphabetic()) {
365 let country_code = area.to_uppercase();
366 info!(?country_code);
367 (
368 format!(r#"area["ISO3166-1"="{country_code}"]->.searchArea;"#),
369 "(area.searchArea)",
370 )
371 } else {
372 info!(?area);
373 (
374 format!(r#"area[name="{}"]->.searchArea;"#, area),
375 "(area.searchArea)",
376 )
377 }
378 }
379 None => (String::new(), ""),
380 };
381
382 let task = AsyncComputeTaskPool::get().spawn(async move {
383 let mut known_positions: HashMap<Entity, NodeCoor> = HashMap::new();
384
385 while let Some((chunk, retry_count)) = task_queue.pop_front() {
386 if retry_count >= MAX_RETRY_COUNT {
387 finished_in_task.fetch_add(chunk.len(), Ordering::Relaxed);
388 continue;
389 }
390
391 let names_regex = chunk
392 .iter()
393 .map(|(_, name)| escape_overpass_regex(name))
394 .collect::<Vec<_>>()
395 .join("|");
396
397 let query = format!(
398 r#"[out:json];{area_def}(node[~"^(railway|public_transport|station|subway|light_rail)$"~"^(station|halt|stop|tram_stop|subway_entrance|monorail_station|light_rail_station|narrow_gauge_station|funicular_station|preserved|disused_station|stop_position|platform|stop_area|subway|railway|tram|yes)$"][~"name(:.*)?"~"^({names_regex})$"]{area_filter};);out;"#,
399 );
400
401 let mut osm_data: Option<OSMResponse> = None;
402 for endpoint in OVERPASS_ENDPOINTS {
403 let request = ehttp::Request::post(
404 endpoint,
405 format!("data={}", urlencoding::encode(&query)).into_bytes(),
406 );
407
408 let response = match ehttp::fetch_async(request).await {
409 Ok(resp) => resp,
410 Err(e) => {
411 warn!(
412 "OSM request failed: endpoint={}, chunk(size={}), retry={}/{} ({:?})",
413 endpoint,
414 chunk.len(),
415 retry_count + 1,
416 MAX_RETRY_COUNT,
417 e
418 );
419 continue;
420 }
421 };
422
423 if !response.ok {
424 let body_preview = response
425 .text()
426 .map(|t| t.chars().take(200).collect::<String>())
427 .unwrap_or_else(|| "<non-utf8>".to_string());
428 warn!(
429 "OSM bad response: endpoint={}, status={} {}, content_type={:?}, body_preview={:?}",
430 endpoint,
431 response.status,
432 response.status_text,
433 response.content_type(),
434 body_preview
435 );
436 continue;
437 }
438
439 match response.json() {
440 Ok(data) => {
441 info!(
442 "OSM chunk fetched: endpoint={}, chunk(size={}), retry={}/{}",
443 endpoint,
444 chunk.len(),
445 retry_count,
446 MAX_RETRY_COUNT
447 );
448 osm_data = Some(data);
449 break;
450 }
451 Err(e) => {
452 let body_preview = response
453 .text()
454 .map(|t| t.chars().take(200).collect::<String>())
455 .unwrap_or_else(|| "<non-utf8>".to_string());
456 warn!(
457 "OSM response parse failed: endpoint={}, chunk(size={}), retry={}/{} ({:?}), content_type={:?}, body_preview={:?}",
458 endpoint,
459 chunk.len(),
460 retry_count + 1,
461 MAX_RETRY_COUNT,
462 e,
463 response.content_type(),
464 body_preview
465 );
466 }
467 }
468 }
469
470 let Some(osm_data) = osm_data else {
471 queued_in_task.fetch_add(chunk.len(), Ordering::Relaxed);
472 task_queue.push_back((chunk, retry_count + 1));
473 continue;
474 };
475
476 let chunk_size = chunk.len();
477 let mut matched_count = 0usize;
478 for (entity, name) in chunk {
479 if let Some(element) = best_name_match(&osm_data.elements, &name) {
480 if let Some(coor) = element.coor() {
481 known_positions.insert(entity, coor);
482 matched_count += 1;
483 }
484 }
485 finished_in_task.fetch_add(1, Ordering::Relaxed);
486 }
487 info!(
488 "OSM chunk processed: matched={}/{}, progress={}/{}",
489 matched_count,
490 chunk_size,
491 finished_in_task.load(Ordering::Relaxed),
492 total
493 );
494 ctx.request_repaint();
495 }
496
497 let fallback_count = fill_unmatched_via_neighbors(&graph, &mut known_positions, &station_entities);
498 info!(
499 "OSM neighbour fallback applied: fallback_mapped={}, total_mapped={}/{}",
500 fallback_count,
501 known_positions.len(),
502 total
503 );
504
505 known_positions.into_iter().collect()
506 });
507
508 commands.insert_resource(GraphLayoutTask::new(
509 task,
510 finished,
511 queued_for_retry,
512 total,
513 GraphLayoutKind::OSM,
514 ));
515}