inflate-bit32.lua 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280
  1. -- zzlib-bit32 - zlib decompression in Lua - version using bit/bit32 libraries
  2. -- Copyright (c) 2016-2020 Francois Galea <fgalea at free.fr>
  3. -- This program is free software. It comes without any warranty, to
  4. -- the extent permitted by applicable law. You can redistribute it
  5. -- and/or modify it under the terms of the Do What The Fuck You Want
  6. -- To Public License, Version 2, as published by Sam Hocevar. See
  7. -- the COPYING file or http://www.wtfpl.net/ for more details.
  8. local inflate = {}
  9. local bit = bit32 or bit
  10. inflate.band = bit.band
  11. inflate.rshift = bit.rshift
  12. function inflate.bitstream_init(file)
  13. local bs = {
  14. file = file, -- the open file handle
  15. buf = nil, -- character buffer
  16. len = nil, -- length of character buffer
  17. pos = 1, -- position in char buffer
  18. b = 0, -- bit buffer
  19. n = 0, -- number of bits in buffer
  20. }
  21. -- get rid of n first bits
  22. function bs:flushb(n)
  23. self.n = self.n - n
  24. self.b = bit.rshift(self.b,n)
  25. end
  26. -- peek a number of n bits from stream
  27. function bs:peekb(n)
  28. while self.n < n do
  29. if self.pos > self.len then
  30. self.buf = self.file:read(4096)
  31. self.len = self.buf:len()
  32. self.pos = 1
  33. end
  34. self.b = self.b + bit.lshift(self.buf:byte(self.pos),self.n)
  35. self.pos = self.pos + 1
  36. self.n = self.n + 8
  37. end
  38. return bit.band(self.b,bit.lshift(1,n)-1)
  39. end
  40. -- get a number of n bits from stream
  41. function bs:getb(n)
  42. local ret = bs:peekb(n)
  43. self.n = self.n - n
  44. self.b = bit.rshift(self.b,n)
  45. return ret
  46. end
  47. -- get next variable-size of maximum size=n element from stream, according to Huffman table
  48. function bs:getv(hufftable,n)
  49. local e = hufftable[bs:peekb(n)]
  50. local len = bit.band(e,15)
  51. local ret = bit.rshift(e,4)
  52. self.n = self.n - len
  53. self.b = bit.rshift(self.b,len)
  54. return ret
  55. end
  56. function bs:close()
  57. if self.file then
  58. self.file:close()
  59. end
  60. end
  61. if type(file) == "string" then
  62. bs.file = nil
  63. bs.buf = file
  64. else
  65. bs.buf = file:read(4096)
  66. end
  67. bs.len = bs.buf:len()
  68. return bs
  69. end
  70. local function hufftable_create(depths)
  71. local nvalues = #depths
  72. local nbits = 1
  73. local bl_count = {}
  74. local next_code = {}
  75. for i=1,nvalues do
  76. local d = depths[i]
  77. if d > nbits then
  78. nbits = d
  79. end
  80. bl_count[d] = (bl_count[d] or 0) + 1
  81. end
  82. local table = {}
  83. local code = 0
  84. bl_count[0] = 0
  85. for i=1,nbits do
  86. code = (code + (bl_count[i-1] or 0)) * 2
  87. next_code[i] = code
  88. end
  89. for i=1,nvalues do
  90. local len = depths[i] or 0
  91. if len > 0 then
  92. local e = (i-1)*16 + len
  93. local code = next_code[len]
  94. local rcode = 0
  95. for j=1,len do
  96. rcode = rcode + bit.lshift(bit.band(1,bit.rshift(code,j-1)),len-j)
  97. end
  98. for j=0,2^nbits-1,2^len do
  99. table[j+rcode] = e
  100. end
  101. next_code[len] = next_code[len] + 1
  102. end
  103. end
  104. return table,nbits
  105. end
  106. local function block_loop(out,bs,nlit,ndist,littable,disttable)
  107. local lit
  108. repeat
  109. lit = bs:getv(littable,nlit)
  110. if lit < 256 then
  111. table.insert(out,lit)
  112. elseif lit > 256 then
  113. local nbits = 0
  114. local size = 3
  115. local dist = 1
  116. if lit < 265 then
  117. size = size + lit - 257
  118. elseif lit < 285 then
  119. nbits = bit.rshift(lit-261,2)
  120. size = size + bit.lshift(bit.band(lit-261,3)+4,nbits)
  121. else
  122. size = 258
  123. end
  124. if nbits > 0 then
  125. size = size + bs:getb(nbits)
  126. end
  127. local v = bs:getv(disttable,ndist)
  128. if v < 4 then
  129. dist = dist + v
  130. else
  131. nbits = bit.rshift(v-2,1)
  132. dist = dist + bit.lshift(bit.band(v,1)+2,nbits)
  133. dist = dist + bs:getb(nbits)
  134. end
  135. local p = #out-dist+1
  136. while size > 0 do
  137. table.insert(out,out[p])
  138. p = p + 1
  139. size = size - 1
  140. end
  141. end
  142. until lit == 256
  143. end
  144. local function block_dynamic(out,bs)
  145. local order = { 17, 18, 19, 1, 9, 8, 10, 7, 11, 6, 12, 5, 13, 4, 14, 3, 15, 2, 16 }
  146. local hlit = 257 + bs:getb(5)
  147. local hdist = 1 + bs:getb(5)
  148. local hclen = 4 + bs:getb(4)
  149. local depths = {}
  150. for i=1,hclen do
  151. local v = bs:getb(3)
  152. depths[order[i]] = v
  153. end
  154. for i=hclen+1,19 do
  155. depths[order[i]] = 0
  156. end
  157. local lengthtable,nlen = hufftable_create(depths)
  158. local i=1
  159. while i<=hlit+hdist do
  160. local v = bs:getv(lengthtable,nlen)
  161. if v < 16 then
  162. depths[i] = v
  163. i = i + 1
  164. elseif v < 19 then
  165. local nbt = {2,3,7}
  166. local nb = nbt[v-15]
  167. local c = 0
  168. local n = 3 + bs:getb(nb)
  169. if v == 16 then
  170. c = depths[i-1]
  171. elseif v == 18 then
  172. n = n + 8
  173. end
  174. for j=1,n do
  175. depths[i] = c
  176. i = i + 1
  177. end
  178. else
  179. error("wrong entry in depth table for literal/length alphabet: "..v);
  180. end
  181. end
  182. local litdepths = {} for i=1,hlit do table.insert(litdepths,depths[i]) end
  183. local littable,nlit = hufftable_create(litdepths)
  184. local distdepths = {} for i=hlit+1,#depths do table.insert(distdepths,depths[i]) end
  185. local disttable,ndist = hufftable_create(distdepths)
  186. block_loop(out,bs,nlit,ndist,littable,disttable)
  187. end
  188. local function block_static(out,bs)
  189. local cnt = { 144, 112, 24, 8 }
  190. local dpt = { 8, 9, 7, 8 }
  191. local depths = {}
  192. for i=1,4 do
  193. local d = dpt[i]
  194. for j=1,cnt[i] do
  195. table.insert(depths,d)
  196. end
  197. end
  198. local littable,nlit = hufftable_create(depths)
  199. depths = {}
  200. for i=1,32 do
  201. depths[i] = 5
  202. end
  203. local disttable,ndist = hufftable_create(depths)
  204. block_loop(out,bs,nlit,ndist,littable,disttable)
  205. end
  206. local function block_uncompressed(out,bs)
  207. bs:flushb(bit.band(bs.n,7))
  208. local len = bs:getb(16)
  209. if bs.n > 0 then
  210. error("Unexpected.. should be zero remaining bits in buffer.")
  211. end
  212. local nlen = bs:getb(16)
  213. if bit.bxor(len,nlen) ~= 65535 then
  214. error("LEN and NLEN don't match")
  215. end
  216. for i=bs.pos,bs.pos+len-1 do
  217. table.insert(out,bs.buf:byte(i,i))
  218. end
  219. bs.pos = bs.pos + len
  220. end
  221. function inflate.main(bs)
  222. local last,type
  223. local output = {}
  224. repeat
  225. local block
  226. last = bs:getb(1)
  227. type = bs:getb(2)
  228. if type == 0 then
  229. block_uncompressed(output,bs)
  230. elseif type == 1 then
  231. block_static(output,bs)
  232. elseif type == 2 then
  233. block_dynamic(output,bs)
  234. else
  235. error("unsupported block type")
  236. end
  237. until last == 1
  238. bs:flushb(bit.band(bs.n,7))
  239. return output
  240. end
  241. local crc32_table
  242. function inflate.crc32(s,crc)
  243. if not crc32_table then
  244. crc32_table = {}
  245. for i=0,255 do
  246. local r=i
  247. for j=1,8 do
  248. r = bit.bxor(bit.rshift(r,1),bit.band(0xedb88320,bit.bnot(bit.band(r,1)-1)))
  249. end
  250. crc32_table[i] = r
  251. end
  252. end
  253. crc = bit.bnot(crc or 0)
  254. for i=1,#s do
  255. local c = s:byte(i)
  256. crc = bit.bxor(crc32_table[bit.bxor(c,bit.band(crc,0xff))],bit.rshift(crc,8))
  257. end
  258. crc = bit.bnot(crc)
  259. if crc<0 then
  260. -- in Lua < 5.2, sign extension was performed
  261. crc = crc + 4294967296
  262. end
  263. return crc
  264. end
  265. return inflate