#!/usr/local/bin/ruby
#
# DNS Balance --- ưŪʬԤʤ DNS 
#
# By: YOKOTA Hiroshi <yokota@netlab.is.tsukuba.ac.jp>

# $Id: dns_balance.rb,v 1.5 2000/08/08 15:10:01 elca Exp $

require 'socket'
require 'datatype'
require 'addrdb'

# 桼㳰
class NotImplementedError < StandardError ; end
class TruncatedError < StandardError ; end
class NoQueryError < StandardError ; end

#
# DNS ѥåȤ饷󥹥ʥСȼƤȼΥפȼ
# 饹Ф
#
def parse_packet(packet)
  (number, flags, num_q, ans_rr, ort_rr, add_rr, str) =  packet.unpack("a2 a2 a2 a2 a2 a2 a*")

  if num_q != "\0\1"
    #p "numq !=1"
    return nil
  end

  # ̯ʥѥåȤӽ
  #  1 ʤ nil ˤʤʤϤ
  if (str.split("\0")[0].length+1 +2+2) != str.length
    return nil
  end

  (q, q_type, q_class) = str.unpack("a#{(str.length-4).to_s} a2 a2")

  return [q, q_type, q_class]

end

#
# 饤ȤIPɥ쥹֤
#
def parse_client_addr(str)
  (family, port, ipaddr, pad) = str.unpack("a2 a2 a4 a*")

  return ipaddr
end

#
# 饤ȤIPɥ쥹ˤäƤѤ
# ֤̾ƤϤޤʪʤ "default" ˤʤ
#
def select_namespace(addrstr, name)
  p1 = sprintf("%d.%d.%d.%d", addrstr[0], addrstr[1], addrstr[2], addrstr[3])
  p2 = sprintf("%d.%d.%d",    addrstr[0], addrstr[1], addrstr[2])
  p3 = sprintf("%d.%d",       addrstr[0], addrstr[1])
  p4 = sprintf("%d",          addrstr[0])

  for i in [p1, p2, p3, p4]
    if $addr_db[i] != nil && $addr_db[i][name] != nil
      return i
    end
  end
  return "default"
end

#
# ŤߤĤѿΤɽ
#
def make_rand_array(namespace, name)
  rnd_max = 0
  rnd_slesh = []
  for i in $addr_db[namespace][name]
    rnd_max += (10000 - min(10000, i[1])) # badness κͤ 10000
    rnd_slesh.push(rnd_max)
  end
  #p rnd_max
  #p rnd_slesh

  return [rnd_max, rnd_slesh]
end

#
# ŤߤĤ
#
def select_rand_array(namespace, name, size)
  (rnd_max, rnd_slesh) = make_rand_array(namespace, name)

  arr = []
  while size > arr.size
    rnd = rand(rnd_max)
    for j in 0...rnd_slesh.size
      if rnd <= rnd_slesh[j]
	arr.push(j)
	break
      end
    end
    arr.uniq!
  end

  return arr
end

#
# DNS ѥå RR ̾ʹ֤ɤߤ䤹Ѵ
#
def dnsstr_to_str(dnsstr)
  arr = []
  c = 0
  while TRUE
    if dnsstr[c] == 0
      break
    end
    arr.push(dnsstr[c+1, dnsstr[c]])
    c += dnsstr[c]+1
  end
  return arr.join(".")
end


def min(a, b)
  if a < b
    return a
  else
    return b
  end
end

######################################################################
# main

#
# ɥ쥹ǡ١ưŪ
#
Thread.start do
  while TRUE
    if test(?r, "addr")
      #p "reload db"
      load("addr")
    end
    #p $addr_db
    sleep(10*60) # 10 ʬ˹
  end
end

srand()

gs = Socket.open(Socket::AF_INET, Socket::SOCK_DGRAM, 0)
gs.bind([Socket::AF_INET, 53, "", ""].pack("n n a4 a8"))

#
# ᥤ롼
#
while TRUE
  (packet, client_addr) = gs.recvfrom(512)
  Thread.start do
    begin
      (q, q_type, q_class) = parse_packet(packet)
      client = parse_client_addr(client_addr)

      if q_type == DnsType::AXFR # ž̵
	raise NotImplementedError
      end

      if q_class != DnsClass::INET # IP Τ߼դ
	raise NoQueryError
      end

      name = dnsstr_to_str(q).downcase
      namespace = select_namespace(client, name)

      # 
      r = packet[0,12] + q + q_type + q_class
      r[2] |= 0x84
      r[3] &= ~0x0f

      size = min($addr_db[namespace][name].size, 3)
      q_array = select_rand_array(namespace, name, size)

      for i in q_array
	addr = $addr_db[namespace][name][i][0]

	#  եåȤ 0x0c, TTL  5 äˤ
	r += "\xc0\x0c" + q_type + q_class + "\0\0\0\5" + "\0\4" + addr.pack("CCCC")
      end

      # ο򥻥å
      r[6..7] = [size].pack("n")

      # Ĺ᤮
      if r.length > 512
	raise TruncatedError
      end

    rescue NotImplementedError
      r = packet[0,12] + q + q_type + q_class
      r[2] |= 0x80  # answer
      r[2] &= ~0x04 # not authenticated
      r[3] &= ~0x0f
      r[3] |= 0x04  # not implemented error
      #p "NoI"
    rescue TruncatedError
      # Ĺ᤮Ϻäƥե饰ΩƤ
      r = r[0,512]
      r[2] |= 0x02
      #p "Tru"
    rescue NoQueryError,StandardError
      r = packet[0,12] + q + q_type + q_class
      r[2] |= 0x80  # answer
      r[2] &= ~0x04 # not authenticated
      r[3] &= ~0x0f
      r[3] |= 0x05  # query refused error
      #p "NoQ"
    rescue
      # ˤʤϤ
      r = packet[0,12] + q + q_type + q_class
      r[2] |= 0x80  # answer
      r[2] &= ~0x04 # not authenticated
      r[3] &= ~0x0f
      r[3] |= 0x05  # query refused error
      #p "other"
    end

    #print packet.dump, "\n"
    #print r.dump, "\n"
    #p q

    gs.send(r, 0, client_addr)

  end
end

# end
