66 from collections import Sequence
77except ImportError :
88 from collections .abc import Sequence
9+ try :
10+ # enum in stdlib as of py3.4
11+ from enum import IntEnum # pylint: disable=import-error
12+ except ImportError :
13+ # vendored backport module
14+ from kafka .vendor .enum34 import IntEnum
915import logging
1016import random
1117import re
2026log = logging .getLogger (__name__ )
2127
2228
29+ class SubscriptionType (IntEnum ):
30+ NONE = 0
31+ AUTO_TOPICS = 1
32+ AUTO_PATTERN = 2
33+ USER_ASSIGNED = 3
34+
35+
2336class SubscriptionState (object ):
2437 """
2538 A class for tracking the topics, partitions, and offsets for the consumer.
@@ -67,6 +80,7 @@ def __init__(self, offset_reset_strategy='earliest'):
6780 self ._default_offset_reset_strategy = offset_reset_strategy
6881
6982 self .subscription = None # set() or None
83+ self .subscription_type = SubscriptionType .NONE
7084 self .subscribed_pattern = None # regex str or None
7185 self ._group_subscription = set ()
7286 self ._user_assignment = set ()
@@ -76,6 +90,14 @@ def __init__(self, offset_reset_strategy='earliest'):
7690 # initialize to true for the consumers to fetch offset upon starting up
7791 self .needs_fetch_committed_offsets = True
7892
93+ def _set_subscription_type (self , subscription_type ):
94+ if not isinstance (subscription_type , SubscriptionType ):
95+ raise ValueError ('SubscriptionType enum required' )
96+ if self .subscription_type == SubscriptionType .NONE :
97+ self .subscription_type = subscription_type
98+ elif self .subscription_type != subscription_type :
99+ raise IllegalStateError (self ._SUBSCRIPTION_EXCEPTION_MESSAGE )
100+
79101 def subscribe (self , topics = (), pattern = None , listener = None ):
80102 """Subscribe to a list of topics, or a topic regex pattern.
81103
@@ -111,17 +133,19 @@ def subscribe(self, topics=(), pattern=None, listener=None):
111133 guaranteed, however, that the partitions revoked/assigned
112134 through this interface are from topics subscribed in this call.
113135 """
114- if self ._user_assignment or (topics and pattern ):
115- raise IllegalStateError (self ._SUBSCRIPTION_EXCEPTION_MESSAGE )
116136 assert topics or pattern , 'Must provide topics or pattern'
137+ if (topics and pattern ):
138+ raise IllegalStateError (self ._SUBSCRIPTION_EXCEPTION_MESSAGE )
117139
118- if pattern :
140+ elif pattern :
141+ self ._set_subscription_type (SubscriptionType .AUTO_PATTERN )
119142 log .info ('Subscribing to pattern: /%s/' , pattern )
120143 self .subscription = set ()
121144 self .subscribed_pattern = re .compile (pattern )
122145 else :
123146 if isinstance (topics , str ) or not isinstance (topics , Sequence ):
124147 raise TypeError ('Topics must be a list (or non-str sequence)' )
148+ self ._set_subscription_type (SubscriptionType .AUTO_TOPICS )
125149 self .change_subscription (topics )
126150
127151 if listener and not isinstance (listener , ConsumerRebalanceListener ):
@@ -141,7 +165,7 @@ def change_subscription(self, topics):
141165 - a topic name is '.' or '..' or
142166 - a topic name does not consist of ASCII-characters/'-'/'_'/'.'
143167 """
144- if self ._user_assignment :
168+ if not self .partitions_auto_assigned () :
145169 raise IllegalStateError (self ._SUBSCRIPTION_EXCEPTION_MESSAGE )
146170
147171 if isinstance (topics , six .string_types ):
@@ -168,13 +192,13 @@ def group_subscribe(self, topics):
168192 Arguments:
169193 topics (list of str): topics to add to the group subscription
170194 """
171- if self ._user_assignment :
195+ if not self .partitions_auto_assigned () :
172196 raise IllegalStateError (self ._SUBSCRIPTION_EXCEPTION_MESSAGE )
173197 self ._group_subscription .update (topics )
174198
175199 def reset_group_subscription (self ):
176200 """Reset the group's subscription to only contain topics subscribed by this consumer."""
177- if self ._user_assignment :
201+ if not self .partitions_auto_assigned () :
178202 raise IllegalStateError (self ._SUBSCRIPTION_EXCEPTION_MESSAGE )
179203 assert self .subscription is not None , 'Subscription required'
180204 self ._group_subscription .intersection_update (self .subscription )
@@ -197,9 +221,7 @@ def assign_from_user(self, partitions):
197221 Raises:
198222 IllegalStateError: if consumer has already called subscribe()
199223 """
200- if self .subscription is not None :
201- raise IllegalStateError (self ._SUBSCRIPTION_EXCEPTION_MESSAGE )
202-
224+ self ._set_subscription_type (SubscriptionType .USER_ASSIGNED )
203225 if self ._user_assignment != set (partitions ):
204226 self ._user_assignment = set (partitions )
205227 self ._set_assignment ({partition : self .assignment .get (partition , TopicPartitionState ())
@@ -250,6 +272,7 @@ def unsubscribe(self):
250272 self ._user_assignment .clear ()
251273 self .assignment .clear ()
252274 self .subscribed_pattern = None
275+ self .subscription_type = SubscriptionType .NONE
253276
254277 def group_subscription (self ):
255278 """Get the topic subscription for the group.
@@ -300,7 +323,7 @@ def fetchable_partitions(self):
300323
301324 def partitions_auto_assigned (self ):
302325 """Return True unless user supplied partitions manually."""
303- return self .subscription is not None
326+ return self .subscription_type in ( SubscriptionType . AUTO_TOPICS , SubscriptionType . AUTO_PATTERN )
304327
305328 def all_consumed_offsets (self ):
306329 """Returns consumed offsets as {TopicPartition: OffsetAndMetadata}"""
0 commit comments