aboutsummaryrefslogtreecommitdiff
path: root/stm32/unittest/scripts/tst_ofdm_demod_check
diff options
context:
space:
mode:
Diffstat (limited to 'stm32/unittest/scripts/tst_ofdm_demod_check')
-rwxr-xr-xstm32/unittest/scripts/tst_ofdm_demod_check503
1 files changed, 503 insertions, 0 deletions
diff --git a/stm32/unittest/scripts/tst_ofdm_demod_check b/stm32/unittest/scripts/tst_ofdm_demod_check
new file mode 100755
index 0000000..06e471c
--- /dev/null
+++ b/stm32/unittest/scripts/tst_ofdm_demod_check
@@ -0,0 +1,503 @@
+#!/usr/bin/env python3
+""" tst_ofdm_demod_check
+
+ Testing for tst_ofdm_demod_* tests
+
+ Usage tst_ofdm_demod_check <dummy_test_name> quick|ideal|AWGN|fade|profile|ldpc|ldpc_AWGN|ldpc_fade
+
+ Checks are different for each option, but similar
+
+ - Convert stm32 output to octave text format
+ (stm32 does not have memory for this)
+
+ - ...
+
+ """
+
+import numpy as np
+import math
+import argparse
+import struct
+import os
+import sys
+
+if ("UNITTEST_BASE" in os.environ):
+ sys.path.append(os.environ["UNITTEST_BASE"] + "/lib/python")
+else:
+ sys.path.append("../../lib/python") # assume in test run dir
+
+import sum_profiles
+
+Nbitsperframe = 238
+
+##############################################################################
+# Read Octave text file
+##############################################################################
+
+def read_octave_text(f):
+ if (args.verbose): print('read_octave_text()')
+ data = {}
+ for line in f:
+ if (line[0:8] == "# name: "):
+ var = line.split()[2]
+ if (args.verbose): print(' var "{}"'.format(var))
+ line = next(f)
+ if (line.startswith("# type: matrix")):
+ line = next(f)
+ rows = int(line.split()[2])
+ line = next(f)
+ cols = int(line.split()[2])
+ if (cols > 0):
+ data[var] = np.empty((rows, cols), np.float32)
+ # Read rows one at a time
+ for row in range(rows):
+ try:
+ line = next(f)
+ data[var][row] = np.fromstring(line, np.float32, cols, " ")
+ except:
+ print("Error reading row {} of var {}".format(row, var))
+ raise
+
+ elif (line.startswith("# type: complex matrix")):
+ line = next(f)
+ rows = int(line.split()[2])
+ line = next(f)
+ cols = int(line.split()[2])
+ if (cols > 0):
+ data[var] = np.empty((rows, cols), np.complex64)
+ # Read rows one at a time
+ for row in range(rows):
+ try:
+ line = next(f)
+ # " (r,i) (r,i) ..."
+ col = 0
+ for tpl in line.split():
+ real, imag = tpl.strip("(,)").split(",")
+ data[var][row][col] = float(real) + (1j * float(imag))
+ col += 1
+ except:
+ print("Error reading row {} of var {}".format(row, var))
+ raise
+ # end for line in f
+ return(data)
+
+
+##############################################################################
+# Read stm32 diag data, syms, amps for each frame
+##############################################################################
+
+def read_tgt_syms(f):
+ # TODO: don't use hardcoded values...
+ syms = np.zeros((100, 112), np.complex64)
+ amps = np.zeros((100, 112), np.float32)
+ row = 0
+ while True:
+ # syms
+ buf = f.read(112 * 8)
+ if (len(buf) < (112 * 8)): break
+ row_lst = struct.unpack("<224f", buf)
+ ary = np.array(row_lst, np.float32)
+ ary.dtype = np.complex64
+ syms[row] = ary
+ # amps
+ buf = f.read(112 * 4)
+ if (len(buf) < (112 * 4)): break
+ row_lst = struct.unpack("<112f", buf)
+ ary = np.array(row_lst, np.float32)
+ amps[row] = ary
+ #
+ row += 1
+ if (row >= 100): break
+ # end While True
+ return(syms, amps)
+ # end read_stm_syms()
+
+
+##############################################################################
+# Write out in octave text format as 2 matricies
+##############################################################################
+
+def write_syms_as_octave(syms, amps):
+ with open("ofdm_demod_log.txt", "w") as f:
+ # syms
+ rows = syms.shape[0]
+ cols = syms.shape[1]
+ f.write("# name: payload_syms_log_stm32\n")
+ f.write("# type: complex matrix\n")
+ f.write("# rows: {}\n".format(rows))
+ f.write("# columns: {}\n".format(cols))
+ for row in range(rows):
+ for col in range(cols):
+ f.write(" ({},{})".format(
+ syms[row][col].real,
+ syms[row][col].imag
+ ))
+ f.write("\n")
+ # amps
+ rows = amps.shape[0]
+ cols = amps.shape[1]
+ f.write("\n")
+ f.write("# name: payload_amps_log_stm32\n")
+ f.write("# type: matrix\n")
+ f.write("# rows: {}\n".format(rows))
+ f.write("# columns: {}\n".format(cols))
+ for row in range(rows):
+ for col in range(cols):
+ f.write(" {}".format(
+ amps[row][col]
+ ))
+ f.write("\n")
+ # end write_syms_as_octave()
+
+
+##############################################################################
+# Main
+##############################################################################
+
+#### Options
+argparser = argparse.ArgumentParser()
+argparser.add_argument("-v", "--verbose", action="store_true")
+argparser.add_argument("test", action="store")
+argparser.add_argument("test_opt", action="store",
+ choices=["quick", "ideal", "AWGN", "fade", "profile",
+ "ldpc", "ldpc_AWGN", "ldpc_fade" ])
+args = argparser.parse_args()
+
+# Use ENV value of UNITTEST_BASE from upper level shell script (default to .)
+if ('UNITTEST_BASE' in os.environ):
+ run_dir = os.environ['UNITTEST_BASE'] + "/test_run/"
+ run_dir += args.test + "_" + args.test_opt
+ print(run_dir)
+ os.chdir(run_dir)
+
+#### Settings
+# Defaults, (for tests without channel degradation, results should be close to ideal)
+max_ber = 0.001 # Max BER value in Target
+max_ber2 = 0.001 # Max Coded BER value in Target
+compare_ber = 1 # Compare Target to Reference?
+# Used if compare_ber:
+tolerance_ber = 0.001 # Difference from reference for BER
+tolerance_ber2 = 0.001 # Difference from reference for Coded BER
+tolerance_tbits = 0
+tolerance_terrs = 1
+#
+compare_output = 1 # Compare Target to Reference?
+# Used if compare_output:
+tolerance_output_differences = 0
+tolerance_syms = 0.01
+tolerance_amps = 0.01
+#
+# Per test settings
+if (args.test_opt == "quick"):
+ pass
+elif (args.test_opt == "ideal"):
+ pass
+elif (args.test_opt == "AWGN"): # Still close enough to compare BERs loosely
+ max_ber = 0.1
+ max_ber2 = 0.1
+ tolerance_ber = 0.01
+ tolerance_ber2 = 0.005
+ tolerance_tbits = 1000
+ tolerance_terrs = 50
+ tolerance_output_differences = 2
+ compare_output = 0
+elif (args.test_opt == "fade"):
+ max_ber = 0.1
+ max_ber2 = 0.1
+ tolerance_ber = 0.01
+ tolerance_ber2 = 0.005
+ tolerance_tbits = 1000
+ tolerance_terrs = 200
+ tolerance_output_differences = 5
+ compare_output = 0
+ pass
+elif (args.test_opt == "profile"):
+ tolerance_output_differences = 1
+ pass
+elif (args.test_opt == "ldpc"):
+ pass
+elif (args.test_opt == "ldpc_AWGN"):
+ max_ber = 0.1
+ max_ber2 = 0.01
+ compare_ber = 0
+ compare_output = 0
+elif (args.test_opt == "ldpc_fade"):
+ max_ber = 0.1
+ max_ber2 = 0.01
+ compare_ber = 0
+ compare_output = 0
+ pass
+else:
+ print("Error: Test {} not recognized".format(args.test_opt))
+ sys.exit(1)
+
+#### Check that we are in the test directory:
+#### TODO:::
+
+
+#### Read test configuration - a file of '0' or '1' characters
+with open("stm_cfg.txt", "r") as f:
+ config = f.read(8)
+ config_verbose = (config[0] == '1')
+ config_testframes = (config[1] == '1')
+ config_ldpc_en = (config[2] == '1')
+ config_log_payload_syms = (config[3] == '1')
+ config_profile = (config[4] == '1')
+
+
+####
+fails = 0
+
+
+if (config_testframes):
+ #### BER checks - log output looks like this, for non-ldpc:
+ # BER......: 0.1951 Tbits: 14994 Terrs: 2926
+ # BER2.....: 0.2001 Tbits: 10234 Terrs: 2048
+ #
+ # Or this, for ldpc:
+ # BER......: 0.0000 Tbits: 15008 Terrs: 0
+ # Coded BER: 0.0000 Tbits: 7504 Terrs: 0
+ #
+ # HACK: store "Coded BER" info as BER2.
+
+ print("\nBER checks")
+
+ # Read ref log
+ print("Reference")
+ with open("ref_gen_log.txt", "r") as f:
+ for line in f:
+ if (line[0:4] == "BER2"):
+ print(line, end="")
+ _, ref_ber2, _, ref_tbits2, _, ref_terrs2 = line.split()
+ elif (line[0:3] == "BER"):
+ print(line, end="")
+ _, ref_ber, _, ref_tbits, _, ref_terrs, _, ref_tpackets, _, ref_snr3k = line.split()
+ elif (line[0:9] == "Coded BER"):
+ print(line, end="")
+ _, _, ref_ber2, _, ref_tbits2, _, ref_terrs2 = line.split()
+
+ # Strings to integers
+ ref_ber = float(ref_ber)
+ ref_tbits = int(ref_tbits)
+ ref_terrs = int(ref_terrs)
+ ref_ber2 = float(ref_ber2)
+ ref_tbits2 = int(ref_tbits2)
+ ref_terrs2 = int(ref_terrs2)
+
+ # Read stm log
+ print("Target")
+ with open("stdout.log", "r") as f:
+ for line in f:
+ if (line[0:4] == "BER2"):
+ print(line, end="")
+ _, tgt_ber2, _, tgt_tbits2, _, tgt_terrs2 = line.split()
+ elif (line[0:3] == "BER"):
+ print(line, end="")
+ _, tgt_ber, _, tgt_tbits, _, tgt_terrs = line.split()
+ elif (line[0:9] == "Coded BER"):
+ print(line, end="")
+ _, _, tgt_ber2, _, tgt_tbits2, _, tgt_terrs2 = line.split()
+ # Strings to integers
+ tgt_ber = float(tgt_ber)
+ tgt_tbits = int(tgt_tbits)
+ tgt_terrs = int(tgt_terrs)
+ tgt_ber2 = float(tgt_ber2)
+ tgt_tbits2 = int(tgt_tbits2)
+ tgt_terrs2 = int(tgt_terrs2)
+
+ # simple hack to tolerate zero bits > NAN
+ if (math.isnan(ref_ber2)): ref_ber2 = 0
+ if (math.isnan(tgt_ber2)): tgt_ber2 = 0
+
+ ## Max BER values
+ if ((tgt_ber > max_ber) or (tgt_ber2 > max_ber2)):
+ fails += 1
+ print("FAIL: max BER")
+ else:
+ print("PASS: max BER")
+
+ ## Compare BER values
+ if (compare_ber):
+ chk_tolerance_ber = abs(ref_ber - tgt_ber)
+ chk_tolerance_tbits = abs(ref_tbits - tgt_tbits)
+ chk_tolerance_terrs = abs(ref_terrs - tgt_terrs)
+ chk_tolerance_ber2 = abs(ref_ber2 - tgt_ber2)
+ chk_tolerance_tbits2 = abs(ref_tbits2 - tgt_tbits2)
+ chk_tolerance_terrs2 = abs(ref_terrs2 - tgt_terrs2)
+ passes = True
+
+ if (chk_tolerance_ber > tolerance_ber):
+ print("fail tolerance_ber {} > {}".
+ format(chk_tolerance_ber, tolerance_ber))
+ passes = False
+ if (chk_tolerance_tbits > tolerance_tbits):
+ print("fail tolerance_tbits {} > {}".
+ format(chk_tolerance_tbits, tolerance_tbits))
+ passes = False
+ if (chk_tolerance_terrs > tolerance_terrs):
+ print("fail tolerance_terrs {} > {}".
+ format(chk_tolerance_terrs, tolerance_terrs))
+ passes = False
+ if (ref_tbits2 == 0):
+ if (chk_tolerance_ber2 > tolerance_ber2):
+ print("fail tolerance_ber2 {} > {}".
+ format(chk_tolerance_ber2, tolerance_ber2))
+ passes = False
+ if (chk_tolerance_tbits2 > tolerance_tbits):
+ print("fail tolerance_tbits2 {} > {}".
+ format(chk_tolerance_tbits2, tolerance_tbits))
+ passes = False
+ if (chk_tolerance_terrs2 > tolerance_terrs):
+ print("fail tolerance_terrs2 {} > {}".
+ format(chk_tolerance_terrs2, tolerance_terrs))
+ passes = False
+
+ if (passes):
+ print("PASS: BER compare")
+ else:
+ fails += 1
+ print("FAIL: BER compare")
+
+ # end Compare BER
+ # end BER checks
+
+#### Output differences
+if (compare_output):
+
+ print("\nOutput checks")
+
+ # Output is a binary file of bytes whose values are 0x00 or 0x01.
+ with open("ref_demod_out.raw", "rb") as f: ref_out_bytes = f.read()
+ with open("stm_out.raw", "rb") as f: tgt_out_bytes = f.read()
+ if (len(ref_out_bytes) != len(tgt_out_bytes)):
+ fails += 1
+ print("FAIL Output, length mismatch")
+ else:
+ output_diffs = 0
+ for i in range(len(ref_out_bytes)):
+ fnum = math.floor(i/Nbitsperframe)
+ bnum = i - (fnum * Nbitsperframe)
+ # Both legal values??
+ if (ref_out_bytes[i] > 1):
+ print("Error: Output frame {} byte {} not 0 or 1 in reference data".format(fnum, bnum))
+ fails += 1
+ if (tgt_out_bytes[i] > 1):
+ print("Error: Output frame {} byte {} not 0 or 1 in target data".format(fnum, bnum))
+ fails += 1
+ # Match??
+ if (ref_out_bytes[i] != tgt_out_bytes[i]):
+ print("Output frame {} byte {} mismatch: ref={} tgt={}".format(
+ fnum, bnum, ref_out_bytes[i], tgt_out_bytes[i]))
+ output_diffs += 1
+ # end for i
+ if (output_diffs > tolerance_output_differences):
+ print("FAIL: Output Differences = {}".format(output_diffs))
+ fails += 1
+ else:
+ print("PASS: Output Differences = {}".format(output_diffs))
+ # end not length mismatch
+
+
+ #### Syms data
+ if (config_log_payload_syms):
+ print("\nSyms and Amps checks")
+
+ fref = open("ofdm_demod_ref_log.txt", "r")
+ fdiag = open("stm_diag.raw", "rb")
+
+ ref_data = read_octave_text(fref)
+ (tgt_syms, tgt_amps) = read_tgt_syms(fdiag)
+ fdiag.close()
+ write_syms_as_octave(tgt_syms, tgt_amps) # for manual debug...
+
+ # Find smallest common subset
+ hgt = min(tgt_syms.shape[0], ref_data["payload_syms_log_c"].shape[0])
+ wid = min(tgt_syms.shape[1], ref_data["payload_syms_log_c"].shape[1])
+
+ ref_syms = ref_data["payload_syms_log_c"][:hgt][:wid]
+ ref_amps = ref_data["payload_amps_log_c"][:hgt][:wid]
+ tgt_syms= tgt_syms[:hgt][:wid]
+ tgt_amps= tgt_amps[:hgt][:wid]
+
+ # Eliminate trailing rows of all zeros
+ # Sum the rows to find rows of all zeros
+ row_sums = ref_syms.sum(axis=1) + tgt_syms.sum(axis=1)
+ nonzeros = row_sums.nonzero()
+ last_nonzero = nonzeros[0][-1]
+ # stop index is 1 past the last!!
+ # and use the Magnitude of the complex values
+ ref_syms = np.abs(ref_syms[:last_nonzero+1])
+ ref_amps = np.abs(ref_amps[:last_nonzero+1])
+ tgt_syms = np.abs(tgt_syms[:last_nonzero+1])
+ tgt_amps = np.abs(tgt_amps[:last_nonzero+1])
+
+ # Differences - Syms
+ #diffs_syms = np.abs(ref_syms - tgt_syms) # This is the mag of complex
+ diffs_syms = np.abs(np.divide((ref_syms - tgt_syms), ref_syms,
+ where=(ref_syms!=0)))
+ print("Minimum syms difference = {:.6f}".format(np.amin(diffs_syms)))
+ print("Maximum syms difference = {:.6f}".format(np.amax(diffs_syms)))
+ print("Average syms difference = {:.6f}".format(np.average(diffs_syms)))
+ if (args.verbose): # Print top 10 differences
+ diffs_syms_sorted_indexes = (diffs_syms).argsort(axis=None)[::-1]
+ print(" Top 10 differences")
+ for i in range(10):
+ j = diffs_syms_sorted_indexes[i]
+ print(" #{} @{}: {} <?> {} = {:.6f}".format(
+ i, j,
+ ref_syms.flatten()[j], tgt_syms.flatten()[j], diffs_syms.flatten()[j])
+ )
+ # Errors are differences > tolerance_syms
+ errors_syms = diffs_syms - tolerance_syms
+ errors_syms[errors_syms < 0.0] = 0.0
+ num_errors_syms = np.count_nonzero(errors_syms)
+ error_rows_syms = np.amax(errors_syms, axis=1)
+ num_error_rows_syms = np.count_nonzero(error_rows_syms)
+ print("")
+ print("{} symbol errors on {} rows".format(num_errors_syms, num_error_rows_syms))
+
+ # Differences - Amps
+ diffs_amps = np.abs(np.divide((ref_amps - tgt_amps), ref_amps,
+ where=(ref_amps!=0)))
+ print("Minimum amps difference = {:.6f}".format(np.amin(diffs_amps)))
+ print("Maximum amps difference = {:.6f}".format(np.amax(diffs_amps)))
+ print("Average amps difference = {:.6f}".format(np.average(diffs_amps)))
+ if (args.verbose): # Print top 10 differences
+ diffs_amps_sorted_indexes = (diffs_amps).argsort(axis=None)[::-1]
+ print(" Top 10 differences")
+ for i in range(10):
+ j = diffs_amps_sorted_indexes[i]
+ print(" #{} @{}: {} <?> {} = {:.6f}".format(
+ i, j,
+ ref_amps.flatten()[j], tgt_amps.flatten()[j], diffs_amps.flatten()[j])
+ )
+
+ # Errors are differences > tolerance_syms
+ errors_amps = diffs_amps - tolerance_amps
+ errors_amps[errors_amps < 0.0] = 0.0
+ num_errors_amps = np.count_nonzero(errors_amps)
+ error_rows_amps = np.amax(errors_amps, axis=1)
+ num_error_rows_amps = np.count_nonzero(error_rows_amps)
+ print("")
+ print("{} Amplitude errors on {} rows".format(num_errors_amps, num_error_rows_amps))
+ # End compare_output
+
+
+#### Profile
+if (config_profile):
+ print("\nProfile:")
+ with open("stdout.log", "r") as f:
+ sum_profiles.sum_profiles(f, 100)
+
+ print("\nStack:")
+ with open("stdout.txt", "r") as f:
+ for line in f:
+ if (line.startswith("Max stack")):
+ print(line)
+
+
+#### Print final status message
+if (fails): print("\nTest FAILED!")
+else: print("\nTest PASSED")
+
+sys.exit(fails)