Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions neuralprophet/configure.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,8 @@ def set_auto_batch_epoch(
# this should (with auto batch size) yield about 1000 steps minimum and 100,000 steps at upper cutoff
self.epochs = int(2 ** (2.5 * np.log10(100 + n_data)) / (n_data / 1000.0))
self.epochs = min(max_epoch, max(min_epoch, self.epochs))
if isinstance(self.optimizer, torch.optim.LBFGS):
self.epochs = 1
log.info(f"Auto-set epochs to {self.epochs}")
# also set lambda_delay:
self.lambda_delay = int(self.reg_delay_pct * self.epochs)
Expand Down
35 changes: 30 additions & 5 deletions neuralprophet/forecaster.py
Original file line number Diff line number Diff line change
Expand Up @@ -2692,7 +2692,14 @@ def _init_train_loader(self, df, num_workers=0):
# Determine the max_number of epochs
self.config_train.set_auto_batch_epoch(n_data=len(dataset))

loader = DataLoader(dataset, batch_size=self.config_train.batch_size, shuffle=True, num_workers=num_workers)
loader = DataLoader(
dataset,
batch_size=self.config_train.batch_size
if self.config_train.optimizer.__name__ != "LBFGS"
else len(dataset),
shuffle=True,
num_workers=num_workers,
)

return loader

Expand All @@ -2711,7 +2718,12 @@ def _init_val_loader(self, df):
df, _, _, _ = df_utils.prep_or_copy_df(df)
df = self._normalize(df)
dataset = self._create_dataset(df, predict_mode=False)
loader = DataLoader(dataset, batch_size=min(1024, len(dataset)), shuffle=False, drop_last=False)
loader = DataLoader(
dataset,
batch_size=min(1024, len(dataset)) if self.config_train.optimizer.__name__ != "LBFGS" else len(dataset),
shuffle=False,
drop_last=False,
)
return loader

def _train(
Expand Down Expand Up @@ -2785,7 +2797,11 @@ def _train(
df_val, _, _, _ = df_utils.prep_or_copy_df(df_val)
val_loader = self._init_val_loader(df_val)

if not continue_training and not self.config_train.learning_rate:
if (
not continue_training
and not self.config_train.learning_rate
and not self.config_train.optimizer.__name__ != "LBFGS"
):
# Set parameters for the learning rate finder
self.config_train.set_lr_finder_args(dataset_size=dataset_size, num_batches=len(train_loader))
# Find suitable learning rate
Expand All @@ -2807,7 +2823,11 @@ def _train(
ckpt_path=self.metrics_logger.checkpoint_path if continue_training else None,
)
else:
if not continue_training and not self.config_train.learning_rate:
if (
not continue_training
and not self.config_train.learning_rate
and not self.config_train.optimizer.__name__ != "LBFGS"
):
# Set parameters for the learning rate finder
self.config_train.set_lr_finder_args(dataset_size=dataset_size, num_batches=len(train_loader))
# Find suitable learning rate
Expand Down Expand Up @@ -3083,7 +3103,12 @@ def _predict_raw(self, df, df_name, include_components=False, prediction_frequen
if "y_scaled" not in df.columns or "t" not in df.columns:
raise ValueError("Received unprepared dataframe to predict. " "Please call predict_dataframe_to_predict.")
dataset = self._create_dataset(df, predict_mode=True, prediction_frequency=prediction_frequency)
loader = DataLoader(dataset, batch_size=min(1024, len(df)), shuffle=False, drop_last=False)
loader = DataLoader(
dataset,
batch_size=min(1024, len(df)) if self.config_train.optimizer.__name__ != "LBFGS" else len(dataset),
shuffle=False,
drop_last=False,
)
if self.n_forecasts > 1:
dates = df["ds"].iloc[self.max_lags : -self.n_forecasts + 1]
else:
Expand Down
40 changes: 22 additions & 18 deletions neuralprophet/time_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def __init__(
# Optimizer and LR Scheduler
self._optimizer = self.config_train.optimizer
self._scheduler = self.config_train.scheduler
self.automatic_optimization = False
# self.automatic_optimization = False

# Hyperparameters (can be tuned using trainer.tune())
self.learning_rate = self.config_train.learning_rate if self.config_train.learning_rate is not None else 1e-3
Expand Down Expand Up @@ -756,13 +756,13 @@ def training_step(self, batch, batch_idx):
loss, reg_loss = self.loss_func(inputs, predicted, targets)

# Optimization
optimizer = self.optimizers()
optimizer.zero_grad()
self.manual_backward(loss)
optimizer.step()
# optimizer = self.optimizers()
# optimizer.zero_grad()
# self.manual_backward(loss)
# optimizer.step()

scheduler = self.lr_schedulers()
scheduler.step()
# scheduler = self.lr_schedulers()
# scheduler.step()

# Manually track the loss for the lr finder
self.trainer.fit_loop.running_loss.append(loss)
Expand Down Expand Up @@ -830,18 +830,22 @@ def predict_step(self, batch, batch_idx, dataloader_idx=0):
return prediction, components

def configure_optimizers(self):
# Optimizer
optimizer = self._optimizer(self.parameters(), lr=self.learning_rate, **self.config_train.optimizer_args)

# Scheduler
lr_scheduler = self._scheduler(
optimizer,
max_lr=self.learning_rate,
total_steps=self.trainer.estimated_stepping_batches,
**self.config_train.scheduler_args,
)
if self._optimizer == torch.optim.LBFGS:
# Optimizer
optimizer = self._optimizer(self.parameters(), lr=0.01, **self.config_train.optimizer_args)
return optimizer
else:
# Optimizer
optimizer = self._optimizer(self.parameters(), lr=self.learning_rate, **self.config_train.optimizer_args)
# Scheduler
lr_scheduler = self._scheduler(
optimizer,
max_lr=self.learning_rate,
total_steps=self.trainer.estimated_stepping_batches,
**self.config_train.scheduler_args,
)

return {"optimizer": optimizer, "lr_scheduler": lr_scheduler}
return {"optimizer": optimizer, "lr_scheduler": lr_scheduler}

def _get_time_based_sample_weight(self, t):
weight = torch.ones_like(t)
Expand Down
4 changes: 4 additions & 0 deletions neuralprophet/utils_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,10 @@ def create_optimizer_from_config(optimizer_name, optimizer_args):
optimizer = torch.optim.SGD
optimizer_args["momentum"] = 0.9
optimizer_args["weight_decay"] = 1e-4
elif optimizer_name.lower() == "l-bfgs":
optimizer = torch.optim.LBFGS
optimizer_args["max_iter"] = 50
optimizer_args["max_eval"] = 25
else:
raise ValueError(f"The optimizer name {optimizer_name} is not supported.")
elif inspect.isclass(optimizer_name) and issubclass(optimizer_name, torch.optim.Optimizer):
Expand Down
47 changes: 47 additions & 0 deletions tests/test_model_performance.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,33 @@ def test_YosemiteTemps():
create_metrics_plot(metrics).write_image(os.path.join(DIR, "tests", "metrics", "YosemiteTemps.svg"))


def test_YosemiteTempsBFGS():
df = pd.read_csv(YOS_FILE)
m = NeuralProphet(
n_lags=24,
n_forecasts=24,
changepoints_range=0.95,
n_changepoints=30,
weekly_seasonality=False,
optimizer="L-BFGS",
)
df_train, df_test = m.split_df(df=df, freq="5min", valid_p=0.2)

system_speed, std = get_system_speed()
start = time.time()
metrics = m.fit(df_train, validation_df=df_test, freq="5min", early_stopping=True)
end = time.time()

accuracy_metrics = metrics.to_dict("records")[-1]
accuracy_metrics["time"] = round(end - start, 2)
accuracy_metrics["system_performance"] = round(system_speed, 5)
accuracy_metrics["system_std"] = round(std, 5)
with open(os.path.join(DIR, "tests", "metrics", "YosemiteTempsBFGS.json"), "w") as outfile:
json.dump(accuracy_metrics, outfile)

create_metrics_plot(metrics).write_image(os.path.join(DIR, "tests", "metrics", "YosemiteTempsBFGS.svg"))


def test_AirPassengers():
df = pd.read_csv(AIR_FILE)
m = NeuralProphet(seasonality_mode="multiplicative")
Expand All @@ -181,3 +208,23 @@ def test_AirPassengers():
json.dump(accuracy_metrics, outfile)

create_metrics_plot(metrics).write_image(os.path.join(DIR, "tests", "metrics", "AirPassengers.svg"))


def test_AirPassengersBFGS():
df = pd.read_csv(AIR_FILE)
m = NeuralProphet(seasonality_mode="multiplicative", optimizer="L-BFGS")
df_train, df_test = m.split_df(df=df, freq="MS", valid_p=0.2)

system_speed, std = get_system_speed()
start = time.time()
metrics = m.fit(df_train, validation_df=df_test, freq="MS", early_stopping=True)
end = time.time()

accuracy_metrics = metrics.to_dict("records")[-1]
accuracy_metrics["time"] = round(end - start, 2)
accuracy_metrics["system_performance"] = round(system_speed, 5)
accuracy_metrics["system_std"] = round(std, 5)
with open(os.path.join(DIR, "tests", "metrics", "AirPassengersBFGS.json"), "w") as outfile:
json.dump(accuracy_metrics, outfile)

create_metrics_plot(metrics).write_image(os.path.join(DIR, "tests", "metrics", "AirPassengersBFGS.svg"))