Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
251 changes: 216 additions & 35 deletions src/util/orm.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,23 @@
on this module work along the ORM of *all* supported versions.
"""

import collections
import logging
import multiprocessing
import os
import re
import sys
import uuid
from contextlib import contextmanager
from functools import wraps
from itertools import chain
from itertools import chain, repeat
from textwrap import dedent

try:
from concurrent.futures import ProcessPoolExecutor
except ImportError:
ProcessPoolExecutor = None

try:
from unittest.mock import patch
except ImportError:
Expand All @@ -27,9 +37,9 @@
except ImportError:
from odoo import SUPERUSER_ID
from odoo import fields as ofields
from odoo import modules, release
from odoo import modules, release, sql_db
except ImportError:
from openerp import SUPERUSER_ID, modules, release
from openerp import SUPERUSER_ID, modules, release, sql_db

try:
from openerp import fields as ofields
Expand All @@ -41,8 +51,8 @@
from .const import BIG_TABLE_THRESHOLD
from .exceptions import MigrationError
from .helpers import table_of_model
from .misc import chunks, log_progress, version_between, version_gte
from .pg import SQLStr, column_exists, format_query, get_columns, named_cursor
from .misc import chunks, log_progress, str2bool, version_between, version_gte
from .pg import SQLStr, column_exists, format_query, get_columns, get_max_workers, named_cursor

# python3 shims
try:
Expand All @@ -52,6 +62,10 @@

_logger = logging.getLogger(__name__)

UPG_PARALLEL_ITER_BROWSE = str2bool(os.environ.get("UPG_PARALLEL_ITER_BROWSE", "0"))
# FIXME: for CI! Remove before merge
UPG_PARALLEL_ITER_BROWSE = True


def env(cr):
"""
Expand Down Expand Up @@ -341,6 +355,26 @@ def get_ids():
cr.execute("DROP TABLE IF EXISTS _upgrade_rf")


def _mp_iter_browse_cb(ids_or_values, params):
me = _mp_iter_browse_cb
# init upon first call. Done here instead of initializer callback, because py3.6 doesn't have it
if not hasattr(me, "env"):
sql_db._Pool = None # children cannot borrow from copies of the same pool, it will cause protocol error
me.env = env(sql_db.db_connect(params["dbname"]).cursor())
me.env.clear()
# process
if params["mode"] == "browse":
getattr(
me.env[params["model_name"]].with_context(params["context"]).browse(ids_or_values), params["attr_name"]
)(*params["args"], **params["kwargs"])
if params["mode"] == "create":
new_ids = me.env[params["model_name"]].with_context(params["context"]).create(ids_or_values).ids
me.env.cr.commit()
if params["mode"] == "create":
return new_ids
return None


class iter_browse(object):
"""
Iterate over recordsets.
Expand Down Expand Up @@ -374,7 +408,8 @@ class iter_browse(object):

:param model: the model to iterate
:type model: :class:`odoo.model.Model`
:param list(int) ids: list of IDs of the records to iterate
:param iterable(int) or SQLStr ids: iterable of IDs of the records to iterate, or a SQL query
that can produce the IDs
:param int chunk_size: number of records to load in each iteration chunk, `200` by
default
:param logger: logger used to report the progress, by default
Expand All @@ -387,23 +422,98 @@ class iter_browse(object):
See also :func:`~odoo.upgrade.util.orm.env`
"""

__slots__ = ("_chunk_size", "_cr_uid", "_it", "_logger", "_model", "_patch", "_size", "_strategy")
__slots__ = (
"_chunk_size",
"_cr_uid",
"_ids",
"_it",
"_logger",
"_model",
"_patch",
"_size",
"_strategy",
"_superchunk_size",
)

def __init__(self, model, *args, **kw):
assert len(args) in [1, 3] # either (cr, uid, ids) or (ids,)
self._model = model
self._cr_uid = args[:-1]
ids = args[-1]
self._size = len(ids)
self._ids = args[-1]
self._size = kw.pop("size", None)
self._chunk_size = kw.pop("chunk_size", 200) # keyword-only argument
self._superchunk_size = self._chunk_size
self._logger = kw.pop("logger", _logger)
self._strategy = kw.pop("strategy", "flush")
assert self._strategy in {"flush", "commit"}
assert self._strategy in {"flush", "commit", "multiprocessing"}
if self._strategy == "multiprocessing":
if not ProcessPoolExecutor:
raise ValueError("multiprocessing strategy can not be used in scripts run by python2")
if UPG_PARALLEL_ITER_BROWSE:
self._superchunk_size = min(get_max_workers() * 10 * self._chunk_size, 1000000)
else:
self._strategy = "commit" # downgrade
if self._size > 100000:
_logger.warning(
"Browsing %d %s, which may take a long time. "
"This can be sped up by setting the env variable UPG_PARALLEL_ITER_BROWSE to 1. "
"If you do, be sure to examine the results carefully.",
self._size,
self._model._name,
)
else:
_logger.info(
"Caller requested multiprocessing strategy, but UPG_PARALLEL_ITER_BROWSE env var is not set. "
"Downgrading strategy to commit.",
)
if kw:
raise TypeError("Unknown arguments: %s" % ", ".join(kw))

if isinstance(self._ids, SQLStr):
self._ids_query()

if not self._size:
try:
self._size = len(self._ids)
except TypeError:
raise ValueError("When passing ids as a generator, the size kwarg is mandatory")
self._patch = None
self._it = chunks(ids, self._chunk_size, fmt=self._browse)
self._it = chunks(self._ids, self._chunk_size, fmt=self._browse)

def _ids_query(self):
cr = self._model.env.cr
tmp_tbl = "_upgrade_ib_{}".format(uuid.uuid4().hex)
cr.execute(
format_query(
cr, "CREATE UNLOGGED TABLE {}(id) AS (WITH query AS ({}) SELECT * FROM query)", tmp_tbl, self._ids
)
)
self._size = cr.rowcount
cr.execute(
format_query(cr, "ALTER TABLE {} ADD CONSTRAINT {} PRIMARY KEY (id)", tmp_tbl, "pk_{}_id".format(tmp_tbl))
)

def get_ids():
with named_cursor(cr, itersize=self._superchunk_size) as ncr:
ncr.execute(format_query(cr, "SELECT id FROM {} ORDER BY id", tmp_tbl))
for (id_,) in ncr:
yield id_
cr.execute(format_query(cr, "DROP TABLE IF EXISTS {}", tmp_tbl))

self._ids = get_ids()

def _values_query(self, query):
cr = self._model.env.cr
cr.execute(format_query(cr, "WITH query AS ({}) SELECT count(*) FROM query", query))
size = cr.fetchone()[0]

def get_values():
with named_cursor(cr, itersize=self._chunk_size) as ncr:
ncr.execute(query)
for row in ncr.iterdict():
yield row

return size, get_values()

def _browse(self, ids):
next(self._end(), None)
Expand All @@ -415,7 +525,7 @@ def _browse(self, ids):
return self._model.browse(*args)

def _end(self):
if self._strategy == "commit":
if self._strategy in ["commit", "multiprocessing"]:
self._model.env.cr.commit()
else:
flush(self._model)
Expand All @@ -430,8 +540,12 @@ def __iter__(self):
raise RuntimeError("%r ran twice" % (self,))

it = chain.from_iterable(self._it)
sz = self._size
if self._strategy == "multiprocessing":
it = self._it
sz = (self._size + self._chunk_size - 1) // self._chunk_size
if self._logger:
it = log_progress(it, self._logger, qualifier=self._model._name, size=self._size)
it = log_progress(it, self._logger, qualifier=self._model._name, size=sz)
self._it = None
return chain(it, self._end())

Expand All @@ -442,15 +556,40 @@ def __getattr__(self, attr):
if not callable(getattr(self._model, attr)):
raise TypeError("The attribute %r is not callable" % attr)

it = self._it
it = chunks(self._ids, self._superchunk_size, fmt=self._browse)
if self._logger:
sz = (self._size + self._chunk_size - 1) // self._chunk_size
qualifier = "%s[:%d]" % (self._model._name, self._chunk_size)
sz = (self._size + self._superchunk_size - 1) // self._superchunk_size
qualifier = "%s[:%d]" % (self._model._name, self._superchunk_size)
it = log_progress(it, self._logger, qualifier=qualifier, size=sz)

def caller(*args, **kwargs):
args = self._cr_uid + args
return [getattr(chnk, attr)(*args, **kwargs) for chnk in chain(it, self._end())]
if self._strategy != "multiprocessing":
return [getattr(chnk, attr)(*args, **kwargs) for chnk in chain(it, self._end())]
params = {
"dbname": self._model.env.cr.dbname,
"model_name": self._model._name,
# convert to dict for pickle. Will still break if any value in the context is not pickleable
"context": dict(self._model.env.context),
"attr_name": attr,
"args": args,
"kwargs": kwargs,
"mode": "browse",
}
self._model.env.cr.commit()
extrakwargs = {"mp_context": multiprocessing.get_context("fork")} if sys.version_info >= (3, 7) else {}
with ProcessPoolExecutor(max_workers=get_max_workers(), **extrakwargs) as executor:
for chunk in it:
collections.deque(
executor.map(
_mp_iter_browse_cb, chunks(chunk._ids, self._chunk_size, fmt=tuple), repeat(params)
),
maxlen=0,
)
next(self._end(), None)
# do not return results in // mode, we expect it to be used for huge numbers of
# records and thus would risk MemoryError, also we cannot know if what attr returns is pickleable
return None

self._it = None
return caller
Expand All @@ -467,6 +606,7 @@ def create(self, values, **kw):
`True` from Odoo 12 and above
"""
multi = kw.pop("multi", version_gte("saas~11.5"))
size = kw.pop("size", None)
if kw:
raise TypeError("Unknown arguments: %s" % ", ".join(kw))

Expand All @@ -476,31 +616,72 @@ def create(self, values, **kw):
if self._size:
raise ValueError("`create` can only called on empty `browse_record` objects.")

ids = []
size = len(values)
it = chunks(values, self._chunk_size, fmt=list)
if self._strategy == "multiprocessing" and not multi:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As discussed in DM: we should have two private implementations. The current one under _create and the new one in _create_multiplrocessing we call each depending on the options.

raise ValueError("The multiprocessing strategy only supports the multi version of `create`")

if isinstance(values, SQLStr):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why don't we add a query parameter making values=None by default. To make this more explicit.

def create(self, values=None, query=None, **kw):

In any case if we keep hacking values for queries I wouldn't use SQLStr, just check against plain str. Normal values of create are never strings so this won't risk any confusion. I still believe being explicit is better though.

size, values = self._values_query(values)

if size is None:
try:
size = len(values)
except TypeError:
raise ValueError("When passing values as a generator, the size kwarg is mandatory")

chunk_size = self._superchunk_size if self._strategy == "multiprocessing" else self._chunk_size
it = chunks(values, chunk_size, fmt=list)
if self._logger:
sz = (size + self._chunk_size - 1) // self._chunk_size
qualifier = "env[%r].create([:%d])" % (self._model._name, self._chunk_size)
sz = (size + chunk_size - 1) // chunk_size
qualifier = "env[%r].create([:%d])" % (self._model._name, chunk_size)
it = log_progress(it, self._logger, qualifier=qualifier, size=sz)

self._patch = no_selection_cache_validation()
for sub_values in it:
def mp_create():
params = {
"dbname": self._model.env.cr.dbname,
"model_name": self._model._name,
# convert to dict for pickle. Will still break if any value in the context is not pickleable
"context": dict(self._model.env.context),
"mode": "create",
}
self._model.env.cr.commit()
self._patch.start()
extrakwargs = {"mp_context": multiprocessing.get_context("fork")} if sys.version_info >= (3, 7) else {}
with ProcessPoolExecutor(max_workers=get_max_workers(), **extrakwargs) as executor:
for sub_values in it:
for task_result in executor.map(
_mp_iter_browse_cb, chunks(sub_values, self._chunk_size, fmt=tuple), repeat(params)
):
self._model.env.cr.commit() # make task_result visible on main cursor before yielding ids
for new_id in task_result:
yield new_id
next(self._end(), None)

if multi:
ids += self._model.create(sub_values).ids
elif not self._cr_uid:
ids += [self._model.create(sub_value).id for sub_value in sub_values]
else:
# old API, `create` directly return the id
ids += [self._model.create(*(self._cr_uid + (sub_value,))) for sub_value in sub_values]
self._patch = no_selection_cache_validation()
if self._strategy == "multiprocessing":
ids = mp_create()
else:
ids = []
for sub_values in it:
self._patch.start()

if multi:
ids += self._model.create(sub_values).ids
elif not self._cr_uid:
ids += [self._model.create(sub_value).id for sub_value in sub_values]
else:
# old API, `create` directly return the id
ids += [self._model.create(*(self._cr_uid + (sub_value,))) for sub_value in sub_values]

next(self._end(), None)

next(self._end(), None)
args = self._cr_uid + (ids,)
return iter_browse(
self._model, *args, chunk_size=self._chunk_size, logger=self._logger, strategy=self._strategy
)
kwargs = {
"size": size,
"chunk_size": self._chunk_size,
"logger": None if self._strategy == "multiprocessing" else self._logger,
"strategy": self._strategy,
}
return iter_browse(self._model, *args, **kwargs)


@contextmanager
Expand Down