1use std::{
2 cmp::Ordering,
3 fmt::{self, Debug, Display},
4 hash::{Hash, Hasher},
5 mem,
6};
7
8use serde::{
9 de::{
10 Deserialize, DeserializeSeed, Deserializer, EnumAccess, Error as _, VariantAccess, Visitor,
11 value::{BorrowedStrDeserializer, StrDeserializer},
12 },
13 forward_to_deserialize_any,
14 ser::{Serialize, SerializeMap, Serializer},
15};
16
17use crate::{
18 Error,
19 value::{
20 Value,
21 de::{MapDeserializer, MapRefDeserializer, SeqDeserializer, SeqRefDeserializer},
22 },
23};
24
25#[derive(Clone)]
30pub struct Tag {
31 pub(crate) string: String,
32}
33
34#[derive(Clone, PartialEq, PartialOrd, Hash, Debug)]
61pub struct TaggedValue {
62 #[allow(missing_docs)]
63 pub tag: Tag,
64 #[allow(missing_docs)]
65 pub value: Value,
66}
67
68impl Tag {
69 pub fn new(string: impl Into<String>) -> Self {
98 let tag: String = string.into();
99 assert!(!tag.is_empty(), "empty YAML tag is not allowed");
100 Tag {
101 string: tag,
102 }
103 }
104}
105
106impl Value {
107 pub(crate) fn untag(self) -> Self {
108 let mut cur = self;
109 while let Value::Tagged(tagged) = cur {
110 cur = tagged.value;
111 }
112 cur
113 }
114
115 pub(crate) fn untag_ref(&self) -> &Self {
116 let mut cur = self;
117 while let Value::Tagged(tagged) = cur {
118 cur = &tagged.value;
119 }
120 cur
121 }
122
123 pub(crate) fn untag_mut(&mut self) -> &mut Self {
124 let mut cur = self;
125 while let Value::Tagged(tagged) = cur {
126 cur = &mut tagged.value;
127 }
128 cur
129 }
130}
131
132pub(crate) fn nobang(maybe_banged: &str) -> &str {
133 match maybe_banged.strip_prefix('!') {
134 Some("") | None => maybe_banged,
135 Some(unbanged) => unbanged,
136 }
137}
138
139impl Eq for Tag {}
140
141impl PartialEq for Tag {
142 fn eq(&self, other: &Tag) -> bool {
143 PartialEq::eq(nobang(&self.string), nobang(&other.string))
144 }
145}
146
147impl<T> PartialEq<T> for Tag
148where
149 T: ?Sized + AsRef<str>,
150{
151 fn eq(&self, other: &T) -> bool {
152 PartialEq::eq(nobang(&self.string), nobang(other.as_ref()))
153 }
154}
155
156impl Ord for Tag {
157 fn cmp(&self, other: &Self) -> Ordering {
158 Ord::cmp(nobang(&self.string), nobang(&other.string))
159 }
160}
161
162impl PartialOrd for Tag {
163 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
164 Some(self.cmp(other))
165 }
166}
167
168impl Hash for Tag {
169 fn hash<H: Hasher>(&self, hasher: &mut H) {
170 nobang(&self.string).hash(hasher);
171 }
172}
173
174impl Display for Tag {
175 fn fmt(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
176 write!(formatter, "!{}", nobang(&self.string))
177 }
178}
179
180impl Debug for Tag {
181 fn fmt(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
182 Display::fmt(self, formatter)
183 }
184}
185
186impl Serialize for TaggedValue {
187 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
188 where
189 S: Serializer,
190 {
191 struct SerializeTag<'a>(&'a Tag);
192
193 impl<'a> Serialize for SerializeTag<'a> {
194 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
195 where
196 S: Serializer,
197 {
198 serializer.collect_str(self.0)
199 }
200 }
201
202 let mut map = serializer.serialize_map(Some(1))?;
203 map.serialize_entry(&SerializeTag(&self.tag), &self.value)?;
204 map.end()
205 }
206}
207
208impl<'de> Deserialize<'de> for TaggedValue {
209 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
210 where
211 D: Deserializer<'de>,
212 {
213 struct TaggedValueVisitor;
214
215 impl<'de> Visitor<'de> for TaggedValueVisitor {
216 type Value = TaggedValue;
217
218 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
219 formatter.write_str("a YAML value with a !Tag")
220 }
221
222 fn visit_enum<A>(self, data: A) -> Result<Self::Value, A::Error>
223 where
224 A: EnumAccess<'de>,
225 {
226 let (tag, contents) = data.variant_seed(TagStringVisitor)?;
227 let value = contents.newtype_variant()?;
228 Ok(TaggedValue {
229 tag,
230 value,
231 })
232 }
233 }
234
235 deserializer.deserialize_any(TaggedValueVisitor)
236 }
237}
238
239impl<'de> Deserializer<'de> for TaggedValue {
240 type Error = Error;
241
242 fn deserialize_any<V>(self, visitor: V) -> Result<V::Value, Error>
243 where
244 V: Visitor<'de>,
245 {
246 visitor.visit_enum(self)
247 }
248
249 fn deserialize_ignored_any<V>(self, visitor: V) -> Result<V::Value, Error>
250 where
251 V: Visitor<'de>,
252 {
253 drop(self);
254 visitor.visit_unit()
255 }
256
257 forward_to_deserialize_any! {
258 bool i8 i16 i32 i64 u8 u16 u32 u64 f32 f64 char str string bytes
259 byte_buf option unit unit_struct newtype_struct seq tuple tuple_struct
260 map struct enum identifier
261 }
262}
263
264impl<'de> EnumAccess<'de> for TaggedValue {
265 type Error = Error;
266 type Variant = Value;
267
268 fn variant_seed<V>(self, seed: V) -> Result<(V::Value, Self::Variant), Error>
269 where
270 V: DeserializeSeed<'de>,
271 {
272 let tag = StrDeserializer::<Error>::new(nobang(&self.tag.string));
273 let value = seed.deserialize(tag)?;
274 Ok((value, self.value))
275 }
276}
277
278impl<'de> VariantAccess<'de> for Value {
279 type Error = Error;
280
281 fn unit_variant(self) -> Result<(), Error> {
282 Deserialize::deserialize(self)
283 }
284
285 fn newtype_variant_seed<T>(self, seed: T) -> Result<T::Value, Error>
286 where
287 T: DeserializeSeed<'de>,
288 {
289 seed.deserialize(self)
290 }
291
292 fn tuple_variant<V>(self, _len: usize, visitor: V) -> Result<V::Value, Error>
293 where
294 V: Visitor<'de>,
295 {
296 if let Value::Sequence(v) = self {
297 Deserializer::deserialize_any(SeqDeserializer::new(v), visitor)
298 } else {
299 Err(Error::invalid_type(self.unexpected(), &"tuple variant"))
300 }
301 }
302
303 fn struct_variant<V>(self, _fields: &'static [&'static str], visitor: V) -> Result<V::Value, Error>
304 where
305 V: Visitor<'de>,
306 {
307 if let Value::Mapping(v) = self {
308 Deserializer::deserialize_any(MapDeserializer::new(v), visitor)
309 } else {
310 Err(Error::invalid_type(self.unexpected(), &"struct variant"))
311 }
312 }
313}
314
315impl<'de> Deserializer<'de> for &'de TaggedValue {
316 type Error = Error;
317
318 fn deserialize_any<V>(self, visitor: V) -> Result<V::Value, Error>
319 where
320 V: Visitor<'de>,
321 {
322 visitor.visit_enum(self)
323 }
324
325 fn deserialize_ignored_any<V>(self, visitor: V) -> Result<V::Value, Error>
326 where
327 V: Visitor<'de>,
328 {
329 visitor.visit_unit()
330 }
331
332 forward_to_deserialize_any! {
333 bool i8 i16 i32 i64 u8 u16 u32 u64 f32 f64 char str string bytes
334 byte_buf option unit unit_struct newtype_struct seq tuple tuple_struct
335 map struct enum identifier
336 }
337}
338
339impl<'de> EnumAccess<'de> for &'de TaggedValue {
340 type Error = Error;
341 type Variant = &'de Value;
342
343 fn variant_seed<V>(self, seed: V) -> Result<(V::Value, Self::Variant), Error>
344 where
345 V: DeserializeSeed<'de>,
346 {
347 let tag = BorrowedStrDeserializer::<Error>::new(nobang(&self.tag.string));
348 let value = seed.deserialize(tag)?;
349 Ok((value, &self.value))
350 }
351}
352
353impl<'de> VariantAccess<'de> for &'de Value {
354 type Error = Error;
355
356 fn unit_variant(self) -> Result<(), Error> {
357 Deserialize::deserialize(self)
358 }
359
360 fn newtype_variant_seed<T>(self, seed: T) -> Result<T::Value, Error>
361 where
362 T: DeserializeSeed<'de>,
363 {
364 seed.deserialize(self)
365 }
366
367 fn tuple_variant<V>(self, _len: usize, visitor: V) -> Result<V::Value, Error>
368 where
369 V: Visitor<'de>,
370 {
371 if let Value::Sequence(v) = self {
372 Deserializer::deserialize_any(SeqRefDeserializer::new(v), visitor)
373 } else {
374 Err(Error::invalid_type(self.unexpected(), &"tuple variant"))
375 }
376 }
377
378 fn struct_variant<V>(self, _fields: &'static [&'static str], visitor: V) -> Result<V::Value, Error>
379 where
380 V: Visitor<'de>,
381 {
382 if let Value::Mapping(v) = self {
383 Deserializer::deserialize_any(MapRefDeserializer::new(v), visitor)
384 } else {
385 Err(Error::invalid_type(self.unexpected(), &"struct variant"))
386 }
387 }
388}
389
390pub(crate) struct TagStringVisitor;
391
392impl<'de> Visitor<'de> for TagStringVisitor {
393 type Value = Tag;
394
395 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
396 formatter.write_str("a YAML tag string")
397 }
398
399 fn visit_str<E>(self, string: &str) -> Result<Self::Value, E>
400 where
401 E: serde::de::Error,
402 {
403 self.visit_string(string.to_owned())
404 }
405
406 fn visit_string<E>(self, string: String) -> Result<Self::Value, E>
407 where
408 E: serde::de::Error,
409 {
410 if string.is_empty() {
411 return Err(E::custom("empty YAML tag is not allowed"));
412 }
413 Ok(Tag::new(string))
414 }
415}
416
417impl<'de> DeserializeSeed<'de> for TagStringVisitor {
418 type Value = Tag;
419
420 fn deserialize<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
421 where
422 D: Deserializer<'de>,
423 {
424 deserializer.deserialize_string(self)
425 }
426}
427
428pub(crate) enum MaybeTag<T> {
429 Tag(String),
430 NotTag(T),
431}
432
433pub(crate) fn check_for_tag<T>(value: &T) -> MaybeTag<String>
434where
435 T: ?Sized + Display,
436{
437 enum CheckForTag {
438 Empty,
439 Bang,
440 Tag(String),
441 NotTag(String),
442 }
443
444 impl fmt::Write for CheckForTag {
445 fn write_str(&mut self, s: &str) -> fmt::Result {
446 if s.is_empty() {
447 return Ok(());
448 }
449 match self {
450 CheckForTag::Empty => {
451 if s == "!" {
452 *self = CheckForTag::Bang;
453 } else {
454 *self = CheckForTag::NotTag(s.to_owned());
455 }
456 }
457 CheckForTag::Bang => {
458 *self = CheckForTag::Tag(s.to_owned());
459 }
460 CheckForTag::Tag(string) => {
461 let mut string = mem::take(string);
462 string.push_str(s);
463 *self = CheckForTag::NotTag(string);
464 }
465 CheckForTag::NotTag(string) => {
466 string.push_str(s);
467 }
468 }
469 Ok(())
470 }
471 }
472
473 let mut check_for_tag = CheckForTag::Empty;
474 fmt::write(&mut check_for_tag, format_args!("{}", value)).unwrap();
475 match check_for_tag {
476 CheckForTag::Empty => MaybeTag::NotTag(String::new()),
477 CheckForTag::Bang => MaybeTag::NotTag("!".to_owned()),
478 CheckForTag::Tag(string) => MaybeTag::Tag(string),
479 CheckForTag::NotTag(string) => MaybeTag::NotTag(string),
480 }
481}