diff --git a/rest-api/api/pkg/api/handler/adminmachine.go b/rest-api/api/pkg/api/handler/adminmachine.go new file mode 100644 index 0000000000..93a678a535 --- /dev/null +++ b/rest-api/api/pkg/api/handler/adminmachine.go @@ -0,0 +1,100 @@ +// SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +package handler + +import ( + "context" + "errors" + "fmt" + "net/http" + + "github.com/labstack/echo/v4" + "github.com/rs/zerolog" + tClient "go.temporal.io/sdk/client" + + "github.com/NVIDIA/infra-controller/rest-api/api/internal/config" + "github.com/NVIDIA/infra-controller/rest-api/api/pkg/api/handler/util/common" + sc "github.com/NVIDIA/infra-controller/rest-api/api/pkg/client/site" + auth "github.com/NVIDIA/infra-controller/rest-api/auth/pkg/authorization" + cutil "github.com/NVIDIA/infra-controller/rest-api/common/pkg/util" + cdb "github.com/NVIDIA/infra-controller/rest-api/db/pkg/db" + cdbm "github.com/NVIDIA/infra-controller/rest-api/db/pkg/db/model" +) + +type adminMachineBase struct { + dbSession *cdb.Session + scp *sc.ClientPool + cfg *config.Config + tracerSpan *cutil.TracerSpan +} + +func (b adminMachineBase) authorizeMachine( + ctx context.Context, + c echo.Context, + logger zerolog.Logger, + org string, + dbUser *cdbm.User, + machineID string, +) (tClient.Client, string, *cdbm.Machine, error) { + if dbUser == nil { + logger.Error().Msg("invalid User object found in request context") + return nil, "", nil, cutil.NewAPIErrorResponse(c, http.StatusInternalServerError, "Failed to retrieve current user", nil) + } + if machineID == "" { + return nil, "", nil, cutil.NewAPIErrorResponse(c, http.StatusBadRequest, "Machine ID is required", nil) + } + + ok, err := auth.ValidateOrgMembership(dbUser, org) + if !ok { + if err != nil { + logger.Error().Err(err).Msg("error validating org membership for User in request") + } else { + logger.Warn().Msg("could not validate org membership for user, access denied") + } + return nil, "", nil, cutil.NewAPIErrorResponse(c, http.StatusForbidden, fmt.Sprintf("Failed to validate membership for org: %s", org), nil) + } + + if ok := auth.ValidateUserRoles(dbUser, org, nil, auth.ProviderAdminRole); !ok { + logger.Warn().Msg("user does not have Provider Admin role, access denied") + return nil, "", nil, cutil.NewAPIErrorResponse(c, http.StatusForbidden, "User does not have Provider Admin role with org", nil) + } + + provider, err := common.GetInfrastructureProviderForOrg(ctx, nil, b.dbSession, org) + if err != nil { + logger.Warn().Err(err).Msg("error getting infrastructure provider for org") + return nil, "", nil, cutil.NewAPIErrorResponse(c, http.StatusBadRequest, "Failed to retrieve Infrastructure Provider for org", nil) + } + + machine, err := cdbm.NewMachineDAO(b.dbSession).GetByID(ctx, nil, machineID, nil, false) + if err != nil { + if errors.Is(err, cdb.ErrDoesNotExist) { + return nil, "", nil, cutil.NewAPIErrorResponse(c, http.StatusNotFound, "Could not find Machine with specified ID", nil) + } + logger.Error().Err(err).Msg("error retrieving Machine DB entity") + return nil, "", nil, cutil.NewAPIErrorResponse(c, http.StatusInternalServerError, "Could not retrieve Machine", nil) + } + + if machine.InfrastructureProviderID != provider.ID { + return nil, "", nil, cutil.NewAPIErrorResponse(c, http.StatusForbidden, "Machine doesn't belong to org's Infrastructure provider", nil) + } + + site, err := common.GetSiteFromIDString(ctx, nil, machine.SiteID.String(), b.dbSession) + if err != nil { + if errors.Is(err, cdb.ErrDoesNotExist) || errors.Is(err, common.ErrInvalidID) { + return nil, "", nil, cutil.NewAPIErrorResponse(c, http.StatusBadRequest, "Machine Site does not exist", nil) + } + logger.Error().Err(err).Msg("error retrieving Machine Site from DB") + return nil, "", nil, cutil.NewAPIErrorResponse(c, http.StatusInternalServerError, "Failed to retrieve Machine Site due to DB error", nil) + } + if site.InfrastructureProviderID != provider.ID { + return nil, "", nil, cutil.NewAPIErrorResponse(c, http.StatusForbidden, "Machine Site doesn't belong to current org's Provider", nil) + } + + stc, err := b.scp.GetClientByID(site.ID) + if err != nil { + logger.Error().Err(err).Msg("failed to retrieve Temporal client for Site") + return nil, "", nil, cutil.NewAPIErrorResponse(c, http.StatusInternalServerError, "Failed to retrieve client for Site", nil) + } + return stc, site.ID.String(), machine, nil +} diff --git a/rest-api/api/pkg/api/handler/adminops_test.go b/rest-api/api/pkg/api/handler/adminops_test.go new file mode 100644 index 0000000000..25d6f7cd1b --- /dev/null +++ b/rest-api/api/pkg/api/handler/adminops_test.go @@ -0,0 +1,204 @@ +// SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +package handler + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/labstack/echo/v4" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + tmocks "go.temporal.io/sdk/mocks" + "google.golang.org/protobuf/encoding/protojson" + "google.golang.org/protobuf/proto" + + "github.com/NVIDIA/infra-controller/rest-api/api/pkg/api/handler/util/common" + "github.com/NVIDIA/infra-controller/rest-api/api/pkg/api/model" + sc "github.com/NVIDIA/infra-controller/rest-api/api/pkg/client/site" + authz "github.com/NVIDIA/infra-controller/rest-api/auth/pkg/authorization" + "github.com/NVIDIA/infra-controller/rest-api/common/pkg/coreproxy" + cutil "github.com/NVIDIA/infra-controller/rest-api/common/pkg/util" + cdbm "github.com/NVIDIA/infra-controller/rest-api/db/pkg/db/model" + cwssaws "github.com/NVIDIA/infra-controller/rest-api/workflow-schema/schema/site-agent/workflows/v1" +) + +type adminOpsHandlerFixture struct { + org string + siteID string + machineID string + user interface{} + handler echo.HandlerFunc + proxiedReq *coreproxy.Request +} + +func newAdminOpsHandlerFixture(t *testing.T, handlerName string, response proto.Message) adminOpsHandlerFixture { + t.Helper() + + dbSession := common.TestInitDB(t) + t.Cleanup(dbSession.Close) + common.TestSetupSchema(t, dbSession) + + org := "test-org" + user := common.TestBuildUser(t, dbSession, "test-starfleet-id", org, []string{authz.ProviderAdminRole}) + ip := common.TestBuildInfrastructureProvider(t, dbSession, "Test Infrastructure Provider", org, user) + site := common.TestBuildSite(t, dbSession, ip, "Test Site", user) + it := common.TestBuildInstanceType(t, dbSession, "test-instance-type", cutil.GetPtr(site.ID), site, nil, user) + machine := common.TestBuildMachine(t, dbSession, ip, site, &it.ID, cutil.GetPtr("test-controller-machine-type"), cdbm.MachineStatusReady) + + proxiedReq := &coreproxy.Request{} + wrun := &tmocks.WorkflowRun{} + wrun.On("Get", mock.Anything, mock.Anything).Run(func(args mock.Arguments) { + if response == nil { + return + } + out := args.Get(1).(*coreproxy.Response) + respJSON, err := protojson.Marshal(response) + require.NoError(t, err) + out.ResponseJSON = respJSON + }).Return(nil) + + tsc := &tmocks.Client{} + tsc.On( + "ExecuteWorkflow", + mock.Anything, + mock.Anything, + coreproxy.WorkflowName, + mock.MatchedBy(func(req coreproxy.Request) bool { + *proxiedReq = req + return true + }), + ).Return(wrun, nil) + + scp := sc.NewClientPool(nil) + scp.IDClientMap[site.ID.String()] = tsc + + cfg := common.GetTestConfig() + var handler echo.HandlerFunc + switch handlerName { + case "health-list": + h := NewListMachineHealthReportHandler(dbSession, scp, cfg) + handler = h.Handle + case "health-insert": + h := NewInsertMachineHealthReportHandler(dbSession, scp, cfg) + handler = h.Handle + case "health-remove": + h := NewRemoveMachineHealthReportHandler(dbSession, scp, cfg) + handler = h.Handle + default: + t.Fatalf("unknown handler %q", handlerName) + } + + return adminOpsHandlerFixture{ + org: org, + siteID: site.ID.String(), + machineID: machine.ID, + user: user, + handler: handler, + proxiedReq: proxiedReq, + } +} + +func (f adminOpsHandlerFixture) request(t *testing.T, method, target string, body any, source string) *httptest.ResponseRecorder { + t.Helper() + + var reqBody string + if body != nil { + bodyBytes, err := json.Marshal(body) + require.NoError(t, err) + reqBody = string(bodyBytes) + } + + e := echo.New() + req := httptest.NewRequest(method, target, strings.NewReader(reqBody)) + if body != nil { + req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON) + } + rec := httptest.NewRecorder() + ec := e.NewContext(req, rec) + names := []string{"orgName", "machineId"} + values := []string{f.org, f.machineID} + if source != "" { + names = append(names, "source") + values = append(values, source) + } + ec.SetParamNames(names...) + ec.SetParamValues(values...) + ec.Set("user", f.user) + + require.NoError(t, f.handler(ec)) + return rec +} + +func TestListMachineHealthReportHandlerProxiesRequest(t *testing.T) { + fixture := newAdminOpsHandlerFixture(t, "health-list", &cwssaws.ListHealthReportResponse{ + HealthReportEntries: []*cwssaws.HealthReportEntry{ + { + Mode: cwssaws.HealthReportApplyMode_Merge, + Report: &cwssaws.HealthReport{ + Source: "overrides.sre", + Alerts: []*cwssaws.HealthProbeAlert{{Id: "probe.alert", Message: "forced unhealthy"}}, + }, + }, + }, + }) + + rec := fixture.request(t, http.MethodGet, "/", nil, "") + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, cwssaws.Forge_ListMachineHealthReports_FullMethodName, fixture.proxiedReq.FullMethod) + assert.Empty(t, fixture.proxiedReq.EncryptedSecrets) + + var coreReq cwssaws.MachineId + require.NoError(t, protojson.Unmarshal(fixture.proxiedReq.RequestJSON, &coreReq)) + assert.Equal(t, fixture.machineID, coreReq.GetId()) + assert.Contains(t, rec.Body.String(), "overrides.sre") + assert.NotContains(t, rec.Body.String(), "password") +} + +func TestInsertMachineHealthReportHandlerProxiesRequest(t *testing.T) { + fixture := newAdminOpsHandlerFixture(t, "health-insert", nil) + req := model.APIMachineHealthReportEntry{ + Source: "overrides.sre", + Mode: model.MachineHealthReportModeMerge, + Alerts: []model.APIMachineHealthProbeAlert{{ID: "probe.alert", Message: "forced unhealthy"}}, + } + + rec := fixture.request(t, http.MethodPut, "/", req, "") + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, cwssaws.Forge_InsertMachineHealthReport_FullMethodName, fixture.proxiedReq.FullMethod) + assert.Empty(t, fixture.proxiedReq.EncryptedSecrets) + + var coreReq cwssaws.InsertMachineHealthReportRequest + require.NoError(t, protojson.Unmarshal(fixture.proxiedReq.RequestJSON, &coreReq)) + assert.Equal(t, fixture.machineID, coreReq.GetMachineId().GetId()) + assert.Equal(t, "overrides.sre", coreReq.GetHealthReportEntry().GetReport().GetSource()) + assert.NotContains(t, rec.Body.String(), "password") +} + +func TestInsertMachineHealthReportHandlerRejectsInvalidRequest(t *testing.T) { + fixture := newAdminOpsHandlerFixture(t, "health-insert", nil) + + rec := fixture.request(t, http.MethodPut, "/", model.APIMachineHealthReportEntry{Mode: model.MachineHealthReportModeMerge}, "") + assert.Equal(t, http.StatusBadRequest, rec.Code) + assert.Empty(t, fixture.proxiedReq.FullMethod) +} + +func TestRemoveMachineHealthReportHandlerProxiesRequest(t *testing.T) { + fixture := newAdminOpsHandlerFixture(t, "health-remove", nil) + + rec := fixture.request(t, http.MethodDelete, "/", nil, "overrides.sre") + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, cwssaws.Forge_RemoveMachineHealthReport_FullMethodName, fixture.proxiedReq.FullMethod) + assert.Empty(t, fixture.proxiedReq.EncryptedSecrets) + + var coreReq cwssaws.RemoveMachineHealthReportRequest + require.NoError(t, protojson.Unmarshal(fixture.proxiedReq.RequestJSON, &coreReq)) + assert.Equal(t, fixture.machineID, coreReq.GetMachineId().GetId()) + assert.Equal(t, "overrides.sre", coreReq.GetSource()) + assert.NotContains(t, rec.Body.String(), "password") +} diff --git a/rest-api/api/pkg/api/model/adminops_test.go b/rest-api/api/pkg/api/model/adminops_test.go new file mode 100644 index 0000000000..d9e1e20ffb --- /dev/null +++ b/rest-api/api/pkg/api/model/adminops_test.go @@ -0,0 +1,85 @@ +// SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +package model + +import ( + "encoding/json" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + cwssaws "github.com/NVIDIA/infra-controller/rest-api/workflow-schema/schema/site-agent/workflows/v1" +) + +func adminOpsStrPtr(s string) *string { return &s } + +func TestAPIMachineHealthReportEntryValidateAndToProto(t *testing.T) { + observedAt := "2026-06-24T12:00:00Z" + inAlertSince := "2026-06-24T11:00:00Z" + req := APIMachineHealthReportEntry{ + Source: "overrides.sre", + TriggeredBy: adminOpsStrPtr("operator"), + ObservedAt: &observedAt, + Mode: MachineHealthReportModeReplace, + Successes: []APIMachineHealthProbeSuccess{ + {ID: "probe.ok", Target: adminOpsStrPtr("host")}, + }, + Alerts: []APIMachineHealthProbeAlert{ + { + ID: "probe.alert", + Target: adminOpsStrPtr("gpu0"), + InAlertSince: &inAlertSince, + Message: "forced unhealthy", + TenantMessage: adminOpsStrPtr("maintenance"), + Classifications: []string{"maintenance"}, + }, + }, + } + require.NoError(t, req.Validate()) + + protoReq := req.ToProto("machine-1") + assert.Equal(t, "machine-1", protoReq.GetMachineId().GetId()) + entry := protoReq.GetHealthReportEntry() + require.NotNil(t, entry) + assert.Equal(t, cwssaws.HealthReportApplyMode_Replace, entry.GetMode()) + report := entry.GetReport() + require.NotNil(t, report) + assert.Equal(t, "overrides.sre", report.GetSource()) + assert.Equal(t, "operator", report.GetTriggeredBy()) + assert.Equal(t, observedAt, report.GetObservedAt().AsTime().Format("2006-01-02T15:04:05Z07:00")) + require.Len(t, report.GetSuccesses(), 1) + assert.Equal(t, "probe.ok", report.GetSuccesses()[0].GetId()) + require.Len(t, report.GetAlerts(), 1) + assert.Equal(t, "probe.alert", report.GetAlerts()[0].GetId()) + assert.Equal(t, inAlertSince, report.GetAlerts()[0].GetInAlertSince().AsTime().Format("2006-01-02T15:04:05Z07:00")) + + assert.Error(t, (&APIMachineHealthReportEntry{Mode: MachineHealthReportModeMerge}).Validate()) + assert.Error(t, (&APIMachineHealthReportEntry{Source: "source", Mode: "merge"}).Validate()) + assert.Error(t, (&APIMachineHealthReportEntry{Source: "source", Mode: MachineHealthReportModeMerge, ObservedAt: adminOpsStrPtr("bad-time")}).Validate()) + assert.Error(t, (&APIMachineHealthReportEntry{Source: "source", Mode: MachineHealthReportModeMerge, Alerts: []APIMachineHealthProbeAlert{{ID: "alert"}}}).Validate()) +} + +func TestAPIMachineHealthReportListResponseFromProto(t *testing.T) { + resp := NewAPIMachineHealthReportListResponse("machine-1", &cwssaws.ListHealthReportResponse{ + HealthReportEntries: []*cwssaws.HealthReportEntry{ + { + Mode: cwssaws.HealthReportApplyMode_Merge, + Report: &cwssaws.HealthReport{ + Source: "overrides.sre", + Alerts: []*cwssaws.HealthProbeAlert{{Id: "probe.alert", Message: "forced unhealthy"}}, + }, + }, + }, + }) + + assert.Equal(t, "machine-1", resp.MachineID) + require.Len(t, resp.HealthReportEntries, 1) + assert.Equal(t, MachineHealthReportModeMerge, resp.HealthReportEntries[0].Mode) + assert.Equal(t, "overrides.sre", resp.HealthReportEntries[0].Source) + + body, err := json.Marshal(resp) + require.NoError(t, err) + assert.NotContains(t, string(body), "password") +} diff --git a/rest-api/api/pkg/api/routes.go b/rest-api/api/pkg/api/routes.go index 5ee2200b57..a532ea6127 100644 --- a/rest-api/api/pkg/api/routes.go +++ b/rest-api/api/pkg/api/routes.go @@ -546,6 +546,21 @@ func NewAPIRoutes(dbSession *cdb.Session, tc tClient.Client, tnc tClient.Namespa Method: http.MethodGet, Handler: apiHandler.NewGetMachineHandler(dbSession, tc, cfg), }, + { + Path: apiPathPrefix + "/machine/:machineId/health-report", + Method: http.MethodGet, + Handler: apiHandler.NewListMachineHealthReportHandler(dbSession, scp, cfg), + }, + { + Path: apiPathPrefix + "/machine/:machineId/health-report", + Method: http.MethodPut, + Handler: apiHandler.NewInsertMachineHealthReportHandler(dbSession, scp, cfg), + }, + { + Path: apiPathPrefix + "/machine/:machineId/health-report/:source", + Method: http.MethodDelete, + Handler: apiHandler.NewRemoveMachineHealthReportHandler(dbSession, scp, cfg), + }, { Path: apiPathPrefix + "/machine/:id", Method: http.MethodPatch,