freya_router_macro/
route_tree.rs

1use proc_macro2::TokenStream;
2use quote::quote;
3use slab::Slab;
4use syn::Ident;
5
6use crate::{
7    RouteEndpoint,
8    nest::{
9        Nest,
10        NestId,
11    },
12    redirect::Redirect,
13    route::{
14        Route,
15        RouteType,
16    },
17    segment::{
18        RouteSegment,
19        static_segment_idx,
20    },
21};
22
23#[derive(Debug, Clone, Default)]
24pub(crate) struct ParseRouteTree<'a> {
25    pub roots: Vec<usize>,
26    entries: Slab<RouteTreeSegmentData<'a>>,
27}
28
29impl<'a> ParseRouteTree<'a> {
30    pub fn get(&self, index: usize) -> Option<&RouteTreeSegmentData<'a>> {
31        self.entries.get(index)
32    }
33
34    pub fn get_mut(&mut self, element: usize) -> Option<&mut RouteTreeSegmentData<'a>> {
35        self.entries.get_mut(element)
36    }
37
38    fn sort_children(&mut self) {
39        let mut old_roots = self.roots.clone();
40        self.sort_ids(&mut old_roots);
41        self.roots = old_roots;
42
43        for id in self.roots.clone() {
44            self.sort_children_of_id(id);
45        }
46    }
47
48    fn sort_ids(&self, ids: &mut [usize]) {
49        ids.sort_by_key(|&seg| {
50            let seg = self.get(seg).unwrap();
51            match seg {
52                RouteTreeSegmentData::Static { .. } => 0,
53                RouteTreeSegmentData::Nest { .. } => 1,
54                RouteTreeSegmentData::Route(route) => {
55                    // Routes that end in a catch all segment should be checked last
56                    match route.segments.last() {
57                        Some(RouteSegment::CatchAll(..)) => 2,
58                        _ => 1,
59                    }
60                }
61                RouteTreeSegmentData::Redirect(redirect) => {
62                    // Routes that end in a catch all segment should be checked last
63                    match redirect.segments.last() {
64                        Some(RouteSegment::CatchAll(..)) => 2,
65                        _ => 1,
66                    }
67                }
68            }
69        });
70    }
71
72    fn sort_children_of_id(&mut self, id: usize) {
73        // Sort segments so that all static routes are checked before dynamic routes
74        let mut children = self.children(id);
75
76        self.sort_ids(&mut children);
77
78        if let Some(old) = self.try_children_mut(id) {
79            old.clone_from(&children)
80        }
81
82        for id in children {
83            self.sort_children_of_id(id);
84        }
85    }
86
87    fn children(&self, element: usize) -> Vec<usize> {
88        let element = self.entries.get(element).unwrap();
89        match element {
90            RouteTreeSegmentData::Static { children, .. } => children.clone(),
91            RouteTreeSegmentData::Nest { children, .. } => children.clone(),
92            _ => Vec::new(),
93        }
94    }
95
96    fn try_children_mut(&mut self, element: usize) -> Option<&mut Vec<usize>> {
97        let element = self.entries.get_mut(element).unwrap();
98        match element {
99            RouteTreeSegmentData::Static { children, .. } => Some(children),
100            RouteTreeSegmentData::Nest { children, .. } => Some(children),
101            _ => None,
102        }
103    }
104
105    fn children_mut(&mut self, element: usize) -> &mut Vec<usize> {
106        self.try_children_mut(element)
107            .expect("Cannot get children of non static or nest segment")
108    }
109
110    pub(crate) fn new(endpoints: &'a [RouteEndpoint], nests: &'a [Nest]) -> Self {
111        let routes = endpoints
112            .iter()
113            .map(|endpoint| match endpoint {
114                RouteEndpoint::Route(route) => PathIter::new_route(route, nests),
115                RouteEndpoint::Redirect(redirect) => PathIter::new_redirect(redirect, nests),
116            })
117            .collect::<Vec<_>>();
118
119        let mut myself = Self::default();
120        myself.roots = myself.construct(routes);
121        myself.sort_children();
122
123        myself
124    }
125
126    pub fn construct(&mut self, routes: Vec<PathIter<'a>>) -> Vec<usize> {
127        let mut segments = Vec::new();
128
129        // Add all routes to the tree
130        for mut route in routes {
131            let mut current_route: Option<usize> = None;
132
133            // First add all nests
134            while let Some(nest) = route.next_nest() {
135                let segments_iter = nest.segments.iter();
136
137                // Add all static segments of the nest
138                'o: for (index, segment) in segments_iter.enumerate() {
139                    match segment {
140                        RouteSegment::Static(segment) => {
141                            // Check if the segment already exists
142                            {
143                                // Either look for the segment in the current route or in the static segments
144                                let segments = current_route
145                                    .map(|id| self.children(id))
146                                    .unwrap_or_else(|| segments.clone());
147
148                                for &seg_id in segments.iter() {
149                                    let seg = self.get(seg_id).unwrap();
150                                    if let RouteTreeSegmentData::Static { segment: s, .. } = seg
151                                        && *s == segment
152                                    {
153                                        // If it does, just update the current route
154                                        current_route = Some(seg_id);
155                                        continue 'o;
156                                    }
157                                }
158                            }
159
160                            let static_segment = RouteTreeSegmentData::Static {
161                                segment,
162                                children: Vec::new(),
163                                error_variant: StaticErrorVariant {
164                                    variant_parse_error: nest.error_ident(),
165                                    enum_variant: nest.error_variant(),
166                                },
167                                index,
168                            };
169
170                            // If it doesn't, add the segment to the current route
171                            let static_segment = self.entries.insert(static_segment);
172
173                            let current_children = current_route
174                                .map(|id| self.children_mut(id))
175                                .unwrap_or_else(|| &mut segments);
176                            current_children.push(static_segment);
177
178                            // Update the current route
179                            current_route = Some(static_segment);
180                        }
181                        // If there is a dynamic segment, stop adding static segments
182                        RouteSegment::Dynamic(..) => break,
183                        RouteSegment::CatchAll(..) => {
184                            todo!("Catch all segments are not allowed in nests")
185                        }
186                    }
187                }
188
189                // Add the nest to the current route
190                let nest = RouteTreeSegmentData::Nest {
191                    nest,
192                    children: Vec::new(),
193                };
194
195                let nest = self.entries.insert(nest);
196                let segments = match current_route.and_then(|id| self.get_mut(id)) {
197                    Some(RouteTreeSegmentData::Static { children, .. }) => children,
198                    Some(RouteTreeSegmentData::Nest { children, .. }) => children,
199                    Some(r) => {
200                        unreachable!("{current_route:?}\n{r:?} is not a static or nest segment",)
201                    }
202                    None => &mut segments,
203                };
204                segments.push(nest);
205
206                // Update the current route
207                current_route = Some(nest);
208            }
209
210            match route.next_static_segment() {
211                // If there is a static segment, check if it already exists in the tree
212                Some((i, segment)) => {
213                    let current_children = current_route
214                        .map(|id| self.children(id))
215                        .unwrap_or_else(|| segments.clone());
216                    let found = current_children.iter().find_map(|&id| {
217                        let seg = self.get(id).unwrap();
218                        match seg {
219                            RouteTreeSegmentData::Static { segment: s, .. } => {
220                                (s == &segment).then_some(id)
221                            }
222                            _ => None,
223                        }
224                    });
225
226                    match found {
227                        Some(id) => {
228                            // If it exists, add the route to the children of the segment
229                            let new_children = self.construct(vec![route]);
230                            self.children_mut(id).extend(new_children);
231                        }
232                        None => {
233                            // If it doesn't exist, add the route as a new segment
234                            let data = RouteTreeSegmentData::Static {
235                                segment,
236                                error_variant: route.error_variant(),
237                                children: self.construct(vec![route]),
238                                index: i,
239                            };
240                            let id = self.entries.insert(data);
241                            let current_children_mut = current_route
242                                .map(|id| self.children_mut(id))
243                                .unwrap_or_else(|| &mut segments);
244                            current_children_mut.push(id);
245                        }
246                    }
247                }
248                // If there is no static segment, add the route to the current_route
249                None => {
250                    let id = self.entries.insert(route.final_segment);
251                    let current_children_mut = current_route
252                        .map(|id| self.children_mut(id))
253                        .unwrap_or_else(|| &mut segments);
254                    current_children_mut.push(id);
255                }
256            }
257        }
258
259        segments
260    }
261}
262
263#[derive(Debug, Clone)]
264pub struct StaticErrorVariant {
265    variant_parse_error: Ident,
266    enum_variant: Ident,
267}
268
269// First deduplicate the routes by the static part of the route
270#[derive(Debug, Clone)]
271pub(crate) enum RouteTreeSegmentData<'a> {
272    Static {
273        segment: &'a str,
274        error_variant: StaticErrorVariant,
275        index: usize,
276        children: Vec<usize>,
277    },
278    Nest {
279        nest: &'a Nest,
280        children: Vec<usize>,
281    },
282    Route(&'a Route),
283    Redirect(&'a Redirect),
284}
285
286impl RouteTreeSegmentData<'_> {
287    pub fn to_tokens(
288        &self,
289        nests: &[Nest],
290        tree: &ParseRouteTree,
291        enum_name: syn::Ident,
292        error_enum_name: syn::Ident,
293    ) -> TokenStream {
294        match self {
295            RouteTreeSegmentData::Static {
296                segment,
297                children,
298                index,
299                error_variant:
300                    StaticErrorVariant {
301                        variant_parse_error,
302                        enum_variant,
303                    },
304            } => {
305                let children = children.iter().map(|child| {
306                    let child = tree.get(*child).unwrap();
307                    child.to_tokens(nests, tree, enum_name.clone(), error_enum_name.clone())
308                });
309
310                if segment.is_empty() {
311                    return quote! {
312                        {
313                            #(#children)*
314                        }
315                    };
316                }
317
318                let error_ident = static_segment_idx(*index);
319
320                quote! {
321                    {
322                        let mut segments = segments.clone();
323                        let segment = segments.next();
324                        if let Some(segment) = segment.as_deref() {
325                            if #segment == segment {
326                                #(#children)*
327                            } else {
328                                errors.push(#error_enum_name::#enum_variant(#variant_parse_error::#error_ident(segment.to_string())))
329                            }
330                        }
331                    }
332                }
333            }
334            RouteTreeSegmentData::Route(route) => {
335                // At this point, we have matched all static segments, so we can just check if the remaining segments match the route
336                let variant_parse_error = route.error_ident();
337                let enum_variant = &route.route_name;
338
339                let route_segments = route
340                    .segments
341                    .iter()
342                    .enumerate()
343                    .skip_while(|(_, seg)| matches!(seg, RouteSegment::Static(_)))
344                    .filter(|(i, _)| {
345                        // Don't add any trailing static segments. We strip them during parsing so that routes can accept either `/route/` and `/route`
346                        !is_trailing_static_segment(&route.segments, *i)
347                    });
348
349                let construct_variant = route.construct(nests, enum_name);
350                let parse_query = route.parse_query();
351                let parse_hash = route.parse_hash();
352
353                let insure_not_trailing = match route.ty {
354                    RouteType::Leaf { .. } => route
355                        .segments
356                        .last()
357                        .map(|seg| !matches!(seg, RouteSegment::CatchAll(_, _)))
358                        .unwrap_or(true),
359                    RouteType::Child(_) => false,
360                };
361
362                let print_route_segment = print_route_segment(
363                    route_segments.peekable(),
364                    return_constructed(
365                        insure_not_trailing,
366                        construct_variant,
367                        &error_enum_name,
368                        enum_variant,
369                        &variant_parse_error,
370                        parse_query,
371                        parse_hash,
372                    ),
373                    &error_enum_name,
374                    enum_variant,
375                    &variant_parse_error,
376                );
377
378                match &route.ty {
379                    RouteType::Child(child) => {
380                        let ty = &child.ty;
381                        let child_name = &child.ident;
382
383                        quote! {
384                            let mut trailing = String::from("/");
385                            for seg in segments.clone() {
386                                trailing += &*seg;
387                                trailing += "/";
388                            }
389                            match #ty::from_str(&trailing).map_err(|err| #error_enum_name::#enum_variant(#variant_parse_error::ChildRoute(err))) {
390                                Ok(#child_name) => {
391                                    #print_route_segment
392                                }
393                                Err(err) => {
394                                    errors.push(err);
395                                }
396                            }
397                        }
398                    }
399                    RouteType::Leaf { .. } => print_route_segment,
400                }
401            }
402            Self::Nest { nest, children } => {
403                // At this point, we have matched all static segments, so we can just check if the remaining segments match the route
404                let variant_parse_error: Ident = nest.error_ident();
405                let enum_variant = nest.error_variant();
406
407                let route_segments = nest
408                    .segments
409                    .iter()
410                    .enumerate()
411                    .skip_while(|(_, seg)| matches!(seg, RouteSegment::Static(_)));
412
413                let parse_children = children
414                    .iter()
415                    .map(|child| {
416                        let child = tree.get(*child).unwrap();
417                        child.to_tokens(nests, tree, enum_name.clone(), error_enum_name.clone())
418                    })
419                    .collect();
420
421                print_route_segment(
422                    route_segments.peekable(),
423                    parse_children,
424                    &error_enum_name,
425                    &enum_variant,
426                    &variant_parse_error,
427                )
428            }
429            Self::Redirect(redirect) => {
430                // At this point, we have matched all static segments, so we can just check if the remaining segments match the route
431                let variant_parse_error = redirect.error_ident();
432                let enum_variant = &redirect.error_variant();
433
434                let route_segments = redirect
435                    .segments
436                    .iter()
437                    .enumerate()
438                    .skip_while(|(_, seg)| matches!(seg, RouteSegment::Static(_)));
439
440                let parse_query = redirect.parse_query();
441                let parse_hash = redirect.parse_hash();
442
443                let insure_not_trailing = redirect
444                    .segments
445                    .last()
446                    .map(|seg| !matches!(seg, RouteSegment::CatchAll(_, _)))
447                    .unwrap_or(true);
448
449                let redirect_function = &redirect.function;
450                let args = redirect_function.inputs.iter().map(|pat| match pat {
451                    syn::Pat::Type(ident) => {
452                        let name = &ident.pat;
453                        quote! {#name}
454                    }
455                    _ => panic!("Expected closure argument to be a typed pattern"),
456                });
457                let return_redirect = quote! {
458                    (#redirect_function)(#(#args,)*)
459                };
460
461                print_route_segment(
462                    route_segments.peekable(),
463                    return_constructed(
464                        insure_not_trailing,
465                        return_redirect,
466                        &error_enum_name,
467                        enum_variant,
468                        &variant_parse_error,
469                        parse_query,
470                        parse_hash,
471                    ),
472                    &error_enum_name,
473                    enum_variant,
474                    &variant_parse_error,
475                )
476            }
477        }
478    }
479}
480
481fn print_route_segment<'a, I: Iterator<Item = (usize, &'a RouteSegment)>>(
482    mut s: std::iter::Peekable<I>,
483    success_tokens: TokenStream,
484    error_enum_name: &Ident,
485    enum_variant: &Ident,
486    variant_parse_error: &Ident,
487) -> TokenStream {
488    if let Some((i, route)) = s.next() {
489        let children = print_route_segment(
490            s,
491            success_tokens,
492            error_enum_name,
493            enum_variant,
494            variant_parse_error,
495        );
496
497        route.try_parse(
498            i,
499            error_enum_name,
500            enum_variant,
501            variant_parse_error,
502            children,
503        )
504    } else {
505        quote! {
506            #success_tokens
507        }
508    }
509}
510
511fn return_constructed(
512    insure_not_trailing: bool,
513    construct_variant: TokenStream,
514    error_enum_name: &Ident,
515    enum_variant: &Ident,
516    variant_parse_error: &Ident,
517    parse_query: TokenStream,
518    parse_hash: TokenStream,
519) -> TokenStream {
520    if insure_not_trailing {
521        quote! {
522            let remaining_segments = segments.clone();
523            let mut segments_clone = segments.clone();
524            let next_segment = segments_clone.next();
525            // This is the last segment, return the parsed route
526            if next_segment.is_none() {
527                #parse_query
528                #parse_hash
529                return Ok(#construct_variant);
530            } else {
531                let mut trailing = String::new();
532                for seg in remaining_segments {
533                    trailing += &*seg;
534                    trailing += "/";
535                }
536                trailing.pop();
537                errors.push(#error_enum_name::#enum_variant(#variant_parse_error::ExtraSegments(trailing)))
538            }
539        }
540    } else {
541        quote! {
542            #parse_query
543            #parse_hash
544            return Ok(#construct_variant);
545        }
546    }
547}
548
549pub struct PathIter<'a> {
550    final_segment: RouteTreeSegmentData<'a>,
551    active_nests: &'a [NestId],
552    all_nests: &'a [Nest],
553    segments: &'a [RouteSegment],
554    error_ident: Ident,
555    error_variant: Ident,
556    nest_index: usize,
557    static_segment_index: usize,
558}
559
560impl<'a> PathIter<'a> {
561    fn new_route(route: &'a Route, nests: &'a [Nest]) -> Self {
562        Self {
563            final_segment: RouteTreeSegmentData::Route(route),
564            active_nests: &*route.nests,
565            segments: &*route.segments,
566            error_ident: route.error_ident(),
567            error_variant: route.route_name.clone(),
568            all_nests: nests,
569            nest_index: 0,
570            static_segment_index: 0,
571        }
572    }
573
574    fn new_redirect(redirect: &'a Redirect, nests: &'a [Nest]) -> Self {
575        Self {
576            final_segment: RouteTreeSegmentData::Redirect(redirect),
577            active_nests: &*redirect.nests,
578            segments: &*redirect.segments,
579            error_ident: redirect.error_ident(),
580            error_variant: redirect.error_variant(),
581            all_nests: nests,
582            nest_index: 0,
583            static_segment_index: 0,
584        }
585    }
586
587    fn next_nest(&mut self) -> Option<&'a Nest> {
588        let idx = self.nest_index;
589        let nest_index = self.active_nests.get(idx)?;
590        let nest = &self.all_nests[nest_index.0];
591        self.nest_index += 1;
592        Some(nest)
593    }
594
595    fn next_static_segment(&mut self) -> Option<(usize, &'a str)> {
596        let idx = self.static_segment_index;
597        let segment = self.segments.get(idx)?;
598        // Don't add any trailing static segments. We strip them during parsing so that routes can accept either `/route/` and `/route`
599        if is_trailing_static_segment(self.segments, idx) {
600            return None;
601        }
602        match segment {
603            RouteSegment::Static(segment) => {
604                self.static_segment_index += 1;
605                Some((idx, segment))
606            }
607            _ => None,
608        }
609    }
610
611    fn error_variant(&self) -> StaticErrorVariant {
612        StaticErrorVariant {
613            variant_parse_error: self.error_ident.clone(),
614            enum_variant: self.error_variant.clone(),
615        }
616    }
617}
618
619// If this is the last segment and it is an empty trailing segment, skip parsing it. The parsing code handles parsing /path/ and /path
620pub(crate) fn is_trailing_static_segment(segments: &[RouteSegment], index: usize) -> bool {
621    // This can only be a trailing segment if we have more than one segment and this is the last segment
622    matches!(segments.get(index), Some(RouteSegment::Static(segment)) if segment.is_empty() && index == segments.len() - 1 && segments.len() > 1)
623}