diff --git a/libs/patchwheel/patch.go b/libs/patchwheel/patch.go index 36f380eac..956871eb6 100644 --- a/libs/patchwheel/patch.go +++ b/libs/patchwheel/patch.go @@ -18,10 +18,7 @@ import ( "github.com/databricks/cli/libs/log" ) -const ( - versionKey = "Version:" - nameKey = "Name:" -) +var versionKey []byte = []byte("Version:") // findFiles returns a slice with a *zip.File for every filename in the arguments slice. // The order of the return value matches the order of the arguments. @@ -51,15 +48,17 @@ func patchMetadata(r io.Reader, oldVersion, newVersion string) ([]byte, error) { scanner := bufio.NewScanner(r) var buf bytes.Buffer for scanner.Scan() { - line := scanner.Text() - if versionValue, ok := strings.CutPrefix(line, versionKey); ok { - foundVersion := strings.TrimSpace(versionValue) + line := scanner.Bytes() + if versionValue, ok := bytes.CutPrefix(line, versionKey); ok { + foundVersion := string(bytes.TrimSpace(versionValue)) if foundVersion != oldVersion { - return nil, fmt.Errorf("Unexpected version in METADATA: %s (expected %s)", strings.TrimSpace(line), oldVersion) + return nil, fmt.Errorf("Unexpected version in METADATA: %s (expected %s)", strings.TrimSpace(string(line)), oldVersion) } - line = versionKey + newVersion + buf.WriteString(string(versionKey) + newVersion) + } else { + buf.Write(line) + buf.WriteString("\n") } - buf.WriteString(line + "\n") } if err := scanner.Err(); err != nil { return nil, err @@ -70,34 +69,42 @@ func patchMetadata(r io.Reader, oldVersion, newVersion string) ([]byte, error) { // patchRecord updates RECORD content: it replaces the old dist-info prefix with the new one // in all file paths and, for the METADATA entry, updates the hash and size. func patchRecord(r io.Reader, oldDistInfoPrefix, newDistInfoPrefix, metadataHash string, metadataSize int) ([]byte, error) { + metadataPath := newDistInfoPrefix + "METADATA" scanner := bufio.NewScanner(r) - var newLines []string + var buf bytes.Buffer for scanner.Scan() { - line := scanner.Text() - if strings.TrimSpace(line) == "" { + line := scanner.Bytes() + if len(bytes.TrimSpace(line)) == 0 { continue } - parts := strings.Split(line, ",") + + parts := strings.Split(string(line), ",") + if len(parts) < 3 { // If the line doesn't have enough parts, preserve it as-is - newLines = append(newLines, line) + buf.Write(line) + buf.WriteString("\n") continue } + origPath := parts[0] - if strings.HasPrefix(origPath, oldDistInfoPrefix) { - parts[0] = newDistInfoPrefix + origPath[len(oldDistInfoPrefix):] + pathSuffix, hasDistPrefix := strings.CutPrefix(origPath, oldDistInfoPrefix) + if hasDistPrefix { + parts[0] = newDistInfoPrefix + pathSuffix } - // For the METADATA file entry, update hash and size. - if strings.HasSuffix(parts[0], "METADATA") { + + if metadataPath == parts[0] { parts[1] = "sha256=" + metadataHash parts[2] = strconv.Itoa(metadataSize) } - newLines = append(newLines, strings.Join(parts, ",")) + + buf.WriteString(strings.Join(parts, ",") + "\n") } if err := scanner.Err(); err != nil { return nil, err } - return []byte(strings.Join(newLines, "\n") + "\n"), nil + buf.WriteString("\n") + return buf.Bytes(), nil } // PatchWheel patches a Python wheel file by updating its version in METADATA and RECORD.