#! /usr/bin/python3

# $Id: dmarc2srs 48282 2019-08-01 11:11:01Z wsl $
# $URL: https://svn.uvt.nl/its-id/trunk/sources/dmarc2srs/dmarc2srs $

# General idea: apply SRS to the from address if we're about to send
# something from our mail servers that might fail on DMARC.

import Milter
import SRS
from Milter.utils import parse_addr
from authres import FeatureContext
import authres.dmarc
from pwd import getpwnam, getpwuid
from grp import getgrnam
from os import setresuid, setresgid, umask
from sys import argv, stderr
from socket import getfqdn
from re import compile as regcomp
from traceback import format_exc
from email.parser import HeaderParser
from email.policy import EmailPolicy
from email.headerregistry import Group, Address, AddressHeader

if len(argv) != 2:
	raise Exception("usage: %s <configuration file>" % (argv[0],))

config = {}

with open(argv[1]) as fh:
	exec(fh.read(), {}, config)

umask(config.get('umask', 0o007))

socket = config['socket']

# The authority of the Authentication-Results header:
authority = config.get('authority', getfqdn()).lower()

# Domain that is used for SRS-rewrites (and for addresses without a domain):
default_domain = config['default_domain'].lower()

# Secret for SRS; must match the secret that is used to decode addresses on incoming mail:
secret = config.get('secret', None)

# Function to detect if a string is just digits:
looks_like_number = regcomp('[0-9]+').fullmatch

user = config.get('user', None)
group = config.get('group', None)

# According to the documentation, Milter counts headers starting from 1.
# However, observed results indicate that it counts starting from 0.
# Might be a milter implementation bug in Postfix.
milter_header_count_base = config.get('milter_header_count_base', 0)

if user is not None:
	if looks_like_number(user):
		uid = int(user)
		if group is None:
			pwent = getpwuid(user)
	else:
		pwent = getpwnam(user)
		uid = pwent.pw_uid

	if group is None:
		gid = pwent.pw_gid
		setresgid(gid, gid, gid)
		setgroups(getgrouplist(user, gid))

if group is not None:
	if looks_like_number(group):
		gid = int(group)
	else:
		gid = getgrnam(group).gr_gid
	setresgid(gid, gid, gid)
	#setgroups([gid])
		
if user is not None:
	setresuid(uid, uid, uid)

srs_rewrite = SRS.new(secret = secret).forward
parse_authres = FeatureContext(authres.dmarc).parse_value
email_policy = EmailPolicy()
parse_headers = HeaderParser(policy=email_policy).parsestr
fold_header = email_policy.fold

authres_none_pass = {'none', 'pass'}
authres_none_pass_fail = {'none', 'pass', 'fail'}
authres_error = {'temperror', 'permerror'}

class DMARC2SRS(Milter.Base):
	envelope_from = None
	header_from = ()
	authres_num = 0
	dkim_sigs = ()
	dkim_headers = 0
	spf_authres = None
	spf_authres_index = None
	dmarc_authres = None

	def log(self, *msg):
		# Milter log() seems broken, just write to stderr and let systemd handle it.
		# return super().log("dmarc2srs: " + msg)
		print(*msg, file = stderr, flush = True)

	def envfrom(self, address, *extra):
		self.envelope_from = address
		return Milter.CONTINUE

	def header(self, name, value, *extra):
		try:
			name = name.lower()
			if name == 'authentication-results':
				authres_num = self.authres_num + 1
				self.authres_num = authres_num
				ar = parse_authres(value)
				if ar.authserv_id.lower() == authority:
					for res in ar.results:
						method = res.method.lower()
						if method == 'spf':
							self.spf_authres = res.result
							self.spf_authres_index = authres_num
						elif method == 'dkim':
							if not self.dkim_sigs:
								self.dkim_sigs = [res]
							else:
								self.dkim_sigs.append(res)
						elif method == 'dmarc':
							self.dmarc_authres = res.result
			elif name == 'from':
				if not self.header_from:
					self.header_from = [value]
				else:
					self.header_from.append(value)
			elif name == 'dkim-signature':
				self.dkim_headers += 1

		except Exception as e:
			self.log(str(e))
			#self.log("header(%s, %s)" % (name, value))
			#self.log(format_exc())

		return Milter.CONTINUE

	def eom(self):
		try:
			dmarc_authres = self.dmarc_authres
			if dmarc_authres not in authres_none_pass_fail:
				if dmarc_authres is None:
					self.log("no dmarc Authentication-Results headers found, skipping")
				else:
					self.log("dmarc Authentication-Results header indicates %s, skipping" % (dmarc_authres,))
				return Milter.CONTINUE

			spf_authres = self.spf_authres
			if spf_authres not in authres_none_pass_fail:
				if spf_authres is None:
					self.log("no spf Authentication-Results headers found, skipping")
				else:
					self.log("spf Authentication-Results header indicates %s, skipping" % (spf_authres,))
				return Milter.CONTINUE

			# The default DKIM result status:
			dkim_authres = 'none'

			# Whether valid DKIM signatures are present is not very interesting by itself.
			# We consider the e-mail valid in a DKIM sense if the header from address(es)
			# are signed by the same domain (or a parent domain).
			dkim_sigs = self.dkim_sigs

			# Collect all domains from the result=pass dkim header(s):
			dkim_domains = set()
			for sig in dkim_sigs:
				result = sig.result.lower()
				if result == 'pass':
					properties = { (p.type, p.name): str(p.value) for p in sig.properties }
					header_i = properties.get(('header', 'i'))
					header_d = properties.get(('header', 'd'))
					if header_i is not None:
						_, _, domain = header_i.rpartition('@')
						dkim_domains.add(domain.lower())
					elif header_d is not None:
						dkim_domains.add(header_d.lower())

				if result != 'none':
					# This is just the default in case subsequent steps fail:
					dkim_authres = 'fail'

			self.log("dkim_domains:", *dkim_domains)

			if dkim_domains:
				# Collect all domains from the From header(s):
				from_domains = set()
				for idx, value in enumerate(self.header_from, start=1):
					(name, value), = parse_headers("From: " + value).items()
					for group in value.groups:
						for address in group.addresses:
							from_domains.add(address.domain.lower())
				self.log("from_domains:", *from_domains)

				if from_domains:
					# For each address we will check if there's a valid DKIM signature for it.
					# Because the signature might be created by a parent domain, we'll check
					# those too.
					dkim_authres = 'pass'
					for from_domain in from_domains:
						while from_domain not in dkim_domains:
							# Try the parent domain:
							_, separator, from_domain = from_domain.partition('.')
							if not separator:
								# No '.' found means no parent domains left to try.
								break
						else:
							# This from_domain was found, so we can go on and check the next one:
							continue

						# This from_domain was not found, we can give up now.
						dkim_authres = 'fail'
						break

			if dkim_authres not in authres_none_pass_fail:
				if dkim_authres is None:
					self.log("no dkim Authentication-Results headers found, skipping")
				else:
					self.log("dkim Authentication-Results header indicates %s, skipping" % (dkim_authres,))
				return Milter.CONTINUE

			self.log("spf=%s dkim=%s dmarc=%s" % (spf_authres, dkim_authres, dmarc_authres))

			if spf_authres in authres_none_pass and dkim_authres in authres_none_pass:
				return Milter.CONTINUE

			if spf_authres == 'fail' or dmarc_authres != 'none' and dkim_authres == 'fail':
				# Rewrite envelope from:

				# parse_addr() returns a list of 1 or 2 elements, depending
				# on whether there's a @ in the address.
				parsed = parse_addr(self.envelope_from)
				if len(parsed) < 2:
					# Address without @:
					localpart, = parsed
					if localpart == '':
						# Empty sender (<>)
						return Milter.CONTINUE
					domain = default_domain
				else:
					localpart, domain = parsed

				address = srs_rewrite(localpart + "@" + domain, default_domain)
				self.log("changing envelope sender to: %s" % (address,))
				self.chgfrom(address)

				self.chgheader('Authentication-Results', self.spf_authres_index,
					"%s; spf=pass smtp.mailfrom=%s" % (authority, address))

			if dmarc_authres != 'none' and (spf_authres == 'fail' or dkim_authres == 'fail'):
				# Rewrite header from:

				for idx, value in enumerate(self.header_from, start = milter_header_count_base):
					(name, value), = parse_headers("From: " + value).items()

					groups = []
					for group in value.groups:
						addresses = []
						for address in group.addresses:
							if address.domain.lower() == default_domain:
								addresses.append(address)
							else:
								srs_address = srs_rewrite(address.addr_spec, default_domain)
								addresses.append(Address(display_name = address.display_name, addr_spec = srs_address))
						groups.append(Group(display_name = group.display_name, addresses = addresses))

					new_header = fold_header(name, ", ".join(map(str, groups)))
					# Remove the "From:" from new_header:
					(new_name, new_value), = parse_headers(new_header).items()
					self.log("changing header sender to: %s" % (new_value,))
					self.chgheader(name, idx, new_value)

				# Changing the From invalidates all signatures. Best to just remove them all.
				for idx in range(milter_header_count_base, self.dkim_headers + milter_header_count_base):
					self.chgheader('DKIM-Signature', idx, None)

		except Exception as e:
			self.log(str(e))
			#self.log(format_exc())

		return Milter.CONTINUE

Milter.factory = DMARC2SRS
Milter.set_flags(Milter.CHGHDRS | Milter.CHGFROM)
Milter.runmilter('dmarc2srs', socket)
