diff --git a/toolz/itertoolz.py b/toolz/itertoolz.py index 76353e94..9848603f 100644 --- a/toolz/itertoolz.py +++ b/toolz/itertoolz.py @@ -10,11 +10,12 @@ __all__ = ('remove', 'accumulate', 'groupby', 'merge_sorted', 'interleave', - 'unique', 'isiterable', 'isdistinct', 'take', 'drop', 'take_nth', - 'first', 'second', 'nth', 'last', 'get', 'concat', 'concatv', - 'mapcat', 'cons', 'interpose', 'frequencies', 'reduceby', 'iterate', - 'sliding_window', 'partition', 'partition_all', 'count', 'pluck', - 'join', 'tail', 'diff', 'topk', 'peek', 'random_sample') + 'unique', 'isiterable', 'isdistinct', 'take', 'drop', 'split', + 'take_nth', 'first', 'second', 'nth', 'last', 'get', 'concat', + 'concatv', 'mapcat', 'cons', 'interpose', 'frequencies', + 'reduceby', 'iterate', 'sliding_window', 'partition', + 'partition_all', 'count', 'pluck', 'join', 'tail', 'diff', + 'topk', 'peek', 'random_sample') def remove(predicate, seq): @@ -350,6 +351,27 @@ def drop(n, seq): return itertools.islice(seq, n, None) +def split(n, seq): + """ Splits the sequence around element n. + + >>> list(map(tuple, split(2, [10, 20, 30, 40, 50]))) + [(10, 20), (30,), (40, 50)] + + See Also: + take + nth + drop + """ + + front, middle, back = itertools.tee(seq, 3) + + front = itertools.islice(front, 0, n) + middle = itertools.islice(middle, n, n + 1) + back = itertools.islice(back, n + 1, None) + + return front, middle, back + + def take_nth(n, seq): """ Every nth item in seq diff --git a/toolz/tests/test_itertoolz.py b/toolz/tests/test_itertoolz.py index 93aa856d..9609097e 100644 --- a/toolz/tests/test_itertoolz.py +++ b/toolz/tests/test_itertoolz.py @@ -8,7 +8,7 @@ concat, concatv, interleave, unique, isiterable, getter, mapcat, isdistinct, first, second, - nth, take, tail, drop, interpose, get, + nth, take, tail, drop, split, interpose, get, rest, last, cons, frequencies, reduceby, iterate, accumulate, sliding_window, count, partition, @@ -172,6 +172,28 @@ def test_drop(): assert list(drop(1, (3, 2, 1))) == list((2, 1)) +def test_split(): + l = [10, 20, 30, 40, 50] + assert list(map(tuple, split(0, l))) == [tuple(), + (10,), + (20, 30, 40, 50)] + + l = [10, 20, 30, 40, 50] + assert list(map(tuple, split(4, l))) == [(10, 20, 30, 40,), + (50,), + tuple()] + + l = [10, 20, 30, 40, 50] + assert list(map(tuple, split(5, l))) == [(10, 20, 30, 40, 50), + tuple(), + tuple()] + + l = [10, 20, 30, 40, 50] + assert list(map(tuple, split(2, l))) == [(10, 20), + (30,), + (40, 50)] + + def test_take_nth(): assert list(take_nth(2, 'ABCDE')) == list('ACE')