- 
                Notifications
    You must be signed in to change notification settings 
- Fork 116
Support Customizing the Training Loop with AllReduce
Now, users need to define the forward computation, loss function, optimizer and dataset function in ElasticDL. ElasticDL provides the training loop with those definitions. Now ElasticDL only support Keras API to define the forward computation. It maybe not flexible for users to define complex models in CV or NLP. Sometimes, users need to customize the training loop to control the model iteration. Here, we discuss how to support customizing the training loop by users.
ElasticDL provides the dataset for users to define the training loop. And users should wrap backward computation including gradient merging using an elastic function.
Using Tensorflow 1.x, we use tf.Session to execute the forward and backward
computation. The training function definition likes the following code snippets.
# Users should wrap their forward and backward computation
# using ElasticDL decorator
@elastic_allreduce
def train_step(session, train_op):
    """Users should wrap the backward computation using ElasticDL
    """
    sess.run(train_op)
def elastic_train(dataset):
    dataset_iter = dataset.make_one_shot_iterator()
    features, labels = dataset_iter.get_next()
    loss = forward(features, labels)
    global_step = tf.train.get_or_create_global_step()
    lr = tf.Variable(base_lr * hvd.size())
    optimizer = tf.train.GradientDescentOptimizer(lr)
    optimizer = hvd.DistributedOptimizer(optimizer)
    train_op = optimizer.minimize(loss, global_step=global_step)
    with tf.Session(config=config) as session:
        # ElasticDL provides ElasticBroadcastObject to set broadcast objects
        ElasticBroadcastObject.set_session(session)
        session.run(initializer)
        step = 0
        while True:
            train_step(session, train_op)
            loss_value = sess.run(loss)
            if step % 20 == 0:
                logging.info("loss = {}".format(loss_value))
            step += 1# Users should wrap their forward and backward computation
# using ElasticDL decorator
@elastic_allreduce
def train_step(model, optimizer, features, labels):
    """Users should wrap the backward computation using ElasticDL
    """
    with tf.GradientTape() as tape:
        outputs = model.call(features, training=True)
        loss = tf.reduce_mean(
            input_tensor=tf.nn.sparse_softmax_cross_entropy_with_logits(
                logits=outputs, labels=labels
            )
        )
    tape = hvd.DistributedGradientTape(tape)
    grads = tape.gradient(loss, model.trainable_variables)
    # Take care of the order of grads and vars if worker modifies
    # `_non_embed_vars` during training.
    optimizer.apply_gradients(
        zip(grads, model.trainable_variables)
    )
def elastic_train(dataset):
    """ ElasticDL will call the function to execute the training loop
    Arguments:
        dataset: tf.data.Dataset which initialized by ElasticDL
    """
    inputs = tf.keras.Input(shape=(28, 28), name="image")
    outputs = Conv(10)(inputs)
    model = tf.keras.Model(inputs=inputs, outputs=outputs, name="mnist_model")
    optimizer = tf.optimizers.SGD(lr)
    # Set object to broadcast
    ElasticBroadcastObject.set_model(model)
    ElasticBroadcastObject.set_optimizer(optimizer)
    for step, (features, labels) in enumerate(dataset):
        train_step(model, optimizer, features, labels)
        if step % 20 == 0:
            logging.info("Step = {}, loss = {}".format(step, loss))import torch
import horovod.torch as hvd
@elastic_allreduce
def train_step(optimizer):
    """Users should wrap the backward computation using ElasticDL
    """
    optimizer.step()
def elastic_train(dataset):
    """ ElasticDL will call the function to execute the training loop
    Arguments:
        dataset: tf.data.Dataset which initialized by ElasticDL. We can
        use eager execution to fetch batch data from the dataset for PyTorch.
    """
    model = ...
    optimizer = optim.SGD(model.parameters(), lr * hvd.size())
    optimizer = hvd.DistributedOptimizer(optimizer)
    # Set object to broadcast
    ElasticBroadcastObject.set_model(model)
    ElasticBroadcastObject.set_optimizer(optimizer)
    for features, labels in dataset:
        optimizer.zero_grad()
        output = model(features)
        loss = F.nll_loss(output, labels)
        loss.backward()
        train_step(optimizer)
    if step % 20 == 0:
        logging.info("Step = {}, loss = {}".format(step, loss))The ElasticDL worker will create a dataset according to tasks. Each task contains the location of a data shard. ElasticDL worker will call the training function for each batch data in the dataset.
def elastic_allreduce(func, framework='tf_v1'):
    """Decorator used to run the elastic training process.
    """
    def wrapper(*args, **kwargs):
        for i in range(MAX_ALLREDUCE_RETRY_COUNT):
            try:
                if need_broadcast:
                    broadcast_variables(ElasticBroadcastObject, framework)
                    need_broadcast = False
                func(*args, **kwargs)
                report_batch_finished()
            except HorovodInternalError:
                init_horovod_if_needed()
                need_broadcast = True
    return wrapper
def broadcast_variables(ElasticBroadcastObject, framework):
    if framework == "tf_v1":
        broadcast_tf_v1(ElasticBroadcastObject)
    elif framework == "tf_v2":
        broadcast_tf_v2(ElasticBroadcastObject):
    else:
        broadcast_torch(ElasticBroadcastObject)
def broadcast_tf_v1(ElasticBroadcastObject):
    from horovod.torch.functions import broadcast_optimizer_state, broadcast_parameters
    bcast_op = broadcast_variables(tf.global_variables(), root_rank=0)
    ElasticBroadcastObject.session.run(bcast_op)
def broadcast_tf_v2(ElasticBroadcastObject):
from horovod.tensorflow.functions import broadcast_variables
    broadcast_variables(ElasticBroadcastObject.model.variables, root_rank=0)
    broadcast_variables(ElasticBroadcastObject.optimizer.variables(), root_rank=0)
def broadcast_torch(ElasticBroadcastObject):
    from horovod.torch.functions import broadcast_optimizer_state, broadcast_parameters
    broadcast_parameters(self.model.state_dict(), root_rank=0)
    broadcast_optimizer_state(self.optimizer, root_rank=0)