diff options
author | Yves Fischer <yvesf-git@xapek.org> | 2018-12-12 14:34:25 +0100 |
---|---|---|
committer | Yves Fischer <yvesf-git@xapek.org> | 2018-12-13 00:18:22 +0100 |
commit | 470be1fed7651b2a3ae678ade43ad2f0cb955cf9 (patch) | |
tree | 6e368f3c9cbcfe4213a281b21abdd93c2f7b1b95 | |
parent | a4a1e5be3d19434a418b0358e1bb75d272392816 (diff) | |
download | influxdb-udp-inserter-470be1fed7651b2a3ae678ade43ad2f0cb955cf9.tar.gz influxdb-udp-inserter-470be1fed7651b2a3ae678ade43ad2f0cb955cf9.zip |
add replay attack prevention
-rw-r--r-- | influxdb_udp_inserter/__init__.py | 198 | ||||
-rw-r--r-- | influxdb_udp_inserter/sample_message_format.js | 21 | ||||
-rw-r--r-- | influxdb_udp_inserter/server_async.py | 51 | ||||
-rwxr-xr-x | run_server.py (renamed from server.py) | 4 | ||||
-rw-r--r-- | test_client.py | 2 | ||||
-rw-r--r-- | test_serialize.py | 14 |
6 files changed, 222 insertions, 68 deletions
diff --git a/influxdb_udp_inserter/__init__.py b/influxdb_udp_inserter/__init__.py index 9c51951..5205a27 100644 --- a/influxdb_udp_inserter/__init__.py +++ b/influxdb_udp_inserter/__init__.py @@ -1,13 +1,29 @@ try: import struct except ImportError: + # noinspection PyUnresolvedReferences import ustruct as struct try: import socket except ImportError: + # noinspection PyUnresolvedReferences import usocket as socket +try: + import time +except ImportError: + # noinspection PyUnresolvedReferences + import utime as time + +import builtins + + +def property(**kwargs): + """Wrapper for micropython which otherwise doesn't have named arguments""" + return builtins.property('fget' in kwargs and kwargs['fget'] or None, + 'fset' in kwargs and kwargs['fset'] or None) + class Struct: def __init__(self, fmt): @@ -24,35 +40,102 @@ class Struct: UINT8 = Struct('B') UINT16 = Struct('H') UINT32 = Struct('I') +UINT64 = Struct('Q') try: import hashlib except ImportError: + # noinspection PyUnresolvedReferences import uhashlib as hashlib +class Timesource: + def unix_time_sec(self): + return int(time.time()) + + class SerializerFactory: - def __init__(self): - self.message_formats = {} + def __init__(self, timesource: Timesource = None): + self._timesource = timesource or Timesource() + self._message_formats = {} def add_message_format(self, description): - if not 'identifier' in description: + if 'identifier' not in description: raise Exception('Missing required option "identifier"') identifier = description['identifier'] if not isinstance(identifier, bytes): identifier = bytes(identifier) - if identifier in self.message_formats: + if identifier in self._message_formats: raise Exception('There is already a message defined with identifier', description['identifier']) - self.message_formats[identifier] = description + self._message_formats[identifier] = description def get_serializer(self, identifier): if not isinstance(identifier, bytes): identifier = bytes(identifier) - if identifier in self.message_formats: - return MessageSerializer.from_config(self.message_formats[identifier]) + if identifier in self._message_formats: + return MessageSerializer.from_config(self._message_formats[identifier], self._timesource) else: raise Exception('No message format with identifier', identifier) + def __get_timesource(self): + return self._timesource + + timesource = property(fget=__get_timesource) + + +class Message: + def __init__(self): + self._identifier = None + self._nonce = None + self._timestamp = None + self._secret = None + self._payload = None + self._timestamp = None + + def __get_identifier(self) -> bytes: + return self._identifier + + def __set_identifer(self, identifer: bytes): + self._identifier = identifer + + def __get_nonce(self) -> int: + return self._nonce + + def __set_nonce(self, nonce: int): + self._nonce = nonce + + def __set_payload(self, payload): + self._payload = payload + + def __get_payload(self) -> bytes: + return self._payload + + def __set_timestamp(self, timestamp: int): + self._timestamp = timestamp + + def __get_timestamp(self) -> int: + return self._timestamp + + identifier = property(fget=__get_identifier, fset=__set_identifer) + nonce = property(fget=__get_nonce, fset=__set_nonce) + payload = property(fget=__get_payload, fset=__set_payload) + timestamp = property(fget=__get_timestamp, fset=__set_timestamp) + + +class MessageWriter: + def __init__(self, secret): + self._secret = secret + + def to_bytes(self, message): + if message.identifier is None or message.timestamp is None \ + or message.payload is None or self._secret is None: + raise Exception('Writer/Message not correctly initialized') + else: + data = message.identifier + UINT16.pack(message.nonce) + message.payload + timestamp = UINT64.pack(message.timestamp) + message_hash = hashlib.sha256(data + self._secret + timestamp).digest() + return data + message_hash[0:6] + class MessageSerializer: @staticmethod @@ -72,29 +155,48 @@ class MessageSerializer: yield result @staticmethod - def from_config(config): + def from_config(config, timesource: Timesource = None): identifier = bytes(config['identifier']) secret = bytes(config['secret']) database = config['database'] fields = list(MessageSerializer._make_fields(config['fields'])) - return MessageSerializer(identifier, database, secret, fields) + return MessageSerializer(identifier, database, secret, fields, timesource) + + def __init__(self, identifier: bytes, database: str, secret, fields, timesource: Timesource = None): + if timesource is None: + self._timesource = Timesource() + else: + self._timesource = timesource + + self._identifier = identifier + self._database = database + self._secret = secret + self._fields = fields + self._size = sum(map(lambda f: sum(v[1].size for v in f[1:]), self._fields)) + + def __get_fields(self): + return self._fields + + fields = property(fget=__get_fields) + + def __get_database(self): + return self._database - def __init__(self, identifier, database, secret, fields): - self.identifier = identifier - self.database = database - self.secret = secret - self.fields = fields - self.size = sum(map(lambda f: sum(v[1].size for v in f[1:]), self.fields)) + database = property(fget=__get_database) - def serialize(self, data: dict): - if len(data) != len(self.fields): + def serialize(self, data: dict, nonce: int) -> (Message, bytes): + if len(data) != len(self._fields): raise Exception() - if set(data.keys()) != set(map(lambda v: v[0], self.fields)): + + if set(data.keys()) != set(map(lambda v: v[0], self._fields)): raise Exception("Data does not match schema") - payload = bytes() + message = Message() + message.identifier = self._identifier + message.nonce = nonce - for config_name, *config_values in self.fields: + payload = bytes() + for config_name, *config_values in self._fields: data_values = data[config_name] if set(data_values.keys()) != set(map(lambda kv: kv[0], config_values)): raise Exception("inconsistent data for ", config_name, data_values.keys(), dict(config_values).keys()) @@ -102,42 +204,60 @@ class MessageSerializer: for config_sub_name, config_struct in config_values: payload += config_struct.pack(data_values[config_sub_name]) - buf = bytes(self.identifier[0:3]) + payload + message.payload = payload + message.timestamp = self._timesource.unix_time_sec() - hash = hashlib.sha256(buf + self.secret).digest() - buf += hash[0:6] + writer = MessageWriter(self._secret) - return buf + return message, writer.to_bytes(message) - def deserialize(self, data): - if len(data) < 3 + 6: - raise Exception('Message of wrong size', len(data)) + def deserialize(self, raw_data, max_delta_t) -> (Message, dict): + message = self.parse_and_verify(raw_data, max_delta_t) - hash = hashlib.sha256(data[:-6] + self.secret).digest() - if hash[0:6] != data[-6:]: - raise Exception("Failed to authenticate message") - - payload = data[3:-6] - if len(payload) != self.size: - raise Exception('Message of wrong payload size', len(payload)) + if len(message.payload) != self._size: + raise Exception('Message of wrong payload size', len(message.payload)) result = [] i = 0 - for config_name, *config_values in self.fields: + for config_name, *config_values in self._fields: result_field = {} for config_sub_name, config_struct in config_values: - window = payload[i:i + config_struct.size] + window = message.payload[i:i + config_struct.size] result_field[config_sub_name] = config_struct.unpack(window) i += config_struct.size result += [(config_name, result_field)] - return dict(result) + return message, dict(result) + + def parse_and_verify(self, raw_data, max_delta_t): + if len(raw_data) < 3 + 6 + 2: # 24-bit identifier + 64-bit hash + 16-bit nonce + ... payload + raise Exception('Message of wrong size', len(raw_data)) + + timestamp = self._timesource.unix_time_sec() + + in_data, in_message_hash = raw_data[:-6], raw_data[-6:] + message = Message() + message.identifier = in_data[0:3] + message.nonce = UINT16.unpack(in_data[3:5]) + message.payload = raw_data[5:-6] + + for delta_t in range(-1 * max_delta_t, max_delta_t): + t = UINT64.pack(timestamp + delta_t) + message._message_hash = hashlib.sha256(in_data + self._secret + t).digest()[0:6] + if message._message_hash == in_message_hash: + message.timestamp = timestamp + delta_t + break + else: + raise Exception("Failed to authenticate message", ) + + return message -def send(host: str, port: int, serializer: MessageSerializer, data: dict): - serialized_data = serializer.serialize(data) +def send(host: str, port: int, serializer: MessageSerializer, data: dict, nonce: int): + """:param nonce: is truncated to unsigned 16bit""" + _, serialized_data = serializer.serialize(data, nonce) sockaddr = socket.getaddrinfo(host, port)[0][-1] s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) diff --git a/influxdb_udp_inserter/sample_message_format.js b/influxdb_udp_inserter/sample_message_format.js index 41d37e8..a50f336 100644 --- a/influxdb_udp_inserter/sample_message_format.js +++ b/influxdb_udp_inserter/sample_message_format.js @@ -1,17 +1,18 @@ { "identifier": [1, 1, 1], - "secret": [12, 23, 23, 11, 244, 23, 222, 123], + "secret": [12, 23, 23, 11, 244, 23, 222, 123, 23, 23, 11, 244, 23, 222, 123, 23, 23, 11, 244, 23, 222, + 123, 23, 23, 11, 244, 23, 222, 123, 23, 23, 11, 244, 23, 222, 123, 23, 23, 11, 244, 23, 222], "database": "data", "fields": [ - ["inverter0.PVVoltage1", ["min", "uint16"], ["max", "uint16"], ["last","uint16"]], - ["inverter0.PVVoltage2", ["min", "uint16"], ["max", "uint16"], ["last","uint16"]], - ["inverter0.PVVoltage3", ["min", "uint16"], ["max", "uint16"], ["last","uint16"]], - ["inverter0.PVCurrent1", ["min", "uint16"], ["max", "uint16"], ["last","uint16"]], - ["inverter0.PVCurrent2", ["min", "uint16"], ["max", "uint16"], ["last","uint16"]], - ["inverter0.PVCurrent3", ["min", "uint16"], ["max", "uint16"], ["last","uint16"]], - ["inverter0.GridVoltage", ["min", "uint16"], ["max", "uint16"], ["last","uint16"]], - ["inverter0.GridFrequency", ["min", "uint16"], ["max", "uint16"], ["last","uint16"]], - ["inverter0.SmoothedInstantEnergyProduction", ["min", "uint16"], ["max", "uint16"], ["last","uint16"]], + ["inverter0.PVVoltage1", ["value","uint16"]], + ["inverter0.PVVoltage2", ["value","uint16"]], + ["inverter0.PVVoltage3", ["value","uint16"]], + ["inverter0.PVCurrent1", ["value","uint16"]], + ["inverter0.PVCurrent2", ["value","uint16"]], + ["inverter0.PVCurrent3", ["value","uint16"]], + ["inverter0.GridVoltage", ["value","uint16"]], + ["inverter0.GridFrequency", ["value","uint16"]], + ["inverter0.SmoothedInstantEnergyProduction", ["value","uint16"]], ["inverter0.LatestEvent", ["value","uint16"]], ["inverter0.LatestEventModule", ["value","uint16"]] ] diff --git a/influxdb_udp_inserter/server_async.py b/influxdb_udp_inserter/server_async.py index 5e2933a..3478a7c 100644 --- a/influxdb_udp_inserter/server_async.py +++ b/influxdb_udp_inserter/server_async.py @@ -10,22 +10,45 @@ import aiohttp class UdpInserterProtocol(asyncio.DatagramProtocol): def __init__(self, factory: SerializerFactory, influx_url: str): - self.factory = factory - self.influx_url = influx_url - self.transport: asyncio.Transport = None + self._factory = factory + self._influx_url = influx_url + self._transport: asyncio.Transport = None + self._max_delta_t = 10 + self._known_nonces = {} # key: timestamp, values: set( (identifier, nonce) ) def connection_made(self, transport): - self.transport = transport - - def datagram_received(self, data, addr): - logging.info('Received %s bytes: %r(...) from %s', len(data), data[0:3], addr) - if len(data) < 7: return - - serializer = self.factory.get_serializer(data[0:3]) - mesg = serializer.deserialize(data) + self._transport = transport + + def cleanup_known_nonces(self): + now = self._factory.timesource.unix_time_sec() + delete = [] + for key in self._known_nonces.keys(): + if key < now - self._max_delta_t: + delete.append(key) + for key in delete: + del self._known_nonces[key] + + def datagram_received(self, raw_data, addr): + logging.info('Received %s bytes: %r(...) from %s', len(raw_data), raw_data[0:3], addr) + if len(raw_data) < 7: return + + identifier = raw_data[0:3] + serializer = self._factory.get_serializer(identifier) + + mesg, fields = serializer.deserialize(raw_data, self._max_delta_t) + + # Verify nonce is not known for that timestamp + if mesg.timestamp in self._known_nonces.keys() and \ + mesg.nonce in self._known_nonces[mesg.timestamp]: + raise Exception('Possible replay attack: Nonce {} already knwon for timestamp {}'.format( + mesg.nonce, mesg.timestamp)) + else: + if not mesg.timestamp in self._known_nonces.keys(): + self._known_nonces[mesg.timestamp] = set() + self._known_nonces[mesg.timestamp].add(mesg.nonce) influxdb_points = [] - for key, value in mesg.items(): + for key, value in fields.items(): influxdb_points.append({ 'measurement': key, 'tags': {}, @@ -34,7 +57,9 @@ class UdpInserterProtocol(asyncio.DatagramProtocol): post_data = line_protocol.make_lines({'points': influxdb_points}).encode() - asyncio.ensure_future(send(self.influx_url + '?db=' + serializer.database, post_data)) + asyncio.ensure_future(send(self._influx_url + '?db=' + serializer.database, post_data)) + + self.cleanup_known_nonces() async def send(url, data): diff --git a/server.py b/run_server.py index 454f4a6..dd1f614 100755 --- a/server.py +++ b/run_server.py @@ -12,10 +12,10 @@ def main(): required=True) parser.add_argument('--formats', nargs=1, action='append', metavar='PATTERN', help='Glob pattern to look for message formats') - parser.add_argument('--port', nargs=1, metavar='PORT', help='List on UDP port number', default=9999) + parser.add_argument('--port', nargs=1, metavar='PORT', help='List on UDP port number', default=[9999]) args = parser.parse_args() - s = Server(args.url[0], local_addr=('0.0.0.0', args.port)) + s = Server(args.url[0], local_addr=('0.0.0.0', args.port[0])) for pattern in args.formats: for filepath in glob.glob(pattern[0]): diff --git a/test_client.py b/test_client.py index a4c416f..0a0abfa 100644 --- a/test_client.py +++ b/test_client.py @@ -15,6 +15,6 @@ for field_name, *values in serializer.fields: fake_data[field_name] = {} for name, value in values: fake_data[field_name][name] = sum(map(ord, field_name+name)) % 0xff -send('0.0.0.0', 9999, serializer, fake_data) +send('0.0.0.0', 9999, serializer, fake_data, 123) print('done') diff --git a/test_serialize.py b/test_serialize.py index 1f5d5b0..e7b468b 100644 --- a/test_serialize.py +++ b/test_serialize.py @@ -1,5 +1,6 @@ from influxdb_udp_inserter import SerializerFactory + def test(): try: import os @@ -12,26 +13,33 @@ def test(): with open('influxdb_udp_inserter/sample_message_format.js') as fp: message_format = json.load(fp) + identifier = bytes(message_format['identifier']) factory = SerializerFactory() factory.add_message_format(message_format) - identifier = bytes((0x01, 0x01, 0x01)) + # generate some fake data fake_data = {} for field_name, *values in factory.get_serializer(identifier).fields: fake_data[field_name] = {} for name, value in values: fake_data[field_name][name] = sum(map(ord, field_name+name)) % 0xff + # serialize data print(fake_data) - serialized = factory.get_serializer(identifier).serialize(fake_data) + message1, serialized = factory.get_serializer(identifier).serialize(fake_data, 123) print(serialized) print("{}byte => 360 * 24 * 12 => {}Mb".format(len(serialized), (len(serialized) * 360 * 24 * 12) / 1024 / 1024)) - result = factory.get_serializer(identifier).deserialize(serialized) + # deserialize + message2, result = factory.get_serializer(identifier).deserialize(serialized, 2) + print(result) + # Compare numbers + if message2.nonce != 123: + raise Exception() if fake_data != result: raise Exception() else: |