diff --git a/cli/src/args.ts b/cli/src/args.ts index 7a156e54e..2028732c8 100644 --- a/cli/src/args.ts +++ b/cli/src/args.ts @@ -10,6 +10,10 @@ interface BenchmarkArguments { epochs: number roundDuration: number batchSize: number + validationSplit: number + epsilon?: number + delta?: number + dpDefaultClippingRadius?: number save: boolean host: URL } @@ -28,6 +32,10 @@ const unsafeArgs = parse( epochs: { type: Number, alias: 'e', description: 'Number of epochs', defaultValue: 10 }, roundDuration: { type: Number, alias: 'r', description: 'Round duration (in epochs)', defaultValue: 2 }, batchSize: { type: Number, alias: 'b', description: 'Training batch size', defaultValue: 10 }, + validationSplit : { type: Number, alias: 'v', description: 'Validation dataset ratio', defaultValue: 0.2 }, + epsilon: { type: Number, alias: 'n', description: 'Privacy budget', optional: true, defaultValue: undefined}, + delta: { type: Number, alias: 'd', description: 'Probability of failure, slack parameter', optional: true, defaultValue: undefined}, + dpDefaultClippingRadius: {type: Number, alias: 'f', description: 'Default clipping radius for DP', optional: true, defaultValue: undefined}, save: { type: Boolean, alias: 's', description: 'Save logs of benchmark', defaultValue: false }, host: { type: (raw: string) => new URL(raw), @@ -52,6 +60,7 @@ const supportedTasks = Map( defaultTasks.simpleFace, defaultTasks.titanic, defaultTasks.tinderDog, + defaultTasks.mnist, ).map( async (t) => [(await t.getTask()).id, t] as [ @@ -77,10 +86,29 @@ export const args: BenchmarkArguments = { task.trainingInformation.batchSize = unsafeArgs.batchSize; task.trainingInformation.roundDuration = unsafeArgs.roundDuration; task.trainingInformation.epochs = unsafeArgs.epochs; + task.trainingInformation.validationSplit = unsafeArgs.validationSplit; // For DP - // TASK.trainingInformation.clippingRadius = 10000000 - // TASK.trainingInformation.noiseScale = 0 + const {dpDefaultClippingRadius, epsilon, delta} = unsafeArgs; + + if ( + // dpDefaultClippingRadius !== undefined && + epsilon !== undefined && + delta !== undefined + ){ + if (task.trainingInformation.scheme === "local") + throw new Error("Can't have differential privacy for local training"); + + const defaultRadius = dpDefaultClippingRadius ? dpDefaultClippingRadius : 1; + + // for the case where privacy parameters are not defined in the default tasks + task.trainingInformation.privacy ??= {} + task.trainingInformation.privacy.differentialPrivacy = { + clippingRadius: defaultRadius, + epsilon: epsilon, + delta: delta, + }; + } return task; }, diff --git a/cli/src/cli.ts b/cli/src/cli.ts index b67f6fe81..6415e7406 100644 --- a/cli/src/cli.ts +++ b/cli/src/cli.ts @@ -51,7 +51,7 @@ async function main( console.log({ args }) const dataSplits = await Promise.all( - Range(0, numberOfUsers).map(async i => getTaskData(task.id, i)) + Range(0, numberOfUsers).map(async i => getTaskData(task.id, i, numberOfUsers)) ) const logs = await Promise.all( dataSplits.map(async data => await runUser(task, args.host, data as Dataset)) diff --git a/cli/src/data.ts b/cli/src/data.ts index 0369f79e4..e93426ee4 100644 --- a/cli/src/data.ts +++ b/cli/src/data.ts @@ -1,6 +1,7 @@ import path from "node:path"; -import { Dataset, processing } from "@epfml/discojs"; -import type { +import { promises as fs } from "fs"; +import { Dataset, processing, defaultTasks } from "@epfml/discojs"; +import { DataFormat, DataType, Image, @@ -9,7 +10,7 @@ import type { import { loadCSV, loadImage, loadImagesInDir } from "@epfml/discojs-node"; import { Repeat } from "immutable"; -async function loadSimpleFaceData(): Promise> { +async function loadSimpleFaceData(userIdx: number, totalClient: number): Promise> { const folder = path.join("..", "datasets", "simple_face"); const [adults, childs]: Dataset<[Image, string]>[] = [ @@ -17,10 +18,12 @@ async function loadSimpleFaceData(): Promise> { (await loadImagesInDir(path.join(folder, "child"))).zip(Repeat("child")), ]; - return adults.chain(childs); + const combinded = adults.chain(childs); + + return combinded.filter((_, i) => i % totalClient === userIdx); } -async function loadLusCovidData(): Promise> { +async function loadLusCovidData(userIdx: number, totalClient: number): Promise> { const folder = path.join("..", "datasets", "lus_covid"); const [positive, negative]: Dataset<[Image, string]>[] = [ @@ -32,7 +35,11 @@ async function loadLusCovidData(): Promise> { ), ]; - return positive.chain(negative); + const combined: Dataset<[Image, string]> = positive.chain(negative); + + const sharded = combined.filter((_, i) => i % totalClient === userIdx); + + return sharded; } function loadTinderDogData(split: number): Dataset { @@ -59,25 +66,97 @@ function loadTinderDogData(split: number): Dataset { }); } +async function loadExtCifar10(userIdx: number){ + const CIFAR10_LABELS = Array.from(await defaultTasks.cifar10.getTask().then(t => t.trainingInformation.LABEL_LIST)); + const folder = path.join("..", "datasets", "extended_cifar10"); + const clientFolder = path.join(folder, `client_${userIdx}`); + + return new Dataset<[Image, string]>(async function*(){ + const entries = await fs.readdir(clientFolder, {withFileTypes: true}); + + const items = entries + .flatMap((e) => { + const m = e.name.match( + /^image_(\d+)_label_(\d+)\.png$/i + ); + if (m === null) return []; + const labelIdx = Number.parseInt(m[2], 10); + + if( + !Number.isInteger(labelIdx) || + labelIdx < 0 || + labelIdx >= CIFAR10_LABELS.length + ){ + throw new Error('Not a valid label index.'); + } + + return { + name: e.name, + labelIdx, + }; + }) + .filter( + (x): x is {idx: number; name: string; labelIdx: number } => x !== null + ) + + for (const {name, labelIdx} of items){ + const label = CIFAR10_LABELS[labelIdx]; + const filePath = path.join(clientFolder, name); + const image = await loadImage(filePath); + yield [image, label] as [Image, string]; + } + }) +} + +function loadMnistData(split: number): Dataset{ + const folder = path.join("..", "datasets", "mnist", `${split + 1}`); + return loadCSV(path.join(folder, "labels.csv")) + .map( + (row) => + [ + processing.extractColumn(row, "filename"), + processing.extractColumn(row, "label"), + ] as const, + ) + .map(async ([filename, label]) => { + try { + const image = await Promise.any( + ["png", "jpg", "jpeg"].map((ext) => + loadImage(path.join(folder, `${filename}.${ext}`)), + ), + ); + return [image, label]; + } catch { + throw Error(`${filename} not found in ${folder}`); + } + }); +} + export async function getTaskData( taskID: Task.ID, userIdx: number, + totalClient: number ): Promise> { switch (taskID) { case "simple_face": - return (await loadSimpleFaceData()) as Dataset; + return (await loadSimpleFaceData(userIdx, totalClient)) as Dataset; case "titanic": - return loadCSV( + const titanicData = loadCSV( path.join("..", "datasets", "titanic_train.csv"), ) as Dataset; + return titanicData.filter((_, i) => i % totalClient === userIdx); case "cifar10": return ( await loadImagesInDir(path.join("..", "datasets", "CIFAR10")) ).zip(Repeat("cat")) as Dataset; case "lus_covid": - return (await loadLusCovidData()) as Dataset; + return (await loadLusCovidData(userIdx, totalClient)) as Dataset; case "tinder_dog": return loadTinderDogData(userIdx) as Dataset; + case "extended_cifar10": + return (await loadExtCifar10(userIdx)) as Dataset; + case "mnist": + return loadMnistData(userIdx) as Dataset; default: throw new Error(`Data loader for ${taskID} not implemented.`); } diff --git a/discojs/src/dataset/dataset.ts b/discojs/src/dataset/dataset.ts index 702e5cf2f..db71ef120 100644 --- a/discojs/src/dataset/dataset.ts +++ b/discojs/src/dataset/dataset.ts @@ -237,6 +237,60 @@ export class Dataset implements AsyncIterable { cached(): Dataset { return new CachingDataset(this.#content); } + + /** Shuffles the Dataset instance within certain window size */ + shuffle(windowSize: number){ + if (!Number.isInteger(windowSize) || windowSize < 1){ + throw new Error("Shuffle window size should be a positive integer"); + } + + return new Dataset( + async function*(this: Dataset){ + const iter = this[Symbol.asyncIterator](); + const buffer: T[] = []; + + // 1. Construct the initial buffer + while (buffer.length < windowSize){ + const n = await iter.next(); + if (n.done) break; + buffer.push(n.value); + } + + // 2. Shuffle + while (buffer.length > 0){ + const pick = Math.floor(Math.random() * buffer.length); + const chosen = buffer[pick]; + + const n = await iter.next(); + + if (n.done){ + // move the last element to the pick position + buffer[pick] = buffer.pop() as T; + }else{ + buffer[pick] = n.value; + } + + yield chosen; + } + }.bind(this) + ); + } + + /** filter the indices according to the splitting condition */ + filter( + condition: (value: T, index: number) => boolean | Promise + ): Dataset{ + return new Dataset(async function* (this: Dataset): AsyncGenerator{ + let i = 0; + for await(const v of this){ + if (await condition(v, i)){ + yield v; + } + i += 1 + } + }.bind(this)); + } + } /** diff --git a/discojs/src/default_tasks/cifar10.ts b/discojs/src/default_tasks/cifar10.ts index 56ca523d9..1cc341981 100644 --- a/discojs/src/default_tasks/cifar10.ts +++ b/discojs/src/default_tasks/cifar10.ts @@ -27,7 +27,7 @@ export const cifar10: TaskProvider<"image", "decentralized"> = { }, }, trainingInformation: { - epochs: 10, + epochs: 20, roundDuration: 10, validationSplit: 0.2, batchSize: 10, @@ -36,7 +36,13 @@ export const cifar10: TaskProvider<"image", "decentralized"> = { LABEL_LIST: ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'], scheme: 'decentralized', aggregationStrategy: 'mean', - privacy: { clippingRadius: 20, noiseScale: 1 }, + privacy: { + differentialPrivacy: { + clippingRadius: 1, + epsilon: 50, + delta: 1e-5, + }, + }, minNbOfParticipants: 3, maxShareValue: 100, tensorBackend: 'tfjs' @@ -63,7 +69,7 @@ export const cifar10: TaskProvider<"image", "decentralized"> = { model.compile({ optimizer: 'sgd', loss: 'categoricalCrossentropy', - metrics: ['accuracy'] + metrics: ['accuracy'], }) return new models.TFJS('image', model) diff --git a/discojs/src/privacy.spec.ts b/discojs/src/privacy.spec.ts index 5ebabee04..d00762287 100644 --- a/discojs/src/privacy.spec.ts +++ b/discojs/src/privacy.spec.ts @@ -1,7 +1,10 @@ import { describe, expect, it } from "vitest"; import { WeightsContainer } from "./index.js"; -import { addNoise, clipNorm } from "./privacy.js"; +import { frobeniusNorm, clipNorm, addOptimalNoise, getClippingRadius } from "./privacy.js"; +import { WeightNormHistory } from "./training/trainer.js"; +import * as tf from "@tensorflow/tfjs"; +import { List } from "immutable"; async function WSIntoArrays(ws: WeightsContainer): Promise { return (await Promise.all(ws.weights.map(async (w) => await w.data()))).map( @@ -9,47 +12,82 @@ async function WSIntoArrays(ws: WeightsContainer): Promise { ); } -async function norm(weights: WeightsContainer): Promise { - return Math.sqrt( - ( - await weights - .map((w) => w.square().sum()) - .reduce((a, b) => a.add(b)) - .data() - )[0], - ); -} -describe("addNoise", () => { - it("returns the same weight vector with zero noise", async () => { - const result = addNoise(WeightsContainer.of([1, 2, 3]), 0); +/** Test the frobenius norm computation */ +describe("frobeniusNorm", () => { + it("computes Frobenius norm for a tf.Tensor", async () => { + const t = tf.tensor([3, 4]); + const n = await frobeniusNorm(t); + expect(n).toBeCloseTo(5, 1e-12); + }) - expect(await WSIntoArrays(result)).to.deep.equal([[1, 2, 3]]); + it("computes Frobenius norm for a WeightsContainer across all layers", async () => { + const wc = WeightsContainer.of([3, 4], [0, 6]); + const n = await frobeniusNorm(wc); + expect(n).toBeCloseTo(Math.sqrt(61), 1e-12); }); +}); - it("adds noise following a normal distribution", async () => { - const result = addNoise(WeightsContainer.of([5]), 1); +describe("clipNorm", () => { + it("clips a single-layer vector using single radius value", async () => { + const result = await clipNorm(WeightsContainer.of([2]), [1]); + expect(await WSIntoArrays(result)).toEqual([[1]]); + }) - expect((await WSIntoArrays(result))[0][0]).to.be.within(2, 8); // 99.7% of success + it("check if it does not change vector when it is already within radius", async () => { + // norm is smaller than the clipping radius 10 + const result = await clipNorm(WeightsContainer.of([3, 4]), [10]); + expect(await WSIntoArrays(result)).toEqual([[3, 4]]) + }) + + it("applying different clipping radii per layer", async () => { + const wc = WeightsContainer.of([3, 4], [0, 6]); + const result = await clipNorm(wc, [5, 3]); // apply different clipping radii for each layer + + expect(await WSIntoArrays(result)).toEqual([ + [3, 4], + [0, 3], + ]); }); -}); +}) -describe("clipNorm", () => { - it("reduce norm for a one-dimensional vector", async () => { - const result = await clipNorm(WeightsContainer.of([2]), 1); +describe("addOptimalNoise", () => { + it("check if the structure is maintained", async () => { + const weights = WeightsContainer.of([3, 4], [0, 6]); + const epsilon = 1; + const delta = 1e-5; + const radius = [5, 3]; - expect(await WSIntoArrays(result)).to.deep.equal([[1]]); + const result = await addOptimalNoise(weights, epsilon, delta, radius); + + const resultArrays = await WSIntoArrays(result); + + // Check the structures of the weights are maintained + expect(resultArrays[0].length).toBe(2); + expect(resultArrays[1].length).toBe(2); + + // Check the values are numbers + expect(Number.isFinite(resultArrays[0][0])).toBe(true); + expect(Number.isFinite(resultArrays[0][1])).toBe(true); + expect(Number.isFinite(resultArrays[1][0])).toBe(true); + expect(Number.isFinite(resultArrays[1][1])).toBe(true); }); +}) - it("keeps direction unchanged", async () => { - const result = await clipNorm( - WeightsContainer.of([2, 3, 6]), // norm = 7 - 1, - ); - const normScaler = 7 / (await norm(result)); +describe("getClippingRadius", () => { + it("correct average clipping radius and default radius", () => { + const weightNormHistory = List([ + List([2, 4, 6]), // expected average norm is 4 + List([10]) + ]); - expect( - await WSIntoArrays(result.map((w) => w.mul(normScaler))), - ).to.deep.equal([[2, 3, 6]]); + expect(getClippingRadius(weightNormHistory as WeightNormHistory, 5)).toEqual([4, 5]); }); -}); + + it("uses smaller window size automatically if needed", () => { + const weightNormHistory = List([List([2, 4])]); + + // Automatically use window size of 2 instead of 10 + expect(getClippingRadius(weightNormHistory as WeightNormHistory, 10)).toEqual([3]); + }); +}); \ No newline at end of file diff --git a/discojs/src/privacy.ts b/discojs/src/privacy.ts index 31a0b812c..3d866cdd9 100644 --- a/discojs/src/privacy.ts +++ b/discojs/src/privacy.ts @@ -1,39 +1,111 @@ import * as tf from "@tensorflow/tfjs"; -import type { WeightsContainer } from "./index.js"; +import { WeightsContainer } from "./index.js"; -async function frobeniusNorm(weights: WeightsContainer): Promise { - const squared = await weights - .map((w) => w.square().sum()) - .reduce((a, b) => a.add(b)) - .data(); - if (squared.length !== 1) throw new Error("unexpected weights shape"); +import { WeightNormHistory } from "./training/trainer.js"; - return Math.sqrt(squared[0]); +export async function frobeniusNorm(weights: WeightsContainer | tf.Tensor): Promise{ + /** + * Computes the Frobenius norm of the given weights. + * + * For a WeightsContainer, the Frob norm is computed over all tensors contained in the container. + * For a single tf.Tensor, the Frob norm is computed for that tensor (the weight of a single layer). + */ + if (weights instanceof WeightsContainer){ + const squared = await weights + .map((w) => w.square().sum()) + .reduce((a, b) => a.add(b)) + .data(); + + if (squared.length !== 1) throw new Error("unexpected weights shape"); + return Math.sqrt(squared[0]); + } else { + const squared = await weights.square().sum().data(); + if (squared.length !== 1) throw new Error("unexpected weights shape"); + return Math.sqrt(squared[0]); + } } -/** Scramble weights */ -export function addNoise( - weights: WeightsContainer, - deviation: number, -): WeightsContainer { - const variance = Math.pow(deviation, 2); - return weights.map((w) => w.add(tf.randomNormal(w.shape, 0, variance))); +/** ALDP-FL implementation */ +// Conditions need to be added for the first three epochs -> get the avg update from all of the available previous updates +export function getClippingRadius(weightNormHistory: WeightNormHistory, defaultClippingRadius:number): number[]{ + const WINDOW_SIZE = 3; + const MIN_RADIUS = 1e-12; + + const radii = weightNormHistory.map((norms) => { + const recent = norms.slice(-WINDOW_SIZE); + const avg = recent.reduce((sum, n) => sum+n, 0) / recent.size; + + return Math.max(MIN_RADIUS, Math.min(avg, defaultClippingRadius)) + }); + + // Convert List to number[] + return radii.toArray(); +} + +/** Optimized Gaussian noise using a clipping radius calculation of ALDP-FL for adaptive local differential privacy in federated learning, + * https://www.nature.com/articles/s41598-025-12575-6 */ +/** Implementation of historical moving average based clipping radius calculation */ +export async function addOptimalNoise( + weightUpdates: WeightsContainer, + epsilon: number, + delta: number, + clippingRadius: number[], +): Promise { + /** + * In the original paper, the sensitivity is given as 2 * clippingRadius / d, though the meaning of d is unclear. + * We believe the L2 sensitivity of the gradient update is 2 * clippingRadius. + */ + // apply different sensitivity and noise to each of the layer + // clippingRadius is now number[] + const sens = clippingRadius.map((r)=>(2*r)); + const sigmas = sens.map((s)=>(s * Math.sqrt(2*Math.log(1.25/delta))/epsilon)); + const clippedWeights = await clipNorm(weightUpdates, clippingRadius); + + return clippedWeights.map((w, i) => + w.add(tf.randomNormal(w.shape, 0, sigmas[i])) + ) } /** * Keep weights' norm within radius - * - * @param radius maximum norm **/ export async function clipNorm( weights: WeightsContainer, - radius: number, + radius: number[], ): Promise { - if (radius <= 0) throw new Error("invalid radius"); + /** + * If radius.length === 1, interpret radius[0] as a global clipping radius (BFT) + * If radius.length === numLayers, apply per-layer clipping (DP) + */ + const layers = weights.weights; + if (radius.length !== 1 && radius.length !== layers.length) + throw new Error(`radius length mismatch: got ${radius.length}, expected 1 or ${layers.length}`); + + if (radius.length === 1){ + const r = radius[0]; + if (!Number.isFinite(r) || r <= 0) + throw new Error("invalide radius"); + + const norm = await frobeniusNorm(weights); + const scaling = Math.max(1, norm / r); - const norm = await frobeniusNorm(weights); - const scaling = Math.max(1, norm / radius); + return weights.map((w) => w.div(scaling)); + }else{ + /** Apply different clipping radius to each layer in the WeightsContainer */ + const clipped = await Promise.all( + layers.map(async (l, i) => { + const norm = await frobeniusNorm(l); + const r = radius[i]; + + // Check the invalid radius value + if (!Number.isFinite(r) || r <= 0) + throw new Error("Invalid radius value") + const scaling = Math.max(1, norm / r); + return l.div(scaling); + }) + ); - return weights.map((w) => w.div(scaling)); + return new WeightsContainer(clipped); + } } diff --git a/discojs/src/task/training_information.ts b/discojs/src/task/training_information.ts index a708b9880..1d423a120 100644 --- a/discojs/src/task/training_information.ts +++ b/discojs/src/task/training_information.ts @@ -8,12 +8,21 @@ const nonLocalNetworkSchema = z.object({ privacy: z .object({ // maximum weights difference between each round - clippingRadius: z.number().optional(), - // variance of the Gaussian noise added to the shared weights. - noiseScale: z.number().optional(), + byzantineFaultTolerance: z.object({ + clippingRadius: z.number().positive(), + }).optional(), + + differentialPrivacy: z.object({ + // maximum weights difference between each epoch, used for differential privacy + clippingRadius: z.number().positive().default(1), + // privacy budget, used to compute the variance of Gaussian noise + epsilon: z.number().positive(), + // small probability that the privacy guarantee may not hold + delta: z.number().gt(0).lt(1), + }).optional(), }) .transform((o) => - o.clippingRadius === undefined && o.noiseScale === undefined + o.byzantineFaultTolerance === undefined && o.differentialPrivacy === undefined ? undefined : o, ) diff --git a/discojs/src/training/trainer.ts b/discojs/src/training/trainer.ts index ac48f4047..56a2bf255 100644 --- a/discojs/src/training/trainer.ts +++ b/discojs/src/training/trainer.ts @@ -1,7 +1,7 @@ import * as tf from "@tensorflow/tfjs"; import { List } from "immutable"; -import type { +import { Batched, BatchLogs, Dataset, @@ -22,6 +22,16 @@ export interface RoundLogs { participants: number; } +/** List of weight update norms */ +export type WeightNormHistory = List>; + +function appendWeightHistory(weightNormHistory: WeightNormHistory, wc: number[]){ + return wc.reduce((hist, t, i) => { + const arr = hist.get(i, List()); + return hist.set(i, arr.push(t)); + }, weightNormHistory); +} + /** Train a model and exchange with others **/ export class Trainer { readonly #client: Client; @@ -38,6 +48,9 @@ export class Trainer { AsyncGenerator, RoundLogs>, void >; + // Map of weight Index and weight update + #weightNormHistory : WeightNormHistory = List(); + #previousRoundWeights?: WeightsContainer; public get model(): Model { if (this.#model === undefined) @@ -94,24 +107,38 @@ export class Trainer { void > { const totalRound = Math.trunc(this.#epochs / this.#roundDuration); - let previousRoundWeights: WeightsContainer | undefined; for (let round = 0; round < totalRound; round++) { + await this.#client.onRoundBeginCommunication(); - yield this.#runRound(dataset, validationDataset); + // Store the clean weight before starting the communication + this.#previousRoundWeights = new WeightsContainer(this.model.weights.weights.map(t => t.clone())); - let localWeights = this.model.weights; - if (this.#privacy !== undefined) - localWeights = await applyPrivacy( - previousRoundWeights, - localWeights, - this.#privacy, - ); + yield this.#runRound(dataset, validationDataset); - const networkWeights = - await this.#client.onRoundEndCommunication(localWeights); + let roundWeights = this.model.weights; - this.model.weights = previousRoundWeights = networkWeights; + // Apply differential privacy before sharing the weight updates with other nodes + if (this.#privacy !== undefined){ + const roundUpdate = roundWeights.sub(this.#previousRoundWeights); + const updateNorm = await Promise.all( + roundUpdate.weights.map((w) => privacy.frobeniusNorm(w)) + ); + this.#weightNormHistory = appendWeightHistory(this.#weightNormHistory, updateNorm); + + roundWeights = await applyOptimalPrivacy( + this.#previousRoundWeights, + roundWeights, + this.#privacy, + this.#weightNormHistory, + totalRound, + ) + } + // Get the updated weights + const networkWeights = await this.#client.onRoundEndCommunication(roundWeights); + + // Update the local weights + this.model.weights = networkWeights; } } @@ -125,10 +152,10 @@ export class Trainer { this.model.train(dataset, validationDataset), ); - yield gen; + yield gen; // batchLogs epochsLogs = epochsLogs.push(await epochLogs); } - + return { epochs: epochsLogs, participants: this.#client.nbOfParticipants, @@ -136,30 +163,45 @@ export class Trainer { } } -async function applyPrivacy( +/** ALDP-FL implementation */ +async function applyOptimalPrivacy( previous: WeightsContainer | undefined, current: WeightsContainer, - options: Exclude< - Task< - DataType, - "decentralized" | "federated" - >["trainingInformation"]["privacy"], - undefined - >, + options: Exclude["trainingInformation"]["privacy"], undefined>, + weightNormHistory: WeightNormHistory, + totalRound: number, ): Promise { let ret = current; - if (options.clippingRadius !== undefined) { - const previousRoundWeights = - previous ?? current.map((w) => tf.zerosLike(w)); + // Clipping radius for BFT + const bftOptions = options.byzantineFaultTolerance; + if (bftOptions !== undefined) { + // might need to change the variable name + const previousRoundWeights = previous ?? current.map((w) => tf.zerosLike(w)); const weightsProgress = current.sub(previousRoundWeights); ret = previousRoundWeights.add( - await privacy.clipNorm(weightsProgress, options.clippingRadius), + await privacy.clipNorm(weightsProgress, [bftOptions.clippingRadius]), // we should make an array containing clipping radius with weightsProgress shape ); } - if (options.noiseScale !== undefined) - ret = privacy.addNoise(ret, options.noiseScale); + // Adding Gaussian noise for DP + const dpOptions = options.differentialPrivacy; + if (dpOptions !== undefined){ + const dpDefaultRadius = dpOptions.clippingRadius; // options.dpDefaultClippingRadius should be a number + + // Divide privacy budget across all rounds (conservative composition) + const delta = dpOptions.delta / totalRound; + const epsilon = dpOptions.epsilon / totalRound; + const dpClippingRadius = privacy.getClippingRadius(weightNormHistory, dpDefaultRadius); + + const previousEpochWeights = previous ?? current.map((w) => tf.zerosLike(w)); + const weightsProgress = current.sub(previousEpochWeights); + + /** Need to use tighter clipping radius for noise calibration */ + const effectiveRadius = bftOptions ? dpClippingRadius.map(r => Math.min(r, bftOptions.clippingRadius)) : dpClippingRadius; + + ret = previousEpochWeights.add(await privacy.addOptimalNoise(weightsProgress, epsilon, delta, effectiveRadius)); + } return ret; -} +} \ No newline at end of file diff --git a/docs/PRIVACY.md b/docs/PRIVACY.md index 71ca5a4dc..d1e6a5d61 100644 --- a/docs/PRIVACY.md +++ b/docs/PRIVACY.md @@ -12,7 +12,55 @@ In addition to the intrinsic security of federated and decentralized learning, D Differential privacy methods protect any dataset(s) used in the training of a machine learning (ML) model, from inference attacks based on the weights of the resulting ML model. -The respective parameters `noiseScale` and `clippingRadius` are available in the [task configuration](TASK.md). +The respective parameters `epsilon`, `delta`, and `clippingRadius` are available in the [task configuration](TASK.md). + +### What is Differential Privacy? +Differential privacy (DP) is a rigorous privacy framework that provides a privacy guarantee by ensuring that an algorithm's output does not significantly change when a single data point in the dataset is modified. This protection is achieved by adding carefully calibrated random values (called "noise") to the data or model updates. + +In DISCO, differential privacy ensures privacy by making sure that the weight updates produced by one client do not significantly change when a single data point in that client's dataset is modified. This is called local differential privacy (LDP). Before sharing weight updates with the server, random noise is added to these updates. By examining only the weight updates that each client sends to the server, no party, including the server, can infer who generated a specific update or which datasets particular clients have. + +Differntial privacy has an important parameter, epsilon($\epsilon$), which indicates the privacy level applied to the learning process. It is also called the "privacy budget." + +### Parameter Explanations +Differential privacy is achieved by adding noise. To guarantee your desired privacy level, you need to specify several parameters: + +`epsilon` +- This is the privacy budget. The smaller the $\epsilon$ value, the stronger the privacy protection. In DISCO, this $\epsilon$ value indicates the privacy guarantee for a single client. + +`delta` +- This parameter indicates the failure pobability of the privacy guarantee. It is used in approximate differential privacy, which DISCO implemented. + +`clipping radius` +- This parameter sets the maximum bound for the adaptive clipping radius. + +### Privacy-utility trade-off +The utility degradation that follows from improving privacy is an inherent feature of differential privacy, so you must consider this when choosing your $\epsilon$ value. When $\epsilon$ equals 0, this guarantees perfect privacy but zero utility. As $\epsilon$ approaches infinity, privacy becomes zero and full utility is recovered. As $\epsilon$ decreases, utility degrades gradually. + +When we repetitively run the same private algorithm, the privacy budget accumulates, resulting in a larger final privacy budget that indicates a weaker privacy guarantee. This is called "composition" of privacy budget. This applies to DP in DISCO: since we add noise to weight updates at every epoch, the privacy budget accumulates with each epoch. The accumulation rate is determined by the total number of epochs defined in the task configuration. + +### What is the best $\epsilon$ value? +Choosing an appropriate $\epsilon$ value depends on your specific use case and requires careful consideration of the privacy-utility trade-off. + +- For local differential privacy (LDP), which DISCO implements, meaningful utility typically requires larger $\epsilon$ values compared to central differential privacy. In practice, LDP implementations often use $\epsilon$ values ranging from 5 to 20. Some implementations may use higher values, though this comes with weaker privacy guarantee. +- Lower $\epsilon$ values (closer to 1) provide stronger privacy guarantees but may significantly reduce model accuracy or utility, which can make the final result meaningless. +- Higher $\epsilon$ values (above 20) may provide better model performance but offer weaker privacy protection. + +- The approapriate $\epsilon$ value for your task depends on several factors as below. + - Your acceptable level of model accuracy degradation + - The number of rounds (due to privacy budget composition over rounds) + +- To provide context, here are examples of $\epsilon$ values used in real-world deployments: + - Apple's local differential privacy implementation for iOS and macOS uses $\epsilon$ = 16 for QuickType suggestions, with a privacy unit of user per day ([Apple Differential Privacy Overview](https://www.apple.com/privacy/docs/Differential_Privacy_Overview.pdf)) + - Microsoft's Windows telemetry collection uses local differential privacy with $\epsilon$ = 1.672, with a privacy unit of user per 6 hours ([Ding et al., 2017](https://www.microsoft.com/en-us/research/publication/collecting-telemetry-data-privately/)) + +### DISCO's Differential Privacy Implementation +Since model weights are shared for aggregation to converge to a final model in DISCO, we add DP noise to weight updates before sharing them with server or other clients. This noise is calibrated with an interaction between $\epsilon$, $\delta$, and `clipping_radius`. + + +To carefully calibrate the smallest possible noise for a given privacy guarantee, we implement window-based adaptive local differential privacy(ALDP). The ALDP process works as follows. + 1. Each round, before sharing the weight update with the server, we calibrate the noise using $\epsilon$, $\delta$, and a new adaptive clipping radius, which is the mean value of the three previous weight updates. This helps us find the optimal clipping radius that avoids over-calibrating the noise needed for the privacy guarantee. + 2. We add the calibrated noise to the current weight update and share it with the server. + 3. We store the weight update before noise addition to use for calibrating the clipping radius in the next round. ## Secure aggregation through MPC diff --git a/docs/examples/custom_task.ts b/docs/examples/custom_task.ts index ab5873432..deb3c08a3 100644 --- a/docs/examples/custom_task.ts +++ b/docs/examples/custom_task.ts @@ -30,8 +30,7 @@ const customTask: TaskProvider<"tabular", "federated"> = { aggregationStrategy: "mean", minNbOfParticipants: 2, tensorBackend: 'tfjs', - noiseScale: undefined, - clippingRadius: undefined + privacy: undefined, } }); }, diff --git a/models/cifar10/model.json b/models/cifar10/model.json new file mode 100644 index 000000000..ac35d14ea Binary files /dev/null and b/models/cifar10/model.json differ diff --git a/models/llm_task/model.json b/models/llm_task/model.json new file mode 100644 index 000000000..bd36ef89a Binary files /dev/null and b/models/llm_task/model.json differ diff --git a/models/lus_covid/model.json b/models/lus_covid/model.json new file mode 100644 index 000000000..9371453e3 Binary files /dev/null and b/models/lus_covid/model.json differ diff --git a/models/titanic/model.json b/models/titanic/model.json new file mode 100644 index 000000000..cc4e7425a Binary files /dev/null and b/models/titanic/model.json differ diff --git a/server/tests/e2e/federated.spec.ts b/server/tests/e2e/federated.spec.ts index 016dd17d5..6c1d17195 100644 --- a/server/tests/e2e/federated.spec.ts +++ b/server/tests/e2e/federated.spec.ts @@ -314,4 +314,42 @@ describe("end-to-end federated", () => { await discoUser3.close(); }); + + /** + * Test if federated learning task lus_covid operates correctly with differential privacy + */ + it("three lus_covid clients meet consensus with differential privacy", { timeout: 1_000_000 }, async () => { + const task = await defaultTasks.lusCovid.getTask(); + task.trainingInformation = { + ...task.trainingInformation, + epochs: 20, + roundDuration: 10, + minNbOfParticipants: 3, + privacy: { + differentialPrivacy: { + epsilon: 50, + delta: 1e-5, + clippingRadius: 1, + } + } + }; + const url = await startServer({ + ...defaultTasks.lusCovid, + getTask: () => Promise.resolve(task), + }); + const dataset = await datasets.loadLusCOVID(); + + const [[m1, l1], [m2, l2], [m3, l3]] = await Promise.all([ + runUser(url, task, dataset), + runUser(url, task, dataset), + runUser(url, task, dataset), + ]); + + for (const lastEpoch of [l1, l2, l3]) { + expect(lastEpoch.training.accuracy).to.be.greaterThan(0.5); + expect(lastEpoch.validation?.accuracy).to.be.greaterThan(0.5); + } + assert.isTrue(m1.equals(m2) && m2.equals(m3)); + }) }); + diff --git a/webapp/cypress/e2e/task-creation.cy.ts b/webapp/cypress/e2e/task-creation.cy.ts index 2871904fa..ccf9e3968 100644 --- a/webapp/cypress/e2e/task-creation.cy.ts +++ b/webapp/cypress/e2e/task-creation.cy.ts @@ -10,6 +10,8 @@ it("submits with tabular task", () => { cy.visit("/#/create"); + cy.get('form').should('be.visible'); // Wait for the form to be fully loaded + cy.get("input[name='id']").type("id"); cy.get("select[name='dataType']").select("tabular"); diff --git a/webapp/src/components/task_creation_form/TaskCreationForm.vue b/webapp/src/components/task_creation_form/TaskCreationForm.vue index 9d8bd5598..29b8b802d 100644 --- a/webapp/src/components/task_creation_form/TaskCreationForm.vue +++ b/webapp/src/components/task_creation_form/TaskCreationForm.vue @@ -362,14 +362,35 @@ label="Differential privacy" type="checkbox" > -
+
+ + + + + + @@ -382,13 +403,13 @@ label="Weight clipping" type="checkbox" > -
+
{ form.setFieldValue("trainingInformation.aggregationStrategy", "mean"); }); +watch(differentialPrivacy, (on) => { + if (!form.value) return; + if (!on) { + form.value.setFieldValue("trainingInformation.privacy.differentialPrivacy", undefined); + } +}); + +watch(weightClipping, (on) => { + if (!form.value) return; + if (!on){ + form.value.setFieldValue("trainingInformation.privacy.byzantineFaultTolerance", undefined); + } +}); + // warn user on page content loss window.onbeforeunload = (event) => { if (form.value === null || form.value.meta.dirty === false) return; @@ -601,15 +636,24 @@ window.onbeforeunload = (event) => { const nonLocalNetwork = { privacy: z .object({ - clippingRadius: z.number().optional(), - noiseScale: z.number().optional(), + byzantineFaultTolerance: z.object({ + clippingRadius: z.number().positive(), + }).optional(), + + differentialPrivacy: z.object({ + // optional on input, default value is 1 when missing + clippingRadius: z.number().positive().default(1), + // privacy budget epsilon + epsilon: z.number().positive(), + // DP delta is a small probability + delta: z.number().gt(0).lt(1), + }).optional() }) - .optional() .transform((arg, ctx) => { - if (!differentialPrivacy.value) return undefined; + if (!differentialPrivacy.value && !weightClipping.value) return undefined; - function addUndefIssue(field?: string): void { - const path = field !== undefined ? [field] : undefined; + function addUndefIssue(field?: string[]): void { + const path = field !== undefined ? field : undefined; ctx.addIssue({ code: "custom", message: "Required", @@ -618,16 +662,34 @@ const nonLocalNetwork = { } if (arg === undefined) { - addUndefIssue(); + addUndefIssue(); return z.NEVER; } - if (arg.clippingRadius === undefined) addUndefIssue("clippingRadius"); - if (arg.noiseScale === undefined) addUndefIssue("noiseScale"); - if (arg.clippingRadius === undefined || arg.noiseScale === undefined) - return z.NEVER; + + if (differentialPrivacy.value){ + if (arg.differentialPrivacy === undefined){ + addUndefIssue(["differentialPrivacy"]); + return z.NEVER; + } + if (arg.differentialPrivacy.epsilon === undefined) addUndefIssue(["differentialPrivacy", "epsilon"]); + if (arg.differentialPrivacy.delta === undefined) addUndefIssue(["differentialPrivacy", "delta"]); + if (arg.differentialPrivacy.epsilon === undefined || arg.differentialPrivacy.delta === undefined) + return z.NEVER; + } + + if (weightClipping.value){ + if (arg.byzantineFaultTolerance === undefined){ + addUndefIssue(["byzantineFaultTolerance"]); + return z.NEVER; + } + if (arg.byzantineFaultTolerance.clippingRadius === undefined){ + addUndefIssue(["byzantineFaultTolerance", "clippingRadius"]); + return z.NEVER; + } + } return arg; - }), + }).optional(), minNbOfParticipants: z.number().positive().int(), }; const trainingInformationNetworks = z.union([ diff --git a/webapp/src/components/training/TrainingDescription.vue b/webapp/src/components/training/TrainingDescription.vue index 3a1a58363..28482c530 100644 --- a/webapp/src/components/training/TrainingDescription.vue +++ b/webapp/src/components/training/TrainingDescription.vue @@ -70,13 +70,25 @@ Differential Privacy: Noise Scale - {{ task.trainingInformation.privacy?.noiseScale ?? "Unused" }} + {{ task.trainingInformation.privacy?.differentialPrivacy?.epsilon ?? "Unused" }} - Differential Privacy: Clipping Radius + Differential Privacy: Delta - {{ task.trainingInformation.privacy?.clippingRadius ?? "Unused" }} + {{ task.trainingInformation.privacy?.differentialPrivacy?.delta ?? "Unused" }} + + + + Differential Privacy: Default Clipping Radius + + {{ task.trainingInformation.privacy?.differentialPrivacy?.clippingRadius ?? "Unused" }} + + + + Byzantine Fault Tolerance: Clipping Radius + + {{ task.trainingInformation.privacy?.byzantineFaultTolerance?.clippingRadius ?? "Unused" }}