Java調(diào)用Pytorch實現(xiàn)以圖搜圖功能
Java調(diào)用Pytorch實現(xiàn)以圖搜圖
設(shè)計技術(shù)棧
1、ElasticSearch環(huán)境;
2、Python運行環(huán)境(如果事先沒有pytorch模型時,可以用python腳本創(chuàng)建模型);
1、運行效果
2、創(chuàng)建模型(有則可以跳過)
1.vi script.py
import torch import torch.nn as nn import torchvision.models as models class ImageFeatureExtractor(nn.Module): def __init__(self): super(ImageFeatureExtractor, self).__init__() self.resnet = models.resnet50(pretrained=True) #最終輸出維度1024的向量,下文elastic search要設(shè)置dims為1024 self.resnet.fc = nn.Linear(2048, 1024) def forward(self, x): x = self.resnet(x) return x if __name__ == '__main__': model = ImageFeatureExtractor() model.eval() #根據(jù)模型隨便創(chuàng)建一個輸入 input = torch.rand([1, 3, 224, 224]) output = model(input) #以這種方式保存 script = torch.jit.trace(model, input) script.save("model.pt")
2、java項目pom.xml
<dependencies> <dependency> <groupId>org.springframework.boot</groupId> <artifactId>spring-boot-starter-web</artifactId> </dependency> <dependency> <groupId>org.projectlombok</groupId> <artifactId>lombok</artifactId> <scope>provided</scope> </dependency> <dependency> <groupId>ai.djl.pytorch</groupId> <artifactId>pytorch-engine</artifactId> <version>0.19.0</version> </dependency> <dependency> <groupId>ai.djl.pytorch</groupId> <artifactId>pytorch-native-cpu</artifactId> <version>1.10.0</version> <scope>runtime</scope> </dependency> <dependency> <groupId>ai.djl.pytorch</groupId> <artifactId>pytorch-jni</artifactId> <version>1.10.0-0.19.0</version> </dependency> <dependency> <groupId>org.elasticsearch.client</groupId> <artifactId>elasticsearch-rest-high-level-client</artifactId> </dependency> </dependencies>
3、ES創(chuàng)建文檔
PUT /isi { "mappings": { "properties": { "vector": { "type": "dense_vector", "dims": 1024 }, "url" : { "type" : "keyword" }, "user_id": { "type": "keyword" } } } }
4、編寫java代碼調(diào)用模型
ORCUtil.java
package com.topprismcloud.rtm; import ai.djl.Device; import ai.djl.Model; import ai.djl.inference.Predictor; import ai.djl.modality.cv.Image; import ai.djl.modality.cv.ImageFactory; import ai.djl.modality.cv.transform.Normalize; import ai.djl.modality.cv.transform.Resize; import ai.djl.modality.cv.transform.ToTensor; import ai.djl.modality.cv.util.NDImageUtils; import ai.djl.ndarray.NDArray; import ai.djl.ndarray.NDList; import ai.djl.ndarray.NDManager; import ai.djl.translate.Transform; import ai.djl.translate.Translator; import ai.djl.translate.TranslatorContext; import org.apache.http.HttpHost; import org.apache.http.auth.AuthScope; import org.apache.http.auth.UsernamePasswordCredentials; import org.apache.http.client.CredentialsProvider; import org.apache.http.impl.client.BasicCredentialsProvider; import org.elasticsearch.action.bulk.BulkRequest; import org.elasticsearch.action.index.IndexRequest; import org.elasticsearch.action.search.SearchRequest; import org.elasticsearch.action.search.SearchResponse; import org.elasticsearch.client.RequestOptions; import org.elasticsearch.client.RestClient; import org.elasticsearch.client.RestClientBuilder; import org.elasticsearch.client.RestHighLevelClient; import org.elasticsearch.client.transport.TransportClient; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.index.query.QueryBuilders; import org.elasticsearch.index.query.ScriptQueryBuilder; import org.elasticsearch.index.query.functionscore.FunctionScoreQueryBuilder; import org.elasticsearch.index.query.functionscore.ScoreFunctionBuilders; import org.elasticsearch.script.Script; import org.elasticsearch.script.ScriptType; import org.elasticsearch.search.SearchHit; import org.elasticsearch.search.SearchHits; import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.xcontent.XContentType; import java.io.File; import java.io.FileInputStream; import java.io.IOException; import java.io.InputStream; import java.net.URI; import java.net.URL; import java.nio.file.Paths; import java.util.*; public class ORCUtil { private static final String INDEX = "isi"; private static final int IMAGE_SIZE = 224; private static Model model; // 模型 private static Predictor<Image, float[]> predictor; // predictor.predict(input)相當(dāng)于python中model(input) static { try { model = Model.newInstance("model"); // 這里的model.pt是上面代碼展示的那種方式保存的 model.load(ORCUtil.class.getClassLoader().getResourceAsStream("model.pt")); Transform resize = new Resize(IMAGE_SIZE); Transform toTensor = new ToTensor(); Transform normalize = new Normalize(new float[] { 0.485f, 0.456f, 0.406f }, new float[] { 0.229f, 0.224f, 0.225f }); // Translator處理輸入Image轉(zhuǎn)為tensor、輸出轉(zhuǎn)為float[] Translator<Image, float[]> translator = new Translator<Image, float[]>() { @Override public NDList processInput(TranslatorContext ctx, Image input) throws Exception { NDManager ndManager = ctx.getNDManager(); System.out.println("input: " + input.getWidth() + ", " + input.getHeight()); NDArray transform = normalize .transform(toTensor.transform(resize.transform(input.toNDArray(ndManager)))); System.out.println(transform.getShape()); NDList list = new NDList(); list.add(transform); return list; } @Override public float[] processOutput(TranslatorContext ctx, NDList ndList) throws Exception { return ndList.get(0).toFloatArray(); } }; predictor = new Predictor<>(model, translator, Device.cpu(), true); } catch (Exception e) { e.printStackTrace(); } } public static void upload() throws Exception { HttpHost host=new HttpHost("14.20.30.16", 9200, HttpHost.DEFAULT_SCHEME_NAME); RestClientBuilder builder=RestClient.builder(host); CredentialsProvider credentialsProvider = new BasicCredentialsProvider(); credentialsProvider.setCredentials(AuthScope.ANY, new UsernamePasswordCredentials("elastic", "123456")); builder.setHttpClientConfigCallback(f -> f.setDefaultCredentialsProvider(credentialsProvider)); RestHighLevelClient client = new RestHighLevelClient( builder); // 批量上傳請求 BulkRequest bulkRequest = new BulkRequest(INDEX); File file = new File("D:\\001ENV\\nginx-1.24.0\\html\\resource\\new"); for (File listFile : file.listFiles()) { // float[] vector = predictor.predict(ImageFactory.getInstance() // .fromInputStream(Test.class.getClassLoader().getResourceAsStream("new/" + listFile.getName()))); float[] vector = predictor.predict(ImageFactory.getInstance() .fromInputStream(new FileInputStream(listFile))); // 構(gòu)建文檔 Map<String, Object> jsonMap = new HashMap<>(); jsonMap.put("url", "/resource/"+listFile.getName()); jsonMap.put("vector", vector); jsonMap.put("user_id", "user123"); IndexRequest request = new IndexRequest(INDEX).source(jsonMap, XContentType.JSON); bulkRequest.add(request); } client.bulk(bulkRequest, RequestOptions.DEFAULT); client.close(); } // 接收待搜索圖片的inputstream,搜索與其相似的圖片 public static List<SearchResult> search(InputStream input) throws Throwable { float[] vector = predictor.predict(ImageFactory.getInstance().fromInputStream(input)); System.out.println(Arrays.toString(vector)); // 展示k個結(jié)果 int k = 100; // 連接Elasticsearch服務(wù)器 RestHighLevelClient client = new RestHighLevelClient( RestClient.builder(new HttpHost("14.20.30.16", 9200, "http"))); SearchRequest searchRequest = new SearchRequest(INDEX); Script script = new Script(ScriptType.INLINE, "painless", "cosineSimilarity(params.queryVector, doc['vector'])", Collections.singletonMap("queryVector", vector)); FunctionScoreQueryBuilder functionScoreQueryBuilder = QueryBuilders .functionScoreQuery(QueryBuilders.matchAllQuery(), ScoreFunctionBuilders.scriptFunction(script)); SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); searchSourceBuilder.query(functionScoreQueryBuilder).fetchSource(null, "vector") // 不返回vector字段,太多了沒用還耗時 .size(k); searchRequest.source(searchSourceBuilder); SearchResponse searchResponse = client.search(searchRequest, RequestOptions.DEFAULT); SearchHits hits = searchResponse.getHits(); List<SearchResult> list = new ArrayList<>(); for (SearchHit hit : hits) { // 處理搜索結(jié)果 System.out.println(hit.toString()); SearchResult result = new SearchResult((String) hit.getSourceAsMap().get("url"), hit.getScore()); list.add(result); } client.close(); return list; } public static void main(String[] args) throws Throwable { ORCUtil.upload(); System.out.println("hao"); } }
SearchController.java
package com.topprismcloud.rtm; import java.util.List; import org.springframework.http.ResponseEntity; import org.springframework.web.bind.annotation.CrossOrigin; import org.springframework.web.bind.annotation.PostMapping; import org.springframework.web.bind.annotation.RestController; import org.springframework.web.multipart.MultipartFile; @RestController @CrossOrigin public class SearchController { @PostMapping("search") public ResponseEntity search(MultipartFile file) { try { List<SearchResult> list = ORCUtil.search(file.getInputStream()); return ResponseEntity.ok(list); } catch (Throwable e) { return ResponseEntity.status(400).body(null); } } }
SearchResult.java
package com.topprismcloud.rtm; import lombok.AllArgsConstructor; import lombok.Data; @Data @AllArgsConstructor public class SearchResult { private String url; private Float score; }
5、前端
index.html
<!DOCTYPE html> <html lang="zh"> <head> <meta charset="UTF-8"> <title>以圖搜圖</title> <style> body { background: url("/img/bg.jpg"); background-attachment: fixed; background-size: 100% 100%; } body>div { width: 1000px; margin: 50px auto; padding: 10px 20px; border: 1px solid lightgray; border-radius: 20px; box-sizing: border-box; background: rgba(255, 255, 255, 0.7); } .upload { display: inline-block; width: 300px; height: 280px; border: 1px dashed lightcoral; vertical-align: top; } .upload .cover { width: 200px; height: 200px; margin: 10px 50px; border: 1px solid black; box-sizing: border-box; text-align: center; line-height: 200px; position: relative; } .upload img { width: 198px; height: 198px; position: absolute; left: 0; top: 0; } .upload input { margin-left: 50px; } .upload button { width: 80px; height: 30px; margin-left: 110px; } .result-block { display: inline-block; margin-left: 40px; border: 1px solid lightgray; border-radius: 10px; min-height: 500px; width: 600px; } .result-block h1 { text-align: center; margin-top: 100px; } .result { padding: 10px; cursor: pointer; display: inline-block; } .result:hover { background: rgb(240, 240, 240); } .result p { width: 110px; overflow: hidden; white-space: nowrap; text-overflow: ellipsis; } .result img { width: 160px; height: 160px; } .result .prob { color: rgb(37, 147, 60) } </style> <script src="js/jquery-3.6.0.js"></script> </head> <body> <div> <div class="upload"> <div class="cover"> 請選擇圖片 <img id="image" src="" /> </div> <input id="file" type="file"> </div> <div class="result-block"> <h1>請選擇圖片</h1> </div> </div> <ul id="box"> </ul> <script> var file = $('#file') file.change(function () { let f = this.files[0] let index = f.name.lastIndexOf('.') let fileText = f.name.substring(index, f.name.length) let ext = fileText.toLowerCase() //文件類型 console.log(ext) if (ext != '.png' && ext != '.jpg' && ext != '.jpeg') { alert('系統(tǒng)僅支持 JPG、PNG、JPEG 格式的圖片,請您調(diào)整格式后重新上傳') return } $('.result-block').empty().append($('<h1>正在識別中...</h1>')) $("#image").attr("src", getObjectURL(f)); let formData = new FormData() formData.append('file', f) $.ajax({ url: 'http://10.1.2.240:8081/search', method: 'post', data: formData, processData: false, contentType: false, success: res => { console.log('shibie', res) $('.result-block').empty() for (let item of res) { console.log(item) let html = `<div class="result"> <img src="${item.url}"/> <div style="display: inline-block;vertical-align: top"> <p class="prob">得分:${item.score.toFixed(4)}</p> </div> </div>` $('.result-block').append($(html)) } } }) }); $('#button').click(function (e) { var file = $('#file')[0].files[0] //單個 console.log(file) }) function getObjectURL(file) { var url = null; if (window.createObjcectURL != undefined) { url = window.createOjcectURL(file); } else if (window.URL != undefined) { url = window.URL.createObjectURL(file); } else if (window.webkitURL != undefined) { url = window.webkitURL.createObjectURL(file); } return url; } function detect() { } </script> </body> </html>
相關(guān)參考文章:Java調(diào)用Pytorch模型實現(xiàn)圖像識別
以上就是Java調(diào)用Pytorch實現(xiàn)以圖搜圖功能的詳細(xì)內(nèi)容,更多關(guān)于Java以圖搜圖的資料請關(guān)注腳本之家其它相關(guān)文章!
相關(guān)文章
Java Web項目部署在Tomcat運行出錯與解決方法示例
這篇文章主要介紹了Java Web項目部署在Tomcat運行出錯與解決方法,結(jié)合具體實例形式分析了Java Web項目部署在Tomcat過程中由于xml配置文件導(dǎo)致的錯誤問題常見提示與解決方法,需要的朋友可以參考下2017-03-03Java后端對接微信支付(小程序、APP、PC端掃碼)包含查單退款
微信支付我們主要聚焦于這三種支付方式,其中JSPAI與APP主要與uniapp開發(fā)微信小程序與APP對接,本文主要介紹了Java后端對接微信支付(小程序、APP、PC端掃碼)包含查單退款,具有一定的參考價值,感興趣的可以了解一下2021-12-12Mybatis動態(tài)查詢字段及表名的實現(xiàn)
本文主要介紹了Mybatis動態(tài)查詢字段及表名的實現(xiàn),通過靈活運用Mybatis提供的動態(tài)SQL功能,我們可以構(gòu)建更加靈活、高效的查詢語句,具有一定的參考價值,感興趣的小伙伴們可以參考一下2024-01-01Java面試題 從源碼角度分析HashSet實現(xiàn)原理
這篇文章主要介紹了Java面試題 從源碼角度分析HashSet實現(xiàn)原理?,文中通過示例代碼介紹的非常詳細(xì),對大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價值,需要的朋友可以參考下2019-07-07SpringBoot整合Swagger3生成接口文檔的示例代碼
Swagger 是一個 RESTful API 的開源框架,它的主要目的是幫助開發(fā)者設(shè)計、構(gòu)建、文檔化和測試 Web API,本文給大家介紹了SpringBoot整合Swagger3生成接口文檔的流程,并通過代碼講解的非常詳細(xì),需要的朋友可以參考下2024-04-04Java源碼解析ArrayList及ConcurrentModificationException
今天小編就為大家分享一篇關(guān)于Java源碼解析ArrayList及ConcurrentModificationException,小編覺得內(nèi)容挺不錯的,現(xiàn)在分享給大家,具有很好的參考價值,需要的朋友一起跟隨小編來看看吧2019-01-01