Skip to content

Commit ca292c6

Browse files
committed
Use utiltest helpers in new tests
1 parent 98ea44d commit ca292c6

2 files changed

Lines changed: 95 additions & 108 deletions

File tree

agentconfig/agentconfig_test.go

Lines changed: 50 additions & 105 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@ import (
3333
"sync"
3434
"testing"
3535
"time"
36+
37+
"github.com/GoogleCloudPlatform/osconfig/util/utiltest"
3638
)
3739

3840
// setupMockMetadataServer starts an httptest.Server with the provided handler and overrides the GCE_METADATA_HOST environment variable.
@@ -42,8 +44,7 @@ func setupMockMetadataServer(t *testing.T, handler http.HandlerFunc) *httptest.S
4244
ts := httptest.NewServer(handler)
4345
t.Cleanup(ts.Close)
4446

45-
rollback := OverrideEnv(t, "GCE_METADATA_HOST", strings.TrimPrefix(ts.URL, "http://"))
46-
t.Cleanup(rollback)
47+
utiltest.OverrideEnv(t, "GCE_METADATA_HOST", strings.TrimPrefix(ts.URL, "http://"))
4748

4849
return ts
4950
}
@@ -238,7 +239,7 @@ func TestSetConfigDefaultValues(t *testing.T) {
238239
// keep polling for real changes. This test verifies that the agent correctly
239240
// continues to wait until its internal timeout runs out, and then exits normally.
240241
func TestWatchConfigUnchangedConfigTimeout(t *testing.T) {
241-
defer OverrideWatchConfigTimeouts(1*time.Millisecond, 10*time.Millisecond)()
242+
OverrideWatchConfigTimeouts(t, 1*time.Millisecond, 10*time.Millisecond)
242243

243244
var count int
244245
setupMockMetadataServer(t, func(w http.ResponseWriter, r *http.Request) {
@@ -267,17 +268,17 @@ func TestWatchConfigUnchangedConfigTimeout(t *testing.T) {
267268
// up to a limit of 12 times before giving up and reporting an error.
268269
func TestWatchConfigWebErrorLimit(t *testing.T) {
269270
lEtag.set("0")
270-
defer OverrideWatchConfigTimeouts(1*time.Millisecond, 1*time.Second)()
271-
defer OverrideEnv(t, "GCE_METADATA_HOST", "mock-host")()
271+
OverrideWatchConfigTimeouts(t, 1*time.Millisecond, 1*time.Second)
272+
utiltest.OverrideEnv(t, "GCE_METADATA_HOST", "mock-host")
272273

273274
mockNetErr := &net.OpError{
274275
Op: "dial",
275276
Net: "tcp",
276277
Err: errors.New("connection refused"),
277278
}
278-
defer MockDefaultClientTransport(t, func(req *http.Request) (*http.Response, error) {
279+
MockDefaultClientTransport(t, func(req *http.Request) (*http.Response, error) {
279280
return nil, mockNetErr
280-
})()
281+
})
281282

282283
err := WatchConfig(context.Background())
283284
if err == nil {
@@ -290,17 +291,15 @@ func TestWatchConfigWebErrorLimit(t *testing.T) {
290291
Err: mockNetErr,
291292
}
292293
expectedErr := fmt.Errorf("network error when requesting metadata, make sure your instance has an active network and can reach the metadata server: %w", expectedBaseErr)
293-
if err.Error() != expectedErr.Error() {
294-
t.Errorf("Expected exact error:\n%q\nGot:\n%q", expectedErr.Error(), err.Error())
295-
}
294+
utiltest.AssertErrorMatch(t, err, expectedErr)
296295
}
297296

298297
// TestWatchConfigUnmarshalErrorLimit tests how WatchConfig handles bad or incomplete
299298
// data from the metadata server. The test gives the agent a broken configuration
300299
// response and verifies that the agent tries to read it again up to a limit of 3
301300
// times before it stops and reports an error.
302301
func TestWatchConfigUnmarshalErrorLimit(t *testing.T) {
303-
defer OverrideWatchConfigTimeouts(1*time.Millisecond, 1*time.Second)()
302+
OverrideWatchConfigTimeouts(t, 1*time.Millisecond, 1*time.Second)
304303

305304
badJSON := []byte(`{"bad json"`)
306305
setupMockMetadataServer(t, func(w http.ResponseWriter, r *http.Request) {
@@ -316,17 +315,15 @@ func TestWatchConfigUnmarshalErrorLimit(t *testing.T) {
316315

317316
var dummy metadataJSON
318317
expectedErr := json.Unmarshal(badJSON, &dummy)
319-
if err.Error() != expectedErr.Error() {
320-
t.Errorf("Expected exact error:\n%q\nGot:\n%q", expectedErr.Error(), err.Error())
321-
}
318+
utiltest.AssertErrorMatch(t, err, expectedErr)
322319
}
323320

324321
// TestWatchConfigContextCancel tests that the WatchConfig function can be stopped
325322
// correctly. It checks that if another part of the program tells WatchConfig to
326323
// cancel, it stops immediately without waiting for a timeout or retrying failed
327324
// requests.
328325
func TestWatchConfigContextCancel(t *testing.T) {
329-
defer OverrideWatchConfigTimeouts(1*time.Minute, 1*time.Minute)()
326+
OverrideWatchConfigTimeouts(t, 1*time.Minute, 1*time.Minute)
330327

331328
setupMockMetadataServer(t, func(w http.ResponseWriter, r *http.Request) {
332329
w.Header().Set("Etag", fmt.Sprintf("cancel-etag-%d", time.Now().UnixNano()))
@@ -418,7 +415,7 @@ func TestIDToken(t *testing.T) {
418415
handler http.HandlerFunc
419416
numCalls int
420417
wantToken string
421-
wantErr bool
418+
wantErr error
422419
wantRequests int
423420
}{
424421
{
@@ -433,7 +430,7 @@ func TestIDToken(t *testing.T) {
433430
},
434431
numCalls: 2,
435432
wantToken: validToken,
436-
wantErr: false,
433+
wantErr: nil,
437434
wantRequests: 1, // Only 1 request should be made due to caching
438435
},
439436
{
@@ -448,7 +445,7 @@ func TestIDToken(t *testing.T) {
448445
},
449446
numCalls: 2,
450447
wantToken: expiringToken,
451-
wantErr: false,
448+
wantErr: nil,
452449
wantRequests: 2, // Token is within 10m of expiry, should trigger a fetch on every call
453450
},
454451
{
@@ -457,7 +454,7 @@ func TestIDToken(t *testing.T) {
457454
http.Error(w, "internal error", http.StatusInternalServerError)
458455
},
459456
numCalls: 1,
460-
wantErr: true,
457+
wantErr: fmt.Errorf("error getting token from metadata: %w", errors.New("compute: Received 500 `internal error\n`")),
461458
// The compute/metadata client library automatically retries on 500 errors (1 initial + 5 retries).
462459
wantRequests: 6,
463460
},
@@ -468,7 +465,7 @@ func TestIDToken(t *testing.T) {
468465
fmt.Fprint(w, "not.a.valid.token")
469466
},
470467
numCalls: 1,
471-
wantErr: true,
468+
wantErr: errors.New("jws: invalid token received"),
472469
wantRequests: 1,
473470
},
474471
}
@@ -488,11 +485,8 @@ func TestIDToken(t *testing.T) {
488485
for i := 0; i < tt.numCalls; i++ {
489486
token, err = IDToken()
490487
}
491-
492-
if (err != nil) != tt.wantErr {
493-
t.Fatalf("IDToken() error = %v, wantErr %v", err, tt.wantErr)
494-
}
495-
if err == nil && token != tt.wantToken {
488+
utiltest.AssertErrorMatch(t, err, tt.wantErr)
489+
if token != tt.wantToken {
496490
t.Errorf("IDToken() = %q, want %q", token, tt.wantToken)
497491
}
498492
if requests != tt.wantRequests {
@@ -504,42 +498,36 @@ func TestIDToken(t *testing.T) {
504498

505499
// TestFormatMetadataError verifies that network and DNS errors are wrapped with helpful context.
506500
func TestFormatMetadataError(t *testing.T) {
507-
errStandard := fmt.Errorf("standard error")
508-
errDNS := &url.Error{Err: &net.DNSError{Err: "no such host"}}
509-
errNet := &url.Error{Err: &net.OpError{Op: "dial", Net: "tcp"}}
501+
dnsErr := &url.Error{Err: &net.DNSError{Err: "no such host"}}
502+
netErr := &url.Error{Err: &net.OpError{Op: "dial", Net: "tcp", Err: errors.New("connection refused")}}
510503

511504
tests := []struct {
512-
name string
513-
inputErr error
514-
wantExact error
515-
wantContain string
505+
name string
506+
inputErr error
507+
wantErr error
516508
}{
517509
{
518-
name: "standard error",
519-
inputErr: errStandard,
520-
wantExact: errStandard,
510+
name: "standard error",
511+
inputErr: fmt.Errorf("standard error"),
512+
wantErr: fmt.Errorf("standard error"),
521513
},
522514
{
523-
name: "DNS error",
524-
inputErr: errDNS,
525-
wantContain: "DNS error when requesting metadata",
515+
name: "DNS error",
516+
inputErr: dnsErr,
517+
wantErr: fmt.Errorf("DNS error when requesting metadata, check DNS settings and ensure metadata.google.internal is setup in your hosts file: %w", dnsErr),
526518
},
527519
{
528-
name: "network error",
529-
inputErr: errNet,
530-
wantContain: "network error when requesting metadata",
520+
name: "network error",
521+
inputErr: netErr,
522+
wantErr: fmt.Errorf("network error when requesting metadata, make sure your instance has an active network and can reach the metadata server: %w", netErr),
531523
},
532524
}
533525

534526
for _, tt := range tests {
535527
t.Run(tt.name, func(t *testing.T) {
536528
got := formatMetadataError(tt.inputErr)
537-
if tt.wantExact != nil && got != tt.wantExact {
538-
t.Errorf("formatMetadataError() = %v, want exact %v", got, tt.wantExact)
539-
}
540-
if tt.wantContain != "" && !strings.Contains(got.Error(), tt.wantContain) {
541-
t.Errorf("formatMetadataError() = %v, want to contain %q", got, tt.wantContain)
542-
}
529+
530+
utiltest.AssertErrorMatch(t, got, tt.wantErr)
543531
})
544532
}
545533
}
@@ -608,13 +596,13 @@ func TestGetMetadata(t *testing.T) {
608596

609597
// TestGetMetadataFallback verifies fallback to the default metadata IP address.
610598
func TestGetMetadataFallback(t *testing.T) {
611-
defer UnsetEnv(t, metadataHostEnv)()
599+
utiltest.UnsetEnv(t, metadataHostEnv)
612600

613601
var requestedURL string
614-
defer MockDefaultClientTransport(t, func(req *http.Request) (*http.Response, error) {
602+
MockDefaultClientTransport(t, func(req *http.Request) (*http.Response, error) {
615603
requestedURL = req.URL.String()
616604
return &http.Response{StatusCode: 200, Body: ioutil.NopCloser(strings.NewReader("mock response"))}, nil
617-
})()
605+
})
618606

619607
_, _, err := getMetadata("test-suffix")
620608
if err != nil {
@@ -632,7 +620,7 @@ func TestGetMetadataErrors(t *testing.T) {
632620
tests := []struct {
633621
name string
634622
suffix string
635-
mockTransport func(t *testing.T) (rollback func())
623+
mockTransport func(t *testing.T)
636624
wantErrContain string
637625
}{
638626
{
@@ -643,8 +631,8 @@ func TestGetMetadataErrors(t *testing.T) {
643631
{
644632
name: "client.Do error",
645633
suffix: "test-suffix",
646-
mockTransport: func(t *testing.T) func() {
647-
return MockDefaultClientTransport(t, func(req *http.Request) (*http.Response, error) {
634+
mockTransport: func(t *testing.T) {
635+
MockDefaultClientTransport(t, func(req *http.Request) (*http.Response, error) {
648636
return nil, fmt.Errorf("mock dial error")
649637
})
650638
},
@@ -655,7 +643,7 @@ func TestGetMetadataErrors(t *testing.T) {
655643
for _, tt := range tests {
656644
t.Run(tt.name, func(t *testing.T) {
657645
if tt.mockTransport != nil {
658-
t.Cleanup(tt.mockTransport(t))
646+
tt.mockTransport(t)
659647
}
660648
_, _, err := getMetadata(tt.suffix)
661649
if err == nil || !strings.Contains(err.Error(), tt.wantErrContain) {
@@ -1184,7 +1172,7 @@ func TestGetCacheDirWindows(t *testing.T) {
11841172
// that os.UserCacheDir relies on to generate paths.
11851173
envs := []string{"HOME", "LocalAppData", "XDG_CACHE_HOME"}
11861174
for _, env := range envs {
1187-
t.Cleanup(UnsetEnv(t, env))
1175+
utiltest.UnsetEnv(t, env)
11881176
}
11891177
},
11901178
want: filepath.Join(os.TempDir(), windowsCacheDir),
@@ -1318,70 +1306,27 @@ func (f roundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) {
13181306
return f(req)
13191307
}
13201308

1321-
// OverrideEnv sets an environment variable for the duration of a test and returns a rollback function to restore its original state.
1322-
func OverrideEnv(t *testing.T, env, value string) (rollback func()) {
1323-
orig, ok := os.LookupEnv(env)
1324-
rollback = func() {
1325-
if ok {
1326-
if err := os.Setenv(env, orig); err != nil {
1327-
t.Fatalf("Failed to restore environment variable %s: %v", env, err)
1328-
}
1329-
} else {
1330-
if err := os.Unsetenv(env); err != nil {
1331-
t.Fatalf("Failed to unset environment variable %s: %v", env, err)
1332-
}
1333-
}
1334-
}
1335-
1336-
if err := os.Setenv(env, value); err != nil {
1337-
t.Fatalf("Failed to set environment variable %s: %v", env, err)
1338-
}
1339-
1340-
return rollback
1341-
}
1342-
1343-
// UnsetEnv unsets an environment variable for the duration of a test and returns a rollback function to restore its original state.
1344-
func UnsetEnv(t *testing.T, env string) (rollback func()) {
1345-
orig, ok := os.LookupEnv(env)
1346-
rollback = func() {
1347-
if ok {
1348-
if err := os.Setenv(env, orig); err != nil {
1349-
t.Fatalf("Failed to restore environment variable %s: %v", env, err)
1350-
}
1351-
} else {
1352-
if err := os.Unsetenv(env); err != nil {
1353-
t.Fatalf("Failed to unset environment variable %s: %v", env, err)
1354-
}
1355-
}
1356-
}
1357-
1358-
if err := os.Unsetenv(env); err != nil {
1359-
t.Fatalf("Failed to unset environment variable %s: %v", env, err)
1360-
}
1361-
1362-
return rollback
1363-
}
1364-
13651309
// OverrideWatchConfigTimeouts temporarily overwrites the timeout and retry intervals for WatchConfig.
1366-
func OverrideWatchConfigTimeouts(interval, timeout time.Duration) (rollback func()) {
1310+
func OverrideWatchConfigTimeouts(t *testing.T, interval, timeout time.Duration) {
1311+
t.Helper()
13671312
origInterval := watchConfigRetryInterval
13681313
origTimeout := osConfigWatchConfigTimeout
13691314

13701315
watchConfigRetryInterval = interval
13711316
osConfigWatchConfigTimeout = timeout
1372-
return func() {
1317+
t.Cleanup(func() {
13731318
watchConfigRetryInterval = origInterval
13741319
osConfigWatchConfigTimeout = origTimeout
1375-
}
1320+
})
13761321
}
13771322

13781323
// MockDefaultClientTransport temporarily replaces the defaultClient's transport with a custom round tripper.
1379-
func MockDefaultClientTransport(t *testing.T, roundTrip func(*http.Request) (*http.Response, error)) (rollback func()) {
1324+
func MockDefaultClientTransport(t *testing.T, roundTrip func(*http.Request) (*http.Response, error)) {
13801325
origClient := defaultClient
13811326
defaultClient = &http.Client{
13821327
Transport: roundTripperFunc(roundTrip),
13831328
}
1384-
return func() {
1329+
t.Cleanup(func() {
13851330
defaultClient = origClient
1386-
}
1331+
})
13871332
}

0 commit comments

Comments
 (0)