Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 100 additions & 0 deletions rest-api/api/pkg/api/handler/adminmachine.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
// SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

package handler

import (
"context"
"errors"
"fmt"
"net/http"

"github.com/labstack/echo/v4"
"github.com/rs/zerolog"
tClient "go.temporal.io/sdk/client"

"github.com/NVIDIA/infra-controller/rest-api/api/internal/config"
"github.com/NVIDIA/infra-controller/rest-api/api/pkg/api/handler/util/common"
sc "github.com/NVIDIA/infra-controller/rest-api/api/pkg/client/site"
auth "github.com/NVIDIA/infra-controller/rest-api/auth/pkg/authorization"
cutil "github.com/NVIDIA/infra-controller/rest-api/common/pkg/util"
cdb "github.com/NVIDIA/infra-controller/rest-api/db/pkg/db"
cdbm "github.com/NVIDIA/infra-controller/rest-api/db/pkg/db/model"
)

type adminMachineBase struct {
dbSession *cdb.Session
scp *sc.ClientPool
cfg *config.Config
tracerSpan *cutil.TracerSpan
}

func (b adminMachineBase) authorizeMachine(
ctx context.Context,
c echo.Context,
logger zerolog.Logger,
org string,
dbUser *cdbm.User,
machineID string,
) (tClient.Client, string, *cdbm.Machine, error) {
if dbUser == nil {
logger.Error().Msg("invalid User object found in request context")
return nil, "", nil, cutil.NewAPIErrorResponse(c, http.StatusInternalServerError, "Failed to retrieve current user", nil)
}
if machineID == "" {
return nil, "", nil, cutil.NewAPIErrorResponse(c, http.StatusBadRequest, "Machine ID is required", nil)
}

ok, err := auth.ValidateOrgMembership(dbUser, org)
if !ok {
if err != nil {
logger.Error().Err(err).Msg("error validating org membership for User in request")
} else {
logger.Warn().Msg("could not validate org membership for user, access denied")
}
return nil, "", nil, cutil.NewAPIErrorResponse(c, http.StatusForbidden, fmt.Sprintf("Failed to validate membership for org: %s", org), nil)
}

if ok := auth.ValidateUserRoles(dbUser, org, nil, auth.ProviderAdminRole); !ok {
logger.Warn().Msg("user does not have Provider Admin role, access denied")
return nil, "", nil, cutil.NewAPIErrorResponse(c, http.StatusForbidden, "User does not have Provider Admin role with org", nil)
}

provider, err := common.GetInfrastructureProviderForOrg(ctx, nil, b.dbSession, org)
if err != nil {
logger.Warn().Err(err).Msg("error getting infrastructure provider for org")
return nil, "", nil, cutil.NewAPIErrorResponse(c, http.StatusBadRequest, "Failed to retrieve Infrastructure Provider for org", nil)
}

machine, err := cdbm.NewMachineDAO(b.dbSession).GetByID(ctx, nil, machineID, nil, false)
if err != nil {
if errors.Is(err, cdb.ErrDoesNotExist) {
return nil, "", nil, cutil.NewAPIErrorResponse(c, http.StatusNotFound, "Could not find Machine with specified ID", nil)
}
logger.Error().Err(err).Msg("error retrieving Machine DB entity")
return nil, "", nil, cutil.NewAPIErrorResponse(c, http.StatusInternalServerError, "Could not retrieve Machine", nil)
}

if machine.InfrastructureProviderID != provider.ID {
return nil, "", nil, cutil.NewAPIErrorResponse(c, http.StatusForbidden, "Machine doesn't belong to org's Infrastructure provider", nil)
}

site, err := common.GetSiteFromIDString(ctx, nil, machine.SiteID.String(), b.dbSession)
if err != nil {
if errors.Is(err, cdb.ErrDoesNotExist) || errors.Is(err, common.ErrInvalidID) {
return nil, "", nil, cutil.NewAPIErrorResponse(c, http.StatusBadRequest, "Machine Site does not exist", nil)
}
logger.Error().Err(err).Msg("error retrieving Machine Site from DB")
return nil, "", nil, cutil.NewAPIErrorResponse(c, http.StatusInternalServerError, "Failed to retrieve Machine Site due to DB error", nil)
}
if site.InfrastructureProviderID != provider.ID {
return nil, "", nil, cutil.NewAPIErrorResponse(c, http.StatusForbidden, "Machine Site doesn't belong to current org's Provider", nil)
}

stc, err := b.scp.GetClientByID(site.ID)
if err != nil {
logger.Error().Err(err).Msg("failed to retrieve Temporal client for Site")
return nil, "", nil, cutil.NewAPIErrorResponse(c, http.StatusInternalServerError, "Failed to retrieve client for Site", nil)
}
return stc, site.ID.String(), machine, nil
}
204 changes: 204 additions & 0 deletions rest-api/api/pkg/api/handler/adminops_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
// SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

package handler

import (
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"testing"

"github.com/labstack/echo/v4"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
tmocks "go.temporal.io/sdk/mocks"
"google.golang.org/protobuf/encoding/protojson"
"google.golang.org/protobuf/proto"

"github.com/NVIDIA/infra-controller/rest-api/api/pkg/api/handler/util/common"
"github.com/NVIDIA/infra-controller/rest-api/api/pkg/api/model"
sc "github.com/NVIDIA/infra-controller/rest-api/api/pkg/client/site"
authz "github.com/NVIDIA/infra-controller/rest-api/auth/pkg/authorization"
"github.com/NVIDIA/infra-controller/rest-api/common/pkg/coreproxy"
cutil "github.com/NVIDIA/infra-controller/rest-api/common/pkg/util"
cdbm "github.com/NVIDIA/infra-controller/rest-api/db/pkg/db/model"
cwssaws "github.com/NVIDIA/infra-controller/rest-api/workflow-schema/schema/site-agent/workflows/v1"
)

type adminOpsHandlerFixture struct {
org string
siteID string
machineID string
user interface{}
handler echo.HandlerFunc
proxiedReq *coreproxy.Request
}

func newAdminOpsHandlerFixture(t *testing.T, handlerName string, response proto.Message) adminOpsHandlerFixture {
t.Helper()

dbSession := common.TestInitDB(t)
t.Cleanup(dbSession.Close)
common.TestSetupSchema(t, dbSession)

org := "test-org"
user := common.TestBuildUser(t, dbSession, "test-starfleet-id", org, []string{authz.ProviderAdminRole})
ip := common.TestBuildInfrastructureProvider(t, dbSession, "Test Infrastructure Provider", org, user)
site := common.TestBuildSite(t, dbSession, ip, "Test Site", user)
it := common.TestBuildInstanceType(t, dbSession, "test-instance-type", cutil.GetPtr(site.ID), site, nil, user)
machine := common.TestBuildMachine(t, dbSession, ip, site, &it.ID, cutil.GetPtr("test-controller-machine-type"), cdbm.MachineStatusReady)

proxiedReq := &coreproxy.Request{}
wrun := &tmocks.WorkflowRun{}
wrun.On("Get", mock.Anything, mock.Anything).Run(func(args mock.Arguments) {
if response == nil {
return
}
out := args.Get(1).(*coreproxy.Response)
respJSON, err := protojson.Marshal(response)
require.NoError(t, err)
out.ResponseJSON = respJSON
}).Return(nil)

tsc := &tmocks.Client{}
tsc.On(
"ExecuteWorkflow",
mock.Anything,
mock.Anything,
coreproxy.WorkflowName,
mock.MatchedBy(func(req coreproxy.Request) bool {
*proxiedReq = req
return true
}),
).Return(wrun, nil)

scp := sc.NewClientPool(nil)
scp.IDClientMap[site.ID.String()] = tsc

cfg := common.GetTestConfig()
var handler echo.HandlerFunc
switch handlerName {
case "health-list":
h := NewListMachineHealthReportHandler(dbSession, scp, cfg)
handler = h.Handle
case "health-insert":
h := NewInsertMachineHealthReportHandler(dbSession, scp, cfg)
handler = h.Handle
case "health-remove":
h := NewRemoveMachineHealthReportHandler(dbSession, scp, cfg)
handler = h.Handle
default:
t.Fatalf("unknown handler %q", handlerName)
}

return adminOpsHandlerFixture{
org: org,
siteID: site.ID.String(),
machineID: machine.ID,
user: user,
handler: handler,
proxiedReq: proxiedReq,
}
}

func (f adminOpsHandlerFixture) request(t *testing.T, method, target string, body any, source string) *httptest.ResponseRecorder {
t.Helper()

var reqBody string
if body != nil {
bodyBytes, err := json.Marshal(body)
require.NoError(t, err)
reqBody = string(bodyBytes)
}

e := echo.New()
req := httptest.NewRequest(method, target, strings.NewReader(reqBody))
if body != nil {
req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON)
}
rec := httptest.NewRecorder()
ec := e.NewContext(req, rec)
names := []string{"orgName", "machineId"}
values := []string{f.org, f.machineID}
if source != "" {
names = append(names, "source")
values = append(values, source)
}
ec.SetParamNames(names...)
ec.SetParamValues(values...)
ec.Set("user", f.user)

require.NoError(t, f.handler(ec))
return rec
}

func TestListMachineHealthReportHandlerProxiesRequest(t *testing.T) {
fixture := newAdminOpsHandlerFixture(t, "health-list", &cwssaws.ListHealthReportResponse{
HealthReportEntries: []*cwssaws.HealthReportEntry{
{
Mode: cwssaws.HealthReportApplyMode_Merge,
Report: &cwssaws.HealthReport{
Source: "overrides.sre",
Alerts: []*cwssaws.HealthProbeAlert{{Id: "probe.alert", Message: "forced unhealthy"}},
},
},
},
})

rec := fixture.request(t, http.MethodGet, "/", nil, "")
assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, cwssaws.Forge_ListMachineHealthReports_FullMethodName, fixture.proxiedReq.FullMethod)
assert.Empty(t, fixture.proxiedReq.EncryptedSecrets)

var coreReq cwssaws.MachineId
require.NoError(t, protojson.Unmarshal(fixture.proxiedReq.RequestJSON, &coreReq))
assert.Equal(t, fixture.machineID, coreReq.GetId())
assert.Contains(t, rec.Body.String(), "overrides.sre")
assert.NotContains(t, rec.Body.String(), "password")
}

func TestInsertMachineHealthReportHandlerProxiesRequest(t *testing.T) {
fixture := newAdminOpsHandlerFixture(t, "health-insert", nil)
req := model.APIMachineHealthReportEntry{
Source: "overrides.sre",
Mode: model.MachineHealthReportModeMerge,
Alerts: []model.APIMachineHealthProbeAlert{{ID: "probe.alert", Message: "forced unhealthy"}},
}

rec := fixture.request(t, http.MethodPut, "/", req, "")
assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, cwssaws.Forge_InsertMachineHealthReport_FullMethodName, fixture.proxiedReq.FullMethod)
assert.Empty(t, fixture.proxiedReq.EncryptedSecrets)

var coreReq cwssaws.InsertMachineHealthReportRequest
require.NoError(t, protojson.Unmarshal(fixture.proxiedReq.RequestJSON, &coreReq))
assert.Equal(t, fixture.machineID, coreReq.GetMachineId().GetId())
assert.Equal(t, "overrides.sre", coreReq.GetHealthReportEntry().GetReport().GetSource())
assert.NotContains(t, rec.Body.String(), "password")
}

func TestInsertMachineHealthReportHandlerRejectsInvalidRequest(t *testing.T) {
fixture := newAdminOpsHandlerFixture(t, "health-insert", nil)

rec := fixture.request(t, http.MethodPut, "/", model.APIMachineHealthReportEntry{Mode: model.MachineHealthReportModeMerge}, "")
assert.Equal(t, http.StatusBadRequest, rec.Code)
assert.Empty(t, fixture.proxiedReq.FullMethod)
}

func TestRemoveMachineHealthReportHandlerProxiesRequest(t *testing.T) {
fixture := newAdminOpsHandlerFixture(t, "health-remove", nil)

rec := fixture.request(t, http.MethodDelete, "/", nil, "overrides.sre")
assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, cwssaws.Forge_RemoveMachineHealthReport_FullMethodName, fixture.proxiedReq.FullMethod)
assert.Empty(t, fixture.proxiedReq.EncryptedSecrets)

var coreReq cwssaws.RemoveMachineHealthReportRequest
require.NoError(t, protojson.Unmarshal(fixture.proxiedReq.RequestJSON, &coreReq))
assert.Equal(t, fixture.machineID, coreReq.GetMachineId().GetId())
assert.Equal(t, "overrides.sre", coreReq.GetSource())
assert.NotContains(t, rec.Body.String(), "password")
}
85 changes: 85 additions & 0 deletions rest-api/api/pkg/api/model/adminops_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
// SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

package model

import (
"encoding/json"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

cwssaws "github.com/NVIDIA/infra-controller/rest-api/workflow-schema/schema/site-agent/workflows/v1"
)

func adminOpsStrPtr(s string) *string { return &s }

func TestAPIMachineHealthReportEntryValidateAndToProto(t *testing.T) {
observedAt := "2026-06-24T12:00:00Z"
inAlertSince := "2026-06-24T11:00:00Z"
req := APIMachineHealthReportEntry{
Source: "overrides.sre",
TriggeredBy: adminOpsStrPtr("operator"),
ObservedAt: &observedAt,
Mode: MachineHealthReportModeReplace,
Successes: []APIMachineHealthProbeSuccess{
{ID: "probe.ok", Target: adminOpsStrPtr("host")},
},
Alerts: []APIMachineHealthProbeAlert{
{
ID: "probe.alert",
Target: adminOpsStrPtr("gpu0"),
InAlertSince: &inAlertSince,
Message: "forced unhealthy",
TenantMessage: adminOpsStrPtr("maintenance"),
Classifications: []string{"maintenance"},
},
},
}
require.NoError(t, req.Validate())

protoReq := req.ToProto("machine-1")
assert.Equal(t, "machine-1", protoReq.GetMachineId().GetId())
entry := protoReq.GetHealthReportEntry()
require.NotNil(t, entry)
assert.Equal(t, cwssaws.HealthReportApplyMode_Replace, entry.GetMode())
report := entry.GetReport()
require.NotNil(t, report)
assert.Equal(t, "overrides.sre", report.GetSource())
assert.Equal(t, "operator", report.GetTriggeredBy())
assert.Equal(t, observedAt, report.GetObservedAt().AsTime().Format("2006-01-02T15:04:05Z07:00"))
require.Len(t, report.GetSuccesses(), 1)
assert.Equal(t, "probe.ok", report.GetSuccesses()[0].GetId())
require.Len(t, report.GetAlerts(), 1)
assert.Equal(t, "probe.alert", report.GetAlerts()[0].GetId())
assert.Equal(t, inAlertSince, report.GetAlerts()[0].GetInAlertSince().AsTime().Format("2006-01-02T15:04:05Z07:00"))

assert.Error(t, (&APIMachineHealthReportEntry{Mode: MachineHealthReportModeMerge}).Validate())
assert.Error(t, (&APIMachineHealthReportEntry{Source: "source", Mode: "merge"}).Validate())
assert.Error(t, (&APIMachineHealthReportEntry{Source: "source", Mode: MachineHealthReportModeMerge, ObservedAt: adminOpsStrPtr("bad-time")}).Validate())
assert.Error(t, (&APIMachineHealthReportEntry{Source: "source", Mode: MachineHealthReportModeMerge, Alerts: []APIMachineHealthProbeAlert{{ID: "alert"}}}).Validate())
}

func TestAPIMachineHealthReportListResponseFromProto(t *testing.T) {
resp := NewAPIMachineHealthReportListResponse("machine-1", &cwssaws.ListHealthReportResponse{
HealthReportEntries: []*cwssaws.HealthReportEntry{
{
Mode: cwssaws.HealthReportApplyMode_Merge,
Report: &cwssaws.HealthReport{
Source: "overrides.sre",
Alerts: []*cwssaws.HealthProbeAlert{{Id: "probe.alert", Message: "forced unhealthy"}},
},
},
},
})

assert.Equal(t, "machine-1", resp.MachineID)
require.Len(t, resp.HealthReportEntries, 1)
assert.Equal(t, MachineHealthReportModeMerge, resp.HealthReportEntries[0].Mode)
assert.Equal(t, "overrides.sre", resp.HealthReportEntries[0].Source)

body, err := json.Marshal(resp)
require.NoError(t, err)
assert.NotContains(t, string(body), "password")
}
Loading