diff --git a/calc/calc_transformer_flops.py b/calc/calc_transformer_flops.py index 74f99e4..79d9ff6 100644 --- a/calc/calc_transformer_flops.py +++ b/calc/calc_transformer_flops.py @@ -65,7 +65,7 @@ def config_parser(): type=float, default=300e9, help='Number of tokens you are training over') - parser.add_argument("--no-checkpoint-activations", "-ca", + parser.add_argument("--no-checkpoint-activations", "-nca", action='store_false', help='Whether Megatron-style activation checkpointing is being used', dest='checkpoint_activations') @@ -85,6 +85,13 @@ def config_parser(): parser.add_argument("--infer", "-i", action='store_true', help='Pass to calculate FLOPs for inference-only workload (no backward pass)') + parser.add_argument("--encoder-only", "-enc", action="store_true", + help="Set if the model is encoder-only (e.g. BERT/ModernBERT)") + + parser.add_argument("--mlm-ratio", "-mr", type=float, default=1.0, + help="Fraction of tokens that receive a language-model head. " + "Use 1.0 for autoregressive GPT-style training, " + "0.15 for BERT-style masked-LM pre-training") return parser # calculates the flops of a model given its hparams @@ -93,6 +100,8 @@ def calc_flops(args): assert args.num_layers % args.expert_interval == 0, "Require for simplicity that we don't have hanging dense layers" assert not args.ffn_hidden_size or (args.ffn_expansion_factor == 4), "both '--ffn-hidden-size' and non-default '-ff' values were specified, these cannot conflict" + is_encoder = args.encoder_only + # An A_(m x k) X B_(k x n) matrix multiplication requires 2m x k x n FLOPs (factor of 2 needed to account for multiplies and adds) # determine the flops factor. @@ -117,8 +126,8 @@ def calc_flops(args): ffn_flops = int(iter_factor * 2 * args.num_mlp_linears * args.ffn_expansion_factor) * args.num_layers * args.tokens * args.hidden_size * args.hidden_size # no activation checkpointing for embeddings - # embedding (4*d_model) plus unembedding (2*d_model*vocab_size) - embedding_flops = args.tokens * (4 * args.hidden_size + 2 * args.hidden_size * args.vocab_size) + embedding_flops = (args.tokens * (4 * args.hidden_size + 2 * args.hidden_size * args.vocab_size) + * (args.mlm_ratio if is_encoder else 1.0)) if args.moe and args.topk > 1: ffn_flops *= args.topk / args.expert_interval