#! /usr/bin/python3

# $Id: dmarc2srs 48846 2020-05-28 10:19:09Z wsl $
# $URL: https://svn.uvt.nl/its-id/trunk/sources/dmarc2srs/dmarc2srs $

# General idea: apply SRS to the from address if we're about to send
# something from our mail servers that might fail on DMARC.

import Milter
import SRS
from Milter.utils import parse_addr
from authres import FeatureContext
import authres.dmarc
from dkim.util import parse_tag_value as parse_dkim_header
from pwd import getpwnam, getpwuid
from grp import getgrnam
from os import setresuid, setresgid, umask
from sys import argv, stderr
from time import time_ns
from socket import getfqdn
from re import compile as regcomp
from traceback import format_exc
from email.parser import HeaderParser
from email.policy import EmailPolicy
from email.headerregistry import Group, Address, AddressHeader

if len(argv) != 2:
	raise Exception("usage: %s <configuration file>" % (argv[0],))

config = {}

with open(argv[1]) as fh:
	exec(fh.read(), {}, config)

umask(config.get('umask', 0o007))

socket = config['socket']

# The authority of the Authentication-Results header:
authority = config.get('authority', getfqdn()).lower()

# Domain that is used for SRS-rewrites (and for addresses without a domain):
default_domain = config['default_domain'].lower()

# Secret for SRS; must match the secret that is used to decode addresses on incoming mail:
secret = config.get('secret', None)

# Strip DKIM signatures made by these domains.
strip_dkim_domains = config.get('strip_dkim_domains', ())

if isinstance(strip_dkim_domains, str):
	strip_dkim_domains = (strip_dkim_domains,)

strip_dkim_domains = frozenset(a.lower() for a in strip_dkim_domains)

# Only strip DKIM signatures if they are in strip_dkim_domains AND
# the selector matches this list. Use None to ignore the selector.
strip_dkim_selectors = config.get('strip_dkim_selectors')

if strip_dkim_selectors is not None:
	if isinstance(strip_dkim_selectors, str):
		strip_dkim_selectors = (strip_dkim_selectors,)

	strip_dkim_selectors = frozenset(a.lower() for a in strip_dkim_selectors)

user = config.get('user', None)
group = config.get('group', None)

# According to the documentation, Milter counts headers starting from 1.
# However, observed results indicate that it counts starting from 0.
# Might be a milter implementation bug in Postfix.
milter_header_count_base = config.get('milter_header_count_base', 0)

# Function to detect if a string is just digits:
looks_like_number = regcomp('[0-9]+').fullmatch

# To make it less likely that log ID's are identical across runs
log_id_prefix = str(time_ns()) + "\0"

if user is not None:
	if looks_like_number(user):
		uid = int(user)
		if group is None:
			pwent = getpwuid(user)
	else:
		pwent = getpwnam(user)
		uid = pwent.pw_uid

	if group is None:
		gid = pwent.pw_gid
		setresgid(gid, gid, gid)
		setgroups(getgrouplist(user, gid))

if group is not None:
	if looks_like_number(group):
		gid = int(group)
	else:
		gid = getgrnam(group).gr_gid
	setresgid(gid, gid, gid)
	#setgroups([gid])
		
if user is not None:
	setresuid(uid, uid, uid)

# 'softfail' arguably belongs with 'fail'
authres_none = frozenset({'none', 'neutral', 'softfail'})
authres_pass = frozenset({'pass'})
authres_fail = frozenset({'fail'})
authres_none_pass = frozenset({*authres_none, *authres_pass})
authres_none_pass_fail = frozenset({*authres_none_pass, *authres_fail})
authres_error = frozenset({'temperror', 'permerror'})

class DMARC2SRS(Milter.Base):
	envelope_from = None
	authres_num = 0
	dkim_headers = 0
	spf_authres = None
	spf_authres_index = None
	dmarc_authres = None

	def __init__(self, *args, **kwargs):
		super().__init__(*args, **kwargs)
		self.header_from = []
		self.dkim_sigs = []
		self.strip_dkim_headers = set()
		self.log_id = '{:016X}:'.format(hash(log_id_prefix + str(Milter.uniqueID())) & 0xFFFFFFFFFFFFFFFF)

	def log(self, *msgs):
		try:
			queue_id = (self.getsymval('i') + ":",)
		except:
			queue_id = ()
		# Just write to stderr and let systemd handle it.
		print(*queue_id, self.log_id, *msgs, file = stderr, flush = True)

	def envfrom(self, address, *extra):
		self.envelope_from = address
		self.log("envfrom", repr(address))
		return Milter.CONTINUE

	def header(self, name, value, *extra):
		try:
			name = name.lower()
			if name == 'authentication-results':
				authres_num = self.authres_num + 1
				self.authres_num = authres_num
				ar = FeatureContext(authres.dmarc).parse_value(value)
				if ar.authserv_id.lower() == authority:
					for res in ar.results:
						method = res.method.lower()
						if method == 'spf':
							self.spf_authres = res.result
							self.spf_authres_index = authres_num
						elif method == 'dkim':
							self.dkim_sigs.append(res)
						elif method == 'dmarc':
							self.dmarc_authres = res.result
			elif name == 'from':
				self.header_from.append(value)
				self.log("header from", repr(value))
			elif name == 'dkim-signature':
				dkim_headers = self.dkim_headers
				self.dkim_headers = dkim_headers + 1

				properties = parse_dkim_header(value.encode())
				d = properties[b'd'].decode().lower()
				s = properties[b's'].decode().lower()

				if d in strip_dkim_domains and (strip_dkim_selectors is None or s in strip_dkim_selectors):
					self.strip_dkim_headers.add(dkim_headers)

		except Exception as e:
			self.log(str(e))
			#self.log("header(%s, %s)" % (name, value))
			#self.log(format_exc())

		return Milter.CONTINUE

	def eom(self):
		self.log("eom", repr(self.envelope_from), repr(self.header_from))
		try:
			srs_rewrite = SRS.new(secret = secret).forward
			parse_headers = HeaderParser(policy = EmailPolicy()).parsestr

			# These signatures match strip_dkim_domains/strip_dkim_selectors, so remove them:
			strip_dkim_headers = self.strip_dkim_headers
			for idx in strip_dkim_headers:
				self.chgheader('DKIM-Signature', idx, None)

			dmarc_authres = self.dmarc_authres
			if dmarc_authres not in authres_none_pass_fail:
				if dmarc_authres is None:
					self.log("no dmarc Authentication-Results headers found, skipping")
				else:
					self.log("dmarc Authentication-Results header indicates %s, skipping" % (dmarc_authres,))
				return Milter.CONTINUE

			spf_authres = self.spf_authres
			if spf_authres not in authres_none_pass_fail:
				if spf_authres is None:
					self.log("no spf Authentication-Results headers found, skipping")
				else:
					self.log("spf Authentication-Results header indicates %s, skipping" % (spf_authres,))
				return Milter.CONTINUE

			# The default DKIM result status:
			dkim_authres = 'none'

			# Whether valid DKIM signatures are present is not very interesting by itself.
			# We consider the e-mail valid in a DKIM sense if the header from address(es)
			# are signed by the same domain (or a parent domain).
			dkim_sigs = self.dkim_sigs

			# Collect all domains from the result=pass dkim header(s):
			dkim_domains = set()
			for sig in dkim_sigs:
				result = sig.result.lower()
				if result == 'pass':
					properties = { (p.type, p.name): str(p.value) for p in sig.properties }
					header_i = properties.get(('header', 'i'))
					header_d = properties.get(('header', 'd'))
					if header_i is not None:
						_, _, domain = header_i.rpartition('@')
						dkim_domains.add(domain.lower())
					elif header_d is not None:
						dkim_domains.add(header_d.lower())

				if result != 'none':
					# This is just the default in case subsequent steps fail:
					dkim_authres = 'fail'

			self.log("dkim_domains:", *dkim_domains)

			if dkim_domains:
				# Collect all domains from the From header(s):
				from_domains = set()
				for idx, value in enumerate(self.header_from, start=1):
					(name, value), = parse_headers("From: " + value).items()
					for group in value.groups:
						for address in group.addresses:
							from_domains.add(address.domain.lower())
				self.log("from_domains:", *from_domains)

				if from_domains:
					# For each address we will check if there's a valid DKIM signature for it.
					# Because the signature might be created by a parent domain, we'll check
					# those too.
					dkim_authres = 'pass'
					for from_domain in from_domains:
						while from_domain not in dkim_domains:
							# Try the parent domain:
							_, separator, from_domain = from_domain.partition('.')
							if not separator:
								# No '.' found means no parent domains left to try.
								break
						else:
							# This from_domain was found, so we can go on and check the next one:
							continue

						# This from_domain was not found, we can give up now.
						dkim_authres = 'fail'
						break

			if dkim_authres not in authres_none_pass_fail:
				if dkim_authres is None:
					self.log("no dkim Authentication-Results headers found, skipping")
				else:
					self.log("dkim Authentication-Results header indicates %s, skipping" % (dkim_authres,))
				return Milter.CONTINUE

			self.log("spf=%s dkim=%s dmarc=%s" % (spf_authres, dkim_authres, dmarc_authres))

			if spf_authres in authres_fail or dmarc_authres in authres_pass and dkim_authres in authres_fail or dmarc_authres in authres_fail:
				# Rewrite envelope from:

				# parse_addr() returns a list of 1 or 2 elements, depending
				# on whether there's a @ in the address.
				parsed = parse_addr(self.envelope_from)
				if len(parsed) < 2:
					# Address without @:
					localpart, = parsed
					if localpart == '':
						# Empty sender (<>)
						return Milter.CONTINUE
					domain = default_domain
				else:
					localpart, domain = parsed

				address = srs_rewrite(localpart + "@" + domain, default_domain)
				self.log("changing envelope sender from %s to: %s" % (self.envelope_from, address))
				self.chgfrom(address)

				self.chgheader('Authentication-Results', self.spf_authres_index,
					"%s; spf=pass smtp.mailfrom=%s" % (authority, address))

			if dmarc_authres in authres_pass and (spf_authres in authres_fail or dkim_authres in authres_fail) or dmarc_authres in authres_fail:
				# Rewrite header from:

				for idx, old_value in enumerate(self.header_from, start = milter_header_count_base):
					(name, parsed_value), = parse_headers("From: " + old_value).items()

					groups = []
					for group in parsed_value.groups:
						addresses = []
						for address in group.addresses:
							if address.domain.lower() == default_domain:
								addresses.append(address)
							else:
								srs_address = srs_rewrite(address.addr_spec, default_domain)
								addresses.append(Address(display_name = address.display_name, addr_spec = srs_address))
						groups.append(Group(display_name = group.display_name, addresses = addresses))

					new_value = ",\n\t".join(map(str, groups))
					self.log("changing header sender from %s to: %s" % (old_value, new_value))
					self.chgheader(name, idx, new_value)

				# Changing the From invalidates all signatures. Best to just remove them all.
				for idx in range(milter_header_count_base, self.dkim_headers + milter_header_count_base):
					if idx not in strip_dkim_headers:
						self.chgheader('DKIM-Signature', idx, None)

		except Exception as e:
			self.log(str(e))
			#self.log(format_exc())

		return Milter.CONTINUE

Milter.factory = DMARC2SRS
Milter.set_flags(Milter.CHGHDRS | Milter.CHGFROM)
Milter.runmilter('dmarc2srs', socket)
