@@ -33,6 +33,8 @@ import (
3333 "sync"
3434 "testing"
3535 "time"
36+
37+ "github.com/GoogleCloudPlatform/osconfig/util/utiltest"
3638)
3739
3840// setupMockMetadataServer starts an httptest.Server with the provided handler and overrides the GCE_METADATA_HOST environment variable.
@@ -42,8 +44,7 @@ func setupMockMetadataServer(t *testing.T, handler http.HandlerFunc) *httptest.S
4244 ts := httptest .NewServer (handler )
4345 t .Cleanup (ts .Close )
4446
45- rollback := OverrideEnv (t , "GCE_METADATA_HOST" , strings .TrimPrefix (ts .URL , "http://" ))
46- t .Cleanup (rollback )
47+ utiltest .OverrideEnv (t , "GCE_METADATA_HOST" , strings .TrimPrefix (ts .URL , "http://" ))
4748
4849 return ts
4950}
@@ -238,7 +239,7 @@ func TestSetConfigDefaultValues(t *testing.T) {
238239// keep polling for real changes. This test verifies that the agent correctly
239240// continues to wait until its internal timeout runs out, and then exits normally.
240241func TestWatchConfigUnchangedConfigTimeout (t * testing.T ) {
241- defer OverrideWatchConfigTimeouts (1 * time .Millisecond , 10 * time .Millisecond )( )
242+ OverrideWatchConfigTimeouts (t , 1 * time .Millisecond , 10 * time .Millisecond )
242243
243244 var count int
244245 setupMockMetadataServer (t , func (w http.ResponseWriter , r * http.Request ) {
@@ -267,17 +268,17 @@ func TestWatchConfigUnchangedConfigTimeout(t *testing.T) {
267268// up to a limit of 12 times before giving up and reporting an error.
268269func TestWatchConfigWebErrorLimit (t * testing.T ) {
269270 lEtag .set ("0" )
270- defer OverrideWatchConfigTimeouts (1 * time .Millisecond , 1 * time .Second )( )
271- defer OverrideEnv (t , "GCE_METADATA_HOST" , "mock-host" )( )
271+ OverrideWatchConfigTimeouts (t , 1 * time .Millisecond , 1 * time .Second )
272+ utiltest . OverrideEnv (t , "GCE_METADATA_HOST" , "mock-host" )
272273
273274 mockNetErr := & net.OpError {
274275 Op : "dial" ,
275276 Net : "tcp" ,
276277 Err : errors .New ("connection refused" ),
277278 }
278- defer MockDefaultClientTransport (t , func (req * http.Request ) (* http.Response , error ) {
279+ MockDefaultClientTransport (t , func (req * http.Request ) (* http.Response , error ) {
279280 return nil , mockNetErr
280- })()
281+ })
281282
282283 err := WatchConfig (context .Background ())
283284 if err == nil {
@@ -290,17 +291,15 @@ func TestWatchConfigWebErrorLimit(t *testing.T) {
290291 Err : mockNetErr ,
291292 }
292293 expectedErr := fmt .Errorf ("network error when requesting metadata, make sure your instance has an active network and can reach the metadata server: %w" , expectedBaseErr )
293- if err .Error () != expectedErr .Error () {
294- t .Errorf ("Expected exact error:\n %q\n Got:\n %q" , expectedErr .Error (), err .Error ())
295- }
294+ utiltest .AssertErrorMatch (t , err , expectedErr )
296295}
297296
298297// TestWatchConfigUnmarshalErrorLimit tests how WatchConfig handles bad or incomplete
299298// data from the metadata server. The test gives the agent a broken configuration
300299// response and verifies that the agent tries to read it again up to a limit of 3
301300// times before it stops and reports an error.
302301func TestWatchConfigUnmarshalErrorLimit (t * testing.T ) {
303- defer OverrideWatchConfigTimeouts (1 * time .Millisecond , 1 * time .Second )( )
302+ OverrideWatchConfigTimeouts (t , 1 * time .Millisecond , 1 * time .Second )
304303
305304 badJSON := []byte (`{"bad json"` )
306305 setupMockMetadataServer (t , func (w http.ResponseWriter , r * http.Request ) {
@@ -316,17 +315,15 @@ func TestWatchConfigUnmarshalErrorLimit(t *testing.T) {
316315
317316 var dummy metadataJSON
318317 expectedErr := json .Unmarshal (badJSON , & dummy )
319- if err .Error () != expectedErr .Error () {
320- t .Errorf ("Expected exact error:\n %q\n Got:\n %q" , expectedErr .Error (), err .Error ())
321- }
318+ utiltest .AssertErrorMatch (t , err , expectedErr )
322319}
323320
324321// TestWatchConfigContextCancel tests that the WatchConfig function can be stopped
325322// correctly. It checks that if another part of the program tells WatchConfig to
326323// cancel, it stops immediately without waiting for a timeout or retrying failed
327324// requests.
328325func TestWatchConfigContextCancel (t * testing.T ) {
329- defer OverrideWatchConfigTimeouts (1 * time .Minute , 1 * time .Minute )( )
326+ OverrideWatchConfigTimeouts (t , 1 * time .Minute , 1 * time .Minute )
330327
331328 setupMockMetadataServer (t , func (w http.ResponseWriter , r * http.Request ) {
332329 w .Header ().Set ("Etag" , fmt .Sprintf ("cancel-etag-%d" , time .Now ().UnixNano ()))
@@ -418,7 +415,7 @@ func TestIDToken(t *testing.T) {
418415 handler http.HandlerFunc
419416 numCalls int
420417 wantToken string
421- wantErr bool
418+ wantErr error
422419 wantRequests int
423420 }{
424421 {
@@ -433,7 +430,7 @@ func TestIDToken(t *testing.T) {
433430 },
434431 numCalls : 2 ,
435432 wantToken : validToken ,
436- wantErr : false ,
433+ wantErr : nil ,
437434 wantRequests : 1 , // Only 1 request should be made due to caching
438435 },
439436 {
@@ -448,7 +445,7 @@ func TestIDToken(t *testing.T) {
448445 },
449446 numCalls : 2 ,
450447 wantToken : expiringToken ,
451- wantErr : false ,
448+ wantErr : nil ,
452449 wantRequests : 2 , // Token is within 10m of expiry, should trigger a fetch on every call
453450 },
454451 {
@@ -457,7 +454,7 @@ func TestIDToken(t *testing.T) {
457454 http .Error (w , "internal error" , http .StatusInternalServerError )
458455 },
459456 numCalls : 1 ,
460- wantErr : true ,
457+ wantErr : fmt . Errorf ( "error getting token from metadata: %w" , errors . New ( "compute: Received 500 `internal error \n `" )) ,
461458 // The compute/metadata client library automatically retries on 500 errors (1 initial + 5 retries).
462459 wantRequests : 6 ,
463460 },
@@ -468,7 +465,7 @@ func TestIDToken(t *testing.T) {
468465 fmt .Fprint (w , "not.a.valid.token" )
469466 },
470467 numCalls : 1 ,
471- wantErr : true ,
468+ wantErr : errors . New ( "jws: invalid token received" ) ,
472469 wantRequests : 1 ,
473470 },
474471 }
@@ -488,11 +485,8 @@ func TestIDToken(t *testing.T) {
488485 for i := 0 ; i < tt .numCalls ; i ++ {
489486 token , err = IDToken ()
490487 }
491-
492- if (err != nil ) != tt .wantErr {
493- t .Fatalf ("IDToken() error = %v, wantErr %v" , err , tt .wantErr )
494- }
495- if err == nil && token != tt .wantToken {
488+ utiltest .AssertErrorMatch (t , err , tt .wantErr )
489+ if token != tt .wantToken {
496490 t .Errorf ("IDToken() = %q, want %q" , token , tt .wantToken )
497491 }
498492 if requests != tt .wantRequests {
@@ -504,42 +498,36 @@ func TestIDToken(t *testing.T) {
504498
505499// TestFormatMetadataError verifies that network and DNS errors are wrapped with helpful context.
506500func TestFormatMetadataError (t * testing.T ) {
507- errStandard := fmt .Errorf ("standard error" )
508- errDNS := & url.Error {Err : & net.DNSError {Err : "no such host" }}
509- errNet := & url.Error {Err : & net.OpError {Op : "dial" , Net : "tcp" }}
501+ dnsErr := & url.Error {Err : & net.DNSError {Err : "no such host" }}
502+ netErr := & url.Error {Err : & net.OpError {Op : "dial" , Net : "tcp" , Err : errors .New ("connection refused" )}}
510503
511504 tests := []struct {
512- name string
513- inputErr error
514- wantExact error
515- wantContain string
505+ name string
506+ inputErr error
507+ wantErr error
516508 }{
517509 {
518- name : "standard error" ,
519- inputErr : errStandard ,
520- wantExact : errStandard ,
510+ name : "standard error" ,
511+ inputErr : fmt . Errorf ( "standard error" ) ,
512+ wantErr : fmt . Errorf ( "standard error" ) ,
521513 },
522514 {
523- name : "DNS error" ,
524- inputErr : errDNS ,
525- wantContain : "DNS error when requesting metadata" ,
515+ name : "DNS error" ,
516+ inputErr : dnsErr ,
517+ wantErr : fmt . Errorf ( "DNS error when requesting metadata, check DNS settings and ensure metadata.google.internal is setup in your hosts file: %w" , dnsErr ) ,
526518 },
527519 {
528- name : "network error" ,
529- inputErr : errNet ,
530- wantContain : "network error when requesting metadata" ,
520+ name : "network error" ,
521+ inputErr : netErr ,
522+ wantErr : fmt . Errorf ( "network error when requesting metadata, make sure your instance has an active network and can reach the metadata server: %w" , netErr ) ,
531523 },
532524 }
533525
534526 for _ , tt := range tests {
535527 t .Run (tt .name , func (t * testing.T ) {
536528 got := formatMetadataError (tt .inputErr )
537- if tt .wantExact != nil && got != tt .wantExact {
538- t .Errorf ("formatMetadataError() = %v, want exact %v" , got , tt .wantExact )
539- }
540- if tt .wantContain != "" && ! strings .Contains (got .Error (), tt .wantContain ) {
541- t .Errorf ("formatMetadataError() = %v, want to contain %q" , got , tt .wantContain )
542- }
529+
530+ utiltest .AssertErrorMatch (t , got , tt .wantErr )
543531 })
544532 }
545533}
@@ -608,13 +596,13 @@ func TestGetMetadata(t *testing.T) {
608596
609597// TestGetMetadataFallback verifies fallback to the default metadata IP address.
610598func TestGetMetadataFallback (t * testing.T ) {
611- defer UnsetEnv (t , metadataHostEnv )( )
599+ utiltest . UnsetEnv (t , metadataHostEnv )
612600
613601 var requestedURL string
614- defer MockDefaultClientTransport (t , func (req * http.Request ) (* http.Response , error ) {
602+ MockDefaultClientTransport (t , func (req * http.Request ) (* http.Response , error ) {
615603 requestedURL = req .URL .String ()
616604 return & http.Response {StatusCode : 200 , Body : ioutil .NopCloser (strings .NewReader ("mock response" ))}, nil
617- })()
605+ })
618606
619607 _ , _ , err := getMetadata ("test-suffix" )
620608 if err != nil {
@@ -632,7 +620,7 @@ func TestGetMetadataErrors(t *testing.T) {
632620 tests := []struct {
633621 name string
634622 suffix string
635- mockTransport func (t * testing.T ) ( rollback func ())
623+ mockTransport func (t * testing.T )
636624 wantErrContain string
637625 }{
638626 {
@@ -643,8 +631,8 @@ func TestGetMetadataErrors(t *testing.T) {
643631 {
644632 name : "client.Do error" ,
645633 suffix : "test-suffix" ,
646- mockTransport : func (t * testing.T ) func () {
647- return MockDefaultClientTransport (t , func (req * http.Request ) (* http.Response , error ) {
634+ mockTransport : func (t * testing.T ) {
635+ MockDefaultClientTransport (t , func (req * http.Request ) (* http.Response , error ) {
648636 return nil , fmt .Errorf ("mock dial error" )
649637 })
650638 },
@@ -655,7 +643,7 @@ func TestGetMetadataErrors(t *testing.T) {
655643 for _ , tt := range tests {
656644 t .Run (tt .name , func (t * testing.T ) {
657645 if tt .mockTransport != nil {
658- t . Cleanup ( tt .mockTransport (t ) )
646+ tt .mockTransport (t )
659647 }
660648 _ , _ , err := getMetadata (tt .suffix )
661649 if err == nil || ! strings .Contains (err .Error (), tt .wantErrContain ) {
@@ -1184,7 +1172,7 @@ func TestGetCacheDirWindows(t *testing.T) {
11841172 // that os.UserCacheDir relies on to generate paths.
11851173 envs := []string {"HOME" , "LocalAppData" , "XDG_CACHE_HOME" }
11861174 for _ , env := range envs {
1187- t . Cleanup ( UnsetEnv (t , env ) )
1175+ utiltest . UnsetEnv (t , env )
11881176 }
11891177 },
11901178 want : filepath .Join (os .TempDir (), windowsCacheDir ),
@@ -1318,70 +1306,27 @@ func (f roundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) {
13181306 return f (req )
13191307}
13201308
1321- // OverrideEnv sets an environment variable for the duration of a test and returns a rollback function to restore its original state.
1322- func OverrideEnv (t * testing.T , env , value string ) (rollback func ()) {
1323- orig , ok := os .LookupEnv (env )
1324- rollback = func () {
1325- if ok {
1326- if err := os .Setenv (env , orig ); err != nil {
1327- t .Fatalf ("Failed to restore environment variable %s: %v" , env , err )
1328- }
1329- } else {
1330- if err := os .Unsetenv (env ); err != nil {
1331- t .Fatalf ("Failed to unset environment variable %s: %v" , env , err )
1332- }
1333- }
1334- }
1335-
1336- if err := os .Setenv (env , value ); err != nil {
1337- t .Fatalf ("Failed to set environment variable %s: %v" , env , err )
1338- }
1339-
1340- return rollback
1341- }
1342-
1343- // UnsetEnv unsets an environment variable for the duration of a test and returns a rollback function to restore its original state.
1344- func UnsetEnv (t * testing.T , env string ) (rollback func ()) {
1345- orig , ok := os .LookupEnv (env )
1346- rollback = func () {
1347- if ok {
1348- if err := os .Setenv (env , orig ); err != nil {
1349- t .Fatalf ("Failed to restore environment variable %s: %v" , env , err )
1350- }
1351- } else {
1352- if err := os .Unsetenv (env ); err != nil {
1353- t .Fatalf ("Failed to unset environment variable %s: %v" , env , err )
1354- }
1355- }
1356- }
1357-
1358- if err := os .Unsetenv (env ); err != nil {
1359- t .Fatalf ("Failed to unset environment variable %s: %v" , env , err )
1360- }
1361-
1362- return rollback
1363- }
1364-
13651309// OverrideWatchConfigTimeouts temporarily overwrites the timeout and retry intervals for WatchConfig.
1366- func OverrideWatchConfigTimeouts (interval , timeout time.Duration ) (rollback func ()) {
1310+ func OverrideWatchConfigTimeouts (t * testing.T , interval , timeout time.Duration ) {
1311+ t .Helper ()
13671312 origInterval := watchConfigRetryInterval
13681313 origTimeout := osConfigWatchConfigTimeout
13691314
13701315 watchConfigRetryInterval = interval
13711316 osConfigWatchConfigTimeout = timeout
1372- return func () {
1317+ t . Cleanup ( func () {
13731318 watchConfigRetryInterval = origInterval
13741319 osConfigWatchConfigTimeout = origTimeout
1375- }
1320+ })
13761321}
13771322
13781323// MockDefaultClientTransport temporarily replaces the defaultClient's transport with a custom round tripper.
1379- func MockDefaultClientTransport (t * testing.T , roundTrip func (* http.Request ) (* http.Response , error )) ( rollback func ()) {
1324+ func MockDefaultClientTransport (t * testing.T , roundTrip func (* http.Request ) (* http.Response , error )) {
13801325 origClient := defaultClient
13811326 defaultClient = & http.Client {
13821327 Transport : roundTripperFunc (roundTrip ),
13831328 }
1384- return func () {
1329+ t . Cleanup ( func () {
13851330 defaultClient = origClient
1386- }
1331+ })
13871332}
0 commit comments