diff --git a/execution/engine/testdata/complex_nesting_query_with_art.json b/execution/engine/testdata/complex_nesting_query_with_art.json index 69a208fe4..ec85c1e5c 100644 --- a/execution/engine/testdata/complex_nesting_query_with_art.json +++ b/execution/engine/testdata/complex_nesting_query_with_art.json @@ -170,7 +170,7 @@ "duration_since_start_pretty": "1ns", "duration_load_nanoseconds": 1, "duration_load_pretty": "1ns", - "single_flight_used": false, + "single_flight_used": true, "single_flight_shared_response": false, "load_skipped": false, "load_stats": { @@ -310,7 +310,7 @@ "duration_since_start_pretty": "1ns", "duration_load_nanoseconds": 1, "duration_load_pretty": "1ns", - "single_flight_used": false, + "single_flight_used": true, "single_flight_shared_response": false, "load_skipped": false, "load_stats": { @@ -496,7 +496,7 @@ "duration_since_start_pretty": "1ns", "duration_load_nanoseconds": 1, "duration_load_pretty": "1ns", - "single_flight_used": false, + "single_flight_used": true, "single_flight_shared_response": false, "load_skipped": false, "load_stats": { diff --git a/go.work.sum b/go.work.sum index d66f874bf..1aecd8d22 100644 --- a/go.work.sum +++ b/go.work.sum @@ -246,6 +246,8 @@ github.com/twmb/franz-go/pkg/kmsg v1.7.0/go.mod h1:se9Mjdt0Nwzc9lnjJ0HyDtLyBnaBD github.com/vbatts/tar-split v0.12.1 h1:CqKoORW7BUWBe7UL/iqTVvkTBOF8UvOMKOIZykxnnbo= github.com/vbatts/tar-split v0.12.1/go.mod h1:eF6B6i6ftWQcDqEn3/iGFRFRo8cBIMSJVOpnNdfTMFA= github.com/wk8/go-ordered-map/v2 v2.1.8 h1:5h/BUHu93oj4gIdvHHHGsScSTMijfx5PeYkE/fJgbpc= +github.com/wundergraph/go-arena v0.0.0-20251008210416-55cb97e6f68f h1:5snewyMaIpajTu4wj22L/DgrGimICqXtUVjkZInBH3Y= +github.com/wundergraph/go-arena v0.0.0-20251008210416-55cb97e6f68f/go.mod h1:ROOysEHWJjLQ8FSfNxZCziagb7Qw2nXY3/vgKRh7eWw= github.com/wk8/go-ordered-map/v2 v2.1.8/go.mod h1:5nJHM5DyteebpVlHnWMV0rPz6Zp7+xBAnxjb1X5vnTw= github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4= diff --git a/v2/go.mod b/v2/go.mod index 83fbcc291..43ada453b 100644 --- a/v2/go.mod +++ b/v2/go.mod @@ -28,7 +28,8 @@ require ( github.com/tidwall/gjson v1.17.0 github.com/tidwall/sjson v1.2.5 github.com/vektah/gqlparser/v2 v2.5.30 - github.com/wundergraph/astjson v0.0.0-20250106123708-be463c97e083 + github.com/wundergraph/astjson v1.0.0 + github.com/wundergraph/go-arena v1.0.0 go.uber.org/atomic v1.11.0 go.uber.org/goleak v1.3.0 go.uber.org/zap v1.26.0 diff --git a/v2/go.sum b/v2/go.sum index 690d15a88..6d0fb3636 100644 --- a/v2/go.sum +++ b/v2/go.sum @@ -134,8 +134,10 @@ github.com/urfave/cli/v2 v2.27.7 h1:bH59vdhbjLv3LAvIu6gd0usJHgoTTPhCFib8qqOwXYU= github.com/urfave/cli/v2 v2.27.7/go.mod h1:CyNAG/xg+iAOg0N4MPGZqVmv2rCoP267496AOXUZjA4= github.com/vektah/gqlparser/v2 v2.5.30 h1:EqLwGAFLIzt1wpx1IPpY67DwUujF1OfzgEyDsLrN6kE= github.com/vektah/gqlparser/v2 v2.5.30/go.mod h1:D1/VCZtV3LPnQrcPBeR/q5jkSQIPti0uYCP/RI0gIeo= -github.com/wundergraph/astjson v0.0.0-20250106123708-be463c97e083 h1:8/D7f8gKxTBjW+SZK4mhxTTBVpxcqeBgWF1Rfmltbfk= -github.com/wundergraph/astjson v0.0.0-20250106123708-be463c97e083/go.mod h1:eOTL6acwctsN4F3b7YE+eE2t8zcJ/doLm9sZzsxxxrE= +github.com/wundergraph/astjson v1.0.0 h1:rETLJuQkMWWW03HCF6WBttEBOu8gi5vznj5KEUPVV2Q= +github.com/wundergraph/astjson v1.0.0/go.mod h1:h12D/dxxnedtLzsKyBLK7/Oe4TAoGpRVC9nDpDrZSWw= +github.com/wundergraph/go-arena v1.0.0 h1:RVYWpDkJ1/6851BRHYehBeEcTLKmZygYIZsvBorcOjw= +github.com/wundergraph/go-arena v1.0.0/go.mod h1:ROOysEHWJjLQ8FSfNxZCziagb7Qw2nXY3/vgKRh7eWw= github.com/xrash/smetrics v0.0.0-20250705151800-55b8f293f342 h1:FnBeRrxr7OU4VvAzt5X7s6266i6cSVkkFPS0TuXWbIg= github.com/xrash/smetrics v0.0.0-20250705151800-55b8f293f342/go.mod h1:Ohn+xnUBiLI6FVj/9LpzZWtj1/D6lUovWYBkxHVV3aM= github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= diff --git a/v2/pkg/astnormalization/uploads/upload_finder.go b/v2/pkg/astnormalization/uploads/upload_finder.go index b69a8bef2..0fd2d44c1 100644 --- a/v2/pkg/astnormalization/uploads/upload_finder.go +++ b/v2/pkg/astnormalization/uploads/upload_finder.go @@ -74,7 +74,7 @@ func (v *UploadFinder) FindUploads(operation, definition *ast.Document, variable variables = []byte("{}") } - v.variables, err = astjson.ParseBytesWithoutCache(variables) + v.variables, err = astjson.ParseBytes(variables) if err != nil { return nil, err } diff --git a/v2/pkg/astvisitor/visitor.go b/v2/pkg/astvisitor/visitor.go index 86f284c72..cae364bce 100644 --- a/v2/pkg/astvisitor/visitor.go +++ b/v2/pkg/astvisitor/visitor.go @@ -5,6 +5,8 @@ import ( "fmt" "sync" + "github.com/wundergraph/go-arena" + "github.com/wundergraph/graphql-go-tools/v2/pkg/ast" "github.com/wundergraph/graphql-go-tools/v2/pkg/lexer/literal" "github.com/wundergraph/graphql-go-tools/v2/pkg/operationreport" @@ -95,6 +97,8 @@ type Walker struct { deferred []func() OnExternalError func(err *operationreport.ExternalError) + + arena arena.Arena } func NewWalkerWithID(ancestorSize int, id string) Walker { @@ -140,6 +144,9 @@ func WalkerFromPool() *Walker { } func (w *Walker) Release() { + if w.arena != nil { + w.arena.Reset() + } w.ResetVisitors() w.Report = nil w.document = nil @@ -1391,6 +1398,11 @@ func (w *Walker) Walk(document, definition *ast.Document, report *operationrepor } else { w.Report = report } + if w.arena == nil { + w.arena = arena.NewMonotonicArena(arena.WithMinBufferSize(64)) + } else { + w.arena.Reset() + } w.Ancestors = w.Ancestors[:0] w.Path = w.Path[:0] w.TypeDefinitions = w.TypeDefinitions[:0] @@ -1843,8 +1855,7 @@ func (w *Walker) walkSelectionSet(ref int, skipFor SkipVisitors) { RefsChanged: for { - refs := make([]int, 0, len(w.document.SelectionSets[ref].SelectionRefs)) - refs = append(refs, w.document.SelectionSets[ref].SelectionRefs...) + refs := arena.SliceAppend(w.arena, nil, w.document.SelectionSets[ref].SelectionRefs...) for i, j := range refs { diff --git a/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource.go b/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource.go index 11c61d9e5..8f8fd1dcf 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource.go +++ b/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource.go @@ -14,7 +14,6 @@ import ( "unicode" "github.com/buger/jsonparser" - "github.com/cespare/xxhash/v2" "github.com/jensneuse/abstractlogger" "github.com/pkg/errors" "github.com/tidwall/sjson" @@ -1907,20 +1906,19 @@ func (s *Source) replaceEmptyObject(variables []byte) ([]byte, bool) { return variables, false } -func (s *Source) LoadWithFiles(ctx context.Context, input []byte, files []*httpclient.FileUpload, out *bytes.Buffer) (err error) { +func (s *Source) LoadWithFiles(ctx context.Context, headers http.Header, input []byte, files []*httpclient.FileUpload) (data []byte, err error) { input = s.compactAndUnNullVariables(input) - return httpclient.DoMultipartForm(s.httpClient, ctx, input, files, out) + return httpclient.DoMultipartForm(s.httpClient, ctx, headers, input, files) } -func (s *Source) Load(ctx context.Context, input []byte, out *bytes.Buffer) (err error) { +func (s *Source) Load(ctx context.Context, headers http.Header, input []byte) (data []byte, err error) { input = s.compactAndUnNullVariables(input) - return httpclient.Do(s.httpClient, ctx, input, out) + return httpclient.Do(s.httpClient, ctx, headers, input) } type GraphQLSubscriptionClient interface { // Subscribe to the origin source. The implementation must not block the calling goroutine. Subscribe(ctx *resolve.Context, options GraphQLSubscriptionOptions, updater resolve.SubscriptionUpdater) error - UniqueRequestID(ctx *resolve.Context, options GraphQLSubscriptionOptions, hash *xxhash.Digest) (err error) SubscribeAsync(ctx *resolve.Context, id uint64, options GraphQLSubscriptionOptions, updater resolve.SubscriptionUpdater) error Unsubscribe(id uint64) } @@ -1956,12 +1954,13 @@ type SubscriptionSource struct { client GraphQLSubscriptionClient } -func (s *SubscriptionSource) AsyncStart(ctx *resolve.Context, id uint64, input []byte, updater resolve.SubscriptionUpdater) error { +func (s *SubscriptionSource) AsyncStart(ctx *resolve.Context, id uint64, headers http.Header, input []byte, updater resolve.SubscriptionUpdater) error { var options GraphQLSubscriptionOptions err := json.Unmarshal(input, &options) if err != nil { return err } + options.Header = headers if options.Body.Query == "" { return resolve.ErrUnableToResolve } @@ -1975,12 +1974,13 @@ func (s *SubscriptionSource) AsyncStop(id uint64) { } // Start the subscription. The updater is called on new events. Start needs to be called in a separate goroutine. -func (s *SubscriptionSource) Start(ctx *resolve.Context, input []byte, updater resolve.SubscriptionUpdater) error { +func (s *SubscriptionSource) Start(ctx *resolve.Context, headers http.Header, input []byte, updater resolve.SubscriptionUpdater) error { var options GraphQLSubscriptionOptions err := json.Unmarshal(input, &options) if err != nil { return err } + options.Header = headers if options.Body.Query == "" { return resolve.ErrUnableToResolve } @@ -1990,16 +1990,3 @@ func (s *SubscriptionSource) Start(ctx *resolve.Context, input []byte, updater r var ( dataSouceName = []byte("graphql") ) - -func (s *SubscriptionSource) UniqueRequestID(ctx *resolve.Context, input []byte, xxh *xxhash.Digest) (err error) { - _, err = xxh.Write(dataSouceName) - if err != nil { - return err - } - var options GraphQLSubscriptionOptions - err = json.Unmarshal(input, &options) - if err != nil { - return err - } - return s.client.UniqueRequestID(ctx, options, xxh) -} diff --git a/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource_test.go b/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource_test.go index 7f41943c8..cb54ab8a8 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource_test.go +++ b/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource_test.go @@ -16,7 +16,6 @@ import ( "testing" "time" - "github.com/cespare/xxhash/v2" "github.com/google/uuid" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -4009,6 +4008,8 @@ func TestGraphQLDataSource(t *testing.T) { NewGraphQLSubscriptionClient(http.DefaultClient, http.DefaultClient, ctx), }, PostProcessing: DefaultPostProcessingConfiguration, + SourceName: "ds-id", + SourceID: "ds-id", }, Response: &resolve.GraphQLResponse{ Fetches: resolve.Sequence(), @@ -4050,6 +4051,8 @@ func TestGraphQLDataSource(t *testing.T) { client: NewGraphQLSubscriptionClient(http.DefaultClient, http.DefaultClient, ctx), }, PostProcessing: DefaultPostProcessingConfiguration, + SourceName: "ds-id", + SourceID: "ds-id", }, Response: &resolve.GraphQLResponse{ Data: &resolve.Object{ @@ -8246,10 +8249,6 @@ func (f *FailingSubscriptionClient) Subscribe(ctx *resolve.Context, options Grap return errSubscriptionClientFail } -func (f *FailingSubscriptionClient) UniqueRequestID(ctx *resolve.Context, options GraphQLSubscriptionOptions, hash *xxhash.Digest) (err error) { - return errSubscriptionClientFail -} - type testSubscriptionUpdaterChan struct { updates chan string complete chan struct{} @@ -8441,13 +8440,13 @@ func TestSubscriptionSource_Start(t *testing.T) { t.Run("should return error when input is invalid", func(t *testing.T) { source := SubscriptionSource{client: &FailingSubscriptionClient{}} - err := source.Start(resolve.NewContext(context.Background()), []byte(`{"url": "", "body": "", "header": null}`), nil) + err := source.Start(resolve.NewContext(context.Background()), nil, []byte(`{"url": "", "body": "", "header": null}`), nil) assert.Error(t, err) }) t.Run("should return error when subscription client returns an error", func(t *testing.T) { source := SubscriptionSource{client: &FailingSubscriptionClient{}} - err := source.Start(resolve.NewContext(context.Background()), []byte(`{"url": "", "body": {}, "header": null}`), nil) + err := source.Start(resolve.NewContext(context.Background()), nil, []byte(`{"url": "", "body": {}, "header": null}`), nil) assert.Error(t, err) assert.Equal(t, resolve.ErrUnableToResolve, err) }) @@ -8460,7 +8459,7 @@ func TestSubscriptionSource_Start(t *testing.T) { source := newSubscriptionSource(ctx.Context()) chatSubscriptionOptions := chatServerSubscriptionOptions(t, `{"variables": {}, "extensions": {}, "operationName": "LiveMessages", "query": "subscription LiveMessages { messageAdded(roomName: "#test") { text createdBy } }"}`) - err := source.Start(ctx, chatSubscriptionOptions, updater) + err := source.Start(ctx, nil, chatSubscriptionOptions, updater) require.ErrorIs(t, err, resolve.ErrUnableToResolve) }) @@ -8472,7 +8471,7 @@ func TestSubscriptionSource_Start(t *testing.T) { source := newSubscriptionSource(ctx.Context()) chatSubscriptionOptions := chatServerSubscriptionOptions(t, `{"variables": {}, "extensions": {}, "operationName": "LiveMessages", "query": "subscription LiveMessages { messageAdded(roomNam: \"#test\") { text createdBy } }"}`) - err := source.Start(ctx, chatSubscriptionOptions, updater) + err := source.Start(ctx, nil, chatSubscriptionOptions, updater) require.NoError(t, err) updater.AwaitUpdates(t, time.Second, 1) assert.Len(t, updater.updates, 1) @@ -8490,7 +8489,7 @@ func TestSubscriptionSource_Start(t *testing.T) { source := newSubscriptionSource(resolverLifecycle) chatSubscriptionOptions := chatServerSubscriptionOptions(t, `{"variables": {}, "extensions": {}, "operationName": "LiveMessages", "query": "subscription LiveMessages { messageAdded(roomName: \"#test\") { text createdBy } }"}`) - err := source.Start(resolve.NewContext(subscriptionLifecycle), chatSubscriptionOptions, updater) + err := source.Start(resolve.NewContext(subscriptionLifecycle), nil, chatSubscriptionOptions, updater) require.NoError(t, err) username := "myuser" @@ -8513,7 +8512,7 @@ func TestSubscriptionSource_Start(t *testing.T) { source := newSubscriptionSource(ctx.Context()) chatSubscriptionOptions := chatServerSubscriptionOptions(t, `{"variables": {}, "extensions": {}, "operationName": "LiveMessages", "query": "subscription LiveMessages { messageAdded(roomName: \"#test\") { text createdBy } }"}`) - err := source.Start(ctx, chatSubscriptionOptions, updater) + err := source.Start(ctx, nil, chatSubscriptionOptions, updater) require.NoError(t, err) username := "myuser" @@ -8577,7 +8576,7 @@ func TestSubscription_GTWS_SubProtocol(t *testing.T) { source := newSubscriptionSource(ctx.Context()) chatSubscriptionOptions := chatServerSubscriptionOptions(t, `{"variables": {}, "extensions": {}, "operationName": "LiveMessages", "query": "subscription LiveMessages { messageAdded(roomNam: \"#test\") { text createdBy } }"}`) - err := source.Start(ctx, chatSubscriptionOptions, updater) + err := source.Start(ctx, nil, chatSubscriptionOptions, updater) require.NoError(t, err) updater.AwaitUpdates(t, time.Second, 1) @@ -8597,7 +8596,7 @@ func TestSubscription_GTWS_SubProtocol(t *testing.T) { source := newSubscriptionSource(resolverLifecycle) chatSubscriptionOptions := chatServerSubscriptionOptions(t, `{"variables": {}, "extensions": {}, "operationName": "LiveMessages", "query": "subscription LiveMessages { messageAdded(roomName: \"#test\") { text createdBy } }"}`) - err := source.Start(resolve.NewContext(subscriptionLifecycle), chatSubscriptionOptions, updater) + err := source.Start(resolve.NewContext(subscriptionLifecycle), nil, chatSubscriptionOptions, updater) require.NoError(t, err) username := "myuser" @@ -8621,7 +8620,7 @@ func TestSubscription_GTWS_SubProtocol(t *testing.T) { source := newSubscriptionSource(ctx.Context()) chatSubscriptionOptions := chatServerSubscriptionOptions(t, `{"variables": {}, "extensions": {}, "operationName": "LiveMessages", "query": "subscription LiveMessages { messageAdded(roomName: \"#test\") { text createdBy } }"}`) - err := source.Start(ctx, chatSubscriptionOptions, updater) + err := source.Start(ctx, nil, chatSubscriptionOptions, updater) require.NoError(t, err) username := "myuser" @@ -8759,10 +8758,9 @@ func TestSource_Load(t *testing.T) { input = httpclient.SetInputBodyWithPath(input, variables, "variables") input = httpclient.SetInputURL(input, []byte(serverUrl)) - buf := bytes.NewBuffer(nil) - - require.NoError(t, src.Load(context.Background(), input, buf)) - assert.Equal(t, `{"variables":{"a":null,"b":"b","c":{}}}`, buf.String()) + data, err := src.Load(context.Background(), nil, input) + require.NoError(t, err) + assert.Equal(t, `{"variables":{"a":null,"b":"b","c":{}}}`, string(data)) }) }) t.Run("remove undefined variables", func(t *testing.T) { @@ -8775,7 +8773,6 @@ func TestSource_Load(t *testing.T) { var input []byte input = httpclient.SetInputBodyWithPath(input, variables, "variables") input = httpclient.SetInputURL(input, []byte(serverUrl)) - buf := bytes.NewBuffer(nil) undefinedVariables := []string{"a", "c"} ctx := context.Background() @@ -8783,8 +8780,9 @@ func TestSource_Load(t *testing.T) { input, err = httpclient.SetUndefinedVariables(input, undefinedVariables) assert.NoError(t, err) - require.NoError(t, src.Load(ctx, input, buf)) - assert.Equal(t, `{"variables":{"b":null}}`, buf.String()) + data, err := src.Load(ctx, nil, input) + require.NoError(t, err) + assert.Equal(t, `{"variables":{"b":null}}`, string(data)) }) }) } @@ -8866,10 +8864,12 @@ func TestLoadFiles(t *testing.T) { input = httpclient.SetInputBodyWithPath(input, variables, "variables") input = httpclient.SetInputBodyWithPath(input, query, "query") input = httpclient.SetInputURL(input, []byte(serverUrl)) - buf := bytes.NewBuffer(nil) ctx := context.Background() - require.NoError(t, src.LoadWithFiles(ctx, input, []*httpclient.FileUpload{httpclient.NewFileUpload(f.Name(), fileName, "variables.file")}, buf)) + got, err := src.LoadWithFiles(ctx, nil, input, []*httpclient.FileUpload{httpclient.NewFileUpload(f.Name(), fileName, "variables.file")}) + require.NoError(t, err) + require.Equal(t, []byte{}, got) + }) t.Run("multiple files", func(t *testing.T) { @@ -8910,7 +8910,6 @@ func TestLoadFiles(t *testing.T) { input = httpclient.SetInputBodyWithPath(input, variables, "variables") input = httpclient.SetInputBodyWithPath(input, query, "query") input = httpclient.SetInputURL(input, []byte(serverUrl)) - buf := bytes.NewBuffer(nil) dir := t.TempDir() f1, err := os.CreateTemp(dir, file1Name) @@ -8924,11 +8923,12 @@ func TestLoadFiles(t *testing.T) { assert.NoError(t, err) ctx := context.Background() - require.NoError(t, src.LoadWithFiles(ctx, input, + got, err := src.LoadWithFiles(ctx, nil, input, []*httpclient.FileUpload{ httpclient.NewFileUpload(f1.Name(), file1Name, "variables.files.0"), - httpclient.NewFileUpload(f2.Name(), file2Name, "variables.files.1")}, - buf)) + httpclient.NewFileUpload(f2.Name(), file2Name, "variables.files.1")}) + require.NoError(t, err) + require.Equal(t, []byte{}, got) }) } diff --git a/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go b/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go index c5a52a476..c8a08df03 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go +++ b/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go @@ -9,13 +9,9 @@ import ( "errors" "fmt" "io" - "maps" "net" "net/http" "net/http/httptrace" - "net/textproto" - "slices" - "strconv" "strings" "sync" "syscall" @@ -295,27 +291,6 @@ func (c *subscriptionClient) Subscribe(ctx *resolve.Context, options GraphQLSubs return c.subscribeWS(ctx.Context(), c.engineCtx, options, updater) } -var ( - withSSE = []byte(`sse:true`) - withSSEMethodPost = []byte(`sse_method_post:true`) -) - -func (c *subscriptionClient) UniqueRequestID(ctx *resolve.Context, options GraphQLSubscriptionOptions, hash *xxhash.Digest) (err error) { - if options.UseSSE { - _, err = hash.Write(withSSE) - if err != nil { - return err - } - } - if options.SSEMethodPost { - _, err = hash.Write(withSSEMethodPost) - if err != nil { - return err - } - } - return c.requestHash(ctx, options, hash) -} - func (c *subscriptionClient) subscribeSSE(requestContext, engineContext context.Context, options GraphQLSubscriptionOptions, updater resolve.SubscriptionUpdater) error { options.readTimeout = c.readTimeout if c.streamingClient == nil { @@ -409,89 +384,6 @@ func (c *subscriptionClient) asyncSubscribeWS(requestContext, engineContext cont return nil } -// generateHandlerIDHash generates a Hash based on: URL and Headers to uniquely identify Upgrade Requests -func (c *subscriptionClient) requestHash(ctx *resolve.Context, options GraphQLSubscriptionOptions, xxh *xxhash.Digest) (err error) { - if _, err = xxh.WriteString(options.URL); err != nil { - return err - } - if err := options.Header.Write(xxh); err != nil { - return err - } - // Make sure any header that will be forwarded to the subgraph - // is hashed to create the handlerID, this way requests with - // different headers will use separate connections. - for _, headerName := range options.ForwardedClientHeaderNames { - if _, err = xxh.WriteString(headerName); err != nil { - return err - } - for _, val := range ctx.Request.Header[textproto.CanonicalMIMEHeaderKey(headerName)] { - if _, err = xxh.WriteString(val); err != nil { - return err - } - } - } - - // Sort header names for deterministic hashing since looping through maps - // results in a non-deterministic order of elements - headerKeys := slices.Sorted(maps.Keys(ctx.Request.Header)) - - for _, headerRegexp := range options.ForwardedClientHeaderRegularExpressions { - // Write header pattern - if _, err = xxh.WriteString(headerRegexp.Pattern.String()); err != nil { - return err - } - - // Write negate match - if _, err = xxh.WriteString(strconv.FormatBool(headerRegexp.NegateMatch)); err != nil { - return err - } - - for _, headerName := range headerKeys { - values := ctx.Request.Header[headerName] - result := headerRegexp.Pattern.MatchString(headerName) - if headerRegexp.NegateMatch { - result = !result - } - if result { - for _, val := range values { - if _, err = xxh.WriteString(val); err != nil { - return err - } - } - } - } - } - if len(ctx.InitialPayload) > 0 { - if _, err = xxh.Write(ctx.InitialPayload); err != nil { - return err - } - } - if options.Body.Extensions != nil { - if _, err = xxh.Write(options.Body.Extensions); err != nil { - return err - } - } - if options.Body.Query != "" { - _, err = xxh.WriteString(options.Body.Query) - if err != nil { - return err - } - } - if options.Body.Variables != nil { - _, err = xxh.Write(options.Body.Variables) - if err != nil { - return err - } - } - if options.Body.OperationName != "" { - _, err = xxh.WriteString(options.Body.OperationName) - if err != nil { - return err - } - } - return nil -} - type UpgradeRequestError struct { URL string StatusCode int diff --git a/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client_test.go b/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client_test.go index 279c4bfe8..86dd57c03 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client_test.go +++ b/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client_test.go @@ -7,7 +7,6 @@ import ( "fmt" "net/http" "net/http/httptest" - "regexp" "runtime" "strings" "sync" @@ -15,7 +14,6 @@ import ( "time" "github.com/buger/jsonparser" - "github.com/cespare/xxhash/v2" "github.com/coder/websocket" ll "github.com/jensneuse/abstractlogger" "github.com/stretchr/testify/assert" @@ -2439,7 +2437,7 @@ func TestWebSocketUpgradeFailures(t *testing.T) { w.Header().Set(key, value) } w.WriteHeader(tc.statusCode) - fmt.Fprintf(w, `{"error": "WebSocket upgrade failed", "status": %d}`, tc.statusCode) + _, _ = fmt.Fprintf(w, `{"error": "WebSocket upgrade failed", "status": %d}`, tc.statusCode) })) defer server.Close() @@ -2571,203 +2569,3 @@ func TestInvalidWebSocketAcceptKey(t *testing.T) { }) } } - -func TestRequestHash(t *testing.T) { - t.Parallel() - client := &subscriptionClient{} - - t.Run("basic request with URL and headers", func(t *testing.T) { - t.Parallel() - - ctx := &resolve.Context{ - Request: resolve.Request{ - Header: http.Header{}, - }, - } - options := GraphQLSubscriptionOptions{ - URL: "http://example.com/graphql", - Header: http.Header{ - "Authorization": []string{"Bearer token"}, - }, - } - hash := xxhash.New() - - err := client.requestHash(ctx, options, hash) - assert.NoError(t, err) - assert.Equal(t, uint64(0xacbca06c541c2a79), hash.Sum64()) - }) - - t.Run("request with forwarded client headers", func(t *testing.T) { - t.Parallel() - - ctx := &resolve.Context{ - Request: resolve.Request{ - Header: http.Header{ - "X-User-Id": []string{"123"}, - "X-Role": []string{"admin"}, - }, - }, - } - options := GraphQLSubscriptionOptions{ - URL: "http://example.com/graphql", - ForwardedClientHeaderNames: []string{"X-User-Id", "X-Role"}, - } - hash := xxhash.New() - - err := client.requestHash(ctx, options, hash) - assert.NoError(t, err) - assert.Equal(t, uint64(0xf428bef25952044c), hash.Sum64()) - }) - - t.Run("request with forwarded client header regex patterns", func(t *testing.T) { - t.Parallel() - - t.Run("with normal", func(t *testing.T) { - header := http.Header{ - "X-Custom-1": []string{"value1"}, - "X-There-2": []string{"value2"}, - "X-Alright-3": []string{"value3"}, - } - ctx := &resolve.Context{ - Request: resolve.Request{ - Header: header, - }, - } - options := GraphQLSubscriptionOptions{ - URL: "http://example.com/graphql", - ForwardedClientHeaderRegularExpressions: []RegularExpression{ - { - Pattern: regexp.MustCompile("^X-Custom-.*$"), - NegateMatch: false, - }, - }, - } - hash := xxhash.New() - - err := client.requestHash(ctx, options, hash) - assert.NoError(t, err) - assert.Equal(t, uint64(0xb1557904bfa9d86a), hash.Sum64()) - }) - - t.Run("with negative", func(t *testing.T) { - t.Parallel() - - ctx := &resolve.Context{ - Request: resolve.Request{ - Header: http.Header{ - "X-Custom-1": []string{"valueThere1"}, - "X-Custom-2": []string{"valueThere2"}, - }, - }, - } - options := GraphQLSubscriptionOptions{ - URL: "http://example.com/graphql", - ForwardedClientHeaderRegularExpressions: []RegularExpression{ - { - Pattern: regexp.MustCompile("^X-Custom-2"), - NegateMatch: true, - }, - }, - } - hash := xxhash.New() - - err := client.requestHash(ctx, options, hash) - assert.NoError(t, err) - assert.Equal(t, uint64(0x5888642db454ccab), hash.Sum64()) - }) - - t.Run("with multiple tries to ensure the hash is idempotent", func(t *testing.T) { - for range 100 { - header := http.Header{ - "X-Custom-1": []string{"a1"}, - "X-There-2": []string{"a2"}, - "X-Custom-6": []string{"a3"}, - "X-Alright-3": []string{"a4"}, - "X-Custom-5": []string{"a5"}, - } - ctx := &resolve.Context{ - Request: resolve.Request{ - Header: header, - }, - } - options := GraphQLSubscriptionOptions{ - URL: "http://example.com/graphql", - ForwardedClientHeaderRegularExpressions: []RegularExpression{ - { - Pattern: regexp.MustCompile("^X-Custom-.*$"), - NegateMatch: false, - }, - }, - } - hash := xxhash.New() - - err := client.requestHash(ctx, options, hash) - assert.NoError(t, err) - assert.Equal(t, uint64(0x6c9c1099adab987d), hash.Sum64()) - } - }) - }) - - t.Run("request with initial payload", func(t *testing.T) { - t.Parallel() - - ctx := &resolve.Context{ - Request: resolve.Request{ - Header: http.Header{}, - }, - InitialPayload: []byte(`{"auth": "token"}`), - } - options := GraphQLSubscriptionOptions{ - URL: "http://example.com/graphql", - } - hash := xxhash.New() - - err := client.requestHash(ctx, options, hash) - assert.NoError(t, err) - assert.Equal(t, uint64(0x3c5af329478bfcce), hash.Sum64()) - - }) - - t.Run("request with body components", func(t *testing.T) { - t.Parallel() - - ctx := &resolve.Context{ - Request: resolve.Request{ - Header: http.Header{}, - }, - } - options := GraphQLSubscriptionOptions{ - URL: "http://example.com/graphql", - Body: GraphQLBody{ - Query: "query { hello }", - Variables: []byte(`{"var": "value"}`), - OperationName: "HelloQuery", - Extensions: []byte(`{"ext": "value"}`), - }, - } - hash := xxhash.New() - - err := client.requestHash(ctx, options, hash) - assert.NoError(t, err) - assert.Equal(t, uint64(0xd8d5588c8a466cf2), hash.Sum64()) - }) - - t.Run("empty components", func(t *testing.T) { - t.Parallel() - - ctx := &resolve.Context{ - Request: resolve.Request{ - Header: http.Header{}, - }, - } - options := GraphQLSubscriptionOptions{ - URL: "http://example.com/graphql", - } - hash := xxhash.New() - - err := client.requestHash(ctx, options, hash) - assert.NoError(t, err) - assert.Equal(t, uint64(0x767db2231989769), hash.Sum64()) - }) - -} diff --git a/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource.go b/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource.go index 8a196cbc6..6cbc4ca12 100644 --- a/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource.go +++ b/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource.go @@ -7,10 +7,12 @@ package grpcdatasource import ( - "bytes" "context" - "errors" + "encoding/binary" + "fmt" + "net/http" + "github.com/cespare/xxhash/v2" "github.com/tidwall/gjson" "golang.org/x/sync/errgroup" "google.golang.org/grpc" @@ -44,6 +46,8 @@ type DataSource struct { mapping *GRPCMapping federationConfigs plan.FederationFieldConfigurations disabled bool + + pool *resolve.ArenaPool } type ProtoConfig struct { @@ -79,28 +83,36 @@ func NewDataSource(client grpc.ClientConnInterface, config DataSourceConfig) (*D mapping: config.Mapping, federationConfigs: config.FederationConfigs, disabled: config.Disabled, + pool: resolve.NewArenaPool(), }, nil } // Load implements resolve.DataSource interface. -// It processes the input JSON data to make gRPC calls and writes -// the response to the output buffer. +// It processes the input JSON data to make gRPC calls and returns +// the response data. // // The input is expected to contain the necessary information to make // a gRPC call, including service name, method name, and request data. -func (d *DataSource) Load(ctx context.Context, input []byte, out *bytes.Buffer) (err error) { +func (d *DataSource) Load(ctx context.Context, headers http.Header, input []byte) (data []byte, err error) { // get variables from input variables := gjson.Parse(unsafebytes.BytesToString(input)).Get("body.variables") - builder := newJSONBuilder(d.mapping, variables) + + var ( + poolItems []*resolve.ArenaPoolItem + ) + defer func() { + d.pool.ReleaseMany(poolItems) + }() + + item := d.acquirePoolItem(input, 0) + poolItems = append(poolItems, item) + builder := newJSONBuilder(item.Arena, d.mapping, variables) if d.disabled { - out.Write(builder.writeErrorBytes(errors.New("gRPC datasource needs to be enabled to be used"))) - return nil + return builder.writeErrorBytes(fmt.Errorf("gRPC datasource needs to be enabled to be used")), nil } - arena := astjson.Arena{} - defer arena.Reset() - root := arena.NewObject() + root := astjson.ObjectValue(nil) failed := false @@ -115,8 +127,10 @@ func (d *DataSource) Load(ctx context.Context, input []byte, out *bytes.Buffer) // make gRPC calls for index, serviceCall := range serviceCalls { + item := d.acquirePoolItem(input, index) + poolItems = append(poolItems, item) + builder := newJSONBuilder(item.Arena, d.mapping, variables) errGrp.Go(func() error { - a := astjson.Arena{} // Invoke the gRPC method - this will populate serviceCall.Output err := d.cc.Invoke(errGrpCtx, serviceCall.MethodFullName(), serviceCall.Input, serviceCall.Output) @@ -124,7 +138,7 @@ func (d *DataSource) Load(ctx context.Context, input []byte, out *bytes.Buffer) return err } - response, err := builder.marshalResponseJSON(&a, &serviceCall.RPC.Response, serviceCall.Output) + response, err := builder.marshalResponseJSON(&serviceCall.RPC.Response, serviceCall.Output) if err != nil { return err } @@ -149,7 +163,7 @@ func (d *DataSource) Load(ctx context.Context, input []byte, out *bytes.Buffer) } if err := errGrp.Wait(); err != nil { - out.Write(builder.writeErrorBytes(err)) + data = builder.writeErrorBytes(err) failed = true return nil } @@ -162,19 +176,29 @@ func (d *DataSource) Load(ctx context.Context, input []byte, out *bytes.Buffer) root, err = builder.mergeValues(root, result.response) } if err != nil { - out.Write(builder.writeErrorBytes(err)) + data = builder.writeErrorBytes(err) return err } } return nil }); err != nil || failed { - return err + return data, err } - data := builder.toDataObject(root) - out.Write(data.MarshalTo(nil)) - return nil + value := builder.toDataObject(root) + return value.MarshalTo(nil), err +} + +func (d *DataSource) acquirePoolItem(input []byte, index int) *resolve.ArenaPoolItem { + keyGen := xxhash.New() + _, _ = keyGen.Write(input) + var b [8]byte + binary.LittleEndian.PutUint64(b[:], uint64(index)) + _, _ = keyGen.Write(b[:]) + key := keyGen.Sum64() + item := d.pool.Acquire(key) + return item } // LoadWithFiles implements resolve.DataSource interface. @@ -184,6 +208,6 @@ func (d *DataSource) Load(ctx context.Context, input []byte, out *bytes.Buffer) // might not be applicable for most gRPC use cases. // // Currently unimplemented. -func (d *DataSource) LoadWithFiles(ctx context.Context, input []byte, files []*httpclient.FileUpload, out *bytes.Buffer) (err error) { +func (d *DataSource) LoadWithFiles(ctx context.Context, headers http.Header, input []byte, files []*httpclient.FileUpload) (data []byte, err error) { panic("unimplemented") } diff --git a/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource_test.go b/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource_test.go index 8191b5b08..de66be94a 100644 --- a/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource_test.go +++ b/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource_test.go @@ -1,7 +1,6 @@ package grpcdatasource import ( - "bytes" "context" "encoding/json" "fmt" @@ -19,8 +18,6 @@ import ( protoref "google.golang.org/protobuf/reflect/protoreflect" "google.golang.org/protobuf/types/dynamicpb" - "github.com/wundergraph/astjson" - "github.com/wundergraph/graphql-go-tools/v2/pkg/astparser" "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/plan" "github.com/wundergraph/graphql-go-tools/v2/pkg/grpctest" @@ -57,8 +54,7 @@ func Benchmark_DataSource_Load(b *testing.B) { b.ReportAllocs() b.ResetTimer() for b.Loop() { - output := new(bytes.Buffer) - err = ds.Load(context.Background(), []byte(`{"query":"`+query+`","body":`+variables+`}`), output) + _, err = ds.Load(context.Background(), nil, []byte(`{"query":"`+query+`","body":`+variables+`}`)) require.NoError(b, err) } } @@ -96,7 +92,7 @@ func Benchmark_DataSource_Load_WithFieldArguments(b *testing.B) { }) require.NoError(b, err) - err = ds.Load(context.Background(), []byte(`{"query":"`+query+`","body":`+variables+`}`), new(bytes.Buffer)) + _, err = ds.Load(context.Background(), nil, []byte(`{"query":"`+query+`","body":`+variables+`}`)) require.NoError(b, err) } } @@ -223,12 +219,8 @@ func Test_DataSource_Load(t *testing.T) { require.NoError(t, err) - output := new(bytes.Buffer) - - err = ds.Load(context.Background(), []byte(`{"query":"`+query+`","variables":`+variables+`}`), output) + _, err = ds.Load(context.Background(), nil, []byte(`{"query":"`+query+`","variables":`+variables+`}`)) require.NoError(t, err) - - fmt.Println(output.String()) } // Test_DataSource_Load_WithMockService tests the datasource.Load method with an actual gRPC server @@ -296,12 +288,11 @@ func Test_DataSource_Load_WithMockService(t *testing.T) { require.NoError(t, err) // 3. Execute the query through our datasource - output := new(bytes.Buffer) - err = ds.Load(context.Background(), []byte(`{"query":"`+query+`","body":`+variables+`}`), output) + output, err := ds.Load(context.Background(), nil, []byte(`{"query":"`+query+`","body":`+variables+`}`)) require.NoError(t, err) // Print the response for debugging - // fmt.Println(output.String()) + // fmt.Println(string(output)) type response struct { Data struct { @@ -314,7 +305,7 @@ func Test_DataSource_Load_WithMockService(t *testing.T) { var resp response - bytes := output.Bytes() + bytes := output fmt.Println(string(bytes)) err = json.Unmarshal(bytes, &resp) @@ -386,12 +377,10 @@ func Test_DataSource_Load_WithMockService_WithResponseMapping(t *testing.T) { require.NoError(t, err) // 3. Execute the query through our datasource - output := new(bytes.Buffer) - // Format the input with query and variables inputJSON := fmt.Sprintf(`{"query":%q,"body":%s}`, query, variables) - err = ds.Load(context.Background(), []byte(inputJSON), output) + output, err := ds.Load(context.Background(), nil, []byte(inputJSON)) require.NoError(t, err) // Set up the correct response structure based on your GraphQL schema @@ -408,7 +397,7 @@ func Test_DataSource_Load_WithMockService_WithResponseMapping(t *testing.T) { } var resp response - err = json.Unmarshal(output.Bytes(), &resp) + err = json.Unmarshal(output, &resp) require.NoError(t, err, "Failed to unmarshal response") // Check if there are any errors in the response @@ -483,11 +472,10 @@ func Test_DataSource_Load_WithGrpcError(t *testing.T) { require.NoError(t, err) // 4. Execute the query - output := new(bytes.Buffer) - err = ds.Load(context.Background(), []byte(`{"query":"`+query+`","body":`+variables+`}`), output) + output, err := ds.Load(context.Background(), nil, []byte(`{"query":"`+query+`","body":`+variables+`}`)) require.NoError(t, err, "Load should not return an error even when the gRPC call fails") - responseJson := output.String() + responseJson := string(output) // 5. Verify the response format according to GraphQL specification // The response should have an "errors" array with the error message @@ -501,7 +489,7 @@ func Test_DataSource_Load_WithGrpcError(t *testing.T) { } `json:"errors"` } - err = json.Unmarshal(output.Bytes(), &response) + err = json.Unmarshal(output, &response) require.NoError(t, err, "Failed to parse response JSON") // Verify there's at least one error @@ -573,9 +561,8 @@ func TestMarshalResponseJSON(t *testing.T) { responseMessage := dynamicpb.NewMessage(responseMessageDesc) responseMessage.Mutable(responseMessageDesc.Fields().ByName("result")).List().Append(protoref.ValueOfMessage(productMessage)) - arena := astjson.Arena{} - jsonBuilder := newJSONBuilder(nil, gjson.Result{}) - responseJSON, err := jsonBuilder.marshalResponseJSON(&arena, &response, responseMessage) + jsonBuilder := newJSONBuilder(nil, nil, gjson.Result{}) + responseJSON, err := jsonBuilder.marshalResponseJSON(&response, responseMessage) require.NoError(t, err) require.Equal(t, `{"_entities":[{"__typename":"Product","id":"123","name_different":"test","price_different":123.45}]}`, responseJSON.String()) } @@ -810,9 +797,8 @@ func Test_DataSource_Load_WithAnimalInterface(t *testing.T) { require.NoError(t, err) // Execute the query through our datasource - output := new(bytes.Buffer) input := fmt.Sprintf(`{"query":%q,"body":%s}`, tc.query, tc.vars) - err = ds.Load(context.Background(), []byte(input), output) + output, err := ds.Load(context.Background(), nil, []byte(input)) require.NoError(t, err) // Parse the response @@ -823,7 +809,7 @@ func Test_DataSource_Load_WithAnimalInterface(t *testing.T) { } `json:"errors,omitempty"` } - err = json.Unmarshal(output.Bytes(), &resp) + err = json.Unmarshal(output, &resp) require.NoError(t, err, "Failed to unmarshal response") require.Empty(t, resp.Errors, "Response should not contain errors") require.NotEmpty(t, resp.Data, "Response should contain data") @@ -1081,9 +1067,8 @@ func Test_Datasource_Load_WithUnionTypes(t *testing.T) { require.NoError(t, err) // Execute the query through our datasource - output := new(bytes.Buffer) input := fmt.Sprintf(`{"query":%q,"body":%s}`, tc.query, tc.vars) - err = ds.Load(context.Background(), []byte(input), output) + output, err := ds.Load(context.Background(), nil, []byte(input)) require.NoError(t, err) // Parse the response @@ -1094,7 +1079,7 @@ func Test_Datasource_Load_WithUnionTypes(t *testing.T) { } `json:"errors,omitempty"` } - err = json.Unmarshal(output.Bytes(), &resp) + err = json.Unmarshal(output, &resp) require.NoError(t, err, "Failed to unmarshal response") require.Empty(t, resp.Errors, "Response should not contain errors") require.NotEmpty(t, resp.Data, "Response should contain data") @@ -1218,9 +1203,8 @@ func Test_DataSource_Load_WithCategoryQueries(t *testing.T) { require.NoError(t, err) // Execute the query through our datasource - output := new(bytes.Buffer) input := fmt.Sprintf(`{"query":%q,"body":%s}`, tc.query, tc.vars) - err = ds.Load(context.Background(), []byte(input), output) + output, err := ds.Load(context.Background(), nil, []byte(input)) require.NoError(t, err) // Parse the response @@ -1231,7 +1215,7 @@ func Test_DataSource_Load_WithCategoryQueries(t *testing.T) { } `json:"errors,omitempty"` } - err = json.Unmarshal(output.Bytes(), &resp) + err = json.Unmarshal(output, &resp) require.NoError(t, err, "Failed to unmarshal response") require.Empty(t, resp.Errors, "Response should not contain errors") require.NotEmpty(t, resp.Data, "Response should contain data") @@ -1299,9 +1283,8 @@ func Test_DataSource_Load_WithTotalCalculation(t *testing.T) { require.NoError(t, err) // Execute the query through our datasource - output := new(bytes.Buffer) input := fmt.Sprintf(`{"query":%q,"body":%s}`, query, variables) - err = ds.Load(context.Background(), []byte(input), output) + output, err := ds.Load(context.Background(), nil, []byte(input)) require.NoError(t, err) // Parse the response @@ -1323,7 +1306,7 @@ func Test_DataSource_Load_WithTotalCalculation(t *testing.T) { } `json:"errors,omitempty"` } - err = json.Unmarshal(output.Bytes(), &resp) + err = json.Unmarshal(output, &resp) require.NoError(t, err, "Failed to unmarshal response") require.Empty(t, resp.Errors, "Response should not contain errors") @@ -1390,9 +1373,8 @@ func Test_DataSource_Load_WithTypename(t *testing.T) { require.NoError(t, err) // Execute the query through our datasource - output := new(bytes.Buffer) input := fmt.Sprintf(`{"query":%q,"body":{}}`, query) - err = ds.Load(context.Background(), []byte(input), output) + output, err := ds.Load(context.Background(), nil, []byte(input)) require.NoError(t, err) // Parse the response @@ -1409,7 +1391,7 @@ func Test_DataSource_Load_WithTypename(t *testing.T) { } `json:"errors,omitempty"` } - err = json.Unmarshal(output.Bytes(), &resp) + err = json.Unmarshal(output, &resp) require.NoError(t, err, "Failed to unmarshal response") require.Empty(t, resp.Errors, "Response should not contain errors") @@ -1860,9 +1842,8 @@ func Test_DataSource_Load_WithAliases(t *testing.T) { require.NoError(t, err) // Execute the query through our datasource - output := new(bytes.Buffer) input := fmt.Sprintf(`{"query":%q,"body":%s}`, tc.query, tc.vars) - err = ds.Load(context.Background(), []byte(input), output) + output, err := ds.Load(context.Background(), nil, []byte(input)) require.NoError(t, err) // Parse the response @@ -1873,7 +1854,7 @@ func Test_DataSource_Load_WithAliases(t *testing.T) { } `json:"errors,omitempty"` } - err = json.Unmarshal(output.Bytes(), &resp) + err = json.Unmarshal(output, &resp) require.NoError(t, err, "Failed to unmarshal response") require.Empty(t, resp.Errors, "Response should not contain errors") require.NotEmpty(t, resp.Data, "Response should contain data") @@ -2239,9 +2220,8 @@ func Test_DataSource_Load_WithNullableFieldsType(t *testing.T) { require.NoError(t, err) // Execute the query through our datasource - output := new(bytes.Buffer) input := fmt.Sprintf(`{"query":%q,"body":%s}`, tc.query, tc.vars) - err = ds.Load(context.Background(), []byte(input), output) + output, err := ds.Load(context.Background(), nil, []byte(input)) require.NoError(t, err) // Parse the response @@ -2252,7 +2232,7 @@ func Test_DataSource_Load_WithNullableFieldsType(t *testing.T) { } `json:"errors,omitempty"` } - err = json.Unmarshal(output.Bytes(), &resp) + err = json.Unmarshal(output, &resp) require.NoError(t, err, "Failed to unmarshal response") require.Empty(t, resp.Errors, "Response should not contain errors") require.NotEmpty(t, resp.Data, "Response should contain data") @@ -3541,9 +3521,8 @@ func Test_DataSource_Load_WithNestedLists(t *testing.T) { require.NoError(t, err) // Execute the query through our datasource - output := new(bytes.Buffer) input := fmt.Sprintf(`{"query":%q,"body":%s}`, tc.query, tc.vars) - err = ds.Load(context.Background(), []byte(input), output) + output, err := ds.Load(context.Background(), nil, []byte(input)) require.NoError(t, err) // Parse the response @@ -3554,7 +3533,7 @@ func Test_DataSource_Load_WithNestedLists(t *testing.T) { } `json:"errors,omitempty"` } - err = json.Unmarshal(output.Bytes(), &resp) + err = json.Unmarshal(output, &resp) require.NoError(t, err, "Failed to unmarshal response") require.Empty(t, resp.Errors, "Response should not contain errors") require.NotEmpty(t, resp.Data, "Response should contain data") @@ -3741,15 +3720,14 @@ func Test_DataSource_Load_WithEntity_Calls(t *testing.T) { require.NoError(t, err) // Execute the query through our datasource - output := new(bytes.Buffer) input := fmt.Sprintf(`{"query":%q,"body":%s}`, tc.query, tc.vars) - err = ds.Load(context.Background(), []byte(input), output) + output, err := ds.Load(context.Background(), nil, []byte(input)) require.NoError(t, err) // Parse the response var resp graphqlResponse - err = json.Unmarshal(output.Bytes(), &resp) + err = json.Unmarshal(output, &resp) require.NoError(t, err, "Failed to unmarshal response") tc.validate(t, resp.Data) @@ -3928,15 +3906,14 @@ func Test_Datasource_Load_WithFieldResolvers(t *testing.T) { require.NoError(t, err) // Execute the query through our datasource - output := new(bytes.Buffer) input := fmt.Sprintf(`{"query":%q,"body":%s}`, tc.query, tc.vars) - err = ds.Load(context.Background(), []byte(input), output) + output, err := ds.Load(context.Background(), nil, []byte(input)) require.NoError(t, err) // Parse the response var resp graphqlResponse - err = json.Unmarshal(output.Bytes(), &resp) + err = json.Unmarshal(output, &resp) require.NoError(t, err, "Failed to unmarshal response") tc.validate(t, resp.Data) diff --git a/v2/pkg/engine/datasource/grpc_datasource/json_builder.go b/v2/pkg/engine/datasource/grpc_datasource/json_builder.go index 6e521b041..0b2edc07c 100644 --- a/v2/pkg/engine/datasource/grpc_datasource/json_builder.go +++ b/v2/pkg/engine/datasource/grpc_datasource/json_builder.go @@ -11,6 +11,7 @@ import ( protoref "google.golang.org/protobuf/reflect/protoreflect" "github.com/wundergraph/astjson" + "github.com/wundergraph/go-arena" "github.com/wundergraph/graphql-go-tools/v2/pkg/ast" ) @@ -107,16 +108,18 @@ type jsonBuilder struct { mapping *GRPCMapping // Mapping configuration for GraphQL to gRPC translation variables gjson.Result // GraphQL variables containing entity representations indexMap indexMap // Entity index mapping for federation ordering + jsonArena arena.Arena } // newJSONBuilder creates a new JSON builder instance with the provided mapping // and variables. The builder automatically creates an index map for proper // federation entity ordering if representations are present in the variables. -func newJSONBuilder(mapping *GRPCMapping, variables gjson.Result) *jsonBuilder { +func newJSONBuilder(a arena.Arena, mapping *GRPCMapping, variables gjson.Result) *jsonBuilder { return &jsonBuilder{ mapping: mapping, variables: variables, indexMap: createRepresentationIndexMap(variables), + jsonArena: a, } } @@ -163,7 +166,7 @@ func (j *jsonBuilder) mergeValues(left *astjson.Value, right *astjson.Value) (*a if len(j.indexMap) == 0 { // No federation index map available - use simple merge // This path is taken for non-federated queries - root, _, err := astjson.MergeValues(left, right) + root, _, err := astjson.MergeValues(j.jsonArena, left, right) if err != nil { return nil, err } @@ -189,11 +192,10 @@ func (j *jsonBuilder) mergeValues(left *astjson.Value, right *astjson.Value) (*a // This function ensures that entities are placed in the correct positions in the final response // array based on their original representation order, which is critical for GraphQL federation. func (j *jsonBuilder) mergeEntities(left *astjson.Value, right *astjson.Value) (*astjson.Value, error) { - arena := astjson.Arena{} // Create the response structure with _entities array - entities := arena.NewObject() - entities.Set(entityPath, arena.NewArray()) + entities := astjson.ObjectValue(j.jsonArena) + entities.Set(j.jsonArena, entityPath, astjson.ArrayValue(j.jsonArena)) arr := entities.Get(entityPath) // Extract entity arrays from both responses @@ -209,12 +211,12 @@ func (j *jsonBuilder) mergeEntities(left *astjson.Value, right *astjson.Value) ( // Merge left entities using index mapping to preserve order for index, lr := range leftRepresentations { - arr.SetArrayItem(j.indexMap.getResultIndex(lr, index), lr) + arr.SetArrayItem(j.jsonArena, j.indexMap.getResultIndex(lr, index), lr) } // Merge right entities using index mapping to preserve order for index, rr := range rightRepresentations { - arr.SetArrayItem(j.indexMap.getResultIndex(rr, index), rr) + arr.SetArrayItem(j.jsonArena, j.indexMap.getResultIndex(rr, index), rr) } return entities, nil @@ -257,7 +259,7 @@ func (j *jsonBuilder) mergeWithPath(base *astjson.Value, resolved *astjson.Value } for i := range responseValues { - responseValues[i].Set(elementName, resolvedValues[i].Get(elementName)) + responseValues[i].Set(j.jsonArena, elementName, resolvedValues[i].Get(elementName)) } return nil @@ -315,12 +317,12 @@ func (j *jsonBuilder) flattenList(items []*astjson.Value, path ast.Path) ([]*ast // marshalResponseJSON converts a protobuf message into a GraphQL-compatible JSON response. // This is the core marshaling function that handles all the complex type conversions, // including oneOf types, nested messages, lists, and scalar values. -func (j *jsonBuilder) marshalResponseJSON(arena *astjson.Arena, message *RPCMessage, data protoref.Message) (*astjson.Value, error) { +func (j *jsonBuilder) marshalResponseJSON(message *RPCMessage, data protoref.Message) (*astjson.Value, error) { if message == nil { - return arena.NewNull(), nil + return astjson.NullValue, nil } - root := arena.NewObject() + root := astjson.ObjectValue(j.jsonArena) // Handle protobuf oneOf types - these represent GraphQL union/interface types if message.IsOneOf() { @@ -354,14 +356,14 @@ func (j *jsonBuilder) marshalResponseJSON(arena *astjson.Arena, message *RPCMess if field.StaticValue != "" { if len(message.MemberTypes) == 0 { // Simple static value - use as-is - root.Set(field.AliasOrPath(), arena.NewString(field.StaticValue)) + root.Set(j.jsonArena, field.AliasOrPath(), astjson.StringValue(j.jsonArena, field.StaticValue)) continue } // Type-specific static value - match against member types for _, memberTypes := range message.MemberTypes { if memberTypes == string(data.Type().Descriptor().Name()) { - root.Set(field.AliasOrPath(), arena.NewString(memberTypes)) + root.Set(j.jsonArena, field.AliasOrPath(), astjson.StringValue(j.jsonArena, memberTypes)) break } } @@ -379,8 +381,8 @@ func (j *jsonBuilder) marshalResponseJSON(arena *astjson.Arena, message *RPCMess // Handle list fields (repeated in protobuf) if fd.IsList() { list := data.Get(fd).List() - arr := arena.NewArray() - root.Set(field.AliasOrPath(), arr) + arr := astjson.ArrayValue(j.jsonArena) + root.Set(j.jsonArena, field.AliasOrPath(), arr) if !list.IsValid() { // Invalid list - leave as empty array @@ -393,15 +395,15 @@ func (j *jsonBuilder) marshalResponseJSON(arena *astjson.Arena, message *RPCMess case protoref.MessageKind: // List of messages - recursively marshal each message message := list.Get(i).Message() - value, err := j.marshalResponseJSON(arena, field.Message, message) + value, err := j.marshalResponseJSON(field.Message, message) if err != nil { return nil, err } - arr.SetArrayItem(i, value) + arr.SetArrayItem(j.jsonArena, i, value) default: // List of scalar values - convert directly - j.setArrayItem(i, arena, arr, list.Get(i), fd) + j.setArrayItem(i, arr, list.Get(i), fd) } } @@ -413,24 +415,24 @@ func (j *jsonBuilder) marshalResponseJSON(arena *astjson.Arena, message *RPCMess msg := data.Get(fd).Message() if !msg.IsValid() { // Invalid message - set to null - root.Set(field.AliasOrPath(), arena.NewNull()) + root.Set(j.jsonArena, field.AliasOrPath(), astjson.NullValue) continue } // Handle special list wrapper types for complex nested lists if field.IsListType { - arr, err := j.flattenListStructure(arena, field.ListMetadata, msg, field.Message) + arr, err := j.flattenListStructure(field.ListMetadata, msg, field.Message) if err != nil { return nil, fmt.Errorf("unable to flatten list structure for field %q: %w", field.AliasOrPath(), err) } - root.Set(field.AliasOrPath(), arr) + root.Set(j.jsonArena, field.AliasOrPath(), arr) continue } // Handle optional scalar wrapper types (e.g., google.protobuf.StringValue) if field.IsOptionalScalar() { - err := j.resolveOptionalField(arena, root, field.AliasOrPath(), msg) + err := j.resolveOptionalField(root, field.AliasOrPath(), msg) if err != nil { return nil, err } @@ -439,27 +441,27 @@ func (j *jsonBuilder) marshalResponseJSON(arena *astjson.Arena, message *RPCMess } // Regular nested message - recursively marshal - value, err := j.marshalResponseJSON(arena, field.Message, msg) + value, err := j.marshalResponseJSON(field.Message, msg) if err != nil { return nil, err } if field.JSONPath == "" { // Field should be merged into parent object (flattened) - root, _, err = astjson.MergeValues(root, value) + root, _, err = astjson.MergeValues(j.jsonArena, root, value) if err != nil { return nil, err } } else { // Field should be nested under its own key - root.Set(field.AliasOrPath(), value) + root.Set(j.jsonArena, field.AliasOrPath(), value) } continue } // Handle scalar fields (string, int, bool, etc.) - j.setJSONValue(arena, root, field.AliasOrPath(), data, fd) + j.setJSONValue(root, field.AliasOrPath(), data, fd) } return root, nil @@ -469,34 +471,34 @@ func (j *jsonBuilder) marshalResponseJSON(arena *astjson.Arena, message *RPCMess // messages to support nullable and multi-dimensional lists. This is necessary because // protobuf doesn't directly support nullable list items or complex nesting scenarios // that GraphQL allows. -func (j *jsonBuilder) flattenListStructure(arena *astjson.Arena, md *ListMetadata, data protoref.Message, message *RPCMessage) (*astjson.Value, error) { +func (j *jsonBuilder) flattenListStructure(md *ListMetadata, data protoref.Message, message *RPCMessage) (*astjson.Value, error) { if md == nil { - return arena.NewNull(), errors.New("list metadata not found") + return astjson.NullValue, errors.New("list metadata not found") } // Validate metadata consistency if len(md.LevelInfo) < md.NestingLevel { - return arena.NewNull(), errors.New("nesting level data does not match the number of levels in the list metadata") + return astjson.NullValue, errors.New("nesting level data does not match the number of levels in the list metadata") } // Handle null data with proper nullability checking if !data.IsValid() { if md.LevelInfo[0].Optional { - return arena.NewNull(), nil + return astjson.NullValue, nil } - return arena.NewNull(), errors.New("cannot add null item to response for non nullable list") + return astjson.NullValue, errors.New("cannot add null item to response for non nullable list") } // Start recursive traversal of the nested list structure - root := arena.NewArray() - return j.traverseList(0, arena, root, md, data, message) + root := astjson.ArrayValue(j.jsonArena) + return j.traverseList(0, root, md, data, message) } // traverseList recursively traverses nested list wrapper structures to extract the actual // list data. This handles multi-dimensional lists like [[String]] or [[[User]]] by // unwrapping the protobuf message wrappers at each level. -func (j *jsonBuilder) traverseList(level int, arena *astjson.Arena, current *astjson.Value, md *ListMetadata, data protoref.Message, message *RPCMessage) (*astjson.Value, error) { +func (j *jsonBuilder) traverseList(level int, current *astjson.Value, md *ListMetadata, data protoref.Message, message *RPCMessage) (*astjson.Value, error) { if level > md.NestingLevel { return current, nil } @@ -504,11 +506,11 @@ func (j *jsonBuilder) traverseList(level int, arena *astjson.Arena, current *ast // List wrappers always use field number 1 in the generated protobuf fd := data.Descriptor().Fields().ByNumber(1) if fd == nil { - return arena.NewNull(), fmt.Errorf("field with number %d not found in message %q", 1, data.Descriptor().Name()) + return astjson.NullValue, fmt.Errorf("field with number %d not found in message %q", 1, data.Descriptor().Name()) } if fd.Kind() != protoref.MessageKind { - return arena.NewNull(), fmt.Errorf("field %q is not a message", fd.Name()) + return astjson.NullValue, fmt.Errorf("field %q is not a message", fd.Name()) } // Get the wrapper message containing the list @@ -516,16 +518,16 @@ func (j *jsonBuilder) traverseList(level int, arena *astjson.Arena, current *ast if !msg.IsValid() { // Handle null wrapper based on nullability rules if md.LevelInfo[level].Optional { - return arena.NewNull(), nil + return astjson.NullValue, nil } - return arena.NewArray(), errors.New("cannot add null item to response for non nullable list") + return astjson.ArrayValue(j.jsonArena), errors.New("cannot add null item to response for non nullable list") } // The actual list is always at field number 1 in the wrapper fd = msg.Descriptor().Fields().ByNumber(1) if !fd.IsList() { - return arena.NewNull(), fmt.Errorf("field %q is not a list", fd.Name()) + return astjson.NullValue, fmt.Errorf("field %q is not a list", fd.Name()) } // Handle intermediate nesting levels (not the final level) @@ -533,13 +535,13 @@ func (j *jsonBuilder) traverseList(level int, arena *astjson.Arena, current *ast list := msg.Get(fd).List() for i := 0; i < list.Len(); i++ { // Create nested array for next level - next := arena.NewArray() - val, err := j.traverseList(level+1, arena, next, md, list.Get(i).Message(), message) + next := astjson.ArrayValue(j.jsonArena) + val, err := j.traverseList(level+1, next, md, list.Get(i).Message(), message) if err != nil { return nil, err } - current.SetArrayItem(i, val) + current.SetArrayItem(j.jsonArena, i, val) } return current, nil @@ -550,22 +552,22 @@ func (j *jsonBuilder) traverseList(level int, arena *astjson.Arena, current *ast if !list.IsValid() { // Invalid list at final level - return empty array // Nullability is checked at the wrapper level, not the list level - return arena.NewArray(), nil + return astjson.ArrayValue(j.jsonArena), nil } // Process each item in the final list for i := 0; i < list.Len(); i++ { if message != nil { // List of complex objects - recursively marshal each item - val, err := j.marshalResponseJSON(arena, message, list.Get(i).Message()) + val, err := j.marshalResponseJSON(message, list.Get(i).Message()) if err != nil { return nil, err } - current.SetArrayItem(i, val) + current.SetArrayItem(j.jsonArena, i, val) } else { // List of scalar values - convert directly - j.setArrayItem(i, arena, current, list.Get(i), fd) + j.setArrayItem(i, current, list.Get(i), fd) } } @@ -575,7 +577,7 @@ func (j *jsonBuilder) traverseList(level int, arena *astjson.Arena, current *ast // resolveOptionalField extracts the value from optional scalar wrapper types like // google.protobuf.StringValue, google.protobuf.Int32Value, etc. These wrappers // are used to represent nullable scalar values in protobuf. -func (j *jsonBuilder) resolveOptionalField(arena *astjson.Arena, root *astjson.Value, name string, data protoref.Message) error { +func (j *jsonBuilder) resolveOptionalField(root *astjson.Value, name string, data protoref.Message) error { // Optional scalar wrappers always have a "value" field fd := data.Descriptor().Fields().ByName(protoref.Name("value")) if fd == nil { @@ -583,16 +585,16 @@ func (j *jsonBuilder) resolveOptionalField(arena *astjson.Arena, root *astjson.V } // Extract and set the wrapped value - j.setJSONValue(arena, root, name, data, fd) + j.setJSONValue(root, name, data, fd) return nil } // setJSONValue converts a protobuf field value to the appropriate JSON representation // and sets it on the provided JSON object. This handles all protobuf scalar types // and enum values with proper GraphQL mapping. -func (j *jsonBuilder) setJSONValue(arena *astjson.Arena, root *astjson.Value, name string, data protoref.Message, fd protoref.FieldDescriptor) { +func (j *jsonBuilder) setJSONValue(root *astjson.Value, name string, data protoref.Message, fd protoref.FieldDescriptor) { if !data.IsValid() { - root.Set(name, arena.NewNull()) + root.Set(j.jsonArena, name, astjson.NullValue) return } @@ -600,27 +602,27 @@ func (j *jsonBuilder) setJSONValue(arena *astjson.Arena, root *astjson.Value, na case protoref.BoolKind: boolValue := data.Get(fd).Bool() if boolValue { - root.Set(name, arena.NewTrue()) + root.Set(j.jsonArena, name, astjson.TrueValue(j.jsonArena)) } else { - root.Set(name, arena.NewFalse()) + root.Set(j.jsonArena, name, astjson.FalseValue(j.jsonArena)) } case protoref.StringKind: - root.Set(name, arena.NewString(data.Get(fd).String())) + root.Set(j.jsonArena, name, astjson.StringValue(j.jsonArena, data.Get(fd).String())) case protoref.Int32Kind: - root.Set(name, arena.NewNumberInt(int(data.Get(fd).Int()))) + root.Set(j.jsonArena, name, astjson.IntValue(j.jsonArena, int(data.Get(fd).Int()))) case protoref.Int64Kind: - root.Set(name, arena.NewNumberString(strconv.FormatInt(data.Get(fd).Int(), 10))) + root.Set(j.jsonArena, name, astjson.NumberValue(j.jsonArena, strconv.FormatInt(data.Get(fd).Int(), 10))) case protoref.Uint32Kind, protoref.Uint64Kind: - root.Set(name, arena.NewNumberString(strconv.FormatUint(data.Get(fd).Uint(), 10))) + root.Set(j.jsonArena, name, astjson.NumberValue(j.jsonArena, strconv.FormatUint(data.Get(fd).Uint(), 10))) case protoref.FloatKind, protoref.DoubleKind: - root.Set(name, arena.NewNumberFloat64(data.Get(fd).Float())) + root.Set(j.jsonArena, name, astjson.FloatValue(j.jsonArena, data.Get(fd).Float())) case protoref.BytesKind: - root.Set(name, arena.NewStringBytes(data.Get(fd).Bytes())) + root.Set(j.jsonArena, name, astjson.StringValueBytes(j.jsonArena, data.Get(fd).Bytes())) case protoref.EnumKind: enumDesc := fd.Enum() enumValueDesc := enumDesc.Values().ByNumber(data.Get(fd).Enum()) if enumValueDesc == nil { - root.Set(name, arena.NewNull()) + root.Set(j.jsonArena, name, astjson.NullValue) return } @@ -628,20 +630,20 @@ func (j *jsonBuilder) setJSONValue(arena *astjson.Arena, root *astjson.Value, na graphqlValue, ok := j.mapping.FindEnumValueMapping(string(enumDesc.Name()), string(enumValueDesc.Name())) if !ok { // No mapping found - set to null - root.Set(name, arena.NewNull()) + root.Set(j.jsonArena, name, astjson.NullValue) return } - root.Set(name, arena.NewString(graphqlValue)) + root.Set(j.jsonArena, name, astjson.StringValue(j.jsonArena, graphqlValue)) } } // setArrayItem converts a protobuf list item value to JSON and sets it at the specified // array index. This is similar to setJSONValue but operates on array elements rather // than object properties, and works with protobuf Value types rather than Message types. -func (j *jsonBuilder) setArrayItem(index int, arena *astjson.Arena, array *astjson.Value, data protoref.Value, fd protoref.FieldDescriptor) { +func (j *jsonBuilder) setArrayItem(index int, array *astjson.Value, data protoref.Value, fd protoref.FieldDescriptor) { if !data.IsValid() { - array.SetArrayItem(index, arena.NewNull()) + array.SetArrayItem(j.jsonArena, index, astjson.NullValue) return } @@ -649,27 +651,27 @@ func (j *jsonBuilder) setArrayItem(index int, arena *astjson.Arena, array *astjs case protoref.BoolKind: boolValue := data.Bool() if boolValue { - array.SetArrayItem(index, arena.NewTrue()) + array.SetArrayItem(j.jsonArena, index, astjson.TrueValue(j.jsonArena)) } else { - array.SetArrayItem(index, arena.NewFalse()) + array.SetArrayItem(j.jsonArena, index, astjson.FalseValue(j.jsonArena)) } case protoref.StringKind: - array.SetArrayItem(index, arena.NewString(data.String())) + array.SetArrayItem(j.jsonArena, index, astjson.StringValue(j.jsonArena, data.String())) case protoref.Int32Kind: - array.SetArrayItem(index, arena.NewNumberInt(int(data.Int()))) + array.SetArrayItem(j.jsonArena, index, astjson.IntValue(j.jsonArena, int(data.Int()))) case protoref.Int64Kind: - array.SetArrayItem(index, arena.NewNumberString(strconv.FormatInt(data.Int(), 10))) + array.SetArrayItem(j.jsonArena, index, astjson.NumberValue(j.jsonArena, strconv.FormatInt(data.Int(), 10))) case protoref.Uint32Kind, protoref.Uint64Kind: - array.SetArrayItem(index, arena.NewNumberString(strconv.FormatUint(data.Uint(), 10))) + array.SetArrayItem(j.jsonArena, index, astjson.NumberValue(j.jsonArena, strconv.FormatUint(data.Uint(), 10))) case protoref.FloatKind, protoref.DoubleKind: - array.SetArrayItem(index, arena.NewNumberFloat64(data.Float())) + array.SetArrayItem(j.jsonArena, index, astjson.FloatValue(j.jsonArena, data.Float())) case protoref.BytesKind: - array.SetArrayItem(index, arena.NewStringBytes(data.Bytes())) + array.SetArrayItem(j.jsonArena, index, astjson.StringValueBytes(j.jsonArena, data.Bytes())) case protoref.EnumKind: enumDesc := fd.Enum() enumValueDesc := enumDesc.Values().ByNumber(data.Enum()) if enumValueDesc == nil { - array.SetArrayItem(index, arena.NewNull()) + array.SetArrayItem(j.jsonArena, index, astjson.NullValue) return } @@ -677,20 +679,19 @@ func (j *jsonBuilder) setArrayItem(index int, arena *astjson.Arena, array *astjs graphqlValue, ok := j.mapping.FindEnumValueMapping(string(enumDesc.Name()), string(enumValueDesc.Name())) if !ok { // No mapping found - use null - array.SetArrayItem(index, arena.NewNull()) + array.SetArrayItem(j.jsonArena, index, astjson.NullValue) return } - array.SetArrayItem(index, arena.NewString(graphqlValue)) + array.SetArrayItem(j.jsonArena, index, astjson.StringValue(j.jsonArena, graphqlValue)) } } // toDataObject wraps a response value in the standard GraphQL data envelope. // This creates the top-level structure { "data": ... } that GraphQL clients expect. func (j *jsonBuilder) toDataObject(root *astjson.Value) *astjson.Value { - a := astjson.Arena{} - data := a.NewObject() - data.Set(dataPath, root) + data := astjson.ObjectValue(j.jsonArena) + data.Set(j.jsonArena, dataPath, root) return data } @@ -698,30 +699,27 @@ func (j *jsonBuilder) toDataObject(root *astjson.Value) *astjson.Value { // This includes the error message and gRPC status code information in the extensions // field, following GraphQL error specification standards. func (j *jsonBuilder) writeErrorBytes(err error) []byte { - a := astjson.Arena{} - defer a.Reset() - // Create standard GraphQL error structure - errorRoot := a.NewObject() - errorArray := a.NewArray() - errorRoot.Set(errorsPath, errorArray) + errorRoot := astjson.ObjectValue(j.jsonArena) + errorArray := astjson.ArrayValue(j.jsonArena) + errorRoot.Set(j.jsonArena, errorsPath, errorArray) // Create individual error object - errorItem := a.NewObject() - errorItem.Set("message", a.NewString(err.Error())) + errorItem := astjson.ObjectValue(j.jsonArena) + errorItem.Set(j.jsonArena, "message", astjson.StringValue(j.jsonArena, err.Error())) // Add gRPC status code information to extensions - extensions := a.NewObject() + extensions := astjson.ObjectValue(j.jsonArena) if st, ok := status.FromError(err); ok { // gRPC error - include the specific status code - extensions.Set("code", a.NewString(st.Code().String())) + extensions.Set(j.jsonArena, "code", astjson.StringValue(j.jsonArena, st.Code().String())) } else { // Generic error - default to INTERNAL status - extensions.Set("code", a.NewString(codes.Internal.String())) + extensions.Set(j.jsonArena, "code", astjson.StringValue(j.jsonArena, codes.Internal.String())) } - errorItem.Set("extensions", extensions) - errorArray.SetArrayItem(0, errorItem) + errorItem.Set(j.jsonArena, "extensions", extensions) + errorArray.SetArrayItem(j.jsonArena, 0, errorItem) return errorRoot.MarshalTo(nil) } diff --git a/v2/pkg/engine/datasource/httpclient/httpclient_test.go b/v2/pkg/engine/datasource/httpclient/httpclient_test.go index 223e5d833..98685cece 100644 --- a/v2/pkg/engine/datasource/httpclient/httpclient_test.go +++ b/v2/pkg/engine/datasource/httpclient/httpclient_test.go @@ -1,7 +1,6 @@ package httpclient import ( - "bytes" "compress/gzip" "context" "io" @@ -80,10 +79,9 @@ func TestHttpClientDo(t *testing.T) { runTest := func(ctx context.Context, input []byte, expectedOutput string) func(t *testing.T) { return func(t *testing.T) { - out := &bytes.Buffer{} - err := Do(http.DefaultClient, ctx, input, out) + output, err := Do(http.DefaultClient, ctx, nil, input) assert.NoError(t, err) - assert.Equal(t, expectedOutput, out.String()) + assert.Equal(t, expectedOutput, string(output)) } } @@ -211,9 +209,8 @@ func TestHttpClientDo(t *testing.T) { input = SetInputURL(input, []byte(server.URL)) input, err := sjson.SetBytes(input, TRACE, true) assert.NoError(t, err) - out := &bytes.Buffer{} - err = Do(http.DefaultClient, context.Background(), input, out) + output, err := Do(http.DefaultClient, context.Background(), nil, input) assert.NoError(t, err) - assert.Contains(t, out.String(), `"Authorization":["****"]`) + assert.Contains(t, string(output), `"Authorization":["****"]`) }) } diff --git a/v2/pkg/engine/datasource/httpclient/nethttpclient.go b/v2/pkg/engine/datasource/httpclient/nethttpclient.go index 4e8ca9b31..46af845e4 100644 --- a/v2/pkg/engine/datasource/httpclient/nethttpclient.go +++ b/v2/pkg/engine/datasource/httpclient/nethttpclient.go @@ -20,7 +20,6 @@ import ( "github.com/buger/jsonparser" "github.com/wundergraph/graphql-go-tools/v2/pkg/lexer/literal" - "github.com/wundergraph/graphql-go-tools/v2/pkg/pool" ) const ( @@ -28,6 +27,7 @@ const ( AcceptEncodingHeader = "Accept-Encoding" AcceptHeader = "Accept" ContentTypeHeader = "Content-Type" + ContentLengthHeader = "Content-Length" EncodingGzip = "gzip" EncodingDeflate = "deflate" @@ -130,21 +130,38 @@ func respBodyReader(res *http.Response) (io.Reader, error) { } } -type bodyHashContextKey struct{} +type httpClientContext string -func BodyHashFromContext(ctx context.Context) (uint64, bool) { - value := ctx.Value(bodyHashContextKey{}) - if value == nil { - return 0, false +const ( + sizeHintKey httpClientContext = "size-hint" +) + +// WithHTTPClientSizeHint allows the engine to keep track of response sizes per subgraph fetch +// If a hint is supplied, we can create a buffer of size close to the required size +// This reduces allocations by reducing the buffer grow calls, which always copies the buffer +func WithHTTPClientSizeHint(ctx context.Context, size int) context.Context { + return context.WithValue(ctx, sizeHintKey, size) +} + +func buffer(ctx context.Context) *bytes.Buffer { + if sizeHint, ok := ctx.Value(sizeHintKey).(int); ok && sizeHint > 0 { + return bytes.NewBuffer(make([]byte, 0, sizeHint)) } - return value.(uint64), true + // if we start with zero, doubling will take a while until we reach the required size + // if we start with a high number, e.g. 1024, we just increase the memory usage of the engine + // 64 seems to be a healthy middle ground + return bytes.NewBuffer(make([]byte, 0, 64)) } -func makeHTTPRequest(client *http.Client, ctx context.Context, url, method, headers, queryParams []byte, body io.Reader, enableTrace bool, out *bytes.Buffer, contentType string) (err error) { +func makeHTTPRequest(client *http.Client, ctx context.Context, baseHeaders http.Header, url, method, headers, queryParams []byte, body io.Reader, enableTrace bool, contentType string, contentLength int) ([]byte, error) { request, err := http.NewRequestWithContext(ctx, string(method), string(url), body) if err != nil { - return err + return nil, err + } + + if baseHeaders != nil { + request.Header = baseHeaders } if headers != nil { @@ -161,7 +178,7 @@ func makeHTTPRequest(client *http.Client, ctx context.Context, url, method, head return err }) if err != nil { - return err + return nil, err } } @@ -190,7 +207,7 @@ func makeHTTPRequest(client *http.Client, ctx context.Context, url, method, head } }) if err != nil { - return err + return nil, err } request.URL.RawQuery = query.Encode() } @@ -199,12 +216,17 @@ func makeHTTPRequest(client *http.Client, ctx context.Context, url, method, head request.Header.Add(ContentTypeHeader, contentType) request.Header.Set(AcceptEncodingHeader, EncodingGzip) request.Header.Add(AcceptEncodingHeader, EncodingDeflate) + if contentLength > 0 { + // always set the ContentLength field so that chunking can be avoided + // and other parties can more efficiently parse + request.ContentLength = int64(contentLength) + } setRequest(ctx, request) response, err := client.Do(request) if err != nil { - return err + return nil, err } defer response.Body.Close() @@ -212,23 +234,26 @@ func makeHTTPRequest(client *http.Client, ctx context.Context, url, method, head respReader, err := respBodyReader(response) if err != nil { - return err + return nil, err } - if !enableTrace { - if response.ContentLength > 0 { - out.Grow(int(response.ContentLength)) - } else { - out.Grow(1024 * 4) - } - _, err = out.ReadFrom(respReader) - return + // we intentionally don't use a pool of sorts here + // we're buffering the response and then later, in the engine, + // parse it into an JSON AST with the use of an arena, which is quite efficient + // Through trial and error it turned out that it's best to leave this buffer to the GC + // It'll know best the lifecycle of the buffer + // Using an arena here just increased overall memory usage + out := buffer(ctx) + _, err = out.ReadFrom(respReader) + if err != nil { + return nil, err } - data, err := io.ReadAll(respReader) - if err != nil { - return err + if !enableTrace { + return out.Bytes(), nil } + + data := out.Bytes() responseTrace := TraceHTTP{ Request: TraceHTTPRequest{ Method: request.Method, @@ -244,39 +269,29 @@ func makeHTTPRequest(client *http.Client, ctx context.Context, url, method, head } trace, err := json.Marshal(responseTrace) if err != nil { - return err + return nil, err } responseWithTraceExtension, err := jsonparser.Set(data, trace, "extensions", "trace") if err != nil { - return err + return nil, err } - _, err = out.Write(responseWithTraceExtension) - return err + return responseWithTraceExtension, nil } -func Do(client *http.Client, ctx context.Context, requestInput []byte, out *bytes.Buffer) (err error) { +func Do(client *http.Client, ctx context.Context, baseHeaders http.Header, requestInput []byte) (data []byte, err error) { url, method, body, headers, queryParams, enableTrace := requestInputParams(requestInput) - h := pool.Hash64.Get() - _, _ = h.Write(body) - bodyHash := h.Sum64() - pool.Hash64.Put(h) - ctx = context.WithValue(ctx, bodyHashContextKey{}, bodyHash) - return makeHTTPRequest(client, ctx, url, method, headers, queryParams, bytes.NewReader(body), enableTrace, out, ContentTypeJSON) + return makeHTTPRequest(client, ctx, baseHeaders, url, method, headers, queryParams, bytes.NewReader(body), enableTrace, ContentTypeJSON, len(body)) } func DoMultipartForm( - client *http.Client, ctx context.Context, requestInput []byte, files []*FileUpload, out *bytes.Buffer, -) (err error) { + client *http.Client, ctx context.Context, baseHeaders http.Header, requestInput []byte, files []*FileUpload, +) (data []byte, err error) { if len(files) == 0 { - return errors.New("no files provided") + return nil, errors.New("no files provided") } url, method, body, headers, queryParams, enableTrace := requestInputParams(requestInput) - h := pool.Hash64.Get() - defer pool.Hash64.Put(h) - _, _ = h.Write(body) - formValues := map[string]io.Reader{ "operations": bytes.NewReader(body), } @@ -293,14 +308,13 @@ func DoMultipartForm( } hasWrittenFileName = true - fmt.Fprintf(fileMap, `"%d":["%s"]`, i, file.variablePath) + _, _ = fmt.Fprintf(fileMap, `"%d":["%s"]`, i, file.variablePath) key := fmt.Sprintf("%d", i) - _, _ = h.WriteString(file.Path()) temporaryFile, err := os.Open(file.Path()) tempFiles = append(tempFiles, temporaryFile) if err != nil { - return err + return nil, err } formValues[key] = bufio.NewReader(temporaryFile) } @@ -309,7 +323,7 @@ func DoMultipartForm( multipartBody, contentType, err := multipartBytes(formValues, files) if err != nil { - return err + return nil, err } defer func() { @@ -324,10 +338,7 @@ func DoMultipartForm( } }() - bodyHash := h.Sum64() - ctx = context.WithValue(ctx, bodyHashContextKey{}, bodyHash) - - return makeHTTPRequest(client, ctx, url, method, headers, queryParams, multipartBody, enableTrace, out, contentType) + return makeHTTPRequest(client, ctx, baseHeaders, url, method, headers, queryParams, multipartBody, enableTrace, contentType, 0) } func multipartBytes(values map[string]io.Reader, files []*FileUpload) (*io.PipeReader, string, error) { diff --git a/v2/pkg/engine/datasource/introspection_datasource/fixtures/schema_introspection.golden b/v2/pkg/engine/datasource/introspection_datasource/fixtures/schema_introspection.golden index f6ab07228..d6f62343c 100644 --- a/v2/pkg/engine/datasource/introspection_datasource/fixtures/schema_introspection.golden +++ b/v2/pkg/engine/datasource/introspection_datasource/fixtures/schema_introspection.golden @@ -363,4 +363,4 @@ } ], "__typename": "__Schema" -} +} \ No newline at end of file diff --git a/v2/pkg/engine/datasource/introspection_datasource/fixtures/schema_introspection_with_custom_root_operation_types.golden b/v2/pkg/engine/datasource/introspection_datasource/fixtures/schema_introspection_with_custom_root_operation_types.golden index 6b73ac8dc..f56fee360 100644 --- a/v2/pkg/engine/datasource/introspection_datasource/fixtures/schema_introspection_with_custom_root_operation_types.golden +++ b/v2/pkg/engine/datasource/introspection_datasource/fixtures/schema_introspection_with_custom_root_operation_types.golden @@ -511,4 +511,4 @@ } ], "__typename": "__Schema" -} +} \ No newline at end of file diff --git a/v2/pkg/engine/datasource/introspection_datasource/fixtures/type_introspection.golden b/v2/pkg/engine/datasource/introspection_datasource/fixtures/type_introspection.golden index 41827c0f6..16017d131 100644 --- a/v2/pkg/engine/datasource/introspection_datasource/fixtures/type_introspection.golden +++ b/v2/pkg/engine/datasource/introspection_datasource/fixtures/type_introspection.golden @@ -56,4 +56,4 @@ "interfaces": [], "possibleTypes": [], "__typename": "__Type" -} +} \ No newline at end of file diff --git a/v2/pkg/engine/datasource/introspection_datasource/source.go b/v2/pkg/engine/datasource/introspection_datasource/source.go index b9a06489d..67195e44a 100644 --- a/v2/pkg/engine/datasource/introspection_datasource/source.go +++ b/v2/pkg/engine/datasource/introspection_datasource/source.go @@ -1,11 +1,11 @@ package introspection_datasource import ( - "bytes" "context" "encoding/json" "errors" "io" + "net/http" "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/httpclient" "github.com/wundergraph/graphql-go-tools/v2/pkg/introspection" @@ -19,21 +19,21 @@ type Source struct { introspectionData *introspection.Data } -func (s *Source) Load(ctx context.Context, input []byte, out *bytes.Buffer) (err error) { +func (s *Source) Load(ctx context.Context, headers http.Header, input []byte) (data []byte, err error) { var req introspectionInput if err := json.Unmarshal(input, &req); err != nil { - return err + return nil, err } if req.RequestType == TypeRequestType { - return s.singleType(out, req.TypeName) + return s.singleTypeBytes(req.TypeName) } - return json.NewEncoder(out).Encode(s.introspectionData.Schema) + return json.Marshal(s.introspectionData.Schema) } -func (s *Source) LoadWithFiles(ctx context.Context, input []byte, files []*httpclient.FileUpload, out *bytes.Buffer) (err error) { - return errors.New("introspection data source does not support file uploads") +func (s *Source) LoadWithFiles(ctx context.Context, headers http.Header, input []byte, files []*httpclient.FileUpload) (data []byte, err error) { + return nil, errors.New("introspection data source does not support file uploads") } func (s *Source) typeInfo(typeName *string) *introspection.FullType { @@ -57,3 +57,12 @@ func (s *Source) singleType(w io.Writer, typeName *string) error { return json.NewEncoder(w).Encode(typeInfo) } + +func (s *Source) singleTypeBytes(typeName *string) ([]byte, error) { + typeInfo := s.typeInfo(typeName) + if typeInfo == nil { + return null, nil + } + + return json.Marshal(typeInfo) +} diff --git a/v2/pkg/engine/datasource/introspection_datasource/source_test.go b/v2/pkg/engine/datasource/introspection_datasource/source_test.go index bb4a91143..9737a4ee9 100644 --- a/v2/pkg/engine/datasource/introspection_datasource/source_test.go +++ b/v2/pkg/engine/datasource/introspection_datasource/source_test.go @@ -27,13 +27,18 @@ func TestSource_Load(t *testing.T) { gen.Generate(&def, &report, &data) require.False(t, report.HasErrors()) - buf := &bytes.Buffer{} source := &Source{introspectionData: &data} - require.NoError(t, source.Load(context.Background(), []byte(input), buf)) + responseData, err := source.Load(context.Background(), nil, []byte(input)) + require.NoError(t, err) actualResponse := &bytes.Buffer{} - require.NoError(t, json.Indent(actualResponse, buf.Bytes(), "", " ")) - goldie.Assert(t, fixtureName, actualResponse.Bytes()) + require.NoError(t, json.Indent(actualResponse, responseData, "", " ")) + // Trim the trailing newline that json.Indent adds + responseBytes := actualResponse.Bytes() + if len(responseBytes) > 0 && responseBytes[len(responseBytes)-1] == '\n' { + responseBytes = responseBytes[:len(responseBytes)-1] + } + goldie.Assert(t, fixtureName, responseBytes) } } diff --git a/v2/pkg/engine/datasource/pubsub_datasource/pubsub_datasource_test.go b/v2/pkg/engine/datasource/pubsub_datasource/pubsub_datasource_test.go index 28a37df33..2ea8114ad 100644 --- a/v2/pkg/engine/datasource/pubsub_datasource/pubsub_datasource_test.go +++ b/v2/pkg/engine/datasource/pubsub_datasource/pubsub_datasource_test.go @@ -424,6 +424,8 @@ func TestPubSub(t *testing.T) { PostProcessing: resolve.PostProcessingConfiguration{ MergePath: []string{"helloSubscription"}, }, + SourceName: "test", + SourceID: "test", }, Response: &resolve.GraphQLResponse{ Data: &resolve.Object{ @@ -487,6 +489,8 @@ func TestPubSub(t *testing.T) { PostProcessing: resolve.PostProcessingConfiguration{ MergePath: []string{"subscriptionWithMultipleSubjects"}, }, + SourceName: "test", + SourceID: "test", }, Response: &resolve.GraphQLResponse{ Data: &resolve.Object{ @@ -532,6 +536,8 @@ func TestPubSub(t *testing.T) { PostProcessing: resolve.PostProcessingConfiguration{ MergePath: []string{"subscriptionWithStaticValues"}, }, + SourceName: "test", + SourceID: "test", }, Response: &resolve.GraphQLResponse{ Data: &resolve.Object{ @@ -583,6 +589,8 @@ func TestPubSub(t *testing.T) { PostProcessing: resolve.PostProcessingConfiguration{ MergePath: []string{"subscriptionWithArgTemplateAndStaticValue"}, }, + SourceName: "test", + SourceID: "test", }, Response: &resolve.GraphQLResponse{ Data: &resolve.Object{ diff --git a/v2/pkg/engine/datasource/pubsub_datasource/pubsub_kafka.go b/v2/pkg/engine/datasource/pubsub_datasource/pubsub_kafka.go index cc562b803..3f688b6b1 100644 --- a/v2/pkg/engine/datasource/pubsub_datasource/pubsub_kafka.go +++ b/v2/pkg/engine/datasource/pubsub_datasource/pubsub_kafka.go @@ -1,13 +1,9 @@ package pubsub_datasource import ( - "bytes" "context" "encoding/json" - "io" - - "github.com/buger/jsonparser" - "github.com/cespare/xxhash/v2" + "net/http" "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/httpclient" "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" @@ -33,28 +29,7 @@ type KafkaSubscriptionSource struct { pubSub KafkaPubSub } -func (s *KafkaSubscriptionSource) UniqueRequestID(ctx *resolve.Context, input []byte, xxh *xxhash.Digest) error { - - val, _, _, err := jsonparser.Get(input, "topics") - if err != nil { - return err - } - - _, err = xxh.Write(val) - if err != nil { - return err - } - - val, _, _, err = jsonparser.Get(input, "providerId") - if err != nil { - return err - } - - _, err = xxh.Write(val) - return err -} - -func (s *KafkaSubscriptionSource) Start(ctx *resolve.Context, input []byte, updater resolve.SubscriptionUpdater) error { +func (s *KafkaSubscriptionSource) Start(ctx *resolve.Context, headers http.Header, input []byte, updater resolve.SubscriptionUpdater) error { var subscriptionConfiguration KafkaSubscriptionEventConfiguration err := json.Unmarshal(input, &subscriptionConfiguration) if err != nil { @@ -68,21 +43,19 @@ type KafkaPublishDataSource struct { pubSub KafkaPubSub } -func (s *KafkaPublishDataSource) Load(ctx context.Context, input []byte, out *bytes.Buffer) error { +func (s *KafkaPublishDataSource) Load(ctx context.Context, headers http.Header, input []byte) (data []byte, err error) { var publishConfiguration KafkaPublishEventConfiguration - err := json.Unmarshal(input, &publishConfiguration) + err = json.Unmarshal(input, &publishConfiguration) if err != nil { - return err + return nil, err } if err := s.pubSub.Publish(ctx, publishConfiguration); err != nil { - _, err = io.WriteString(out, `{"success": false}`) - return err + return []byte(`{"success": false}`), err } - _, err = io.WriteString(out, `{"success": true}`) - return err + return []byte(`{"success": true}`), nil } -func (s *KafkaPublishDataSource) LoadWithFiles(ctx context.Context, input []byte, files []*httpclient.FileUpload, out *bytes.Buffer) (err error) { +func (s *KafkaPublishDataSource) LoadWithFiles(ctx context.Context, headers http.Header, input []byte, files []*httpclient.FileUpload) (data []byte, err error) { panic("not implemented") } diff --git a/v2/pkg/engine/datasource/pubsub_datasource/pubsub_nats.go b/v2/pkg/engine/datasource/pubsub_datasource/pubsub_nats.go index 31cb6d415..776b5deac 100644 --- a/v2/pkg/engine/datasource/pubsub_datasource/pubsub_nats.go +++ b/v2/pkg/engine/datasource/pubsub_datasource/pubsub_nats.go @@ -5,9 +5,7 @@ import ( "context" "encoding/json" "io" - - "github.com/buger/jsonparser" - "github.com/cespare/xxhash/v2" + "net/http" "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/httpclient" "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" @@ -42,28 +40,7 @@ type NatsSubscriptionSource struct { pubSub NatsPubSub } -func (s *NatsSubscriptionSource) UniqueRequestID(ctx *resolve.Context, input []byte, xxh *xxhash.Digest) error { - - val, _, _, err := jsonparser.Get(input, "subjects") - if err != nil { - return err - } - - _, err = xxh.Write(val) - if err != nil { - return err - } - - val, _, _, err = jsonparser.Get(input, "providerId") - if err != nil { - return err - } - - _, err = xxh.Write(val) - return err -} - -func (s *NatsSubscriptionSource) Start(ctx *resolve.Context, input []byte, updater resolve.SubscriptionUpdater) error { +func (s *NatsSubscriptionSource) Start(ctx *resolve.Context, headers http.Header, input []byte, updater resolve.SubscriptionUpdater) error { var subscriptionConfiguration NatsSubscriptionEventConfiguration err := json.Unmarshal(input, &subscriptionConfiguration) if err != nil { @@ -77,23 +54,21 @@ type NatsPublishDataSource struct { pubSub NatsPubSub } -func (s *NatsPublishDataSource) Load(ctx context.Context, input []byte, out *bytes.Buffer) error { +func (s *NatsPublishDataSource) Load(ctx context.Context, headers http.Header, input []byte) (data []byte, err error) { var publishConfiguration NatsPublishAndRequestEventConfiguration - err := json.Unmarshal(input, &publishConfiguration) + err = json.Unmarshal(input, &publishConfiguration) if err != nil { - return err + return nil, err } if err := s.pubSub.Publish(ctx, publishConfiguration); err != nil { - _, err = io.WriteString(out, `{"success": false}`) - return err + return []byte(`{"success": false}`), err } - _, err = io.WriteString(out, `{"success": true}`) - return err + return []byte(`{"success": true}`), nil } -func (s *NatsPublishDataSource) LoadWithFiles(ctx context.Context, input []byte, files []*httpclient.FileUpload, out *bytes.Buffer) error { +func (s *NatsPublishDataSource) LoadWithFiles(ctx context.Context, headers http.Header, input []byte, files []*httpclient.FileUpload) (data []byte, err error) { panic("not implemented") } @@ -101,16 +76,22 @@ type NatsRequestDataSource struct { pubSub NatsPubSub } -func (s *NatsRequestDataSource) Load(ctx context.Context, input []byte, out *bytes.Buffer) error { +func (s *NatsRequestDataSource) Load(ctx context.Context, headers http.Header, input []byte) (data []byte, err error) { var subscriptionConfiguration NatsPublishAndRequestEventConfiguration - err := json.Unmarshal(input, &subscriptionConfiguration) + err = json.Unmarshal(input, &subscriptionConfiguration) if err != nil { - return err + return nil, err + } + + var buf bytes.Buffer + err = s.pubSub.Request(ctx, subscriptionConfiguration, &buf) + if err != nil { + return nil, err } - return s.pubSub.Request(ctx, subscriptionConfiguration, out) + return buf.Bytes(), nil } -func (s *NatsRequestDataSource) LoadWithFiles(ctx context.Context, input []byte, files []*httpclient.FileUpload, out *bytes.Buffer) error { +func (s *NatsRequestDataSource) LoadWithFiles(ctx context.Context, headers http.Header, input []byte, files []*httpclient.FileUpload) (data []byte, err error) { panic("not implemented") } diff --git a/v2/pkg/engine/datasource/staticdatasource/static_datasource.go b/v2/pkg/engine/datasource/staticdatasource/static_datasource.go index e9074635c..3fb75c8b3 100644 --- a/v2/pkg/engine/datasource/staticdatasource/static_datasource.go +++ b/v2/pkg/engine/datasource/staticdatasource/static_datasource.go @@ -1,8 +1,8 @@ package staticdatasource import ( - "bytes" "context" + "net/http" "github.com/jensneuse/abstractlogger" @@ -71,11 +71,10 @@ func (p *Planner[T]) ConfigureSubscription() plan.SubscriptionConfiguration { type Source struct{} -func (Source) Load(ctx context.Context, input []byte, out *bytes.Buffer) (err error) { - _, err = out.Write(input) - return +func (Source) Load(ctx context.Context, headers http.Header, input []byte) (data []byte, err error) { + return input, nil } -func (Source) LoadWithFiles(ctx context.Context, input []byte, files []*httpclient.FileUpload, out *bytes.Buffer) (err error) { +func (Source) LoadWithFiles(ctx context.Context, headers http.Header, input []byte, files []*httpclient.FileUpload) (data []byte, err error) { panic("not implemented") } diff --git a/v2/pkg/engine/plan/planner_test.go b/v2/pkg/engine/plan/planner_test.go index 270140381..b952107f0 100644 --- a/v2/pkg/engine/plan/planner_test.go +++ b/v2/pkg/engine/plan/planner_test.go @@ -1,10 +1,10 @@ package plan import ( - "bytes" "context" "encoding/json" "fmt" + "net/http" "reflect" "slices" "testing" @@ -1075,10 +1075,10 @@ type FakeDataSource struct { source *StatefulSource } -func (f *FakeDataSource) Load(ctx context.Context, input []byte, out *bytes.Buffer) (err error) { - return +func (f *FakeDataSource) Load(ctx context.Context, headers http.Header, input []byte) (data []byte, err error) { + return nil, nil } -func (f *FakeDataSource) LoadWithFiles(ctx context.Context, input []byte, files []*httpclient.FileUpload, out *bytes.Buffer) (err error) { - return +func (f *FakeDataSource) LoadWithFiles(ctx context.Context, headers http.Header, input []byte, files []*httpclient.FileUpload) (data []byte, err error) { + return nil, nil } diff --git a/v2/pkg/engine/plan/visitor.go b/v2/pkg/engine/plan/visitor.go index ebbd0d5c1..9a1379e0b 100644 --- a/v2/pkg/engine/plan/visitor.go +++ b/v2/pkg/engine/plan/visitor.go @@ -1291,6 +1291,8 @@ func (v *Visitor) configureSubscription(config *objectFetchConfiguration) { v.subscription.Trigger.QueryPlan = subscription.QueryPlan v.resolveInputTemplates(config, &subscription.Input, &v.subscription.Trigger.Variables) v.subscription.Trigger.Input = []byte(subscription.Input) + v.subscription.Trigger.SourceName = config.sourceName + v.subscription.Trigger.SourceID = config.sourceID v.subscription.Filter = config.filter } diff --git a/v2/pkg/engine/postprocess/create_concrete_single_fetch_types.go b/v2/pkg/engine/postprocess/create_concrete_single_fetch_types.go index f5d0b2ae2..44b3225fb 100644 --- a/v2/pkg/engine/postprocess/create_concrete_single_fetch_types.go +++ b/v2/pkg/engine/postprocess/create_concrete_single_fetch_types.go @@ -51,19 +51,11 @@ func (d *createConcreteSingleFetchTypes) traverseSingleFetch(fetch *resolve.Sing return d.createEntityBatchFetch(fetch) case fetch.RequiresEntityFetch: return d.createEntityFetch(fetch) - case fetch.RequiresParallelListItemFetch: - return d.createParallelListItemFetch(fetch) default: return fetch } } -func (d *createConcreteSingleFetchTypes) createParallelListItemFetch(fetch *resolve.SingleFetch) resolve.Fetch { - return &resolve.ParallelListItemFetch{ - Fetch: fetch, - } -} - func (d *createConcreteSingleFetchTypes) createEntityBatchFetch(fetch *resolve.SingleFetch) resolve.Fetch { representationsVariableIndex := -1 for i, segment := range fetch.InputTemplate.Segments { diff --git a/v2/pkg/engine/resolve/arena.go b/v2/pkg/engine/resolve/arena.go new file mode 100644 index 000000000..7909460b2 --- /dev/null +++ b/v2/pkg/engine/resolve/arena.go @@ -0,0 +1,144 @@ +package resolve + +import ( + "sync" + "weak" + + "github.com/wundergraph/go-arena" +) + +// ArenaPool provides a thread-safe pool of arena.Arena instances for memory-efficient allocations. +// It uses weak pointers to allow garbage collection of unused arenas while maintaining +// a pool of reusable arenas for high-frequency allocation patterns. +// +// by storing ArenaPoolItem as weak pointers, the GC can collect them at any time +// before using an ArenaPoolItem, we try to get a strong pointer while removing it from the pool +// once we call Release, we turn the item back to the pool and make it a weak pointer again +// this means that at any time, GC can claim back the memory if required, +// allowing GC to automatically manage an appropriate pool size depending on available memory and GC pressure +type ArenaPool struct { + // pool is a slice of weak pointers to the struct holding the arena.Arena + pool []weak.Pointer[ArenaPoolItem] + sizes map[uint64]*arenaPoolItemSize + mu sync.Mutex +} + +// arenaPoolItemSize is used to track the required memory across the last 50 arenas in the pool +type arenaPoolItemSize struct { + count int + totalBytes int +} + +// ArenaPoolItem wraps an arena.Arena for use in the pool +type ArenaPoolItem struct { + Arena arena.Arena + Key uint64 +} + +// NewArenaPool creates a new ArenaPool instance +func NewArenaPool() *ArenaPool { + return &ArenaPool{ + sizes: make(map[uint64]*arenaPoolItemSize), + } +} + +// Acquire gets an arena from the pool or creates a new one if none are available. +// The id parameter is used to track arena sizes per use case for optimization. +func (p *ArenaPool) Acquire(key uint64) *ArenaPoolItem { + p.mu.Lock() + defer p.mu.Unlock() + + // Try to find an available arena in the pool + for len(p.pool) > 0 { + // Pop the last item + lastIdx := len(p.pool) - 1 + wp := p.pool[lastIdx] + p.pool = p.pool[:lastIdx] + + v := wp.Value() + if v != nil { + v.Key = key + return v + } + // If weak pointer was nil (GC collected), continue to next item + } + + // No arena available, create a new one + size := arena.WithMinBufferSize(p.getArenaSize(key)) + return &ArenaPoolItem{ + Arena: arena.NewMonotonicArena(size), + Key: key, + } +} + +// Release returns an arena to the pool for reuse. +// The peak memory usage is recorded to optimize future arena sizes for this use case. +func (p *ArenaPool) Release(item *ArenaPoolItem) { + peak := item.Arena.Peak() + item.Arena.Reset() + + p.mu.Lock() + defer p.mu.Unlock() + + // Record the peak usage for this use case + if size, ok := p.sizes[item.Key]; ok { + if size.count == 50 { + size.count = 1 + size.totalBytes = size.totalBytes / 50 + } + size.count++ + size.totalBytes += peak + } else { + p.sizes[item.Key] = &arenaPoolItemSize{ + count: 1, + totalBytes: peak, + } + } + + item.Key = 0 + + // Add the arena back to the pool using a weak pointer + w := weak.Make(item) + p.pool = append(p.pool, w) +} + +func (p *ArenaPool) ReleaseMany(items []*ArenaPoolItem) { + p.mu.Lock() + defer p.mu.Unlock() + + for _, item := range items { + + peak := item.Arena.Peak() + item.Arena.Reset() + + // Record the peak usage for this use case + if size, ok := p.sizes[item.Key]; ok { + if size.count == 50 { + size.count = 1 + size.totalBytes = size.totalBytes / 50 + } + size.count++ + size.totalBytes += peak + } else { + p.sizes[item.Key] = &arenaPoolItemSize{ + count: 1, + totalBytes: peak, + } + } + + item.Key = 0 + + // Add the arena back to the pool using a weak pointer + w := weak.Make(item) + p.pool = append(p.pool, w) + } +} + +// getArenaSize returns the optimal arena size for a given use case ID. +// If no size is recorded, it defaults to 1MB. +func (p *ArenaPool) getArenaSize(id uint64) int { + if size, ok := p.sizes[id]; ok { + return size.totalBytes / size.count + } + return 1024 * 1024 // Default 1MB +} diff --git a/v2/pkg/engine/resolve/arena_test.go b/v2/pkg/engine/resolve/arena_test.go new file mode 100644 index 000000000..c884434f1 --- /dev/null +++ b/v2/pkg/engine/resolve/arena_test.go @@ -0,0 +1,261 @@ +package resolve + +import ( + "runtime" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/wundergraph/go-arena" +) + +func TestNewArenaPool(t *testing.T) { + pool := NewArenaPool() + + require.NotNil(t, pool, "NewArenaPool returned nil") + assert.Equal(t, 0, len(pool.pool), "expected empty pool") + assert.Equal(t, 0, len(pool.sizes), "expected empty sizes map") +} + +func TestArenaPool_Acquire_EmptyPool(t *testing.T) { + pool := NewArenaPool() + + item := pool.Acquire(1) + + require.NotNil(t, item, "Acquire returned nil") + assert.NotNil(t, item.Arena, "Arena is nil") + + // Verify we can use the arena + buf := arena.NewArenaBuffer(item.Arena) + _, err := buf.WriteString("test") + assert.NoError(t, err) + + assert.Equal(t, 0, len(pool.pool), "pool should still be empty") +} + +func TestArenaPool_ReleaseAndAcquire(t *testing.T) { + pool := NewArenaPool() + id := uint64(42) + + // Acquire first arena + item1 := pool.Acquire(id) + + // Use the arena + buf := arena.NewArenaBuffer(item1.Arena) + _, err := buf.WriteString("test data") + assert.NoError(t, err) + + // Release it + pool.Release(item1) + + // Pool should have one item + assert.Equal(t, 1, len(pool.pool), "expected pool to have 1 item") + + // Acquire from pool + item2 := pool.Acquire(id) + + require.NotNil(t, item2, "Acquire returned nil") + + // Pool should be empty again + assert.Equal(t, 0, len(pool.pool), "expected empty pool after acquire") + + // The acquired arena should be reset and usable + buf2 := arena.NewArenaBuffer(item2.Arena) + _, err = buf2.WriteString("new data") + assert.NoError(t, err) + + assert.Equal(t, "new data", buf2.String()) +} + +func TestArenaPool_Acquire_ProvesBugFix(t *testing.T) { + // This test specifically proves the bug fix works + // Creates multiple items, clears some references, then acquires + // to ensure all items are checked without skipping + pool := NewArenaPool() + id := uint64(800) + + numItems := 10 + items := make([]*ArenaPoolItem, numItems) + + // Acquire all items + for i := 0; i < numItems; i++ { + items[i] = pool.Acquire(id) + buf := arena.NewArenaBuffer(items[i].Arena) + _, err := buf.WriteString("item data") + assert.NoError(t, err) + } + + // Release all while keeping strong references + for i := 0; i < numItems; i++ { + pool.Release(items[i]) + } + + // Pool should have all items + assert.Equal(t, numItems, len(pool.pool), "expected items in pool") + + // Clear every other item to simulate partial GC + for i := 0; i < numItems; i += 2 { + items[i] = nil + } + + // Force GC + runtime.GC() + runtime.GC() + + // Acquire items - should process ALL items without skipping + processed := 0 + acquired := 0 + + for len(pool.pool) > 0 && processed < numItems*2 { + poolSizeBefore := len(pool.pool) + item := pool.Acquire(id) + poolSizeAfter := len(pool.pool) + processed++ + + assert.Less(t, poolSizeAfter, poolSizeBefore, "Pool size did not decrease - item not removed properly!") + + if item != nil { + acquired++ + } + } + + // Pool should be empty + assert.Equal(t, 0, len(pool.pool), "expected empty pool") +} + +func TestArenaPool_Release_PeakTracking(t *testing.T) { + pool := NewArenaPool() + id := uint64(200) + + // First arena + item1 := pool.Acquire(id) + buf1 := arena.NewArenaBuffer(item1.Arena) + _, err := buf1.WriteString("small") + assert.NoError(t, err) + + peak1 := item1.Arena.Peak() + assert.Equal(t, peak1, 5) + + pool.Release(item1) + + // Check that size was tracked + size, exists := pool.sizes[id] + require.True(t, exists, "size tracking not created") + assert.Equal(t, 1, size.count, "expected count 1") + + // Second arena + item2 := pool.Acquire(id) + buf2 := arena.NewArenaBuffer(item2.Arena) + _, err = buf2.WriteString("larger data") + assert.NoError(t, err) + + pool.Release(item2) + + // Check updated tracking + assert.Equal(t, 2, size.count, "expected count 2") +} + +func TestArenaPool_GetArenaSize(t *testing.T) { + pool := NewArenaPool() + + // Test default size for unknown ID + size1 := pool.getArenaSize(999) + expectedDefault := 1024 * 1024 + assert.Equal(t, expectedDefault, size1, "expected default size") + + // Test calculated size after usage + id := uint64(400) + item := pool.Acquire(id) + buf := arena.NewArenaBuffer(item.Arena) + _, err := buf.WriteString("some data") + assert.NoError(t, err) + pool.Release(item) + + size2 := pool.getArenaSize(id) + assert.NotEqual(t, 0, size2, "expected non-zero size after usage") +} + +func TestArenaPool_MultipleItemsInPool(t *testing.T) { + pool := NewArenaPool() + id := uint64(500) + + // Acquire multiple distinct items + numItems := 3 + items := make([]*ArenaPoolItem, numItems) + + for i := 0; i < numItems; i++ { + items[i] = pool.Acquire(id) + buf := arena.NewArenaBuffer(items[i].Arena) + _, err := buf.WriteString("data") + assert.NoError(t, err) + } + + // Release all while keeping references + for i := 0; i < numItems; i++ { + pool.Release(items[i]) + } + + // Should have all items in pool + assert.Equal(t, numItems, len(pool.pool), "expected items in pool") + + // Acquire all back + acquired := 0 + for len(pool.pool) > 0 { + item := pool.Acquire(id) + if item != nil { + acquired++ + } + } + + assert.Equal(t, numItems, acquired, "expected to acquire all items") +} + +func TestArenaPool_Release_MovingWindow(t *testing.T) { + pool := NewArenaPool() + id := uint64(600) + + // Release exactly 50 items + for i := 0; i < 50; i++ { + item := pool.Acquire(id) + buf := arena.NewArenaBuffer(item.Arena) + _, err := buf.WriteString("test data") + assert.NoError(t, err) + pool.Release(item) + } + + // After 50 releases, verify count and total + size := pool.sizes[id] + require.NotNil(t, size, "size tracking should exist") + assert.Equal(t, 50, size.count, "expected count to be 50") + + totalBytesAfter50 := size.totalBytes + + // Release one more item to trigger the window reset + item51 := pool.Acquire(id) + buf51 := arena.NewArenaBuffer(item51.Arena) + _, err := buf51.WriteString("test data") + assert.NoError(t, err) + peak51 := item51.Arena.Peak() + pool.Release(item51) + + // After 51st release, verify the window was reset + // count should be 2 (reset to 1, then incremented) + // totalBytes should be (totalBytesAfter50 / 50) + peak51 + assert.Equal(t, 2, size.count, "expected count to be 2 after window reset") + + expectedTotalBytes := (totalBytesAfter50 / 50) + peak51 + assert.Equal(t, expectedTotalBytes, size.totalBytes, "expected totalBytes to be divided by 50 and new peak added") + + // Verify we can continue releasing and counting works correctly + for i := 0; i < 10; i++ { + item := pool.Acquire(id) + buf := arena.NewArenaBuffer(item.Arena) + _, err := buf.WriteString("more data") + assert.NoError(t, err) + pool.Release(item) + } + + // After 10 more releases, count should be 12 (2 + 10) + assert.Equal(t, 12, size.count, "expected count to continue incrementing after window reset") +} diff --git a/v2/pkg/engine/resolve/authorization_test.go b/v2/pkg/engine/resolve/authorization_test.go index 263724a77..95051def7 100644 --- a/v2/pkg/engine/resolve/authorization_test.go +++ b/v2/pkg/engine/resolve/authorization_test.go @@ -1,11 +1,11 @@ package resolve import ( - "bytes" "context" "encoding/json" "errors" "io" + "net/http" "sync/atomic" "testing" @@ -510,38 +510,32 @@ func TestAuthorization(t *testing.T) { func generateTestFederationGraphQLResponse(t *testing.T, ctrl *gomock.Controller) *GraphQLResponse { userService := NewMockDataSource(ctrl) userService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w *bytes.Buffer) (err error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4001","body":{"query":"{me {id username}}"}}` assert.Equal(t, expected, actual) - pair := NewBufPair() - pair.Data.WriteString(`{"me":{"id":"1234","username":"Me","__typename": "User"}}`) - return writeGraphqlResponse(pair, w, false) + return []byte(`{"data":{"me":{"id":"1234","username":"Me","__typename": "User"}}}`), nil }).AnyTimes() reviewsService := NewMockDataSource(ctrl) reviewsService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w *bytes.Buffer) (err error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4002","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){... on User {reviews {body product {upc __typename}}}}}","variables":{"representations":[{"__typename":"User","id":"1234"}]}}}` assert.Equal(t, expected, actual) - pair := NewBufPair() - pair.Data.WriteString(`{"_entities": [{"__typename":"User","reviews": [{"body": "A highly effective form of birth control.","product": {"upc": "top-1","__typename": "Product"}},{"body": "Fedoras are one of the most fashionable hats around and can look great with a variety of outfits.","product": {"upc": "top-2","__typename": "Product"}}]}]}`) - return writeGraphqlResponse(pair, w, false) + return []byte(`{"data":{"_entities": [{"__typename":"User","reviews": [{"body": "A highly effective form of birth control.","product": {"upc": "top-1","__typename": "Product"}},{"body": "Fedoras are one of the most fashionable hats around and can look great with a variety of outfits.","product": {"upc": "top-2","__typename": "Product"}}]}]}}`), nil }).AnyTimes() productService := NewMockDataSource(ctrl) productService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w *bytes.Buffer) (err error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4003","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){... on Product {name}}}","variables":{"representations":[{"__typename":"Product","upc":"top-1"},{"__typename":"Product","upc":"top-2"}]}}}` assert.Equal(t, expected, actual) - pair := NewBufPair() - pair.Data.WriteString(`{"_entities": [{"name": "Trilby"},{"name": "Fedora"}]}`) - return writeGraphqlResponse(pair, w, false) + return []byte(`{"data":{"_entities": [{"name": "Trilby"},{"name": "Fedora"}]}}`), nil }).AnyTimes() return &GraphQLResponse{ @@ -821,38 +815,32 @@ func generateTestFederationGraphQLResponse(t *testing.T, ctrl *gomock.Controller func generateTestFederationGraphQLResponseWithoutAuthorizationRules(t *testing.T, ctrl *gomock.Controller) *GraphQLResponse { userService := NewMockDataSource(ctrl) userService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w *bytes.Buffer) (err error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4001","body":{"query":"{me {id username}}"}}` assert.Equal(t, expected, actual) - pair := NewBufPair() - pair.Data.WriteString(`{"me":{"id":"1234","username":"Me","__typename": "User"}}`) - return writeGraphqlResponse(pair, w, false) + return []byte(`{"data":{"me":{"id":"1234","username":"Me","__typename": "User"}}}`), nil }).AnyTimes() reviewsService := NewMockDataSource(ctrl) reviewsService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w *bytes.Buffer) (err error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4002","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){... on User {reviews {body product {upc __typename}}}}}","variables":{"representations":[{"__typename":"User","id":"1234"}]}}}` assert.Equal(t, expected, actual) - pair := NewBufPair() - pair.Data.WriteString(`{"_entities": [{"__typename":"User","reviews": [{"body": "A highly effective form of birth control.","product": {"upc": "top-1","__typename": "Product"}},{"body": "Fedoras are one of the most fashionable hats around and can look great with a variety of outfits.","product": {"upc": "top-2","__typename": "Product"}}]}]}`) - return writeGraphqlResponse(pair, w, false) + return []byte(`{"data":{"_entities": [{"__typename":"User","reviews": [{"body": "A highly effective form of birth control.","product": {"upc": "top-1","__typename": "Product"}},{"body": "Fedoras are one of the most fashionable hats around and can look great with a variety of outfits.","product": {"upc": "top-2","__typename": "Product"}}]}]}}`), nil }).AnyTimes() productService := NewMockDataSource(ctrl) productService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w *bytes.Buffer) (err error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4003","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){... on Product {name}}}","variables":{"representations":[{"__typename":"Product","upc":"top-1"},{"__typename":"Product","upc":"top-2"}]}}}` assert.Equal(t, expected, actual) - pair := NewBufPair() - pair.Data.WriteString(`{"_entities": [{"name": "Trilby"},{"name": "Fedora"}]}`) - return writeGraphqlResponse(pair, w, false) + return []byte(`{"data":{"_entities": [{"name": "Trilby"},{"name": "Fedora"}]}}`), nil }).AnyTimes() return &GraphQLResponse{ diff --git a/v2/pkg/engine/resolve/const.go b/v2/pkg/engine/resolve/const.go index 8a259494e..2958fe1f5 100644 --- a/v2/pkg/engine/resolve/const.go +++ b/v2/pkg/engine/resolve/const.go @@ -8,6 +8,8 @@ var ( lBrack = []byte("[") rBrack = []byte("]") comma = []byte(",") + pipe = []byte("|") + dot = []byte(".") colon = []byte(":") quote = []byte("\"") null = []byte("null") diff --git a/v2/pkg/engine/resolve/context.go b/v2/pkg/engine/resolve/context.go index 65d2d6b90..5783b29a5 100644 --- a/v2/pkg/engine/resolve/context.go +++ b/v2/pkg/engine/resolve/context.go @@ -16,6 +16,7 @@ import ( type Context struct { ctx context.Context Variables *astjson.Value + VariablesHash uint64 Files []*httpclient.FileUpload Request Request RenameTypeNames []RenameTypeName @@ -32,12 +33,48 @@ type Context struct { fieldRenderer FieldValueRenderer subgraphErrors error + + SubgraphHeadersBuilder SubgraphHeadersBuilder +} + +// SubgraphHeadersBuilder allows the user of the engine to "define" the headers for a subgraph request +// Instead of going back and forth between engine & transport, +// you can simply define a function that returns headers for a Subgraph request +// In addition to just the header, the implementer can return a hash for the header which will be used by request deduplication +type SubgraphHeadersBuilder interface { + // HeadersForSubgraph must return the headers and a hash for a Subgraph Request + // The hash will be used for request deduplication + HeadersForSubgraph(subgraphName string) (http.Header, uint64) + // HashAll must return the hash for all subgraph requests combined + HashAll() uint64 +} + +// HeadersForSubgraphRequest returns headers and a hash for a request that the engine will make to a subgraph +func (c *Context) HeadersForSubgraphRequest(subgraphName string) (http.Header, uint64) { + if c.SubgraphHeadersBuilder == nil { + return nil, 0 + } + return c.SubgraphHeadersBuilder.HeadersForSubgraph(subgraphName) } type ExecutionOptions struct { - SkipLoader bool + // SkipLoader will, as the name indicates, skip loading data + // However, it does indeed resolve a response + // This can be useful, e.g. in combination with IncludeQueryPlanInResponse + // The purpose is to get a QueryPlan (even for Subscriptions) + SkipLoader bool + // IncludeQueryPlanInResponse generates a QueryPlan as part of the response in Resolvable IncludeQueryPlanInResponse bool - SendHeartbeat bool + // SendHeartbeat sends regular HeartBeats for Subscriptions + SendHeartbeat bool + // DisableSubgraphRequestDeduplication disables deduplication of requests to the same subgraph with the same input within a single operation execution. + DisableSubgraphRequestDeduplication bool + // DisableInboundRequestDeduplication disables deduplication of inbound client requests + // The engine is hashing the normalized operation, variables, and forwarded headers to achieve robust deduplication + // By default, overhead is negligible and as such this should be false (not disabled) most of the time + // However, if you're benchmarking internals of the engine, it can be helpful to switch it off + // When disabled (set to true) the code becomes a no-op + DisableInboundRequestDeduplication bool } type FieldValue struct { @@ -146,7 +183,7 @@ func (c *Context) appendSubgraphErrors(errs ...error) { } type Request struct { - ID string + ID uint64 Header http.Header } diff --git a/v2/pkg/engine/resolve/datasource.go b/v2/pkg/engine/resolve/datasource.go index c679d7693..7855fa637 100644 --- a/v2/pkg/engine/resolve/datasource.go +++ b/v2/pkg/engine/resolve/datasource.go @@ -1,28 +1,24 @@ package resolve import ( - "bytes" "context" - - "github.com/cespare/xxhash/v2" + "net/http" "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/httpclient" ) type DataSource interface { - Load(ctx context.Context, input []byte, out *bytes.Buffer) (err error) - LoadWithFiles(ctx context.Context, input []byte, files []*httpclient.FileUpload, out *bytes.Buffer) (err error) + Load(ctx context.Context, headers http.Header, input []byte) (data []byte, err error) + LoadWithFiles(ctx context.Context, headers http.Header, input []byte, files []*httpclient.FileUpload) (data []byte, err error) } type SubscriptionDataSource interface { // Start is called when a new subscription is created. It establishes the connection to the data source. // The updater is used to send updates to the client. Deduplication of the request must be done before calling this method. - Start(ctx *Context, input []byte, updater SubscriptionUpdater) error - UniqueRequestID(ctx *Context, input []byte, xxh *xxhash.Digest) (err error) + Start(ctx *Context, headers http.Header, input []byte, updater SubscriptionUpdater) error } type AsyncSubscriptionDataSource interface { - AsyncStart(ctx *Context, id uint64, input []byte, updater SubscriptionUpdater) error + AsyncStart(ctx *Context, id uint64, headers http.Header, input []byte, updater SubscriptionUpdater) error AsyncStop(id uint64) - UniqueRequestID(ctx *Context, input []byte, xxh *xxhash.Digest) (err error) } diff --git a/v2/pkg/engine/resolve/event_loop_test.go b/v2/pkg/engine/resolve/event_loop_test.go index 11389630a..ba8b7c8e2 100644 --- a/v2/pkg/engine/resolve/event_loop_test.go +++ b/v2/pkg/engine/resolve/event_loop_test.go @@ -3,12 +3,12 @@ package resolve import ( "context" "io" + "net/http" "sync" "sync/atomic" "testing" "time" - "github.com/cespare/xxhash/v2" "github.com/stretchr/testify/require" ) @@ -71,12 +71,7 @@ type FakeSource struct { interval time.Duration } -func (f *FakeSource) UniqueRequestID(ctx *Context, input []byte, xxh *xxhash.Digest) (err error) { - _, err = xxh.Write(input) - return err -} - -func (f *FakeSource) Start(ctx *Context, input []byte, updater SubscriptionUpdater) error { +func (f *FakeSource) Start(ctx *Context, headers http.Header, input []byte, updater SubscriptionUpdater) error { go func() { for i, u := range f.updates { updater.Update([]byte(u)) diff --git a/v2/pkg/engine/resolve/fetch.go b/v2/pkg/engine/resolve/fetch.go index deeea25a4..622e731c4 100644 --- a/v2/pkg/engine/resolve/fetch.go +++ b/v2/pkg/engine/resolve/fetch.go @@ -12,7 +12,6 @@ type FetchKind int const ( FetchKindSingle FetchKind = iota + 1 - FetchKindParallelListItem FetchKindEntity FetchKindEntityBatch ) @@ -227,27 +226,6 @@ func (*EntityFetch) FetchKind() FetchKind { return FetchKindEntity } -// The ParallelListItemFetch can be used to make nested parallel fetches within a list -// Usually, you want to batch fetches within a list, which is the default behavior of SingleFetch -// However, if the data source does not support batching, you can use this fetch to make parallel fetches within a list -type ParallelListItemFetch struct { - Fetch *SingleFetch - Traces []*SingleFetch - Trace *DataSourceLoadTrace -} - -func (p *ParallelListItemFetch) Dependencies() *FetchDependencies { - return &p.Fetch.FetchDependencies -} - -func (p *ParallelListItemFetch) FetchInfo() *FetchInfo { - return p.Fetch.Info -} - -func (*ParallelListItemFetch) FetchKind() FetchKind { - return FetchKindParallelListItem -} - type QueryPlan struct { DependsOnFields []Representation Query string @@ -272,12 +250,6 @@ type FetchConfiguration struct { Variables Variables DataSource DataSource - // RequiresParallelListItemFetch indicates that the single fetches should be executed without batching. - // If we have multiple fetches attached to the object, then after post-processing of a plan - // we will get ParallelListItemFetch instead of ParallelFetch. - // Happens only for objects under the array path and used only for the introspection. - RequiresParallelListItemFetch bool - // RequiresEntityFetch will be set to true if the fetch is an entity fetch on an object. // After post-processing, we will get EntityFetch. RequiresEntityFetch bool @@ -313,9 +285,6 @@ func (fc *FetchConfiguration) Equals(other *FetchConfiguration) bool { // Note: we do not compare datasources, as they will always be a different instance. - if fc.RequiresParallelListItemFetch != other.RequiresParallelListItemFetch { - return false - } if fc.RequiresEntityFetch != other.RequiresEntityFetch { return false } @@ -505,5 +474,4 @@ var ( _ Fetch = (*SingleFetch)(nil) _ Fetch = (*BatchEntityFetch)(nil) _ Fetch = (*EntityFetch)(nil) - _ Fetch = (*ParallelListItemFetch)(nil) ) diff --git a/v2/pkg/engine/resolve/fetchtree.go b/v2/pkg/engine/resolve/fetchtree.go index f4fd987ce..9bc38497c 100644 --- a/v2/pkg/engine/resolve/fetchtree.go +++ b/v2/pkg/engine/resolve/fetchtree.go @@ -130,17 +130,6 @@ func (n *FetchTreeNode) Trace() *FetchTreeTraceNode { Trace: f.Trace, Path: n.Item.ResponsePath, } - case *ParallelListItemFetch: - trace.Fetch = &FetchTraceNode{ - Kind: "ParallelList", - SourceID: f.Fetch.Info.DataSourceID, - SourceName: f.Fetch.Info.DataSourceName, - Traces: make([]*DataSourceLoadTrace, len(f.Traces)), - Path: n.Item.ResponsePath, - } - for i, t := range f.Traces { - trace.Fetch.Traces[i] = t.Trace - } default: } case FetchTreeNodeKindSequence, FetchTreeNodeKindParallel: @@ -253,20 +242,6 @@ func (n *FetchTreeNode) queryPlan() *FetchTreeQueryPlanNode { queryPlan.Fetch.Query = f.Info.QueryPlan.Query queryPlan.Fetch.Representations = f.Info.QueryPlan.DependsOnFields } - case *ParallelListItemFetch: - queryPlan.Fetch = &FetchTreeQueryPlan{ - Kind: "ParallelList", - FetchID: f.Fetch.FetchDependencies.FetchID, - DependsOnFetchIDs: f.Fetch.FetchDependencies.DependsOnFetchIDs, - SubgraphName: f.Fetch.Info.DataSourceName, - SubgraphID: f.Fetch.Info.DataSourceID, - Path: n.Item.ResponsePath, - } - - if f.Fetch.Info.QueryPlan != nil { - queryPlan.Fetch.Query = f.Fetch.Info.QueryPlan.Query - queryPlan.Fetch.Representations = f.Fetch.Info.QueryPlan.DependsOnFields - } default: } case FetchTreeNodeKindSequence, FetchTreeNodeKindParallel: diff --git a/v2/pkg/engine/resolve/inbound_request_singleflight.go b/v2/pkg/engine/resolve/inbound_request_singleflight.go new file mode 100644 index 000000000..2552a43fd --- /dev/null +++ b/v2/pkg/engine/resolve/inbound_request_singleflight.go @@ -0,0 +1,132 @@ +package resolve + +import ( + "encoding/binary" + "sync" + + "github.com/cespare/xxhash/v2" +) + +// InboundRequestSingleFlight is a sharded goroutine safe single flight implementation to de-couple inbound requests +// It's taking into consideration the normalized operation hash, variables hash and headers hash +// making it robust against collisions +// for scalability, you can add more shards in case the mutexes are a bottleneck +type InboundRequestSingleFlight struct { + shards []requestShard +} + +type requestShard struct { + m sync.Map +} + +const defaultRequestSingleFlightShardCount = 4 + +// NewRequestSingleFlight creates a InboundRequestSingleFlight with the provided +// number of shards. If shardCount <= 0, the default of 4 is used. +func NewRequestSingleFlight(shardCount int) *InboundRequestSingleFlight { + if shardCount <= 0 { + shardCount = defaultRequestSingleFlightShardCount + } + r := &InboundRequestSingleFlight{ + shards: make([]requestShard, shardCount), + } + return r +} + +type InflightRequest struct { + Done chan struct{} + Data []byte + Err error + ID uint64 + + HasFollowers bool + Mu sync.Mutex +} + +// GetOrCreate creates a new InflightRequest or returns an existing (shared) one +// The first caller to create an InflightRequest for a given key is a leader, everyone else a follower +// GetOrCreate blocks until ctx.ctx.Done() returns or InflightRequest.Done is closed +// It returns an error if the leader returned an error +// It returns nil,nil if the inbound request is not eligible for request deduplication +// or if DisableInboundRequestDeduplication is set to true on Context +func (r *InboundRequestSingleFlight) GetOrCreate(ctx *Context, response *GraphQLResponse) (*InflightRequest, error) { + + if ctx.ExecutionOptions.DisableInboundRequestDeduplication { + return nil, nil + } + + if !response.SingleFlightAllowed() { + return nil, nil + } + + // Derive a robust key from request ID, variables hash and (optional) headers hash + var b [24]byte + binary.LittleEndian.PutUint64(b[0:8], ctx.Request.ID) + binary.LittleEndian.PutUint64(b[8:16], ctx.VariablesHash) + hh := uint64(0) + if ctx.SubgraphHeadersBuilder != nil { + hh = ctx.SubgraphHeadersBuilder.HashAll() + } + binary.LittleEndian.PutUint64(b[16:24], hh) + key := xxhash.Sum64(b[:]) + + shard := r.shardFor(key) + req, shared := shard.m.Load(key) + if shared { + req := req.(*InflightRequest) + req.Mu.Lock() + req.HasFollowers = true + req.Mu.Unlock() + select { + case <-req.Done: + if req.Err != nil { + return nil, req.Err + } + return req, nil + case <-ctx.ctx.Done(): + return nil, ctx.ctx.Err() + } + } + + value := &InflightRequest{ + Done: make(chan struct{}), + ID: key, + } + + shard.m.Store(key, value) + return value, nil +} + +func (r *InboundRequestSingleFlight) FinishOk(req *InflightRequest, data []byte) { + if req == nil { + return + } + shard := r.shardFor(req.ID) + shard.m.Delete(req.ID) + req.Mu.Lock() + hasFollowers := req.HasFollowers + req.Mu.Unlock() + if hasFollowers { + // optimization to only copy when we actually have to + req.Data = make([]byte, len(data)) + copy(req.Data, data) + } + close(req.Done) +} + +func (r *InboundRequestSingleFlight) FinishErr(req *InflightRequest, err error) { + if req == nil { + return + } + shard := r.shardFor(req.ID) + shard.m.Delete(req.ID) + req.Err = err + close(req.Done) +} + +func (r *InboundRequestSingleFlight) shardFor(key uint64) *requestShard { + // Fast modulo using power-of-two shard count if desired in the future. + // For now, use standard modulo for clarity. + idx := int(key % uint64(len(r.shards))) + return &r.shards[idx] +} diff --git a/v2/pkg/engine/resolve/inputtemplate.go b/v2/pkg/engine/resolve/inputtemplate.go index 82825cac7..e0fc97aa6 100644 --- a/v2/pkg/engine/resolve/inputtemplate.go +++ b/v2/pkg/engine/resolve/inputtemplate.go @@ -1,10 +1,10 @@ package resolve import ( - "bytes" "context" "errors" "fmt" + "io" "github.com/wundergraph/astjson" @@ -36,7 +36,7 @@ type InputTemplate struct { SetTemplateOutputToNullOnVariableNull bool } -func SetInputUndefinedVariables(preparedInput *bytes.Buffer, undefinedVariables []string) error { +func SetInputUndefinedVariables(preparedInput InputTemplateWriter, undefinedVariables []string) error { if len(undefinedVariables) > 0 { output, err := httpclient.SetUndefinedVariables(preparedInput.Bytes(), undefinedVariables) if err != nil { @@ -55,7 +55,16 @@ func SetInputUndefinedVariables(preparedInput *bytes.Buffer, undefinedVariables // to callers; renderSegments intercepts it and writes literal.NULL instead. var errSetTemplateOutputNull = errors.New("set to null") -func (i *InputTemplate) Render(ctx *Context, data *astjson.Value, preparedInput *bytes.Buffer) error { +// InputTemplateWriter is used to decouple Buffer implementations from InputTemplate +// This way, the implementation can easily be swapped, e.g. between bytes.Buffer and similar implementations +type InputTemplateWriter interface { + io.Writer + io.StringWriter + Reset() + Bytes() []byte +} + +func (i *InputTemplate) Render(ctx *Context, data *astjson.Value, preparedInput InputTemplateWriter) error { var undefinedVariables []string if err := i.renderSegments(ctx, data, i.Segments, preparedInput, &undefinedVariables); err != nil { @@ -65,12 +74,12 @@ func (i *InputTemplate) Render(ctx *Context, data *astjson.Value, preparedInput return SetInputUndefinedVariables(preparedInput, undefinedVariables) } -func (i *InputTemplate) RenderAndCollectUndefinedVariables(ctx *Context, data *astjson.Value, preparedInput *bytes.Buffer, undefinedVariables *[]string) (err error) { +func (i *InputTemplate) RenderAndCollectUndefinedVariables(ctx *Context, data *astjson.Value, preparedInput InputTemplateWriter, undefinedVariables *[]string) (err error) { err = i.renderSegments(ctx, data, i.Segments, preparedInput, undefinedVariables) return } -func (i *InputTemplate) renderSegments(ctx *Context, data *astjson.Value, segments []TemplateSegment, preparedInput *bytes.Buffer, undefinedVariables *[]string) (err error) { +func (i *InputTemplate) renderSegments(ctx *Context, data *astjson.Value, segments []TemplateSegment, preparedInput InputTemplateWriter, undefinedVariables *[]string) (err error) { for _, segment := range segments { switch segment.SegmentType { case StaticSegmentType: @@ -107,7 +116,7 @@ func (i *InputTemplate) renderSegments(ctx *Context, data *astjson.Value, segmen return err } -func (i *InputTemplate) renderObjectVariable(ctx context.Context, variables *astjson.Value, segment TemplateSegment, preparedInput *bytes.Buffer) error { +func (i *InputTemplate) renderObjectVariable(ctx context.Context, variables *astjson.Value, segment TemplateSegment, preparedInput InputTemplateWriter) error { value := variables.Get(segment.VariableSourcePath...) if value == nil || value.Type() == astjson.TypeNull { if i.SetTemplateOutputToNullOnVariableNull { @@ -119,11 +128,11 @@ func (i *InputTemplate) renderObjectVariable(ctx context.Context, variables *ast return segment.Renderer.RenderVariable(ctx, value, preparedInput) } -func (i *InputTemplate) renderResolvableObjectVariable(ctx context.Context, objectData *astjson.Value, segment TemplateSegment, preparedInput *bytes.Buffer) error { +func (i *InputTemplate) renderResolvableObjectVariable(ctx context.Context, objectData *astjson.Value, segment TemplateSegment, preparedInput InputTemplateWriter) error { return segment.Renderer.RenderVariable(ctx, objectData, preparedInput) } -func (i *InputTemplate) renderContextVariable(ctx *Context, segment TemplateSegment, preparedInput *bytes.Buffer) (variableWasUndefined bool, err error) { +func (i *InputTemplate) renderContextVariable(ctx *Context, segment TemplateSegment, preparedInput InputTemplateWriter) (variableWasUndefined bool, err error) { variableSourcePath := segment.VariableSourcePath if len(variableSourcePath) == 1 && ctx.RemapVariables != nil { nameToUse, hasMapping := ctx.RemapVariables[variableSourcePath[0]] @@ -142,7 +151,7 @@ func (i *InputTemplate) renderContextVariable(ctx *Context, segment TemplateSegm return false, segment.Renderer.RenderVariable(ctx.Context(), value, preparedInput) } -func (i *InputTemplate) renderHeaderVariable(ctx *Context, path []string, preparedInput *bytes.Buffer) error { +func (i *InputTemplate) renderHeaderVariable(ctx *Context, path []string, preparedInput InputTemplateWriter) error { if len(path) != 1 { return errHeaderPathInvalid } @@ -151,14 +160,20 @@ func (i *InputTemplate) renderHeaderVariable(ctx *Context, path []string, prepar return nil } if len(value) == 1 { - preparedInput.WriteString(value[0]) + if _, err := preparedInput.WriteString(value[0]); err != nil { + return err + } return nil } for j := range value { if j != 0 { - _, _ = preparedInput.Write(literal.COMMA) + if _, err := preparedInput.Write(literal.COMMA); err != nil { + return err + } + } + if _, err := preparedInput.WriteString(value[j]); err != nil { + return err } - preparedInput.WriteString(value[j]) } return nil } diff --git a/v2/pkg/engine/resolve/loader.go b/v2/pkg/engine/resolve/loader.go index 7a14d61dc..23f0b6d32 100644 --- a/v2/pkg/engine/resolve/loader.go +++ b/v2/pkg/engine/resolve/loader.go @@ -22,6 +22,7 @@ import ( "golang.org/x/sync/errgroup" "github.com/wundergraph/astjson" + "github.com/wundergraph/go-arena" "github.com/wundergraph/graphql-go-tools/v2/pkg/ast" "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/httpclient" @@ -56,11 +57,11 @@ type ResponseInfo struct { // ResponseHeaders contains a clone of the headers of the response from the subgraph. ResponseHeaders http.Header // This should be private as we do not want user's to access the raw responseBody directly - responseBody *bytes.Buffer + responseBody []byte } -func (ri *ResponseInfo) GetResponseBody() string { - return ri.responseBody.String() +func (r *ResponseInfo) GetResponseBody() string { + return string(r.responseBody) } func newResponseInfo(res *result, subgraphError error) *ResponseInfo { @@ -91,35 +92,21 @@ func newResponseInfo(res *result, subgraphError error) *ResponseInfo { return responseInfo } -// batchStats represents an index map for batched items. -// It is used to ensure that the correct json values will be merged with the correct items from the batch. -// -// Example: -// [[0],[1],[0],[1]] We originally have 4 items, but we have 2 unique indexes (0 and 1). -// This means we are deduplicating 2 items by merging them from their response entity indexes. -// 0 -> 0, 1 -> 1, 2 -> 0, 3 -> 1 -type batchStats [][]int - -// getUniqueIndexes returns the number of unique indexes in the batchStats. -// This is used to ensure that we can provide a valid error message in case of differing array lengths. -func (b *batchStats) getUniqueIndexes() int { - uniqueIndexes := make(map[int]struct{}) - for _, bi := range *b { - for _, index := range bi { - if index < 0 { - continue - } - uniqueIndexes[index] = struct{}{} - } - } - - return len(uniqueIndexes) -} - type result struct { - postProcessing PostProcessingConfiguration - out *bytes.Buffer - batchStats batchStats + postProcessing PostProcessingConfiguration + // batchStats represents per-unique-batch-item merge targets. + // Outer slice index corresponds to the unique representation index in the request batch, + // and the inner slice contains all target values that should be merged with the response at that index. + // + // Example: + // For 4 original items that deduplicate to 2 unique representations, we might have: + // [ + // + // [item0, item2], // merge response[0] into item0 and item2 + // [item1, item3], // merge response[1] into item1 and item3 + // + // ] + batchStats [][]*astjson.Value fetchSkipped bool nestedMergeItems []*result @@ -138,6 +125,10 @@ type result struct { loaderHookContext context.Context httpResponseContext *httpclient.ResponseContext + // out is the subgraph response body + out []byte + singleFlightStats *singleFlightStats + tools *batchEntityTools } func (r *result) init(postProcessing PostProcessingConfiguration, info *FetchInfo) { @@ -180,6 +171,19 @@ type Loader struct { validateRequiredExternalFields bool taintedObjs taintedObjects + + // jsonArena is the arena to allocation json, supplied by the Resolver + // Disclaimer: this arena is NOT thread safe! + // Only use from main goroutine + // Don't Reset or Release, the Resolver handles this + // Disclaimer: When parsing json into the arena, the underlying bytes must also be allocated on the arena! + // This is very important to "tie" their lifecycles together + // If you're not doing this, you will see segfaults + // Example of correct usage in func "mergeResult" + jsonArena arena.Arena + // sf is the SubgraphRequestSingleFlight object shared across all client requests + // it's thread safe and can be used to de-duplicate subgraph requests + sf *SubgraphRequestSingleFlight } func (l *Loader) Free() { @@ -218,6 +222,12 @@ func (l *Loader) resolveParallel(nodes []*FetchTreeNode) error { return nil } results := make([]*result, len(nodes)) + defer func() { + for i := range results { + // no-op if tools == nil + batchEntityToolPool.Put(results[i].tools) + } + }() itemsItems := make([][]*astjson.Value, len(nodes)) g, ctx := errgroup.WithContext(l.ctx.ctx) for i := range nodes { @@ -280,9 +290,7 @@ func (l *Loader) resolveSingle(item *FetchItem) error { switch f := item.Fetch.(type) { case *SingleFetch: - res := &result{ - out: &bytes.Buffer{}, - } + res := &result{} err := l.loadSingleFetch(l.ctx.ctx, f, item, items, res) if err != nil { return err @@ -291,12 +299,10 @@ func (l *Loader) resolveSingle(item *FetchItem) error { if l.ctx.LoaderHooks != nil { l.ctx.LoaderHooks.OnFinished(res.loaderHookContext, res.ds, newResponseInfo(res, l.ctx.subgraphErrors)) } - return err case *BatchEntityFetch: - res := &result{ - out: &bytes.Buffer{}, - } + res := &result{} + defer batchEntityToolPool.Put(res.tools) err := l.loadBatchEntityFetch(l.ctx.ctx, item, f, items, res) if err != nil { return errors.WithStack(err) @@ -307,9 +313,7 @@ func (l *Loader) resolveSingle(item *FetchItem) error { } return err case *EntityFetch: - res := &result{ - out: &bytes.Buffer{}, - } + res := &result{} err := l.loadEntityFetch(l.ctx.ctx, item, f, items, res) if err != nil { return errors.WithStack(err) @@ -319,50 +323,15 @@ func (l *Loader) resolveSingle(item *FetchItem) error { l.ctx.LoaderHooks.OnFinished(res.loaderHookContext, res.ds, newResponseInfo(res, l.ctx.subgraphErrors)) } return err - case *ParallelListItemFetch: - results := make([]*result, len(items)) - if l.ctx.TracingOptions.Enable { - f.Traces = make([]*SingleFetch, len(items)) - } - g, ctx := errgroup.WithContext(l.ctx.ctx) - for i := range items { - i := i - results[i] = &result{ - out: &bytes.Buffer{}, - } - if l.ctx.TracingOptions.Enable { - f.Traces[i] = new(SingleFetch) - *f.Traces[i] = *f.Fetch - g.Go(func() error { - return l.loadFetch(ctx, f.Traces[i], item, items[i:i+1], results[i]) - }) - continue - } - g.Go(func() error { - return l.loadFetch(ctx, f.Fetch, item, items[i:i+1], results[i]) - }) - } - err := g.Wait() - if err != nil { - return errors.WithStack(err) - } - for i := range results { - err = l.mergeResult(item, results[i], items[i:i+1]) - if l.ctx.LoaderHooks != nil { - l.ctx.LoaderHooks.OnFinished(results[i].loaderHookContext, results[i].ds, newResponseInfo(results[i], l.ctx.subgraphErrors)) - } - if err != nil { - return errors.WithStack(err) - } - } - return nil default: return nil } } func (l *Loader) selectItemsForPath(path []FetchItemPathElement) []*astjson.Value { - items := []*astjson.Value{l.resolvable.data} + // Use arena allocation for the initial items slice + items := arena.AllocateSlice[*astjson.Value](l.jsonArena, 1, 1) + items[0] = l.resolvable.data if len(path) == 0 { return l.taintedObjs.filterOutTainted(items) } @@ -370,7 +339,7 @@ func (l *Loader) selectItemsForPath(path []FetchItemPathElement) []*astjson.Valu if len(items) == 0 { break } - items = selectItems(items, path[i]) + items = selectItems(l.jsonArena, items, path[i]) } return l.taintedObjs.filterOutTainted(items) } @@ -391,7 +360,7 @@ func isItemAllowedByTypename(obj *astjson.Value, typeNames []string) bool { return slices.Contains(typeNames, __typeNameStr) } -func selectItems(items []*astjson.Value, element FetchItemPathElement) []*astjson.Value { +func selectItems(a arena.Arena, items []*astjson.Value, element FetchItemPathElement) []*astjson.Value { if len(items) == 0 { return nil } @@ -413,7 +382,7 @@ func selectItems(items []*astjson.Value, element FetchItemPathElement) []*astjso } return []*astjson.Value{field} } - selected := make([]*astjson.Value, 0, len(items)) + selected := arena.AllocateSlice[*astjson.Value](a, 0, len(items)) for _, item := range items { if !isItemAllowedByTypename(item, element.TypeNames) { continue @@ -423,15 +392,15 @@ func selectItems(items []*astjson.Value, element FetchItemPathElement) []*astjso continue } if field.Type() == astjson.TypeArray { - selected = append(selected, field.GetArray()...) + selected = arena.SliceAppend(a, selected, field.GetArray()...) continue } - selected = append(selected, field) + selected = arena.SliceAppend(a, selected, field) } return selected } -func itemsData(items []*astjson.Value) *astjson.Value { +func (l *Loader) itemsData(items []*astjson.Value) *astjson.Value { if len(items) == 0 { return astjson.NullValue } @@ -442,7 +411,7 @@ func itemsData(items []*astjson.Value) *astjson.Value { // however, itemsData can be called concurrently, so this might result in a race arr := astjson.MustParseBytes([]byte(`[]`)) for i, item := range items { - arr.SetArrayItem(i, item) + arr.SetArrayItem(nil, i, item) } return arr } @@ -450,42 +419,10 @@ func itemsData(items []*astjson.Value) *astjson.Value { func (l *Loader) loadFetch(ctx context.Context, fetch Fetch, fetchItem *FetchItem, items []*astjson.Value, res *result) error { switch f := fetch.(type) { case *SingleFetch: - res.out = &bytes.Buffer{} return l.loadSingleFetch(ctx, f, fetchItem, items, res) - case *ParallelListItemFetch: - results := make([]*result, len(items)) - if l.ctx.TracingOptions.Enable { - f.Traces = make([]*SingleFetch, len(items)) - } - g, ctx := errgroup.WithContext(l.ctx.ctx) - for i := range items { - i := i - results[i] = &result{ - out: &bytes.Buffer{}, - } - if l.ctx.TracingOptions.Enable { - f.Traces[i] = new(SingleFetch) - *f.Traces[i] = *f.Fetch - g.Go(func() error { - return l.loadFetch(ctx, f.Traces[i], fetchItem, items[i:i+1], results[i]) - }) - continue - } - g.Go(func() error { - return l.loadFetch(ctx, f.Fetch, fetchItem, items[i:i+1], results[i]) - }) - } - err := g.Wait() - if err != nil { - return errors.WithStack(err) - } - res.nestedMergeItems = results - return nil case *EntityFetch: - res.out = &bytes.Buffer{} return l.loadEntityFetch(ctx, fetchItem, f, items, res) case *BatchEntityFetch: - res.out = &bytes.Buffer{} return l.loadBatchEntityFetch(ctx, fetchItem, f, items, res) } return nil @@ -548,11 +485,15 @@ func (l *Loader) mergeResult(fetchItem *FetchItem, res *result, items []*astjson if res.fetchSkipped { return nil } - if res.out.Len() == 0 { + if len(res.out) == 0 { return l.renderErrorsFailedToFetch(fetchItem, res, emptyGraphQLResponse) } - - response, err := astjson.ParseBytesWithoutCache(res.out.Bytes()) + // before parsing bytes with an arena.Arena, it's important to first allocate the bytes ON the same arena.Arena + // this ties their lifecycles together + // if you don't do this, you'll get segfaults + slice := arena.AllocateSlice[byte](l.jsonArena, len(res.out), len(res.out)) + copy(slice, res.out) + response, err := astjson.ParseBytesWithArena(l.jsonArena, slice) if err != nil { // Fall back to status code if parsing fails and non-2XX if (res.statusCode > 0 && res.statusCode < 200) || res.statusCode >= 300 { @@ -633,7 +574,7 @@ func (l *Loader) mergeResult(fetchItem *FetchItem, res *result, items []*astjson return nil } if len(items) == 1 && res.batchStats == nil { - items[0], _, err = astjson.MergeValuesWithPath(items[0], responseData, res.postProcessing.MergePath...) + items[0], _, err = astjson.MergeValuesWithPath(l.jsonArena, items[0], responseData, res.postProcessing.MergePath...) if err != nil { return errors.WithStack(ErrMergeResult{ Subgraph: res.ds.Name, @@ -652,26 +593,23 @@ func (l *Loader) mergeResult(fetchItem *FetchItem, res *result, items []*astjson } if res.batchStats != nil { - uniqueIndexes := res.batchStats.getUniqueIndexes() - if uniqueIndexes != len(batch) { - return l.renderErrorsFailedToFetch(fetchItem, res, fmt.Sprintf(invalidBatchItemCount, uniqueIndexes, len(batch))) + if len(res.batchStats) != len(batch) { + return l.renderErrorsFailedToFetch(fetchItem, res, fmt.Sprintf(invalidBatchItemCount, len(res.batchStats), len(batch))) } - for i, stats := range res.batchStats { - for _, idx := range stats { - if idx == -1 { - continue - } - items[i], _, err = astjson.MergeValuesWithPath(items[i], batch[idx], res.postProcessing.MergePath...) - if err != nil { + for batchIndex, targets := range res.batchStats { + src := batch[batchIndex] + for _, target := range targets { + _, _, mErr := astjson.MergeValuesWithPath(l.jsonArena, target, src, res.postProcessing.MergePath...) + if mErr != nil { return errors.WithStack(ErrMergeResult{ Subgraph: res.ds.Name, - Reason: err, + Reason: mErr, Path: fetchItem.ResponsePath, }) } - if slices.Contains(taintedIndices, idx) { - l.taintedObjs.add(items[i]) + if slices.Contains(taintedIndices, batchIndex) { + l.taintedObjs.add(target) } } } @@ -683,7 +621,7 @@ func (l *Loader) mergeResult(fetchItem *FetchItem, res *result, items []*astjson } for i := range items { - items[i], _, err = astjson.MergeValuesWithPath(items[i], batch[i], res.postProcessing.MergePath...) + items[i], _, err = astjson.MergeValuesWithPath(l.jsonArena, items[i], batch[i], res.postProcessing.MergePath...) if err != nil { return errors.WithStack(ErrMergeResult{ Subgraph: res.ds.Name, @@ -703,7 +641,8 @@ var ( errorsInvalidInputFooter = []byte(`]}]}`) ) -func (l *Loader) renderErrorsInvalidInput(fetchItem *FetchItem, out *bytes.Buffer) error { +func (l *Loader) renderErrorsInvalidInput(fetchItem *FetchItem) []byte { + out := bytes.NewBuffer(nil) elements := fetchItem.ResponsePathElements if len(elements) > 0 && elements[len(elements)-1] == "@" { elements = elements[:len(elements)-1] @@ -721,7 +660,7 @@ func (l *Loader) renderErrorsInvalidInput(fetchItem *FetchItem, out *bytes.Buffe _, _ = out.Write(quote) } _, _ = out.Write(errorsInvalidInputFooter) - return nil + return out.Bytes() } func (l *Loader) appendSubgraphError(res *result, fetchItem *FetchItem, value *astjson.Value, values []*astjson.Value) error { @@ -749,7 +688,7 @@ func (l *Loader) mergeErrors(res *result, fetchItem *FetchItem, value *astjson.V values := value.GetArray() l.optionallyOmitErrorLocations(values) if l.rewriteSubgraphErrorPaths { - rewriteErrorPaths(fetchItem, values) + rewriteErrorPaths(l.jsonArena, fetchItem, values) } l.optionallyEnsureExtensionErrorCode(values) @@ -778,7 +717,10 @@ func (l *Loader) mergeErrors(res *result, fetchItem *FetchItem, value *astjson.V return err } } - + // for efficiency purposes, resolvable.errors is not initialized + // don't change this, it's measurable + // downside: we have to verify it's initialized before appending to it + l.resolvable.ensureErrorsInitialized() // If the error propagation mode is pass-through, we append the errors to the root array l.resolvable.errors.AppendArrayItems(value) return nil @@ -792,7 +734,7 @@ func (l *Loader) mergeErrors(res *result, fetchItem *FetchItem, value *astjson.V } // Wrap mode (default) - errorObject, err := astjson.ParseWithoutCache(l.renderSubgraphBaseError(res.ds, fetchItem.ResponsePath, failedToFetchNoReason)) + errorObject, err := astjson.ParseWithArena(l.jsonArena, l.renderSubgraphBaseError(res.ds, fetchItem.ResponsePath, failedToFetchNoReason)) if err != nil { return err } @@ -815,7 +757,10 @@ func (l *Loader) mergeErrors(res *result, fetchItem *FetchItem, value *astjson.V if err := l.addApolloRouterCompatibilityError(res); err != nil { return err } - + // for efficiency purposes, resolvable.errors is not initialized + // don't change this, it's measurable + // downside: we have to verify it's initialized before appending to it + l.resolvable.ensureErrorsInitialized() astjson.AppendToArray(l.resolvable.errors, errorObject) return nil @@ -861,17 +806,17 @@ func (l *Loader) optionallyEnsureExtensionErrorCode(values []*astjson.Value) { switch extensions.Type() { case astjson.TypeObject: if !extensions.Exists("code") { - extensions.Set("code", l.resolvable.astjsonArena.NewString(l.defaultErrorExtensionCode)) + extensions.Set(l.jsonArena, "code", astjson.StringValue(l.jsonArena, l.defaultErrorExtensionCode)) } case astjson.TypeNull: - extensionsObj := l.resolvable.astjsonArena.NewObject() - extensionsObj.Set("code", l.resolvable.astjsonArena.NewString(l.defaultErrorExtensionCode)) - value.Set("extensions", extensionsObj) + extensionsObj := astjson.ObjectValue(l.jsonArena) + extensionsObj.Set(l.jsonArena, "code", astjson.StringValue(l.jsonArena, l.defaultErrorExtensionCode)) + value.Set(l.jsonArena, "extensions", extensionsObj) } } else { - extensionsObj := l.resolvable.astjsonArena.NewObject() - extensionsObj.Set("code", l.resolvable.astjsonArena.NewString(l.defaultErrorExtensionCode)) - value.Set("extensions", extensionsObj) + extensionsObj := astjson.ObjectValue(l.jsonArena) + extensionsObj.Set(l.jsonArena, "code", astjson.StringValue(l.jsonArena, l.defaultErrorExtensionCode)) + value.Set(l.jsonArena, "extensions", extensionsObj) } } } @@ -888,16 +833,16 @@ func (l *Loader) optionallyAttachServiceNameToErrorExtension(values []*astjson.V extensions := value.Get("extensions") switch extensions.Type() { case astjson.TypeObject: - extensions.Set("serviceName", l.resolvable.astjsonArena.NewString(serviceName)) + extensions.Set(l.jsonArena, "serviceName", astjson.StringValue(l.jsonArena, serviceName)) case astjson.TypeNull: - extensionsObj := l.resolvable.astjsonArena.NewObject() - extensionsObj.Set("serviceName", l.resolvable.astjsonArena.NewString(serviceName)) - value.Set("extensions", extensionsObj) + extensionsObj := astjson.ObjectValue(l.jsonArena) + extensionsObj.Set(l.jsonArena, "serviceName", astjson.StringValue(l.jsonArena, serviceName)) + value.Set(l.jsonArena, "extensions", extensionsObj) } } else { - extensionsObj := l.resolvable.astjsonArena.NewObject() - extensionsObj.Set("serviceName", l.resolvable.astjsonArena.NewString(serviceName)) - value.Set("extensions", extensionsObj) + extensionsObj := astjson.ObjectValue(l.jsonArena) + extensionsObj.Set(l.jsonArena, "serviceName", astjson.StringValue(l.jsonArena, serviceName)) + value.Set(l.jsonArena, "extensions", extensionsObj) } } } @@ -951,7 +896,7 @@ func (l *Loader) optionallyOmitErrorLocations(values []*astjson.Value) { // - Drops the numeric index immediately following "_entities". // - Converts all subsequent numeric segments to strings (e.g., 1 -> "1"). // - Skips non-string/non-number segments. -func rewriteErrorPaths(fetchItem *FetchItem, values []*astjson.Value) { +func rewriteErrorPaths(a arena.Arena, fetchItem *FetchItem, values []*astjson.Value) { pathPrefix := make([]string, len(fetchItem.ResponsePathElements)) copy(pathPrefix, fetchItem.ResponsePathElements) // remove the trailing @ in case we're in an array as it looks weird in the path @@ -993,11 +938,11 @@ func rewriteErrorPaths(fetchItem *FetchItem, values []*astjson.Value) { } } newPathJSON, _ := json.Marshal(newPath) - pathBytes, err := astjson.ParseBytesWithoutCache(newPathJSON) + pathBytes, err := astjson.ParseBytesWithArena(a, newPathJSON) if err != nil { continue } - value.Set("path", pathBytes) + value.Set(a, "path", pathBytes) break } } @@ -1018,17 +963,17 @@ func (l *Loader) setSubgraphStatusCode(values []*astjson.Value, statusCode int) if extensions.Type() != astjson.TypeObject { continue } - v, err := astjson.ParseWithoutCache(strconv.Itoa(statusCode)) + v, err := astjson.ParseWithArena(l.jsonArena, strconv.Itoa(statusCode)) if err != nil { continue } - extensions.Set("statusCode", v) + extensions.Set(l.jsonArena, "statusCode", v) } else { - v, err := astjson.ParseWithoutCache(`{"statusCode":` + strconv.Itoa(statusCode) + `}`) + v, err := astjson.ParseWithArena(l.jsonArena, `{"statusCode":`+strconv.Itoa(statusCode)+`}`) if err != nil { continue } - value.Set("extensions", v) + value.Set(l.jsonArena, "extensions", v) } } } @@ -1065,11 +1010,14 @@ func (l *Loader) addApolloRouterCompatibilityError(res *result) error { } } }`, res.ds.Name, http.StatusText(res.statusCode), res.statusCode) - apolloRouterStatusError, err := astjson.ParseWithoutCache(apolloRouterStatusErrorJSON) + apolloRouterStatusError, err := astjson.ParseWithArena(l.jsonArena, apolloRouterStatusErrorJSON) if err != nil { return err } - + // for efficiency purposes, resolvable.errors is not initialized + // don't change this, it's measurable + // downside: we have to verify it's initialized before appending to it + l.resolvable.ensureErrorsInitialized() astjson.AppendToArray(l.resolvable.errors, apolloRouterStatusError) return nil @@ -1078,22 +1026,30 @@ func (l *Loader) addApolloRouterCompatibilityError(res *result) error { func (l *Loader) renderErrorsFailedDeps(fetchItem *FetchItem, res *result) error { path := l.renderAtPathErrorPart(fetchItem.ResponsePath) msg := fmt.Sprintf(`{"message":"Failed to obtain field dependencies from Subgraph '%s'%s."}`, res.ds.Name, path) - errorObject, err := astjson.ParseWithoutCache(msg) + errorObject, err := astjson.ParseWithArena(l.jsonArena, msg) if err != nil { return err } l.setSubgraphStatusCode([]*astjson.Value{errorObject}, res.statusCode) + // for efficiency purposes, resolvable.errors is not initialized + // don't change this, it's measurable + // downside: we have to verify it's initialized before appending to it + l.resolvable.ensureErrorsInitialized() astjson.AppendToArray(l.resolvable.errors, errorObject) return nil } func (l *Loader) renderErrorsFailedToFetch(fetchItem *FetchItem, res *result, reason string) error { l.ctx.appendSubgraphErrors(res.err, NewSubgraphError(res.ds, fetchItem.ResponsePath, reason, res.statusCode)) - errorObject, err := astjson.ParseWithoutCache(l.renderSubgraphBaseError(res.ds, fetchItem.ResponsePath, reason)) + errorObject, err := astjson.ParseWithArena(l.jsonArena, l.renderSubgraphBaseError(res.ds, fetchItem.ResponsePath, reason)) if err != nil { return err } l.setSubgraphStatusCode([]*astjson.Value{errorObject}, res.statusCode) + // for efficiency purposes, resolvable.errors is not initialized + // don't change this, it's measurable + // downside: we have to verify it's initialized before appending to it + l.resolvable.ensureErrorsInitialized() astjson.AppendToArray(l.resolvable.errors, errorObject) return nil } @@ -1106,13 +1062,16 @@ func (l *Loader) renderErrorsStatusFallback(fetchItem *FetchItem, res *result, s l.ctx.appendSubgraphErrors(res.err, NewSubgraphError(res.ds, fetchItem.ResponsePath, reason, res.statusCode)) - errorObject, err := astjson.ParseWithoutCache(fmt.Sprintf(`{"message":"%s"}`, reason)) + errorObject, err := astjson.ParseWithArena(l.jsonArena, fmt.Sprintf(`{"message":"%s"}`, reason)) if err != nil { return err } l.setSubgraphStatusCode([]*astjson.Value{errorObject}, res.statusCode) - + // for efficiency purposes, resolvable.errors is not initialized + // don't change this, it's measurable + // downside: we have to verify it's initialized before appending to it + l.resolvable.ensureErrorsInitialized() astjson.AppendToArray(l.resolvable.errors, errorObject) return nil } @@ -1137,16 +1096,20 @@ func (l *Loader) renderAuthorizationRejectedErrors(fetchItem *FetchItem, res *re } pathPart := l.renderAtPathErrorPart(fetchItem.ResponsePath) extensionErrorCode := fmt.Sprintf(`"extensions":{"code":"%s"}`, errorcodes.UnauthorizedFieldOrType) + // for efficiency purposes, resolvable.errors is not initialized + // don't change this, it's measurable + // downside: we have to verify it's initialized before appending to it + l.resolvable.ensureErrorsInitialized() if res.ds.Name == "" { for _, reason := range res.authorizationRejectedReasons { if reason == "" { - errorObject, err := astjson.ParseWithoutCache(fmt.Sprintf(`{"message":"Unauthorized Subgraph request%s.",%s}`, pathPart, extensionErrorCode)) + errorObject, err := astjson.ParseWithArena(l.jsonArena, fmt.Sprintf(`{"message":"Unauthorized Subgraph request%s.",%s}`, pathPart, extensionErrorCode)) if err != nil { continue } astjson.AppendToArray(l.resolvable.errors, errorObject) } else { - errorObject, err := astjson.ParseWithoutCache(fmt.Sprintf(`{"message":"Unauthorized Subgraph request%s, Reason: %s.",%s}`, pathPart, reason, extensionErrorCode)) + errorObject, err := astjson.ParseWithArena(l.jsonArena, fmt.Sprintf(`{"message":"Unauthorized Subgraph request%s, Reason: %s.",%s}`, pathPart, reason, extensionErrorCode)) if err != nil { continue } @@ -1156,13 +1119,13 @@ func (l *Loader) renderAuthorizationRejectedErrors(fetchItem *FetchItem, res *re } else { for _, reason := range res.authorizationRejectedReasons { if reason == "" { - errorObject, err := astjson.ParseWithoutCache(fmt.Sprintf(`{"message":"Unauthorized request to Subgraph '%s'%s.",%s}`, res.ds.Name, pathPart, extensionErrorCode)) + errorObject, err := astjson.ParseWithArena(l.jsonArena, fmt.Sprintf(`{"message":"Unauthorized request to Subgraph '%s'%s.",%s}`, res.ds.Name, pathPart, extensionErrorCode)) if err != nil { continue } astjson.AppendToArray(l.resolvable.errors, errorObject) } else { - errorObject, err := astjson.ParseWithoutCache(fmt.Sprintf(`{"message":"Unauthorized request to Subgraph '%s'%s, Reason: %s.",%s}`, res.ds.Name, pathPart, reason, extensionErrorCode)) + errorObject, err := astjson.ParseWithArena(l.jsonArena, fmt.Sprintf(`{"message":"Unauthorized request to Subgraph '%s'%s, Reason: %s.",%s}`, res.ds.Name, pathPart, reason, extensionErrorCode)) if err != nil { continue } @@ -1182,39 +1145,43 @@ func (l *Loader) renderRateLimitRejectedErrors(fetchItem *FetchItem, res *result ) if res.ds.Name == "" { if res.rateLimitRejectedReason == "" { - errorObject, err = astjson.ParseWithoutCache(fmt.Sprintf(`{"message":"Rate limit exceeded for Subgraph request%s."}`, pathPart)) + errorObject, err = astjson.ParseWithArena(l.jsonArena, fmt.Sprintf(`{"message":"Rate limit exceeded for Subgraph request%s."}`, pathPart)) if err != nil { return err } } else { - errorObject, err = astjson.ParseWithoutCache(fmt.Sprintf(`{"message":"Rate limit exceeded for Subgraph request%s, Reason: %s."}`, pathPart, res.rateLimitRejectedReason)) + errorObject, err = astjson.ParseWithArena(l.jsonArena, fmt.Sprintf(`{"message":"Rate limit exceeded for Subgraph request%s, Reason: %s."}`, pathPart, res.rateLimitRejectedReason)) if err != nil { return err } } } else { if res.rateLimitRejectedReason == "" { - errorObject, err = astjson.ParseWithoutCache(fmt.Sprintf(`{"message":"Rate limit exceeded for Subgraph '%s'%s."}`, res.ds.Name, pathPart)) + errorObject, err = astjson.ParseWithArena(l.jsonArena, fmt.Sprintf(`{"message":"Rate limit exceeded for Subgraph '%s'%s."}`, res.ds.Name, pathPart)) if err != nil { return err } } else { - errorObject, err = astjson.ParseWithoutCache(fmt.Sprintf(`{"message":"Rate limit exceeded for Subgraph '%s'%s, Reason: %s."}`, res.ds.Name, pathPart, res.rateLimitRejectedReason)) + errorObject, err = astjson.ParseWithArena(l.jsonArena, fmt.Sprintf(`{"message":"Rate limit exceeded for Subgraph '%s'%s, Reason: %s."}`, res.ds.Name, pathPart, res.rateLimitRejectedReason)) if err != nil { return err } } } if l.ctx.RateLimitOptions.ErrorExtensionCode.Enabled { - extension, err := astjson.ParseWithoutCache(fmt.Sprintf(`{"code":"%s"}`, l.ctx.RateLimitOptions.ErrorExtensionCode.Code)) + extension, err := astjson.ParseWithArena(l.jsonArena, fmt.Sprintf(`{"code":"%s"}`, l.ctx.RateLimitOptions.ErrorExtensionCode.Code)) if err != nil { return err } - errorObject, _, err = astjson.MergeValuesWithPath(errorObject, extension, "extensions") + errorObject, _, err = astjson.MergeValuesWithPath(l.jsonArena, errorObject, extension, "extensions") if err != nil { return err } } + // for efficiency purposes, resolvable.errors is not initialized + // don't change this, it's measurable + // downside: we have to verify it's initialized before appending to it + l.resolvable.ensureErrorsInitialized() astjson.AppendToArray(l.resolvable.errors, errorObject) return nil } @@ -1285,9 +1252,9 @@ func (l *Loader) validatePreFetch(input []byte, info *FetchInfo, res *result) (a func (l *Loader) loadSingleFetch(ctx context.Context, fetch *SingleFetch, fetchItem *FetchItem, items []*astjson.Value, res *result) error { res.init(fetch.PostProcessing, fetch.Info) - buf := &bytes.Buffer{} + buf := bytes.NewBuffer(nil) - inputData := itemsData(items) + inputData := l.itemsData(items) if l.ctx.TracingOptions.Enable { fetch.Trace = &DataSourceLoadTrace{} if !l.ctx.TracingOptions.ExcludeRawInputData && inputData != nil { @@ -1309,7 +1276,8 @@ func (l *Loader) loadSingleFetch(ctx context.Context, fetch *SingleFetch, fetchI err := fetch.InputTemplate.Render(l.ctx, inputData, buf) if err != nil { - return l.renderErrorsInvalidInput(fetchItem, res.out) + res.out = l.renderErrorsInvalidInput(fetchItem) + return nil } fetchInput := buf.Bytes() allowed, err := l.validatePreFetch(fetchInput, fetch.Info, res) @@ -1323,37 +1291,9 @@ func (l *Loader) loadSingleFetch(ctx context.Context, fetch *SingleFetch, fetchI return nil } -var ( - entityFetchPool = sync.Pool{ - New: func() any { - return &entityFetchBuffer{ - item: &bytes.Buffer{}, - preparedInput: &bytes.Buffer{}, - } - }, - } -) - -type entityFetchBuffer struct { - item *bytes.Buffer - preparedInput *bytes.Buffer -} - -func acquireEntityFetchBuffer() *entityFetchBuffer { - return entityFetchPool.Get().(*entityFetchBuffer) -} - -func releaseEntityFetchBuffer(buf *entityFetchBuffer) { - buf.item.Reset() - buf.preparedInput.Reset() - entityFetchPool.Put(buf) -} - func (l *Loader) loadEntityFetch(ctx context.Context, fetchItem *FetchItem, fetch *EntityFetch, items []*astjson.Value, res *result) error { res.init(fetch.PostProcessing, fetch.Info) - buf := acquireEntityFetchBuffer() - defer releaseEntityFetchBuffer(buf) - input := itemsData(items) + input := l.itemsData(items) if l.ctx.TracingOptions.Enable { fetch.Trace = &DataSourceLoadTrace{} if !l.ctx.TracingOptions.ExcludeRawInputData && input != nil { @@ -1361,14 +1301,17 @@ func (l *Loader) loadEntityFetch(ctx context.Context, fetchItem *FetchItem, fetc } } + preparedInput := bytes.NewBuffer(nil) + item := bytes.NewBuffer(nil) + var undefinedVariables []string - err := fetch.Input.Header.RenderAndCollectUndefinedVariables(l.ctx, nil, buf.preparedInput, &undefinedVariables) + err := fetch.Input.Header.RenderAndCollectUndefinedVariables(l.ctx, nil, preparedInput, &undefinedVariables) if err != nil { return errors.WithStack(err) } - err = fetch.Input.Item.Render(l.ctx, input, buf.item) + err = fetch.Input.Item.Render(l.ctx, input, item) if err != nil { if fetch.Input.SkipErrItem { // skip fetch on render item error @@ -1380,7 +1323,7 @@ func (l *Loader) loadEntityFetch(ctx context.Context, fetchItem *FetchItem, fetc } return errors.WithStack(err) } - renderedItem := buf.item.Bytes() + renderedItem := item.Bytes() if bytes.Equal(renderedItem, null) { // skip fetch if item is null res.fetchSkipped = true @@ -1399,17 +1342,17 @@ func (l *Loader) loadEntityFetch(ctx context.Context, fetchItem *FetchItem, fetc return nil } } - _, _ = buf.item.WriteTo(buf.preparedInput) - err = fetch.Input.Footer.RenderAndCollectUndefinedVariables(l.ctx, nil, buf.preparedInput, &undefinedVariables) + _, _ = item.WriteTo(preparedInput) + err = fetch.Input.Footer.RenderAndCollectUndefinedVariables(l.ctx, nil, preparedInput, &undefinedVariables) if err != nil { return errors.WithStack(err) } - err = SetInputUndefinedVariables(buf.preparedInput, undefinedVariables) + err = SetInputUndefinedVariables(preparedInput, undefinedVariables) if err != nil { return errors.WithStack(err) } - fetchInput := buf.preparedInput.Bytes() + fetchInput := preparedInput.Bytes() if l.ctx.TracingOptions.Enable && res.fetchSkipped { l.setTracingInput(fetchItem, fetchInput, fetch.Trace) @@ -1427,71 +1370,94 @@ func (l *Loader) loadEntityFetch(ctx context.Context, fetchItem *FetchItem, fetc return nil } -var ( - batchEntityFetchPool = sync.Pool{} -) +type batchEntityTools struct { + keyGen *xxhash.Digest + batchHashToIndex map[uint64]int + a arena.Arena +} -type batchEntityFetchBuffer struct { - preparedInput *bytes.Buffer - itemInput *bytes.Buffer - keyGen *xxhash.Digest +func (b *batchEntityTools) reset() { + b.keyGen.Reset() + b.a.Reset() + for i := range b.batchHashToIndex { + delete(b.batchHashToIndex, i) + } } -func acquireBatchEntityFetchBuffer() *batchEntityFetchBuffer { - buf := batchEntityFetchPool.Get() - if buf == nil { - return &batchEntityFetchBuffer{ - preparedInput: &bytes.Buffer{}, - itemInput: &bytes.Buffer{}, - keyGen: xxhash.New(), +type _batchEntityToolPool struct { + pool sync.Pool +} + +func (p *_batchEntityToolPool) Get(items int) *batchEntityTools { + item := p.pool.Get() + if item == nil { + return &batchEntityTools{ + keyGen: xxhash.New(), + batchHashToIndex: make(map[uint64]int, items), + a: arena.NewMonotonicArena(arena.WithMinBufferSize(1024)), } } - return buf.(*batchEntityFetchBuffer) + return item.(*batchEntityTools) } -func releaseBatchEntityFetchBuffer(buf *batchEntityFetchBuffer) { - buf.preparedInput.Reset() - buf.itemInput.Reset() - buf.keyGen.Reset() - batchEntityFetchPool.Put(buf) +func (p *_batchEntityToolPool) Put(item *batchEntityTools) { + if item == nil { + return + } + item.reset() + p.pool.Put(item) } +var ( + batchEntityToolPool = _batchEntityToolPool{} +) + func (l *Loader) loadBatchEntityFetch(ctx context.Context, fetchItem *FetchItem, fetch *BatchEntityFetch, items []*astjson.Value, res *result) error { res.init(fetch.PostProcessing, fetch.Info) - buf := acquireBatchEntityFetchBuffer() - defer releaseBatchEntityFetchBuffer(buf) - if l.ctx.TracingOptions.Enable { fetch.Trace = &DataSourceLoadTrace{} if !l.ctx.TracingOptions.ExcludeRawInputData && len(items) != 0 { - data := itemsData(items) + data := l.itemsData(items) if data != nil { fetch.Trace.RawInputData, _ = l.compactJSON(data.MarshalTo(nil)) } } } + res.tools = batchEntityToolPool.Get(len(items)) + preparedInput := arena.NewArenaBuffer(res.tools.a) + itemInput := arena.NewArenaBuffer(res.tools.a) + batchStats := arena.AllocateSlice[[]*astjson.Value](res.tools.a, 0, len(items)) + defer func() { + // we need to clear the batchStats slice to avoid memory corruption + // once the outer func returns, we must not keep pointers to items on the arena + for i := range batchStats { + // nolint:ineffassign + batchStats[i] = nil + } + // nolint:ineffassign + batchStats = nil + }() + + // I tried using arena here, but it only worsened the situation var undefinedVariables []string - err := fetch.Input.Header.RenderAndCollectUndefinedVariables(l.ctx, nil, buf.preparedInput, &undefinedVariables) + err := fetch.Input.Header.RenderAndCollectUndefinedVariables(l.ctx, nil, preparedInput, &undefinedVariables) if err != nil { return errors.WithStack(err) } - res.batchStats = make(batchStats, len(items)) - itemHashes := make([]uint64, 0, len(items)) batchItemIndex := 0 addSeparator := false WithNextItem: for i, item := range items { for j := range fetch.Input.Items { - buf.itemInput.Reset() - err = fetch.Input.Items[j].Render(l.ctx, item, buf.itemInput) + itemInput.Reset() + err = fetch.Input.Items[j].Render(l.ctx, item, itemInput) if err != nil { if fetch.Input.SkipErrItems { err = nil // nolint:ineffassign - res.batchStats[i] = append(res.batchStats[i], -1) continue } if l.ctx.TracingOptions.Enable { @@ -1499,39 +1465,38 @@ WithNextItem: } return errors.WithStack(err) } - if fetch.Input.SkipNullItems && buf.itemInput.Len() == 4 && bytes.Equal(buf.itemInput.Bytes(), null) { - res.batchStats[i] = append(res.batchStats[i], -1) + if fetch.Input.SkipNullItems && itemInput.Len() == 4 && bytes.Equal(itemInput.Bytes(), null) { continue } - if fetch.Input.SkipEmptyObjectItems && buf.itemInput.Len() == 2 && bytes.Equal(buf.itemInput.Bytes(), emptyObject) { - res.batchStats[i] = append(res.batchStats[i], -1) + if fetch.Input.SkipEmptyObjectItems && itemInput.Len() == 2 && bytes.Equal(itemInput.Bytes(), emptyObject) { continue } - buf.keyGen.Reset() - _, _ = buf.keyGen.Write(buf.itemInput.Bytes()) - itemHash := buf.keyGen.Sum64() - for k := range itemHashes { - if itemHashes[k] == itemHash { - res.batchStats[i] = append(res.batchStats[i], k) - continue WithNextItem - } - } - itemHashes = append(itemHashes, itemHash) - if addSeparator { - err = fetch.Input.Separator.Render(l.ctx, nil, buf.preparedInput) - if err != nil { - return errors.WithStack(err) + res.tools.keyGen.Reset() + _, _ = res.tools.keyGen.Write(itemInput.Bytes()) + itemHash := res.tools.keyGen.Sum64() + if existingIndex, ok := res.tools.batchHashToIndex[itemHash]; ok { + batchStats[existingIndex] = arena.SliceAppend(res.tools.a, batchStats[existingIndex], items[i]) + continue WithNextItem + } else { + if addSeparator { + err = fetch.Input.Separator.Render(l.ctx, nil, preparedInput) + if err != nil { + return errors.WithStack(err) + } } + _, _ = itemInput.WriteTo(preparedInput) + // new unique representation + res.tools.batchHashToIndex[itemHash] = batchItemIndex + // create a new targets bucket for this unique index + batchStats = arena.SliceAppend(res.tools.a, batchStats, []*astjson.Value{items[i]}) + batchItemIndex++ + addSeparator = true } - _, _ = buf.itemInput.WriteTo(buf.preparedInput) - res.batchStats[i] = append(res.batchStats[i], batchItemIndex) - batchItemIndex++ - addSeparator = true } } - if len(itemHashes) == 0 { + if len(batchStats) == 0 { // all items were skipped - discard fetch res.fetchSkipped = true if l.ctx.TracingOptions.Enable { @@ -1541,16 +1506,23 @@ WithNextItem: } } - err = fetch.Input.Footer.RenderAndCollectUndefinedVariables(l.ctx, nil, buf.preparedInput, &undefinedVariables) + err = fetch.Input.Footer.RenderAndCollectUndefinedVariables(l.ctx, nil, preparedInput, &undefinedVariables) if err != nil { return errors.WithStack(err) } - err = SetInputUndefinedVariables(buf.preparedInput, undefinedVariables) + err = SetInputUndefinedVariables(preparedInput, undefinedVariables) if err != nil { return errors.WithStack(err) } - fetchInput := buf.preparedInput.Bytes() + + fetchInput := preparedInput.Bytes() + // it's important to copy the *astjson.Value's off the arena to avoid memory corruption + res.batchStats = make([][]*astjson.Value, len(batchStats)) + for i := range batchStats { + res.batchStats[i] = make([]*astjson.Value, len(batchStats[i])) + copy(res.batchStats[i], batchStats[i]) + } if l.ctx.TracingOptions.Enable && res.fetchSkipped { l.setTracingInput(fetchItem, fetchInput, fetch.Trace) @@ -1564,6 +1536,7 @@ WithNextItem: if !allowed { return nil } + l.executeSourceLoad(ctx, fetchItem, fetch.DataSource, fetchInput, res, fetch.Trace) return nil } @@ -1605,29 +1578,8 @@ func redactHeaders(rawJSON json.RawMessage) (json.RawMessage, error) { return redactedJSON, nil } -type disallowSingleFlightContextKey struct{} - -func SingleFlightDisallowed(ctx context.Context) bool { - return ctx.Value(disallowSingleFlightContextKey{}) != nil -} - -type singleFlightStatsKey struct{} - -type SingleFlightStats struct { - SingleFlightUsed bool - SingleFlightSharedResponse bool -} - -func GetSingleFlightStats(ctx context.Context) *SingleFlightStats { - maybeStats := ctx.Value(singleFlightStatsKey{}) - if maybeStats == nil { - return nil - } - return maybeStats.(*SingleFlightStats) -} - -func setSingleFlightStats(ctx context.Context, stats *SingleFlightStats) context.Context { - return context.WithValue(ctx, singleFlightStatsKey{}, stats) +type singleFlightStats struct { + used, shared bool } func (l *Loader) setTracingInput(fetchItem *FetchItem, input []byte, trace *DataSourceLoadTrace) { @@ -1643,11 +1595,120 @@ func (l *Loader) setTracingInput(fetchItem *FetchItem, input []byte, trace *Data } } -func (l *Loader) loadByContext(ctx context.Context, source DataSource, input []byte, res *result) error { +type loaderContextKey string + +const ( + operationTypeContextKey loaderContextKey = "operationType" +) + +// GetOperationTypeFromContext can be used, e.g. by the transport, to check if the operation is a Mutation +func GetOperationTypeFromContext(ctx context.Context) ast.OperationType { + if ctx == nil { + return ast.OperationTypeQuery + } + if v := ctx.Value(operationTypeContextKey); v != nil { + if opType, ok := v.(ast.OperationType); ok { + return opType + } + } + return ast.OperationTypeQuery +} + +func (l *Loader) headersForSubgraphRequest(fetchItem *FetchItem) (http.Header, uint64) { + if fetchItem == nil || fetchItem.Fetch == nil { + return nil, 0 + } + info := fetchItem.Fetch.FetchInfo() + if info == nil { + return nil, 0 + } + return l.ctx.HeadersForSubgraphRequest(info.DataSourceName) +} + +// singleFlightAllowed returns true if the specific GraphQL Operation is a Query +// even if the root operation type is a Mutation or Subscription +// sub-operations can still be of type Query +// even in such cases we allow request de-duplication because such requests are idempotent +func (l *Loader) singleFlightAllowed(fetchItem *FetchItem) bool { + if l.ctx.ExecutionOptions.DisableSubgraphRequestDeduplication { + return false + } + if fetchItem == nil { + return false + } + if fetchItem.Fetch == nil { + return false + } + info := fetchItem.Fetch.FetchInfo() + if info == nil { + return false + } + if info.OperationType == ast.OperationTypeQuery { + return true + } + return false +} + +func (l *Loader) loadByContext(ctx context.Context, source DataSource, fetchItem *FetchItem, input []byte, res *result) error { + + if l.info != nil { + ctx = context.WithValue(ctx, operationTypeContextKey, l.info.OperationType) + } + + headers, extraKey := l.headersForSubgraphRequest(fetchItem) + + if !l.singleFlightAllowed(fetchItem) { + // Disable single flight for mutations + return l.loadByContextDirect(ctx, source, headers, input, res) + } + + item, shared := l.sf.GetOrCreateItem(fetchItem, input, extraKey) + if res.singleFlightStats != nil { + res.singleFlightStats.used = true + res.singleFlightStats.shared = shared + } + + if shared { + select { + case <-item.loaded: + case <-ctx.Done(): + return ctx.Err() + } + + if item.err != nil { + return item.err + } + + res.out = item.response + return nil + } + + // helps the http client to create buffers at the right size + ctx = httpclient.WithHTTPClientSizeHint(ctx, item.sizeHint) + + defer l.sf.Finish(item) + + // Perform the actual load + err := l.loadByContextDirect(ctx, source, headers, input, res) + if err != nil { + item.err = err + return err + } + + item.response = res.out + return nil +} + +func (l *Loader) loadByContextDirect(ctx context.Context, source DataSource, headers http.Header, input []byte, res *result) error { if l.ctx.Files != nil { - return source.LoadWithFiles(ctx, input, l.ctx.Files, res.out) + res.out, res.err = source.LoadWithFiles(ctx, headers, input, l.ctx.Files) + } else { + res.out, res.err = source.Load(ctx, headers, input) } - return source.Load(ctx, input, res.out) + if res.err != nil { + return errors.WithStack(res.err) + } + return nil } func (l *Loader) executeSourceLoad(ctx context.Context, fetchItem *FetchItem, source DataSource, input []byte, res *result, trace *DataSourceLoadTrace) { @@ -1676,7 +1737,7 @@ func (l *Loader) executeSourceLoad(ctx context.Context, fetchItem *FetchItem, so } } if l.ctx.TracingOptions.Enable { - ctx = setSingleFlightStats(ctx, &SingleFlightStats{}) + res.singleFlightStats = &singleFlightStats{} trace.Path = fetchItem.ResponsePath if !l.ctx.TracingOptions.ExcludeInput { trace.Input = make([]byte, len(input)) @@ -1780,9 +1841,6 @@ func (l *Loader) executeSourceLoad(ctx context.Context, fetchItem *FetchItem, so ctx = httptrace.WithClientTrace(ctx, clientTrace) } } - if l.info != nil && l.info.OperationType == ast.OperationTypeMutation { - ctx = context.WithValue(ctx, disallowSingleFlightContextKey{}, true) - } var responseContext *httpclient.ResponseContext ctx, responseContext = httpclient.InjectResponseContext(ctx) @@ -1791,27 +1849,26 @@ func (l *Loader) executeSourceLoad(ctx context.Context, fetchItem *FetchItem, so // Prevent that the context is destroyed when the loader hook return an empty context if res.loaderHookContext != nil { - res.err = l.loadByContext(res.loaderHookContext, source, input, res) + res.err = l.loadByContext(res.loaderHookContext, source, fetchItem, input, res) } else { - res.err = l.loadByContext(ctx, source, input, res) + res.err = l.loadByContext(ctx, source, fetchItem, input, res) res.loaderHookContext = ctx // Set the context to the original context to ensure that OnFinished hook gets valid context } } else { - res.err = l.loadByContext(ctx, source, input, res) + res.err = l.loadByContext(ctx, source, fetchItem, input, res) } res.statusCode = responseContext.StatusCode res.httpResponseContext = responseContext if l.ctx.TracingOptions.Enable { - stats := GetSingleFlightStats(ctx) - if stats != nil { - trace.SingleFlightUsed = stats.SingleFlightUsed - trace.SingleFlightSharedResponse = stats.SingleFlightSharedResponse + if res.singleFlightStats != nil { + trace.SingleFlightUsed = res.singleFlightStats.used + trace.SingleFlightSharedResponse = res.singleFlightStats.shared } - if !l.ctx.TracingOptions.ExcludeOutput && res.out.Len() > 0 { - trace.Output, _ = l.compactJSON(res.out.Bytes()) + if !l.ctx.TracingOptions.ExcludeOutput && len(res.out) > 0 { + trace.Output, _ = l.compactJSON(res.out) if l.ctx.TracingOptions.EnablePredictableDebugTimings { trace.Output, _ = sjson.DeleteBytes(trace.Output, "extensions.trace.response.headers.Date") } @@ -1840,7 +1897,10 @@ func (l *Loader) compactJSON(data []byte) ([]byte, error) { return nil, err } out := dst.Bytes() - v, err := astjson.ParseBytesWithoutCache(out) + // don't use arena here or segfault + // it's also not a hot path and not important to optimize + // arena requires the parsed content to be on the arena as well + v, err := astjson.ParseBytes(out) if err != nil { return nil, err } diff --git a/v2/pkg/engine/resolve/loader_hooks_test.go b/v2/pkg/engine/resolve/loader_hooks_test.go index 4b7b3ea6c..4a2ce9cb2 100644 --- a/v2/pkg/engine/resolve/loader_hooks_test.go +++ b/v2/pkg/engine/resolve/loader_hooks_test.go @@ -3,7 +3,7 @@ package resolve import ( "bytes" "context" - "io" + "net/http" "sync" "sync/atomic" "testing" @@ -50,11 +50,9 @@ func TestLoaderHooks_FetchPipeline(t *testing.T) { t.Run("simple fetch with simple subgraph error", testFnWithPostEvaluation(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx *Context, expectedOutput string, postEvaluation func(t *testing.T)) { mockDataSource := NewMockDataSource(ctrl) mockDataSource.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { - pair := NewBufPair() - pair.WriteErr([]byte("errorMessage"), nil, nil, nil) - return writeGraphqlResponse(pair, w, false) + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { + return []byte(`{"errors":[{"message":"errorMessage"}]}`), nil }) resolveCtx := Context{ ctx: context.Background(), @@ -124,11 +122,9 @@ func TestLoaderHooks_FetchPipeline(t *testing.T) { mockDataSource := NewMockDataSource(ctrl) mockDataSource.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { - pair := NewBufPair() - pair.WriteErr([]byte("errorMessage"), nil, nil, nil) - return writeGraphqlResponse(pair, w, false) + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { + return []byte(`{"errors":[{"message":"errorMessage"}]}`), nil }) resolveCtx := &Context{ ctx: context.Background(), @@ -192,11 +188,9 @@ func TestLoaderHooks_FetchPipeline(t *testing.T) { t.Run("parallel fetch with simple subgraph error", testFnWithPostEvaluation(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx *Context, expectedOutput string, postEvaluation func(t *testing.T)) { mockDataSource := NewMockDataSource(ctrl) mockDataSource.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { - pair := NewBufPair() - pair.WriteErr([]byte("errorMessage"), nil, nil, nil) - return writeGraphqlResponse(pair, w, false) + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { + return []byte(`{"errors":[{"message":"errorMessage"}]}`), nil }) resolveCtx := &Context{ ctx: context.Background(), @@ -254,80 +248,12 @@ func TestLoaderHooks_FetchPipeline(t *testing.T) { } })) - t.Run("parallel list item fetch with simple subgraph error", testFnWithPostEvaluation(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx *Context, expectedOutput string, postEvaluation func(t *testing.T)) { - mockDataSource := NewMockDataSource(ctrl) - mockDataSource.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { - pair := NewBufPair() - pair.WriteErr([]byte("errorMessage"), nil, nil, nil) - return writeGraphqlResponse(pair, w, false) - }) - resolveCtx := Context{ - ctx: context.Background(), - LoaderHooks: NewTestLoaderHooks(), - } - return &GraphQLResponse{ - Info: &GraphQLResponseInfo{ - OperationType: ast.OperationTypeQuery, - }, - Fetches: SingleWithPath(&ParallelListItemFetch{ - Fetch: &SingleFetch{ - FetchConfiguration: FetchConfiguration{ - DataSource: mockDataSource, - PostProcessing: PostProcessingConfiguration{ - SelectResponseErrorsPath: []string{"errors"}, - }, - }, - Info: &FetchInfo{ - DataSourceID: "Users", - DataSourceName: "Users", - }, - }, - }, "query"), - Data: &Object{ - Nullable: false, - Fields: []*Field{ - { - Name: []byte("name"), - Value: &String{ - Path: []string{"name"}, - Nullable: true, - }, - }, - }, - }, - }, &resolveCtx, `{"errors":[{"message":"Failed to fetch from Subgraph 'Users' at Path 'query'.","extensions":{"errors":[{"message":"errorMessage"}]}}],"data":{"name":null}}`, - func(t *testing.T) { - loaderHooks := resolveCtx.LoaderHooks.(*TestLoaderHooks) - - assert.Equal(t, int64(1), loaderHooks.preFetchCalls.Load()) - assert.Equal(t, int64(1), loaderHooks.postFetchCalls.Load()) - - var subgraphError *SubgraphError - assert.Len(t, loaderHooks.errors, 1) - assert.ErrorAs(t, loaderHooks.errors[0], &subgraphError) - assert.Equal(t, "Users", subgraphError.DataSourceInfo.Name) - assert.Equal(t, "query", subgraphError.Path) - assert.Equal(t, "", subgraphError.Reason) - assert.Equal(t, 0, subgraphError.ResponseCode) - assert.Len(t, subgraphError.DownstreamErrors, 1) - assert.Equal(t, "errorMessage", subgraphError.DownstreamErrors[0].Message) - assert.Nil(t, subgraphError.DownstreamErrors[0].Extensions) - - assert.NotNil(t, resolveCtx.SubgraphErrors()) - } - })) - t.Run("fetch with subgraph error and custom extension code. No extension fields are propagated by default", testFnWithPostEvaluation(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx *Context, expectedOutput string, postEvaluation func(t *testing.T)) { mockDataSource := NewMockDataSource(ctrl) mockDataSource.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { - pair := NewBufPair() - pair.WriteErr([]byte("errorMessage"), nil, nil, []byte("{\"code\":\"GRAPHQL_VALIDATION_FAILED\"}")) - pair.WriteErr([]byte("errorMessage2"), nil, nil, []byte("{\"code\":\"BAD_USER_INPUT\"}")) - return writeGraphqlResponse(pair, w, false) + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { + return []byte(`{"errors":[{"message":"errorMessage","extensions":{"code":"GRAPHQL_VALIDATION_FAILED"}},{"message":"errorMessage2","extensions":{"code":"BAD_USER_INPUT"}}]}`), nil }) resolveCtx := Context{ ctx: context.Background(), @@ -388,12 +314,9 @@ func TestLoaderHooks_FetchPipeline(t *testing.T) { t.Run("Propagate only extension code field from subgraph errors", testFnSubgraphErrorsWithExtensionFieldCode(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { mockDataSource := NewMockDataSource(ctrl) mockDataSource.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { - pair := NewBufPair() - pair.WriteErr([]byte("errorMessage"), nil, nil, []byte("{\"code\":\"GRAPHQL_VALIDATION_FAILED\",\"foo\":\"bar\"}")) - pair.WriteErr([]byte("errorMessage2"), nil, nil, []byte("{\"code\":\"BAD_USER_INPUT\"}")) - return writeGraphqlResponse(pair, w, false) + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { + return []byte(`{"errors":[{"message":"errorMessage","extensions":{"code":"GRAPHQL_VALIDATION_FAILED","foo":"bar"}},{"message":"errorMessage2","extensions":{"code":"BAD_USER_INPUT"}}]}`), nil }) return &GraphQLResponse{ Fetches: Single(&SingleFetch{ @@ -426,12 +349,9 @@ func TestLoaderHooks_FetchPipeline(t *testing.T) { t.Run("Propagate all extension fields from subgraph errors when allow all option is enabled", testFnSubgraphErrorsWithAllowAllExtensionFields(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { mockDataSource := NewMockDataSource(ctrl) mockDataSource.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { - pair := NewBufPair() - pair.WriteErr([]byte("errorMessage"), nil, nil, []byte("{\"code\":\"GRAPHQL_VALIDATION_FAILED\",\"foo\":\"bar\"}")) - pair.WriteErr([]byte("errorMessage2"), nil, nil, []byte("{\"code\":\"BAD_USER_INPUT\"}")) - return writeGraphqlResponse(pair, w, false) + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { + return []byte(`{"errors":[{"message":"errorMessage","extensions":{"code":"GRAPHQL_VALIDATION_FAILED","foo":"bar"}},{"message":"errorMessage2","extensions":{"code":"BAD_USER_INPUT"}}]}`), nil }) return &GraphQLResponse{ Fetches: Single(&SingleFetch{ @@ -464,12 +384,9 @@ func TestLoaderHooks_FetchPipeline(t *testing.T) { t.Run("Include datasource name as serviceName extension field", testFnSubgraphErrorsWithExtensionFieldServiceName(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { mockDataSource := NewMockDataSource(ctrl) mockDataSource.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { - pair := NewBufPair() - pair.WriteErr([]byte("errorMessage"), nil, nil, []byte("{\"code\":\"GRAPHQL_VALIDATION_FAILED\"}")) - pair.WriteErr([]byte("errorMessage2"), nil, nil, []byte("{\"code\":\"BAD_USER_INPUT\"}")) - return writeGraphqlResponse(pair, w, false) + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { + return []byte(`{"errors":[{"message":"errorMessage","extensions":{"code":"GRAPHQL_VALIDATION_FAILED"}},{"message":"errorMessage2","extensions":{"code":"BAD_USER_INPUT"}}]}`), nil }) return &GraphQLResponse{ Fetches: Single(&SingleFetch{ @@ -502,12 +419,9 @@ func TestLoaderHooks_FetchPipeline(t *testing.T) { t.Run("Include datasource name as serviceName when extensions is null", testFnSubgraphErrorsWithExtensionFieldServiceName(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { mockDataSource := NewMockDataSource(ctrl) mockDataSource.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { - pair := NewBufPair() - pair.WriteErr([]byte("errorMessage"), nil, nil, []byte("null")) - pair.WriteErr([]byte("errorMessage2"), nil, nil, []byte("null")) - return writeGraphqlResponse(pair, w, false) + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { + return []byte(`{"errors":[{"message":"errorMessage","extensions":null},{"message":"errorMessage2","extensions":null}]}`), nil }) return &GraphQLResponse{ Fetches: Single(&SingleFetch{ @@ -540,12 +454,9 @@ func TestLoaderHooks_FetchPipeline(t *testing.T) { t.Run("Include datasource name as serviceName when extensions is an empty object", testFnSubgraphErrorsWithExtensionFieldServiceName(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { mockDataSource := NewMockDataSource(ctrl) mockDataSource.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { - pair := NewBufPair() - pair.WriteErr([]byte("errorMessage"), nil, nil, []byte("{}")) - pair.WriteErr([]byte("errorMessage2"), nil, nil, []byte("null")) - return writeGraphqlResponse(pair, w, false) + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { + return []byte(`{"errors":[{"message":"errorMessage","extensions":{}},{"message":"errorMessage2","extensions":null}]}`), nil }) return &GraphQLResponse{ Fetches: Single(&SingleFetch{ @@ -578,12 +489,9 @@ func TestLoaderHooks_FetchPipeline(t *testing.T) { t.Run("Fallback to default extension code value when no code field was set", testFnSubgraphErrorsWithExtensionDefaultCode(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { mockDataSource := NewMockDataSource(ctrl) mockDataSource.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { - pair := NewBufPair() - pair.WriteErr([]byte("errorMessage"), nil, nil, []byte("{\"code\":\"GRAPHQL_VALIDATION_FAILED\"}")) - pair.WriteErr([]byte("errorMessage2"), nil, nil, nil) - return writeGraphqlResponse(pair, w, false) + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { + return []byte(`{"errors":[{"message":"errorMessage","extensions":{"code":"GRAPHQL_VALIDATION_FAILED"}},{"message":"errorMessage2"}]}`), nil }) return &GraphQLResponse{ Fetches: Single(&SingleFetch{ @@ -616,12 +524,9 @@ func TestLoaderHooks_FetchPipeline(t *testing.T) { t.Run("Fallback to default extension code value when extensions is null", testFnSubgraphErrorsWithExtensionDefaultCode(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { mockDataSource := NewMockDataSource(ctrl) mockDataSource.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { - pair := NewBufPair() - pair.WriteErr([]byte("errorMessage"), nil, nil, []byte("null")) - pair.WriteErr([]byte("errorMessage2"), nil, nil, nil) - return writeGraphqlResponse(pair, w, false) + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { + return []byte(`{"errors":[{"message":"errorMessage","extensions":null},{"message":"errorMessage2"}]}`), nil }) return &GraphQLResponse{ Fetches: Single(&SingleFetch{ @@ -654,12 +559,9 @@ func TestLoaderHooks_FetchPipeline(t *testing.T) { t.Run("Fallback to default extension code value when extensions is an empty object", testFnSubgraphErrorsWithExtensionDefaultCode(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { mockDataSource := NewMockDataSource(ctrl) mockDataSource.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { - pair := NewBufPair() - pair.WriteErr([]byte("errorMessage"), nil, nil, []byte("{}")) - pair.WriteErr([]byte("errorMessage2"), nil, nil, nil) - return writeGraphqlResponse(pair, w, false) + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { + return []byte(`{"errors":[{"message":"errorMessage","extensions":{}},{"message":"errorMessage2"}]}`), nil }) return &GraphQLResponse{ Fetches: Single(&SingleFetch{ diff --git a/v2/pkg/engine/resolve/loader_test.go b/v2/pkg/engine/resolve/loader_test.go index 4ed83d444..f88d7227f 100644 --- a/v2/pkg/engine/resolve/loader_test.go +++ b/v2/pkg/engine/resolve/loader_test.go @@ -19,19 +19,19 @@ func TestLoader_LoadGraphQLResponseData(t *testing.T) { ctrl := gomock.NewController(t) productsService := mockedDS(t, ctrl, `{"method":"POST","url":"http://products","body":{"query":"query{topProducts{name __typename upc}}"}}`, - `{"topProducts":[{"name":"Table","__typename":"Product","upc":"1"},{"name":"Couch","__typename":"Product","upc":"2"},{"name":"Chair","__typename":"Product","upc":"3"}]}`) + `{"data":{"topProducts":[{"name":"Table","__typename":"Product","upc":"1"},{"name":"Couch","__typename":"Product","upc":"2"},{"name":"Chair","__typename":"Product","upc":"3"}]}}`) reviewsService := mockedDS(t, ctrl, `{"method":"POST","url":"http://reviews","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){__typename ... on Product {reviews {body author {__typename id}}}}}","variables":{"representations":[{"__typename":"Product","upc":"1"},{"__typename":"Product","upc":"2"},{"__typename":"Product","upc":"3"}]}}}`, - `{"_entities":[{"__typename":"Product","reviews":[{"body":"Love Table!","author":{"__typename":"User","id":"1"}},{"body":"Prefer other Table.","author":{"__typename":"User","id":"2"}}]},{"__typename":"Product","reviews":[{"body":"Couch Too expensive.","author":{"__typename":"User","id":"1"}}]},{"__typename":"Product","reviews":[{"body":"Chair Could be better.","author":{"__typename":"User","id":"2"}}]}]}`) + `{"data":{"_entities":[{"__typename":"Product","reviews":[{"body":"Love Table!","author":{"__typename":"User","id":"1"}},{"body":"Prefer other Table.","author":{"__typename":"User","id":"2"}}]},{"__typename":"Product","reviews":[{"body":"Couch Too expensive.","author":{"__typename":"User","id":"1"}}]},{"__typename":"Product","reviews":[{"body":"Chair Could be better.","author":{"__typename":"User","id":"2"}}]}]}}`) stockService := mockedDS(t, ctrl, `{"method":"POST","url":"http://stock","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){__typename ... on Product {stock}}}","variables":{"representations":[{"__typename":"Product","upc":"1"},{"__typename":"Product","upc":"2"},{"__typename":"Product","upc":"3"}]}}}`, - `{"_entities":[{"stock":8},{"stock":2},{"stock":5}]}`) + `{"data":{"_entities":[{"stock":8},{"stock":2},{"stock":5}]}}`) usersService := mockedDS(t, ctrl, `{"method":"POST","url":"http://users","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){__typename ... on User {name}}}","variables":{"representations":[{"__typename":"User","id":"1"},{"__typename":"User","id":"2"}]}}}`, - `{"_entities":[{"name":"user-1"},{"name":"user-2"}]}`) + `{"data":{"_entities":[{"name":"user-1"},{"name":"user-2"}]}}`) response := &GraphQLResponse{ Fetches: Sequence( Single(&SingleFetch{ @@ -287,7 +287,7 @@ func TestLoader_LoadGraphQLResponseData(t *testing.T) { ctx := &Context{ ctx: context.Background(), } - resolvable := NewResolvable(ResolvableOptions{}) + resolvable := NewResolvable(nil, ResolvableOptions{}) loader := &Loader{} err := resolvable.Init(ctx, nil, ast.OperationTypeQuery) assert.NoError(t, err) @@ -296,7 +296,7 @@ func TestLoader_LoadGraphQLResponseData(t *testing.T) { ctrl.Finish() out := fastjsonext.PrintGraphQLResponse(resolvable.data, resolvable.errors) assert.NoError(t, err) - expected := `{"errors":[],"data":{"topProducts":[{"name":"Table","__typename":"Product","upc":"1","reviews":[{"body":"Love Table!","author":{"__typename":"User","id":"1","name":"user-1"}},{"body":"Prefer other Table.","author":{"__typename":"User","id":"2","name":"user-2"}}],"stock":8},{"name":"Couch","__typename":"Product","upc":"2","reviews":[{"body":"Couch Too expensive.","author":{"__typename":"User","id":"1","name":"user-1"}}],"stock":2},{"name":"Chair","__typename":"Product","upc":"3","reviews":[{"body":"Chair Could be better.","author":{"__typename":"User","id":"2","name":"user-2"}}],"stock":5}]}}` + expected := `{"data":{"topProducts":[{"name":"Table","__typename":"Product","upc":"1","reviews":[{"body":"Love Table!","author":{"__typename":"User","id":"1","name":"user-1"}},{"body":"Prefer other Table.","author":{"__typename":"User","id":"2","name":"user-2"}}],"stock":8},{"name":"Couch","__typename":"Product","upc":"2","reviews":[{"body":"Couch Too expensive.","author":{"__typename":"User","id":"1","name":"user-1"}}],"stock":2},{"name":"Chair","__typename":"Product","upc":"3","reviews":[{"body":"Chair Could be better.","author":{"__typename":"User","id":"2","name":"user-2"}}],"stock":5}]}}` assert.Equal(t, expected, out) } @@ -376,7 +376,7 @@ func TestLoader_MergeErrorDifferingTypes(t *testing.T) { ctx := &Context{ ctx: context.Background(), } - resolvable := NewResolvable(ResolvableOptions{}) + resolvable := NewResolvable(nil, ResolvableOptions{}) loader := &Loader{} err := resolvable.Init(ctx, nil, ast.OperationTypeQuery) assert.NoError(t, err) @@ -467,7 +467,7 @@ func TestLoader_MergeErrorDifferingArrayLength(t *testing.T) { ctx := &Context{ ctx: context.Background(), } - resolvable := NewResolvable(ResolvableOptions{}) + resolvable := NewResolvable(nil, ResolvableOptions{}) loader := &Loader{} err := resolvable.Init(ctx, nil, ast.OperationTypeQuery) assert.NoError(t, err) @@ -480,19 +480,19 @@ func TestLoader_LoadGraphQLResponseDataWithExtensions(t *testing.T) { ctrl := gomock.NewController(t) productsService := mockedDS(t, ctrl, `{"method":"POST","url":"http://products","body":{"query":"query{topProducts{name __typename upc}}","extensions":{"foo":"bar"}}}`, - `{"topProducts":[{"name":"Table","__typename":"Product","upc":"1"},{"name":"Couch","__typename":"Product","upc":"2"},{"name":"Chair","__typename":"Product","upc":"3"}]}`) + `{"data":{"topProducts":[{"name":"Table","__typename":"Product","upc":"1"},{"name":"Couch","__typename":"Product","upc":"2"},{"name":"Chair","__typename":"Product","upc":"3"}]}}`) reviewsService := mockedDS(t, ctrl, `{"method":"POST","url":"http://reviews","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){__typename ... on Product {reviews {body author {__typename id}}}}}","variables":{"representations":[{"__typename":"Product","upc":"1"},{"__typename":"Product","upc":"2"},{"__typename":"Product","upc":"3"}]},"extensions":{"foo":"bar"}}}`, - `{"_entities":[{"__typename":"Product","reviews":[{"body":"Love Table!","author":{"__typename":"User","id":"1"}},{"body":"Prefer other Table.","author":{"__typename":"User","id":"2"}}]},{"__typename":"Product","reviews":[{"body":"Couch Too expensive.","author":{"__typename":"User","id":"1"}}]},{"__typename":"Product","reviews":[{"body":"Chair Could be better.","author":{"__typename":"User","id":"2"}}]}]}`) + `{"data":{"_entities":[{"__typename":"Product","reviews":[{"body":"Love Table!","author":{"__typename":"User","id":"1"}},{"body":"Prefer other Table.","author":{"__typename":"User","id":"2"}}]},{"__typename":"Product","reviews":[{"body":"Couch Too expensive.","author":{"__typename":"User","id":"1"}}]},{"__typename":"Product","reviews":[{"body":"Chair Could be better.","author":{"__typename":"User","id":"2"}}]}]}}`) stockService := mockedDS(t, ctrl, `{"method":"POST","url":"http://stock","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){__typename ... on Product {stock}}}","variables":{"representations":[{"__typename":"Product","upc":"1"},{"__typename":"Product","upc":"2"},{"__typename":"Product","upc":"3"}]},"extensions":{"foo":"bar"}}}`, - `{"_entities":[{"stock":8},{"stock":2},{"stock":5}]}`) + `{"data":{"_entities":[{"stock":8},{"stock":2},{"stock":5}]}}`) usersService := mockedDS(t, ctrl, `{"method":"POST","url":"http://users","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){__typename ... on User {name}}}","variables":{"representations":[{"__typename":"User","id":"1"},{"__typename":"User","id":"2"}]},"extensions":{"foo":"bar"}}}`, - `{"_entities":[{"name":"user-1"},{"name":"user-2"}]}`) + `{"data":{"_entities":[{"name":"user-1"},{"name":"user-2"}]}}`) response := &GraphQLResponse{ Fetches: Sequence( Single(&SingleFetch{ @@ -749,7 +749,7 @@ func TestLoader_LoadGraphQLResponseDataWithExtensions(t *testing.T) { ctx: context.Background(), Extensions: []byte(`{"foo":"bar"}`), } - resolvable := NewResolvable(ResolvableOptions{}) + resolvable := NewResolvable(nil, ResolvableOptions{}) loader := &Loader{} err := resolvable.Init(ctx, nil, ast.OperationTypeQuery) assert.NoError(t, err) @@ -758,7 +758,7 @@ func TestLoader_LoadGraphQLResponseDataWithExtensions(t *testing.T) { ctrl.Finish() out := fastjsonext.PrintGraphQLResponse(resolvable.data, resolvable.errors) assert.NoError(t, err) - expected := `{"errors":[],"data":{"topProducts":[{"name":"Table","__typename":"Product","upc":"1","reviews":[{"body":"Love Table!","author":{"__typename":"User","id":"1","name":"user-1"}},{"body":"Prefer other Table.","author":{"__typename":"User","id":"2","name":"user-2"}}],"stock":8},{"name":"Couch","__typename":"Product","upc":"2","reviews":[{"body":"Couch Too expensive.","author":{"__typename":"User","id":"1","name":"user-1"}}],"stock":2},{"name":"Chair","__typename":"Product","upc":"3","reviews":[{"body":"Chair Could be better.","author":{"__typename":"User","id":"2","name":"user-2"}}],"stock":5}]}}` + expected := `{"data":{"topProducts":[{"name":"Table","__typename":"Product","upc":"1","reviews":[{"body":"Love Table!","author":{"__typename":"User","id":"1","name":"user-1"}},{"body":"Prefer other Table.","author":{"__typename":"User","id":"2","name":"user-2"}}],"stock":8},{"name":"Couch","__typename":"Product","upc":"2","reviews":[{"body":"Couch Too expensive.","author":{"__typename":"User","id":"1","name":"user-1"}}],"stock":2},{"name":"Chair","__typename":"Product","upc":"3","reviews":[{"body":"Chair Could be better.","author":{"__typename":"User","id":"2","name":"user-2"}}],"stock":5}]}}` assert.Equal(t, expected, out) } @@ -1024,9 +1024,9 @@ func BenchmarkLoader_LoadGraphQLResponseData(b *testing.B) { ctx := &Context{ ctx: context.Background(), } - resolvable := NewResolvable(ResolvableOptions{}) + resolvable := NewResolvable(nil, ResolvableOptions{}) loader := &Loader{} - expected := `{"errors":[],"data":{"topProducts":[{"name":"Table","__typename":"Product","upc":"1","reviews":[{"body":"Love Table!","author":{"__typename":"User","id":"1","name":"user-1"}},{"body":"Prefer other Table.","author":{"__typename":"User","id":"2","name":"user-2"}}],"stock":8},{"name":"Couch","__typename":"Product","upc":"2","reviews":[{"body":"Couch Too expensive.","author":{"__typename":"User","id":"1","name":"user-1"}}],"stock":2},{"name":"Chair","__typename":"Product","upc":"3","reviews":[{"body":"Chair Could be better.","author":{"__typename":"User","id":"2","name":"user-2"}}],"stock":5}]}}` + expected := `{"data":{"topProducts":[{"name":"Table","__typename":"Product","upc":"1","reviews":[{"body":"Love Table!","author":{"__typename":"User","id":"1","name":"user-1"}},{"body":"Prefer other Table.","author":{"__typename":"User","id":"2","name":"user-2"}}],"stock":8},{"name":"Couch","__typename":"Product","upc":"2","reviews":[{"body":"Couch Too expensive.","author":{"__typename":"User","id":"1","name":"user-1"}}],"stock":2},{"name":"Chair","__typename":"Product","upc":"3","reviews":[{"body":"Chair Could be better.","author":{"__typename":"User","id":"2","name":"user-2"}}],"stock":5}]}}` b.SetBytes(int64(len(expected))) b.ReportAllocs() b.ResetTimer() @@ -1054,7 +1054,7 @@ func TestLoader_RedactHeaders(t *testing.T) { productsService := mockedDS(t, ctrl, `{"method":"POST","url":"http://products","header":{"Authorization":"value"},"body":{"query":"query{topProducts{name __typename upc}}"},"__trace__":true}`, - `{"topProducts":[{"name":"Table","__typename":"Product","upc":"1"},{"name":"Couch","__typename":"Product","upc":"2"},{"name":"Chair","__typename":"Product","upc":"3"}]}`) + `{"data":{"topProducts":[{"name":"Table","__typename":"Product","upc":"1"},{"name":"Couch","__typename":"Product","upc":"2"},{"name":"Chair","__typename":"Product","upc":"3"}]}}`) response := &GraphQLResponse{ Fetches: Single(&SingleFetch{ @@ -1125,7 +1125,7 @@ func TestLoader_RedactHeaders(t *testing.T) { Enable: true, }, } - resolvable := NewResolvable(ResolvableOptions{}) + resolvable := NewResolvable(nil, ResolvableOptions{}) loader := &Loader{} err := resolvable.Init(ctx, nil, ast.OperationTypeQuery) @@ -1153,19 +1153,19 @@ func TestLoader_InvalidBatchItemCount(t *testing.T) { ctrl := gomock.NewController(t) productsService := mockedDS(t, ctrl, `{"method":"POST","url":"http://products","body":{"query":"query{topProducts{name __typename upc}}"}}`, - `{"topProducts":[{"name":"Table","__typename":"Product","upc":"1"},{"name":"Couch","__typename":"Product","upc":"2"},{"name":"Chair","__typename":"Product","upc":"3"}]}`) + `{"data":{"topProducts":[{"name":"Table","__typename":"Product","upc":"1"},{"name":"Couch","__typename":"Product","upc":"2"},{"name":"Chair","__typename":"Product","upc":"3"}]}}`) reviewsService := mockedDS(t, ctrl, `{"method":"POST","url":"http://reviews","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){__typename ... on Product {reviews {body author {__typename id}}}}}","variables":{"representations":[{"__typename":"Product","upc":"1"},{"__typename":"Product","upc":"2"},{"__typename":"Product","upc":"3"}]}}}`, - `{"_entities":[{"__typename":"Product","reviews":[{"body":"Love Table!","author":{"__typename":"User","id":"1"}},{"body":"Prefer other Table.","author":{"__typename":"User","id":"2"}}]},{"__typename":"Product","reviews":[{"body":"Couch Too expensive.","author":{"__typename":"User","id":"1"}}]},{"__typename":"Product","reviews":[{"body":"Chair Could be better.","author":{"__typename":"User","id":"2"}}]}]}`) + `{"data":{"_entities":[{"__typename":"Product","reviews":[{"body":"Love Table!","author":{"__typename":"User","id":"1"}},{"body":"Prefer other Table.","author":{"__typename":"User","id":"2"}}]},{"__typename":"Product","reviews":[{"body":"Couch Too expensive.","author":{"__typename":"User","id":"1"}}]},{"__typename":"Product","reviews":[{"body":"Chair Could be better.","author":{"__typename":"User","id":"2"}}]}]}}`) stockService := mockedDS(t, ctrl, `{"method":"POST","url":"http://stock","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){__typename ... on Product {stock}}}","variables":{"representations":[{"__typename":"Product","upc":"1"},{"__typename":"Product","upc":"2"},{"__typename":"Product","upc":"3"}]}}}`, - `{"_entities":[{"stock":8},{"stock":2}]}`) // 3 items expected, 2 returned + `{"data":{"_entities":[{"stock":8},{"stock":2}]}}`) // 3 items expected, 2 returned usersService := mockedDS(t, ctrl, `{"method":"POST","url":"http://users","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){__typename ... on User {name}}}","variables":{"representations":[{"__typename":"User","id":"1"},{"__typename":"User","id":"2"}]}}}`, - `{"_entities":[{"name":"user-1"},{"name":"user-2"},{"name":"user-3"}]}`) // 2 items expected, 3 returned + `{"data":{"_entities":[{"name":"user-1"},{"name":"user-2"},{"name":"user-3"}]}}`) // 2 items expected, 3 returned response := &GraphQLResponse{ Fetches: Sequence( Single(&SingleFetch{ @@ -1421,7 +1421,7 @@ func TestLoader_InvalidBatchItemCount(t *testing.T) { ctx := &Context{ ctx: context.Background(), } - resolvable := NewResolvable(ResolvableOptions{}) + resolvable := NewResolvable(nil, ResolvableOptions{}) loader := &Loader{} err := resolvable.Init(ctx, nil, ast.OperationTypeQuery) assert.NoError(t, err) @@ -1521,13 +1521,13 @@ func TestRewriteErrorPaths(t *testing.T) { for i, inputError := range tc.inputErrors { // Create a copy by marshaling and parsing again data := inputError.MarshalTo(nil) - value, err := astjson.ParseBytesWithoutCache(data) + value, err := astjson.ParseBytesWithArena(nil, data) assert.NoError(t, err, "Failed to copy input error") values[i] = value } // Call the function under test - rewriteErrorPaths(fetchItem, values) + rewriteErrorPaths(nil, fetchItem, values) // Compare the results assert.Equal(t, len(tc.expectedErrors), len(values), diff --git a/v2/pkg/engine/resolve/resolvable.go b/v2/pkg/engine/resolve/resolvable.go index 5219c910d..226705a70 100644 --- a/v2/pkg/engine/resolve/resolvable.go +++ b/v2/pkg/engine/resolve/resolvable.go @@ -13,6 +13,7 @@ import ( "github.com/tidwall/gjson" "github.com/wundergraph/astjson" + "github.com/wundergraph/go-arena" "github.com/wundergraph/graphql-go-tools/v2/pkg/ast" "github.com/wundergraph/graphql-go-tools/v2/pkg/errorcodes" @@ -30,8 +31,9 @@ type Resolvable struct { errors *astjson.Value valueCompletion *astjson.Value skipAddingNullErrors bool - - astjsonArena *astjson.Arena + // astjsonArena is the arena to handle json, supplied by Resolver + // not thread safe, but Resolvable is single threaded anyways + astjsonArena arena.Arena parsers []*astjson.Parser print bool @@ -67,13 +69,13 @@ type ResolvableOptions struct { ApolloCompatibilityReplaceInvalidVarError bool } -func NewResolvable(options ResolvableOptions) *Resolvable { +func NewResolvable(a arena.Arena, options ResolvableOptions) *Resolvable { return &Resolvable{ options: options, xxh: xxhash.New(), authorizationAllow: make(map[uint64]struct{}), authorizationDeny: make(map[uint64]string), - astjsonArena: &astjson.Arena{}, + astjsonArena: a, } } @@ -95,7 +97,7 @@ func (r *Resolvable) Reset() { r.operationType = ast.OperationTypeUnknown r.renameTypeNames = r.renameTypeNames[:0] r.authorizationError = nil - r.astjsonArena.Reset() + r.astjsonArena = nil r.xxh.Reset() for k := range r.authorizationAllow { delete(r.authorizationAllow, k) @@ -109,14 +111,15 @@ func (r *Resolvable) Init(ctx *Context, initialData []byte, operationType ast.Op r.ctx = ctx r.operationType = operationType r.renameTypeNames = ctx.RenameTypeNames - r.data = r.astjsonArena.NewObject() - r.errors = r.astjsonArena.NewArray() + r.data = astjson.ObjectValue(r.astjsonArena) + // don't init errors! It will heavily increase memory usage + r.errors = nil if initialData != nil { - initialValue, err := astjson.ParseBytesWithoutCache(initialData) + initialValue, err := astjson.ParseBytesWithArena(r.astjsonArena, initialData) if err != nil { return err } - r.data, _, err = astjson.MergeValues(r.data, initialValue) + r.data, _, err = astjson.MergeValues(r.astjsonArena, r.data, initialValue) if err != nil { return err } @@ -128,20 +131,22 @@ func (r *Resolvable) InitSubscription(ctx *Context, initialData []byte, postProc r.ctx = ctx r.operationType = ast.OperationTypeSubscription r.renameTypeNames = ctx.RenameTypeNames + // don't init errors! It will heavily increase memory usage + r.errors = nil if initialData != nil { - initialValue, err := astjson.ParseBytesWithoutCache(initialData) + initialValue, err := astjson.ParseBytesWithArena(r.astjsonArena, initialData) if err != nil { return err } if postProcessing.SelectResponseDataPath == nil { - r.data, _, err = astjson.MergeValuesWithPath(r.data, initialValue, postProcessing.MergePath...) + r.data, _, err = astjson.MergeValuesWithPath(r.astjsonArena, r.data, initialValue, postProcessing.MergePath...) if err != nil { return err } } else { selectedInitialValue := initialValue.Get(postProcessing.SelectResponseDataPath...) if selectedInitialValue != nil { - r.data, _, err = astjson.MergeValuesWithPath(r.data, selectedInitialValue, postProcessing.MergePath...) + r.data, _, err = astjson.MergeValuesWithPath(r.astjsonArena, r.data, selectedInitialValue, postProcessing.MergePath...) if err != nil { return err } @@ -155,10 +160,7 @@ func (r *Resolvable) InitSubscription(ctx *Context, initialData []byte, postProc } } if r.data == nil { - r.data = r.astjsonArena.NewObject() - } - if r.errors == nil { - r.errors = r.astjsonArena.NewArray() + r.data = astjson.ObjectValue(r.astjsonArena) } return } @@ -168,7 +170,8 @@ func (r *Resolvable) ResolveNode(node Node, data *astjson.Value, out io.Writer) r.print = false r.printErr = nil r.authorizationError = nil - r.errors = r.astjsonArena.NewArray() + // don't init errors! It will heavily increase memory usage + r.errors = nil hasErrors := r.walkNode(node, data) if hasErrors { @@ -234,6 +237,13 @@ func (r *Resolvable) Resolve(ctx context.Context, rootData *Object, fetchTree *F return r.printErr } +// ensureErrorsInitialized is used to lazily init r.errors if needed +func (r *Resolvable) ensureErrorsInitialized() { + if r.errors == nil { + r.errors = astjson.ArrayValue(r.astjsonArena) + } +} + func (r *Resolvable) enclosingTypeName() string { if len(r.enclosingTypeNames) > 0 { return r.enclosingTypeNames[len(r.enclosingTypeNames)-1] @@ -464,7 +474,7 @@ func (r *Resolvable) renderScalarFieldValue(value *astjson.Value, nullable bool) // renderScalarFieldString - is used when value require some pre-processing, e.g. unescaping or custom rendering func (r *Resolvable) renderScalarFieldBytes(data []byte, nullable bool) { - value, err := astjson.ParseBytesWithoutCache(data) + value, err := astjson.ParseBytesWithArena(r.astjsonArena, data) if err != nil { r.printErr = err return @@ -760,6 +770,7 @@ func (r *Resolvable) addRejectFieldError(reason string, ds DataSourceInfo, field } r.ctx.appendSubgraphErrors(errors.New(errorMessage), NewSubgraphError(ds, fieldPath, reason, 0)) + r.ensureErrorsInitialized() fastjsonext.AppendErrorWithExtensionsCodeToArray(r.astjsonArena, r.errors, errorMessage, errorcodes.UnauthorizedFieldOrType, r.path) r.popNodePathElement(nodePath) } @@ -853,7 +864,7 @@ func (r *Resolvable) walkArray(arr *Array, value *astjson.Value) bool { r.popArrayPathElement() if err { if arr.Item.NodeKind() == NodeKindObject && arr.Item.NodeNullable() { - value.SetArrayItem(i, astjson.NullValue) + value.SetArrayItem(r.astjsonArena, i, astjson.NullValue) continue } if arr.Nullable { @@ -1201,6 +1212,7 @@ func (r *Resolvable) addNonNullableFieldError(fieldPath []string, parent *astjso r.addValueCompletion(r.renderApolloCompatibleNonNullableErrorMessage(), errorcodes.InvalidGraphql) } else { errorMessage := fmt.Sprintf("Cannot return null for non-nullable field '%s'.", r.renderFieldPath()) + r.ensureErrorsInitialized() fastjsonext.AppendErrorToArray(r.astjsonArena, r.errors, errorMessage, r.path) } r.popNodePathElement(fieldPath) @@ -1271,30 +1283,33 @@ func (r *Resolvable) renderFieldCoordinates() string { func (r *Resolvable) addError(message string, fieldPath []string) { r.pushNodePathElement(fieldPath) + r.ensureErrorsInitialized() fastjsonext.AppendErrorToArray(r.astjsonArena, r.errors, message, r.path) r.popNodePathElement(fieldPath) } func (r *Resolvable) addErrorWithCode(message, code string) { + r.ensureErrorsInitialized() fastjsonext.AppendErrorWithExtensionsCodeToArray(r.astjsonArena, r.errors, message, code, r.path) } func (r *Resolvable) addErrorWithCodeAndPath(message, code string, fieldPath []string) { r.pushNodePathElement(fieldPath) + r.ensureErrorsInitialized() fastjsonext.AppendErrorWithExtensionsCodeToArray(r.astjsonArena, r.errors, message, code, r.path) r.popNodePathElement(fieldPath) } func (r *Resolvable) addValueCompletion(message, code string) { if r.valueCompletion == nil { - r.valueCompletion = r.astjsonArena.NewArray() + r.valueCompletion = astjson.ArrayValue(r.astjsonArena) } fastjsonext.AppendErrorWithExtensionsCodeToArray(r.astjsonArena, r.valueCompletion, message, code, r.path) } func (r *Resolvable) addValueCompletionWithPath(message, code string, fieldPath []string) { if r.valueCompletion == nil { - r.valueCompletion = r.astjsonArena.NewArray() + r.valueCompletion = astjson.ArrayValue(r.astjsonArena) } r.pushNodePathElement(fieldPath) fastjsonext.AppendErrorWithExtensionsCodeToArray(r.astjsonArena, r.valueCompletion, message, code, r.path) diff --git a/v2/pkg/engine/resolve/resolvable_custom_field_renderer_test.go b/v2/pkg/engine/resolve/resolvable_custom_field_renderer_test.go index 843c6e696..0dbb0394b 100644 --- a/v2/pkg/engine/resolve/resolvable_custom_field_renderer_test.go +++ b/v2/pkg/engine/resolve/resolvable_custom_field_renderer_test.go @@ -440,7 +440,7 @@ func TestResolvable_CustomFieldRenderer(t *testing.T) { t.Parallel() // Setup - res := NewResolvable(ResolvableOptions{}) + res := NewResolvable(nil, ResolvableOptions{}) ctx := &Context{} var input []byte @@ -543,7 +543,7 @@ func TestResolvable_CustomFieldRenderer(t *testing.T) { t.Parallel() input := []byte(tc.input) - res := NewResolvable(ResolvableOptions{}) + res := NewResolvable(nil, ResolvableOptions{}) ctx := &Context{} err := res.Init(ctx, input, ast.OperationTypeQuery) assert.NoError(t, err) diff --git a/v2/pkg/engine/resolve/resolvable_test.go b/v2/pkg/engine/resolve/resolvable_test.go index 4b92f8591..aea4e78ef 100644 --- a/v2/pkg/engine/resolve/resolvable_test.go +++ b/v2/pkg/engine/resolve/resolvable_test.go @@ -12,7 +12,7 @@ import ( func TestResolvable_Resolve(t *testing.T) { topProducts := `{"topProducts":[{"name":"Table","__typename":"Product","upc":"1","reviews":[{"body":"Love Table!","author":{"__typename":"User","id":"1","name":"user-1"}},{"body":"Prefer other Table.","author":{"__typename":"User","id":"2","name":"user-2"}}],"stock":8},{"name":"Couch","__typename":"Product","upc":"2","reviews":[{"body":"Couch Too expensive.","author":{"__typename":"User","id":"1","name":"user-1"}}],"stock":2},{"name":"Chair","__typename":"Product","upc":"3","reviews":[{"body":"Chair Could be better.","author":{"__typename":"User","id":"2","name":"user-2"}}],"stock":5}]}` - res := NewResolvable(ResolvableOptions{}) + res := NewResolvable(nil, ResolvableOptions{}) ctx := &Context{ Variables: nil, } @@ -84,7 +84,7 @@ func TestResolvable_Resolve(t *testing.T) { func TestResolvable_ResolveWithTypeMismatch(t *testing.T) { topProducts := `{"topProducts":[{"name":"Table","__typename":"Product","upc":"1","reviews":[{"body":"Love Table!","author":{"__typename":"User","id":"1","name":true}},{"body":"Prefer other Table.","author":{"__typename":"User","id":"2","name":"user-2"}}],"stock":8},{"name":"Couch","__typename":"Product","upc":"2","reviews":[{"body":"Couch Too expensive.","author":{"__typename":"User","id":"1","name":"user-1"}}],"stock":2},{"name":"Chair","__typename":"Product","upc":"3","reviews":[{"body":"Chair Could be better.","author":{"__typename":"User","id":"2","name":"user-2"}}],"stock":5}]}` - res := NewResolvable(ResolvableOptions{}) + res := NewResolvable(nil, ResolvableOptions{}) ctx := &Context{ Variables: nil, } @@ -157,7 +157,7 @@ func TestResolvable_ResolveWithTypeMismatch(t *testing.T) { func TestResolvable_ResolveWithErrorBubbleUp(t *testing.T) { topProducts := `{"topProducts":[{"name":"Table","__typename":"Product","upc":"1","reviews":[{"body":"Love Table!","author":{"__typename":"User","id":"1"}},{"body":"Prefer other Table.","author":{"__typename":"User","id":"2","name":"user-2"}}],"stock":8},{"name":"Couch","__typename":"Product","upc":"2","reviews":[{"body":"Couch Too expensive.","author":{"__typename":"User","id":"1","name":"user-1"}}],"stock":2},{"name":"Chair","__typename":"Product","upc":"3","reviews":[{"body":"Chair Could be better.","author":{"__typename":"User","id":"2","name":"user-2"}}],"stock":5}]}` - res := NewResolvable(ResolvableOptions{}) + res := NewResolvable(nil, ResolvableOptions{}) ctx := &Context{ Variables: nil, } @@ -231,7 +231,7 @@ func TestResolvable_ResolveWithErrorBubbleUp(t *testing.T) { func TestResolvable_ApolloCompatibilityMode_NonNullability(t *testing.T) { t.Run("Non-nullable root field", func(t *testing.T) { topProducts := `{"topProducts":null}` - res := NewResolvable(ResolvableOptions{ + res := NewResolvable(nil, ResolvableOptions{ ApolloCompatibilityValueCompletionInExtensions: true, }) ctx := &Context{ @@ -258,7 +258,7 @@ func TestResolvable_ApolloCompatibilityMode_NonNullability(t *testing.T) { }) t.Run("Non-Nullable root field and nested field", func(t *testing.T) { topProducts := `{"topProducts":[{"name":"Table","__typename":"Product","upc":"1","reviews":[{"body":"Love Table!","author":{"__typename":"User","id":"1"}},{"body":"Prefer other Table.","author":{"__typename":"User","id":"2","name":"user-2"}}],"stock":8},{"name":"Couch","__typename":"Product","upc":"2","reviews":[{"body":"Couch Too expensive.","author":{"__typename":"User","id":"1","name":"user-1"}}],"stock":2},{"name":"Chair","__typename":"Product","upc":"3","reviews":[{"body":"Chair Could be better.","author":{"__typename":"User","id":"2","name":"user-2"}}],"stock":5}]}` - res := NewResolvable(ResolvableOptions{ + res := NewResolvable(nil, ResolvableOptions{ ApolloCompatibilityValueCompletionInExtensions: true, }) ctx := &Context{ @@ -333,7 +333,7 @@ func TestResolvable_ApolloCompatibilityMode_NonNullability(t *testing.T) { }) t.Run("Nullable root field and non-Nullable nested field", func(t *testing.T) { topProducts := `{"topProduct":{"name":null}}` - res := NewResolvable(ResolvableOptions{ + res := NewResolvable(nil, ResolvableOptions{ ApolloCompatibilityValueCompletionInExtensions: true, }) ctx := &Context{ @@ -370,7 +370,7 @@ func TestResolvable_ApolloCompatibilityMode_NonNullability(t *testing.T) { }) t.Run("Non-Nullable sibling field", func(t *testing.T) { topProducts := `{"topProducts":[{"name":"Table","__typename":"Product","reviews":[{"author":{"__typename":"User","name":"Bob"},"body":null}]}]}` - res := NewResolvable(ResolvableOptions{ + res := NewResolvable(nil, ResolvableOptions{ ApolloCompatibilityValueCompletionInExtensions: true, }) ctx := &Context{ @@ -439,7 +439,7 @@ func TestResolvable_ApolloCompatibilityMode_NonNullability(t *testing.T) { }) t.Run("Non-nullable array and array item", func(t *testing.T) { topProducts := `{"topProducts":[null]}` - res := NewResolvable(ResolvableOptions{ + res := NewResolvable(nil, ResolvableOptions{ ApolloCompatibilityValueCompletionInExtensions: true, }) ctx := &Context{ @@ -469,7 +469,7 @@ func TestResolvable_ApolloCompatibilityMode_NonNullability(t *testing.T) { }) t.Run("Nullable array and non-nullable array item", func(t *testing.T) { topProducts := `{"topProducts":[null]}` - res := NewResolvable(ResolvableOptions{ + res := NewResolvable(nil, ResolvableOptions{ ApolloCompatibilityValueCompletionInExtensions: true, }) ctx := &Context{ @@ -500,7 +500,7 @@ func TestResolvable_ApolloCompatibilityMode_NonNullability(t *testing.T) { }) t.Run("Non-Nullable array, array item, and array item field", func(t *testing.T) { topProducts := `{"topProducts":[{"author":{"name":"Name"}},{"author":null}]}` - res := NewResolvable(ResolvableOptions{ + res := NewResolvable(nil, ResolvableOptions{ ApolloCompatibilityValueCompletionInExtensions: true, }) ctx := &Context{ @@ -549,7 +549,7 @@ func TestResolvable_ApolloCompatibilityMode_NonNullability(t *testing.T) { func TestResolvable_ResolveWithErrorBubbleUpUntilData(t *testing.T) { topProducts := `{"topProducts":[{"name":"Table","__typename":"Product","upc":"1","reviews":[{"body":"Love Table!","author":{"__typename":"User","id":"1","name":"user-1"}},{"body":"Prefer other Table.","author":{"__typename":"User","id":"2"}}],"stock":8},{"name":"Couch","__typename":"Product","upc":"2","reviews":[{"body":"Couch Too expensive.","author":{"__typename":"User","id":"1","name":"user-1"}}],"stock":2},{"name":"Chair","__typename":"Product","upc":"3","reviews":[{"body":"Chair Could be better.","author":{"__typename":"User","id":"2","name":"user-2"}}],"stock":5}]}` - res := NewResolvable(ResolvableOptions{}) + res := NewResolvable(nil, ResolvableOptions{}) ctx := &Context{ Variables: nil, } @@ -622,7 +622,7 @@ func TestResolvable_ResolveWithErrorBubbleUpUntilData(t *testing.T) { func TestResolvable_InvalidEnumValues(t *testing.T) { t.Run("Invalid enum value", func(t *testing.T) { enum := `{"enum":"B"}` - res := NewResolvable(ResolvableOptions{}) + res := NewResolvable(nil, ResolvableOptions{}) ctx := &Context{ Variables: nil, } @@ -653,7 +653,7 @@ func TestResolvable_InvalidEnumValues(t *testing.T) { t.Run("Inaccessible enum value", func(t *testing.T) { enum := `{"enum":"B"}` - res := NewResolvable(ResolvableOptions{}) + res := NewResolvable(nil, ResolvableOptions{}) ctx := &Context{ Variables: nil, } @@ -686,7 +686,7 @@ func TestResolvable_InvalidEnumValues(t *testing.T) { t.Run("Invalid enum value with value completion Apollo compatibility flag", func(t *testing.T) { enum := `{"enum":"B"}` - res := NewResolvable(ResolvableOptions{ + res := NewResolvable(nil, ResolvableOptions{ ApolloCompatibilityValueCompletionInExtensions: true, }) ctx := &Context{ @@ -719,7 +719,7 @@ func TestResolvable_InvalidEnumValues(t *testing.T) { t.Run("Inaccessible enum value with value completion Apollo compatibility flag", func(t *testing.T) { enum := `{"enum":"B"}` - res := NewResolvable(ResolvableOptions{ + res := NewResolvable(nil, ResolvableOptions{ ApolloCompatibilityValueCompletionInExtensions: true, }) ctx := &Context{ @@ -755,7 +755,7 @@ func TestResolvable_InvalidEnumValues(t *testing.T) { func BenchmarkResolvable_Resolve(b *testing.B) { topProducts := `{"topProducts":[{"name":"Table","__typename":"Product","upc":"1","reviews":[{"body":"Love Table!","author":{"__typename":"User","id":"1","name":"user-1"}},{"body":"Prefer other Table.","author":{"__typename":"User","id":"2","name":"user-2"}}],"stock":8},{"name":"Couch","__typename":"Product","upc":"2","reviews":[{"body":"Couch Too expensive.","author":{"__typename":"User","id":"1","name":"user-1"}}],"stock":2},{"name":"Chair","__typename":"Product","upc":"3","reviews":[{"body":"Chair Could be better.","author":{"__typename":"User","id":"2","name":"user-2"}}],"stock":5}]}` - res := NewResolvable(ResolvableOptions{}) + res := NewResolvable(nil, ResolvableOptions{}) ctx := &Context{ Variables: nil, } @@ -838,7 +838,7 @@ func BenchmarkResolvable_Resolve(b *testing.B) { func BenchmarkResolvable_ResolveWithErrorBubbleUp(b *testing.B) { topProducts := `{"topProducts":[{"name":"Table","__typename":"Product","upc":"1","reviews":[{"body":"Love Table!","author":{"__typename":"User","id":"1"}},{"body":"Prefer other Table.","author":{"__typename":"User","id":"2","name":"user-2"}}],"stock":8},{"name":"Couch","__typename":"Product","upc":"2","reviews":[{"body":"Couch Too expensive.","author":{"__typename":"User","id":"1","name":"user-1"}}],"stock":2},{"name":"Chair","__typename":"Product","upc":"3","reviews":[{"body":"Chair Could be better.","author":{"__typename":"User","id":"2","name":"user-2"}}],"stock":5}]}` - res := NewResolvable(ResolvableOptions{}) + res := NewResolvable(nil, ResolvableOptions{}) ctx := &Context{ Variables: nil, } @@ -923,7 +923,7 @@ func BenchmarkResolvable_ResolveWithErrorBubbleUp(b *testing.B) { } func TestResolvable_WithTracingNotStarted(t *testing.T) { - res := NewResolvable(ResolvableOptions{}) + res := NewResolvable(nil, ResolvableOptions{}) // Do not start a trace with SetTraceStart(), but request it to be output ctx := NewContext(context.Background()) ctx.TracingOptions.Enable = true @@ -950,7 +950,7 @@ func TestResolvable_WithTracingNotStarted(t *testing.T) { func TestResolveFloat(t *testing.T) { t.Run("default behaviour", func(t *testing.T) { - res := NewResolvable(ResolvableOptions{}) + res := NewResolvable(nil, ResolvableOptions{}) ctx := NewContext(context.Background()) err := res.Init(ctx, []byte(`{"f":1.0}`), ast.OperationTypeQuery) assert.NoError(t, err) @@ -972,7 +972,7 @@ func TestResolveFloat(t *testing.T) { assert.Equal(t, `{"data":{"f":1.0}}`, out.String()) }) t.Run("invalid float", func(t *testing.T) { - res := NewResolvable(ResolvableOptions{}) + res := NewResolvable(nil, ResolvableOptions{}) ctx := NewContext(context.Background()) err := res.Init(ctx, []byte(`{"f":false}`), ast.OperationTypeQuery) assert.NoError(t, err) @@ -994,7 +994,7 @@ func TestResolveFloat(t *testing.T) { assert.Equal(t, `{"errors":[{"message":"Float cannot represent non-float value: \"false\"","path":["f"]}],"data":null}`, out.String()) }) t.Run("truncate float", func(t *testing.T) { - res := NewResolvable(ResolvableOptions{ + res := NewResolvable(nil, ResolvableOptions{ ApolloCompatibilityTruncateFloatValues: true, }) ctx := NewContext(context.Background()) @@ -1018,7 +1018,7 @@ func TestResolveFloat(t *testing.T) { assert.Equal(t, `{"data":{"f":1}}`, out.String()) }) t.Run("truncate float with decimal place", func(t *testing.T) { - res := NewResolvable(ResolvableOptions{ + res := NewResolvable(nil, ResolvableOptions{ ApolloCompatibilityTruncateFloatValues: true, }) ctx := NewContext(context.Background()) @@ -1045,7 +1045,7 @@ func TestResolveFloat(t *testing.T) { func TestResolvable_ValueCompletion(t *testing.T) { t.Run("nested object", func(t *testing.T) { - res := NewResolvable(ResolvableOptions{ + res := NewResolvable(nil, ResolvableOptions{ ApolloCompatibilityValueCompletionInExtensions: true, }) ctx := NewContext(context.Background()) @@ -1143,7 +1143,7 @@ func TestResolvable_ValueCompletion(t *testing.T) { }`) t.Run("nullable", func(t *testing.T) { - res := NewResolvable(ResolvableOptions{ + res := NewResolvable(nil, ResolvableOptions{ ApolloCompatibilityValueCompletionInExtensions: true, }) ctx := NewContext(context.Background()) @@ -1241,7 +1241,7 @@ func TestResolvable_ValueCompletion(t *testing.T) { }) t.Run("mixed nullability", func(t *testing.T) { - res := NewResolvable(ResolvableOptions{ + res := NewResolvable(nil, ResolvableOptions{ ApolloCompatibilityValueCompletionInExtensions: true, }) ctx := NewContext(context.Background()) @@ -1342,7 +1342,7 @@ func TestResolvable_ValueCompletion(t *testing.T) { func TestResolvable_WithTracing(t *testing.T) { topProducts := `{"topProducts":[{"name":"Table","__typename":"Product","upc":"1","reviews":[{"body":"Love Table!","author":{"__typename":"User","id":"1","name":"user-1"}},{"body":"Prefer other Table.","author":{"__typename":"User","id":"2","name":"user-2"}}],"stock":8},{"name":"Couch","__typename":"Product","upc":"2","reviews":[{"body":"Couch Too expensive.","author":{"__typename":"User","id":"1","name":"user-1"}}],"stock":2},{"name":"Chair","__typename":"Product","upc":"3","reviews":[{"body":"Chair Could be better.","author":{"__typename":"User","id":"2","name":"user-2"}}],"stock":5}]}` - res := NewResolvable(ResolvableOptions{}) + res := NewResolvable(nil, ResolvableOptions{}) background := SetTraceStart(context.Background(), true) ctx := NewContext(background) ctx.TracingOptions.Enable = true diff --git a/v2/pkg/engine/resolve/resolve.go b/v2/pkg/engine/resolve/resolve.go index 14d8ad4b5..747ee02c4 100644 --- a/v2/pkg/engine/resolve/resolve.go +++ b/v2/pkg/engine/resolve/resolve.go @@ -5,14 +5,18 @@ package resolve import ( "bytes" "context" + "encoding/binary" "fmt" "io" + "net/http" "time" "github.com/buger/jsonparser" "github.com/pkg/errors" "go.uber.org/atomic" + "github.com/wundergraph/go-arena" + "github.com/wundergraph/graphql-go-tools/v2/pkg/internal/xcontext" "github.com/wundergraph/graphql-go-tools/v2/pkg/pool" ) @@ -69,6 +73,20 @@ type Resolver struct { heartbeatInterval time.Duration // maxSubscriptionFetchTimeout defines the maximum time a subscription fetch can take before it is considered timed out maxSubscriptionFetchTimeout time.Duration + + // resolveArenaPool is the arena pool dedicated for Loader & Resolvable + // ArenaPool automatically adjusts arena buffer sizes per workload + // resolving & response buffering are very different tasks + // as such, it was best to have two arena pools in terms of memory usage + // A single pool for both was much less efficient + resolveArenaPool *ArenaPool + // responseBufferPool is the arena pool dedicated for response buffering before sending to the client + responseBufferPool *ArenaPool + + // subgraphRequestSingleFlight is used to de-duplicate subgraph requests + subgraphRequestSingleFlight *SubgraphRequestSingleFlight + // inboundRequestSingleFlight is used to de-duplicate subgraph requests + inboundRequestSingleFlight *InboundRequestSingleFlight } func (r *Resolver) SetAsyncErrorWriter(w AsyncErrorWriter) { @@ -222,6 +240,10 @@ func New(ctx context.Context, options ResolverOptions) *Resolver { allowedErrorFields: allowedErrorFields, heartbeatInterval: options.SubscriptionHeartbeatInterval, maxSubscriptionFetchTimeout: options.MaxSubscriptionFetchTimeout, + resolveArenaPool: NewArenaPool(), + responseBufferPool: NewArenaPool(), + subgraphRequestSingleFlight: NewSingleFlight(8), + inboundRequestSingleFlight: NewRequestSingleFlight(8), } resolver.maxConcurrency = make(chan struct{}, options.MaxConcurrency) for i := 0; i < options.MaxConcurrency; i++ { @@ -233,9 +255,9 @@ func New(ctx context.Context, options ResolverOptions) *Resolver { return resolver } -func newTools(options ResolverOptions, allowedExtensionFields map[string]struct{}, allowedErrorFields map[string]struct{}) *tools { +func newTools(options ResolverOptions, allowedExtensionFields map[string]struct{}, allowedErrorFields map[string]struct{}, sf *SubgraphRequestSingleFlight, a arena.Arena) *tools { return &tools{ - resolvable: NewResolvable(options.ResolvableOptions), + resolvable: NewResolvable(a, options.ResolvableOptions), loader: &Loader{ propagateSubgraphErrors: options.PropagateSubgraphErrors, propagateSubgraphStatusCodes: options.PropagateSubgraphStatusCodes, @@ -251,6 +273,8 @@ func newTools(options ResolverOptions, allowedExtensionFields map[string]struct{ apolloRouterCompatibilitySubrequestHTTPError: options.ApolloRouterCompatibilitySubrequestHTTPError, propagateFetchReasons: options.PropagateFetchReasons, validateRequiredExternalFields: options.ValidateRequiredExternalFields, + sf: sf, + jsonArena: a, }, } } @@ -269,7 +293,7 @@ func (r *Resolver) ResolveGraphQLResponse(ctx *Context, response *GraphQLRespons r.maxConcurrency <- struct{}{} }() - t := newTools(r.options, r.allowedErrorExtensionFields, r.allowedErrorFields) + t := newTools(r.options, r.allowedErrorExtensionFields, r.allowedErrorFields, r.subgraphRequestSingleFlight, nil) err := t.resolvable.Init(ctx, data, response.Info.OperationType) if err != nil { @@ -291,6 +315,72 @@ func (r *Resolver) ResolveGraphQLResponse(ctx *Context, response *GraphQLRespons return resp, err } +func (r *Resolver) ArenaResolveGraphQLResponse(ctx *Context, response *GraphQLResponse, writer io.Writer) (*GraphQLResolveInfo, error) { + resp := &GraphQLResolveInfo{} + + inflight, err := r.inboundRequestSingleFlight.GetOrCreate(ctx, response) + if err != nil { + return nil, err + } + + if inflight != nil && inflight.Data != nil { // follower + _, err = writer.Write(inflight.Data) + return resp, err + } + + start := time.Now() + <-r.maxConcurrency + resp.ResolveAcquireWaitTime = time.Since(start) + defer func() { + r.maxConcurrency <- struct{}{} + }() + + resolveArena := r.resolveArenaPool.Acquire(ctx.Request.ID) + // we're intentionally not using defer Release to have more control over the timing (see below) + t := newTools(r.options, r.allowedErrorExtensionFields, r.allowedErrorFields, r.subgraphRequestSingleFlight, resolveArena.Arena) + + err = t.resolvable.Init(ctx, nil, response.Info.OperationType) + if err != nil { + r.inboundRequestSingleFlight.FinishErr(inflight, err) + r.resolveArenaPool.Release(resolveArena) + return nil, err + } + + if !ctx.ExecutionOptions.SkipLoader { + err = t.loader.LoadGraphQLResponseData(ctx, response, t.resolvable) + if err != nil { + r.inboundRequestSingleFlight.FinishErr(inflight, err) + r.resolveArenaPool.Release(resolveArena) + return nil, err + } + } + + // only when loading is done, acquire an arena for the response buffer + responseArena := r.responseBufferPool.Acquire(ctx.Request.ID) + buf := arena.NewArenaBuffer(responseArena.Arena) + err = t.resolvable.Resolve(ctx.ctx, response.Data, response.Fetches, buf) + if err != nil { + r.inboundRequestSingleFlight.FinishErr(inflight, err) + r.resolveArenaPool.Release(resolveArena) + r.responseBufferPool.Release(responseArena) + return nil, err + } + + // first release resolverArena + // all data is resolved and written into the response arena + r.resolveArenaPool.Release(resolveArena) + // next we write back to the client + // this includes flushing and syscalls + // as such, it can take some time + // which is why we split the arenas and released the first one + _, err = writer.Write(buf.Bytes()) + r.inboundRequestSingleFlight.FinishOk(inflight, buf.Bytes()) + // all data is written to the client + // we're safe to release our buffer + r.responseBufferPool.Release(responseArena) + return resp, err +} + type trigger struct { id uint64 cancel context.CancelFunc @@ -421,9 +511,11 @@ func (r *Resolver) executeSubscriptionUpdate(resolveCtx *Context, sub *sub, shar input := make([]byte, len(sharedInput)) copy(input, sharedInput) - t := newTools(r.options, r.allowedErrorExtensionFields, r.allowedErrorFields) + resolveArena := r.resolveArenaPool.Acquire(resolveCtx.Request.ID) + t := newTools(r.options, r.allowedErrorExtensionFields, r.allowedErrorFields, r.subgraphRequestSingleFlight, resolveArena.Arena) if err := t.resolvable.InitSubscription(resolveCtx, input, sub.resolve.Trigger.PostProcessing); err != nil { + r.resolveArenaPool.Release(resolveArena) r.asyncErrorWriter.WriteError(resolveCtx, err, sub.resolve.Response, sub.writer) if r.options.Debug { fmt.Printf("resolver:trigger:subscription:init:failed:%d\n", sub.id.SubscriptionID) @@ -435,6 +527,7 @@ func (r *Resolver) executeSubscriptionUpdate(resolveCtx *Context, sub *sub, shar } if err := t.loader.LoadGraphQLResponseData(resolveCtx, sub.resolve.Response, t.resolvable); err != nil { + r.resolveArenaPool.Release(resolveArena) r.asyncErrorWriter.WriteError(resolveCtx, err, sub.resolve.Response, sub.writer) if r.options.Debug { fmt.Printf("resolver:trigger:subscription:load:failed:%d\n", sub.id.SubscriptionID) @@ -446,6 +539,7 @@ func (r *Resolver) executeSubscriptionUpdate(resolveCtx *Context, sub *sub, shar } if err := t.resolvable.Resolve(resolveCtx.ctx, sub.resolve.Response.Data, sub.resolve.Response.Fetches, sub.writer); err != nil { + r.resolveArenaPool.Release(resolveArena) r.asyncErrorWriter.WriteError(resolveCtx, err, sub.resolve.Response, sub.writer) if r.options.Debug { fmt.Printf("resolver:trigger:subscription:resolve:failed:%d\n", sub.id.SubscriptionID) @@ -456,6 +550,8 @@ func (r *Resolver) executeSubscriptionUpdate(resolveCtx *Context, sub *sub, shar return } + r.resolveArenaPool.Release(resolveArena) + if err := sub.writer.Flush(); err != nil { // If flush fails (e.g. client disconnected), remove the subscription. _ = r.AsyncUnsubscribeSubscription(sub.id) @@ -656,9 +752,9 @@ func (r *Resolver) handleAddSubscription(triggerID uint64, add *addSubscription) fmt.Printf("resolver:trigger:start:%d\n", triggerID) } if asyncDataSource != nil { - err = asyncDataSource.AsyncStart(cloneCtx, triggerID, add.input, updater) + err = asyncDataSource.AsyncStart(cloneCtx, triggerID, add.headers, add.input, updater) } else { - err = add.resolve.Trigger.Source.Start(cloneCtx, add.input, updater) + err = add.resolve.Trigger.Source.Start(cloneCtx, add.headers, add.input, updater) } if err != nil { if r.options.Debug { @@ -1001,6 +1097,24 @@ func (r *Resolver) AsyncUnsubscribeClient(connectionID int64) error { return nil } +// prepareTrigger safely gets the headers for the trigger Subgraph and computes the hash across headers and input +// the generated has is the unique triggerID +// the headers must be forwarded to the DataSource to create the trigger +func (r *Resolver) prepareTrigger(ctx *Context, sourceName string, input []byte) (headers http.Header, triggerID uint64) { + if ctx.SubgraphHeadersBuilder != nil { + header, headerHash := ctx.SubgraphHeadersBuilder.HeadersForSubgraph(sourceName) + keyGen := pool.Hash64.Get() + _, _ = keyGen.Write(input) + var b [8]byte + binary.LittleEndian.PutUint64(b[:], headerHash) + _, _ = keyGen.Write(b[:]) + triggerID = keyGen.Sum64() + pool.Hash64.Put(keyGen) + return header, triggerID + } + return nil, 0 +} + func (r *Resolver) ResolveGraphQLSubscription(ctx *Context, subscription *GraphQLSubscription, writer SubscriptionResponseWriter) error { if subscription.Trigger.Source == nil { return errors.New("no data source found") @@ -1014,7 +1128,7 @@ func (r *Resolver) ResolveGraphQLSubscription(ctx *Context, subscription *GraphQ // If SkipLoader is enabled, we skip retrieving actual data. For example, this is useful when requesting a query plan. // By returning early, we avoid starting a subscription and resolve with empty data instead. if ctx.ExecutionOptions.SkipLoader { - t := newTools(r.options, r.allowedErrorExtensionFields, r.allowedErrorFields) + t := newTools(r.options, r.allowedErrorExtensionFields, r.allowedErrorFields, r.subgraphRequestSingleFlight, nil) err = t.resolvable.InitSubscription(ctx, nil, subscription.Trigger.PostProcessing) if err != nil { @@ -1038,20 +1152,13 @@ func (r *Resolver) ResolveGraphQLSubscription(ctx *Context, subscription *GraphQ return nil } - xxh := pool.Hash64.Get() - defer pool.Hash64.Put(xxh) - err = subscription.Trigger.Source.UniqueRequestID(ctx, input, xxh) - if err != nil { - msg := []byte(`{"errors":[{"message":"unable to resolve"}]}`) - return writeFlushComplete(writer, msg) - } - uniqueID := xxh.Sum64() + headers, triggerID := r.prepareTrigger(ctx, subscription.Trigger.SourceName, input) id := SubscriptionIdentifier{ ConnectionID: ConnectionIDs.Inc(), SubscriptionID: 0, } if r.options.Debug { - fmt.Printf("resolver:trigger:subscribe:sync:%d:%d\n", uniqueID, id.SubscriptionID) + fmt.Printf("resolver:trigger:subscribe:sync:%d:%d\n", triggerID, id.SubscriptionID) } completed := make(chan struct{}) @@ -1061,15 +1168,17 @@ func (r *Resolver) ResolveGraphQLSubscription(ctx *Context, subscription *GraphQ // Stop processing if the resolver is shutting down return r.ctx.Err() case r.events <- subscriptionEvent{ - triggerID: uniqueID, + triggerID: triggerID, kind: subscriptionEventKindAddSubscription, addSubscription: &addSubscription{ - ctx: ctx, - input: input, - resolve: subscription, - writer: writer, - id: id, - completed: completed, + ctx: ctx, + input: input, + resolve: subscription, + writer: writer, + id: id, + completed: completed, + sourceName: subscription.Trigger.SourceName, + headers: headers, }, }: } @@ -1096,13 +1205,13 @@ func (r *Resolver) ResolveGraphQLSubscription(ctx *Context, subscription *GraphQ } if r.options.Debug { - fmt.Printf("resolver:trigger:unsubscribe:sync:%d:%d\n", uniqueID, id.SubscriptionID) + fmt.Printf("resolver:trigger:unsubscribe:sync:%d:%d\n", triggerID, id.SubscriptionID) } // Remove the subscription when the client disconnects. r.events <- subscriptionEvent{ - triggerID: uniqueID, + triggerID: triggerID, kind: subscriptionEventKindRemoveSubscription, id: id, } @@ -1123,7 +1232,7 @@ func (r *Resolver) AsyncResolveGraphQLSubscription(ctx *Context, subscription *G // If SkipLoader is enabled, we skip retrieving actual data. For example, this is useful when requesting a query plan. // By returning early, we avoid starting a subscription and resolve with empty data instead. if ctx.ExecutionOptions.SkipLoader { - t := newTools(r.options, r.allowedErrorExtensionFields, r.allowedErrorFields) + t := newTools(r.options, r.allowedErrorExtensionFields, r.allowedErrorFields, r.subgraphRequestSingleFlight, nil) err = t.resolvable.InitSubscription(ctx, nil, subscription.Trigger.PostProcessing) if err != nil { @@ -1147,13 +1256,7 @@ func (r *Resolver) AsyncResolveGraphQLSubscription(ctx *Context, subscription *G return nil } - xxh := pool.Hash64.Get() - defer pool.Hash64.Put(xxh) - err = subscription.Trigger.Source.UniqueRequestID(ctx, input, xxh) - if err != nil { - msg := []byte(`{"errors":[{"message":"unable to resolve"}]}`) - return writeFlushComplete(writer, msg) - } + headers, triggerID := r.prepareTrigger(ctx, subscription.Trigger.SourceName, input) select { case <-r.ctx.Done(): @@ -1163,15 +1266,17 @@ func (r *Resolver) AsyncResolveGraphQLSubscription(ctx *Context, subscription *G // Stop resolving if the client is gone return ctx.ctx.Err() case r.events <- subscriptionEvent{ - triggerID: xxh.Sum64(), + triggerID: triggerID, kind: subscriptionEventKindAddSubscription, addSubscription: &addSubscription{ - ctx: ctx, - input: input, - resolve: subscription, - writer: writer, - id: id, - completed: make(chan struct{}), + ctx: ctx, + input: input, + resolve: subscription, + writer: writer, + id: id, + completed: make(chan struct{}), + sourceName: subscription.Trigger.SourceName, + headers: headers, }, }: } @@ -1279,12 +1384,14 @@ type subscriptionEvent struct { } type addSubscription struct { - ctx *Context - input []byte - resolve *GraphQLSubscription - writer SubscriptionResponseWriter - id SubscriptionIdentifier - completed chan struct{} + ctx *Context + input []byte + resolve *GraphQLSubscription + writer SubscriptionResponseWriter + id SubscriptionIdentifier + completed chan struct{} + sourceName string + headers http.Header } type subscriptionEventKind int diff --git a/v2/pkg/engine/resolve/resolve_federation_test.go b/v2/pkg/engine/resolve/resolve_federation_test.go index 2547c6d10..1c32db689 100644 --- a/v2/pkg/engine/resolve/resolve_federation_test.go +++ b/v2/pkg/engine/resolve/resolve_federation_test.go @@ -1,9 +1,8 @@ package resolve import ( - "bytes" "context" - "io" + "net/http" "testing" "github.com/golang/mock/gomock" @@ -21,18 +20,11 @@ func mockedDS(t TestingTB, ctrl *gomock.Controller, expectedInput, responseData t.Helper() service := NewMockDataSource(ctrl) service.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { - actual := string(input) - expected := expectedInput - - require.Equal(t, expected, actual) - - pair := NewBufPair() - pair.Data.WriteString(responseData) - - return writeGraphqlResponse(pair, w, false) - }).AnyTimes() + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { + require.Equal(t, expectedInput, string(input)) + return []byte(responseData), nil + }).Times(1) return service } @@ -48,7 +40,7 @@ func TestResolveGraphQLResponse_Federation(t *testing.T) { DataSource: mockedDS( t, ctrl, `{"method":"POST","url":"http://user.service","body":{"query":"{user {account {__typename id info {a b}}}}"}}`, - `{"user":{"account":{"__typename":"Account","id":"1234","info":{"a":"foo","b":"bar"}}}}`, + `{"data":{"user":{"account":{"__typename":"Account","id":"1234","info":{"a":"foo","b":"bar"}}}}}`, ), Input: `{"method":"POST","url":"http://user.service","body":{"query":"{user {account {__typename id info {a b}}}}"}}`, PostProcessing: PostProcessingConfiguration{ @@ -70,7 +62,7 @@ func TestResolveGraphQLResponse_Federation(t *testing.T) { DataSource: mockedDS( t, ctrl, expectedAccountsQuery, - `{"_entities":[{"__typename":"Account","name":"John Doe","shippingInfo":{"zip":"12345"}}]}`, + `{"data":{"_entities":[{"__typename":"Account","name":"John Doe","shippingInfo":{"zip":"12345"}}]}}`, ), Input: `{"method":"POST","url":"http://account.service","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){__typename ... on Account {name shippingInfo {zip}}}}","variables":{"representations":$$0$$}}}`, PostProcessing: PostProcessingConfiguration{ @@ -182,38 +174,38 @@ func TestResolveGraphQLResponse_Federation(t *testing.T) { t.Run("federation with shareable", testFn(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { firstService := NewMockDataSource(ctrl) firstService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w *bytes.Buffer) (err error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://first.service","body":{"query":"{me {details {forename middlename} __typename id}}"}}` assert.Equal(t, expected, actual) pair := NewBufPair() - pair.Data.WriteString(`{"me": {"__typename": "User", "id": "1234", "details": {"forename": "John", "middlename": "A"}}}`) - return writeGraphqlResponse(pair, w, false) + pair.Data.WriteString(`{"data":{"me": {"__typename": "User", "id": "1234", "details": {"forename": "John", "middlename": "A"}}}}`) + return pair.Data.Bytes(), nil }) secondService := NewMockDataSource(ctrl) secondService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w *bytes.Buffer) (err error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://second.service","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){__typename ... on User {details {surname}}}}","variables":{"representations":[{"__typename":"User","id":"1234"}]}}}` assert.Equal(t, expected, actual) pair := NewBufPair() - pair.Data.WriteString(`{"_entities": [{"__typename": "User", "details": {"surname": "Smith"}}]}`) - return writeGraphqlResponse(pair, w, false) + pair.Data.WriteString(`{"data":{"_entities": [{"__typename": "User", "details": {"surname": "Smith"}}]}}`) + return pair.Data.Bytes(), nil }) thirdService := NewMockDataSource(ctrl) thirdService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w *bytes.Buffer) (err error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://third.service","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){__typename ... on User {details {age}}}}","variables":{"representations":[{"__typename":"User","id":"1234"}]}}}` assert.Equal(t, expected, actual) pair := NewBufPair() - pair.Data.WriteString(`{"_entities": [{"__typename": "User", "details": {"age": 21}}]}`) - return writeGraphqlResponse(pair, w, false) + pair.Data.WriteString(`{"data":{"_entities": [{"__typename": "User", "details": {"age": 21}}]}}`) + return pair.Data.Bytes(), nil }) return &GraphQLResponse{ @@ -377,26 +369,26 @@ func TestResolveGraphQLResponse_Federation(t *testing.T) { userService := NewMockDataSource(ctrl) userService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4001","body":{"query":"{ user { name infoOrAddress { ... on Info {id __typename} ... on Address {id __typename}}}}"}}` assert.Equal(t, expected, actual) pair := NewBufPair() - pair.Data.WriteString(`{"user":{"name":"Bill","infoOrAddress":[{"id":11,"__typename":"Info"},{"id": 55,"__typename":"Address"}]}}`) - return writeGraphqlResponse(pair, w, false) + pair.Data.WriteString(`{"data":{"user":{"name":"Bill","infoOrAddress":[{"id":11,"__typename":"Info"},{"id": 55,"__typename":"Address"}]}}}`) + return pair.Data.Bytes(), nil }) infoService := NewMockDataSource(ctrl) infoService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4002","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){query($representations: [_Any!]!){_entities(representations: $representations) { ... on Info { age } ... on Address { line1 }}}}}","variables":{"representations":[{"id":11,"__typename":"Info"},{"id":55,"__typename":"Address"}]}}}` assert.Equal(t, expected, actual) pair := NewBufPair() - pair.Data.WriteString(`{"_entities":[{"age":21,"__typename":"Info"},{"line1":"Munich","__typename":"Address"}]}`) - return writeGraphqlResponse(pair, w, false) + pair.Data.WriteString(`{"data":{"_entities":[{"age":21,"__typename":"Info"},{"line1":"Munich","__typename":"Address"}]}}`) + return pair.Data.Bytes(), nil }) return &GraphQLResponse{ @@ -530,19 +522,19 @@ func TestResolveGraphQLResponse_Federation(t *testing.T) { userService := NewMockDataSource(ctrl) userService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4001","body":{"query":"{ user { name infoOrAddress { ... on Info {id __typename} ... on Address {id __typename}}}}"}}` assert.Equal(t, expected, actual) pair := NewBufPair() - pair.Data.WriteString(`{"user":{"name":"Bill","infoOrAddress":[{"id":11,"__typename":"Whatever"},{"id": 55,"__typename":"Whatever"}]}}`) - return writeGraphqlResponse(pair, w, false) + pair.Data.WriteString(`{"data":{"user":{"name":"Bill","infoOrAddress":[{"id":11,"__typename":"Whatever"},{"id": 55,"__typename":"Whatever"}]}}}`) + return pair.Data.Bytes(), nil }) infoService := NewMockDataSource(ctrl) infoService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). + Load(gomock.Any(), gomock.Any(), gomock.Any()). Times(0) return &GraphQLResponse{ @@ -675,26 +667,26 @@ func TestResolveGraphQLResponse_Federation(t *testing.T) { t.Run("batching on a field", testFn(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { userService := NewMockDataSource(ctrl) userService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4001","body":{"query":"{ users { name info {id __typename}}}}"}}` assert.Equal(t, expected, actual) pair := NewBufPair() - pair.Data.WriteString(`{"users":[{"name":"Bill","info":{"id":11,"__typename":"Info"}},{"name":"John","info":{"id":12,"__typename":"Info"}},{"name":"Jane","info":{"id":13,"__typename":"Info"}}]}`) - return writeGraphqlResponse(pair, w, false) + pair.Data.WriteString(`{"data":{"users":[{"name":"Bill","info":{"id":11,"__typename":"Info"}},{"name":"John","info":{"id":12,"__typename":"Info"}},{"name":"Jane","info":{"id":13,"__typename":"Info"}}]}}`) + return pair.Data.Bytes(), nil }) infoService := NewMockDataSource(ctrl) infoService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4002","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations) { ... on Info { age }}}}}","variables":{"representations":[{"id":11,"__typename":"Info"},{"id":12,"__typename":"Info"},{"id":13,"__typename":"Info"}]}}}` assert.Equal(t, expected, actual) pair := NewBufPair() - pair.Data.WriteString(`{"_entities":[{"age":21,"__typename":"Info"},{"age":22,"__typename":"Info"},{"age":23,"__typename":"Info"}]}`) - return writeGraphqlResponse(pair, w, false) + pair.Data.WriteString(`{"data":{"_entities":[{"age":21,"__typename":"Info"},{"age":22,"__typename":"Info"},{"age":23,"__typename":"Info"}]}}`) + return pair.Data.Bytes(), nil }) return &GraphQLResponse{ @@ -819,26 +811,26 @@ func TestResolveGraphQLResponse_Federation(t *testing.T) { t.Run("batching with duplicates", testFn(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { userService := NewMockDataSource(ctrl) userService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4001","body":{"query":"{ users { name info {id __typename}}}}"}}` assert.Equal(t, expected, actual) pair := NewBufPair() - pair.Data.WriteString(`{"users":[{"name":"Bill","info":{"id":11,"__typename":"Info"}},{"name":"John","info":{"id":11,"__typename":"Info"}},{"name":"Jane","info":{"id":11,"__typename":"Info"}}]}`) - return writeGraphqlResponse(pair, w, false) + pair.Data.WriteString(`{"data":{"users":[{"name":"Bill","info":{"id":11,"__typename":"Info"}},{"name":"John","info":{"id":11,"__typename":"Info"}},{"name":"Jane","info":{"id":11,"__typename":"Info"}}]}}`) + return pair.Data.Bytes(), nil }) infoService := NewMockDataSource(ctrl) infoService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4002","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations) { ... on Info { age }}}}}","variables":{"representations":[{"id":11,"__typename":"Info"}]}}}` assert.Equal(t, expected, actual) pair := NewBufPair() - pair.Data.WriteString(`{"_entities":[{"age":77,"__typename":"Info"}]}`) - return writeGraphqlResponse(pair, w, false) + pair.Data.WriteString(`{"data":{"_entities":[{"age":77,"__typename":"Info"}]}}`) + return pair.Data.Bytes(), nil }) return &GraphQLResponse{ @@ -960,26 +952,26 @@ func TestResolveGraphQLResponse_Federation(t *testing.T) { t.Run("batching with null entry", testFn(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { userService := NewMockDataSource(ctrl) userService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4001","body":{"query":"{ users { name info {id __typename}}}}"}}` assert.Equal(t, expected, actual) pair := NewBufPair() - pair.Data.WriteString(`{"users":[{"name":"Bill","info":{"id":11,"__typename":"Info"}},{"name":"John","info":null},{"name":"Jane","info":{"id":13,"__typename":"Info"}}]}`) - return writeGraphqlResponse(pair, w, false) + pair.Data.WriteString(`{"data":{"users":[{"name":"Bill","info":{"id":11,"__typename":"Info"}},{"name":"John","info":null},{"name":"Jane","info":{"id":13,"__typename":"Info"}}]}}`) + return pair.Data.Bytes(), nil }) infoService := NewMockDataSource(ctrl) infoService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4002","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations) { ... on Info { age }}}}}","variables":{"representations":[{"id":11,"__typename":"Info"},{"id":13,"__typename":"Info"}]}}}` assert.Equal(t, expected, actual) pair := NewBufPair() - pair.Data.WriteString(`{"_entities":[{"age":21,"__typename":"Info"},{"age":23,"__typename":"Info"}]}`) - return writeGraphqlResponse(pair, w, false) + pair.Data.WriteString(`{"data":{"_entities":[{"age":21,"__typename":"Info"},{"age":23,"__typename":"Info"}]}}`) + return pair.Data.Bytes(), nil }) return &GraphQLResponse{ @@ -1105,19 +1097,19 @@ func TestResolveGraphQLResponse_Federation(t *testing.T) { t.Run("batching with all null entries", testFn(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { userService := NewMockDataSource(ctrl) userService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4001","body":{"query":"{ users { name info {id __typename}}}}"}}` assert.Equal(t, expected, actual) pair := NewBufPair() - pair.Data.WriteString(`{"users":[{"name":"Bill","info":null},{"name":"John","info":null},{"name":"Jane","info":null}]}`) - return writeGraphqlResponse(pair, w, false) + pair.Data.WriteString(`{"data":{"users":[{"name":"Bill","info":null},{"name":"John","info":null},{"name":"Jane","info":null}]}}`) + return pair.Data.Bytes(), nil }) infoService := NewMockDataSource(ctrl) infoService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). + Load(gomock.Any(), gomock.Any(), gomock.Any()). Times(0) return &GraphQLResponse{ @@ -1243,27 +1235,27 @@ func TestResolveGraphQLResponse_Federation(t *testing.T) { t.Run("batching with render error", testFn(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { userService := NewMockDataSource(ctrl) userService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4001","body":{"query":"{ users { name info {id __typename}}}}"}}` assert.Equal(t, expected, actual) pair := NewBufPair() // render error - first item id is boolean - pair.Data.WriteString(`{"users":[{"name":"Bill","info":{"id":true,"__typename":"Info"}},{"name":"John","info":{"id":12,"__typename":"Info"}},{"name":"Jane","info":{"id":13,"__typename":"Info"}}]}`) - return writeGraphqlResponse(pair, w, false) + pair.Data.WriteString(`{"data":{"users":[{"name":"Bill","info":{"id":true,"__typename":"Info"}},{"name":"John","info":{"id":12,"__typename":"Info"}},{"name":"Jane","info":{"id":13,"__typename":"Info"}}]}}`) + return pair.Data.Bytes(), nil }) infoService := NewMockDataSource(ctrl) infoService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4002","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations) { ... on Info { age }}}}}","variables":{"representations":[{"id":12,"__typename":"Info"},{"id":13,"__typename":"Info"}]}}}` assert.Equal(t, expected, actual) pair := NewBufPair() - pair.Data.WriteString(`{"_entities":[{"age":21,"__typename":"Info"},{"age":22,"__typename":"Info"}]}`) - return writeGraphqlResponse(pair, w, false) + pair.Data.WriteString(`{"data":{"_entities":[{"age":21,"__typename":"Info"},{"age":22,"__typename":"Info"}]}}`) + return pair.Data.Bytes(), nil }) return &GraphQLResponse{ @@ -1390,26 +1382,26 @@ func TestResolveGraphQLResponse_Federation(t *testing.T) { t.Run("all data", testFn(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { userService := NewMockDataSource(ctrl) userService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4001","body":{"query":"{ user { name info {id __typename}}}}"}}` assert.Equal(t, expected, actual) pair := NewBufPair() - pair.Data.WriteString(`{"user":{"name":"Bill","info":{"id":11,"__typename":"Info"}}}`) - return writeGraphqlResponse(pair, w, false) + pair.Data.WriteString(`{"data":{"user":{"name":"Bill","info":{"id":11,"__typename":"Info"}}}}`) + return pair.Data.Bytes(), nil }) infoService := NewMockDataSource(ctrl) infoService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4002","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations) { ... on Info { age }}}}}","variables":{"representations":[{"id":11,"__typename":"Info"}]}}}` assert.Equal(t, expected, actual) pair := NewBufPair() - pair.Data.WriteString(`{"_entities":[{"age":21,"__typename":"Info"}]}`) - return writeGraphqlResponse(pair, w, false) + pair.Data.WriteString(`{"data":{"_entities":[{"age":21,"__typename":"Info"}]}}`) + return pair.Data.Bytes(), nil }) return &GraphQLResponse{ @@ -1524,19 +1516,19 @@ func TestResolveGraphQLResponse_Federation(t *testing.T) { t.Run("null info data", testFn(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { userService := NewMockDataSource(ctrl) userService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4001","body":{"query":"{ user { name info {id __typename}}}}"}}` assert.Equal(t, expected, actual) pair := NewBufPair() - pair.Data.WriteString(`{"user":{"name":"Bill","info":null}}`) - return writeGraphqlResponse(pair, w, false) + pair.Data.WriteString(`{"data":{"user":{"name":"Bill","info":null}}}`) + return pair.Data.Bytes(), nil }) infoService := NewMockDataSource(ctrl) infoService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). + Load(gomock.Any(), gomock.Any(), gomock.Any()). Times(0) return &GraphQLResponse{ @@ -1652,19 +1644,19 @@ func TestResolveGraphQLResponse_Federation(t *testing.T) { t.Run("wrong type data", testFn(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { userService := NewMockDataSource(ctrl) userService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4001","body":{"query":"{ user { name info {id __typename}}}}"}}` assert.Equal(t, expected, actual) pair := NewBufPair() - pair.Data.WriteString(`{"user":{"name":"Bill","info":{"id":false,"__typename":"Info"}}}`) - return writeGraphqlResponse(pair, w, false) + pair.Data.WriteString(`{"data":{"user":{"name":"Bill","info":{"id":false,"__typename":"Info"}}}}`) + return pair.Data.Bytes(), nil }) infoService := NewMockDataSource(ctrl) infoService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). + Load(gomock.Any(), gomock.Any(), gomock.Any()). Times(0) return &GraphQLResponse{ @@ -1780,19 +1772,19 @@ func TestResolveGraphQLResponse_Federation(t *testing.T) { t.Run("not matching type data", testFn(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { userService := NewMockDataSource(ctrl) userService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4001","body":{"query":"{ user { name info {id __typename}}}}"}}` assert.Equal(t, expected, actual) pair := NewBufPair() - pair.Data.WriteString(`{"user":{"name":"Bill","info":{"id":1,"__typename":"Whatever"}}}`) - return writeGraphqlResponse(pair, w, false) + pair.Data.WriteString(`{"data":{"user":{"name":"Bill","info":{"id":1,"__typename":"Whatever"}}}}`) + return pair.Data.Bytes(), nil }) infoService := NewMockDataSource(ctrl) infoService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). + Load(gomock.Any(), gomock.Any(), gomock.Any()). Times(0) return &GraphQLResponse{ @@ -1912,19 +1904,19 @@ func TestResolveGraphQLResponse_Federation(t *testing.T) { user := mockedDS(t, ctrl, `{"method":"POST","url":"http://user.service","body":{"query":"{user {account {address {__typename id line1 line2}}}}"}}`, - `{"user":{"account":{"address":{"__typename":"Address","id":"address-1","line1":"line1","line2":"line2"}}}}`) + `{"data":{"user":{"account":{"address":{"__typename":"Address","id":"address-1","line1":"line1","line2":"line2"}}}}}`) addressEnricher := mockedDS(t, ctrl, `{"method":"POST","url":"http://address-enricher.service","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){__typename ... on Address {country city}}}","variables":{"representations":[{"__typename":"Address","id":"address-1"}]}}}`, - `{"__typename":"Address","country":"country-1","city":"city-1"}`) + `{"data":{"__typename":"Address","country":"country-1","city":"city-1"}}`) address := mockedDS(t, ctrl, `{"method":"POST","url":"http://address.service","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){__typename ... on Address {line3(test: "BOOM") zip}}}","variables":{"representations":[{"__typename":"Address","id":"address-1","country":"country-1","city":"city-1"}]}}}`, - `{"__typename": "Address", "line3": "line3-1", "zip": "zip-1"}`) + `{"data":{"__typename": "Address", "line3": "line3-1", "zip": "zip-1"}}`) account := mockedDS(t, ctrl, `{"method":"POST","url":"http://account.service","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){__typename ... on Address {fullAddress}}}","variables":{"representations":[{"__typename":"Address","id":"address-1","line1":"line1","line2":"line2","line3":"line3-1","zip":"zip-1"}]}}}`, - `{"__typename":"Address","fullAddress":"line1 line2 line3-1 city-1 country-1 zip-1"}`) + `{"data":{"__typename":"Address","fullAddress":"line1 line2 line3-1 city-1 country-1 zip-1"}}`) return &GraphQLResponse{ Fetches: Sequence( @@ -2152,19 +2144,19 @@ func TestResolveGraphQLResponse_Federation(t *testing.T) { productsService := mockedDS(t, ctrl, `{"method":"POST","url":"http://products","body":{"query":"query{topProducts{name __typename upc}}"}}`, - `{"topProducts":[{"name":"Table","__typename":"Product","upc":"1"},{"name":"Couch","__typename":"Product","upc":"2"},{"name":"Chair","__typename":"Product","upc":"3"}]}`) + `{"data":{"topProducts":[{"name":"Table","__typename":"Product","upc":"1"},{"name":"Couch","__typename":"Product","upc":"2"},{"name":"Chair","__typename":"Product","upc":"3"}]}}`) reviewsService := mockedDS(t, ctrl, `{"method":"POST","url":"http://reviews","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){__typename ... on Product {reviews {body author {__typename id}}}}}","variables":{"representations":[{"__typename":"Product","upc":"1"},{"__typename":"Product","upc":"2"},{"__typename":"Product","upc":"3"}]}}}`, - `{"_entities":[{"__typename":"Product","reviews":[{"body":"Love Table!","author":{"__typename":"User","id":"1"}},{"body":"Prefer other Table.","author":{"__typename":"User","id":"2"}}]},{"__typename":"Product","reviews":[{"body":"Couch Too expensive.","author":{"__typename":"User","id":"1"}}]},{"__typename":"Product","reviews":[{"body":"Chair Could be better.","author":{"__typename":"User","id":"2"}}]}]}`) + `{"data":{"_entities":[{"__typename":"Product","reviews":[{"body":"Love Table!","author":{"__typename":"User","id":"1"}},{"body":"Prefer other Table.","author":{"__typename":"User","id":"2"}}]},{"__typename":"Product","reviews":[{"body":"Couch Too expensive.","author":{"__typename":"User","id":"1"}}]},{"__typename":"Product","reviews":[{"body":"Chair Could be better.","author":{"__typename":"User","id":"2"}}]}]}}`) stockService := mockedDS(t, ctrl, `{"method":"POST","url":"http://stock","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){__typename ... on Product {stock}}}","variables":{"representations":[{"__typename":"Product","upc":"1"},{"__typename":"Product","upc":"2"},{"__typename":"Product","upc":"3"}]}}}`, - `{"_entities":[{"stock":8},{"stock":2},{"stock":5}]}`) + `{"data":{"_entities":[{"stock":8},{"stock":2},{"stock":5}]}}`) usersService := mockedDS(t, ctrl, `{"method":"POST","url":"http://users","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){__typename ... on User {name}}}","variables":{"representations":[{"__typename":"User","id":"1"},{"__typename":"User","id":"2"}]}}}`, - `{"_entities":[{"name":"user-1"},{"name":"user-2"}]}`) + `{"data":{"_entities":[{"name":"user-1"},{"name":"user-2"}]}}`) return &GraphQLResponse{ Fetches: Sequence( @@ -2424,19 +2416,19 @@ func TestResolveGraphQLResponse_Federation(t *testing.T) { productsService := mockedDS(t, ctrl, `{"method":"POST","url":"http://products","body":{"query":"query{topProducts{name __typename upc}}"}}`, - `{"topProducts":[{"name":"Table","__typename":"Product","upc":"1"}]}`) + `{"data":{"topProducts":[{"name":"Table","__typename":"Product","upc":"1"}]}}`) reviewsService := mockedDS(t, ctrl, `{"method":"POST","url":"http://reviews","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){__typename ... on Product {reviews {body author {__typename id}}}}}","variables":{"representations":[{"__typename":"Product","upc":"1"}]}}}`, - `{"_entities":[{"__typename":"Product","reviews":[{"body":"Love Table!","author":{"__typename":"User","id":"1"}}]}]}`) + `{"data":{"_entities":[{"__typename":"Product","reviews":[{"body":"Love Table!","author":{"__typename":"User","id":"1"}}]}]}}`) stockService := mockedDS(t, ctrl, `{"method":"POST","url":"http://stock","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){__typename ... on Product {stock}}}","variables":{"representations":[{"__typename":"Product","upc":"1"}]}}}`, - `{"_entities":[{"stock":8}]}`) + `{"data":{"_entities":[{"stock":8}]}}`) usersService := mockedDS(t, ctrl, `{"method":"POST","url":"http://users","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){__typename ... on User {name}}}","variables":{"representations":[{"__typename":"User","id":"1"}]}}}`, - `{"_entities":[{"name":"user-1"}]}`) + `{"data":{"_entities":[{"name":"user-1"}]}}`) return &GraphQLResponse{ Fetches: Sequence( @@ -2696,11 +2688,11 @@ func TestResolveGraphQLResponse_Federation(t *testing.T) { accountsService := mockedDS(t, ctrl, `{"method":"POST","url":"http://accounts","body":{"query":"{accounts{__typename ... on User {__typename id} ... on Moderator {__typename moderatorID} ... on Admin {__typename adminID}}}"}}`, - `{"accounts":[{"__typename":"User","id":"3"},{"__typename":"Admin","adminID":"2"},{"__typename":"Moderator","moderatorID":"1"}]}`) + `{"data":{"accounts":[{"__typename":"User","id":"3"},{"__typename":"Admin","adminID":"2"},{"__typename":"Moderator","moderatorID":"1"}]}}`) namesService := mockedDS(t, ctrl, `{"method":"POST","url":"http://names","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){__typename ... on User {name} ... on Moderator {subject} ... on Admin {type}}}","variables":{"representations":[{"__typename":"User","id":"3"},{"__typename":"Admin","adminID":"2"},{"__typename":"Moderator","moderatorID":"1"}]}}}`, - `{"_entities":[{"__typename":"User","name":"User"},{"__typename":"Admin","type":"super"},{"__typename":"Moderator","subject":"posts"}]}`) + `{"data":{"_entities":[{"__typename":"User","name":"User"},{"__typename":"Admin","type":"super"},{"__typename":"Moderator","subject":"posts"}]}}`) return &GraphQLResponse{ Fetches: Sequence( @@ -2836,11 +2828,11 @@ func TestResolveGraphQLResponse_Federation(t *testing.T) { accountsService := mockedDS(t, ctrl, `{"method":"POST","url":"http://accounts","body":{"query":"{accounts {__typename ... on User {some {__typename id}} ... on Admin {some {__typename id}}}}"}}`, - `{"accounts":[{"__typename":"User","some":{"__typename":"User","id":"1"}},{"__typename":"Admin","some":{"__typename":"User","id":"2"}},{"__typename":"User","some":{"__typename":"User","id":"3"}}]}`) + `{"data":{"accounts":[{"__typename":"User","some":{"__typename":"User","id":"1"}},{"__typename":"Admin","some":{"__typename":"User","id":"2"}},{"__typename":"User","some":{"__typename":"User","id":"3"}}]}}`) namesService := mockedDS(t, ctrl, `{"method":"POST","url":"http://names","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){... on User {__typename title}}}","variables":{"representations":[{"__typename":"User","id":"1"},{"__typename":"User","id":"3"}]}}}`, - `{"_entities":[{"__typename":"User","title":"User1"},{"__typename":"User","title":"User3"}]}`) + `{"data":{"_entities":[{"__typename":"User","title":"User1"},{"__typename":"User","title":"User3"}]}}`) return &GraphQLResponse{ Fetches: Sequence( diff --git a/v2/pkg/engine/resolve/resolve_mock_test.go b/v2/pkg/engine/resolve/resolve_mock_test.go index 3f72cc3d8..a64b7dd83 100644 --- a/v2/pkg/engine/resolve/resolve_mock_test.go +++ b/v2/pkg/engine/resolve/resolve_mock_test.go @@ -5,8 +5,8 @@ package resolve import ( - bytes "bytes" context "context" + http "net/http" reflect "reflect" gomock "github.com/golang/mock/gomock" @@ -37,11 +37,12 @@ func (m *MockDataSource) EXPECT() *MockDataSourceMockRecorder { } // Load mocks base method. -func (m *MockDataSource) Load(arg0 context.Context, arg1 []byte, arg2 *bytes.Buffer) error { +func (m *MockDataSource) Load(arg0 context.Context, arg1 http.Header, arg2 []byte) ([]byte, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "Load", arg0, arg1, arg2) - ret0, _ := ret[0].(error) - return ret0 + ret0, _ := ret[0].([]byte) + ret1, _ := ret[1].(error) + return ret0, ret1 } // Load indicates an expected call of Load. @@ -51,11 +52,12 @@ func (mr *MockDataSourceMockRecorder) Load(arg0, arg1, arg2 interface{}) *gomock } // LoadWithFiles mocks base method. -func (m *MockDataSource) LoadWithFiles(arg0 context.Context, arg1 []byte, arg2 []*httpclient.FileUpload, arg3 *bytes.Buffer) error { +func (m *MockDataSource) LoadWithFiles(arg0 context.Context, arg1 http.Header, arg2 []byte, arg3 []*httpclient.FileUpload) ([]byte, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "LoadWithFiles", arg0, arg1, arg2, arg3) - ret0, _ := ret[0].(error) - return ret0 + ret0, _ := ret[0].([]byte) + ret1, _ := ret[1].(error) + return ret0, ret1 } // LoadWithFiles indicates an expected call of LoadWithFiles. diff --git a/v2/pkg/engine/resolve/resolve_test.go b/v2/pkg/engine/resolve/resolve_test.go index 8e15ff98a..112776037 100644 --- a/v2/pkg/engine/resolve/resolve_test.go +++ b/v2/pkg/engine/resolve/resolve_test.go @@ -13,7 +13,6 @@ import ( "testing" "time" - "github.com/cespare/xxhash/v2" "github.com/golang/mock/gomock" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -32,7 +31,7 @@ type _fakeDataSource struct { artificialLatency time.Duration } -func (f *_fakeDataSource) Load(ctx context.Context, input []byte, out *bytes.Buffer) (err error) { +func (f *_fakeDataSource) Load(ctx context.Context, headers http.Header, input []byte) (data []byte, err error) { if f.artificialLatency != 0 { time.Sleep(f.artificialLatency) } @@ -41,11 +40,10 @@ func (f *_fakeDataSource) Load(ctx context.Context, input []byte, out *bytes.Buf require.Equal(f.t, string(f.input), string(input), "input mismatch") } } - _, err = out.Write(f.data) - return + return f.data, nil } -func (f *_fakeDataSource) LoadWithFiles(ctx context.Context, input []byte, files []*httpclient.FileUpload, out *bytes.Buffer) (err error) { +func (f *_fakeDataSource) LoadWithFiles(ctx context.Context, headers http.Header, input []byte, files []*httpclient.FileUpload) (data []byte, err error) { if f.artificialLatency != 0 { time.Sleep(f.artificialLatency) } @@ -54,8 +52,7 @@ func (f *_fakeDataSource) LoadWithFiles(ctx context.Context, input []byte, files require.Equal(f.t, string(f.input), string(input), "input mismatch") } } - _, err = out.Write(f.data) - return + return f.data, nil } func FakeDataSource(data string) *_fakeDataSource { @@ -351,12 +348,11 @@ func TestResolver_ResolveNode(t *testing.T) { t.Run("fetch with context variable resolver", testFn(true, func(t *testing.T, ctrl *gomock.Controller) (response *GraphQLResponse, ctx Context, expectedOutput string) { mockDataSource := NewMockDataSource(ctrl) mockDataSource.EXPECT(). - Load(gomock.Any(), []byte(`{"id":1}`), gomock.AssignableToTypeOf(&bytes.Buffer{})). - Do(func(ctx context.Context, input []byte, w *bytes.Buffer) (err error) { - _, err = w.Write([]byte(`{"name":"Jens"}`)) - return + Load(gomock.Any(), gomock.Any(), []byte(`{"id":1}`)). + Do(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { + return []byte(`{"name":"Jens"}`), nil }). - Return(nil) + Return([]byte(`{"name":"Jens"}`), nil) return &GraphQLResponse{ Fetches: Single(&SingleFetch{ FetchConfiguration: FetchConfiguration{ @@ -1802,11 +1798,9 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { t.Run("fetch with simple error without datasource ID", testFn(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { mockDataSource := NewMockDataSource(ctrl) mockDataSource.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { - pair := NewBufPair() - pair.WriteErr([]byte("errorMessage"), nil, nil, nil) - return writeGraphqlResponse(pair, w, false) + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { + return []byte(`{"errors":[{"message":"errorMessage"}]}`), nil }) return &GraphQLResponse{ Fetches: SingleWithPath(&SingleFetch{ @@ -1834,11 +1828,9 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { t.Run("fetch with simple error without datasource ID no subgraph error forwarding", testFnNoSubgraphErrorForwarding(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { mockDataSource := NewMockDataSource(ctrl) mockDataSource.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { - pair := NewBufPair() - pair.WriteErr([]byte("errorMessage"), nil, nil, nil) - return writeGraphqlResponse(pair, w, false) + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { + return []byte(`{"errors":[{"message":"errorMessage"}]}`), nil }) return &GraphQLResponse{ Fetches: SingleWithPath(&SingleFetch{ @@ -1866,11 +1858,9 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { t.Run("fetch with simple error", testFn(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { mockDataSource := NewMockDataSource(ctrl) mockDataSource.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { - pair := NewBufPair() - pair.WriteErr([]byte("errorMessage"), nil, nil, nil) - return writeGraphqlResponse(pair, w, false) + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { + return []byte(`{"errors":[{"message":"errorMessage"}]}`), nil }) return &GraphQLResponse{ Fetches: SingleWithPath(&SingleFetch{ @@ -1902,11 +1892,9 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { t.Run("fetch with simple error in pass through Subgraph Error Mode", testFnSubgraphErrorsPassthrough(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { mockDataSource := NewMockDataSource(ctrl) mockDataSource.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { - pair := NewBufPair() - pair.WriteErr([]byte("errorMessage"), nil, nil, nil) - return writeGraphqlResponse(pair, w, false) + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { + return []byte(`{"errors":[{"message":"errorMessage"}]}`), nil }) return &GraphQLResponse{ Fetches: Single(&SingleFetch{ @@ -1938,10 +1926,9 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { t.Run("fetch with pass through mode and omit custom fields", testFnSubgraphErrorsPassthroughAndOmitCustomFields(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { mockDataSource := NewMockDataSource(ctrl) mockDataSource.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) error { - _, err := w.Write([]byte(`{"errors":[{"message":"errorMessage","longMessage":"This is a long message","extensions":{"code":"GRAPHQL_VALIDATION_FAILED"}}],"data":{"name":null}}`)) - return err + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { + return []byte(`{"errors":[{"message":"errorMessage","longMessage":"This is a long message","extensions":{"code":"GRAPHQL_VALIDATION_FAILED"}}],"data":{"name":null}}`), nil }) return &GraphQLResponse{ Info: &GraphQLResponseInfo{ @@ -1976,9 +1963,9 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { t.Run("fetch with returned err (with DataSourceID)", testFn(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { mockDataSource := NewMockDataSource(ctrl) mockDataSource.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { - return &net.AddrError{} + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { + return nil, &net.AddrError{} }) return &GraphQLResponse{ Fetches: SingleWithPath(&SingleFetch{ @@ -2010,9 +1997,9 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { t.Run("fetch with returned err (no DataSourceID)", testFn(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { mockDataSource := NewMockDataSource(ctrl) mockDataSource.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { - return &net.AddrError{} + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { + return nil, &net.AddrError{} }) return &GraphQLResponse{ Fetches: SingleWithPath(&SingleFetch{ @@ -2040,9 +2027,9 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { t.Run("fetch with returned err and non-nullable root field", testFn(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { mockDataSource := NewMockDataSource(ctrl) mockDataSource.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { - return &net.AddrError{} + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { + return nil, &net.AddrError{} }) return &GraphQLResponse{ Fetches: SingleWithPath(&SingleFetch{ @@ -2218,14 +2205,10 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { t.Run("fetch with two Errors", testFn(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { mockDataSource := NewMockDataSource(ctrl) mockDataSource.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - Do(func(ctx context.Context, input []byte, w io.Writer) (err error) { - pair := NewBufPair() - pair.WriteErr([]byte("errorMessage1"), nil, nil, nil) - pair.WriteErr([]byte("errorMessage2"), nil, nil, nil) - return writeGraphqlResponse(pair, w, false) - }). - Return(nil) + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { + return []byte(`{"errors":[{"message":"errorMessage1"},{"message":"errorMessage2"}]}`), nil + }).Times(1) return &GraphQLResponse{ Fetches: SingleWithPath(&SingleFetch{ FetchConfiguration: FetchConfiguration{ @@ -2578,39 +2561,32 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { t.Run("complex GraphQL Server plan", testFn(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { serviceOne := NewMockDataSource(ctrl) serviceOne.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"url":"https://service.one","body":{"query":"query($firstArg: String, $thirdArg: Int){serviceOne(serviceOneArg: $firstArg){fieldOne} anotherServiceOne(anotherServiceOneArg: $thirdArg){fieldOne} reusingServiceOne(reusingServiceOneArg: $firstArg){fieldOne}}","variables":{"thirdArg":123,"firstArg":"firstArgValue"}}}` assert.Equal(t, expected, actual) - pair := NewBufPair() - pair.Data.WriteString(`{"serviceOne":{"fieldOne":"fieldOneValue"},"anotherServiceOne":{"fieldOne":"anotherFieldOneValue"},"reusingServiceOne":{"fieldOne":"reUsingFieldOneValue"}}`) - return writeGraphqlResponse(pair, w, false) + return []byte(`{"data":{"serviceOne":{"fieldOne":"fieldOneValue"},"anotherServiceOne":{"fieldOne":"anotherFieldOneValue"},"reusingServiceOne":{"fieldOne":"reUsingFieldOneValue"}}}`), nil }) serviceTwo := NewMockDataSource(ctrl) serviceTwo.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"url":"https://service.two","body":{"query":"query($secondArg: Boolean, $fourthArg: Float){serviceTwo(serviceTwoArg: $secondArg){fieldTwo} secondServiceTwo(secondServiceTwoArg: $fourthArg){fieldTwo}}","variables":{"fourthArg":12.34,"secondArg":true}}}` assert.Equal(t, expected, actual) - - pair := NewBufPair() - pair.Data.WriteString(`{"serviceTwo":{"fieldTwo":"fieldTwoValue"},"secondServiceTwo":{"fieldTwo":"secondFieldTwoValue"}}`) - return writeGraphqlResponse(pair, w, false) + return []byte(`{"data":{"serviceTwo":{"fieldTwo":"fieldTwoValue"},"secondServiceTwo":{"fieldTwo":"secondFieldTwoValue"}}}`), nil }) nestedServiceOne := NewMockDataSource(ctrl) nestedServiceOne.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"url":"https://service.one","body":{"query":"{serviceOne {fieldOne}}"}}` assert.Equal(t, expected, actual) - pair := NewBufPair() - pair.Data.WriteString(`{"serviceOne":{"fieldOne":"fieldOneValue"}}`) - return writeGraphqlResponse(pair, w, false) + return []byte(`{"data":{"serviceOne":{"fieldOne":"fieldOneValue"}}}`), nil }) return &GraphQLResponse{ @@ -2817,259 +2793,35 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { }, Context{ctx: context.Background(), Variables: astjson.MustParseBytes([]byte(`{"firstArg":"firstArgValue","thirdArg":123,"secondArg": true, "fourthArg": 12.34}`))}, `{"data":{"serviceOne":{"fieldOne":"fieldOneValue"},"serviceTwo":{"fieldTwo":"fieldTwoValue","serviceOneResponse":{"fieldOne":"fieldOneValue"}},"anotherServiceOne":{"fieldOne":"anotherFieldOneValue"},"secondServiceTwo":{"fieldTwo":"secondFieldTwoValue"},"reusingServiceOne":{"fieldOne":"reUsingFieldOneValue"}}}` })) t.Run("federation", func(t *testing.T) { - t.Run("simple", testFn(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { - - userService := NewMockDataSource(ctrl) - userService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { - actual := string(input) - expected := `{"method":"POST","url":"http://localhost:4001","body":{"query":"{me {id username}}"}}` - assert.Equal(t, expected, actual) - pair := NewBufPair() - pair.Data.WriteString(`{"me":{"id":"1234","username":"Me","__typename":"User"}}`) - return writeGraphqlResponse(pair, w, false) - }) - - reviewsService := NewMockDataSource(ctrl) - reviewsService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { - actual := string(input) - // {"method":"POST","url":"http://localhost:4002","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){... on User {reviews {body product {upc __typename}}}}}","variables":{"representations":["id":"1234","__typename":"User"]}}} - expected := `{"method":"POST","url":"http://localhost:4002","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){... on User {reviews {body product {upc __typename}}}}}","variables":{"representations":[{"id":"1234","__typename":"User"}]}}}` - assert.Equal(t, expected, actual) - pair := NewBufPair() - pair.Data.WriteString(`{"_entities":[{"reviews":[{"body": "A highly effective form of birth control.","product": {"upc": "top-1","__typename": "Product"}},{"body": "Fedoras are one of the most fashionable hats around and can look great with a variety of outfits.","product": {"upc": "top-2","__typename": "Product"}}]}]}`) - return writeGraphqlResponse(pair, w, false) - }) - - var productServiceCallCount atomic.Int64 - - productService := NewMockDataSource(ctrl) - productService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - Do(func(ctx context.Context, input []byte, w io.Writer) (err error) { - actual := string(input) - productServiceCallCount.Add(1) - switch actual { - case `{"method":"POST","url":"http://localhost:4003","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){... on Product {name}}}","variables":{"representations":[{"upc":"top-1","__typename":"Product"}]}}}`: - pair := NewBufPair() - pair.Data.WriteString(`{"_entities":[{"name": "Furby"}]}`) - return writeGraphqlResponse(pair, w, false) - case `{"method":"POST","url":"http://localhost:4003","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){... on Product {name}}}","variables":{"representations":[{"upc":"top-2","__typename":"Product"}]}}}`: - pair := NewBufPair() - pair.Data.WriteString(`{"_entities":[{"name": "Trilby"}]}`) - return writeGraphqlResponse(pair, w, false) - default: - t.Fatalf("unexpected request: %s", actual) - } - return - }). - Return(nil).Times(2) - - return &GraphQLResponse{ - Fetches: Sequence( - SingleWithPath(&SingleFetch{ - InputTemplate: InputTemplate{ - Segments: []TemplateSegment{ - { - Data: []byte(`{"method":"POST","url":"http://localhost:4001","body":{"query":"{me {id username}}"}}`), - SegmentType: StaticSegmentType, - }, - }, - }, - FetchConfiguration: FetchConfiguration{ - DataSource: userService, - PostProcessing: PostProcessingConfiguration{ - SelectResponseDataPath: []string{"data"}, - }, - }, - }, "query"), - SingleWithPath(&SingleFetch{ - InputTemplate: InputTemplate{ - Segments: []TemplateSegment{ - { - Data: []byte(`{"method":"POST","url":"http://localhost:4002","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){... on User {reviews {body product {upc __typename}}}}}","variables":{"representations":[`), - SegmentType: StaticSegmentType, - }, - { - SegmentType: VariableSegmentType, - VariableKind: ResolvableObjectVariableKind, - Renderer: NewGraphQLVariableResolveRenderer(&Object{ - Fields: []*Field{ - { - Name: []byte("id"), - Value: &String{ - Path: []string{"id"}, - }, - }, - { - Name: []byte("__typename"), - Value: &String{ - Path: []string{"__typename"}, - }, - }, - }, - }), - }, - { - Data: []byte(`]}}}`), - SegmentType: StaticSegmentType, - }, - }, - }, - FetchConfiguration: FetchConfiguration{ - DataSource: reviewsService, - PostProcessing: PostProcessingConfiguration{ - SelectResponseDataPath: []string{"data", "_entities", "0"}, - }, - }, - }, "query.me", ObjectPath("me")), - SingleWithPath(&ParallelListItemFetch{ - Fetch: &SingleFetch{ - FetchConfiguration: FetchConfiguration{ - DataSource: productService, - PostProcessing: PostProcessingConfiguration{ - SelectResponseDataPath: []string{"data", "_entities", "0"}, - }, - }, - InputTemplate: InputTemplate{ - Segments: []TemplateSegment{ - { - Data: []byte(`{"method":"POST","url":"http://localhost:4003","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){... on Product {name}}}","variables":{"representations":[`), - SegmentType: StaticSegmentType, - }, - { - SegmentType: VariableSegmentType, - VariableKind: ResolvableObjectVariableKind, - Renderer: NewGraphQLVariableResolveRenderer(&Object{ - Fields: []*Field{ - { - Name: []byte("upc"), - Value: &String{ - Path: []string{"upc"}, - }, - }, - { - Name: []byte("__typename"), - Value: &String{ - Path: []string{"__typename"}, - }, - }, - }, - }), - }, - { - Data: []byte(`]}}}`), - SegmentType: StaticSegmentType, - }, - }, - }, - }, - }, "query.me.reviews.@.product", ObjectPath("me"), ArrayPath("reviews"), ObjectPath("product")), - ), - Data: &Object{ - Fields: []*Field{ - { - Name: []byte("me"), - Value: &Object{ - Path: []string{"me"}, - Nullable: true, - Fields: []*Field{ - { - Name: []byte("id"), - Value: &String{ - Path: []string{"id"}, - }, - }, - { - Name: []byte("username"), - Value: &String{ - Path: []string{"username"}, - }, - }, - { - - Name: []byte("reviews"), - Value: &Array{ - Path: []string{"reviews"}, - Nullable: true, - Item: &Object{ - Nullable: true, - Fields: []*Field{ - { - Name: []byte("body"), - Value: &String{ - Path: []string{"body"}, - }, - }, - { - Name: []byte("product"), - Value: &Object{ - Path: []string{"product"}, - Fields: []*Field{ - { - Name: []byte("upc"), - Value: &String{ - Path: []string{"upc"}, - }, - }, - { - Name: []byte("name"), - Value: &String{ - Path: []string{"name"}, - }, - }, - }, - }, - }, - }, - }, - }, - }, - }, - }, - }, - }, - }, - }, Context{ctx: context.Background(), Variables: nil}, `{"data":{"me":{"id":"1234","username":"Me","reviews":[{"body":"A highly effective form of birth control.","product":{"upc":"top-1","name":"Furby"}},{"body":"Fedoras are one of the most fashionable hats around and can look great with a variety of outfits.","product":{"upc":"top-2","name":"Trilby"}}]}}}` - })) t.Run("federation with batch", testFn(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { userService := NewMockDataSource(ctrl) userService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w *bytes.Buffer) (err error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4001","body":{"query":"{me {id username}}"}}` assert.Equal(t, expected, actual) - pair := NewBufPair() - pair.Data.WriteString(`{"me":{"id":"1234","username":"Me","__typename": "User"}}`) - return writeGraphqlResponse(pair, w, false) + return []byte(`{"data":{"me":{"id":"1234","username":"Me","__typename": "User"}}}`), nil }) reviewsService := NewMockDataSource(ctrl) reviewsService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w *bytes.Buffer) (err error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4002","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){... on User {reviews {body product {upc __typename}}}}}","variables":{"representations":[{"__typename":"User","id":"1234"}]}}}` assert.Equal(t, expected, actual) - pair := NewBufPair() - pair.Data.WriteString(`{"_entities": [{"__typename":"User","reviews": [{"body": "A highly effective form of birth control.","product": {"upc": "top-1","__typename": "Product"}},{"body": "Fedoras are one of the most fashionable hats around and can look great with a variety of outfits.","product": {"upc": "top-2","__typename": "Product"}}]}]}`) - return writeGraphqlResponse(pair, w, false) + return []byte(`{"data":{"_entities": [{"__typename":"User","reviews": [{"body": "A highly effective form of birth control.","product": {"upc": "top-1","__typename": "Product"}},{"body": "Fedoras are one of the most fashionable hats around and can look great with a variety of outfits.","product": {"upc": "top-2","__typename": "Product"}}]}]}}`), nil }) productService := NewMockDataSource(ctrl) productService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w *bytes.Buffer) (err error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4003","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){... on Product {name}}}","variables":{"representations":[{"__typename":"Product","upc":"top-1"},{"__typename":"Product","upc":"top-2"}]}}}` assert.Equal(t, expected, actual) - pair := NewBufPair() - pair.Data.WriteString(`{"_entities": [{"name": "Trilby"},{"name": "Fedora"}]}`) - return writeGraphqlResponse(pair, w, false) + return []byte(`{"data":{"_entities": [{"name": "Trilby"},{"name": "Fedora"}]}}`), nil }) return &GraphQLResponse{ @@ -3241,38 +2993,32 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { t.Run("federation with merge paths", testFn(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { userService := NewMockDataSource(ctrl) userService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w *bytes.Buffer) (err error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4001","body":{"query":"{me {id username}}"}}` assert.Equal(t, expected, actual) - pair := NewBufPair() - pair.Data.WriteString(`{"me":{"id":"1234","username":"Me","__typename": "User"}}`) - return writeGraphqlResponse(pair, w, false) + return []byte(`{"data":{"me":{"id":"1234","username":"Me","__typename": "User"}}}`), nil }) reviewsService := NewMockDataSource(ctrl) reviewsService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w *bytes.Buffer) (err error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4002","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){... on User {reviews {body product {upc __typename}}}}}","variables":{"representations":[{"__typename":"User","id":"1234"}]}}}` assert.Equal(t, expected, actual) - pair := NewBufPair() - pair.Data.WriteString(`{"_entities": [{"__typename":"User","reviews": [{"body": "A highly effective form of birth control.","product": {"upc": "top-1","__typename": "Product"}},{"body": "Fedoras are one of the most fashionable hats around and can look great with a variety of outfits.","product": {"upc": "top-2","__typename": "Product"}}]}]}`) - return writeGraphqlResponse(pair, w, false) + return []byte(`{"data":{"_entities": [{"__typename":"User","reviews": [{"body": "A highly effective form of birth control.","product": {"upc": "top-1","__typename": "Product"}},{"body": "Fedoras are one of the most fashionable hats around and can look great with a variety of outfits.","product": {"upc": "top-2","__typename": "Product"}}]}]}}`), nil }) productService := NewMockDataSource(ctrl) productService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w *bytes.Buffer) (err error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4003","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){... on Product {name}}}","variables":{"representations":[{"__typename":"Product","upc":"top-1"},{"__typename":"Product","upc":"top-2"}]}}}` assert.Equal(t, expected, actual) - pair := NewBufPair() - pair.Data.WriteString(`{"_entities": [{"name": "Trilby"},{"name": "Fedora"}]}`) - return writeGraphqlResponse(pair, w, false) + return []byte(`{"data":{"_entities": [{"name": "Trilby"},{"name": "Fedora"}]}}`), nil }) return &GraphQLResponse{ @@ -3445,45 +3191,39 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { t.Run("federation with null response", testFn(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { userService := NewMockDataSource(ctrl) userService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4001","body":{"query":"{me {id username}}"}}` assert.Equal(t, expected, actual) - pair := NewBufPair() - pair.Data.WriteString(`{"me":{"id":"1234","username":"Me","__typename": "User"}}`) - return writeGraphqlResponse(pair, w, false) + return []byte(`{"data":{"me":{"id":"1234","username":"Me","__typename": "User"}}}`), nil }) reviewsService := NewMockDataSource(ctrl) reviewsService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4002","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){... on User {reviews {body product {upc __typename}}}}}","variables":{"representations":[{"id":"1234","__typename":"User"}]}}}` assert.Equal(t, expected, actual) - pair := NewBufPair() - pair.Data.WriteString(`{"_entities":[{"reviews": [ + return []byte(`{"data":{"_entities":[{"reviews": [ {"body": "foo","product": {"upc": "top-1","__typename": "Product"}}, {"body": "bar","product": {"upc": "top-2","__typename": "Product"}}, {"body": "baz","product": null}, {"body": "bat","product": {"upc": "top-4","__typename": "Product"}}, {"body": "bal","product": {"upc": "top-5","__typename": "Product"}}, {"body": "ban","product": {"upc": "top-6","__typename": "Product"}} -]}]}`) - return writeGraphqlResponse(pair, w, false) +]}]}}`), nil }) productService := NewMockDataSource(ctrl) productService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4003","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){... on Product {name}}}","variables":{"representations":[{"upc":"top-1","__typename":"Product"},{"upc":"top-2","__typename":"Product"},{"upc":"top-4","__typename":"Product"},{"upc":"top-5","__typename":"Product"},{"upc":"top-6","__typename":"Product"}]}}}` assert.Equal(t, expected, actual) - pair := NewBufPair() - pair.Data.WriteString(`{"_entities":[{"name":"Trilby"},{"name":"Fedora"},{"name":"Boater"},{"name":"Top Hat"},{"name":"Bowler"}]}`) - return writeGraphqlResponse(pair, w, false) + return []byte(`{"data":{"_entities":[{"name":"Trilby"},{"name":"Fedora"},{"name":"Boater"},{"name":"Top Hat"},{"name":"Bowler"}]}}`), nil }) return &GraphQLResponse{ @@ -3678,38 +3418,32 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { userService := NewMockDataSource(ctrl) userService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4001","body":{"query":"{me {id username}}"}}` assert.Equal(t, expected, actual) - pair := NewBufPair() - pair.Data.WriteString(`{"me": {"id": "1234","username": "Me","__typename": "User"}}`) - return writeGraphqlResponse(pair, w, false) + return []byte(`{"data":{"me": {"id": "1234","username": "Me","__typename": "User"}}}`), nil }) reviewsService := NewMockDataSource(ctrl) reviewsService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4002","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){... on User {reviews {body product {upc __typename}}}}}","variables":{"representations":[{"id":"1234","__typename":"User"}]}}}` assert.Equal(t, expected, actual) - pair := NewBufPair() - pair.Data.WriteString(`{"_entities":[{"reviews":[{"body": "A highly effective form of birth control.","product":{"upc": "top-1","__typename":"Product"}},{"body":"Fedoras are one of the most fashionable hats around and can look great with a variety of outfits.","product":{"upc":"top-2","__typename":"Product"}}]}]}`) - return writeGraphqlResponse(pair, w, false) + return []byte(`{"data":{"_entities":[{"reviews":[{"body": "A highly effective form of birth control.","product":{"upc": "top-1","__typename":"Product"}},{"body":"Fedoras are one of the most fashionable hats around and can look great with a variety of outfits.","product":{"upc":"top-2","__typename":"Product"}}]}]}}`), nil }) productService := NewMockDataSource(ctrl) productService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4003","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){... on Product {name}}}","variables":{"representations":[{"upc":"top-1","__typename":"Product"},{"upc":"top-2","__typename":"Product"}]}}}` assert.Equal(t, expected, actual) - pair := NewBufPair() - pair.WriteErr([]byte("errorMessage"), nil, nil, nil) - return writeGraphqlResponse(pair, w, false) + return []byte(`{"errors":[{"message":"errorMessage"}]}`), nil }) return &GraphQLResponse{ @@ -3871,38 +3605,32 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { userService := NewMockDataSource(ctrl) userService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4001","body":{"query":"{me {id username}}"}}` assert.Equal(t, expected, actual) - pair := NewBufPair() - pair.Data.WriteString(`{"me": {"id": "1234","username": "Me","__typename": "User"}}`) - return writeGraphqlResponse(pair, w, false) + return []byte(`{"data":{"me": {"id": "1234","username": "Me","__typename": "User"}}}`), nil }) reviewsService := NewMockDataSource(ctrl) reviewsService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4002","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){... on User {reviews {body product {upc __typename}}}}}","variables":{"representations":[{"id":"1234","__typename":"User"}]}}}` assert.Equal(t, expected, actual) - pair := NewBufPair() - pair.Data.WriteString(`{"_entities":[{"reviews":[{"body": "A highly effective form of birth control.","product":{"upc": "top-1","__typename":"Product"}},{"body":"Fedoras are one of the most fashionable hats around and can look great with a variety of outfits.","product":{"upc":"top-2","__typename":"Product"}}]}]}`) - return writeGraphqlResponse(pair, w, false) + return []byte(`{"data":{"_entities":[{"reviews":[{"body": "A highly effective form of birth control.","product":{"upc": "top-1","__typename":"Product"}},{"body":"Fedoras are one of the most fashionable hats around and can look great with a variety of outfits.","product":{"upc":"top-2","__typename":"Product"}}]}]}}`), nil }) productService := NewMockDataSource(ctrl) productService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4003","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){... on Product {name}}}","variables":{"representations":[{"upc":"top-1","__typename":"Product"},{"upc":"top-2","__typename":"Product"}]}}}` assert.Equal(t, expected, actual) - pair := NewBufPair() - pair.WriteErr([]byte("errorMessage"), nil, nil, nil) - return writeGraphqlResponse(pair, w, false) + return []byte(`{"errors":[{"message":"errorMessage"}]}`), nil }) return &GraphQLResponse{ @@ -4061,38 +3789,32 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { t.Run("federation with optional variable", testFn(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { userService := NewMockDataSource(ctrl) userService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:8080/query","body":{"query":"{me {id}}"}}` assert.Equal(t, expected, actual) - pair := NewBufPair() - pair.Data.WriteString(`{"me":{"id":"1234","__typename":"User"}}`) - return writeGraphqlResponse(pair, w, false) + return []byte(`{"data":{"me":{"id":"1234","__typename":"User"}}}`), nil }) employeeService := NewMockDataSource(ctrl) employeeService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:8081/query","body":{"query":"query($representations: [_Any!]!, $companyId: ID!){_entities(representations: $representations){... on User {employment(companyId: $companyId){id}}}}","variables":{"companyId":"abc123","representations":[{"id":"1234","__typename":"User"}]}}}` assert.Equal(t, expected, actual) - pair := NewBufPair() - pair.Data.WriteString(`{"_entities":[{"employment":{"id":"xyz987"}}]}`) - return writeGraphqlResponse(pair, w, false) + return []byte(`{"data":{"_entities":[{"employment":{"id":"xyz987"}}]}}`), nil }) timeService := NewMockDataSource(ctrl) timeService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:8082/query","body":{"query":"query($representations: [_Any!]!, $date: LocalTime){_entities(representations: $representations){... on Employee {times(date: $date){id employee {id} start end}}}}","variables":{"date":null,"representations":[{"id":"xyz987","__typename":"Employee"}]}}}` assert.Equal(t, expected, actual) - pair := NewBufPair() - pair.Data.WriteString(`{"_entities":[{"times":[{"id": "t1","employee":{"id":"xyz987"},"start":"2022-11-02T08:00:00","end":"2022-11-02T12:00:00"}]}]}`) - return writeGraphqlResponse(pair, w, false) + return []byte(`{"data":{"_entities":[{"times":[{"id": "t1","employee":{"id":"xyz987"},"start":"2022-11-02T08:00:00","end":"2022-11-02T12:00:00"}]}]}}`), nil }) return &GraphQLResponse{ @@ -4263,62 +3985,517 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { }) } -func TestResolver_ApolloCompatibilityMode_FetchError(t *testing.T) { - options := apolloCompatibilityOptions{ - valueCompletion: true, - suppressFetchErrors: true, +// testFnArena is a helper function for testing ArenaResolveGraphQLResponse +func testFnArena(fn func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string)) func(t *testing.T) { + return func(t *testing.T) { + t.Helper() + + ctrl := gomock.NewController(t) + rCtx, cancel := context.WithCancel(context.Background()) + defer cancel() + r := newResolver(rCtx) + node, ctx, expectedOutput := fn(t, ctrl) + + if node.Info == nil { + node.Info = &GraphQLResponseInfo{ + OperationType: ast.OperationTypeQuery, + } + } + + if t.Skipped() { + return + } + + buf := &bytes.Buffer{} + _, err := r.ArenaResolveGraphQLResponse(&ctx, node, buf) + assert.NoError(t, err) + assert.Equal(t, expectedOutput, buf.String()) + ctrl.Finish() } - t.Run("simple fetch with fetch error suppression - empty response", testFnApolloCompatibility(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { - mockDataSource := NewMockDataSource(ctrl) - mockDataSource.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { - _, _ = w.Write([]byte("{}")) - return - }) +} + +func TestResolver_ArenaResolveGraphQLResponse(t *testing.T) { + + t.Run("empty graphql response", testFnArena(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { return &GraphQLResponse{ - Fetches: SingleWithPath(&SingleFetch{ - InputTemplate: InputTemplate{ - Segments: []TemplateSegment{ - { - Data: []byte(`{"method":"POST","url":"http://localhost:4001","body":{"query":"{query{name}}"}}`), - SegmentType: StaticSegmentType, - }, - }, - }, - FetchConfiguration: FetchConfiguration{ - DataSource: mockDataSource, - PostProcessing: PostProcessingConfiguration{ - SelectResponseDataPath: []string{"data"}, - SelectResponseErrorsPath: []string{"errors"}, - }, - }, - }, "query"), Data: &Object{ - Fields: []*Field{ - { - Name: []byte("name"), - Value: &String{ - Path: []string{"name"}, - }, - }, - }, + Nullable: true, }, - }, Context{ctx: context.Background()}, `{"data":null,"extensions":{"valueCompletion":[{"message":"Cannot return null for non-nullable field Query.name.","path":["name"],"extensions":{"code":"INVALID_GRAPHQL"}}]}}` - }, &options)) + }, Context{ctx: context.Background()}, `{"data":{}}` + })) - t.Run("simple fetch with fetch error suppression - response with error", testFnApolloCompatibility(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { - mockDataSource := NewMockDataSource(ctrl) - mockDataSource.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { - _, _ = w.Write([]byte(`{"errors":[{"message":"Cannot query field 'name' on type 'Query'"}]}`)) - return - }) + t.Run("simple data source", testFnArena(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { return &GraphQLResponse{ - Fetches: SingleWithPath(&SingleFetch{ - InputTemplate: InputTemplate{ - Segments: []TemplateSegment{ + Fetches: Single(&SingleFetch{ + FetchConfiguration: FetchConfiguration{DataSource: FakeDataSource(`{"id":"1","name":"Jens","registered":true}`)}, + }), + Data: &Object{ + Fields: []*Field{ + { + Name: []byte("user"), + Value: &Object{ + Fields: []*Field{ + { + Name: []byte("id"), + Value: &String{ + Path: []string{"id"}, + Nullable: false, + }, + }, + { + Name: []byte("name"), + Value: &String{ + Path: []string{"name"}, + Nullable: false, + }, + }, + { + Name: []byte("registered"), + Value: &Boolean{ + Path: []string{"registered"}, + Nullable: false, + }, + }, + }, + }, + }, + }, + }, + }, Context{ctx: context.Background()}, `{"data":{"user":{"id":"1","name":"Jens","registered":true}}}` + })) + + t.Run("array of strings", testFnArena(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { + return &GraphQLResponse{ + Fetches: Single(&SingleFetch{ + FetchConfiguration: FetchConfiguration{DataSource: FakeDataSource(`{"strings": ["Alex", "true", "123"]}`)}, + }), + Data: &Object{ + Fields: []*Field{ + { + Name: []byte("strings"), + Value: &Array{ + Path: []string{"strings"}, + Item: &String{ + Nullable: false, + }, + }, + }, + }, + }, + }, Context{ctx: context.Background()}, `{"data":{"strings":["Alex","true","123"]}}` + })) + + t.Run("array of objects", testFnArena(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { + return &GraphQLResponse{ + Fetches: Single(&SingleFetch{ + FetchConfiguration: FetchConfiguration{DataSource: FakeDataSource(`{"friends":[{"id":1,"name":"Alex"},{"id":2,"name":"Patric"}]}`)}, + }), + Data: &Object{ + Fields: []*Field{ + { + Name: []byte("friends"), + Value: &Array{ + Path: []string{"friends"}, + Item: &Object{ + Fields: []*Field{ + { + Name: []byte("id"), + Value: &Integer{ + Path: []string{"id"}, + Nullable: false, + }, + }, + { + Name: []byte("name"), + Value: &String{ + Path: []string{"name"}, + Nullable: false, + }, + }, + }, + }, + }, + }, + }, + }, + }, Context{ctx: context.Background()}, `{"data":{"friends":[{"id":1,"name":"Alex"},{"id":2,"name":"Patric"}]}}` + })) + + t.Run("nested objects", testFnArena(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { + return &GraphQLResponse{ + Fetches: Single(&SingleFetch{ + FetchConfiguration: FetchConfiguration{DataSource: FakeDataSource(`{"id":"1","name":"Jens","pet":{"name":"Barky","kind":"Dog"}}`)}, + }), + Data: &Object{ + Fields: []*Field{ + { + Name: []byte("user"), + Value: &Object{ + Fields: []*Field{ + { + Name: []byte("id"), + Value: &String{ + Path: []string{"id"}, + Nullable: false, + }, + }, + { + Name: []byte("name"), + Value: &String{ + Path: []string{"name"}, + Nullable: false, + }, + }, + { + Name: []byte("pet"), + Value: &Object{ + Path: []string{"pet"}, + Fields: []*Field{ + { + Name: []byte("name"), + Value: &String{ + Path: []string{"name"}, + Nullable: false, + }, + }, + { + Name: []byte("kind"), + Value: &String{ + Path: []string{"kind"}, + Nullable: false, + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, Context{ctx: context.Background()}, `{"data":{"user":{"id":"1","name":"Jens","pet":{"name":"Barky","kind":"Dog"}}}}` + })) + + t.Run("scalar types", testFnArena(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { + return &GraphQLResponse{ + Fetches: Single(&SingleFetch{ + FetchConfiguration: FetchConfiguration{DataSource: FakeDataSource(`{"int": 12345, "float": 3.5, "str":"value", "bool": true}`)}, + }), + Data: &Object{ + Fields: []*Field{ + { + Name: []byte("int"), + Value: &Integer{ + Path: []string{"int"}, + Nullable: false, + }, + }, + { + Name: []byte("float"), + Value: &Float{ + Path: []string{"float"}, + Nullable: false, + }, + }, + { + Name: []byte("str"), + Value: &String{ + Path: []string{"str"}, + Nullable: false, + }, + }, + { + Name: []byte("bool"), + Value: &Boolean{ + Path: []string{"bool"}, + Nullable: false, + }, + }, + }, + }, + }, Context{ctx: context.Background()}, `{"data":{"int":12345,"float":3.5,"str":"value","bool":true}}` + })) + + t.Run("null field", testFnArena(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { + return &GraphQLResponse{ + Data: &Object{ + Fields: []*Field{ + { + Name: []byte("foo"), + Value: &Null{}, + }, + }, + }, + }, Context{ctx: context.Background()}, `{"data":{"foo":null}}` + })) + + t.Run("__typename field", testFnArena(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { + return &GraphQLResponse{ + Fetches: Single(&SingleFetch{ + FetchConfiguration: FetchConfiguration{DataSource: FakeDataSource(`{"id":1,"name":"Jannik","__typename":"User"}`)}, + }), + Data: &Object{ + Fields: []*Field{ + { + Name: []byte("user"), + Value: &Object{ + Fields: []*Field{ + { + Name: []byte("id"), + Value: &Integer{ + Path: []string{"id"}, + Nullable: false, + }, + }, + { + Name: []byte("name"), + Value: &String{ + Path: []string{"name"}, + Nullable: false, + }, + }, + { + Name: []byte("__typename"), + Value: &String{ + Path: []string{"__typename"}, + Nullable: false, + IsTypeName: true, + }, + }, + }, + }, + }, + }, + }, + }, Context{ctx: context.Background()}, `{"data":{"user":{"id":1,"name":"Jannik","__typename":"User"}}}` + })) + + t.Run("multiple fetches", testFnArena(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { + return &GraphQLResponse{ + Fetches: Single(&SingleFetch{ + FetchConfiguration: FetchConfiguration{DataSource: FakeDataSource(`{"user1":{"id":1,"name":"User1"},"user2":{"id":2,"name":"User2"}}`)}, + }), + Data: &Object{ + Fields: []*Field{ + { + Name: []byte("user1"), + Value: &Object{ + Path: []string{"user1"}, + Fields: []*Field{ + { + Name: []byte("id"), + Value: &Integer{ + Path: []string{"id"}, + Nullable: false, + }, + }, + { + Name: []byte("name"), + Value: &String{ + Path: []string{"name"}, + Nullable: false, + }, + }, + }, + }, + }, + { + Name: []byte("user2"), + Value: &Object{ + Path: []string{"user2"}, + Fields: []*Field{ + { + Name: []byte("id"), + Value: &Integer{ + Path: []string{"id"}, + Nullable: false, + }, + }, + { + Name: []byte("name"), + Value: &String{ + Path: []string{"name"}, + Nullable: false, + }, + }, + }, + }, + }, + }, + }, + }, Context{ctx: context.Background()}, `{"data":{"user1":{"id":1,"name":"User1"},"user2":{"id":2,"name":"User2"}}}` + })) + + t.Run("with variables", testFnArena(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { + mockDataSource := NewMockDataSource(ctrl) + mockDataSource.EXPECT(). + Load(gomock.Any(), gomock.Any(), []byte(`{"id":1}`)). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { + return []byte(`{"name":"Jens"}`), nil + }) + return &GraphQLResponse{ + Fetches: Single(&SingleFetch{ + FetchConfiguration: FetchConfiguration{DataSource: mockDataSource}, + InputTemplate: InputTemplate{ + Segments: []TemplateSegment{ + { + Data: []byte(`{"id":`), + SegmentType: StaticSegmentType, + }, + { + Data: []byte(`{{.arguments.id}}`), + SegmentType: VariableSegmentType, + VariableKind: ContextVariableKind, + VariableSourcePath: []string{"id"}, + Renderer: NewPlainVariableRenderer(), + }, + { + Data: []byte(`}`), + SegmentType: StaticSegmentType, + }, + }, + }, + }), + Data: &Object{ + Fields: []*Field{ + { + Name: []byte("name"), + Value: &String{ + Path: []string{"name"}, + Nullable: false, + }, + }, + }, + }, + }, Context{ctx: context.Background(), Variables: astjson.MustParseBytes([]byte(`{"id":1}`))}, `{"data":{"name":"Jens"}}` + })) + + t.Run("error handling", testFnArena(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { + mockDataSource := NewMockDataSource(ctrl) + mockDataSource.EXPECT(). + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { + return nil, errors.New("data source error") + }) + return &GraphQLResponse{ + Fetches: Single(&SingleFetch{ + FetchConfiguration: FetchConfiguration{DataSource: mockDataSource}, + }), + Data: &Object{ + Fields: []*Field{ + { + Name: []byte("name"), + Value: &String{ + Path: []string{"name"}, + Nullable: false, + }, + }, + }, + }, + }, Context{ctx: context.Background()}, `{"errors":[{"message":"Failed to fetch from Subgraph."}],"data":null}` + })) + + t.Run("bigint handling", testFnArena(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { + return &GraphQLResponse{ + Fetches: Single(&SingleFetch{ + FetchConfiguration: FetchConfiguration{DataSource: FakeDataSource(`{"n": 12345, "ns_small": "12346", "ns_big": "1152921504606846976"}`)}, + }), + Data: &Object{ + Fields: []*Field{ + { + Name: []byte("n"), + Value: &BigInt{ + Path: []string{"n"}, + Nullable: false, + }, + }, + { + Name: []byte("ns_small"), + Value: &BigInt{ + Path: []string{"ns_small"}, + Nullable: false, + }, + }, + { + Name: []byte("ns_big"), + Value: &BigInt{ + Path: []string{"ns_big"}, + Nullable: false, + }, + }, + }, + }, + }, Context{ctx: context.Background()}, `{"data":{"n":12345,"ns_small":"12346","ns_big":"1152921504606846976"}}` + })) + + t.Run("skip loader", testFnArena(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { + return &GraphQLResponse{ + Data: &Object{ + Fields: []*Field{ + { + Name: []byte("static"), + Value: &Null{}, + }, + }, + }, + }, Context{ctx: context.Background(), ExecutionOptions: ExecutionOptions{SkipLoader: true}}, `{"data":null}` + })) +} + +func TestResolver_ApolloCompatibilityMode_FetchError(t *testing.T) { + options := apolloCompatibilityOptions{ + valueCompletion: true, + suppressFetchErrors: true, + } + t.Run("simple fetch with fetch error suppression - empty response", testFnApolloCompatibility(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { + mockDataSource := NewMockDataSource(ctrl) + mockDataSource.EXPECT(). + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { + return []byte("{}"), nil + }) + return &GraphQLResponse{ + Fetches: SingleWithPath(&SingleFetch{ + InputTemplate: InputTemplate{ + Segments: []TemplateSegment{ + { + Data: []byte(`{"method":"POST","url":"http://localhost:4001","body":{"query":"{query{name}}"}}`), + SegmentType: StaticSegmentType, + }, + }, + }, + FetchConfiguration: FetchConfiguration{ + DataSource: mockDataSource, + PostProcessing: PostProcessingConfiguration{ + SelectResponseDataPath: []string{"data"}, + SelectResponseErrorsPath: []string{"errors"}, + }, + }, + }, "query"), + Data: &Object{ + Fields: []*Field{ + { + Name: []byte("name"), + Value: &String{ + Path: []string{"name"}, + }, + }, + }, + }, + }, Context{ctx: context.Background()}, `{"data":null,"extensions":{"valueCompletion":[{"message":"Cannot return null for non-nullable field Query.name.","path":["name"],"extensions":{"code":"INVALID_GRAPHQL"}}]}}` + }, &options)) + + t.Run("simple fetch with fetch error suppression - response with error", testFnApolloCompatibility(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { + mockDataSource := NewMockDataSource(ctrl) + mockDataSource.EXPECT(). + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { + return []byte(`{"errors":[{"message":"Cannot query field 'name' on type 'Query'"}]}`), nil + }) + return &GraphQLResponse{ + Fetches: SingleWithPath(&SingleFetch{ + InputTemplate: InputTemplate{ + Segments: []TemplateSegment{ { Data: []byte(`{"method":"POST","url":"http://localhost:4001","body":{"query":"{query{name}}"}}`), SegmentType: StaticSegmentType, @@ -4349,38 +4526,32 @@ func TestResolver_ApolloCompatibilityMode_FetchError(t *testing.T) { t.Run("complex fetch with fetch error suppression", testFnApolloCompatibility(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { userService := NewMockDataSource(ctrl) userService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4001","body":{"query":"{me {id username}}"}}` assert.Equal(t, expected, actual) - pair := NewBufPair() - pair.Data.WriteString(`{"me": {"id": "1234","username": "Me","__typename": "User"}}`) - return writeGraphqlResponse(pair, w, false) + return []byte(`{"data":{"me": {"id": "1234","username": "Me","__typename": "User"}}}`), nil }) reviewsService := NewMockDataSource(ctrl) reviewsService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4002","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){... on User {reviews {body product {upc __typename}}}}}","variables":{"representations":[{"id":"1234","__typename":"User"}]}}}` assert.Equal(t, expected, actual) - pair := NewBufPair() - pair.Data.WriteString(`{"_entities":[{"reviews":[{"body": "A highly effective form of birth control.","product":{"upc": "top-1","__typename":"Product"}},{"body":"Fedoras are one of the most fashionable hats around and can look great with a variety of outfits.","product":{"upc":"top-2","__typename":"Product"}}]}]}`) - return writeGraphqlResponse(pair, w, false) + return []byte(`{"data":{"_entities":[{"reviews":[{"body": "A highly effective form of birth control.","product":{"upc": "top-1","__typename":"Product"}},{"body":"Fedoras are one of the most fashionable hats around and can look great with a variety of outfits.","product":{"upc":"top-2","__typename":"Product"}}]}]}}`), nil }) productService := NewMockDataSource(ctrl) productService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4003","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){... on Product {name}}}","variables":{"representations":[{"upc":"top-1","__typename":"Product"},{"upc":"top-2","__typename":"Product"}]}}}` assert.Equal(t, expected, actual) - pair := NewBufPair() - pair.WriteErr([]byte("errorMessage"), nil, nil, nil) - return writeGraphqlResponse(pair, w, false) + return []byte(`{"errors":[{"message":"errorMessage"}]}`), nil }) return &GraphQLResponse{ @@ -4566,14 +4737,12 @@ func TestResolver_WithHeader(t *testing.T) { ctrl := gomock.NewController(t) fakeService := NewMockDataSource(ctrl) fakeService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - Do(func(ctx context.Context, input []byte, w io.Writer) (err error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) assert.Equal(t, "foo", actual) - _, err = w.Write([]byte(`{"bar":"baz"}`)) - return - }). - Return(nil) + return []byte(`{"bar":"baz"}`), nil + }) out := &bytes.Buffer{} res := &GraphQLResponse{ @@ -4639,14 +4808,12 @@ func TestResolver_WithVariableRemapping(t *testing.T) { ctrl := gomock.NewController(t) fakeService := NewMockDataSource(ctrl) fakeService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - Do(func(ctx context.Context, input []byte, w io.Writer) (err error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) assert.Equal(t, tc.expectedOutput, actual) - _, err = w.Write([]byte(`{"bar":"baz"}`)) - return - }). - Return(nil) + return []byte(`{"bar":"baz"}`), nil + }) out := &bytes.Buffer{} res := &GraphQLResponse{ @@ -4827,16 +4994,7 @@ func (f *_fakeStream) AwaitIsDone(t *testing.T, timeout time.Duration) { } } -func (f *_fakeStream) UniqueRequestID(ctx *Context, input []byte, xxh *xxhash.Digest) (err error) { - _, err = fmt.Fprint(xxh, fakeStreamRequestId.Add(1)) - if err != nil { - return - } - _, err = xxh.Write(input) - return -} - -func (f *_fakeStream) Start(ctx *Context, input []byte, updater SubscriptionUpdater) error { +func (f *_fakeStream) Start(ctx *Context, headers http.Header, input []byte, updater SubscriptionUpdater) error { if f.onStart != nil { f.onStart(input) } @@ -5909,50 +6067,353 @@ func Test_ResolveGraphQLSubscriptionWithFilter(t *testing.T) { Data: []byte(`{"method":"POST","url":"http://localhost:4000"}`), }, }, - }, - }, - Filter: &SubscriptionFilter{ - In: &SubscriptionFieldFilter{ - FieldPath: []string{"id"}, - Values: []InputTemplate{ - { + }, + }, + Filter: &SubscriptionFilter{ + In: &SubscriptionFieldFilter{ + FieldPath: []string{"id"}, + Values: []InputTemplate{ + { + Segments: []TemplateSegment{ + { + SegmentType: StaticSegmentType, + Data: []byte(`x.`), + }, + { + SegmentType: VariableSegmentType, + VariableKind: ContextVariableKind, + VariableSourcePath: []string{"a"}, + Renderer: NewPlainVariableRenderer(), + }, + { + SegmentType: StaticSegmentType, + Data: []byte(`.`), + }, + { + SegmentType: VariableSegmentType, + VariableKind: ContextVariableKind, + VariableSourcePath: []string{"b"}, + Renderer: NewPlainVariableRenderer(), + }, + }, + }, + }, + }, + }, + Response: &GraphQLResponse{ + Data: &Object{ + Fields: []*Field{ + { + Name: []byte("oneUserByID"), + Value: &Object{ + Fields: []*Field{ + { + Name: []byte("id"), + Value: &String{ + Path: []string{"id"}, + }, + }, + }, + }, + }, + }, + }, + }, + } + + out := &SubscriptionRecorder{ + buf: &bytes.Buffer{}, + messages: []string{}, + complete: atomic.Bool{}, + } + out.complete.Store(false) + + id := SubscriptionIdentifier{ + ConnectionID: 1, + SubscriptionID: 1, + } + + resolver := newResolver(c) + + ctx := &Context{ + ctx: context.Background(), + Variables: astjson.MustParseBytes([]byte(`{"a":[1,2],"b":[3,4]}`)), + } + + err := resolver.AsyncResolveGraphQLSubscription(ctx, plan, out, id) + assert.NoError(t, err) + out.AwaitComplete(t, defaultTimeout) + assert.Equal(t, 4, len(out.Messages())) + assert.ElementsMatch(t, []string{ + `{"errors":[{"message":"invalid subscription filter template"}],"data":null}`, + `{"errors":[{"message":"invalid subscription filter template"}],"data":null}`, + `{"errors":[{"message":"invalid subscription filter template"}],"data":null}`, + `{"errors":[{"message":"invalid subscription filter template"}],"data":null}`, + }, out.Messages()) + }) +} + +func Benchmark_NestedBatching(b *testing.B) { + rCtx, cancel := context.WithCancel(context.Background()) + defer cancel() + + resolver := newResolver(rCtx) + + productsService := fakeDataSourceWithInputCheck(b, + []byte(`{"method":"POST","url":"http://products","body":{"query":"query{topProducts{name __typename upc}}"}}`), + []byte(`{"data":{"topProducts":[{"name":"Table","__typename":"Product","upc":"1"},{"name":"Couch","__typename":"Product","upc":"2"},{"name":"Chair","__typename":"Product","upc":"3"}]}}`)) + stockService := fakeDataSourceWithInputCheck(b, + []byte(`{"method":"POST","url":"http://stock","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){__typename ... on Product {stock}}}","variables":{"representations":[{"__typename":"Product","upc":"1"},{"__typename":"Product","upc":"2"},{"__typename":"Product","upc":"3"}]}}}`), + []byte(`{"data":{"_entities":[{"stock":8},{"stock":2},{"stock":5}]}}`)) + reviewsService := fakeDataSourceWithInputCheck(b, + []byte(`{"method":"POST","url":"http://reviews","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){__typename ... on Product {reviews {body author {__typename id}}}}}","variables":{"representations":[{"__typename":"Product","upc":"1"},{"__typename":"Product","upc":"2"},{"__typename":"Product","upc":"3"}]}}}`), + []byte(`{"data":{"_entities":[{"__typename":"Product","reviews":[{"body":"Love Table!","author":{"__typename":"User","id":"1"}},{"body":"Prefer other Table.","author":{"__typename":"User","id":"2"}}]},{"__typename":"Product","reviews":[{"body":"Couch Too expensive.","author":{"__typename":"User","id":"1"}}]},{"__typename":"Product","reviews":[{"body":"Chair Could be better.","author":{"__typename":"User","id":"2"}}]}]}}`)) + usersService := fakeDataSourceWithInputCheck(b, + []byte(`{"method":"POST","url":"http://users","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){__typename ... on User {name}}}","variables":{"representations":[{"__typename":"User","id":"1"},{"__typename":"User","id":"2"}]}}}`), + []byte(`{"data":{"_entities":[{"name":"user-1"},{"name":"user-2"}]}}`)) + + plan := &GraphQLResponse{ + Fetches: Sequence( + SingleWithPath(&SingleFetch{ + InputTemplate: InputTemplate{ + Segments: []TemplateSegment{ + { + Data: []byte(`{"method":"POST","url":"http://products","body":{"query":"query{topProducts{name __typename upc}}"}}`), + SegmentType: StaticSegmentType, + }, + }, + }, + FetchConfiguration: FetchConfiguration{ + DataSource: productsService, + PostProcessing: PostProcessingConfiguration{ + SelectResponseDataPath: []string{"data"}, + }, + }, + }, ""), + Parallel( + SingleWithPath(&BatchEntityFetch{ + Input: BatchInput{ + Header: InputTemplate{ + Segments: []TemplateSegment{ + { + Data: []byte(`{"method":"POST","url":"http://reviews","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){__typename ... on Product {reviews {body author {__typename id}}}}}","variables":{"representations":[`), + SegmentType: StaticSegmentType, + }, + }, + }, + Items: []InputTemplate{ + { + Segments: []TemplateSegment{ + { + SegmentType: VariableSegmentType, + VariableKind: ResolvableObjectVariableKind, + Renderer: NewGraphQLVariableResolveRenderer(&Object{ + Fields: []*Field{ + { + Name: []byte("__typename"), + Value: &String{ + Path: []string{"__typename"}, + }, + }, + { + Name: []byte("upc"), + Value: &String{ + Path: []string{"upc"}, + }, + }, + }, + }), + }, + }, + }, + }, + Separator: InputTemplate{ + Segments: []TemplateSegment{ + { + Data: []byte(`,`), + SegmentType: StaticSegmentType, + }, + }, + }, + Footer: InputTemplate{ + Segments: []TemplateSegment{ + { + Data: []byte(`]}}}`), + SegmentType: StaticSegmentType, + }, + }, + }, + }, + DataSource: reviewsService, + PostProcessing: PostProcessingConfiguration{ + SelectResponseDataPath: []string{"data", "_entities"}, + }, + }, "topProducts", ArrayPath("topProducts")), + SingleWithPath(&BatchEntityFetch{ + Input: BatchInput{ + Header: InputTemplate{ Segments: []TemplateSegment{ { + Data: []byte(`{"method":"POST","url":"http://stock","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){__typename ... on Product {stock}}}","variables":{"representations":[`), SegmentType: StaticSegmentType, - Data: []byte(`x.`), }, + }, + }, + Items: []InputTemplate{ + { + Segments: []TemplateSegment{ + { + SegmentType: VariableSegmentType, + VariableKind: ResolvableObjectVariableKind, + Renderer: NewGraphQLVariableResolveRenderer(&Object{ + Fields: []*Field{ + { + Name: []byte("__typename"), + Value: &String{ + Path: []string{"__typename"}, + }, + }, + { + Name: []byte("upc"), + Value: &String{ + Path: []string{"upc"}, + }, + }, + }, + }), + }, + }, + }, + }, + Separator: InputTemplate{ + Segments: []TemplateSegment{ { - SegmentType: VariableSegmentType, - VariableKind: ContextVariableKind, - VariableSourcePath: []string{"a"}, - Renderer: NewPlainVariableRenderer(), + Data: []byte(`,`), + SegmentType: StaticSegmentType, }, + }, + }, + Footer: InputTemplate{ + Segments: []TemplateSegment{ { + Data: []byte(`]}}}`), SegmentType: StaticSegmentType, - Data: []byte(`.`), }, + }, + }, + }, + DataSource: stockService, + PostProcessing: PostProcessingConfiguration{ + SelectResponseDataPath: []string{"data", "_entities"}, + }, + }, "topProducts", ArrayPath("topProducts")), + ), + SingleWithPath(&BatchEntityFetch{ + Input: BatchInput{ + Header: InputTemplate{ + Segments: []TemplateSegment{ + { + Data: []byte(`{"method":"POST","url":"http://users","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){__typename ... on User {name}}}","variables":{"representations":[`), + SegmentType: StaticSegmentType, + }, + }, + }, + Items: []InputTemplate{ + { + Segments: []TemplateSegment{ { - SegmentType: VariableSegmentType, - VariableKind: ContextVariableKind, - VariableSourcePath: []string{"b"}, - Renderer: NewPlainVariableRenderer(), + SegmentType: VariableSegmentType, + VariableKind: ResolvableObjectVariableKind, + Renderer: NewGraphQLVariableResolveRenderer(&Object{ + Fields: []*Field{ + { + Name: []byte("__typename"), + Value: &String{ + Path: []string{"__typename"}, + }, + }, + { + Name: []byte("id"), + Value: &String{ + Path: []string{"id"}, + }, + }, + }, + }), }, }, }, }, + Separator: InputTemplate{ + Segments: []TemplateSegment{ + { + Data: []byte(`,`), + SegmentType: StaticSegmentType, + }, + }, + }, + Footer: InputTemplate{ + Segments: []TemplateSegment{ + { + Data: []byte(`]}}}`), + SegmentType: StaticSegmentType, + }, + }, + }, }, - }, - Response: &GraphQLResponse{ - Data: &Object{ - Fields: []*Field{ - { - Name: []byte("oneUserByID"), - Value: &Object{ - Fields: []*Field{ - { - Name: []byte("id"), - Value: &String{ - Path: []string{"id"}, + DataSource: usersService, + PostProcessing: PostProcessingConfiguration{ + SelectResponseDataPath: []string{"data", "_entities"}, + }, + }, "topProducts.@.reviews.@.author", ArrayPath("topProducts"), ArrayPath("reviews"), ObjectPath("author")), + ), + Data: &Object{ + Fields: []*Field{ + { + Name: []byte("topProducts"), + Value: &Array{ + Path: []string{"topProducts"}, + Item: &Object{ + Fields: []*Field{ + { + Name: []byte("name"), + Value: &String{ + Path: []string{"name"}, + }, + }, + { + Name: []byte("stock"), + Value: &Integer{ + Path: []string{"stock"}, + }, + }, + { + Name: []byte("reviews"), + Value: &Array{ + Path: []string{"reviews"}, + Item: &Object{ + Fields: []*Field{ + { + Name: []byte("body"), + Value: &String{ + Path: []string{"body"}, + }, + }, + { + Name: []byte("author"), + Value: &Object{ + Path: []string{"author"}, + Fields: []*Field{ + { + Name: []byte("name"), + Value: &String{ + Path: []string{"name"}, + }, + }, + }, + }, + }, + }, }, }, }, @@ -5961,41 +6422,53 @@ func Test_ResolveGraphQLSubscriptionWithFilter(t *testing.T) { }, }, }, - } + }, + Info: &GraphQLResponseInfo{ + OperationType: ast.OperationTypeQuery, + }, + } - out := &SubscriptionRecorder{ - buf: &bytes.Buffer{}, - messages: []string{}, - complete: atomic.Bool{}, - } - out.complete.Store(false) + expected := []byte(`{"data":{"topProducts":[{"name":"Table","stock":8,"reviews":[{"body":"Love Table!","author":{"name":"user-1"}},{"body":"Prefer other Table.","author":{"name":"user-2"}}]},{"name":"Couch","stock":2,"reviews":[{"body":"Couch Too expensive.","author":{"name":"user-1"}}]},{"name":"Chair","stock":5,"reviews":[{"body":"Chair Could be better.","author":{"name":"user-2"}}]}]}}`) - id := SubscriptionIdentifier{ - ConnectionID: 1, - SubscriptionID: 1, - } + pool := sync.Pool{ + New: func() interface{} { + return bytes.NewBuffer(make([]byte, 0, 1024)) + }, + } - resolver := newResolver(c) + ctxPool := sync.Pool{ + New: func() interface{} { + return NewContext(context.Background()) + }, + } - ctx := &Context{ - ctx: context.Background(), - Variables: astjson.MustParseBytes([]byte(`{"a":[1,2],"b":[3,4]}`)), - } + b.ReportAllocs() + b.SetBytes(int64(len(expected))) + b.ResetTimer() - err := resolver.AsyncResolveGraphQLSubscription(ctx, plan, out, id) - assert.NoError(t, err) - out.AwaitComplete(t, defaultTimeout) - assert.Equal(t, 4, len(out.Messages())) - assert.ElementsMatch(t, []string{ - `{"errors":[{"message":"invalid subscription filter template"}],"data":null}`, - `{"errors":[{"message":"invalid subscription filter template"}],"data":null}`, - `{"errors":[{"message":"invalid subscription filter template"}],"data":null}`, - `{"errors":[{"message":"invalid subscription filter template"}],"data":null}`, - }, out.Messages()) + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + ctx := ctxPool.Get().(*Context) + buf := pool.Get().(*bytes.Buffer) + ctx.ctx = context.Background() + _, err := resolver.ResolveGraphQLResponse(ctx, plan, nil, buf) + if err != nil { + b.Fatal(err) + } + if !bytes.Equal(expected, buf.Bytes()) { + require.Equal(b, string(expected), buf.String()) + } + + buf.Reset() + pool.Put(buf) + + ctx.Free() + ctxPool.Put(ctx) + } }) } -func Benchmark_NestedBatching(b *testing.B) { +func Benchmark_NestedBatchingArena(b *testing.B) { rCtx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -6293,7 +6766,7 @@ func Benchmark_NestedBatching(b *testing.B) { ctx := ctxPool.Get().(*Context) buf := pool.Get().(*bytes.Buffer) ctx.ctx = context.Background() - _, err := resolver.ResolveGraphQLResponse(ctx, plan, nil, buf) + _, err := resolver.ArenaResolveGraphQLResponse(ctx, plan, buf) if err != nil { b.Fatal(err) } @@ -6310,7 +6783,7 @@ func Benchmark_NestedBatching(b *testing.B) { }) } -func Benchmark_NestedBatchingWithoutChecks(b *testing.B) { +func Benchmark_NoCheckNestedBatching(b *testing.B) { rCtx, cancel := context.WithCancel(context.Background()) defer cancel() diff --git a/v2/pkg/engine/resolve/response.go b/v2/pkg/engine/resolve/response.go index b98f4c00f..d8af8d017 100644 --- a/v2/pkg/engine/resolve/response.go +++ b/v2/pkg/engine/resolve/response.go @@ -22,6 +22,8 @@ type GraphQLSubscriptionTrigger struct { Source SubscriptionDataSource PostProcessing PostProcessingConfiguration QueryPlan *QueryPlan + SourceName string + SourceID string } // GraphQLResponse contains an ordered tree of fetches and the response shape. @@ -41,6 +43,19 @@ type GraphQLResponse struct { DataSources []DataSourceInfo } +func (g *GraphQLResponse) SingleFlightAllowed() bool { + if g == nil { + return false + } + if g.Info == nil { + return false + } + if g.Info.OperationType == ast.OperationTypeQuery { + return true + } + return false +} + type GraphQLResponseInfo struct { OperationType ast.OperationType } diff --git a/v2/pkg/engine/resolve/subgraph_request_singleflight.go b/v2/pkg/engine/resolve/subgraph_request_singleflight.go new file mode 100644 index 000000000..85f73d742 --- /dev/null +++ b/v2/pkg/engine/resolve/subgraph_request_singleflight.go @@ -0,0 +1,196 @@ +package resolve + +import ( + "sync" + + "github.com/cespare/xxhash/v2" +) + +// SubgraphRequestSingleFlight is a sharded, goroutine safe single flight implementation to de-duplicate subgraph requests +// It's hashing the input and adds the pre-computed subgraph headers hash to avoid collisions +// In addition to single flight, it provides size hints to create right-sized buffers for subgraph requests +type SubgraphRequestSingleFlight struct { + shards []singleFlightShard + xxPool *sync.Pool +} + +type singleFlightShard struct { + items sync.Map // map[uint64]*SingleFlightItem + sizes sync.Map // map[uint64]*fetchSize +} + +const defaultSingleFlightShardCount = 4 + +// SingleFlightItem is used to communicate between leader and followers +// If an Item for a key doesn't exist, the leader creates and followers can join +type SingleFlightItem struct { + // loaded will be closed by the leader to indicate to followers when the work is done + loaded chan struct{} + // response is the shared result, it must not be modified + response []byte + // err is non nil if the leader produced an error while doing the work + err error + // sizeHint keeps track of the last 50 responses per fetchKey to give an estimate on the size + // this gives a leader a hint on how much space it should pre-allocate for buffers when fetching + // this reduces memory usage + sizeHint int + // SFKey uniquely identifies a single flight request + SFKey uint64 + // FetchKey groups similar fetches for size hinting + FetchKey uint64 +} + +// fetchSize gives an estimate of required buffer size for a given fetchKey when dividing totalBytes / count +type fetchSize struct { + mu sync.Mutex + // count is the number of fetches tracked + count int + // totalBytes is the cumulative bytes across tracked fetches + totalBytes int +} + +func NewSingleFlight(shardCount int) *SubgraphRequestSingleFlight { + if shardCount <= 0 { + shardCount = defaultSingleFlightShardCount + } + s := &SubgraphRequestSingleFlight{ + shards: make([]singleFlightShard, shardCount), + xxPool: &sync.Pool{ + New: func() any { + return xxhash.New() + }, + }, + } + return s +} + +// GetOrCreateItem returns a SingleFlightItem, which contains the single flight key (100% identical fetches), +// a fetchKey (similar fetches, collisions possible but unproblematic because it's only used for size hints), +// and an indication if it is shared or not. +// If not shared, the caller is a leader, otherwise it is a follower. +// item.sizeHint can be used to create an optimal buffer for the fetch in case of a leader. +// item.err must always be checked. +// item.response must never be mutated. +func (s *SubgraphRequestSingleFlight) GetOrCreateItem(fetchItem *FetchItem, input []byte, extraKey uint64) (item *SingleFlightItem, shared bool) { + sfKey, fetchKey := s.computeKeys(fetchItem, input, extraKey) + + // Get shard based on sfKey for items + shard := s.shardFor(sfKey) + + if existing, ok := shard.items.Load(sfKey); ok { + return existing.(*SingleFlightItem), true + } + + item = &SingleFlightItem{ + // empty chan to indicate to all followers when we're done (close) + loaded: make(chan struct{}), + SFKey: sfKey, + FetchKey: fetchKey, + } + // Read size hint from the same shard (both items and sizes use the same shard now) + if sizeValue, ok := shard.sizes.Load(fetchKey); ok { + size := sizeValue.(*fetchSize) + size.mu.Lock() + if size.count > 0 { + item.sizeHint = size.totalBytes / size.count + } + size.mu.Unlock() + } + + actual, loaded := shard.items.LoadOrStore(sfKey, item) + if loaded { + return actual.(*SingleFlightItem), true + } + return item, false +} + +// Finish is for the leader to mark the SingleFlightItem as "done" +// trigger all followers to look at the err & response of the item +// and to update the size estimates +func (s *SubgraphRequestSingleFlight) Finish(item *SingleFlightItem) { + sfKey := item.SFKey + fetchKey := item.FetchKey + close(item.loaded) + // Update sizes in the same shard as the item (using sfKey to get the shard) + shard := s.shardFor(sfKey) + + shard.items.Delete(sfKey) + + sizeValue, ok := shard.sizes.Load(fetchKey) + if !ok { + newSize := &fetchSize{} + sizeValue, _ = shard.sizes.LoadOrStore(fetchKey, newSize) + } + size := sizeValue.(*fetchSize) + size.mu.Lock() + if size.count == 0 { + size.count = 1 + size.totalBytes = len(item.response) + size.mu.Unlock() + return + } + if size.count == 50 { + size.count = 1 + size.totalBytes = size.totalBytes / 50 + } + size.count++ + size.totalBytes += len(item.response) + size.mu.Unlock() +} + +func (s *SubgraphRequestSingleFlight) shardFor(key uint64) *singleFlightShard { + idx := int(key % uint64(len(s.shards))) + return &s.shards[idx] +} + +func (s *SubgraphRequestSingleFlight) computeKeys(fetchItem *FetchItem, input []byte, extraKey uint64) (sfKey, fetchKey uint64) { + h := s.xxPool.Get().(*xxhash.Digest) + sfKey = s.computeSFKey(fetchItem, input, extraKey) + h.Reset() + fetchKey = s.computeFetchKey(fetchItem) + h.Reset() + s.xxPool.Put(h) + return sfKey, fetchKey +} + +// computeSFKey returns a key that 100% uniquely identifies a fetch with no collision. +// Two sfKey values are only the same when the fetches are 100% equal. +func (s *SubgraphRequestSingleFlight) computeSFKey(fetchItem *FetchItem, input []byte, extraKey uint64) uint64 { + h := s.xxPool.Get().(*xxhash.Digest) + if fetchItem != nil && fetchItem.Fetch != nil { + info := fetchItem.Fetch.FetchInfo() + if info != nil { + _, _ = h.WriteString(info.DataSourceID) + _, _ = h.WriteString(":") + } + } + _, _ = h.Write(input) + return h.Sum64() + extraKey // extraKey in this case is the pre-generated hash for the headers +} + +// computeFetchKey is a less robust key compared to sfKey. +// The purpose is to create a key from the DataSourceID and root fields to have less cardinality. +// The goal is to get an estimate buffer size for similar fetches; hashing headers or the body is not needed. +func (s *SubgraphRequestSingleFlight) computeFetchKey(fetchItem *FetchItem) uint64 { + h := s.xxPool.Get().(*xxhash.Digest) + defer s.xxPool.Put(h) + if fetchItem == nil || fetchItem.Fetch == nil { + return 0 + } + info := fetchItem.Fetch.FetchInfo() + if info == nil { + return 0 + } + _, _ = h.WriteString(info.DataSourceID) + _, _ = h.Write(pipe) + for i := range info.RootFields { + if i != 0 { + _, _ = h.Write(comma) + } + _, _ = h.WriteString(info.RootFields[i].TypeName) + _, _ = h.Write(dot) + _, _ = h.WriteString(info.RootFields[i].FieldName) + } + sum := h.Sum64() + return sum +} diff --git a/v2/pkg/engine/resolve/subgraph_request_singleflight_test.go b/v2/pkg/engine/resolve/subgraph_request_singleflight_test.go new file mode 100644 index 000000000..312236359 --- /dev/null +++ b/v2/pkg/engine/resolve/subgraph_request_singleflight_test.go @@ -0,0 +1,209 @@ +package resolve + +import ( + "bytes" + "fmt" + "testing" +) + +type stubFetch struct { + info *FetchInfo +} + +func (s *stubFetch) FetchKind() FetchKind { + return FetchKindSingle +} + +func (s *stubFetch) Dependencies() *FetchDependencies { + return nil +} + +func (s *stubFetch) FetchInfo() *FetchInfo { + return s.info +} + +type nilInfoFetch struct{} + +func (n *nilInfoFetch) FetchKind() FetchKind { + return FetchKindSingle +} + +func (n *nilInfoFetch) Dependencies() *FetchDependencies { + return nil +} + +func (n *nilInfoFetch) FetchInfo() *FetchInfo { + return nil +} + +func newFetchItem(info *FetchInfo) *FetchItem { + return &FetchItem{ + Fetch: &stubFetch{ + info: info, + }, + } +} + +func TestSubgraphRequestSingleFlight_LeaderFollowerSizeHint(t *testing.T) { + flight := NewSingleFlight(2) + fetchInfo := &FetchInfo{ + DataSourceID: "accounts", + RootFields: []GraphCoordinate{ + {TypeName: "Query", FieldName: "viewer"}, + }, + } + fetchItem := newFetchItem(fetchInfo) + + item, shared := flight.GetOrCreateItem(fetchItem, []byte("query { viewer { id } }"), 42) + if shared { + t.Fatalf("expected leader to be first caller") + } + if item == nil { + t.Fatalf("expected item, got nil") + } + if item.sizeHint != 0 { + t.Fatalf("expected empty size hint, got %d", item.sizeHint) + } + + follower, followerShared := flight.GetOrCreateItem(fetchItem, []byte("query { viewer { id } }"), 42) + if !followerShared { + t.Fatalf("expected second caller to be follower") + } + if follower != item { + t.Fatalf("expected follower to receive same item instance") + } + + item.response = []byte("hello") + flight.Finish(item) + + select { + case <-item.loaded: + default: + t.Fatalf("expected leader to close loaded channel") + } + + next, nextShared := flight.GetOrCreateItem(fetchItem, []byte("query { viewer { id } }"), 42) + if nextShared { + t.Fatalf("expected new leader after finish") + } + if next == item { + t.Fatalf("expected new item after finish") + } + if next.sizeHint != len("hello") { + t.Fatalf("expected size hint %d, got %d", len("hello"), next.sizeHint) + } +} + +func TestSubgraphRequestSingleFlight_SimilarFetchesShareFetchKey(t *testing.T) { + flight := NewSingleFlight(1) + fetchInfo := &FetchInfo{ + DataSourceID: "reviews", + RootFields: []GraphCoordinate{ + {TypeName: "Query", FieldName: "reviews"}, + }, + } + fetchItem := newFetchItem(fetchInfo) + + item1, shared1 := flight.GetOrCreateItem(fetchItem, []byte("body-1"), 0) + if shared1 { + t.Fatalf("expected first call to be leader") + } + item1.response = []byte("first response") + flight.Finish(item1) + + item2, shared2 := flight.GetOrCreateItem(fetchItem, []byte("body-2"), 0) + if shared2 { + t.Fatalf("expected leader after finishing previous item") + } + if item1.FetchKey != item2.FetchKey { + t.Fatalf("expected identical fetch keys for similar fetches") + } + if item1.SFKey == item2.SFKey { + t.Fatalf("expected different single-flight keys for different request bodies") + } + item2.response = []byte("second response") + flight.Finish(item2) +} + +func TestSubgraphRequestSingleFlight_FetchKeyZeroWithoutFetchInfo(t *testing.T) { + t.Run("nil fetch item", func(t *testing.T) { + flight := NewSingleFlight(1) + item, shared := flight.GetOrCreateItem(nil, []byte("body"), 0) + if shared { + t.Fatalf("expected leader for nil fetch item") + } + if item.FetchKey != 0 { + t.Fatalf("expected fetch key 0, got %d", item.FetchKey) + } + flight.Finish(item) + }) + + t.Run("nil fetch", func(t *testing.T) { + flight := NewSingleFlight(1) + item, shared := flight.GetOrCreateItem(&FetchItem{}, []byte("body"), 0) + if shared { + t.Fatalf("expected leader for nil fetch") + } + if item.FetchKey != 0 { + t.Fatalf("expected fetch key 0, got %d", item.FetchKey) + } + flight.Finish(item) + }) + + t.Run("missing fetch info", func(t *testing.T) { + flight := NewSingleFlight(1) + item, shared := flight.GetOrCreateItem(&FetchItem{Fetch: &nilInfoFetch{}}, []byte("body"), 0) + if shared { + t.Fatalf("expected leader for missing fetch info") + } + if item.FetchKey != 0 { + t.Fatalf("expected fetch key 0, got %d", item.FetchKey) + } + flight.Finish(item) + }) +} + +func TestSubgraphRequestSingleFlight_SizeHintRollingWindow(t *testing.T) { + flight := NewSingleFlight(1) + fetchInfo := &FetchInfo{ + DataSourceID: "products", + RootFields: []GraphCoordinate{ + {TypeName: "Query", FieldName: "products"}, + }, + } + fetchItem := newFetchItem(fetchInfo) + + var fetchKey uint64 + for i := 0; i < 50; i++ { + item, shared := flight.GetOrCreateItem(fetchItem, []byte(fmt.Sprintf("body-%d", i)), 0) + if shared { + t.Fatalf("expected leader for iteration %d", i) + } + if i == 0 { + fetchKey = item.FetchKey + } else if item.FetchKey != fetchKey { + t.Fatalf("expected consistent fetch key across iterations, got %d and %d", fetchKey, item.FetchKey) + } + item.response = bytes.Repeat([]byte("a"), 100) + flight.Finish(item) + } + + item, shared := flight.GetOrCreateItem(fetchItem, []byte("body-50"), 0) + if shared { + t.Fatalf("expected leader for rolling window update") + } + if item.FetchKey != fetchKey { + t.Fatalf("expected same fetch key, got %d and %d", fetchKey, item.FetchKey) + } + item.response = bytes.Repeat([]byte("b"), 200) + flight.Finish(item) + + next, nextShared := flight.GetOrCreateItem(fetchItem, []byte("body-51"), 0) + if nextShared { + t.Fatalf("expected leader for new request") + } + expected := 150 + if next.sizeHint != expected { + t.Fatalf("expected rolling average size hint %d, got %d", expected, next.sizeHint) + } +} diff --git a/v2/pkg/engine/resolve/tainted_objects_test.go b/v2/pkg/engine/resolve/tainted_objects_test.go index 0eeb34440..b8205dc72 100644 --- a/v2/pkg/engine/resolve/tainted_objects_test.go +++ b/v2/pkg/engine/resolve/tainted_objects_test.go @@ -70,7 +70,7 @@ func TestSelectObjectAndIndex(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - response, err := astjson.ParseBytesWithoutCache([]byte(tt.responseJSON)) + response, err := astjson.ParseBytes([]byte(tt.responseJSON)) assert.NoError(t, err, "Failed to parse response JSON") // Convert path elements to astjson.Value slice @@ -94,7 +94,7 @@ func TestSelectObjectAndIndex(t *testing.T) { assert.Nil(t, entity, "Expected nil entity") } else { assert.NotNil(t, entity, "Expected non-nil entity") - expectedEntity, err := astjson.ParseBytesWithoutCache([]byte(tt.expectedEntity)) + expectedEntity, err := astjson.ParseBytes([]byte(tt.expectedEntity)) assert.NoError(t, err, "Failed to parse expected entity JSON") // Compare JSON representations @@ -320,10 +320,10 @@ func TestGetTaintedIndices(t *testing.T) { } mockFetch := &mockFetchWithInfo{info: fetchInfo} - response, err := astjson.ParseBytesWithoutCache([]byte(tt.responseJSON)) + response, err := astjson.ParseBytes([]byte(tt.responseJSON)) assert.NoError(t, err, "Failed to parse response JSON") - errors, err := astjson.ParseBytesWithoutCache([]byte(tt.errorsJSON)) + errors, err := astjson.ParseBytes([]byte(tt.errorsJSON)) assert.NoError(t, err, "Failed to parse errors JSON") indices := getTaintedIndices(mockFetch, response, errors) diff --git a/v2/pkg/engine/resolve/variables_renderer.go b/v2/pkg/engine/resolve/variables_renderer.go index 4cbb471f8..0fa1d3ee1 100644 --- a/v2/pkg/engine/resolve/variables_renderer.go +++ b/v2/pkg/engine/resolve/variables_renderer.go @@ -350,7 +350,7 @@ var ( func (g *GraphQLVariableResolveRenderer) getResolvable() *Resolvable { v := _graphQLVariableResolveRendererPool.Get() if v == nil { - return NewResolvable(ResolvableOptions{}) + return NewResolvable(nil, ResolvableOptions{}) } return v.(*Resolvable) } diff --git a/v2/pkg/fastjsonext/fastjsonext.go b/v2/pkg/fastjsonext/fastjsonext.go index 0480fcbd4..4929e8a96 100644 --- a/v2/pkg/fastjsonext/fastjsonext.go +++ b/v2/pkg/fastjsonext/fastjsonext.go @@ -2,27 +2,28 @@ package fastjsonext import ( "github.com/wundergraph/astjson" + "github.com/wundergraph/go-arena" ) -func AppendErrorToArray(arena *astjson.Arena, v *astjson.Value, msg string, path []PathElement) { +func AppendErrorToArray(a arena.Arena, v *astjson.Value, msg string, path []PathElement) { if v.Type() != astjson.TypeArray { return } - errorObject := CreateErrorObjectWithPath(arena, msg, path) + errorObject := CreateErrorObjectWithPath(a, msg, path) items, _ := v.Array() - v.SetArrayItem(len(items), errorObject) + v.SetArrayItem(a, len(items), errorObject) } -func AppendErrorWithExtensionsCodeToArray(arena *astjson.Arena, v *astjson.Value, msg, code string, path []PathElement) { +func AppendErrorWithExtensionsCodeToArray(a arena.Arena, v *astjson.Value, msg, code string, path []PathElement) { if v.Type() != astjson.TypeArray { return } - errorObject := CreateErrorObjectWithPath(arena, msg, path) - extensions := arena.NewObject() - extensions.Set("code", arena.NewString(code)) - errorObject.Set("extensions", extensions) + errorObject := CreateErrorObjectWithPath(a, msg, path) + extensions := astjson.ObjectValue(a) + extensions.Set(a, "code", astjson.StringValue(a, code)) + errorObject.Set(a, "extensions", extensions) items, _ := v.Array() - v.SetArrayItem(len(items), errorObject) + v.SetArrayItem(a, len(items), errorObject) } type PathElement struct { @@ -30,29 +31,29 @@ type PathElement struct { Idx int } -func CreateErrorObjectWithPath(arena *astjson.Arena, message string, path []PathElement) *astjson.Value { - errorObject := arena.NewObject() - errorObject.Set("message", arena.NewString(message)) +func CreateErrorObjectWithPath(a arena.Arena, message string, path []PathElement) *astjson.Value { + errorObject := astjson.ObjectValue(a) + errorObject.Set(a, "message", astjson.StringValue(a, message)) if len(path) == 0 { return errorObject } - errorPath := arena.NewArray() + errorPath := astjson.ArrayValue(a) for i := range path { if path[i].Name != "" { - errorPath.SetArrayItem(i, arena.NewString(path[i].Name)) + errorPath.SetArrayItem(a, i, astjson.StringValue(a, path[i].Name)) } else { - errorPath.SetArrayItem(i, arena.NewNumberInt(path[i].Idx)) + errorPath.SetArrayItem(a, i, astjson.IntValue(a, path[i].Idx)) } } - errorObject.Set("path", errorPath) + errorObject.Set(a, "path", errorPath) return errorObject } func PrintGraphQLResponse(data, errors *astjson.Value) string { out := astjson.MustParse(`{}`) if astjson.ValueIsNonNull(errors) { - out.Set("errors", errors) + out.Set(nil, "errors", errors) } - out.Set("data", data) + out.Set(nil, "data", data) return string(out.MarshalTo(nil)) } diff --git a/v2/pkg/fastjsonext/fastjsonext_test.go b/v2/pkg/fastjsonext/fastjsonext_test.go index af4271630..e48a2ad1c 100644 --- a/v2/pkg/fastjsonext/fastjsonext_test.go +++ b/v2/pkg/fastjsonext/fastjsonext_test.go @@ -21,28 +21,28 @@ func TestGetArray(t *testing.T) { func TestAppendErrorWithMessage(t *testing.T) { a := astjson.MustParse(`[]`) - AppendErrorToArray(&astjson.Arena{}, a, "error", nil) + AppendErrorToArray(nil, a, "error", nil) out := a.MarshalTo(nil) require.Equal(t, `[{"message":"error"}]`, string(out)) - AppendErrorToArray(&astjson.Arena{}, a, "error2", []PathElement{{Name: "a"}}) + AppendErrorToArray(nil, a, "error2", []PathElement{{Name: "a"}}) out = a.MarshalTo(nil) require.Equal(t, `[{"message":"error"},{"message":"error2","path":["a"]}]`, string(out)) } func TestCreateErrorObjectWithPath(t *testing.T) { - v := CreateErrorObjectWithPath(&astjson.Arena{}, "my error message", []PathElement{ + v := CreateErrorObjectWithPath(nil, "my error message", []PathElement{ {Name: "a"}, }) out := v.MarshalTo(nil) require.Equal(t, `{"message":"my error message","path":["a"]}`, string(out)) - v = CreateErrorObjectWithPath(&astjson.Arena{}, "my error message", []PathElement{ + v = CreateErrorObjectWithPath(nil, "my error message", []PathElement{ {Name: "a"}, {Idx: 1}, {Name: "b"}, }) out = v.MarshalTo(nil) require.Equal(t, `{"message":"my error message","path":["a",1,"b"]}`, string(out)) - v = CreateErrorObjectWithPath(&astjson.Arena{}, "my error message", []PathElement{ + v = CreateErrorObjectWithPath(nil, "my error message", []PathElement{ {Name: "a"}, {Name: "b"}, }) diff --git a/v2/pkg/variablesvalidation/variablesvalidation.go b/v2/pkg/variablesvalidation/variablesvalidation.go index d9631739e..b1af4f40e 100644 --- a/v2/pkg/variablesvalidation/variablesvalidation.go +++ b/v2/pkg/variablesvalidation/variablesvalidation.go @@ -98,7 +98,7 @@ func (v *VariablesValidator) ValidateWithRemap(operation, definition *ast.Docume func (v *VariablesValidator) Validate(operation, definition *ast.Document, variables []byte) error { v.visitor.definition = definition v.visitor.operation = operation - v.visitor.variables, v.visitor.err = astjson.ParseBytesWithoutCache(variables) + v.visitor.variables, v.visitor.err = astjson.ParseBytes(variables) if v.visitor.err != nil { return v.visitor.err }