#!/dcs/bin/python

# encryption_style == 0: use null encryption
# encryption_style == 1: use the rotor module
# encryption_style == 2: use the simpcryp module
encryption_style = 2

import socket
import SOCKET
import sys
import regex
import posix
import string
import whrandom
if encryption_style == 1:
	import rotor
elif encryption_style == 2:
	import simpcryp

debug = 0
#debug = 1

server_key_file = '/.srsh.key'
client_key_file = '/usr/local/srsh/keys'

output_dir = '/tmp'
# key_chars isn't as adjustable as it appears - see random_key()
key_chars = 16
max_key = pow(64L,key_chars)
key_digits = 0

msg_greeting = \
	"Send your key (if you are authorized to use this service)"
msg_ok_client = \
	"Correct client key.  Here's my guess at the server key:"
msg_bad_client_address = \
	"Bad client address.  Goodbye, and don't bother to try again."
msg_bad_client_key = \
	"Bad client key.  Goodbye, and don't bother to try again."
msg_ok_server = \
	"Correct server key.  I'm ready for new keys."
msg_bad_server = \
	"Bad server key.  Goodbye, and don't bother to try again."
msg_shell_commands = \
	'Send your shell commands.'

def main():
	global key_digits
	key_digits=len(long_to_string(max_key-1))
	if regex.match('.*srshd$',sys.argv[0]) <> -1:
		# client doesn't need this, and it's not installed on everything...
		server()
	else:
		if sys.argv[1:]:
			if sys.argv[1] == '-p':
				if len(sys.argv) <> 6:
					usage()
				populate()
			elif sys.argv[1] == '-r':
				if len(sys.argv) <> 3:
					usage()
				remove()
			elif sys.argv[1] == '-d':
				if len(sys.argv) <> 3:
					usage()
				dump()
			elif sys.argv[1] == '-s':
				if len(sys.argv) <> 3:
					usage()
				client(1)
			elif sys.argv[1] == '-l':
				if len(sys.argv) <> 2:
					usage()
				list_keys()
			elif sys.argv[1] == '-c':
				if len(sys.argv) <> 4:
					usage()
				client(0)
			elif sys.argv[1] == '-t':
				# not checking number of args...
				test()
			else:
				usage()
		else:
			usage()
	sys.exit(0);

def usage():
	print 'do command    : '+sys.argv[0]+' -c host command'
	print 'cmd on stdin  : '+sys.argv[0]+' -s host'
	print 'store new keys: '+sys.argv[0]+' -p host clientkey serverkey sharedkey'
	print 'dump keys     : '+sys.argv[0]+' -d host'
	print 'remove keys   : '+sys.argv[0]+' -r host'
	print 'list keys     : '+sys.argv[0]+' -l'
	sys.exit(0)

def server():
	try:
		file = open(server_key_file,'r')
	except IOError,value:
		log('1: Could not open key file: '+value[1])
		sys.exit(1)
	posix.chmod(server_key_file,0600)
	posix.chown(server_key_file,0,0)
	sock = socket.fromfd(0,socket.AF_INET,socket.SOCK_STREAM)
	if sock.getpeername()[0] <> '128.200.34.17':
		putnl(msg_bad_client_address)
		log('2: bad client address')
		sys.exit(1)
	#sock.setsockopt(SOCKET.SOL_SOCKET,SOCKET.SO_KEEPALIVE,'')
	[client_key,server_key,shared_key]=string.split(nonl(file.readline()))
	putnl(msg_greeting)
	guessed_client_key = nonl(sys.stdin.readline())
	if client_key <> guessed_client_key:
		putnl(msg_bad_client_key)
		log('2: bad client key')
		sys.exit(1)
	putnl(msg_ok_client)
	putnl(server_key)
	msg = sys.stdin.readline()
	msg = nonl(msg)
	if msg <> msg_ok_server:
		log('3: client did not believe its key: '+msg)
		sys.exit(1)
	# at this point, we are mutually authenticated, and have VERY good reason
	# to believe the third and last key, is shared and hasn't been transmitted
	# in cleartext.  Use it to establish a new set of three keys.
	if encryption_style == 1:
		r = rotor.newrotor(shared_key)
	elif encryption_style == 2:
		sc = simpcryp.simpcryp(hex_to_string(shared_key))
	new_client_key = random_key()
	new_server_key = random_key()
	new_shared_key = random_key()
	if encryption_style == 0:
		munge(null_crypt,new_client_key)
		munge(null_crypt,new_server_key)
		munge(null_crypt,new_shared_key)
	elif encryption_style == 1:
		munge(r.encrypt,new_client_key)
		munge(r.encryptmore,new_server_key)
		munge(r.encryptmore,new_shared_key)
	elif encryption_style == 2:
		munge(sc.encrypt,new_client_key)
		munge(sc.encrypt,new_server_key)
		munge(sc.encrypt,new_shared_key)
	else:
		log('unknown encryption style: '+str(encryption_style))
		sys.exit(1)
	# this is messy. If we cannot open the file and write the keys, but
	# the client has received them...  the authentication will fail next
	# time!
	try:
		file = open(server_key_file,'w')
	except IOError,value:
		log('4: Could not open key file'+value[1])
		sys.exit(1)
	file.write(new_client_key+' '+new_server_key+' '+new_shared_key+'\n')
	file.close()
	# this could probably be done with one less process...
	dummy = posix.system('/bin/sh')

def client(from_stdin):
	if debug:
		print 'starting client'
	host=sys.argv[2]
	if from_stdin == 0:
		command=string.join(sys.argv[3:])
	sock = socket.socket(socket.AF_INET,socket.SOCK_STREAM)
	if debug:
		print 'got socket'
	try:
		#sock.connect(host,portno('srsh'))
		sock.connect(host,650)
	except socket.error,message:
		# message comes back bogus?  "o"?
		#sys.stderr.write(message[1]+'\n')
		sys.stderr.write('Could not connect to '+host+'\n')
		sys.exit(1)
	if debug:
		print 'connected to remote host'
	msg = getnl(sock)
	if nonl(msg) <> msg_greeting:
		sys.stderr.write('bad greeting:\n')
		sys.stderr.write(msg)
		sys.exit(1)
	if debug:
		print 'getting keys'
	(client_key,server_key,shared_key) = get_keys(host)
	# the initiator, the client, divulges first
	sock.send(client_key+'\n')
	# get affirmation/negation of what we presented...
	msg = nonl(getnl(sock))
	if msg <> msg_ok_client:
		if msg == msg_bad_client_key:
			sys.stderr.write(host+' does not believe its key is '+client_key+'\n')
			sys.exit(1)
		elif msg == msg_bad_client_address:
			sys.stderr.write(host+' is refusing connections from this machine\n')
			sys.exit(1)
		else:
			sys.stderr.write('garbage received from '+host+'\n')
			sys.exit(1)
	# Now the server knows we are who we say we are.  Now it needs to prove
	# its identity to us.
	guessed_server_key = nonl(getnl(sock))
	if guessed_server_key <> server_key:
		sock.send(msg_bad_server+'\n')
		sys.stderr.write(host+" does not know it's key\n")
		sys.exit(1)
	sock.send(msg_ok_server+'\n')
	if encryption_style == 0:
		new_client_key = demunge(null_crypt,sock)
		new_server_key = demunge(null_crypt,sock)
		new_shared_key = demunge(null_crypt,sock)
	elif encryption_style == 1:
		r = rotor.newrotor(shared_key)
		new_client_key = demunge(r.decrypt,sock)
		new_server_key = demunge(r.decryptmore,sock)
		new_shared_key = demunge(r.decryptmore,sock)
	elif encryption_style == 2:
		sc = simpcryp.simpcryp(hex_to_string(shared_key))
		new_client_key = demunge(sc.decrypt,sock)
		new_server_key = demunge(sc.decrypt,sock)
		new_shared_key = demunge(sc.decrypt,sock)
	else:
		sys.stderr.write('unknown encryption style: '+str(encryption_style)+'\n')
		sys.exit(1)
	put_keys(host,new_client_key,new_server_key,new_shared_key)
	if from_stdin:
		while 1:
			line = sys.stdin.readline()
			if not line:
				break
			sock.send(line)
	else:
		sock.send(command+'\n')
	sock.send('\nexit\n')
	# shutdown the "send" half of the socket
	sock.shutdown(1)
	while 1:
		buf = sock.recv(1024)
		if not buf:
			break
		sys.stdout.write(buf)
	# this causes problems sometimes, and not others
	#sock.shutdown(2)
	sys.exit(0)

def demunge(fn,sock):
	#return fn(hex_to_string(nonl(getnl(sock))))
	s1 = getnl(sock)
	s2 = nonl(s1)
	s3 = hex_to_string(s2)
	s4 = fn(s3)
	return s4

def munge(fn,key):
	s1 = fn(key)
	s2 = string_to_hex(s1)
	putnl(s2)

def list_keys():
	import dbm
	try:
		database = dbm.open(client_key_file,'r',0600)
	except dbm.error,msg:
		sys.stderr.write('Could not open key database "keys": '+msg[1]+'\n')
		sys.exit(0)
	list = database.keys()
	list.sort()
	print string.joinfields(list,'\n')

def get_keys(host):
	import dbm
	try:
		database = dbm.open(client_key_file,'r',0600)
	except dbm.error,msg:
		sys.stderr.write('Could not open key database "keys": '+msg[1]+'\n')
		sys.exit(0)
	try:
		s = database[host]
	except (KeyError):
		sys.stderr.write('No key for '+host+'\n')
		sys.exit(1)
	fields = string.split(s)
	return (fields[0],fields[1],fields[2])
	
def put_keys(host,a,b,c):
	import dbm
	try:
		database = dbm.open(client_key_file,'w',0600)
	except dbm.error,value:
		sys.stderr.write('Could not open local key database: '+value[1]+'\n')
		sys.exit(0)
	database[host] = a+' '+b+' '+c

def random_key():
	global key_digits
	max_key_root = pow(64L,key_chars/2)
	t = long_to_string(long(whrandom.random() * max_key_root))
	t = t + long_to_string(long(whrandom.random() * max_key_root))
	# having all the leading zeros is important - things break without them
	while len(t) < key_digits:
		t = '0' + t
	return t

def nonl(str):
	result = ''
	for ch in str:
		if ch <> '\n' and ch <> '\r':
			result = result + ch
	return result

def examine(str):
	print len(str)
	for ch in str:
		print ord(ch),ch

def putnl(message):
	put(message+'\n')

def put(message):
	sys.stdout.write(message)
	sys.stdout.flush()
	
def portno(p):
	numeric=regex.compile('^[0-9][0-9]*$')
	if numeric.match(p) == -1:
		return socket.getservbyname(p,'tcp')
	else:
		return string.atoi(p)

def getnl(sock):
	# very slow.  used only during initial exchange.
	result = ''
	while 1:
		ch = sock.recv(1)
		if not ch:
			break
		result = result + ch
		if ch == '\n':
			break
	return result

def long_to_string(l):
	# in hex, strip off leading '0x' and trailing 'L'
	t = hex(l)[2:-1]
	return t

def string_to_long(s):
	result = 0
	for ch in s:
		result = result * 16 + hex_digit(ch)
	return result

def string_to_hex(s):
	s2 = ''
	for ch in s:
		c = ord(ch)
		if c >= 16:
			s2 = s2 + hex(c)[2:]
		else:
			s2 = s2 + '0' + hex(c)[2:]
	return s2

def hex_digit(d):
	if d >= '0' and d <= '9':
		return ord(d) - ord('0')
	elif d >= 'a' and d <= 'z':
		return ord(d) - ord('a') + 10
	elif d >= 'A' and d <= 'Z':
		return ord(d) - ord('A') + 10
	else:
		print 'bad hex digit:',d,ord(d)
		sys.exit(1)

def hex_to_string(s):
	s2 = ''
	t = s
	while len(t) > 0:
		if len(t) == 1:
			v = hex_digit(t[0])
		else:
			v = hex_digit(t[0]) * 16 + hex_digit(t[1])
		s2 = s2 + chr(v)
		t = t[2:]
	return s2

def log(msg):
	file = open('/var/adm/srsh.log','a')
	file.write(msg+'\n')
	file.close()

def populate():
	host = sys.argv[2]
	put_keys(host,sys.argv[3],sys.argv[4],sys.argv[5])

def dump():
	host = sys.argv[2]
	keys = get_keys(host)
	print keys[0],keys[1],keys[2]

def remove():
	host = sys.argv[2]
	import dbm
	try:
		database = dbm.open(client_key_file,'w',0600)
	except dbm.error,value:
		sys.stderr.write('Could not open local key database: '+value[1]+'\n')
		sys.exit(0)
	del database[host]

def test():
	#print hex_to_string('testing!')
	print hex_to_string(string_to_hex('B5EF68E95F9E7654C695D2E4'))
	#r = rotor.newrotor('key')
	#r2 = rotor.newrotor('key')
	#print r2.decrypt(r.encrypt('fred'))

def null_crypt(s):
	return s

main()