1use super::traj_it::TrajIterator;
20use super::{ExportCfg, INTERPOLATION_SAMPLES, InterpolationSnafu};
21use super::{Interpolatable, TrajError};
22use crate::errors::{NyxError, StateError};
23use crate::io::InputOutputError;
24use crate::io::watermark::pq_writer;
25use crate::linalg::DefaultAllocator;
26use crate::linalg::allocator::Allocator;
27use crate::md::prelude::{GuidanceMode, StateParameter};
28use crate::md::trajectory::smooth_state_diff_in_place;
29use crate::time::{Duration, Epoch, TimeSeries, TimeUnits};
30use anise::analysis::AnalysisError;
31use anise::analysis::specs::StateSpecTrait;
32use anise::astro::orbit::Orbit;
33use anise::errors::PhysicsError;
34use anise::prelude::{Aberration, Almanac};
35use arrow::array::{Array, Float64Builder, StringBuilder};
36use arrow::datatypes::{DataType, Field, Schema};
37use arrow::record_batch::RecordBatch;
38use hifitime::TimeScale;
39use log::{info, warn};
40use parquet::arrow::ArrowWriter;
41use snafu::ResultExt;
42use std::collections::HashMap;
43use std::error::Error;
44use std::fmt;
45use std::fs::File;
46use std::iter::Iterator;
47use std::ops;
48use std::ops::Bound::{Excluded, Included, Unbounded};
49use std::path::{Path, PathBuf};
50use std::sync::Arc;
51
52#[derive(Clone, PartialEq)]
54pub struct Traj<S: Interpolatable>
55where
56 DefaultAllocator: Allocator<S::VecLength> + Allocator<S::Size> + Allocator<S::Size, S::Size>,
57{
58 pub name: Option<String>,
60 pub states: Vec<S>,
62}
63
64impl<S: Interpolatable> Traj<S>
65where
66 DefaultAllocator: Allocator<S::VecLength> + Allocator<S::Size> + Allocator<S::Size, S::Size>,
67{
68 pub fn new() -> Self {
69 Self {
70 name: None,
71 states: Vec::new(),
72 }
73 }
74 pub fn finalize(&mut self) {
76 self.states.dedup_by(|a, b| a.epoch().eq(&b.epoch()));
78 self.states.sort_by_key(|a| a.epoch());
80 }
81
82 pub fn at(&self, epoch: Epoch) -> Result<S, TrajError> {
84 if self.states.is_empty() || self.first().epoch() > epoch || self.last().epoch() < epoch {
85 return Err(TrajError::NoInterpolationData { epoch });
86 }
87 match self
88 .states
89 .binary_search_by(|state| state.epoch().cmp(&epoch))
90 {
91 Ok(idx) => {
92 Ok(self.states[idx])
94 }
95 Err(idx) => {
96 if idx == 0 || idx >= self.states.len() {
97 return Err(TrajError::NoInterpolationData { epoch });
100 }
101 let num_left = INTERPOLATION_SAMPLES / 2;
106
107 let mut first_idx = idx.saturating_sub(num_left);
109 let last_idx = self.states.len().min(first_idx + INTERPOLATION_SAMPLES);
110
111 if last_idx == self.states.len() {
113 first_idx = last_idx.saturating_sub(2 * num_left);
114 }
115
116 let mut states = Vec::with_capacity(last_idx - first_idx);
117 for idx in first_idx..last_idx {
118 states.push(self.states[idx]);
119 }
120
121 self.states[idx]
122 .interpolate(epoch, &states)
123 .context(InterpolationSnafu)
124 }
125 }
126 }
127
128 pub fn first(&self) -> &S {
130 self.states.first().unwrap()
132 }
133
134 pub fn last(&self) -> &S {
136 self.states.last().unwrap()
137 }
138
139 pub fn start_epoch(&self) -> Epoch {
140 self.first().epoch()
141 }
142
143 pub fn end_epoch(&self) -> Epoch {
144 self.last().epoch()
145 }
146
147 pub fn every(&self, step: Duration) -> TrajIterator<'_, S> {
149 self.every_between(step, self.first().epoch(), self.last().epoch())
150 }
151
152 pub fn every_between(&self, step: Duration, start: Epoch, end: Epoch) -> TrajIterator<'_, S> {
154 TrajIterator {
155 time_series: TimeSeries::inclusive(
156 start.max(self.first().epoch()),
157 end.min(self.last().epoch()),
158 step,
159 ),
160 traj: self,
161 }
162 }
163
164 pub fn filter_by_epoch<R: ops::RangeBounds<Epoch>>(mut self, bound: R) -> Self {
166 self.states = self
167 .states
168 .iter()
169 .copied()
170 .filter(|s| bound.contains(&s.epoch()))
171 .collect::<Vec<_>>();
172 self
173 }
174
175 pub fn filter_by_offset<R: ops::RangeBounds<Duration>>(self, bound: R) -> Self {
178 if self.states.is_empty() {
179 return self;
180 }
181 let start = match bound.start_bound() {
183 Unbounded => self.states.first().unwrap().epoch(),
184 Included(offset) | Excluded(offset) => self.states.first().unwrap().epoch() + *offset,
185 };
186
187 let end = match bound.end_bound() {
188 Unbounded => self.states.last().unwrap().epoch(),
189 Included(offset) | Excluded(offset) => self.states.first().unwrap().epoch() + *offset,
190 };
191
192 self.filter_by_epoch(start..=end)
193 }
194 pub fn to_parquet_simple<P: AsRef<Path>>(&self, path: P) -> Result<PathBuf, Box<dyn Error>> {
196 self.to_parquet(path, ExportCfg::default())
197 }
198
199 pub fn to_parquet_with_cfg<P: AsRef<Path>>(
201 &self,
202 path: P,
203 cfg: ExportCfg,
204 ) -> Result<PathBuf, Box<dyn Error>> {
205 self.to_parquet(path, cfg)
206 }
207
208 pub fn to_parquet_with_step<P: AsRef<Path>>(
210 &self,
211 path: P,
212 step: Duration,
213 ) -> Result<(), Box<dyn Error>> {
214 self.to_parquet_with_cfg(
215 path,
216 ExportCfg {
217 step: Some(step),
218 ..Default::default()
219 },
220 )?;
221
222 Ok(())
223 }
224
225 pub fn to_parquet<P: AsRef<Path>>(
227 &self,
228 path: P,
229 cfg: ExportCfg,
230 ) -> Result<PathBuf, Box<dyn Error>> {
231 let tick = Epoch::now().unwrap();
232 info!("Exporting trajectory to parquet file...");
233
234 let path_buf = cfg.actual_path(path);
236
237 let states = if cfg.start_epoch.is_some() || cfg.end_epoch.is_some() || cfg.step.is_some() {
239 let start = cfg.start_epoch.unwrap_or_else(|| self.first().epoch());
241 let end = cfg.end_epoch.unwrap_or_else(|| self.last().epoch());
242 let step = cfg.step.unwrap_or_else(|| 1.minutes());
243 self.every_between(step, start, end).collect::<Vec<S>>()
244 } else {
245 self.states.to_vec()
246 };
247
248 let mut hdrs = vec![Field::new("Epoch (UTC)", DataType::Utf8, false)];
250
251 let frame = self.states[0].frame();
252 let more_meta = Some(vec![(
253 "Frame".to_string(),
254 serde_dhall::serialize(&frame)
255 .static_type_annotation()
256 .to_string()
257 .map_err(|e| {
258 Box::new(InputOutputError::SerializeDhall {
259 what: format!("frame `{frame}`"),
260 err: e.to_string(),
261 })
262 })?,
263 )]);
264
265 let requested_fields = match cfg.fields {
266 Some(fields) => fields,
267 None => S::export_params(),
268 };
269
270 let mut fields = Vec::new();
271 let mut field_nullable = Vec::new();
272 for field in requested_fields {
273 let mut any_ok = false;
274 let mut any_err = false;
275 for state in &states {
276 if state.value(field).is_ok() {
277 any_ok = true;
278 } else {
279 any_err = true;
280 }
281 }
282
283 if any_ok {
284 fields.push(field);
285 field_nullable.push(any_err);
286 }
287 }
288
289 for (field, nullable) in fields.iter().zip(field_nullable.iter().copied()) {
290 hdrs.push(field.to_field(more_meta.clone()).with_nullable(nullable));
291 }
292
293 let schema = Arc::new(Schema::new(hdrs));
295 let mut record: Vec<Arc<dyn Array>> = Vec::new();
296
297 let mut utc_epoch = StringBuilder::new();
301 for s in &states {
302 utc_epoch.append_value(s.epoch().to_time_scale(TimeScale::UTC).to_isoformat());
303 }
304 record.push(Arc::new(utc_epoch.finish()));
305
306 for field in fields {
308 if field == StateParameter::GuidanceMode() {
309 let mut guid_mode = StringBuilder::new();
310 for s in &states {
311 match s.value(field) {
312 Ok(value) => {
313 guid_mode.append_value(format!("{:?}", GuidanceMode::from(value)));
314 }
315 Err(_) => guid_mode.append_null(),
316 }
317 }
318 record.push(Arc::new(guid_mode.finish()));
319 } else {
320 let mut data = Float64Builder::new();
321 for s in &states {
322 match s.value(field) {
323 Ok(value) => data.append_value(value),
324 Err(_) => data.append_null(),
325 };
326 }
327 record.push(Arc::new(data.finish()));
328 }
329 }
330
331 info!(
332 "Serialized {} states from {} to {}",
333 states.len(),
334 states.first().unwrap().epoch(),
335 states.last().unwrap().epoch()
336 );
337
338 let mut metadata = HashMap::new();
340 metadata.insert("Purpose".to_string(), "Trajectory data".to_string());
341 if let Some(add_meta) = cfg.metadata {
342 for (k, v) in add_meta {
343 metadata.insert(k, v);
344 }
345 }
346
347 let props = pq_writer(Some(metadata));
348
349 let file = File::create(&path_buf)?;
350 let mut writer = ArrowWriter::try_new(file, schema.clone(), props).unwrap();
351
352 let batch = RecordBatch::try_new(schema, record)?;
353 writer.write(&batch)?;
354 writer.close()?;
355
356 let tock_time = Epoch::now().unwrap() - tick;
358 info!(
359 "Trajectory written to {} in {tock_time}",
360 path_buf.display()
361 );
362 Ok(path_buf)
363 }
364
365 pub fn resample(&self, step: Duration) -> Result<Self, NyxError> {
368 if self.states.is_empty() {
369 return Err(NyxError::Trajectory {
370 source: TrajError::CreationError {
371 msg: "No trajectory to convert".to_string(),
372 },
373 });
374 }
375
376 let mut traj = Self::new();
377 for state in self.every(step) {
378 traj.states.push(state);
379 }
380
381 traj.finalize();
382
383 Ok(traj)
384 }
385
386 pub fn rebuild(&self, epochs: &[Epoch]) -> Result<Self, NyxError> {
389 if self.states.is_empty() {
390 return Err(NyxError::Trajectory {
391 source: TrajError::CreationError {
392 msg: "No trajectory to convert".to_string(),
393 },
394 });
395 }
396
397 let mut traj = Self::new();
398 for epoch in epochs {
399 traj.states.push(self.at(*epoch)?);
400 }
401
402 traj.finalize();
403
404 Ok(traj)
405 }
406
407 pub fn ric_diff_to_parquet<P: AsRef<Path>>(
412 &self,
413 other: &Self,
414 path: P,
415 cfg: ExportCfg,
416 ) -> Result<PathBuf, TrajError> {
417 let tick = Epoch::now().unwrap();
418 info!("Exporting trajectory to parquet file...");
419
420 let path_buf = cfg.actual_path(path);
422
423 let mut hdrs = vec![Field::new("Epoch (UTC)", DataType::Utf8, false)];
425
426 for coord in ["X", "Y", "Z"] {
428 let mut meta = HashMap::new();
429 meta.insert("unit".to_string(), "km".to_string());
430
431 let field = Field::new(
432 format!("Delta {coord} (RIC) (km)"),
433 DataType::Float64,
434 false,
435 )
436 .with_metadata(meta);
437
438 hdrs.push(field);
439 }
440
441 for coord in ["x", "y", "z"] {
442 let mut meta = HashMap::new();
443 meta.insert("unit".to_string(), "km/s".to_string());
444
445 let field = Field::new(
446 format!("Delta V{coord} (RIC) (km/s)"),
447 DataType::Float64,
448 false,
449 )
450 .with_metadata(meta);
451
452 hdrs.push(field);
453 }
454
455 let frame = self.states[0].frame();
456 let more_meta = Some(vec![(
457 "Frame".to_string(),
458 serde_dhall::serialize(&frame)
459 .static_type_annotation()
460 .to_string()
461 .unwrap_or(frame.to_string()),
462 )]);
463
464 let mut cfg = cfg;
465
466 let mut fields = match cfg.fields {
467 Some(fields) => fields,
468 None => S::export_params(),
469 };
470
471 fields.retain(|param| {
473 param != &StateParameter::GuidanceMode() && self.first().value(*param).is_ok()
474 });
475
476 for field in &fields {
477 hdrs.push(field.to_field(more_meta.clone()));
478 }
479
480 let schema = Arc::new(Schema::new(hdrs));
482 let mut record: Vec<Arc<dyn Array>> = Vec::new();
483
484 cfg.start_epoch = if self.first().epoch() > other.first().epoch() {
486 Some(self.first().epoch())
487 } else {
488 Some(other.first().epoch())
489 };
490
491 cfg.end_epoch = if self.last().epoch() > other.last().epoch() {
492 Some(other.last().epoch())
493 } else {
494 Some(self.last().epoch())
495 };
496
497 let step = cfg.step.unwrap_or_else(|| 1.minutes());
499 let self_states = self
500 .every_between(step, cfg.start_epoch.unwrap(), cfg.end_epoch.unwrap())
501 .collect::<Vec<S>>();
502
503 let other_states = other
504 .every_between(step, cfg.start_epoch.unwrap(), cfg.end_epoch.unwrap())
505 .collect::<Vec<S>>();
506
507 let mut ric_diff = Vec::with_capacity(other_states.len());
509 for (other_state, self_state) in other_states.iter().zip(self_states.iter()) {
510 let self_orbit = self_state.orbit();
511 let other_orbit = other_state.orbit();
512
513 let this_ric_diff = self_orbit
514 .ric_difference(&other_orbit)
515 .map_err(|source: PhysicsError| TrajError::TrajPhysics { source })?;
516
517 ric_diff.push(this_ric_diff);
518 }
519
520 smooth_state_diff_in_place(&mut ric_diff, if other_states.len() > 5 { 5 } else { 1 });
521
522 let mut utc_epoch = StringBuilder::new();
526 for s in &self_states {
527 utc_epoch.append_value(s.epoch().to_time_scale(TimeScale::UTC).to_isoformat());
528 }
529 record.push(Arc::new(utc_epoch.finish()));
530
531 for coord_no in 0..6 {
533 let mut data = Float64Builder::new();
534 for this_ric_dff in &ric_diff {
535 data.append_value(this_ric_dff.to_cartesian_pos_vel()[coord_no]);
536 }
537 record.push(Arc::new(data.finish()));
538 }
539
540 for field in fields {
542 let mut data = Float64Builder::new();
543 for (other_state, self_state) in other_states.iter().zip(self_states.iter()) {
544 let self_val =
545 self_state
546 .value(field)
547 .map_err(|err: StateError| TrajError::TrajGeneric {
548 err: err.to_string(),
549 })?;
550 let other_val =
551 other_state
552 .value(field)
553 .map_err(|err: StateError| TrajError::TrajGeneric {
554 err: err.to_string(),
555 })?;
556 data.append_value(self_val - other_val);
557 }
558
559 record.push(Arc::new(data.finish()));
560 }
561
562 info!("Serialized {} states differences", self_states.len());
563
564 let mut metadata = HashMap::new();
566 metadata.insert(
567 "Purpose".to_string(),
568 "Trajectory difference data".to_string(),
569 );
570 if let Some(add_meta) = cfg.metadata {
571 for (k, v) in add_meta {
572 metadata.insert(k, v);
573 }
574 }
575
576 let props = pq_writer(Some(metadata));
577
578 let file = File::create(&path_buf).map_err(|err| TrajError::TrajGeneric {
579 err: format!("{err:?}"),
580 })?;
581 let mut writer = ArrowWriter::try_new(file, schema.clone(), props).unwrap();
582
583 let batch = RecordBatch::try_new(schema, record).map_err(|err| TrajError::TrajGeneric {
584 err: format!("{err:?}"),
585 })?;
586 writer.write(&batch).map_err(|err| TrajError::TrajGeneric {
587 err: format!("{err:?}"),
588 })?;
589 writer.close().map_err(|err| TrajError::TrajGeneric {
590 err: format!("{err:?}"),
591 })?;
592
593 let tock_time = Epoch::now().unwrap() - tick;
595 info!(
596 "Trajectory written to {} in {tock_time}",
597 path_buf.display()
598 );
599 Ok(path_buf)
600 }
601}
602
603impl<S: Interpolatable> ops::Add for Traj<S>
604where
605 DefaultAllocator: Allocator<S::VecLength> + Allocator<S::Size> + Allocator<S::Size, S::Size>,
606{
607 type Output = Result<Traj<S>, NyxError>;
608
609 fn add(self, other: Traj<S>) -> Self::Output {
611 &self + &other
612 }
613}
614
615impl<S: Interpolatable> ops::Add<&Traj<S>> for &Traj<S>
616where
617 DefaultAllocator: Allocator<S::VecLength> + Allocator<S::Size> + Allocator<S::Size, S::Size>,
618{
619 type Output = Result<Traj<S>, NyxError>;
620
621 fn add(self, other: &Traj<S>) -> Self::Output {
623 if self.first().frame() != other.first().frame() {
624 Err(NyxError::Trajectory {
625 source: TrajError::CreationError {
626 msg: format!(
627 "Frame mismatch in add operation: {} != {}",
628 self.first().frame(),
629 other.first().frame()
630 ),
631 },
632 })
633 } else {
634 if self.last().epoch() < other.first().epoch() {
635 let gap = other.first().epoch() - self.last().epoch();
636 warn!(
637 "Resulting merged trajectory will have a time-gap of {} starting at {}",
638 gap,
639 self.last().epoch()
640 );
641 }
642
643 let mut me = self.clone();
644 for state in &other
646 .states
647 .iter()
648 .copied()
649 .filter(|s| s.epoch() > self.last().epoch())
650 .collect::<Vec<S>>()
651 {
652 me.states.push(*state);
653 }
654 me.finalize();
655
656 Ok(me)
657 }
658 }
659}
660
661impl<S: Interpolatable> ops::AddAssign<&Traj<S>> for Traj<S>
662where
663 DefaultAllocator: Allocator<S::VecLength> + Allocator<S::Size> + Allocator<S::Size, S::Size>,
664{
665 fn add_assign(&mut self, rhs: &Self) {
671 *self = (self.clone() + rhs.clone()).unwrap();
672 }
673}
674
675impl<S: Interpolatable> fmt::Display for Traj<S>
676where
677 DefaultAllocator: Allocator<S::VecLength> + Allocator<S::Size> + Allocator<S::Size, S::Size>,
678{
679 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
680 if self.states.is_empty() {
681 write!(f, "Empty Trajectory!")
682 } else {
683 let dur = self.last().epoch() - self.first().epoch();
684 write!(
685 f,
686 "Trajectory {}in {} from {} to {} ({}, or {:.3} s) [{} states]",
687 match &self.name {
688 Some(name) => format!("of {name} "),
689 None => String::new(),
690 },
691 self.first().frame(),
692 self.first().epoch(),
693 self.last().epoch(),
694 dur,
695 dur.to_seconds(),
696 self.states.len()
697 )
698 }
699 }
700}
701
702impl<S: Interpolatable> fmt::Debug for Traj<S>
703where
704 DefaultAllocator: Allocator<S::VecLength> + Allocator<S::Size> + Allocator<S::Size, S::Size>,
705{
706 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
707 write!(f, "{self}",)
708 }
709}
710
711impl<S: Interpolatable> Default for Traj<S>
712where
713 DefaultAllocator: Allocator<S::VecLength> + Allocator<S::Size> + Allocator<S::Size, S::Size>,
714{
715 fn default() -> Self {
716 Self::new()
717 }
718}
719
720impl<S: Interpolatable> StateSpecTrait for Traj<S>
721where
722 DefaultAllocator: Allocator<S::VecLength> + Allocator<S::Size> + Allocator<S::Size, S::Size>,
723{
724 fn ab_corr(&self) -> Option<Aberration> {
725 None
726 }
727
728 fn evaluate(&self, epoch: Epoch, _almanac: &Almanac) -> Result<Orbit, AnalysisError> {
729 self.at(epoch)
730 .map(|state| state.orbit())
731 .map_err(|e| AnalysisError::GenericAnalysisError {
732 err: format!("{e}"),
733 })
734 }
735}