Skip to content

Commit c5e0077

Browse files
zstd: Fix encoder changing dictionary with same ID (#1135)
* zstd: Fix encoder changing dictionary with same ID Fix crash from `Enoder.ResetWithOptions` when replacing a dictionary with another one with same. Dictionary tables would not get reset appropriately when there was a new ID. Leading to references outside of valid entries. * Fix changing non-dict <-> dict. * Update zstd/encoder.go Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> * Better tests. --------- Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
1 parent fd3f23e commit c5e0077

File tree

8 files changed

+191
-15
lines changed

8 files changed

+191
-15
lines changed

zstd/blockenc.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ func (b *blockEnc) initNewEncode() {
7878
b.recentOffsets = [3]uint32{1, 4, 8}
7979
b.litEnc.Reuse = huff0.ReusePolicyNone
8080
b.coders.setPrev(nil, nil, nil)
81+
b.dictLitEnc = nil
8182
}
8283

8384
// reset will reset the block for a new encode, but in the same stream,

zstd/dict_test.go

Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -638,3 +638,153 @@ func TestDecoderRawDict(t *testing.T) {
638638
t.Errorf("mismatch: got %q, wanted %q", out, ref)
639639
}
640640
}
641+
642+
// TestEncoderDictResetDifferentContent verifies that ResetWithOptions correctly
643+
// handles switching between raw dicts that share the same ID but have different
644+
// content lengths. Previously, the encoder cached dict tables by ID only, so a
645+
// shorter dict reusing the same ID would leave stale table entries pointing
646+
// beyond the new (shorter) history, causing an out-of-bounds panic in matchlen.
647+
func TestEncoderDictResetDifferentContent(t *testing.T) {
648+
// Two raw dicts: same ID, different content lengths.
649+
longDict := make([]byte, 700)
650+
for i := range longDict {
651+
longDict[i] = byte(i * 3)
652+
}
653+
shortDict := make([]byte, 120)
654+
for i := range shortDict {
655+
shortDict[i] = byte(i * 7)
656+
}
657+
658+
const dictID = 42
659+
// Payload reuses bytes from the tail of longDict (beyond shortDict's length).
660+
// This makes stale dict table entries match during encoding, triggering the
661+
// out-of-bounds access when the encoder uses the stale offset.
662+
payload := make([]byte, 200)
663+
copy(payload, longDict[500:])
664+
665+
for level := SpeedFastest; level < speedLast; level++ {
666+
t.Run(level.String(), func(t *testing.T) {
667+
enc, err := NewWriter(nil, WithEncoderConcurrency(1), WithEncoderLevel(level), WithEncoderDictRaw(dictID, longDict))
668+
if err != nil {
669+
t.Fatal(err)
670+
}
671+
672+
// Encode with long dict first to populate table entries at high offsets.
673+
enc.EncodeAll(payload, nil)
674+
675+
// Switch to shorter dict with same ID. This must rebuild the tables.
676+
if err := enc.ResetWithOptions(nil, WithEncoderDictRaw(dictID, shortDict)); err != nil {
677+
t.Fatal(err)
678+
}
679+
compressed := enc.EncodeAll(payload, nil)
680+
681+
// Verify round-trip with matching dict.
682+
dec, err := NewReader(nil, WithDecoderConcurrency(1), WithDecoderDictRaw(dictID, shortDict))
683+
if err != nil {
684+
t.Fatal(err)
685+
}
686+
defer dec.Close()
687+
got, err := dec.DecodeAll(compressed, nil)
688+
if err != nil {
689+
t.Fatal(err)
690+
}
691+
if !bytes.Equal(got, payload) {
692+
t.Errorf("round-trip mismatch: got %q, want %q", got, payload)
693+
}
694+
})
695+
}
696+
}
697+
698+
// TestEncoderDictAddViaReset verifies that adding/removing a dict via
699+
// ResetWithOptions works (requires recreating the encoder type).
700+
func TestEncoderDictAddViaReset(t *testing.T) {
701+
dict := make([]byte, 120)
702+
for i := range dict {
703+
dict[i] = byte(i)
704+
}
705+
payload := []byte("hello world, this is a test payload!!")
706+
707+
for level := SpeedFastest; level < speedLast; level++ {
708+
t.Run("nil-to-dict/"+level.String(), func(t *testing.T) {
709+
enc, err := NewWriter(nil, WithEncoderConcurrency(1), WithEncoderLevel(level))
710+
if err != nil {
711+
t.Fatal(err)
712+
}
713+
if err := enc.ResetWithOptions(nil, WithEncoderDictRaw(42, dict)); err != nil {
714+
t.Fatal(err)
715+
}
716+
compressed := enc.EncodeAll(payload, nil)
717+
718+
dec, err := NewReader(nil, WithDecoderConcurrency(1), WithDecoderDictRaw(42, dict))
719+
if err != nil {
720+
t.Fatal(err)
721+
}
722+
defer dec.Close()
723+
got, err := dec.DecodeAll(compressed, nil)
724+
if err != nil {
725+
t.Fatal(err)
726+
}
727+
if !bytes.Equal(got, payload) {
728+
t.Errorf("round-trip mismatch: got %q, want %q", got, payload)
729+
}
730+
})
731+
732+
t.Run("dict-to-nil/"+level.String(), func(t *testing.T) {
733+
enc, err := NewWriter(nil, WithEncoderConcurrency(1), WithEncoderLevel(level), WithEncoderDictRaw(42, dict))
734+
if err != nil {
735+
t.Fatal(err)
736+
}
737+
if err := enc.ResetWithOptions(nil, WithEncoderDictDelete()); err != nil {
738+
t.Fatal(err)
739+
}
740+
compressed := enc.EncodeAll(payload, nil)
741+
742+
dec, err := NewReader(nil, WithDecoderConcurrency(1))
743+
if err != nil {
744+
t.Fatal(err)
745+
}
746+
defer dec.Close()
747+
got, err := dec.DecodeAll(compressed, nil)
748+
if err != nil {
749+
t.Fatal(err)
750+
}
751+
if !bytes.Equal(got, payload) {
752+
t.Errorf("round-trip mismatch: got %q, want %q", got, payload)
753+
}
754+
})
755+
756+
t.Run("streaming-dict-to-nil/"+level.String(), func(t *testing.T) {
757+
var buf bytes.Buffer
758+
enc, err := NewWriter(&buf, WithEncoderConcurrency(2), WithEncoderLevel(level), WithEncoderDictRaw(42, dict))
759+
if err != nil {
760+
t.Fatal(err)
761+
}
762+
enc.Close()
763+
764+
buf.Reset()
765+
if err := enc.ResetWithOptions(&buf, WithEncoderDictDelete()); err != nil {
766+
t.Fatal(err)
767+
}
768+
_, err = enc.Write(payload)
769+
if err != nil {
770+
t.Fatal(err)
771+
}
772+
if err := enc.Close(); err != nil {
773+
t.Fatal(err)
774+
}
775+
776+
dec, err := NewReader(nil, WithDecoderConcurrency(1))
777+
if err != nil {
778+
t.Fatal(err)
779+
}
780+
defer dec.Close()
781+
got, err := dec.DecodeAll(buf.Bytes(), nil)
782+
if err != nil {
783+
t.Fatal(err)
784+
}
785+
if !bytes.Equal(got, payload) {
786+
t.Errorf("round-trip mismatch: got %q, want %q", got, payload)
787+
}
788+
})
789+
}
790+
}

zstd/enc_base.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ type fastBase struct {
2121
crc *xxhash.Digest
2222
tmp [8]byte
2323
blk *blockEnc
24-
lastDictID uint32
24+
lastDict *dict
2525
lowMem bool
2626
}
2727

zstd/enc_best.go

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -479,10 +479,13 @@ func (e *bestFastEncoder) Reset(d *dict, singleBlock bool) {
479479
if d == nil {
480480
return
481481
}
482+
dictChanged := d != e.lastDict
482483
// Init or copy dict table
483-
if len(e.dictTable) != len(e.table) || d.id != e.lastDictID {
484+
if len(e.dictTable) != len(e.table) || dictChanged {
484485
if len(e.dictTable) != len(e.table) {
485486
e.dictTable = make([]prevEntry, len(e.table))
487+
} else {
488+
clear(e.dictTable)
486489
}
487490
end := int32(len(d.content)) - 8 + e.maxMatchOff
488491
for i := e.maxMatchOff; i < end; i += 4 {
@@ -510,13 +513,14 @@ func (e *bestFastEncoder) Reset(d *dict, singleBlock bool) {
510513
offset: i + 3,
511514
}
512515
}
513-
e.lastDictID = d.id
514516
}
515517

516-
// Init or copy dict table
517-
if len(e.dictLongTable) != len(e.longTable) || d.id != e.lastDictID {
518+
// Init or copy dict long table
519+
if len(e.dictLongTable) != len(e.longTable) || dictChanged {
518520
if len(e.dictLongTable) != len(e.longTable) {
519521
e.dictLongTable = make([]prevEntry, len(e.longTable))
522+
} else {
523+
clear(e.dictLongTable)
520524
}
521525
if len(d.content) >= 8 {
522526
cv := load6432(d.content, 0)
@@ -538,8 +542,8 @@ func (e *bestFastEncoder) Reset(d *dict, singleBlock bool) {
538542
off++
539543
}
540544
}
541-
e.lastDictID = d.id
542545
}
546+
e.lastDict = d
543547
// Reset table to initial state
544548
copy(e.longTable[:], e.dictLongTable)
545549

zstd/enc_better.go

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1102,10 +1102,13 @@ func (e *betterFastEncoderDict) Reset(d *dict, singleBlock bool) {
11021102
if d == nil {
11031103
return
11041104
}
1105+
dictChanged := d != e.lastDict
11051106
// Init or copy dict table
1106-
if len(e.dictTable) != len(e.table) || d.id != e.lastDictID {
1107+
if len(e.dictTable) != len(e.table) || dictChanged {
11071108
if len(e.dictTable) != len(e.table) {
11081109
e.dictTable = make([]tableEntry, len(e.table))
1110+
} else {
1111+
clear(e.dictTable)
11091112
}
11101113
end := int32(len(d.content)) - 8 + e.maxMatchOff
11111114
for i := e.maxMatchOff; i < end; i += 4 {
@@ -1133,14 +1136,15 @@ func (e *betterFastEncoderDict) Reset(d *dict, singleBlock bool) {
11331136
offset: i + 3,
11341137
}
11351138
}
1136-
e.lastDictID = d.id
11371139
e.allDirty = true
11381140
}
11391141

1140-
// Init or copy dict table
1141-
if len(e.dictLongTable) != len(e.longTable) || d.id != e.lastDictID {
1142+
// Init or copy dict long table
1143+
if len(e.dictLongTable) != len(e.longTable) || dictChanged {
11421144
if len(e.dictLongTable) != len(e.longTable) {
11431145
e.dictLongTable = make([]prevEntry, len(e.longTable))
1146+
} else {
1147+
clear(e.dictLongTable)
11441148
}
11451149
if len(d.content) >= 8 {
11461150
cv := load6432(d.content, 0)
@@ -1162,9 +1166,9 @@ func (e *betterFastEncoderDict) Reset(d *dict, singleBlock bool) {
11621166
off++
11631167
}
11641168
}
1165-
e.lastDictID = d.id
11661169
e.allDirty = true
11671170
}
1171+
e.lastDict = d
11681172

11691173
// Reset table to initial state
11701174
{

zstd/enc_dfast.go

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1040,15 +1040,18 @@ func (e *doubleFastEncoder) Reset(d *dict, singleBlock bool) {
10401040
// ResetDict will reset and set a dictionary if not nil
10411041
func (e *doubleFastEncoderDict) Reset(d *dict, singleBlock bool) {
10421042
allDirty := e.allDirty
1043+
dictChanged := d != e.lastDict
10431044
e.fastEncoderDict.Reset(d, singleBlock)
10441045
if d == nil {
10451046
return
10461047
}
10471048

10481049
// Init or copy dict table
1049-
if len(e.dictLongTable) != len(e.longTable) || d.id != e.lastDictID {
1050+
if len(e.dictLongTable) != len(e.longTable) || dictChanged {
10501051
if len(e.dictLongTable) != len(e.longTable) {
10511052
e.dictLongTable = make([]tableEntry, len(e.longTable))
1053+
} else {
1054+
clear(e.dictLongTable)
10521055
}
10531056
if len(d.content) >= 8 {
10541057
cv := load6432(d.content, 0)
@@ -1065,7 +1068,6 @@ func (e *doubleFastEncoderDict) Reset(d *dict, singleBlock bool) {
10651068
}
10661069
}
10671070
}
1068-
e.lastDictID = d.id
10691071
allDirty = true
10701072
}
10711073
// Reset table to initial state

zstd/enc_fast.go

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -805,9 +805,11 @@ func (e *fastEncoderDict) Reset(d *dict, singleBlock bool) {
805805
}
806806

807807
// Init or copy dict table
808-
if len(e.dictTable) != len(e.table) || d.id != e.lastDictID {
808+
if len(e.dictTable) != len(e.table) || d != e.lastDict {
809809
if len(e.dictTable) != len(e.table) {
810810
e.dictTable = make([]tableEntry, len(e.table))
811+
} else {
812+
clear(e.dictTable)
811813
}
812814
if true {
813815
end := e.maxMatchOff + int32(len(d.content)) - 8
@@ -827,7 +829,7 @@ func (e *fastEncoderDict) Reset(d *dict, singleBlock bool) {
827829
}
828830
}
829831
}
830-
e.lastDictID = d.id
832+
e.lastDict = d
831833
e.allDirty = true
832834
}
833835

zstd/encoder.go

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,11 +138,18 @@ func (e *Encoder) Reset(w io.Writer) {
138138
func (e *Encoder) ResetWithOptions(w io.Writer, opts ...EOption) error {
139139
e.o.resetOpt = true
140140
defer func() { e.o.resetOpt = false }()
141+
hadDict := e.o.dict != nil
141142
for _, o := range opts {
142143
if err := o(&e.o); err != nil {
143144
return err
144145
}
145146
}
147+
hasDict := e.o.dict != nil
148+
if hadDict != hasDict {
149+
// Dict presence changed — encoder type must be recreated.
150+
e.state.encoder = nil
151+
e.init = sync.Once{}
152+
}
146153
e.Reset(w)
147154
return nil
148155
}
@@ -448,6 +455,12 @@ func (e *Encoder) Close() error {
448455
if s.encoder == nil {
449456
return nil
450457
}
458+
if s.w == nil {
459+
if len(s.filling) == 0 && !s.headerWritten && !s.eofWritten && s.nInput == 0 {
460+
return nil
461+
}
462+
return errors.New("zstd: encoder has no writer")
463+
}
451464
err := e.nextBlock(true)
452465
if err != nil {
453466
if errors.Is(s.err, ErrEncoderClosed) {

0 commit comments

Comments
 (0)