diff --git a/requirements.txt b/requirements.txt index ccee6de..b81b24e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1 +1,2 @@ -scrapy \ No newline at end of file +scrapy +simplejson \ No newline at end of file diff --git a/scrapy_streaming/commands/streaming.py b/scrapy_streaming/commands/streaming.py index 2199415..75242f6 100644 --- a/scrapy_streaming/commands/streaming.py +++ b/scrapy_streaming/commands/streaming.py @@ -1,5 +1,3 @@ -import os - from scrapy.commands import ScrapyCommand from scrapy.exceptions import UsageError diff --git a/scrapy_streaming/communication/__init__.py b/scrapy_streaming/communication/__init__.py index c0ab284..25b930a 100644 --- a/scrapy_streaming/communication/__init__.py +++ b/scrapy_streaming/communication/__init__.py @@ -1,3 +1,3 @@ from scrapy_streaming.communication.map import CommunicationMap -from scrapy_streaming.communication.wrappers import * +from scrapy_streaming.communication.validators import * diff --git a/scrapy_streaming/line_receiver.py b/scrapy_streaming/communication/line_receiver.py similarity index 79% rename from scrapy_streaming/line_receiver.py rename to scrapy_streaming/communication/line_receiver.py index aa77217..50a7ceb 100644 --- a/scrapy_streaming/line_receiver.py +++ b/scrapy_streaming/communication/line_receiver.py @@ -9,21 +9,21 @@ class LineProcessProtocol(protocol.ProcessProtocol, object): """ def __init__(self): - self.__buffer = b'' - self.__delimiter = b'\n' + self._buffer = b'' + self._delimiter = b'\n' def outReceived(self, data): """ Implement the outReceived method, buffering the incoming data and dispatching line by line in the ``lineReceived`` method. """ - self.__buffer += data + self._buffer += data - lines = self.__buffer.splitlines() - if data.endswith(self.__delimiter): - self.__buffer = b'' + lines = self._buffer.splitlines() + if data.endswith(self._delimiter): + self._buffer = b'' else: - self.__buffer = lines.pop() + self._buffer = lines.pop() for line in lines: self.lineReceived(line) @@ -40,5 +40,5 @@ def writeLine(self, data): """ data = to_bytes(data) if not data.endswith(b'\n'): - data += self.__delimiter + data += self._delimiter self.transport.write(data) diff --git a/scrapy_streaming/communication/map.py b/scrapy_streaming/communication/map.py index 1bb7f4c..13f4ab4 100644 --- a/scrapy_streaming/communication/map.py +++ b/scrapy_streaming/communication/map.py @@ -1,24 +1,31 @@ -import json +import base64 + +import simplejson as json from scrapy.utils.python import to_unicode, to_native_str -from scrapy_streaming.communication import wrappers -from scrapy_streaming.utils import MessageError +from scrapy_streaming.communication import validators +from scrapy_streaming.utils import MessageError, extract_instance_fields class CommunicationMap(object): """ - Helper class to create the json messages + Helper class to create and receive json messages """ mapping = { - 'spider': wrappers.SpiderMessage, - 'request': wrappers.RequestMessage, - 'log': wrappers.LogMessage + 'spider': validators.SpiderMessage, + 'request': validators.RequestMessage, + 'from_response_request': validators.FromResponseRequestMessage, + 'log': validators.LogMessage, + 'close': validators.CloseMessage } @staticmethod def parse(line): + """ + Receives a json string in a line, that will be decoded and parsed to a message + """ try: msg = json.loads(to_native_str(line)) @@ -28,10 +35,10 @@ def parse(line): raise MessageError('"type" field not provided.') msg_type = msg.pop('type') - try: - return CommunicationMap.mapping[msg_type].from_dict(msg) - except KeyError: + if msg_type not in CommunicationMap.mapping: raise MessageError('%s is not a valid message type.' % msg_type) + + return CommunicationMap.mapping[msg_type].from_dict(msg, line) except ValueError: raise MessageError('Received message is not a valid json.') @@ -48,18 +55,28 @@ def error(message, details): return json.dumps(fields) @staticmethod - def response(resp, request_id='parse'): - fields = _extract_fields(resp, ['url', 'headers', 'status', 'body', 'meta', 'flags']) - fields['id'] = to_unicode(request_id) + def response(resp, encode64): + fields = extract_instance_fields(resp, ['url', 'headers', 'status', 'meta', 'flags']) + if encode64: + fields['body'] = base64.b64encode(resp.body) + else: + # validates if the body is text-like and serializable + try: + json.dumps(resp.body) # throws UnicodeDecodeError if not text-serializable + fields['body'] = resp.body + except UnicodeDecodeError: + raise ValueError('Response body is not serializable. If it\'s returning binary data, ' + 'set the "base64" to True to encode the data.') + + fields['id'] = resp.meta['request_id'] + fields['type'] = 'response' + return json.dumps(fields) + @staticmethod + def exception(line, exception): + fields = {'type': 'exception', + 'received_message': to_unicode(line), + 'exception': to_unicode(exception)} -def _extract_fields(item, fields): - """ - Given a list of fields, generate a dict with key being the name of the field - mapping to the serialized item.field value - """ - data = {} - for field in fields: - data[field] = json.loads(json.dumps(getattr(item, field))) - return data + return json.dumps(fields) diff --git a/scrapy_streaming/communication/validators.py b/scrapy_streaming/communication/validators.py new file mode 100644 index 0000000..eda2727 --- /dev/null +++ b/scrapy_streaming/communication/validators.py @@ -0,0 +1,142 @@ +import logging + +import six + +from scrapy_streaming.utils import MessageError +from scrapy_streaming.utils.fields import EmptyField, RequiredField + + +class MessageValidator(object): + """ + This class is responsible for validating dicts keys and values. + """ + + validator = {} + """ + Validators must be defined as: + key: name of the field + value: expected type + All possible message fields must be defined in the validator. + """ + + def __init__(self, default, fields): + self.data = fields + self.validate(fields) + self.update(default, fields) + + @classmethod + def from_dict(cls, data, line=None): + c = cls(data) + c.line = line + return c + + def validate(self, data): + """ + This methods check if the dict ``data`` follows the validator scheme. + If there is a problem in the validation, raises a MessageError. + """ + validator = self.validator + + for field, value in data.items(): + if field not in validator: + raise MessageError('Unknown message field: %s' % field) + + if value is not None and not isinstance(value, validator[field]): + raise MessageError('%s field must be defined as %s, received: %s' % + (field, validator[field].__name__, type(value).__name__)) + + def update(self, default, data): + """ + After being validated, this method can merge the ``data`` object with the default values. + If a RequiredField was not provided, raises a MessageError. + """ + default.update(data) + for item, value in default.items(): + if isinstance(value, RequiredField): + raise MessageError('Required field: %s' % item) + + if not isinstance(value, EmptyField): + setattr(self, item, value) + + +class RequestMessage(MessageValidator): + validator = {'id': six.string_types, 'url': six.string_types, 'method': six.string_types, + 'meta': dict, 'body': six.string_types, 'headers': dict, + 'cookies': (dict, list), 'encoding': six.string_types, + 'priority': int, 'dont_filter': bool, 'base64': bool} + + def __init__(self, fields): + default = {'id': RequiredField(), 'url': RequiredField(), 'method': EmptyField(), + 'meta': EmptyField(), 'body': EmptyField(), 'headers': EmptyField(), + 'cookies': EmptyField(), 'encoding': EmptyField(), 'priority': EmptyField(), + 'dont_filter': EmptyField(), 'base64': False} + + super(RequestMessage, self).__init__(default, fields) + self.data.pop('base64', None) + + +class Form(MessageValidator): + validator = {'formname': six.string_types, 'formxpath': six.string_types, + 'formcss': six.string_types, 'formnumber': int, + 'formdata': dict, 'clickdata': dict, 'dont_click': bool, + # request fields + 'method': six.string_types, 'meta': dict, 'body': six.string_types, + 'headers': dict, 'cookies': (dict, list), 'encoding': six.string_types, + 'priority': int, 'dont_filter': bool} + + def __init__(self, form): + default = {'formname': EmptyField(), 'formxpath': EmptyField(), + 'formcss': EmptyField(), 'formnumber': EmptyField(), + 'formdata': EmptyField(), 'clickdata': EmptyField(), + 'dont_click': EmptyField(), + # request fields + 'method': EmptyField(), 'meta': EmptyField(), 'body': EmptyField(), + 'headers': EmptyField(), 'cookies': EmptyField(), 'encoding': EmptyField(), + 'priority': EmptyField(), 'dont_filter': EmptyField()} + + super(Form, self).__init__(default, form) + + +class FromResponseRequestMessage(RequestMessage): + + def __init__(self, fields): + if 'from_response_request' not in fields: + raise MessageError('Required field: from_response_request') + from_response_request = fields.pop('from_response_request') + + super(FromResponseRequestMessage, self).__init__(fields) + self.from_response_request = Form.from_dict(from_response_request) + + +class SpiderMessage(MessageValidator): + validator = {'name': six.string_types, 'start_urls': list, + 'allowed_domains': list, 'custom_settings': dict} + + def __init__(self, fields): + default = {'name': RequiredField(), 'start_urls': RequiredField(), + 'allowed_domains': EmptyField(), 'custom_settings': EmptyField()} + + super(SpiderMessage, self).__init__(default, fields) + + +class LogMessage(MessageValidator): + validator = {'message': six.string_types, 'level': six.string_types} + + def __init__(self, fields): + default = {'message': RequiredField(), 'level': RequiredField()} + + super(LogMessage, self).__init__(default, fields) + levels = {'CRITICAL': logging.CRITICAL, 'ERROR': logging.ERROR, + 'WARNING': logging.WARNING, 'INFO': logging.INFO, + 'DEBUG': logging.DEBUG} + + if self.level.upper() not in levels: + raise MessageError('Invalid log level: %s' % self.level) + + self.level = levels[self.level.upper()] + + +class CloseMessage(MessageValidator): + + def __init__(self, fields): + super(CloseMessage, self).__init__({}, fields) diff --git a/scrapy_streaming/communication/wrappers.py b/scrapy_streaming/communication/wrappers.py deleted file mode 100644 index 1431249..0000000 --- a/scrapy_streaming/communication/wrappers.py +++ /dev/null @@ -1,66 +0,0 @@ -import six - -from scrapy_streaming.spiders import StreamingSpider -from scrapy_streaming.utils import MessageError, RequiredField - - -class ExternalSpiderMessageWrapper(object): - validator = {} - - def __init__(self, default, fields): - self.data = fields - self.validate(fields) - self.update(default, fields) - - @classmethod - def from_dict(cls, data): - return cls(data) - - def validate(self, data): - validator = self.validator - for key, value in data.items(): - if key not in validator: - raise MessageError('Unknown message field: %s' % key) - - if value is not None and not isinstance(value, validator[key]): - raise MessageError('%s field must be defined as %s, received: %s' % - (key, validator[key].__name__, type(value).__name__)) - - def update(self, default, data): - default.update(data) - for item, value in default.items(): - if isinstance(value, RequiredField): - raise MessageError('Required field: %s' % item) - setattr(self, item, value) - - -class RequestMessage(ExternalSpiderMessageWrapper): - validator = {'id': six.text_type, 'url': six.text_type} - - def __init__(self, fields): - default = {'id': None, 'start_urls': None, 'method': None, 'meta': None, - 'body': None, 'headers': None, 'cookies': None, 'encoding': None, - 'priority': None, 'dont_filter': None} - - super(RequestMessage, self).__init__(default, fields) - - -class SpiderMessage(ExternalSpiderMessageWrapper): - validator = {'name': six.text_type, 'start_urls': list, - 'allowed_domains': list, 'custom_settings': dict} - - def __init__(self, fields): - default = {'name': RequiredField(), 'start_urls': RequiredField(), - 'allowed_domains': None, 'custom_settings': None} - - super(SpiderMessage, self).__init__(default, fields) - - -class LogMessage(ExternalSpiderMessageWrapper): - - validator = {'message': six.text_type, 'level': six.text_type} - - def __init__(self, fields): - default = {'message': RequiredField(), 'level': RequiredField()} - - super(LogMessage, self).__init__(default, fields) diff --git a/scrapy_streaming/external_spiderloader.py b/scrapy_streaming/external_spiderloader.py index 60d5d55..2d874a7 100644 --- a/scrapy_streaming/external_spiderloader.py +++ b/scrapy_streaming/external_spiderloader.py @@ -1,9 +1,9 @@ -import json +import simplejson as json import os from twisted.internet import reactor -from scrapy_streaming.process_streaming import ProcessStreamingProtocol +from scrapy_streaming.streaming import StreamingProtocol from scrapy_streaming.utils import get_project_root @@ -34,7 +34,6 @@ def __init__(self, settings, load_spiders=True): if load_spiders: path = settings.get('EXTERNAL_SPIDERS_PATH', get_project_root()) - # TODO add EXTERNAL_SPIDERS_PATH in docs path = os.path.abspath(path) self.external = os.path.join(path, 'external.json') self._fetch_spiders() @@ -68,7 +67,7 @@ def crawl(self, name_or_spider): if not isinstance(name_or_spider, ExternalSpider): name_or_spider = self._spiders[name_or_spider] - protocol = ProcessStreamingProtocol() + protocol = StreamingProtocol() reactor.spawnProcess(protocol, name_or_spider.command, args=[name_or_spider.command] + name_or_spider.args) reactor.run() diff --git a/scrapy_streaming/process_streaming.py b/scrapy_streaming/process_streaming.py deleted file mode 100644 index 5b8a047..0000000 --- a/scrapy_streaming/process_streaming.py +++ /dev/null @@ -1,58 +0,0 @@ -from twisted.internet import reactor - -from scrapy_streaming.communication import CommunicationMap, LogMessage, SpiderMessage -from scrapy_streaming.line_receiver import LineProcessProtocol -from scrapy_streaming.utils import MessageError - - -class ProcessStreamingProtocol(LineProcessProtocol): - """ - This class is responsible for the communication channel between scrapy-streaming and the external spider. - All messages are sent/received by this class - """ - - def __init__(self): - super(ProcessStreamingProtocol, self).__init__() - self.spider = None - self.streaming = ProcessStreaming(self) - - def connectionMade(self): - self.writeLine(CommunicationMap.ready()) - - def lineReceived(self, line): - try: - msg = CommunicationMap.parse(line) - self.streaming.on_message(msg) - except MessageError as e: - self.sendError(line, str(e)) - - def sendError(self, msg, details): - self.writeLine(CommunicationMap.error(msg, details)) - - def errReceived(self, data): - print(data) - - def processEnded(self, reason): - reactor.stop() - # FIXME add a valid process listener - - -class ProcessStreaming(object): - - def __init__(self, protocol): - self.protocol = protocol - self.mapping = { - LogMessage: self.on_log, - SpiderMessage: self.on_spider - } - - def on_message(self, msg): - self.mapping[type(msg)](msg) - - def on_log(self, msg): - import logging - logging.info(msg.message) - # FIXME add a real logger - - def on_spider(self, msg): - pass diff --git a/scrapy_streaming/spiders.py b/scrapy_streaming/spiders.py deleted file mode 100644 index 7b44253..0000000 --- a/scrapy_streaming/spiders.py +++ /dev/null @@ -1,13 +0,0 @@ -from scrapy import Spider -from twisted.internet import defer - - -class StreamingSpider(Spider): - - def __init__(self, process, **kwargs): - super(StreamingSpider, self).__init__(**kwargs) - self.process = process - self.stream = defer.Deferred() - - def parse(self, response): - return self.stream diff --git a/scrapy_streaming/streaming.py b/scrapy_streaming/streaming.py new file mode 100644 index 0000000..a91ee85 --- /dev/null +++ b/scrapy_streaming/streaming.py @@ -0,0 +1,161 @@ +import logging + +from scrapy import Request, FormRequest +from scrapy.crawler import CrawlerRunner +from scrapy.utils.python import to_native_str +from twisted.internet import reactor +from twisted.internet.error import ProcessExitedAlready + +from scrapy_streaming.communication import CommunicationMap, LogMessage, SpiderMessage, RequestMessage, CloseMessage, \ + FromResponseRequestMessage +from scrapy_streaming.communication.line_receiver import LineProcessProtocol +from scrapy_streaming.utils import MessageError +from scrapy_streaming.utils.spiders import StreamingSpider + + +class StreamingProtocol(LineProcessProtocol): + """ + This class is responsible for the low-level communication channel between scrapy-streaming and the external spider. + All messages are sent/received by this class. + """ + + def __init__(self): + super(StreamingProtocol, self).__init__() + self.logger = logging.getLogger(self.__class__.__name__) + self.streaming = Streaming(self) + self._closing = False + + def connectionMade(self): + self.writeLine(CommunicationMap.ready()) + + def lineReceived(self, line): + try: + msg = CommunicationMap.parse(line) + self.streaming.on_message(msg) + except MessageError as e: + self.sendError(line, str(e)) + + def sendError(self, msg, details): + self.logger.error('Received message: ' + to_native_str(msg)) + self.logger.error(details) + self.writeLine(CommunicationMap.error(msg, details)) + self.closeProcess() + + def errReceived(self, data): + self.logger.error('Received error from external spider') + self.logger.error(data) + self.logger.error('Closing the process due to this error') + self.closeProcess() + + def processEnded(self, status): + self.logger.debug("Process ended") + self.closeProcess() + + def closeProcess(self): + if self._closing: + return + self._closing = True + self.transport.loseConnection() + try: # kill the process if it still running + pid = self.transport.pid + self.transport.signalProcess('KILL') + self.logger.debug('Killing the process %s' % pid) + except ProcessExitedAlready: + pass + reactor.stop() + + +class Streaming(object): + """ + This class contains the high-level communication channel between scrapy-streaming and the external spider. + The logic and workflow is processed here. All received messages and sent pass through this class. + + Incoming messages are in the methods started with on_. Outcoming messages are in methods that + starts with send_ + """ + + def __init__(self, protocol): + self.protocol = protocol + self.crawler = None + self.logger = logging.getLogger(self.__class__.__name__) + self.mapping = { + LogMessage: self._on_log, + SpiderMessage: self._on_spider, + RequestMessage: self._on_request, + FromResponseRequestMessage: self._on_from_response_request, + CloseMessage: self._on_close + } + + def on_message(self, msg): + if not isinstance(msg, LogMessage) and not isinstance(msg, SpiderMessage) and self.crawler is None: + raise MessageError('You must start your spider before sending this message') + self.mapping[type(msg)](msg) + + def _on_log(self, msg): + self.logger.log(msg.level, msg.message) + + def _on_spider(self, msg): + if self.crawler is not None: + raise MessageError('Spider already initialized') + fields = {'streaming': self, 'msg': msg} + fields.update(msg.data) + + runner = CrawlerRunner() + self.crawler = runner.create_crawler(StreamingSpider) + dfd = runner.crawl(self.crawler, **fields) + dfd.addBoth(lambda x: self.protocol.closeProcess()) + + self.logger.debug('Spider started: %s' % msg.name) + + def _on_request(self, msg, callback=None): + if callback is None: + callback = self.send_response + # update request with id field + request_id = msg.data.pop('id') + meta = msg.data.pop('meta', {}) + meta['request_id'] = request_id + msg.data['meta'] = meta + + try: + r = Request(callback=lambda x: callback(msg, x), + errback=lambda x: self.send_exception(msg, x.getErrorMessage()), + **msg.data) + self.crawler.engine.crawl(r, self.crawler.spider) + except (ValueError, TypeError) as e: # errors raised by request creator + self.send_exception(msg, str(e)) + + def _on_from_response_request(self, msg): + self._on_request(msg, self._from_response) + + def _from_response(self, msg, response): + request_id = response.meta['request_id'] + + meta = msg.from_response_request.data.pop('meta', {}) + meta['request_id'] = request_id + msg.from_response_request.data['meta'] = meta + try: + # check for possible problems in the response + r = FormRequest.from_response(response, callback=lambda x: self.send_response(msg, x), + errback=lambda x: self.send_exception(msg, x.getErrorMessage()), + **msg.from_response_request.data) + self.crawler.engine.crawl(r, self.crawler.spider) + except (ValueError, IndexError) as e: # errors raised by from_response + self.send_exception(msg, str(e)) + + def _on_close(self, msg): + self.crawler.stop() + self.crawler.spider.close_spider() + self.protocol.closeProcess() + + self.logger.debug('Spider closed') + + def send_response(self, msg, response): + try: + self.protocol.writeLine(CommunicationMap.response(response, msg.base64)) + except ValueError as e: # problems in the encoding + self.send_exception(msg, str(e)) + + def send_exception(self, msg, details): + self.logger.error('Scrapy raised an exception: ' + details) + self.logger.error('Caused by: ' + to_native_str(msg.line)) + self.protocol.writeLine(CommunicationMap.exception(msg.line, details)) diff --git a/scrapy_streaming/utils.py b/scrapy_streaming/utils/__init__.py similarity index 63% rename from scrapy_streaming/utils.py rename to scrapy_streaming/utils/__init__.py index 5f2d724..1944828 100644 --- a/scrapy_streaming/utils.py +++ b/scrapy_streaming/utils/__init__.py @@ -1,9 +1,9 @@ -import json +import simplejson as json import os from scrapy.utils.conf import closest_scrapy_cfg from scrapy.utils.project import inside_project -from scrapy.utils.python import to_bytes +from scrapy.utils.python import to_bytes, to_unicode def get_project_root(): @@ -24,9 +24,16 @@ def dict_serialize(dict_obj, enc=None): return to_bytes(json.dumps(dict_obj), enc) -class MessageError(Exception): - pass +def extract_instance_fields(instance, fields): + """ + Given a list of fields, generate a dict with key being the name of the field + mapping to the serialized instance.field value + """ + data = {} + for field in fields: + data[field] = getattr(instance, field) + return data -class RequiredField(object): +class MessageError(Exception): pass diff --git a/scrapy_streaming/utils/fields.py b/scrapy_streaming/utils/fields.py new file mode 100644 index 0000000..5d29882 --- /dev/null +++ b/scrapy_streaming/utils/fields.py @@ -0,0 +1,6 @@ +class RequiredField(object): + pass + + +class EmptyField(object): + pass diff --git a/scrapy_streaming/utils/spiders.py b/scrapy_streaming/utils/spiders.py new file mode 100644 index 0000000..c9289f7 --- /dev/null +++ b/scrapy_streaming/utils/spiders.py @@ -0,0 +1,54 @@ +import scrapy +from scrapy import Spider +from scrapy.exceptions import DontCloseSpider +from scrapy.http import Request +from twisted.internet import defer + + +class StreamingSpider(Spider): + + def __init__(self, streaming=None, msg=None, **kwargs): + self.streaming = streaming + self.msg = msg + super(StreamingSpider, self).__init__(**kwargs) + self.stream = defer.Deferred() + self._done = False + + def parse(self, response): + # sets the default encoding type with a fake request message + class FakeMessage(object): + base64 = False + line = '' + msg = FakeMessage() + + response.meta['request_id'] = 'parse' + self.streaming.send_response(msg, response) + return self.stream + + def close_spider(self): + self.stream.callback(None) + self._done = True + + def make_requests_from_url(self, url): + # adds errback to spider initial requests + return Request(url, callback=self.parse, dont_filter=True, + errback=lambda x: self.streaming.send_exception(self.msg, x.getErrorMessage())) + + def start_requests(self): + # catch exceptions in spider initial urls + for url in self.start_urls: + try: + yield self.make_requests_from_url(url) + except (ValueError, TypeError) as e: # errors raised by request creator + self.streaming.send_exception(self.msg, str(e)) + + @classmethod + def from_crawler(cls, crawler, *args, **kwargs): + from_crawler = super(StreamingSpider, cls).from_crawler + spider = from_crawler(crawler, *args, **kwargs) + crawler.signals.connect(spider.keep_alive, signal=scrapy.signals.spider_idle) + return spider + + def keep_alive(self, spider): + if not self._done: + raise DontCloseSpider() diff --git a/setup.py b/setup.py index 915e87b..e7e4db2 100644 --- a/setup.py +++ b/setup.py @@ -8,7 +8,7 @@ description='Develop Spiders using any Programming Language', author='Scrapy developers', packages=find_packages(exclude=('tests', 'tests.*')), - requires=['scrapy'], + requires=['scrapy', 'simplejson'], entry_points={ 'scrapy.commands': [ diff --git a/tests/__init__.py b/tests/__init__.py index e69de29..9b4dc7f 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -0,0 +1,4 @@ +try: + import unittest.mock as mock +except ImportError: + import mock diff --git a/tests/requirements.txt b/tests/requirements.txt new file mode 100644 index 0000000..dffb3a7 --- /dev/null +++ b/tests/requirements.txt @@ -0,0 +1,2 @@ +testfixtures +mock diff --git a/tests/spiders/error_spider.py b/tests/spiders/error_spider.py new file mode 100755 index 0000000..687259e --- /dev/null +++ b/tests/spiders/error_spider.py @@ -0,0 +1,4 @@ +#! /usr/bin/env python + +if __name__ == '__main__': + print(0/0) diff --git a/tests/spiders/multiple_data.py b/tests/spiders/multiple_data.py new file mode 100755 index 0000000..e859784 --- /dev/null +++ b/tests/spiders/multiple_data.py @@ -0,0 +1,9 @@ +#! /usr/bin/env python +import sys + +if __name__ == '__main__': + status = sys.stdin.readline() + sys.stdout.write('{"type": "log", "level": "debug", "message": "qwertyuiop"}\n' * 1000) + sys.stdout.flush() + sys.stdout.write('{"type": "close"}' * 10) + sys.stdout.flush() diff --git a/tests/spiders/request_exception.py b/tests/spiders/request_exception.py new file mode 100755 index 0000000..9b7292d --- /dev/null +++ b/tests/spiders/request_exception.py @@ -0,0 +1,8 @@ +#! /usr/bin/env python +import sys + +if __name__ == '__main__': + status = sys.stdin.readline() + sys.stdout.write('{"type": "spider", "start_urls": ["invalid_url_scheme"], "name": "test"}\n') + sys.stdout.flush() + line = sys.stdin.readline() diff --git a/tests/spiders/wrong_message.py b/tests/spiders/wrong_message.py new file mode 100755 index 0000000..12352e8 --- /dev/null +++ b/tests/spiders/wrong_message.py @@ -0,0 +1,7 @@ +#! /usr/bin/env python +import sys + +if __name__ == '__main__': + status = sys.stdin.readline() + sys.stdout.write('{"type": "invalid_type", "field1": "value"}\n') + sys.stdout.flush() diff --git a/tests/test_commands.py b/tests/test_commands.py index cc4432e..9e89ade 100644 --- a/tests/test_commands.py +++ b/tests/test_commands.py @@ -111,6 +111,14 @@ def test_streaming(self): self.assertIn('sample1.py working', log) + def test_streaming_wrong_message(self): + path = os.path.abspath(os.path.dirname(__file__)) + test1 = os.path.join(path, 'spiders', 'wrong_message.py') + p = self.proc('streaming', test1) + log = to_native_str(p.stderr.read()) + + self.assertIn('invalid_type is not a valid message type.', log) + def test_streaming_args(self): path = os.path.abspath(os.path.dirname(__file__)) test1 = os.path.join(path, 'spiders', 'sample1.py') @@ -119,6 +127,30 @@ def test_streaming_args(self): self.assertIn('sample1.py working', log) + def test_streaming_request_exception(self): + path = os.path.abspath(os.path.dirname(__file__)) + test1 = os.path.join(path, 'spiders', 'request_exception.py') + p = self.proc('streaming', test1) + log = to_native_str(p.stderr.read()) + + self.assertIn('Scrapy raised an exception', log) + + def test_streaming_external_error(self): + path = os.path.abspath(os.path.dirname(__file__)) + test1 = os.path.join(path, 'spiders', 'error_spider.py') + p = self.proc('streaming', test1) + log = to_native_str(p.stderr.read()) + + self.assertIn('Closing the process due to this error', log) + + def test_multiple_messages_streaming(self): + path = os.path.abspath(os.path.dirname(__file__)) + test1 = os.path.join(path, 'spiders', 'multiple_data.py') + p = self.proc('streaming', test1) + log = to_native_str(p.stderr.read()) + + self.assertEqual(log.count("qwertyuiop"), 1000) + class CrawlCommandTest(ProjectTest): diff --git a/tests/test_communication/test_line_receiver.py b/tests/test_communication/test_line_receiver.py new file mode 100644 index 0000000..341374f --- /dev/null +++ b/tests/test_communication/test_line_receiver.py @@ -0,0 +1,23 @@ +from scrapy_streaming.communication.line_receiver import LineProcessProtocol +from twisted.trial import unittest + +from tests import mock + + +class LineReceiverTest(unittest.TestCase): + + def setUp(self): + self.receiver = LineProcessProtocol() + + def test_buffering(self): + + with mock.patch.object(self.receiver, 'lineReceived') as mock_method: + self.receiver.outReceived(b'received data \n with multiple \n lines') + + mock_method.assert_any_call(b'received data ') + mock_method.assert_any_call(b' with multiple ') + + self.assertEqual(self.receiver._buffer, b' lines') + + def test_must_implement_lineReceived(self): + self.assertRaises(NotImplementedError, self.receiver.outReceived, b'test\n') diff --git a/tests/test_communication/test_map.py b/tests/test_communication/test_map.py index 90e613f..d8299dd 100644 --- a/tests/test_communication/test_map.py +++ b/tests/test_communication/test_map.py @@ -1,26 +1,76 @@ +import base64 + +import simplejson as json + +from scrapy import Request +from scrapy.http import Response +from scrapy.utils.python import to_native_str from twisted.trial import unittest -from scrapy_streaming.communication.map import _extract_fields +from scrapy_streaming.communication import CommunicationMap +from scrapy_streaming.utils import MessageError class CommunicationMapTest(unittest.TestCase): - def test_extract_field(self): - class Test(object): - a = 'a' - b = 2 - c = {'a': 'b'} - d = [1, 2, 3] - e = 2.5 - f = None - - fields = ['a', 'b', 'c', 'd', 'e', 'f'] - expected = { - 'a': 'a', - 'b': 2, - 'c': {'a': 'b'}, - 'd': [1, 2, 3], - 'e': 2.5, - 'f': None - } - self.assertDictEqual(_extract_fields(Test(), fields), expected) + def test_invalid_json(self): + self.assertRaisesRegexp(MessageError, 'Received message is not a valid json.', CommunicationMap.parse, '{a: 2') + + def test_invalid_object(self): + self.assertRaisesRegexp(MessageError, 'This message is not a json object.', CommunicationMap.parse, '[1, 2, 3]') + + def test_invalid_type(self): + self.assertRaisesRegexp(MessageError, 'x is not a valid message type.', CommunicationMap.parse, '{"type": "x"}') + + def test_missing_type(self): + self.assertRaisesRegexp(MessageError, '"type" field not provided.', CommunicationMap.parse, '{"a": 1}') + + def test_ready(self): + ready = {'type': 'ready', 'status': 'ready'} + self.assertDictEqual(ready, json.loads(CommunicationMap.ready())) + + def test_error(self): + error = {'type': 'error', 'received_message': 'message', 'details': 'error details'} + + self.assertEqual(error, json.loads(CommunicationMap.error('message', 'error details'))) + + def test_response(self): + req = Request('http://example.com') + req.meta['request_id'] = 'test' + r = Response('http://example.com', request=req) + + resp = {'type': 'response', 'id': 'test', 'url': 'http://example.com', + 'status': 200, 'headers': {}, 'body': '', 'flags': [], 'meta': {'request_id': 'test'}} + + self.assertDictEqual(resp, json.loads(CommunicationMap.response(r, False))) + + def test_response_binary(self): + req = Request('http://example.com/file.png') + req.meta['request_id'] = 'test' + + img = b'the binary image data' + r = Response('http://example.com/file.png', request=req, body=img) + + resp = {'type': 'response', 'id': 'test', 'url': 'http://example.com/file.png', + 'status': 200, 'headers': {}, 'body': to_native_str(base64.b64encode(img)), + 'flags': [], 'meta': {'request_id': 'test'}} + + self.assertDictEqual(resp, json.loads(CommunicationMap.response(r, True))) + + def test_response_binary_missing_encoding(self): + req = Request('http://example.com/file.png') + req.meta['request_id'] = 'test' + + img = b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00d\x00\x00\x00d\x08\x02\x00' + r = Response('http://example.com/file.png', request=req, body=img) + + self.assertRaisesRegexp(ValueError, 'Response body is not serializable', + CommunicationMap.response, r, False) + + def test_exception(self): + line = '{"type": "log", "level": "debug", "message": "sample1.py working"}' + exception = 'Logger not found' + + exc = {'type': 'exception', 'received_message': line, 'exception': exception} + + self.assertDictEqual(exc, json.loads(CommunicationMap.exception(line, exception))) diff --git a/tests/test_communication/test_validators.py b/tests/test_communication/test_validators.py new file mode 100644 index 0000000..954013f --- /dev/null +++ b/tests/test_communication/test_validators.py @@ -0,0 +1,100 @@ +import logging + +import six +from twisted.trial import unittest + +from scrapy_streaming.communication import MessageValidator, RequestMessage, SpiderMessage, LogMessage, CloseMessage, \ + FromResponseRequestMessage +from scrapy_streaming.utils import MessageError +from scrapy_streaming.utils.fields import RequiredField, EmptyField + + +class MessageValidatorTest(unittest.TestCase): + + def test_validate(self): + class SimpleValidator(MessageValidator): + validator = {'a': int, 'b': six.string_types, 'c': dict, 'd': (int, float)} + + def __init__(self, data): + default = {} + super(SimpleValidator, self).__init__(default, data) + + self.assertRaisesRegexp(MessageError, 'Unknown message field: x', SimpleValidator.from_dict, {'x': 2}) + self.assertRaisesRegexp(MessageError, 'a field must be defined as', SimpleValidator.from_dict, {'a': 'value'}) + + def test_required_field(self): + class SimpleValidator(MessageValidator): + validator = {'a': int, 'b': six.string_types} + + def __init__(self, data): + default = {'a': RequiredField(), 'b': None} + super(SimpleValidator, self).__init__(default, data) + + self.assertRaisesRegexp(MessageError, 'Required field: a', SimpleValidator.from_dict, {}) + + def set_attributes(self): + class SimpleValidator(MessageValidator): + validator = {'a': int, 'b': list, 'c': int} + + def __init__(self, data): + default = {'a': RequiredField(), 'b': RequiredField(), 'c': None} + super(SimpleValidator, self).__init__(default, data) + a = 3 + b = [1, 2, 3] + v = SimpleValidator.from_dict({'a': a, 'b': b}) + + self.assertEqual(v.a, a) + self.assertEqual(v.b, b) + self.assertEqual(v.c, None) + + def test_empty_field(self): + class SimpleValidator(MessageValidator): + validator = {'a': int, 'b': list} + + def __init__(self, data): + default = {'a': None, 'b': EmptyField()} + super(SimpleValidator, self).__init__(default, data) + + v = SimpleValidator.from_dict({}) + + self.assertEqual(v.a, None) + self.assertEqual(hasattr(v, 'b'), False) + + +class MessagesTest(unittest.TestCase): + + def test_create_messages(self): + RequestMessage({'id': u'id', 'url': u'http://example.com'}) + FromResponseRequestMessage({'id': u'id', 'url': u'http://example.com', "from_response_request": {}}) + SpiderMessage({'name': u'name', 'start_urls': []}) + LogMessage({'message': u'message', 'level': u'debug'}) + CloseMessage({}) + + def test_log_level(self): + msg_critical = {'message': 'message', 'level': 'Critical'} + msg_error = {'message': 'message', 'level': 'ERRor'} + msg_warning = {'message': 'message', 'level': 'warning'} + msg_info = {'message': 'message', 'level': 'INFO'} + msg_debug = {'message': 'message', 'level': 'debug'} + + log_critical = LogMessage.from_dict(msg_critical) + log_error = LogMessage.from_dict(msg_error) + log_warning = LogMessage.from_dict(msg_warning) + log_info = LogMessage.from_dict(msg_info) + log_debug = LogMessage.from_dict(msg_debug) + + self.assertEqual(log_critical.level, logging.CRITICAL) + self.assertEqual(log_error.level, logging.ERROR) + self.assertEqual(log_warning.level, logging.WARNING) + self.assertEqual(log_info.level, logging.INFO) + self.assertEqual(log_debug.level, logging.DEBUG) + + def test_log_invalid_level(self): + msg = {'message': 'message', 'level': 'mycustomlevel'} + + self.assertRaisesRegexp(MessageError, 'Invalid log level: mycustomlevel', LogMessage.from_dict, msg) + + def test_from_response_request_params(self): + fields = {'id': u'id', 'url': u'http://example.com'} + + self.assertRaisesRegexp(MessageError, 'Required field: from_response_request', FromResponseRequestMessage.from_dict, fields) diff --git a/tests/test_streaming.py b/tests/test_streaming.py new file mode 100644 index 0000000..6957404 --- /dev/null +++ b/tests/test_streaming.py @@ -0,0 +1,151 @@ +from scrapy import Request +from scrapy.http import Response +from twisted.internet import reactor + +from tests import mock + +from scrapy.crawler import Crawler +from testfixtures import LogCapture + +from scrapy_streaming.streaming import Streaming +from twisted.trial import unittest + +from scrapy_streaming.communication import LogMessage, SpiderMessage, MessageError, CloseMessage, RequestMessage, \ + FromResponseRequestMessage, CommunicationMap + + +class FakeProtocol(object): + + def writeLine(self, data): + pass + + def closeProcess(self): + pass + + +class StreamingTest(unittest.TestCase): + + def setUp(self): + self.streaming = Streaming(FakeProtocol()) + + def tearDown(self): + try: + # try to clean the reactor, if necessary + self.streaming._on_close(None) + except: + pass + + def create_spider(self, name='sample'): + spider = SpiderMessage.from_dict({'name': name, 'start_urls': []}) + self.streaming.on_message(spider) + + def test_log_message(self): + log = LogMessage.from_dict({'message': 'test message', 'level': 'debug'}) + with LogCapture() as l: + self.streaming.on_message(log) + + l.check(('Streaming', 'DEBUG', 'test message')) + + def test_spider_message(self): + with LogCapture('Streaming') as l: + self.create_spider() + l.check(('Streaming', 'DEBUG', 'Spider started: sample')) + + self.assertIsInstance(self.streaming.crawler, Crawler) + self.assertRaisesRegexp(MessageError, 'Spider already initialized', self.create_spider) + + def test_close_message(self): + close = CloseMessage.from_dict({}) + + self.assertRaisesRegexp(MessageError, 'You must start your spider before sending this message', self.streaming.on_message, close) + self.create_spider() + with LogCapture('Streaming') as l: + self.streaming.on_message(close) + l.check(('Streaming', 'DEBUG', 'Spider closed')) + + def test_request_message(self): + request = RequestMessage.from_dict({'id': 'id', 'url': 'http://example.com'}) + + self.assertRaisesRegexp(MessageError, 'You must start your spider before sending this message', self.streaming.on_message, request) + self.create_spider() + with mock.patch.object(self.streaming.crawler.engine, 'crawl', return_value=None) as mock_method: + self.streaming.on_message(request) + + self.assertTrue(mock_method.called) + + def test_request_message_missing_url_scheme(self): + request = RequestMessage.from_dict({'id': 'id', 'url': 'example.com'}) + + self.assertRaisesRegexp(MessageError, 'You must start your spider before sending this message', self.streaming.on_message, request) + self.create_spider() + with mock.patch.object(self.streaming, 'send_exception', return_value=None) as mock_method: + self.streaming.on_message(request) + + self.assertTrue(mock_method.called) + + def test_from_response_equest_message(self): + msg_request = FromResponseRequestMessage.from_dict({'id': 'id', 'url': 'http://example.com', 'from_response_request': {}}) + fake_request = Request('http://example.com', meta={'request_id': 'id'}) + + self.assertRaisesRegexp(MessageError, 'You must start your spider before sending this message', self.streaming.on_message, msg_request) + self.create_spider() + with mock.patch.object(self.streaming.crawler.engine, 'crawl', return_value=None) as mock_method1: + self.streaming.on_message(msg_request) + + response = Response(url='http://example.com', request=fake_request) + response.encoding = 'utf-8' + response.text = '
' + with mock.patch.object(self.streaming.crawler.engine, 'crawl', return_value=None) as mock_method2: + self.streaming._from_response(msg_request, response) + + self.assertTrue(mock_method1.called) + self.assertTrue(mock_method2.called) + + def test_from_response_request_missing_form(self): + msg_request = FromResponseRequestMessage.from_dict({'id': 'id', 'url': 'http://example.com', 'from_response_request': {}}) + fake_request = Request('http://example.com', meta={'request_id': 'id'}) + + self.assertRaisesRegexp(MessageError, 'You must start your spider before sending this message', self.streaming.on_message, msg_request) + self.create_spider() + with mock.patch.object(self.streaming.crawler.engine, 'crawl', return_value=None) as mock_method1: + self.streaming.on_message(msg_request) + + response = Response(url='http://example.com', request=fake_request) + response.encoding = 'utf-8' + response.text = '

Test

' + with mock.patch.object(self.streaming, 'send_exception', return_value=None) as mock_method2: + self.streaming._from_response(msg_request, response) + + self.assertTrue(mock_method1.called) + self.assertTrue(mock_method2.called) + + def test_response_message(self): + req = Request('http://example.com') + req.meta['request_id'] = 'test' + response = Response(url='http://example.com', request=req) + msg = RequestMessage.from_dict({'id': 'test', 'url': 'http://example.com'}) + with mock.patch.object(self.streaming.protocol, 'writeLine') as mock_method: + self.streaming.send_response(msg, response) + + self.assertTrue(mock_method.called) + + def test_response_wrong_encoding(self): + req = Request('http://example.com') + req.meta['request_id'] = 'test' + response = Response(url='http://example.com', request=req, + body=b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00d\x00\x00\x00d\x08\x02\x00') + msg = RequestMessage.from_dict({'id': 'test', 'url': 'http://example.com'}) + with mock.patch.object(self.streaming, 'send_exception') as mock_method: + self.streaming.send_response(msg, response) + + self.assertTrue(mock_method.called) + + def test_exception_message(self): + class FakeMessage: + line = '{"type": "log", "level": "debug", "message": "sample1.py working"}' + msg = FakeMessage() + exception = 'Problem in spider' + with mock.patch.object(self.streaming.protocol, 'writeLine') as mock_method: + self.streaming.send_exception(msg, exception) + + mock_method.assert_any_call(CommunicationMap.exception(msg.line, exception)) diff --git a/tests/test_utils.py b/tests/test_utils.py index d4159f5..a901be4 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,9 +1,16 @@ import os -from tests.test_commands import ProjectTest +from scrapy import Request +from scrapy.http import Response +from scrapy_streaming.streaming import Streaming, StreamingProtocol +from twisted.internet.defer import Deferred + +from scrapy_streaming.utils.spiders import StreamingSpider +from tests.test_commands import ProjectTest +from tests import mock -from scrapy_streaming.utils import get_project_root, dict_serialize +from scrapy_streaming.utils import get_project_root, dict_serialize, extract_instance_fields class UtilsTest(ProjectTest): @@ -18,3 +25,45 @@ def test_get_project_default(self): def test_dict_serialize(self): d = {'a': 'b'} self.assertEqual(dict_serialize(d), b'{"a": "b"}') + + def test_extract_instance_fields(self): + class Test(object): + a = 'a' + b = 2 + c = {'a': 'b'} + d = [1, 2, 3] + e = 2.5 + f = None + + fields = ['a', 'b', 'c', 'd', 'e', 'f'] + expected = { + 'a': 'a', + 'b': 2, + 'c': {'a': 'b'}, + 'd': [1, 2, 3], + 'e': 2.5, + 'f': None + } + self.assertDictEqual(extract_instance_fields(Test(), fields), expected) + + def test_streaming_spider_parse(self): + spider = StreamingSpider(streaming=Streaming(protocol=StreamingProtocol()), + name='sample', start_urls=[]) + + req = Request('http://example.com') + req.meta['request_id'] = 'test' + fake_response = Response('http://example.com', request=req) + + with mock.patch.object(spider.streaming.protocol, 'writeLine') as mock_send: + self.assertIsInstance(spider.parse(fake_response), Deferred) + + self.assertTrue(mock_send.called) + + def test_streaming_spider_close(self): + class FakeStreaming(object): + def send_response(self, resp): + pass + + spider = StreamingSpider(streaming=FakeStreaming(), name='sample', start_urls=[]) + spider.close_spider() + self.assertTrue(spider.stream.called) diff --git a/tox.ini b/tox.ini index d72d3df..1d7cd40 100644 --- a/tox.ini +++ b/tox.ini @@ -9,6 +9,7 @@ envlist = py27,py35 [testenv] deps = -rrequirements.txt + -rtests/requirements.txt pytest pytest-cov hypothesis