This commit is contained in:
Denis Bilenko 2025-03-04 12:19:11 +01:00
parent ab1667814d
commit 98d2f38129
1 changed files with 27 additions and 14 deletions

View File

@ -7,6 +7,7 @@ import (
"context" "context"
"crypto/sha256" "crypto/sha256"
"encoding/base64" "encoding/base64"
"errors"
"fmt" "fmt"
"io" "io"
"os" "os"
@ -147,13 +148,13 @@ func PatchWheel(ctx context.Context, path, outputDir string) (string, error) {
// Target wheel doesn't exist, proceed with patching // Target wheel doesn't exist, proceed with patching
// Create a temporary file in the same directory with a unique name // Create a temporary file in the same directory with a unique name
tmpFile := outpath + fmt.Sprintf(".tmp%d", os.Getpid()) tmpFilename := outpath + fmt.Sprintf(".tmp%d", os.Getpid())
needRemoval := true needRemoval := true
defer func() { defer func() {
if needRemoval { if needRemoval {
_ = os.Remove(tmpFile) _ = os.Remove(tmpFilename)
} }
}() }()
@ -178,12 +179,6 @@ func PatchWheel(ctx context.Context, path, outputDir string) (string, error) {
} }
defer metadataReader.Close() defer metadataReader.Close()
recordReader, err := recordFile.Open()
if err != nil {
return "", err
}
defer recordReader.Close()
newMetadata, err := patchMetadata(metadataReader, wheelInfo.Version, newVersion) newMetadata, err := patchMetadata(metadataReader, wheelInfo.Version, newVersion)
if err != nil { if err != nil {
return "", err return "", err
@ -203,17 +198,26 @@ func PatchWheel(ctx context.Context, path, outputDir string) (string, error) {
return "", fmt.Errorf("unexpected dist-info directory format: %s", oldDistInfoPrefix) return "", fmt.Errorf("unexpected dist-info directory format: %s", oldDistInfoPrefix)
} }
recordReader, err := recordFile.Open()
if err != nil {
return "", err
}
defer recordReader.Close()
newRecord, err := patchRecord(recordReader, oldDistInfoPrefix, newDistInfoPrefix, metadataHash, metadataSize) newRecord, err := patchRecord(recordReader, oldDistInfoPrefix, newDistInfoPrefix, metadataHash, metadataSize)
if err != nil { if err != nil {
return "", err return "", err
} }
outFile, err := os.Create(tmpFile) outFile, err := os.Create(tmpFilename)
if err != nil { if err != nil {
return "", err return "", err
} }
defer outFile.Close() defer outFile.Close()
metadataUpdated := 0
recordUpdated := 0
zipw := zip.NewWriter(outFile) zipw := zip.NewWriter(outFile)
for _, f := range r.File { for _, f := range r.File {
// If the file is inside the old dist-info directory, update its name. // If the file is inside the old dist-info directory, update its name.
@ -237,17 +241,18 @@ func PatchWheel(ctx context.Context, path, outputDir string) (string, error) {
return "", err return "", err
} }
// For METADATA and RECORD files, write the modified content. if f.Name == metadataFile.Name {
if strings.HasSuffix(f.Name, "METADATA") && strings.HasPrefix(f.Name, oldDistInfoPrefix) {
_, err = writer.Write(newMetadata) _, err = writer.Write(newMetadata)
if err != nil { if err != nil {
return "", err return "", err
} }
} else if strings.HasSuffix(f.Name, "RECORD") && strings.HasPrefix(f.Name, oldDistInfoPrefix) { metadataUpdated += 1
} else if f.Name == recordFile.Name {
_, err = writer.Write(newRecord) _, err = writer.Write(newRecord)
if err != nil { if err != nil {
return "", err return "", err
} }
recordUpdated += 1
} else { } else {
rc, err := f.Open() rc, err := f.Open()
if err != nil { if err != nil {
@ -270,10 +275,18 @@ func PatchWheel(ctx context.Context, path, outputDir string) (string, error) {
outFile.Close() outFile.Close()
if err := os.Rename(tmpFile, outpath); err != nil { if metadataUpdated != 1 {
return "", err return "", errors.New("Could not update METADATA")
} }
if recordUpdated != 1 {
return "", errors.New("Could not update RECORD")
}
if err := os.Rename(tmpFilename, outpath); err != nil {
return "", err
}
needRemoval = false needRemoval = false
return outpath, nil return outpath, nil
} }