Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions server/db/pg/driver.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ import { drizzle as DrizzlePostgres } from "drizzle-orm/node-postgres";
import { Pool } from "pg";
import { readConfigFile } from "@server/lib/readConfigFile";
import { withReplicas } from "drizzle-orm/pg-core";
import * as schema from "./schema";

function createDb() {
const config = readConfigFile();
Expand Down Expand Up @@ -45,7 +46,7 @@ function createDb() {
const replicas = [];

if (!replicaConnections.length) {
replicas.push(DrizzlePostgres(primaryPool));
replicas.push(DrizzlePostgres(primaryPool, { schema }));
} else {
for (const conn of replicaConnections) {
const replicaPool = new Pool({
Expand All @@ -54,11 +55,11 @@ function createDb() {
idleTimeoutMillis: 30000,
connectionTimeoutMillis: 2000,
});
replicas.push(DrizzlePostgres(replicaPool));
replicas.push(DrizzlePostgres(replicaPool, { schema }));
}
}

return withReplicas(DrizzlePostgres(primaryPool), replicas as any);
return withReplicas(DrizzlePostgres(primaryPool, { schema }), replicas as any);
}

export const db = createDb();
Expand Down
184 changes: 136 additions & 48 deletions server/routers/badger/verifySession.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,13 @@ import {
getResourceRules
} from "@server/db/queries/verifySessionQueries";
import {
db,
Resource,
ResourceAccessToken,
ResourcePassword,
ResourcePincode,
ResourceRule,
roles,
sessions,
users
} from "@server/db";
Expand All @@ -33,6 +35,7 @@ import NodeCache from "node-cache";
import { z } from "zod";
import { fromError } from "zod-validation-error";
import { getCountryCodeForIp } from "@server/lib";
import { eq } from "drizzle-orm";

// We'll see if this speeds anything up
const cache = new NodeCache({
Expand Down Expand Up @@ -60,12 +63,14 @@ type BasicUserData = {
username: string;
email: string | null;
name: string | null;
role: string | null;
};

export type VerifyUserResponse = {
valid: boolean;
redirectUrl?: string;
userData?: BasicUserData;
headers?: Record<string, string>;
};

export async function verifyResourceSession(
Expand Down Expand Up @@ -99,23 +104,23 @@ export async function verifyResourceSession(

const clientIp = requestIp
? (() => {
logger.debug("Request IP:", { requestIp });
if (requestIp.startsWith("[") && requestIp.includes("]")) {
// if brackets are found, extract the IPv6 address from between the brackets
const ipv6Match = requestIp.match(/\[(.*?)\]/);
if (ipv6Match) {
return ipv6Match[1];
}
}

// ivp4
// split at last colon
const lastColonIndex = requestIp.lastIndexOf(":");
if (lastColonIndex !== -1) {
return requestIp.substring(0, lastColonIndex);
}
return requestIp;
})()
logger.debug("Request IP:", { requestIp });
if (requestIp.startsWith("[") && requestIp.includes("]")) {
// if brackets are found, extract the IPv6 address from between the brackets
const ipv6Match = requestIp.match(/\[(.*?)\]/);
if (ipv6Match) {
return ipv6Match[1];
}
}

// ivp4
// split at last colon
const lastColonIndex = requestIp.lastIndexOf(":");
if (lastColonIndex !== -1) {
return requestIp.substring(0, lastColonIndex);
}
return requestIp;
})()
: undefined;

logger.debug("Client IP:", { clientIp });
Expand All @@ -130,10 +135,10 @@ export async function verifyResourceSession(
const resourceCacheKey = `resource:${cleanHost}`;
let resourceData:
| {
resource: Resource | null;
pincode: ResourcePincode | null;
password: ResourcePassword | null;
}
resource: Resource | null;
pincode: ResourcePincode | null;
password: ResourcePassword | null;
}
| undefined = cache.get(resourceCacheKey);

if (!resourceData) {
Expand Down Expand Up @@ -172,7 +177,8 @@ export async function verifyResourceSession(

if (action == "ACCEPT") {
logger.debug("Resource allowed by rule");
return allowed(res);
const interpolatedHeaders = interpolateHeaders(resource.headers, undefined);
return allowed(res, undefined, interpolatedHeaders);
} else if (action == "DROP") {
logger.debug("Resource denied by rule");
return notAllowed(res);
Expand All @@ -193,7 +199,8 @@ export async function verifyResourceSession(
!resource.emailWhitelistEnabled
) {
logger.debug("Resource allowed because no auth");
return allowed(res);
const interpolatedHeaders = interpolateHeaders(resource.headers, undefined);
return allowed(res, undefined, interpolatedHeaders);
}

let endpoint: string;
Expand All @@ -213,21 +220,21 @@ export async function verifyResourceSession(
if (
headers &&
headers[
config.getRawConfig().server.resource_access_token_headers.id
config.getRawConfig().server.resource_access_token_headers.id
] &&
headers[
config.getRawConfig().server.resource_access_token_headers.token
config.getRawConfig().server.resource_access_token_headers.token
]
) {
const accessTokenId =
headers[
config.getRawConfig().server.resource_access_token_headers
.id
config.getRawConfig().server.resource_access_token_headers
.id
];
const accessToken =
headers[
config.getRawConfig().server.resource_access_token_headers
.token
config.getRawConfig().server.resource_access_token_headers
.token
];

const { valid, error, tokenItem } = await verifyResourceAccessToken(
Expand All @@ -253,7 +260,8 @@ export async function verifyResourceSession(
}

if (valid && tokenItem) {
return allowed(res);
const interpolatedHeaders = interpolateHeaders(resource.headers, undefined);
return allowed(res, undefined, interpolatedHeaders);
}
}

Expand Down Expand Up @@ -289,7 +297,8 @@ export async function verifyResourceSession(
}

if (valid && tokenItem) {
return allowed(res);
const interpolatedHeaders = interpolateHeaders(resource.headers, undefined);
return allowed(res, undefined, interpolatedHeaders);
}
}

Expand Down Expand Up @@ -323,6 +332,7 @@ export async function verifyResourceSession(
cache.set(sessionCacheKey, resourceSession);
}


if (resourceSession?.isRequestToken) {
logger.debug(
"Resource not allowed because session is a temporary request token"
Expand All @@ -339,17 +349,17 @@ export async function verifyResourceSession(

if (resourceSession) {
if (pincode && resourceSession.pincodeId) {
logger.debug(
"Resource allowed because pincode session is valid"
);
return allowed(res);
logger.debug("Resource allowed because pincode session is valid");
const interpolatedHeaders = interpolateHeaders(resource.headers, undefined);
return allowed(res, undefined, interpolatedHeaders);
}

if (password && resourceSession.passwordId) {
logger.debug(
"Resource allowed because password session is valid"
);
return allowed(res);
const interpolatedHeaders = interpolateHeaders(resource.headers, undefined);
return allowed(res, undefined, interpolatedHeaders);
}

if (
Expand All @@ -359,20 +369,22 @@ export async function verifyResourceSession(
logger.debug(
"Resource allowed because whitelist session is valid"
);
return allowed(res);
const interpolatedHeaders = interpolateHeaders(resource.headers, undefined);
return allowed(res, undefined, interpolatedHeaders);
}

if (resourceSession.accessTokenId) {
logger.debug(
"Resource allowed because access token session is valid"
);
return allowed(res);
const interpolatedHeaders = interpolateHeaders(resource.headers, undefined);
return allowed(res, undefined, interpolatedHeaders);
}

if (resourceSession.userSessionId && sso) {
const userAccessCacheKey = `userAccess:${
resourceSession.userSessionId
}:${resource.resourceId}`;
}:${resource.resourceId}`;

let allowedUserData: BasicUserData | null | undefined =
cache.get(userAccessCacheKey);
Expand All @@ -393,7 +405,8 @@ export async function verifyResourceSession(
logger.debug(
"Resource allowed because user session is valid"
);
return allowed(res, allowedUserData);
const interpolatedHeaders = interpolateHeaders(resource.headers, allowedUserData);
return allowed(res, allowedUserData, interpolatedHeaders);
}
}
}
Expand Down Expand Up @@ -426,7 +439,7 @@ function extractResourceSessionToken(
) {
const prefix = `${config.getRawConfig().server.session_cookie_name}${
ssl ? "_s" : ""
}`;
}`;

const all: { cookieName: string; token: string; priority: number }[] = [];

Expand Down Expand Up @@ -475,12 +488,13 @@ function notAllowed(res: Response, redirectUrl?: string) {
return response<VerifyUserResponse>(res, data);
}

function allowed(res: Response, userData?: BasicUserData) {
function allowed(res: Response, userData?: BasicUserData, headers?: Record<string, string>) {
const data = {
data:
userData !== undefined && userData !== null
? { valid: true, ...userData }
: { valid: true },
data: {
valid: true,
...(userData && { userData }),
...(headers && { responseHeaders: headers })
},
success: true,
error: false,
message: "Access allowed",
Expand Down Expand Up @@ -557,10 +571,14 @@ async function isUserAllowedToAccessResource(
);

if (roleResourceAccess) {
const role = await db.query.roles.findFirst({
where: eq(roles.roleId, userOrgRole.roleId)
});
return {
username: user.username,
email: user.email,
name: user.name
name: user.name,
role: role?.name || null
};
}

Expand All @@ -570,10 +588,14 @@ async function isUserAllowedToAccessResource(
);

if (userResourceAccess) {
const role = await db.query.roles.findFirst({
where: eq(roles.roleId, userOrgRole.roleId)
});
return {
username: user.username,
email: user.email,
name: user.name
name: user.name,
role: role?.name || null
};
}

Expand Down Expand Up @@ -771,3 +793,69 @@ async function isIpInGeoIP(ip: string, countryCode: string): Promise<boolean> {

return cachedCountryCode?.toUpperCase() === countryCode.toUpperCase();
}
function interpolateHeaders(
headerTemplate: string | null,
userData?: BasicUserData
): Record<string, string> | undefined {
if (!headerTemplate) return undefined;

let parsedHeaders: any;
try {
parsedHeaders = JSON.parse(headerTemplate);
} catch (e) {
logger.error("Failed to parse headers template:", e);
return undefined;
}

const interpolated: Record<string, string> = {};

// Sanitize function to prevent header injection
const sanitize = (value: string | null): string => {
if (!value) return '';
// Remove newlines and carriage returns to prevent header injection
return value.replace(/[\r\n]/g, '');
};

// Check if it's an array format (from UI) or object format
const headersArray = Array.isArray(parsedHeaders)
? parsedHeaders
: Object.values(parsedHeaders).filter(h => h && typeof h === 'object' && 'name' in h);

if (headersArray.length > 0) {
// Array format: [{"name": "x-header", "value": "{{username}}"}]
for (const header of headersArray) {
if (!header.name || !header.value) continue;

let interpolatedValue = header.value;

if (userData) {
interpolatedValue = interpolatedValue
.replace(/\{\{username\}\}/g, sanitize(userData.username))
.replace(/\{\{email\}\}/g, sanitize(userData.email))
.replace(/\{\{name\}\}/g, sanitize(userData.name))
.replace(/\{\{role\}\}/g, sanitize(userData.role));
}

interpolated[header.name] = interpolatedValue;
}
} else {
// Simple object format: {"X-Header": "{{username}}"}
for (const [key, value] of Object.entries(parsedHeaders)) {
if (typeof value !== 'string') continue;

let interpolatedValue = value;

if (userData) {
interpolatedValue = interpolatedValue
.replace(/\{\{username\}\}/g, sanitize(userData.username))
.replace(/\{\{email\}\}/g, sanitize(userData.email))
.replace(/\{\{name\}\}/g, sanitize(userData.name))
.replace(/\{\{role\}\}/g, sanitize(userData.role));
}

interpolated[key] = interpolatedValue;
}
}

return Object.keys(interpolated).length > 0 ? interpolated : undefined;
}