diff options
Diffstat (limited to 'stm32/unittest/scripts/tst_ofdm_demod_check')
| -rwxr-xr-x | stm32/unittest/scripts/tst_ofdm_demod_check | 503 |
1 files changed, 503 insertions, 0 deletions
diff --git a/stm32/unittest/scripts/tst_ofdm_demod_check b/stm32/unittest/scripts/tst_ofdm_demod_check new file mode 100755 index 0000000..06e471c --- /dev/null +++ b/stm32/unittest/scripts/tst_ofdm_demod_check @@ -0,0 +1,503 @@ +#!/usr/bin/env python3 +""" tst_ofdm_demod_check + + Testing for tst_ofdm_demod_* tests + + Usage tst_ofdm_demod_check <dummy_test_name> quick|ideal|AWGN|fade|profile|ldpc|ldpc_AWGN|ldpc_fade + + Checks are different for each option, but similar + + - Convert stm32 output to octave text format + (stm32 does not have memory for this) + + - ... + + """ + +import numpy as np +import math +import argparse +import struct +import os +import sys + +if ("UNITTEST_BASE" in os.environ): + sys.path.append(os.environ["UNITTEST_BASE"] + "/lib/python") +else: + sys.path.append("../../lib/python") # assume in test run dir + +import sum_profiles + +Nbitsperframe = 238 + +############################################################################## +# Read Octave text file +############################################################################## + +def read_octave_text(f): + if (args.verbose): print('read_octave_text()') + data = {} + for line in f: + if (line[0:8] == "# name: "): + var = line.split()[2] + if (args.verbose): print(' var "{}"'.format(var)) + line = next(f) + if (line.startswith("# type: matrix")): + line = next(f) + rows = int(line.split()[2]) + line = next(f) + cols = int(line.split()[2]) + if (cols > 0): + data[var] = np.empty((rows, cols), np.float32) + # Read rows one at a time + for row in range(rows): + try: + line = next(f) + data[var][row] = np.fromstring(line, np.float32, cols, " ") + except: + print("Error reading row {} of var {}".format(row, var)) + raise + + elif (line.startswith("# type: complex matrix")): + line = next(f) + rows = int(line.split()[2]) + line = next(f) + cols = int(line.split()[2]) + if (cols > 0): + data[var] = np.empty((rows, cols), np.complex64) + # Read rows one at a time + for row in range(rows): + try: + line = next(f) + # " (r,i) (r,i) ..." + col = 0 + for tpl in line.split(): + real, imag = tpl.strip("(,)").split(",") + data[var][row][col] = float(real) + (1j * float(imag)) + col += 1 + except: + print("Error reading row {} of var {}".format(row, var)) + raise + # end for line in f + return(data) + + +############################################################################## +# Read stm32 diag data, syms, amps for each frame +############################################################################## + +def read_tgt_syms(f): + # TODO: don't use hardcoded values... + syms = np.zeros((100, 112), np.complex64) + amps = np.zeros((100, 112), np.float32) + row = 0 + while True: + # syms + buf = f.read(112 * 8) + if (len(buf) < (112 * 8)): break + row_lst = struct.unpack("<224f", buf) + ary = np.array(row_lst, np.float32) + ary.dtype = np.complex64 + syms[row] = ary + # amps + buf = f.read(112 * 4) + if (len(buf) < (112 * 4)): break + row_lst = struct.unpack("<112f", buf) + ary = np.array(row_lst, np.float32) + amps[row] = ary + # + row += 1 + if (row >= 100): break + # end While True + return(syms, amps) + # end read_stm_syms() + + +############################################################################## +# Write out in octave text format as 2 matricies +############################################################################## + +def write_syms_as_octave(syms, amps): + with open("ofdm_demod_log.txt", "w") as f: + # syms + rows = syms.shape[0] + cols = syms.shape[1] + f.write("# name: payload_syms_log_stm32\n") + f.write("# type: complex matrix\n") + f.write("# rows: {}\n".format(rows)) + f.write("# columns: {}\n".format(cols)) + for row in range(rows): + for col in range(cols): + f.write(" ({},{})".format( + syms[row][col].real, + syms[row][col].imag + )) + f.write("\n") + # amps + rows = amps.shape[0] + cols = amps.shape[1] + f.write("\n") + f.write("# name: payload_amps_log_stm32\n") + f.write("# type: matrix\n") + f.write("# rows: {}\n".format(rows)) + f.write("# columns: {}\n".format(cols)) + for row in range(rows): + for col in range(cols): + f.write(" {}".format( + amps[row][col] + )) + f.write("\n") + # end write_syms_as_octave() + + +############################################################################## +# Main +############################################################################## + +#### Options +argparser = argparse.ArgumentParser() +argparser.add_argument("-v", "--verbose", action="store_true") +argparser.add_argument("test", action="store") +argparser.add_argument("test_opt", action="store", + choices=["quick", "ideal", "AWGN", "fade", "profile", + "ldpc", "ldpc_AWGN", "ldpc_fade" ]) +args = argparser.parse_args() + +# Use ENV value of UNITTEST_BASE from upper level shell script (default to .) +if ('UNITTEST_BASE' in os.environ): + run_dir = os.environ['UNITTEST_BASE'] + "/test_run/" + run_dir += args.test + "_" + args.test_opt + print(run_dir) + os.chdir(run_dir) + +#### Settings +# Defaults, (for tests without channel degradation, results should be close to ideal) +max_ber = 0.001 # Max BER value in Target +max_ber2 = 0.001 # Max Coded BER value in Target +compare_ber = 1 # Compare Target to Reference? +# Used if compare_ber: +tolerance_ber = 0.001 # Difference from reference for BER +tolerance_ber2 = 0.001 # Difference from reference for Coded BER +tolerance_tbits = 0 +tolerance_terrs = 1 +# +compare_output = 1 # Compare Target to Reference? +# Used if compare_output: +tolerance_output_differences = 0 +tolerance_syms = 0.01 +tolerance_amps = 0.01 +# +# Per test settings +if (args.test_opt == "quick"): + pass +elif (args.test_opt == "ideal"): + pass +elif (args.test_opt == "AWGN"): # Still close enough to compare BERs loosely + max_ber = 0.1 + max_ber2 = 0.1 + tolerance_ber = 0.01 + tolerance_ber2 = 0.005 + tolerance_tbits = 1000 + tolerance_terrs = 50 + tolerance_output_differences = 2 + compare_output = 0 +elif (args.test_opt == "fade"): + max_ber = 0.1 + max_ber2 = 0.1 + tolerance_ber = 0.01 + tolerance_ber2 = 0.005 + tolerance_tbits = 1000 + tolerance_terrs = 200 + tolerance_output_differences = 5 + compare_output = 0 + pass +elif (args.test_opt == "profile"): + tolerance_output_differences = 1 + pass +elif (args.test_opt == "ldpc"): + pass +elif (args.test_opt == "ldpc_AWGN"): + max_ber = 0.1 + max_ber2 = 0.01 + compare_ber = 0 + compare_output = 0 +elif (args.test_opt == "ldpc_fade"): + max_ber = 0.1 + max_ber2 = 0.01 + compare_ber = 0 + compare_output = 0 + pass +else: + print("Error: Test {} not recognized".format(args.test_opt)) + sys.exit(1) + +#### Check that we are in the test directory: +#### TODO::: + + +#### Read test configuration - a file of '0' or '1' characters +with open("stm_cfg.txt", "r") as f: + config = f.read(8) + config_verbose = (config[0] == '1') + config_testframes = (config[1] == '1') + config_ldpc_en = (config[2] == '1') + config_log_payload_syms = (config[3] == '1') + config_profile = (config[4] == '1') + + +#### +fails = 0 + + +if (config_testframes): + #### BER checks - log output looks like this, for non-ldpc: + # BER......: 0.1951 Tbits: 14994 Terrs: 2926 + # BER2.....: 0.2001 Tbits: 10234 Terrs: 2048 + # + # Or this, for ldpc: + # BER......: 0.0000 Tbits: 15008 Terrs: 0 + # Coded BER: 0.0000 Tbits: 7504 Terrs: 0 + # + # HACK: store "Coded BER" info as BER2. + + print("\nBER checks") + + # Read ref log + print("Reference") + with open("ref_gen_log.txt", "r") as f: + for line in f: + if (line[0:4] == "BER2"): + print(line, end="") + _, ref_ber2, _, ref_tbits2, _, ref_terrs2 = line.split() + elif (line[0:3] == "BER"): + print(line, end="") + _, ref_ber, _, ref_tbits, _, ref_terrs, _, ref_tpackets, _, ref_snr3k = line.split() + elif (line[0:9] == "Coded BER"): + print(line, end="") + _, _, ref_ber2, _, ref_tbits2, _, ref_terrs2 = line.split() + + # Strings to integers + ref_ber = float(ref_ber) + ref_tbits = int(ref_tbits) + ref_terrs = int(ref_terrs) + ref_ber2 = float(ref_ber2) + ref_tbits2 = int(ref_tbits2) + ref_terrs2 = int(ref_terrs2) + + # Read stm log + print("Target") + with open("stdout.log", "r") as f: + for line in f: + if (line[0:4] == "BER2"): + print(line, end="") + _, tgt_ber2, _, tgt_tbits2, _, tgt_terrs2 = line.split() + elif (line[0:3] == "BER"): + print(line, end="") + _, tgt_ber, _, tgt_tbits, _, tgt_terrs = line.split() + elif (line[0:9] == "Coded BER"): + print(line, end="") + _, _, tgt_ber2, _, tgt_tbits2, _, tgt_terrs2 = line.split() + # Strings to integers + tgt_ber = float(tgt_ber) + tgt_tbits = int(tgt_tbits) + tgt_terrs = int(tgt_terrs) + tgt_ber2 = float(tgt_ber2) + tgt_tbits2 = int(tgt_tbits2) + tgt_terrs2 = int(tgt_terrs2) + + # simple hack to tolerate zero bits > NAN + if (math.isnan(ref_ber2)): ref_ber2 = 0 + if (math.isnan(tgt_ber2)): tgt_ber2 = 0 + + ## Max BER values + if ((tgt_ber > max_ber) or (tgt_ber2 > max_ber2)): + fails += 1 + print("FAIL: max BER") + else: + print("PASS: max BER") + + ## Compare BER values + if (compare_ber): + chk_tolerance_ber = abs(ref_ber - tgt_ber) + chk_tolerance_tbits = abs(ref_tbits - tgt_tbits) + chk_tolerance_terrs = abs(ref_terrs - tgt_terrs) + chk_tolerance_ber2 = abs(ref_ber2 - tgt_ber2) + chk_tolerance_tbits2 = abs(ref_tbits2 - tgt_tbits2) + chk_tolerance_terrs2 = abs(ref_terrs2 - tgt_terrs2) + passes = True + + if (chk_tolerance_ber > tolerance_ber): + print("fail tolerance_ber {} > {}". + format(chk_tolerance_ber, tolerance_ber)) + passes = False + if (chk_tolerance_tbits > tolerance_tbits): + print("fail tolerance_tbits {} > {}". + format(chk_tolerance_tbits, tolerance_tbits)) + passes = False + if (chk_tolerance_terrs > tolerance_terrs): + print("fail tolerance_terrs {} > {}". + format(chk_tolerance_terrs, tolerance_terrs)) + passes = False + if (ref_tbits2 == 0): + if (chk_tolerance_ber2 > tolerance_ber2): + print("fail tolerance_ber2 {} > {}". + format(chk_tolerance_ber2, tolerance_ber2)) + passes = False + if (chk_tolerance_tbits2 > tolerance_tbits): + print("fail tolerance_tbits2 {} > {}". + format(chk_tolerance_tbits2, tolerance_tbits)) + passes = False + if (chk_tolerance_terrs2 > tolerance_terrs): + print("fail tolerance_terrs2 {} > {}". + format(chk_tolerance_terrs2, tolerance_terrs)) + passes = False + + if (passes): + print("PASS: BER compare") + else: + fails += 1 + print("FAIL: BER compare") + + # end Compare BER + # end BER checks + +#### Output differences +if (compare_output): + + print("\nOutput checks") + + # Output is a binary file of bytes whose values are 0x00 or 0x01. + with open("ref_demod_out.raw", "rb") as f: ref_out_bytes = f.read() + with open("stm_out.raw", "rb") as f: tgt_out_bytes = f.read() + if (len(ref_out_bytes) != len(tgt_out_bytes)): + fails += 1 + print("FAIL Output, length mismatch") + else: + output_diffs = 0 + for i in range(len(ref_out_bytes)): + fnum = math.floor(i/Nbitsperframe) + bnum = i - (fnum * Nbitsperframe) + # Both legal values?? + if (ref_out_bytes[i] > 1): + print("Error: Output frame {} byte {} not 0 or 1 in reference data".format(fnum, bnum)) + fails += 1 + if (tgt_out_bytes[i] > 1): + print("Error: Output frame {} byte {} not 0 or 1 in target data".format(fnum, bnum)) + fails += 1 + # Match?? + if (ref_out_bytes[i] != tgt_out_bytes[i]): + print("Output frame {} byte {} mismatch: ref={} tgt={}".format( + fnum, bnum, ref_out_bytes[i], tgt_out_bytes[i])) + output_diffs += 1 + # end for i + if (output_diffs > tolerance_output_differences): + print("FAIL: Output Differences = {}".format(output_diffs)) + fails += 1 + else: + print("PASS: Output Differences = {}".format(output_diffs)) + # end not length mismatch + + + #### Syms data + if (config_log_payload_syms): + print("\nSyms and Amps checks") + + fref = open("ofdm_demod_ref_log.txt", "r") + fdiag = open("stm_diag.raw", "rb") + + ref_data = read_octave_text(fref) + (tgt_syms, tgt_amps) = read_tgt_syms(fdiag) + fdiag.close() + write_syms_as_octave(tgt_syms, tgt_amps) # for manual debug... + + # Find smallest common subset + hgt = min(tgt_syms.shape[0], ref_data["payload_syms_log_c"].shape[0]) + wid = min(tgt_syms.shape[1], ref_data["payload_syms_log_c"].shape[1]) + + ref_syms = ref_data["payload_syms_log_c"][:hgt][:wid] + ref_amps = ref_data["payload_amps_log_c"][:hgt][:wid] + tgt_syms= tgt_syms[:hgt][:wid] + tgt_amps= tgt_amps[:hgt][:wid] + + # Eliminate trailing rows of all zeros + # Sum the rows to find rows of all zeros + row_sums = ref_syms.sum(axis=1) + tgt_syms.sum(axis=1) + nonzeros = row_sums.nonzero() + last_nonzero = nonzeros[0][-1] + # stop index is 1 past the last!! + # and use the Magnitude of the complex values + ref_syms = np.abs(ref_syms[:last_nonzero+1]) + ref_amps = np.abs(ref_amps[:last_nonzero+1]) + tgt_syms = np.abs(tgt_syms[:last_nonzero+1]) + tgt_amps = np.abs(tgt_amps[:last_nonzero+1]) + + # Differences - Syms + #diffs_syms = np.abs(ref_syms - tgt_syms) # This is the mag of complex + diffs_syms = np.abs(np.divide((ref_syms - tgt_syms), ref_syms, + where=(ref_syms!=0))) + print("Minimum syms difference = {:.6f}".format(np.amin(diffs_syms))) + print("Maximum syms difference = {:.6f}".format(np.amax(diffs_syms))) + print("Average syms difference = {:.6f}".format(np.average(diffs_syms))) + if (args.verbose): # Print top 10 differences + diffs_syms_sorted_indexes = (diffs_syms).argsort(axis=None)[::-1] + print(" Top 10 differences") + for i in range(10): + j = diffs_syms_sorted_indexes[i] + print(" #{} @{}: {} <?> {} = {:.6f}".format( + i, j, + ref_syms.flatten()[j], tgt_syms.flatten()[j], diffs_syms.flatten()[j]) + ) + # Errors are differences > tolerance_syms + errors_syms = diffs_syms - tolerance_syms + errors_syms[errors_syms < 0.0] = 0.0 + num_errors_syms = np.count_nonzero(errors_syms) + error_rows_syms = np.amax(errors_syms, axis=1) + num_error_rows_syms = np.count_nonzero(error_rows_syms) + print("") + print("{} symbol errors on {} rows".format(num_errors_syms, num_error_rows_syms)) + + # Differences - Amps + diffs_amps = np.abs(np.divide((ref_amps - tgt_amps), ref_amps, + where=(ref_amps!=0))) + print("Minimum amps difference = {:.6f}".format(np.amin(diffs_amps))) + print("Maximum amps difference = {:.6f}".format(np.amax(diffs_amps))) + print("Average amps difference = {:.6f}".format(np.average(diffs_amps))) + if (args.verbose): # Print top 10 differences + diffs_amps_sorted_indexes = (diffs_amps).argsort(axis=None)[::-1] + print(" Top 10 differences") + for i in range(10): + j = diffs_amps_sorted_indexes[i] + print(" #{} @{}: {} <?> {} = {:.6f}".format( + i, j, + ref_amps.flatten()[j], tgt_amps.flatten()[j], diffs_amps.flatten()[j]) + ) + + # Errors are differences > tolerance_syms + errors_amps = diffs_amps - tolerance_amps + errors_amps[errors_amps < 0.0] = 0.0 + num_errors_amps = np.count_nonzero(errors_amps) + error_rows_amps = np.amax(errors_amps, axis=1) + num_error_rows_amps = np.count_nonzero(error_rows_amps) + print("") + print("{} Amplitude errors on {} rows".format(num_errors_amps, num_error_rows_amps)) + # End compare_output + + +#### Profile +if (config_profile): + print("\nProfile:") + with open("stdout.log", "r") as f: + sum_profiles.sum_profiles(f, 100) + + print("\nStack:") + with open("stdout.txt", "r") as f: + for line in f: + if (line.startswith("Max stack")): + print(line) + + +#### Print final status message +if (fails): print("\nTest FAILED!") +else: print("\nTest PASSED") + +sys.exit(fails) |
