#! /usr/bin/python3

# $Id: stripar 48288 2019-08-01 12:56:05Z wsl $
# $URL: https://svn.uvt.nl/its-id/trunk/sources/dmarc2srs/stripar $

# Remove Authentication-Results headers, either those matching a list
# of authorities or simply all of them.

import Milter
from pwd import getpwnam, getpwuid
from grp import getgrnam
from os import setresuid, setresgid, umask
from sys import argv, stderr
from socket import getfqdn
from re import compile as regcomp
from email.parser import HeaderParser

if len(argv) != 2:
	raise Exception("usage: %s <configuration file>" % (argv[0],))

config = {}

with open(argv[1]) as fh:
	exec(fh.read(), {}, config)

umask(config.get('umask', 0o007))

socket = config['socket']

# de authority van de Authentication-Results header:
authorities = config.get('authorities', getfqdn())

if authorities is not None:
	if isinstance(authorities, str):
		authorities = (authorities,)

	authorities = frozenset(a.lower() for a in authorities)

looks_like_number = regcomp('[0-9]+').fullmatch

user = config.get('user', None)
group = config.get('group', None)

# According to the documentation, Milter counts headers starting from 1.
# However, observed results indicate that it counts starting from 0.
# Might be a milter implementation bug in Postfix.
milter_header_count_base = config.get('milter_header_count_base', 0)

if user is not None:
	if looks_like_number(user):
		uid = int(user)
		if group is None:
			pwent = getpwuid(user)
	else:
		pwent = getpwnam(user)
		uid = pwent.pw_uid

	if group is None:
		gid = pwent.pw_gid
		setresgid(gid, gid, gid)
		setgroups(getgrouplist(user, gid))

if group is not None:
	if looks_like_number(group):
		gid = int(group)
	else:
		gid = getgrnam(group).gr_gid
	setresgid(gid, gid, gid)
	#setgroups([gid])
		
if user is not None:
	setresuid(uid, uid, uid)

class StripAR(Milter.Base):
	authres_headers = milter_header_count_base
	authres_strip_headers = ()

	def log(self, msg):
		# Milter log() seems broken, just write to stderr and let
		# systemd handle it.
		# return super().log("stripar: " + msg)
		print(msg, file = stderr, flush = True)

	def header(self, name, value, *extra):
		try:
			if name.lower() == 'authentication-results':
				authres_headers = self.authres_headers
				self.authres_headers = authres_headers + 1
				ar = parse_authres(value)
				if authorities is None or ar.authserv_id.lower() in authorities:
					if not self.authres_strip_headers:
						self.authres_strip_headers = [authres_headers]
					else:
						self.authres_strip_headers.append(authres_headers)

		except Exception as e:
			self.log(str(e))
			#self.log("header(%s, %s)" % (name, value))
			#self.log(format_exc())

		return Milter.CONTINUE

	def eom(self):
		try:
			for idx in self.authres_strip_headers:
				self.chgheader('Authentication-Results', idx, None)

		except Exception as e:
			self.log(str(e))

		return Milter.CONTINUE

Milter.factory = StripAR
Milter.set_flags(Milter.ADDHDRS)
Milter.runmilter('stripar', socket)
