とりあえずε-greedy
import random
import numpy as np
class Arm:
def __init__(self, prob, no):
self.probability = prob
self.no = no
self.hit_count = 0
self.total_count = 0
def lot(self):
random_float = random.random()
self.increment_total_count()
if random_float < self.probability:
self.increment_hit_count()
return True
return False
def increment_total_count(self):
self.total_count += 1
return
def increment_hit_count(self):
self.hit_count += 1
return
def observable_expected_hit(self):
if self.total_count == 0:
return 0
return self.hit_count / self.total_count
return
class Recorder:
def __init__(self):
self.trial_count = 0
self.probability_list = []
self.arm_no_list = []
self.hit_count = 0
return
def increment_hit_count(self):
self.hit_count += 1
return
def increment_trial_count(self):
self.trial_count += 1
return
def push_probability(self, prob):
self.probability_list.append(prob)
return
def push_arm_no(self, arm_no):
self.arm_no_list.append(arm_no)
return
def output(self):
print "trial_count : %d" % self.trial_count
print "hit_count : %d" % self.hit_count
return
class Bandit:
def __init__(self, recorder):
self.counter = {}
return
def explore(self, arms):
return
def exploit(self, arms):
return
def execute(self, arms):
return
def lot_arm(self,arm):
return
class EpsilonGreedyBandit(Bandit):
def __init__(self, recorder, init_epsilon):
self.recorder = recorder
self.epsilon = init_epsilon
return
def explore(self, arms):
number_of_arms = len(arms)
selected_index = int(number_of_arms * random.random()) - 1
selected_arm = arms[selected_index]
self.lot_arm(selected_arm)
return
def exploit(self, arms):
expected_hit_list = [ arm.observable_expected_hit() for arm in arms ]
max_index = np.argmax(expected_hit_list)
selected_arm = arms[max_index]
self.lot_arm(selected_arm)
return
def lot_arm(self, arm):
is_hit = arm.lot()
self.recorder.increment_trial_count()
self.recorder.push_arm_no(arm.no)
self.recorder.push_probability(arm.probability)
if is_hit:
self.recorder.increment_hit_count()
def execute(self, arms):
epsilon = self.epsilon
if epsilon < random.random():
self.exploit(arms)
else:
self.explore(arms)
arm_attributes = [
(0.1, 1),
(0.2, 2),
(0.4, 3),
]
arms = [ Arm(attribute[0], attribute[1]) for attribute in arm_attributes ]
recorder = Recorder()
EPSILON = 0.1
TRIAL_NUM = 100
epsilon_greedy_bandit = EpsilonGreedyBandit(recorder, EPSILON)
counter = 0
for _ in range(0, TRIAL_NUM):
counter += 1
epsilon_greedy_bandit.execute(arms)
recorder.output()
print recorder.arm_no_list
print counter