-
Notifications
You must be signed in to change notification settings - Fork 76
[IMP] orm: parallel iter_browse #320
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
9c9f200
4a56a32
a40dde7
060b1a3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,13 +9,23 @@ | |
on this module work along the ORM of *all* supported versions. | ||
""" | ||
|
||
import collections | ||
import logging | ||
import multiprocessing | ||
import os | ||
import re | ||
import sys | ||
import uuid | ||
from contextlib import contextmanager | ||
from functools import wraps | ||
from itertools import chain | ||
from itertools import chain, repeat | ||
from textwrap import dedent | ||
|
||
try: | ||
from concurrent.futures import ProcessPoolExecutor | ||
except ImportError: | ||
ProcessPoolExecutor = None | ||
|
||
try: | ||
from unittest.mock import patch | ||
except ImportError: | ||
|
@@ -27,9 +37,9 @@ | |
except ImportError: | ||
from odoo import SUPERUSER_ID | ||
from odoo import fields as ofields | ||
from odoo import modules, release | ||
from odoo import modules, release, sql_db | ||
except ImportError: | ||
from openerp import SUPERUSER_ID, modules, release | ||
from openerp import SUPERUSER_ID, modules, release, sql_db | ||
|
||
try: | ||
from openerp import fields as ofields | ||
|
@@ -41,8 +51,8 @@ | |
from .const import BIG_TABLE_THRESHOLD | ||
from .exceptions import MigrationError | ||
from .helpers import table_of_model | ||
from .misc import chunks, log_progress, version_between, version_gte | ||
from .pg import SQLStr, column_exists, format_query, get_columns, named_cursor | ||
from .misc import chunks, log_progress, str2bool, version_between, version_gte | ||
from .pg import SQLStr, column_exists, format_query, get_columns, get_max_workers, named_cursor | ||
|
||
# python3 shims | ||
try: | ||
|
@@ -52,6 +62,10 @@ | |
|
||
_logger = logging.getLogger(__name__) | ||
|
||
UPG_PARALLEL_ITER_BROWSE = str2bool(os.environ.get("UPG_PARALLEL_ITER_BROWSE", "0")) | ||
# FIXME: for CI! Remove before merge | ||
UPG_PARALLEL_ITER_BROWSE = True | ||
|
||
|
||
def env(cr): | ||
""" | ||
|
@@ -341,6 +355,26 @@ def get_ids(): | |
cr.execute("DROP TABLE IF EXISTS _upgrade_rf") | ||
|
||
|
||
def _mp_iter_browse_cb(ids_or_values, params): | ||
me = _mp_iter_browse_cb | ||
# init upon first call. Done here instead of initializer callback, because py3.6 doesn't have it | ||
if not hasattr(me, "env"): | ||
sql_db._Pool = None # children cannot borrow from copies of the same pool, it will cause protocol error | ||
me.env = env(sql_db.db_connect(params["dbname"]).cursor()) | ||
me.env.clear() | ||
# process | ||
if params["mode"] == "browse": | ||
getattr( | ||
me.env[params["model_name"]].with_context(params["context"]).browse(ids_or_values), params["attr_name"] | ||
)(*params["args"], **params["kwargs"]) | ||
if params["mode"] == "create": | ||
new_ids = me.env[params["model_name"]].with_context(params["context"]).create(ids_or_values).ids | ||
me.env.cr.commit() | ||
if params["mode"] == "create": | ||
return new_ids | ||
return None | ||
|
||
|
||
class iter_browse(object): | ||
""" | ||
Iterate over recordsets. | ||
|
@@ -374,7 +408,8 @@ class iter_browse(object): | |
|
||
:param model: the model to iterate | ||
:type model: :class:`odoo.model.Model` | ||
:param list(int) ids: list of IDs of the records to iterate | ||
:param iterable(int) or SQLStr ids: iterable of IDs of the records to iterate, or a SQL query | ||
that can produce the IDs | ||
:param int chunk_size: number of records to load in each iteration chunk, `200` by | ||
default | ||
:param logger: logger used to report the progress, by default | ||
|
@@ -387,23 +422,98 @@ class iter_browse(object): | |
See also :func:`~odoo.upgrade.util.orm.env` | ||
""" | ||
|
||
__slots__ = ("_chunk_size", "_cr_uid", "_it", "_logger", "_model", "_patch", "_size", "_strategy") | ||
__slots__ = ( | ||
"_chunk_size", | ||
"_cr_uid", | ||
"_ids", | ||
"_it", | ||
"_logger", | ||
"_model", | ||
"_patch", | ||
"_size", | ||
"_strategy", | ||
"_superchunk_size", | ||
) | ||
|
||
def __init__(self, model, *args, **kw): | ||
assert len(args) in [1, 3] # either (cr, uid, ids) or (ids,) | ||
self._model = model | ||
self._cr_uid = args[:-1] | ||
ids = args[-1] | ||
self._size = len(ids) | ||
self._ids = args[-1] | ||
self._size = kw.pop("size", None) | ||
self._chunk_size = kw.pop("chunk_size", 200) # keyword-only argument | ||
self._superchunk_size = self._chunk_size | ||
self._logger = kw.pop("logger", _logger) | ||
self._strategy = kw.pop("strategy", "flush") | ||
assert self._strategy in {"flush", "commit"} | ||
assert self._strategy in {"flush", "commit", "multiprocessing"} | ||
if self._strategy == "multiprocessing": | ||
if not ProcessPoolExecutor: | ||
raise ValueError("multiprocessing strategy can not be used in scripts run by python2") | ||
if UPG_PARALLEL_ITER_BROWSE: | ||
self._superchunk_size = min(get_max_workers() * 10 * self._chunk_size, 1000000) | ||
else: | ||
self._strategy = "commit" # downgrade | ||
if self._size > 100000: | ||
_logger.warning( | ||
"Browsing %d %s, which may take a long time. " | ||
"This can be sped up by setting the env variable UPG_PARALLEL_ITER_BROWSE to 1. " | ||
"If you do, be sure to examine the results carefully.", | ||
self._size, | ||
self._model._name, | ||
) | ||
else: | ||
_logger.info( | ||
"Caller requested multiprocessing strategy, but UPG_PARALLEL_ITER_BROWSE env var is not set. " | ||
"Downgrading strategy to commit.", | ||
) | ||
if kw: | ||
raise TypeError("Unknown arguments: %s" % ", ".join(kw)) | ||
|
||
if isinstance(self._ids, SQLStr): | ||
self._ids_query() | ||
|
||
if not self._size: | ||
try: | ||
self._size = len(self._ids) | ||
except TypeError: | ||
raise ValueError("When passing ids as a generator, the size kwarg is mandatory") | ||
self._patch = None | ||
self._it = chunks(ids, self._chunk_size, fmt=self._browse) | ||
self._it = chunks(self._ids, self._chunk_size, fmt=self._browse) | ||
|
||
def _ids_query(self): | ||
cr = self._model.env.cr | ||
tmp_tbl = "_upgrade_ib_{}".format(uuid.uuid4().hex) | ||
cr.execute( | ||
format_query( | ||
cr, "CREATE UNLOGGED TABLE {}(id) AS (WITH query AS ({}) SELECT * FROM query)", tmp_tbl, self._ids | ||
) | ||
) | ||
self._size = cr.rowcount | ||
cr.execute( | ||
format_query(cr, "ALTER TABLE {} ADD CONSTRAINT {} PRIMARY KEY (id)", tmp_tbl, "pk_{}_id".format(tmp_tbl)) | ||
) | ||
|
||
def get_ids(): | ||
with named_cursor(cr, itersize=self._superchunk_size) as ncr: | ||
ncr.execute(format_query(cr, "SELECT id FROM {} ORDER BY id", tmp_tbl)) | ||
for (id_,) in ncr: | ||
yield id_ | ||
cr.execute(format_query(cr, "DROP TABLE IF EXISTS {}", tmp_tbl)) | ||
|
||
self._ids = get_ids() | ||
|
||
def _values_query(self, query): | ||
cr = self._model.env.cr | ||
cr.execute(format_query(cr, "WITH query AS ({}) SELECT count(*) FROM query", query)) | ||
size = cr.fetchone()[0] | ||
|
||
def get_values(): | ||
with named_cursor(cr, itersize=self._chunk_size) as ncr: | ||
ncr.execute(query) | ||
for row in ncr.iterdict(): | ||
yield row | ||
|
||
return size, get_values() | ||
|
||
def _browse(self, ids): | ||
next(self._end(), None) | ||
|
@@ -415,7 +525,7 @@ def _browse(self, ids): | |
return self._model.browse(*args) | ||
|
||
def _end(self): | ||
if self._strategy == "commit": | ||
if self._strategy in ["commit", "multiprocessing"]: | ||
self._model.env.cr.commit() | ||
else: | ||
flush(self._model) | ||
|
@@ -430,8 +540,12 @@ def __iter__(self): | |
raise RuntimeError("%r ran twice" % (self,)) | ||
|
||
it = chain.from_iterable(self._it) | ||
sz = self._size | ||
if self._strategy == "multiprocessing": | ||
it = self._it | ||
sz = (self._size + self._chunk_size - 1) // self._chunk_size | ||
if self._logger: | ||
it = log_progress(it, self._logger, qualifier=self._model._name, size=self._size) | ||
it = log_progress(it, self._logger, qualifier=self._model._name, size=sz) | ||
self._it = None | ||
return chain(it, self._end()) | ||
|
||
|
@@ -442,15 +556,40 @@ def __getattr__(self, attr): | |
if not callable(getattr(self._model, attr)): | ||
raise TypeError("The attribute %r is not callable" % attr) | ||
|
||
it = self._it | ||
it = chunks(self._ids, self._superchunk_size, fmt=self._browse) | ||
if self._logger: | ||
sz = (self._size + self._chunk_size - 1) // self._chunk_size | ||
qualifier = "%s[:%d]" % (self._model._name, self._chunk_size) | ||
sz = (self._size + self._superchunk_size - 1) // self._superchunk_size | ||
qualifier = "%s[:%d]" % (self._model._name, self._superchunk_size) | ||
it = log_progress(it, self._logger, qualifier=qualifier, size=sz) | ||
|
||
def caller(*args, **kwargs): | ||
args = self._cr_uid + args | ||
return [getattr(chnk, attr)(*args, **kwargs) for chnk in chain(it, self._end())] | ||
if self._strategy != "multiprocessing": | ||
return [getattr(chnk, attr)(*args, **kwargs) for chnk in chain(it, self._end())] | ||
params = { | ||
"dbname": self._model.env.cr.dbname, | ||
"model_name": self._model._name, | ||
# convert to dict for pickle. Will still break if any value in the context is not pickleable | ||
"context": dict(self._model.env.context), | ||
"attr_name": attr, | ||
"args": args, | ||
"kwargs": kwargs, | ||
"mode": "browse", | ||
} | ||
self._model.env.cr.commit() | ||
extrakwargs = {"mp_context": multiprocessing.get_context("fork")} if sys.version_info >= (3, 7) else {} | ||
with ProcessPoolExecutor(max_workers=get_max_workers(), **extrakwargs) as executor: | ||
for chunk in it: | ||
collections.deque( | ||
executor.map( | ||
_mp_iter_browse_cb, chunks(chunk._ids, self._chunk_size, fmt=tuple), repeat(params) | ||
), | ||
maxlen=0, | ||
) | ||
next(self._end(), None) | ||
# do not return results in // mode, we expect it to be used for huge numbers of | ||
# records and thus would risk MemoryError, also we cannot know if what attr returns is pickleable | ||
return None | ||
|
||
self._it = None | ||
return caller | ||
|
@@ -467,6 +606,7 @@ def create(self, values, **kw): | |
`True` from Odoo 12 and above | ||
""" | ||
multi = kw.pop("multi", version_gte("saas~11.5")) | ||
size = kw.pop("size", None) | ||
if kw: | ||
raise TypeError("Unknown arguments: %s" % ", ".join(kw)) | ||
|
||
|
@@ -476,31 +616,72 @@ def create(self, values, **kw): | |
if self._size: | ||
raise ValueError("`create` can only called on empty `browse_record` objects.") | ||
|
||
ids = [] | ||
size = len(values) | ||
it = chunks(values, self._chunk_size, fmt=list) | ||
if self._strategy == "multiprocessing" and not multi: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As discussed in DM: we should have two private implementations. The current one under |
||
raise ValueError("The multiprocessing strategy only supports the multi version of `create`") | ||
|
||
if isinstance(values, SQLStr): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why don't we add a def create(self, values=None, query=None, **kw): In any case if we keep hacking |
||
size, values = self._values_query(values) | ||
|
||
if size is None: | ||
try: | ||
size = len(values) | ||
except TypeError: | ||
raise ValueError("When passing values as a generator, the size kwarg is mandatory") | ||
|
||
chunk_size = self._superchunk_size if self._strategy == "multiprocessing" else self._chunk_size | ||
it = chunks(values, chunk_size, fmt=list) | ||
if self._logger: | ||
sz = (size + self._chunk_size - 1) // self._chunk_size | ||
qualifier = "env[%r].create([:%d])" % (self._model._name, self._chunk_size) | ||
sz = (size + chunk_size - 1) // chunk_size | ||
qualifier = "env[%r].create([:%d])" % (self._model._name, chunk_size) | ||
it = log_progress(it, self._logger, qualifier=qualifier, size=sz) | ||
|
||
self._patch = no_selection_cache_validation() | ||
for sub_values in it: | ||
def mp_create(): | ||
params = { | ||
"dbname": self._model.env.cr.dbname, | ||
"model_name": self._model._name, | ||
# convert to dict for pickle. Will still break if any value in the context is not pickleable | ||
"context": dict(self._model.env.context), | ||
"mode": "create", | ||
} | ||
self._model.env.cr.commit() | ||
self._patch.start() | ||
extrakwargs = {"mp_context": multiprocessing.get_context("fork")} if sys.version_info >= (3, 7) else {} | ||
with ProcessPoolExecutor(max_workers=get_max_workers(), **extrakwargs) as executor: | ||
for sub_values in it: | ||
for task_result in executor.map( | ||
_mp_iter_browse_cb, chunks(sub_values, self._chunk_size, fmt=tuple), repeat(params) | ||
): | ||
self._model.env.cr.commit() # make task_result visible on main cursor before yielding ids | ||
for new_id in task_result: | ||
yield new_id | ||
next(self._end(), None) | ||
|
||
if multi: | ||
ids += self._model.create(sub_values).ids | ||
elif not self._cr_uid: | ||
ids += [self._model.create(sub_value).id for sub_value in sub_values] | ||
else: | ||
# old API, `create` directly return the id | ||
ids += [self._model.create(*(self._cr_uid + (sub_value,))) for sub_value in sub_values] | ||
self._patch = no_selection_cache_validation() | ||
if self._strategy == "multiprocessing": | ||
ids = mp_create() | ||
else: | ||
ids = [] | ||
for sub_values in it: | ||
self._patch.start() | ||
|
||
if multi: | ||
ids += self._model.create(sub_values).ids | ||
elif not self._cr_uid: | ||
ids += [self._model.create(sub_value).id for sub_value in sub_values] | ||
else: | ||
# old API, `create` directly return the id | ||
ids += [self._model.create(*(self._cr_uid + (sub_value,))) for sub_value in sub_values] | ||
|
||
next(self._end(), None) | ||
|
||
next(self._end(), None) | ||
args = self._cr_uid + (ids,) | ||
return iter_browse( | ||
self._model, *args, chunk_size=self._chunk_size, logger=self._logger, strategy=self._strategy | ||
) | ||
kwargs = { | ||
"size": size, | ||
"chunk_size": self._chunk_size, | ||
"logger": None if self._strategy == "multiprocessing" else self._logger, | ||
"strategy": self._strategy, | ||
} | ||
return iter_browse(self._model, *args, **kwargs) | ||
|
||
|
||
@contextmanager | ||
|
Uh oh!
There was an error while loading. Please reload this page.