@@ -35,4 +35,4 @@ function draw() {
if (context === null) throw new Error("canvas doesn't support 2D context");
context.putImageData(props.image, 0, 0);
}
-
+
\ No newline at end of file
From 40de1e9b6b0447ecca0d77a379dd55fde8b5647f Mon Sep 17 00:00:00 2001
From: Julien Vignoud
Date: Mon, 25 Nov 2024 12:35:31 +0100
Subject: [PATCH 03/25] webapp/components/training/TrainingInformation: display
training charts on first row and validation charts on second row if any
---
.../training/TrainingInformation.vue | 46 +++++++++----------
1 file changed, 22 insertions(+), 24 deletions(-)
diff --git a/webapp/src/components/training/TrainingInformation.vue b/webapp/src/components/training/TrainingInformation.vue
index d291e0e59..f6763faa6 100644
--- a/webapp/src/components/training/TrainingInformation.vue
+++ b/webapp/src/components/training/TrainingInformation.vue
@@ -53,9 +53,8 @@
+ class="flex flex-col md:grid gap-4 md:gap-8 md:grid-cols-2"
+ >
Training Loss of the Model
@@ -76,56 +75,55 @@
/>
-
-
- Validation Loss of the Model
+
+
+ Training Accuracy of the Model
- {{ (lastEpoch?.validation?.loss ?? 0).toFixed(2) }}
+ {{ percent(lastEpoch?.training.accuracy ?? 0) }}
- validation loss
+ % of training accuracy
-
+
-
+
- Training Accuracy of the Model
+ Validation Loss of the Model
- {{ percent(lastEpoch?.training.accuracy ?? 0) }}
+ {{ (lastEpoch?.validation?.loss ?? 0).toFixed(2) }}
- % of training accuracy
+ validation loss
-
-
+
Validation Accuracy of the Model
@@ -361,4 +359,4 @@ const lossChartsOptions = computed(() => {
function percent(n: number): string {
return (n * 100).toFixed(2);
}
-
+
\ No newline at end of file
From 6916223873319733f3af51c3bf676bb199913f08 Mon Sep 17 00:00:00 2001
From: Julien Vignoud
Date: Mon, 25 Nov 2024 12:36:09 +0100
Subject: [PATCH 04/25] discojs/src/training/disco: handle validation split
ratio equals zero
---
discojs/src/training/disco.ts | 13 ++++++++-----
1 file changed, 8 insertions(+), 5 deletions(-)
diff --git a/discojs/src/training/disco.ts b/discojs/src/training/disco.ts
index 2ef90728e..07d37c68c 100644
--- a/discojs/src/training/disco.ts
+++ b/discojs/src/training/disco.ts
@@ -205,18 +205,21 @@ export class Disco extends EventEmitter<{
): Promise<
[
Dataset>,
- Dataset>,
+ Dataset> | undefined,
]
> {
const { batchSize, validationSplit } = this.#task.trainingInformation;
- const preprocessed = await processing.preprocess(this.#task, dataset);
+ let preprocessed = await processing.preprocess(this.#task, dataset);
- const [training, validation] = (
+ preprocessed = (
this.#preprocessOnce
? new Dataset(await arrayFromAsync(preprocessed))
: preprocessed
- ).split(validationSplit);
+ )
+ if (validationSplit === 0) return [preprocessed.batch(batchSize).cached(), undefined];
+
+ const [training, validation] = preprocessed.split(validationSplit);
return [
training.batch(batchSize).cached(),
@@ -230,4 +233,4 @@ async function arrayFromAsync(iter: AsyncIterable): Promise {
const ret: T[] = [];
for await (const e of iter) ret.push(e);
return ret;
-}
+}
\ No newline at end of file
From 4b22ef75d0a2edcefe9088fd6754c98435cb3e70 Mon Sep 17 00:00:00 2001
From: Julien Vignoud
Date: Mon, 25 Nov 2024 12:36:31 +0100
Subject: [PATCH 05/25] discojs/client/federated/federated_client: wait
indefinitely for server answer
---
discojs/src/client/federated/federated_client.ts | 3 +--
1 file changed, 1 insertion(+), 2 deletions(-)
diff --git a/discojs/src/client/federated/federated_client.ts b/discojs/src/client/federated/federated_client.ts
index e6c89cc79..fc172d7b8 100644
--- a/discojs/src/client/federated/federated_client.ts
+++ b/discojs/src/client/federated/federated_client.ts
@@ -6,7 +6,6 @@ import { Client, shortenId } from "../client.js";
import { type, type ClientConnected } from "../messages.js";
import {
waitMessage,
- waitMessageWithTimeout,
WebSocketServer,
} from "../event_connection.js";
import * as messages from "./messages.js";
@@ -75,7 +74,7 @@ export class FederatedClient extends Client {
const {
id, waitForMoreParticipants, payload,
round, nbOfParticipants
- } = await waitMessageWithTimeout(this.server, type.NewFederatedNodeInfo);
+ } = await waitMessage(this.server, type.NewFederatedNodeInfo);
// This should come right after receiving the message to make sure
// we don't miss a subsequent message from the server
From 7a457a4727ac58eef88bfcd1e17907e82f3c99fa Mon Sep 17 00:00:00 2001
From: Julien Vignoud
Date: Mon, 25 Nov 2024 12:37:47 +0100
Subject: [PATCH 06/25] disco/default_tasks: implement GDHF tinder dog task
---
discojs/src/default_tasks/index.ts | 1 +
discojs/src/default_tasks/tinder_dog.ts | 84 +++++++++++++++++++++++++
2 files changed, 85 insertions(+)
create mode 100644 discojs/src/default_tasks/tinder_dog.ts
diff --git a/discojs/src/default_tasks/index.ts b/discojs/src/default_tasks/index.ts
index 7ee583f1f..43adf0d3c 100644
--- a/discojs/src/default_tasks/index.ts
+++ b/discojs/src/default_tasks/index.ts
@@ -4,3 +4,4 @@ export { mnist } from './mnist.js'
export { simpleFace } from './simple_face.js'
export { titanic } from './titanic.js'
export { wikitext } from './wikitext.js'
+export { tinderDog } from './tinder_dog.js'
\ No newline at end of file
diff --git a/discojs/src/default_tasks/tinder_dog.ts b/discojs/src/default_tasks/tinder_dog.ts
new file mode 100644
index 000000000..f308275ba
--- /dev/null
+++ b/discojs/src/default_tasks/tinder_dog.ts
@@ -0,0 +1,84 @@
+import * as tf from '@tensorflow/tfjs'
+
+import type { Model, Task, TaskProvider } from '../index.js'
+import { models } from '../index.js'
+
+export const tinderDog: TaskProvider<'image'> = {
+ getTask (): Task<'image'> {
+ return {
+ id: 'tinder_dog',
+ displayInformation: {
+ taskTitle: 'GDHF 2024 | TinderDog',
+ summary: {
+ preview: 'Which dog is the cutest....or not?',
+ overview: "Binary classification model for dog cuteness."
+ },
+ // model: 'The model is a pretrained MobileNetV1 model trained in Tensorflow.js. The last output layer is replaced with a fully connected layer with softmax activation and one output neuron per CIFAR10 category. The data preprocessing reshapes images into 224x224 pixels and normalizes values between 0 and 1. The neural network is optimized via Stochastic Gradient Descent and a categorical Cross Entropy loss.',
+ dataFormatInformation: 'Images should be of .png format.',
+ /* dataExampleText: 'Below you can find 10 random examples from each of the 10 classes in the dataset.',
+ dataExampleImage: 'https://storage.googleapis.com/deai-313515.appspot.com/example_training_data/cifar10-example.png',
+ sampleDatasetLink: 'https://storage.googleapis.com/deai-313515.appspot.com/example_training_data.tar.gz',
+ sampleDatasetInstructions: 'Opening the link should start downloading a zip file which you can unzip. To connect the data, use the CSV option below and select the file named "cifar10-labels.csv". You can now connect the images located in the "CIFAR10" folder. Note that there are only 24 images in this sample dataset which is far too few to successfully train a machine learning model.' */
+ },
+ trainingInformation: {
+ epochs: 10,
+ roundDuration: 2,
+ validationSplit: 0,
+ batchSize: 10,
+ dataType: 'image',
+ IMAGE_H: 64,
+ IMAGE_W: 64,
+ LABEL_LIST: ['Cute dogs', 'Less cute dogs'],
+ scheme: 'federated',
+ aggregationStrategy: 'mean',
+ minNbOfParticipants: 3,
+ tensorBackend: 'tfjs'
+ }
+ }
+ },
+
+
+ async getModel(): Promise> {
+ const seed = 42
+ const imageHeight = this.getTask().trainingInformation.IMAGE_H
+ const imageWidth = this.getTask().trainingInformation.IMAGE_W
+ const imageChannels = 3
+
+ const model = tf.sequential()
+
+ model.add(
+ tf.layers.conv2d({
+ inputShape: [imageHeight, imageWidth, imageChannels],
+ kernelSize: 5,
+ filters: 8,
+ activation: 'relu',
+ kernelInitializer: tf.initializers.heNormal({ seed })
+ })
+ )
+ model.add(tf.layers.conv2d({
+ kernelSize: 5, filters: 16, activation: 'relu',
+ kernelInitializer: tf.initializers.heNormal({ seed })
+ }))
+ model.add(tf.layers.maxPooling2d({ poolSize: 2, strides: 2 }))
+ model.add(tf.layers.dropout({ rate: 0.25, seed }))
+
+ model.add(tf.layers.flatten())
+ model.add(tf.layers.dense({
+ units: 32, activation: 'relu',
+ kernelInitializer: tf.initializers.heNormal({ seed })
+ }))
+ model.add(tf.layers.dropout({rate:0.25, seed}))
+ model.add(tf.layers.dense({
+ units: 2, activation: 'softmax',
+ kernelInitializer: tf.initializers.heNormal({ seed })
+ }))
+
+ model.compile({
+ optimizer: tf.train.adam(0.0005),
+ loss: 'categoricalCrossentropy',
+ metrics: ['accuracy']
+ })
+
+ return Promise.resolve(new models.TFJS('image', model))
+ }
+}
\ No newline at end of file
From c573bb5a3fd403d6e965117981098c511462ede9 Mon Sep 17 00:00:00 2001
From: Julien Vignoud
Date: Mon, 25 Nov 2024 12:38:34 +0100
Subject: [PATCH 07/25] cli: setup tinder dog CLI support
---
cli/package.json | 3 ++-
cli/src/args.ts | 7 ++++---
cli/src/cli.ts | 38 +++++++++++++++++----------------
cli/src/data.ts | 51 ++++++++++++++++++++++++++++++++++++++++-----
datasets/.gitignore | 3 +++
package-lock.json | 7 ++++---
6 files changed, 79 insertions(+), 30 deletions(-)
diff --git a/cli/package.json b/cli/package.json
index cc2a1eefd..f338b7521 100644
--- a/cli/package.json
+++ b/cli/package.json
@@ -14,9 +14,10 @@
"author": "",
"license": "ISC",
"dependencies": {
- "server": "*",
"@epfml/discojs-node": "*",
+ "csv-parse": "^5.6.0",
"immutable": "4",
+ "server": "*",
"tslib": "2"
},
"devDependencies": {
diff --git a/cli/src/args.ts b/cli/src/args.ts
index 74c4ed5c7..aad21a9f1 100644
--- a/cli/src/args.ts
+++ b/cli/src/args.ts
@@ -22,8 +22,8 @@ const argExample = 'e.g. npm start -- -u 2 -e 3 # runs 2 users for 3 epochs'
const unsafeArgs = parse(
{
- task: { type: String, alias: 't', description: 'Task: titanic, simple_face, cifar10 or lus_covid', defaultValue: 'simple_face' },
- numberOfUsers: { type: Number, alias: 'u', description: 'Number of users', defaultValue: 1 },
+ task: { type: String, alias: 't', description: 'Task: tinder_dog, titanic, simple_face, cifar10 or lus_covid', defaultValue: 'tinder_dog' },
+ numberOfUsers: { type: Number, alias: 'u', description: 'Number of users', defaultValue: 2 },
epochs: { type: Number, alias: 'e', description: 'Number of epochs', defaultValue: 10 },
roundDuration: { type: Number, alias: 'r', description: 'Round duration (in epochs)', defaultValue: 2 },
batchSize: { type: Number, alias: 'b', description: 'Training batch size', defaultValue: 10 },
@@ -42,6 +42,7 @@ const supportedTasks = Map(
defaultTasks.lusCovid,
defaultTasks.simpleFace,
defaultTasks.titanic,
+ defaultTasks.tinderDog,
).map((t) => [t.getTask().id, t]),
);
@@ -69,4 +70,4 @@ export const args: BenchmarkArguments = {
},
getModel: () => provider.getModel(),
},
-};
+};
\ No newline at end of file
diff --git a/cli/src/cli.ts b/cli/src/cli.ts
index 54ff0c3ef..dbf646374 100644
--- a/cli/src/cli.ts
+++ b/cli/src/cli.ts
@@ -13,9 +13,8 @@ import type {
TaskProvider,
} from "@epfml/discojs";
import { Disco, aggregator as aggregators, client as clients } from '@epfml/discojs'
-import { Server } from 'server'
-import { getTaskData } from './data.js'
+import { getTaskData, loadTinderDogData } from './data.js'
import { args } from './args.js'
// Array.fromAsync not yet widely used (2024)
@@ -49,23 +48,26 @@ async function main(
console.log(`Started ${task.trainingInformation.scheme} training of ${task.id}`)
console.log({ args })
- const [server, url] = await new Server().serve(undefined, provider)
+ const url = new URL('http://localhost:8080/')
- const data = await getTaskData(task)
-
- const logs = await Promise.all(
- Range(0, numberOfUsers).map(async (_) => await runUser(task, url, data)).toArray()
- )
-
- if (args.save) {
- const fileName = `${task.id}_${numberOfUsers}users.csv`;
- await fs.writeFile(fileName, JSON.stringify(logs, null, 2));
+ if (task.id === 'tinder_dog') {
+ const dataSplits = await Promise.all(
+ Range(0, numberOfUsers).map(async i => loadTinderDogData(i))
+ )
+ const _ = await Promise.all(
+ dataSplits.map(async data => runUser(task, url, data as Dataset))
+ )
+ } else {
+ const data = await getTaskData(task)
+
+ const logs = await Promise.all(
+ Range(0, numberOfUsers).map(async (_) => await runUser(task, url, data)).toArray()
+ )
+ if (args.save) {
+ const fileName = `${task.id}_${numberOfUsers}users.csv`;
+ await fs.writeFile(fileName, JSON.stringify(logs, null, 2));
+ }
}
- console.log('Shutting down the server...')
- await new Promise((resolve, reject) => {
- server.once('close', resolve)
- server.close(reject)
- })
}
-main(args.provider, args.numberOfUsers).catch(console.error)
+main(args.provider, args.numberOfUsers).catch(console.error)
\ No newline at end of file
diff --git a/cli/src/data.ts b/cli/src/data.ts
index 895a35bf7..8a1e13f05 100644
--- a/cli/src/data.ts
+++ b/cli/src/data.ts
@@ -1,14 +1,15 @@
import path from "node:path";
-
+import fs from 'node:fs/promises'
+import { parse } from 'csv-parse';
+import { Dataset } from "@epfml/discojs";
import type {
- Dataset,
DataFormat,
DataType,
Image,
Task,
} from "@epfml/discojs";
-import { loadCSV, loadImagesInDir } from "@epfml/discojs-node";
-import { Repeat } from "immutable";
+import { loadCSV, loadImage, loadImagesInDir } from "@epfml/discojs-node";
+import { Repeat, Map } from "immutable";
async function loadSimpleFaceData(): Promise> {
const folder = path.join("..", "datasets", "simple_face");
@@ -36,6 +37,46 @@ async function loadLusCovidData(): Promise> {
return positive.chain(negative);
}
+export async function loadTinderDogData(split: number): Promise> {
+ const folder = path.join("..", "datasets", "tinder_dog", `${split + 1}`);
+ console.log(`Reading data split ${folder}`)
+ const csvPath = path.join(folder, 'labels.csv')
+
+ const headers = ['filename', 'label'];
+ const fileContent = await fs.readFile(csvPath, { encoding: 'utf-8' });
+ const csvContent = await new Promise<{ filename: string, label: number }[]>((resolve, reject) => {
+ parse(fileContent, {
+ delimiter: ',',
+ columns: headers,
+ }, (error, result: { filename: string, label: number }[]) => {
+ if (error) {
+ console.error(error);
+ reject(error)
+ }
+ resolve(result)
+ });
+ })
+ const imgToLabel = Map(csvContent.map(entry =>
+ [entry.filename, entry.label] as const)
+ );
+ const fileExtensions = [".png", ".jpg", ".jpeg"];
+ const imagesFile = (await fs.readdir(folder)).filter(file => {
+ for (const ext of fileExtensions) if(file.endsWith(ext)) return true;
+ return false;
+ })
+ const labels = imagesFile.map(img => {
+ const label = imgToLabel.get(img.slice(0, -4)) // remove the file extension
+ if (label === undefined) throw Error(`Image ${img} not found in CSV`)
+ return label.toString()
+ })
+ const imgPaths = imagesFile.map(imgName => path.join(folder, imgName))
+ console.log(`Found ${imgPaths.length} in split ${split}`)
+ const images = await Promise.all(imgPaths.map(imgPath => loadImage(imgPath)))
+
+ return new Dataset(images).zip(labels)
+}
+
+
export async function getTaskData(
task: Task,
): Promise> {
@@ -55,4 +96,4 @@ export async function getTaskData(
default:
throw new Error(`Data loader for ${task.id} not implemented.`);
}
-}
+}
\ No newline at end of file
diff --git a/datasets/.gitignore b/datasets/.gitignore
index d1d80c705..f73eda468 100644
--- a/datasets/.gitignore
+++ b/datasets/.gitignore
@@ -17,3 +17,6 @@
# LUS Covid
/lus_covid/
+
+# GDHF demo
+/tinder_dog/
diff --git a/package-lock.json b/package-lock.json
index 1ce9857ae..2748b8368 100644
--- a/package-lock.json
+++ b/package-lock.json
@@ -30,6 +30,7 @@
"license": "ISC",
"dependencies": {
"@epfml/discojs-node": "*",
+ "csv-parse": "^5.6.0",
"immutable": "4",
"server": "*",
"tslib": "2"
@@ -5230,9 +5231,9 @@
"license": "MIT"
},
"node_modules/csv-parse": {
- "version": "5.5.6",
- "resolved": "https://registry.npmjs.org/csv-parse/-/csv-parse-5.5.6.tgz",
- "integrity": "sha512-uNpm30m/AGSkLxxy7d9yRXpJQFrZzVWLFBkS+6ngPcZkw/5k3L/jjFuj7tVnEpRn+QgmiXr21nDlhCiUK4ij2A==",
+ "version": "5.6.0",
+ "resolved": "https://registry.npmjs.org/csv-parse/-/csv-parse-5.6.0.tgz",
+ "integrity": "sha512-l3nz3euub2QMg5ouu5U09Ew9Wf6/wQ8I++ch1loQ0ljmzhmfZYrH9fflS22i/PQEvsPvxCwxgz5q7UB8K1JO4Q==",
"license": "MIT"
},
"node_modules/cypress": {
From 0733ae0577fa543874aa61c8584ee658f63cec9e Mon Sep 17 00:00:00 2001
From: Julien Vignoud
Date: Tue, 26 Nov 2024 13:32:47 +0100
Subject: [PATCH 08/25] server/controllers/federated_controller: reset state
when training session finished
---
server/src/controllers/federated_controller.ts | 11 +++++++++--
1 file changed, 9 insertions(+), 2 deletions(-)
diff --git a/server/src/controllers/federated_controller.ts b/server/src/controllers/federated_controller.ts
index c58a05328..0e5f92d71 100644
--- a/server/src/controllers/federated_controller.ts
+++ b/server/src/controllers/federated_controller.ts
@@ -32,9 +32,9 @@ export class FederatedController<
*/
#latestGlobalWeights: serialization.Encoded;
- constructor(task: Task, initialWeights: serialization.Encoded) {
+ constructor(task: Task, private readonly initialWeights: serialization.Encoded) {
super(task)
- this.#latestGlobalWeights = initialWeights
+ this.#latestGlobalWeights = this.initialWeights
// Save the latest weight updates to be able to send it to new or outdated clients
this.#aggregator.on('aggregation', async (weightUpdate) => {
@@ -145,6 +145,13 @@ export class FederatedController<
this.#aggregator.removeNode(clientId)
debug("client [%s] left", shortId)
+ // Reset the training session when all participants left
+ if (this.connections.size === 0) {
+ debug("All participants left. Resetting the training session")
+ this.#aggregator = new aggregators.MeanAggregator(undefined, 1, 'relative')
+ this.#latestGlobalWeights = this.initialWeights
+ }
+
// Check if we dropped below the minimum number of participant required
// or if we are already waiting for new participants to join
if (this.connections.size >= minNbOfParticipants ||
From bb6242d21cf3e572d296fcefb5bd7558b9939a98 Mon Sep 17 00:00:00 2001
From: Julien Vignoud
Date: Wed, 27 Nov 2024 14:28:20 +0100
Subject: [PATCH 09/25]
webapp/components/dataset_input/LabeledImageDatasetInput/ByGroup: shuffle
files connected by labels
---
.../dataset_input/LabeledImageDatasetInput/ByGroup.vue | 8 +++++++-
1 file changed, 7 insertions(+), 1 deletion(-)
diff --git a/webapp/src/components/dataset_input/LabeledImageDatasetInput/ByGroup.vue b/webapp/src/components/dataset_input/LabeledImageDatasetInput/ByGroup.vue
index f556367fb..6c1b2a030 100644
--- a/webapp/src/components/dataset_input/LabeledImageDatasetInput/ByGroup.vue
+++ b/webapp/src/components/dataset_input/LabeledImageDatasetInput/ByGroup.vue
@@ -67,7 +67,13 @@ function refreshWatcher() {
([label, files]) =>
files.value?.map((f) => [label, f] as const)?.toArray() ?? [],
);
-
+ // shuffle the filenames o.w. they are ordered by labels
+ for (let i = 0; i < expanded.length; i++) {
+ const j = Math.floor(Math.random() * i)
+ const swap = expanded[i]
+ expanded[i] = expanded[j]
+ expanded[j] = swap
+ }
dataset.value = new Dataset(expanded).map(async ([label, file]) => ({
filename: file.name,
image: await loadImage(file),
From da536eb43080b4c31336fc018e3058e3a890323a Mon Sep 17 00:00:00 2001
From: Julien Vignoud
Date: Wed, 27 Nov 2024 17:35:32 +0100
Subject: [PATCH 10/25] discojs/default_tasks/tinder_dog: improve textual
description and link sample dataset
---
discojs/src/default_tasks/tinder_dog.ts | 12 ++++++------
1 file changed, 6 insertions(+), 6 deletions(-)
diff --git a/discojs/src/default_tasks/tinder_dog.ts b/discojs/src/default_tasks/tinder_dog.ts
index f308275ba..3cc266ea4 100644
--- a/discojs/src/default_tasks/tinder_dog.ts
+++ b/discojs/src/default_tasks/tinder_dog.ts
@@ -13,12 +13,12 @@ export const tinderDog: TaskProvider<'image'> = {
preview: 'Which dog is the cutest....or not?',
overview: "Binary classification model for dog cuteness."
},
- // model: 'The model is a pretrained MobileNetV1 model trained in Tensorflow.js. The last output layer is replaced with a fully connected layer with softmax activation and one output neuron per CIFAR10 category. The data preprocessing reshapes images into 224x224 pixels and normalizes values between 0 and 1. The neural network is optimized via Stochastic Gradient Descent and a categorical Cross Entropy loss.',
- dataFormatInformation: 'Images should be of .png format.',
- /* dataExampleText: 'Below you can find 10 random examples from each of the 10 classes in the dataset.',
- dataExampleImage: 'https://storage.googleapis.com/deai-313515.appspot.com/example_training_data/cifar10-example.png',
- sampleDatasetLink: 'https://storage.googleapis.com/deai-313515.appspot.com/example_training_data.tar.gz',
- sampleDatasetInstructions: 'Opening the link should start downloading a zip file which you can unzip. To connect the data, use the CSV option below and select the file named "cifar10-labels.csv". You can now connect the images located in the "CIFAR10" folder. Note that there are only 24 images in this sample dataset which is far too few to successfully train a machine learning model.' */
+ model: 'The model is a simple Convolutional Neural Network composed of two convolutional layers with ReLU activations and max pooling layers, followed by a fully connected output layer. The data preprocessing reshapes images into 64x64 pixels and normalizes values between 0 and 1',
+ dataFormatInformation: 'Accepted image formats are .png .jpg and .jpeg.',
+ dataExampleText: '',
+ dataExampleImage: 'https://storage.googleapis.com/deai-313515.appspot.com/tinder_dog_preview.png',
+ sampleDatasetLink: 'https://storage.googleapis.com/deai-313515.appspot.com/tinder_dog.zip',
+ sampleDatasetInstructions: 'Opening the link should start downloading a zip file which you can unzip. To connect the data, pick one of the data splits (the folder 0 for example) and use the CSV option below to select the file named "labels.csv". You can now connect the images located in the same folder.'
},
trainingInformation: {
epochs: 10,
From b0be74599eab9a93f1d34b653b1f3269114dd902 Mon Sep 17 00:00:00 2001
From: Julien Vignoud
Date: Thu, 28 Nov 2024 11:32:46 +0100
Subject: [PATCH 11/25] discojs/models: use trainOnBatch instead of fit and
fitDataset
---
discojs/src/models/gpt/index.ts | 33 ++++----
discojs/src/models/gpt/model.ts | 141 +++++++++++---------------------
discojs/src/models/tfjs.ts | 25 +++---
3 files changed, 74 insertions(+), 125 deletions(-)
diff --git a/discojs/src/models/gpt/index.ts b/discojs/src/models/gpt/index.ts
index dbdcca5d7..ebb3494bd 100644
--- a/discojs/src/models/gpt/index.ts
+++ b/discojs/src/models/gpt/index.ts
@@ -80,24 +80,21 @@ export class GPT extends Model<"text"> {
async #runBatch(
batch: Batched,
): Promise {
- const tfBatch = this.#batchToTF(batch);
-
- let logs: tf.Logs | undefined;
- await this.model.fitDataset(tf.data.array([tfBatch]), {
- epochs: 1,
- verbose: 0, // don't pollute
- callbacks: {
- onEpochEnd: (_, cur) => {
- logs = cur;
- },
- },
- });
- tf.dispose(tfBatch);
- if (logs === undefined) throw new Error("batch didn't gave any logs");
-
- const { loss, acc: accuracy } = logs;
- if (loss === undefined || isNaN(loss))
- throw new Error("training loss is undefined or NaN");
+ const {xs, ys} = this.#batchToTF(batch);
+
+ const history = await this.model.trainOnBatch(xs, ys);
+ tf.dispose([xs, ys]);
+ if (!Array.isArray(history) || history.length != 2)
+ throw new Error("training output has unexpected shape")
+
+ const loss = history[0]
+ const accuracy = history[1]
+
+ if (
+ typeof loss !== "number" || isNaN(loss) ||
+ typeof accuracy !== "number" || isNaN(accuracy)
+ )
+ throw new Error("training loss or accuracy is undefined or NaN");
return {
accuracy,
diff --git a/discojs/src/models/gpt/model.ts b/discojs/src/models/gpt/model.ts
index 9a359b494..7aaf43a3d 100644
--- a/discojs/src/models/gpt/model.ts
+++ b/discojs/src/models/gpt/model.ts
@@ -4,7 +4,6 @@ import * as tf from '@tensorflow/tfjs'
import type { GPTConfig } from './config.js'
import { getModelSizes, DEFAULT_CONFIG } from './config.js'
import { getCustomAdam, clipByGlobalNormObj } from './optimizers.js'
-import evaluate from './evaluate.js'
import { GPTArchitecture } from './layers.js'
const debug = createDebug("discojs:models:gpt");
@@ -55,101 +54,57 @@ export class GPTModel extends tf.LayersModel {
: tf.train.adam(this.config.lr)
}
- override async fitDataset(dataset: Dataset, trainingArgs: tf.ModelFitDatasetArgs): Promise {
- const callbacks = trainingArgs.callbacks as tf.CustomCallbackArgs
- const evalDataset = trainingArgs.validationData as tf.data.Dataset<{ xs: tf.Tensor2D, ys: tf.Tensor3D }>
- await callbacks.onTrainBegin?.()
-
- for (let epoch = 1; epoch <= trainingArgs.epochs; epoch++) {
- let accuracyFraction: [number, number] = [0, 0];
- let averageLoss = 0
- let iteration = 1
- const iterator = await dataset.iterator()
- let next = await iterator.next()
-
- while (next.done !== true && iteration <= this.config.maxIter) {
- let weightUpdateTime = performance.now()
- await callbacks.onEpochBegin?.(epoch)
- const { xs, ys } = next.value as { xs: tf.Tensor2D, ys: tf.Tensor3D }
+ override async trainOnBatch(x: tf.Tensor, y: tf.Tensor): Promise {
+ let weightUpdateTime = performance.now()
- let preprocessingTime = performance.now()
- await Promise.all([xs.data(), ys.data()])
- preprocessingTime = performance.now() - preprocessingTime
+ let preprocessingTime = performance.now()
+ await Promise.all([x.data(), y.data()])
+ preprocessingTime = performance.now() - preprocessingTime
- // TODO include as a tensor inside the model
- const accTensor = tf.tidy(() => {
- const logits = this.apply(xs)
- if (Array.isArray(logits))
- throw new Error('model outputs too many tensor')
- if (logits instanceof tf.SymbolicTensor)
- throw new Error('model outputs symbolic tensor')
- return tf.metrics.categoricalAccuracy(ys, logits)
- })
- const accSize = accTensor.shape.reduce((l, r) => l * r, 1)
- const accSumTensor = accTensor.sum()
- const accSum = await accSumTensor.array()
- tf.dispose(accSumTensor)
- if (typeof accSum !== 'number')
- throw new Error('got multiple accuracy sum')
- accuracyFraction = [accuracyFraction[0] + accSum, accuracyFraction[1] + accSize];
- tf.dispose([accTensor])
+ // TODO include as a tensor inside the model
+ const accTensor = tf.tidy(() => {
+ const logits = this.apply(x)
+ if (Array.isArray(logits))
+ throw new Error('model outputs too many tensor')
+ if (logits instanceof tf.SymbolicTensor)
+ throw new Error('model outputs symbolic tensor')
+ return tf.metrics.categoricalAccuracy(y, logits)
+ })
+ const accSize = accTensor.shape.reduce((l, r) => l * r, 1)
+ const accSumTensor = accTensor.sum()
+ const accSum = await accSumTensor.array()
+ tf.dispose(accSumTensor)
+ if (typeof accSum !== 'number')
+ throw new Error('got multiple accuracy sum')
+ tf.dispose([accTensor])
- const lossTensor = tf.tidy(() => {
- const { grads, value: lossTensor } = this.optimizer.computeGradients(() => {
- const logits = this.apply(xs)
- if (Array.isArray(logits))
- throw new Error('model outputs too many tensor')
- if (logits instanceof tf.SymbolicTensor)
- throw new Error('model outputs symbolic tensor')
- return tf.losses.softmaxCrossEntropy(ys, logits)
- })
- const gradsClipped = clipByGlobalNormObj(grads, 1)
- this.optimizer.applyGradients(gradsClipped)
- return lossTensor
- })
-
- const loss = await lossTensor.array()
- averageLoss += loss
- weightUpdateTime = performance.now() - weightUpdateTime
+ const lossTensor = tf.tidy(() => {
+ const { grads, value: lossTensor } = this.optimizer.computeGradients(() => {
+ const logits = this.apply(x)
+ if (Array.isArray(logits))
+ throw new Error('model outputs too many tensor')
+ if (logits instanceof tf.SymbolicTensor)
+ throw new Error('model outputs symbolic tensor')
+ return tf.losses.softmaxCrossEntropy(y, logits)
+ })
+ const gradsClipped = clipByGlobalNormObj(grads, 1)
+ this.optimizer.applyGradients(gradsClipped)
+ return lossTensor
+ })
+
+ const loss = await lossTensor.array()
+ weightUpdateTime = performance.now() - weightUpdateTime
- tf.dispose([xs, ys, lossTensor])
-
- if (
- evalDataset !== undefined &&
- this.config.evaluateEvery !== undefined &&
- iteration % this.config.evaluateEvery == 0
- ){
- const iterationLogs = await evaluate(this, evalDataset, this.config.maxEvalBatches)
- debug('evaluation metrics: %O', iterationLogs);
- }
- const memory = tf.memory().numBytes / 1024 / 1024 / 1024
- debug("training metrics: %O", {
- epoch,
- iteration,
- loss,
- memory,
- allocated: tf.memory().numTensors,
- preprocessingTime,
- weightUpdateTime,
- });
- iteration++
- next = await iterator.next()
- }
- // Memory leak: If we reached the last iteration rather than the end of the dataset, cleanup the tensors
- if (next.done !== true && iteration > this.config.maxIter) {
- const { xs, ys } = next.value as { xs: tf.Tensor2D, ys: tf.Tensor3D }
- tf.dispose([xs, ys])
- }
- let logs: tf.Logs = {
- 'loss': averageLoss / iteration,
- 'acc': accuracyFraction[0] / accuracyFraction[1],
- }
- if (evalDataset !== undefined) {
- logs = { ...logs, ...await evaluate(this, evalDataset, this.config.maxEvalBatches) }
- }
- await callbacks.onEpochEnd?.(epoch, logs)
- }
- await callbacks.onTrainEnd?.()
- return new tf.History()
+ tf.dispose([x, y, lossTensor])
+
+ const memory = tf.memory().numBytes / 1024 / 1024 / 1024
+ debug("training metrics: %O", {
+ loss,
+ memory,
+ allocated: tf.memory().numTensors,
+ preprocessingTime,
+ weightUpdateTime,
+ });
+ return [loss, accSum / accSize]
}
}
diff --git a/discojs/src/models/tfjs.ts b/discojs/src/models/tfjs.ts
index c0f9cadc8..b2ccc6723 100644
--- a/discojs/src/models/tfjs.ts
+++ b/discojs/src/models/tfjs.ts
@@ -64,25 +64,22 @@ export class TFJS extends Model {
): Promise> {
const { xs, ys } = this.#batchToTF(batch);
- const { history } = await this.model.fit(xs, ys, {
- epochs: 1,
- verbose: 0, // don't pollute
- });
-
- const { loss: losses, acc: accuracies } = history;
+ const history = await this.model.trainOnBatch(xs, ys);
+ if (!Array.isArray(history) || history.length != 2)
+ throw new Error("training output has unexpected shape")
+
+ const loss = history[0]
+ const accuracy = history[1]
+
if (
- losses === undefined ||
- accuracies === undefined ||
- typeof losses[0] !== "number" ||
- typeof accuracies[0] !== "number" ||
- isNaN(losses[0]) ||
- isNaN(accuracies[0])
+ typeof loss !== "number" || isNaN(loss) ||
+ typeof accuracy !== "number" || isNaN(accuracy)
)
throw new Error("training loss or accuracy is undefined or NaN");
return {
- accuracy: accuracies[0],
- loss: losses[0],
+ accuracy,
+ loss,
memoryUsage: tf.memory().numBytes / 1024 / 1024 / 1024,
};
}
From 83a12d7ae9b478e1a5f377e689478a2453027849 Mon Sep 17 00:00:00 2001
From: Julien Vignoud
Date: Thu, 28 Nov 2024 11:33:14 +0100
Subject: [PATCH 12/25] cli/src: support wikitext task
---
cli/src/args.ts | 5 +++--
cli/src/data.ts | 11 ++++++++++-
2 files changed, 13 insertions(+), 3 deletions(-)
diff --git a/cli/src/args.ts b/cli/src/args.ts
index aad21a9f1..fc0c4ccf7 100644
--- a/cli/src/args.ts
+++ b/cli/src/args.ts
@@ -22,7 +22,7 @@ const argExample = 'e.g. npm start -- -u 2 -e 3 # runs 2 users for 3 epochs'
const unsafeArgs = parse(
{
- task: { type: String, alias: 't', description: 'Task: tinder_dog, titanic, simple_face, cifar10 or lus_covid', defaultValue: 'tinder_dog' },
+ task: { type: String, alias: 't', description: 'Task: tinder_dog, titanic, simple_face, cifar10, llm_task or lus_covid', defaultValue: 'tinder_dog' },
numberOfUsers: { type: Number, alias: 'u', description: 'Number of users', defaultValue: 2 },
epochs: { type: Number, alias: 'e', description: 'Number of epochs', defaultValue: 10 },
roundDuration: { type: Number, alias: 'r', description: 'Round duration (in epochs)', defaultValue: 2 },
@@ -37,12 +37,13 @@ const unsafeArgs = parse(
)
const supportedTasks = Map(
- Set.of | TaskProvider<"tabular">>(
+ Set.of | TaskProvider<"tabular"> | TaskProvider<"text">>(
defaultTasks.cifar10,
defaultTasks.lusCovid,
defaultTasks.simpleFace,
defaultTasks.titanic,
defaultTasks.tinderDog,
+ defaultTasks.wikitext,
).map((t) => [t.getTask().id, t]),
);
diff --git a/cli/src/data.ts b/cli/src/data.ts
index 8a1e13f05..c5ee247a4 100644
--- a/cli/src/data.ts
+++ b/cli/src/data.ts
@@ -7,8 +7,9 @@ import type {
DataType,
Image,
Task,
+ Text,
} from "@epfml/discojs";
-import { loadCSV, loadImage, loadImagesInDir } from "@epfml/discojs-node";
+import { loadCSV, loadImage, loadImagesInDir, loadText } from "@epfml/discojs-node";
import { Repeat, Map } from "immutable";
async function loadSimpleFaceData(): Promise> {
@@ -37,6 +38,12 @@ async function loadLusCovidData(): Promise> {
return positive.chain(negative);
}
+async function loadWikitextData(): Promise> {
+ const folder = path.join("..", "datasets", "wikitext");
+ const dataset: Dataset = loadText(path.join(folder, "wiki.train.tokens"))
+ return Promise.resolve(dataset)
+}
+
export async function loadTinderDogData(split: number): Promise> {
const folder = path.join("..", "datasets", "tinder_dog", `${split + 1}`);
console.log(`Reading data split ${folder}`)
@@ -93,6 +100,8 @@ export async function getTaskData(
).zip(Repeat("cat")) as Dataset;
case "lus_covid":
return (await loadLusCovidData()) as Dataset;
+ case "llm_task":
+ return (await loadWikitextData()) as Dataset;
default:
throw new Error(`Data loader for ${task.id} not implemented.`);
}
From c52ef111a61822ad37eb40dbf11528d37a3d3840 Mon Sep 17 00:00:00 2001
From: Julien Vignoud
Date: Thu, 28 Nov 2024 11:45:12 +0100
Subject: [PATCH 13/25] discojs/models/gpt: compute logits only once, 10%
faster
---
discojs/src/models/gpt/model.ts | 29 ++++++++++++-----------------
1 file changed, 12 insertions(+), 17 deletions(-)
diff --git a/discojs/src/models/gpt/model.ts b/discojs/src/models/gpt/model.ts
index 7aaf43a3d..42aed7e6e 100644
--- a/discojs/src/models/gpt/model.ts
+++ b/discojs/src/models/gpt/model.ts
@@ -61,23 +61,7 @@ export class GPTModel extends tf.LayersModel {
await Promise.all([x.data(), y.data()])
preprocessingTime = performance.now() - preprocessingTime
- // TODO include as a tensor inside the model
- const accTensor = tf.tidy(() => {
- const logits = this.apply(x)
- if (Array.isArray(logits))
- throw new Error('model outputs too many tensor')
- if (logits instanceof tf.SymbolicTensor)
- throw new Error('model outputs symbolic tensor')
- return tf.metrics.categoricalAccuracy(y, logits)
- })
- const accSize = accTensor.shape.reduce((l, r) => l * r, 1)
- const accSumTensor = accTensor.sum()
- const accSum = await accSumTensor.array()
- tf.dispose(accSumTensor)
- if (typeof accSum !== 'number')
- throw new Error('got multiple accuracy sum')
- tf.dispose([accTensor])
-
+ let logitsTensor: tf.Tensor;
const lossTensor = tf.tidy(() => {
const { grads, value: lossTensor } = this.optimizer.computeGradients(() => {
const logits = this.apply(x)
@@ -85,12 +69,23 @@ export class GPTModel extends tf.LayersModel {
throw new Error('model outputs too many tensor')
if (logits instanceof tf.SymbolicTensor)
throw new Error('model outputs symbolic tensor')
+ logitsTensor = tf.keep(logits)
return tf.losses.softmaxCrossEntropy(y, logits)
})
const gradsClipped = clipByGlobalNormObj(grads, 1)
this.optimizer.applyGradients(gradsClipped)
return lossTensor
})
+
+ // @ts-expect-error Variable 'logitsTensor' is used before being assigned
+ const accTensor = tf.metrics.categoricalAccuracy(y, logitsTensor)
+ const accSize = accTensor.shape.reduce((l, r) => l * r, 1)
+ const accSumTensor = accTensor.sum()
+ const accSum = await accSumTensor.array()
+ if (typeof accSum !== 'number')
+ throw new Error('got multiple accuracy sum')
+ // @ts-expect-error Variable 'logitsTensor' is used before being assigned
+ tf.dispose([accTensor, accSumTensor, logitsTensor])
const loss = await lossTensor.array()
weightUpdateTime = performance.now() - weightUpdateTime
From d287795d48c9d86164082c0991db03368da6c206 Mon Sep 17 00:00:00 2001
From: Julien Vignoud
Date: Thu, 28 Nov 2024 12:23:15 +0100
Subject: [PATCH 14/25] fixup! discojs/models: use trainOnBatch instead of fit
and fitDataset
---
discojs/src/models/gpt/index.ts | 23 +++--------------------
discojs/src/models/gpt/model.ts | 1 +
discojs/src/models/model.ts | 22 ++++++++++++++++++++++
discojs/src/models/tfjs.ts | 24 ++++--------------------
4 files changed, 30 insertions(+), 40 deletions(-)
diff --git a/discojs/src/models/gpt/index.ts b/discojs/src/models/gpt/index.ts
index ebb3494bd..43f95fe3b 100644
--- a/discojs/src/models/gpt/index.ts
+++ b/discojs/src/models/gpt/index.ts
@@ -81,26 +81,9 @@ export class GPT extends Model<"text"> {
batch: Batched,
): Promise {
const {xs, ys} = this.#batchToTF(batch);
-
- const history = await this.model.trainOnBatch(xs, ys);
- tf.dispose([xs, ys]);
- if (!Array.isArray(history) || history.length != 2)
- throw new Error("training output has unexpected shape")
-
- const loss = history[0]
- const accuracy = history[1]
-
- if (
- typeof loss !== "number" || isNaN(loss) ||
- typeof accuracy !== "number" || isNaN(accuracy)
- )
- throw new Error("training loss or accuracy is undefined or NaN");
-
- return {
- accuracy,
- loss,
- memoryUsage: tf.memory().numBytes / 1024 / 1024 / 1024,
- };
+ const logs = await this.model.trainOnBatch(xs, ys);
+ tf.dispose([xs, ys])
+ return this.getBatchLogs(logs)
}
async #evaluate(
diff --git a/discojs/src/models/gpt/model.ts b/discojs/src/models/gpt/model.ts
index 42aed7e6e..2557ce855 100644
--- a/discojs/src/models/gpt/model.ts
+++ b/discojs/src/models/gpt/model.ts
@@ -77,6 +77,7 @@ export class GPTModel extends tf.LayersModel {
return lossTensor
})
+ // TODO: replace accuracy by perplexity
// @ts-expect-error Variable 'logitsTensor' is used before being assigned
const accTensor = tf.metrics.categoricalAccuracy(y, logitsTensor)
const accSize = accTensor.shape.reduce((l, r) => l * r, 1)
diff --git a/discojs/src/models/model.ts b/discojs/src/models/model.ts
index dd7c0477c..6bbc4e70c 100644
--- a/discojs/src/models/model.ts
+++ b/discojs/src/models/model.ts
@@ -6,6 +6,8 @@ import type {
WeightsContainer,
} from "../index.js";
+import * as tf from "@tensorflow/tfjs";
+
import type { BatchLogs, EpochLogs } from "./logs.js";
/**
@@ -39,6 +41,26 @@ export abstract class Model implements Disposable {
batch: Batched,
): Promise>;
+ protected getBatchLogs(
+ logs: number | number[],
+ ): BatchLogs {
+ if (!Array.isArray(logs) || logs.length != 2)
+ throw new Error("training output has unexpected shape")
+
+ const [loss, accuracy] = logs
+
+ if (
+ typeof loss !== "number" || isNaN(loss) ||
+ typeof accuracy !== "number" || isNaN(accuracy)
+ )
+ throw new Error("training loss or accuracy is undefined or NaN");
+
+ return {
+ accuracy,
+ loss,
+ memoryUsage: tf.memory().numBytes / 1024 / 1024 / 1024,
+ };
+ }
/**
* This method is automatically called to cleanup the memory occupied by the model
* when leaving the definition scope if the instance has been defined with the `using` keyword.
diff --git a/discojs/src/models/tfjs.ts b/discojs/src/models/tfjs.ts
index b2ccc6723..4d6908569 100644
--- a/discojs/src/models/tfjs.ts
+++ b/discojs/src/models/tfjs.ts
@@ -61,27 +61,11 @@ export class TFJS extends Model {
async #runBatch(
batch: Batched,
- ): Promise> {
+ ): Promise {
const { xs, ys } = this.#batchToTF(batch);
-
- const history = await this.model.trainOnBatch(xs, ys);
- if (!Array.isArray(history) || history.length != 2)
- throw new Error("training output has unexpected shape")
-
- const loss = history[0]
- const accuracy = history[1]
-
- if (
- typeof loss !== "number" || isNaN(loss) ||
- typeof accuracy !== "number" || isNaN(accuracy)
- )
- throw new Error("training loss or accuracy is undefined or NaN");
-
- return {
- accuracy,
- loss,
- memoryUsage: tf.memory().numBytes / 1024 / 1024 / 1024,
- };
+ const logs = await this.model.trainOnBatch(xs, ys);
+ tf.dispose([xs, ys])
+ return this.getBatchLogs(logs)
}
async #evaluate(
From 2391c1bc240969d0996a9d7d1768ba3d5c579692 Mon Sep 17 00:00:00 2001
From: Julien Vignoud
Date: Thu, 28 Nov 2024 16:12:58 +0100
Subject: [PATCH 15/25] tmp: overriding weight update yields same as default
---
discojs/src/default_tasks/lus_covid.ts | 17 +++---
discojs/src/models/tfjs.ts | 72 ++++++++++++++++++++++++--
2 files changed, 78 insertions(+), 11 deletions(-)
diff --git a/discojs/src/default_tasks/lus_covid.ts b/discojs/src/default_tasks/lus_covid.ts
index 44dd46ed6..8331ac0d8 100644
--- a/discojs/src/default_tasks/lus_covid.ts
+++ b/discojs/src/default_tasks/lus_covid.ts
@@ -39,7 +39,8 @@ export const lusCovid: TaskProvider<'image'> = {
// Model architecture from tensorflow.js docs:
// https://codelabs.developers.google.com/codelabs/tfjs-training-classfication/index.html#4
- async getModel (): Promise> {
+ async getModel(): Promise> {
+ const seed = 42
const imageHeight = 100
const imageWidth = 100
const imageChannels = 3
@@ -55,7 +56,7 @@ export const lusCovid: TaskProvider<'image'> = {
filters: 8,
strides: 1,
activation: 'relu',
- kernelInitializer: 'varianceScaling'
+ kernelInitializer: tf.initializers.heNormal({ seed })
}))
// The MaxPooling layer acts as a sort of downsampling using max values
@@ -69,7 +70,7 @@ export const lusCovid: TaskProvider<'image'> = {
filters: 16,
strides: 1,
activation: 'relu',
- kernelInitializer: 'varianceScaling'
+ kernelInitializer: tf.initializers.heNormal({ seed })
}))
model.add(tf.layers.maxPooling2d({ poolSize: [2, 2], strides: [2, 2] }))
@@ -82,16 +83,16 @@ export const lusCovid: TaskProvider<'image'> = {
// output class.
model.add(tf.layers.dense({
units: numOutputClasses,
- kernelInitializer: 'varianceScaling',
- activation: 'softmax'
+ activation: 'softmax',
+ kernelInitializer: tf.initializers.heNormal({ seed })
}))
-
+
model.compile({
- optimizer: 'sgd',
+ optimizer: tf.train.sgd(0.001),
loss: 'binaryCrossentropy',
metrics: ['accuracy']
})
return Promise.resolve(new models.TFJS('image', model))
}
-}
+}
\ No newline at end of file
diff --git a/discojs/src/models/tfjs.ts b/discojs/src/models/tfjs.ts
index 4d6908569..09f329097 100644
--- a/discojs/src/models/tfjs.ts
+++ b/discojs/src/models/tfjs.ts
@@ -1,3 +1,4 @@
+import createDebug from "debug";
import { List, Map, Range } from "immutable";
import * as tf from '@tensorflow/tfjs'
@@ -13,6 +14,8 @@ import { BatchLogs } from './index.js'
import { Model } from './index.js'
import { EpochLogs } from './logs.js'
+const debug = createDebug("discojs:models:tfjs");
+
type Serialized = [D, tf.io.ModelArtifacts];
/** TensorFlow JavaScript model with standard training */
@@ -63,11 +66,71 @@ export class TFJS extends Model {
batch: Batched,
): Promise {
const { xs, ys } = this.#batchToTF(batch);
- const logs = await this.model.trainOnBatch(xs, ys);
+ const logs = await this.trainFedProx(xs, ys);
+ // const logs = await this.model.trainOnBatch(xs, ys);
tf.dispose([xs, ys])
return this.getBatchLogs(logs)
}
+ async trainFedProx(
+ xs: tf.Tensor, ys: tf.Tensor): Promise<[number, number]> {
+ // let logitsTensor: tf.Tensor;
+ debug(this.model.loss, this.model.losses, this.model.lossFunctions)
+ const lossFunction: () => tf.Scalar = () => {
+ this.model.apply(xs)
+ const logits = this.model.apply(xs)
+ if (Array.isArray(logits))
+ throw new Error('model outputs too many tensor')
+ if (logits instanceof tf.SymbolicTensor)
+ throw new Error('model outputs symbolic tensor')
+ // logitsTensor = tf.keep(logits)
+ // return tf.losses.softmaxCrossEntropy(ys, logits)
+ let y: tf.Tensor;
+ y = tf.clipByValue(logits, 0.00001, 1 - 0.00001);
+ y = tf.log(tf.div(y, tf.sub(1, y)));
+ return tf.losses.sigmoidCrossEntropy(ys, y);
+ // return tf.losses.sigmoidCrossEntropy(ys, logits)
+ }
+ const lossTensor = this.model.optimizer.minimize(lossFunction, true)
+ if (lossTensor === null) throw new Error("loss should not be null")
+ // const lossTensor = tf.tidy(() => {
+ // const { grads, value: lossTensor } = this.model.optimizer.computeGradients(() => {
+ // const logits = this.model.apply(xs)
+ // if (Array.isArray(logits))
+ // throw new Error('model outputs too many tensor')
+ // if (logits instanceof tf.SymbolicTensor)
+ // throw new Error('model outputs symbolic tensor')
+ // logitsTensor = tf.keep(logits)
+ // // return tf.losses.softmaxCrossEntropy(ys, logits)
+ // return this.model.calculateLosses(ys, logits)[0]
+ // })
+ // this.model.optimizer.applyGradients(grads)
+ // return lossTensor
+ // })
+
+ // // @ts-expect-error Variable 'logitsTensor' is used before being assigned
+ // const accTensor = tf.metrics.categoricalAccuracy(ys, logitsTensor)
+ // const accSize = accTensor.shape.reduce((l, r) => l * r, 1)
+ // const accSumTensor = accTensor.sum()
+ // const accSum = await accSumTensor.array()
+ // if (typeof accSum !== 'number')
+ // throw new Error('got multiple accuracy sum')
+ // // @ts-expect-error Variable 'logitsTensor' is used before being assigned
+ // tf.dispose([accTensor, accSumTensor, logitsTensor])
+
+ const loss = await lossTensor.array()
+ tf.dispose([xs, ys, lossTensor])
+
+ // const memory = tf.memory().numBytes / 1024 / 1024 / 1024
+ // debug("training metrics: %O", {
+ // loss,
+ // memory,
+ // allocated: tf.memory().numTensors,
+ // });
+ return [loss, 0]
+ // return [loss, accSum / accSize]
+ }
+
async #evaluate(
dataset: Dataset>,
): Promise> {
@@ -160,7 +223,10 @@ export class TFJS extends Model {
return new this(
datatype,
await tf.loadLayersModel({
- load: () => Promise.resolve(artifacts),
+ load: () => {
+ console.log("deserialize called")
+ return Promise.resolve(artifacts)
+ },
}),
);
}
@@ -187,7 +253,7 @@ export class TFJS extends Model {
return [this.datatype, await ret]
}
- [Symbol.dispose](): void{
+ [Symbol.dispose](): void {
this.model.dispose()
}
From a3dfaf4afda85a96b806335fbdd36f861cc8c5e0 Mon Sep 17 00:00:00 2001
From: Julien Vignoud
Date: Thu, 28 Nov 2024 16:57:20 +0100
Subject: [PATCH 16/25] tmp: sketch of fedprox implementation
---
discojs/src/models/model.ts | 4 ++
discojs/src/models/tfjs.ts | 88 ++++++++++++++++-----------------
discojs/src/training/trainer.ts | 3 +-
3 files changed, 49 insertions(+), 46 deletions(-)
diff --git a/discojs/src/models/model.ts b/discojs/src/models/model.ts
index 6bbc4e70c..b7e51eab7 100644
--- a/discojs/src/models/model.ts
+++ b/discojs/src/models/model.ts
@@ -17,12 +17,16 @@ import type { BatchLogs, EpochLogs } from "./logs.js";
**/
// TODO make it typesafe: same shape of data/input/weights
export abstract class Model implements Disposable {
+ protected prevRoundWeights: WeightsContainer | undefined;
// TODO don't allow external access but upgrade train to return weights on every epoch
/** Return training state */
abstract get weights(): WeightsContainer;
/** Set training state */
abstract set weights(ws: WeightsContainer);
+ set previousRoundWeights(ws: WeightsContainer | undefined) {
+ this.prevRoundWeights = ws
+ }
/**
* Improve predictor
*
diff --git a/discojs/src/models/tfjs.ts b/discojs/src/models/tfjs.ts
index 09f329097..16c1727ad 100644
--- a/discojs/src/models/tfjs.ts
+++ b/discojs/src/models/tfjs.ts
@@ -73,62 +73,60 @@ export class TFJS extends Model {
}
async trainFedProx(
- xs: tf.Tensor, ys: tf.Tensor): Promise<[number, number]> {
- // let logitsTensor: tf.Tensor;
- debug(this.model.loss, this.model.losses, this.model.lossFunctions)
+ xs: tf.Tensor, ys: tf.Tensor,
+ ): Promise<[number, number]> {
+ let logitsTensor: tf.Tensor;
const lossFunction: () => tf.Scalar = () => {
+ // Proximal term
+ let proximalTerm = tf.tensor(0)
+ if (this.prevRoundWeights !== undefined) {
+ // squared norm
+ const norm = new WeightsContainer(this.model.getWeights())
+ .sub(this.prevRoundWeights)
+ .map(t => t.square().sum())
+ .reduce((t, acc) => tf.add(t, acc)).asScalar()
+ const mu = 1
+ proximalTerm = tf.mul(mu / 2, norm)
+ }
+
this.model.apply(xs)
const logits = this.model.apply(xs)
- if (Array.isArray(logits))
- throw new Error('model outputs too many tensor')
- if (logits instanceof tf.SymbolicTensor)
- throw new Error('model outputs symbolic tensor')
- // logitsTensor = tf.keep(logits)
- // return tf.losses.softmaxCrossEntropy(ys, logits)
- let y: tf.Tensor;
- y = tf.clipByValue(logits, 0.00001, 1 - 0.00001);
- y = tf.log(tf.div(y, tf.sub(1, y)));
- return tf.losses.sigmoidCrossEntropy(ys, y);
- // return tf.losses.sigmoidCrossEntropy(ys, logits)
+ if (Array.isArray(logits))
+ throw new Error('model outputs too many tensor')
+ if (logits instanceof tf.SymbolicTensor)
+ throw new Error('model outputs symbolic tensor')
+ logitsTensor = tf.keep(logits)
+ // binaryCrossEntropy
+ let y: tf.Tensor;
+ y = tf.clipByValue(logits, 0.00001, 1 - 0.00001);
+ y = tf.log(tf.div(y, tf.sub(1, y)));
+ const loss = tf.losses.sigmoidCrossEntropy(ys, y);
+ console.log(loss.dataSync(), proximalTerm.dataSync())
+ return tf.add(loss, proximalTerm)
}
const lossTensor = this.model.optimizer.minimize(lossFunction, true)
if (lossTensor === null) throw new Error("loss should not be null")
- // const lossTensor = tf.tidy(() => {
- // const { grads, value: lossTensor } = this.model.optimizer.computeGradients(() => {
- // const logits = this.model.apply(xs)
- // if (Array.isArray(logits))
- // throw new Error('model outputs too many tensor')
- // if (logits instanceof tf.SymbolicTensor)
- // throw new Error('model outputs symbolic tensor')
- // logitsTensor = tf.keep(logits)
- // // return tf.losses.softmaxCrossEntropy(ys, logits)
- // return this.model.calculateLosses(ys, logits)[0]
- // })
- // this.model.optimizer.applyGradients(grads)
- // return lossTensor
- // })
- // // @ts-expect-error Variable 'logitsTensor' is used before being assigned
- // const accTensor = tf.metrics.categoricalAccuracy(ys, logitsTensor)
- // const accSize = accTensor.shape.reduce((l, r) => l * r, 1)
- // const accSumTensor = accTensor.sum()
- // const accSum = await accSumTensor.array()
- // if (typeof accSum !== 'number')
- // throw new Error('got multiple accuracy sum')
- // // @ts-expect-error Variable 'logitsTensor' is used before being assigned
- // tf.dispose([accTensor, accSumTensor, logitsTensor])
+ // @ts-expect-error Variable 'logitsTensor' is used before being assigned
+ const accTensor = tf.metrics.categoricalAccuracy(ys, logitsTensor)
+ const accSize = accTensor.shape.reduce((l, r) => l * r, 1)
+ const accSumTensor = accTensor.sum()
+ const accSum = await accSumTensor.array()
+ if (typeof accSum !== 'number')
+ throw new Error('got multiple accuracy sum')
+ // @ts-expect-error Variable 'logitsTensor' is used before being assigned
+ tf.dispose([accTensor, accSumTensor, logitsTensor])
const loss = await lossTensor.array()
tf.dispose([xs, ys, lossTensor])
- // const memory = tf.memory().numBytes / 1024 / 1024 / 1024
- // debug("training metrics: %O", {
- // loss,
- // memory,
- // allocated: tf.memory().numTensors,
- // });
- return [loss, 0]
- // return [loss, accSum / accSize]
+ const memory = tf.memory().numBytes / 1024 / 1024 / 1024
+ debug("training metrics: %O", {
+ loss,
+ memory,
+ allocated: tf.memory().numTensors,
+ });
+ return [loss, accSum / accSize]
}
async #evaluate(
diff --git a/discojs/src/training/trainer.ts b/discojs/src/training/trainer.ts
index 1124137be..db6877323 100644
--- a/discojs/src/training/trainer.ts
+++ b/discojs/src/training/trainer.ts
@@ -90,7 +90,8 @@ export class Trainer {
let previousRoundWeights: WeightsContainer | undefined;
for (let round = 0; round < totalRound; round++) {
await this.#client.onRoundBeginCommunication();
-
+
+ this.model.previousRoundWeights = previousRoundWeights
yield this.#runRound(dataset, validationDataset);
let localWeights = this.model.weights;
From 9ad89811fadaaf73f333c1f856897c6cbb9a6cd2 Mon Sep 17 00:00:00 2001
From: Julien Vignoud
Date: Thu, 28 Nov 2024 11:32:46 +0100
Subject: [PATCH 17/25] discojs/models: use trainOnBatch instead of fit and
fitDataset
---
discojs/src/models/gpt/index.ts | 28 +------
discojs/src/models/gpt/model.ts | 141 +++++++++++---------------------
discojs/src/models/model.ts | 22 +++++
discojs/src/models/tfjs.ts | 27 +-----
4 files changed, 78 insertions(+), 140 deletions(-)
diff --git a/discojs/src/models/gpt/index.ts b/discojs/src/models/gpt/index.ts
index 2eb02d4fe..6fdf4b6c1 100644
--- a/discojs/src/models/gpt/index.ts
+++ b/discojs/src/models/gpt/index.ts
@@ -76,30 +76,10 @@ export class GPT extends Model<"text"> {
async #runBatch(
batch: Batched,
): Promise {
- const tfBatch = this.#batchToTF(batch);
-
- let logs: tf.Logs | undefined;
- await this.model.fitDataset(tf.data.array([tfBatch]), {
- epochs: 1,
- verbose: 0, // don't pollute
- callbacks: {
- onEpochEnd: (_, cur) => {
- logs = cur;
- },
- },
- });
- tf.dispose(tfBatch);
- if (logs === undefined) throw new Error("batch didn't gave any logs");
-
- const { loss, acc: accuracy } = logs;
- if (loss === undefined || isNaN(loss))
- throw new Error("training loss is undefined or NaN");
-
- return {
- accuracy,
- loss,
- memoryUsage: tf.memory().numBytes / 1024 / 1024 / 1024,
- };
+ const {xs, ys} = this.#batchToTF(batch);
+ const logs = await this.model.trainOnBatch(xs, ys);
+ tf.dispose([xs, ys])
+ return this.getBatchLogs(logs)
}
async #evaluate(
diff --git a/discojs/src/models/gpt/model.ts b/discojs/src/models/gpt/model.ts
index 01ee51e92..3a278b357 100644
--- a/discojs/src/models/gpt/model.ts
+++ b/discojs/src/models/gpt/model.ts
@@ -4,7 +4,6 @@ import * as tf from '@tensorflow/tfjs'
import type { GPTConfig } from './config.js'
import { getModelSizes, DefaultGPTConfig } from './config.js'
import { getCustomAdam, clipByGlobalNormObj } from './optimizers.js'
-import evaluate from './evaluate.js'
import { GPTArchitecture } from './layers.js'
const debug = createDebug("discojs:models:gpt:model");
@@ -55,101 +54,57 @@ export class GPTModel extends tf.LayersModel {
: tf.train.adam(this.config.lr)
}
- override async fitDataset(dataset: Dataset, trainingArgs: tf.ModelFitDatasetArgs): Promise {
- const callbacks = trainingArgs.callbacks as tf.CustomCallbackArgs
- const evalDataset = trainingArgs.validationData as tf.data.Dataset<{ xs: tf.Tensor2D, ys: tf.Tensor3D }>
- await callbacks.onTrainBegin?.()
+ override async trainOnBatch(x: tf.Tensor, y: tf.Tensor): Promise {
+ let weightUpdateTime = performance.now()
- for (let epoch = 1; epoch <= trainingArgs.epochs; epoch++) {
- let accuracyFraction: [number, number] = [0, 0];
- let averageLoss = 0
- let iteration = 1
- const iterator = await dataset.iterator()
- let next = await iterator.next()
+ let preprocessingTime = performance.now()
+ await Promise.all([x.data(), y.data()])
+ preprocessingTime = performance.now() - preprocessingTime
- while (next.done !== true && iteration <= this.config.maxIter) {
- let weightUpdateTime = performance.now()
- await callbacks.onEpochBegin?.(epoch)
- const { xs, ys } = next.value as { xs: tf.Tensor2D, ys: tf.Tensor3D }
+ // TODO include as a tensor inside the model
+ const accTensor = tf.tidy(() => {
+ const logits = this.apply(x)
+ if (Array.isArray(logits))
+ throw new Error('model outputs too many tensor')
+ if (logits instanceof tf.SymbolicTensor)
+ throw new Error('model outputs symbolic tensor')
+ return tf.metrics.categoricalAccuracy(y, logits)
+ })
+ const accSize = accTensor.shape.reduce((l, r) => l * r, 1)
+ const accSumTensor = accTensor.sum()
+ const accSum = await accSumTensor.array()
+ tf.dispose(accSumTensor)
+ if (typeof accSum !== 'number')
+ throw new Error('got multiple accuracy sum')
+ tf.dispose([accTensor])
- let preprocessingTime = performance.now()
- await Promise.all([xs.data(), ys.data()])
- preprocessingTime = performance.now() - preprocessingTime
-
- // TODO include as a tensor inside the model
- const accTensor = tf.tidy(() => {
- const logits = this.apply(xs)
- if (Array.isArray(logits))
- throw new Error('model outputs too many tensor')
- if (logits instanceof tf.SymbolicTensor)
- throw new Error('model outputs symbolic tensor')
- return tf.metrics.categoricalAccuracy(ys, logits)
- })
- const accSize = accTensor.shape.reduce((l, r) => l * r, 1)
- const accSumTensor = accTensor.sum()
- const accSum = await accSumTensor.array()
- tf.dispose(accSumTensor)
- if (typeof accSum !== 'number')
- throw new Error('got multiple accuracy sum')
- accuracyFraction = [accuracyFraction[0] + accSum, accuracyFraction[1] + accSize];
- tf.dispose([accTensor])
+ const lossTensor = tf.tidy(() => {
+ const { grads, value: lossTensor } = this.optimizer.computeGradients(() => {
+ const logits = this.apply(x)
+ if (Array.isArray(logits))
+ throw new Error('model outputs too many tensor')
+ if (logits instanceof tf.SymbolicTensor)
+ throw new Error('model outputs symbolic tensor')
+ return tf.losses.softmaxCrossEntropy(y, logits)
+ })
+ const gradsClipped = clipByGlobalNormObj(grads, 1)
+ this.optimizer.applyGradients(gradsClipped)
+ return lossTensor
+ })
+
+ const loss = await lossTensor.array()
+ weightUpdateTime = performance.now() - weightUpdateTime
- const lossTensor = tf.tidy(() => {
- const { grads, value: lossTensor } = this.optimizer.computeGradients(() => {
- const logits = this.apply(xs)
- if (Array.isArray(logits))
- throw new Error('model outputs too many tensor')
- if (logits instanceof tf.SymbolicTensor)
- throw new Error('model outputs symbolic tensor')
- return tf.losses.softmaxCrossEntropy(ys, logits)
- })
- const gradsClipped = clipByGlobalNormObj(grads, 1)
- this.optimizer.applyGradients(gradsClipped)
- return lossTensor
- })
-
- const loss = await lossTensor.array()
- averageLoss += loss
- weightUpdateTime = performance.now() - weightUpdateTime
-
- tf.dispose([xs, ys, lossTensor])
-
- if (
- evalDataset !== undefined &&
- this.config.evaluateEvery !== undefined &&
- iteration % this.config.evaluateEvery == 0
- ){
- const iterationLogs = await evaluate(this, evalDataset, this.config.maxEvalBatches)
- debug('evaluation metrics: %O', iterationLogs);
- }
- const memory = tf.memory().numBytes / 1024 / 1024 / 1024
- debug("training metrics: %O", {
- epoch,
- iteration,
- loss,
- memory,
- allocated: tf.memory().numTensors,
- preprocessingTime,
- weightUpdateTime,
- });
- iteration++
- next = await iterator.next()
- }
- // Memory leak: If we reached the last iteration rather than the end of the dataset, cleanup the tensors
- if (next.done !== true && iteration > this.config.maxIter) {
- const { xs, ys } = next.value as { xs: tf.Tensor2D, ys: tf.Tensor3D }
- tf.dispose([xs, ys])
- }
- let logs: tf.Logs = {
- 'loss': averageLoss / (iteration - 1), // -1 because iteration got incremented at the end of the loop
- 'acc': accuracyFraction[0] / accuracyFraction[1],
- }
- if (evalDataset !== undefined) {
- logs = { ...logs, ...await evaluate(this, evalDataset, this.config.maxEvalBatches) }
- }
- await callbacks.onEpochEnd?.(epoch, logs)
- }
- await callbacks.onTrainEnd?.()
- return new tf.History()
+ tf.dispose([x, y, lossTensor])
+
+ const memory = tf.memory().numBytes / 1024 / 1024 / 1024
+ debug("training metrics: %O", {
+ loss,
+ memory,
+ allocated: tf.memory().numTensors,
+ preprocessingTime,
+ weightUpdateTime,
+ });
+ return [loss, accSum / accSize]
}
}
diff --git a/discojs/src/models/model.ts b/discojs/src/models/model.ts
index dd7c0477c..6bbc4e70c 100644
--- a/discojs/src/models/model.ts
+++ b/discojs/src/models/model.ts
@@ -6,6 +6,8 @@ import type {
WeightsContainer,
} from "../index.js";
+import * as tf from "@tensorflow/tfjs";
+
import type { BatchLogs, EpochLogs } from "./logs.js";
/**
@@ -39,6 +41,26 @@ export abstract class Model implements Disposable {
batch: Batched,
): Promise>;
+ protected getBatchLogs(
+ logs: number | number[],
+ ): BatchLogs {
+ if (!Array.isArray(logs) || logs.length != 2)
+ throw new Error("training output has unexpected shape")
+
+ const [loss, accuracy] = logs
+
+ if (
+ typeof loss !== "number" || isNaN(loss) ||
+ typeof accuracy !== "number" || isNaN(accuracy)
+ )
+ throw new Error("training loss or accuracy is undefined or NaN");
+
+ return {
+ accuracy,
+ loss,
+ memoryUsage: tf.memory().numBytes / 1024 / 1024 / 1024,
+ };
+ }
/**
* This method is automatically called to cleanup the memory occupied by the model
* when leaving the definition scope if the instance has been defined with the `using` keyword.
diff --git a/discojs/src/models/tfjs.ts b/discojs/src/models/tfjs.ts
index b60060f49..158a0bb64 100644
--- a/discojs/src/models/tfjs.ts
+++ b/discojs/src/models/tfjs.ts
@@ -62,30 +62,11 @@ export class TFJS extends Model {
async #runBatch(
batch: Batched,
- ): Promise> {
+ ): Promise {
const { xs, ys } = this.#batchToTF(batch);
-
- const { history } = await this.model.fit(xs, ys, {
- epochs: 1,
- verbose: 0, // don't pollute
- });
-
- const { loss: losses, acc: accuracies } = history;
- if (
- losses === undefined ||
- accuracies === undefined ||
- typeof losses[0] !== "number" ||
- typeof accuracies[0] !== "number" ||
- isNaN(losses[0]) ||
- isNaN(accuracies[0])
- )
- throw new Error("training loss or accuracy is undefined or NaN");
-
- return {
- accuracy: accuracies[0],
- loss: losses[0],
- memoryUsage: tf.memory().numBytes / 1024 / 1024 / 1024,
- };
+ const logs = await this.model.trainOnBatch(xs, ys);
+ tf.dispose([xs, ys])
+ return this.getBatchLogs(logs)
}
async #evaluate(
From c20ad8289cc4332f470341e5199649b366c43993 Mon Sep 17 00:00:00 2001
From: Julien Vignoud
Date: Thu, 28 Nov 2024 11:45:12 +0100
Subject: [PATCH 18/25] discojs/models/gpt: compute logits only once, 10%
faster
---
discojs/src/models/gpt/model.ts | 29 ++++++++++++-----------------
1 file changed, 12 insertions(+), 17 deletions(-)
diff --git a/discojs/src/models/gpt/model.ts b/discojs/src/models/gpt/model.ts
index 3a278b357..8c646c6d4 100644
--- a/discojs/src/models/gpt/model.ts
+++ b/discojs/src/models/gpt/model.ts
@@ -61,23 +61,7 @@ export class GPTModel extends tf.LayersModel {
await Promise.all([x.data(), y.data()])
preprocessingTime = performance.now() - preprocessingTime
- // TODO include as a tensor inside the model
- const accTensor = tf.tidy(() => {
- const logits = this.apply(x)
- if (Array.isArray(logits))
- throw new Error('model outputs too many tensor')
- if (logits instanceof tf.SymbolicTensor)
- throw new Error('model outputs symbolic tensor')
- return tf.metrics.categoricalAccuracy(y, logits)
- })
- const accSize = accTensor.shape.reduce((l, r) => l * r, 1)
- const accSumTensor = accTensor.sum()
- const accSum = await accSumTensor.array()
- tf.dispose(accSumTensor)
- if (typeof accSum !== 'number')
- throw new Error('got multiple accuracy sum')
- tf.dispose([accTensor])
-
+ let logitsTensor: tf.Tensor;
const lossTensor = tf.tidy(() => {
const { grads, value: lossTensor } = this.optimizer.computeGradients(() => {
const logits = this.apply(x)
@@ -85,12 +69,23 @@ export class GPTModel extends tf.LayersModel {
throw new Error('model outputs too many tensor')
if (logits instanceof tf.SymbolicTensor)
throw new Error('model outputs symbolic tensor')
+ logitsTensor = tf.keep(logits)
return tf.losses.softmaxCrossEntropy(y, logits)
})
const gradsClipped = clipByGlobalNormObj(grads, 1)
this.optimizer.applyGradients(gradsClipped)
return lossTensor
})
+
+ // @ts-expect-error Variable 'logitsTensor' is used before being assigned
+ const accTensor = tf.metrics.categoricalAccuracy(y, logitsTensor)
+ const accSize = accTensor.shape.reduce((l, r) => l * r, 1)
+ const accSumTensor = accTensor.sum()
+ const accSum = await accSumTensor.array()
+ if (typeof accSum !== 'number')
+ throw new Error('got multiple accuracy sum')
+ // @ts-expect-error Variable 'logitsTensor' is used before being assigned
+ tf.dispose([accTensor, accSumTensor, logitsTensor])
const loss = await lossTensor.array()
weightUpdateTime = performance.now() - weightUpdateTime
From abe799607a6fe67235a67f027d84fa44e3997b3f Mon Sep 17 00:00:00 2001
From: Julien Vignoud
Date: Tue, 25 Feb 2025 16:48:25 +0100
Subject: [PATCH 19/25] tmp: overriding weight update yields same as default
---
discojs/src/default_tasks/lus_covid.ts | 17 ++++++------
discojs/src/models/tfjs.ts | 37 +++++++++++++++++++++++++-
2 files changed, 45 insertions(+), 9 deletions(-)
diff --git a/discojs/src/default_tasks/lus_covid.ts b/discojs/src/default_tasks/lus_covid.ts
index 44dd46ed6..8331ac0d8 100644
--- a/discojs/src/default_tasks/lus_covid.ts
+++ b/discojs/src/default_tasks/lus_covid.ts
@@ -39,7 +39,8 @@ export const lusCovid: TaskProvider<'image'> = {
// Model architecture from tensorflow.js docs:
// https://codelabs.developers.google.com/codelabs/tfjs-training-classfication/index.html#4
- async getModel (): Promise> {
+ async getModel(): Promise> {
+ const seed = 42
const imageHeight = 100
const imageWidth = 100
const imageChannels = 3
@@ -55,7 +56,7 @@ export const lusCovid: TaskProvider<'image'> = {
filters: 8,
strides: 1,
activation: 'relu',
- kernelInitializer: 'varianceScaling'
+ kernelInitializer: tf.initializers.heNormal({ seed })
}))
// The MaxPooling layer acts as a sort of downsampling using max values
@@ -69,7 +70,7 @@ export const lusCovid: TaskProvider<'image'> = {
filters: 16,
strides: 1,
activation: 'relu',
- kernelInitializer: 'varianceScaling'
+ kernelInitializer: tf.initializers.heNormal({ seed })
}))
model.add(tf.layers.maxPooling2d({ poolSize: [2, 2], strides: [2, 2] }))
@@ -82,16 +83,16 @@ export const lusCovid: TaskProvider<'image'> = {
// output class.
model.add(tf.layers.dense({
units: numOutputClasses,
- kernelInitializer: 'varianceScaling',
- activation: 'softmax'
+ activation: 'softmax',
+ kernelInitializer: tf.initializers.heNormal({ seed })
}))
-
+
model.compile({
- optimizer: 'sgd',
+ optimizer: tf.train.sgd(0.001),
loss: 'binaryCrossentropy',
metrics: ['accuracy']
})
return Promise.resolve(new models.TFJS('image', model))
}
-}
+}
\ No newline at end of file
diff --git a/discojs/src/models/tfjs.ts b/discojs/src/models/tfjs.ts
index 158a0bb64..4e29498b9 100644
--- a/discojs/src/models/tfjs.ts
+++ b/discojs/src/models/tfjs.ts
@@ -1,3 +1,4 @@
+import createDebug from "debug";
import { List, Map, Range } from "immutable";
import * as tf from '@tensorflow/tfjs'
@@ -13,6 +14,8 @@ import { BatchLogs } from './index.js'
import { Model } from './index.js'
import { EpochLogs } from './logs.js'
+const debug = createDebug("discojs:models:tfjs");
+
type Serialized = [D, tf.io.ModelArtifacts];
/** TensorFlow JavaScript model with standard training */
@@ -64,11 +67,43 @@ export class TFJS extends Model {
batch: Batched,
): Promise {
const { xs, ys } = this.#batchToTF(batch);
- const logs = await this.model.trainOnBatch(xs, ys);
+ // Toggling two next lines should yield the same training loss
+ const logs = await this.trainFedProx(xs, ys);
+ // const logs = await this.model.trainOnBatch(xs, ys);
tf.dispose([xs, ys])
return this.getBatchLogs(logs)
}
+ // First iteration: replace trainOnBatch with custom loss computation
+ async trainFedProx(
+ xs: tf.Tensor, ys: tf.Tensor): Promise<[number, number]> {
+
+ debug(this.model.loss, this.model.losses, this.model.lossFunctions)
+ const lossFunction: () => tf.Scalar = () => {
+ this.model.apply(xs)
+ const logits = this.model.apply(xs)
+ if (Array.isArray(logits))
+ throw new Error('model outputs too many tensor')
+ if (logits instanceof tf.SymbolicTensor)
+ throw new Error('model outputs symbolic tensor')
+
+ // binaryCrossEntropyLoss as implemented by tensorflow.js
+ // https://github.com/tensorflow/tfjs/blob/2644bd0d6cea677f80e44ed4a44bea5e04aabeb3/tfjs-layers/src/losses.ts#L193
+ let y: tf.Tensor;
+ y = tf.clipByValue(logits, 0.00001, 1 - 0.00001);
+ y = tf.log(tf.div(y, tf.sub(1, y)));
+ return tf.losses.sigmoidCrossEntropy(ys, y);
+ }
+ const lossTensor = this.model.optimizer.minimize(lossFunction, true)
+ if (lossTensor === null) throw new Error("loss should not be null")
+
+ const loss = await lossTensor.array()
+ tf.dispose([xs, ys, lossTensor])
+
+ // dummy accuracy for now
+ return [loss, 0]
+ }
+
async #evaluate(
dataset: Dataset>,
): Promise> {
From 76977e7096919ef97d59d3356bd9b9b746520b41 Mon Sep 17 00:00:00 2001
From: Julien Vignoud
Date: Thu, 28 Nov 2024 16:57:20 +0100
Subject: [PATCH 20/25] tmp: sketch of fedprox implementation
---
discojs/src/models/model.ts | 4 ++
discojs/src/models/tfjs.ts | 67 +++++++++++++++++++++++----------
discojs/src/training/trainer.ts | 3 +-
3 files changed, 54 insertions(+), 20 deletions(-)
diff --git a/discojs/src/models/model.ts b/discojs/src/models/model.ts
index 6bbc4e70c..b7e51eab7 100644
--- a/discojs/src/models/model.ts
+++ b/discojs/src/models/model.ts
@@ -17,12 +17,16 @@ import type { BatchLogs, EpochLogs } from "./logs.js";
**/
// TODO make it typesafe: same shape of data/input/weights
export abstract class Model implements Disposable {
+ protected prevRoundWeights: WeightsContainer | undefined;
// TODO don't allow external access but upgrade train to return weights on every epoch
/** Return training state */
abstract get weights(): WeightsContainer;
/** Set training state */
abstract set weights(ws: WeightsContainer);
+ set previousRoundWeights(ws: WeightsContainer | undefined) {
+ this.prevRoundWeights = ws
+ }
/**
* Improve predictor
*
diff --git a/discojs/src/models/tfjs.ts b/discojs/src/models/tfjs.ts
index 4e29498b9..a906be741 100644
--- a/discojs/src/models/tfjs.ts
+++ b/discojs/src/models/tfjs.ts
@@ -76,32 +76,61 @@ export class TFJS extends Model {
// First iteration: replace trainOnBatch with custom loss computation
async trainFedProx(
- xs: tf.Tensor, ys: tf.Tensor): Promise<[number, number]> {
-
- debug(this.model.loss, this.model.losses, this.model.lossFunctions)
+ xs: tf.Tensor, ys: tf.Tensor,
+ ): Promise<[number, number]> {
+ let logitsTensor: tf.Tensor;
const lossFunction: () => tf.Scalar = () => {
+ // Proximal term
+ let proximalTerm = tf.tensor(0)
+ if (this.prevRoundWeights !== undefined) {
+ // squared norm
+ const norm = new WeightsContainer(this.model.getWeights())
+ .sub(this.prevRoundWeights)
+ .map(t => t.square().sum())
+ .reduce((t, acc) => tf.add(t, acc)).asScalar()
+ const mu = 1
+ proximalTerm = tf.mul(mu / 2, norm)
+ }
+
this.model.apply(xs)
const logits = this.model.apply(xs)
- if (Array.isArray(logits))
- throw new Error('model outputs too many tensor')
- if (logits instanceof tf.SymbolicTensor)
- throw new Error('model outputs symbolic tensor')
-
- // binaryCrossEntropyLoss as implemented by tensorflow.js
- // https://github.com/tensorflow/tfjs/blob/2644bd0d6cea677f80e44ed4a44bea5e04aabeb3/tfjs-layers/src/losses.ts#L193
- let y: tf.Tensor;
- y = tf.clipByValue(logits, 0.00001, 1 - 0.00001);
- y = tf.log(tf.div(y, tf.sub(1, y)));
- return tf.losses.sigmoidCrossEntropy(ys, y);
+ if (Array.isArray(logits))
+ throw new Error('model outputs too many tensor')
+ if (logits instanceof tf.SymbolicTensor)
+ throw new Error('model outputs symbolic tensor')
+ logitsTensor = tf.keep(logits)
+ // binaryCrossentropy as implemented by tensorflow.js
+ // https://github.com/tensorflow/tfjs/blob/2644bd0d6cea677f80e44ed4a44bea5e04aabeb3/tfjs-layers/src/losses.ts#L193
+ let y: tf.Tensor;
+ y = tf.clipByValue(logits, 0.00001, 1 - 0.00001);
+ y = tf.log(tf.div(y, tf.sub(1, y)));
+ const loss = tf.losses.sigmoidCrossEntropy(ys, y);
+ console.log(loss.dataSync(), proximalTerm.dataSync())
+ return tf.add(loss, proximalTerm)
}
const lossTensor = this.model.optimizer.minimize(lossFunction, true)
if (lossTensor === null) throw new Error("loss should not be null")
-
- const loss = await lossTensor.array()
- tf.dispose([xs, ys, lossTensor])
+
+ // @ts-expect-error Variable 'logitsTensor' is used before being assigned
+ const accTensor = tf.metrics.categoricalAccuracy(ys, logitsTensor)
+ const accSize = accTensor.shape.reduce((l, r) => l * r, 1)
+ const accSumTensor = accTensor.sum()
+ const accSum = await accSumTensor.array()
+ if (typeof accSum !== 'number')
+ throw new Error('got multiple accuracy sum')
+ // @ts-expect-error Variable 'logitsTensor' is used before being assigned
+ tf.dispose([accTensor, accSumTensor, logitsTensor])
+
+ const loss = await lossTensor.array()
+ tf.dispose([xs, ys, lossTensor])
- // dummy accuracy for now
- return [loss, 0]
+ const memory = tf.memory().numBytes / 1024 / 1024 / 1024
+ debug("training metrics: %O", {
+ loss,
+ memory,
+ allocated: tf.memory().numTensors,
+ });
+ return [loss, accSum / accSize]
}
async #evaluate(
diff --git a/discojs/src/training/trainer.ts b/discojs/src/training/trainer.ts
index 1124137be..db6877323 100644
--- a/discojs/src/training/trainer.ts
+++ b/discojs/src/training/trainer.ts
@@ -90,7 +90,8 @@ export class Trainer {
let previousRoundWeights: WeightsContainer | undefined;
for (let round = 0; round < totalRound; round++) {
await this.#client.onRoundBeginCommunication();
-
+
+ this.model.previousRoundWeights = previousRoundWeights
yield this.#runRound(dataset, validationDataset);
let localWeights = this.model.weights;
From 0e1f4e5bc0df634395e4fba195ecd500e737eaee Mon Sep 17 00:00:00 2001
From: tomasoignons
Date: Fri, 14 Mar 2025 15:33:00 +0100
Subject: [PATCH 21/25] Added the FedAverage training
This commit add the fedaverage training function inside the tfjs file
---
discojs/src/models/tfjs.ts | 48 +++++++++++++++++++++++++++++++++++++-
1 file changed, 47 insertions(+), 1 deletion(-)
diff --git a/discojs/src/models/tfjs.ts b/discojs/src/models/tfjs.ts
index c4ae98f0d..7bedc3acf 100644
--- a/discojs/src/models/tfjs.ts
+++ b/discojs/src/models/tfjs.ts
@@ -67,12 +67,58 @@ export class TFJS extends Model {
batch: Batched,
): Promise {
const { xs, ys } = this.#batchToTF(batch);
- const logs = await this.trainFedProx(xs, ys);
+ const logs = await this.trainFedAverage(xs, ys);
// const logs = await this.model.trainOnBatch(xs, ys);
tf.dispose([xs, ys])
return this.getBatchLogs(logs)
}
+ async trainFedAverage(
+ xs: tf.Tensor, ys: tf.Tensor,
+ ) : Promise<[number, number]> {
+ let logitsTensor: tf.Tensor;
+
+ const optimizer = tf.train.sgd(0.01); // adjust the learning rate here
+ const lossFunction: () => tf.Scalar = () => {
+ // Apply the model to get logits
+ const logits = this.model.apply(xs) as tf.Tensor;
+ logitsTensor = tf.keep(logits);
+
+ // Calculate binary cross-entropy loss
+ const loss = tf.losses.sigmoidCrossEntropy(ys, logits);
+
+ // Add regularization term (L2 norm of weights)
+ const regularizationTerm = tf.addN(
+ this.model.getWeights().map(w => w.square().sum())
+ );
+
+ return tf.add(loss, regularizationTerm);
+ };
+ const lossTensor = optimizer.minimize(lossFunction, true);
+ if (lossTensor === null) throw new Error("loss should not be null")
+
+ // @ts-expect-error Variable 'logitsTensor' is used before being assigned
+ const accTensor = tf.metrics.categoricalAccuracy(ys, logitsTensor)
+ const accSize = accTensor.shape.reduce((l, r) => l * r, 1)
+ const accSumTensor = accTensor.sum()
+ const accSum = await accSumTensor.array()
+ if (typeof accSum !== 'number')
+ throw new Error('got multiple accuracy sum')
+ // @ts-expect-error Variable 'logitsTensor' is used before being assigned
+ tf.dispose([accTensor, accSumTensor, logitsTensor])
+
+ const loss = await lossTensor.array()
+ tf.dispose([xs, ys, lossTensor])
+
+ const memory = tf.memory().numBytes / 1024 / 1024 / 1024
+ debug("training metrics: %O", {
+ loss,
+ memory,
+ allocated: tf.memory().numTensors,
+ });
+ return [loss, accSum / accSize]
+ }
+
async trainFedProx(
xs: tf.Tensor, ys: tf.Tensor,
): Promise<[number, number]> {
From 576a3efc8c592bed93a95d6dfe0faaa7b81349ba Mon Sep 17 00:00:00 2001
From: tomasoignons
Date: Fri, 14 Mar 2025 16:15:06 +0100
Subject: [PATCH 22/25] Begin to implement the choice between fedaverage and
FedProx
This commit allow the user, in the task creation form, to select the framework of training he wants
For the different presentations, the algorithm is for the moment hard coded, but a choice will be integrated in the future
---
cli/node_modules/immutable/LICENSE | 21 +
cli/node_modules/immutable/README.md | 761 ++++++++++++++++++
cli/node_modules/immutable/package.json | 39 +
cli/src/benchmark_gpt.ts | 2 +-
cli/src/cli.ts | 2 +-
cli/src/data.ts | 1 -
cli/src/train_gpt.ts | 2 +-
discojs/src/default_tasks/cifar10.ts | 2 +-
discojs/src/default_tasks/lus_covid.ts | 2 +-
discojs/src/default_tasks/mnist.ts | 2 +-
discojs/src/default_tasks/simple_face.ts | 2 +-
discojs/src/default_tasks/tinder_dog.ts | 2 +-
discojs/src/default_tasks/titanic.ts | 2 +-
discojs/src/models/gpt/model.ts | 1 +
discojs/src/models/tfjs.ts | 15 +-
discojs/src/serialization/model.spec.ts | 2 +-
docs/examples/custom_task.ts | 2 +-
package-lock.json | 9 +-
.../task_creation_form/TaskForm.vue | 7 +
webapp/src/task_creation_form.ts | 9 +
20 files changed, 867 insertions(+), 18 deletions(-)
create mode 100644 cli/node_modules/immutable/LICENSE
create mode 100644 cli/node_modules/immutable/README.md
create mode 100644 cli/node_modules/immutable/package.json
diff --git a/cli/node_modules/immutable/LICENSE b/cli/node_modules/immutable/LICENSE
new file mode 100644
index 000000000..1e3c4f39c
--- /dev/null
+++ b/cli/node_modules/immutable/LICENSE
@@ -0,0 +1,21 @@
+MIT License
+
+Copyright (c) 2014-present, Lee Byron and other contributors.
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
diff --git a/cli/node_modules/immutable/README.md b/cli/node_modules/immutable/README.md
new file mode 100644
index 000000000..a6eae67dd
--- /dev/null
+++ b/cli/node_modules/immutable/README.md
@@ -0,0 +1,761 @@
+# Immutable collections for JavaScript
+
+[](https://github.com/immutable-js/immutable-js/actions/workflows/ci.yml?query=branch%3Amain) [Chat on slack](https://immutable-js.slack.com)
+
+[Read the docs](https://immutable-js.com/docs/) and eat your vegetables.
+
+Docs are automatically generated from [README.md][] and [immutable.d.ts][].
+Please contribute! Also, don't miss the [wiki][] which contains articles on
+additional specific topics. Can't find something? Open an [issue][].
+
+**Table of contents:**
+
+- [Introduction](#introduction)
+- [Getting started](#getting-started)
+- [The case for Immutability](#the-case-for-immutability)
+- [JavaScript-first API](#javascript-first-api)
+- [Nested Structures](#nested-structures)
+- [Equality treats Collections as Values](#equality-treats-collections-as-values)
+- [Batching Mutations](#batching-mutations)
+- [Lazy Seq](#lazy-seq)
+- [Additional Tools and Resources](#additional-tools-and-resources)
+- [Contributing](#contributing)
+
+## Introduction
+
+[Immutable][] data cannot be changed once created, leading to much simpler
+application development, no defensive copying, and enabling advanced memoization
+and change detection techniques with simple logic. [Persistent][] data presents
+a mutative API which does not update the data in-place, but instead always
+yields new updated data.
+
+Immutable.js provides many Persistent Immutable data structures including:
+`List`, `Stack`, `Map`, `OrderedMap`, `Set`, `OrderedSet` and `Record`.
+
+These data structures are highly efficient on modern JavaScript VMs by using
+structural sharing via [hash maps tries][] and [vector tries][] as popularized
+by Clojure and Scala, minimizing the need to copy or cache data.
+
+Immutable.js also provides a lazy `Seq`, allowing efficient
+chaining of collection methods like `map` and `filter` without creating
+intermediate representations. Create some `Seq` with `Range` and `Repeat`.
+
+Want to hear more? Watch the presentation about Immutable.js:
+
+[](https://youtu.be/I7IdS-PbEgI)
+
+[README.md]: https://github.com/immutable-js/immutable-js/blob/main/README.md
+[immutable.d.ts]: https://github.com/immutable-js/immutable-js/blob/main/type-definitions/immutable.d.ts
+[wiki]: https://github.com/immutable-js/immutable-js/wiki
+[issue]: https://github.com/immutable-js/immutable-js/issues
+[Persistent]: https://en.wikipedia.org/wiki/Persistent_data_structure
+[Immutable]: https://en.wikipedia.org/wiki/Immutable_object
+[hash maps tries]: https://en.wikipedia.org/wiki/Hash_array_mapped_trie
+[vector tries]: https://hypirion.com/musings/understanding-persistent-vector-pt-1
+
+## Getting started
+
+Install `immutable` using npm.
+
+```shell
+# using npm
+npm install immutable
+
+# using Yarn
+yarn add immutable
+
+# using pnpm
+pnpm add immutable
+
+# using Bun
+bun add immutable
+```
+
+Then require it into any module.
+
+
+
+```js
+const { Map } = require('immutable');
+const map1 = Map({ a: 1, b: 2, c: 3 });
+const map2 = map1.set('b', 50);
+map1.get('b') + ' vs. ' + map2.get('b'); // 2 vs. 50
+```
+
+### Browser
+
+Immutable.js has no dependencies, which makes it predictable to include in a Browser.
+
+It's highly recommended to use a module bundler like [webpack](https://webpack.github.io/),
+[rollup](https://rollupjs.org/), or
+[browserify](https://browserify.org/). The `immutable` npm module works
+without any additional consideration. All examples throughout the documentation
+will assume use of this kind of tool.
+
+Alternatively, Immutable.js may be directly included as a script tag. Download
+or link to a CDN such as [CDNJS](https://cdnjs.com/libraries/immutable)
+or [jsDelivr](https://www.jsdelivr.com/package/npm/immutable).
+
+Use a script tag to directly add `Immutable` to the global scope:
+
+```html
+
+
+```
+
+Or use an AMD-style loader (such as [RequireJS](https://requirejs.org/)):
+
+```js
+require(['./immutable.min.js'], function (Immutable) {
+ var map1 = Immutable.Map({ a: 1, b: 2, c: 3 });
+ var map2 = map1.set('b', 50);
+ map1.get('b'); // 2
+ map2.get('b'); // 50
+});
+```
+
+### Flow & TypeScript
+
+Use these Immutable collections and sequences as you would use native
+collections in your [Flowtype](https://flowtype.org/) or [TypeScript](https://typescriptlang.org) programs while still taking
+advantage of type generics, error detection, and auto-complete in your IDE.
+
+Installing `immutable` via npm brings with it type definitions for Flow (v0.55.0 or higher)
+and TypeScript (v2.1.0 or higher), so you shouldn't need to do anything at all!
+
+#### Using TypeScript with Immutable.js v4
+
+Immutable.js type definitions embrace ES2015. While Immutable.js itself supports
+legacy browsers and environments, its type definitions require TypeScript's 2015
+lib. Include either `"target": "es2015"` or `"lib": "es2015"` in your
+`tsconfig.json`, or provide `--target es2015` or `--lib es2015` to the
+`tsc` command.
+
+
+
+```js
+const { Map } = require('immutable');
+const map1 = Map({ a: 1, b: 2, c: 3 });
+const map2 = map1.set('b', 50);
+map1.get('b') + ' vs. ' + map2.get('b'); // 2 vs. 50
+```
+
+#### Using TypeScript with Immutable.js v3 and earlier:
+
+Previous versions of Immutable.js include a reference file which you can include
+via relative path to the type definitions at the top of your file.
+
+```js
+///
+import Immutable from 'immutable';
+var map1: Immutable.Map;
+map1 = Immutable.Map({ a: 1, b: 2, c: 3 });
+var map2 = map1.set('b', 50);
+map1.get('b'); // 2
+map2.get('b'); // 50
+```
+
+## The case for Immutability
+
+Much of what makes application development difficult is tracking mutation and
+maintaining state. Developing with immutable data encourages you to think
+differently about how data flows through your application.
+
+Subscribing to data events throughout your application creates a huge overhead of
+book-keeping which can hurt performance, sometimes dramatically, and creates
+opportunities for areas of your application to get out of sync with each other
+due to easy to make programmer error. Since immutable data never changes,
+subscribing to changes throughout the model is a dead-end and new data can only
+ever be passed from above.
+
+This model of data flow aligns well with the architecture of [React][]
+and especially well with an application designed using the ideas of [Flux][].
+
+When data is passed from above rather than being subscribed to, and you're only
+interested in doing work when something has changed, you can use equality.
+
+Immutable collections should be treated as _values_ rather than _objects_. While
+objects represent some thing which could change over time, a value represents
+the state of that thing at a particular instance of time. This principle is most
+important to understanding the appropriate use of immutable data. In order to
+treat Immutable.js collections as values, it's important to use the
+`Immutable.is()` function or `.equals()` method to determine _value equality_
+instead of the `===` operator which determines object _reference identity_.
+
+
+
+```js
+const { Map } = require('immutable');
+const map1 = Map({ a: 1, b: 2, c: 3 });
+const map2 = Map({ a: 1, b: 2, c: 3 });
+map1.equals(map2); // true
+map1 === map2; // false
+```
+
+Note: As a performance optimization Immutable.js attempts to return the existing
+collection when an operation would result in an identical collection, allowing
+for using `===` reference equality to determine if something definitely has not
+changed. This can be extremely useful when used within a memoization function
+which would prefer to re-run the function if a deeper equality check could
+potentially be more costly. The `===` equality check is also used internally by
+`Immutable.is` and `.equals()` as a performance optimization.
+
+
+
+```js
+const { Map } = require('immutable');
+const map1 = Map({ a: 1, b: 2, c: 3 });
+const map2 = map1.set('b', 2); // Set to same value
+map1 === map2; // true
+```
+
+If an object is immutable, it can be "copied" simply by making another reference
+to it instead of copying the entire object. Because a reference is much smaller
+than the object itself, this results in memory savings and a potential boost in
+execution speed for programs which rely on copies (such as an undo-stack).
+
+
+
+```js
+const { Map } = require('immutable');
+const map = Map({ a: 1, b: 2, c: 3 });
+const mapCopy = map; // Look, "copies" are free!
+```
+
+[React]: https://reactjs.org/
+[Flux]: https://facebook.github.io/flux/docs/in-depth-overview/
+
+
+## JavaScript-first API
+
+While Immutable.js is inspired by Clojure, Scala, Haskell and other functional
+programming environments, it's designed to bring these powerful concepts to
+JavaScript, and therefore has an Object-Oriented API that closely mirrors that
+of [ES2015][] [Array][], [Map][], and [Set][].
+
+[es2015]: https://developer.mozilla.org/en-US/docs/Web/JavaScript/New_in_JavaScript/ECMAScript_6_support_in_Mozilla
+[array]: https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Global_Objects/Array
+[map]: https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Global_Objects/Map
+[set]: https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Global_Objects/Set
+
+The difference for the immutable collections is that methods which would mutate
+the collection, like `push`, `set`, `unshift` or `splice`, instead return a new
+immutable collection. Methods which return new arrays, like `slice` or `concat`,
+instead return new immutable collections.
+
+
+
+```js
+const { List } = require('immutable');
+const list1 = List([1, 2]);
+const list2 = list1.push(3, 4, 5);
+const list3 = list2.unshift(0);
+const list4 = list1.concat(list2, list3);
+assert.equal(list1.size, 2);
+assert.equal(list2.size, 5);
+assert.equal(list3.size, 6);
+assert.equal(list4.size, 13);
+assert.equal(list4.get(0), 1);
+```
+
+Almost all of the methods on [Array][] will be found in similar form on
+`Immutable.List`, those of [Map][] found on `Immutable.Map`, and those of [Set][]
+found on `Immutable.Set`, including collection operations like `forEach()`
+and `map()`.
+
+
+
+```js
+const { Map } = require('immutable');
+const alpha = Map({ a: 1, b: 2, c: 3, d: 4 });
+alpha.map((v, k) => k.toUpperCase()).join();
+// 'A,B,C,D'
+```
+
+### Convert from raw JavaScript objects and arrays.
+
+Designed to inter-operate with your existing JavaScript, Immutable.js
+accepts plain JavaScript Arrays and Objects anywhere a method expects a
+`Collection`.
+
+
+
+```js
+const { Map, List } = require('immutable');
+const map1 = Map({ a: 1, b: 2, c: 3, d: 4 });
+const map2 = Map({ c: 10, a: 20, t: 30 });
+const obj = { d: 100, o: 200, g: 300 };
+const map3 = map1.merge(map2, obj);
+// Map { a: 20, b: 2, c: 10, d: 100, t: 30, o: 200, g: 300 }
+const list1 = List([1, 2, 3]);
+const list2 = List([4, 5, 6]);
+const array = [7, 8, 9];
+const list3 = list1.concat(list2, array);
+// List [ 1, 2, 3, 4, 5, 6, 7, 8, 9 ]
+```
+
+This is possible because Immutable.js can treat any JavaScript Array or Object
+as a Collection. You can take advantage of this in order to get sophisticated
+collection methods on JavaScript Objects, which otherwise have a very sparse
+native API. Because Seq evaluates lazily and does not cache intermediate
+results, these operations can be extremely efficient.
+
+
+
+```js
+const { Seq } = require('immutable');
+const myObject = { a: 1, b: 2, c: 3 };
+Seq(myObject)
+ .map(x => x * x)
+ .toObject();
+// { a: 1, b: 4, c: 9 }
+```
+
+Keep in mind, when using JS objects to construct Immutable Maps, that
+JavaScript Object properties are always strings, even if written in a quote-less
+shorthand, while Immutable Maps accept keys of any type.
+
+
+
+```js
+const { fromJS } = require('immutable');
+
+const obj = { 1: 'one' };
+console.log(Object.keys(obj)); // [ "1" ]
+console.log(obj['1'], obj[1]); // "one", "one"
+
+const map = fromJS(obj);
+console.log(map.get('1'), map.get(1)); // "one", undefined
+```
+
+Property access for JavaScript Objects first converts the key to a string, but
+since Immutable Map keys can be of any type the argument to `get()` is
+not altered.
+
+### Converts back to raw JavaScript objects.
+
+All Immutable.js Collections can be converted to plain JavaScript Arrays and
+Objects shallowly with `toArray()` and `toObject()` or deeply with `toJS()`.
+All Immutable Collections also implement `toJSON()` allowing them to be passed
+to `JSON.stringify` directly. They also respect the custom `toJSON()` methods of
+nested objects.
+
+
+
+```js
+const { Map, List } = require('immutable');
+const deep = Map({ a: 1, b: 2, c: List([3, 4, 5]) });
+console.log(deep.toObject()); // { a: 1, b: 2, c: List [ 3, 4, 5 ] }
+console.log(deep.toArray()); // [ 1, 2, List [ 3, 4, 5 ] ]
+console.log(deep.toJS()); // { a: 1, b: 2, c: [ 3, 4, 5 ] }
+JSON.stringify(deep); // '{"a":1,"b":2,"c":[3,4,5]}'
+```
+
+### Embraces ES2015
+
+Immutable.js supports all JavaScript environments, including legacy
+browsers (even IE11). However it also takes advantage of features added to
+JavaScript in [ES2015][], the latest standard version of JavaScript, including
+[Iterators][], [Arrow Functions][], [Classes][], and [Modules][]. It's inspired
+by the native [Map][] and [Set][] collections added to ES2015.
+
+All examples in the Documentation are presented in ES2015. To run in all
+browsers, they need to be translated to ES5.
+
+```js
+// ES2015
+const mapped = foo.map(x => x * x);
+// ES5
+var mapped = foo.map(function (x) {
+ return x * x;
+});
+```
+
+All Immutable.js collections are [Iterable][iterators], which allows them to be
+used anywhere an Iterable is expected, such as when spreading into an Array.
+
+
+
+```js
+const { List } = require('immutable');
+const aList = List([1, 2, 3]);
+const anArray = [0, ...aList, 4, 5]; // [ 0, 1, 2, 3, 4, 5 ]
+```
+
+Note: A Collection is always iterated in the same order, however that order may
+not always be well defined, as is the case for the `Map` and `Set`.
+
+[Iterators]: https://developer.mozilla.org/en-US/docs/Web/JavaScript/Guide/The_Iterator_protocol
+[Arrow Functions]: https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Functions/Arrow_functions
+[Classes]: https://wiki.ecmascript.org/doku.php?id=strawman:maximally_minimal_classes
+[Modules]: https://www.2ality.com/2014/09/es6-modules-final.html
+
+
+## Nested Structures
+
+The collections in Immutable.js are intended to be nested, allowing for deep
+trees of data, similar to JSON.
+
+
+
+```js
+const { fromJS } = require('immutable');
+const nested = fromJS({ a: { b: { c: [3, 4, 5] } } });
+// Map { a: Map { b: Map { c: List [ 3, 4, 5 ] } } }
+```
+
+A few power-tools allow for reading and operating on nested data. The
+most useful are `mergeDeep`, `getIn`, `setIn`, and `updateIn`, found on `List`,
+`Map` and `OrderedMap`.
+
+
+
+```js
+const { fromJS } = require('immutable');
+const nested = fromJS({ a: { b: { c: [3, 4, 5] } } });
+
+const nested2 = nested.mergeDeep({ a: { b: { d: 6 } } });
+// Map { a: Map { b: Map { c: List [ 3, 4, 5 ], d: 6 } } }
+
+console.log(nested2.getIn(['a', 'b', 'd'])); // 6
+
+const nested3 = nested2.updateIn(['a', 'b', 'd'], value => value + 1);
+console.log(nested3);
+// Map { a: Map { b: Map { c: List [ 3, 4, 5 ], d: 7 } } }
+
+const nested4 = nested3.updateIn(['a', 'b', 'c'], list => list.push(6));
+// Map { a: Map { b: Map { c: List [ 3, 4, 5, 6 ], d: 7 } } }
+```
+
+## Equality treats Collections as Values
+
+Immutable.js collections are treated as pure data _values_. Two immutable
+collections are considered _value equal_ (via `.equals()` or `is()`) if they
+represent the same collection of values. This differs from JavaScript's typical
+_reference equal_ (via `===` or `==`) for Objects and Arrays which only
+determines if two variables represent references to the same object instance.
+
+Consider the example below where two identical `Map` instances are not
+_reference equal_ but are _value equal_.
+
+
+
+```js
+// First consider:
+const obj1 = { a: 1, b: 2, c: 3 };
+const obj2 = { a: 1, b: 2, c: 3 };
+obj1 !== obj2; // two different instances are always not equal with ===
+
+const { Map, is } = require('immutable');
+const map1 = Map({ a: 1, b: 2, c: 3 });
+const map2 = Map({ a: 1, b: 2, c: 3 });
+map1 !== map2; // two different instances are not reference-equal
+map1.equals(map2); // but are value-equal if they have the same values
+is(map1, map2); // alternatively can use the is() function
+```
+
+Value equality allows Immutable.js collections to be used as keys in Maps or
+values in Sets, and retrieved with different but equivalent collections:
+
+
+
+```js
+const { Map, Set } = require('immutable');
+const map1 = Map({ a: 1, b: 2, c: 3 });
+const map2 = Map({ a: 1, b: 2, c: 3 });
+const set = Set().add(map1);
+set.has(map2); // true because these are value-equal
+```
+
+Note: `is()` uses the same measure of equality as [Object.is][] for scalar
+strings and numbers, but uses value equality for Immutable collections,
+determining if both are immutable and all keys and values are equal
+using the same measure of equality.
+
+[object.is]: https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Global_Objects/Object/is
+
+#### Performance tradeoffs
+
+While value equality is useful in many circumstances, it has different
+performance characteristics than reference equality. Understanding these
+tradeoffs may help you decide which to use in each case, especially when used
+to memoize some operation.
+
+When comparing two collections, value equality may require considering every
+item in each collection, on an `O(N)` time complexity. For large collections of
+values, this could become a costly operation. Though if the two are not equal
+and hardly similar, the inequality is determined very quickly. In contrast, when
+comparing two collections with reference equality, only the initial references
+to memory need to be compared which is not based on the size of the collections,
+which has an `O(1)` time complexity. Checking reference equality is always very
+fast, however just because two collections are not reference-equal does not rule
+out the possibility that they may be value-equal.
+
+#### Return self on no-op optimization
+
+When possible, Immutable.js avoids creating new objects for updates where no
+change in _value_ occurred, to allow for efficient _reference equality_ checking
+to quickly determine if no change occurred.
+
+
+
+```js
+const { Map } = require('immutable');
+const originalMap = Map({ a: 1, b: 2, c: 3 });
+const updatedMap = originalMap.set('b', 2);
+updatedMap === originalMap; // No-op .set() returned the original reference.
+```
+
+However updates which do result in a change will return a new reference. Each
+of these operations occur independently, so two similar updates will not return
+the same reference:
+
+
+
+```js
+const { Map } = require('immutable');
+const originalMap = Map({ a: 1, b: 2, c: 3 });
+const updatedMap = originalMap.set('b', 1000);
+// New instance, leaving the original immutable.
+updatedMap !== originalMap;
+const anotherUpdatedMap = originalMap.set('b', 1000);
+// Despite both the results of the same operation, each created a new reference.
+anotherUpdatedMap !== updatedMap;
+// However the two are value equal.
+anotherUpdatedMap.equals(updatedMap);
+```
+
+## Batching Mutations
+
+> If a tree falls in the woods, does it make a sound?
+>
+> If a pure function mutates some local data in order to produce an immutable
+> return value, is that ok?
+>
+> — Rich Hickey, Clojure
+
+Applying a mutation to create a new immutable object results in some overhead,
+which can add up to a minor performance penalty. If you need to apply a series
+of mutations locally before returning, Immutable.js gives you the ability to
+create a temporary mutable (transient) copy of a collection and apply a batch of
+mutations in a performant manner by using `withMutations`. In fact, this is
+exactly how Immutable.js applies complex mutations itself.
+
+As an example, building `list2` results in the creation of 1, not 3, new
+immutable Lists.
+
+
+
+```js
+const { List } = require('immutable');
+const list1 = List([1, 2, 3]);
+const list2 = list1.withMutations(function (list) {
+ list.push(4).push(5).push(6);
+});
+assert.equal(list1.size, 3);
+assert.equal(list2.size, 6);
+```
+
+Note: Immutable.js also provides `asMutable` and `asImmutable`, but only
+encourages their use when `withMutations` will not suffice. Use caution to not
+return a mutable copy, which could result in undesired behavior.
+
+_Important!_: Only a select few methods can be used in `withMutations` including
+`set`, `push` and `pop`. These methods can be applied directly against a
+persistent data-structure where other methods like `map`, `filter`, `sort`,
+and `splice` will always return new immutable data-structures and never mutate
+a mutable collection.
+
+## Lazy Seq
+
+`Seq` describes a lazy operation, allowing them to efficiently chain
+use of all the higher-order collection methods (such as `map` and `filter`)
+by not creating intermediate collections.
+
+**Seq is immutable** — Once a Seq is created, it cannot be
+changed, appended to, rearranged or otherwise modified. Instead, any mutative
+method called on a `Seq` will return a new `Seq`.
+
+**Seq is lazy** — `Seq` does as little work as necessary to respond to any
+method call. Values are often created during iteration, including implicit
+iteration when reducing or converting to a concrete data structure such as
+a `List` or JavaScript `Array`.
+
+For example, the following performs no work, because the resulting
+`Seq`'s values are never iterated:
+
+```js
+const { Seq } = require('immutable');
+const oddSquares = Seq([1, 2, 3, 4, 5, 6, 7, 8])
+ .filter(x => x % 2 !== 0)
+ .map(x => x * x);
+```
+
+Once the `Seq` is used, it performs only the work necessary. In this
+example, no intermediate arrays are ever created, filter is called three
+times, and map is only called once:
+
+```js
+oddSquares.get(1); // 9
+```
+
+Any collection can be converted to a lazy Seq with `Seq()`.
+
+
+
+```js
+const { Map, Seq } = require('immutable');
+const map = Map({ a: 1, b: 2, c: 3 });
+const lazySeq = Seq(map);
+```
+
+`Seq` allows for the efficient chaining of operations, allowing for the
+expression of logic that can otherwise be very tedious:
+
+```js
+lazySeq
+ .flip()
+ .map(key => key.toUpperCase())
+ .flip();
+// Seq { A: 1, B: 2, C: 3 }
+```
+
+As well as expressing logic that would otherwise seem memory or time
+limited, for example `Range` is a special kind of Lazy sequence.
+
+
+
+```js
+const { Range } = require('immutable');
+Range(1, Infinity)
+ .skip(1000)
+ .map(n => -n)
+ .filter(n => n % 2 === 0)
+ .take(2)
+ .reduce((r, n) => r * n, 1);
+// 1006008
+```
+
+## Comparison of filter(), groupBy(), and partition()
+
+The `filter()`, `groupBy()`, and `partition()` methods are similar in that they
+all divide a collection into parts based on applying a function to each element.
+All three call the predicate or grouping function once for each item in the
+input collection. All three return zero or more collections of the same type as
+their input. The returned collections are always distinct from the input
+(according to `===`), even if the contents are identical.
+
+Of these methods, `filter()` is the only one that is lazy and the only one which
+discards items from the input collection. It is the simplest to use, and the
+fact that it returns exactly one collection makes it easy to combine with other
+methods to form a pipeline of operations.
+
+The `partition()` method is similar to an eager version of `filter()`, but it
+returns two collections; the first contains the items that would have been
+discarded by `filter()`, and the second contains the items that would have been
+kept. It always returns an array of exactly two collections, which can make it
+easier to use than `groupBy()`. Compared to making two separate calls to
+`filter()`, `partition()` makes half as many calls it the predicate passed to
+it.
+
+The `groupBy()` method is a more generalized version of `partition()` that can
+group by an arbitrary function rather than just a predicate. It returns a map
+with zero or more entries, where the keys are the values returned by the
+grouping function, and the values are nonempty collections of the corresponding
+arguments. Although `groupBy()` is more powerful than `partition()`, it can be
+harder to use because it is not always possible predict in advance how many
+entries the returned map will have and what their keys will be.
+
+| Summary | `filter` | `partition` | `groupBy` |
+|:------------------------------|:---------|:------------|:---------------|
+| ease of use | easiest | moderate | hardest |
+| generality | least | moderate | most |
+| laziness | lazy | eager | eager |
+| # of returned sub-collections | 1 | 2 | 0 or more |
+| sub-collections may be empty | yes | yes | no |
+| can discard items | yes | no | no |
+| wrapping container | none | array | Map/OrderedMap |
+
+## Additional Tools and Resources
+
+- [Atom-store](https://github.com/jameshopkins/atom-store/)
+ - A Clojure-inspired atom implementation in Javascript with configurability
+ for external persistance.
+
+- [Chai Immutable](https://github.com/astorije/chai-immutable)
+ - If you are using the [Chai Assertion Library](https://chaijs.com/), this
+ provides a set of assertions to use against Immutable.js collections.
+
+- [Fantasy-land](https://github.com/fantasyland/fantasy-land)
+ - Specification for interoperability of common algebraic structures in JavaScript.
+
+- [Immutagen](https://github.com/pelotom/immutagen)
+ - A library for simulating immutable generators in JavaScript.
+
+- [Immutable-cursor](https://github.com/redbadger/immutable-cursor)
+ - Immutable cursors incorporating the Immutable.js interface over
+ Clojure-inspired atom.
+
+- [Immutable-ext](https://github.com/DrBoolean/immutable-ext)
+ - Fantasyland extensions for immutablejs
+
+- [Immutable-js-tools](https://github.com/madeinfree/immutable-js-tools)
+ - Util tools for immutable.js
+
+- [Immutable-Redux](https://github.com/gajus/redux-immutable)
+ - redux-immutable is used to create an equivalent function of Redux
+ combineReducers that works with Immutable.js state.
+
+- [Immutable-Treeutils](https://github.com/lukasbuenger/immutable-treeutils)
+ - Functional tree traversal helpers for ImmutableJS data structures.
+
+- [Irecord](https://github.com/ericelliott/irecord)
+ - An immutable store that exposes an RxJS observable. Great for React.
+
+- [Mudash](https://github.com/brianneisler/mudash)
+ - Lodash wrapper providing Immutable.JS support.
+
+- [React-Immutable-PropTypes](https://github.com/HurricaneJames/react-immutable-proptypes)
+ - PropType validators that work with Immutable.js.
+
+- [Redux-Immutablejs](https://github.com/indexiatech/redux-immutablejs)
+ - Redux Immutable facilities.
+
+- [Rxstate](https://github.com/yamalight/rxstate)
+ - Simple opinionated state management library based on RxJS and Immutable.js.
+
+- [Transit-Immutable-js](https://github.com/glenjamin/transit-immutable-js)
+ - Transit serialisation for Immutable.js.
+ - See also: [Transit-js](https://github.com/cognitect/transit-js)
+
+Have an additional tool designed to work with Immutable.js?
+Submit a PR to add it to this list in alphabetical order.
+
+## Contributing
+
+Use [Github issues](https://github.com/immutable-js/immutable-js/issues) for requests.
+
+We actively welcome pull requests, learn how to [contribute](https://github.com/immutable-js/immutable-js/blob/main/.github/CONTRIBUTING.md).
+
+Immutable.js is maintained within the [Contributor Covenant's Code of Conduct](https://www.contributor-covenant.org/version/2/0/code_of_conduct/).
+
+### Changelog
+
+Changes are tracked as [Github releases](https://github.com/immutable-js/immutable-js/releases).
+
+### License
+
+Immutable.js is [MIT-licensed](./LICENSE).
+
+### Thanks
+
+[Phil Bagwell](https://www.youtube.com/watch?v=K2NYwP90bNs), for his inspiration
+and research in persistent data structures.
+
+[Hugh Jackson](https://github.com/hughfdjackson/), for providing the npm package
+name. If you're looking for his unsupported package, see [this repository](https://github.com/hughfdjackson/immutable).
diff --git a/cli/node_modules/immutable/package.json b/cli/node_modules/immutable/package.json
new file mode 100644
index 000000000..e380730f6
--- /dev/null
+++ b/cli/node_modules/immutable/package.json
@@ -0,0 +1,39 @@
+{
+ "name": "immutable",
+ "version": "4.3.7",
+ "description": "Immutable Data Collections",
+ "license": "MIT",
+ "homepage": "https://immutable-js.com",
+ "author": {
+ "name": "Lee Byron",
+ "url": "https://github.com/leebyron"
+ },
+ "repository": {
+ "type": "git",
+ "url": "git://github.com/immutable-js/immutable-js.git"
+ },
+ "bugs": {
+ "url": "https://github.com/immutable-js/immutable-js/issues"
+ },
+ "main": "dist/immutable.js",
+ "module": "dist/immutable.es.js",
+ "sideEffects": false,
+ "types": "dist/immutable.d.ts",
+ "files": [
+ "dist",
+ "README.md",
+ "LICENSE"
+ ],
+ "keywords": [
+ "immutable",
+ "persistent",
+ "lazy",
+ "data",
+ "datastructure",
+ "functional",
+ "collection",
+ "stateless",
+ "sequence",
+ "iteration"
+ ]
+}
\ No newline at end of file
diff --git a/cli/src/benchmark_gpt.ts b/cli/src/benchmark_gpt.ts
index 6c8bc99ce..c067aba85 100644
--- a/cli/src/benchmark_gpt.ts
+++ b/cli/src/benchmark_gpt.ts
@@ -134,4 +134,4 @@ async function main(args: Required): Promise {
}
// You can run this example with "npm start" from this folder
-main(args).catch(console.error)
+main(args).catch(console.error)
\ No newline at end of file
diff --git a/cli/src/cli.ts b/cli/src/cli.ts
index 052bfcacc..c56c9280b 100644
--- a/cli/src/cli.ts
+++ b/cli/src/cli.ts
@@ -14,7 +14,7 @@ import type {
} from "@epfml/discojs";
import { Disco, aggregator as aggregators, client as clients } from '@epfml/discojs'
-import { getTaskData, loadTinderDogData } from './data.js'
+import { getTaskData } from './data.js'
import { args } from './args.js'
// Array.fromAsync not yet widely used (2024)
diff --git a/cli/src/data.ts b/cli/src/data.ts
index 5ebfc6fe5..c3077c11b 100644
--- a/cli/src/data.ts
+++ b/cli/src/data.ts
@@ -5,7 +5,6 @@ import type {
DataType,
Image,
Task,
- Text,
} from "@epfml/discojs";
import { loadCSV, loadImage, loadImagesInDir } from "@epfml/discojs-node";
import { Repeat } from "immutable";
diff --git a/cli/src/train_gpt.ts b/cli/src/train_gpt.ts
index 531667a0f..bf70b33cf 100644
--- a/cli/src/train_gpt.ts
+++ b/cli/src/train_gpt.ts
@@ -45,4 +45,4 @@ async function main(): Promise {
}
// You can run this example with "npm run run_gpt" from this folder
-main().catch(console.error)
+main().catch(console.error)
\ No newline at end of file
diff --git a/discojs/src/default_tasks/cifar10.ts b/discojs/src/default_tasks/cifar10.ts
index b644b6e93..9d411e022 100644
--- a/discojs/src/default_tasks/cifar10.ts
+++ b/discojs/src/default_tasks/cifar10.ts
@@ -63,6 +63,6 @@ export const cifar10: TaskProvider<'image'> = {
metrics: ['accuracy']
})
- return new models.TFJS('image', model)
+ return new models.TFJS('image', model, "fedprox")
}
}
diff --git a/discojs/src/default_tasks/lus_covid.ts b/discojs/src/default_tasks/lus_covid.ts
index 8331ac0d8..6733df3fd 100644
--- a/discojs/src/default_tasks/lus_covid.ts
+++ b/discojs/src/default_tasks/lus_covid.ts
@@ -93,6 +93,6 @@ export const lusCovid: TaskProvider<'image'> = {
metrics: ['accuracy']
})
- return Promise.resolve(new models.TFJS('image', model))
+ return Promise.resolve(new models.TFJS('image', model, "fedprox"))
}
}
\ No newline at end of file
diff --git a/discojs/src/default_tasks/mnist.ts b/discojs/src/default_tasks/mnist.ts
index d73a044d9..6f65aa98b 100644
--- a/discojs/src/default_tasks/mnist.ts
+++ b/discojs/src/default_tasks/mnist.ts
@@ -66,6 +66,6 @@ export const mnist: TaskProvider<'image'> = {
metrics: ['accuracy']
})
- return Promise.resolve(new models.TFJS('image', model))
+ return Promise.resolve(new models.TFJS('image', model, "fedprox"))
}
}
diff --git a/discojs/src/default_tasks/simple_face.ts b/discojs/src/default_tasks/simple_face.ts
index a87825e5d..f2f60b273 100644
--- a/discojs/src/default_tasks/simple_face.ts
+++ b/discojs/src/default_tasks/simple_face.ts
@@ -48,6 +48,6 @@ export const simpleFace: TaskProvider<'image'> = {
metrics: ['accuracy']
})
- return new models.TFJS('image', model)
+ return new models.TFJS('image', model, "fedprox")
}
}
diff --git a/discojs/src/default_tasks/tinder_dog.ts b/discojs/src/default_tasks/tinder_dog.ts
index a19bf5f8b..7884c36cf 100644
--- a/discojs/src/default_tasks/tinder_dog.ts
+++ b/discojs/src/default_tasks/tinder_dog.ts
@@ -79,6 +79,6 @@ export const tinderDog: TaskProvider<'image'> = {
metrics: ['accuracy']
})
- return Promise.resolve(new models.TFJS('image', model))
+ return Promise.resolve(new models.TFJS('image', model, "fedprox"))
}
}
\ No newline at end of file
diff --git a/discojs/src/default_tasks/titanic.ts b/discojs/src/default_tasks/titanic.ts
index b9462ee50..9efbe4b86 100644
--- a/discojs/src/default_tasks/titanic.ts
+++ b/discojs/src/default_tasks/titanic.ts
@@ -90,6 +90,6 @@ export const titanic: TaskProvider<'tabular'> = {
metrics: ['accuracy']
})
- return Promise.resolve(new models.TFJS('tabular', model))
+ return Promise.resolve(new models.TFJS('tabular', model, "fedprox"))
}
}
diff --git a/discojs/src/models/gpt/model.ts b/discojs/src/models/gpt/model.ts
index f4537b8d0..01ee51e92 100644
--- a/discojs/src/models/gpt/model.ts
+++ b/discojs/src/models/gpt/model.ts
@@ -4,6 +4,7 @@ import * as tf from '@tensorflow/tfjs'
import type { GPTConfig } from './config.js'
import { getModelSizes, DefaultGPTConfig } from './config.js'
import { getCustomAdam, clipByGlobalNormObj } from './optimizers.js'
+import evaluate from './evaluate.js'
import { GPTArchitecture } from './layers.js'
const debug = createDebug("discojs:models:gpt:model");
diff --git a/discojs/src/models/tfjs.ts b/discojs/src/models/tfjs.ts
index 7bedc3acf..ffb5cb06d 100644
--- a/discojs/src/models/tfjs.ts
+++ b/discojs/src/models/tfjs.ts
@@ -18,12 +18,15 @@ const debug = createDebug("discojs:models:tfjs");
type Serialized = [D, tf.io.ModelArtifacts];
+type FrameWorkAlgorithm = "fedaverage" | "fedprox";
+
/** TensorFlow JavaScript model with standard training */
export class TFJS extends Model {
/** Wrap the given trainable model */
constructor (
public readonly datatype: D,
- private readonly model: tf.LayersModel
+ private readonly model: tf.LayersModel,
+ public readonly framework: FrameWorkAlgorithm = "fedaverage",
) {
super()
@@ -67,8 +70,14 @@ export class TFJS extends Model {
batch: Batched,
): Promise {
const { xs, ys } = this.#batchToTF(batch);
- const logs = await this.trainFedAverage(xs, ys);
- // const logs = await this.model.trainOnBatch(xs, ys);
+ let logs: [number, number];
+ if (this.framework === "fedaverage") {
+ logs = await this.trainFedAverage(xs, ys);
+ } else if (this.framework === "fedprox") {
+ logs = await this.trainFedProx(xs, ys);
+ } else {
+ throw new Error("unknown framework");
+ }
tf.dispose([xs, ys])
return this.getBatchLogs(logs)
}
diff --git a/discojs/src/serialization/model.spec.ts b/discojs/src/serialization/model.spec.ts
index a966d39db..135406e78 100644
--- a/discojs/src/serialization/model.spec.ts
+++ b/discojs/src/serialization/model.spec.ts
@@ -28,7 +28,7 @@ describe('serialization', () => {
]
})
rawModel.compile({ optimizer: 'sgd', loss: 'hinge' })
- const model = new models.TFJS("image", rawModel)
+ const model = new models.TFJS("image", rawModel, "fedprox")
const encoded = await serialization.model.encode(model)
assert.isTrue(serialization.isEncoded(encoded))
diff --git a/docs/examples/custom_task.ts b/docs/examples/custom_task.ts
index 609b748ca..4b4548d65 100644
--- a/docs/examples/custom_task.ts
+++ b/docs/examples/custom_task.ts
@@ -55,7 +55,7 @@ const customTask: TaskProvider<"tabular"> = {
metrics: ['accuracy']
})
- return Promise.resolve(new models.TFJS('tabular', model))
+ return Promise.resolve(new models.TFJS('tabular', model, "fedprox"))
}
}
diff --git a/package-lock.json b/package-lock.json
index e2f7dd7d8..5c23c9481 100644
--- a/package-lock.json
+++ b/package-lock.json
@@ -39,6 +39,12 @@
"ts-command-line-args": "2"
}
},
+ "cli/node_modules/immutable": {
+ "version": "4.3.7",
+ "resolved": "https://registry.npmjs.org/immutable/-/immutable-4.3.7.tgz",
+ "integrity": "sha512-1hqclzwYwjRDFLjcFxOM5AYkkG0rpFPpr1RLPMEuGczoS7YA8gLhy8SWXYRAA/XwfEHpfo3cw5JGioS32fnMRw==",
+ "license": "MIT"
+ },
"discojs": {
"name": "@epfml/discojs",
"version": "3.0.0",
@@ -5474,9 +5480,6 @@
"license": "MIT"
},
"node_modules/csv-parse": {
- "version": "5.6.0",
- "resolved": "https://registry.npmjs.org/csv-parse/-/csv-parse-5.6.0.tgz",
- "integrity": "sha512-l3nz3euub2QMg5ouu5U09Ew9Wf6/wQ8I++ch1loQ0ljmzhmfZYrH9fflS22i/PQEvsPvxCwxgz5q7UB8K1JO4Q==",
"version": "5.6.0",
"resolved": "https://registry.npmjs.org/csv-parse/-/csv-parse-5.6.0.tgz",
"integrity": "sha512-l3nz3euub2QMg5ouu5U09Ew9Wf6/wQ8I++ch1loQ0ljmzhmfZYrH9fflS22i/PQEvsPvxCwxgz5q7UB8K1JO4Q==",
diff --git a/webapp/src/components/task_creation_form/TaskForm.vue b/webapp/src/components/task_creation_form/TaskForm.vue
index 29648ea05..90bc917e5 100644
--- a/webapp/src/components/task_creation_form/TaskForm.vue
+++ b/webapp/src/components/task_creation_form/TaskForm.vue
@@ -76,6 +76,11 @@
v-model="scheme"
:field="field"
/>
+
())
@@ -270,6 +276,7 @@ const onSubmit = async (rawTask: any): Promise => {
await tf.loadLayersModel(
tf.io.browserFiles(modelFiles.value.toArray()),
),
+ rawTask.framework
);
break;
case "text":
diff --git a/webapp/src/task_creation_form.ts b/webapp/src/task_creation_form.ts
index f33420364..78864ee1c 100644
--- a/webapp/src/task_creation_form.ts
+++ b/webapp/src/task_creation_form.ts
@@ -66,6 +66,15 @@ const generalInformation: FormSection = {
type: 'select',
options: ['Decentralized', 'Federated'],
default: 'Decentralized'
+ },
+ {
+ id: "framework",
+ "name" : "Training Framework",
+ "yup" : yup.string().required(),
+ "as" : "input",
+ "type" : "select",
+ "options" : ["fedaverage", "fedprox"],
+ "default" : "fedprox"
}
]
}
From f88856fd61b8268a8515887f41d7a97f71a1164a Mon Sep 17 00:00:00 2001
From: tomasoignons
Date: Fri, 14 Mar 2025 16:27:19 +0100
Subject: [PATCH 23/25] added the fedprox by default if nothing is specified
---
discojs/src/models/tfjs.ts | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/discojs/src/models/tfjs.ts b/discojs/src/models/tfjs.ts
index a09b33962..a21b71583 100644
--- a/discojs/src/models/tfjs.ts
+++ b/discojs/src/models/tfjs.ts
@@ -26,7 +26,7 @@ export class TFJS extends Model {
constructor (
public readonly datatype: D,
private readonly model: tf.LayersModel,
- public readonly framework: FrameWorkAlgorithm = "fedaverage",
+ public readonly framework: FrameWorkAlgorithm = "fedprox",
) {
super()
From de00d93bd85ee66ff128407f922509cc13b525fb Mon Sep 17 00:00:00 2001
From: tharvik
Date: Mon, 31 Mar 2025 13:13:31 +0200
Subject: [PATCH 24/25] cli: cleanup node_modules
---
cli/node_modules/immutable/LICENSE | 21 -
cli/node_modules/immutable/README.md | 761 ------------------------
cli/node_modules/immutable/package.json | 39 --
3 files changed, 821 deletions(-)
delete mode 100644 cli/node_modules/immutable/LICENSE
delete mode 100644 cli/node_modules/immutable/README.md
delete mode 100644 cli/node_modules/immutable/package.json
diff --git a/cli/node_modules/immutable/LICENSE b/cli/node_modules/immutable/LICENSE
deleted file mode 100644
index 1e3c4f39c..000000000
--- a/cli/node_modules/immutable/LICENSE
+++ /dev/null
@@ -1,21 +0,0 @@
-MIT License
-
-Copyright (c) 2014-present, Lee Byron and other contributors.
-
-Permission is hereby granted, free of charge, to any person obtaining a copy
-of this software and associated documentation files (the "Software"), to deal
-in the Software without restriction, including without limitation the rights
-to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-copies of the Software, and to permit persons to whom the Software is
-furnished to do so, subject to the following conditions:
-
-The above copyright notice and this permission notice shall be included in all
-copies or substantial portions of the Software.
-
-THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
-SOFTWARE.
diff --git a/cli/node_modules/immutable/README.md b/cli/node_modules/immutable/README.md
deleted file mode 100644
index a6eae67dd..000000000
--- a/cli/node_modules/immutable/README.md
+++ /dev/null
@@ -1,761 +0,0 @@
-# Immutable collections for JavaScript
-
-[](https://github.com/immutable-js/immutable-js/actions/workflows/ci.yml?query=branch%3Amain) [Chat on slack](https://immutable-js.slack.com)
-
-[Read the docs](https://immutable-js.com/docs/) and eat your vegetables.
-
-Docs are automatically generated from [README.md][] and [immutable.d.ts][].
-Please contribute! Also, don't miss the [wiki][] which contains articles on
-additional specific topics. Can't find something? Open an [issue][].
-
-**Table of contents:**
-
-- [Introduction](#introduction)
-- [Getting started](#getting-started)
-- [The case for Immutability](#the-case-for-immutability)
-- [JavaScript-first API](#javascript-first-api)
-- [Nested Structures](#nested-structures)
-- [Equality treats Collections as Values](#equality-treats-collections-as-values)
-- [Batching Mutations](#batching-mutations)
-- [Lazy Seq](#lazy-seq)
-- [Additional Tools and Resources](#additional-tools-and-resources)
-- [Contributing](#contributing)
-
-## Introduction
-
-[Immutable][] data cannot be changed once created, leading to much simpler
-application development, no defensive copying, and enabling advanced memoization
-and change detection techniques with simple logic. [Persistent][] data presents
-a mutative API which does not update the data in-place, but instead always
-yields new updated data.
-
-Immutable.js provides many Persistent Immutable data structures including:
-`List`, `Stack`, `Map`, `OrderedMap`, `Set`, `OrderedSet` and `Record`.
-
-These data structures are highly efficient on modern JavaScript VMs by using
-structural sharing via [hash maps tries][] and [vector tries][] as popularized
-by Clojure and Scala, minimizing the need to copy or cache data.
-
-Immutable.js also provides a lazy `Seq`, allowing efficient
-chaining of collection methods like `map` and `filter` without creating
-intermediate representations. Create some `Seq` with `Range` and `Repeat`.
-
-Want to hear more? Watch the presentation about Immutable.js:
-
-[](https://youtu.be/I7IdS-PbEgI)
-
-[README.md]: https://github.com/immutable-js/immutable-js/blob/main/README.md
-[immutable.d.ts]: https://github.com/immutable-js/immutable-js/blob/main/type-definitions/immutable.d.ts
-[wiki]: https://github.com/immutable-js/immutable-js/wiki
-[issue]: https://github.com/immutable-js/immutable-js/issues
-[Persistent]: https://en.wikipedia.org/wiki/Persistent_data_structure
-[Immutable]: https://en.wikipedia.org/wiki/Immutable_object
-[hash maps tries]: https://en.wikipedia.org/wiki/Hash_array_mapped_trie
-[vector tries]: https://hypirion.com/musings/understanding-persistent-vector-pt-1
-
-## Getting started
-
-Install `immutable` using npm.
-
-```shell
-# using npm
-npm install immutable
-
-# using Yarn
-yarn add immutable
-
-# using pnpm
-pnpm add immutable
-
-# using Bun
-bun add immutable
-```
-
-Then require it into any module.
-
-
-
-```js
-const { Map } = require('immutable');
-const map1 = Map({ a: 1, b: 2, c: 3 });
-const map2 = map1.set('b', 50);
-map1.get('b') + ' vs. ' + map2.get('b'); // 2 vs. 50
-```
-
-### Browser
-
-Immutable.js has no dependencies, which makes it predictable to include in a Browser.
-
-It's highly recommended to use a module bundler like [webpack](https://webpack.github.io/),
-[rollup](https://rollupjs.org/), or
-[browserify](https://browserify.org/). The `immutable` npm module works
-without any additional consideration. All examples throughout the documentation
-will assume use of this kind of tool.
-
-Alternatively, Immutable.js may be directly included as a script tag. Download
-or link to a CDN such as [CDNJS](https://cdnjs.com/libraries/immutable)
-or [jsDelivr](https://www.jsdelivr.com/package/npm/immutable).
-
-Use a script tag to directly add `Immutable` to the global scope:
-
-```html
-
-
-```
-
-Or use an AMD-style loader (such as [RequireJS](https://requirejs.org/)):
-
-```js
-require(['./immutable.min.js'], function (Immutable) {
- var map1 = Immutable.Map({ a: 1, b: 2, c: 3 });
- var map2 = map1.set('b', 50);
- map1.get('b'); // 2
- map2.get('b'); // 50
-});
-```
-
-### Flow & TypeScript
-
-Use these Immutable collections and sequences as you would use native
-collections in your [Flowtype](https://flowtype.org/) or [TypeScript](https://typescriptlang.org) programs while still taking
-advantage of type generics, error detection, and auto-complete in your IDE.
-
-Installing `immutable` via npm brings with it type definitions for Flow (v0.55.0 or higher)
-and TypeScript (v2.1.0 or higher), so you shouldn't need to do anything at all!
-
-#### Using TypeScript with Immutable.js v4
-
-Immutable.js type definitions embrace ES2015. While Immutable.js itself supports
-legacy browsers and environments, its type definitions require TypeScript's 2015
-lib. Include either `"target": "es2015"` or `"lib": "es2015"` in your
-`tsconfig.json`, or provide `--target es2015` or `--lib es2015` to the
-`tsc` command.
-
-
-
-```js
-const { Map } = require('immutable');
-const map1 = Map({ a: 1, b: 2, c: 3 });
-const map2 = map1.set('b', 50);
-map1.get('b') + ' vs. ' + map2.get('b'); // 2 vs. 50
-```
-
-#### Using TypeScript with Immutable.js v3 and earlier:
-
-Previous versions of Immutable.js include a reference file which you can include
-via relative path to the type definitions at the top of your file.
-
-```js
-///
-import Immutable from 'immutable';
-var map1: Immutable.Map;
-map1 = Immutable.Map({ a: 1, b: 2, c: 3 });
-var map2 = map1.set('b', 50);
-map1.get('b'); // 2
-map2.get('b'); // 50
-```
-
-## The case for Immutability
-
-Much of what makes application development difficult is tracking mutation and
-maintaining state. Developing with immutable data encourages you to think
-differently about how data flows through your application.
-
-Subscribing to data events throughout your application creates a huge overhead of
-book-keeping which can hurt performance, sometimes dramatically, and creates
-opportunities for areas of your application to get out of sync with each other
-due to easy to make programmer error. Since immutable data never changes,
-subscribing to changes throughout the model is a dead-end and new data can only
-ever be passed from above.
-
-This model of data flow aligns well with the architecture of [React][]
-and especially well with an application designed using the ideas of [Flux][].
-
-When data is passed from above rather than being subscribed to, and you're only
-interested in doing work when something has changed, you can use equality.
-
-Immutable collections should be treated as _values_ rather than _objects_. While
-objects represent some thing which could change over time, a value represents
-the state of that thing at a particular instance of time. This principle is most
-important to understanding the appropriate use of immutable data. In order to
-treat Immutable.js collections as values, it's important to use the
-`Immutable.is()` function or `.equals()` method to determine _value equality_
-instead of the `===` operator which determines object _reference identity_.
-
-
-
-```js
-const { Map } = require('immutable');
-const map1 = Map({ a: 1, b: 2, c: 3 });
-const map2 = Map({ a: 1, b: 2, c: 3 });
-map1.equals(map2); // true
-map1 === map2; // false
-```
-
-Note: As a performance optimization Immutable.js attempts to return the existing
-collection when an operation would result in an identical collection, allowing
-for using `===` reference equality to determine if something definitely has not
-changed. This can be extremely useful when used within a memoization function
-which would prefer to re-run the function if a deeper equality check could
-potentially be more costly. The `===` equality check is also used internally by
-`Immutable.is` and `.equals()` as a performance optimization.
-
-
-
-```js
-const { Map } = require('immutable');
-const map1 = Map({ a: 1, b: 2, c: 3 });
-const map2 = map1.set('b', 2); // Set to same value
-map1 === map2; // true
-```
-
-If an object is immutable, it can be "copied" simply by making another reference
-to it instead of copying the entire object. Because a reference is much smaller
-than the object itself, this results in memory savings and a potential boost in
-execution speed for programs which rely on copies (such as an undo-stack).
-
-
-
-```js
-const { Map } = require('immutable');
-const map = Map({ a: 1, b: 2, c: 3 });
-const mapCopy = map; // Look, "copies" are free!
-```
-
-[React]: https://reactjs.org/
-[Flux]: https://facebook.github.io/flux/docs/in-depth-overview/
-
-
-## JavaScript-first API
-
-While Immutable.js is inspired by Clojure, Scala, Haskell and other functional
-programming environments, it's designed to bring these powerful concepts to
-JavaScript, and therefore has an Object-Oriented API that closely mirrors that
-of [ES2015][] [Array][], [Map][], and [Set][].
-
-[es2015]: https://developer.mozilla.org/en-US/docs/Web/JavaScript/New_in_JavaScript/ECMAScript_6_support_in_Mozilla
-[array]: https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Global_Objects/Array
-[map]: https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Global_Objects/Map
-[set]: https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Global_Objects/Set
-
-The difference for the immutable collections is that methods which would mutate
-the collection, like `push`, `set`, `unshift` or `splice`, instead return a new
-immutable collection. Methods which return new arrays, like `slice` or `concat`,
-instead return new immutable collections.
-
-
-
-```js
-const { List } = require('immutable');
-const list1 = List([1, 2]);
-const list2 = list1.push(3, 4, 5);
-const list3 = list2.unshift(0);
-const list4 = list1.concat(list2, list3);
-assert.equal(list1.size, 2);
-assert.equal(list2.size, 5);
-assert.equal(list3.size, 6);
-assert.equal(list4.size, 13);
-assert.equal(list4.get(0), 1);
-```
-
-Almost all of the methods on [Array][] will be found in similar form on
-`Immutable.List`, those of [Map][] found on `Immutable.Map`, and those of [Set][]
-found on `Immutable.Set`, including collection operations like `forEach()`
-and `map()`.
-
-
-
-```js
-const { Map } = require('immutable');
-const alpha = Map({ a: 1, b: 2, c: 3, d: 4 });
-alpha.map((v, k) => k.toUpperCase()).join();
-// 'A,B,C,D'
-```
-
-### Convert from raw JavaScript objects and arrays.
-
-Designed to inter-operate with your existing JavaScript, Immutable.js
-accepts plain JavaScript Arrays and Objects anywhere a method expects a
-`Collection`.
-
-
-
-```js
-const { Map, List } = require('immutable');
-const map1 = Map({ a: 1, b: 2, c: 3, d: 4 });
-const map2 = Map({ c: 10, a: 20, t: 30 });
-const obj = { d: 100, o: 200, g: 300 };
-const map3 = map1.merge(map2, obj);
-// Map { a: 20, b: 2, c: 10, d: 100, t: 30, o: 200, g: 300 }
-const list1 = List([1, 2, 3]);
-const list2 = List([4, 5, 6]);
-const array = [7, 8, 9];
-const list3 = list1.concat(list2, array);
-// List [ 1, 2, 3, 4, 5, 6, 7, 8, 9 ]
-```
-
-This is possible because Immutable.js can treat any JavaScript Array or Object
-as a Collection. You can take advantage of this in order to get sophisticated
-collection methods on JavaScript Objects, which otherwise have a very sparse
-native API. Because Seq evaluates lazily and does not cache intermediate
-results, these operations can be extremely efficient.
-
-
-
-```js
-const { Seq } = require('immutable');
-const myObject = { a: 1, b: 2, c: 3 };
-Seq(myObject)
- .map(x => x * x)
- .toObject();
-// { a: 1, b: 4, c: 9 }
-```
-
-Keep in mind, when using JS objects to construct Immutable Maps, that
-JavaScript Object properties are always strings, even if written in a quote-less
-shorthand, while Immutable Maps accept keys of any type.
-
-
-
-```js
-const { fromJS } = require('immutable');
-
-const obj = { 1: 'one' };
-console.log(Object.keys(obj)); // [ "1" ]
-console.log(obj['1'], obj[1]); // "one", "one"
-
-const map = fromJS(obj);
-console.log(map.get('1'), map.get(1)); // "one", undefined
-```
-
-Property access for JavaScript Objects first converts the key to a string, but
-since Immutable Map keys can be of any type the argument to `get()` is
-not altered.
-
-### Converts back to raw JavaScript objects.
-
-All Immutable.js Collections can be converted to plain JavaScript Arrays and
-Objects shallowly with `toArray()` and `toObject()` or deeply with `toJS()`.
-All Immutable Collections also implement `toJSON()` allowing them to be passed
-to `JSON.stringify` directly. They also respect the custom `toJSON()` methods of
-nested objects.
-
-
-
-```js
-const { Map, List } = require('immutable');
-const deep = Map({ a: 1, b: 2, c: List([3, 4, 5]) });
-console.log(deep.toObject()); // { a: 1, b: 2, c: List [ 3, 4, 5 ] }
-console.log(deep.toArray()); // [ 1, 2, List [ 3, 4, 5 ] ]
-console.log(deep.toJS()); // { a: 1, b: 2, c: [ 3, 4, 5 ] }
-JSON.stringify(deep); // '{"a":1,"b":2,"c":[3,4,5]}'
-```
-
-### Embraces ES2015
-
-Immutable.js supports all JavaScript environments, including legacy
-browsers (even IE11). However it also takes advantage of features added to
-JavaScript in [ES2015][], the latest standard version of JavaScript, including
-[Iterators][], [Arrow Functions][], [Classes][], and [Modules][]. It's inspired
-by the native [Map][] and [Set][] collections added to ES2015.
-
-All examples in the Documentation are presented in ES2015. To run in all
-browsers, they need to be translated to ES5.
-
-```js
-// ES2015
-const mapped = foo.map(x => x * x);
-// ES5
-var mapped = foo.map(function (x) {
- return x * x;
-});
-```
-
-All Immutable.js collections are [Iterable][iterators], which allows them to be
-used anywhere an Iterable is expected, such as when spreading into an Array.
-
-
-
-```js
-const { List } = require('immutable');
-const aList = List([1, 2, 3]);
-const anArray = [0, ...aList, 4, 5]; // [ 0, 1, 2, 3, 4, 5 ]
-```
-
-Note: A Collection is always iterated in the same order, however that order may
-not always be well defined, as is the case for the `Map` and `Set`.
-
-[Iterators]: https://developer.mozilla.org/en-US/docs/Web/JavaScript/Guide/The_Iterator_protocol
-[Arrow Functions]: https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Functions/Arrow_functions
-[Classes]: https://wiki.ecmascript.org/doku.php?id=strawman:maximally_minimal_classes
-[Modules]: https://www.2ality.com/2014/09/es6-modules-final.html
-
-
-## Nested Structures
-
-The collections in Immutable.js are intended to be nested, allowing for deep
-trees of data, similar to JSON.
-
-
-
-```js
-const { fromJS } = require('immutable');
-const nested = fromJS({ a: { b: { c: [3, 4, 5] } } });
-// Map { a: Map { b: Map { c: List [ 3, 4, 5 ] } } }
-```
-
-A few power-tools allow for reading and operating on nested data. The
-most useful are `mergeDeep`, `getIn`, `setIn`, and `updateIn`, found on `List`,
-`Map` and `OrderedMap`.
-
-
-
-```js
-const { fromJS } = require('immutable');
-const nested = fromJS({ a: { b: { c: [3, 4, 5] } } });
-
-const nested2 = nested.mergeDeep({ a: { b: { d: 6 } } });
-// Map { a: Map { b: Map { c: List [ 3, 4, 5 ], d: 6 } } }
-
-console.log(nested2.getIn(['a', 'b', 'd'])); // 6
-
-const nested3 = nested2.updateIn(['a', 'b', 'd'], value => value + 1);
-console.log(nested3);
-// Map { a: Map { b: Map { c: List [ 3, 4, 5 ], d: 7 } } }
-
-const nested4 = nested3.updateIn(['a', 'b', 'c'], list => list.push(6));
-// Map { a: Map { b: Map { c: List [ 3, 4, 5, 6 ], d: 7 } } }
-```
-
-## Equality treats Collections as Values
-
-Immutable.js collections are treated as pure data _values_. Two immutable
-collections are considered _value equal_ (via `.equals()` or `is()`) if they
-represent the same collection of values. This differs from JavaScript's typical
-_reference equal_ (via `===` or `==`) for Objects and Arrays which only
-determines if two variables represent references to the same object instance.
-
-Consider the example below where two identical `Map` instances are not
-_reference equal_ but are _value equal_.
-
-
-
-```js
-// First consider:
-const obj1 = { a: 1, b: 2, c: 3 };
-const obj2 = { a: 1, b: 2, c: 3 };
-obj1 !== obj2; // two different instances are always not equal with ===
-
-const { Map, is } = require('immutable');
-const map1 = Map({ a: 1, b: 2, c: 3 });
-const map2 = Map({ a: 1, b: 2, c: 3 });
-map1 !== map2; // two different instances are not reference-equal
-map1.equals(map2); // but are value-equal if they have the same values
-is(map1, map2); // alternatively can use the is() function
-```
-
-Value equality allows Immutable.js collections to be used as keys in Maps or
-values in Sets, and retrieved with different but equivalent collections:
-
-
-
-```js
-const { Map, Set } = require('immutable');
-const map1 = Map({ a: 1, b: 2, c: 3 });
-const map2 = Map({ a: 1, b: 2, c: 3 });
-const set = Set().add(map1);
-set.has(map2); // true because these are value-equal
-```
-
-Note: `is()` uses the same measure of equality as [Object.is][] for scalar
-strings and numbers, but uses value equality for Immutable collections,
-determining if both are immutable and all keys and values are equal
-using the same measure of equality.
-
-[object.is]: https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Global_Objects/Object/is
-
-#### Performance tradeoffs
-
-While value equality is useful in many circumstances, it has different
-performance characteristics than reference equality. Understanding these
-tradeoffs may help you decide which to use in each case, especially when used
-to memoize some operation.
-
-When comparing two collections, value equality may require considering every
-item in each collection, on an `O(N)` time complexity. For large collections of
-values, this could become a costly operation. Though if the two are not equal
-and hardly similar, the inequality is determined very quickly. In contrast, when
-comparing two collections with reference equality, only the initial references
-to memory need to be compared which is not based on the size of the collections,
-which has an `O(1)` time complexity. Checking reference equality is always very
-fast, however just because two collections are not reference-equal does not rule
-out the possibility that they may be value-equal.
-
-#### Return self on no-op optimization
-
-When possible, Immutable.js avoids creating new objects for updates where no
-change in _value_ occurred, to allow for efficient _reference equality_ checking
-to quickly determine if no change occurred.
-
-
-
-```js
-const { Map } = require('immutable');
-const originalMap = Map({ a: 1, b: 2, c: 3 });
-const updatedMap = originalMap.set('b', 2);
-updatedMap === originalMap; // No-op .set() returned the original reference.
-```
-
-However updates which do result in a change will return a new reference. Each
-of these operations occur independently, so two similar updates will not return
-the same reference:
-
-
-
-```js
-const { Map } = require('immutable');
-const originalMap = Map({ a: 1, b: 2, c: 3 });
-const updatedMap = originalMap.set('b', 1000);
-// New instance, leaving the original immutable.
-updatedMap !== originalMap;
-const anotherUpdatedMap = originalMap.set('b', 1000);
-// Despite both the results of the same operation, each created a new reference.
-anotherUpdatedMap !== updatedMap;
-// However the two are value equal.
-anotherUpdatedMap.equals(updatedMap);
-```
-
-## Batching Mutations
-
-> If a tree falls in the woods, does it make a sound?
->
-> If a pure function mutates some local data in order to produce an immutable
-> return value, is that ok?
->
-> — Rich Hickey, Clojure
-
-Applying a mutation to create a new immutable object results in some overhead,
-which can add up to a minor performance penalty. If you need to apply a series
-of mutations locally before returning, Immutable.js gives you the ability to
-create a temporary mutable (transient) copy of a collection and apply a batch of
-mutations in a performant manner by using `withMutations`. In fact, this is
-exactly how Immutable.js applies complex mutations itself.
-
-As an example, building `list2` results in the creation of 1, not 3, new
-immutable Lists.
-
-
-
-```js
-const { List } = require('immutable');
-const list1 = List([1, 2, 3]);
-const list2 = list1.withMutations(function (list) {
- list.push(4).push(5).push(6);
-});
-assert.equal(list1.size, 3);
-assert.equal(list2.size, 6);
-```
-
-Note: Immutable.js also provides `asMutable` and `asImmutable`, but only
-encourages their use when `withMutations` will not suffice. Use caution to not
-return a mutable copy, which could result in undesired behavior.
-
-_Important!_: Only a select few methods can be used in `withMutations` including
-`set`, `push` and `pop`. These methods can be applied directly against a
-persistent data-structure where other methods like `map`, `filter`, `sort`,
-and `splice` will always return new immutable data-structures and never mutate
-a mutable collection.
-
-## Lazy Seq
-
-`Seq` describes a lazy operation, allowing them to efficiently chain
-use of all the higher-order collection methods (such as `map` and `filter`)
-by not creating intermediate collections.
-
-**Seq is immutable** — Once a Seq is created, it cannot be
-changed, appended to, rearranged or otherwise modified. Instead, any mutative
-method called on a `Seq` will return a new `Seq`.
-
-**Seq is lazy** — `Seq` does as little work as necessary to respond to any
-method call. Values are often created during iteration, including implicit
-iteration when reducing or converting to a concrete data structure such as
-a `List` or JavaScript `Array`.
-
-For example, the following performs no work, because the resulting
-`Seq`'s values are never iterated:
-
-```js
-const { Seq } = require('immutable');
-const oddSquares = Seq([1, 2, 3, 4, 5, 6, 7, 8])
- .filter(x => x % 2 !== 0)
- .map(x => x * x);
-```
-
-Once the `Seq` is used, it performs only the work necessary. In this
-example, no intermediate arrays are ever created, filter is called three
-times, and map is only called once:
-
-```js
-oddSquares.get(1); // 9
-```
-
-Any collection can be converted to a lazy Seq with `Seq()`.
-
-
-
-```js
-const { Map, Seq } = require('immutable');
-const map = Map({ a: 1, b: 2, c: 3 });
-const lazySeq = Seq(map);
-```
-
-`Seq` allows for the efficient chaining of operations, allowing for the
-expression of logic that can otherwise be very tedious:
-
-```js
-lazySeq
- .flip()
- .map(key => key.toUpperCase())
- .flip();
-// Seq { A: 1, B: 2, C: 3 }
-```
-
-As well as expressing logic that would otherwise seem memory or time
-limited, for example `Range` is a special kind of Lazy sequence.
-
-
-
-```js
-const { Range } = require('immutable');
-Range(1, Infinity)
- .skip(1000)
- .map(n => -n)
- .filter(n => n % 2 === 0)
- .take(2)
- .reduce((r, n) => r * n, 1);
-// 1006008
-```
-
-## Comparison of filter(), groupBy(), and partition()
-
-The `filter()`, `groupBy()`, and `partition()` methods are similar in that they
-all divide a collection into parts based on applying a function to each element.
-All three call the predicate or grouping function once for each item in the
-input collection. All three return zero or more collections of the same type as
-their input. The returned collections are always distinct from the input
-(according to `===`), even if the contents are identical.
-
-Of these methods, `filter()` is the only one that is lazy and the only one which
-discards items from the input collection. It is the simplest to use, and the
-fact that it returns exactly one collection makes it easy to combine with other
-methods to form a pipeline of operations.
-
-The `partition()` method is similar to an eager version of `filter()`, but it
-returns two collections; the first contains the items that would have been
-discarded by `filter()`, and the second contains the items that would have been
-kept. It always returns an array of exactly two collections, which can make it
-easier to use than `groupBy()`. Compared to making two separate calls to
-`filter()`, `partition()` makes half as many calls it the predicate passed to
-it.
-
-The `groupBy()` method is a more generalized version of `partition()` that can
-group by an arbitrary function rather than just a predicate. It returns a map
-with zero or more entries, where the keys are the values returned by the
-grouping function, and the values are nonempty collections of the corresponding
-arguments. Although `groupBy()` is more powerful than `partition()`, it can be
-harder to use because it is not always possible predict in advance how many
-entries the returned map will have and what their keys will be.
-
-| Summary | `filter` | `partition` | `groupBy` |
-|:------------------------------|:---------|:------------|:---------------|
-| ease of use | easiest | moderate | hardest |
-| generality | least | moderate | most |
-| laziness | lazy | eager | eager |
-| # of returned sub-collections | 1 | 2 | 0 or more |
-| sub-collections may be empty | yes | yes | no |
-| can discard items | yes | no | no |
-| wrapping container | none | array | Map/OrderedMap |
-
-## Additional Tools and Resources
-
-- [Atom-store](https://github.com/jameshopkins/atom-store/)
- - A Clojure-inspired atom implementation in Javascript with configurability
- for external persistance.
-
-- [Chai Immutable](https://github.com/astorije/chai-immutable)
- - If you are using the [Chai Assertion Library](https://chaijs.com/), this
- provides a set of assertions to use against Immutable.js collections.
-
-- [Fantasy-land](https://github.com/fantasyland/fantasy-land)
- - Specification for interoperability of common algebraic structures in JavaScript.
-
-- [Immutagen](https://github.com/pelotom/immutagen)
- - A library for simulating immutable generators in JavaScript.
-
-- [Immutable-cursor](https://github.com/redbadger/immutable-cursor)
- - Immutable cursors incorporating the Immutable.js interface over
- Clojure-inspired atom.
-
-- [Immutable-ext](https://github.com/DrBoolean/immutable-ext)
- - Fantasyland extensions for immutablejs
-
-- [Immutable-js-tools](https://github.com/madeinfree/immutable-js-tools)
- - Util tools for immutable.js
-
-- [Immutable-Redux](https://github.com/gajus/redux-immutable)
- - redux-immutable is used to create an equivalent function of Redux
- combineReducers that works with Immutable.js state.
-
-- [Immutable-Treeutils](https://github.com/lukasbuenger/immutable-treeutils)
- - Functional tree traversal helpers for ImmutableJS data structures.
-
-- [Irecord](https://github.com/ericelliott/irecord)
- - An immutable store that exposes an RxJS observable. Great for React.
-
-- [Mudash](https://github.com/brianneisler/mudash)
- - Lodash wrapper providing Immutable.JS support.
-
-- [React-Immutable-PropTypes](https://github.com/HurricaneJames/react-immutable-proptypes)
- - PropType validators that work with Immutable.js.
-
-- [Redux-Immutablejs](https://github.com/indexiatech/redux-immutablejs)
- - Redux Immutable facilities.
-
-- [Rxstate](https://github.com/yamalight/rxstate)
- - Simple opinionated state management library based on RxJS and Immutable.js.
-
-- [Transit-Immutable-js](https://github.com/glenjamin/transit-immutable-js)
- - Transit serialisation for Immutable.js.
- - See also: [Transit-js](https://github.com/cognitect/transit-js)
-
-Have an additional tool designed to work with Immutable.js?
-Submit a PR to add it to this list in alphabetical order.
-
-## Contributing
-
-Use [Github issues](https://github.com/immutable-js/immutable-js/issues) for requests.
-
-We actively welcome pull requests, learn how to [contribute](https://github.com/immutable-js/immutable-js/blob/main/.github/CONTRIBUTING.md).
-
-Immutable.js is maintained within the [Contributor Covenant's Code of Conduct](https://www.contributor-covenant.org/version/2/0/code_of_conduct/).
-
-### Changelog
-
-Changes are tracked as [Github releases](https://github.com/immutable-js/immutable-js/releases).
-
-### License
-
-Immutable.js is [MIT-licensed](./LICENSE).
-
-### Thanks
-
-[Phil Bagwell](https://www.youtube.com/watch?v=K2NYwP90bNs), for his inspiration
-and research in persistent data structures.
-
-[Hugh Jackson](https://github.com/hughfdjackson/), for providing the npm package
-name. If you're looking for his unsupported package, see [this repository](https://github.com/hughfdjackson/immutable).
diff --git a/cli/node_modules/immutable/package.json b/cli/node_modules/immutable/package.json
deleted file mode 100644
index e380730f6..000000000
--- a/cli/node_modules/immutable/package.json
+++ /dev/null
@@ -1,39 +0,0 @@
-{
- "name": "immutable",
- "version": "4.3.7",
- "description": "Immutable Data Collections",
- "license": "MIT",
- "homepage": "https://immutable-js.com",
- "author": {
- "name": "Lee Byron",
- "url": "https://github.com/leebyron"
- },
- "repository": {
- "type": "git",
- "url": "git://github.com/immutable-js/immutable-js.git"
- },
- "bugs": {
- "url": "https://github.com/immutable-js/immutable-js/issues"
- },
- "main": "dist/immutable.js",
- "module": "dist/immutable.es.js",
- "sideEffects": false,
- "types": "dist/immutable.d.ts",
- "files": [
- "dist",
- "README.md",
- "LICENSE"
- ],
- "keywords": [
- "immutable",
- "persistent",
- "lazy",
- "data",
- "datastructure",
- "functional",
- "collection",
- "stateless",
- "sequence",
- "iteration"
- ]
-}
\ No newline at end of file
From a5ecff3b058ae978367ef881dfe2257584a131b8 Mon Sep 17 00:00:00 2001
From: tharvik
Date: Mon, 31 Mar 2025 13:14:40 +0200
Subject: [PATCH 25/25] cli: drop local immutable
---
cli/package.json | 1 -
package-lock.json | 7 -------
2 files changed, 8 deletions(-)
diff --git a/cli/package.json b/cli/package.json
index 5c3353f80..5bd6176f3 100644
--- a/cli/package.json
+++ b/cli/package.json
@@ -17,7 +17,6 @@
"dependencies": {
"@epfml/discojs-node": "*",
"csv-parse": "^5.6.0",
- "immutable": "4",
"server": "*",
"tslib": "2"
},
diff --git a/package-lock.json b/package-lock.json
index 5c23c9481..3608f60dc 100644
--- a/package-lock.json
+++ b/package-lock.json
@@ -30,7 +30,6 @@
"dependencies": {
"@epfml/discojs-node": "*",
"csv-parse": "^5.6.0",
- "immutable": "4",
"server": "*",
"tslib": "2"
},
@@ -39,12 +38,6 @@
"ts-command-line-args": "2"
}
},
- "cli/node_modules/immutable": {
- "version": "4.3.7",
- "resolved": "https://registry.npmjs.org/immutable/-/immutable-4.3.7.tgz",
- "integrity": "sha512-1hqclzwYwjRDFLjcFxOM5AYkkG0rpFPpr1RLPMEuGczoS7YA8gLhy8SWXYRAA/XwfEHpfo3cw5JGioS32fnMRw==",
- "license": "MIT"
- },
"discojs": {
"name": "@epfml/discojs",
"version": "3.0.0",