Skip to main content

pg_derive/
lib.rs

1use proc_macro::TokenStream;
2use quote::quote;
3use syn::{Data, DeriveInput, Fields, Meta, parse_macro_input};
4
5#[proc_macro_derive(FromRow, attributes(pg))]
6pub fn derive_from_row(input: TokenStream) -> TokenStream {
7    let input = parse_macro_input!(input as DeriveInput);
8    let name = &input.ident;
9
10    let rename_all = get_rename_all(&input.attrs);
11
12    let fields = match &input.data {
13        Data::Struct(data) => match &data.fields {
14            Fields::Named(fields) => &fields.named,
15            _ => panic!("FromRow only supports structs with named fields"),
16        },
17        _ => panic!("FromRow only supports structs"),
18    };
19
20    let field_assignments: Vec<_> = fields
21        .iter()
22        .map(|f| {
23            let field_name = &f.ident;
24            let field_name_str = field_name.as_ref().unwrap().to_string();
25
26            let (column_name, default) = parse_field_attrs(&f.attrs);
27            let column_name = column_name.unwrap_or_else(|| match rename_all.as_deref() {
28                Some("camelCase") => to_camel_case(&field_name_str),
29                Some("PascalCase") => to_pascal_case(&field_name_str),
30                Some("lowercase") => field_name_str.to_lowercase(),
31                Some("UPPERCASE") => field_name_str.to_uppercase(),
32                _ => to_snake_case(&field_name_str),
33            });
34
35            let ty = &f.ty;
36
37            if default {
38                quote! {
39                    #field_name: row.try_get::<#ty>(#column_name).unwrap_or_default(),
40                }
41            } else {
42                quote! {
43                    #field_name: row.try_get::<#ty>(#column_name)?,
44                }
45            }
46        })
47        .collect();
48
49    let expanded = quote! {
50        impl ::pg::FromRow for #name {
51            fn from_row(row: &::pg::Row) -> std::result::Result<Self, ::pg::PgError> {
52                Ok(Self {
53                    #(#field_assignments)*
54                })
55            }
56        }
57    };
58
59    TokenStream::from(expanded)
60}
61
62fn get_rename_all(attrs: &[syn::Attribute]) -> Option<String> {
63    for attr in attrs {
64        if attr.path().is_ident("pg") {
65            if let Ok(meta) = attr.parse_args::<Meta>() {
66                match meta {
67                    Meta::List(list) => {
68                        let inner = list.tokens.to_string();
69                        let inner = inner.trim().trim_start_matches('(').trim_end_matches(')');
70                        if let Some(val) = inner.strip_prefix("rename_all = ") {
71                            return Some(val.trim().trim_matches('"').to_string());
72                        }
73                    }
74                    _ => {}
75                }
76            }
77        }
78    }
79    None
80}
81
82fn parse_field_attrs(attrs: &[syn::Attribute]) -> (Option<String>, bool) {
83    let mut column = None;
84    let mut default = false;
85
86    for attr in attrs {
87        if attr.path().is_ident("pg") {
88            let inner = attr.meta.require_list().unwrap().tokens.to_string();
89            for part in inner.split(',') {
90                let part = part.trim();
91                if part == "default" {
92                    default = true;
93                } else if let Some(val) = part.strip_prefix("column = ") {
94                    column = Some(val.trim().trim_matches('"').to_string());
95                }
96            }
97        }
98    }
99
100    (column, default)
101}
102
103fn to_camel_case(s: &str) -> String {
104    let mut result = String::new();
105    let mut upper = false;
106    for c in s.chars() {
107        if c == '_' {
108            upper = true;
109        } else if upper {
110            result.push(c.to_ascii_uppercase());
111            upper = false;
112        } else {
113            result.push(c);
114        }
115    }
116    result
117}
118
119fn to_pascal_case(s: &str) -> String {
120    let camel = to_camel_case(s);
121    let mut chars = camel.chars();
122    match chars.next() {
123        Some(c) => c.to_ascii_uppercase().to_string() + chars.as_str(),
124        None => String::new(),
125    }
126}
127
128fn to_snake_case(s: &str) -> String {
129    let mut result = String::new();
130    for (i, c) in s.chars().enumerate() {
131        if c.is_uppercase() && i > 0 {
132            result.push('_');
133        }
134        result.push(c.to_ascii_lowercase());
135    }
136    result
137}