luaunit.lua 105 KB


  1. --[[
  2. luaunit.lua
  3. Description: A unit testing framework
  4. Homepage: https://github.com/bluebird75/luaunit
  5. Development by Philippe Fremy <phil@freehackers.org>
  6. Based on initial work of Ryu, Gwang (http://www.gpgstudy.com/gpgiki/LuaUnit)
  7. License: BSD License, see LICENSE.txt
  8. Version: 3.2
  9. ]]--
  10. module(...,package.seeall)
  11. --require("math")
  12. local M={}
  13. -- private exported functions (for testing)
  14. M.private = {}
  15. M.VERSION='3.3'
  16. M._VERSION=M.VERSION -- For LuaUnit v2 compatibility
  17. --[[ Some people like assertEquals( actual, expected ) and some people prefer
  18. assertEquals( expected, actual ).
  19. ]]--
  20. M.ORDER_ACTUAL_EXPECTED = true
  21. M.PRINT_TABLE_REF_IN_ERROR_MSG = false
  22. M.TABLE_EQUALS_KEYBYCONTENT = true
  23. M.LINE_LENGTH = 80
  24. M.TABLE_DIFF_ANALYSIS_THRESHOLD = 10 -- display deep analysis for more than 10 items
  25. M.LIST_DIFF_ANALYSIS_THRESHOLD = 10 -- display deep analysis for more than 10 items
  26. --[[ EPS is meant to help with Lua's floating point math in simple corner
  27. cases like almostEquals(1.1-0.1, 1), which may not work as-is (e.g. on numbers
  28. with rational binary representation) if the user doesn't provide some explicit
  29. error margin.
  30. The default margin used by almostEquals() in such cases is EPS; and since
  31. Lua may be compiled with different numeric precisions (single vs. double), we
  32. try to select a useful default for it dynamically. Note: If the initial value
  33. is not acceptable, it can be changed by the user to better suit specific needs.
  34. See also: https://en.wikipedia.org/wiki/Machine_epsilon
  35. ]]
  36. M.EPS = 2^-52 -- = machine epsilon for "double", ~2.22E-16
  37. if math.abs(1.1 - 1 - 0.1) > M.EPS then
  38. -- rounding error is above EPS, assume single precision
  39. M.EPS = 2^-23 -- = machine epsilon for "float", ~1.19E-07
  40. end
  41. -- set this to false to debug luaunit
  42. local STRIP_LUAUNIT_FROM_STACKTRACE = true
  43. M.VERBOSITY_DEFAULT = 10
  44. M.VERBOSITY_LOW = 1
  45. M.VERBOSITY_QUIET = 0
  46. M.VERBOSITY_VERBOSE = 20
  47. M.DEFAULT_DEEP_ANALYSIS = nil
  48. M.FORCE_DEEP_ANALYSIS = true
  49. M.DISABLE_DEEP_ANALYSIS = false
  50. -- set EXPORT_ASSERT_TO_GLOBALS to have all asserts visible as global values
  51. -- EXPORT_ASSERT_TO_GLOBALS = true
  52. -- we need to keep a copy of the script args before it is overriden
  53. local cmdline_argv = rawget(_G, "arg")
  54. M.FAILURE_PREFIX = 'LuaUnit test FAILURE: ' -- prefix string for failed tests
  55. M.USAGE=[[Usage: lua <your_test_suite.lua> [options] [testname1 [testname2] ... ]
  56. Options:
  57. -h, --help: Print this help
  58. --version: Print version information
  59. -v, --verbose: Increase verbosity
  60. -q, --quiet: Set verbosity to minimum
  61. -e, --error: Stop on first error
  62. -f, --failure: Stop on first failure or error
  63. -s, --shuffle: Shuffle tests before running them
  64. -o, --output OUTPUT: Set output type to OUTPUT
  65. Possible values: text, tap, junit, nil
  66. -n, --name NAME: For junit only, mandatory name of xml file
  67. -r, --repeat NUM: Execute all tests NUM times, e.g. to trig the JIT
  68. -p, --pattern PATTERN: Execute all test names matching the Lua PATTERN
  69. May be repeated to include several patterns
  70. Make sure you escape magic chars like +? with %
  71. -x, --exclude PATTERN: Exclude all test names matching the Lua PATTERN
  72. May be repeated to exclude several patterns
  73. Make sure you escape magic chars like +? with %
  74. testname1, testname2, ... : tests to run in the form of testFunction,
  75. TestClass or TestClass.testMethod
  76. ]]
  77. local is_equal -- defined here to allow calling from mismatchFormattingPureList
  78. ----------------------------------------------------------------
  79. --
  80. -- general utility functions
  81. --
  82. ----------------------------------------------------------------
  83. local function pcall_or_abort(func, ...)
  84. -- unpack is a global function for Lua 5.1, otherwise use table.unpack
  85. local unpack = rawget(_G, "unpack") or table.unpack
  86. local result = {pcall(func, ...)}
  87. if not result[1] then
  88. -- an error occurred
  89. print(result[2]) -- error message
  90. print()
  91. print(M.USAGE)
  92. os.exit(-1)
  93. end
  94. return unpack(result, 2)
  95. end
  96. local crossTypeOrdering = {
  97. number = 1, boolean = 2, string = 3, table = 4, other = 5
  98. }
  99. local crossTypeComparison = {
  100. number = function(a, b) return a < b end,
  101. string = function(a, b) return a < b end,
  102. other = function(a, b) return tostring(a) < tostring(b) end,
  103. }
  104. local function crossTypeSort(a, b)
  105. local type_a, type_b = type(a), type(b)
  106. if type_a == type_b then
  107. local func = crossTypeComparison[type_a] or crossTypeComparison.other
  108. return func(a, b)
  109. end
  110. type_a = crossTypeOrdering[type_a] or crossTypeOrdering.other
  111. type_b = crossTypeOrdering[type_b] or crossTypeOrdering.other
  112. return type_a < type_b
  113. end
  114. local function __genSortedIndex( t )
  115. -- Returns a sequence consisting of t's keys, sorted.
  116. local sortedIndex = {}
  117. for key,_ in pairs(t) do
  118. table.insert(sortedIndex, key)
  119. end
  120. table.sort(sortedIndex, crossTypeSort)
  121. return sortedIndex
  122. end
  123. M.private.__genSortedIndex = __genSortedIndex
  124. local function sortedNext(state, control)
  125. -- Equivalent of the next() function of table iteration, but returns the
  126. -- keys in sorted order (see __genSortedIndex and crossTypeSort).
  127. -- The state is a temporary variable during iteration and contains the
  128. -- sorted key table (state.sortedIdx). It also stores the last index (into
  129. -- the keys) used by the iteration, to find the next one quickly.
  130. local key
  131. --print("sortedNext: control = "..tostring(control) )
  132. if control == nil then
  133. -- start of iteration
  134. state.count = #state.sortedIdx
  135. state.lastIdx = 1
  136. key = state.sortedIdx[1]
  137. return key, state.t[key]
  138. end
  139. -- normally, we expect the control variable to match the last key used
  140. if control ~= state.sortedIdx[state.lastIdx] then
  141. -- strange, we have to find the next value by ourselves
  142. -- the key table is sorted in crossTypeSort() order! -> use bisection
  143. local lower, upper = 1, state.count
  144. repeat
  145. state.lastIdx = math.modf((lower + upper) / 2)
  146. key = state.sortedIdx[state.lastIdx]
  147. if key == control then
  148. break -- key found (and thus prev index)
  149. end
  150. if crossTypeSort(key, control) then
  151. -- key < control, continue search "right" (towards upper bound)
  152. lower = state.lastIdx + 1
  153. else
  154. -- key > control, continue search "left" (towards lower bound)
  155. upper = state.lastIdx - 1
  156. end
  157. until lower > upper
  158. if lower > upper then -- only true if the key wasn't found, ...
  159. state.lastIdx = state.count -- ... so ensure no match in code below
  160. end
  161. end
  162. -- proceed by retrieving the next value (or nil) from the sorted keys
  163. state.lastIdx = state.lastIdx + 1
  164. key = state.sortedIdx[state.lastIdx]
  165. if key then
  166. return key, state.t[key]
  167. end
  168. -- getting here means returning `nil`, which will end the iteration
  169. end
  170. local function sortedPairs(tbl)
  171. -- Equivalent of the pairs() function on tables. Allows to iterate in
  172. -- sorted order. As required by "generic for" loops, this will return the
  173. -- iterator (function), an "invariant state", and the initial control value.
  174. -- (see http://www.lua.org/pil/7.2.html)
  175. return sortedNext, {t = tbl, sortedIdx = __genSortedIndex(tbl)}, nil
  176. end
  177. M.private.sortedPairs = sortedPairs
  178. -- seed the random with a strongly varying seed
  179. math.randomseed(math.floor(os.clock()*1E11))
  180. local function randomizeTable( t )
  181. -- randomize the item orders of the table t
  182. for i = #t, 2, -1 do
  183. local j = math.random(i)
  184. if i ~= j then
  185. t[i], t[j] = t[j], t[i]
  186. end
  187. end
  188. end
  189. M.private.randomizeTable = randomizeTable
  190. local function strsplit(delimiter, text)
  191. -- Split text into a list consisting of the strings in text, separated
  192. -- by strings matching delimiter (which may _NOT_ be a pattern).
  193. -- Example: strsplit(", ", "Anna, Bob, Charlie, Dolores")
  194. if delimiter == "" then -- this would result in endless loops
  195. error("delimiter matches empty string!")
  196. end
  197. local list, pos, first, last = {}, 1
  198. while true do
  199. first, last = text:find(delimiter, pos, true)
  200. if first then -- found?
  201. table.insert(list, text:sub(pos, first - 1))
  202. pos = last + 1
  203. else
  204. table.insert(list, text:sub(pos))
  205. break
  206. end
  207. end
  208. return list
  209. end
  210. M.private.strsplit = strsplit
  211. local function hasNewLine( s )
  212. -- return true if s has a newline
  213. return (string.find(s, '\n', 1, true) ~= nil)
  214. end
  215. M.private.hasNewLine = hasNewLine
  216. local function prefixString( prefix, s )
  217. -- Prefix all the lines of s with prefix
  218. return prefix .. string.gsub(s, '\n', '\n' .. prefix)
  219. end
  220. M.private.prefixString = prefixString
  221. local function strMatch(s, pattern, start, final )
  222. -- return true if s matches completely the pattern from index start to index end
  223. -- return false in every other cases
  224. -- if start is nil, matches from the beginning of the string
  225. -- if final is nil, matches to the end of the string
  226. start = start or 1
  227. final = final or string.len(s)
  228. local foundStart, foundEnd = string.find(s, pattern, start, false)
  229. return foundStart == start and foundEnd == final
  230. end
  231. M.private.strMatch = strMatch
  232. local function patternFilter(patterns, expr)
  233. -- Run `expr` through the inclusion and exclusion rules defined in patterns
  234. -- and return true if expr shall be included, false for excluded.
  235. -- Inclusion pattern are defined as normal patterns, exclusions
  236. -- patterns start with `!` and are followed by a normal pattern
  237. -- result: nil = UNKNOWN (not matched yet), true = ACCEPT, false = REJECT
  238. -- default: true if no explicit "include" is found, set to false otherwise
  239. local default, result = true, nil
  240. if patterns ~= nil then
  241. for _, pattern in ipairs(patterns) do
  242. local exclude = pattern:sub(1,1) == '!'
  243. if exclude then
  244. pattern = pattern:sub(2)
  245. else
  246. -- at least one include pattern specified, a match is required
  247. default = false
  248. end
  249. -- print('pattern: ',pattern)
  250. -- print('exclude: ',exclude)
  251. -- print('default: ',default)
  252. if string.find(expr, pattern) then
  253. -- set result to false when excluding, true otherwise
  254. result = not exclude
  255. end
  256. end
  257. end
  258. if result ~= nil then
  259. return result
  260. end
  261. return default
  262. end
  263. M.private.patternFilter = patternFilter
  264. local function xmlEscape( s )
  265. -- Return s escaped for XML attributes
  266. -- escapes table:
  267. -- " &quot;
  268. -- ' &apos;
  269. -- < &lt;
  270. -- > &gt;
  271. -- & &amp;
  272. return string.gsub( s, '.', {
  273. ['&'] = "&amp;",
  274. ['"'] = "&quot;",
  275. ["'"] = "&apos;",
  276. ['<'] = "&lt;",
  277. ['>'] = "&gt;",
  278. } )
  279. end
  280. M.private.xmlEscape = xmlEscape
  281. local function xmlCDataEscape( s )
  282. -- Return s escaped for CData section, escapes: "]]>"
  283. return string.gsub( s, ']]>', ']]&gt;' )
  284. end
  285. M.private.xmlCDataEscape = xmlCDataEscape
  286. local function stripLuaunitTrace( stackTrace )
  287. --[[
  288. -- Example of a traceback:
  289. <<stack traceback:
  290. example_with_luaunit.lua:130: in function 'test2_withFailure'
  291. ./luaunit.lua:1449: in function <./luaunit.lua:1449>
  292. [C]: in function 'xpcall'
  293. ./luaunit.lua:1449: in function 'protectedCall'
  294. ./luaunit.lua:1508: in function 'execOneFunction'
  295. ./luaunit.lua:1596: in function 'runSuiteByInstances'
  296. ./luaunit.lua:1660: in function 'runSuiteByNames'
  297. ./luaunit.lua:1736: in function 'runSuite'
  298. example_with_luaunit.lua:140: in main chunk
  299. [C]: in ?>>
  300. Other example:
  301. <<stack traceback:
  302. ./luaunit.lua:545: in function 'assertEquals'
  303. example_with_luaunit.lua:58: in function 'TestToto.test7'
  304. ./luaunit.lua:1517: in function <./luaunit.lua:1517>
  305. [C]: in function 'xpcall'
  306. ./luaunit.lua:1517: in function 'protectedCall'
  307. ./luaunit.lua:1578: in function 'execOneFunction'
  308. ./luaunit.lua:1677: in function 'runSuiteByInstances'
  309. ./luaunit.lua:1730: in function 'runSuiteByNames'
  310. ./luaunit.lua:1806: in function 'runSuite'
  311. example_with_luaunit.lua:140: in main chunk
  312. [C]: in ?>>
  313. <<stack traceback:
  314. luaunit2/example_with_luaunit.lua:124: in function 'test1_withFailure'
  315. luaunit2/luaunit.lua:1532: in function <luaunit2/luaunit.lua:1532>
  316. [C]: in function 'xpcall'
  317. luaunit2/luaunit.lua:1532: in function 'protectedCall'
  318. luaunit2/luaunit.lua:1591: in function 'execOneFunction'
  319. luaunit2/luaunit.lua:1679: in function 'runSuiteByInstances'
  320. luaunit2/luaunit.lua:1743: in function 'runSuiteByNames'
  321. luaunit2/luaunit.lua:1819: in function 'runSuite'
  322. luaunit2/example_with_luaunit.lua:140: in main chunk
  323. [C]: in ?>>
  324. -- first line is "stack traceback": KEEP
  325. -- next line may be luaunit line: REMOVE
  326. -- next lines are call in the program under testOk: REMOVE
  327. -- next lines are calls from luaunit to call the program under test: KEEP
  328. -- Strategy:
  329. -- keep first line
  330. -- remove lines that are part of luaunit
  331. -- kepp lines until we hit a luaunit line
  332. ]]
  333. local function isLuaunitInternalLine( s )
  334. -- return true if line of stack trace comes from inside luaunit
  335. return s:find('[/\\]luaunit%.lua:%d+: ') ~= nil
  336. end
  337. -- print( '<<'..stackTrace..'>>' )
  338. local t = strsplit( '\n', stackTrace )
  339. -- print( prettystr(t) )
  340. local idx = 2
  341. -- remove lines that are still part of luaunit
  342. while t[idx] and isLuaunitInternalLine( t[idx] ) do
  343. -- print('Removing : '..t[idx] )
  344. table.remove(t, idx)
  345. end
  346. -- keep lines until we hit luaunit again
  347. while t[idx] and (not isLuaunitInternalLine(t[idx])) do
  348. -- print('Keeping : '..t[idx] )
  349. idx = idx + 1
  350. end
  351. -- remove remaining luaunit lines
  352. while t[idx] do
  353. -- print('Removing : '..t[idx] )
  354. table.remove(t, idx)
  355. end
  356. -- print( prettystr(t) )
  357. return table.concat( t, '\n')
  358. end
  359. M.private.stripLuaunitTrace = stripLuaunitTrace
  360. local function prettystr_sub(v, indentLevel, printTableRefs, recursionTable )
  361. local type_v = type(v)
  362. if "string" == type_v then
  363. -- use clever delimiters according to content:
  364. -- enclose with single quotes if string contains ", but no '
  365. if v:find('"', 1, true) and not v:find("'", 1, true) then
  366. return "'" .. v .. "'"
  367. end
  368. -- use double quotes otherwise, escape embedded "
  369. return '"' .. v:gsub('"', '\\"') .. '"'
  370. elseif "table" == type_v then
  371. --if v.__class__ then
  372. -- return string.gsub( tostring(v), 'table', v.__class__ )
  373. --end
  374. return M.private._table_tostring(v, indentLevel, printTableRefs, recursionTable)
  375. elseif "number" == type_v then
  376. -- eliminate differences in formatting between various Lua versions
  377. if v ~= v then
  378. return "#NaN" -- "not a number"
  379. end
  380. if v == math.huge then
  381. return "#Inf" -- "infinite"
  382. end
  383. if v == -math.huge then
  384. return "-#Inf"
  385. end
  386. if _VERSION == "Lua 5.3" then
  387. local i = math.tointeger(v)
  388. if i then
  389. return tostring(i)
  390. end
  391. end
  392. end
  393. return tostring(v)
  394. end
  395. local function prettystr( v )
  396. --[[ Pretty string conversion, to display the full content of a variable of any type.
  397. * string are enclosed with " by default, or with ' if string contains a "
  398. * tables are expanded to show their full content, with indentation in case of nested tables
  399. ]]--
  400. local recursionTable = {}
  401. local s = prettystr_sub(v, 1, M.PRINT_TABLE_REF_IN_ERROR_MSG, recursionTable)
  402. if recursionTable.recursionDetected and not M.PRINT_TABLE_REF_IN_ERROR_MSG then
  403. -- some table contain recursive references,
  404. -- so we must recompute the value by including all table references
  405. -- else the result looks like crap
  406. recursionTable = {}
  407. s = prettystr_sub(v, 1, true, recursionTable)
  408. end
  409. return s
  410. end
  411. M.prettystr = prettystr
  412. local function tryMismatchFormatting( table_a, table_b, doDeepAnalysis )
  413. --[[
  414. Prepares a nice error message when comparing tables, performing a deeper
  415. analysis.
  416. Arguments:
  417. * table_a, table_b: tables to be compared
  418. * doDeepAnalysis:
  419. M.DEFAULT_DEEP_ANALYSIS: (the default if not specified) perform deep analysis only for big lists and big dictionnaries
  420. M.FORCE_DEEP_ANALYSIS : always perform deep analysis
  421. M.DISABLE_DEEP_ANALYSIS: never perform deep analysis
  422. Returns: {success, result}
  423. * success: false if deep analysis could not be performed
  424. in this case, just use standard assertion message
  425. * result: if success is true, a multi-line string with deep analysis of the two lists
  426. ]]
  427. -- check if table_a & table_b are suitable for deep analysis
  428. if type(table_a) ~= 'table' or type(table_b) ~= 'table' then
  429. return false
  430. end
  431. if doDeepAnalysis == M.DISABLE_DEEP_ANALYSIS then
  432. return false
  433. end
  434. local len_a, len_b, isPureList = #table_a, #table_b, true
  435. for k1, v1 in pairs(table_a) do
  436. if type(k1) ~= 'number' or k1 > len_a then
  437. -- this table a mapping
  438. isPureList = false
  439. break
  440. end
  441. end
  442. if isPureList then
  443. for k2, v2 in pairs(table_b) do
  444. if type(k2) ~= 'number' or k2 > len_b then
  445. -- this table a mapping
  446. isPureList = false
  447. break
  448. end
  449. end
  450. end
  451. if isPureList and math.min(len_a, len_b) < M.LIST_DIFF_ANALYSIS_THRESHOLD then
  452. if not (doDeepAnalysis == M.FORCE_DEEP_ANALYSIS) then
  453. return false
  454. end
  455. end
  456. if isPureList then
  457. return M.private.mismatchFormattingPureList( table_a, table_b )
  458. else
  459. -- only work on mapping for the moment
  460. -- return M.private.mismatchFormattingMapping( table_a, table_b, doDeepAnalysis )
  461. return false
  462. end
  463. end
  464. M.private.tryMismatchFormatting = tryMismatchFormatting
  465. local function getTaTbDescr()
  466. if not M.ORDER_ACTUAL_EXPECTED then
  467. return 'expected', 'actual'
  468. end
  469. return 'actual', 'expected'
  470. end
  471. local function extendWithStrFmt( res, ... )
  472. table.insert( res, string.format( ... ) )
  473. end
  474. local function mismatchFormattingMapping( table_a, table_b, doDeepAnalysis )
  475. --[[
  476. Prepares a nice error message when comparing tables which are not pure lists, performing a deeper
  477. analysis.
  478. Returns: {success, result}
  479. * success: false if deep analysis could not be performed
  480. in this case, just use standard assertion message
  481. * result: if success is true, a multi-line string with deep analysis of the two lists
  482. ]]
  483. -- disable for the moment
  484. --[[
  485. local result = {}
  486. local descrTa, descrTb = getTaTbDescr()
  487. local keysCommon = {}
  488. local keysOnlyTa = {}
  489. local keysOnlyTb = {}
  490. local keysDiffTaTb = {}
  491. local k, v
  492. for k,v in pairs( table_a ) do
  493. if is_equal( v, table_b[k] ) then
  494. table.insert( keysCommon, k )
  495. else
  496. if table_b[k] == nil then
  497. table.insert( keysOnlyTa, k )
  498. else
  499. table.insert( keysDiffTaTb, k )
  500. end
  501. end
  502. end
  503. for k,v in pairs( table_b ) do
  504. if not is_equal( v, table_a[k] ) and table_a[k] == nil then
  505. table.insert( keysOnlyTb, k )
  506. end
  507. end
  508. local len_a = #keysCommon + #keysDiffTaTb + #keysOnlyTa
  509. local len_b = #keysCommon + #keysDiffTaTb + #keysOnlyTb
  510. local limited_display = (len_a < 5 or len_b < 5)
  511. if math.min(len_a, len_b) < M.TABLE_DIFF_ANALYSIS_THRESHOLD then
  512. return false
  513. end
  514. if not limited_display then
  515. if len_a == len_b then
  516. extendWithStrFmt( result, 'Table A (%s) and B (%s) both have %d items', descrTa, descrTb, len_a )
  517. else
  518. extendWithStrFmt( result, 'Table A (%s) has %d items and table B (%s) has %d items', descrTa, len_a, descrTb, len_b )
  519. end
  520. if #keysCommon == 0 and #keysDiffTaTb == 0 then
  521. table.insert( result, 'Table A and B have no keys in common, they are totally different')
  522. else
  523. local s_other = 'other '
  524. if #keysCommon then
  525. extendWithStrFmt( result, 'Table A and B have %d identical items', #keysCommon )
  526. else
  527. table.insert( result, 'Table A and B have no identical items' )
  528. s_other = ''
  529. end
  530. if #keysDiffTaTb ~= 0 then
  531. result[#result] = string.format( '%s and %d items differing present in both tables', result[#result], #keysDiffTaTb)
  532. else
  533. result[#result] = string.format( '%s and no %sitems differing present in both tables', result[#result], s_other, #keysDiffTaTb)
  534. end
  535. end
  536. extendWithStrFmt( result, 'Table A has %d keys not present in table B and table B has %d keys not present in table A', #keysOnlyTa, #keysOnlyTb )
  537. end
  538. local function keytostring(k)
  539. if "string" == type(k) and k:match("^[_%a][_%w]*$") then
  540. return k
  541. end
  542. return prettystr(k)
  543. end
  544. if #keysDiffTaTb ~= 0 then
  545. table.insert( result, 'Items differing in A and B:')
  546. for k,v in sortedPairs( keysDiffTaTb ) do
  547. extendWithStrFmt( result, ' - A[%s]: %s', keytostring(v), prettystr(table_a[v]) )
  548. extendWithStrFmt( result, ' + B[%s]: %s', keytostring(v), prettystr(table_b[v]) )
  549. end
  550. end
  551. if #keysOnlyTa ~= 0 then
  552. table.insert( result, 'Items only in table A:' )
  553. for k,v in sortedPairs( keysOnlyTa ) do
  554. extendWithStrFmt( result, ' - A[%s]: %s', keytostring(v), prettystr(table_a[v]) )
  555. end
  556. end
  557. if #keysOnlyTb ~= 0 then
  558. table.insert( result, 'Items only in table B:' )
  559. for k,v in sortedPairs( keysOnlyTb ) do
  560. extendWithStrFmt( result, ' + B[%s]: %s', keytostring(v), prettystr(table_b[v]) )
  561. end
  562. end
  563. if #keysCommon ~= 0 then
  564. table.insert( result, 'Items common to A and B:')
  565. for k,v in sortedPairs( keysCommon ) do
  566. extendWithStrFmt( result, ' = A and B [%s]: %s', keytostring(v), prettystr(table_a[v]) )
  567. end
  568. end
  569. return true, table.concat( result, '\n')
  570. ]]
  571. end
  572. M.private.mismatchFormattingMapping = mismatchFormattingMapping
  573. local function mismatchFormattingPureList( table_a, table_b )
  574. --[[
  575. Prepares a nice error message when comparing tables which are lists, performing a deeper
  576. analysis.
  577. Returns: {success, result}
  578. * success: false if deep analysis could not be performed
  579. in this case, just use standard assertion message
  580. * result: if success is true, a multi-line string with deep analysis of the two lists
  581. ]]
  582. local result, descrTa, descrTb = {}, getTaTbDescr()
  583. local len_a, len_b, refa, refb = #table_a, #table_b, '', ''
  584. if M.PRINT_TABLE_REF_IN_ERROR_MSG then
  585. refa, refb = string.format( '<%s> ', tostring(table_a)), string.format('<%s> ', tostring(table_b) )
  586. end
  587. local longest, shortest = math.max(len_a, len_b), math.min(len_a, len_b)
  588. local deltalv = longest - shortest
  589. local commonUntil = shortest
  590. for i = 1, shortest do
  591. if not is_equal(table_a[i], table_b[i]) then
  592. commonUntil = i - 1
  593. break
  594. end
  595. end
  596. local commonBackTo = shortest - 1
  597. for i = 0, shortest - 1 do
  598. if not is_equal(table_a[len_a-i], table_b[len_b-i]) then
  599. commonBackTo = i - 1
  600. break
  601. end
  602. end
  603. table.insert( result, 'List difference analysis:' )
  604. if len_a == len_b then
  605. -- TODO: handle expected/actual naming
  606. extendWithStrFmt( result, '* lists %sA (%s) and %sB (%s) have the same size', refa, descrTa, refb, descrTb )
  607. else
  608. extendWithStrFmt( result, '* list sizes differ: list %sA (%s) has %d items, list %sB (%s) has %d items', refa, descrTa, len_a, refb, descrTb, len_b )
  609. end
  610. extendWithStrFmt( result, '* lists A and B start differing at index %d', commonUntil+1 )
  611. if commonBackTo >= 0 then
  612. if deltalv > 0 then
  613. extendWithStrFmt( result, '* lists A and B are equal again from index %d for A, %d for B', len_a-commonBackTo, len_b-commonBackTo )
  614. else
  615. extendWithStrFmt( result, '* lists A and B are equal again from index %d', len_a-commonBackTo )
  616. end
  617. end
  618. local function insertABValue(ai, bi)
  619. bi = bi or ai
  620. if is_equal( table_a[ai], table_b[bi]) then
  621. return extendWithStrFmt( result, ' = A[%d], B[%d]: %s', ai, bi, prettystr(table_a[ai]) )
  622. else
  623. extendWithStrFmt( result, ' - A[%d]: %s', ai, prettystr(table_a[ai]))
  624. extendWithStrFmt( result, ' + B[%d]: %s', bi, prettystr(table_b[bi]))
  625. end
  626. end
  627. -- common parts to list A & B, at the beginning
  628. if commonUntil > 0 then
  629. table.insert( result, '* Common parts:' )
  630. for i = 1, commonUntil do
  631. insertABValue( i )
  632. end
  633. end
  634. -- diffing parts to list A & B
  635. if commonUntil < shortest - commonBackTo - 1 then
  636. table.insert( result, '* Differing parts:' )
  637. for i = commonUntil + 1, shortest - commonBackTo - 1 do
  638. insertABValue( i )
  639. end
  640. end
  641. -- display indexes of one list, with no match on other list
  642. if shortest - commonBackTo <= longest - commonBackTo - 1 then
  643. table.insert( result, '* Present only in one list:' )
  644. for i = shortest - commonBackTo, longest - commonBackTo - 1 do
  645. if len_a > len_b then
  646. extendWithStrFmt( result, ' - A[%d]: %s', i, prettystr(table_a[i]) )
  647. -- table.insert( result, '+ (no matching B index)')
  648. else
  649. -- table.insert( result, '- no matching A index')
  650. extendWithStrFmt( result, ' + B[%d]: %s', i, prettystr(table_b[i]) )
  651. end
  652. end
  653. end
  654. -- common parts to list A & B, at the end
  655. if commonBackTo >= 0 then
  656. table.insert( result, '* Common parts at the end of the lists' )
  657. for i = longest - commonBackTo, longest do
  658. if len_a > len_b then
  659. insertABValue( i, i-deltalv )
  660. else
  661. insertABValue( i-deltalv, i )
  662. end
  663. end
  664. end
  665. return true, table.concat( result, '\n')
  666. end
  667. M.private.mismatchFormattingPureList = mismatchFormattingPureList
  668. local function prettystrPairs(value1, value2, suffix_a, suffix_b)
  669. --[[
  670. This function helps with the recurring task of constructing the "expected
  671. vs. actual" error messages. It takes two arbitrary values and formats
  672. corresponding strings with prettystr().
  673. To keep the (possibly complex) output more readable in case the resulting
  674. strings contain line breaks, they get automatically prefixed with additional
  675. newlines. Both suffixes are optional (default to empty strings), and get
  676. appended to the "value1" string. "suffix_a" is used if line breaks were
  677. encountered, "suffix_b" otherwise.
  678. Returns the two formatted strings (including padding/newlines).
  679. ]]
  680. local str1, str2 = prettystr(value1), prettystr(value2)
  681. if hasNewLine(str1) or hasNewLine(str2) then
  682. -- line break(s) detected, add padding
  683. return "\n" .. str1 .. (suffix_a or ""), "\n" .. str2
  684. end
  685. return str1 .. (suffix_b or ""), str2
  686. end
  687. M.private.prettystrPairs = prettystrPairs
  688. local function _table_raw_tostring( t )
  689. -- return the default tostring() for tables, with the table ID, even if the table has a metatable
  690. -- with the __tostring converter
  691. local mt = getmetatable( t )
  692. if mt then setmetatable( t, nil ) end
  693. local ref = tostring(t)
  694. if mt then setmetatable( t, mt ) end
  695. return ref
  696. end
  697. M.private._table_raw_tostring = _table_raw_tostring
  698. local TABLE_TOSTRING_SEP = ", "
  699. local TABLE_TOSTRING_SEP_LEN = string.len(TABLE_TOSTRING_SEP)
  700. local function _table_tostring( tbl, indentLevel, printTableRefs, recursionTable )
  701. printTableRefs = printTableRefs or M.PRINT_TABLE_REF_IN_ERROR_MSG
  702. recursionTable = recursionTable or {}
  703. recursionTable[tbl] = true
  704. local result, dispOnMultLines = {}, false
  705. -- like prettystr but do not enclose with "" if the string is just alphanumerical
  706. -- this is better for displaying table keys who are often simple strings
  707. local function keytostring(k)
  708. if "string" == type(k) and k:match("^[_%a][_%w]*$") then
  709. return k
  710. end
  711. return prettystr_sub(k, indentLevel+1, printTableRefs, recursionTable)
  712. end
  713. local mt = getmetatable( tbl )
  714. if mt and mt.__tostring then
  715. -- if table has a __tostring() function in its metatable, use it to display the table
  716. -- else, compute a regular table
  717. result = strsplit( '\n', tostring(tbl) )
  718. return M.private._table_tostring_format_multiline_string( result, indentLevel )
  719. else
  720. -- no metatable, compute the table representation
  721. local entry, count, seq_index = nil, 0, 1
  722. for k, v in sortedPairs( tbl ) do
  723. -- key part
  724. if k == seq_index then
  725. -- for the sequential part of tables, we'll skip the "<key>=" output
  726. entry = ''
  727. seq_index = seq_index + 1
  728. elseif recursionTable[k] then
  729. -- recursion in the key detected
  730. recursionTable.recursionDetected = true
  731. entry = "<".._table_raw_tostring(k)..">="
  732. else
  733. entry = keytostring(k) .. "="
  734. end
  735. -- value part
  736. if recursionTable[v] then
  737. -- recursion in the value detected!
  738. recursionTable.recursionDetected = true
  739. entry = entry .. "<".._table_raw_tostring(v)..">"
  740. else
  741. entry = entry ..
  742. prettystr_sub( v, indentLevel+1, printTableRefs, recursionTable )
  743. end
  744. count = count + 1
  745. result[count] = entry
  746. end
  747. return M.private._table_tostring_format_result( tbl, result, indentLevel, printTableRefs )
  748. end
  749. end
  750. M.private._table_tostring = _table_tostring -- prettystr_sub() needs it
  751. local function _table_tostring_format_multiline_string( tbl_str, indentLevel )
  752. local indentString = '\n'..string.rep(" ", indentLevel - 1)
  753. return table.concat( tbl_str, indentString )
  754. end
  755. M.private._table_tostring_format_multiline_string = _table_tostring_format_multiline_string
  756. local function _table_tostring_format_result( tbl, result, indentLevel, printTableRefs )
  757. -- final function called in _table_to_string() to format the resulting list of
  758. -- string describing the table.
  759. local dispOnMultLines = false
  760. -- set dispOnMultLines to true if the maximum LINE_LENGTH would be exceeded with the values
  761. local totalLength = 0
  762. for k, v in ipairs( result ) do
  763. totalLength = totalLength + string.len( v )
  764. if totalLength >= M.LINE_LENGTH then
  765. dispOnMultLines = true
  766. break
  767. end
  768. end
  769. -- set dispOnMultLines to true if the max LINE_LENGTH would be exceeded
  770. -- with the values and the separators.
  771. if not dispOnMultLines then
  772. -- adjust with length of separator(s):
  773. -- two items need 1 sep, three items two seps, ... plus len of '{}'
  774. if #result > 0 then
  775. totalLength = totalLength + TABLE_TOSTRING_SEP_LEN * (#result - 1)
  776. end
  777. dispOnMultLines = (totalLength + 2 >= M.LINE_LENGTH)
  778. end
  779. -- now reformat the result table (currently holding element strings)
  780. if dispOnMultLines then
  781. local indentString = string.rep(" ", indentLevel - 1)
  782. result = {
  783. "{\n ",
  784. indentString,
  785. table.concat(result, ",\n " .. indentString),
  786. "\n",
  787. indentString,
  788. "}"
  789. }
  790. else
  791. result = {"{", table.concat(result, TABLE_TOSTRING_SEP), "}"}
  792. end
  793. if printTableRefs then
  794. table.insert(result, 1, "<".._table_raw_tostring(tbl).."> ") -- prepend table ref
  795. end
  796. return table.concat(result)
  797. end
  798. M.private._table_tostring_format_result = _table_tostring_format_result -- prettystr_sub() needs it
  799. local function _table_contains(t, element)
  800. if type(t) == "table" then
  801. local type_e = type(element)
  802. for _, value in pairs(t) do
  803. if type(value) == type_e then
  804. if value == element then
  805. return true
  806. end
  807. if type_e == 'table' then
  808. -- if we wanted recursive items content comparison, we could use
  809. -- _is_table_items_equals(v, expected) but one level of just comparing
  810. -- items is sufficient
  811. if M.private._is_table_equals( value, element ) then
  812. return true
  813. end
  814. end
  815. end
  816. end
  817. end
  818. return false
  819. end
  820. local function _is_table_items_equals(actual, expected )
  821. local type_a, type_e = type(actual), type(expected)
  822. if (type_a == 'table') and (type_e == 'table') then
  823. for k, v in pairs(actual) do
  824. if not _table_contains(expected, v) then
  825. return false
  826. end
  827. end
  828. for k, v in pairs(expected) do
  829. if not _table_contains(actual, v) then
  830. return false
  831. end
  832. end
  833. return true
  834. elseif type_a ~= type_e then
  835. return false
  836. elseif actual ~= expected then
  837. return false
  838. end
  839. return true
  840. end
  841. --[[
  842. This is a specialized metatable to help with the bookkeeping of recursions
  843. in _is_table_equals(). It provides an __index table that implements utility
  844. functions for easier management of the table. The "cached" method queries
  845. the state of a specific (actual,expected) pair; and the "store" method sets
  846. this state to the given value. The state of pairs not "seen" / visited is
  847. assumed to be `nil`.
  848. ]]
  849. local _recursion_cache_MT = {
  850. __index = {
  851. -- Return the cached value for an (actual,expected) pair (or `nil`)
  852. cached = function(t, actual, expected)
  853. local subtable = t[actual] or {}
  854. return subtable[expected]
  855. end,
  856. -- Store cached value for a specific (actual,expected) pair.
  857. -- Returns the value, so it's easy to use for a "tailcall" (return ...).
  858. store = function(t, actual, expected, value, asymmetric)
  859. local subtable = t[actual]
  860. if not subtable then
  861. subtable = {}
  862. t[actual] = subtable
  863. end
  864. subtable[expected] = value
  865. -- Unless explicitly marked "asymmetric": Consider the recursion
  866. -- on (expected,actual) to be equivalent to (actual,expected) by
  867. -- default, and thus cache the value for both.
  868. if not asymmetric then
  869. t:store(expected, actual, value, true)
  870. end
  871. return value
  872. end
  873. }
  874. }
  875. local function _is_table_equals(actual, expected, recursions)
  876. local type_a, type_e = type(actual), type(expected)
  877. recursions = recursions or setmetatable({}, _recursion_cache_MT)
  878. if type_a ~= type_e then
  879. return false -- different types won't match
  880. end
  881. if (type_a == 'table') --[[ and (type_e == 'table') ]] then
  882. if actual == expected then
  883. -- Both reference the same table, so they are actually identical
  884. return recursions:store(actual, expected, true)
  885. end
  886. -- If we've tested this (actual,expected) pair before: return cached value
  887. local previous = recursions:cached(actual, expected)
  888. if previous ~= nil then
  889. return previous
  890. end
  891. -- Mark this (actual,expected) pair, so we won't recurse it again. For
  892. -- now, assume a "false" result, which we might adjust later if needed.
  893. recursions:store(actual, expected, false)
  894. -- Tables must have identical element count, or they can't match.
  895. if (#actual ~= #expected) then
  896. return false
  897. end
  898. local actualKeysMatched, actualTableKeys = {}, {}
  899. for k, v in pairs(actual) do
  900. if M.TABLE_EQUALS_KEYBYCONTENT and type(k) == "table" then
  901. -- If the keys are tables, things get a bit tricky here as we
  902. -- can have _is_table_equals(t[k1], t[k2]) despite k1 ~= k2. So
  903. -- we first collect table keys from "actual", and then later try
  904. -- to match each table key from "expected" to actualTableKeys.
  905. table.insert(actualTableKeys, k)
  906. else
  907. if not _is_table_equals(v, expected[k], recursions) then
  908. return false -- Mismatch on value, tables can't be equal
  909. end
  910. actualKeysMatched[k] = true -- Keep track of matched keys
  911. end
  912. end
  913. for k, v in pairs(expected) do
  914. if M.TABLE_EQUALS_KEYBYCONTENT and type(k) == "table" then
  915. local found = false
  916. -- Note: DON'T use ipairs() here, table may be non-sequential!
  917. for i, candidate in pairs(actualTableKeys) do
  918. if _is_table_equals(candidate, k, recursions) then
  919. if _is_table_equals(actual[candidate], v, recursions) then
  920. found = true
  921. -- Remove the candidate we matched against from the list
  922. -- of table keys, so each key in actual can only match
  923. -- one key in expected.
  924. actualTableKeys[i] = nil
  925. break
  926. end
  927. -- keys match but values don't, keep searching
  928. end
  929. end
  930. if not found then
  931. return false -- no matching (key,value) pair
  932. end
  933. else
  934. if not actualKeysMatched[k] then
  935. -- Found a key that we did not see in "actual" -> mismatch
  936. return false
  937. end
  938. -- Otherwise actual[k] was already matched against v = expected[k].
  939. end
  940. end
  941. if next(actualTableKeys) then
  942. -- If there is any key left in actualTableKeys, then that is
  943. -- a table-type key in actual with no matching counterpart
  944. -- (in expected), and so the tables aren't equal.
  945. return false
  946. end
  947. -- The tables are actually considered equal, update cache and return result
  948. return recursions:store(actual, expected, true)
  949. elseif actual ~= expected then
  950. return false
  951. end
  952. return true
  953. end
  954. M.private._is_table_equals = _is_table_equals
  955. is_equal = _is_table_equals
  956. local function failure(msg, level)
  957. -- raise an error indicating a test failure
  958. -- for error() compatibility we adjust "level" here (by +1), to report the
  959. -- calling context
  960. error(M.FAILURE_PREFIX .. msg, (level or 1) + 1)
  961. end
  962. local function fail_fmt(level, ...)
  963. -- failure with printf-style formatted message and given error level
  964. failure(string.format(...), (level or 1) + 1)
  965. end
  966. M.private.fail_fmt = fail_fmt
  967. local function error_fmt(level, ...)
  968. -- printf-style error()
  969. error(string.format(...), (level or 1) + 1)
  970. end
  971. ----------------------------------------------------------------
  972. --
  973. -- assertions
  974. --
  975. ----------------------------------------------------------------
  976. local function errorMsgEquality(actual, expected, doDeepAnalysis)
  977. if not M.ORDER_ACTUAL_EXPECTED then
  978. expected, actual = actual, expected
  979. end
  980. if type(expected) == 'string' or type(expected) == 'table' then
  981. local strExpected, strActual = prettystrPairs(expected, actual)
  982. local result = string.format("expected: %s\nactual: %s", strExpected, strActual)
  983. -- extend with mismatch analysis if possible:
  984. local success, mismatchResult
  985. success, mismatchResult = tryMismatchFormatting( actual, expected, doDeepAnalysis )
  986. if success then
  987. result = table.concat( { result, mismatchResult }, '\n' )
  988. end
  989. return result
  990. end
  991. return string.format("expected: %s, actual: %s",
  992. prettystr(expected), prettystr(actual))
  993. end
  994. function M.assertError(f, ...)
  995. -- assert that calling f with the arguments will raise an error
  996. -- example: assertError( f, 1, 2 ) => f(1,2) should generate an error
  997. if pcall( f, ... ) then
  998. failure( "Expected an error when calling function but no error generated", 2 )
  999. end
  1000. end
  1001. function M.fail( msg )
  1002. -- stops a test due to a failure
  1003. failure( msg, 2 )
  1004. end
  1005. function M.failIf( cond, msg )
  1006. -- Fails a test with "msg" if condition is true
  1007. if cond then
  1008. failure( msg, 2 )
  1009. end
  1010. end
  1011. ------------------------------------------------------------------
  1012. -- Equality assertion
  1013. ------------------------------------------------------------------
  1014. function M.assertEquals(actual, expected, doDeepAnalysis)
  1015. if type(actual) == 'table' and type(expected) == 'table' then
  1016. if not _is_table_equals(actual, expected) then
  1017. failure( errorMsgEquality(actual, expected, doDeepAnalysis), 2 )
  1018. end
  1019. elseif type(actual) ~= type(expected) then
  1020. failure( errorMsgEquality(actual, expected), 2 )
  1021. elseif actual ~= expected then
  1022. failure( errorMsgEquality(actual, expected), 2 )
  1023. end
  1024. end
  1025. function M.almostEquals( actual, expected, margin )
  1026. if type(actual) ~= 'number' or type(expected) ~= 'number' or type(margin) ~= 'number' then
  1027. error_fmt(3, 'almostEquals: must supply only number arguments.\nArguments supplied: %s, %s, %s',
  1028. prettystr(actual), prettystr(expected), prettystr(margin))
  1029. end
  1030. if margin < 0 then
  1031. error('almostEquals: margin must not be negative, current value is ' .. margin, 3)
  1032. end
  1033. return math.abs(expected - actual) <= margin
  1034. end
  1035. function M.assertAlmostEquals( actual, expected, margin )
  1036. -- check that two floats are close by margin
  1037. margin = margin or M.EPS
  1038. if not M.almostEquals(actual, expected, margin) then
  1039. if not M.ORDER_ACTUAL_EXPECTED then
  1040. expected, actual = actual, expected
  1041. end
  1042. local delta = math.abs(actual - expected)
  1043. fail_fmt(2, 'Values are not almost equal\n' ..
  1044. 'Actual: %s, expected: %s, delta %s above margin of %s',
  1045. actual, expected, delta, margin)
  1046. end
  1047. end
  1048. function M.assertNotEquals(actual, expected)
  1049. if type(actual) ~= type(expected) then
  1050. return
  1051. end
  1052. if type(actual) == 'table' and type(expected) == 'table' then
  1053. if not _is_table_equals(actual, expected) then
  1054. return
  1055. end
  1056. elseif actual ~= expected then
  1057. return
  1058. end
  1059. fail_fmt(2, 'Received the not expected value: %s', prettystr(actual))
  1060. end
  1061. function M.assertNotAlmostEquals( actual, expected, margin )
  1062. -- check that two floats are not close by margin
  1063. margin = margin or M.EPS
  1064. if M.almostEquals(actual, expected, margin) then
  1065. if not M.ORDER_ACTUAL_EXPECTED then
  1066. expected, actual = actual, expected
  1067. end
  1068. local delta = math.abs(actual - expected)
  1069. fail_fmt(2, 'Values are almost equal\nActual: %s, expected: %s' ..
  1070. ', delta %s below margin of %s',
  1071. actual, expected, delta, margin)
  1072. end
  1073. end
  1074. function M.assertItemsEquals(actual, expected)
  1075. -- checks that the items of table expected
  1076. -- are contained in table actual. Warning, this function
  1077. -- is at least O(n^2)
  1078. if not _is_table_items_equals(actual, expected ) then
  1079. expected, actual = prettystrPairs(expected, actual)
  1080. fail_fmt(2, 'Contents of the tables are not identical:\nExpected: %s\nActual: %s',
  1081. expected, actual)
  1082. end
  1083. end
  1084. ------------------------------------------------------------------
  1085. -- String assertion
  1086. ------------------------------------------------------------------
  1087. function M.assertStrContains( str, sub, useRe )
  1088. -- this relies on lua string.find function
  1089. -- a string always contains the empty string
  1090. if not string.find(str, sub, 1, not useRe) then
  1091. sub, str = prettystrPairs(sub, str, '\n')
  1092. fail_fmt(2, 'Error, %s %s was not found in string %s',
  1093. useRe and 'regexp' or 'substring', sub, str)
  1094. end
  1095. end
  1096. function M.assertStrIContains( str, sub )
  1097. -- this relies on lua string.find function
  1098. -- a string always contains the empty string
  1099. if not string.find(str:lower(), sub:lower(), 1, true) then
  1100. sub, str = prettystrPairs(sub, str, '\n')
  1101. fail_fmt(2, 'Error, substring %s was not found (case insensitively) in string %s',
  1102. sub, str)
  1103. end
  1104. end
  1105. function M.assertNotStrContains( str, sub, useRe )
  1106. -- this relies on lua string.find function
  1107. -- a string always contains the empty string
  1108. if string.find(str, sub, 1, not useRe) then
  1109. sub, str = prettystrPairs(sub, str, '\n')
  1110. fail_fmt(2, 'Error, %s %s was found in string %s',
  1111. useRe and 'regexp' or 'substring', sub, str)
  1112. end
  1113. end
  1114. function M.assertNotStrIContains( str, sub )
  1115. -- this relies on lua string.find function
  1116. -- a string always contains the empty string
  1117. if string.find(str:lower(), sub:lower(), 1, true) then
  1118. sub, str = prettystrPairs(sub, str, '\n')
  1119. fail_fmt(2, 'Error, substring %s was found (case insensitively) in string %s',
  1120. sub, str)
  1121. end
  1122. end
  1123. function M.assertStrMatches( str, pattern, start, final )
  1124. -- Verify a full match for the string
  1125. -- for a partial match, simply use assertStrContains with useRe set to true
  1126. if not strMatch( str, pattern, start, final ) then
  1127. pattern, str = prettystrPairs(pattern, str, '\n')
  1128. fail_fmt(2, 'Error, pattern %s was not matched by string %s',
  1129. pattern, str)
  1130. end
  1131. end
  1132. function M.assertErrorMsgEquals( expectedMsg, func, ... )
  1133. -- assert that calling f with the arguments will raise an error
  1134. -- example: assertError( f, 1, 2 ) => f(1,2) should generate an error
  1135. local no_error, error_msg = pcall( func, ... )
  1136. if no_error then
  1137. failure( 'No error generated when calling function but expected error: "'..expectedMsg..'"', 2 )
  1138. end
  1139. if type(expectedMsg) == "string" and type(error_msg) ~= "string" then
  1140. error_msg = tostring(error_msg)
  1141. end
  1142. local differ = false
  1143. if error_msg ~= expectedMsg then
  1144. local tr = type(error_msg)
  1145. local te = type(expectedMsg)
  1146. if te == 'table' then
  1147. if tr ~= 'table' then
  1148. differ = true
  1149. else
  1150. local ok = pcall(M.assertItemsEquals, error_msg, expectedMsg)
  1151. if not ok then
  1152. differ = true
  1153. end
  1154. end
  1155. else
  1156. differ = true
  1157. end
  1158. end
  1159. if differ then
  1160. error_msg, expectedMsg = prettystrPairs(error_msg, expectedMsg)
  1161. fail_fmt(2, 'Exact error message expected: %s\nError message received: %s\n',
  1162. expectedMsg, error_msg)
  1163. end
  1164. end
  1165. function M.assertErrorMsgContains( partialMsg, func, ... )
  1166. -- assert that calling f with the arguments will raise an error
  1167. -- example: assertError( f, 1, 2 ) => f(1,2) should generate an error
  1168. local no_error, error_msg = pcall( func, ... )
  1169. if no_error then
  1170. failure( 'No error generated when calling function but expected error containing: '..prettystr(partialMsg), 2 )
  1171. end
  1172. if type(error_msg) ~= "string" then
  1173. error_msg = tostring(error_msg)
  1174. end
  1175. if not string.find( error_msg, partialMsg, nil, true ) then
  1176. error_msg, partialMsg = prettystrPairs(error_msg, partialMsg)
  1177. fail_fmt(2, 'Error message does not contain: %s\nError message received: %s\n',
  1178. partialMsg, error_msg)
  1179. end
  1180. end
  1181. function M.assertErrorMsgMatches( expectedMsg, func, ... )
  1182. -- assert that calling f with the arguments will raise an error
  1183. -- example: assertError( f, 1, 2 ) => f(1,2) should generate an error
  1184. local no_error, error_msg = pcall( func, ... )
  1185. if no_error then
  1186. failure( 'No error generated when calling function but expected error matching: "'..expectedMsg..'"', 2 )
  1187. end
  1188. if type(error_msg) ~= "string" then
  1189. error_msg = tostring(error_msg)
  1190. end
  1191. if not strMatch( error_msg, expectedMsg ) then
  1192. expectedMsg, error_msg = prettystrPairs(expectedMsg, error_msg)
  1193. fail_fmt(2, 'Error message does not match: %s\nError message received: %s\n',
  1194. expectedMsg, error_msg)
  1195. end
  1196. end
  1197. ------------------------------------------------------------------
  1198. -- Type assertions
  1199. ------------------------------------------------------------------
  1200. function M.assertEvalToTrue(value)
  1201. if not value then
  1202. failure("expected: a value evaluating to true, actual: " ..prettystr(value), 2)
  1203. end
  1204. end
  1205. function M.assertEvalToFalse(value)
  1206. if value then
  1207. failure("expected: false or nil, actual: " ..prettystr(value), 2)
  1208. end
  1209. end
  1210. function M.assertIsTrue(value)
  1211. if value ~= true then
  1212. failure("expected: true, actual: " ..prettystr(value), 2)
  1213. end
  1214. end
  1215. function M.assertNotIsTrue(value)
  1216. if value == true then
  1217. failure("expected: anything but true, actual: " ..prettystr(value), 2)
  1218. end
  1219. end
  1220. function M.assertIsFalse(value)
  1221. if value ~= false then
  1222. failure("expected: false, actual: " ..prettystr(value), 2)
  1223. end
  1224. end
  1225. function M.assertNotIsFalse(value)
  1226. if value == false then
  1227. failure("expected: anything but false, actual: " ..prettystr(value), 2)
  1228. end
  1229. end
  1230. function M.assertIsNil(value)
  1231. if value ~= nil then
  1232. failure("expected: nil, actual: " ..prettystr(value), 2)
  1233. end
  1234. end
  1235. function M.assertNotIsNil(value)
  1236. if value == nil then
  1237. failure("expected non nil value, received nil", 2)
  1238. end
  1239. end
  1240. --[[
  1241. Add type assertion functions to the module table M. Each of these functions
  1242. takes a single parameter "value", and checks that its Lua type matches the
  1243. expected string (derived from the function name):
  1244. M.assertIsXxx(value) -> ensure that type(value) conforms to "xxx"
  1245. ]]
  1246. for _, funcName in ipairs(
  1247. {'assertIsNumber', 'assertIsString', 'assertIsTable', 'assertIsBoolean',
  1248. 'assertIsFunction', 'assertIsUserdata', 'assertIsThread'}
  1249. ) do
  1250. local typeExpected = funcName:match("^assertIs([A-Z]%a*)$")
  1251. -- Lua type() always returns lowercase, also make sure the match() succeeded
  1252. typeExpected = typeExpected and typeExpected:lower()
  1253. or error("bad function name '"..funcName.."' for type assertion")
  1254. M[funcName] = function(value)
  1255. if type(value) ~= typeExpected then
  1256. fail_fmt(2, 'Expected: a %s value, actual: type %s, value %s',
  1257. typeExpected, type(value), prettystrPairs(value))
  1258. end
  1259. end
  1260. end
  1261. --[[
  1262. Add shortcuts for verifying type of a variable, without failure (luaunit v2 compatibility)
  1263. M.isXxx(value) -> returns true if type(value) conforms to "xxx"
  1264. ]]
  1265. for _, typeExpected in ipairs(
  1266. {'Number', 'String', 'Table', 'Boolean',
  1267. 'Function', 'Userdata', 'Thread', 'Nil' }
  1268. ) do
  1269. local typeExpectedLower = typeExpected:lower()
  1270. local isType = function(value)
  1271. return (type(value) == typeExpectedLower)
  1272. end
  1273. M['is'..typeExpected] = isType
  1274. M['is_'..typeExpectedLower] = isType
  1275. end
  1276. --[[
  1277. Add non-type assertion functions to the module table M. Each of these functions
  1278. takes a single parameter "value", and checks that its Lua type differs from the
  1279. expected string (derived from the function name):
  1280. M.assertNotIsXxx(value) -> ensure that type(value) is not "xxx"
  1281. ]]
  1282. for _, funcName in ipairs(
  1283. {'assertNotIsNumber', 'assertNotIsString', 'assertNotIsTable', 'assertNotIsBoolean',
  1284. 'assertNotIsFunction', 'assertNotIsUserdata', 'assertNotIsThread'}
  1285. ) do
  1286. local typeUnexpected = funcName:match("^assertNotIs([A-Z]%a*)$")
  1287. -- Lua type() always returns lowercase, also make sure the match() succeeded
  1288. typeUnexpected = typeUnexpected and typeUnexpected:lower()
  1289. or error("bad function name '"..funcName.."' for type assertion")
  1290. M[funcName] = function(value)
  1291. if type(value) == typeUnexpected then
  1292. fail_fmt(2, 'Not expected: a %s type, actual: value %s',
  1293. typeUnexpected, prettystrPairs(value))
  1294. end
  1295. end
  1296. end
  1297. function M.assertIs(actual, expected)
  1298. if actual ~= expected then
  1299. if not M.ORDER_ACTUAL_EXPECTED then
  1300. actual, expected = expected, actual
  1301. end
  1302. expected, actual = prettystrPairs(expected, actual, '\n', ', ')
  1303. fail_fmt(2, 'Expected object and actual object are not the same\nExpected: %sactual: %s',
  1304. expected, actual)
  1305. end
  1306. end
  1307. function M.assertNotIs(actual, expected)
  1308. if actual == expected then
  1309. if not M.ORDER_ACTUAL_EXPECTED then
  1310. expected = actual
  1311. end
  1312. fail_fmt(2, 'Expected object and actual object are the same object: %s',
  1313. prettystrPairs(expected))
  1314. end
  1315. end
  1316. ------------------------------------------------------------------
  1317. -- Scientific assertions
  1318. ------------------------------------------------------------------
  1319. function M.assertIsNaN(value)
  1320. if type(value) ~= "number" or value == value then
  1321. failure("expected: nan, actual: " ..prettystr(value), 2)
  1322. end
  1323. end
  1324. function M.assertNotIsNaN(value)
  1325. if type(value) == "number" and value ~= value then
  1326. failure("expected non nan value, received nan", 2)
  1327. end
  1328. end
  1329. function M.assertIsInf(value)
  1330. if type(value) ~= "number" or math.abs(value) ~= math.huge then
  1331. failure("expected: inf, actual: " ..prettystr(value), 2)
  1332. end
  1333. end
  1334. function M.assertIsPlusInf(value)
  1335. if type(value) ~= "number" or value ~= math.huge then
  1336. failure("expected: +inf, actual: " ..prettystr(value), 2)
  1337. end
  1338. end
  1339. function M.assertIsMinusInf(value)
  1340. if type(value) ~= "number" or value ~= -math.huge then
  1341. failure("expected: -inf, actual: " ..prettystr(value), 2)
  1342. end
  1343. end
  1344. function M.assertNotIsPlusInf(value)
  1345. if type(value) == "number" and value == math.huge then
  1346. failure("expected not +inf value, received +inf", 2)
  1347. end
  1348. end
  1349. function M.assertNotIsMinusInf(value)
  1350. if type(value) == "number" and value == -math.huge then
  1351. failure("expected not -inf value, received -inf", 2)
  1352. end
  1353. end
  1354. function M.assertNotIsInf(value)
  1355. if type(value) == "number" and math.abs(value) == math.huge then
  1356. failure("expected non inf value, received ±inf", 2)
  1357. end
  1358. end
  1359. function M.assertIsPlusZero(value)
  1360. if type(value) ~= 'number' or value ~= 0 then
  1361. failure("expected: +0.0, actual: " ..prettystr(value), 2)
  1362. else if (1/value == -math.huge) then
  1363. -- more precise error diagnosis
  1364. failure("expected: +0.0, actual: -0.0", 2)
  1365. else if (1/value ~= math.huge) then
  1366. -- strange, case should have already been covered
  1367. failure("expected: +0.0, actual: " ..prettystr(value), 2)
  1368. end
  1369. end
  1370. end
  1371. end
  1372. function M.assertIsMinusZero(value)
  1373. if type(value) ~= 'number' or value ~= 0 then
  1374. failure("expected: -0.0, actual: " ..prettystr(value), 2)
  1375. else if (1/value == math.huge) then
  1376. -- more precise error diagnosis
  1377. failure("expected: -0.0, actual: +0.0", 2)
  1378. else if (1/value ~= -math.huge) then
  1379. -- strange, case should have already been covered
  1380. failure("expected: -0.0, actual: " ..prettystr(value), 2)
  1381. end
  1382. end
  1383. end
  1384. end
  1385. function M.assertNotIsPlusZero(value)
  1386. if type(value) == 'number' and value == 0 and (1/value ~= math.huge) then
  1387. failure("expected: not +0.0, actual: +0.0", 2)
  1388. end
  1389. end
  1390. function M.assertNotIsMinusZero(value)
  1391. if type(value) == 'number' and value == 0 and (1/value ~= -math.huge) then
  1392. failure("expected: not +0.0, actual: +0.0", 2)
  1393. end
  1394. end
  1395. ----------------------------------------------------------------
  1396. -- Compatibility layer
  1397. ----------------------------------------------------------------
  1398. -- for compatibility with LuaUnit v2.x
  1399. function M.wrapFunctions()
  1400. -- In LuaUnit version <= 2.1 , this function was necessary to include
  1401. -- a test function inside the global test suite. Nowadays, the functions
  1402. -- are simply run directly as part of the test discovery process.
  1403. -- so just do nothing !
  1404. io.stderr:write[[Use of WrapFunctions() is no longer needed.
  1405. Just prefix your test function names with "test" or "Test" and they
  1406. will be picked up and run by LuaUnit.
  1407. ]]
  1408. end
  1409. local list_of_funcs = {
  1410. -- { official function name , alias }
  1411. -- general assertions
  1412. { 'assertEquals' , 'assert_equals' },
  1413. { 'assertItemsEquals' , 'assert_items_equals' },
  1414. { 'assertNotEquals' , 'assert_not_equals' },
  1415. { 'assertAlmostEquals' , 'assert_almost_equals' },
  1416. { 'assertNotAlmostEquals' , 'assert_not_almost_equals' },
  1417. { 'assertEvalToTrue' , 'assert_eval_to_true' },
  1418. { 'assertEvalToFalse' , 'assert_eval_to_false' },
  1419. { 'assertStrContains' , 'assert_str_contains' },
  1420. { 'assertStrIContains' , 'assert_str_icontains' },
  1421. { 'assertNotStrContains' , 'assert_not_str_contains' },
  1422. { 'assertNotStrIContains' , 'assert_not_str_icontains' },
  1423. { 'assertStrMatches' , 'assert_str_matches' },
  1424. { 'assertError' , 'assert_error' },
  1425. { 'assertErrorMsgEquals' , 'assert_error_msg_equals' },
  1426. { 'assertErrorMsgContains' , 'assert_error_msg_contains' },
  1427. { 'assertErrorMsgMatches' , 'assert_error_msg_matches' },
  1428. { 'assertIs' , 'assert_is' },
  1429. { 'assertNotIs' , 'assert_not_is' },
  1430. { 'wrapFunctions' , 'WrapFunctions' },
  1431. { 'wrapFunctions' , 'wrap_functions' },
  1432. -- type assertions: assertIsXXX -> assert_is_xxx
  1433. { 'assertIsNumber' , 'assert_is_number' },
  1434. { 'assertIsString' , 'assert_is_string' },
  1435. { 'assertIsTable' , 'assert_is_table' },
  1436. { 'assertIsBoolean' , 'assert_is_boolean' },
  1437. { 'assertIsNil' , 'assert_is_nil' },
  1438. { 'assertIsTrue' , 'assert_is_true' },
  1439. { 'assertIsFalse' , 'assert_is_false' },
  1440. { 'assertIsNaN' , 'assert_is_nan' },
  1441. { 'assertIsInf' , 'assert_is_inf' },
  1442. { 'assertIsPlusInf' , 'assert_is_plus_inf' },
  1443. { 'assertIsMinusInf' , 'assert_is_minus_inf' },
  1444. { 'assertIsPlusZero' , 'assert_is_plus_zero' },
  1445. { 'assertIsMinusZero' , 'assert_is_minus_zero' },
  1446. { 'assertIsFunction' , 'assert_is_function' },
  1447. { 'assertIsThread' , 'assert_is_thread' },
  1448. { 'assertIsUserdata' , 'assert_is_userdata' },
  1449. -- type assertions: assertIsXXX -> assertXxx
  1450. { 'assertIsNumber' , 'assertNumber' },
  1451. { 'assertIsString' , 'assertString' },
  1452. { 'assertIsTable' , 'assertTable' },
  1453. { 'assertIsBoolean' , 'assertBoolean' },
  1454. { 'assertIsNil' , 'assertNil' },
  1455. { 'assertIsTrue' , 'assertTrue' },
  1456. { 'assertIsFalse' , 'assertFalse' },
  1457. { 'assertIsNaN' , 'assertNaN' },
  1458. { 'assertIsInf' , 'assertInf' },
  1459. { 'assertIsPlusInf' , 'assertPlusInf' },
  1460. { 'assertIsMinusInf' , 'assertMinusInf' },
  1461. { 'assertIsPlusZero' , 'assertPlusZero' },
  1462. { 'assertIsMinusZero' , 'assertMinusZero'},
  1463. { 'assertIsFunction' , 'assertFunction' },
  1464. { 'assertIsThread' , 'assertThread' },
  1465. { 'assertIsUserdata' , 'assertUserdata' },
  1466. -- type assertions: assertIsXXX -> assert_xxx (luaunit v2 compat)
  1467. { 'assertIsNumber' , 'assert_number' },
  1468. { 'assertIsString' , 'assert_string' },
  1469. { 'assertIsTable' , 'assert_table' },
  1470. { 'assertIsBoolean' , 'assert_boolean' },
  1471. { 'assertIsNil' , 'assert_nil' },
  1472. { 'assertIsTrue' , 'assert_true' },
  1473. { 'assertIsFalse' , 'assert_false' },
  1474. { 'assertIsNaN' , 'assert_nan' },
  1475. { 'assertIsInf' , 'assert_inf' },
  1476. { 'assertIsPlusInf' , 'assert_plus_inf' },
  1477. { 'assertIsMinusInf' , 'assert_minus_inf' },
  1478. { 'assertIsPlusZero' , 'assert_plus_zero' },
  1479. { 'assertIsMinusZero' , 'assert_minus_zero' },
  1480. { 'assertIsFunction' , 'assert_function' },
  1481. { 'assertIsThread' , 'assert_thread' },
  1482. { 'assertIsUserdata' , 'assert_userdata' },
  1483. -- type assertions: assertNotIsXXX -> assert_not_is_xxx
  1484. { 'assertNotIsNumber' , 'assert_not_is_number' },
  1485. { 'assertNotIsString' , 'assert_not_is_string' },
  1486. { 'assertNotIsTable' , 'assert_not_is_table' },
  1487. { 'assertNotIsBoolean' , 'assert_not_is_boolean' },
  1488. { 'assertNotIsNil' , 'assert_not_is_nil' },
  1489. { 'assertNotIsTrue' , 'assert_not_is_true' },
  1490. { 'assertNotIsFalse' , 'assert_not_is_false' },
  1491. { 'assertNotIsNaN' , 'assert_not_is_nan' },
  1492. { 'assertNotIsInf' , 'assert_not_is_inf' },
  1493. { 'assertNotIsPlusInf' , 'assert_not_plus_inf' },
  1494. { 'assertNotIsMinusInf' , 'assert_not_minus_inf' },
  1495. { 'assertNotIsPlusZero' , 'assert_not_plus_zero' },
  1496. { 'assertNotIsMinusZero' , 'assert_not_minus_zero' },
  1497. { 'assertNotIsFunction' , 'assert_not_is_function' },
  1498. { 'assertNotIsThread' , 'assert_not_is_thread' },
  1499. { 'assertNotIsUserdata' , 'assert_not_is_userdata' },
  1500. -- type assertions: assertNotIsXXX -> assertNotXxx (luaunit v2 compat)
  1501. { 'assertNotIsNumber' , 'assertNotNumber' },
  1502. { 'assertNotIsString' , 'assertNotString' },
  1503. { 'assertNotIsTable' , 'assertNotTable' },
  1504. { 'assertNotIsBoolean' , 'assertNotBoolean' },
  1505. { 'assertNotIsNil' , 'assertNotNil' },
  1506. { 'assertNotIsTrue' , 'assertNotTrue' },
  1507. { 'assertNotIsFalse' , 'assertNotFalse' },
  1508. { 'assertNotIsNaN' , 'assertNotNaN' },
  1509. { 'assertNotIsInf' , 'assertNotInf' },
  1510. { 'assertNotIsPlusInf' , 'assertNotPlusInf' },
  1511. { 'assertNotIsMinusInf' , 'assertNotMinusInf' },
  1512. { 'assertNotIsPlusZero' , 'assertNotPlusZero' },
  1513. { 'assertNotIsMinusZero' , 'assertNotMinusZero' },
  1514. { 'assertNotIsFunction' , 'assertNotFunction' },
  1515. { 'assertNotIsThread' , 'assertNotThread' },
  1516. { 'assertNotIsUserdata' , 'assertNotUserdata' },
  1517. -- type assertions: assertNotIsXXX -> assert_not_xxx
  1518. { 'assertNotIsNumber' , 'assert_not_number' },
  1519. { 'assertNotIsString' , 'assert_not_string' },
  1520. { 'assertNotIsTable' , 'assert_not_table' },
  1521. { 'assertNotIsBoolean' , 'assert_not_boolean' },
  1522. { 'assertNotIsNil' , 'assert_not_nil' },
  1523. { 'assertNotIsTrue' , 'assert_not_true' },
  1524. { 'assertNotIsFalse' , 'assert_not_false' },
  1525. { 'assertNotIsNaN' , 'assert_not_nan' },
  1526. { 'assertNotIsInf' , 'assert_not_inf' },
  1527. { 'assertNotIsPlusInf' , 'assert_not_plus_inf' },
  1528. { 'assertNotIsMinusInf' , 'assert_not_minus_inf' },
  1529. { 'assertNotIsPlusZero' , 'assert_not_plus_zero' },
  1530. { 'assertNotIsMinusZero' , 'assert_not_minus_zero' },
  1531. { 'assertNotIsFunction' , 'assert_not_function' },
  1532. { 'assertNotIsThread' , 'assert_not_thread' },
  1533. { 'assertNotIsUserdata' , 'assert_not_userdata' },
  1534. -- all assertions with Coroutine duplicate Thread assertions
  1535. { 'assertIsThread' , 'assertIsCoroutine' },
  1536. { 'assertIsThread' , 'assertCoroutine' },
  1537. { 'assertIsThread' , 'assert_is_coroutine' },
  1538. { 'assertIsThread' , 'assert_coroutine' },
  1539. { 'assertNotIsThread' , 'assertNotIsCoroutine' },
  1540. { 'assertNotIsThread' , 'assertNotCoroutine' },
  1541. { 'assertNotIsThread' , 'assert_not_is_coroutine' },
  1542. { 'assertNotIsThread' , 'assert_not_coroutine' },
  1543. }
  1544. -- Create all aliases in M
  1545. for _,v in ipairs( list_of_funcs ) do
  1546. local funcname, alias = v[1], v[2]
  1547. M[alias] = M[funcname]
  1548. if EXPORT_ASSERT_TO_GLOBALS then
  1549. _G[funcname] = M[funcname]
  1550. _G[alias] = M[funcname]
  1551. end
  1552. end
  1553. ----------------------------------------------------------------
  1554. --
  1555. -- Outputters
  1556. --
  1557. ----------------------------------------------------------------
  1558. -- A common "base" class for outputters
  1559. -- For concepts involved (class inheritance) see http://www.lua.org/pil/16.2.html
  1560. local genericOutput = { __class__ = 'genericOutput' } -- class
  1561. local genericOutput_MT = { __index = genericOutput } -- metatable
  1562. M.genericOutput = genericOutput -- publish, so that custom classes may derive from it
  1563. function genericOutput.new(runner, default_verbosity)
  1564. -- runner is the "parent" object controlling the output, usually a LuaUnit instance
  1565. local t = { runner = runner }
  1566. if runner then
  1567. t.result = runner.result
  1568. t.verbosity = runner.verbosity or default_verbosity
  1569. t.fname = runner.fname
  1570. else
  1571. t.verbosity = default_verbosity
  1572. end
  1573. return setmetatable( t, genericOutput_MT)
  1574. end
  1575. -- abstract ("empty") methods
  1576. function genericOutput:startSuite() end
  1577. function genericOutput:startClass(className) end
  1578. function genericOutput:startTest(testName) end
  1579. function genericOutput:addStatus(node) end
  1580. function genericOutput:endTest(node) end
  1581. function genericOutput:endClass() end
  1582. function genericOutput:endSuite() end
  1583. ----------------------------------------------------------------
  1584. -- class TapOutput
  1585. ----------------------------------------------------------------
  1586. local TapOutput = genericOutput.new() -- derived class
  1587. local TapOutput_MT = { __index = TapOutput } -- metatable
  1588. TapOutput.__class__ = 'TapOutput'
  1589. -- For a good reference for TAP format, check: http://testanything.org/tap-specification.html
  1590. function TapOutput.new(runner)
  1591. local t = genericOutput.new(runner, M.VERBOSITY_LOW)
  1592. return setmetatable( t, TapOutput_MT)
  1593. end
  1594. function TapOutput:startSuite()
  1595. print("1.."..self.result.testCount)
  1596. print('# Started on '..self.result.startDate)
  1597. end
  1598. function TapOutput:startClass(className)
  1599. if className ~= '[TestFunctions]' then
  1600. print('# Starting class: '..className)
  1601. end
  1602. end
  1603. function TapOutput:addStatus( node )
  1604. io.stdout:write("not ok ", self.result.currentTestNumber, "\t", node.testName, "\n")
  1605. if self.verbosity > M.VERBOSITY_LOW then
  1606. print( prefixString( '# ', node.msg ) )
  1607. end
  1608. if self.verbosity > M.VERBOSITY_DEFAULT then
  1609. print( prefixString( '# ', node.stackTrace ) )
  1610. end
  1611. end
  1612. function TapOutput:endTest( node )
  1613. if node:isPassed() then
  1614. io.stdout:write("ok ", self.result.currentTestNumber, "\t", node.testName, "\n")
  1615. end
  1616. end
  1617. function TapOutput:endSuite()
  1618. print( '# '..M.LuaUnit.statusLine( self.result ) )
  1619. return self.result.notPassedCount
  1620. end
  1621. -- class TapOutput end
  1622. ----------------------------------------------------------------
  1623. -- class JUnitOutput
  1624. ----------------------------------------------------------------
  1625. -- See directory junitxml for more information about the junit format
  1626. local JUnitOutput = genericOutput.new() -- derived class
  1627. local JUnitOutput_MT = { __index = JUnitOutput } -- metatable
  1628. JUnitOutput.__class__ = 'JUnitOutput'
  1629. function JUnitOutput.new(runner)
  1630. local t = genericOutput.new(runner, M.VERBOSITY_LOW)
  1631. t.testList = {}
  1632. return setmetatable( t, JUnitOutput_MT )
  1633. end
  1634. function JUnitOutput:startSuite()
  1635. -- open xml file early to deal with errors
  1636. if self.fname == nil then
  1637. error('With Junit, an output filename must be supplied with --name!')
  1638. end
  1639. if string.sub(self.fname,-4) ~= '.xml' then
  1640. self.fname = self.fname..'.xml'
  1641. end
  1642. self.fd = io.open(self.fname, "w")
  1643. if self.fd == nil then
  1644. error("Could not open file for writing: "..self.fname)
  1645. end
  1646. print('# XML output to '..self.fname)
  1647. print('# Started on '..self.result.startDate)
  1648. end
  1649. function JUnitOutput:startClass(className)
  1650. if className ~= '[TestFunctions]' then
  1651. print('# Starting class: '..className)
  1652. end
  1653. end
  1654. function JUnitOutput:startTest(testName)
  1655. print('# Starting test: '..testName)
  1656. end
  1657. function JUnitOutput:addStatus( node )
  1658. if node:isFailure() then
  1659. print( '# Failure: ' .. prefixString( '# ', node.msg ):sub(4, nil) )
  1660. -- print('# ' .. node.stackTrace)
  1661. elseif node:isError() then
  1662. print( '# Error: ' .. prefixString( '# ' , node.msg ):sub(4, nil) )
  1663. -- print('# ' .. node.stackTrace)
  1664. end
  1665. end
  1666. function JUnitOutput:endSuite()
  1667. print( '# '..M.LuaUnit.statusLine(self.result))
  1668. -- XML file writing
  1669. self.fd:write('<?xml version="1.0" encoding="UTF-8" ?>\n')
  1670. self.fd:write('<testsuites>\n')
  1671. self.fd:write(string.format(
  1672. ' <testsuite name="LuaUnit" id="00001" package="" hostname="localhost" tests="%d" timestamp="%s" time="%0.3f" errors="%d" failures="%d">\n',
  1673. self.result.runCount, self.result.startIsodate, self.result.duration, self.result.errorCount, self.result.failureCount ))
  1674. self.fd:write(" <properties>\n")
  1675. self.fd:write(string.format(' <property name="Lua Version" value="%s"/>\n', _VERSION ) )
  1676. self.fd:write(string.format(' <property name="LuaUnit Version" value="%s"/>\n', M.VERSION) )
  1677. -- XXX please include system name and version if possible
  1678. self.fd:write(" </properties>\n")
  1679. for i,node in ipairs(self.result.tests) do
  1680. self.fd:write(string.format(' <testcase classname="%s" name="%s" time="%0.3f">\n',
  1681. node.className, node.testName, node.duration ) )
  1682. if node:isNotPassed() then
  1683. self.fd:write(node:statusXML())
  1684. end
  1685. self.fd:write(' </testcase>\n')
  1686. end
  1687. -- Next two lines are needed to validate junit ANT xsd, but really not useful in general:
  1688. self.fd:write(' <system-out/>\n')
  1689. self.fd:write(' <system-err/>\n')
  1690. self.fd:write(' </testsuite>\n')
  1691. self.fd:write('</testsuites>\n')
  1692. self.fd:close()
  1693. return self.result.notPassedCount
  1694. end
  1695. -- class TapOutput end
  1696. ----------------------------------------------------------------
  1697. -- class TextOutput
  1698. ----------------------------------------------------------------
  1699. --[[
  1700. -- Python Non verbose:
  1701. For each test: . or F or E
  1702. If some failed tests:
  1703. ==============
  1704. ERROR / FAILURE: TestName (testfile.testclass)
  1705. ---------
  1706. Stack trace
  1707. then --------------
  1708. then "Ran x tests in 0.000s"
  1709. then OK or FAILED (failures=1, error=1)
  1710. -- Python Verbose:
  1711. testname (filename.classname) ... ok
  1712. testname (filename.classname) ... FAIL
  1713. testname (filename.classname) ... ERROR
  1714. then --------------
  1715. then "Ran x tests in 0.000s"
  1716. then OK or FAILED (failures=1, error=1)
  1717. -- Ruby:
  1718. Started
  1719. .
  1720. Finished in 0.002695 seconds.
  1721. 1 tests, 2 assertions, 0 failures, 0 errors
  1722. -- Ruby:
  1723. >> ruby tc_simple_number2.rb
  1724. Loaded suite tc_simple_number2
  1725. Started
  1726. F..
  1727. Finished in 0.038617 seconds.
  1728. 1) Failure:
  1729. test_failure(TestSimpleNumber) [tc_simple_number2.rb:16]:
  1730. Adding doesn't work.
  1731. <3> expected but was
  1732. <4>.
  1733. 3 tests, 4 assertions, 1 failures, 0 errors
  1734. -- Java Junit
  1735. .......F.
  1736. Time: 0,003
  1737. There was 1 failure:
  1738. 1) testCapacity(junit.samples.VectorTest)junit.framework.AssertionFailedError
  1739. at junit.samples.VectorTest.testCapacity(VectorTest.java:87)
  1740. at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
  1741. at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
  1742. at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
  1743. FAILURES!!!
  1744. Tests run: 8, Failures: 1, Errors: 0
  1745. -- Maven
  1746. # mvn test
  1747. -------------------------------------------------------
  1748. T E S T S
  1749. -------------------------------------------------------
  1750. Running math.AdditionTest
  1751. Tests run: 2, Failures: 1, Errors: 0, Skipped: 0, Time elapsed:
  1752. 0.03 sec <<< FAILURE!
  1753. Results :
  1754. Failed tests:
  1755. testLireSymbole(math.AdditionTest)
  1756. Tests run: 2, Failures: 1, Errors: 0, Skipped: 0
  1757. -- LuaUnit
  1758. ---- non verbose
  1759. * display . or F or E when running tests
  1760. ---- verbose
  1761. * display test name + ok/fail
  1762. ----
  1763. * blank line
  1764. * number) ERROR or FAILURE: TestName
  1765. Stack trace
  1766. * blank line
  1767. * number) ERROR or FAILURE: TestName
  1768. Stack trace
  1769. then --------------
  1770. then "Ran x tests in 0.000s (%d not selected, %d skipped)"
  1771. then OK or FAILED (failures=1, error=1)
  1772. ]]
  1773. local TextOutput = genericOutput.new() -- derived class
  1774. local TextOutput_MT = { __index = TextOutput } -- metatable
  1775. TextOutput.__class__ = 'TextOutput'
  1776. function TextOutput.new(runner)
  1777. local t = genericOutput.new(runner, M.VERBOSITY_DEFAULT)
  1778. t.errorList = {}
  1779. return setmetatable( t, TextOutput_MT )
  1780. end
  1781. function TextOutput:startSuite()
  1782. if self.verbosity > M.VERBOSITY_DEFAULT then
  1783. print( 'Started on '.. self.result.startDate )
  1784. end
  1785. end
  1786. function TextOutput:startTest(testName)
  1787. if self.verbosity > M.VERBOSITY_DEFAULT then
  1788. io.stdout:write( " ", self.result.currentNode.testName, " ... " )
  1789. end
  1790. end
  1791. function TextOutput:endTest( node )
  1792. if node:isPassed() then
  1793. if self.verbosity > M.VERBOSITY_DEFAULT then
  1794. io.stdout:write("Ok\n")
  1795. else
  1796. io.stdout:write(".")
  1797. end
  1798. else
  1799. if self.verbosity > M.VERBOSITY_DEFAULT then
  1800. print( node.status )
  1801. print( node.msg )
  1802. --[[
  1803. -- find out when to do this:
  1804. if self.verbosity > M.VERBOSITY_DEFAULT then
  1805. print( node.stackTrace )
  1806. end
  1807. ]]
  1808. else
  1809. -- write only the first character of status
  1810. io.stdout:write(string.sub(node.status, 1, 1))
  1811. end
  1812. end
  1813. end
  1814. function TextOutput:displayOneFailedTest( index, fail )
  1815. print(index..") "..fail.testName )
  1816. print( fail.msg )
  1817. print( fail.stackTrace )
  1818. print()
  1819. end
  1820. function TextOutput:displayFailedTests()
  1821. if self.result.notPassedCount ~= 0 then
  1822. print("Failed tests:")
  1823. print("-------------")
  1824. for i, v in ipairs(self.result.notPassed) do
  1825. self:displayOneFailedTest(i, v)
  1826. end
  1827. end
  1828. end
  1829. function TextOutput:endSuite()
  1830. if self.verbosity > M.VERBOSITY_DEFAULT then
  1831. print("=========================================================")
  1832. else
  1833. print()
  1834. end
  1835. self:displayFailedTests()
  1836. print( M.LuaUnit.statusLine( self.result ) )
  1837. if self.result.notPassedCount == 0 then
  1838. print('OK')
  1839. end
  1840. end
  1841. -- class TextOutput end
  1842. ----------------------------------------------------------------
  1843. -- class NilOutput
  1844. ----------------------------------------------------------------
  1845. local function nopCallable()
  1846. --print(42)
  1847. return nopCallable
  1848. end
  1849. local NilOutput = { __class__ = 'NilOuptut' } -- class
  1850. local NilOutput_MT = { __index = nopCallable } -- metatable
  1851. function NilOutput.new(runner)
  1852. return setmetatable( { __class__ = 'NilOutput' }, NilOutput_MT )
  1853. end
  1854. ----------------------------------------------------------------
  1855. --
  1856. -- class LuaUnit
  1857. --
  1858. ----------------------------------------------------------------
  1859. M.LuaUnit = {
  1860. outputType = TextOutput,
  1861. verbosity = M.VERBOSITY_DEFAULT,
  1862. __class__ = 'LuaUnit'
  1863. }
  1864. local LuaUnit_MT = { __index = M.LuaUnit }
  1865. if EXPORT_ASSERT_TO_GLOBALS then
  1866. LuaUnit = M.LuaUnit
  1867. end
  1868. function M.LuaUnit.new()
  1869. return setmetatable( {}, LuaUnit_MT )
  1870. end
  1871. -----------------[[ Utility methods ]]---------------------
  1872. function M.LuaUnit.asFunction(aObject)
  1873. -- return "aObject" if it is a function, and nil otherwise
  1874. if 'function' == type(aObject) then
  1875. return aObject
  1876. end
  1877. end
  1878. function M.LuaUnit.splitClassMethod(someName)
  1879. --[[
  1880. Return a pair of className, methodName strings for a name in the form
  1881. "class.method". If no class part (or separator) is found, will return
  1882. nil, someName instead (the latter being unchanged).
  1883. This convention thus also replaces the older isClassMethod() test:
  1884. You just have to check for a non-nil className (return) value.
  1885. ]]
  1886. local separator = string.find(someName, '.', 1, true)
  1887. if separator then
  1888. return someName:sub(1, separator - 1), someName:sub(separator + 1)
  1889. end
  1890. return nil, someName
  1891. end
  1892. function M.LuaUnit.isMethodTestName( s )
  1893. -- return true is the name matches the name of a test method
  1894. -- default rule is that is starts with 'Test' or with 'test'
  1895. return string.sub(s, 1, 4):lower() == 'test'
  1896. end
  1897. function M.LuaUnit.isTestName( s )
  1898. -- return true is the name matches the name of a test
  1899. -- default rule is that is starts with 'Test' or with 'test'
  1900. return string.sub(s, 1, 4):lower() == 'test'
  1901. end
  1902. function M.LuaUnit.collectTests()
  1903. -- return a list of all test names in the global namespace
  1904. -- that match LuaUnit.isTestName
  1905. local testNames = {}
  1906. for k, _ in pairs(_G) do
  1907. if type(k) == "string" and M.LuaUnit.isTestName( k ) then
  1908. table.insert( testNames , k )
  1909. end
  1910. end
  1911. table.sort( testNames )
  1912. return testNames
  1913. end
  1914. function M.LuaUnit.parseCmdLine( cmdLine )
  1915. -- parse the command line
  1916. -- Supported command line parameters:
  1917. -- --verbose, -v: increase verbosity
  1918. -- --quiet, -q: silence output
  1919. -- --error, -e: treat errors as fatal (quit program)
  1920. -- --output, -o, + name: select output type
  1921. -- --pattern, -p, + pattern: run test matching pattern, may be repeated
  1922. -- --exclude, -x, + pattern: run test not matching pattern, may be repeated
  1923. -- --shuffle, -s, : shuffle tests before reunning them
  1924. -- --name, -n, + fname: name of output file for junit, default to stdout
  1925. -- --repeat, -r, + num: number of times to execute each test
  1926. -- [testnames, ...]: run selected test names
  1927. --
  1928. -- Returns a table with the following fields:
  1929. -- verbosity: nil, M.VERBOSITY_DEFAULT, M.VERBOSITY_QUIET, M.VERBOSITY_VERBOSE
  1930. -- output: nil, 'tap', 'junit', 'text', 'nil'
  1931. -- testNames: nil or a list of test names to run
  1932. -- exeRepeat: num or 1
  1933. -- pattern: nil or a list of patterns
  1934. -- exclude: nil or a list of patterns
  1935. local result, state = {}, nil
  1936. local SET_OUTPUT = 1
  1937. local SET_PATTERN = 2
  1938. local SET_EXCLUDE = 3
  1939. local SET_FNAME = 4
  1940. local SET_REPEAT = 5
  1941. if cmdLine == nil then
  1942. return result
  1943. end
  1944. local function parseOption( option )
  1945. if option == '--help' or option == '-h' then
  1946. result['help'] = true
  1947. return
  1948. elseif option == '--version' then
  1949. result['version'] = true
  1950. return
  1951. elseif option == '--verbose' or option == '-v' then
  1952. result['verbosity'] = M.VERBOSITY_VERBOSE
  1953. return
  1954. elseif option == '--quiet' or option == '-q' then
  1955. result['verbosity'] = M.VERBOSITY_QUIET
  1956. return
  1957. elseif option == '--error' or option == '-e' then
  1958. result['quitOnError'] = true
  1959. return
  1960. elseif option == '--failure' or option == '-f' then
  1961. result['quitOnFailure'] = true
  1962. return
  1963. elseif option == '--shuffle' or option == '-s' then
  1964. result['shuffle'] = true
  1965. return
  1966. elseif option == '--output' or option == '-o' then
  1967. state = SET_OUTPUT
  1968. return state
  1969. elseif option == '--name' or option == '-n' then
  1970. state = SET_FNAME
  1971. return state
  1972. elseif option == '--repeat' or option == '-r' then
  1973. state = SET_REPEAT
  1974. return state
  1975. elseif option == '--pattern' or option == '-p' then
  1976. state = SET_PATTERN
  1977. return state
  1978. elseif option == '--exclude' or option == '-x' then
  1979. state = SET_EXCLUDE
  1980. return state
  1981. end
  1982. error('Unknown option: '..option,3)
  1983. end
  1984. local function setArg( cmdArg, state )
  1985. if state == SET_OUTPUT then
  1986. result['output'] = cmdArg
  1987. return
  1988. elseif state == SET_FNAME then
  1989. result['fname'] = cmdArg
  1990. return
  1991. elseif state == SET_REPEAT then
  1992. result['exeRepeat'] = tonumber(cmdArg)
  1993. or error('Malformed -r argument: '..cmdArg)
  1994. return
  1995. elseif state == SET_PATTERN then
  1996. if result['pattern'] then
  1997. table.insert( result['pattern'], cmdArg )
  1998. else
  1999. result['pattern'] = { cmdArg }
  2000. end
  2001. return
  2002. elseif state == SET_EXCLUDE then
  2003. local notArg = '!'..cmdArg
  2004. if result['pattern'] then
  2005. table.insert( result['pattern'], notArg )
  2006. else
  2007. result['pattern'] = { notArg }
  2008. end
  2009. return
  2010. end
  2011. error('Unknown parse state: '.. state)
  2012. end
  2013. for i, cmdArg in ipairs(cmdLine) do
  2014. if state ~= nil then
  2015. setArg( cmdArg, state, result )
  2016. state = nil
  2017. else
  2018. if cmdArg:sub(1,1) == '-' then
  2019. state = parseOption( cmdArg )
  2020. else
  2021. if result['testNames'] then
  2022. table.insert( result['testNames'], cmdArg )
  2023. else
  2024. result['testNames'] = { cmdArg }
  2025. end
  2026. end
  2027. end
  2028. end
  2029. if result['help'] then
  2030. M.LuaUnit.help()
  2031. end
  2032. if result['version'] then
  2033. M.LuaUnit.version()
  2034. end
  2035. if state ~= nil then
  2036. error('Missing argument after '..cmdLine[ #cmdLine ],2 )
  2037. end
  2038. return result
  2039. end
  2040. function M.LuaUnit.help()
  2041. print(M.USAGE)
  2042. os.exit(0)
  2043. end
  2044. function M.LuaUnit.version()
  2045. print('LuaUnit v'..M.VERSION..' by Philippe Fremy <phil@freehackers.org>')
  2046. os.exit(0)
  2047. end
  2048. ----------------------------------------------------------------
  2049. -- class NodeStatus
  2050. ----------------------------------------------------------------
  2051. local NodeStatus = { __class__ = 'NodeStatus' } -- class
  2052. local NodeStatus_MT = { __index = NodeStatus } -- metatable
  2053. M.NodeStatus = NodeStatus
  2054. -- values of status
  2055. NodeStatus.PASS = 'PASS'
  2056. NodeStatus.FAIL = 'FAIL'
  2057. NodeStatus.ERROR = 'ERROR'
  2058. function NodeStatus.new( number, testName, className )
  2059. local t = { number = number, testName = testName, className = className }
  2060. setmetatable( t, NodeStatus_MT )
  2061. t:pass()
  2062. return t
  2063. end
  2064. function NodeStatus:pass()
  2065. self.status = self.PASS
  2066. -- useless but we know it's the field we want to use
  2067. self.msg = nil
  2068. self.stackTrace = nil
  2069. end
  2070. function NodeStatus:fail(msg, stackTrace)
  2071. self.status = self.FAIL
  2072. self.msg = msg
  2073. self.stackTrace = stackTrace
  2074. end
  2075. function NodeStatus:error(msg, stackTrace)
  2076. self.status = self.ERROR
  2077. self.msg = msg
  2078. self.stackTrace = stackTrace
  2079. end
  2080. function NodeStatus:isPassed()
  2081. return self.status == NodeStatus.PASS
  2082. end
  2083. function NodeStatus:isNotPassed()
  2084. -- print('hasFailure: '..prettystr(self))
  2085. return self.status ~= NodeStatus.PASS
  2086. end
  2087. function NodeStatus:isFailure()
  2088. return self.status == NodeStatus.FAIL
  2089. end
  2090. function NodeStatus:isError()
  2091. return self.status == NodeStatus.ERROR
  2092. end
  2093. function NodeStatus:statusXML()
  2094. if self:isError() then
  2095. return table.concat(
  2096. {' <error type="', xmlEscape(self.msg), '">\n',
  2097. ' <![CDATA[', xmlCDataEscape(self.stackTrace),
  2098. ']]></error>\n'})
  2099. elseif self:isFailure() then
  2100. return table.concat(
  2101. {' <failure type="', xmlEscape(self.msg), '">\n',
  2102. ' <![CDATA[', xmlCDataEscape(self.stackTrace),
  2103. ']]></failure>\n'})
  2104. end
  2105. return ' <passed/>\n' -- (not XSD-compliant! normally shouldn't get here)
  2106. end
  2107. --------------[[ Output methods ]]-------------------------
  2108. local function conditional_plural(number, singular)
  2109. -- returns a grammatically well-formed string "%d <singular/plural>"
  2110. local suffix = ''
  2111. if number ~= 1 then -- use plural
  2112. suffix = (singular:sub(-2) == 'ss') and 'es' or 's'
  2113. end
  2114. return string.format('%d %s%s', number, singular, suffix)
  2115. end
  2116. function M.LuaUnit.statusLine(result)
  2117. -- return status line string according to results
  2118. local s = {
  2119. string.format('Ran %d tests in %0.3f seconds',
  2120. result.runCount, result.duration),
  2121. conditional_plural(result.passedCount, 'success'),
  2122. }
  2123. if result.notPassedCount > 0 then
  2124. if result.failureCount > 0 then
  2125. table.insert(s, conditional_plural(result.failureCount, 'failure'))
  2126. end
  2127. if result.errorCount > 0 then
  2128. table.insert(s, conditional_plural(result.errorCount, 'error'))
  2129. end
  2130. else
  2131. table.insert(s, '0 failures')
  2132. end
  2133. if result.nonSelectedCount > 0 then
  2134. table.insert(s, string.format("%d non-selected", result.nonSelectedCount))
  2135. end
  2136. return table.concat(s, ', ')
  2137. end
  2138. function M.LuaUnit:startSuite(testCount, nonSelectedCount)
  2139. self.result = {
  2140. testCount = testCount,
  2141. nonSelectedCount = nonSelectedCount,
  2142. passedCount = 0,
  2143. runCount = 0,
  2144. currentTestNumber = 0,
  2145. currentClassName = "",
  2146. currentNode = nil,
  2147. suiteStarted = true,
  2148. startTime = os.clock(),
  2149. startDate = os.date(os.getenv('LUAUNIT_DATEFMT')),
  2150. startIsodate = os.date('%Y-%m-%dT%H:%M:%S'),
  2151. patternIncludeFilter = self.patternIncludeFilter,
  2152. tests = {},
  2153. failures = {},
  2154. errors = {},
  2155. notPassed = {},
  2156. }
  2157. self.outputType = self.outputType or TextOutput
  2158. self.output = self.outputType.new(self)
  2159. self.output:startSuite()
  2160. end
  2161. function M.LuaUnit:startClass( className )
  2162. self.result.currentClassName = className
  2163. self.output:startClass( className )
  2164. end
  2165. function M.LuaUnit:startTest( testName )
  2166. self.result.currentTestNumber = self.result.currentTestNumber + 1
  2167. self.result.runCount = self.result.runCount + 1
  2168. self.result.currentNode = NodeStatus.new(
  2169. self.result.currentTestNumber,
  2170. testName,
  2171. self.result.currentClassName
  2172. )
  2173. self.result.currentNode.startTime = os.clock()
  2174. table.insert( self.result.tests, self.result.currentNode )
  2175. self.output:startTest( testName )
  2176. end
  2177. function M.LuaUnit:addStatus( err )
  2178. -- "err" is expected to be a table / result from protectedCall()
  2179. if err.status == NodeStatus.PASS then
  2180. return
  2181. end
  2182. local node = self.result.currentNode
  2183. --[[ As a first approach, we will report only one error or one failure for one test.
  2184. However, we can have the case where the test is in failure, and the teardown is in error.
  2185. In such case, it's a good idea to report both a failure and an error in the test suite. This is
  2186. what Python unittest does for example. However, it mixes up counts so need to be handled carefully: for
  2187. example, there could be more (failures + errors) count that tests. What happens to the current node ?
  2188. We will do this more intelligent version later.
  2189. ]]
  2190. -- if the node is already in failure/error, just don't report the new error (see above)
  2191. if node.status ~= NodeStatus.PASS then
  2192. return
  2193. end
  2194. if err.status == NodeStatus.FAIL then
  2195. node:fail( err.msg, err.trace )
  2196. table.insert( self.result.failures, node )
  2197. elseif err.status == NodeStatus.ERROR then
  2198. node:error( err.msg, err.trace )
  2199. table.insert( self.result.errors, node )
  2200. end
  2201. if node:isFailure() or node:isError() then
  2202. -- add to the list of failed tests (gets printed separately)
  2203. table.insert( self.result.notPassed, node )
  2204. end
  2205. self.output:addStatus( node )
  2206. end
  2207. function M.LuaUnit:endTest()
  2208. local node = self.result.currentNode
  2209. -- print( 'endTest() '..prettystr(node))
  2210. -- print( 'endTest() '..prettystr(node:isNotPassed()))
  2211. node.duration = os.clock() - node.startTime
  2212. node.startTime = nil
  2213. self.output:endTest( node )
  2214. if node:isPassed() then
  2215. self.result.passedCount = self.result.passedCount + 1
  2216. elseif node:isError() then
  2217. if self.quitOnError or self.quitOnFailure then
  2218. -- Runtime error - abort test execution as requested by
  2219. -- "--error" option. This is done by setting a special
  2220. -- flag that gets handled in runSuiteByInstances().
  2221. print("\nERROR during LuaUnit test execution:\n" .. node.msg)
  2222. self.result.aborted = true
  2223. end
  2224. elseif node:isFailure() then
  2225. if self.quitOnFailure then
  2226. -- Failure - abort test execution as requested by
  2227. -- "--failure" option. This is done by setting a special
  2228. -- flag that gets handled in runSuiteByInstances().
  2229. print("\nFailure during LuaUnit test execution:\n" .. node.msg)
  2230. self.result.aborted = true
  2231. end
  2232. end
  2233. self.result.currentNode = nil
  2234. end
  2235. function M.LuaUnit:endClass()
  2236. self.output:endClass()
  2237. end
  2238. function M.LuaUnit:endSuite()
  2239. if self.result.suiteStarted == false then
  2240. error('LuaUnit:endSuite() -- suite was already ended' )
  2241. end
  2242. self.result.duration = os.clock()-self.result.startTime
  2243. self.result.suiteStarted = false
  2244. -- Expose test counts for outputter's endSuite(). This could be managed
  2245. -- internally instead, but unit tests (and existing use cases) might
  2246. -- rely on these fields being present.
  2247. self.result.notPassedCount = #self.result.notPassed
  2248. self.result.failureCount = #self.result.failures
  2249. self.result.errorCount = #self.result.errors
  2250. self.output:endSuite()
  2251. end
  2252. function M.LuaUnit:setOutputType(outputType)
  2253. -- default to text
  2254. -- tap produces results according to TAP format
  2255. if outputType:upper() == "NIL" then
  2256. self.outputType = NilOutput
  2257. return
  2258. end
  2259. if outputType:upper() == "TAP" then
  2260. self.outputType = TapOutput
  2261. return
  2262. end
  2263. if outputType:upper() == "JUNIT" then
  2264. self.outputType = JUnitOutput
  2265. return
  2266. end
  2267. if outputType:upper() == "TEXT" then
  2268. self.outputType = TextOutput
  2269. return
  2270. end
  2271. error( 'No such format: '..outputType,2)
  2272. end
  2273. --------------[[ Runner ]]-----------------
  2274. function M.LuaUnit:protectedCall(classInstance, methodInstance, prettyFuncName)
  2275. -- if classInstance is nil, this is just a function call
  2276. -- else, it's method of a class being called.
  2277. local function err_handler(e)
  2278. -- transform error into a table, adding the traceback information
  2279. return {
  2280. status = NodeStatus.ERROR,
  2281. msg = e,
  2282. trace = string.sub(debug.traceback("", 3), 2)
  2283. }
  2284. end
  2285. local ok, err
  2286. if classInstance then
  2287. -- stupid Lua < 5.2 does not allow xpcall with arguments so let's use a workaround
  2288. ok, err = xpcall( function () methodInstance(classInstance) end, err_handler )
  2289. else
  2290. ok, err = xpcall( function () methodInstance() end, err_handler )
  2291. end
  2292. if ok then
  2293. return {status = NodeStatus.PASS}
  2294. end
  2295. -- Failure message usually looks like:
  2296. -- "./test\\test_luaunit.lua:2241: LuaUnit test FAILURE: expected: 2, actual: 1"
  2297. -- If failure prefix is present, we assume this is a failure
  2298. -- we strip the prefix, and insert the iteration number along the way if relevant
  2299. -- we only strip one failure prefix of course.
  2300. local failed, iter_msg
  2301. iter_msg = self.exeRepeat and 'iteration: '..self.currentCount..', '
  2302. err.msg, failed = err.msg:gsub(M.FAILURE_PREFIX, iter_msg or '', 1)
  2303. if failed > 0 then
  2304. err.status = NodeStatus.FAIL
  2305. end
  2306. -- reformat / improve the stack trace
  2307. if prettyFuncName then -- we do have the real method name
  2308. err.trace = err.trace:gsub("in (%a+) 'methodInstance'", "in %1 '"..prettyFuncName.."'")
  2309. end
  2310. if STRIP_LUAUNIT_FROM_STACKTRACE then
  2311. err.trace = stripLuaunitTrace(err.trace)
  2312. end
  2313. return err -- return the error "object" (table)
  2314. end
  2315. function M.LuaUnit:execOneFunction(className, methodName, classInstance, methodInstance)
  2316. -- When executing a test function, className and classInstance must be nil
  2317. -- When executing a class method, all parameters must be set
  2318. if type(methodInstance) ~= 'function' then
  2319. error( tostring(methodName)..' must be a function, not '..type(methodInstance))
  2320. end
  2321. local prettyFuncName
  2322. if className == nil then
  2323. className = '[TestFunctions]'
  2324. prettyFuncName = methodName
  2325. else
  2326. prettyFuncName = className..'.'..methodName
  2327. end
  2328. if self.lastClassName ~= className then
  2329. if self.lastClassName ~= nil then
  2330. self:endClass()
  2331. end
  2332. self:startClass( className )
  2333. self.lastClassName = className
  2334. end
  2335. self:startTest(prettyFuncName)
  2336. local node = self.result.currentNode
  2337. for iter_n = 1, self.exeRepeat or 1 do
  2338. if node:isNotPassed() then
  2339. break
  2340. end
  2341. self.currentCount = iter_n
  2342. -- run setUp first (if any)
  2343. if classInstance then
  2344. local func = self.asFunction( classInstance.setUp ) or
  2345. self.asFunction( classInstance.Setup ) or
  2346. self.asFunction( classInstance.setup ) or
  2347. self.asFunction( classInstance.SetUp )
  2348. if func then
  2349. self:addStatus(self:protectedCall(classInstance, func, className..'.setUp'))
  2350. end
  2351. end
  2352. -- run testMethod()
  2353. if node:isPassed() then
  2354. self:addStatus(self:protectedCall(classInstance, methodInstance, prettyFuncName))
  2355. end
  2356. -- lastly, run tearDown (if any)
  2357. if classInstance then
  2358. local func = self.asFunction( classInstance.tearDown ) or
  2359. self.asFunction( classInstance.TearDown ) or
  2360. self.asFunction( classInstance.teardown ) or
  2361. self.asFunction( classInstance.Teardown )
  2362. if func then
  2363. self:addStatus(self:protectedCall(classInstance, func, className..'.tearDown'))
  2364. end
  2365. end
  2366. end
  2367. self:endTest()
  2368. end
  2369. function M.LuaUnit.expandOneClass( result, className, classInstance )
  2370. --[[
  2371. Input: a list of { name, instance }, a class name, a class instance
  2372. Ouptut: modify result to add all test method instance in the form:
  2373. { className.methodName, classInstance }
  2374. ]]
  2375. for methodName, methodInstance in sortedPairs(classInstance) do
  2376. if M.LuaUnit.asFunction(methodInstance) and M.LuaUnit.isMethodTestName( methodName ) then
  2377. table.insert( result, { className..'.'..methodName, classInstance } )
  2378. end
  2379. end
  2380. end
  2381. function M.LuaUnit.expandClasses( listOfNameAndInst )
  2382. --[[
  2383. -- expand all classes (provided as {className, classInstance}) to a list of {className.methodName, classInstance}
  2384. -- functions and methods remain untouched
  2385. Input: a list of { name, instance }
  2386. Output:
  2387. * { function name, function instance } : do nothing
  2388. * { class.method name, class instance }: do nothing
  2389. * { class name, class instance } : add all method names in the form of (className.methodName, classInstance)
  2390. ]]
  2391. local result = {}
  2392. for i,v in ipairs( listOfNameAndInst ) do
  2393. local name, instance = v[1], v[2]
  2394. if M.LuaUnit.asFunction(instance) then
  2395. table.insert( result, { name, instance } )
  2396. else
  2397. if type(instance) ~= 'table' then
  2398. error( 'Instance must be a table or a function, not a '..type(instance)..' with value '..prettystr(instance))
  2399. end
  2400. local className, methodName = M.LuaUnit.splitClassMethod( name )
  2401. if className then
  2402. local methodInstance = instance[methodName]
  2403. if methodInstance == nil then
  2404. error( "Could not find method in class "..tostring(className).." for method "..tostring(methodName) )
  2405. end
  2406. table.insert( result, { name, instance } )
  2407. else
  2408. M.LuaUnit.expandOneClass( result, name, instance )
  2409. end
  2410. end
  2411. end
  2412. return result
  2413. end
  2414. function M.LuaUnit.applyPatternFilter( patternIncFilter, listOfNameAndInst )
  2415. local included, excluded = {}, {}
  2416. for i, v in ipairs( listOfNameAndInst ) do
  2417. -- local name, instance = v[1], v[2]
  2418. if patternFilter( patternIncFilter, v[1] ) then
  2419. table.insert( included, v )
  2420. else
  2421. table.insert( excluded, v )
  2422. end
  2423. end
  2424. return included, excluded
  2425. end
  2426. function M.LuaUnit:runSuiteByInstances( listOfNameAndInst )
  2427. --[[ Run an explicit list of tests. Each item of the list must be one of:
  2428. * { function name, function instance }
  2429. * { class name, class instance }
  2430. * { class.method name, class instance }
  2431. ]]
  2432. local expandedList = self.expandClasses( listOfNameAndInst )
  2433. if self.shuffle then
  2434. randomizeTable( expandedList )
  2435. end
  2436. local filteredList, filteredOutList = self.applyPatternFilter(
  2437. self.patternIncludeFilter, expandedList )
  2438. self:startSuite( #filteredList, #filteredOutList )
  2439. for i,v in ipairs( filteredList ) do
  2440. local name, instance = v[1], v[2]
  2441. if M.LuaUnit.asFunction(instance) then
  2442. self:execOneFunction( nil, name, nil, instance )
  2443. else
  2444. -- expandClasses() should have already taken care of sanitizing the input
  2445. assert( type(instance) == 'table' )
  2446. local className, methodName = M.LuaUnit.splitClassMethod( name )
  2447. assert( className ~= nil )
  2448. local methodInstance = instance[methodName]
  2449. assert(methodInstance ~= nil)
  2450. self:execOneFunction( className, methodName, instance, methodInstance )
  2451. end
  2452. if self.result.aborted then
  2453. break -- "--error" or "--failure" option triggered
  2454. end
  2455. end
  2456. if self.lastClassName ~= nil then
  2457. self:endClass()
  2458. end
  2459. self:endSuite()
  2460. if self.result.aborted then
  2461. print("LuaUnit ABORTED (as requested by --error or --failure option)")
  2462. os.exit(-2)
  2463. end
  2464. end
  2465. function M.LuaUnit:runSuiteByNames( listOfName )
  2466. --[[ Run LuaUnit with a list of generic names, coming either from command-line or from global
  2467. namespace analysis. Convert the list into a list of (name, valid instances (table or function))
  2468. and calls runSuiteByInstances.
  2469. ]]
  2470. local instanceName, instance
  2471. local listOfNameAndInst = {}
  2472. for i,name in ipairs( listOfName ) do
  2473. local className, methodName = M.LuaUnit.splitClassMethod( name )
  2474. if className then
  2475. instanceName = className
  2476. instance = _G[instanceName]
  2477. if instance == nil then
  2478. error( "No such name in global space: "..instanceName )
  2479. end
  2480. if type(instance) ~= 'table' then
  2481. error( 'Instance of '..instanceName..' must be a table, not '..type(instance))
  2482. end
  2483. local methodInstance = instance[methodName]
  2484. if methodInstance == nil then
  2485. error( "Could not find method in class "..tostring(className).." for method "..tostring(methodName) )
  2486. end
  2487. else
  2488. -- for functions and classes
  2489. instanceName = name
  2490. instance = _G[instanceName]
  2491. end
  2492. if instance == nil then
  2493. error( "No such name in global space: "..instanceName )
  2494. end
  2495. if (type(instance) ~= 'table' and type(instance) ~= 'function') then
  2496. error( 'Name must match a function or a table: '..instanceName )
  2497. end
  2498. table.insert( listOfNameAndInst, { name, instance } )
  2499. end
  2500. self:runSuiteByInstances( listOfNameAndInst )
  2501. end
  2502. function M.LuaUnit.run(...)
  2503. -- Run some specific test classes.
  2504. -- If no arguments are passed, run the class names specified on the
  2505. -- command line. If no class name is specified on the command line
  2506. -- run all classes whose name starts with 'Test'
  2507. --
  2508. -- If arguments are passed, they must be strings of the class names
  2509. -- that you want to run or generic command line arguments (-o, -p, -v, ...)
  2510. local runner = M.LuaUnit.new()
  2511. return runner:runSuite(...)
  2512. end
  2513. function M.LuaUnit:runSuite( ... )
  2514. local args = {...}
  2515. if type(args[1]) == 'table' and args[1].__class__ == 'LuaUnit' then
  2516. -- run was called with the syntax M.LuaUnit:runSuite()
  2517. -- we support both M.LuaUnit.run() and M.LuaUnit:run()
  2518. -- strip out the first argument
  2519. table.remove(args,1)
  2520. end
  2521. if #args == 0 then
  2522. args = cmdline_argv
  2523. end
  2524. local options = pcall_or_abort( M.LuaUnit.parseCmdLine, args )
  2525. -- We expect these option fields to be either `nil` or contain
  2526. -- valid values, so it's safe to always copy them directly.
  2527. self.verbosity = options.verbosity
  2528. self.quitOnError = options.quitOnError
  2529. self.quitOnFailure = options.quitOnFailure
  2530. self.fname = options.fname
  2531. self.exeRepeat = options.exeRepeat
  2532. self.patternIncludeFilter = options.pattern
  2533. self.shuffle = options.shuffle
  2534. if options.output then
  2535. if options.output:lower() == 'junit' and options.fname == nil then
  2536. print('With junit output, a filename must be supplied with -n or --name')
  2537. os.exit(-1)
  2538. end
  2539. pcall_or_abort(self.setOutputType, self, options.output)
  2540. end
  2541. self:runSuiteByNames( options.testNames or M.LuaUnit.collectTests() )
  2542. return self.result.notPassedCount
  2543. end
  2544. -- class LuaUnit
  2545. -- For compatbility with LuaUnit v2
  2546. M.run = M.LuaUnit.run
  2547. M.Run = M.LuaUnit.run
  2548. function M:setVerbosity( verbosity )
  2549. M.LuaUnit.verbosity = verbosity
  2550. end
  2551. M.set_verbosity = M.setVerbosity
  2552. M.SetVerbosity = M.setVerbosity
  2553. return M