diff --git a/CHANGELOG.md b/CHANGELOG.md index b2a63b2..4326d10 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -26,6 +26,7 @@ FIXES * Fixing paths using in-built library instead of string manipulation. See [#143](https://github.com/hashicorp/terraform-mcp-server/pull/143) * Explicitly setting destructive annotation to false. See [#143](https://github.com/hashicorp/terraform-mcp-server/pull/143) +* Fix provider search prioritization to show official providers first in search results. See [#179](https://github.com/hashicorp/terraform-mcp-server/pull/179) SECURITY diff --git a/pkg/client/registry.go b/pkg/client/registry.go index ff897f8..e81d1df 100644 --- a/pkg/client/registry.go +++ b/pkg/client/registry.go @@ -62,7 +62,9 @@ func createHTTPClient(insecureSkipVerify bool, logger *log.Logger) *http.Client return retryClient.StandardClient() } -func SendRegistryCall(client *http.Client, method string, uri string, logger *log.Logger, callOptions ...string) ([]byte, error) { +// SendRegistryCallFn is a package-level function variable so callers (and tests) +// can override registry call behavior for testing. +var SendRegistryCallFn = func(client *http.Client, method string, uri string, logger *log.Logger, callOptions ...string) ([]byte, error) { ver := "v1" if len(callOptions) > 0 { ver = callOptions[0] // API version will be the first optional arg to this function @@ -100,6 +102,11 @@ func SendRegistryCall(client *http.Client, method string, uri string, logger *lo return body, nil } +// Backwards-compatible wrapper +func SendRegistryCall(client *http.Client, method string, uri string, logger *log.Logger, callOptions ...string) ([]byte, error) { + return SendRegistryCallFn(client, method, uri, logger, callOptions...) +} + func SendPaginatedRegistryCall(client *http.Client, uriPrefix string, logger *log.Logger) ([]ProviderDocData, error) { var results []ProviderDocData page := 1 diff --git a/pkg/tools/registry/search_providers.go b/pkg/tools/registry/search_providers.go index 6e96ea9..63a8106 100644 --- a/pkg/tools/registry/search_providers.go +++ b/pkg/tools/registry/search_providers.go @@ -9,6 +9,7 @@ import ( "fmt" "net/http" "path" + "sort" "strings" "github.com/hashicorp/terraform-mcp-server/pkg/client" @@ -19,6 +20,26 @@ import ( "github.com/mark3labs/mcp-go/server" ) +// sendRegistryCall is a package-level variable so tests can override registry calls. +var sendRegistryCall = client.SendRegistryCall + +// tierOrder defines sorting priority for provider tiers. +var tierOrder = map[string]int{"official": 0, "partner": 1, "community": 2} + +type providerMatch struct { + Namespace string + Name string + Tier string + DocMatch []client.ProviderDoc +} + +// sortMatchesByTier sorts the matches slice in-place by tier using tierOrder. +func sortMatchesByTier(matches []providerMatch) { + sort.SliceStable(matches, func(i, j int) bool { + return tierOrder[strings.ToLower(matches[i].Tier)] < tierOrder[strings.ToLower(matches[j].Tier)] + }) +} + // ResolveProviderDocID creates a tool to get provider details from registry. func ResolveProviderDocID(logger *log.Logger) server.ServerTool { return server.ServerTool{ @@ -27,8 +48,8 @@ func ResolveProviderDocID(logger *log.Logger) server.ServerTool { You MUST call this function before 'get_provider_details' to obtain a valid tfprovider-compatible provider_doc_id. Use the most relevant single word as the search query for service_slug, if unsure about the service_slug, use the provider_name for its value. When selecting the best match, consider the following: - - Title similarity to the query - - Category relevance + - Title similarity to the query + - Category relevance Return the selected provider_doc_id and explain your choice. If there are multiple good matches, mention this but proceed with the most relevant one.`), mcp.WithTitleAnnotation("Identify the most relevant provider document ID for a Terraform service"), @@ -92,56 +113,161 @@ func resolveProviderDocIDHandler(ctx context.Context, request mcp.CallToolReques if utils.IsV2ProviderDataType(providerDetail.ProviderDataType) { content, err := providerDetailsV2(httpClient, providerDetail, logger) if err != nil { - errMessage := fmt.Sprintf(`finding %s documentation for provider '%s' in the '%s' namespace, %s`, - providerDetail.ProviderDataType, providerDetail.ProviderName, providerDetail.ProviderNamespace, defaultErrorGuide) + errMessage := fmt.Sprintf(`finding %s documentation for provider '%s' in the '%s' namespace, %s`, providerDetail.ProviderDataType, providerDetail.ProviderName, providerDetail.ProviderNamespace, defaultErrorGuide) return nil, utils.LogAndReturnError(logger, errMessage, err) } - fullContent := fmt.Sprintf("# %s provider docs\n\n%s", - providerDetail.ProviderName, content) + fullContent := fmt.Sprintf("# %s provider docs\n\n%s", providerDetail.ProviderName, content) return mcp.NewToolResultText(fullContent), nil } - // For resources/data-sources, use the v1 API for better performance (single response) - uri := path.Join("providers", providerDetail.ProviderNamespace, providerDetail.ProviderName, providerDetail.ProviderVersion) - response, err := client.SendRegistryCall(httpClient, "GET", uri, logger) + // Delegate to extracted helper so it can be unit-tested. + result, err := searchProvidersDocs(httpClient, providerDetail, serviceSlug, defaultErrorGuide, logger) if err != nil { - return nil, utils.LogAndReturnError(logger, fmt.Sprintf(`getting the "%s" provider, with version "%s" in the %s namespace, %s`, providerDetail.ProviderName, providerDetail.ProviderVersion, providerDetail.ProviderNamespace, defaultErrorGuide), nil) + return nil, err } + return mcp.NewToolResultText(result), nil +} - var providerDocs client.ProviderDocs - if err := json.Unmarshal(response, &providerDocs); err != nil { - return nil, utils.LogAndReturnError(logger, "unmarshalling provider docs", err) +// searchProvidersDocs contains the core provider-search and prioritization logic. +// It returns the textual result (same content as the tool would return) for easier unit testing. +func searchProvidersDocs(httpClient *http.Client, providerDetail client.ProviderDetail, serviceSlug string, defaultErrorGuide string, logger *log.Logger) (string, error) { + // Enhanced: Search all providers matching the name and prioritize by tier + searchUri := "providers?filter[name]=" + providerDetail.ProviderName + searchResp, err := sendRegistryCall(httpClient, "GET", searchUri, logger, "v2") + if err != nil { + return "", utils.LogAndReturnError(logger, "error searching providers in registry", err) } - var builder strings.Builder - builder.WriteString(fmt.Sprintf("Available Documentation (top matches) for %s in Terraform provider %s/%s version: %s\n\n", providerDetail.ProviderDataType, providerDetail.ProviderNamespace, providerDetail.ProviderName, providerDetail.ProviderVersion)) - builder.WriteString("Each result includes:\n- providerDocID: tfprovider-compatible identifier\n- Title: Service or resource name\n- Category: Type of document\n- Description: Brief summary of the document\n") - builder.WriteString("For best results, select libraries based on the service_slug match and category of information requested.\n\n---\n\n") + var providerList client.ProviderList + if err := json.Unmarshal(searchResp, &providerList); err != nil { + return "", utils.LogAndReturnError(logger, "unmarshalling provider list", err) + } + + // If the registry search didn't return any providers, fall back to fetching + // the single provider directly (preserves previous behavior for cases where + // provider namespace defaults to hashicorp and the search endpoint may not + // return results matching our filter). + logger.Infof("provider search returned %d providers for name '%s'", len(providerList.Data), providerDetail.ProviderName) + if len(providerList.Data) == 0 { + logger.Infof("falling back to single-provider fetch for %s/%s@%s", providerDetail.ProviderNamespace, providerDetail.ProviderName, providerDetail.ProviderVersion) + uri := path.Join("providers", providerDetail.ProviderNamespace, providerDetail.ProviderName, providerDetail.ProviderVersion) + response, err := sendRegistryCall(httpClient, "GET", uri, logger) + logger.Debugf("provider docs fetch URI: %s", uri) + if err != nil { + return "", utils.LogAndReturnError(logger, fmt.Sprintf(`getting the "%s" provider, with version "%s" in the %s namespace, %s`, providerDetail.ProviderName, providerDetail.ProviderVersion, providerDetail.ProviderNamespace, defaultErrorGuide), nil) + } + var providerDocs client.ProviderDocs + if err := json.Unmarshal(response, &providerDocs); err != nil { + return "", utils.LogAndReturnError(logger, "unmarshalling provider docs", err) + } + logger.Infof("provider docs returned %d docs for %s/%s@%s", len(providerDocs.Docs), providerDetail.ProviderNamespace, providerDetail.ProviderName, providerDetail.ProviderVersion) + var builder strings.Builder + builder.WriteString(fmt.Sprintf("Available Documentation (top matches) for %s in Terraform provider %s/%s version: %s\n\n", providerDetail.ProviderDataType, providerDetail.ProviderNamespace, providerDetail.ProviderName, providerDetail.ProviderVersion)) + builder.WriteString("Each result includes:\n- providerDocID: tfprovider-compatible identifier\n- Title: Service or resource name\n- Category: Type of document\n- Description: Brief summary of the document\n") + builder.WriteString("For best results, select libraries based on the service_slug match and category of information requested.\n\n---\n\n") + contentAvailable := false + for _, doc := range providerDocs.Docs { + if doc.Language == "hcl" && doc.Category == providerDetail.ProviderDataType { + cs, err := utils.ContainsSlug(doc.Slug, serviceSlug) + cs_pn, err_pn := utils.ContainsSlug(fmt.Sprintf("%s_%s", providerDetail.ProviderName, doc.Slug), serviceSlug) + if (cs || cs_pn) && err == nil && err_pn == nil { + contentAvailable = true + descriptionSnippet, err := getContentSnippet(httpClient, doc.ID, logger) + if err != nil { + logger.Warnf("Error fetching content snippet for provider doc ID: %s: %v", doc.ID, err) + } + builder.WriteString(fmt.Sprintf("- providerDocID: %s\n- Title: %s\n- Category: %s\n- Description: %s\n---\n", doc.ID, doc.Title, doc.Category, descriptionSnippet)) + } + } + } + + if !contentAvailable { + errMessage := fmt.Sprintf(`finding documentation for service_slug %s, provide a more relevant service_slug if unsure, use the provider_name for its value`, serviceSlug) + return "", utils.LogAndReturnError(logger, errMessage, err) + } + return builder.String(), nil + } + + var matches []providerMatch - contentAvailable := false - for _, doc := range providerDocs.Docs { - if doc.Language == "hcl" && doc.Category == providerDetail.ProviderDataType { - cs, err := utils.ContainsSlug(doc.Slug, serviceSlug) - cs_pn, err_pn := utils.ContainsSlug(fmt.Sprintf("%s_%s", providerDetail.ProviderName, doc.Slug), serviceSlug) - if (cs || cs_pn) && err == nil && err_pn == nil { - contentAvailable = true - descriptionSnippet, err := getContentSnippet(httpClient, doc.ID, logger) - if err != nil { - logger.Warnf("Error fetching content snippet for provider doc ID: %s: %v", doc.ID, err) + for _, pdata := range providerList.Data { + namespace := pdata.Attributes.Namespace + name := pdata.Attributes.Name + tier := pdata.Attributes.Tier + logger.Debugf("search provider entry: namespace=%s name=%s tier=%s", namespace, name, tier) + + // Get docs for this provider. Try the requested version first; if that + // fails (for example the version doesn't exist in this namespace), try + // to resolve the latest version for that namespace/name and retry. + uri := path.Join("providers", namespace, name, providerDetail.ProviderVersion) + response, err := sendRegistryCall(httpClient, "GET", uri, logger) + if err != nil { + // Attempt to fetch the latest provider version for this namespace/name + latestVer, verErr := client.GetLatestProviderVersion(httpClient, namespace, name, logger) + if verErr != nil { + logger.Debugf("skipping provider %s/%s: error fetching docs: %v (also failed to get latest version: %v)", namespace, name, err, verErr) + continue // skip providers we can't fetch + } + uri = path.Join("providers", namespace, name, latestVer) + response, err = sendRegistryCall(httpClient, "GET", uri, logger) + if err != nil { + logger.Debugf("skipping provider %s/%s: error fetching docs with latest version %s: %v", namespace, name, latestVer, err) + continue + } + } + var providerDocs client.ProviderDocs + if err := json.Unmarshal(response, &providerDocs); err != nil { + logger.Debugf("skipping provider %s/%s: error unmarshalling docs: %v", namespace, name, err) + continue + } + logger.Debugf("fetched %d docs for provider %s/%s", len(providerDocs.Docs), namespace, name) + var docMatches []client.ProviderDoc + for _, doc := range providerDocs.Docs { + logger.Tracef("considering doc slug=%s title=%s category=%s language=%s", doc.Slug, doc.Title, doc.Category, doc.Language) + if doc.Language == "hcl" && doc.Category == providerDetail.ProviderDataType { + cs, err := utils.ContainsSlug(doc.Slug, serviceSlug) + cs_pn, err_pn := utils.ContainsSlug(fmt.Sprintf("%s_%s", name, doc.Slug), serviceSlug) + if (cs || cs_pn) && err == nil && err_pn == nil { + logger.Debugf("matched doc %s for provider %s/%s (slug=%s)", doc.ID, namespace, name, doc.Slug) + docMatches = append(docMatches, doc) } - builder.WriteString(fmt.Sprintf("- providerDocID: %s\n- Title: %s\n- Category: %s\n- Description: %s\n---\n", doc.ID, doc.Title, doc.Category, descriptionSnippet)) } } + if len(docMatches) > 0 { + matches = append(matches, providerMatch{ + Namespace: namespace, + Name: name, + Tier: tier, + DocMatch: docMatches, + }) + } } - // Check if the content data is not fulfilled - if !contentAvailable { + if len(matches) == 0 { errMessage := fmt.Sprintf(`finding documentation for service_slug %s, provide a more relevant service_slug if unsure, use the provider_name for its value`, serviceSlug) - return nil, utils.LogAndReturnError(logger, errMessage, err) + return "", utils.LogAndReturnError(logger, errMessage, err) + } + + // Sort matches by tier + sortMatchesByTier(matches) + + var builder strings.Builder + builder.WriteString("Available Documentation (prioritized by provider tier)\n\n") + builder.WriteString("Tier order: official > partner > community\n\n") + for _, match := range matches { + builder.WriteString(fmt.Sprintf("Provider: %s/%s (Tier: %s)\n", match.Namespace, match.Name, match.Tier)) + for _, doc := range match.DocMatch { + descriptionSnippet, err := getContentSnippet(httpClient, doc.ID, logger) + if err != nil { + logger.Warnf("Error fetching content snippet for provider doc ID: %s: %v", doc.ID, err) + } + builder.WriteString(fmt.Sprintf("- providerDocID: %s\n- Title: %s\n- Category: %s\n- Description: %s\n---\n", doc.ID, doc.Title, doc.Category, descriptionSnippet)) + } + builder.WriteString("\n") } - return mcp.NewToolResultText(builder.String()), nil + return builder.String(), nil } func resolveProviderDetails(request mcp.CallToolRequest, httpClient *http.Client, defaultErrorGuide string, logger *log.Logger) (client.ProviderDetail, error) { @@ -214,8 +340,7 @@ func providerDetailsV2(httpClient *http.Client, providerDetail client.ProviderDe return client.GetProviderOverviewDocs(httpClient, providerVersionID, logger) } - uriPrefix := fmt.Sprintf("provider-docs?filter[provider-version]=%s&filter[category]=%s&filter[language]=hcl", - providerVersionID, category) + uriPrefix := fmt.Sprintf("provider-docs?filter[provider-version]=%s&filter[category]=%s&filter[language]=hcl", providerVersionID, category) docs, err := client.SendPaginatedRegistryCall(httpClient, uriPrefix, logger) if err != nil { @@ -270,4 +395,4 @@ func getContentSnippet(httpClient *http.Client, docID string, logger *log.Logger return desc[:300] + "...", nil } return desc, nil -} +} \ No newline at end of file diff --git a/pkg/tools/registry/search_providers_test.go b/pkg/tools/registry/search_providers_test.go new file mode 100644 index 0000000..325bf22 --- /dev/null +++ b/pkg/tools/registry/search_providers_test.go @@ -0,0 +1,59 @@ +package tools + +import ( + "net/http" + "strings" + "testing" + + log "github.com/sirupsen/logrus" + "github.com/hashicorp/terraform-mcp-server/pkg/client" +) + +func TestSearchProvidersPrioritizesOfficial(t *testing.T) { + // Backup original and restore + original := sendRegistryCall + defer func() { sendRegistryCall = original }() + + // Fake responses + sendRegistryCall = func(httpClient *http.Client, method string, uri string, logger *log.Logger, callOptions ...string) ([]byte, error) { + // provider list call + if strings.HasPrefix(uri, "providers?filter[name]=") { + // Return two providers: community then official (unordered) + // minimal JSON with attributes name, namespace, tier + return []byte(`{"data":[{"id":"1","attributes":{"name":"keycloak","namespace":"mrparkers","tier":"community"}},{"id":"2","attributes":{"name":"keycloak","namespace":"keycloak-official","tier":"official"}}]}`), nil + } + + // provider docs calls: uri like providers/{namespace}/{name}/{version} + if strings.HasPrefix(uri, "providers/mrparkers/") { + return []byte(`{"docs":[{"id":"doc1","title":"Keycloak (community)","path":"","slug":"keycloak","category":"resources","language":"hcl"}]}`), nil + } + if strings.HasPrefix(uri, "providers/keycloak-official/") { + return []byte(`{"docs":[{"id":"doc2","title":"Keycloak (official)","path":"","slug":"keycloak","category":"resources","language":"hcl"}]}`), nil + } + + return nil, nil + } + + logger := log.New() + providerDetail := client.ProviderDetail{ + ProviderName: "keycloak", + ProviderNamespace: "", + ProviderVersion: "latest", + ProviderDataType: "resources", + } + + result, err := searchProvidersDocs(http.DefaultClient, providerDetail, "keycloak", "default guide", logger) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + // official provider should appear before community provider in the output + officialIdx := strings.Index(result, "Provider: keycloak-official/keycloak (Tier: official)") + communityIdx := strings.Index(result, "Provider: mrparkers/keycloak (Tier: community)") + if officialIdx == -1 || communityIdx == -1 { + t.Fatalf("expected both providers in result, got: %s", result) + } + if officialIdx > communityIdx { + t.Fatalf("official provider found after community provider; result: %s", result) + } +} diff --git a/pkg/tools/registry/sort_test.go b/pkg/tools/registry/sort_test.go new file mode 100644 index 0000000..6be1fbb --- /dev/null +++ b/pkg/tools/registry/sort_test.go @@ -0,0 +1,19 @@ +package tools + +import ( + "testing" +) + +func TestSortMatchesByTier(t *testing.T) { + matches := []providerMatch{ + {Namespace: "a", Name: "one", Tier: "community"}, + {Namespace: "b", Name: "two", Tier: "partner"}, + {Namespace: "c", Name: "three", Tier: "official"}, + } + + sortMatchesByTier(matches) + + if matches[0].Tier != "official" || matches[1].Tier != "partner" || matches[2].Tier != "community" { + t.Fatalf("unexpected tier order: %v", []string{matches[0].Tier, matches[1].Tier, matches[2].Tier}) + } +}