Skip to main content

paiagram_core/
import.rs

1//! # Import
2//! Handles foreign formats such as GTFS Static, qETRC/pyETRC, and OuDiaSecond.
3
4use std::path::PathBuf;
5
6use crate::{
7    graph::Graph,
8    interval::Interval,
9    station::Station,
10    trip::class::{Class, ClassBundle},
11    units::{
12        distance::Distance,
13        time::{Duration, TimetableTime},
14    },
15};
16use anyhow::{Result, anyhow};
17use bevy::{
18    platform::collections::HashMap,
19    prelude::*,
20    tasks::{AsyncComputeTaskPool, Task, block_on, futures_lite::future::poll_once},
21};
22use moonshine_core::kind::*;
23use paiagram_rw::save::{LoadCandidate, SaveData};
24
25mod gtfs;
26mod llt;
27mod oudia;
28mod qetrc;
29
30pub struct ImportPlugin;
31impl Plugin for ImportPlugin {
32    fn build(&self, app: &mut App) {
33        app.add_observer(qetrc::load_qetrc)
34            .add_observer(oudia::load_oud)
35            .add_observer(gtfs::load_gtfs_static)
36            .add_observer(llt::load_llt)
37            .add_observer(download_file)
38            .add_systems(Update, pull_file);
39    }
40}
41
42#[derive(Event)]
43pub struct LoadQETRC {
44    pub content: String,
45}
46
47#[derive(Event)]
48pub struct LoadLlt {
49    pub content: String,
50}
51
52pub enum OuDiaContentType {
53    OuDiaSecond(String),
54    OuDia(Vec<u8>),
55}
56
57#[derive(Event)]
58pub struct LoadOuDia {
59    pub content: OuDiaContentType,
60}
61
62impl LoadOuDia {
63    pub fn original(data: Vec<u8>) -> Self {
64        Self {
65            content: OuDiaContentType::OuDia(data),
66        }
67    }
68    pub fn second(data: String) -> Self {
69        Self {
70            content: OuDiaContentType::OuDiaSecond(data),
71        }
72    }
73}
74
75#[derive(Event)]
76pub struct LoadGTFS {
77    pub content: Vec<u8>,
78}
79
80#[derive(Event)]
81pub struct DownloadFile {
82    pub url: String,
83}
84
85fn normalize_times<'a>(mut time_iter: impl Iterator<Item = &'a mut TimetableTime> + 'a) {
86    let Some(mut previous_time) = time_iter.next().copied() else {
87        return;
88    };
89    for time in time_iter {
90        while *time < previous_time {
91            *time += Duration(86400);
92        }
93        previous_time = *time;
94    }
95}
96
97pub(crate) fn make_station(
98    name: &str,
99    station_map: &mut HashMap<String, Instance<Station>>,
100    graph: &mut Graph,
101    commands: &mut Commands,
102) -> Instance<Station> {
103    if let Some(&entity) = station_map.get(name) {
104        return entity;
105    }
106    let station_entity = commands
107        .spawn(Name::new(name.to_string()))
108        .insert_instance(Station::default())
109        .into();
110    station_map.insert(name.to_string(), station_entity);
111    graph.add_node(station_entity.entity());
112    station_entity
113}
114
115pub(crate) fn make_class(
116    name: &str,
117    class_map: &mut HashMap<String, Instance<Class>>,
118    commands: &mut Commands,
119    mut make_class: impl FnMut() -> ClassBundle,
120) -> Instance<Class> {
121    if let Some(&entity) = class_map.get(name) {
122        return entity;
123    };
124    let class_bundle = make_class();
125    let class_entity = commands
126        .spawn((class_bundle.name, class_bundle.stroke))
127        .insert_instance(class_bundle.class)
128        .into();
129    class_map.insert(name.to_string(), class_entity);
130    class_entity
131}
132
133// TODO: remove this function
134pub(crate) fn add_interval_pair(
135    graph: &mut Graph,
136    commands: &mut Commands,
137    from: Entity,
138    to: Entity,
139    length: Distance,
140) {
141    if !graph.contains_edge(from, to) {
142        let e1: Instance<Interval> = commands.spawn_instance(Interval { length }).into();
143        graph.add_edge(from, to, e1.entity());
144    }
145    if !graph.contains_edge(to, from) {
146        let e2: Instance<Interval> = commands.spawn_instance(Interval { length }).into();
147        graph.add_edge(to, from, e2.entity());
148    }
149}
150
151#[derive(Component)]
152pub struct FileDownloadTask {
153    task: Option<Task<(Vec<u8>, String)>>,
154    url: String,
155}
156
157pub fn download_file(event: On<DownloadFile>, mut commands: Commands) {
158    commands.spawn(FileDownloadTask {
159        task: None,
160        url: event.url.clone(),
161    });
162}
163
164fn pull_file(mut commands: Commands, tasks: Populated<(Entity, &mut FileDownloadTask)>) {
165    for (task_entity, mut task) in tasks {
166        if task.task.is_none() {
167            let url = task.url.clone();
168            task.task = Some(AsyncComputeTaskPool::get().spawn(async move {
169                let response = ehttp::fetch_async(ehttp::Request::get(&url))
170                    .await
171                    .unwrap_or_else(|e| panic!("Failed to download file from {url}: {e:?}"));
172                if !response.ok {
173                    panic!(
174                        "Failed to download file from {url}: status={} {}",
175                        response.status, response.status_text
176                    );
177                }
178                (response.bytes, response.url)
179            }));
180            continue;
181        }
182
183        let Some(task_handle) = task.task.as_mut() else {
184            continue;
185        };
186        let Some((content, final_url)) = block_on(poll_once(task_handle)) else {
187            continue;
188        };
189
190        let path = infer_path_from_url(&final_url)
191            .or_else(|| infer_path_from_url(&task.url))
192            .unwrap_or_else(|| PathBuf::from(task.url.clone()));
193        if let Err(e) = load_and_trigger(&path, content, &mut commands) {
194            error!(
195                "Failed to load downloaded file from {} (resolved as {}): {e:#}",
196                task.url,
197                path.display(),
198            );
199        }
200        commands.entity(task_entity).despawn();
201    }
202}
203
204fn infer_path_from_url(url: &str) -> Option<PathBuf> {
205    let no_query = url.split('?').next().unwrap_or(url);
206    let no_fragment = no_query.split('#').next().unwrap_or(no_query);
207    let filename = no_fragment.rsplit('/').next().unwrap_or_default().trim();
208    if filename.is_empty() {
209        return None;
210    }
211    Some(PathBuf::from(filename))
212}
213
214pub fn load_and_trigger(path: &PathBuf, content: Vec<u8>, commands: &mut Commands) -> Result<()> {
215    match path.extension().and_then(|s| s.to_str()) {
216        Some("paia") => {
217            commands.insert_resource(LoadCandidate(SaveData::CompressedCbor(content)));
218        }
219        Some("pyetgr") | Some("json") => {
220            let content = String::from_utf8(content)?;
221            commands.trigger(LoadQETRC { content });
222        }
223        Some("oud2") => {
224            let content = String::from_utf8(content)?;
225            commands.trigger(LoadOuDia::second(content));
226        }
227        Some("zip") => {
228            commands.trigger(LoadGTFS { content });
229        }
230        Some("oud") => {
231            // oudia does not use utf-8
232            commands.trigger(LoadOuDia::original(content))
233        }
234        Some("ron") => {
235            commands.insert_resource(LoadCandidate(SaveData::Ron(content)));
236        }
237        Some(e) => return Err(anyhow!("Unexpected extension: {e}")),
238        None => return Err(anyhow!("Path does not have an extension")),
239    }
240    return Ok(());
241}