Skip to content

Lab6 of AI computing Architecture and System (2024 spring) around riscv emulator and implementation of fibonacci, sudoku (2x2) and maxpool in RISC-V

License

Notifications You must be signed in to change notification settings

Shiritai/aias-lab6-spring-2024

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

11 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

NTHU_109062274_楊子慶 ACAL 2024 Spring Lab 6 HW Submission

tags: AIAS Spring 2024

HW6-1 - Fibonacci Series

Assembly Code

:::info Note: the return value of fibonacci is still store in a1. For better running experience, we print a1 with '\n' (which require us to override a1 with '\n') in main function, then exit.

  • Code snippet of printf("%d\n", a1)
    addi  a0, x0, 1
    ecall
    addi  a1, x0, '\n'
    addi  a0, x0, 11
    ecall
    
  • Code snippet of exit(0)
    addi  a0, x0, 10
    ecall
    

:::

## fibonacci.S
## put input n in register x10 (a0)
## put output fibonacci(n) in register x11 (a1)
## use Venus to test correctness

    .text
main:
## write assembly code here.
## call fibonacci function and get return value.
    li    a0, 16       # n = 16
    jal   fibonacci    # fib(16)

    addi  a0, x0, 1    # printf("%d\n, a1)
    ecall
    addi  a1, x0, '\n'
    addi  a0, x0, 11
    ecall

    addi  a0, x0, 10   # exit(0)
    ecall

fibonacci:
## fibonacci function
    mv    a1, a0       # if n == 0 or n == 1, return value is n and should be put at a1
    addi  t0, x0, 1
    ble   a1, t0, end  # if n <= 1 then end

recur:
    addi  sp, sp, -12  # initiate stack, -12 is for alignment
    sw    ra, 0(sp)    # push ra to stack
    sw    s0, 4(sp)    # push s0 to stack (save s0 before use)
    sw    s1, 8(sp)    # push s1 to stack (save s1 before use)

    mv    s0, a0       # s0 = n
    addi  a0, s0, -1   # reg a0: n - 1, as parameter of the next function call
    jal   fibonacci    # call fib(n - 1)
    mv    s1, a1       # save result of fib(n - 1) to s1

    addi  a0, s0, -2   # reg a0: n - 2, as parameter of the next function call
    jal   fibonacci    # call fib(n - 2)

    add   a1, s1, a1   # reg a1: fib(n - 1) + fib(n - 2), as return value

    lw    ra, 0(sp)    # load ra back
    lw    s0, 4(sp)    # load s0 back
    lw    s1, 8(sp)    # load s1 back
    addi  sp, sp, 12   # pop stack

end:
    jr    ra           # jump to ra

Simulation Result

:::success Note: the result is not in a1, but at the output section: $987$. :::

HW6-2 - Fibonacci Series with C/Assembly Hybrid

Assembly Code

Just change the register to store return value from a0 to a1, and everything will work fine :)

## fibonacci.S

    .text                                  # code section
    .global fibonacci_asm                  # declar the sum_asm function as a global function
    .type   fibonacci_asm, @function       # define sum_asm as a function

fibonacci_asm:
## fibonacci function
    addi    t0, x0, 1
    ble     a0, t0, end                    # if n <= 1 then end

recur:
    addi    sp, sp, -12                    # initiate stack, -12 is for alignment
    sw      ra, 0(sp)                      # push ra to stack
    sw      s0, 4(sp)                      # push s0 to stack (save s0 before use)
    sw      s1, 8(sp)                      # push s1 to stack (save s1 before use)

    mv      s0, a0                         # s0 = n
    addi    a0, s0, -1                     # reg a0: n - 1, as parameter of the next function call
    jal     fibonacci_asm                  # call fib(n - 1)
    mv      s1, a0                         # save result of fib(n - 1) to s1

    addi    a0, s0, -2                     # reg a0: n - 2, as parameter of the next function call
    jal     fibonacci_asm                  # call fib(n - 2)

    add     a0, s1, a0                     # reg a0: fib(n - 1) + fib(n - 2), as return value

    lw      ra, 0(sp)                      # load ra back
    lw      s0, 4(sp)                      # load s0 back
    lw      s1, 8(sp)                      # load s1 back
    addi    sp, sp, 12                     # pop stack

end:
    jr      ra                             # jump to ra

    .size   fibonacci_asm, .-fibonacci_asm

Simulation Result

Just like what we've done in lab{2,4,5}, one can use shell script to run simulation:

./scripts/run rvc_pair hw2/fibonacci # compile
./scripts/run sim hw2/fibonacci # simulate

and see the result like:

C code fibonacci_c=8
ASM code fibonacci_asm=8
Baremetal code runs to the end. (pressing `ctrl+a x` to quit)
QEMU: Terminated

HW6-3 - 2x2 Sudoku

main.c

In main.c we design two helper functions: print_meow and print_val for debugging. These functions are handy when it comes to cases that we want to peek values in the context of assemply code.

  • print_meow: can print Meow to make sure that assembly runs the order as we expected.
    jal print_meow
    
  • print_val: can peek value of register a0
    mv a0, REGISTER_TO_PEEK
    jal print_val
    
  • print_single_result: can peek the whole real-time sudoku table
    mv a0, s
    jal print_single_result
    
    :::info s8 is treated as common register (like global variables) among the whole process of sudoku_2x2_asm function, storing the pointer of sudoku table. ::: At the same time, I refector print_sudoku_result which reuse print_single_result to reduce reduncancy.

sudoku_2x2_c.c

Overview

Before implementing the code in c or assembly, one should understand the algorithm to implement first. Fortunately, TA has provided us the psudocode. We translated it into C as:

static inline int solve(char *ls, int idx) {
  if (idx == bound) {
    return true;
  }

  if (ls[idx] > 0) {
    return solve(ls, idx + 1);
  } else {
    for (char n = 1; n <= side_lim; ++n) {
      ls[idx] = n; // mark
      if (check(ls, idx) && solve(ls, idx + 1)) {
        return true;
      }
    }
    ls[idx] = 0;
    return false;
  }
}

:::success Note: to write a beautiful sudoku code that can also solve 3x3 or bigger cases (bounded by max bit-length of a single register), I defined three global constants at the front of c code:

const int box_width = 2;
const int side_lim = box_width * box_width;
const int bound = side_lim * side_lim;

Where box_width is the length of the single side of a number box; side_lim is the limit of single side of the 2d table (well, use 1d array as container); bound is the number of numbers in sudoku.

Since we're implementing 2x2 sudoku, the box_width is then become $2$, and side_lim as well as bound will be automatically determined in compile time. Elegant! :::

With solve, the sudoku2x2_c can be easily implemented as:

int sudoku_2x2_c(char *ls) { return solve(ls, 0); }

:::info Notice that to make sudoku_2x2_c being able to deliver the information that whether the solving process solved successfully, we change the return value of sudoku_2x2_c into int (0 or 1). :::

Then, we should finish the algorithm of check.

check algorithm

Before implementation

In order to check whether a number in some index is valid or not, we should check whether there exist number confliction in horizontal, vertical and nearby (in shome box) manners.

Before introducing the algorithm, I shall first introduce two simple helper constants: x and y, which is the mapping of 1d index onto 2d index.

const int x = idx % side_lim;
const int y = idx / side_lim;

And for recorder of number confliction, since there are just at most $4$ different numbers, one can use any data type with over $4$ bits. We use a single integer with $32$ bits so that this code can be runnable also on 3x3 sudoku.

int tb = 0; // bit table to check confliction

So the structure of check is:

static inline int check(char *ls, int idx) {
  // constant definition

  // horizontal check

  // vertical check

  // box check

  // until here, not confliction...
  return true;
}
Horizontal check

Then, for the horizontal check, the algorithm is quite easy: scan along the same y axis by side_lim.

// in check
// horizontal scan by fixing y (in single dimension version)
int base = y * side_lim;
int tb = 0;
for (int i = base; i < base + side_lim; ++i) {
  if (!check_and_mark(ls[i], &tb)) {
    return false;
  }
}

We can see that check_and_mark is the helper function that trully deals with number confliction detection. We'll introduce it later.

Vertical check

For the vertical check, the algorithm is still quite easy: scan along the same x axis by increment index with side_lim with side_lim times.

// in check
// vertical scan by fixing x
base = x;
tb = 0;
for (int i = base; i < bound; i += side_lim) {
  if (!check_and_mark(ls[i], &tb)) {
    return false;
  }
}
Box check

This one is a bit tricky. Our algorithm is to map the x, y axis to baseX, baseY axis that is indexed according to box, not number.

For example, 2x2 sudoku has "four" number box, which lead to baseX and baseY have value $(0, 0)$ until $(1, 1)$.

The implementation of axis mapping is suprisingly simple, using the integer division and multiply them back w.r.t. box_width:

int baseX = (x / box_width) * box_width;
int baseY = (y / box_width) * box_width;

We use nested loop to traverse the numbers in our target box. In each iteration, we re-map back the axis from 2d to 1d to access the sudoku table:

// box scan by indirect index mapping
tb = 0;
int baseX = (x / box_width) * box_width;
int baseY = (y / box_width) * box_width;
for (int i = 0; i < box_width; ++i) {
  int offset = (i + baseY) * side_lim + baseX;
  for (int j = 0; j < box_width; ++j) {
    if (!check_and_mark(ls[offset + j], &tb)) {
      return false;
    }
  }
}

:::success Note: we act as a human-body compiler optimizer to do loop-invariant code motion so that we can implement the according assembly code happier! :::

check_and_mark

The meaning of this function is to reduce redundancy of writing the same code to check number confliction for horizontal, vertical and box manners in check function. I generalize the behavior of:

  • Check if a number is not zero or...
  • Whether a number conflicts with recorded numbers
  • If yes, returns false.
  • Otherwise, record current number and returns true.

Take the advantage of bitwise operations, our code can be so clean as:

static inline int check_and_mark(char shift, int *tb) {
  int hot = 1 << shift;
  if (shift && ((*tb) & hot))
    return false;
  *tb = (*tb) | hot;
  return true;
}

:::info Note: to modify a number of the other function, the simplest way in c is to use a pointer. However, we can do some tricks in assembly code... :::

sudoku_2x2_asm.S

The methodology of implementation is... do what we've done in C code with some rearrangements.

sudoku_2x2_asm

Just like sudoku_2x2_c, we initialize the recursive code of solve, but with the initialization of common constants stored in s8-s11 registers, as the global variables of the whole algorithm.

:::info This reduce the number of arguments to be passed to solve and check.

Note: we should use saved registers so that they'll be fine even if we call the other function (e.g. functions to print value for debugging). :::

sudoku_2x2_asm:

# solve(idx = 0)
# prologue: push stack
    addi    sp, sp, -20
    sw      ra, 0(sp)
    sw      s8, 4(sp)                        # will be use as global constant
    sw      s9, 8(sp)                        # will be use as global constant
    sw      s10, 12(sp)                      # will be use as global constant
    sw      s11, 16(sp)                      # will be use as global constant

# setup global variables for the whole
# recursive call to temperal registers
    mv      s8, a0                           # s8 = ls (char array)
    li      s9, 2                            # s9 (box_width) = 2
    mul     s10, s9, s9                      # s10 (side_lim) = 4 = box_width ^ 2
    mul     s11, s10, s10                    # s11 (bound) = 16 = side_lim ^ 2

# funciton call
    li      a0, 0                            # start to "solve" from index 0
    jal     solve
# epilogue: pop stack
    lw      ra, 0(sp)
    lw      s8, 4(sp)
    lw      s9, 8(sp)
    lw      s10, 12(sp)
    lw      s11, 16(sp)
    addi    sp, sp, 20
    ret

solve

:::info I think the naming style and comments in this function is enough to explain what we've done :) :::

solve: # solve(a0 = idx)
    beq     a0, s11, solve_ret_true          # if (idx == bound) return true
# prologue for all function calls: push stack
    addi    sp, sp, -20
    sw      ra, 0(sp)
    sw      s0, 4(sp)                        # we need s0 as current idx (a0)
    sw      s1, 8(sp)                        # we need s1 as n in loop
    sw      s2, 12(sp)                       # we need s2 as address of ls[idx]
    sw      s3, 16(sp)                       # we need s3 as value of ls[idx]

    mv      s0, a0                           # s0 = idx

    add     s2, s8, s0                       # s2 = address of ls[idx]
    lb      s3, 0(s2)                        # s3 = ls[idx]

    bne     s3, x0, call_solve_nxt_and_ret   # if ls[idx] > (!=) 0, then call_solve_nxt_and_ret
# for loop
    li      s1, 1                            # for n = 1
loop:
    sb      s1, 0(s2)                        # ls[idx] = n

    mv      a0, s0                           # a0 = idx
    jal     check                            # check(idx)

    beq     a0, x0, to_nxt_loop              # if check returns invalid, then prepare next loop
# solve(idx = idx + 1)
    addi    a0, s0, 1                        # a0 = idx + 1
    jal     solve
    bne     a0, x0, solve_epilogue_ret_true  # if recursive solve returns valid, then returns true
to_nxt_loop:
    addi    s1, s1, 1                        # n += 1
    bgt     s1, s10, end_loop                # if n > side_lim (4), end loop
    j       loop
end_loop:
    sb      x0, 0(s2)                        # ls[idx] = 0
    j       solve_epilogue_ret_false         # return false

call_solve_nxt_and_ret:
# solve(idx = idx + 1)
# funciton call
    addi    a0, s0, 1                        # a0 = idx + 1
    jal     solve
# epilogue: before return, pop stack
    lw      ra, 0(sp)
    lw      s0, 4(sp)
    lw      s1, 8(sp)
    lw      s2, 12(sp)
    lw      s3, 16(sp)
    addi    sp, sp, 20
# return a0 = solve(idx = idx + 1)
    ret

solve_epilogue_ret_true:
# epilogue: before return, pop stack
    lw      ra, 0(sp)
    lw      s0, 4(sp)
    lw      s1, 8(sp)
    lw      s2, 12(sp)
    lw      s3, 16(sp)
    addi    sp, sp, 20

solve_ret_true:
    li      a0, 1                            # a0 = true
    ret

solve_epilogue_ret_false:
    li      a0, 0                            # a0 = false
# epilogue: before return, pop stack
    lw      ra, 0(sp)
    lw      s0, 4(sp)
    lw      s1, 8(sp)
    lw      s2, 12(sp)
    lw      s3, 16(sp)
    addi    sp, sp, 20
    ret

check

This function is a bit long. We can see the structure of assembly is:

check: # check(a0 = idx)
# prologue for all function calls
    addi    sp, sp, -28
    sw      ra, 0(sp)
    sw      s0, 4(sp)                        # will be used as tb, used as a shared register between check and check_and_mark
    sw      s1, 8(sp)                        # will be used as x and baseX
    sw      s2, 12(sp)                       # will be used as y and baseY
    sw      s3, 16(sp)                       # will be used as iterator (or i in nested loop)
    sw      s4, 20(sp)                       # will be used as iterator limit (or j in nested loop)
    sw      s5, 24(sp)                       # will be used as offset in nested loop

    rem     s1, a0, s10                      # int x = idx % side_lim
    div     s2, a0, s10                      # int y = idx / side_lim

# horizontal scan by fixing y (in single dimension version)
# ...
# vertical scan by fixing x
# ...
# box scan by indirect index mapping
# ...

    li      a0, 1                            # return value: true
# epilogue: before return, pop stack
    lw      ra, 0(sp)
    lw      s0, 4(sp)
    lw      s1, 8(sp)
    lw      s2, 12(sp)
    lw      s3, 16(sp)
    lw      s4, 20(sp)
    lw      s5, 24(sp)
    addi    sp, sp, 28
    ret

check_ret_false:
    li      a0, 0                            # return value: false
# epilogue: before return, pop stack
    lw      ra, 0(sp)
    lw      s0, 4(sp)
    lw      s1, 8(sp)
    lw      s2, 12(sp)
    lw      s3, 16(sp)
    lw      s4, 20(sp)
    lw      s5, 24(sp)
    addi    sp, sp, 28
    ret

You can see that there are two tricks here:

  • Use s0 as tb for number confliction detection in normal-value manners, rather than using a pointer. This reduce the number of arguments to be passed to check_and_mark
  • Use the smallest numbers of local registers s1-s5 for looping three different cases.

For horizontal case, the implementation is:

# horizontal scan by fixing y (in single dimension version)
    mul     s3, s2, s10                      # int base = y * side_lim
    add     s4, s3, s10                      # let s4 as iteration limit = base + side_lim
    li      s0, 0                            # tb = 0 (initialize mark table)
check_h_loop: # loop for horizontal scanning
    add     a0, s8, s3                       # a0 = position of ls[i]
    lb      a0, 0(a0)                        # a0 = ls[i]
    jal     check_and_mark                   # check_and_mark(shift = ls[i])
    beq     a0, x0, check_ret_false          # if check_and_mark returns false, then check_ret_false

    addi    s3, s3, 1                        # i++
    blt     s3, s4, check_h_loop             # if i < base + side_lim then check_h_loop

and vertical case:

# vertical scan by fixing x
    mv      s3, s1                           # int base = x
    mv      s4, s11                          # let s4 as iteration limit = bound
    li      s0, 0                            # tb = 0 (initialize mark table)
check_v_loop: # loop for vertical scanning
    add     a0, s8, s3                       # a0 = position of ls[i]
    lb      a0, 0(a0)                        # a0 = ls[i]
    jal     check_and_mark                   # check_and_mark(shift = ls[i])
    beq     a0, x0, check_ret_false          # if check_and_mark returns false, then check_ret_false

    add     s3, s3, s10                      # i += side_lim
    blt     s3, s4, check_v_loop             # if i < bound then check_v_loop

and box case:

# box scan by indirect index mapping
    div     s1, s1, s9                       # s1 = x / box_width
    mul     s1, s1, s9                       # int baseX = (x / box_width) * box_width
    div     s2, s2, s9                       # s2 = y / box_width
    mul     s2, s2, s9                       # int baseY = (y / box_width) * box_width
    li      s0, 0                            # tb = 0 (initialize mark table)
# for (int i = 0; i < box_width; ++i)
    li      s3, 0                            # i = 0
check_b_first_loop: # first loop for box scanning
    add     s5, s3, s2                       # s5 = i + baseY
    mul     s5, s5, s10                      # s5 = (i + baseY) * side_lim
    add     s5, s5, s1                       # offset = (i + baseY) * side_lim + baseX

# for (int j = 0; j < box_width; ++j)
    li      s4, 0                            # j = 0
check_b_second_loop: # second loop for box scanning
    mv      a0, s8

    add     a0, s8, s5                       # a0 = ls + offset
    add     a0, a0, s4                       # a0 = position of ls[offset + j]
    lb      a0, 0(a0)                        # a0 = ls[i]
    jal     check_and_mark                   # check_and_mark(shift = ls[i])
    beq     a0, x0, check_ret_false          # if check_and_mark returns false, then check_ret_false

    addi    s4, s4, 1                        # j++
    blt     s4, s9, check_b_second_loop      # if j < box_width then check_b_second_loop
# end of j for loop

    addi    s3, s3, 1                        # i++
    blt     s3, s9, check_b_first_loop       # if i < box_width then check_b_first_loop
# end of i for loop

check_and_mark

:::info I think the naming style and comments in this function is enough to explain what we've done :) :::

check_and_mark: # check_and_mark(shift)
# prologue for all function calls
    addi    sp, sp, -12
    sw      ra, 0(sp)
    sw      s1, 4(sp)
    sw      s2, 8(sp)

    li      s1, 1
    sll     s1, s1, a0                       # int hot = 1 << shift

    beq     a0, x0, check_and_mark_ret_true  # if shift == 0, return true
    and     t0, s0, s1                       # t0 = tb & hot
    beq     t0, x0, check_and_mark_ret_true  # (tb & hot) == 0, no overlap --> return true

    li      a0, 0                            # return valuie: false

# epilogue: before return, pop stack
    lw      ra, 0(sp)
    lw      s1, 4(sp)
    lw      s2, 8(sp)
    addi    sp, sp, 12

    ret

check_and_mark_ret_true:
    or      s0, s0, s1                       # tb = tb | hot, i.e. mark table
    li      a0, 1                            # return valuie: true

# epilogue: before return, pop stack
    lw      ra, 0(sp)
    lw      s1, 4(sp)
    lw      s2, 8(sp)
    addi    sp, sp, 12

    ret

Simulation Result

One can compile the program with:

./scripts/run hw3c

and run with command:

./scripts/run sim hw3/hw3c

to get the result of:

Is sudoku solvable using C function? value: [1]
Is sudoku solvable using risc-v assembly? value: [1]
Output c & assembly function result
c result :
4 1 2 3
3 2 1 4
2 3 4 1
1 4 3 2

assembly result :
4 1 2 3
3 2 1 4
2 3 4 1
1 4 3 2

your c & assembly got same result!
Baremetal code runs to the end. (pressing `ctrl+a x` to quit)
QEMU: Terminated

Bonus

:::info I may finish this in... someday...

Note: partially finished! :::

Step 1: C Implementation

C code

According to the formula in Pytorch website, the implementation should be something like:

void maxpool2d_c(const int8_t *input_X, int32_t input_X_dimW,
                 int32_t input_X_dimH, int32_t input_X_dimC, int8_t *output_Y,
                 int32_t kernel_W, int32_t kernel_H, int32_t stride_W,
                 int32_t stride_H) {
  const int h_out = (input_X_dimH - kernel_H) / stride_H + 1;
  const int w_out = (input_X_dimW - kernel_W) / stride_W + 1;
  const int square_sz = input_X_dimH * input_X_dimW;
  const int out_square_sz = h_out * w_out;
  for (int c = 0; c < input_X_dimC; ++c) {
    for (int h = 0; h < h_out; ++h) {
      for (int w = 0; w < w_out; ++w) {
        int mx = 0;
        int h_offset = stride_H * h;
        int w_offset = stride_W * w;
        for (int kh = 0; kh < kernel_H; ++kh) {
          for (int kw = 0; kw < kernel_W; ++kw) {
            int h_idx = h_offset + kh;
            int w_idx = w_offset + kw;

            int8_t cur = input_X[h_idx * input_X_dimW + w_idx + square_sz * c];
            print_debug_msg("%d,\t", cur);
            /**
             * @brief Bitwise hack to find max of two numbers
             * ref:
             * http://graphics.stanford.edu/~seander/bithacks.html#IntegerMinOrMax
             */
            mx = mx ^ ((mx ^ cur) & -(mx < cur)); // max(mx, cur)
          }
        }
        print_debug_msg(" --> %d\n", mx);
        output_Y[h * w_out + w + out_square_sz * c] = mx;
      }
    }
  }
}

You can see that there exists several lines for debugging: print_debug_msg. It's defined and implemented in util.h, just like what we've done in lab02.

#ifndef UTIL_H
#define UTIL_H

#ifdef DEBUG_MODE
/**
 * @brief Print debug message if DEBUG_MODE is defined
 * otherwise, this will be optimized off
 */
#define print_debug_msg(...) \
  {                          \
    printf(__VA_ARGS__);     \
    fflush(stdout);          \
  }
#else
#define print_debug_msg(...) \
  {}
#endif

#endif /* UTIL_H */

Verification

To verify the function of maxpool2d_c, we can use the Pytorch version of maxpool with random input.

The testing structure is in single_c_py_random_test in main.c, it:

  • Generate random parameters and input
    const int dim_limit = 50;
    const int stride_limit = 4;
    const int kernel_limit = 4;
    
    const int32_t input_X_dimH = rand() % dim_limit + 4;
    const int32_t input_X_dimW = rand() % dim_limit + 4;
    const int32_t input_X_dimC = rand() % 3 + 1;
    
    const int32_t stride_H = rand() % stride_limit + 1;
    const int32_t stride_W = rand() % stride_limit + 1;
    
    const int32_t kernel_H = rand() % kernel_limit + 1;
    const int32_t kernel_W = rand() % kernel_limit + 1;
    
    int8_t *input = (int8_t *)malloc(sizeof(int8_t) * input_X_dimH *
                                     input_X_dimW * input_X_dimC);
    
    // ...
    
    const int32_t square_sz = input_X_dimW * input_X_dimH;
    for (int c = 0; c < input_X_dimC; ++c) {
      for (int h = 0; h < input_X_dimH; ++h) {
        for (int w = 0; w < input_X_dimW; ++w) {
          int8_t rand_val = rand() % INT8_MAX;
          input[h * input_X_dimW + w + square_sz * c] = rand_val;
          // ...
        }
        // ...
      }
      // ...
    }
    // ...
    
  • Write the generated parameters and input into file rand_input.txt
    // save array as file
    FILE *fp = fopen("Bonus/rand_input.txt", "w");
    fprintf(fp, "%d %d %d\n", input_X_dimC, input_X_dimH, input_X_dimW);
    fprintf(fp, "%d %d\n", stride_H, stride_W);
    fprintf(fp, "%d %d\n", kernel_H, kernel_W);
    print_debug_msg("input: %d %d %d\n", input_X_dimC, input_X_dimH,
                    input_X_dimW);
    print_debug_msg("stride: %d %d\n", stride_H, stride_W);
    print_debug_msg("kernel: %d %d\n", kernel_H, kernel_W);
    fprintf(fp, "[ ");
    
    for (int h = 0; h < input_X_dimH; ++h) {
      for (int w = 0; w < input_X_dimW; ++w) {
        int8_t rand_val = rand() % INT8_MAX;
        // ...
        fprintf(fp, "%d, ", rand_val);
      }
    }
    
    fprintf(fp, " ]");
    fclose(fp);
    
  • Run maxpool2d_c and write result
    maxpool2d_c(input, input_X_dimW, input_X_dimH, input_X_dimC, output, kernel_W,
                kernel_H, stride_W, stride_H);
    
    fp = fopen("Bonus/c_output.txt", "w"); // also save array as file
    
    const int32_t out_square_sz = h_out * w_out;
    for (int c = 0; c < input_X_dimC; ++c) {
      for (int h = 0; h < h_out; ++h) {
        for (int w = 0; w < w_out; ++w) {
          fprintf(fp, "%d\n", output[h * w_out + w + out_square_sz * c]);
        }
      }
    }
    fclose(fp);
    
  • Invoke python script to generate the golden and use diff command to check the correctness of output
    #ifdef DEBUG_MODE
      system("python3 Bonus/maxpool.py -d");
    #else
      system("python3 Bonus/maxpool.py");
    #endif
      system("command -v colordiff > /dev/null || (sudo apt update && sudo apt install "
           "colordiff)");
      /**
       * if no output, then everything works fine :)
       */
      int res = system("colordiff Bonus/c_output.txt Bonus/py_output.txt");
    

We run the random tests $20$ times by default, you can change to what you like, also the parameters to generate random cases.

void maxpool_c_py_test() {
  const int tests = 20;
  int cnt = 0;
  for (int i = 0; i < tests; ++i) {
    cnt += single_c_py_random_test();
  }
  if (!cnt) {
    printf("Passed %d random tests\n", tests);
  } else {
    printf("%d WA within %d random tests\n", cnt, tests);
  }
}

int main() {
  maxpool_c_py_test();
  // ...
}

And for the python code, the framework can be:

# ...
ls = []
input_X_dimC, input_X_dimH, input_X_dimW = 0, 0, 0
strideH, strideW = 0, 0
kernelH, kernelW = 0, 0
# read parameters and random list
with open("Bonus/rand_input.txt", "r") as rand_file:
    it = rand_file.readlines()
    input_X_dimC, input_X_dimH, input_X_dimW = ls_to_int(it[0].strip().split())
    strideH, strideW = ls_to_int(it[1].strip().split())
    kernelH, kernelW = ls_to_int(it[2].strip().split())
    ls = ls_to_int(eval(it[3]))
# ...
# run maxpool algorithm
ls = torch.tensor(ls).reshape((input_X_dimC, input_X_dimH, input_X_dimW))
# ...
maxpool = torch.nn.MaxPool2d(kernel_size=(kernelH, kernelW), stride=(strideH, strideW))
ls: torch.Tensor = maxpool.forward((ls))
ls = torch.flatten(ls)
# ...
# write result to file
with open("Bonus/py_output.txt", "w") as f:
    for n in ls.numpy():
        f.write(f"{n}\n")

One can run code with:

./scripts/run bonus_c

and see the all-passed information:

Passed 20 random tests

Step 2: Assembly Implementation

The implementation of maxpool2d is not that hard.

## maxpool2d.S

    .text                            # code section
    .global maxpool2d_asm            # declar the sum_asm function as a global function
    .type   maxpool2d_asm, @function # define maxpool2d_asm as a function

maxpool2d_asm:
# maxpool2d_c(
#const int8_t *input_X = a0
# int32_t input_X_dimW = a1
# int32_t input_X_dimH = a2
# int32_t input_X_dimC = a3
# int8_t *output_Y = a4
# int32_t kernel_WH = a5 -> kernel_H
# int32_t stride_WH = a6 -> stride_H
# )
# prologue: push stack
    addi    sp, sp, -52
    sw      ra, 0(sp)
    sw      s0, 4(sp)                # kernel_W
    sw      s1, 8(sp)                # stride_W
    sw      s2, 12(sp)               # h_out
    sw      s3, 16(sp)               # w_out
    sw      s4, 20(sp)               # square_sz
    sw      s5, 24(sp)               # out_square_sz
    sw      s6, 28(sp)               # c
    sw      s7, 32(sp)               # c_offset
    sw      s8, 36(sp)               # out_c_offset
    sw      s9, 40(sp)               # h
    sw      s10, 44(sp)              # w
    sw      s11, 48(sp)              # mx
# t0: h_offset
# t1: w_offset
# t2: kh
# t3: kw
# t4: h_idx -> idx -> position of input_X[idx] -> cur
# t5: w_idx -> o_idx -> position of output_Y[o_idx], then assign value

    srli    s0, a5, 16               # s0 = kernel_W
    li      t0, 0x0ffff
    and     a5, a5, t0               # a5 = kernel_H (mask with 0x0000ffff)

    srli    s1, a6, 16               # s1 = stride_W
    and     a6, a6, t0               # a5 = stride_H (mask with 0x0000ffff)

    sub     s2, a2, a5
    div     s2, s2, a6
    addi    s2, s2, 1                # s2 = h_out

    sub     s3, a1, s0
    div     s3, s3, s1
    addi    s3, s3, 1                # s3 = w_out

    mul     s4, a1, a2               # s4 = square_sz
    mul     s5, s2, s3               # s4 = out_square_sz

# for (int c = 0; c < input_X_dimC; ++c) {
    li      s6, 0                    # c = 0
loopC:
    beq     s6, a3, endLoopC         # if c == input_X_dimC, end
    mul     s7, s4, s6               # s7 = c_offset
    mul     s8, s5, s6               # s8 = out_c_offset

# for (int h = 0; h < h_out; ++h) {
    li      s9, 0                    # h = 0
loopH:
    beq     s9, s2, endLoopH         # if h == h_out, end

# for (int w = 0; w < w_out; ++w) {
    li      s10, 0                   # w = 0
loopW:
    beq     s10, s3, endLoopW        # if w == w_out, end

    li      s11, 0                   # mx = 0
    mul     t0, a6, s9               # t0 = h_offset
    mul     t1, s1, s10              # t1 = w_offset

    li      t2, 0                    # kh = 0
loopKH:
    beq     t2, a5, endLoopKH        # if kh == kernel_H, end

    li      t3, 0                    # kw = 0
loopKW:
    beq     t3, s0, endLoopKW        # if kw == kernel_W, end

    add     t4, t0, t2               # t4 = h_idx
    add     t5, t1, t3               # t5 = w_idx

    mul     t4, t4, a1               # h_idx * input_X_dimW
    add     t4, t4, t5               # h_idx * input_X_dimW + w_idx
    add     t4, t4, s7               # idx = h_idx * input_X_dimW + w_idx + c_offset
    add     t4, t4, a0               # position of input_X[idx]
    lb      t4, 0(t4)                # t4 = cur = input_X[idx]

# get max using branching
    bge     s11, t4, noMaxChange     # if mx > cur -> no change
    mv      s11, t4                  # otherwise, mx = cur
noMaxChange:
    addi    t3, t3, 1                # kw++
    j       loopKW

endLoopKW:
    addi    t2, t2, 1                # kh++
    j       loopKH

endLoopKH:
# store local max
    mul     t5, s9, s3               # h * w_out
    add     t5, t5, s10              # h * w_out + w
    add     t5, t5, s8               # o_idx = h * w_out + w + out_c_offset
    add     t5, t5, a4               # position of outputY[o_idx]
    sb      s11, 0(t5)               # outputY[o_idx] = mx

    addi    s10, s10, 1              # w++
    j       loopW

endLoopW:
    addi    s9, s9, 1                # h++
    j       loopH

endLoopH:
    addi    s6, s6, 1                # c++
    j       loopC

endLoopC:
# epilogue: pop stack
    lw      ra, 0(sp)
    lw      s0, 4(sp)
    lw      s1, 8(sp)
    lw      s2, 12(sp)
    lw      s3, 16(sp)
    lw      s4, 20(sp)
    lw      s5, 24(sp)
    lw      s6, 28(sp)
    lw      s7, 32(sp)
    lw      s8, 36(sp)
    lw      s9, 40(sp)
    lw      s10, 44(sp)
    lw      s11, 48(sp)
    addi    sp, sp, 52
    ret

    .size   maxpool2d_asm, .-maxpool2d_asm

There exists some tricks that simplify and optimize the implementation.

  • We combine kernel_H and kernel_W, stride_H and stride_W as single argument. Since risv-c can only have at most 8 arguments and there are 9 arguments in maxpool_c, with this trickm we can shrink the number of arguments to 7, which can be passed without engaging stack operations.
  • Just like Hw6-3, we the human compiler can use loop invariant code motion to reduce redundant operations. We apply this to all the possible instance.
  • All local paraneters are manipulate using registers only. This speed up the whole process without accessing memory redundantly.

:::info Note: While in C we use bitwise tricks to implement branchless max operation, we use normal branch in assembly, since this the number of operations in branchless version is a bit... long. :::

Verification

One can run script to compile and run the code with:

./scripts/run bonus_asm
./scripts/run sim Bonus/bonus_asm

and see the passed message:

Passed 100 random tests
Baremetal code runs to the end. (pressing `ctrl+a x` to quit)
QEMU: Terminated

The implementation of a little bit different from c-python version, since the we can't use srand(time(NULL)) and malloc. However, we can still apply random test by initializing arrays with the space in some range limitation. Within that limitation, we conduct all the random tests.

Code below shows how we run the programs and verify the result.

// in single_c_asm_random_test
int8_t c_output[max_input_X_dimH * max_input_X_dimW * max_input_X_dimC];
int8_t asm_output[max_input_X_dimH * max_input_X_dimW * max_input_X_dimC];

maxpool2d_c(input, input_X_dimW, input_X_dimH, input_X_dimC, c_output,
            kernel_W, kernel_H, stride_W, stride_H);


maxpool2d_asm(input, input_X_dimW, input_X_dimH, input_X_dimC, asm_output,
              (kernel_W << (sizeof(int32_t) * 4)) | kernel_H,
              (stride_W << (sizeof(int32_t) * 4)) | stride_H);

int wa_cnt = 0;
  char str[25];

  const int32_t out_square_sz = h_out * w_out;
  for (int c = 0; c < input_X_dimC; ++c) {
    for (int h = 0; h < h_out; ++h) {
      for (int w = 0; w < w_out; ++w) {
        int idx = h * w_out + w + out_square_sz * c;
        if (c_output[idx] != asm_output[idx]) {
          puts("c: [");
          itoa(c_output[idx], str, 10);
          puts(str);
          puts("] != asm: [");
          itoa(asm_output[idx], str, 10);
          puts(str);
          puts("]\n");
          wa_cnt += 1;
        }
      }
    }
  }

  return wa_cnt;

Step 3 Performance Optimization

This is mentioned in the implementation of assembly code :)

About

Lab6 of AI computing Architecture and System (2024 spring) around riscv emulator and implementation of fibonacci, sudoku (2x2) and maxpool in RISC-V

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published