summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorYves Fischer <yvesf-git@xapek.org>2018-12-12 14:34:25 +0100
committerYves Fischer <yvesf-git@xapek.org>2018-12-13 00:18:22 +0100
commit470be1fed7651b2a3ae678ade43ad2f0cb955cf9 (patch)
tree6e368f3c9cbcfe4213a281b21abdd93c2f7b1b95
parenta4a1e5be3d19434a418b0358e1bb75d272392816 (diff)
downloadinfluxdb-udp-inserter-470be1fed7651b2a3ae678ade43ad2f0cb955cf9.tar.gz
influxdb-udp-inserter-470be1fed7651b2a3ae678ade43ad2f0cb955cf9.zip
add replay attack prevention
-rw-r--r--influxdb_udp_inserter/__init__.py198
-rw-r--r--influxdb_udp_inserter/sample_message_format.js21
-rw-r--r--influxdb_udp_inserter/server_async.py51
-rwxr-xr-xrun_server.py (renamed from server.py)4
-rw-r--r--test_client.py2
-rw-r--r--test_serialize.py14
6 files changed, 222 insertions, 68 deletions
diff --git a/influxdb_udp_inserter/__init__.py b/influxdb_udp_inserter/__init__.py
index 9c51951..5205a27 100644
--- a/influxdb_udp_inserter/__init__.py
+++ b/influxdb_udp_inserter/__init__.py
@@ -1,13 +1,29 @@
try:
import struct
except ImportError:
+ # noinspection PyUnresolvedReferences
import ustruct as struct
try:
import socket
except ImportError:
+ # noinspection PyUnresolvedReferences
import usocket as socket
+try:
+ import time
+except ImportError:
+ # noinspection PyUnresolvedReferences
+ import utime as time
+
+import builtins
+
+
+def property(**kwargs):
+ """Wrapper for micropython which otherwise doesn't have named arguments"""
+ return builtins.property('fget' in kwargs and kwargs['fget'] or None,
+ 'fset' in kwargs and kwargs['fset'] or None)
+
class Struct:
def __init__(self, fmt):
@@ -24,35 +40,102 @@ class Struct:
UINT8 = Struct('B')
UINT16 = Struct('H')
UINT32 = Struct('I')
+UINT64 = Struct('Q')
try:
import hashlib
except ImportError:
+ # noinspection PyUnresolvedReferences
import uhashlib as hashlib
+class Timesource:
+ def unix_time_sec(self):
+ return int(time.time())
+
+
class SerializerFactory:
- def __init__(self):
- self.message_formats = {}
+ def __init__(self, timesource: Timesource = None):
+ self._timesource = timesource or Timesource()
+ self._message_formats = {}
def add_message_format(self, description):
- if not 'identifier' in description:
+ if 'identifier' not in description:
raise Exception('Missing required option "identifier"')
identifier = description['identifier']
if not isinstance(identifier, bytes):
identifier = bytes(identifier)
- if identifier in self.message_formats:
+ if identifier in self._message_formats:
raise Exception('There is already a message defined with identifier', description['identifier'])
- self.message_formats[identifier] = description
+ self._message_formats[identifier] = description
def get_serializer(self, identifier):
if not isinstance(identifier, bytes):
identifier = bytes(identifier)
- if identifier in self.message_formats:
- return MessageSerializer.from_config(self.message_formats[identifier])
+ if identifier in self._message_formats:
+ return MessageSerializer.from_config(self._message_formats[identifier], self._timesource)
else:
raise Exception('No message format with identifier', identifier)
+ def __get_timesource(self):
+ return self._timesource
+
+ timesource = property(fget=__get_timesource)
+
+
+class Message:
+ def __init__(self):
+ self._identifier = None
+ self._nonce = None
+ self._timestamp = None
+ self._secret = None
+ self._payload = None
+ self._timestamp = None
+
+ def __get_identifier(self) -> bytes:
+ return self._identifier
+
+ def __set_identifer(self, identifer: bytes):
+ self._identifier = identifer
+
+ def __get_nonce(self) -> int:
+ return self._nonce
+
+ def __set_nonce(self, nonce: int):
+ self._nonce = nonce
+
+ def __set_payload(self, payload):
+ self._payload = payload
+
+ def __get_payload(self) -> bytes:
+ return self._payload
+
+ def __set_timestamp(self, timestamp: int):
+ self._timestamp = timestamp
+
+ def __get_timestamp(self) -> int:
+ return self._timestamp
+
+ identifier = property(fget=__get_identifier, fset=__set_identifer)
+ nonce = property(fget=__get_nonce, fset=__set_nonce)
+ payload = property(fget=__get_payload, fset=__set_payload)
+ timestamp = property(fget=__get_timestamp, fset=__set_timestamp)
+
+
+class MessageWriter:
+ def __init__(self, secret):
+ self._secret = secret
+
+ def to_bytes(self, message):
+ if message.identifier is None or message.timestamp is None \
+ or message.payload is None or self._secret is None:
+ raise Exception('Writer/Message not correctly initialized')
+ else:
+ data = message.identifier + UINT16.pack(message.nonce) + message.payload
+ timestamp = UINT64.pack(message.timestamp)
+ message_hash = hashlib.sha256(data + self._secret + timestamp).digest()
+ return data + message_hash[0:6]
+
class MessageSerializer:
@staticmethod
@@ -72,29 +155,48 @@ class MessageSerializer:
yield result
@staticmethod
- def from_config(config):
+ def from_config(config, timesource: Timesource = None):
identifier = bytes(config['identifier'])
secret = bytes(config['secret'])
database = config['database']
fields = list(MessageSerializer._make_fields(config['fields']))
- return MessageSerializer(identifier, database, secret, fields)
+ return MessageSerializer(identifier, database, secret, fields, timesource)
+
+ def __init__(self, identifier: bytes, database: str, secret, fields, timesource: Timesource = None):
+ if timesource is None:
+ self._timesource = Timesource()
+ else:
+ self._timesource = timesource
+
+ self._identifier = identifier
+ self._database = database
+ self._secret = secret
+ self._fields = fields
+ self._size = sum(map(lambda f: sum(v[1].size for v in f[1:]), self._fields))
+
+ def __get_fields(self):
+ return self._fields
+
+ fields = property(fget=__get_fields)
+
+ def __get_database(self):
+ return self._database
- def __init__(self, identifier, database, secret, fields):
- self.identifier = identifier
- self.database = database
- self.secret = secret
- self.fields = fields
- self.size = sum(map(lambda f: sum(v[1].size for v in f[1:]), self.fields))
+ database = property(fget=__get_database)
- def serialize(self, data: dict):
- if len(data) != len(self.fields):
+ def serialize(self, data: dict, nonce: int) -> (Message, bytes):
+ if len(data) != len(self._fields):
raise Exception()
- if set(data.keys()) != set(map(lambda v: v[0], self.fields)):
+
+ if set(data.keys()) != set(map(lambda v: v[0], self._fields)):
raise Exception("Data does not match schema")
- payload = bytes()
+ message = Message()
+ message.identifier = self._identifier
+ message.nonce = nonce
- for config_name, *config_values in self.fields:
+ payload = bytes()
+ for config_name, *config_values in self._fields:
data_values = data[config_name]
if set(data_values.keys()) != set(map(lambda kv: kv[0], config_values)):
raise Exception("inconsistent data for ", config_name, data_values.keys(), dict(config_values).keys())
@@ -102,42 +204,60 @@ class MessageSerializer:
for config_sub_name, config_struct in config_values:
payload += config_struct.pack(data_values[config_sub_name])
- buf = bytes(self.identifier[0:3]) + payload
+ message.payload = payload
+ message.timestamp = self._timesource.unix_time_sec()
- hash = hashlib.sha256(buf + self.secret).digest()
- buf += hash[0:6]
+ writer = MessageWriter(self._secret)
- return buf
+ return message, writer.to_bytes(message)
- def deserialize(self, data):
- if len(data) < 3 + 6:
- raise Exception('Message of wrong size', len(data))
+ def deserialize(self, raw_data, max_delta_t) -> (Message, dict):
+ message = self.parse_and_verify(raw_data, max_delta_t)
- hash = hashlib.sha256(data[:-6] + self.secret).digest()
- if hash[0:6] != data[-6:]:
- raise Exception("Failed to authenticate message")
-
- payload = data[3:-6]
- if len(payload) != self.size:
- raise Exception('Message of wrong payload size', len(payload))
+ if len(message.payload) != self._size:
+ raise Exception('Message of wrong payload size', len(message.payload))
result = []
i = 0
- for config_name, *config_values in self.fields:
+ for config_name, *config_values in self._fields:
result_field = {}
for config_sub_name, config_struct in config_values:
- window = payload[i:i + config_struct.size]
+ window = message.payload[i:i + config_struct.size]
result_field[config_sub_name] = config_struct.unpack(window)
i += config_struct.size
result += [(config_name, result_field)]
- return dict(result)
+ return message, dict(result)
+
+ def parse_and_verify(self, raw_data, max_delta_t):
+ if len(raw_data) < 3 + 6 + 2: # 24-bit identifier + 64-bit hash + 16-bit nonce + ... payload
+ raise Exception('Message of wrong size', len(raw_data))
+
+ timestamp = self._timesource.unix_time_sec()
+
+ in_data, in_message_hash = raw_data[:-6], raw_data[-6:]
+ message = Message()
+ message.identifier = in_data[0:3]
+ message.nonce = UINT16.unpack(in_data[3:5])
+ message.payload = raw_data[5:-6]
+
+ for delta_t in range(-1 * max_delta_t, max_delta_t):
+ t = UINT64.pack(timestamp + delta_t)
+ message._message_hash = hashlib.sha256(in_data + self._secret + t).digest()[0:6]
+ if message._message_hash == in_message_hash:
+ message.timestamp = timestamp + delta_t
+ break
+ else:
+ raise Exception("Failed to authenticate message", )
+
+ return message
-def send(host: str, port: int, serializer: MessageSerializer, data: dict):
- serialized_data = serializer.serialize(data)
+def send(host: str, port: int, serializer: MessageSerializer, data: dict, nonce: int):
+ """:param nonce: is truncated to unsigned 16bit"""
+ _, serialized_data = serializer.serialize(data, nonce)
sockaddr = socket.getaddrinfo(host, port)[0][-1]
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
diff --git a/influxdb_udp_inserter/sample_message_format.js b/influxdb_udp_inserter/sample_message_format.js
index 41d37e8..a50f336 100644
--- a/influxdb_udp_inserter/sample_message_format.js
+++ b/influxdb_udp_inserter/sample_message_format.js
@@ -1,17 +1,18 @@
{
"identifier": [1, 1, 1],
- "secret": [12, 23, 23, 11, 244, 23, 222, 123],
+ "secret": [12, 23, 23, 11, 244, 23, 222, 123, 23, 23, 11, 244, 23, 222, 123, 23, 23, 11, 244, 23, 222,
+ 123, 23, 23, 11, 244, 23, 222, 123, 23, 23, 11, 244, 23, 222, 123, 23, 23, 11, 244, 23, 222],
"database": "data",
"fields": [
- ["inverter0.PVVoltage1", ["min", "uint16"], ["max", "uint16"], ["last","uint16"]],
- ["inverter0.PVVoltage2", ["min", "uint16"], ["max", "uint16"], ["last","uint16"]],
- ["inverter0.PVVoltage3", ["min", "uint16"], ["max", "uint16"], ["last","uint16"]],
- ["inverter0.PVCurrent1", ["min", "uint16"], ["max", "uint16"], ["last","uint16"]],
- ["inverter0.PVCurrent2", ["min", "uint16"], ["max", "uint16"], ["last","uint16"]],
- ["inverter0.PVCurrent3", ["min", "uint16"], ["max", "uint16"], ["last","uint16"]],
- ["inverter0.GridVoltage", ["min", "uint16"], ["max", "uint16"], ["last","uint16"]],
- ["inverter0.GridFrequency", ["min", "uint16"], ["max", "uint16"], ["last","uint16"]],
- ["inverter0.SmoothedInstantEnergyProduction", ["min", "uint16"], ["max", "uint16"], ["last","uint16"]],
+ ["inverter0.PVVoltage1", ["value","uint16"]],
+ ["inverter0.PVVoltage2", ["value","uint16"]],
+ ["inverter0.PVVoltage3", ["value","uint16"]],
+ ["inverter0.PVCurrent1", ["value","uint16"]],
+ ["inverter0.PVCurrent2", ["value","uint16"]],
+ ["inverter0.PVCurrent3", ["value","uint16"]],
+ ["inverter0.GridVoltage", ["value","uint16"]],
+ ["inverter0.GridFrequency", ["value","uint16"]],
+ ["inverter0.SmoothedInstantEnergyProduction", ["value","uint16"]],
["inverter0.LatestEvent", ["value","uint16"]],
["inverter0.LatestEventModule", ["value","uint16"]]
]
diff --git a/influxdb_udp_inserter/server_async.py b/influxdb_udp_inserter/server_async.py
index 5e2933a..3478a7c 100644
--- a/influxdb_udp_inserter/server_async.py
+++ b/influxdb_udp_inserter/server_async.py
@@ -10,22 +10,45 @@ import aiohttp
class UdpInserterProtocol(asyncio.DatagramProtocol):
def __init__(self, factory: SerializerFactory, influx_url: str):
- self.factory = factory
- self.influx_url = influx_url
- self.transport: asyncio.Transport = None
+ self._factory = factory
+ self._influx_url = influx_url
+ self._transport: asyncio.Transport = None
+ self._max_delta_t = 10
+ self._known_nonces = {} # key: timestamp, values: set( (identifier, nonce) )
def connection_made(self, transport):
- self.transport = transport
-
- def datagram_received(self, data, addr):
- logging.info('Received %s bytes: %r(...) from %s', len(data), data[0:3], addr)
- if len(data) < 7: return
-
- serializer = self.factory.get_serializer(data[0:3])
- mesg = serializer.deserialize(data)
+ self._transport = transport
+
+ def cleanup_known_nonces(self):
+ now = self._factory.timesource.unix_time_sec()
+ delete = []
+ for key in self._known_nonces.keys():
+ if key < now - self._max_delta_t:
+ delete.append(key)
+ for key in delete:
+ del self._known_nonces[key]
+
+ def datagram_received(self, raw_data, addr):
+ logging.info('Received %s bytes: %r(...) from %s', len(raw_data), raw_data[0:3], addr)
+ if len(raw_data) < 7: return
+
+ identifier = raw_data[0:3]
+ serializer = self._factory.get_serializer(identifier)
+
+ mesg, fields = serializer.deserialize(raw_data, self._max_delta_t)
+
+ # Verify nonce is not known for that timestamp
+ if mesg.timestamp in self._known_nonces.keys() and \
+ mesg.nonce in self._known_nonces[mesg.timestamp]:
+ raise Exception('Possible replay attack: Nonce {} already knwon for timestamp {}'.format(
+ mesg.nonce, mesg.timestamp))
+ else:
+ if not mesg.timestamp in self._known_nonces.keys():
+ self._known_nonces[mesg.timestamp] = set()
+ self._known_nonces[mesg.timestamp].add(mesg.nonce)
influxdb_points = []
- for key, value in mesg.items():
+ for key, value in fields.items():
influxdb_points.append({
'measurement': key,
'tags': {},
@@ -34,7 +57,9 @@ class UdpInserterProtocol(asyncio.DatagramProtocol):
post_data = line_protocol.make_lines({'points': influxdb_points}).encode()
- asyncio.ensure_future(send(self.influx_url + '?db=' + serializer.database, post_data))
+ asyncio.ensure_future(send(self._influx_url + '?db=' + serializer.database, post_data))
+
+ self.cleanup_known_nonces()
async def send(url, data):
diff --git a/server.py b/run_server.py
index 454f4a6..dd1f614 100755
--- a/server.py
+++ b/run_server.py
@@ -12,10 +12,10 @@ def main():
required=True)
parser.add_argument('--formats', nargs=1, action='append', metavar='PATTERN',
help='Glob pattern to look for message formats')
- parser.add_argument('--port', nargs=1, metavar='PORT', help='List on UDP port number', default=9999)
+ parser.add_argument('--port', nargs=1, metavar='PORT', help='List on UDP port number', default=[9999])
args = parser.parse_args()
- s = Server(args.url[0], local_addr=('0.0.0.0', args.port))
+ s = Server(args.url[0], local_addr=('0.0.0.0', args.port[0]))
for pattern in args.formats:
for filepath in glob.glob(pattern[0]):
diff --git a/test_client.py b/test_client.py
index a4c416f..0a0abfa 100644
--- a/test_client.py
+++ b/test_client.py
@@ -15,6 +15,6 @@ for field_name, *values in serializer.fields:
fake_data[field_name] = {}
for name, value in values:
fake_data[field_name][name] = sum(map(ord, field_name+name)) % 0xff
-send('0.0.0.0', 9999, serializer, fake_data)
+send('0.0.0.0', 9999, serializer, fake_data, 123)
print('done')
diff --git a/test_serialize.py b/test_serialize.py
index 1f5d5b0..e7b468b 100644
--- a/test_serialize.py
+++ b/test_serialize.py
@@ -1,5 +1,6 @@
from influxdb_udp_inserter import SerializerFactory
+
def test():
try:
import os
@@ -12,26 +13,33 @@ def test():
with open('influxdb_udp_inserter/sample_message_format.js') as fp:
message_format = json.load(fp)
+ identifier = bytes(message_format['identifier'])
factory = SerializerFactory()
factory.add_message_format(message_format)
- identifier = bytes((0x01, 0x01, 0x01))
+ # generate some fake data
fake_data = {}
for field_name, *values in factory.get_serializer(identifier).fields:
fake_data[field_name] = {}
for name, value in values:
fake_data[field_name][name] = sum(map(ord, field_name+name)) % 0xff
+ # serialize data
print(fake_data)
- serialized = factory.get_serializer(identifier).serialize(fake_data)
+ message1, serialized = factory.get_serializer(identifier).serialize(fake_data, 123)
print(serialized)
print("{}byte => 360 * 24 * 12 => {}Mb".format(len(serialized), (len(serialized) * 360 * 24 * 12) / 1024 / 1024))
- result = factory.get_serializer(identifier).deserialize(serialized)
+ # deserialize
+ message2, result = factory.get_serializer(identifier).deserialize(serialized, 2)
+
print(result)
+ # Compare numbers
+ if message2.nonce != 123:
+ raise Exception()
if fake_data != result:
raise Exception()
else: