package yamlsaver import ( "fmt" "io" "os" "path/filepath" "sort" "strconv" "github.com/databricks/cli/libs/dyn" "gopkg.in/yaml.v3" ) type saver struct { nodesWithStyle map[string]yaml.Style } func NewSaver() *saver { return &saver{} } func NewSaverWithStyle(nodesWithStyle map[string]yaml.Style) *saver { return &saver{ nodesWithStyle: nodesWithStyle, } } func (s *saver) SaveAsYAML(data any, filename string, force bool) error { err := os.MkdirAll(filepath.Dir(filename), 0o755) if err != nil { return err } // check that file exists info, err := os.Stat(filename) if err == nil { if info.IsDir() { return fmt.Errorf("%s is a directory", filename) } if !force { return fmt.Errorf("%s already exists. Use --force to overwrite", filename) } } file, err := os.Create(filename) if err != nil { return err } defer file.Close() err = s.encode(data, file) if err != nil { return err } return nil } func (s *saver) encode(data any, w io.Writer) error { yamlNode, err := s.toYamlNode(dyn.V(data)) if err != nil { return err } enc := yaml.NewEncoder(w) enc.SetIndent(2) return enc.Encode(yamlNode) } func (s *saver) toYamlNode(v dyn.Value) (*yaml.Node, error) { return s.toYamlNodeWithStyle(v, yaml.Style(0)) } func (s *saver) toYamlNodeWithStyle(v dyn.Value, style yaml.Style) (*yaml.Node, error) { switch v.Kind() { case dyn.KindMap: m, _ := v.AsMap() // We're using location lines to define the order of keys in YAML. // The location is set when we convert API response struct to config.Value representation // See convert.convertMap for details pairs := m.Pairs() sort.SliceStable(pairs, func(i, j int) bool { return pairs[i].Value.Location().Line < pairs[j].Value.Location().Line }) content := make([]*yaml.Node, 0) for _, pair := range pairs { pk := pair.Key pv := pair.Value node := yaml.Node{Kind: yaml.ScalarNode, Value: pk.MustString(), Style: style} var nestedNodeStyle yaml.Style if customStyle, ok := s.hasStyle(pk.MustString()); ok { nestedNodeStyle = customStyle } else { nestedNodeStyle = style } c, err := s.toYamlNodeWithStyle(pv, nestedNodeStyle) if err != nil { return nil, err } content = append(content, &node) content = append(content, c) } return &yaml.Node{Kind: yaml.MappingNode, Content: content, Style: style}, nil case dyn.KindSequence: seq, _ := v.AsSequence() content := make([]*yaml.Node, 0) for _, item := range seq { node, err := s.toYamlNodeWithStyle(item, style) if err != nil { return nil, err } content = append(content, node) } return &yaml.Node{Kind: yaml.SequenceNode, Content: content, Style: style}, nil case dyn.KindNil: return &yaml.Node{Kind: yaml.ScalarNode, Value: "null", Style: style}, nil case dyn.KindString: // If the string is a scalar value (bool, int, float and etc.), we want to quote it. if isScalarValueInString(v) { return &yaml.Node{Kind: yaml.ScalarNode, Value: v.MustString(), Style: yaml.DoubleQuotedStyle}, nil } return &yaml.Node{Kind: yaml.ScalarNode, Value: v.MustString(), Style: style}, nil case dyn.KindBool: return &yaml.Node{Kind: yaml.ScalarNode, Value: strconv.FormatBool(v.MustBool()), Style: style}, nil case dyn.KindInt: return &yaml.Node{Kind: yaml.ScalarNode, Value: strconv.FormatInt(v.MustInt(), 10), Style: style}, nil case dyn.KindFloat: return &yaml.Node{Kind: yaml.ScalarNode, Value: fmt.Sprint(v.MustFloat()), Style: style}, nil case dyn.KindTime: return &yaml.Node{Kind: yaml.ScalarNode, Value: v.MustTime().String(), Style: style}, nil default: // Panic because we only want to deal with known types. panic(fmt.Sprintf("invalid kind: %d", v.Kind())) } } func (s *saver) hasStyle(key string) (yaml.Style, bool) { style, ok := s.nodesWithStyle[key] return style, ok } func isScalarValueInString(v dyn.Value) bool { if v.Kind() != dyn.KindString { return false } // Parse value of the string and check if it's a scalar value. // If it's a scalar value, we want to quote it. switch v.MustString() { case "true", "false": return true case "": return true default: _, err := parseNumber(v.MustString()) return err == nil } } func parseNumber(s string) (any, error) { if i, err := strconv.ParseInt(s, 0, 64); err == nil { return i, nil } if f, err := strconv.ParseFloat(s, 64); err == nil { return f, nil } return nil, fmt.Errorf("invalid number: %s", s) }