Skip to content

Commit

Permalink
Replace signal-based MT closing with a TCP message
Browse files Browse the repository at this point in the history
  • Loading branch information
Mikel committed Jul 18, 2024
1 parent d44bd0e commit daa8ab4
Show file tree
Hide file tree
Showing 5 changed files with 43 additions and 60 deletions.
18 changes: 10 additions & 8 deletions craftium/craftium_env.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
import os
from typing import Optional, Any
import time

from .mt_channel import MtChannel
from .minetest import Minetest

import numpy as np

from gymnasium import Env
from gymnasium.spaces import Dict, Discrete, Box

Expand Down Expand Up @@ -125,10 +123,11 @@ def reset(
super().reset(seed=seed)
self.timesteps = 0

# close the active (if any) channel with mintest
self.mt_chann.close_conn()
# kill the active mt process if there's any
self.mt.kill_process()
if self.mt_chann.conn is not None:
self.mt_chann.send_termination()
self.mt_chann.close_conn()
self.mt.close_pipes()
self.mt.wait_close()

# start the new MT process
self.mt.start_process()
Expand Down Expand Up @@ -198,6 +197,9 @@ def render(self):
return self.last_observation

def close(self):
self.mt_chann.close()
self.mt.kill_process()
if self.mt_chann.conn is not None:
self.mt_chann.send_termination()
self.mt_chann.close()
self.mt.close_pipes()
self.mt.wait_close()
self.mt.clear()
22 changes: 5 additions & 17 deletions craftium/minetest.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
import os
from typing import Optional, Any
import subprocess
import multiprocessing
from uuid import uuid4
import shutil
import signal
import atexit


def is_minetest_build_dir(path: os.PathLike) -> bool:
Expand Down Expand Up @@ -127,17 +124,6 @@ def __init__(
if headless:
self.mt_env["SDL_VIDEODRIVER"] = "offscreen"

# register the cleanup function to be called on exit
atexit.register(self._kill_proc)

def _kill_proc(self):
if self.proc is not None:
# send kill signal to the process
os.kill(self.proc.pid, signal.SIGKILL)
# wait for the process to finish (timeout at 30s)
self.proc.wait(timeout=30)
self.proc = None

def start_process(self):
if self.pipe_proc:
# open files for piping stderr and stdout into
Expand All @@ -160,16 +146,18 @@ def start_process(self):
**kwargs,
)

def kill_process(self):
def wait_close(self):
if self.proc is not None:
self.proc.wait()

def close_pipes(self):
# close the files where the process is being piped
# into berfore the process itself
if self.stderr is not None:
self.stderr.close()
if self.stdout is not None:
self.stdout.close()

self._kill_proc()

def clear(self):
# delete the run's directory
if os.path.exists(self.run_dir):
Expand Down
10 changes: 6 additions & 4 deletions craftium/mt_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,16 +49,18 @@ def receive(self):

return img, reward, termination

def send(self, keys: list[int], mouse_x: int, mouse_y: int):
def send(self, keys: list[int], mouse_x: int, mouse_y: int, terminate: bool = False):
assert len(keys) == 21, f"Keys list must be of length 21 and is {len(keys)}"

mouse = list(struct.pack("<h", mouse_x)) + list(struct.pack("<h", mouse_y))

self.conn.sendall(bytes(keys + mouse))
self.conn.sendall(bytes(keys + mouse + [1 if terminate else 0]))

def send_termination(self):
self.send(keys=[0]*21, mouse_x=0, mouse_y=0, terminate=True)

def close(self):
if self.conn is not None:
self.conn.close()
self.close_conn()
self.s.close()

def close_conn(self):
Expand Down
23 changes: 21 additions & 2 deletions src/client/client.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,20 @@ void Client::startPyConn()
exit(EXIT_FAILURE);
}

/*
struct timeval timeout;
timeout.tv_sec = 2; // timeout time in seconds
timeout.tv_usec = 0;
if (setsockopt(py_sockfd, SOL_SOCKET, SO_RCVTIMEO, (const char*)&timeout, sizeof(timeout)) < 0) {
perror("[ERROR] PyConn setsockopt failed");
exit(EXIT_FAILURE);
}
if (setsockopt(py_sockfd, SOL_SOCKET, SO_SNDTIMEO, (const char*)&timeout, sizeof(timeout)) < 0) {
perror("[ERROR] PyConn setsockopt failed");
exit(EXIT_FAILURE);
}
*/

py_servaddr = (struct sockaddr_in*) malloc(sizeof(struct sockaddr_in));

memset(py_servaddr, 0, sizeof(*py_servaddr));
Expand All @@ -186,7 +200,7 @@ void Client::startPyConn()
}

void Client::pyConnStep() {
char actions[25];
char actions[26];
int n_send, n_recv, W, H, obs_rwd_buffer_size;
u32 c; // stores the RGBA pixel color

Expand Down Expand Up @@ -327,7 +341,12 @@ void Client::pyConnStep() {
/* If sending or receiving went wrong, print an error message and quit */
if (n_send + n_recv < 2) {
printf("[!!] Python client disconnected. Shutting down...\n");
exit(43);
exit(EXIT_FAILURE);
}

if (actions[25]) {
printf("[NOTE] Termination signal received, exiting...\n");
exit(0);
}
}

Expand Down
30 changes: 1 addition & 29 deletions src/client/craftium.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,21 +49,6 @@ inline int syncServerInit() {
address.sin_addr.s_addr = INADDR_ANY;
address.sin_port = htons(0); // let the OS choose an empty port

// Set receive and send timeout on the socket
struct timeval timeout;
timeout.tv_sec = 10; // timeout time in seconds
timeout.tv_usec = 0;
if (setsockopt(server_fd, SOL_SOCKET, SO_RCVTIMEO, (const char*)&timeout, sizeof(timeout)) < 0) {
perror("[SyncServer] setsockopt failed");
exit(EXIT_FAILURE);
}

int opt = 1;
if (setsockopt(server_fd, SOL_SOCKET, SO_REUSEADDR, &opt, sizeof(opt)) < 0) {
perror("[SyncServer] setsockopt SO_REUSEADDR failed");
exit(EXIT_FAILURE);
}

// Bind the server's socket to a port
if (bind(server_fd, (struct sockaddr*)&address,
sizeof(address))
Expand Down Expand Up @@ -117,14 +102,6 @@ inline int syncClientInit() {
exit(EXIT_FAILURE);
}

struct timeval timeout;
timeout.tv_sec = 10; // timeout time in seconds
timeout.tv_usec = 0;
if (setsockopt(sync_client_fd, SOL_SOCKET, SO_SNDTIMEO, (const char*)&timeout, sizeof(timeout)) < 0) {
perror("[SyncClient] setsockopt failed");
exit(EXIT_FAILURE);
}

// Connect to the server @ sync_port
if ((status
= connect(sync_client_fd, (struct sockaddr*)&serv_addr,
Expand All @@ -145,12 +122,7 @@ inline void syncServerStep() {
char msg[2];
if (read(sync_conn_fd, msg, 2) <= 0) {
perror("[syncServerStep] Step failed");
if (errno == EAGAIN) {
fprintf(stderr, "[SyncServerStep] Warning: timeout\n" );
} else {
fprintf(stderr, "Error code is %d, exiting...\n", errno);
exit(EXIT_FAILURE);
}
exit(EXIT_FAILURE);
}
}

Expand Down

0 comments on commit daa8ab4

Please sign in to comment.