-
미로 찾기 ai -강화학습(q-table,입실론 그리디 정책)프로그래밍/python 2023. 6. 7. 11:23
미로를 생성하고 최단거리로 탈출하는 Q-테이블을 찾는 AI를 만들었습니다. 미로 생성은 깊이 우선 탐색을 사용하였고,
탈출 과정에서는 앱실론 그리디 정책을 적용했습니다.
AI는 미로를 최단 경로로 탈출하기 위해 Q-테이블을 활용합니다. Q-테이블은 상태(State)와 행동(Action)을 기반으로최적의 가치(Value)를 예측하는 테이블입니다. AI는 현재 상태에서 최적의 행동을 선택하고,
행동을 수행한 후에는 Q-테이블을 업데이트하여 학습합니다.
이 방법은 탐색과 이용을 균형있게 조정하여 최적의 경로를 탐색하면서도 이미 발견한 경로를 활용합니다.
아래는 소스 코드입니다.
maze_ai.py
(강화학습을 진행하여 미로를탈출하는 모듈)
import pygame import numpy as np import env import argparse import calculate class Maze_ai(): def __init__(self,maze_size,view_episode,visualize,episodes): self.maze = np.array(env.make_maze(maze_size)) self.wall_img=pygame.image.load("images/wall.png") self.player_img=pygame.image.load("images/player.png") self.exit_img=pygame.image.load("images/exit.png") self.wall_img=pygame.transform.scale(self.wall_img,(50,50)) self.player_img=pygame.transform.scale(self.player_img,(50,50)) self.exit_img=pygame.transform.scale(self.exit_img,(50,50)) self.maze_height = len(self.maze) self.maze_width = len(self.maze[0]) self.maze[self.maze_height - 2][self.maze_width - 1] = 2 # 가능한 행동 self.actions = ['up', 'down', 'left', 'right'] # Q 테이블 초기화 self.q_table = np.zeros((self.maze_height, self.maze_width, len(self.actions))) # 하이퍼파라미터 설정 self.learning_rate = 0.1 self.discount_factor = 0.99 self.exploration_rate = 0.1 self.ep_num = episodes #reward관련변수 self.move=0 self.prev_move=0 self.failed=0 self.serched=[]# 미로 설정 self.visualize=visualize self.view_episode=view_episode self.maze_size=maze_size self.calculate_shortest_path=calculate.Calculate(self.maze) self.shortest_path=self.calculate_shortest_path.return_shortest_path()+1 self.end_ep=self.start() print("Shortest Path:", self.shortest_path) print("End Episode:", self.end_ep) def start(self): # pygame 초기화 pygame.init() clock = pygame.time.Clock() # 색상 설정 BLACK = (0, 0, 0) WHITE = (255, 255, 255) GREEN = (0, 255, 0) BLUE = (0, 0, 255) # 창 크기 설정 if self.visualize: screen_width = self.maze_width * 50 screen_height = self.maze_height * 50 screen = pygame.display.set_mode((screen_width, screen_height)) pygame.display.set_caption("Maze Solver") # 에이전트 초기 위치 agent_position = (0, 1) # 게임 루프 running = True current_episode = 1 while running: if self.visualize: if self.view_episode<=current_episode: for event in pygame.event.get(): if event.type == pygame.QUIT: running = False # 미로 그리기 screen.fill(BLACK) for row in range(self.maze_height): for col in range(self.maze_width): if self.maze[row][col] == 1: #img사용하기 screen.blit(self.wall_img,(col * 50, row * 50)) if self.maze[row][col] == 2: screen.blit(self.exit_img,(col * 50, row * 50)) # 에이전트 그리기 screen.blit(self.player_img,(agent_position[1] * 50, agent_position[0] * 50)) # 탐험 및 활용 결정 self.exploration_rate_threshold = np.random.uniform(0, 1) if self.exploration_rate_threshold > self.exploration_rate: action_index = np.argmax(self.q_table[agent_position])#최적의 행동 else: action_index = np.random.randint(len(self.actions))#완전랜덤 self.move+=1 action = self.actions[action_index] # 에이전트 이동 if action == 'up': next_position = (agent_position[0] - 1, agent_position[1]) elif action == 'down': next_position = (agent_position[0] + 1, agent_position[1]) elif action == 'left': next_position = (agent_position[0], agent_position[1] - 1) else: next_position = (agent_position[0], agent_position[1] + 1) #이전과 같은 위치에 있을 경우 prev_move=agent_position self.serched.append(prev_move) # 벽에 부딪혔을 경우 if self.maze[next_position[0]][next_position[1]] == 1: if self.visualize: if self.view_episode<=current_episode: font=pygame.font.Font(None, 50) text=font.render("Fail",True,WHITE) screen.blit(text,(next_position[1]*50,next_position[0]*50)) self.failed+=1 next_position = agent_position reward = -1000 if self.maze[next_position[0]][next_position[1]] == 2: agent_position = next_position if self.move>self.shortest_path: reward=-100 # 보상 계산 if next_position == (self.maze_height - 2, self.maze_width - 1): reward = 100 elif next_position in self.serched: if self.visualize: if self.view_episode<=current_episode: font=pygame.font.Font(None, 50) text=font.render("Fail",True,WHITE) screen.blit(text,(next_position[1]*50,next_position[0]*50)) reward= -10000 # Q 테이블 업데이트 self.q_table[agent_position][action_index] += self.learning_rate * ( reward + self.discount_factor * np.max(self.q_table[next_position]) - self.q_table[agent_position][action_index] ) # 에이전트 위치 업데이트 agent_position = next_position # 화면 업데이트 if self.visualize: if self.view_episode<=current_episode: #reward표시 font = pygame.font.SysFont('malgungothic', 30) text = font.render("current_episode"+str(current_episode), True, GREEN) screen.blit(text, (10, 70)) font = pygame.font.SysFont('malgungothic', 30) text = font.render("move"+str(self.move), True, GREEN) screen.blit(text, (10, 10)) font = pygame.font.SysFont('malgungothic', 30) text = font.render("action"+str(action), True, GREEN) screen.blit(text, (10, 40)) text=font.render("Failed:"+str(self.failed),True,WHITE) screen.blit(text,(10,100)) pygame.display.flip() clock.tick(10)# # 에피소드 종료 체크 if agent_position == (self.maze_height - 2, self.maze_width - 1): print("Episode", current_episode, "completed") current_episode += 1 agent_position = (0, 1) self.failed=0 print(self.move) if self.shortest_path+1>=self.move: self.exploration_rate=0 # self.visualize=True # self.view_episode=current_episode # if self.shortest_path==self.move: # print("최단경로입니다") # return current_episode else: self.exploration_rate=self.exploration_rate*0.99 self.move=0 print(self.shortest_path) prev_move=0 self.serched=[] if self.visualize: if self.view_episode<=current_episode: print(self.exploration_rate) pygame.time.wait(10) # 모든 에피소드가 완료되면 종료 if current_episode > self.ep_num: running = False # 게임 종료 pygame.quit() if __name__ == '__main__': paser=argparse.ArgumentParser() paser.add_argument('--maze_size',type=int,default=3) paser.add_argument('--view_episode',type=int,default=900) paser.add_argument('--visualize',type=bool,default=False) paser.add_argument('--episode',type=int,default=1000) args=paser.parse_args() Maze_ai(args.maze_size,args.view_episode,args.visualize,args.episode)
env.py
(미로 생성해주는 모듈)import random class Room: def __init__(self, x, y): self.x = x self.y = y self.dir = [(x + 1, y), (x - 1, y), (x, y + 1), (x, y - 1)] random.shuffle(self.dir) def get_cur_pos(self): return self.x, self.y def get_next_pos(self): return self.dir.pop() def make_maze(size): rooms = [[Room(x, y) for x in range(size)] for y in range(size)] maze = [[1 for _ in range(size * 2 + 1)] for _ in range(size * 2 + 1)] visited = [] def make(cur_room): cx, cy = cur_room.get_cur_pos() visited.append((cx, cy)) maze[cy * 2 + 1][cx * 2 + 1] = 0 while cur_room.dir: nx, ny = cur_room.get_next_pos() if 0 <= nx < size and 0 <= ny < size: if (nx, ny) not in visited: maze[cy + ny + 1][cx + nx + 1] = 0 make(rooms[ny][nx]) make(rooms[0][0]) return maze
calculate.py
(q테이블을 자동으로 평가하기위해 최단거리를 계산하는 모듈입니다)
from collections import deque class Calculate(): def __init__(self, maze): self.maze = maze self.shortest_path = self.calculate_shortest_path() print("Shortest Path:", self.shortest_path) def return_shortest_path(self): return self.shortest_path # 최단거리 계산 함수 def calculate_shortest_path(self): maze = self.maze start_position = (0, 1) end_position = (len(maze) - 2, len(maze[0]) - 2) # 방문 여부를 저장하는 배열 visited = [[False] * len(maze[0]) for _ in range(len(maze))] # 이동 방향 (상, 하, 좌, 우) directions = [(-1, 0), (1, 0), (0, -1), (0, 1)] # 시작 위치를 큐에 추가 queue = deque([(start_position, 0)]) visited[start_position[0]][start_position[1]] = True while queue: position, distance = queue.popleft() if position == end_position: return distance for direction in directions: next_position = (position[0] + direction[0], position[1] + direction[1]) if ( 0 <= next_position[0] < len(maze) and 0 <= next_position[1] < len(maze[0]) and maze[next_position[0]][next_position[1]] == 0 and not visited[next_position[0]][next_position[1]] ): queue.append((next_position, distance + 1)) visited[next_position[0]][next_position[1]] = True return -1 # 최단거리가 없을 경우 -1 반환
실행 방법
python maze_ai.py --maze_size 9 --view_episode 900 --visualize true --episode 1000
'프로그래밍 > python' 카테고리의 다른 글
gaze-tracking 시선추적 python (3) 2022.11.14 window에서 mediapipe(미디어파이프)사용하기 (0) 2022.09.20 python back-end Django (0)-환경 준비 (0) 2021.12.03