#! /usr/bin/python2.6
# -*- coding: utf-8 -*-

# <pwloggingr.py  Receive e-mail logging and filter.>
# Copyright (C) <2012> <yasuyosi kimura>
#
# This program is free software; you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation; either version 3 of the License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License along with this program. If not, see <http://www.gnu.org/licenses/>.


### Revision 0.1  2012/01/08 21:05:00  Change name & license 
### Revision 0.0  2011/12/31 14:58:00  Test version.
### Vre 0.0


import sys
import time
import Milter
from Milter.utils import parseaddr, parse_addr

from email.Header import decode_header

import logging
import logging.handlers

socketname = "inet:1025@localhost"
sockettimeout = 600

##    %(levelno)s         Numeric logging level for the message (DEBUG, INFO,
##                        WARNING, ERROR, CRITICAL)
log_filename = "/var/log/pwmail/pwloggingr.log"
log_level = logging.INFO

my_logger = logging.getLogger("pwloggingr")
my_logger.setLevel(log_level)

log_fh = logging.handlers.RotatingFileHandler(log_filename, maxBytes=1024000, backupCount=10)
log_fh.setLevel(logging.DEBUG)
log_fm = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
log_fh.setFormatter(log_fm)
my_logger.addHandler(log_fh)

my_domainlist = ()


def parse_header(val):
  """Decode headers gratuitously encoded to hide the content.
  """
  try:
    h = decode_header(val)
    if not len(h): return val
    u = []
    for s,enc in h:
      if enc:
        try:
          u.append(unicode(s,enc))
        except LookupError:
          u.append(unicode(s))
      else:
        if isinstance(s,unicode):
          u.append(s)
        else:
          for enc1 in ('cp932','utf8'):
            try:
              u.append(unicode(s,enc1))
            except UnicodeDecodeError: continue
            break
    u = ''.join(u)
    for enc in ('us-ascii','iso-8859-1','utf8'):
      try:
        return u.encode(enc)
      except UnicodeError: continue
  except UnicodeDecodeError: pass
  except LookupError: pass
  except email.Errors.HeaderParseError: pass
  return val


class myMilter(Milter.Base):

  def __init__(self):  # A new instance with each new connection.
    self.id = Milter.uniqueID()  # Integer incremented with each call.


#  @Milter.noreply
  def connect(self, IPname, family, hostaddr):
    self.log("connect from %s at %s" % (IPname, hostaddr) )

    # Ini Setup
    if hostaddr and len(hostaddr) > 0:
      self.IP = hostaddr[0]
    else: 
      self.log_critical("REJECT: connect attacks")
      self.setreply('550','5.7.1', 'Banned for connect attacks')
      return Milter.REJECT

    self.Cname = IPname.lower()  # Name from a reverse IP lookup

    self.Hname = None
    self.Hmyd = None
    self.Sabort = None
    return Milter.CONTINUE


#  @Milter.noreply
  def hello(self, heloname):
    self.log("HELO",heloname)

    # Ini Setup
    self.Hname = heloname
    self.Hmyd = None

    self.Hmyd = self.my_domain_check(self.Hname.lower())
    ## rcpt コマンドまでエラーを保留する postofix の標準仕様により
    #if self.Hmyd:
    #  self.log_critical('504','(%s:%s): Helo command rejected: Breach of Local Policy.' % (self.Hname,self.IP))
    #  self.setreply('504','5.5.2','<%s>: Helo command rejected: Breach of Local Policy.' % (self.Hname))
    #  return Milter.REJECT

    return Milter.CONTINUE


#  @Milter.noreply
  def envfrom(self, mailfrom, *str):
    self.log("mail from:", mailfrom, *str)

    # Ini Setup
    self.Fname = mailfrom
    self.Fmyd = False
    self.Rname = []  # list of recipients

    if self.Fname == '<>':
      return Milter.CONTINUE

    Fnad = self.Fname.lower()
    if Fnad[-1] == '>':
      Fnad = Fnad[:-1]
      
    self.Fmyd = self.my_domain_check(Fnad)
    ## rcpt コマンドまでエラーを保留する postofix の標準仕様により
    #if self.Fmyd:
    #  self.log_critical('504','(%s:%s) %s: Sender address rejected: Breach of Local Policy.' % (self.Hname,self.IP,self.Fname))
    #  self.setreply('504','5.5.2','%s: Sender address rejected: Breach of Local Policy.' % (self.Fname))
    #  return Milter.REJECT

    return Milter.CONTINUE


#  @Milter.noreply
  def envrcpt(self, recipient, *str):
    self.log("rcpt to:", recipient, ":", *str)
    self.Rname.append(recipient)

    return Milter.CONTINUE


#  @Milter.noreply
  def data(self):
    self.log("data")

    ## 自サーバからの送信を submission ポートで行っている場合は、エラーである
    ## ヘッダー情報を記録する場合は、コメントにする。
    if self.Hmyd:
      self.log_critical('504','(%s:%s): Helo command rejected: Breach of Local Policy.' % (self.Hname,self.IP))
      self.setreply('504','5.5.2','<%s>: Helo command rejected: Breach of Local Policy.' % (self.Hname))
      return Milter.REJECT

    ## 自サーバからの送信を submission ポートで行っている場合は、エラーである
    ## ヘッダー情報を記録する場合は、コメントにする。
    if self.Fmyd:
      self.log_critical('504','(%s:%s) %s: Sender address rejected: Breach of Local Policy.' % (self.Hname,self.IP,self.Fname))
      self.setreply('504','5.5.2','%s: Sender address rejected: Breach of Local Policy.' % (self.Fname))
      return Milter.REJECT

    # Ini Setup
    self.HDate = None
    self.Subject = None
    self.HMid = None
    self.HList = False

    return Milter.CONTINUE


#  @Milter.noreply
  def header(self, name, hval):
    ### self.log_debug("header:%s: %s" % (name,hval))
    
    nbuf = name.lower()
    if nbuf == "from":
      ms = []
      adbuf = hval.split(',')
      for ad in adbuf:
        ma = parseaddr(ad)
        mn = parse_header(ma[0])
        ms.append(mn + ' <' + ma[1] + '>')
      mf = ",".join(ms)
      self.log_debug("Header-From-B:", hval)
      self.log("Header-From:", mf)
    elif nbuf == "date":
      self.HDate = hval
    elif nbuf == "subject":
      self.Subject = parse_header(hval)
      self.log_debug("Subject-B:", hval)
      self.log("Subject:", self.Subject)
    elif nbuf == "message-id":
      self.log("Message-ID:", hval)
      self.HMid = hval
    elif nbuf.startswith("list-"):
      self.HList = True

    return Milter.CONTINUE


#  @Milter.noreply
  def eoh(self):
    self.log("eoh")

    if self.HList:
      self.log("Mail List")

    ## 自サーバからの送信を submission ポートで行っている場合は、エラーである
    ## ヘッダー情報を記録する場合は、コメントを外す。
    #if self.Hmyd:
    #  self.log_critical('504','(%s:%s): Helo command rejected: Breach of Local Policy.' % (self.Hname,self.IP))
    #  self.setreply('504','5.5.2','<%s>: Helo command rejected: Breach of Local Policy.' % (self.Hname))
    # return Milter.REJECT

    ## 自サーバからの送信を submission ポートで行っている場合は、エラーである
    ## ヘッダー情報を記録する場合は、コメントを外す。
    #if self.Fmyd:
    #  self.log_critical('504','(%s:%s) %s: Sender address rejected: Breach of Local Policy.' % (self.Hname,self.IP,self.Fname))
    #  self.setreply('504','5.5.2','%s: Sender address rejected: Breach of Local Policy.' % (self.Fname))
    #  return Milter.REJECT

    ###--------------------Abort
    # 通常のサーバからの送信でない可能性がある
    if not self.HDate:
      self.setreply('550','5.7.1','Breach of Header-Date Policy.')
      self.log_critical('550','Breach of Header-Date Policy.')
      return Milter.REJECT

    ###--------------------Abort
    # 通常のサーバからの送信でない可能性がある
    if not self.HMid:
      self.setreply('550','5.7.1','Breach of Message-ID Local Policy.')
      self.log_critical('550','(%s): Breach of Message-ID Local Policy.')
      return Milter.REJECT

    return Milter.CONTINUE


  def eom(self):
    self.log("eom")
    # ヘッダーにセンダーアドレスを追加する 迷惑メール対応をメーラーで行う為
    self.addheader('X-PWfrom',self.Fname)
    return Milter.CONTINUE

  def abort(self):
    self.log_debug("abort")
    self.Sabort = True
    return Milter.CONTINUE

  def close(self):
    # abort 時の注意点を指示する。
    if self.Sabort:
      self.log_warning("sever abort: mail server log read")

    self.log("close")
    return Milter.CONTINUE


  ## === Support Functions ===

  def my_domain_check(self,td):
    tl = len(td)
    for d, l in my_domainlist:
      if tl == l:
        if td == d:
          return True
      elif tl > l:
        if td.endswith(d):
          pw = td[tl - l -1] 
          if (pw == '.') or (pw == '@'):
            return True
    return False
  
  def log_debug(self, *msg):
    my_logger.debug('[%d] %s',self.id,' '.join([str(m) for m in msg]))

  def log(self,*msg):
    my_logger.info('[%d] %s',self.id,' '.join([str(m) for m in msg]))

  def log_warning(self, *msg):
    my_logger.warning('[%d] %s',self.id,' '.join([str(m) for m in msg]))

  def log_error(self, *msg):
    my_logger.error('[%d] %s',self.id,' '.join([str(m) for m in msg]))

  def log_critical(self, *msg):
    my_logger.critical('[%d] %s',self.id,' '.join([str(m) for m in msg]))

## ===
    
def main():
  my_logger.info("pwloggingr startup")

  global my_domainlist
  s = sys.argv[1]
  for v in s.split(','):
    p = (v,len(v))
    my_domainlist = my_domainlist + (p,)

  my_logger.info("mydomain:" + str(my_domainlist))

  # Register to have the Milter factory create instances of your class:
  Milter.factory = myMilter
  flags = Milter.ADDHDRS
  Milter.set_flags(flags)       # tell Sendmail which features we use

  Milter.runmilter("pwloggingr",socketname,sockettimeout)
  my_logger.info("pwloggingr shutdown")

if __name__ == "__main__":
  main()