2024-01-17 14:26:33 +00:00
|
|
|
package yamlsaver
|
|
|
|
|
|
|
|
import (
|
|
|
|
"fmt"
|
|
|
|
"io"
|
|
|
|
"os"
|
|
|
|
"path/filepath"
|
|
|
|
"sort"
|
|
|
|
"strconv"
|
|
|
|
|
|
|
|
"github.com/databricks/cli/libs/dyn"
|
|
|
|
"gopkg.in/yaml.v3"
|
|
|
|
)
|
|
|
|
|
2024-02-15 15:03:19 +00:00
|
|
|
type saver struct {
|
|
|
|
nodesWithStyle map[string]yaml.Style
|
|
|
|
}
|
|
|
|
|
|
|
|
func NewSaver() *saver {
|
|
|
|
return &saver{}
|
|
|
|
}
|
|
|
|
|
|
|
|
func NewSaverWithStyle(nodesWithStyle map[string]yaml.Style) *saver {
|
|
|
|
return &saver{
|
|
|
|
nodesWithStyle: nodesWithStyle,
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
func (s *saver) SaveAsYAML(data any, filename string, force bool) error {
|
2024-01-17 14:26:33 +00:00
|
|
|
err := os.MkdirAll(filepath.Dir(filename), 0755)
|
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
|
|
|
// check that file exists
|
|
|
|
info, err := os.Stat(filename)
|
|
|
|
if err == nil {
|
|
|
|
if info.IsDir() {
|
|
|
|
return fmt.Errorf("%s is a directory", filename)
|
|
|
|
}
|
|
|
|
if !force {
|
|
|
|
return fmt.Errorf("%s already exists. Use --force to overwrite", filename)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
file, err := os.Create(filename)
|
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
defer file.Close()
|
|
|
|
|
2024-02-15 15:03:19 +00:00
|
|
|
err = s.encode(data, file)
|
2024-01-17 14:26:33 +00:00
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
2024-02-15 15:03:19 +00:00
|
|
|
func (s *saver) encode(data any, w io.Writer) error {
|
|
|
|
yamlNode, err := s.toYamlNode(dyn.V(data))
|
2024-01-17 14:26:33 +00:00
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
enc := yaml.NewEncoder(w)
|
|
|
|
enc.SetIndent(2)
|
|
|
|
return enc.Encode(yamlNode)
|
|
|
|
}
|
|
|
|
|
2024-02-15 15:03:19 +00:00
|
|
|
func (s *saver) toYamlNode(v dyn.Value) (*yaml.Node, error) {
|
|
|
|
return s.toYamlNodeWithStyle(v, yaml.Style(0))
|
|
|
|
}
|
|
|
|
|
|
|
|
func (s *saver) toYamlNodeWithStyle(v dyn.Value, style yaml.Style) (*yaml.Node, error) {
|
2024-01-17 14:26:33 +00:00
|
|
|
switch v.Kind() {
|
|
|
|
case dyn.KindMap:
|
|
|
|
m, _ := v.AsMap()
|
2024-03-25 11:01:09 +00:00
|
|
|
|
2024-01-17 14:26:33 +00:00
|
|
|
// We're using location lines to define the order of keys in YAML.
|
|
|
|
// The location is set when we convert API response struct to config.Value representation
|
|
|
|
// See convert.convertMap for details
|
2024-03-25 11:01:09 +00:00
|
|
|
pairs := m.Pairs()
|
|
|
|
sort.SliceStable(pairs, func(i, j int) bool {
|
|
|
|
return pairs[i].Value.Location().Line < pairs[j].Value.Location().Line
|
2024-01-17 14:26:33 +00:00
|
|
|
})
|
|
|
|
|
|
|
|
content := make([]*yaml.Node, 0)
|
2024-03-25 11:01:09 +00:00
|
|
|
for _, pair := range pairs {
|
|
|
|
pk := pair.Key
|
|
|
|
pv := pair.Value
|
|
|
|
node := yaml.Node{Kind: yaml.ScalarNode, Value: pk.MustString(), Style: style}
|
2024-02-15 15:03:19 +00:00
|
|
|
var nestedNodeStyle yaml.Style
|
2024-03-25 11:01:09 +00:00
|
|
|
if customStyle, ok := s.hasStyle(pk.MustString()); ok {
|
2024-02-15 15:03:19 +00:00
|
|
|
nestedNodeStyle = customStyle
|
|
|
|
} else {
|
|
|
|
nestedNodeStyle = style
|
|
|
|
}
|
2024-03-25 11:01:09 +00:00
|
|
|
c, err := s.toYamlNodeWithStyle(pv, nestedNodeStyle)
|
2024-01-17 14:26:33 +00:00
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
content = append(content, &node)
|
|
|
|
content = append(content, c)
|
|
|
|
}
|
|
|
|
|
2024-02-15 15:03:19 +00:00
|
|
|
return &yaml.Node{Kind: yaml.MappingNode, Content: content, Style: style}, nil
|
2024-01-17 14:26:33 +00:00
|
|
|
case dyn.KindSequence:
|
2024-02-15 15:03:19 +00:00
|
|
|
seq, _ := v.AsSequence()
|
2024-01-17 14:26:33 +00:00
|
|
|
content := make([]*yaml.Node, 0)
|
2024-02-15 15:03:19 +00:00
|
|
|
for _, item := range seq {
|
|
|
|
node, err := s.toYamlNodeWithStyle(item, style)
|
2024-01-17 14:26:33 +00:00
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
content = append(content, node)
|
|
|
|
}
|
2024-02-15 15:03:19 +00:00
|
|
|
return &yaml.Node{Kind: yaml.SequenceNode, Content: content, Style: style}, nil
|
2024-01-17 14:26:33 +00:00
|
|
|
case dyn.KindNil:
|
2024-02-15 15:03:19 +00:00
|
|
|
return &yaml.Node{Kind: yaml.ScalarNode, Value: "null", Style: style}, nil
|
2024-01-17 14:26:33 +00:00
|
|
|
case dyn.KindString:
|
|
|
|
// If the string is a scalar value (bool, int, float and etc.), we want to quote it.
|
|
|
|
if isScalarValueInString(v) {
|
|
|
|
return &yaml.Node{Kind: yaml.ScalarNode, Value: v.MustString(), Style: yaml.DoubleQuotedStyle}, nil
|
|
|
|
}
|
2024-02-15 15:03:19 +00:00
|
|
|
return &yaml.Node{Kind: yaml.ScalarNode, Value: v.MustString(), Style: style}, nil
|
2024-01-17 14:26:33 +00:00
|
|
|
case dyn.KindBool:
|
2024-02-15 15:03:19 +00:00
|
|
|
return &yaml.Node{Kind: yaml.ScalarNode, Value: fmt.Sprint(v.MustBool()), Style: style}, nil
|
2024-01-17 14:26:33 +00:00
|
|
|
case dyn.KindInt:
|
2024-02-15 15:03:19 +00:00
|
|
|
return &yaml.Node{Kind: yaml.ScalarNode, Value: fmt.Sprint(v.MustInt()), Style: style}, nil
|
2024-01-17 14:26:33 +00:00
|
|
|
case dyn.KindFloat:
|
2024-02-15 15:03:19 +00:00
|
|
|
return &yaml.Node{Kind: yaml.ScalarNode, Value: fmt.Sprint(v.MustFloat()), Style: style}, nil
|
2024-01-17 14:26:33 +00:00
|
|
|
case dyn.KindTime:
|
2024-08-29 13:02:34 +00:00
|
|
|
return &yaml.Node{Kind: yaml.ScalarNode, Value: v.MustTime().String(), Style: style}, nil
|
2024-01-17 14:26:33 +00:00
|
|
|
default:
|
|
|
|
// Panic because we only want to deal with known types.
|
|
|
|
panic(fmt.Sprintf("invalid kind: %d", v.Kind()))
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2024-02-15 15:03:19 +00:00
|
|
|
func (s *saver) hasStyle(key string) (yaml.Style, bool) {
|
|
|
|
style, ok := s.nodesWithStyle[key]
|
|
|
|
return style, ok
|
|
|
|
}
|
|
|
|
|
2024-01-17 14:26:33 +00:00
|
|
|
func isScalarValueInString(v dyn.Value) bool {
|
|
|
|
if v.Kind() != dyn.KindString {
|
|
|
|
return false
|
|
|
|
}
|
|
|
|
|
|
|
|
// Parse value of the string and check if it's a scalar value.
|
|
|
|
// If it's a scalar value, we want to quote it.
|
|
|
|
switch v.MustString() {
|
|
|
|
case "true", "false":
|
|
|
|
return true
|
2024-09-11 09:49:58 +00:00
|
|
|
case "":
|
|
|
|
return true
|
2024-01-17 14:26:33 +00:00
|
|
|
default:
|
|
|
|
_, err := parseNumber(v.MustString())
|
|
|
|
return err == nil
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
func parseNumber(s string) (any, error) {
|
|
|
|
if i, err := strconv.ParseInt(s, 0, 64); err == nil {
|
|
|
|
return i, nil
|
|
|
|
}
|
|
|
|
|
|
|
|
if f, err := strconv.ParseFloat(s, 64); err == nil {
|
|
|
|
return f, nil
|
|
|
|
}
|
|
|
|
return nil, fmt.Errorf("invalid number: %s", s)
|
|
|
|
}
|