Skip to content

Commit 9af71d1

Browse files
Add EncryptReader
Add EncryptReader function to encrypt io.Reader (pull-based encryption) and use it in to get rid of io.Copy. Updates #644 Signed-off-by: Alexander Yastrebov <[email protected]>
1 parent 15153e6 commit 9af71d1

File tree

4 files changed

+145
-31
lines changed

4 files changed

+145
-31
lines changed

age.go

Lines changed: 51 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,8 @@ type Stanza struct {
113113
const fileKeySize = 16
114114
const streamNonceSize = 16
115115

116+
var errNoRecipients = errors.New("no recipients specified")
117+
116118
// Encrypt encrypts a file to one or more recipients.
117119
//
118120
// Writes to the returned WriteCloser are encrypted and written to dst as an age
@@ -122,49 +124,81 @@ const streamNonceSize = 16
122124
// be encrypted and flushed to dst.
123125
func Encrypt(dst io.Writer, recipients ...Recipient) (io.WriteCloser, error) {
124126
if len(recipients) == 0 {
125-
return nil, errors.New("no recipients specified")
127+
return nil, errNoRecipients
126128
}
127129

128130
fileKey := make([]byte, fileKeySize)
129131
if _, err := rand.Read(fileKey); err != nil {
130132
return nil, err
131133
}
132134

135+
if err := writeHeader(dst, recipients, fileKey); err != nil {
136+
return nil, fmt.Errorf("failed to write header: %v", err)
137+
}
138+
139+
nonce := make([]byte, streamNonceSize)
140+
if _, err := rand.Read(nonce); err != nil {
141+
return nil, err
142+
}
143+
if _, err := dst.Write(nonce); err != nil {
144+
return nil, fmt.Errorf("failed to write nonce: %v", err)
145+
}
146+
147+
return stream.NewWriter(streamKey(fileKey, nonce), dst)
148+
}
149+
150+
// EncryptReader encrypts a reader to one or more recipients.
151+
//
152+
// It encrypts src into dst as an age file. Every recipient will be able to decrypt the file.
153+
func EncryptReader(src io.Reader, dst io.Writer, recipients ...Recipient) error {
154+
if len(recipients) == 0 {
155+
return errNoRecipients
156+
}
157+
158+
fileKey := make([]byte, fileKeySize)
159+
if _, err := rand.Read(fileKey); err != nil {
160+
return err
161+
}
162+
163+
if err := writeHeader(dst, recipients, fileKey); err != nil {
164+
return fmt.Errorf("failed to write header: %v", err)
165+
}
166+
167+
nonce := make([]byte, streamNonceSize)
168+
if _, err := rand.Read(nonce); err != nil {
169+
return err
170+
}
171+
if _, err := dst.Write(nonce); err != nil {
172+
return fmt.Errorf("failed to write nonce: %v", err)
173+
}
174+
175+
return stream.Encrypt(streamKey(fileKey, nonce), src, dst)
176+
}
177+
178+
func writeHeader(dst io.Writer, recipients []Recipient, fileKey []byte) error {
133179
hdr := &format.Header{}
134180
var labels []string
135181
for i, r := range recipients {
136182
stanzas, l, err := wrapWithLabels(r, fileKey)
137183
if err != nil {
138-
return nil, fmt.Errorf("failed to wrap key for recipient #%d: %v", i, err)
184+
return fmt.Errorf("failed to wrap key for recipient #%d: %v", i, err)
139185
}
140186
sort.Strings(l)
141187
if i == 0 {
142188
labels = l
143189
} else if !slicesEqual(labels, l) {
144-
return nil, fmt.Errorf("incompatible recipients")
190+
return fmt.Errorf("incompatible recipients")
145191
}
146192
for _, s := range stanzas {
147193
hdr.Recipients = append(hdr.Recipients, (*format.Stanza)(s))
148194
}
149195
}
150196
if mac, err := headerMAC(fileKey, hdr); err != nil {
151-
return nil, fmt.Errorf("failed to compute header MAC: %v", err)
197+
return fmt.Errorf("failed to compute header MAC: %v", err)
152198
} else {
153199
hdr.MAC = mac
154200
}
155-
if err := hdr.Marshal(dst); err != nil {
156-
return nil, fmt.Errorf("failed to write header: %v", err)
157-
}
158-
159-
nonce := make([]byte, streamNonceSize)
160-
if _, err := rand.Read(nonce); err != nil {
161-
return nil, err
162-
}
163-
if _, err := dst.Write(nonce); err != nil {
164-
return nil, fmt.Errorf("failed to write nonce: %v", err)
165-
}
166-
167-
return stream.NewWriter(streamKey(fileKey, nonce), dst)
201+
return hdr.Marshal(dst)
168202
}
169203

170204
func wrapWithLabels(r Recipient, fileKey []byte) (s []*Stanza, labels []string, err error) {

cmd/age/age.go

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -402,16 +402,10 @@ func encrypt(recipients []age.Recipient, in io.Reader, out io.Writer, withArmor
402402
}()
403403
out = a
404404
}
405-
w, err := age.Encrypt(out, recipients...)
405+
err := age.EncryptReader(in, out, recipients...)
406406
if err != nil {
407407
errorf("%v", err)
408408
}
409-
if _, err := io.Copy(w, in); err != nil {
410-
errorf("%v", err)
411-
}
412-
if err := w.Close(); err != nil {
413-
errorf("%v", err)
414-
}
415409
}
416410

417411
// crlfMangledIntro and utf16MangledIntro are the intro lines of the age format

internal/stream/stream.go

Lines changed: 53 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -215,16 +215,62 @@ const (
215215
)
216216

217217
func (w *Writer) flushChunk(last bool) error {
218-
if !last && len(w.unwritten) != ChunkSize {
219-
panic("stream: internal error: flush called with partial chunk")
218+
err := writeChunk(w.dst, w.a, &w.nonce, w.unwritten, last)
219+
w.unwritten = w.buf[:0]
220+
return err
221+
}
222+
223+
func writeChunk(dst io.Writer, aead cipher.AEAD, nonce *[chacha20poly1305.NonceSize]byte, plaintext []byte, last bool) error {
224+
if !last && len(plaintext) != ChunkSize {
225+
panic("stream: internal error: writeChunk called with partial chunk")
220226
}
221227

222228
if last {
223-
setLastChunkFlag(&w.nonce)
229+
setLastChunkFlag(nonce)
224230
}
225-
buf := w.a.Seal(w.buf[:0], w.nonce[:], w.unwritten, nil)
226-
_, err := w.dst.Write(buf)
227-
w.unwritten = w.buf[:0]
228-
incNonce(&w.nonce)
231+
buf := aead.Seal(plaintext[:0], nonce[:], plaintext, nil)
232+
_, err := dst.Write(buf)
233+
incNonce(nonce)
229234
return err
230235
}
236+
237+
func Encrypt(key []byte, src io.Reader, dst io.Writer) error {
238+
aead, err := chacha20poly1305.New(key)
239+
if err != nil {
240+
return err
241+
}
242+
243+
var nonce [chacha20poly1305.NonceSize]byte
244+
var bufs [2][encChunkSize]byte
245+
prevPending := false
246+
247+
for prev, curr := 0, 1; ; prev, curr = prev^1, curr^1 {
248+
n, err := io.ReadFull(src, bufs[curr][:ChunkSize])
249+
if err == io.EOF {
250+
if prevPending {
251+
return writeChunk(dst, aead, &nonce, bufs[prev][:ChunkSize], lastChunk)
252+
} else { // empty payload
253+
return writeChunk(dst, aead, &nonce, bufs[prev][:0], lastChunk)
254+
}
255+
} else if err == io.ErrUnexpectedEOF {
256+
if prevPending {
257+
err := writeChunk(dst, aead, &nonce, bufs[prev][:ChunkSize], notLastChunk)
258+
if err != nil {
259+
return err
260+
}
261+
}
262+
return writeChunk(dst, aead, &nonce, bufs[curr][:n], lastChunk)
263+
} else if err != nil {
264+
return err
265+
}
266+
267+
if prevPending {
268+
err := writeChunk(dst, aead, &nonce, bufs[prev][:ChunkSize], notLastChunk)
269+
if err != nil {
270+
return err
271+
}
272+
} else {
273+
prevPending = true
274+
}
275+
}
276+
}

internal/stream/stream_test.go

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
"bytes"
99
"crypto/rand"
1010
"fmt"
11+
"io"
1112
"testing"
1213

1314
"filippo.io/age/internal/stream"
@@ -91,3 +92,42 @@ func testRoundTrip(t *testing.T, stepSize, length int) {
9192
n += nn
9293
}
9394
}
95+
96+
func TestEncrypt(t *testing.T) {
97+
for _, mul := range []int{0, 1, 2, 3} {
98+
for _, add := range []int{0, 1, 2, 3, stream.ChunkSize - 1} {
99+
length := mul*stream.ChunkSize + add
100+
101+
t.Run(fmt.Sprintf("length=%d", length), func(t *testing.T) {
102+
src := make([]byte, length)
103+
if _, err := rand.Read(src); err != nil {
104+
t.Fatal(err)
105+
}
106+
buf := &bytes.Buffer{}
107+
key := make([]byte, chacha20poly1305.KeySize)
108+
if _, err := rand.Read(key); err != nil {
109+
t.Fatal(err)
110+
}
111+
112+
err := stream.Encrypt(key, bytes.NewReader(src), buf)
113+
if err != nil {
114+
t.Fatal(err)
115+
}
116+
117+
r, err := stream.NewReader(key, buf)
118+
if err != nil {
119+
t.Fatal(err)
120+
}
121+
122+
dec, err := io.ReadAll(r)
123+
if err != nil {
124+
t.Fatal(err)
125+
}
126+
127+
if !bytes.Equal(src, dec) {
128+
t.Errorf("Wrong decrypted data")
129+
}
130+
})
131+
}
132+
}
133+
}

0 commit comments

Comments
 (0)