#!/usr/bin/env python3

import binascii
import logging
import os
import sys
import time
import uuid

from structure import Structure
from kmsDB2Dict import kmsDB2Dict
import kmsPidGenerator
import filetimes
from formatText import justify, shell_message, byterize

# sqlite3 is optional
try:
        import sqlite3
except ImportError:
        pass

logger = logging.getLogger('root')

class UUID(Structure):
        commonHdr = ()
        structure = (
                ('raw', '16s'),
        )

        def get(self):
                return uuid.UUID(bytes_le=str(self).encode('latin-1'))

class kmsBase:
        def __init__(self, data, config):
                self.data = data
                self.config = config
                
        class kmsRequestStruct(Structure):
                commonHdr = ()
                structure = (
                        ('versionMinor',            '<H'),
                        ('versionMajor',            '<H'),
                        ('isClientVm',              '<I'),
                        ('licenseStatus',           '<I'),
                        ('graceTime',               '<I'),
                        ('applicationId',           ':', UUID),
                        ('skuId',                   ':', UUID),
                        ('kmsCountedId' ,           ':', UUID),
                        ('clientMachineId',         ':', UUID),
                        ('requiredClientCount',     '<I'),
                        ('requestTime',             '<Q'),
                        ('previousClientMachineId', ':', UUID),
                        ('machineName',             'u'),
                        ('_mnPad',                  '_-mnPad', '126-len(machineName)'),
                        ('mnPad',                   ':'),
                )

                def getMachineName(self):
                        return self['machineName'].decode('utf-16le')
                
                def getLicenseStatus(self):
                        return kmsBase.licenseStates[self['licenseStatus']] or "Unknown"

        class kmsResponseStruct(Structure):
                commonHdr = ()
                structure = (
                        ('versionMinor',         '<H'),
                        ('versionMajor',         '<H'),
                        ('epidLen',              '<I=len(kmsEpid)+2'),
                        ('kmsEpid',              'u'),
                        ('clientMachineId',      ':', UUID),
                        ('responseTime',         '<Q'),
                        ('currentClientCount',   '<I'),
                        ('vLActivationInterval', '<I'),
                        ('vLRenewalInterval',    '<I'),
                )

        class GenericRequestHeader(Structure):
                commonHdr = ()
                structure = (
                        ('bodyLength1',  '<I'),
                        ('bodyLength2',  '<I'),
                        ('versionMinor', '<H'),
                        ('versionMajor', '<H'),
                        ('remainder',    '_'),
                )

        licenseStates = {
                0 : "Unlicensed",
                1 : "Activated",
                2 : "Grace Period",
                3 : "Out-of-Tolerance Grace Period",
                4 : "Non-Genuine Grace Period",
                5 : "Notifications Mode",
                6 : "Extended Grace Period",
        }

        licenseStatesEnum = {
                'unlicensed' : 0,
                'licensed' : 1,
                'oobGrace' : 2,
                'ootGrace' : 3,
                'nonGenuineGrace' : 4,
                'notification' : 5,
                'extendedGrace' : 6
        }

        errorCodes = {
                'SL_E_VL_NOT_WINDOWS_SLP' : 0xC004F035,
                'SL_E_VL_NOT_ENOUGH_COUNT' : 0xC004F038,
                'SL_E_VL_BINDING_SERVICE_NOT_ENABLED' : 0xC004F039,
                'SL_E_VL_INFO_PRODUCT_USER_RIGHT' : 0x4004F040,
                'SL_I_VL_OOB_NO_BINDING_SERVER_REGISTRATION' : 0x4004F041,
                'SL_E_VL_KEY_MANAGEMENT_SERVICE_ID_MISMATCH' : 0xC004F042,
                'SL_E_VL_MACHINE_NOT_BOUND' : 0xC004F056
        }

        def getPadding(self, bodyLength):
                ## https://forums.mydigitallife.info/threads/71213-Source-C-KMS-Server-from-Microsoft-Toolkit?p=1277542&viewfull=1#post1277542
                return 4 + (((~bodyLength & 3) + 1) & 3)

        def serverLogic(self, kmsRequest):
                if self.config['sqlite'] and self.config['dbSupport']:
                        self.dbName = 'clients.db'
                        if not os.path.isfile(self.dbName):
                                # Initialize the database.
                                con = None
                                try:
                                        con = sqlite3.connect(self.dbName)
                                        cur = con.cursor()
                                        cur.execute("CREATE TABLE clients(clientMachineId TEXT, machineName TEXT, applicationId TEXT, \
skuId TEXT, licenseStatus TEXT, lastRequestTime INTEGER, kmsEpid TEXT, requestCount INTEGER)")

                                except sqlite3.Error as e:
                                        logger.error("Error %s:" % e.args[0])
                                        sys.exit(1)

                                finally:
                                        if con:
                                                con.commit()
                                                con.close()

                shell_message(nshell = 15)
                kmsRequest = byterize(kmsRequest)
                logger.debug("KMS Request Bytes: \n%s\n" % justify(binascii.b2a_hex(str(kmsRequest).encode('latin-1')).decode('utf-8')))                         
                logger.debug("KMS Request: \n%s\n" % justify(kmsRequest.dump(print_to_stdout = False)))
                                        
                clientMachineId = kmsRequest['clientMachineId'].get()
                applicationId = kmsRequest['applicationId'].get()
                skuId = kmsRequest['skuId'].get()
                requestDatetime = filetimes.filetime_to_dt(kmsRequest['requestTime'])

                # Localize the request time, if module "tzlocal" is available.
                try:
                        from tzlocal import get_localzone
                        from pytz.exceptions import UnknownTimeZoneError
                        try:
                                tz = get_localzone()
                                local_dt = tz.localize(requestDatetime)
                        except UnknownTimeZoneError:
                                logger.warning('Unknown time zone ! Request time not localized.')
                                local_dt = requestDatetime
                except ImportError:
                        logger.warning('Module "tzlocal" not available ! Request time not localized.')
                        local_dt = requestDatetime
                        
                # Get SkuId, AppId and client threshold.
                appName, skuName = applicationId, skuId
        
                kmsdb = kmsDB2Dict()
 
                appitems = kmsdb[2]
                for appitem in appitems:
                        kmsitems = appitem['KmsItems']
                        for kmsitem in kmsitems:
                                # Activation threshold.
                                try:
                                        count = int(kmsitem['NCountPolicy'])
                                except KeyError:
                                        count = 25
                                
                                if self.config["CurrentClientCount"] <= count:
                                        currentClientCount = count + 1
                                else:
                                        currentClientCount = self.config["CurrentClientCount"]
                                
                                skuitems = kmsitem['SkuItems']
                                for skuitem in skuitems:
                                        try:
                                                if uuid.UUID(skuitem['Id']) == skuId:
                                                        skuName = skuitem['DisplayName']
                                                        break
                                        except IndexError:
                                                pass
                                        
                        if uuid.UUID(appitem['Id']) == applicationId:
                                appName = appitem['DisplayName']

                infoDict = {
                        "machineName" : kmsRequest.getMachineName(),
                        "clientMachineId" : str(clientMachineId),
                        "appId" : appName,
                        "skuId" : skuName,
                        "licenseStatus" : kmsRequest.getLicenseStatus(),
                        "requestTime" : int(time.time()),
                        "kmsEpid" : None
                }

                #print infoDict
                logger.info("Machine Name: %s" % infoDict["machineName"])
                logger.info("Client Machine ID: %s" % infoDict["clientMachineId"])
                logger.info("Application ID: %s" % infoDict["appId"])
                logger.info("SKU ID: %s" % infoDict["skuId"])
                logger.info("License Status: %s" % infoDict["licenseStatus"])
                logger.info("Request Time: %s" % local_dt.strftime('%Y-%m-%d %H:%M:%S %Z (UTC%z)'))

                if self.config['sqlite'] and self.config['dbSupport']:
                        con = None
                        try:
                                con = sqlite3.connect(self.dbName)
                                cur = con.cursor()
                                cur.execute("SELECT * FROM clients WHERE clientMachineId=:clientMachineId;", infoDict)
                                try:
                                        data = cur.fetchone()
                                        if not data:
                                                #print "Inserting row..."
                                                cur.execute("INSERT INTO clients (clientMachineId, machineName, applicationId, \
skuId, licenseStatus, lastRequestTime, requestCount) VALUES (:clientMachineId, :machineName, :appId, :skuId, :licenseStatus, :requestTime, 1);", infoDict)
                                        else:
                                                #print "Data:", data
                                                if data[1] != infoDict["machineName"]:
                                                        cur.execute("UPDATE clients SET machineName=:machineName WHERE \
clientMachineId=:clientMachineId;", infoDict)
                                                if data[2] != infoDict["appId"]:
                                                        cur.execute("UPDATE clients SET applicationId=:appId WHERE \
clientMachineId=:clientMachineId;", infoDict)
                                                if data[3] != infoDict["skuId"]:
                                                        cur.execute("UPDATE clients SET skuId=:skuId WHERE \
clientMachineId=:clientMachineId;", infoDict)
                                                if data[4] != infoDict["licenseStatus"]:
                                                        cur.execute("UPDATE clients SET licenseStatus=:licenseStatus WHERE \
clientMachineId=:clientMachineId;", infoDict)
                                                if data[5] != infoDict["requestTime"]:
                                                        cur.execute("UPDATE clients SET lastRequestTime=:requestTime WHERE \
clientMachineId=:clientMachineId;", infoDict)
                                                # Increment requestCount
                                                cur.execute("UPDATE clients SET requestCount=requestCount+1 WHERE \
clientMachineId=:clientMachineId;", infoDict)

                                except sqlite3.Error as e:
                                        logger.error("Error %s:" % e.args[0])
                                        
                        except sqlite3.Error as e:
                                logger.error("Error %s:" % e.args[0])
                                sys.exit(1)
                        finally:
                                if con:
                                        con.commit()
                                        con.close()

                return self.createKmsResponse(kmsRequest, currentClientCount)

        def createKmsResponse(self, kmsRequest, currentClientCount):
                response = self.kmsResponseStruct()
                response['versionMinor'] = kmsRequest['versionMinor']
                response['versionMajor'] = kmsRequest['versionMajor']
                
                if not self.config["epid"]:
                        response["kmsEpid"] = kmsPidGenerator.epidGenerator(kmsRequest['kmsCountedId'].get(), kmsRequest['versionMajor'],
                                                                            self.config["lcid"]).encode('utf-16le')
                else:
                        response["kmsEpid"] = self.config["epid"].encode('utf-16le')
                        
                response['clientMachineId'] = kmsRequest['clientMachineId']
                response['responseTime'] = kmsRequest['requestTime']
                response['currentClientCount'] = currentClientCount
                response['vLActivationInterval'] = self.config["VLActivationInterval"]
                response['vLRenewalInterval'] = self.config["VLRenewalInterval"]

                if self.config['sqlite'] and self.config['dbSupport']:
                        con = None
                        try:
                                con = sqlite3.connect(self.dbName)
                                cur = con.cursor()
                                cur.execute("SELECT * FROM clients WHERE clientMachineId=?;", [str(kmsRequest['clientMachineId'].get())])
                                try:
                                        data = cur.fetchone()
                                        if data[6]:
                                                response["kmsEpid"] = data[6].encode('utf-16le')
                                        else:
                                                cur.execute("UPDATE clients SET kmsEpid=? WHERE clientMachineId=?;",
                                                            (str(response["kmsEpid"].decode('utf-16le')), str(kmsRequest['clientMachineId'].get())))

                                except sqlite3.Error as e:
                                        logger.error("Error %s:" % e.args[0])
                                        
                        except sqlite3.Error as e:
                                logger.error("Error %s:" % e.args[0])
                                sys.exit(1)
                        finally:
                                if con:
                                        con.commit()
                                        con.close()

                logger.info("Server ePID: %s" % response["kmsEpid"].decode('utf-16le'))
                        
                return response


import kmsRequestV4, kmsRequestV5, kmsRequestV6, kmsRequestUnknown

def generateKmsResponseData(data, config):
        version = kmsBase.GenericRequestHeader(data)['versionMajor']
        currentDate = time.strftime("%a %b %d %H:%M:%S %Y")

        if version == 4:
                logger.info("Received V%d request on %s." % (version, currentDate))
                messagehandler = kmsRequestV4.kmsRequestV4(data, config)     
        elif version == 5:
                logger.info("Received V%d request on %s." % (version, currentDate))
                messagehandler = kmsRequestV5.kmsRequestV5(data, config)
        elif version == 6:
                logger.info("Received V%d request on %s." % (version, currentDate))
                messagehandler = kmsRequestV6.kmsRequestV6(data, config)
        else:
                logger.info("Unhandled KMS version V%d." % version)
                messagehandler = kmsRequestUnknown.kmsRequestUnknown(data, config)
                
        return messagehandler.executeRequestLogic()