//基本语法结构 module Decode38(A_in, Y_out); //模块名称(端口列表) //端口定义申明: input [2:0] A_in; output reg [7:0] Y_out; //内部变量及参数申明 //模块功能实现 always@(A_in) begin case(A_in) 3'b000: Y_out = 8'b1111_1110; 3'b001: Y_out = 8'b1111_1101; 3'b010: Y_out = 8'b1111_1011; 3'b011: Y_out = 8'b1111_0111; 3'b100: Y_out = 8'b1110_1111; 3'b101: Y_out = 8'b1101_1111; 3'b110: Y_out = 8'b1011_1111; 3'b111: Y_out = 8'b0111_1111; default: Y_out = 8'b1111_1111; endcase end endmodule //模块结束
assign key_edge = key_rst_pre & (~key_rst);//脉冲边沿检测。当key检测到下降沿时,key_edge产生一个时钟周期的高电平 always @(posedge clk or negedge rst) begin if (!rst) key_sec <= {N{1'b1}}; else if (cnt==18'h3ffff) key_sec <= key; end always @(posedge clk or negedge rst) begin if (!rst) key_sec_pre <= {N{1'b1}}; else key_sec_pre <= key_sec; end assign key_pulse = key_sec_pre & (~key_sec);
assign clkout = (N==1)?clk:(N[0])?(clk_p&clk_n):clk_p; //对clk进行N分频
#define K 8 typedef ap_int<16> dtype_dat; typedef ap_int<16*K> dtype_bus; typedef struct { dtype_bus data; bool last; }dtype_stream;
void pool_1D(hls::stream<dtype_bus> &in,hls::stream<dtype_bus> &out,int ch_div_K,int height_in,int width_in,int Kx)
参数说明: **in** 特征数据流按照[ch_div_K][height_in][width_in]顺序输入; **out** 特征数据流按照[ch_div_K][height_in][width_out]顺序输出; **Kx** 特征数据流按行进行Kx池化。
void pool_2D(hls::stream<dtype_bus> &in,hls::stream<dtype_bus> &out,int ch_div_K,int height_in,int width_out,int Ky)
参数说明: **in** 特征数据流按照[ch_div_K][height_in][width_out]顺序输入; **out** 特征数据流按照[ch_div_K][height_out][width_out]顺序输出; **Ky** 特征数据流按列进行Ky池化。
void pool(hls::stream<dtype_bus> &in,hls::stream<dtype_stream> &out, int ch_div_K,int height_in,int width_in, int height_out,int width_out,int Kx,int Ky) { #pragma HLS INTERFACE s_axilite port=return #pragma HLS INTERFACE s_axilite port=Ky #pragma HLS INTERFACE s_axilite port=width_in #pragma HLS INTERFACE s_axilite port=Kx #pragma HLS INTERFACE s_axilite port=height_in #pragma HLS INTERFACE s_axilite port=height_out #pragma HLS INTERFACE s_axilite port=width_out #pragma HLS INTERFACE s_axilite port=ch_div_K #pragma HLS DATAFLOW #pragma HLS INTERFACE axis register both port=out #pragma HLS INTERFACE axis register both port=in hls::stream<dtype_bus> stream_tp; #pragma HLS STREAM variable=stream_tp depth=8 dim=1 hls::stream<dtype_bus> stream_tp2; pool_1D(in,stream_tp,ch_div_K,height_in,width_in,Kx); pool_2D(stream_tp,stream_tp2,ch_div_K,height_in,width_out,Ky); hs2axis(stream_tp2,out,ch_div_K,height_out,width_out); }
from pynq import Overlay from pynq import Xlnk ol = Overlay("conv.bit") ol.download() print(ol.ip_dict.keys()) dma = ol.axi_dma_0 pool = ol.pool_stream_0 conv = ol.Conv_0 xlnk = Xlnk()
import driver from MNIST_LARGE_cfg import * driver.Load_Weight_From_File(W_conv1, "./record/W_conv1.bin") driver.Load_Weight_From_File(W_fc1, "./record/W_fc1.bin") driver.Load_Weight_From_File(W_fc2, "./record/W_fc2.bin")
start=time.time() driver.Run_Conv(conv, 1,32, 3,3, 1,1, 1,0, src_buffer,PTR_IMG,W_conv1,PTR_W_CONV1,h_conv1,PTR_H_CONV1) driver.Run_Pool(pool,dma, 32, 4,4, h_conv1,h_pool1) driver.Run_Conv(conv, 32,256, 7,7, 1,1, 0,0, h_pool1,PTR_H_POOL1,W_fc1,PTR_W_FC1,h_fc1,PTR_H_FC1) driver.Run_Conv(conv, 256,10, 1,1, 1,1, 0,0, h_fc1,PTR_H_FC1,W_fc2,PTR_W_FC2,h_fc2,PTR_H_FC2) end=time.time() print("Hardware run time=%s s"%(end-start))
max=-32768 num=0 for i in range(10): if(h_fc2[i//K][0][0][i%K]>max): max=h_fc2[i//K][0][0][i%K] num=i; print("predict num is %d"%num);