Java調(diào)用Pytorch模型實現(xiàn)圖像識別
之前寫了個輸入是1x2向量的模型的調(diào)用文章,后來有了個需要用到圖像識別的項目,因此寫下此文記錄一下在java中如何借助DJL調(diào)用自己寫的pytorch模型進行圖像識別。
我具體模型用的什么模型就不介紹了,輸入圖片是3*224*224,放入圖片前需要看一下橫縱比是否合理,不合理的話會進行下面這樣的操作:
1. 依賴
<dependency> <groupId>ai.djl.pytorch</groupId> <artifactId>pytorch-engine</artifactId> <version>0.16.0</version> </dependency> <dependency> <groupId>ai.djl.pytorch</groupId> <artifactId>pytorch-native-auto</artifactId> <version>1.9.1</version> <scope>runtime</scope> </dependency> <dependency> <groupId>ai.djl.pytorch</groupId> <artifactId>pytorch-jni</artifactId> <version>1.9.1-0.16.0</version> <scope>runtime</scope> </dependency>
2. 準(zhǔn)備模型
1.首先將模型按下面方法保存,放到項目resources中
import torch # An instance of your model. model = MyModel(num_classes = 80) # Switch the model to eval model model.eval() # An example input you would normally provide to your model's forward() method. example = torch.rand(1, 3, 224, 224) # Use torch.jit.trace to generate a torch.jit.ScriptModule via tracing. traced_script_module = torch.jit.trace(model, example) # Save the TorchScript model traced_script_module.save("model.pt")
2.編寫工具類,用于完成識別功能
public class HerbUtil { //規(guī)定輸入尺寸 private static final int INPUT_SIZE = 224; //標(biāo)簽文件 一種類別名字占一行 private List<String> herbNames; //用于識別 Predictor<Image, Classifications> predictor; //模型 private Model model; public HerbUtil() { //加載標(biāo)簽到herbNames中 this.loadHerbNames(); //初始化模型工作 this.init(); } }
3.將標(biāo)簽文件放到resources中,載入標(biāo)簽
private void loadHerbNames() { BufferedReader reader = null; herbNames = new ArrayList<>(); try { InputStream in = HerbUtil.class.getClassLoader().getResourceAsStream("names.txt"); reader = new BufferedReader(new InputStreamReader(in)); String name = null; while ((name = reader.readLine()) != null) { herbNames.add(name); } System.out.println(herbNames); } catch (Exception e) { e.printStackTrace(); } finally { if (reader != null) { try { reader.close(); } catch (IOException e) { e.printStackTrace(); } } } }
4.初始化模型
private void init() { Translator<Image, Classifications> translator = ImageClassificationTranslator.builder() //下面的transform根據(jù)自己的改 .addTransform(new RandomResizedCrop(INPUT_SIZE, INPUT_SIZE, 0.6, 1, 3. / 4, 4. / 3)) .addTransform(new ToTensor()) .addTransform(new Normalize( new float[] {0.5f, 0.5f, 0.5f}, new float[] {0.5f, 0.5f, 0.5f})) //如果你的模型最后一層沒有經(jīng)過softmax就啟用它 .optApplySoftmax(true) //載入所有標(biāo)簽進去 .optSynset(herbNames) //最終顯示概率最高的5個 .optTopK(5) .build(); //隨便起名 Model model = Model.newInstance("model", Device.cpu()); try { InputStream inputStream = HerbUtil.class.getClassLoader().getResourceAsStream("model.pt"); if (inputStream == null) { throw new RuntimeException("找不到模型文件"); } model.load(inputStream); predictor = model.newPredictor(translator); } catch (Exception e) { e.printStackTrace(); } }
5.我開頭提到的圖片預(yù)處理 的代碼
private Image resizeImage(InputStream inputStream) { BufferedImage input = null; try { input = ImageIO.read(inputStream); } catch (IOException e) { e.printStackTrace(); } int iw = input.getWidth(), ih = input.getHeight(); int w = 224, h = 224; double scale = Math.min(1. * w / iw, 1. * h / ih); int nw = (int) (iw * scale), nh = (int) (ih * scale); java.awt.Image img; //只有太長或太寬才會保留橫縱比,填充顏色 boolean needResize = 1. * iw / ih > 1.4 || 1. * ih / iw > 1.4; if (needResize) { img = input.getScaledInstance(nw, nh, BufferedImage.SCALE_SMOOTH); } else { img = input.getScaledInstance(INPUT_SIZE, INPUT_SIZE, BufferedImage.SCALE_SMOOTH); } BufferedImage out = new BufferedImage(INPUT_SIZE, INPUT_SIZE, BufferedImage.TYPE_INT_RGB); Graphics g = out.getGraphics(); //先將整個224*224區(qū)域填充128 128 128顏色 g.setColor(new Color(128, 128, 128)); g.fillRect(0, 0, INPUT_SIZE, INPUT_SIZE); out.getGraphics().drawImage(img, 0, needResize ? (INPUT_SIZE - nh) / 2 : 0, null); ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); try { ImageOutputStream imageOutputStream = ImageIO.createImageOutputStream(outputStream); ImageIO.write(out, "jpg", imageOutputStream); //去D盤看效果 //ImageIO.write(out, "jpg", new File("D:\\out.jpg")); InputStream is = new ByteArrayInputStream(outputStream.toByteArray()); return ImageFactory.getInstance().fromInputStream(is); } catch (IOException e) { e.printStackTrace(); throw new RuntimeException("圖片轉(zhuǎn)換失敗"); } }
6.識別功能
public List<Classifications.Classification> predict(InputStream inputStream) { List<Classifications.Classification> result = new ArrayList<>(); Image input = this.resizeImage(inputStream); try { Classifications output = predictor.predict(input); System.out.println("推測為:" + output.best().getClassName() + ", 概率:" + output.best().getProbability()); System.out.println(output); result = output.topK(); } catch (Exception e) { e.printStackTrace(); } return result; }
3. 測試
@Test public void test7() { HerbUtil herbUtil = new HerbUtil(); String path = "E:\\深度學(xué)習(xí)專用\\data\\train\\當(dāng)歸\\24.jpeg"; try { File file = new File(path); InputStream inputStream = new FileInputStream(file); herbUtil.predict(inputStream); } catch (Exception e) { e.printStackTrace(); } }
輸出:
加入到項目中后,工具類直接Autowire注入或者方法都寫static的,Controller接收前端MultipartFile,將其inputstream用于推測
如果你想加載網(wǎng)絡(luò)圖片,那就去網(wǎng)上搜索怎么把它轉(zhuǎn)成inputstream吧
測試多線程一起predict時報錯了
4.更新
當(dāng)我打包成jar到centos7的linux中運行時,報錯UnsatisfiedLinkError,經(jīng)過大神的指導(dǎo),問題出來我引的依賴。
修改后的依賴:
<properties> <java.version>8</java.version> <jna.version>5.3.0</jna.version> </properties> <dependencies> <dependency> <groupId>ai.djl.pytorch</groupId> <artifactId>pytorch-engine</artifactId> <version>0.16.0</version> </dependency> <dependency> <groupId>ai.djl.pytorch</groupId> <artifactId>pytorch-native-cpu-precxx11</artifactId> <classifier>linux-x86_64</classifier> <version>1.9.1</version> <scope>runtime</scope> </dependency> <dependency> <groupId>ai.djl.pytorch</groupId> <artifactId>pytorch-jni</artifactId> <version>1.9.1-0.16.0</version> <scope>runtime</scope> </dependency> <dependency> <groupId>org.springframework.boot</groupId> <artifactId>spring-boot-starter-web</artifactId> </dependency> </dependencies>
到此這篇關(guān)于Java調(diào)用Pytorch模型實現(xiàn)圖像識別的文章就介紹到這了,更多相關(guān)Java Pytorch圖像識別內(nèi)容請搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!
相關(guān)文章
springboot集成nacos實現(xiàn)自動刷新的示例代碼
研究nacos時發(fā)現(xiàn),springboot版本可使用@NacosValue實現(xiàn)配置的自動刷新,本文主要介紹了springboot集成nacos實現(xiàn)自動刷新的示例代碼,感興趣的可以了解一下2023-11-11mybatis批量update時報錯multi-statement not allow的問題
這篇文章主要介紹了mybatis批量update時報錯multi-statement not allow的問題及解決方案,具有很好的參考價值,希望對大家有所幫助,如有錯誤或未考慮完全的地方,望不吝賜教2023-10-10