1use proc_macro::TokenStream;
2use quote::quote;
3use syn::{Data, DeriveInput, Fields, Meta, parse_macro_input};
4
5#[proc_macro_derive(FromRow, attributes(pg))]
6pub fn derive_from_row(input: TokenStream) -> TokenStream {
7 let input = parse_macro_input!(input as DeriveInput);
8 let name = &input.ident;
9
10 let rename_all = get_rename_all(&input.attrs);
11
12 let fields = match &input.data {
13 Data::Struct(data) => match &data.fields {
14 Fields::Named(fields) => &fields.named,
15 _ => panic!("FromRow only supports structs with named fields"),
16 },
17 _ => panic!("FromRow only supports structs"),
18 };
19
20 let field_assignments: Vec<_> = fields
21 .iter()
22 .map(|f| {
23 let field_name = &f.ident;
24 let field_name_str = field_name.as_ref().unwrap().to_string();
25
26 let (column_name, default) = parse_field_attrs(&f.attrs);
27 let column_name = column_name.unwrap_or_else(|| match rename_all.as_deref() {
28 Some("camelCase") => to_camel_case(&field_name_str),
29 Some("PascalCase") => to_pascal_case(&field_name_str),
30 Some("lowercase") => field_name_str.to_lowercase(),
31 Some("UPPERCASE") => field_name_str.to_uppercase(),
32 _ => to_snake_case(&field_name_str),
33 });
34
35 let ty = &f.ty;
36
37 if default {
38 quote! {
39 #field_name: row.try_get::<#ty>(#column_name).unwrap_or_default(),
40 }
41 } else {
42 quote! {
43 #field_name: row.try_get::<#ty>(#column_name)?,
44 }
45 }
46 })
47 .collect();
48
49 let expanded = quote! {
50 impl ::pg::FromRow for #name {
51 fn from_row(row: &::pg::Row) -> std::result::Result<Self, ::pg::PgError> {
52 Ok(Self {
53 #(#field_assignments)*
54 })
55 }
56 }
57 };
58
59 TokenStream::from(expanded)
60}
61
62fn get_rename_all(attrs: &[syn::Attribute]) -> Option<String> {
63 for attr in attrs {
64 if attr.path().is_ident("pg") {
65 if let Ok(meta) = attr.parse_args::<Meta>() {
66 match meta {
67 Meta::List(list) => {
68 let inner = list.tokens.to_string();
69 let inner = inner.trim().trim_start_matches('(').trim_end_matches(')');
70 if let Some(val) = inner.strip_prefix("rename_all = ") {
71 return Some(val.trim().trim_matches('"').to_string());
72 }
73 }
74 _ => {}
75 }
76 }
77 }
78 }
79 None
80}
81
82fn parse_field_attrs(attrs: &[syn::Attribute]) -> (Option<String>, bool) {
83 let mut column = None;
84 let mut default = false;
85
86 for attr in attrs {
87 if attr.path().is_ident("pg") {
88 let inner = attr.meta.require_list().unwrap().tokens.to_string();
89 for part in inner.split(',') {
90 let part = part.trim();
91 if part == "default" {
92 default = true;
93 } else if let Some(val) = part.strip_prefix("column = ") {
94 column = Some(val.trim().trim_matches('"').to_string());
95 }
96 }
97 }
98 }
99
100 (column, default)
101}
102
103fn to_camel_case(s: &str) -> String {
104 let mut result = String::new();
105 let mut upper = false;
106 for c in s.chars() {
107 if c == '_' {
108 upper = true;
109 } else if upper {
110 result.push(c.to_ascii_uppercase());
111 upper = false;
112 } else {
113 result.push(c);
114 }
115 }
116 result
117}
118
119fn to_pascal_case(s: &str) -> String {
120 let camel = to_camel_case(s);
121 let mut chars = camel.chars();
122 match chars.next() {
123 Some(c) => c.to_ascii_uppercase().to_string() + chars.as_str(),
124 None => String::new(),
125 }
126}
127
128fn to_snake_case(s: &str) -> String {
129 let mut result = String::new();
130 for (i, c) in s.chars().enumerate() {
131 if c.is_uppercase() && i > 0 {
132 result.push('_');
133 }
134 result.push(c.to_ascii_lowercase());
135 }
136 result
137}