Skip to content
This repository was archived by the owner on Sep 3, 2022. It is now read-only.

Commit c245c8a

Browse files
qimingjojarjur
authored andcommitted
Fix an issue that batch predict in mltoolbox fails. (#695)
* Fix an issue that batch predict in mltoolbox fails. In local prediction run, it is a dataflow local run job. One DataFlow operator, the TF graph runner, calls TF's "bundle_shim.load_session_bundle_or_saved_model_bundle_from_path" API in its "start_bundle" function. This API not only creates a session, but also populate default TF graph with operators loaded from the saved model. If this operator is called multiple times, multiple operators (even with the same name) will be added to the default graph, which is incorrect and causes unexpected behavior. The fix resets default graph in the operator's "finish_bundle" call to properly do cleanup. * Pin dill's version.
1 parent f9cfab1 commit c245c8a

File tree

2 files changed

+4
-0
lines changed

2 files changed

+4
-0
lines changed

solutionbox/structured_data/mltoolbox/_structured_data/prediction/predict.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,10 @@ def start_bundle(self, element=None):
157157
self._aliases, self._tensor_names = zip(*self._output_alias_map.items())
158158

159159
def finish_bundle(self, element=None):
160+
import tensorflow as tf
161+
160162
self._session.close()
163+
tf.reset_default_graph()
161164

162165
def process(self, element):
163166
"""Run batch prediciton on a TF graph.

tox.ini

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ skip_missing_interpreters = true
1414
# For pandas-profiling, we have to install 1.4.0 as 1.4.1 triggers division-by-zero errors.
1515
deps = pandas-profiling==1.4.0
1616
apache-airflow==1.9.0
17+
dill==0.2.6
1718
tensorflow==1.8.0
1819
lime==0.1.1.23
1920
xgboost==0.6a2

0 commit comments

Comments
 (0)